apache-airflow-providers-amazon 8.26.0rc1__py3-none-any.whl → 8.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +10 -0
  3. airflow/providers/amazon/aws/datasets/__init__.py +16 -0
  4. airflow/providers/amazon/aws/datasets/s3.py +45 -0
  5. airflow/providers/amazon/aws/executors/batch/batch_executor.py +27 -17
  6. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +31 -13
  7. airflow/providers/amazon/aws/hooks/kinesis_analytics.py +65 -0
  8. airflow/providers/amazon/aws/hooks/rds.py +3 -3
  9. airflow/providers/amazon/aws/hooks/s3.py +26 -1
  10. airflow/providers/amazon/aws/hooks/step_function.py +18 -0
  11. airflow/providers/amazon/aws/operators/athena.py +16 -17
  12. airflow/providers/amazon/aws/operators/emr.py +23 -23
  13. airflow/providers/amazon/aws/operators/kinesis_analytics.py +348 -0
  14. airflow/providers/amazon/aws/operators/rds.py +17 -20
  15. airflow/providers/amazon/aws/operators/redshift_cluster.py +71 -53
  16. airflow/providers/amazon/aws/operators/s3.py +18 -12
  17. airflow/providers/amazon/aws/operators/sagemaker.py +12 -27
  18. airflow/providers/amazon/aws/operators/step_function.py +12 -2
  19. airflow/providers/amazon/aws/sensors/kinesis_analytics.py +234 -0
  20. airflow/providers/amazon/aws/sensors/s3.py +11 -5
  21. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -0
  22. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
  23. airflow/providers/amazon/aws/triggers/emr.py +3 -1
  24. airflow/providers/amazon/aws/triggers/kinesis_analytics.py +69 -0
  25. airflow/providers/amazon/aws/triggers/sagemaker.py +9 -1
  26. airflow/providers/amazon/aws/waiters/kinesisanalyticsv2.json +151 -0
  27. airflow/providers/amazon/aws/waiters/rds.json +253 -0
  28. airflow/providers/amazon/get_provider_info.py +35 -2
  29. {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/METADATA +32 -25
  30. {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/RECORD +32 -24
  31. {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/WHEEL +0 -0
  32. {apache_airflow_providers_amazon-8.26.0rc1.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "8.26.0"
32
+ __version__ = "8.27.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.7.0"
@@ -81,6 +81,16 @@ class AwsAuthManager(BaseAuthManager):
81
81
  """
82
82
 
83
83
  def __init__(self, appbuilder: AirflowAppBuilder) -> None:
84
+ from packaging.version import Version
85
+
86
+ from airflow.version import version
87
+
88
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.9.0
89
+ if Version(version) < Version("2.9"):
90
+ raise AirflowOptionalProviderFeatureException(
91
+ "``AwsAuthManager`` is compatible with Airflow versions >= 2.9."
92
+ )
93
+
84
94
  super().__init__(appbuilder)
85
95
  self._check_avp_schema_version()
86
96
 
@@ -0,0 +1,16 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
@@ -0,0 +1,45 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ from airflow.datasets import Dataset
22
+ from airflow.providers.amazon.aws.hooks.s3 import S3Hook
23
+
24
+ if TYPE_CHECKING:
25
+ from urllib.parse import SplitResult
26
+
27
+ from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset
28
+
29
+
30
+ def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset:
31
+ return Dataset(uri=f"s3://{bucket}/{key}", extra=extra)
32
+
33
+
34
+ def sanitize_uri(uri: SplitResult) -> SplitResult:
35
+ if not uri.netloc:
36
+ raise ValueError("URI format s3:// must contain a bucket name")
37
+ return uri
38
+
39
+
40
+ def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset:
41
+ """Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the hook."""
42
+ from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset
43
+
44
+ bucket, key = S3Hook.parse_s3_url(dataset.uri)
45
+ return OpenLineageDataset(namespace=f"s3://{bucket}", name=key if key else "/")
@@ -19,9 +19,9 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- import contextlib
23
22
  import time
24
- from collections import defaultdict, deque
23
+ from collections import deque
24
+ from contextlib import suppress
25
25
  from copy import deepcopy
26
26
  from typing import TYPE_CHECKING, Any, Dict, List, Sequence
27
27
 
@@ -264,7 +264,6 @@ class AwsBatchExecutor(BaseExecutor):
264
264
  in the next iteration of the sync() method, unless it has exceeded the maximum number of
265
265
  attempts. If a job exceeds the maximum number of attempts, it is removed from the queue.
266
266
  """
267
- failure_reasons = defaultdict(int)
268
267
  for _ in range(len(self.pending_jobs)):
269
268
  batch_job = self.pending_jobs.popleft()
270
269
  key = batch_job.key
@@ -272,7 +271,7 @@ class AwsBatchExecutor(BaseExecutor):
272
271
  queue = batch_job.queue
273
272
  exec_config = batch_job.executor_config
274
273
  attempt_number = batch_job.attempt_number
275
- _failure_reason = []
274
+ failure_reason: str | None = None
276
275
  if timezone.utcnow() < batch_job.next_attempt_time:
277
276
  self.pending_jobs.append(batch_job)
278
277
  continue
@@ -286,18 +285,25 @@ class AwsBatchExecutor(BaseExecutor):
286
285
  if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
287
286
  self.pending_jobs.append(batch_job)
288
287
  raise
289
- _failure_reason.append(str(e))
288
+ failure_reason = str(e)
290
289
  except Exception as e:
291
- _failure_reason.append(str(e))
292
-
293
- if _failure_reason:
294
- for reason in _failure_reason:
295
- failure_reasons[reason] += 1
290
+ failure_reason = str(e)
296
291
 
292
+ if failure_reason:
297
293
  if attempt_number >= int(self.__class__.MAX_SUBMIT_JOB_ATTEMPTS):
298
294
  self.log.error(
299
- "This job has been unsuccessfully attempted too many times (%s). Dropping the task.",
295
+ (
296
+ "This job has been unsuccessfully attempted too many times (%s). "
297
+ "Dropping the task. Reason: %s"
298
+ ),
300
299
  attempt_number,
300
+ failure_reason,
301
+ )
302
+ self.log_task_event(
303
+ event="batch job submit failure",
304
+ extra=f"This job has been unsuccessfully attempted too many times ({attempt_number}). "
305
+ f"Dropping the task. Reason: {failure_reason}",
306
+ ti_key=key,
301
307
  )
302
308
  self.fail(key=key)
303
309
  else:
@@ -317,16 +323,11 @@ class AwsBatchExecutor(BaseExecutor):
317
323
  exec_config=exec_config,
318
324
  attempt_number=attempt_number,
319
325
  )
320
- with contextlib.suppress(AttributeError):
326
+ with suppress(AttributeError):
321
327
  # TODO: Remove this when min_airflow_version is 2.10.0 or higher in Amazon provider.
322
328
  # running_state is added in Airflow 2.10 and only needed to support task adoption
323
329
  # (an optional executor feature).
324
330
  self.running_state(key, job_id)
325
- if failure_reasons:
326
- self.log.error(
327
- "Pending Batch jobs failed to launch for the following reasons: %s. Retrying later.",
328
- dict(failure_reasons),
329
- )
330
331
 
331
332
  def _describe_jobs(self, job_ids) -> list[BatchJob]:
332
333
  all_jobs = []
@@ -462,3 +463,12 @@ class AwsBatchExecutor(BaseExecutor):
462
463
 
463
464
  not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
464
465
  return not_adopted_tis
466
+
467
+ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
468
+ # TODO: remove this method when min_airflow_version is set to higher than 2.10.0
469
+ with suppress(AttributeError):
470
+ super().log_task_event(
471
+ event=event,
472
+ extra=extra,
473
+ ti_key=ti_key,
474
+ )
@@ -25,6 +25,7 @@ from __future__ import annotations
25
25
 
26
26
  import time
27
27
  from collections import defaultdict, deque
28
+ from contextlib import suppress
28
29
  from copy import deepcopy
29
30
  from typing import TYPE_CHECKING, Sequence
30
31
 
@@ -347,7 +348,7 @@ class AwsEcsExecutor(BaseExecutor):
347
348
  queue = ecs_task.queue
348
349
  exec_config = ecs_task.executor_config
349
350
  attempt_number = ecs_task.attempt_number
350
- _failure_reasons = []
351
+ failure_reasons = []
351
352
  if timezone.utcnow() < ecs_task.next_attempt_time:
352
353
  self.pending_tasks.append(ecs_task)
353
354
  continue
@@ -361,23 +362,21 @@ class AwsEcsExecutor(BaseExecutor):
361
362
  if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
362
363
  self.pending_tasks.append(ecs_task)
363
364
  raise
364
- _failure_reasons.append(str(e))
365
+ failure_reasons.append(str(e))
365
366
  except Exception as e:
366
367
  # Failed to even get a response back from the Boto3 API or something else went
367
368
  # wrong. For any possible failure we want to add the exception reasons to the
368
369
  # failure list so that it is logged to the user and most importantly the task is
369
370
  # added back to the pending list to be retried later.
370
- _failure_reasons.append(str(e))
371
+ failure_reasons.append(str(e))
371
372
  else:
372
373
  # We got a response back, check if there were failures. If so, add them to the
373
374
  # failures list so that it is logged to the user and most importantly the task
374
375
  # is added back to the pending list to be retried later.
375
376
  if run_task_response["failures"]:
376
- _failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])
377
+ failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])
377
378
 
