apache-airflow-providers-amazon 9.7.0rc2__py3-none-any.whl → 9.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/operators/sagemaker.py +180 -103
- airflow/providers/amazon/aws/sensors/s3.py +3 -2
- airflow/providers/amazon/aws/sensors/sagemaker.py +26 -16
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +24 -10
- {apache_airflow_providers_amazon-9.7.0rc2.dist-info → apache_airflow_providers_amazon-9.8.0.dist-info}/METADATA +12 -12
- {apache_airflow_providers_amazon-9.7.0rc2.dist-info → apache_airflow_providers_amazon-9.8.0.dist-info}/RECORD +9 -9
- {apache_airflow_providers_amazon-9.7.0rc2.dist-info → apache_airflow_providers_amazon-9.8.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.7.0rc2.dist-info → apache_airflow_providers_amazon-9.8.0.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "9.
|
32
|
+
__version__ = "9.8.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.10.0"
|
@@ -21,14 +21,12 @@ import json
|
|
21
21
|
import time
|
22
22
|
import urllib
|
23
23
|
from collections.abc import Sequence
|
24
|
-
from functools import cached_property
|
25
24
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
26
25
|
|
27
26
|
from botocore.exceptions import ClientError
|
28
27
|
|
29
28
|
from airflow.configuration import conf
|
30
29
|
from airflow.exceptions import AirflowException
|
31
|
-
from airflow.models import BaseOperator
|
32
30
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
33
31
|
from airflow.providers.amazon.aws.hooks.sagemaker import (
|
34
32
|
LogState,
|
@@ -36,11 +34,13 @@ from airflow.providers.amazon.aws.hooks.sagemaker import (
|
|
36
34
|
secondary_training_status_message,
|
37
35
|
)
|
38
36
|
from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink
|
37
|
+
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
39
38
|
from airflow.providers.amazon.aws.triggers.sagemaker import (
|
40
39
|
SageMakerPipelineTrigger,
|
41
40
|
SageMakerTrigger,
|
42
41
|
)
|
43
42
|
from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event
|
43
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
44
44
|
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
|
45
45
|
from airflow.providers.amazon.aws.utils.tags import format_tags
|
46
46
|
from airflow.utils.helpers import prune_dict
|
@@ -50,7 +50,7 @@ if TYPE_CHECKING:
|
|
50
50
|
from airflow.providers.openlineage.extractors.base import OperatorLineage
|
51
51
|
from airflow.utils.context import Context
|
52
52
|
|
53
|
-
DEFAULT_CONN_ID: str = "aws_default"
|
53
|
+
# DEFAULT_CONN_ID: str = "aws_default"
|
54
54
|
CHECK_INTERVAL_SECOND: int = 30
|
55
55
|
|
56
56
|
|
@@ -58,23 +58,33 @@ def serialize(result: dict) -> dict:
|
|
58
58
|
return json.loads(json.dumps(result, default=repr))
|
59
59
|
|
60
60
|
|
61
|
-
class SageMakerBaseOperator(
|
61
|
+
class SageMakerBaseOperator(AwsBaseOperator[SageMakerHook]):
|
62
62
|
"""
|
63
63
|
This is the base operator for all SageMaker operators.
|
64
64
|
|
65
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
66
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
67
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
68
|
+
empty, then default boto3 configuration would be used (and must be
|
69
|
+
maintained on each worker node).
|
70
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
71
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
72
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
65
73
|
:param config: The configuration necessary to start a training job (templated)
|
66
74
|
"""
|
67
75
|
|
68
|
-
|
76
|
+
aws_hook_class = SageMakerHook
|
77
|
+
template_fields: Sequence[str] = aws_template_fields(
|
78
|
+
"config",
|
79
|
+
)
|
69
80
|
template_ext: Sequence[str] = ()
|
70
81
|
template_fields_renderers: ClassVar[dict] = {"config": "json"}
|
71
82
|
ui_color: str = "#ededed"
|
72
83
|
integer_fields: list[list[Any]] = []
|
73
84
|
|
74
|
-
def __init__(self, *, config: dict,
|
85
|
+
def __init__(self, *, config: dict, **kwargs):
|
75
86
|
super().__init__(**kwargs)
|
76
87
|
self.config = config
|
77
|
-
self.aws_conn_id = aws_conn_id
|
78
88
|
|
79
89
|
def parse_integer(self, config: dict, field: list[str] | str) -> None:
|
80
90
|
"""Recursive method for parsing string fields holding integer values to integers."""
|
@@ -199,11 +209,6 @@ class SageMakerBaseOperator(BaseOperator):
|
|
199
209
|
def execute(self, context: Context):
|
200
210
|
raise NotImplementedError("Please implement execute() in sub class!")
|
201
211
|
|
202
|
-
@cached_property
|
203
|
-
def hook(self):
|
204
|
-
"""Return SageMakerHook."""
|
205
|
-
return SageMakerHook(aws_conn_id=self.aws_conn_id)
|
206
|
-
|
207
212
|
@staticmethod
|
208
213
|
def path_to_s3_dataset(path) -> Dataset:
|
209
214
|
from airflow.providers.common.compat.openlineage.facet import Dataset
|
@@ -227,7 +232,14 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
|
|
227
232
|
|
228
233
|
:param config: The configuration necessary to start a processing job (templated).
|
229
234
|
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job`
|
230
|
-
:param aws_conn_id: The
|
235
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
236
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
237
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
238
|
+
empty, then default boto3 configuration would be used (and must be
|
239
|
+
maintained on each worker node).
|
240
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
241
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
242
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
231
243
|
:param wait_for_completion: If wait is set to True, the time interval, in seconds,
|
232
244
|
that the operation waits to check the status of the processing job.
|
233
245
|
:param print_log: if the operator should print the cloudwatch log during processing
|
@@ -249,7 +261,6 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
|
|
249
261
|
self,
|
250
262
|
*,
|
251
263
|
config: dict,
|
252
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
253
264
|
wait_for_completion: bool = True,
|
254
265
|
print_log: bool = True,
|
255
266
|
check_interval: int = CHECK_INTERVAL_SECOND,
|
@@ -259,7 +270,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
|
|
259
270
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
260
271
|
**kwargs,
|
261
272
|
):
|
262
|
-
super().__init__(config=config,
|
273
|
+
super().__init__(config=config, **kwargs)
|
263
274
|
if action_if_job_exists not in ("fail", "timestamp"):
|
264
275
|
raise AirflowException(
|
265
276
|
f"Argument action_if_job_exists accepts only 'timestamp' and 'fail'. \
|
@@ -403,7 +414,14 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
|
|
403
414
|
:param config: The configuration necessary to create an endpoint config.
|
404
415
|
|
405
416
|
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config`
|
406
|
-
:param aws_conn_id: The
|
417
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
418
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
419
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
420
|
+
empty, then default boto3 configuration would be used (and must be
|
421
|
+
maintained on each worker node).
|
422
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
423
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
424
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
407
425
|
:return Dict: Returns The ARN of the endpoint config created in Amazon SageMaker.
|
408
426
|
"""
|
409
427
|
|
@@ -411,10 +429,9 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
|
|
411
429
|
self,
|
412
430
|
*,
|
413
431
|
config: dict,
|
414
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
415
432
|
**kwargs,
|
416
433
|
):
|
417
|
-
super().__init__(config=config,
|
434
|
+
super().__init__(config=config, **kwargs)
|
418
435
|
|
419
436
|
def _create_integer_fields(self) -> None:
|
420
437
|
"""Set fields which should be cast to integers."""
|
@@ -476,7 +493,14 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
|
|
476
493
|
:param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
|
477
494
|
finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
|
478
495
|
:param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
|
479
|
-
:param aws_conn_id: The
|
496
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
497
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
498
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
499
|
+
empty, then default boto3 configuration would be used (and must be
|
500
|
+
maintained on each worker node).
|
501
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
502
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
503
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
480
504
|
:param deferrable: Will wait asynchronously for completion.
|
481
505
|
:return Dict: Returns The ARN of the endpoint created in Amazon SageMaker.
|
482
506
|
"""
|
@@ -485,7 +509,6 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
|
|
485
509
|
self,
|
486
510
|
*,
|
487
511
|
config: dict,
|
488
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
489
512
|
wait_for_completion: bool = True,
|
490
513
|
check_interval: int = CHECK_INTERVAL_SECOND,
|
491
514
|
max_ingestion_time: int | None = None,
|
@@ -493,7 +516,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
|
|
493
516
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
494
517
|
**kwargs,
|
495
518
|
):
|
496
|
-
super().__init__(config=config,
|
519
|
+
super().__init__(config=config, **kwargs)
|
497
520
|
self.wait_for_completion = wait_for_completion
|
498
521
|
self.check_interval = check_interval
|
499
522
|
self.max_ingestion_time = max_ingestion_time or 3600 * 10
|
@@ -634,7 +657,14 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
634
657
|
For details of the configuration parameter of model_config, See:
|
635
658
|
:py:meth:`SageMaker.Client.create_model`
|
636
659
|
|
637
|
-
:param aws_conn_id: The
|
660
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
661
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
662
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
663
|
+
empty, then default boto3 configuration would be used (and must be
|
664
|
+
maintained on each worker node).
|
665
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
666
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
667
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
638
668
|
:param wait_for_completion: Set to True to wait until the transform job finishes.
|
639
669
|
:param check_interval: If wait is set to True, the time interval, in seconds,
|
640
670
|
that this operation waits to check the status of the transform job.
|
@@ -657,7 +687,6 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
657
687
|
self,
|
658
688
|
*,
|
659
689
|
config: dict,
|
660
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
661
690
|
wait_for_completion: bool = True,
|
662
691
|
check_interval: int = CHECK_INTERVAL_SECOND,
|
663
692
|
max_attempts: int | None = None,
|
@@ -669,7 +698,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
669
698
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
670
699
|
**kwargs,
|
671
700
|
):
|
672
|
-
super().__init__(config=config,
|
701
|
+
super().__init__(config=config, **kwargs)
|
673
702
|
self.wait_for_completion = wait_for_completion
|
674
703
|
self.check_interval = check_interval
|
675
704
|
self.max_attempts = max_attempts or 60
|
@@ -898,7 +927,14 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
|
|
898
927
|
|
899
928
|
For details of the configuration parameter see
|
900
929
|
:py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
|
901
|
-
:param aws_conn_id: The
|
930
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
931
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
932
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
933
|
+
empty, then default boto3 configuration would be used (and must be
|
934
|
+
maintained on each worker node).
|
935
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
936
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
937
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
902
938
|
:param wait_for_completion: Set to True to wait until the tuning job finishes.
|
903
939
|
:param check_interval: If wait is set to True, the time interval, in seconds,
|
904
940
|
that this operation waits to check the status of the tuning job.
|
@@ -913,14 +949,13 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
|
|
913
949
|
self,
|
914
950
|
*,
|
915
951
|
config: dict,
|
916
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
917
952
|
wait_for_completion: bool = True,
|
918
953
|
check_interval: int = CHECK_INTERVAL_SECOND,
|
919
954
|
max_ingestion_time: int | None = None,
|
920
955
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
921
956
|
**kwargs,
|
922
957
|
):
|
923
|
-
super().__init__(config=config,
|
958
|
+
super().__init__(config=config, **kwargs)
|
924
959
|
self.wait_for_completion = wait_for_completion
|
925
960
|
self.check_interval = check_interval
|
926
961
|
self.max_ingestion_time = max_ingestion_time
|
@@ -1016,8 +1051,8 @@ class SageMakerModelOperator(SageMakerBaseOperator):
|
|
1016
1051
|
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
|
1017
1052
|
"""
|
1018
1053
|
|
1019
|
-
def __init__(self, *, config: dict,
|
1020
|
-
super().__init__(config=config,
|
1054
|
+
def __init__(self, *, config: dict, **kwargs):
|
1055
|
+
super().__init__(config=config, **kwargs)
|
1021
1056
|
|
1022
1057
|
def expand_role(self) -> None:
|
1023
1058
|
"""Expand an IAM role name into an ARN."""
|
@@ -1048,7 +1083,14 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
|
|
1048
1083
|
:param config: The configuration necessary to start a training job (templated).
|
1049
1084
|
|
1050
1085
|
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job`
|
1051
|
-
:param aws_conn_id: The
|
1086
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1087
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1088
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1089
|
+
empty, then default boto3 configuration would be used (and must be
|
1090
|
+
maintained on each worker node).
|
1091
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1092
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1093
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1052
1094
|
:param wait_for_completion: If wait is set to True, the time interval, in seconds,
|
1053
1095
|
that the operation waits to check the status of the training job.
|
1054
1096
|
:param print_log: if the operator should print the cloudwatch log during training
|
@@ -1073,7 +1115,6 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
|
|
1073
1115
|
self,
|
1074
1116
|
*,
|
1075
1117
|
config: dict,
|
1076
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
1077
1118
|
wait_for_completion: bool = True,
|
1078
1119
|
print_log: bool = True,
|
1079
1120
|
check_interval: int = CHECK_INTERVAL_SECOND,
|
@@ -1084,7 +1125,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
|
|
1084
1125
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
1085
1126
|
**kwargs,
|
1086
1127
|
):
|
1087
|
-
super().__init__(config=config,
|
1128
|
+
super().__init__(config=config, **kwargs)
|
1088
1129
|
self.wait_for_completion = wait_for_completion
|
1089
1130
|
self.print_log = print_log
|
1090
1131
|
self.check_interval = check_interval
|
@@ -1243,15 +1284,22 @@ class SageMakerDeleteModelOperator(SageMakerBaseOperator):
|
|
1243
1284
|
|
1244
1285
|
:param config: The configuration necessary to delete the model.
|
1245
1286
|
For details of the configuration parameter see :py:meth:`SageMaker.Client.delete_model`
|
1246
|
-
:param aws_conn_id: The
|
1287
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1288
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1289
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1290
|
+
empty, then default boto3 configuration would be used (and must be
|
1291
|
+
maintained on each worker node).
|
1292
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1293
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1294
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1247
1295
|
"""
|
1248
1296
|
|
1249
|
-
def __init__(self, *, config: dict,
|
1250
|
-
super().__init__(config=config,
|
1297
|
+
def __init__(self, *, config: dict, **kwargs):
|
1298
|
+
super().__init__(config=config, **kwargs)
|
1251
1299
|
|
1252
1300
|
def execute(self, context: Context) -> Any:
|
1253
|
-
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
1254
|
-
|
1301
|
+
# sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
1302
|
+
self.hook.delete_model(model_name=self.config["ModelName"])
|
1255
1303
|
self.log.info("Model %s deleted successfully.", self.config["ModelName"])
|
1256
1304
|
|
1257
1305
|
|
@@ -1264,7 +1312,14 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
1264
1312
|
:ref:`howto/operator:SageMakerStartPipelineOperator`
|
1265
1313
|
|
1266
1314
|
:param config: The configuration to start the pipeline execution.
|
1267
|
-
:param aws_conn_id: The
|
1315
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1316
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1317
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1318
|
+
empty, then default boto3 configuration would be used (and must be
|
1319
|
+
maintained on each worker node).
|
1320
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1321
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1322
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1268
1323
|
:param pipeline_name: Name of the pipeline to start.
|
1269
1324
|
:param display_name: The name this pipeline execution will have in the UI. Doesn't need to be unique.
|
1270
1325
|
:param pipeline_params: Optional parameters for the pipeline.
|
@@ -1279,8 +1334,7 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
1279
1334
|
:return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
|
1280
1335
|
"""
|
1281
1336
|
|
1282
|
-
template_fields: Sequence[str] = (
|
1283
|
-
"aws_conn_id",
|
1337
|
+
template_fields: Sequence[str] = aws_template_fields(
|
1284
1338
|
"pipeline_name",
|
1285
1339
|
"display_name",
|
1286
1340
|
"pipeline_params",
|
@@ -1289,7 +1343,6 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
1289
1343
|
def __init__(
|
1290
1344
|
self,
|
1291
1345
|
*,
|
1292
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
1293
1346
|
pipeline_name: str,
|
1294
1347
|
display_name: str = "airflow-triggered-execution",
|
1295
1348
|
pipeline_params: dict | None = None,
|
@@ -1300,7 +1353,7 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
1300
1353
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
1301
1354
|
**kwargs,
|
1302
1355
|
):
|
1303
|
-
super().__init__(config={},
|
1356
|
+
super().__init__(config={}, **kwargs)
|
1304
1357
|
self.pipeline_name = pipeline_name
|
1305
1358
|
self.display_name = display_name
|
1306
1359
|
self.pipeline_params = pipeline_params
|
@@ -1358,7 +1411,14 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
|
|
1358
1411
|
:ref:`howto/operator:SageMakerStopPipelineOperator`
|
1359
1412
|
|
1360
1413
|
:param config: The configuration to start the pipeline execution.
|
1361
|
-
:param aws_conn_id: The
|
1414
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1415
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1416
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1417
|
+
empty, then default boto3 configuration would be used (and must be
|
1418
|
+
maintained on each worker node).
|
1419
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1420
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1421
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1362
1422
|
:param pipeline_exec_arn: Amazon Resource Name of the pipeline execution to stop.
|
1363
1423
|
:param wait_for_completion: If true, this operator will only complete once the pipeline is fully stopped.
|
1364
1424
|
:param check_interval: How long to wait between checks for pipeline status when waiting for completion.
|
@@ -1370,15 +1430,13 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
|
|
1370
1430
|
:return str: Returns the status of the pipeline execution after the operation has been done.
|
1371
1431
|
"""
|
1372
1432
|
|
1373
|
-
template_fields: Sequence[str] = (
|
1374
|
-
"aws_conn_id",
|
1433
|
+
template_fields: Sequence[str] = aws_template_fields(
|
1375
1434
|
"pipeline_exec_arn",
|
1376
1435
|
)
|
1377
1436
|
|
1378
1437
|
def __init__(
|
1379
1438
|
self,
|
1380
1439
|
*,
|
1381
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
1382
1440
|
pipeline_exec_arn: str,
|
1383
1441
|
wait_for_completion: bool = False,
|
1384
1442
|
check_interval: int = CHECK_INTERVAL_SECOND,
|
@@ -1388,7 +1446,7 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
|
|
1388
1446
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
1389
1447
|
**kwargs,
|
1390
1448
|
):
|
1391
|
-
super().__init__(config={},
|
1449
|
+
super().__init__(config={}, **kwargs)
|
1392
1450
|
self.pipeline_exec_arn = pipeline_exec_arn
|
1393
1451
|
self.wait_for_completion = wait_for_completion
|
1394
1452
|
self.check_interval = check_interval
|
@@ -1469,11 +1527,18 @@ class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
|
|
1469
1527
|
:param extras: Can contain extra parameters for the boto call to create_model_package, and/or overrides
|
1470
1528
|
for any parameter defined above. For a complete list of available parameters, see
|
1471
1529
|
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model_package
|
1472
|
-
|
1530
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1531
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1532
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1533
|
+
empty, then default boto3 configuration would be used (and must be
|
1534
|
+
maintained on each worker node).
|
1535
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1536
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1537
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1473
1538
|
:return str: Returns the ARN of the model package created.
|
1474
1539
|
"""
|
1475
1540
|
|
1476
|
-
template_fields: Sequence[str] = (
|
1541
|
+
template_fields: Sequence[str] = aws_template_fields(
|
1477
1542
|
"image_uri",
|
1478
1543
|
"model_url",
|
1479
1544
|
"package_group_name",
|
@@ -1492,11 +1557,10 @@ class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
|
|
1492
1557
|
package_desc: str = "",
|
1493
1558
|
model_approval: ApprovalStatus = ApprovalStatus.PENDING_MANUAL_APPROVAL,
|
1494
1559
|
extras: dict | None = None,
|
1495
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
1496
1560
|
config: dict | None = None,
|
1497
1561
|
**kwargs,
|
1498
1562
|
):
|
1499
|
-
super().__init__(config=config or {},
|
1563
|
+
super().__init__(config=config or {}, **kwargs)
|
1500
1564
|
self.image_uri = image_uri
|
1501
1565
|
self.model_url = model_url
|
1502
1566
|
self.package_group_name = package_group_name
|
@@ -1563,13 +1627,20 @@ class SageMakerAutoMLOperator(SageMakerBaseOperator):
|
|
1563
1627
|
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_auto_ml_job
|
1564
1628
|
:param wait_for_completion: Whether to wait for the job to finish before returning. Defaults to True.
|
1565
1629
|
:param check_interval: Interval in seconds between 2 status checks when waiting for completion.
|
1566
|
-
|
1630
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1631
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1632
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1633
|
+
empty, then default boto3 configuration would be used (and must be
|
1634
|
+
maintained on each worker node).
|
1635
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1636
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1637
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1567
1638
|
:returns: Only if waiting for completion, a dictionary detailing the best model. The structure is that of
|
1568
1639
|
the "BestCandidate" key in:
|
1569
1640
|
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
|
1570
1641
|
"""
|
1571
1642
|
|
1572
|
-
template_fields: Sequence[str] = (
|
1643
|
+
template_fields: Sequence[str] = aws_template_fields(
|
1573
1644
|
"job_name",
|
1574
1645
|
"s3_input",
|
1575
1646
|
"target_attribute",
|
@@ -1595,11 +1666,10 @@ class SageMakerAutoMLOperator(SageMakerBaseOperator):
|
|
1595
1666
|
extras: dict | None = None,
|
1596
1667
|
wait_for_completion: bool = True,
|
1597
1668
|
check_interval: int = 30,
|
1598
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
1599
1669
|
config: dict | None = None,
|
1600
1670
|
**kwargs,
|
1601
1671
|
):
|
1602
|
-
super().__init__(config=config or {},
|
1672
|
+
super().__init__(config=config or {}, **kwargs)
|
1603
1673
|
self.job_name = job_name
|
1604
1674
|
self.s3_input = s3_input
|
1605
1675
|
self.target_attribute = target_attribute
|
@@ -1640,12 +1710,19 @@ class SageMakerCreateExperimentOperator(SageMakerBaseOperator):
|
|
1640
1710
|
:param name: name of the experiment, must be unique within the AWS account
|
1641
1711
|
:param description: description of the experiment, optional
|
1642
1712
|
:param tags: tags to attach to the experiment, optional
|
1643
|
-
:param aws_conn_id: The
|
1713
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1714
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1715
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1716
|
+
empty, then default boto3 configuration would be used (and must be
|
1717
|
+
maintained on each worker node).
|
1718
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1719
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1720
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1644
1721
|
|
1645
1722
|
:returns: the ARN of the experiment created, though experiments are referred to by name
|
1646
1723
|
"""
|
1647
1724
|
|
1648
|
-
template_fields: Sequence[str] = (
|
1725
|
+
template_fields: Sequence[str] = aws_template_fields(
|
1649
1726
|
"name",
|
1650
1727
|
"description",
|
1651
1728
|
"tags",
|
@@ -1657,28 +1734,26 @@ class SageMakerCreateExperimentOperator(SageMakerBaseOperator):
|
|
1657
1734
|
name: str,
|
1658
1735
|
description: str | None = None,
|
1659
1736
|
tags: dict | None = None,
|
1660
|
-
aws_conn_id: str | None = DEFAULT_CONN_ID,
|
1661
1737
|
**kwargs,
|
1662
1738
|
):
|
1663
|
-
super().__init__(config={},
|
1739
|
+
super().__init__(config={}, **kwargs)
|
1664
1740
|
self.name = name
|
1665
1741
|
self.description = description
|
1666
1742
|
self.tags = tags or {}
|
1667
1743
|
|
1668
1744
|
def execute(self, context: Context) -> str:
|
1669
|
-
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
|
1670
1745
|
params = {
|
1671
1746
|
"ExperimentName": self.name,
|
1672
1747
|
"Description": self.description,
|
1673
1748
|
"Tags": format_tags(self.tags),
|
1674
1749
|
}
|
1675
|
-
ans =
|
1750
|
+
ans = self.hook.conn.create_experiment(**trim_none_values(params))
|
1676
1751
|
arn = ans["ExperimentArn"]
|
1677
1752
|
self.log.info("Experiment %s created successfully with ARN %s.", self.name, arn)
|
1678
1753
|
return arn
|
1679
1754
|
|
1680
1755
|
|
1681
|
-
class SageMakerCreateNotebookOperator(
|
1756
|
+
class SageMakerCreateNotebookOperator(AwsBaseOperator[SageMakerHook]):
|
1682
1757
|
"""
|
1683
1758
|
Create a SageMaker notebook.
|
1684
1759
|
|
@@ -1699,12 +1774,19 @@ class SageMakerCreateNotebookOperator(BaseOperator):
|
|
1699
1774
|
:param root_access: Whether to give the notebook instance root access to the Amazon S3 bucket.
|
1700
1775
|
:param wait_for_completion: Whether or not to wait for the notebook to be InService before returning
|
1701
1776
|
:param create_instance_kwargs: Additional configuration options for the create call.
|
1702
|
-
:param aws_conn_id: The
|
1703
|
-
|
1777
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1778
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1779
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1780
|
+
empty, then default boto3 configuration would be used (and must be
|
1781
|
+
maintained on each worker node).
|
1782
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1783
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1784
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1704
1785
|
:return: The ARN of the created notebook.
|
1705
1786
|
"""
|
1706
1787
|
|
1707
|
-
|
1788
|
+
aws_hook_class = SageMakerHook
|
1789
|
+
template_fields: Sequence[str] = aws_template_fields(
|
1708
1790
|
"instance_name",
|
1709
1791
|
"instance_type",
|
1710
1792
|
"role_arn",
|
@@ -1732,7 +1814,6 @@ class SageMakerCreateNotebookOperator(BaseOperator):
|
|
1732
1814
|
root_access: str | None = None,
|
1733
1815
|
create_instance_kwargs: dict[str, Any] | None = None,
|
1734
1816
|
wait_for_completion: bool = True,
|
1735
|
-
aws_conn_id: str | None = "aws_default",
|
1736
1817
|
**kwargs,
|
1737
1818
|
):
|
1738
1819
|
super().__init__(**kwargs)
|
@@ -1745,17 +1826,11 @@ class SageMakerCreateNotebookOperator(BaseOperator):
|
|
1745
1826
|
self.direct_internet_access = direct_internet_access
|
1746
1827
|
self.root_access = root_access
|
1747
1828
|
self.wait_for_completion = wait_for_completion
|
1748
|
-
self.aws_conn_id = aws_conn_id
|
1749
1829
|
self.create_instance_kwargs = create_instance_kwargs or {}
|
1750
1830
|
|
1751
1831
|
if self.create_instance_kwargs.get("tags") is not None:
|
1752
1832
|
self.create_instance_kwargs["tags"] = format_tags(self.create_instance_kwargs["tags"])
|
1753
1833
|
|
1754
|
-
@cached_property
|
1755
|
-
def hook(self) -> SageMakerHook:
|
1756
|
-
"""Create and return SageMakerHook."""
|
1757
|
-
return SageMakerHook(aws_conn_id=self.aws_conn_id)
|
1758
|
-
|
1759
1834
|
def execute(self, context: Context):
|
1760
1835
|
create_notebook_instance_kwargs = {
|
1761
1836
|
"NotebookInstanceName": self.instance_name,
|
@@ -1783,7 +1858,7 @@ class SageMakerCreateNotebookOperator(BaseOperator):
|
|
1783
1858
|
return response["NotebookInstanceArn"]
|
1784
1859
|
|
1785
1860
|
|
1786
|
-
class SageMakerStopNotebookOperator(
|
1861
|
+
class SageMakerStopNotebookOperator(AwsBaseOperator[SageMakerHook]):
|
1787
1862
|
"""
|
1788
1863
|
Stop a notebook instance.
|
1789
1864
|
|
@@ -1793,10 +1868,18 @@ class SageMakerStopNotebookOperator(BaseOperator):
|
|
1793
1868
|
|
1794
1869
|
:param instance_name: The name of the notebook instance to stop.
|
1795
1870
|
:param wait_for_completion: Whether or not to wait for the notebook to be stopped before returning
|
1796
|
-
:param aws_conn_id: The
|
1871
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1872
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1873
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1874
|
+
empty, then default boto3 configuration would be used (and must be
|
1875
|
+
maintained on each worker node).
|
1876
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1877
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1878
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1797
1879
|
"""
|
1798
1880
|
|
1799
|
-
|
1881
|
+
aws_hook_class = SageMakerHook
|
1882
|
+
template_fields: Sequence[str] = aws_template_fields("instance_name", "wait_for_completion")
|
1800
1883
|
|
1801
1884
|
ui_color = "#ff7300"
|
1802
1885
|
|
@@ -1804,18 +1887,11 @@ class SageMakerStopNotebookOperator(BaseOperator):
|
|
1804
1887
|
self,
|
1805
1888
|
instance_name: str,
|
1806
1889
|
wait_for_completion: bool = True,
|
1807
|
-
aws_conn_id: str | None = "aws_default",
|
1808
1890
|
**kwargs,
|
1809
1891
|
):
|
1810
1892
|
super().__init__(**kwargs)
|
1811
1893
|
self.instance_name = instance_name
|
1812
1894
|
self.wait_for_completion = wait_for_completion
|
1813
|
-
self.aws_conn_id = aws_conn_id
|
1814
|
-
|
1815
|
-
@cached_property
|
1816
|
-
def hook(self) -> SageMakerHook:
|
1817
|
-
"""Create and return SageMakerHook."""
|
1818
|
-
return SageMakerHook(aws_conn_id=self.aws_conn_id)
|
1819
1895
|
|
1820
1896
|
def execute(self, context):
|
1821
1897
|
self.log.info("Stopping SageMaker notebook %s.", self.instance_name)
|
@@ -1828,7 +1904,7 @@ class SageMakerStopNotebookOperator(BaseOperator):
|
|
1828
1904
|
)
|
1829
1905
|
|
1830
1906
|
|
1831
|
-
class SageMakerDeleteNotebookOperator(
|
1907
|
+
class SageMakerDeleteNotebookOperator(AwsBaseOperator[SageMakerHook]):
|
1832
1908
|
"""
|
1833
1909
|
Delete a notebook instance.
|
1834
1910
|
|
@@ -1838,30 +1914,30 @@ class SageMakerDeleteNotebookOperator(BaseOperator):
|
|
1838
1914
|
|
1839
1915
|
:param instance_name: The name of the notebook instance to delete.
|
1840
1916
|
:param wait_for_completion: Whether or not to wait for the notebook to delete before returning.
|
1841
|
-
:param aws_conn_id: The
|
1917
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1918
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1919
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1920
|
+
empty, then default boto3 configuration would be used (and must be
|
1921
|
+
maintained on each worker node).
|
1922
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1923
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1924
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1842
1925
|
"""
|
1843
1926
|
|
1844
|
-
template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
|
1845
|
-
|
1927
|
+
template_fields: Sequence[str] = aws_template_fields("instance_name", "wait_for_completion")
|
1928
|
+
aws_hook_class = SageMakerHook
|
1846
1929
|
ui_color = "#ff7300"
|
1847
1930
|
|
1848
1931
|
def __init__(
|
1849
1932
|
self,
|
1850
1933
|
instance_name: str,
|
1851
1934
|
wait_for_completion: bool = True,
|
1852
|
-
aws_conn_id: str | None = "aws_default",
|
1853
1935
|
**kwargs,
|
1854
1936
|
):
|
1855
1937
|
super().__init__(**kwargs)
|
1856
1938
|
self.instance_name = instance_name
|
1857
|
-
self.aws_conn_id = aws_conn_id
|
1858
1939
|
self.wait_for_completion = wait_for_completion
|
1859
1940
|
|
1860
|
-
@cached_property
|
1861
|
-
def hook(self) -> SageMakerHook:
|
1862
|
-
"""Create and return SageMakerHook."""
|
1863
|
-
return SageMakerHook(aws_conn_id=self.aws_conn_id)
|
1864
|
-
|
1865
1941
|
def execute(self, context):
|
1866
1942
|
self.log.info("Deleting SageMaker notebook %s....", self.instance_name)
|
1867
1943
|
self.hook.conn.delete_notebook_instance(NotebookInstanceName=self.instance_name)
|
@@ -1873,7 +1949,7 @@ class SageMakerDeleteNotebookOperator(BaseOperator):
|
|
1873
1949
|
)
|
1874
1950
|
|
1875
1951
|
|
1876
|
-
class SageMakerStartNoteBookOperator(
|
1952
|
+
class SageMakerStartNoteBookOperator(AwsBaseOperator[SageMakerHook]):
|
1877
1953
|
"""
|
1878
1954
|
Start a notebook instance.
|
1879
1955
|
|
@@ -1883,10 +1959,18 @@ class SageMakerStartNoteBookOperator(BaseOperator):
|
|
1883
1959
|
|
1884
1960
|
:param instance_name: The name of the notebook instance to start.
|
1885
1961
|
:param wait_for_completion: Whether or not to wait for notebook to be InService before returning
|
1886
|
-
:param aws_conn_id: The
|
1962
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
1963
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
1964
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
1965
|
+
empty, then default boto3 configuration would be used (and must be
|
1966
|
+
maintained on each worker node).
|
1967
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
1968
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
1969
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
1887
1970
|
"""
|
1888
1971
|
|
1889
|
-
template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
|
1972
|
+
template_fields: Sequence[str] = aws_template_fields("instance_name", "wait_for_completion")
|
1973
|
+
aws_hook_class = SageMakerHook
|
1890
1974
|
|
1891
1975
|
ui_color = "#ff7300"
|
1892
1976
|
|
@@ -1894,19 +1978,12 @@ class SageMakerStartNoteBookOperator(BaseOperator):
|
|
1894
1978
|
self,
|
1895
1979
|
instance_name: str,
|
1896
1980
|
wait_for_completion: bool = True,
|
1897
|
-
aws_conn_id: str | None = "aws_default",
|
1898
1981
|
**kwargs,
|
1899
1982
|
):
|
1900
1983
|
super().__init__(**kwargs)
|
1901
1984
|
self.instance_name = instance_name
|
1902
|
-
self.aws_conn_id = aws_conn_id
|
1903
1985
|
self.wait_for_completion = wait_for_completion
|
1904
1986
|
|
1905
|
-
@cached_property
|
1906
|
-
def hook(self) -> SageMakerHook:
|
1907
|
-
"""Create and return SageMakerHook."""
|
1908
|
-
return SageMakerHook(aws_conn_id=self.aws_conn_id)
|
1909
|
-
|
1910
1987
|
def execute(self, context):
|
1911
1988
|
self.log.info("Starting SageMaker notebook %s....", self.instance_name)
|
1912
1989
|
self.hook.conn.start_notebook_instance(NotebookInstanceName=self.instance_name)
|
@@ -107,7 +107,7 @@ class S3KeySensor(AwsBaseSensor[S3Hook]):
|
|
107
107
|
self.verify = verify
|
108
108
|
self.deferrable = deferrable
|
109
109
|
self.use_regex = use_regex
|
110
|
-
self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
|
110
|
+
self.metadata_keys = metadata_keys if metadata_keys else ["Size", "Key"]
|
111
111
|
|
112
112
|
def _check_key(self, key, context: Context):
|
113
113
|
bucket_name, key = self.hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
|
@@ -116,7 +116,8 @@ class S3KeySensor(AwsBaseSensor[S3Hook]):
|
|
116
116
|
"""
|
117
117
|
Set variable `files` which contains a list of dict which contains attributes defined by the user
|
118
118
|
Format: [{
|
119
|
-
'Size': int
|
119
|
+
'Size': int,
|
120
|
+
'Key': str,
|
120
121
|
}]
|
121
122
|
"""
|
122
123
|
if self.wildcard_match:
|
@@ -18,18 +18,18 @@ from __future__ import annotations
|
|
18
18
|
|
19
19
|
import time
|
20
20
|
from collections.abc import Sequence
|
21
|
-
from functools import cached_property
|
22
21
|
from typing import TYPE_CHECKING
|
23
22
|
|
24
23
|
from airflow.exceptions import AirflowException
|
25
24
|
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
|
26
|
-
from airflow.sensors.
|
25
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
26
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
29
|
from airflow.utils.context import Context
|
30
30
|
|
31
31
|
|
32
|
-
class SageMakerBaseSensor(
|
32
|
+
class SageMakerBaseSensor(AwsBaseSensor[SageMakerHook]):
|
33
33
|
"""
|
34
34
|
Contains general sensor behavior for SageMaker.
|
35
35
|
|
@@ -37,17 +37,13 @@ class SageMakerBaseSensor(BaseSensorOperator):
|
|
37
37
|
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
|
38
38
|
"""
|
39
39
|
|
40
|
+
aws_hook_class = SageMakerHook
|
40
41
|
ui_color = "#ededed"
|
41
42
|
|
42
|
-
def __init__(self, *,
|
43
|
+
def __init__(self, *, resource_type: str = "job", **kwargs):
|
43
44
|
super().__init__(**kwargs)
|
44
|
-
self.aws_conn_id = aws_conn_id
|
45
45
|
self.resource_type = resource_type # only used for logs, to say what kind of resource we are sensing
|
46
46
|
|
47
|
-
@cached_property
|
48
|
-
def hook(self) -> SageMakerHook:
|
49
|
-
return SageMakerHook(aws_conn_id=self.aws_conn_id)
|
50
|
-
|
51
47
|
def poke(self, context: Context):
|
52
48
|
response = self.get_sagemaker_response()
|
53
49
|
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
@@ -96,7 +92,9 @@ class SageMakerEndpointSensor(SageMakerBaseSensor):
|
|
96
92
|
:param endpoint_name: Name of the endpoint instance to watch.
|
97
93
|
"""
|
98
94
|
|
99
|
-
template_fields: Sequence[str] = (
|
95
|
+
template_fields: Sequence[str] = aws_template_fields(
|
96
|
+
"endpoint_name",
|
97
|
+
)
|
100
98
|
template_ext: Sequence[str] = ()
|
101
99
|
|
102
100
|
def __init__(self, *, endpoint_name, **kwargs):
|
@@ -131,7 +129,9 @@ class SageMakerTransformSensor(SageMakerBaseSensor):
|
|
131
129
|
:param job_name: Name of the transform job to watch.
|
132
130
|
"""
|
133
131
|
|
134
|
-
template_fields: Sequence[str] = (
|
132
|
+
template_fields: Sequence[str] = aws_template_fields(
|
133
|
+
"job_name",
|
134
|
+
)
|
135
135
|
template_ext: Sequence[str] = ()
|
136
136
|
|
137
137
|
def __init__(self, *, job_name: str, **kwargs):
|
@@ -166,7 +166,9 @@ class SageMakerTuningSensor(SageMakerBaseSensor):
|
|
166
166
|
:param job_name: Name of the tuning instance to watch.
|
167
167
|
"""
|
168
168
|
|
169
|
-
template_fields: Sequence[str] = (
|
169
|
+
template_fields: Sequence[str] = aws_template_fields(
|
170
|
+
"job_name",
|
171
|
+
)
|
170
172
|
template_ext: Sequence[str] = ()
|
171
173
|
|
172
174
|
def __init__(self, *, job_name: str, **kwargs):
|
@@ -202,7 +204,9 @@ class SageMakerTrainingSensor(SageMakerBaseSensor):
|
|
202
204
|
:param print_log: Prints the cloudwatch log if True; Defaults to True.
|
203
205
|
"""
|
204
206
|
|
205
|
-
template_fields: Sequence[str] = (
|
207
|
+
template_fields: Sequence[str] = aws_template_fields(
|
208
|
+
"job_name",
|
209
|
+
)
|
206
210
|
template_ext: Sequence[str] = ()
|
207
211
|
|
208
212
|
def __init__(self, *, job_name, print_log=True, **kwargs):
|
@@ -281,7 +285,9 @@ class SageMakerPipelineSensor(SageMakerBaseSensor):
|
|
281
285
|
Defaults to true, consider turning off for pipelines that have thousands of steps.
|
282
286
|
"""
|
283
287
|
|
284
|
-
template_fields: Sequence[str] = (
|
288
|
+
template_fields: Sequence[str] = aws_template_fields(
|
289
|
+
"pipeline_exec_arn",
|
290
|
+
)
|
285
291
|
|
286
292
|
def __init__(self, *, pipeline_exec_arn: str, verbose: bool = True, **kwargs):
|
287
293
|
super().__init__(resource_type="pipeline", **kwargs)
|
@@ -313,7 +319,9 @@ class SageMakerAutoMLSensor(SageMakerBaseSensor):
|
|
313
319
|
:param job_name: unique name of the AutoML job to watch.
|
314
320
|
"""
|
315
321
|
|
316
|
-
template_fields: Sequence[str] = (
|
322
|
+
template_fields: Sequence[str] = aws_template_fields(
|
323
|
+
"job_name",
|
324
|
+
)
|
317
325
|
|
318
326
|
def __init__(self, *, job_name: str, **kwargs):
|
319
327
|
super().__init__(resource_type="autoML job", **kwargs)
|
@@ -344,7 +352,9 @@ class SageMakerProcessingSensor(SageMakerBaseSensor):
|
|
344
352
|
:param job_name: Name of the processing job to watch.
|
345
353
|
"""
|
346
354
|
|
347
|
-
template_fields: Sequence[str] = (
|
355
|
+
template_fields: Sequence[str] = aws_template_fields(
|
356
|
+
"job_name",
|
357
|
+
)
|
348
358
|
template_ext: Sequence[str] = ()
|
349
359
|
|
350
360
|
def __init__(self, *, job_name: str, **kwargs):
|
@@ -18,9 +18,10 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import enum
|
21
|
+
import gzip
|
22
|
+
import io
|
21
23
|
from collections import namedtuple
|
22
24
|
from collections.abc import Iterable, Mapping, Sequence
|
23
|
-
from tempfile import NamedTemporaryFile
|
24
25
|
from typing import TYPE_CHECKING, Any, cast
|
25
26
|
|
26
27
|
from typing_extensions import Literal
|
@@ -191,16 +192,29 @@ class SqlToS3Operator(BaseOperator):
|
|
191
192
|
self.log.info("Data from SQL obtained")
|
192
193
|
self._fix_dtypes(data_df, self.file_format)
|
193
194
|
file_options = FILE_OPTIONS_MAP[self.file_format]
|
195
|
+
|
194
196
|
for group_name, df in self._partition_dataframe(df=data_df):
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
197
|
+
buf = io.BytesIO()
|
198
|
+
self.log.info("Writing data to in-memory buffer")
|
199
|
+
object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key
|
200
|
+
|
201
|
+
if self.pd_kwargs.get("compression") == "gzip":
|
202
|
+
pd_kwargs = {k: v for k, v in self.pd_kwargs.items() if k != "compression"}
|
203
|
+
with gzip.GzipFile(fileobj=buf, mode="wb", filename=object_key) as gz:
|
204
|
+
getattr(df, file_options.function)(gz, **pd_kwargs)
|
205
|
+
else:
|
206
|
+
if self.file_format == FILE_FORMAT.PARQUET:
|
207
|
+
getattr(df, file_options.function)(buf, **self.pd_kwargs)
|
208
|
+
else:
|
209
|
+
text_buf = io.TextIOWrapper(buf, encoding="utf-8", write_through=True)
|
210
|
+
getattr(df, file_options.function)(text_buf, **self.pd_kwargs)
|
211
|
+
text_buf.flush()
|
212
|
+
buf.seek(0)
|
213
|
+
|
214
|
+
self.log.info("Uploading data to S3")
|
215
|
+
s3_conn.load_file_obj(
|
216
|
+
file_obj=buf, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
|
217
|
+
)
|
204
218
|
|
205
219
|
def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
|
206
220
|
"""Partition dataframe using pandas groupby() method."""
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: apache-airflow-providers-amazon
|
3
|
-
Version: 9.
|
3
|
+
Version: 9.8.0
|
4
4
|
Summary: Provider package apache-airflow-providers-amazon for Apache Airflow
|
5
5
|
Keywords: airflow-provider,amazon,airflow,integration
|
6
6
|
Author-email: Apache Software Foundation <dev@airflow.apache.org>
|
@@ -20,9 +20,9 @@ Classifier: Programming Language :: Python :: 3.10
|
|
20
20
|
Classifier: Programming Language :: Python :: 3.11
|
21
21
|
Classifier: Programming Language :: Python :: 3.12
|
22
22
|
Classifier: Topic :: System :: Monitoring
|
23
|
-
Requires-Dist: apache-airflow>=2.10.
|
24
|
-
Requires-Dist: apache-airflow-providers-common-compat>=1.6.
|
25
|
-
Requires-Dist: apache-airflow-providers-common-sql>=1.27.
|
23
|
+
Requires-Dist: apache-airflow>=2.10.0
|
24
|
+
Requires-Dist: apache-airflow-providers-common-compat>=1.6.1
|
25
|
+
Requires-Dist: apache-airflow-providers-common-sql>=1.27.0
|
26
26
|
Requires-Dist: apache-airflow-providers-http
|
27
27
|
Requires-Dist: boto3>=1.37.0
|
28
28
|
Requires-Dist: botocore>=1.37.0
|
@@ -40,8 +40,8 @@ Requires-Dist: sagemaker-studio>=1.0.9
|
|
40
40
|
Requires-Dist: marshmallow>=3
|
41
41
|
Requires-Dist: aiobotocore[boto3]>=2.21.1 ; extra == "aiobotocore"
|
42
42
|
Requires-Dist: apache-airflow-providers-apache-hive ; extra == "apache-hive"
|
43
|
-
Requires-Dist: apache-airflow-providers-cncf-kubernetes>=7.2.
|
44
|
-
Requires-Dist: apache-airflow-providers-common-messaging>=1.0.
|
43
|
+
Requires-Dist: apache-airflow-providers-cncf-kubernetes>=7.2.0 ; extra == "cncf-kubernetes"
|
44
|
+
Requires-Dist: apache-airflow-providers-common-messaging>=1.0.1 ; extra == "common-messaging"
|
45
45
|
Requires-Dist: apache-airflow-providers-exasol ; extra == "exasol"
|
46
46
|
Requires-Dist: apache-airflow-providers-fab ; extra == "fab"
|
47
47
|
Requires-Dist: apache-airflow-providers-ftp ; extra == "ftp"
|
@@ -49,15 +49,15 @@ Requires-Dist: apache-airflow-providers-google ; extra == "google"
|
|
49
49
|
Requires-Dist: apache-airflow-providers-imap ; extra == "imap"
|
50
50
|
Requires-Dist: apache-airflow-providers-microsoft-azure ; extra == "microsoft-azure"
|
51
51
|
Requires-Dist: apache-airflow-providers-mongo ; extra == "mongo"
|
52
|
-
Requires-Dist: apache-airflow-providers-openlineage>=2.3.
|
52
|
+
Requires-Dist: apache-airflow-providers-openlineage>=2.3.0 ; extra == "openlineage"
|
53
53
|
Requires-Dist: python3-saml>=1.16.0 ; extra == "python3-saml"
|
54
54
|
Requires-Dist: s3fs>=2023.10.0 ; extra == "s3fs"
|
55
55
|
Requires-Dist: apache-airflow-providers-salesforce ; extra == "salesforce"
|
56
56
|
Requires-Dist: apache-airflow-providers-ssh ; extra == "ssh"
|
57
57
|
Requires-Dist: apache-airflow-providers-standard ; extra == "standard"
|
58
58
|
Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
|
59
|
-
Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.
|
60
|
-
Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.
|
59
|
+
Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.8.0/changelog.html
|
60
|
+
Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.8.0
|
61
61
|
Project-URL: Mastodon, https://fosstodon.org/@airflow
|
62
62
|
Project-URL: Slack Chat, https://s.apache.org/airflow-slack
|
63
63
|
Project-URL: Source Code, https://github.com/apache/airflow
|
@@ -105,7 +105,7 @@ Provides-Extra: standard
|
|
105
105
|
|
106
106
|
Package ``apache-airflow-providers-amazon``
|
107
107
|
|
108
|
-
Release: ``9.
|
108
|
+
Release: ``9.8.0``
|
109
109
|
|
110
110
|
|
111
111
|
Amazon integration (including `Amazon Web Services (AWS) <https://aws.amazon.com/>`__).
|
@@ -118,7 +118,7 @@ This is a provider package for ``amazon`` provider. All classes for this provide
|
|
118
118
|
are in ``airflow.providers.amazon`` python package.
|
119
119
|
|
120
120
|
You can find package information and changelog for the provider
|
121
|
-
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.
|
121
|
+
in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.8.0/>`_.
|
122
122
|
|
123
123
|
Installation
|
124
124
|
------------
|
@@ -189,5 +189,5 @@ Dependent package
|
|
189
189
|
======================================================================================================================== ====================
|
190
190
|
|
191
191
|
The changelog for the provider package can be found in the
|
192
|
-
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.
|
192
|
+
`changelog <https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.8.0/changelog.html>`_.
|
193
193
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
airflow/providers/amazon/LICENSE,sha256=gXPVwptPlW1TJ4HSuG5OMPg-a3h43OGMkZRR1rpwfJA,10850
|
2
|
-
airflow/providers/amazon/__init__.py,sha256=
|
2
|
+
airflow/providers/amazon/__init__.py,sha256=_CZ_wEd1ln7ijiK7yeMFFxlJFhkSQb_T-U1QhAjBiH8,1495
|
3
3
|
airflow/providers/amazon/get_provider_info.py,sha256=iXOUQZQkWSX6JDGZnqaQp7B7RzYyyW0RoLgS8qXFRl0,68490
|
4
4
|
airflow/providers/amazon/version_compat.py,sha256=j5PCtXvZ71aBjixu-EFTNtVDPsngzzs7os0ZQDgFVDk,1536
|
5
5
|
airflow/providers/amazon/aws/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
|
@@ -136,7 +136,7 @@ airflow/providers/amazon/aws/operators/rds.py,sha256=cib-k2aHpa-DeYL92HqRMa5x9O8
|
|
136
136
|
airflow/providers/amazon/aws/operators/redshift_cluster.py,sha256=dXakMyZV5vvEh0-20FOolMU5xEDugnMrvwwbLvuYc3o,37168
|
137
137
|
airflow/providers/amazon/aws/operators/redshift_data.py,sha256=dcPEYGgXn9M1zS2XP7szm_kYO2xWD2CiUqsmlygf0_k,10854
|
138
138
|
airflow/providers/amazon/aws/operators/s3.py,sha256=Imd3siCtmtaPWRmmSd382dJHhr49WRd-_aP6Tx5T7ac,38389
|
139
|
-
airflow/providers/amazon/aws/operators/sagemaker.py,sha256=
|
139
|
+
airflow/providers/amazon/aws/operators/sagemaker.py,sha256=54RsCEJ735MsAmXh3Z0hWVQFOP4X98kSeeKKXdWijaU,91293
|
140
140
|
airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py,sha256=J-huObn3pZ_fg0gEy-BLsX298CX_n7qWV2YwjfpFnrw,6867
|
141
141
|
airflow/providers/amazon/aws/operators/sns.py,sha256=uVcSJBbqy7YCOeiCrMvFFn9F9xTzMRpfrEygqEIhWEM,3757
|
142
142
|
airflow/providers/amazon/aws/operators/sqs.py,sha256=o9rH2Pm5DNmccLh5I2wr96hZiuxOPi6YGZ2QluOeVb0,4764
|
@@ -170,8 +170,8 @@ airflow/providers/amazon/aws/sensors/opensearch_serverless.py,sha256=cSaZvCvAC7z
|
|
170
170
|
airflow/providers/amazon/aws/sensors/quicksight.py,sha256=lm1omzh01BKh0KHU3g2I1yH9LAXtddUDiuIS3uIeOrE,3575
|
171
171
|
airflow/providers/amazon/aws/sensors/rds.py,sha256=HWYQOQ7n9s48Ci2WxBOtrAp17aB-at5werAljq3NDYE,7420
|
172
172
|
airflow/providers/amazon/aws/sensors/redshift_cluster.py,sha256=8JxB23ifahmqrso6j8JPmiqYLcHZmFznE79aSeSHrJs,4086
|
173
|
-
airflow/providers/amazon/aws/sensors/s3.py,sha256=
|
174
|
-
airflow/providers/amazon/aws/sensors/sagemaker.py,sha256=
|
173
|
+
airflow/providers/amazon/aws/sensors/s3.py,sha256=y118-NM72oMrUQhSO9shdemqIIgBv-qJJa8BDXqHzsA,17344
|
174
|
+
airflow/providers/amazon/aws/sensors/sagemaker.py,sha256=dVQntJNRyUYCLQ7cIkeHesgZxf-1yS_BBAiVBzCwaHI,13795
|
175
175
|
airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py,sha256=REPtB8BXcGu5xX-n-y_IAZXYpZ9d1nKiJAVqnMW3ayY,2865
|
176
176
|
airflow/providers/amazon/aws/sensors/sqs.py,sha256=V3d05xb2VuxdWimpDVJy_SOKX7N0ok9TBbEYO-9o3v4,10672
|
177
177
|
airflow/providers/amazon/aws/sensors/step_function.py,sha256=gaklKHdfmE-9avKSmyuGYvv9CuSklpjPz4KXZI8wXnY,3607
|
@@ -197,7 +197,7 @@ airflow/providers/amazon/aws/transfers/s3_to_sftp.py,sha256=gYT0iG9pbL36ObnjzKm_
|
|
197
197
|
airflow/providers/amazon/aws/transfers/s3_to_sql.py,sha256=shkNpAoNgKJ3fBEBg7ZQaU8zMa5WjZuZV4eihxh2uLQ,4976
|
198
198
|
airflow/providers/amazon/aws/transfers/salesforce_to_s3.py,sha256=WbCZUa9gfQB1SjDfUfPw5QO8lZ8Q-vSLriTnpXLhvxs,5713
|
199
199
|
airflow/providers/amazon/aws/transfers/sftp_to_s3.py,sha256=-D5AR306Q8710e4dHao75CMGS7InHernCH_aZsE6Je4,4209
|
200
|
-
airflow/providers/amazon/aws/transfers/sql_to_s3.py,sha256=
|
200
|
+
airflow/providers/amazon/aws/transfers/sql_to_s3.py,sha256=KbmMAVkIgC1jcOul70fzQnrx6V6vF7uXdO9I7JYA9oI,11118
|
201
201
|
airflow/providers/amazon/aws/triggers/README.md,sha256=ax2F0w2CuQSDN4ghJADozrrv5W4OeCDPA8Vzp00BXOU,10919
|
202
202
|
airflow/providers/amazon/aws/triggers/__init__.py,sha256=mlJxuZLkd5x-iq2SBwD3mvRQpt3YR7wjz_nceyF1IaI,787
|
203
203
|
airflow/providers/amazon/aws/triggers/athena.py,sha256=62ty40zejcm5Y0d1rTQZuYzSjq3hUkmAs0d_zxM_Kjw,2596
|
@@ -269,7 +269,7 @@ airflow/providers/amazon/aws/waiters/rds.json,sha256=HNmNQm5J-VaFHzjWb1pE5P7-Ix-
|
|
269
269
|
airflow/providers/amazon/aws/waiters/redshift.json,sha256=jOBotCgbkko1b_CHcGEbhhRvusgt0YSzVuFiZrqVP30,1742
|
270
270
|
airflow/providers/amazon/aws/waiters/sagemaker.json,sha256=JPHuQtUFZ1B7EMLfVmCRevNZ9jgpB71LM0dva8ZEO9A,5254
|
271
271
|
airflow/providers/amazon/aws/waiters/stepfunctions.json,sha256=GsOH-emGerKGBAUFmI5lpMfNGH4c0ol_PSiea25DCEY,1033
|
272
|
-
apache_airflow_providers_amazon-9.
|
273
|
-
apache_airflow_providers_amazon-9.
|
274
|
-
apache_airflow_providers_amazon-9.
|
275
|
-
apache_airflow_providers_amazon-9.
|
272
|
+
apache_airflow_providers_amazon-9.8.0.dist-info/entry_points.txt,sha256=vlc0ZzhBkMrav1maTRofgksnAw4SwoQLFX9cmnTgktk,102
|
273
|
+
apache_airflow_providers_amazon-9.8.0.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
|
274
|
+
apache_airflow_providers_amazon-9.8.0.dist-info/METADATA,sha256=7ofNSfONmOGy5QURfFZDNobHBnDnvz9skVVmtv_v0RY,10185
|
275
|
+
apache_airflow_providers_amazon-9.8.0.dist-info/RECORD,,
|
File without changes
|