apache-airflow-providers-amazon 8.19.0rc1__py3-none-any.whl → 8.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +4 -2
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +22 -7
  4. airflow/providers/amazon/aws/auth_manager/{cli → avp}/schema.json +34 -2
  5. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +91 -170
  6. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +7 -32
  7. airflow/providers/amazon/aws/auth_manager/cli/definition.py +1 -1
  8. airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +1 -0
  9. airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
  10. airflow/providers/amazon/aws/executors/batch/__init__.py +16 -0
  11. airflow/providers/amazon/aws/executors/batch/batch_executor.py +420 -0
  12. airflow/providers/amazon/aws/executors/batch/batch_executor_config.py +87 -0
  13. airflow/providers/amazon/aws/executors/batch/boto_schema.py +67 -0
  14. airflow/providers/amazon/aws/executors/batch/utils.py +160 -0
  15. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +61 -18
  16. airflow/providers/amazon/aws/executors/ecs/utils.py +8 -13
  17. airflow/providers/amazon/aws/executors/utils/base_config_keys.py +25 -0
  18. airflow/providers/amazon/aws/hooks/athena.py +1 -0
  19. airflow/providers/amazon/aws/hooks/base_aws.py +1 -0
  20. airflow/providers/amazon/aws/hooks/batch_client.py +4 -3
  21. airflow/providers/amazon/aws/hooks/batch_waiters.py +1 -0
  22. airflow/providers/amazon/aws/hooks/bedrock.py +59 -0
  23. airflow/providers/amazon/aws/hooks/chime.py +1 -0
  24. airflow/providers/amazon/aws/hooks/cloud_formation.py +1 -0
  25. airflow/providers/amazon/aws/hooks/datasync.py +1 -0
  26. airflow/providers/amazon/aws/hooks/dynamodb.py +1 -0
  27. airflow/providers/amazon/aws/hooks/eks.py +1 -0
  28. airflow/providers/amazon/aws/hooks/glue.py +13 -5
  29. airflow/providers/amazon/aws/hooks/glue_catalog.py +1 -0
  30. airflow/providers/amazon/aws/hooks/kinesis.py +1 -0
  31. airflow/providers/amazon/aws/hooks/lambda_function.py +1 -0
  32. airflow/providers/amazon/aws/hooks/rds.py +1 -0
  33. airflow/providers/amazon/aws/hooks/s3.py +24 -30
  34. airflow/providers/amazon/aws/hooks/ses.py +1 -0
  35. airflow/providers/amazon/aws/hooks/sns.py +1 -0
  36. airflow/providers/amazon/aws/hooks/sqs.py +1 -0
  37. airflow/providers/amazon/aws/operators/athena.py +2 -2
  38. airflow/providers/amazon/aws/operators/base_aws.py +4 -1
  39. airflow/providers/amazon/aws/operators/batch.py +4 -2
  40. airflow/providers/amazon/aws/operators/bedrock.py +252 -0
  41. airflow/providers/amazon/aws/operators/cloud_formation.py +1 -0
  42. airflow/providers/amazon/aws/operators/datasync.py +1 -0
  43. airflow/providers/amazon/aws/operators/ecs.py +9 -10
  44. airflow/providers/amazon/aws/operators/eks.py +1 -0
  45. airflow/providers/amazon/aws/operators/emr.py +57 -7
  46. airflow/providers/amazon/aws/operators/s3.py +1 -0
  47. airflow/providers/amazon/aws/operators/sns.py +1 -0
  48. airflow/providers/amazon/aws/operators/sqs.py +1 -0
  49. airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -0
  50. airflow/providers/amazon/aws/secrets/systems_manager.py +1 -0
  51. airflow/providers/amazon/aws/sensors/base_aws.py +4 -1
  52. airflow/providers/amazon/aws/sensors/bedrock.py +110 -0
  53. airflow/providers/amazon/aws/sensors/cloud_formation.py +1 -0
  54. airflow/providers/amazon/aws/sensors/eks.py +3 -4
  55. airflow/providers/amazon/aws/sensors/sqs.py +2 -1
  56. airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +4 -2
  57. airflow/providers/amazon/aws/transfers/base.py +1 -0
  58. airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -0
  59. airflow/providers/amazon/aws/transfers/gcs_to_s3.py +1 -0
  60. airflow/providers/amazon/aws/transfers/google_api_to_s3.py +1 -0
  61. airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +1 -0
  62. airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -0
  63. airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -0
  64. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +21 -19
  65. airflow/providers/amazon/aws/triggers/bedrock.py +61 -0
  66. airflow/providers/amazon/aws/triggers/eks.py +1 -1
  67. airflow/providers/amazon/aws/triggers/redshift_cluster.py +1 -0
  68. airflow/providers/amazon/aws/triggers/s3.py +4 -2
  69. airflow/providers/amazon/aws/triggers/sagemaker.py +6 -4
  70. airflow/providers/amazon/aws/utils/emailer.py +1 -0
  71. airflow/providers/amazon/aws/waiters/bedrock.json +42 -0
  72. airflow/providers/amazon/get_provider_info.py +86 -1
  73. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/METADATA +10 -9
  74. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/RECORD +77 -66
  75. /airflow/providers/amazon/aws/executors/{ecs/Dockerfile → Dockerfile} +0 -0
  76. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/WHEEL +0 -0
  77. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/entry_points.txt +0 -0