378
- if _failure_reasons:
379
- for reason in _failure_reasons:
380
- failure_reasons[reason] += 1
379
+ if failure_reasons:
381
380
  # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
382
381
  if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
383
382
  ecs_task.attempt_number += 1
@@ -386,14 +385,29 @@ class AwsEcsExecutor(BaseExecutor):
386
385
  )
387
386
  self.pending_tasks.append(ecs_task)
388
387
  else:
388
+ reasons_str = ", ".join(failure_reasons)
389
389
  self.log.error(
390
- "ECS task %s has failed a maximum of %s times. Marking as failed",
390
+ "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
391
391
  task_key,
392
392
  attempt_number,
393
+ reasons_str,
394
+ )
395
+ self.log_task_event(
396
+ event="ecs task submit failure",
397
+ ti_key=task_key,
398
+ extra=(
399
+ f"Task could not be queued after {attempt_number} attempts. "
400
+ f"Marking as failed. Reasons: {reasons_str}"
401
+ ),
393
402
  )
394
403
  self.fail(task_key)
395
404
  elif not run_task_response["tasks"]:
396
405
  self.log.error("ECS RunTask Response: %s", run_task_response)
406
+ self.log_task_event(
407
+ event="ecs task submit failure",
408
+ extra=f"ECS RunTask Response: {run_task_response}",
409
+ ti_key=task_key,
410
+ )
397
411
  raise EcsExecutorException(
398
412
  "No failures and no ECS tasks provided in response. This should never happen."
399
413
  )
