apache-airflow-providers-amazon 8.25.0__py3-none-any.whl → 8.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +10 -0
  3. airflow/providers/amazon/aws/executors/batch/batch_executor.py +19 -16
  4. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +22 -15
  5. airflow/providers/amazon/aws/hooks/athena.py +18 -9
  6. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -1
  7. airflow/providers/amazon/aws/hooks/base_aws.py +34 -10
  8. airflow/providers/amazon/aws/hooks/chime.py +2 -1
  9. airflow/providers/amazon/aws/hooks/datasync.py +6 -3
  10. airflow/providers/amazon/aws/hooks/ecr.py +2 -1
  11. airflow/providers/amazon/aws/hooks/ecs.py +12 -6
  12. airflow/providers/amazon/aws/hooks/glacier.py +8 -4
  13. airflow/providers/amazon/aws/hooks/kinesis.py +2 -1
  14. airflow/providers/amazon/aws/hooks/logs.py +4 -2
  15. airflow/providers/amazon/aws/hooks/redshift_cluster.py +24 -12
  16. airflow/providers/amazon/aws/hooks/redshift_data.py +4 -2
  17. airflow/providers/amazon/aws/hooks/redshift_sql.py +6 -3
  18. airflow/providers/amazon/aws/hooks/s3.py +70 -53
  19. airflow/providers/amazon/aws/hooks/sagemaker.py +82 -41
  20. airflow/providers/amazon/aws/hooks/secrets_manager.py +6 -3
  21. airflow/providers/amazon/aws/hooks/sts.py +2 -1
  22. airflow/providers/amazon/aws/operators/athena.py +21 -8
  23. airflow/providers/amazon/aws/operators/batch.py +12 -6
  24. airflow/providers/amazon/aws/operators/datasync.py +2 -1
  25. airflow/providers/amazon/aws/operators/ecs.py +1 -0
  26. airflow/providers/amazon/aws/operators/emr.py +6 -86
  27. airflow/providers/amazon/aws/operators/glue.py +4 -2
  28. airflow/providers/amazon/aws/operators/glue_crawler.py +22 -19
  29. airflow/providers/amazon/aws/operators/neptune.py +2 -1
  30. airflow/providers/amazon/aws/operators/redshift_cluster.py +2 -1
  31. airflow/providers/amazon/aws/operators/s3.py +11 -1
  32. airflow/providers/amazon/aws/operators/sagemaker.py +8 -10
  33. airflow/providers/amazon/aws/sensors/base_aws.py +2 -1
  34. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +25 -17
  35. airflow/providers/amazon/aws/sensors/glue_crawler.py +16 -12
  36. airflow/providers/amazon/aws/sensors/s3.py +11 -5
  37. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +6 -3
  38. airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +2 -1
  39. airflow/providers/amazon/aws/transfers/s3_to_sql.py +2 -1
  40. airflow/providers/amazon/aws/triggers/ecs.py +3 -1
  41. airflow/providers/amazon/aws/triggers/glue.py +15 -3
  42. airflow/providers/amazon/aws/triggers/glue_crawler.py +8 -1
  43. airflow/providers/amazon/aws/utils/connection_wrapper.py +10 -5
  44. airflow/providers/amazon/aws/utils/mixins.py +2 -1
  45. airflow/providers/amazon/aws/utils/redshift.py +2 -1
  46. airflow/providers/amazon/get_provider_info.py +2 -1
  47. {apache_airflow_providers_amazon-8.25.0.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/METADATA +6 -6
  48. {apache_airflow_providers_amazon-8.25.0.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/RECORD +50 -50
  49. {apache_airflow_providers_amazon-8.25.0.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/WHEEL +0 -0
  50. {apache_airflow_providers_amazon-8.25.0.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/entry_points.txt +0 -0
@@ -40,7 +40,8 @@ from airflow.utils import timezone
40
40
 
41
41
 
42
42
  class LogState:
43
- """Enum-style class holding all possible states of CloudWatch log streams.
43
+ """
44
+ Enum-style class holding all possible states of CloudWatch log streams.
44
45
 
45
46
  https://sagemaker.readthedocs.io/en/stable/session.html#sagemaker.session.LogState
46
47
  """
@@ -58,7 +59,8 @@ Position = namedtuple("Position", ["timestamp", "skip"])
58
59
 
59
60
 
60
61
  def argmin(arr, f: Callable) -> int | None:
61
- """Given callable ``f``, find index in ``arr`` to minimize ``f(arr[i])``.
62
+ """
63
+ Given callable ``f``, find index in ``arr`` to minimize ``f(arr[i])``.
62
64
 
63
65
  None is returned if ``arr`` is empty.
64
66
  """
@@ -73,7 +75,8 @@ def argmin(arr, f: Callable) -> int | None:
73
75
 
74
76
 
75
77
  def secondary_training_status_changed(current_job_description: dict, prev_job_description: dict) -> bool:
76
- """Check if training job's secondary status message has changed.
78
+ """
79
+ Check if training job's secondary status message has changed.
77
80
 
78
81
  :param current_job_description: Current job description, returned from DescribeTrainingJob call.
79
82
  :param prev_job_description: Previous job description, returned from DescribeTrainingJob call.
@@ -102,7 +105,8 @@ def secondary_training_status_changed(current_job_description: dict, prev_job_de
102
105
  def secondary_training_status_message(
103
106
  job_description: dict[str, list[Any]], prev_description: dict | None
104
107
  ) -> str:
105
- """Format string containing start time and the secondary training job status message.
108
+ """
109
+ Format string containing start time and the secondary training job status message.
106
110
 
107
111
  :param job_description: Returned response from DescribeTrainingJob call
108
112
  :param prev_description: Previous job description from DescribeTrainingJob call
@@ -134,7 +138,8 @@ def secondary_training_status_message(
134
138
 
135
139
 
136
140
  class SageMakerHook(AwsBaseHook):
137
- """Interact with Amazon SageMaker.
141
+ """
142
+ Interact with Amazon SageMaker.
138
143
 
139
144
  Provide thick wrapper around
140
145
  :external+boto3:py:class:`boto3.client("sagemaker") <SageMaker.Client>`.
@@ -157,7 +162,8 @@ class SageMakerHook(AwsBaseHook):
157
162
  self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
158
163
 
159
164
  def tar_and_s3_upload(self, path: str, key: str, bucket: str) -> None:
160
- """Tar the local file or directory and upload to s3.
165
+ """
166
+ Tar the local file or directory and upload to s3.
161
167
 
162
168
  :param path: local file or directory
163
169
  :param key: s3 key
@@ -175,7 +181,8 @@ class SageMakerHook(AwsBaseHook):
175
181
  self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True)
176
182
 
177
183
  def configure_s3_resources(self, config: dict) -> None:
178
- """Extract the S3 operations from the configuration and execute them.
184
+ """
185
+ Extract the S3 operations from the configuration and execute them.
179
186
 
180
187
  :param config: config of SageMaker operation
181
188
  """
@@ -193,7 +200,8 @@ class SageMakerHook(AwsBaseHook):
193
200
  self.s3_hook.load_file(op["Path"], op["Key"], op["Bucket"])
194
201
 
195
202
  def check_s3_url(self, s3url: str) -> bool:
196
- """Check if an S3 URL exists.
203
+ """
204
+ Check if an S3 URL exists.
197
205
 
198
206
  :param s3url: S3 url
199
207
  """
@@ -214,7 +222,8 @@ class SageMakerHook(AwsBaseHook):
214
222
  return True
215
223
 
216
224
  def check_training_config(self, training_config: dict) -> None:
217
- """Check if a training configuration is valid.
225
+ """
226
+ Check if a training configuration is valid.
218
227
 
219
228
  :param training_config: training_config
220
229
  """
@@ -224,7 +233,8 @@ class SageMakerHook(AwsBaseHook):
224
233
  self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
225
234
 
226
235
  def check_tuning_config(self, tuning_config: dict) -> None:
227
- """Check if a tuning configuration is valid.
236
+ """
237
+ Check if a tuning configuration is valid.
228
238
 
229
239
  :param tuning_config: tuning_config
230
240
  """
@@ -233,7 +243,8 @@ class SageMakerHook(AwsBaseHook):
233
243
  self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
234
244
 
235
245
  def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Generator:
236
- """Iterate over the available events.
246
+ """
247
+ Iterate over the available events.
237
248
 
238
249
  The events coming from a set of log streams in a single log group
239
250
  interleaving the events from each stream so they're yielded in timestamp order.
@@ -276,7 +287,8 @@ class SageMakerHook(AwsBaseHook):
276
287
  check_interval: int = 30,
277
288
  max_ingestion_time: int | None = None,
278
289
  ):
279
- """Start a model training job.
290
+ """
291
+ Start a model training job.
280
292
 
281
293
  After training completes, Amazon SageMaker saves the resulting model
282
294
  artifacts to an Amazon S3 location that you specify.
@@ -327,7 +339,8 @@ class SageMakerHook(AwsBaseHook):
327
339
  check_interval: int = 30,
328
340
  max_ingestion_time: int | None = None,
329
341
  ):
330
- """Start a hyperparameter tuning job.
342
+ """
343
+ Start a hyperparameter tuning job.
331
344
 
332
345
  A hyperparameter tuning job finds the best version of a model by running
333
346
  many training jobs on your dataset using the algorithm you choose and
@@ -364,7 +377,8 @@ class SageMakerHook(AwsBaseHook):
364
377
  check_interval: int = 30,
365
378
  max_ingestion_time: int | None = None,
366
379
  ):
367
- """Start a transform job.
380
+ """
381
+ Start a transform job.
368
382
 
369
383
  A transform job uses a trained model to get inferences on a dataset and
370
384
  saves these results to an Amazon S3 location that you specify.
@@ -402,7 +416,8 @@ class SageMakerHook(AwsBaseHook):
402
416
  check_interval: int = 30,
403
417
  max_ingestion_time: int | None = None,
404
418
  ):
405
- """Use Amazon SageMaker Processing to analyze data and evaluate models.
419
+ """
420
+ Use Amazon SageMaker Processing to analyze data and evaluate models.
406
421
 
407
422
  With Processing, you can use a simplified, managed experience on
408
423
  SageMaker to run your data processing workloads, such as feature
@@ -433,7 +448,8 @@ class SageMakerHook(AwsBaseHook):
433
448
  return response
434
449
 
435
450
  def create_model(self, config: dict):
436
- """Create a model in Amazon SageMaker.
451
+ """
452
+ Create a model in Amazon SageMaker.
437
453
 
438
454
  In the request, you name the model and describe a primary container. For
439
455
  the primary container, you specify the Docker image that contains
@@ -450,7 +466,8 @@ class SageMakerHook(AwsBaseHook):
450
466
  return self.get_conn().create_model(**config)
451
467
 
452
468
  def create_endpoint_config(self, config: dict):
453
- """Create an endpoint configuration to deploy models.
469
+ """
470
+ Create an endpoint configuration to deploy models.
454
471
 
455
472
  In the configuration, you identify one or more models, created using the
456
473
  CreateModel API, to deploy and the resources that you want Amazon
@@ -473,7 +490,8 @@ class SageMakerHook(AwsBaseHook):
473
490
  check_interval: int = 30,
474
491
  max_ingestion_time: int | None = None,
475
492
  ):