@@ -576,7 +576,7 @@ class EmrContainerOperator(BaseOperator):
576
576
  stacklevel=2,
577
577
  )
578
578
  if max_polling_attempts and max_polling_attempts != max_tries:
579
- raise Exception("max_polling_attempts must be the same value as max_tries")
579
+ raise ValueError("max_polling_attempts must be the same value as max_tries")
580
580
  else:
581
581
  self.max_polling_attempts = max_tries
582
582
 
@@ -1253,27 +1253,77 @@ class EmrServerlessStartJobOperator(BaseOperator):
1253
1253
  op_extra_links = []
1254
1254
 
1255
1255
  if isinstance(self, MappedOperator):
1256
+ operator_class = self.operator_class
1256
1257
  enable_application_ui_links = self.partial_kwargs.get(
1257
1258
  "enable_application_ui_links"
1258
1259
  ) or self.expand_input.value.get("enable_application_ui_links")
1259
- job_driver = self.partial_kwargs.get("job_driver") or self.expand_input.value.get("job_driver")
1260
+ job_driver = self.partial_kwargs.get("job_driver", {}) or self.expand_input.value.get(
1261
+ "job_driver", {}
1262
+ )
1260
1263
  configuration_overrides = self.partial_kwargs.get(
1261
1264
  "configuration_overrides"
1262
1265
  ) or self.expand_input.value.get("configuration_overrides")
1263
1266
 
1267
+ # Configuration overrides can either be a list or a dictionary, depending on whether it's passed in as partial or expand.
1268
+ if isinstance(configuration_overrides, list):
1269
+ if any(
1270
+ [
1271
+ operator_class.is_monitoring_in_job_override(
1272
+ self=operator_class,
1273
+ config_key="s3MonitoringConfiguration",
1274
+ job_override=job_override,
1275
+ )
1276
+ for job_override in configuration_overrides
1277
+ ]
1278
+ ):
1279
+ op_extra_links.extend([EmrServerlessS3LogsLink()])
1280
+ if any(
1281
+ [
1282
+ operator_class.is_monitoring_in_job_override(
1283
+ self=operator_class,
1284
+ config_key="cloudWatchLoggingConfiguration",
1285
+ job_override=job_override,
1286
+ )
1287
+ for job_override in configuration_overrides
1288
+ ]
1289
+ ):
1290
+ op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
1291
+ else:
1292
+ if operator_class.is_monitoring_in_job_override(
1293
+ self=operator_class,
1294
+ config_key="s3MonitoringConfiguration",
1295
+ job_override=configuration_overrides,
1296
+ ):
1297
+ op_extra_links.extend([EmrServerlessS3LogsLink()])
1298
+ if operator_class.is_monitoring_in_job_override(
1299
+ self=operator_class,
1300
+ config_key="cloudWatchLoggingConfiguration",
1301
+ job_override=configuration_overrides,
1302
+ ):
1303
+ op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
1304
+
1264
1305
  else:
1306
+ operator_class = self
1265
1307
  enable_application_ui_links = self.enable_application_ui_links
1266
1308
  configuration_overrides = self.configuration_overrides
1267
1309
  job_driver = self.job_driver
1268
1310
 