@@ -407,11 +421,6 @@ class AwsEcsExecutor(BaseExecutor):
407
421
  # executor feature).
408
422
  # TODO: remove when min airflow version >= 2.9.2
409
423
  pass
410
- if failure_reasons:
411
- self.log.error(
412
- "Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
413
- dict(failure_reasons),
414
- )
415
424
 
416
425
  def _run_task(
417
426
  self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
@@ -543,3 +552,12 @@ class AwsEcsExecutor(BaseExecutor):
543
552
 
544
553
  not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
545
554
  return not_adopted_tis
555
+
556
+ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
557
+ # TODO: remove this method when min_airflow_version is set to higher than 2.10.0
558
+ with suppress(AttributeError):
559
+ super().log_task_event(
560
+ event=event,
561
+ extra=extra,
562
+ ti_key=ti_key,
563
+ )
@@ -0,0 +1,65 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
20
+
21
+
22
+ class KinesisAnalyticsV2Hook(AwsBaseHook):
23
+ """
24
+ Interact with Amazon Kinesis Analytics V2.
25
+
26
+ Provide thin wrapper around :external+boto3:py:class:`boto3.client("kinesisanalyticsv2") <KinesisAnalyticsV2.Client>`.
27
+
28
+ Additional arguments (such as ``aws_conn_id``) may be specified and
29
+ are passed down to the underlying AwsBaseHook.
30
+
31
+ .. seealso::
32
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
33
+ """
34
+
35
+ APPLICATION_START_INTERMEDIATE_STATES: tuple[str, ...] = ("STARTING", "UPDATING", "AUTOSCALING")
36
+ APPLICATION_START_FAILURE_STATES: tuple[str, ...] = (
37
+ "DELETING",
38
+ "STOPPING",
39
+ "READY",
40
+ "FORCE_STOPPING",
41
+ "ROLLING_BACK",
42
+ "MAINTENANCE",
43
+ "ROLLED_BACK",
44
+ )
45
+ APPLICATION_START_SUCCESS_STATES: tuple[str, ...] = ("RUNNING",)
46
+
47
+ APPLICATION_STOP_INTERMEDIATE_STATES: tuple[str, ...] = (
48
+ "STARTING",
49
+ "UPDATING",
50
+ "AUTOSCALING",
51
+ "RUNNING",
52
+ "STOPPING",
53
+ "FORCE_STOPPING",
54
+ )
55
+ APPLICATION_STOP_FAILURE_STATES: tuple[str, ...] = (
56
+ "DELETING",
57
+ "ROLLING_BACK",
58
+ "MAINTENANCE",
59
+ "ROLLED_BACK",
60
+ )
61
+ APPLICATION_STOP_SUCCESS_STATES: tuple[str, ...] = ("READY",)
62
+
63
+ def __init__(self, *args, **kwargs) -> None:
64
+ kwargs["client_type"] = "kinesisanalyticsv2"
65
+ super().__init__(*args, **kwargs)
@@ -259,7 +259,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
259
259
  return self.get_db_instance_state(db_instance_id)
260
260
 
261
261
  target_state = target_state.lower()
262
- if target_state in ("available", "deleted"):
262
+ if target_state in ("available", "deleted", "stopped"):
263
263
  waiter = self.conn.get_waiter(f"db_instance_{target_state}") # type: ignore
264
264
  wait(
265
265
  waiter=waiter,
@@ -272,7 +272,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
272
272
  )
273
273
  else:
274
274
  self._wait_for_state(poke, target_state, check_interval, max_attempts)
275
- self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_instance_id, target_state)
275
+ self.log.info("DB cluster '%s' reached the '%s' state", db_instance_id, target_state)
276
276
 