476
- """Create an endpoint from configuration.
493
+ """
494
+ Create an endpoint from configuration.
477
495
 
478
496
  When you create a serverless endpoint, SageMaker provisions and manages
479
497
  the compute resources for you. Then, you can make inference requests to
@@ -512,7 +530,8 @@ class SageMakerHook(AwsBaseHook):
512
530
  check_interval: int = 30,
513
531
  max_ingestion_time: int | None = None,
514
532
  ):
515
- """Deploy the config in the request and switch to using the new endpoint.
533
+ """
534
+ Deploy the config in the request and switch to using the new endpoint.
516
535
 
517
536
  Resources provisioned for the endpoint using the previous EndpointConfig
518
537
  are deleted (there is no availability loss).
@@ -542,7 +561,8 @@ class SageMakerHook(AwsBaseHook):
542
561
  return response
543
562
 
544
563
  def describe_training_job(self, name: str):
545
- """Get the training job info associated with the name.
564
+ """
565
+ Get the training job info associated with the name.
546
566
 
547
567
  .. seealso::
548
568
  - :external+boto3:py:meth:`SageMaker.Client.describe_training_job`
@@ -614,7 +634,8 @@ class SageMakerHook(AwsBaseHook):
614
634
  return state, last_description, last_describe_job_call
