apache-airflow-providers-amazon 8.26.0rc2__py3-none-any.whl → 8.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/datasets/__init__.py +16 -0
  3. airflow/providers/amazon/aws/datasets/s3.py +45 -0
  4. airflow/providers/amazon/aws/executors/batch/batch_executor.py +20 -13
  5. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +24 -13
  6. airflow/providers/amazon/aws/hooks/kinesis_analytics.py +65 -0
  7. airflow/providers/amazon/aws/hooks/rds.py +3 -3
  8. airflow/providers/amazon/aws/hooks/s3.py +26 -1
  9. airflow/providers/amazon/aws/hooks/step_function.py +18 -0
  10. airflow/providers/amazon/aws/operators/athena.py +16 -17
  11. airflow/providers/amazon/aws/operators/emr.py +23 -23
  12. airflow/providers/amazon/aws/operators/kinesis_analytics.py +348 -0
  13. airflow/providers/amazon/aws/operators/rds.py +17 -20
  14. airflow/providers/amazon/aws/operators/redshift_cluster.py +71 -53
  15. airflow/providers/amazon/aws/operators/s3.py +7 -11
  16. airflow/providers/amazon/aws/operators/sagemaker.py +6 -18
  17. airflow/providers/amazon/aws/operators/step_function.py +12 -2
  18. airflow/providers/amazon/aws/sensors/kinesis_analytics.py +234 -0
  19. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -0
  20. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
  21. airflow/providers/amazon/aws/triggers/emr.py +3 -1
  22. airflow/providers/amazon/aws/triggers/kinesis_analytics.py +69 -0
  23. airflow/providers/amazon/aws/triggers/sagemaker.py +9 -1
  24. airflow/providers/amazon/aws/waiters/kinesisanalyticsv2.json +151 -0
  25. airflow/providers/amazon/aws/waiters/rds.json +253 -0
  26. airflow/providers/amazon/get_provider_info.py +35 -2
  27. {apache_airflow_providers_amazon-8.26.0rc2.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/METADATA +32 -25
  28. {apache_airflow_providers_amazon-8.26.0rc2.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/RECORD +30 -22
  29. {apache_airflow_providers_amazon-8.26.0rc2.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/WHEEL +0 -0
  30. {apache_airflow_providers_amazon-8.26.0rc2.dist-info → apache_airflow_providers_amazon-8.27.0.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "8.26.0"
32
+ __version__ = "8.27.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.7.0"
@@ -0,0 +1,16 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
@@ -0,0 +1,45 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ from airflow.datasets import Dataset
22
+ from airflow.providers.amazon.aws.hooks.s3 import S3Hook
23
+
24
+ if TYPE_CHECKING:
25
+ from urllib.parse import SplitResult
26
+
27
+ from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset
28
+
29
+
30
+ def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset:
31
+ return Dataset(uri=f"s3://{bucket}/{key}", extra=extra)
32
+
33
+
34
+ def sanitize_uri(uri: SplitResult) -> SplitResult:
35
+ if not uri.netloc:
36
+ raise ValueError("URI format s3:// must contain a bucket name")
37
+ return uri
38
+
39
+
40
+ def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset:
41
+ """Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the hook."""
42
+ from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset
43
+
44
+ bucket, key = S3Hook.parse_s3_url(dataset.uri)
45
+ return OpenLineageDataset(namespace=f"s3://{bucket}", name=key if key else "/")
@@ -19,10 +19,9 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
- import contextlib
23
- import logging
24
22
  import time
25
23
  from collections import deque
24
+ from contextlib import suppress
26
25
  from copy import deepcopy
27
26
  from typing import TYPE_CHECKING, Any, Dict, List, Sequence
28
27
 
@@ -292,12 +291,19 @@ class AwsBatchExecutor(BaseExecutor):
292
291
 
293
292
  if failure_reason:
294
293
  if attempt_number >= int(self.__class__.MAX_SUBMIT_JOB_ATTEMPTS):
295
- self.send_message_to_task_logs(
296
- logging.ERROR,
297
- "This job has been unsuccessfully attempted too many times (%s). Dropping the task. Reason: %s",
294
+ self.log.error(
295
+ (
296
+ "This job has been unsuccessfully attempted too many times (%s). "
297
+ "Dropping the task. Reason: %s"
298
+ ),
298
299
  attempt_number,
299
300
  failure_reason,
300
- ti=key,
301
+ )
302
+ self.log_task_event(
303
+ event="batch job submit failure",
304
+ extra=f"This job has been unsuccessfully attempted too many times ({attempt_number}). "
305
+ f"Dropping the task. Reason: {failure_reason}",
306
+ ti_key=key,
301
307
  )
302
308
  self.fail(key=key)
303
309
  else:
@@ -317,7 +323,7 @@ class AwsBatchExecutor(BaseExecutor):
317
323
  exec_config=exec_config,
318
324
  attempt_number=attempt_number,
319
325
  )
320
- with contextlib.suppress(AttributeError):
326
+ with suppress(AttributeError):
321
327
  # TODO: Remove this when min_airflow_version is 2.10.0 or higher in Amazon provider.
322
328
  # running_state is added in Airflow 2.10 and only needed to support task adoption
323
329
  # (an optional executor feature).
@@ -458,10 +464,11 @@ class AwsBatchExecutor(BaseExecutor):
458
464
  not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
459
465
  return not_adopted_tis
460
466
 
461
- def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
467
+ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
462
468
  # TODO: remove this method when min_airflow_version is set to higher than 2.10.0
463
- try:
464
- super().send_message_to_task_logs(level, msg, *args, ti=ti)
465
- except AttributeError:
466
- # ``send_message_to_task_logs`` is added in 2.10.0
467
- self.log.error(msg, *args)
469
+ with suppress(AttributeError):
470
+ super().log_task_event(
471
+ event=event,
472
+ extra=extra,
473
+ ti_key=ti_key,
474
+ )
@@ -23,9 +23,9 @@ Each Airflow task gets delegated out to an Amazon ECS Task.
23
23
 
24
24
  from __future__ import annotations
25
25
 
26
- import logging
27
26
  import time
28
27
  from collections import defaultdict, deque
28
+ from contextlib import suppress
29
29
  from copy import deepcopy
30
30
  from typing import TYPE_CHECKING, Sequence
31
31
 
@@ -385,18 +385,28 @@ class AwsEcsExecutor(BaseExecutor):
385
385
  )
386
386
  self.pending_tasks.append(ecs_task)
387
387
  else:
388
- self.send_message_to_task_logs(
389
- logging.ERROR,
388
+ reasons_str = ", ".join(failure_reasons)
389
+ self.log.error(
390
390
  "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
391
391
  task_key,
392
392
  attempt_number,
393
- ", ".join(failure_reasons),
394
- ti=task_key,
393
+ reasons_str,
394
+ )
395
+ self.log_task_event(
396
+ event="ecs task submit failure",
397
+ ti_key=task_key,
398
+ extra=(
399
+ f"Task could not be queued after {attempt_number} attempts. "
400
+ f"Marking as failed. Reasons: {reasons_str}"
401
+ ),
395
402
  )
396
403
  self.fail(task_key)
397
404
  elif not run_task_response["tasks"]:
398
- self.send_message_to_task_logs(
399
- logging.ERROR, "ECS RunTask Response: %s", run_task_response, ti=task_key
405
+ self.log.error("ECS RunTask Response: %s", run_task_response)
406
+ self.log_task_event(
407
+ event="ecs task submit failure",
408
+ extra=f"ECS RunTask Response: {run_task_response}",
409
+ ti_key=task_key,
400
410
  )
401
411
  raise EcsExecutorException(
402
412
  "No failures and no ECS tasks provided in response. This should never happen."
@@ -543,10 +553,11 @@ class AwsEcsExecutor(BaseExecutor):
543
553
  not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
544
554
  return not_adopted_tis
545
555
 
546
- def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
556
+ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
547
557
  # TODO: remove this method when min_airflow_version is set to higher than 2.10.0
548
- try:
549
- super().send_message_to_task_logs(level, msg, *args, ti=ti)
550
- except AttributeError:
551
- # ``send_message_to_task_logs`` is added in 2.10.0
552
- self.log.error(msg, *args)
558
+ with suppress(AttributeError):
559
+ super().log_task_event(
560
+ event=event,
561
+ extra=extra,
562
+ ti_key=ti_key,
563
+ )
@@ -0,0 +1,65 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
20
+
21
+
22
+ class KinesisAnalyticsV2Hook(AwsBaseHook):
23
+ """
24
+ Interact with Amazon Kinesis Analytics V2.
25
+
26
+ Provide thin wrapper around :external+boto3:py:class:`boto3.client("kinesisanalyticsv2") <KinesisAnalyticsV2.Client>`.
27
+
28
+ Additional arguments (such as ``aws_conn_id``) may be specified and
29
+ are passed down to the underlying AwsBaseHook.
30
+
31
+ .. seealso::
32
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
33
+ """
34
+
35
+ APPLICATION_START_INTERMEDIATE_STATES: tuple[str, ...] = ("STARTING", "UPDATING", "AUTOSCALING")
36
+ APPLICATION_START_FAILURE_STATES: tuple[str, ...] = (
37
+ "DELETING",
38
+ "STOPPING",
39
+ "READY",
40
+ "FORCE_STOPPING",
41
+ "ROLLING_BACK",
42
+ "MAINTENANCE",
43
+ "ROLLED_BACK",
44
+ )
45
+ APPLICATION_START_SUCCESS_STATES: tuple[str, ...] = ("RUNNING",)
46
+
47
+ APPLICATION_STOP_INTERMEDIATE_STATES: tuple[str, ...] = (
48
+ "STARTING",
49
+ "UPDATING",
50
+ "AUTOSCALING",
51
+ "RUNNING",
52
+ "STOPPING",
53
+ "FORCE_STOPPING",
54
+ )
55
+ APPLICATION_STOP_FAILURE_STATES: tuple[str, ...] = (
56
+ "DELETING",
57
+ "ROLLING_BACK",
58
+ "MAINTENANCE",
59
+ "ROLLED_BACK",
60
+ )
61
+ APPLICATION_STOP_SUCCESS_STATES: tuple[str, ...] = ("READY",)
62
+
63
+ def __init__(self, *args, **kwargs) -> None:
64
+ kwargs["client_type"] = "kinesisanalyticsv2"
65
+ super().__init__(*args, **kwargs)
@@ -259,7 +259,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
259
259
  return self.get_db_instance_state(db_instance_id)
260
260
 
261
261
  target_state = target_state.lower()
262
- if target_state in ("available", "deleted"):
262
+ if target_state in ("available", "deleted", "stopped"):
263
263
  waiter = self.conn.get_waiter(f"db_instance_{target_state}") # type: ignore
264
264
  wait(
265
265
  waiter=waiter,
@@ -272,7 +272,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
272
272
  )
273
273
  else:
274
274
  self._wait_for_state(poke, target_state, check_interval, max_attempts)
275
- self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_instance_id, target_state)
275
+ self.log.info("DB cluster '%s' reached the '%s' state", db_instance_id, target_state)
276
276
 
277
277
  def get_db_cluster_state(self, db_cluster_id: str) -> str:
278
278
  """
@@ -310,7 +310,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
310
310
  return self.get_db_cluster_state(db_cluster_id)
311
311
 
312
312
  target_state = target_state.lower()
313
- if target_state in ("available", "deleted"):
313
+ if target_state in ("available", "deleted", "stopped"):
314
314
  waiter = self.conn.get_waiter(f"db_cluster_{target_state}") # type: ignore
315
315
  waiter.wait(
316
316
  DBClusterIdentifier=db_cluster_id,
@@ -41,6 +41,8 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
41
41
  from urllib.parse import urlsplit
42
42
  from uuid import uuid4
43
43
 
44
+ from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
45
+
44
46
  if TYPE_CHECKING:
45
47
  from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject
46
48
 
@@ -1111,6 +1113,12 @@ class S3Hook(AwsBaseHook):
1111
1113
 
1112
1114
  client = self.get_conn()
1113
1115
  client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config)
1116
+ get_hook_lineage_collector().add_input_dataset(
1117
+ context=self, scheme="file", dataset_kwargs={"path": filename}
1118
+ )
1119
+ get_hook_lineage_collector().add_output_dataset(
1120
+ context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
1121
+ )
1114
1122
 
1115
1123
  @unify_bucket_name_and_key
1116
1124
  @provide_bucket_name
@@ -1251,6 +1259,10 @@ class S3Hook(AwsBaseHook):
1251
1259
  ExtraArgs=extra_args,
1252
1260
  Config=self.transfer_config,
1253
1261
  )
1262
+ # No input because file_obj can be anything - handle in calling function if possible
1263
+ get_hook_lineage_collector().add_output_dataset(
1264
+ context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
1265
+ )
1254
1266
 
1255
1267
  def copy_object(
1256
1268
  self,
@@ -1306,6 +1318,12 @@ class S3Hook(AwsBaseHook):
1306
1318
  response = self.get_conn().copy_object(
1307
1319
  Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
1308
1320
  )
1321
+ get_hook_lineage_collector().add_input_dataset(
1322
+ context=self, scheme="s3", dataset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key}
1323
+ )
1324
+ get_hook_lineage_collector().add_output_dataset(
1325
+ context=self, scheme="s3", dataset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key}
1326
+ )
1309
1327
  return response
1310
1328
 
1311
1329
  @provide_bucket_name
@@ -1425,6 +1443,11 @@ class S3Hook(AwsBaseHook):
1425
1443
 
1426
1444
  file_path.parent.mkdir(exist_ok=True, parents=True)
1427
1445
 
1446
+ get_hook_lineage_collector().add_output_dataset(
1447
+ context=self,
1448
+ scheme="file",
1449
+ dataset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()},
1450
+ )
1428
1451
  file = open(file_path, "wb")
1429
1452
  else:
1430
1453
  file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore
@@ -1435,7 +1458,9 @@ class S3Hook(AwsBaseHook):
1435
1458
  ExtraArgs=self.extra_args,
1436
1459
  Config=self.transfer_config,
1437
1460
  )
1438
-
1461
+ get_hook_lineage_collector().add_input_dataset(
1462
+ context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
1463
+ )
1439
1464
  return file.name
1440
1465
 
1441
1466
  def generate_presigned_url(
@@ -18,6 +18,7 @@ from __future__ import annotations
18
18
 
19
19
  import json
20
20
 
21
+ from airflow.exceptions import AirflowFailException
21
22
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
22
23
 
23
24
 
@@ -43,6 +44,7 @@ class StepFunctionHook(AwsBaseHook):
43
44
  state_machine_arn: str,
44
45
  name: str | None = None,
45
46
  state_machine_input: dict | str | None = None,
47
+ is_redrive_execution: bool = False,
46
48
  ) -> str:
47
49
  """
48
50
  Start Execution of the State Machine.
@@ -51,10 +53,26 @@ class StepFunctionHook(AwsBaseHook):
51
53
  - :external+boto3:py:meth:`SFN.Client.start_execution`
52
54
 
53
55
  :param state_machine_arn: AWS Step Function State Machine ARN.
56
+ :param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not
57
+ complete successfully in the last 14 days.
54
58
  :param name: The name of the execution.
55
59
  :param state_machine_input: JSON data input to pass to the State Machine.
56
60
  :return: Execution ARN.
57
61
  """
62
+ if is_redrive_execution:
63
+ if not name:
64
+ raise AirflowFailException(
65
+ "Execution name is required to start RedriveExecution for %s.", state_machine_arn
66
+ )
67
+ elements = state_machine_arn.split(":stateMachine:")
68
+ execution_arn = f"{elements[0]}:execution:{elements[1]}:{name}"
69
+ self.conn.redrive_execution(executionArn=execution_arn)
70
+ self.log.info(
71
+ "Successfully started RedriveExecution for Step Function State Machine: %s.",
72
+ state_machine_arn,
73
+ )
74
+ return execution_arn
75
+
58
76
  execution_args = {"stateMachineArn": state_machine_arn}
59
77
  if name is not None:
60
78
  execution_args["name"] = name
@@ -30,9 +30,7 @@ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
30
30
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
31
31
 
32
32
  if TYPE_CHECKING:
33
- from openlineage.client.facet import BaseFacet
34
- from openlineage.client.run import Dataset
35
-
33
+ from airflow.providers.common.compat.openlineage.facet import BaseFacet, Dataset, DatasetFacet
36
34
  from airflow.providers.openlineage.extractors.base import OperatorLineage
37
35
  from airflow.utils.context import Context
38
36
 
@@ -217,20 +215,19 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
217
215
  path where the results are saved (user's prefix + some UUID), we are creating a dataset with the
218
216
  user-provided path only. This should make it easier to match this dataset across different processes.
219
217
  """
220
- from openlineage.client.facet import (
218
+ from airflow.providers.common.compat.openlineage.facet import (
219
+ Dataset,
220
+ Error,
221
221
  ExternalQueryRunFacet,
222
- ExtractionError,
223
222
  ExtractionErrorRunFacet,
224
- SqlJobFacet,
223
+ SQLJobFacet,
225
224
  )
226
- from openlineage.client.run import Dataset
227
-
228
225
  from airflow.providers.openlineage.extractors.base import OperatorLineage
229
226
  from airflow.providers.openlineage.sqlparser import SQLParser
230
227
 
231
228
  sql_parser = SQLParser(dialect="generic")
232
229
 
233
- job_facets: dict[str, BaseFacet] = {"sql": SqlJobFacet(query=sql_parser.normalize_sql(self.query))}
230
+ job_facets: dict[str, BaseFacet] = {"sql": SQLJobFacet(query=sql_parser.normalize_sql(self.query))}
234
231
  parse_result = sql_parser.parse(sql=self.query)
235
232
 
236
233
  if not parse_result:
@@ -242,7 +239,7 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
242
239
  totalTasks=len(self.query) if isinstance(self.query, list) else 1,
243
240
  failedTasks=len(parse_result.errors),
244
241
  errors=[
245
- ExtractionError(
242
+ Error(
246
243
  errorMessage=error.message,
247
244
  stackTrace=None,
248
245
  task=error.origin_statement,
@@ -284,13 +281,13 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
284
281
  return OperatorLineage(job_facets=job_facets, run_facets=run_facets, inputs=inputs, outputs=outputs)
285
282
 
286
283
  def get_openlineage_dataset(self, database, table) -> Dataset | None:
287
- from openlineage.client.facet import (
284
+ from airflow.providers.common.compat.openlineage.facet import (
285
+ Dataset,
286
+ Identifier,
288
287
  SchemaDatasetFacet,
289
- SchemaField,
288
+ SchemaDatasetFacetFields,
290
289
  SymlinksDatasetFacet,
291
- SymlinksDatasetFacetIdentifiers,
292
290
  )
293
- from openlineage.client.run import Dataset
294
291
 
295
292
  client = self.hook.get_conn()
296
293
  try:
@@ -301,10 +298,10 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
301
298
  # Dataset has also its' physical location which we can add in symlink facet.
302
299
  s3_location = table_metadata["TableMetadata"]["Parameters"]["location"]
303
300
  parsed_path = urlparse(s3_location)
304
- facets: dict[str, BaseFacet] = {
301
+ facets: dict[str, DatasetFacet] = {
305
302
  "symlinks": SymlinksDatasetFacet(
306
303
  identifiers=[
307
- SymlinksDatasetFacetIdentifiers(
304
+ Identifier(
308
305
  namespace=f"{parsed_path.scheme}://{parsed_path.netloc}",
309
306
  name=str(parsed_path.path),
310
307
  type="TABLE",
@@ -313,7 +310,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
313
310
  )
314
311
  }
315
312
  fields = [
316
- SchemaField(name=column["Name"], type=column["Type"], description=column.get("Comment"))
313
+ SchemaDatasetFacetFields(
314
+ name=column["Name"], type=column["Type"], description=column["Comment"]
315
+ )
317
316
  for column in table_metadata["TableMetadata"]["Columns"]
318
317
  ]
319
318
  if fields:
@@ -1382,30 +1382,30 @@ class EmrServerlessStartJobOperator(BaseOperator):
1382
1382
 
1383
1383
  self.persist_links(context)
1384
1384
 
1385
- if self.deferrable:
1386
- self.defer(
1387
- trigger=EmrServerlessStartJobTrigger(
1388
- application_id=self.application_id,
1389
- job_id=self.job_id,
1390
- waiter_delay=self.waiter_delay,
1391
- waiter_max_attempts=self.waiter_max_attempts,
1392
- aws_conn_id=self.aws_conn_id,
1393
- ),
1394
- method_name="execute_complete",
1395
- timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
1396
- )
1397
-
1398
1385
  if self.wait_for_completion:
1399
- waiter = self.hook.get_waiter("serverless_job_completed")
1400
- wait(
1401
- waiter=waiter,
1402
- waiter_max_attempts=self.waiter_max_attempts,
1403
- waiter_delay=self.waiter_delay,
1404
- args={"applicationId": self.application_id, "jobRunId": self.job_id},
1405
- failure_message="Serverless Job failed",
1406
- status_message="Serverless Job status is",
1407
- status_args=["jobRun.state", "jobRun.stateDetails"],
1408
- )
1386
+ if self.deferrable:
1387
+ self.defer(
1388
+ trigger=EmrServerlessStartJobTrigger(
1389
+ application_id=self.application_id,
1390
+ job_id=self.job_id,
1391
+ waiter_delay=self.waiter_delay,
1392
+ waiter_max_attempts=self.waiter_max_attempts,
1393
+ aws_conn_id=self.aws_conn_id,
1394
+ ),
1395
+ method_name="execute_complete",
1396
+ timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
1397
+ )
1398
+ else:
1399
+ waiter = self.hook.get_waiter("serverless_job_completed")
1400
+ wait(
1401
+ waiter=waiter,
1402
+ waiter_max_attempts=self.waiter_max_attempts,
1403
+ waiter_delay=self.waiter_delay,
1404
+ args={"applicationId": self.application_id, "jobRunId": self.job_id},
1405
+ failure_message="Serverless Job failed",
1406
+ status_message="Serverless Job status is",
1407
+ status_args=["jobRun.state", "jobRun.stateDetails"],
1408
+ )
1409
1409
 
1410
1410
  return self.job_id
1411
1411