apache-airflow-providers-amazon 8.25.0rc1__py3-none-any.whl → 8.26.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/hooks/athena.py +18 -9
  3. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -1
  4. airflow/providers/amazon/aws/hooks/base_aws.py +34 -10
  5. airflow/providers/amazon/aws/hooks/chime.py +2 -1
  6. airflow/providers/amazon/aws/hooks/datasync.py +6 -3
  7. airflow/providers/amazon/aws/hooks/ecr.py +2 -1
  8. airflow/providers/amazon/aws/hooks/ecs.py +12 -6
  9. airflow/providers/amazon/aws/hooks/glacier.py +8 -4
  10. airflow/providers/amazon/aws/hooks/kinesis.py +2 -1
  11. airflow/providers/amazon/aws/hooks/logs.py +4 -2
  12. airflow/providers/amazon/aws/hooks/redshift_cluster.py +24 -12
  13. airflow/providers/amazon/aws/hooks/redshift_data.py +4 -2
  14. airflow/providers/amazon/aws/hooks/redshift_sql.py +6 -3
  15. airflow/providers/amazon/aws/hooks/s3.py +70 -53
  16. airflow/providers/amazon/aws/hooks/sagemaker.py +82 -41
  17. airflow/providers/amazon/aws/hooks/secrets_manager.py +6 -3
  18. airflow/providers/amazon/aws/hooks/sts.py +2 -1
  19. airflow/providers/amazon/aws/operators/athena.py +21 -8
  20. airflow/providers/amazon/aws/operators/batch.py +12 -6
  21. airflow/providers/amazon/aws/operators/datasync.py +2 -1
  22. airflow/providers/amazon/aws/operators/ecs.py +1 -0
  23. airflow/providers/amazon/aws/operators/emr.py +6 -86
  24. airflow/providers/amazon/aws/operators/glue.py +4 -2
  25. airflow/providers/amazon/aws/operators/glue_crawler.py +22 -19
  26. airflow/providers/amazon/aws/operators/neptune.py +2 -1
  27. airflow/providers/amazon/aws/operators/redshift_cluster.py +2 -1
  28. airflow/providers/amazon/aws/operators/sagemaker.py +2 -1
  29. airflow/providers/amazon/aws/sensors/base_aws.py +2 -1
  30. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +25 -17
  31. airflow/providers/amazon/aws/sensors/glue_crawler.py +16 -12
  32. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +6 -3
  33. airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +2 -1
  34. airflow/providers/amazon/aws/transfers/s3_to_sql.py +2 -1
  35. airflow/providers/amazon/aws/triggers/ecs.py +3 -1
  36. airflow/providers/amazon/aws/triggers/glue.py +15 -3
  37. airflow/providers/amazon/aws/triggers/glue_crawler.py +8 -1
  38. airflow/providers/amazon/aws/utils/connection_wrapper.py +10 -5
  39. airflow/providers/amazon/aws/utils/mixins.py +2 -1
  40. airflow/providers/amazon/aws/utils/redshift.py +2 -1
  41. airflow/providers/amazon/get_provider_info.py +2 -1
  42. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0rc1.dist-info}/METADATA +6 -6
  43. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0rc1.dist-info}/RECORD +45 -45
  44. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0rc1.dist-info}/WHEEL +0 -0
  45. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0rc1.dist-info}/entry_points.txt +0 -0
@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
  import asyncio
23
23
  import fnmatch
24
24
  import gzip as gz
25
+ import inspect
25
26
  import logging
26
27
  import os
27
28
  import re
@@ -36,7 +37,7 @@ from inspect import signature
36
37
  from io import BytesIO
37
38
  from pathlib import Path
38
39
  from tempfile import NamedTemporaryFile, gettempdir
39
- from typing import TYPE_CHECKING, Any, Callable
40
+ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
40
41
  from urllib.parse import urlsplit
41
42
  from uuid import uuid4
42
43
 
@@ -61,42 +62,24 @@ from airflow.utils.helpers import chunks
61
62
  logger = logging.getLogger(__name__)
62
63
 
63
64
 