615
635
 
616
636
  def describe_tuning_job(self, name: str) -> dict:
617
- """Get the tuning job info associated with the name.
637
+ """
638
+ Get the tuning job info associated with the name.
618
639
 
619
640
  .. seealso::
620
641
  - :external+boto3:py:meth:`SageMaker.Client.describe_hyper_parameter_tuning_job`
@@ -625,7 +646,8 @@ class SageMakerHook(AwsBaseHook):
625
646
  return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
626
647
 
627
648
  def describe_model(self, name: str) -> dict:
628
- """Get the SageMaker model info associated with the name.
649
+ """
650
+ Get the SageMaker model info associated with the name.
629
651
 
630
652
  :param name: the name of the SageMaker model
631
653
  :return: A dict contains all the model info
@@ -633,7 +655,8 @@ class SageMakerHook(AwsBaseHook):
633
655
  return self.get_conn().describe_model(ModelName=name)
634
656
 
635
657
  def describe_transform_job(self, name: str) -> dict:
636
- """Get the transform job info associated with the name.
658
+ """
659
+ Get the transform job info associated with the name.
637
660
 
638
661
  .. seealso::
639
662
  - :external+boto3:py:meth:`SageMaker.Client.describe_transform_job`
@@ -644,7 +667,8 @@ class SageMakerHook(AwsBaseHook):
644
667
  return self.get_conn().describe_transform_job(TransformJobName=name)