277
277
  def get_db_cluster_state(self, db_cluster_id: str) -> str:
278
278
  """
@@ -310,7 +310,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
310
310
  return self.get_db_cluster_state(db_cluster_id)
311
311
 
312
312
  target_state = target_state.lower()
313
- if target_state in ("available", "deleted"):
313
+ if target_state in ("available", "deleted", "stopped"):
314
314
  waiter = self.conn.get_waiter(f"db_cluster_{target_state}") # type: ignore
315
315
  waiter.wait(
316
316
  DBClusterIdentifier=db_cluster_id,
@@ -41,6 +41,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
41
41
  from urllib.parse import urlsplit
42
42
  from uuid import uuid4
43
43
 
44
+ from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
45
+
44
46
  if TYPE_CHECKING:
45
47
  from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject
46
48
 
@@ -1111,6 +1113,12 @@ class S3Hook(AwsBaseHook):
1111
1113
 
1112
1114
  client = self.get_conn()
1113
1115
  client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config)
1116
+ get_hook_lineage_collector().add_input_dataset(
1117
+ context=self, scheme="file", dataset_kwargs={"path": filename}
1118
+ )
1119
+ get_hook_lineage_collector().add_output_dataset(
1120
+ context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
1121
+ )
1114
1122
 
1115
1123
  @unify_bucket_name_and_key
1116
1124
  @provide_bucket_name
@@ -1251,6 +1259,10 @@ class S3Hook(AwsBaseHook):
1251
1259
  ExtraArgs=extra_args,
1252
1260
  Config=self.transfer_config,
1253
1261
  )
1262
+ # No input because file_obj can be anything - handle in calling function if possible
1263
+ get_hook_lineage_collector().add_output_dataset(
1264
+ context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
1265
+ )
1254
1266
 
1255
1267
  def copy_object(
1256
1268
  self,
@@ -1306,6 +1318,12 @@ class S3Hook(AwsBaseHook):
1306
1318
  response = self.get_conn().copy_object(
1307
1319
  Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
1308
1320
  )
1321
+ get_hook_lineage_collector().add_input_dataset(
1322
+ context=self, scheme="s3", dataset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key}
1323
+ )
1324
+ get_hook_lineage_collector().add_output_dataset(
1325
+ context=self, scheme="s3", dataset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key}
1326
+ )
1309
1327
  return response
1310
1328
 
1311
1329
  @provide_bucket_name
@@ -1425,6 +1443,11 @@ class S3Hook(AwsBaseHook):
1425
1443
 
1426
1444
  file_path.parent.mkdir(exist_ok=True, parents=True)
1427
1445
 
1446
+ get_hook_lineage_collector().add_output_dataset(
1447
+ context=self,
1448
+ scheme="file",
1449
+ dataset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()},
1450
+ )
1428
1451
  file = open(file_path, "wb")
1429
1452
  else:
1430
1453
  file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore
@@ -1435,7 +1458,9 @@ class S3Hook(AwsBaseHook):
1435
1458
  ExtraArgs=self.extra_args,
1436
1459
  Config=self.transfer_config,
1437
1460
  )
1438
-
1461
+ get_hook_lineage_collector().add_input_dataset(
1462
+ context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
1463
+ )
1439
1464
  return file.name
1440
1465
 
1441
1466
  def generate_presigned_url(
@@ -18,6 +18,7 @@ from __future__ import annotations
18
18
 
19
19
  import json
20
20
 
21
+ from airflow.exceptions import AirflowFailException
21
22
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
22
23
 
23
24
 
@@ -43,6 +44,7 @@ class StepFunctionHook(AwsBaseHook):
43
44
  state_machine_arn: str,
44
45
  name: str | None = None,
45
46
  state_machine_input: dict | str | None = None,
47
+ is_redrive_execution: bool = False,
46
48
  ) -> str:
47
49
  """