1311
+ if operator_class.is_monitoring_in_job_override(
1312
+ "s3MonitoringConfiguration", configuration_overrides
1313
+ ):
1314
+ op_extra_links.extend([EmrServerlessS3LogsLink()])
1315
+ if operator_class.is_monitoring_in_job_override(
1316
+ "cloudWatchLoggingConfiguration", configuration_overrides
1317
+ ):
1318
+ op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
1319
+
1269
1320
  if enable_application_ui_links:
1270
1321
  op_extra_links.extend([EmrServerlessDashboardLink()])
1271
- if "sparkSubmit" in job_driver:
1322
+ if isinstance(job_driver, list):
1323
+ if any("sparkSubmit" in ind_job_driver for ind_job_driver in job_driver):
1324
+ op_extra_links.extend([EmrServerlessLogsLink()])
1325
+ elif "sparkSubmit" in job_driver:
1272
1326
  op_extra_links.extend([EmrServerlessLogsLink()])
1273
- if self.is_monitoring_in_job_override("s3MonitoringConfiguration", configuration_overrides):
1274
- op_extra_links.extend([EmrServerlessS3LogsLink()])
1275
- if self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", configuration_overrides):
1276
- op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
1277
1327
 
1278
1328
  return tuple(op_extra_links)
1279
1329
 
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains AWS S3 operators."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import subprocess
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """Publish message to SNS queue."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -15,6 +15,7 @@
15
15
  # specific language governing permissions and limitations
16
16
  # under the License.
17
17
  """Publish message to SQS queue."""
18
+
18
19
  from __future__ import annotations
19
20
 
20
21
  from typing import TYPE_CHECKING, Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """Objects relating to sourcing secrets from AWS Secrets Manager."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import json
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """Objects relating to sourcing connections from AWS SSM Parameter Store."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import re
@@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.utils.mixins import (
26
26
  aws_template_fields,
27
27
  )
28
28
  from airflow.sensors.base import BaseSensorOperator
29
+ from airflow.utils.types import NOTSET, ArgNotSet
29
30
 
30
31
 
31
32
  class AwsBaseSensor(BaseSensorOperator, AwsBaseHookMixin[AwsHookType]):
@@ -84,10 +85,12 @@ class AwsBaseSensor(BaseSensorOperator, AwsBaseHookMixin[AwsHookType]):
84
85
  region_name: str | None = None,
85
86
  verify: bool | str | None = None,
86
87
  botocore_config: dict | None = None,
88
+ region: str | None | ArgNotSet = NOTSET, # Required for `.partial` signature check
87
89
  **kwargs,
88
90
  ):
91
+ additional_params = {} if region is NOTSET else {"region": region}
89
92
  hook_params = AwsHookParams.from_constructor(
90
- aws_conn_id, region_name, verify, botocore_config, additional_params=kwargs
93
+ aws_conn_id, region_name, verify, botocore_config, additional_params=additional_params
91
94
  )
92
95
  super().__init__(**kwargs)
93
96
  self.aws_conn_id = hook_params.aws_conn_id
@@ -0,0 +1,110 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ from __future__ import annotations
19
+
20
+ from typing import TYPE_CHECKING, Any, Sequence
21
+
22
+ from airflow.configuration import conf
23
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
24
+ from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger
25
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
26
+
27
+ if TYPE_CHECKING:
28
+ from airflow.utils.context import Context
29
+
30
+ from airflow.exceptions import AirflowException, AirflowSkipException
31
+ from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
32
+
33
+
34
+ class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]):
35
+ """
36
+ Poll the state of the model customization job until it reaches a terminal state; fails if the job fails.
37
+
38
+ .. seealso::
39
+ For more information on how to use this sensor, take a look at the guide:
40
+ :ref:`howto/sensor:BedrockCustomizeModelCompletedSensor`
41
+
42
+
43
+ :param job_name: The name of the Bedrock model customization job.
44
+
45
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
46
+ module to be installed.
47
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
48
+ :param max_retries: Number of times before returning the current state. (default: 75)
49
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
50
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
51
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
52
+ running Airflow in a distributed manner and aws_conn_id is None or
53
+ empty, then default boto3 configuration would be used (and must be
54
+ maintained on each worker node).
55
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
56
+ :param verify: Whether or not to verify SSL certificates. See:
57
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
58
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
59
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
60
+ """
61
+
62
+ INTERMEDIATE_STATES = ("InProgress",)
63
+ FAILURE_STATES = ("Failed", "Stopping", "Stopped")
64
+ SUCCESS_STATES = ("Completed",)
65
+ FAILURE_MESSAGE = "Bedrock model customization job sensor failed."
66
+
67
+ aws_hook_class = BedrockHook
68
+ template_fields: Sequence[str] = aws_template_fields("job_name")
69
+ ui_color = "#66c3ff"
70
+
71
+ def __init__(
72
+ self,
73
+ *,
74
+ job_name: str,
75
+ max_retries: int = 75,
76
+ poke_interval: int = 120,
77
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
78
+ **kwargs: Any,
79
+ ) -> None:
80
+ super().__init__(**kwargs)
81
+ self.job_name = job_name
82
+ self.poke_interval = poke_interval
83
+ self.max_retries = max_retries
84
+ self.deferrable = deferrable
85
+
86
+ def execute(self, context: Context) -> Any:
87
+ if self.deferrable:
88
+ self.defer(
89
+ trigger=BedrockCustomizeModelCompletedTrigger(
90
+ job_name=self.job_name,
91
+ waiter_delay=int(self.poke_interval),
92
+ waiter_max_attempts=self.max_retries,
93
+ aws_conn_id=self.aws_conn_id,
94
+ ),
95
+ method_name="poke",
96
+ )
97
+ else:
98
+ super().execute(context=context)
99
+
100
+ def poke(self, context: Context) -> bool:
101
+ state = self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"]
102
+ self.log.info("Job '%s' state: %s", self.job_name, state)
103
+
104
+ if state in self.FAILURE_STATES:
105
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
106
+ if self.soft_fail:
107
+ raise AirflowSkipException(self.FAILURE_MESSAGE)
108
+ raise AirflowException(self.FAILURE_MESSAGE)
109
+
110
+ return state not in self.INTERMEDIATE_STATES
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains sensors for AWS CloudFormation."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -15,6 +15,7 @@
15
15
  # specific language governing permissions and limitations
16
16
  # under the License.
17
17
  """Tracking the state of Amazon EKS Clusters, Amazon EKS managed node groups, and AWS Fargate profiles."""
18
+
18
19
  from __future__ import annotations
19
20
 
20
21
  from abc import abstractmethod
@@ -114,12 +115,10 @@ class EksBaseSensor(BaseSensorOperator):
114
115
  return state == self.target_state
115
116
 
116
117
  @abstractmethod
117
- def get_state(self) -> ClusterStates | NodegroupStates | FargateProfileStates:
118
- ...
118
+ def get_state(self) -> ClusterStates | NodegroupStates | FargateProfileStates: ...
119
119
 
120
120
  @abstractmethod
121
- def get_terminal_states(self) -> frozenset:
122
- ...
121
+ def get_terminal_states(self) -> frozenset: ...
123
122
 
124
123
 
125
124
  class EksClusterStateSensor(EksBaseSensor):
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """Reads and then deletes the message from SQS queue."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from datetime import timedelta
@@ -72,7 +73,7 @@ class SqsSensor(AwsBaseSensor[SqsHook]):
72
73
  :param delete_message_on_reception: Default to `True`, the messages are deleted from the queue