645
668
 
646
669
  def describe_processing_job(self, name: str) -> dict:
647
- """Get the processing job info associated with the name.
670
+ """
671
+ Get the processing job info associated with the name.
648
672
 
649
673
  .. seealso::
650
674
  - :external+boto3:py:meth:`SageMaker.Client.describe_processing_job`
@@ -655,7 +679,8 @@ class SageMakerHook(AwsBaseHook):
655
679
  return self.get_conn().describe_processing_job(ProcessingJobName=name)
656
680
 
657
681
  def describe_endpoint_config(self, name: str) -> dict:
658
- """Get the endpoint config info associated with the name.
682
+ """
683
+ Get the endpoint config info associated with the name.
659
684
 
660
685
  .. seealso::
661
686
  - :external+boto3:py:meth:`SageMaker.Client.describe_endpoint_config`
@@ -666,7 +691,8 @@ class SageMakerHook(AwsBaseHook):
666
691
  return self.get_conn().describe_endpoint_config(EndpointConfigName=name)
667
692
 
668
693
  def describe_endpoint(self, name: str) -> dict:
669
- """Get the description of an endpoint.
694
+ """
695
+ Get the description of an endpoint.
670
696
 
671
697
  .. seealso::
672
698
  - :external+boto3:py:meth:`SageMaker.Client.describe_endpoint`