48
50
  Start Execution of the State Machine.
@@ -51,10 +53,26 @@ class StepFunctionHook(AwsBaseHook):
51
53
  - :external+boto3:py:meth:`SFN.Client.start_execution`
52
54
 
53
55
  :param state_machine_arn: AWS Step Function State Machine ARN.
56
+ :param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not
57
+ complete successfully in the last 14 days.
54
58
  :param name: The name of the execution.
55
59
  :param state_machine_input: JSON data input to pass to the State Machine.
56
60
  :return: Execution ARN.
57
61
  """
62
+ if is_redrive_execution:
63
+ if not name:
64
+ raise AirflowFailException(
65
+ "Execution name is required to start RedriveExecution for %s.", state_machine_arn
66
+ )
67
+ elements = state_machine_arn.split(":stateMachine:")
68
+ execution_arn = f"{elements[0]}:execution:{elements[1]}:{name}"
69
+ self.conn.redrive_execution(executionArn=execution_arn)
70
+ self.log.info(
71
+ "Successfully started RedriveExecution for Step Function State Machine: %s.",
72
+ state_machine_arn,
73
+ )
74
+ return execution_arn
75
+
58
76
  execution_args = {"stateMachineArn": state_machine_arn}
59
77
  if name is not None:
60
78
  execution_args["name"] = name
@@ -30,9 +30,7 @@ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
30
30
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
31
31
 
32
32
  if TYPE_CHECKING:
33
- from openlineage.client.facet import BaseFacet
34
- from openlineage.client.run import Dataset
35
-
33
+ from airflow.providers.common.compat.openlineage.facet import BaseFacet, Dataset, DatasetFacet
36
34
  from airflow.providers.openlineage.extractors.base import OperatorLineage
37
35
  from airflow.utils.context import Context
38
36
 
@@ -217,20 +215,19 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
217
215
  path where the results are saved (user's prefix + some UUID), we are creating a dataset with the
218
216
  user-provided path only. This should make it easier to match this dataset across different processes.
219
217
  """