65
+ # Explicit value that would remove ACLs from a copy
66
+ # No conflicts with Canned ACLs:
67
+ # https://docs.aws.amazon.com/AmazonS3/latest/userguide/acl-overview.html#canned-acl
68
+ NO_ACL = "no-acl"
69
+
70
+
64
71
  def provide_bucket_name(func: Callable) -> Callable:
65
72
  """Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
66
73
  if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
67
74
  logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
68
- function_signature = signature(func)
69
-
70
- @wraps(func)
71
- def wrapper(*args, **kwargs) -> Callable:
72
- bound_args = function_signature.bind(*args, **kwargs)
73
-
74
- if "bucket_name" not in bound_args.arguments:
75
- self = args[0]
76
-
77
- if "bucket_name" in self.service_config:
78
- bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
79
- elif self.conn_config and self.conn_config.schema:
80
- warnings.warn(
81
- "s3 conn_type, and the associated schema field, is deprecated. "
82
- "Please use aws conn_type instead, and specify `bucket_name` "
83
- "in `service_config.s3` within `extras`.",
84
- AirflowProviderDeprecationWarning,
85
- stacklevel=2,
86
- )
87
- bound_args.arguments["bucket_name"] = self.conn_config.schema
88
-
89
- return func(*bound_args.args, **bound_args.kwargs)
90
75
 
91
- return wrapper
92
-
93
-
94
- def provide_bucket_name_async(func: Callable) -> Callable:
95
- """Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
96
76
  function_signature = signature(func)
77
+ if "bucket_name" not in function_signature.parameters:
78
+ raise RuntimeError(
79
+ "Decorator provide_bucket_name should only wrap a function with param 'bucket_name'."
80
+ )
97
81
 
98
- @wraps(func)
99
- async def wrapper(*args: Any, **kwargs: Any) -> Any:
82
+ async def maybe_add_bucket_name(*args, **kwargs):
100
83
  bound_args = function_signature.bind(*args, **kwargs)
101
84
 
102
85
  if "bucket_name" not in bound_args.arguments:
@@ -105,8 +88,46 @@ def provide_bucket_name_async(func: Callable) -> Callable:
105
88
  connection = await sync_to_async(self.get_connection)(self.aws_conn_id)
106
89
  if connection.schema:
107
90
  bound_args.arguments["bucket_name"] = connection.schema
91
+ return bound_args
92
+
93
+ if inspect.iscoroutinefunction(func):
94
+
95
+ @wraps(func)
96
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
97
+ bound_args = await maybe_add_bucket_name(*args, **kwargs)
98
+ print(f"invoking async function {func=}")
99
+ return await func(*bound_args.args, **bound_args.kwargs)
100
+
101
+ elif inspect.isasyncgenfunction(func):
102
+
103
+ @wraps(func)
104
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
105
+ bound_args = await maybe_add_bucket_name(*args, **kwargs)
106
+ async for thing in func(*bound_args.args, **bound_args.kwargs):
107
+ yield thing
108
+
109
+ else:
110
+
111
+ @wraps(func)
112
+ def wrapper(*args, **kwargs) -> Callable:
113
+ bound_args = function_signature.bind(*args, **kwargs)
114
+
115
+ if "bucket_name" not in bound_args.arguments:
116
+ self = args[0]
117
+
118
+ if "bucket_name" in self.service_config:
119
+ bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
120
+ elif self.conn_config and self.conn_config.schema:
121
+ warnings.warn(
122
+ "s3 conn_type, and the associated schema field, is deprecated. "
123
+ "Please use aws conn_type instead, and specify `bucket_name` "
124
+ "in `service_config.s3` within `extras`.",
125
+ AirflowProviderDeprecationWarning,
126
+ stacklevel=2,
127
+ )
128
+ bound_args.arguments["bucket_name"] = self.conn_config.schema
108
129
 
109
- return await func(*bound_args.args, **bound_args.kwargs)
130
+ return func(*bound_args.args, **bound_args.kwargs)
110
131
 
111
132
  return wrapper
112
133
 
@@ -400,8 +421,8 @@ class S3Hook(AwsBaseHook):
400
421
 