@@ -685,7 +711,8 @@ class SageMakerHook(AwsBaseHook):
685
711
  max_ingestion_time: int | None = None,
686
712
  non_terminal_states: set | None = None,
687
713
  ) -> dict:
688
- """Check status of a SageMaker resource.
714
+ """
715
+ Check status of a SageMaker resource.
689
716
 
690
717
  :param job_name: name of the resource to check status, can be a job but
691
718
  also pipeline for instance.
@@ -739,7 +766,8 @@ class SageMakerHook(AwsBaseHook):
739
766
  check_interval: int,
740
767
  max_ingestion_time: int | None = None,
741
768
  ):
742
- """Display logs for a given training job.
769
+ """
770
+ Display logs for a given training job.
743
771
 
744
772
  Optionally tailing them until the job is complete.
745
773
 
@@ -824,7 +852,8 @@ class SageMakerHook(AwsBaseHook):
824
852
  def list_training_jobs(
825
853
  self, name_contains: str | None = None, max_results: int | None = None, **kwargs
826
854
  ) -> list[dict]:
827
- """Call boto3's ``list_training_jobs``.
855
+ """
856
+ Call boto3's ``list_training_jobs``.
828
857
 
829
858
  The training job name and max results are configurable via arguments.
830
859
  Other arguments are not, and should be provided via kwargs. Note that
@@ -852,7 +881,8 @@ class SageMakerHook(AwsBaseHook):
852
881
  def list_transform_jobs(
853
882
  self, name_contains: str | None = None, max_results: int | None = None, **kwargs
854
883
  ) -> list[dict]:
855
- """Call boto3's ``list_transform_jobs``.
884
+ """
885
+ Call boto3's ``list_transform_jobs``.
856
886
 
857
887
  The transform job name and max results are configurable via arguments.
858
888
  Other arguments are not, and should be provided via kwargs. Note that
@@ -879,7 +909,8 @@ class SageMakerHook(AwsBaseHook):
879
909
  return results
880
910
 
881
911
  def list_processing_jobs(self, **kwargs) -> list[dict]:
882
- """Call boto3's `list_processing_jobs`.
912
+ """
913
+ Call boto3's `list_processing_jobs`.
883
914
 
884
915
  All arguments should be provided via kwargs. Note that boto3 expects
885
916
  these in CamelCase, for example:
@@ -903,7 +934,8 @@ class SageMakerHook(AwsBaseHook):
903
934
  def _preprocess_list_request_args(
904
935
  self, name_contains: str | None = None, max_results: int | None = None, **kwargs
905
936
  ) -> tuple[dict[str, Any], int | None]:
906
- """Preprocess arguments for boto3's ``list_*`` methods.
937
+ """
938
+ Preprocess arguments for boto3's ``list_*`` methods.
907
939
 
908
940
  It will turn arguments name_contains and max_results as boto3 compliant
909
941
  CamelCase format. This method also makes sure that these two arguments
@@ -936,7 +968,8 @@ class SageMakerHook(AwsBaseHook):
936
968
  def _list_request(
937
969
  self, partial_func: Callable, result_key: str, max_results: int | None = None
938
970
  ) -> list[dict]:
939
- """Process a list request to produce results.
971
+ """
972
+ Process a list request to produce results.
940
973
 
941
974
  All AWS boto3 ``list_*`` requests return results in batches, and if the
942
975
  key "NextToken" is contained in the result, there are more results to
@@ -992,7 +1025,8 @@ class SageMakerHook(AwsBaseHook):
992
1025
  throttle_retry_delay: int = 2,
993
1026
  retries: int = 3,
994
1027
  ) -> int:
995
- """Get the number of processing jobs found with the provided name prefix.
1028
+ """
1029
+ Get the number of processing jobs found with the provided name prefix.
996
1030
 
997
1031
  :param processing_job_name: The prefix to look for.
998
1032
  :param job_name_suffix: The optional suffix which may be appended to deduplicate an existing job name.