220
- from openlineage.client.facet import (
218
+ from airflow.providers.common.compat.openlineage.facet import (
219
+ Dataset,
220
+ Error,
221
221
  ExternalQueryRunFacet,
222
- ExtractionError,
223
222
  ExtractionErrorRunFacet,
224
- SqlJobFacet,
223
+ SQLJobFacet,
225
224
  )
226
- from openlineage.client.run import Dataset
227
-
228
225
  from airflow.providers.openlineage.extractors.base import OperatorLineage
229
226
  from airflow.providers.openlineage.sqlparser import SQLParser
230
227
 
231
228
  sql_parser = SQLParser(dialect="generic")
232
229
 
233
- job_facets: dict[str, BaseFacet] = {"sql": SqlJobFacet(query=sql_parser.normalize_sql(self.query))}
230
+ job_facets: dict[str, BaseFacet] = {"sql": SQLJobFacet(query=sql_parser.normalize_sql(self.query))}
234
231
  parse_result = sql_parser.parse(sql=self.query)
235
232
 
236
233
  if not parse_result:
@@ -242,7 +239,7 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
242
239
  totalTasks=len(self.query) if isinstance(self.query, list) else 1,
243
240
  failedTasks=len(parse_result.errors),
244
241
  errors=[
245
- ExtractionError(
242
+ Error(
246
243
  errorMessage=error.message,
247
244
  stackTrace=None,
248
245
  task=error.origin_statement,
@@ -284,13 +281,13 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
284
281
  return OperatorLineage(job_facets=job_facets, run_facets=run_facets, inputs=inputs, outputs=outputs)
285
282
 
286
283
  def get_openlineage_dataset(self, database, table) -> Dataset | None:
287
- from openlineage.client.facet import (
284
+ from airflow.providers.common.compat.openlineage.facet import (
285
+ Dataset,
286
+ Identifier,
288
287
  SchemaDatasetFacet,
289
- SchemaField,
288
+ SchemaDatasetFacetFields,
290
289
  SymlinksDatasetFacet,
291
- SymlinksDatasetFacetIdentifiers,
292
290
  )
293
- from openlineage.client.run import Dataset
294
291
 
295
292
  client = self.hook.get_conn()
296
293
  try:
@@ -301,10 +298,10 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
301
298
  # Dataset has also its' physical location which we can add in symlink facet.
302
299
  s3_location = table_metadata["TableMetadata"]["Parameters"]["location"]
303
300
  parsed_path = urlparse(s3_location)
304
- facets: dict[str, BaseFacet] = {
301
+ facets: dict[str, DatasetFacet] = {
305
302
  "symlinks": SymlinksDatasetFacet(
306
303
  identifiers=[
307
- SymlinksDatasetFacetIdentifiers(
304
+ Identifier(
308
305
  namespace=f"{parsed_path.scheme}://{parsed_path.netloc}",
309
306
  name=str(parsed_path.path),
310
307
  type="TABLE",
@@ -313,7 +310,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
313
310
  )
314
311
  }
315
312
  fields = [
316
- SchemaField(name=column["Name"], type=column["Type"], description=column.get("Comment"))
313
+ SchemaDatasetFacetFields(
314
+ name=column["Name"], type=column["Type"], description=column["Comment"]
315
+ )
317
316
  for column in table_metadata["TableMetadata"]["Columns"]
318
317
  ]
319
318
  if fields:
@@ -1382,30 +1382,30 @@ class EmrServerlessStartJobOperator(BaseOperator):
1382
1382
 
1383
1383
  self.persist_links(context)
1384
1384
 
1385
- if self.deferrable:
1386
- self.defer(
1387
- trigger=EmrServerlessStartJobTrigger(
1388
- application_id=self.application_id,
1389
- job_id=self.job_id,
1390
- waiter_delay=self.waiter_delay,
1391
- waiter_max_attempts=self.waiter_max_attempts,
1392
- aws_conn_id=self.aws_conn_id,
1393
- ),
1394
- method_name="execute_complete",
1395
- timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
1396
- )
1397
-
1398
1385
  if self.wait_for_completion:
1399
- waiter = self.hook.get_waiter("serverless_job_completed")
1400
- wait(
1401
- waiter=waiter,
1402
- waiter_max_attempts=self.waiter_max_attempts,
1403
- waiter_delay=self.waiter_delay,
1404
- args={"applicationId": self.application_id, "jobRunId": self.job_id},
1405
- failure_message="Serverless Job failed",
1406
- status_message="Serverless Job status is",
1407
- status_args=["jobRun.state", "jobRun.stateDetails"],
1408
- )
1386
+ if self.deferrable:
1387
+ self.defer(
1388
+ trigger=EmrServerlessStartJobTrigger(
1389
+ application_id=self.application_id,
1390
+ job_id=self.job_id,
1391
+ waiter_delay=self.waiter_delay,
1392
+ waiter_max_attempts=self.waiter_max_attempts,
1393
+ aws_conn_id=self.aws_conn_id,
1394
+ ),
1395
+ method_name="execute_complete",
1396
+ timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
1397
+ )
1398
+ else:
1399
+ waiter = self.hook.get_waiter("serverless_job_completed")
1400
+ wait(
1401
+ waiter=waiter,
1402
+ waiter_max_attempts=self.waiter_max_attempts,
1403
+ waiter_delay=self.waiter_delay,
1404
+ args={"applicationId": self.application_id, "jobRunId": self.job_id},
1405
+ failure_message="Serverless Job failed",
1406
+ status_message="Serverless Job status is",
1407
+ status_args=["jobRun.state", "jobRun.stateDetails"],
1408
+ )
1409
1409
 
1410
1410
  return self.job_id
1411
1411