73
74
  as soon as being consumed. Otherwise, the messages remain in the queue after consumption and
74
75
  should be deleted manually.
75
- :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
76
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
76
77
  module to be installed.
77
78
  (default: False, but can be overridden in config file by setting default_deferrable to True)
78
79
  :param aws_conn_id: The Airflow connection used for AWS credentials.
@@ -120,8 +120,10 @@ class AzureBlobStorageToS3Operator(BaseOperator):
120
120
  )
121
121
 
122
122
  self.log.info(
123
- f"Getting list of the files in Container: {self.container_name}; "
124
- f"Prefix: {self.prefix}; Delimiter: {self.delimiter};"
123
+ "Getting list of the files in Container: %r; Prefix: %r; Delimiter: %r.",
124
+ self.container_name,
125
+ self.prefix,
126
+ self.delimiter,
125
127
  )
126
128
 
127
129
  files = wasb_hook.get_blobs_list_recursive(
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains base AWS to AWS transfer operator."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import warnings
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """Transfers data from Exasol database into a S3 Bucket."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from tempfile import NamedTemporaryFile
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains Google Cloud Storage to S3 operator."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import os
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module allows you to transfer data from any Google API endpoint into a S3 Bucket."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import json
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains operator to move data from Hive to DynamoDB."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import json
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains operator to move data from HTTP endpoint to S3."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from functools import cached_property
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module allows you to transfer mail attachments from a mail server into s3 bucket."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """Transfers data from AWS Redshift into a S3 Bucket."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import re
@@ -109,35 +110,19 @@ class RedshiftToS3Operator(BaseOperator):
109
110
  ) -> None:
110
111
  super().__init__(**kwargs)
111
112
  self.s3_bucket = s3_bucket
112
- self.s3_key = f"{s3_key}/{table}_" if (table and table_as_file_name) else s3_key
113
+ self.s3_key = s3_key
113
114
  self.schema = schema
114
115
  self.table = table
115
116
  self.redshift_conn_id = redshift_conn_id
116
117
  self.aws_conn_id = aws_conn_id
117
118
  self.verify = verify
118
- self.unload_options: list = unload_options or []
119
+ self.unload_options = unload_options or []
119
120
  self.autocommit = autocommit
120
121
  self.include_header = include_header
121
122
  self.parameters = parameters
122
123
  self.table_as_file_name = table_as_file_name
123
124
  self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
124
-
125
- if select_query:
126
- self.select_query = select_query
127
- elif self.schema and self.table:
128
- self.select_query = f"SELECT * FROM {self.schema}.{self.table}"
129
- else:
130
- raise ValueError(
131
- "Please provide both `schema` and `table` params or `select_query` to fetch the data."
132
- )
133
-
134
- if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
135
- self.unload_options = [*self.unload_options, "HEADER"]
136
-
137
- if self.redshift_data_api_kwargs:
138
- for arg in ["sql", "parameters"]:
139
- if arg in self.redshift_data_api_kwargs:
140
- raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
125
+ self.select_query = select_query
141
126
 
142
127
  def _build_unload_query(
143
128
  self, credentials_block: str, select_query: str, s3_key: str, unload_options: str
@@ -153,9 +138,26 @@ class RedshiftToS3Operator(BaseOperator):
153
138
  """
154
139
 
155
140
  def execute(self, context: Context) -> None:
141
+ if self.table and self.table_as_file_name:
142
+ self.s3_key = f"{self.s3_key}/{self.table}_"
143
+
144
+ if self.schema and self.table:
145
+ self.select_query = f"SELECT * FROM {self.schema}.{self.table}"
146
+
147
+ if self.select_query is None:
148
+ raise ValueError(
149
+ "Please provide both `schema` and `table` params or `select_query` to fetch the data."
150
+ )
151
+
152
+ if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
153
+ self.unload_options = [*self.unload_options, "HEADER"]
154
+
156
155
  redshift_hook: RedshiftDataHook | RedshiftSQLHook
157
156
  if self.redshift_data_api_kwargs:
158
157
  redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
158
+ for arg in ["sql", "parameters"]:
159
+ if arg in self.redshift_data_api_kwargs:
160
+ raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
159
161
  else:
160
162
  redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
161
163
  conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
@@ -0,0 +1,61 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
22
+ from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
23
+
24
+ if TYPE_CHECKING:
25
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
26
+
27
+
28
+ class BedrockCustomizeModelCompletedTrigger(AwsBaseWaiterTrigger):
29
+ """
30
+ Trigger when a Bedrock model customization job is complete.
31
+
32
+ :param job_name: The name of the Bedrock model customization job.
33
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
34
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
35
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ *,
41
+ job_name: str,
42
+ waiter_delay: int = 120,
43
+ waiter_max_attempts: int = 75,
44
+ aws_conn_id: str | None = None,
45
+ ) -> None:
46
+ super().__init__(
47
+ serialized_fields={"job_name": job_name},
48
+ waiter_name="model_customization_job_complete",
49
+ waiter_args={"jobIdentifier": job_name},
50
+ failure_message="Bedrock model customization failed.",
51
+ status_message="Status of Bedrock model customization job is",
52
+ status_queries=["status"],
53
+ return_key="job_name",
54
+ return_value=job_name,
55
+ waiter_delay=waiter_delay,
56
+ waiter_max_attempts=waiter_max_attempts,
57
+ aws_conn_id=aws_conn_id,
58
+ )
59
+
60
+ def hook(self) -> AwsGenericHook:
61
+ return BedrockHook(aws_conn_id=self.aws_conn_id)
@@ -214,7 +214,7 @@ class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
214
214
  )
215
215
  self.log.info("All Fargate profiles deleted")