@@ -1022,7 +1056,8 @@ class SageMakerHook(AwsBaseHook):
1022
1056
  raise
1023
1057
 
1024
1058
  def delete_model(self, model_name: str):
1025
- """Delete a SageMaker model.
1059
+ """
1060
+ Delete a SageMaker model.
1026
1061
 
1027
1062
  .. seealso::
1028
1063
  - :external+boto3:py:meth:`SageMaker.Client.delete_model`
@@ -1036,7 +1071,8 @@ class SageMakerHook(AwsBaseHook):
1036
1071
  raise
1037
1072
 
1038
1073
  def describe_pipeline_exec(self, pipeline_exec_arn: str, verbose: bool = False):
1039
- """Get info about a SageMaker pipeline execution.
1074
+ """
1075
+ Get info about a SageMaker pipeline execution.
1040
1076
 
1041
1077
  .. seealso::
1042
1078
  - :external+boto3:py:meth:`SageMaker.Client.describe_pipeline_execution`
@@ -1065,7 +1101,8 @@ class SageMakerHook(AwsBaseHook):
1065
1101
  check_interval: int | None = None,
1066
1102
  verbose: bool = True,
1067
1103
  ) -> str:
1068
- """Start a new execution for a SageMaker pipeline.
1104
+ """
1105
+ Start a new execution for a SageMaker pipeline.
1069
1106
 
1070
1107
  .. seealso::
1071
1108
  - :external+boto3:py:meth:`SageMaker.Client.start_pipeline_execution`
@@ -1118,7 +1155,8 @@ class SageMakerHook(AwsBaseHook):
1118
1155
  verbose: bool = True,
1119
1156
  fail_if_not_running: bool = False,
1120
1157
  ) -> str:
1121
- """Stop SageMaker pipeline execution.
1158
+ """
1159
+ Stop SageMaker pipeline execution.
1122
1160
 
1123
1161
  .. seealso::
1124
1162
  - :external+boto3:py:meth:`SageMaker.Client.stop_pipeline_execution`
@@ -1186,7 +1224,8 @@ class SageMakerHook(AwsBaseHook):
1186
1224
  return res["PipelineExecutionStatus"]
1187
1225
 
1188
1226
  def create_model_package_group(self, package_group_name: str, package_group_desc: str = "") -> bool:
1189
- """Create a Model Package Group if it does not already exist.
1227
+ """
1228
+ Create a Model Package Group if it does not already exist.
1190
1229
 
1191
1230
  .. seealso::
1192
1231
  - :external+boto3:py:meth:`SageMaker.Client.create_model_package_group`
@@ -1239,7 +1278,8 @@ class SageMakerHook(AwsBaseHook):
1239
1278
  wait_for_completion: bool = True,
1240
1279
  check_interval: int = 30,
1241
1280
  ) -> dict | None:
1242
- """Create an auto ML job to predict the given column.
1281
+ """
1282
+ Create an auto ML job to predict the given column.
1243
1283
 
1244
1284
  The learning input is based on data provided through S3 , and the output
1245
1285
  is written to the specified S3 location.
@@ -1393,7 +1433,8 @@ class SageMakerHook(AwsBaseHook):
1393
1433
  async def get_multi_stream(
1394
1434
  self, log_group: str, streams: list[str], positions: dict[str, Any]
1395
1435
  ) -> AsyncGenerator[Any, tuple[int, Any | None]]:
1396
- """Iterate over the available events coming and interleaving the events from each stream so they're yielded in timestamp order.
1436
+ """
1437
+ Iterate over the available events coming and interleaving the events from each stream so they're yielded in timestamp order.
1397
1438
 
1398
1439
  :param log_group: The name of the log group.
1399
1440
  :param streams: A list of the log stream names. The position of the stream in this list is
@@ -24,7 +24,8 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
24
24
 
25
25
 
26
26
  class SecretsManagerHook(AwsBaseHook):