401
422
  return prefixes
402
423
 
403
- @provide_bucket_name_async
404
424
  @unify_bucket_name_and_key
425
+ @provide_bucket_name
405
426
  async def get_head_object_async(
406
427
  self, client: AioBaseClient, key: str, bucket_name: str | None = None
407
428
  ) -> dict[str, Any] | None:
@@ -462,10 +483,10 @@ class S3Hook(AwsBaseHook):
462
483
 
463
484
  return prefixes
464
485
 
465
- @provide_bucket_name_async
486
+ @provide_bucket_name
466
487
  async def get_file_metadata_async(
467
488
  self, client: AioBaseClient, bucket_name: str, key: str | None = None
468
- ) -> list[Any]:
489
+ ) -> AsyncIterator[Any]:
469
490
  """
470
491
  Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously.
471
492
 
@@ -477,11 +498,10 @@ class S3Hook(AwsBaseHook):
477
498
  delimiter = ""
478
499
  paginator = client.get_paginator("list_objects_v2")
479
500
  response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
480
- files = []
481
501
  async for page in response:
482
502
  if "Contents" in page:
483
- files += page["Contents"]
484
- return files
503
+ for row in page["Contents"]:
504
+ yield row
485
505
 
486
506
  async def _check_key_async(
487
507
  self,
@@ -506,21 +526,16 @@ class S3Hook(AwsBaseHook):
506
526
  """
507
527
  bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
508
528
  if wildcard_match:
509
- keys = await self.get_file_metadata_async(client, bucket_name, key)
510
- key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
511
- if not key_matches:
512
- return False
513
- elif use_regex:
514
- keys = await self.get_file_metadata_async(client, bucket_name)
515
- key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
516
- if not key_matches:
517
- return False
518
- else:
519
- obj = await self.get_head_object_async(client, key, bucket_name)
520
- if obj is None:
521
- return False
522
-
523
- return True
529
+ async for k in self.get_file_metadata_async(client, bucket_name, key):
530
+ if fnmatch.fnmatch(k["Key"], key):
531
+ return True
532
+ return False
533
+ if use_regex:
534
+ async for k in self.get_file_metadata_async(client, bucket_name):
535
+ if re.match(pattern=key, string=k["Key"]):
536
+ return True
537
+ return False
538
+ return bool(await self.get_head_object_async(client, key, bucket_name))
524
539
 