216
216
  else:
217
- self.log.info(f"No Fargate profiles associated with cluster {self.cluster_name}")
217
+ self.log.info("No Fargate profiles associated with cluster %s", self.cluster_name)
218
218
 
219
219
 
220
220
  class EksCreateFargateProfileTrigger(AwsBaseWaiterTrigger):
@@ -311,6 +311,7 @@ class RedshiftClusterTrigger(BaseTrigger):
311
311
  "status"
312
312
  ] == "error":
313
313
  yield TriggerEvent(res)
314
+ return
314
315
  await asyncio.sleep(self.poke_interval)
315
316
  except Exception as e:
316
317
  yield TriggerEvent({"status": "error", "message": str(e)})
@@ -98,8 +98,9 @@ class S3KeyTrigger(BaseTrigger):
98
98
  )
99
99
  await asyncio.sleep(self.poke_interval)
100
100
  yield TriggerEvent({"status": "running", "files": s3_objects})
101
-
102
- yield TriggerEvent({"status": "success"})
101
+ else:
102
+ yield TriggerEvent({"status": "success"})
103
+ return
103
104
 
104
105
  self.log.info("Sleeping for %s seconds", self.poke_interval)
105
106
  await asyncio.sleep(self.poke_interval)
@@ -204,6 +205,7 @@ class S3KeysUnchangedTrigger(BaseTrigger):
204
205
  )
205
206
  if result.get("status") in ("success", "error"):
206
207
  yield TriggerEvent(result)
208
+ return
207
209
  elif result.get("status") == "pending":
208
210
  self.previous_objects = result.get("previous_objects", set())
209
211
  self.last_activity_time = result.get("last_activity_time")
@@ -245,8 +245,8 @@ class SageMakerTrainingPrintLogTrigger(BaseTrigger):
245
245
  job_already_completed = status not in self.hook.non_terminal_states
246
246
  state = LogState.COMPLETE if job_already_completed else LogState.TAILING
247
247
  last_describe_job_call = time.time()
248
- while True:
249
- try:
248
+ try:
249
+ while True:
250
250
  (
251
251
  state,
252
252
  last_description,
@@ -267,6 +267,7 @@ class SageMakerTrainingPrintLogTrigger(BaseTrigger):
267
267
  reason = last_description.get("FailureReason", "(No reason provided)")
268
268
  error_message = f"SageMaker job failed because {reason}"
269
269
  yield TriggerEvent({"status": "error", "message": error_message})
270
+ return
270
271
  else:
271
272
  billable_seconds = SageMakerHook.count_billable_seconds(
272
273
  training_start_time=last_description["TrainingStartTime"],
@@ -275,5 +276,6 @@ class SageMakerTrainingPrintLogTrigger(BaseTrigger):
275
276
  )
276
277
  self.log.info("Billable seconds: %d", billable_seconds)
277
278
  yield TriggerEvent({"status": "success", "message": last_description})
278
- except Exception as e:
279
- yield TriggerEvent({"status": "error", "message": str(e)})
279
+ return
280
+ except Exception as e:
281
+ yield TriggerEvent({"status": "error", "message": str(e)})
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """Airflow module for email backend using AWS SES."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import Any
@@ -0,0 +1,42 @@
1
+ {
2
+ "version": 2,
3
+ "waiters": {
4
+ "model_customization_job_complete": {
5
+ "delay": 120,
6
+ "maxAttempts": 75,
7
+ "operation": "GetModelCustomizationJob",
8
+ "acceptors": [
9
+ {
10
+ "matcher": "path",
11
+ "argument": "status",
12
+ "expected": "InProgress",
13
+ "state": "retry"
14
+ },
15
+ {
16
+ "matcher": "path",
17
+ "argument": "status",
18
+ "expected": "Completed",
19
+ "state": "success"
20
+ },
21
+ {
22
+ "matcher": "path",
23
+ "argument": "status",
24
+ "expected": "Failed",
25
+ "state": "failure"
26
+ },
27
+ {
28
+ "matcher": "path",
29
+ "argument": "status",
30
+ "expected": "Stopping",
31
+ "state": "failure"
32
+ },
33
+ {
34
+ "matcher": "path",
35
+ "argument": "status",
36
+ "expected": "Stopped",
37
+ "state": "failure"
38
+ }
39
+ ]
40
+ }
41
+ }
42
+ }