27
- """Interact with Amazon SecretsManager Service.
27
+ """
28
+ Interact with Amazon SecretsManager Service.
28
29
 
29
30
  Provide thin wrapper around
30
31
  :external+boto3:py:class:`boto3.client("secretsmanager") <SecretsManager.Client>`.
@@ -40,7 +41,8 @@ class SecretsManagerHook(AwsBaseHook):
40
41
  super().__init__(client_type="secretsmanager", *args, **kwargs)
41
42
 
42
43
  def get_secret(self, secret_name: str) -> str | bytes:
43
- """Retrieve secret value from AWS Secrets Manager as a str or bytes.
44
+ """
45
+ Retrieve secret value from AWS Secrets Manager as a str or bytes.
44
46
 
45
47
  The value reflects format it stored in the AWS Secrets Manager.
46
48
 
@@ -60,7 +62,8 @@ class SecretsManagerHook(AwsBaseHook):
60
62
  return secret
61
63
 
62
64
  def get_secret_as_dict(self, secret_name: str) -> dict:
63
- """Retrieve secret value from AWS Secrets Manager as a dict.
65
+ """
66
+ Retrieve secret value from AWS Secrets Manager as a dict.
64
67
 
65
68
  :param secret_name: name of the secrets.
66
69
  :return: dict with the information about the secrets
@@ -36,7 +36,8 @@ class StsHook(AwsBaseHook):
36
36
  super().__init__(client_type="sts", *args, **kwargs)
37
37
 
38
38
  def get_account_number(self) -> str:
39
- """Get the account Number.
39
+ """
40
+ Get the account Number.
40
41
 
41
42
  .. seealso::
42
43
  - :external+boto3:py:meth:`STS.Client.get_caller_identity`
@@ -175,9 +175,6 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
175
175
  f"query_execution_id is {self.query_execution_id}."
176
176
  )
177
177
 
178
- # Save output location from API response for later use in OpenLineage.
179
- self.output_location = self.hook.get_output_location(self.query_execution_id)
180
-
181
178
  return self.query_execution_id
182
179
 
183
180
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
@@ -185,6 +182,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
185
182
 
186
183
  if event["status"] != "success":
187
184
  raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}")
185
+
186
+ # Save query_execution_id to be later used by listeners
187
+ self.query_execution_id = event["value"]
188
188
  return event["value"]
189
189
 
190
190
  def on_kill(self) -> None:
@@ -208,13 +208,21 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
208
208
  )
209
209
  self.hook.poll_query_status(self.query_execution_id, sleep_time=self.sleep_time)
210
210
 
211
- def get_openlineage_facets_on_start(self) -> OperatorLineage:
212
- """Retrieve OpenLineage data by parsing SQL queries and enriching them with Athena API.
211
+ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
212
+ """
213
+ Retrieve OpenLineage data by parsing SQL queries and enriching them with Athena API.
213
214
 
214
215
  In addition to CTAS query, query and calculation results are stored in S3 location.
