apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +69 -97
  5. airflow/providers/amazon/aws/auth_manager/router/login.py +9 -4
  6. airflow/providers/amazon/aws/auth_manager/user.py +7 -4
  7. airflow/providers/amazon/aws/hooks/appflow.py +5 -15
  8. airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
  9. airflow/providers/amazon/aws/hooks/ec2.py +1 -1
  10. airflow/providers/amazon/aws/hooks/eks.py +3 -6
  11. airflow/providers/amazon/aws/hooks/glue.py +6 -2
  12. airflow/providers/amazon/aws/hooks/logs.py +2 -2
  13. airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
  14. airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -1
  15. airflow/providers/amazon/aws/hooks/redshift_data.py +2 -2
  16. airflow/providers/amazon/aws/hooks/s3.py +3 -1
  17. airflow/providers/amazon/aws/hooks/sagemaker.py +1 -1
  18. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
  19. airflow/providers/amazon/aws/links/base_aws.py +8 -1
  20. airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
  21. airflow/providers/amazon/aws/log/s3_task_handler.py +22 -7
  22. airflow/providers/amazon/aws/notifications/chime.py +1 -2
  23. airflow/providers/amazon/aws/notifications/sns.py +1 -1
  24. airflow/providers/amazon/aws/notifications/sqs.py +1 -1
  25. airflow/providers/amazon/aws/operators/ec2.py +91 -83
  26. airflow/providers/amazon/aws/operators/mwaa.py +73 -2
  27. airflow/providers/amazon/aws/operators/s3.py +147 -157
  28. airflow/providers/amazon/aws/operators/sagemaker.py +1 -2
  29. airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
  30. airflow/providers/amazon/aws/sensors/ec2.py +5 -12
  31. airflow/providers/amazon/aws/sensors/emr.py +1 -1
  32. airflow/providers/amazon/aws/sensors/mwaa.py +160 -0
  33. airflow/providers/amazon/aws/sensors/rds.py +10 -5
  34. airflow/providers/amazon/aws/sensors/s3.py +31 -42
  35. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
  36. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
  37. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
  38. airflow/providers/amazon/aws/triggers/README.md +4 -4
  39. airflow/providers/amazon/aws/triggers/base.py +11 -2
  40. airflow/providers/amazon/aws/triggers/ecs.py +6 -2
  41. airflow/providers/amazon/aws/triggers/eks.py +2 -2
  42. airflow/providers/amazon/aws/triggers/glue.py +1 -1
  43. airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
  44. airflow/providers/amazon/aws/triggers/s3.py +31 -6
  45. airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
  46. airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
  47. airflow/providers/amazon/aws/triggers/sqs.py +11 -3
  48. airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
  49. airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
  50. airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
  51. airflow/providers/amazon/get_provider_info.py +45 -4
  52. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/METADATA +38 -31
  53. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/RECORD +55 -48
  54. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/WHEEL +1 -1
  55. airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
  56. {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0rc2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,73 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ """This module contains the Amazon SageMaker Unified Studio Notebook sensor."""
19
+
20
+ from __future__ import annotations
21
+
22
+ from typing import TYPE_CHECKING
23
+
24
+ from airflow.exceptions import AirflowException
25
+ from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import (
26
+ SageMakerNotebookHook,
27
+ )
28
+ from airflow.sensors.base import BaseSensorOperator
29
+
30
+ if TYPE_CHECKING:
31
+ from airflow.utils.context import Context
32
+
33
+
34
+ class SageMakerNotebookSensor(BaseSensorOperator):
35
+ """
36
+ Waits for a Sagemaker Workflows Notebook execution to reach any of the status below.
37
+
38
+ 'FAILED', 'STOPPED', 'COMPLETED'
39
+
40
+ :param execution_id: The Sagemaker Workflows Notebook running execution identifier
41
+ :param execution_name: The Sagemaker Workflows Notebook unique execution name
42
+ """
43
+
44
+ def __init__(self, *, execution_id: str, execution_name: str, **kwargs):
45
+ super().__init__(**kwargs)
46
+ self.execution_id = execution_id
47
+ self.execution_name = execution_name
48
+ self.success_state = ["COMPLETED"]
49
+ self.in_progress_states = ["PENDING", "RUNNING"]
50
+
51
+ def hook(self):
52
+ return SageMakerNotebookHook(execution_name=self.execution_name)
53
+
54
+ # override from base sensor
55
+ def poke(self, context=None):
56
+ status = self.hook().get_execution_status(execution_id=self.execution_id)
57
+
58
+ if status in self.success_state:
59
+ log_info_message = f"Exiting Execution {self.execution_id} State: {status}"
60
+ self.log.info(log_info_message)
61
+ return True
62
+ elif status in self.in_progress_states:
63
+ return False
64
+ else:
65
+ error_message = f"Exiting Execution {self.execution_id} State: {status}"
66
+ self.log.info(error_message)
67
+ raise AirflowException(error_message)
68
+
69
+ def execute(self, context: Context):
70
+ # This will invoke poke method in the base sensor
71
+ log_info_message = f"Polling Sagemaker Workflows Artifact execution: {self.execution_name} and execution id: {self.execution_id}"
72
+ self.log.info(log_info_message)
73
+ super().execute(context=context)
@@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
29
29
  from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
30
30
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
31
31
  from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
32
+ from airflow.utils.types import NOTSET, ArgNotSet
32
33
 
33
34
  if TYPE_CHECKING:
34
35
  from airflow.utils.context import Context
@@ -102,7 +103,7 @@ class RedshiftToS3Operator(BaseOperator):
102
103
  table: str | None = None,
103
104
  select_query: str | None = None,
104
105
  redshift_conn_id: str = "redshift_default",
105
- aws_conn_id: str | None = "aws_default",
106
+ aws_conn_id: str | None | ArgNotSet = NOTSET,
106
107
  verify: bool | str | None = None,
107
108
  unload_options: list | None = None,
108
109
  autocommit: bool = False,
@@ -118,7 +119,6 @@ class RedshiftToS3Operator(BaseOperator):
118
119
  self.schema = schema
119
120
  self.table = table
120
121
  self.redshift_conn_id = redshift_conn_id
121
- self.aws_conn_id = aws_conn_id
122
122
  self.verify = verify
123
123
  self.unload_options = unload_options or []
124
124
  self.autocommit = autocommit
@@ -127,6 +127,16 @@ class RedshiftToS3Operator(BaseOperator):
127
127
  self.table_as_file_name = table_as_file_name
128
128
  self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
129
129
  self.select_query = select_query
130
+ # In execute() we attempt to fetch this aws connection to check for extras. If the user didn't
131
+ # actually provide a connection note that, because we don't want to let the exception bubble up in
132
+ # that case (since we're silently injecting a connection on their behalf).
133
+ self._aws_conn_id: str | None
134
+ if isinstance(aws_conn_id, ArgNotSet):
135
+ self.conn_set = False
136
+ self._aws_conn_id = "aws_default"
137
+ else:
138
+ self.conn_set = True
139
+ self._aws_conn_id = aws_conn_id
130
140
 
131
141
  def _build_unload_query(
132
142
  self, credentials_block: str, select_query: str, s3_key: str, unload_options: str
@@ -176,11 +186,16 @@ class RedshiftToS3Operator(BaseOperator):
176
186
  raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")
177
187
  else:
178
188
  redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
179
- conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
189
+ conn = (
190
+ S3Hook.get_connection(conn_id=self._aws_conn_id)
191
+ # Only fetch the connection if it was set by the user and it is not None
192
+ if self.conn_set and self._aws_conn_id
193
+ else None
194
+ )
180
195
  if conn and conn.extra_dejson.get("role_arn", False):
181
196
  credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
182
197
  else:
183
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
198
+ s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify)
184
199
  credentials = s3_hook.get_credentials()
185
200
  credentials_block = build_credentials_block(credentials)
186
201
 
@@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
25
25
  from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
26
26
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
27
27
  from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
28
+ from airflow.utils.types import NOTSET, ArgNotSet
28
29
 
29
30
  if TYPE_CHECKING:
30
31
  from airflow.utils.context import Context
@@ -93,7 +94,7 @@ class S3ToRedshiftOperator(BaseOperator):
93
94
  s3_key: str,
94
95
  schema: str | None = None,
95
96
  redshift_conn_id: str = "redshift_default",
96
- aws_conn_id: str | None = "aws_default",
97
+ aws_conn_id: str | None | ArgNotSet = NOTSET,
97
98
  verify: bool | str | None = None,
98
99
  column_list: list[str] | None = None,
99
100
  copy_options: list | None = None,
@@ -117,6 +118,16 @@ class S3ToRedshiftOperator(BaseOperator):
117
118
  self.method = method
118
119
  self.upsert_keys = upsert_keys
119
120
  self.redshift_data_api_kwargs = redshift_data_api_kwargs or {}
121
+ # In execute() we attempt to fetch this aws connection to check for extras. If the user didn't
122
+ # actually provide a connection note that, because we don't want to let the exception bubble up in
123
+ # that case (since we're silently injecting a connection on their behalf).
124
+ self._aws_conn_id: str | None
125
+ if isinstance(aws_conn_id, ArgNotSet):
126
+ self.conn_set = False
127
+ self._aws_conn_id = "aws_default"
128
+ else:
129
+ self.conn_set = True
130
+ self._aws_conn_id = aws_conn_id
120
131
 
121
132
  if self.redshift_data_api_kwargs:
122
133
  for arg in ["sql", "parameters"]:
@@ -149,14 +160,19 @@ class S3ToRedshiftOperator(BaseOperator):
149
160
  else:
150
161
  redshift_sql_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
151
162
 
152
- conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if self.aws_conn_id else None
163
+ conn = (
164
+ S3Hook.get_connection(conn_id=self._aws_conn_id)
165
+ # Only fetch the connection if it was set by the user and it is not None
166
+ if self.conn_set and self._aws_conn_id
167
+ else None
168
+ )
153
169
  region_info = ""
154
170
  if conn and conn.extra_dejson.get("region", False):
155
171
  region_info = f"region '{conn.extra_dejson['region']}'"
156
172
  if conn and conn.extra_dejson.get("role_arn", False):
157
173
  credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
158
174
  else:
159
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
175
+ s3_hook = S3Hook(aws_conn_id=self._aws_conn_id, verify=self.verify)
160
176
  credentials = s3_hook.get_credentials()
161
177
  credentials_block = build_credentials_block(credentials)
162
178
 
@@ -65,10 +65,10 @@ To call the asynchronous `wait` function, first create a hook for the particular
65
65
  self.redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
66
66
  ```
67
67
 
68
- With this hook, we can use the async_conn property to get access to the aiobotocore client:
68
+ With this hook, we can use the asynchronous get_async_conn method to get access to the aiobotocore client:
69
69
 
70
70
  ```python
71
- async with self.redshift_hook.async_conn as client:
71
+ async with await self.redshift_hook.get_async_conn() as client:
72
72
  await client.get_waiter("cluster_available").wait(
73
73
  ClusterIdentifier=self.cluster_identifier,
74
74
  WaiterConfig={
@@ -81,7 +81,7 @@ async with self.redshift_hook.async_conn as client:
81
81
  In this case, we are using the built-in cluster_available waiter. If we wanted to use a custom waiter, we would change the code slightly to use the `get_waiter` function from the hook, rather than the aiobotocore client:
82
82
 
83
83
  ```python
84
- async with self.redshift_hook.async_conn as client:
84
+ async with await self.redshift_hook.get_async_conn() as client:
85
85
  waiter = self.redshift_hook.get_waiter("cluster_paused", deferrable=True, client=client)
86
86
  await waiter.wait(
87
87
  ClusterIdentifier=self.cluster_identifier,
@@ -131,7 +131,7 @@ For more information about writing custom waiter, see the [README.md](https://gi
131
131
  In some cases, a built-in or custom waiter may not be able to solve the problem. In such cases, the asynchronous method used to poll the boto3 API would need to be defined in the hook of the service being used. This method is essentially the same as the synchronous version of the method, except that it will use the aiobotocore client, and will be awaited. For the Redshift example, the async `describe_clusters` method would look as follows:
132
132
 
133
133
  ```python
134
- async with self.async_conn as client:
134
+ async with await self.get_async_conn() as client:
135
135
  response = client.describe_clusters(ClusterIdentifier=self.cluster_identifier)
136
136
  ```
137
137
 
@@ -55,6 +55,8 @@ class AwsBaseWaiterTrigger(BaseTrigger):
55
55
 
56
56
  :param waiter_delay: The amount of time in seconds to wait between attempts.
57
57
  :param waiter_max_attempts: The maximum number of attempts to be made.
58
+ :param waiter_config_overrides: A dict to update waiter's default configuration. Only specified keys will
59
+ be updated.
58
60
  :param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook.
59
61
  :param region_name: The AWS region where the resources to watch are. To be used to build the hook.
60
62
  :param verify: Whether or not to verify SSL certificates. To be used to build the hook.
@@ -77,6 +79,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
77
79
  return_value: Any,
78
80
  waiter_delay: int,
79
81
  waiter_max_attempts: int,
82
+ waiter_config_overrides: dict[str, Any] | None = None,
80
83
  aws_conn_id: str | None,
81
84
  region_name: str | None = None,
82
85
  verify: bool | str | None = None,
@@ -91,6 +94,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
91
94
  self.failure_message = failure_message
92
95
  self.status_message = status_message
93
96
  self.status_queries = status_queries
97
+ self.waiter_config_overrides = waiter_config_overrides
94
98
 
95
99
  self.return_key = return_key
96
100
  self.return_value = return_value
@@ -139,8 +143,13 @@ class AwsBaseWaiterTrigger(BaseTrigger):
139
143
 
140
144
  async def run(self) -> AsyncIterator[TriggerEvent]:
141
145
  hook = self.hook()
142
- async with hook.async_conn as client:
143
- waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=client)
146
+ async with await hook.get_async_conn() as client:
147
+ waiter = hook.get_waiter(
148
+ self.waiter_name,
149
+ deferrable=True,
150
+ client=client,
151
+ config_overrides=self.waiter_config_overrides,
152
+ )
144
153
  await async_wait(
145
154
  waiter,
146
155
  self.waiter_delay,
@@ -167,8 +167,12 @@ class TaskDoneTrigger(BaseTrigger):
167
167
 
168
168
  async def run(self) -> AsyncIterator[TriggerEvent]:
169
169
  async with (
170
- EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as ecs_client,
171
- AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as logs_client,
170
+ await EcsHook(
171
+ aws_conn_id=self.aws_conn_id, region_name=self.region
172
+ ).get_async_conn() as ecs_client,
173
+ await AwsLogsHook(
174
+ aws_conn_id=self.aws_conn_id, region_name=self.region
175
+ ).get_async_conn() as logs_client,
172
176
  ):
173
177
  waiter = ecs_client.get_waiter("tasks_stopped")
174
178
  logs_token = None
@@ -70,7 +70,7 @@ class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
70
70
  return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
71
71
 
72
72
  async def run(self):
73
- async with self.hook().async_conn as client:
73
+ async with await self.hook().get_async_conn() as client:
74
74
  waiter = client.get_waiter(self.waiter_name)
75
75
  try:
76
76
  await async_wait(
@@ -140,7 +140,7 @@ class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
140
140
  return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
141
141
 
142
142
  async def run(self):
143
- async with self.hook().async_conn as client:
143
+ async with await self.hook().get_async_conn() as client:
144
144
  waiter = client.get_waiter("cluster_deleted")
145
145
  if self.force_delete_compute:
146
146
  await self.delete_any_nodegroups(client=client)
@@ -157,7 +157,7 @@ class GlueCatalogPartitionTrigger(BaseTrigger):
157
157
  return bool(partitions)
158
158
 
159
159
  async def run(self) -> AsyncIterator[TriggerEvent]:
160
- async with self.hook.async_conn as client:
160
+ async with await self.hook.get_async_conn() as client:
161
161
  while True:
162
162
  result = await self.poke(client=client)
163
163
  if result:
@@ -0,0 +1,128 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ from collections.abc import Collection
21
+ from typing import TYPE_CHECKING
22
+
23
+ from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
24
+ from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
25
+ from airflow.utils.state import DagRunState
26
+
27
+ if TYPE_CHECKING:
28
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
29
+
30
+
31
+ class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger):
32
+ """
33
+ Trigger when an MWAA Dag Run is complete.
34
+
35
+ :param external_env_name: The external MWAA environment name that contains the DAG Run you want to wait for
36
+ (templated)
37
+ :param external_dag_id: The DAG ID in the external MWAA environment that contains the DAG Run you want to wait for
38
+ (templated)
39
+ :param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
40
+ :param success_states: Collection of DAG Run states that would make this task marked as successful, default is
41
+ ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
42
+ :param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
43
+ AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
44
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
45
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 720)
46
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ *,
52
+ external_env_name: str,
53
+ external_dag_id: str,
54
+ external_dag_run_id: str,
55
+ success_states: Collection[str] | None = None,
56
+ failure_states: Collection[str] | None = None,
57
+ waiter_delay: int = 60,
58
+ waiter_max_attempts: int = 720,
59
+ aws_conn_id: str | None = None,
60
+ ) -> None:
61
+ self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
62
+ self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}
63
+
64
+ if len(self.success_states & self.failure_states):
65
+ raise ValueError("success_states and failure_states must not have any values in common")
66
+
67
+ in_progress_states = {s.value for s in DagRunState} - self.success_states - self.failure_states
68
+
69
+ super().__init__(
70
+ serialized_fields={
71
+ "external_env_name": external_env_name,
72
+ "external_dag_id": external_dag_id,
73
+ "external_dag_run_id": external_dag_run_id,
74
+ "success_states": success_states,
75
+ "failure_states": failure_states,
76
+ },
77
+ waiter_name="mwaa_dag_run_complete",
78
+ waiter_args={
79
+ "Name": external_env_name,
80
+ "Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}",
81
+ "Method": "GET",
82
+ },
83
+ failure_message=f"The DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state",
84
+ status_message="State of DAG run",
85
+ status_queries=["RestApiResponse.state"],
86
+ return_key="dag_run_id",
87
+ return_value=external_dag_run_id,
88
+ waiter_delay=waiter_delay,
89
+ waiter_max_attempts=waiter_max_attempts,
90
+ aws_conn_id=aws_conn_id,
91
+ waiter_config_overrides={
92
+ "acceptors": _build_waiter_acceptors(
93
+ success_states=self.success_states,
94
+ failure_states=self.failure_states,
95
+ in_progress_states=in_progress_states,
96
+ )
97
+ },
98
+ )
99
+
100
+ def hook(self) -> AwsGenericHook:
101
+ return MwaaHook(
102
+ aws_conn_id=self.aws_conn_id,
103
+ region_name=self.region_name,
104
+ verify=self.verify,
105
+ config=self.botocore_config,
106
+ )
107
+
108
+
109
+ def _build_waiter_acceptors(
110
+ success_states: set[str], failure_states: set[str], in_progress_states: set[str]
111
+ ) -> list:
112
+ acceptors = []
113
+ for state_set, state_waiter_category in (
114
+ (success_states, "success"),
115
+ (failure_states, "failure"),
116
+ (in_progress_states, "retry"),
117
+ ):
118
+ for dag_run_state in state_set:
119
+ acceptors.append(
120
+ {
121
+ "matcher": "path",
122
+ "argument": "RestApiResponse.state",
123
+ "expected": dag_run_state,
124
+ "state": state_waiter_category,
125
+ }
126
+ )
127
+
128
+ return acceptors
@@ -53,6 +53,9 @@ class S3KeyTrigger(BaseTrigger):
53
53
  poke_interval: float = 5.0,
54
54
  should_check_fn: bool = False,
55
55
  use_regex: bool = False,
56
+ region_name: str | None = None,
57
+ verify: bool | str | None = None,
58
+ botocore_config: dict | None = None,
56
59
  **hook_params: Any,
57
60
  ):
58
61
  super().__init__()
@@ -64,6 +67,9 @@ class S3KeyTrigger(BaseTrigger):
64
67
  self.poke_interval = poke_interval
65
68
  self.should_check_fn = should_check_fn
66
69
  self.use_regex = use_regex
70
+ self.region_name = region_name
71
+ self.verify = verify
72
+ self.botocore_config = botocore_config
67
73
 
68
74
  def serialize(self) -> tuple[str, dict[str, Any]]:
69
75
  """Serialize S3KeyTrigger arguments and classpath."""
@@ -78,17 +84,25 @@ class S3KeyTrigger(BaseTrigger):
78
84
  "poke_interval": self.poke_interval,
79
85
  "should_check_fn": self.should_check_fn,
80
86
  "use_regex": self.use_regex,
87
+ "region_name": self.region_name,
88
+ "verify": self.verify,
89
+ "botocore_config": self.botocore_config,
81
90
  },
82
91
  )
83
92
 
84
93
  @cached_property
85
94
  def hook(self) -> S3Hook:
86
- return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))
95
+ return S3Hook(
96
+ aws_conn_id=self.aws_conn_id,
97
+ region_name=self.region_name,
98
+ verify=self.verify,
99
+ config=self.botocore_config,
100
+ )
87
101
 
88
102
  async def run(self) -> AsyncIterator[TriggerEvent]:
89
103
  """Make an asynchronous connection using S3HookAsync."""
90
104
  try:
91
- async with self.hook.async_conn as client:
105
+ async with await self.hook.get_async_conn() as client:
92
106
  while True:
93
107
  if await self.hook.check_key_async(
94
108
  client, self.bucket_name, self.bucket_key, self.wildcard_match, self.use_regex
@@ -143,7 +157,9 @@ class S3KeysUnchangedTrigger(BaseTrigger):
143
157
  allow_delete: bool = True,
144
158
  aws_conn_id: str | None = "aws_default",
145
159
  last_activity_time: datetime | None = None,
160
+ region_name: str | None = None,
146
161
  verify: bool | str | None = None,
162
+ botocore_config: dict | None = None,
147
163
  **hook_params: Any,
148
164
  ):
149
165
  super().__init__()
@@ -160,8 +176,10 @@ class S3KeysUnchangedTrigger(BaseTrigger):
160
176
  self.allow_delete = allow_delete
161
177
  self.aws_conn_id = aws_conn_id
162
178
  self.last_activity_time = last_activity_time
163
- self.verify = verify
164
179
  self.polling_period_seconds = 0
180
+ self.region_name = region_name
181
+ self.verify = verify
182
+ self.botocore_config = botocore_config
165
183
  self.hook_params = hook_params
166
184
 
167
185
  def serialize(self) -> tuple[str, dict[str, Any]]:
@@ -179,19 +197,26 @@ class S3KeysUnchangedTrigger(BaseTrigger):
179
197
  "aws_conn_id": self.aws_conn_id,
180
198
  "last_activity_time": self.last_activity_time,
181
199
  "hook_params": self.hook_params,
182
- "verify": self.verify,
183
200
  "polling_period_seconds": self.polling_period_seconds,
201
+ "region_name": self.region_name,
202
+ "verify": self.verify,
203
+ "botocore_config": self.botocore_config,
184
204
  },
185
205
  )
186
206
 
187
207
  @cached_property
188
208
  def hook(self) -> S3Hook:
189
- return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))
209
+ return S3Hook(
210
+ aws_conn_id=self.aws_conn_id,
211
+ region_name=self.region_name,
212
+ verify=self.verify,
213
+ config=self.botocore_config,
214
+ )
190
215
 
191
216
  async def run(self) -> AsyncIterator[TriggerEvent]:
192
217
  """Make an asynchronous connection using S3Hook."""
193
218
  try:
194
- async with self.hook.async_conn as client:
219
+ async with await self.hook.get_async_conn() as client:
195
220
  while True:
196
221
  result = await self.hook.is_keys_unchanged_async(
197
222
  client=client,
@@ -108,7 +108,7 @@ class SageMakerTrigger(BaseTrigger):
108
108
 
109
109
  async def run(self):
110
110
  self.log.info("job name is %s and job type is %s", self.job_name, self.job_type)
111
- async with self.hook.async_conn as client:
111
+ async with await self.hook.get_async_conn() as client:
112
112
  waiter = self.hook.get_waiter(
113
113
  self._get_job_type_waiter(self.job_type), deferrable=True, client=client
114
114
  )
@@ -166,7 +166,7 @@ class SageMakerPipelineTrigger(BaseTrigger):
166
166
 
167
167
  async def run(self) -> AsyncIterator[TriggerEvent]:
168
168
  hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
169
- async with hook.async_conn as conn:
169
+ async with await hook.get_async_conn() as conn:
170
170
  waiter = hook.get_waiter(self._waiter_name[self.waiter_type], deferrable=True, client=conn)
171
171
  for _ in range(self.waiter_max_attempts):
172
172
  try:
@@ -0,0 +1,66 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ """This module contains the Amazon SageMaker Unified Studio Notebook job trigger."""
19
+
20
+ from __future__ import annotations
21
+
22
+ from airflow.triggers.base import BaseTrigger
23
+
24
+
25
+ class SageMakerNotebookJobTrigger(BaseTrigger):
26
+ """
27
+ Watches for a notebook job, triggers when it finishes.
28
+
29
+ Examples:
30
+ .. code-block:: python
31
+
32
+ from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import SageMakerNotebookJobTrigger
33
+
34
+ notebook_trigger = SageMakerNotebookJobTrigger(
35
+ execution_id="notebook_job_1234",
36
+ execution_name="notebook_task",
37
+ waiter_delay=10,
38
+ waiter_max_attempts=1440,
39
+ )
40
+
41
+ :param execution_id: A unique, meaningful id for the task.
42
+ :param execution_name: A unique, meaningful name for the task.
43
+ :param waiter_delay: Interval in seconds to check the notebook execution status.
44
+ :param waiter_max_attempts: Number of attempts to wait before returning FAILED.
45
+ """
46
+
47
+ def __init__(self, execution_id, execution_name, waiter_delay, waiter_max_attempts, **kwargs):
48
+ super().__init__(**kwargs)
49
+ self.execution_id = execution_id
50
+ self.execution_name = execution_name
51
+ self.waiter_delay = waiter_delay
52
+ self.waiter_max_attempts = waiter_max_attempts
53
+
54
+ def serialize(self):
55
+ return (
56
+ # dynamically generate the fully qualified name of the class
57
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
58
+ {
59
+ "execution_id": self.execution_id,
60
+ "execution_name": self.execution_name,
61
+ "poll_interval": self.poll_interval,
62
+ },
63
+ )
64
+
65
+ async def run(self):
66
+ pass
@@ -23,14 +23,22 @@ from typing import TYPE_CHECKING, Any
23
23
  from airflow.exceptions import AirflowException
24
24
  from airflow.providers.amazon.aws.hooks.sqs import SqsHook
25
25
  from airflow.providers.amazon.aws.utils.sqs import process_response
26
- from airflow.triggers.base import BaseTrigger, TriggerEvent
26
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
27
+
28
+ if AIRFLOW_V_3_0_PLUS:
29
+ from airflow.triggers.base import BaseEventTrigger, TriggerEvent
30
+ else:
31
+ from airflow.triggers.base import ( # type: ignore
32
+ BaseTrigger as BaseEventTrigger,
33
+ TriggerEvent,
34
+ )
27
35
 
28
36
  if TYPE_CHECKING:
29
37
  from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
30
38
  from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType
31
39
 
32
40
 
33
- class SqsSensorTrigger(BaseTrigger):
41
+ class SqsSensorTrigger(BaseEventTrigger):
34
42
  """
35
43
  Asynchronously get messages from an Amazon SQS queue and then delete the messages from the queue.
36
44
 
@@ -176,7 +184,7 @@ class SqsSensorTrigger(BaseTrigger):
176
184
  while True:
177
185
  # This loop will run indefinitely until the timeout, which is set in the self.defer
178
186
  # method, is reached.
179
- async with self.hook.async_conn as client:
187
+ async with await self.hook.get_async_conn() as client:
180
188
  result = await self.poke(client=client)
181
189
  if result:
182
190
  yield TriggerEvent({"status": "success", "message_batch": result})
@@ -14,3 +14,15 @@
14
14
  # KIND, either express or implied. See the License for the
15
15
  # specific language governing permissions and limitations
16
16
  # under the License.
17
+
18
+ """This module contains utils for the Amazon SageMaker Unified Studio Notebook plugin."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import os
23
+
24
+ workflows_env_key = "WORKFLOWS_ENV"
25
+
26
+
27
+ def is_local_runner():
28
+ return os.getenv(workflows_env_key, "") == "Local"