525
540
  async def check_key_async(
526
541
  self,
@@ -1276,6 +1291,8 @@ class S3Hook(AwsBaseHook):
1276
1291
  object to be copied which is private by default.
1277
1292
  """
1278
1293
  acl_policy = acl_policy or "private"
1294
+ if acl_policy != NO_ACL:
1295
+ kwargs["ACL"] = acl_policy
1279
1296
 
1280
1297
  dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
1281
1298
  dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
@@ -1287,7 +1304,7 @@ class S3Hook(AwsBaseHook):
1287
1304
 
1288
1305
  copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key, "VersionId": source_version_id}
1289
1306
  response = self.get_conn().copy_object(
1290
- Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, ACL=acl_policy, **kwargs
1307
+ Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
1291
1308
  )
1292
1309
  return response
1293
1310
 
@@ -40,7 +40,8 @@ from airflow.utils import timezone
40
40
 
41
41
 
42
42
  class LogState:
43
- """Enum-style class holding all possible states of CloudWatch log streams.
43
+ """
44
+ Enum-style class holding all possible states of CloudWatch log streams.
44
45
 
45
46
  https://sagemaker.readthedocs.io/en/stable/session.html#sagemaker.session.LogState
46
47
  """
@@ -58,7 +59,8 @@ Position = namedtuple("Position", ["timestamp", "skip"])
58
59
 
59
60
 
60
61
  def argmin(arr, f: Callable) -> int | None:
61
- """Given callable ``f``, find index in ``arr`` to minimize ``f(arr[i])``.
62
+ """
63
+ Given callable ``f``, find index in ``arr`` to minimize ``f(arr[i])``.
62
64
 
63
65
  None is returned if ``arr`` is empty.
64
66
  """
@@ -73,7 +75,8 @@ def argmin(arr, f: Callable) -> int | None:
73
75
 
74
76
 
75
77
  def secondary_training_status_changed(current_job_description: dict, prev_job_description: dict) -> bool:
76
- """Check if training job's secondary status message has changed.
78
+ """
79
+ Check if training job's secondary status message has changed.
77
80
 
78
81
  :param current_job_description: Current job description, returned from DescribeTrainingJob call.
79
82
  :param prev_job_description: Previous job description, returned from DescribeTrainingJob call.
@@ -102,7 +105,8 @@ def secondary_training_status_changed(current_job_description: dict, prev_job_de
102
105
  def secondary_training_status_message(
103
106
  job_description: dict[str, list[Any]], prev_description: dict | None
104
107
  ) -> str:
105
- """Format string containing start time and the secondary training job status message.
108
+ """
109
+ Format string containing start time and the secondary training job status message.
106
110
 
107
111
  :param job_description: Returned response from DescribeTrainingJob call
108
112
  :param prev_description: Previous job description from DescribeTrainingJob call
@@ -134,7 +138,8 @@ def secondary_training_status_message(
134
138
 
135
139
 
136
140
  class SageMakerHook(AwsBaseHook):
137
- """Interact with Amazon SageMaker.
141
+ """
142
+ Interact with Amazon SageMaker.
138
143
 
139
144
  Provide thick wrapper around
140
145
  :external+boto3:py:class:`boto3.client("sagemaker") <SageMaker.Client>`.
@@ -157,7 +162,8 @@ class SageMakerHook(AwsBaseHook):
157
162
  self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
158
163
 
159
164
  def tar_and_s3_upload(self, path: str, key: str, bucket: str) -> None:
160
- """Tar the local file or directory and upload to s3.
165
+ """
166
+ Tar the local file or directory and upload to s3.
161
167
 
162
168
  :param path: local file or directory
163
169
  :param key: s3 key
@@ -175,7 +181,8 @@ class SageMakerHook(AwsBaseHook):
175
181
  self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True)
176
182
 
177
183
  def configure_s3_resources(self, config: dict) -> None:
178
- """Extract the S3 operations from the configuration and execute them.
184
+ """
185
+ Extract the S3 operations from the configuration and execute them.
179
186
 
180
187
  :param config: config of SageMaker operation
181
188
  """
@@ -193,7 +200,8 @@ class SageMakerHook(AwsBaseHook):
193
200
  self.s3_hook.load_file(op["Path"], op["Key"], op["Bucket"])
194
201
 
195
202
  def check_s3_url(self, s3url: str) -> bool:
196
- """Check if an S3 URL exists.
203
+ """
204
+ Check if an S3 URL exists.
197
205
 
198
206
  :param s3url: S3 url
199
207
  """
@@ -214,7 +222,8 @@ class SageMakerHook(AwsBaseHook):
214
222
  return True
215
223
 
216
224
  def check_training_config(self, training_config: dict) -> None:
217
- """Check if a training configuration is valid.
225
+ """
226
+ Check if a training configuration is valid.
218
227
 
219
228
  :param training_config: training_config
220
229
  """
@@ -224,7 +233,8 @@ class SageMakerHook(AwsBaseHook):
224
233
  self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
225
234
 
226
235
  def check_tuning_config(self, tuning_config: dict) -> None:
227
- """Check if a tuning configuration is valid.
236
+ """
237
+ Check if a tuning configuration is valid.
228
238
 
229
239
  :param tuning_config: tuning_config
230
240
  """
@@ -233,7 +243,8 @@ class SageMakerHook(AwsBaseHook):
233
243
  self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
234
244
 
235
245
  def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Generator:
236
- """Iterate over the available events.
246
+ """
247
+ Iterate over the available events.
237
248
 
238
249
  The events coming from a set of log streams in a single log group
239
250
  interleaving the events from each stream so they're yielded in timestamp order.
@@ -276,7 +287,8 @@ class SageMakerHook(AwsBaseHook):
276
287
  check_interval: int = 30,
277
288
  max_ingestion_time: int | None = None,
278
289
  ):
279
- """Start a model training job.
290
+ """
291
+ Start a model training job.
280
292
 
281
293
  After training completes, Amazon SageMaker saves the resulting model
282
294
  artifacts to an Amazon S3 location that you specify.
@@ -327,7 +339,8 @@ class SageMakerHook(AwsBaseHook):
327
339
  check_interval: int = 30,
328
340
  max_ingestion_time: int | None = None,
329
341
  ):
330
- """Start a hyperparameter tuning job.
342
+ """
343
+ Start a hyperparameter tuning job.
331
344
 
332
345
  A hyperparameter tuning job finds the best version of a model by running
333
346
  many training jobs on your dataset using the algorithm you choose and
@@ -364,7 +377,8 @@ class SageMakerHook(AwsBaseHook):
364
377
  check_interval: int = 30,
365
378
  max_ingestion_time: int | None = None,
366
379
  ):
367
- """Start a transform job.
380
+ """
381
+ Start a transform job.
368
382
 
369
383
  A transform job uses a trained model to get inferences on a dataset and
370
384
  saves these results to an Amazon S3 location that you specify.
@@ -402,7 +416,8 @@ class SageMakerHook(AwsBaseHook):
402
416
  check_interval: int = 30,
403
417
  max_ingestion_time: int | None = None,
404
418
  ):
405
- """Use Amazon SageMaker Processing to analyze data and evaluate models.
419
+ """
420
+ Use Amazon SageMaker Processing to analyze data and evaluate models.
406
421
 
407
422
  With Processing, you can use a simplified, managed experience on
408
423
  SageMaker to run your data processing workloads, such as feature
@@ -433,7 +448,8 @@ class SageMakerHook(AwsBaseHook):
433
448
  return response
434
449
 
435
450
  def create_model(self, config: dict):
436
- """Create a model in Amazon SageMaker.
451
+ """
452
+ Create a model in Amazon SageMaker.
437
453
 
438
454
  In the request, you name the model and describe a primary container. For
439
455
  the primary container, you specify the Docker image that contains
@@ -450,7 +466,8 @@ class SageMakerHook(AwsBaseHook):
450
466
  return self.get_conn().create_model(**config)
451
467
 
452
468
  def create_endpoint_config(self, config: dict):
453
- """Create an endpoint configuration to deploy models.
469
+ """
470
+ Create an endpoint configuration to deploy models.
454
471
 
455
472
  In the configuration, you identify one or more models, created using the
456
473
  CreateModel API, to deploy and the resources that you want Amazon
@@ -473,7 +490,8 @@ class SageMakerHook(AwsBaseHook):
473
490
  check_interval: int = 30,
474
491
  max_ingestion_time: int | None = None,
475
492
  ):
476
- """Create an endpoint from configuration.
493
+ """
494
+ Create an endpoint from configuration.
477
495
 
478
496
  When you create a serverless endpoint, SageMaker provisions and manages
479
497
  the compute resources for you. Then, you can make inference requests to
@@ -512,7 +530,8 @@ class SageMakerHook(AwsBaseHook):
512
530
  check_interval: int = 30,
513
531
  max_ingestion_time: int | None = None,
514
532
  ):
515
- """Deploy the config in the request and switch to using the new endpoint.
533
+ """
534
+ Deploy the config in the request and switch to using the new endpoint.
516
535
 
517
536
  Resources provisioned for the endpoint using the previous EndpointConfig
518
537
  are deleted (there is no availability loss).
@@ -542,7 +561,8 @@ class SageMakerHook(AwsBaseHook):
542
561
  return response
543
562
 
544
563
  def describe_training_job(self, name: str):
545
- """Get the training job info associated with the name.
564
+ """
565
+ Get the training job info associated with the name.
546
566
 
547
567
  .. seealso::
548
568
  - :external+boto3:py:meth:`SageMaker.Client.describe_training_job`
@@ -614,7 +634,8 @@ class SageMakerHook(AwsBaseHook):
614
634
  return state, last_description, last_describe_job_call
615
635
 
616
636
  def describe_tuning_job(self, name: str) -> dict:
617
- """Get the tuning job info associated with the name.
637
+ """
638
+ Get the tuning job info associated with the name.
618
639
 
619
640
  .. seealso::
620
641
  - :external+boto3:py:meth:`SageMaker.Client.describe_hyper_parameter_tuning_job`
@@ -625,7 +646,8 @@ class SageMakerHook(AwsBaseHook):
625
646
  return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
626
647
 
627
648
  def describe_model(self, name: str) -> dict:
628
- """Get the SageMaker model info associated with the name.
649
+ """
650
+ Get the SageMaker model info associated with the name.
629
651
 
630
652
  :param name: the name of the SageMaker model
631
653
  :return: A dict contains all the model info
@@ -633,7 +655,8 @@ class SageMakerHook(AwsBaseHook):
633
655
  return self.get_conn().describe_model(ModelName=name)
634
656
 
635
657
  def describe_transform_job(self, name: str) -> dict:
636
- """Get the transform job info associated with the name.
658
+ """
659
+ Get the transform job info associated with the name.
637
660
 
638
661
  .. seealso::
639
662
  - :external+boto3:py:meth:`SageMaker.Client.describe_transform_job`
@@ -644,7 +667,8 @@ class SageMakerHook(AwsBaseHook):
644
667
  return self.get_conn().describe_transform_job(TransformJobName=name)
645
668
 
646
669
  def describe_processing_job(self, name: str) -> dict:
647
- """Get the processing job info associated with the name.
670
+ """
671
+ Get the processing job info associated with the name.
648
672
 
649
673
  .. seealso::
650
674
  - :external+boto3:py:meth:`SageMaker.Client.describe_processing_job`
@@ -655,7 +679,8 @@ class SageMakerHook(AwsBaseHook):
655
679
  return self.get_conn().describe_processing_job(ProcessingJobName=name)
656
680
 
657
681
  def describe_endpoint_config(self, name: str) -> dict:
658
- """Get the endpoint config info associated with the name.
682
+ """
683
+ Get the endpoint config info associated with the name.
659
684
 
660
685
  .. seealso::
661
686
  - :external+boto3:py:meth:`SageMaker.Client.describe_endpoint_config`
@@ -666,7 +691,8 @@ class SageMakerHook(AwsBaseHook):
666
691
  return self.get_conn().describe_endpoint_config(EndpointConfigName=name)
667
692
 
668
693
  def describe_endpoint(self, name: str) -> dict:
669
- """Get the description of an endpoint.
694
+ """
695
+ Get the description of an endpoint.
670
696
 
671
697
  .. seealso::
672
698
  - :external+boto3:py:meth:`SageMaker.Client.describe_endpoint`
@@ -685,7 +711,8 @@ class SageMakerHook(AwsBaseHook):
685
711
  max_ingestion_time: int | None = None,
686
712
  non_terminal_states: set | None = None,
687
713
  ) -> dict:
688
- """Check status of a SageMaker resource.
714
+ """
715
+ Check status of a SageMaker resource.
689
716
 
690
717
  :param job_name: name of the resource to check status, can be a job but
691
718
  also pipeline for instance.
@@ -739,7 +766,8 @@ class SageMakerHook(AwsBaseHook):
739
766
  check_interval: int,
740
767
  max_ingestion_time: int | None = None,
741
768
  ):
742
- """Display logs for a given training job.
769
+ """
770
+ Display logs for a given training job.
743
771
 
744
772
  Optionally tailing them until the job is complete.
745
773
 
@@ -824,7 +852,8 @@ class SageMakerHook(AwsBaseHook):
824
852
  def list_training_jobs(
825
853
  self, name_contains: str | None = None, max_results: int | None = None, **kwargs
826
854
  ) -> list[dict]:
827
- """Call boto3's ``list_training_jobs``.
855
+ """
856
+ Call boto3's ``list_training_jobs``.
828
857
 
829
858
  The training job name and max results are configurable via arguments.
830
859
  Other arguments are not, and should be provided via kwargs. Note that
@@ -852,7 +881,8 @@ class SageMakerHook(AwsBaseHook):
852
881
  def list_transform_jobs(
853
882
  self, name_contains: str | None = None, max_results: int | None = None, **kwargs
854
883
  ) -> list[dict]:
855
- """Call boto3's ``list_transform_jobs``.
884
+ """
885
+ Call boto3's ``list_transform_jobs``.
856
886
 
857
887
  The transform job name and max results are configurable via arguments.
858
888
  Other arguments are not, and should be provided via kwargs. Note that
@@ -879,7 +909,8 @@ class SageMakerHook(AwsBaseHook):
879
909
  return results
880
910
 
881
911
  def list_processing_jobs(self, **kwargs) -> list[dict]:
882
- """Call boto3's `list_processing_jobs`.
912
+ """
913
+ Call boto3's `list_processing_jobs`.
883
914
 
884
915
  All arguments should be provided via kwargs. Note that boto3 expects
885
916
  these in CamelCase, for example:
@@ -903,7 +934,8 @@ class SageMakerHook(AwsBaseHook):
903
934
  def _preprocess_list_request_args(
904
935
  self, name_contains: str | None = None, max_results: int | None = None, **kwargs
905
936
  ) -> tuple[dict[str, Any], int | None]:
906
- """Preprocess arguments for boto3's ``list_*`` methods.
937
+ """
938
+ Preprocess arguments for boto3's ``list_*`` methods.
907
939
 
908
940
  It will turn arguments name_contains and max_results as boto3 compliant
909
941
  CamelCase format. This method also makes sure that these two arguments
@@ -936,7 +968,8 @@ class SageMakerHook(AwsBaseHook):
936
968
  def _list_request(
937
969
  self, partial_func: Callable, result_key: str, max_results: int | None = None
938
970
  ) -> list[dict]:
939
- """Process a list request to produce results.
971
+ """
972
+ Process a list request to produce results.
940
973
 
941
974
  All AWS boto3 ``list_*`` requests return results in batches, and if the
942
975
  key "NextToken" is contained in the result, there are more results to
@@ -992,7 +1025,8 @@ class SageMakerHook(AwsBaseHook):
992
1025
  throttle_retry_delay: int = 2,
993
1026
  retries: int = 3,
994
1027
  ) -> int:
995
- """Get the number of processing jobs found with the provided name prefix.
1028
+ """
1029
+ Get the number of processing jobs found with the provided name prefix.
996
1030
 
997
1031
  :param processing_job_name: The prefix to look for.
998
1032
  :param job_name_suffix: The optional suffix which may be appended to deduplicate an existing job name.
@@ -1022,7 +1056,8 @@ class SageMakerHook(AwsBaseHook):
1022
1056
  raise
1023
1057
 
1024
1058
  def delete_model(self, model_name: str):
1025
- """Delete a SageMaker model.
1059
+ """
1060
+ Delete a SageMaker model.
1026
1061
 
1027
1062
  .. seealso::
1028
1063
  - :external+boto3:py:meth:`SageMaker.Client.delete_model`
@@ -1036,7 +1071,8 @@ class SageMakerHook(AwsBaseHook):
1036
1071
  raise
1037
1072
 
1038
1073
  def describe_pipeline_exec(self, pipeline_exec_arn: str, verbose: bool = False):
1039
- """Get info about a SageMaker pipeline execution.
1074
+ """
1075
+ Get info about a SageMaker pipeline execution.
1040
1076
 
1041
1077
  .. seealso::
1042
1078
  - :external+boto3:py:meth:`SageMaker.Client.describe_pipeline_execution`
@@ -1065,7 +1101,8 @@ class SageMakerHook(AwsBaseHook):
1065
1101
  check_interval: int | None = None,
1066
1102
  verbose: bool = True,
1067
1103
  ) -> str:
1068
- """Start a new execution for a SageMaker pipeline.
1104
+ """
1105
+ Start a new execution for a SageMaker pipeline.
1069
1106
 
1070
1107
  .. seealso::
1071
1108
  - :external+boto3:py:meth:`SageMaker.Client.start_pipeline_execution`
@@ -1118,7 +1155,8 @@ class SageMakerHook(AwsBaseHook):
1118
1155
  verbose: bool = True,
1119
1156
  fail_if_not_running: bool = False,
1120
1157
  ) -> str:
1121
- """Stop SageMaker pipeline execution.
1158
+ """
1159
+ Stop SageMaker pipeline execution.
1122
1160
 
1123
1161
  .. seealso::
1124
1162
  - :external+boto3:py:meth:`SageMaker.Client.stop_pipeline_execution`
@@ -1186,7 +1224,8 @@ class SageMakerHook(AwsBaseHook):
1186
1224
  return res["PipelineExecutionStatus"]
1187
1225
 
1188
1226
  def create_model_package_group(self, package_group_name: str, package_group_desc: str = "") -> bool:
1189
- """Create a Model Package Group if it does not already exist.
1227
+ """
1228
+ Create a Model Package Group if it does not already exist.
1190
1229
 
1191
1230
  .. seealso::
1192
1231
  - :external+boto3:py:meth:`SageMaker.Client.create_model_package_group`
@@ -1239,7 +1278,8 @@ class SageMakerHook(AwsBaseHook):
1239
1278
  wait_for_completion: bool = True,
1240
1279
  check_interval: int = 30,
1241
1280
  ) -> dict | None:
1242
- """Create an auto ML job to predict the given column.
1281
+ """
1282
+ Create an auto ML job to predict the given column.
1243
1283
 
1244
1284
  The learning input is based on data provided through S3 , and the output
1245
1285
  is written to the specified S3 location.
@@ -1393,7 +1433,8 @@ class SageMakerHook(AwsBaseHook):
1393
1433
  async def get_multi_stream(
1394
1434
  self, log_group: str, streams: list[str], positions: dict[str, Any]
1395
1435
  ) -> AsyncGenerator[Any, tuple[int, Any | None]]:
1396
- """Iterate over the available events coming and interleaving the events from each stream so they're yielded in timestamp order.
1436
+ """
1437
+ Iterate over the available events coming and interleaving the events from each stream so they're yielded in timestamp order.
1397
1438
 
1398
1439
  :param log_group: The name of the log group.
1399
1440
  :param streams: A list of the log stream names. The position of the stream in this list is
@@ -24,7 +24,8 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
24
24
 
25
25
 
26
26
  class SecretsManagerHook(AwsBaseHook):
27
- """Interact with Amazon SecretsManager Service.
27
+ """
28
+ Interact with Amazon SecretsManager Service.
28
29
 
29
30
  Provide thin wrapper around
30
31
  :external+boto3:py:class:`boto3.client("secretsmanager") <SecretsManager.Client>`.
@@ -40,7 +41,8 @@ class SecretsManagerHook(AwsBaseHook):
40
41
  super().__init__(client_type="secretsmanager", *args, **kwargs)
41
42
 
42
43
  def get_secret(self, secret_name: str) -> str | bytes:
43
- """Retrieve secret value from AWS Secrets Manager as a str or bytes.
44
+ """
45
+ Retrieve secret value from AWS Secrets Manager as a str or bytes.
44
46
 
45
47
  The value reflects format it stored in the AWS Secrets Manager.
46
48
 
@@ -60,7 +62,8 @@ class SecretsManagerHook(AwsBaseHook):
60
62
  return secret
61
63
 
62
64
  def get_secret_as_dict(self, secret_name: str) -> dict:
63
- """Retrieve secret value from AWS Secrets Manager as a dict.
65
+ """
66
+ Retrieve secret value from AWS Secrets Manager as a dict.
64
67
 
65
68
  :param secret_name: name of the secrets.
66
69
  :return: dict with the information about the secrets
@@ -36,7 +36,8 @@ class StsHook(AwsBaseHook):
36
36
  super().__init__(client_type="sts", *args, **kwargs)
37
37
 
38
38
  def get_account_number(self) -> str:
39
- """Get the account Number.
39
+ """
40
+ Get the account Number.
40
41
 
41
42
  .. seealso::
42
43
  - :external+boto3:py:meth:`STS.Client.get_caller_identity`