215
- For that reason additional output is attached with this location.
216
+ For that reason additional output is attached with this location. Instead of using the complete
217
+ path where the results are saved (user's prefix + some UUID), we are creating a dataset with the
218
+ user-provided path only. This should make it easier to match this dataset across different processes.
216
219
  """
217
- from openlineage.client.facet import ExtractionError, ExtractionErrorRunFacet, SqlJobFacet
220
+ from openlineage.client.facet import (
221
+ ExternalQueryRunFacet,
222
+ ExtractionError,
223
+ ExtractionErrorRunFacet,
224
+ SqlJobFacet,
225
+ )
218
226
  from openlineage.client.run import Dataset
219
227
 
220
228
  from airflow.providers.openlineage.extractors.base import OperatorLineage
@@ -264,6 +272,11 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
264
272
  )
265
273
  )
266
274
 
275
+ if self.query_execution_id:
276
+ run_facets["externalQuery"] = ExternalQueryRunFacet(
277
+ externalQueryId=self.query_execution_id, source="awsathena"
278
+ )
279
+
267
280
  if self.output_location:
268
281
  parsed = urlparse(self.output_location)
269
282
  outputs.append(Dataset(namespace=f"{parsed.scheme}://{parsed.netloc}", name=parsed.path or "/"))
@@ -300,7 +313,7 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
300
313
  )
301
314
  }
302
315
  fields = [
303
- SchemaField(name=column["Name"], type=column["Type"], description=column["Comment"])
316
+ SchemaField(name=column["Name"], type=column["Type"], description=column.get("Comment"))
304
317
  for column in table_metadata["TableMetadata"]["Columns"]
305
318
  ]
306
319
  if fields:
@@ -14,7 +14,8 @@
14
14
  # KIND, either express or implied. See the License for the
15
15
  # specific language governing permissions and limitations
16
16
  # under the License.
17
- """AWS Batch services.
17
+ """
18
+ AWS Batch services.
18
19
 
19
20
  .. seealso::
20
21
 
@@ -54,7 +55,8 @@ if TYPE_CHECKING:
54
55
 
55
56
 
56
57
  class BatchOperator(BaseOperator):
57
- """Execute a job on AWS Batch.
58
+ """
59
+ Execute a job on AWS Batch.
58
60
 
59
61
  .. seealso::
60
62
  For more information on how to use this operator, take a look at the guide:
@@ -236,7 +238,8 @@ class BatchOperator(BaseOperator):
236
238
  )
237
239
 
238
240
  def execute(self, context: Context) -> str | None:
239
- """Submit and monitor an AWS Batch job.
241
+ """
242
+ Submit and monitor an AWS Batch job.
240
243
 
241
244
  :raises: AirflowException
242
245
  """
@@ -287,7 +290,8 @@ class BatchOperator(BaseOperator):
287
290
  self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response)
288
291
 
289
292
  def submit_job(self, context: Context):
290
- """Submit an AWS Batch job.
293
+ """
294
+ Submit an AWS Batch job.
291
295
 
292
296
  :raises: AirflowException
293
297
  """
@@ -342,7 +346,8 @@ class BatchOperator(BaseOperator):
342
346
  )
343
347
 
344
348
  def monitor_job(self, context: Context):
345
- """Monitor an AWS Batch job.
349
+ """
350
+ Monitor an AWS Batch job.
346
351
 
347
352
  This can raise an exception or an AirflowTaskTimeout if the task was
348
353
  created with ``execution_timeout``.
@@ -434,7 +439,8 @@ class BatchOperator(BaseOperator):
434
439
 
435
440
 
436
441
  class BatchCreateComputeEnvironmentOperator(BaseOperator):
437
- """Create an AWS Batch compute environment.
442
+ """
443
+ Create an AWS Batch compute environment.
438
444
 
439
445
  .. seealso::
440
446
  For more information on how to use this operator, take a look at the guide:
@@ -34,7 +34,8 @@ if TYPE_CHECKING:
34
34
 
35
35
 
36
36
  class DataSyncOperator(AwsBaseOperator[DataSyncHook]):
37
- """Find, Create, Update, Execute and Delete AWS DataSync Tasks.
37
+ """
38
+ Find, Create, Update, Execute and Delete AWS DataSync Tasks.
38
39
 
39
40
  If ``do_xcom_push`` is True, then the DataSync TaskArn and TaskExecutionArn
40
41
  which were executed will be pushed to an XCom.
@@ -586,6 +586,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
586
586
  if event["status"] != "success":
587
587
  raise AirflowException(f"Error in task execution: {event}")
588
588
  self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps
589
+ self.cluster = event["cluster"]
589
590
  self._after_execution()
590
591
  if self._aws_logs_enabled():
591
592
  # same behavior as non-deferrable mode, return last line of logs of the task.