apache-airflow-providers-amazon 8.25.0rc1__py3-none-any.whl → 8.26.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/hooks/athena.py +18 -9
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -1
- airflow/providers/amazon/aws/hooks/base_aws.py +34 -10
- airflow/providers/amazon/aws/hooks/chime.py +2 -1
- airflow/providers/amazon/aws/hooks/datasync.py +6 -3
- airflow/providers/amazon/aws/hooks/ecr.py +2 -1
- airflow/providers/amazon/aws/hooks/ecs.py +12 -6
- airflow/providers/amazon/aws/hooks/glacier.py +8 -4
- airflow/providers/amazon/aws/hooks/kinesis.py +2 -1
- airflow/providers/amazon/aws/hooks/logs.py +4 -2
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +24 -12
- airflow/providers/amazon/aws/hooks/redshift_data.py +4 -2
- airflow/providers/amazon/aws/hooks/redshift_sql.py +6 -3
- airflow/providers/amazon/aws/hooks/s3.py +70 -53
- airflow/providers/amazon/aws/hooks/sagemaker.py +82 -41
- airflow/providers/amazon/aws/hooks/secrets_manager.py +6 -3
- airflow/providers/amazon/aws/hooks/sts.py +2 -1
- airflow/providers/amazon/aws/operators/athena.py +21 -8
- airflow/providers/amazon/aws/operators/batch.py +12 -6
- airflow/providers/amazon/aws/operators/datasync.py +2 -1
- airflow/providers/amazon/aws/operators/ecs.py +1 -0
- airflow/providers/amazon/aws/operators/emr.py +6 -86
- airflow/providers/amazon/aws/operators/glue.py +4 -2
- airflow/providers/amazon/aws/operators/glue_crawler.py +22 -19
- airflow/providers/amazon/aws/operators/neptune.py +2 -1
- airflow/providers/amazon/aws/operators/redshift_cluster.py +2 -1
- airflow/providers/amazon/aws/operators/sagemaker.py +2 -1
- airflow/providers/amazon/aws/sensors/base_aws.py +2 -1
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +25 -17
- airflow/providers/amazon/aws/sensors/glue_crawler.py +16 -12
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +6 -3
- airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +2 -1
- airflow/providers/amazon/aws/transfers/s3_to_sql.py +2 -1
- airflow/providers/amazon/aws/triggers/ecs.py +3 -1
- airflow/providers/amazon/aws/triggers/glue.py +15 -3
- airflow/providers/amazon/aws/triggers/glue_crawler.py +8 -1
- airflow/providers/amazon/aws/utils/connection_wrapper.py +10 -5
- airflow/providers/amazon/aws/utils/mixins.py +2 -1
- airflow/providers/amazon/aws/utils/redshift.py +2 -1
- airflow/providers/amazon/get_provider_info.py +2 -1
- {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0rc1.dist-info}/METADATA +6 -6
- {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0rc1.dist-info}/RECORD +45 -45
- {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0rc1.dist-info}/entry_points.txt +0 -0
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
22
22
|
import asyncio
|
23
23
|
import fnmatch
|
24
24
|
import gzip as gz
|
25
|
+
import inspect
|
25
26
|
import logging
|
26
27
|
import os
|
27
28
|
import re
|
@@ -36,7 +37,7 @@ from inspect import signature
|
|
36
37
|
from io import BytesIO
|
37
38
|
from pathlib import Path
|
38
39
|
from tempfile import NamedTemporaryFile, gettempdir
|
39
|
-
from typing import TYPE_CHECKING, Any, Callable
|
40
|
+
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
|
40
41
|
from urllib.parse import urlsplit
|
41
42
|
from uuid import uuid4
|
42
43
|
|
@@ -61,42 +62,24 @@ from airflow.utils.helpers import chunks
|
|
61
62
|
logger = logging.getLogger(__name__)
|
62
63
|
|
63
64
|
|
65
|
+
# Explicit value that would remove ACLs from a copy
|
66
|
+
# No conflicts with Canned ACLs:
|
67
|
+
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/acl-overview.html#canned-acl
|
68
|
+
NO_ACL = "no-acl"
|
69
|
+
|
70
|
+
|
64
71
|
def provide_bucket_name(func: Callable) -> Callable:
|
65
72
|
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
|
66
73
|
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
|
67
74
|
logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
|
68
|
-
function_signature = signature(func)
|
69
|
-
|
70
|
-
@wraps(func)
|
71
|
-
def wrapper(*args, **kwargs) -> Callable:
|
72
|
-
bound_args = function_signature.bind(*args, **kwargs)
|
73
|
-
|
74
|
-
if "bucket_name" not in bound_args.arguments:
|
75
|
-
self = args[0]
|
76
|
-
|
77
|
-
if "bucket_name" in self.service_config:
|
78
|
-
bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
|
79
|
-
elif self.conn_config and self.conn_config.schema:
|
80
|
-
warnings.warn(
|
81
|
-
"s3 conn_type, and the associated schema field, is deprecated. "
|
82
|
-
"Please use aws conn_type instead, and specify `bucket_name` "
|
83
|
-
"in `service_config.s3` within `extras`.",
|
84
|
-
AirflowProviderDeprecationWarning,
|
85
|
-
stacklevel=2,
|
86
|
-
)
|
87
|
-
bound_args.arguments["bucket_name"] = self.conn_config.schema
|
88
|
-
|
89
|
-
return func(*bound_args.args, **bound_args.kwargs)
|
90
75
|
|
91
|
-
return wrapper
|
92
|
-
|
93
|
-
|
94
|
-
def provide_bucket_name_async(func: Callable) -> Callable:
|
95
|
-
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
|
96
76
|
function_signature = signature(func)
|
77
|
+
if "bucket_name" not in function_signature.parameters:
|
78
|
+
raise RuntimeError(
|
79
|
+
"Decorator provide_bucket_name should only wrap a function with param 'bucket_name'."
|
80
|
+
)
|
97
81
|
|
98
|
-
|
99
|
-
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
82
|
+
async def maybe_add_bucket_name(*args, **kwargs):
|
100
83
|
bound_args = function_signature.bind(*args, **kwargs)
|
101
84
|
|
102
85
|
if "bucket_name" not in bound_args.arguments:
|
@@ -105,8 +88,46 @@ def provide_bucket_name_async(func: Callable) -> Callable:
|
|
105
88
|
connection = await sync_to_async(self.get_connection)(self.aws_conn_id)
|
106
89
|
if connection.schema:
|
107
90
|
bound_args.arguments["bucket_name"] = connection.schema
|
91
|
+
return bound_args
|
92
|
+
|
93
|
+
if inspect.iscoroutinefunction(func):
|
94
|
+
|
95
|
+
@wraps(func)
|
96
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
97
|
+
bound_args = await maybe_add_bucket_name(*args, **kwargs)
|
98
|
+
print(f"invoking async function {func=}")
|
99
|
+
return await func(*bound_args.args, **bound_args.kwargs)
|
100
|
+
|
101
|
+
elif inspect.isasyncgenfunction(func):
|
102
|
+
|
103
|
+
@wraps(func)
|
104
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
105
|
+
bound_args = await maybe_add_bucket_name(*args, **kwargs)
|
106
|
+
async for thing in func(*bound_args.args, **bound_args.kwargs):
|
107
|
+
yield thing
|
108
|
+
|
109
|
+
else:
|
110
|
+
|
111
|
+
@wraps(func)
|
112
|
+
def wrapper(*args, **kwargs) -> Callable:
|
113
|
+
bound_args = function_signature.bind(*args, **kwargs)
|
114
|
+
|
115
|
+
if "bucket_name" not in bound_args.arguments:
|
116
|
+
self = args[0]
|
117
|
+
|
118
|
+
if "bucket_name" in self.service_config:
|
119
|
+
bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
|
120
|
+
elif self.conn_config and self.conn_config.schema:
|
121
|
+
warnings.warn(
|
122
|
+
"s3 conn_type, and the associated schema field, is deprecated. "
|
123
|
+
"Please use aws conn_type instead, and specify `bucket_name` "
|
124
|
+
"in `service_config.s3` within `extras`.",
|
125
|
+
AirflowProviderDeprecationWarning,
|
126
|
+
stacklevel=2,
|
127
|
+
)
|
128
|
+
bound_args.arguments["bucket_name"] = self.conn_config.schema
|
108
129
|
|
109
|
-
|
130
|
+
return func(*bound_args.args, **bound_args.kwargs)
|
110
131
|
|
111
132
|
return wrapper
|
112
133
|
|
@@ -400,8 +421,8 @@ class S3Hook(AwsBaseHook):
|
|
400
421
|
|
401
422
|
return prefixes
|
402
423
|
|
403
|
-
@provide_bucket_name_async
|
404
424
|
@unify_bucket_name_and_key
|
425
|
+
@provide_bucket_name
|
405
426
|
async def get_head_object_async(
|
406
427
|
self, client: AioBaseClient, key: str, bucket_name: str | None = None
|
407
428
|
) -> dict[str, Any] | None:
|
@@ -462,10 +483,10 @@ class S3Hook(AwsBaseHook):
|
|
462
483
|
|
463
484
|
return prefixes
|
464
485
|
|
465
|
-
@
|
486
|
+
@provide_bucket_name
|
466
487
|
async def get_file_metadata_async(
|
467
488
|
self, client: AioBaseClient, bucket_name: str, key: str | None = None
|
468
|
-
) ->
|
489
|
+
) -> AsyncIterator[Any]:
|
469
490
|
"""
|
470
491
|
Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously.
|
471
492
|
|
@@ -477,11 +498,10 @@ class S3Hook(AwsBaseHook):
|
|
477
498
|
delimiter = ""
|
478
499
|
paginator = client.get_paginator("list_objects_v2")
|
479
500
|
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
|
480
|
-
files = []
|
481
501
|
async for page in response:
|
482
502
|
if "Contents" in page:
|
483
|
-
|
484
|
-
|
503
|
+
for row in page["Contents"]:
|
504
|
+
yield row
|
485
505
|
|
486
506
|
async def _check_key_async(
|
487
507
|
self,
|
@@ -506,21 +526,16 @@ class S3Hook(AwsBaseHook):
|
|
506
526
|
"""
|
507
527
|
bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
|
508
528
|
if wildcard_match:
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
obj = await self.get_head_object_async(client, key, bucket_name)
|
520
|
-
if obj is None:
|
521
|
-
return False
|
522
|
-
|
523
|
-
return True
|
529
|
+
async for k in self.get_file_metadata_async(client, bucket_name, key):
|
530
|
+
if fnmatch.fnmatch(k["Key"], key):
|
531
|
+
return True
|
532
|
+
return False
|
533
|
+
if use_regex:
|
534
|
+
async for k in self.get_file_metadata_async(client, bucket_name):
|
535
|
+
if re.match(pattern=key, string=k["Key"]):
|
536
|
+
return True
|
537
|
+
return False
|
538
|
+
return bool(await self.get_head_object_async(client, key, bucket_name))
|
524
539
|
|
525
540
|
async def check_key_async(
|
526
541
|
self,
|
@@ -1276,6 +1291,8 @@ class S3Hook(AwsBaseHook):
|
|
1276
1291
|
object to be copied which is private by default.
|
1277
1292
|
"""
|
1278
1293
|
acl_policy = acl_policy or "private"
|
1294
|
+
if acl_policy != NO_ACL:
|
1295
|
+
kwargs["ACL"] = acl_policy
|
1279
1296
|
|
1280
1297
|
dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
|
1281
1298
|
dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
|
@@ -1287,7 +1304,7 @@ class S3Hook(AwsBaseHook):
|
|
1287
1304
|
|
1288
1305
|
copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key, "VersionId": source_version_id}
|
1289
1306
|
response = self.get_conn().copy_object(
|
1290
|
-
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source,
|
1307
|
+
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
|
1291
1308
|
)
|
1292
1309
|
return response
|
1293
1310
|
|
@@ -40,7 +40,8 @@ from airflow.utils import timezone
|
|
40
40
|
|
41
41
|
|
42
42
|
class LogState:
|
43
|
-
"""
|
43
|
+
"""
|
44
|
+
Enum-style class holding all possible states of CloudWatch log streams.
|
44
45
|
|
45
46
|
https://sagemaker.readthedocs.io/en/stable/session.html#sagemaker.session.LogState
|
46
47
|
"""
|
@@ -58,7 +59,8 @@ Position = namedtuple("Position", ["timestamp", "skip"])
|
|
58
59
|
|
59
60
|
|
60
61
|
def argmin(arr, f: Callable) -> int | None:
|
61
|
-
"""
|
62
|
+
"""
|
63
|
+
Given callable ``f``, find index in ``arr`` to minimize ``f(arr[i])``.
|
62
64
|
|
63
65
|
None is returned if ``arr`` is empty.
|
64
66
|
"""
|
@@ -73,7 +75,8 @@ def argmin(arr, f: Callable) -> int | None:
|
|
73
75
|
|
74
76
|
|
75
77
|
def secondary_training_status_changed(current_job_description: dict, prev_job_description: dict) -> bool:
|
76
|
-
"""
|
78
|
+
"""
|
79
|
+
Check if training job's secondary status message has changed.
|
77
80
|
|
78
81
|
:param current_job_description: Current job description, returned from DescribeTrainingJob call.
|
79
82
|
:param prev_job_description: Previous job description, returned from DescribeTrainingJob call.
|
@@ -102,7 +105,8 @@ def secondary_training_status_changed(current_job_description: dict, prev_job_de
|
|
102
105
|
def secondary_training_status_message(
|
103
106
|
job_description: dict[str, list[Any]], prev_description: dict | None
|
104
107
|
) -> str:
|
105
|
-
"""
|
108
|
+
"""
|
109
|
+
Format string containing start time and the secondary training job status message.
|
106
110
|
|
107
111
|
:param job_description: Returned response from DescribeTrainingJob call
|
108
112
|
:param prev_description: Previous job description from DescribeTrainingJob call
|
@@ -134,7 +138,8 @@ def secondary_training_status_message(
|
|
134
138
|
|
135
139
|
|
136
140
|
class SageMakerHook(AwsBaseHook):
|
137
|
-
"""
|
141
|
+
"""
|
142
|
+
Interact with Amazon SageMaker.
|
138
143
|
|
139
144
|
Provide thick wrapper around
|
140
145
|
:external+boto3:py:class:`boto3.client("sagemaker") <SageMaker.Client>`.
|
@@ -157,7 +162,8 @@ class SageMakerHook(AwsBaseHook):
|
|
157
162
|
self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
|
158
163
|
|
159
164
|
def tar_and_s3_upload(self, path: str, key: str, bucket: str) -> None:
|
160
|
-
"""
|
165
|
+
"""
|
166
|
+
Tar the local file or directory and upload to s3.
|
161
167
|
|
162
168
|
:param path: local file or directory
|
163
169
|
:param key: s3 key
|
@@ -175,7 +181,8 @@ class SageMakerHook(AwsBaseHook):
|
|
175
181
|
self.s3_hook.load_file_obj(temp_file, key, bucket, replace=True)
|
176
182
|
|
177
183
|
def configure_s3_resources(self, config: dict) -> None:
|
178
|
-
"""
|
184
|
+
"""
|
185
|
+
Extract the S3 operations from the configuration and execute them.
|
179
186
|
|
180
187
|
:param config: config of SageMaker operation
|
181
188
|
"""
|
@@ -193,7 +200,8 @@ class SageMakerHook(AwsBaseHook):
|
|
193
200
|
self.s3_hook.load_file(op["Path"], op["Key"], op["Bucket"])
|
194
201
|
|
195
202
|
def check_s3_url(self, s3url: str) -> bool:
|
196
|
-
"""
|
203
|
+
"""
|
204
|
+
Check if an S3 URL exists.
|
197
205
|
|
198
206
|
:param s3url: S3 url
|
199
207
|
"""
|
@@ -214,7 +222,8 @@ class SageMakerHook(AwsBaseHook):
|
|
214
222
|
return True
|
215
223
|
|
216
224
|
def check_training_config(self, training_config: dict) -> None:
|
217
|
-
"""
|
225
|
+
"""
|
226
|
+
Check if a training configuration is valid.
|
218
227
|
|
219
228
|
:param training_config: training_config
|
220
229
|
"""
|
@@ -224,7 +233,8 @@ class SageMakerHook(AwsBaseHook):
|
|
224
233
|
self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
|
225
234
|
|
226
235
|
def check_tuning_config(self, tuning_config: dict) -> None:
|
227
|
-
"""
|
236
|
+
"""
|
237
|
+
Check if a tuning configuration is valid.
|
228
238
|
|
229
239
|
:param tuning_config: tuning_config
|
230
240
|
"""
|
@@ -233,7 +243,8 @@ class SageMakerHook(AwsBaseHook):
|
|
233
243
|
self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
|
234
244
|
|
235
245
|
def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Generator:
|
236
|
-
"""
|
246
|
+
"""
|
247
|
+
Iterate over the available events.
|
237
248
|
|
238
249
|
The events coming from a set of log streams in a single log group
|
239
250
|
interleaving the events from each stream so they're yielded in timestamp order.
|
@@ -276,7 +287,8 @@ class SageMakerHook(AwsBaseHook):
|
|
276
287
|
check_interval: int = 30,
|
277
288
|
max_ingestion_time: int | None = None,
|
278
289
|
):
|
279
|
-
"""
|
290
|
+
"""
|
291
|
+
Start a model training job.
|
280
292
|
|
281
293
|
After training completes, Amazon SageMaker saves the resulting model
|
282
294
|
artifacts to an Amazon S3 location that you specify.
|
@@ -327,7 +339,8 @@ class SageMakerHook(AwsBaseHook):
|
|
327
339
|
check_interval: int = 30,
|
328
340
|
max_ingestion_time: int | None = None,
|
329
341
|
):
|
330
|
-
"""
|
342
|
+
"""
|
343
|
+
Start a hyperparameter tuning job.
|
331
344
|
|
332
345
|
A hyperparameter tuning job finds the best version of a model by running
|
333
346
|
many training jobs on your dataset using the algorithm you choose and
|
@@ -364,7 +377,8 @@ class SageMakerHook(AwsBaseHook):
|
|
364
377
|
check_interval: int = 30,
|
365
378
|
max_ingestion_time: int | None = None,
|
366
379
|
):
|
367
|
-
"""
|
380
|
+
"""
|
381
|
+
Start a transform job.
|
368
382
|
|
369
383
|
A transform job uses a trained model to get inferences on a dataset and
|
370
384
|
saves these results to an Amazon S3 location that you specify.
|
@@ -402,7 +416,8 @@ class SageMakerHook(AwsBaseHook):
|
|
402
416
|
check_interval: int = 30,
|
403
417
|
max_ingestion_time: int | None = None,
|
404
418
|
):
|
405
|
-
"""
|
419
|
+
"""
|
420
|
+
Use Amazon SageMaker Processing to analyze data and evaluate models.
|
406
421
|
|
407
422
|
With Processing, you can use a simplified, managed experience on
|
408
423
|
SageMaker to run your data processing workloads, such as feature
|
@@ -433,7 +448,8 @@ class SageMakerHook(AwsBaseHook):
|
|
433
448
|
return response
|
434
449
|
|
435
450
|
def create_model(self, config: dict):
|
436
|
-
"""
|
451
|
+
"""
|
452
|
+
Create a model in Amazon SageMaker.
|
437
453
|
|
438
454
|
In the request, you name the model and describe a primary container. For
|
439
455
|
the primary container, you specify the Docker image that contains
|
@@ -450,7 +466,8 @@ class SageMakerHook(AwsBaseHook):
|
|
450
466
|
return self.get_conn().create_model(**config)
|
451
467
|
|
452
468
|
def create_endpoint_config(self, config: dict):
|
453
|
-
"""
|
469
|
+
"""
|
470
|
+
Create an endpoint configuration to deploy models.
|
454
471
|
|
455
472
|
In the configuration, you identify one or more models, created using the
|
456
473
|
CreateModel API, to deploy and the resources that you want Amazon
|
@@ -473,7 +490,8 @@ class SageMakerHook(AwsBaseHook):
|
|
473
490
|
check_interval: int = 30,
|
474
491
|
max_ingestion_time: int | None = None,
|
475
492
|
):
|
476
|
-
"""
|
493
|
+
"""
|
494
|
+
Create an endpoint from configuration.
|
477
495
|
|
478
496
|
When you create a serverless endpoint, SageMaker provisions and manages
|
479
497
|
the compute resources for you. Then, you can make inference requests to
|
@@ -512,7 +530,8 @@ class SageMakerHook(AwsBaseHook):
|
|
512
530
|
check_interval: int = 30,
|
513
531
|
max_ingestion_time: int | None = None,
|
514
532
|
):
|
515
|
-
"""
|
533
|
+
"""
|
534
|
+
Deploy the config in the request and switch to using the new endpoint.
|
516
535
|
|
517
536
|
Resources provisioned for the endpoint using the previous EndpointConfig
|
518
537
|
are deleted (there is no availability loss).
|
@@ -542,7 +561,8 @@ class SageMakerHook(AwsBaseHook):
|
|
542
561
|
return response
|
543
562
|
|
544
563
|
def describe_training_job(self, name: str):
|
545
|
-
"""
|
564
|
+
"""
|
565
|
+
Get the training job info associated with the name.
|
546
566
|
|
547
567
|
.. seealso::
|
548
568
|
- :external+boto3:py:meth:`SageMaker.Client.describe_training_job`
|
@@ -614,7 +634,8 @@ class SageMakerHook(AwsBaseHook):
|
|
614
634
|
return state, last_description, last_describe_job_call
|
615
635
|
|
616
636
|
def describe_tuning_job(self, name: str) -> dict:
|
617
|
-
"""
|
637
|
+
"""
|
638
|
+
Get the tuning job info associated with the name.
|
618
639
|
|
619
640
|
.. seealso::
|
620
641
|
- :external+boto3:py:meth:`SageMaker.Client.describe_hyper_parameter_tuning_job`
|
@@ -625,7 +646,8 @@ class SageMakerHook(AwsBaseHook):
|
|
625
646
|
return self.get_conn().describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
|
626
647
|
|
627
648
|
def describe_model(self, name: str) -> dict:
|
628
|
-
"""
|
649
|
+
"""
|
650
|
+
Get the SageMaker model info associated with the name.
|
629
651
|
|
630
652
|
:param name: the name of the SageMaker model
|
631
653
|
:return: A dict contains all the model info
|
@@ -633,7 +655,8 @@ class SageMakerHook(AwsBaseHook):
|
|
633
655
|
return self.get_conn().describe_model(ModelName=name)
|
634
656
|
|
635
657
|
def describe_transform_job(self, name: str) -> dict:
|
636
|
-
"""
|
658
|
+
"""
|
659
|
+
Get the transform job info associated with the name.
|
637
660
|
|
638
661
|
.. seealso::
|
639
662
|
- :external+boto3:py:meth:`SageMaker.Client.describe_transform_job`
|
@@ -644,7 +667,8 @@ class SageMakerHook(AwsBaseHook):
|
|
644
667
|
return self.get_conn().describe_transform_job(TransformJobName=name)
|
645
668
|
|
646
669
|
def describe_processing_job(self, name: str) -> dict:
|
647
|
-
"""
|
670
|
+
"""
|
671
|
+
Get the processing job info associated with the name.
|
648
672
|
|
649
673
|
.. seealso::
|
650
674
|
- :external+boto3:py:meth:`SageMaker.Client.describe_processing_job`
|
@@ -655,7 +679,8 @@ class SageMakerHook(AwsBaseHook):
|
|
655
679
|
return self.get_conn().describe_processing_job(ProcessingJobName=name)
|
656
680
|
|
657
681
|
def describe_endpoint_config(self, name: str) -> dict:
|
658
|
-
"""
|
682
|
+
"""
|
683
|
+
Get the endpoint config info associated with the name.
|
659
684
|
|
660
685
|
.. seealso::
|
661
686
|
- :external+boto3:py:meth:`SageMaker.Client.describe_endpoint_config`
|
@@ -666,7 +691,8 @@ class SageMakerHook(AwsBaseHook):
|
|
666
691
|
return self.get_conn().describe_endpoint_config(EndpointConfigName=name)
|
667
692
|
|
668
693
|
def describe_endpoint(self, name: str) -> dict:
|
669
|
-
"""
|
694
|
+
"""
|
695
|
+
Get the description of an endpoint.
|
670
696
|
|
671
697
|
.. seealso::
|
672
698
|
- :external+boto3:py:meth:`SageMaker.Client.describe_endpoint`
|
@@ -685,7 +711,8 @@ class SageMakerHook(AwsBaseHook):
|
|
685
711
|
max_ingestion_time: int | None = None,
|
686
712
|
non_terminal_states: set | None = None,
|
687
713
|
) -> dict:
|
688
|
-
"""
|
714
|
+
"""
|
715
|
+
Check status of a SageMaker resource.
|
689
716
|
|
690
717
|
:param job_name: name of the resource to check status, can be a job but
|
691
718
|
also pipeline for instance.
|
@@ -739,7 +766,8 @@ class SageMakerHook(AwsBaseHook):
|
|
739
766
|
check_interval: int,
|
740
767
|
max_ingestion_time: int | None = None,
|
741
768
|
):
|
742
|
-
"""
|
769
|
+
"""
|
770
|
+
Display logs for a given training job.
|
743
771
|
|
744
772
|
Optionally tailing them until the job is complete.
|
745
773
|
|
@@ -824,7 +852,8 @@ class SageMakerHook(AwsBaseHook):
|
|
824
852
|
def list_training_jobs(
|
825
853
|
self, name_contains: str | None = None, max_results: int | None = None, **kwargs
|
826
854
|
) -> list[dict]:
|
827
|
-
"""
|
855
|
+
"""
|
856
|
+
Call boto3's ``list_training_jobs``.
|
828
857
|
|
829
858
|
The training job name and max results are configurable via arguments.
|
830
859
|
Other arguments are not, and should be provided via kwargs. Note that
|
@@ -852,7 +881,8 @@ class SageMakerHook(AwsBaseHook):
|
|
852
881
|
def list_transform_jobs(
|
853
882
|
self, name_contains: str | None = None, max_results: int | None = None, **kwargs
|
854
883
|
) -> list[dict]:
|
855
|
-
"""
|
884
|
+
"""
|
885
|
+
Call boto3's ``list_transform_jobs``.
|
856
886
|
|
857
887
|
The transform job name and max results are configurable via arguments.
|
858
888
|
Other arguments are not, and should be provided via kwargs. Note that
|
@@ -879,7 +909,8 @@ class SageMakerHook(AwsBaseHook):
|
|
879
909
|
return results
|
880
910
|
|
881
911
|
def list_processing_jobs(self, **kwargs) -> list[dict]:
|
882
|
-
"""
|
912
|
+
"""
|
913
|
+
Call boto3's `list_processing_jobs`.
|
883
914
|
|
884
915
|
All arguments should be provided via kwargs. Note that boto3 expects
|
885
916
|
these in CamelCase, for example:
|
@@ -903,7 +934,8 @@ class SageMakerHook(AwsBaseHook):
|
|
903
934
|
def _preprocess_list_request_args(
|
904
935
|
self, name_contains: str | None = None, max_results: int | None = None, **kwargs
|
905
936
|
) -> tuple[dict[str, Any], int | None]:
|
906
|
-
"""
|
937
|
+
"""
|
938
|
+
Preprocess arguments for boto3's ``list_*`` methods.
|
907
939
|
|
908
940
|
It will turn arguments name_contains and max_results as boto3 compliant
|
909
941
|
CamelCase format. This method also makes sure that these two arguments
|
@@ -936,7 +968,8 @@ class SageMakerHook(AwsBaseHook):
|
|
936
968
|
def _list_request(
|
937
969
|
self, partial_func: Callable, result_key: str, max_results: int | None = None
|
938
970
|
) -> list[dict]:
|
939
|
-
"""
|
971
|
+
"""
|
972
|
+
Process a list request to produce results.
|
940
973
|
|
941
974
|
All AWS boto3 ``list_*`` requests return results in batches, and if the
|
942
975
|
key "NextToken" is contained in the result, there are more results to
|
@@ -992,7 +1025,8 @@ class SageMakerHook(AwsBaseHook):
|
|
992
1025
|
throttle_retry_delay: int = 2,
|
993
1026
|
retries: int = 3,
|
994
1027
|
) -> int:
|
995
|
-
"""
|
1028
|
+
"""
|
1029
|
+
Get the number of processing jobs found with the provided name prefix.
|
996
1030
|
|
997
1031
|
:param processing_job_name: The prefix to look for.
|
998
1032
|
:param job_name_suffix: The optional suffix which may be appended to deduplicate an existing job name.
|
@@ -1022,7 +1056,8 @@ class SageMakerHook(AwsBaseHook):
|
|
1022
1056
|
raise
|
1023
1057
|
|
1024
1058
|
def delete_model(self, model_name: str):
|
1025
|
-
"""
|
1059
|
+
"""
|
1060
|
+
Delete a SageMaker model.
|
1026
1061
|
|
1027
1062
|
.. seealso::
|
1028
1063
|
- :external+boto3:py:meth:`SageMaker.Client.delete_model`
|
@@ -1036,7 +1071,8 @@ class SageMakerHook(AwsBaseHook):
|
|
1036
1071
|
raise
|
1037
1072
|
|
1038
1073
|
def describe_pipeline_exec(self, pipeline_exec_arn: str, verbose: bool = False):
|
1039
|
-
"""
|
1074
|
+
"""
|
1075
|
+
Get info about a SageMaker pipeline execution.
|
1040
1076
|
|
1041
1077
|
.. seealso::
|
1042
1078
|
- :external+boto3:py:meth:`SageMaker.Client.describe_pipeline_execution`
|
@@ -1065,7 +1101,8 @@ class SageMakerHook(AwsBaseHook):
|
|
1065
1101
|
check_interval: int | None = None,
|
1066
1102
|
verbose: bool = True,
|
1067
1103
|
) -> str:
|
1068
|
-
"""
|
1104
|
+
"""
|
1105
|
+
Start a new execution for a SageMaker pipeline.
|
1069
1106
|
|
1070
1107
|
.. seealso::
|
1071
1108
|
- :external+boto3:py:meth:`SageMaker.Client.start_pipeline_execution`
|
@@ -1118,7 +1155,8 @@ class SageMakerHook(AwsBaseHook):
|
|
1118
1155
|
verbose: bool = True,
|
1119
1156
|
fail_if_not_running: bool = False,
|
1120
1157
|
) -> str:
|
1121
|
-
"""
|
1158
|
+
"""
|
1159
|
+
Stop SageMaker pipeline execution.
|
1122
1160
|
|
1123
1161
|
.. seealso::
|
1124
1162
|
- :external+boto3:py:meth:`SageMaker.Client.stop_pipeline_execution`
|
@@ -1186,7 +1224,8 @@ class SageMakerHook(AwsBaseHook):
|
|
1186
1224
|
return res["PipelineExecutionStatus"]
|
1187
1225
|
|
1188
1226
|
def create_model_package_group(self, package_group_name: str, package_group_desc: str = "") -> bool:
|
1189
|
-
"""
|
1227
|
+
"""
|
1228
|
+
Create a Model Package Group if it does not already exist.
|
1190
1229
|
|
1191
1230
|
.. seealso::
|
1192
1231
|
- :external+boto3:py:meth:`SageMaker.Client.create_model_package_group`
|
@@ -1239,7 +1278,8 @@ class SageMakerHook(AwsBaseHook):
|
|
1239
1278
|
wait_for_completion: bool = True,
|
1240
1279
|
check_interval: int = 30,
|
1241
1280
|
) -> dict | None:
|
1242
|
-
"""
|
1281
|
+
"""
|
1282
|
+
Create an auto ML job to predict the given column.
|
1243
1283
|
|
1244
1284
|
The learning input is based on data provided through S3 , and the output
|
1245
1285
|
is written to the specified S3 location.
|
@@ -1393,7 +1433,8 @@ class SageMakerHook(AwsBaseHook):
|
|
1393
1433
|
async def get_multi_stream(
|
1394
1434
|
self, log_group: str, streams: list[str], positions: dict[str, Any]
|
1395
1435
|
) -> AsyncGenerator[Any, tuple[int, Any | None]]:
|
1396
|
-
"""
|
1436
|
+
"""
|
1437
|
+
Iterate over the available events coming and interleaving the events from each stream so they're yielded in timestamp order.
|
1397
1438
|
|
1398
1439
|
:param log_group: The name of the log group.
|
1399
1440
|
:param streams: A list of the log stream names. The position of the stream in this list is
|
@@ -24,7 +24,8 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
|
24
24
|
|
25
25
|
|
26
26
|
class SecretsManagerHook(AwsBaseHook):
|
27
|
-
"""
|
27
|
+
"""
|
28
|
+
Interact with Amazon SecretsManager Service.
|
28
29
|
|
29
30
|
Provide thin wrapper around
|
30
31
|
:external+boto3:py:class:`boto3.client("secretsmanager") <SecretsManager.Client>`.
|
@@ -40,7 +41,8 @@ class SecretsManagerHook(AwsBaseHook):
|
|
40
41
|
super().__init__(client_type="secretsmanager", *args, **kwargs)
|
41
42
|
|
42
43
|
def get_secret(self, secret_name: str) -> str | bytes:
|
43
|
-
"""
|
44
|
+
"""
|
45
|
+
Retrieve secret value from AWS Secrets Manager as a str or bytes.
|
44
46
|
|
45
47
|
The value reflects format it stored in the AWS Secrets Manager.
|
46
48
|
|
@@ -60,7 +62,8 @@ class SecretsManagerHook(AwsBaseHook):
|
|
60
62
|
return secret
|
61
63
|
|
62
64
|
def get_secret_as_dict(self, secret_name: str) -> dict:
|
63
|
-
"""
|
65
|
+
"""
|
66
|
+
Retrieve secret value from AWS Secrets Manager as a dict.
|
64
67
|
|
65
68
|
:param secret_name: name of the secrets.
|
66
69
|
:return: dict with the information about the secrets
|
@@ -36,7 +36,8 @@ class StsHook(AwsBaseHook):
|
|
36
36
|
super().__init__(client_type="sts", *args, **kwargs)
|
37
37
|
|
38
38
|
def get_account_number(self) -> str:
|
39
|
-
"""
|
39
|
+
"""
|
40
|
+
Get the account Number.
|
40
41
|
|
41
42
|
.. seealso::
|
42
43
|
- :external+boto3:py:meth:`STS.Client.get_caller_identity`
|