apache-airflow-providers-amazon 9.7.0rc2__py3-none-any.whl → 9.8.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "9.7.0"
32
+ __version__ = "9.8.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.10.0"
@@ -21,14 +21,12 @@ import json
21
21
  import time
22
22
  import urllib
23
23
  from collections.abc import Sequence
24
- from functools import cached_property
25
24
  from typing import TYPE_CHECKING, Any, Callable, ClassVar
26
25
 
27
26
  from botocore.exceptions import ClientError
28
27
 
29
28
  from airflow.configuration import conf
30
29
  from airflow.exceptions import AirflowException
31
- from airflow.models import BaseOperator
32
30
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
33
31
  from airflow.providers.amazon.aws.hooks.sagemaker import (
34
32
  LogState,
@@ -36,11 +34,13 @@ from airflow.providers.amazon.aws.hooks.sagemaker import (
36
34
  secondary_training_status_message,
37
35
  )
38
36
  from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink
37
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
39
38
  from airflow.providers.amazon.aws.triggers.sagemaker import (
40
39
  SageMakerPipelineTrigger,
41
40
  SageMakerTrigger,
42
41
  )
43
42
  from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event
43
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
44
44
  from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
45
45
  from airflow.providers.amazon.aws.utils.tags import format_tags
46
46
  from airflow.utils.helpers import prune_dict
@@ -50,7 +50,7 @@ if TYPE_CHECKING:
50
50
  from airflow.providers.openlineage.extractors.base import OperatorLineage
51
51
  from airflow.utils.context import Context
52
52
 
53
- DEFAULT_CONN_ID: str = "aws_default"
53
+ # DEFAULT_CONN_ID: str = "aws_default"
54
54
  CHECK_INTERVAL_SECOND: int = 30
55
55
 
56
56
 
@@ -58,23 +58,33 @@ def serialize(result: dict) -> dict:
58
58
  return json.loads(json.dumps(result, default=repr))
59
59
 
60
60
 
61
- class SageMakerBaseOperator(BaseOperator):
61
+ class SageMakerBaseOperator(AwsBaseOperator[SageMakerHook]):
62
62
  """
63
63
  This is the base operator for all SageMaker operators.
64
64
 
65
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
66
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
67
+ running Airflow in a distributed manner and aws_conn_id is None or
68
+ empty, then default boto3 configuration would be used (and must be
69
+ maintained on each worker node).
70
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
71
+ :param verify: Whether or not to verify SSL certificates. See:
72
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
65
73
  :param config: The configuration necessary to start a training job (templated)
66
74
  """
67
75
 
68
- template_fields: Sequence[str] = ("config",)
76
+ aws_hook_class = SageMakerHook
77
+ template_fields: Sequence[str] = aws_template_fields(
78
+ "config",
79
+ )
69
80
  template_ext: Sequence[str] = ()
70
81
  template_fields_renderers: ClassVar[dict] = {"config": "json"}
71
82
  ui_color: str = "#ededed"
72
83
  integer_fields: list[list[Any]] = []
73
84
 
74
- def __init__(self, *, config: dict, aws_conn_id: str | None = DEFAULT_CONN_ID, **kwargs):
85
+ def __init__(self, *, config: dict, **kwargs):
75
86
  super().__init__(**kwargs)
76
87
  self.config = config
77
- self.aws_conn_id = aws_conn_id
78
88
 
79
89
  def parse_integer(self, config: dict, field: list[str] | str) -> None:
80
90
  """Recursive method for parsing string fields holding integer values to integers."""
@@ -199,11 +209,6 @@ class SageMakerBaseOperator(BaseOperator):
199
209
  def execute(self, context: Context):
200
210
  raise NotImplementedError("Please implement execute() in sub class!")
201
211
 
202
- @cached_property
203
- def hook(self):
204
- """Return SageMakerHook."""
205
- return SageMakerHook(aws_conn_id=self.aws_conn_id)
206
-
207
212
  @staticmethod
208
213
  def path_to_s3_dataset(path) -> Dataset:
209
214
  from airflow.providers.common.compat.openlineage.facet import Dataset
@@ -227,7 +232,14 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
227
232
 
228
233
  :param config: The configuration necessary to start a processing job (templated).
229
234
  For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job`
230
- :param aws_conn_id: The AWS connection ID to use.
235
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
236
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
237
+ running Airflow in a distributed manner and aws_conn_id is None or
238
+ empty, then default boto3 configuration would be used (and must be
239
+ maintained on each worker node).
240
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
241
+ :param verify: Whether or not to verify SSL certificates. See:
242
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
231
243
  :param wait_for_completion: If wait is set to True, the time interval, in seconds,
232
244
  that the operation waits to check the status of the processing job.
233
245
  :param print_log: if the operator should print the cloudwatch log during processing
@@ -249,7 +261,6 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
249
261
  self,
250
262
  *,
251
263
  config: dict,
252
- aws_conn_id: str | None = DEFAULT_CONN_ID,
253
264
  wait_for_completion: bool = True,
254
265
  print_log: bool = True,
255
266
  check_interval: int = CHECK_INTERVAL_SECOND,
@@ -259,7 +270,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
259
270
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
260
271
  **kwargs,
261
272
  ):
262
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
273
+ super().__init__(config=config, **kwargs)
263
274
  if action_if_job_exists not in ("fail", "timestamp"):
264
275
  raise AirflowException(
265
276
  f"Argument action_if_job_exists accepts only 'timestamp' and 'fail'. \
@@ -403,7 +414,14 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
403
414
  :param config: The configuration necessary to create an endpoint config.
404
415
 
405
416
  For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config`
406
- :param aws_conn_id: The AWS connection ID to use.
417
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
418
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
419
+ running Airflow in a distributed manner and aws_conn_id is None or
420
+ empty, then default boto3 configuration would be used (and must be
421
+ maintained on each worker node).
422
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
423
+ :param verify: Whether or not to verify SSL certificates. See:
424
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
407
425
  :return Dict: Returns The ARN of the endpoint config created in Amazon SageMaker.
408
426
  """
409
427
 
@@ -411,10 +429,9 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
411
429
  self,
412
430
  *,
413
431
  config: dict,
414
- aws_conn_id: str | None = DEFAULT_CONN_ID,
415
432
  **kwargs,
416
433
  ):
417
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
434
+ super().__init__(config=config, **kwargs)
418
435
 
419
436
  def _create_integer_fields(self) -> None:
420
437
  """Set fields which should be cast to integers."""
@@ -476,7 +493,14 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
476
493
  :param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
477
494
  finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
478
495
  :param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
479
- :param aws_conn_id: The AWS connection ID to use.
496
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
497
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
498
+ running Airflow in a distributed manner and aws_conn_id is None or
499
+ empty, then default boto3 configuration would be used (and must be
500
+ maintained on each worker node).
501
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
502
+ :param verify: Whether or not to verify SSL certificates. See:
503
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
480
504
  :param deferrable: Will wait asynchronously for completion.
481
505
  :return Dict: Returns The ARN of the endpoint created in Amazon SageMaker.
482
506
  """
@@ -485,7 +509,6 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
485
509
  self,
486
510
  *,
487
511
  config: dict,
488
- aws_conn_id: str | None = DEFAULT_CONN_ID,
489
512
  wait_for_completion: bool = True,
490
513
  check_interval: int = CHECK_INTERVAL_SECOND,
491
514
  max_ingestion_time: int | None = None,
@@ -493,7 +516,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
493
516
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
494
517
  **kwargs,
495
518
  ):
496
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
519
+ super().__init__(config=config, **kwargs)
497
520
  self.wait_for_completion = wait_for_completion
498
521
  self.check_interval = check_interval
499
522
  self.max_ingestion_time = max_ingestion_time or 3600 * 10
@@ -634,7 +657,14 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
634
657
  For details of the configuration parameter of model_config, See:
635
658
  :py:meth:`SageMaker.Client.create_model`
636
659
 
637
- :param aws_conn_id: The AWS connection ID to use.
660
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
661
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
662
+ running Airflow in a distributed manner and aws_conn_id is None or
663
+ empty, then default boto3 configuration would be used (and must be
664
+ maintained on each worker node).
665
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
666
+ :param verify: Whether or not to verify SSL certificates. See:
667
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
638
668
  :param wait_for_completion: Set to True to wait until the transform job finishes.
639
669
  :param check_interval: If wait is set to True, the time interval, in seconds,
640
670
  that this operation waits to check the status of the transform job.
@@ -657,7 +687,6 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
657
687
  self,
658
688
  *,
659
689
  config: dict,
660
- aws_conn_id: str | None = DEFAULT_CONN_ID,
661
690
  wait_for_completion: bool = True,
662
691
  check_interval: int = CHECK_INTERVAL_SECOND,
663
692
  max_attempts: int | None = None,
@@ -669,7 +698,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
669
698
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
670
699
  **kwargs,
671
700
  ):
672
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
701
+ super().__init__(config=config, **kwargs)
673
702
  self.wait_for_completion = wait_for_completion
674
703
  self.check_interval = check_interval
675
704
  self.max_attempts = max_attempts or 60
@@ -898,7 +927,14 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
898
927
 
899
928
  For details of the configuration parameter see
900
929
  :py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
901
- :param aws_conn_id: The AWS connection ID to use.
930
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
931
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
932
+ running Airflow in a distributed manner and aws_conn_id is None or
933
+ empty, then default boto3 configuration would be used (and must be
934
+ maintained on each worker node).
935
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
936
+ :param verify: Whether or not to verify SSL certificates. See:
937
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
902
938
  :param wait_for_completion: Set to True to wait until the tuning job finishes.
903
939
  :param check_interval: If wait is set to True, the time interval, in seconds,
904
940
  that this operation waits to check the status of the tuning job.
@@ -913,14 +949,13 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
913
949
  self,
914
950
  *,
915
951
  config: dict,
916
- aws_conn_id: str | None = DEFAULT_CONN_ID,
917
952
  wait_for_completion: bool = True,
918
953
  check_interval: int = CHECK_INTERVAL_SECOND,
919
954
  max_ingestion_time: int | None = None,
920
955
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
921
956
  **kwargs,
922
957
  ):
923
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
958
+ super().__init__(config=config, **kwargs)
924
959
  self.wait_for_completion = wait_for_completion
925
960
  self.check_interval = check_interval
926
961
  self.max_ingestion_time = max_ingestion_time
@@ -1016,8 +1051,8 @@ class SageMakerModelOperator(SageMakerBaseOperator):
1016
1051
  :return Dict: Returns The ARN of the model created in Amazon SageMaker.
1017
1052
  """
1018
1053
 
1019
- def __init__(self, *, config: dict, aws_conn_id: str | None = DEFAULT_CONN_ID, **kwargs):
1020
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
1054
+ def __init__(self, *, config: dict, **kwargs):
1055
+ super().__init__(config=config, **kwargs)
1021
1056
 
1022
1057
  def expand_role(self) -> None:
1023
1058
  """Expand an IAM role name into an ARN."""
@@ -1048,7 +1083,14 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
1048
1083
  :param config: The configuration necessary to start a training job (templated).
1049
1084
 
1050
1085
  For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job`
1051
- :param aws_conn_id: The AWS connection ID to use.
1086
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1087
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1088
+ running Airflow in a distributed manner and aws_conn_id is None or
1089
+ empty, then default boto3 configuration would be used (and must be
1090
+ maintained on each worker node).
1091
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1092
+ :param verify: Whether or not to verify SSL certificates. See:
1093
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1052
1094
  :param wait_for_completion: If wait is set to True, the time interval, in seconds,
1053
1095
  that the operation waits to check the status of the training job.
1054
1096
  :param print_log: if the operator should print the cloudwatch log during training
@@ -1073,7 +1115,6 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
1073
1115
  self,
1074
1116
  *,
1075
1117
  config: dict,
1076
- aws_conn_id: str | None = DEFAULT_CONN_ID,
1077
1118
  wait_for_completion: bool = True,
1078
1119
  print_log: bool = True,
1079
1120
  check_interval: int = CHECK_INTERVAL_SECOND,
@@ -1084,7 +1125,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
1084
1125
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1085
1126
  **kwargs,
1086
1127
  ):
1087
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
1128
+ super().__init__(config=config, **kwargs)
1088
1129
  self.wait_for_completion = wait_for_completion
1089
1130
  self.print_log = print_log
1090
1131
  self.check_interval = check_interval
@@ -1243,15 +1284,22 @@ class SageMakerDeleteModelOperator(SageMakerBaseOperator):
1243
1284
 
1244
1285
  :param config: The configuration necessary to delete the model.
1245
1286
  For details of the configuration parameter see :py:meth:`SageMaker.Client.delete_model`
1246
- :param aws_conn_id: The AWS connection ID to use.
1287
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1288
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1289
+ running Airflow in a distributed manner and aws_conn_id is None or
1290
+ empty, then default boto3 configuration would be used (and must be
1291
+ maintained on each worker node).
1292
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1293
+ :param verify: Whether or not to verify SSL certificates. See:
1294
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1247
1295
  """
1248
1296
 
1249
- def __init__(self, *, config: dict, aws_conn_id: str | None = DEFAULT_CONN_ID, **kwargs):
1250
- super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
1297
+ def __init__(self, *, config: dict, **kwargs):
1298
+ super().__init__(config=config, **kwargs)
1251
1299
 
1252
1300
  def execute(self, context: Context) -> Any:
1253
- sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
1254
- sagemaker_hook.delete_model(model_name=self.config["ModelName"])
1301
+ # sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
1302
+ self.hook.delete_model(model_name=self.config["ModelName"])
1255
1303
  self.log.info("Model %s deleted successfully.", self.config["ModelName"])
1256
1304
 
1257
1305
 
@@ -1264,7 +1312,14 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
1264
1312
  :ref:`howto/operator:SageMakerStartPipelineOperator`
1265
1313
 
1266
1314
  :param config: The configuration to start the pipeline execution.
1267
- :param aws_conn_id: The AWS connection ID to use.
1315
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1316
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1317
+ running Airflow in a distributed manner and aws_conn_id is None or
1318
+ empty, then default boto3 configuration would be used (and must be
1319
+ maintained on each worker node).
1320
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1321
+ :param verify: Whether or not to verify SSL certificates. See:
1322
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1268
1323
  :param pipeline_name: Name of the pipeline to start.
1269
1324
  :param display_name: The name this pipeline execution will have in the UI. Doesn't need to be unique.
1270
1325
  :param pipeline_params: Optional parameters for the pipeline.
@@ -1279,8 +1334,7 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
1279
1334
  :return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
1280
1335
  """
1281
1336
 
1282
- template_fields: Sequence[str] = (
1283
- "aws_conn_id",
1337
+ template_fields: Sequence[str] = aws_template_fields(
1284
1338
  "pipeline_name",
1285
1339
  "display_name",
1286
1340
  "pipeline_params",
@@ -1289,7 +1343,6 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
1289
1343
  def __init__(
1290
1344
  self,
1291
1345
  *,
1292
- aws_conn_id: str | None = DEFAULT_CONN_ID,
1293
1346
  pipeline_name: str,
1294
1347
  display_name: str = "airflow-triggered-execution",
1295
1348
  pipeline_params: dict | None = None,
@@ -1300,7 +1353,7 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
1300
1353
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1301
1354
  **kwargs,
1302
1355
  ):
1303
- super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
1356
+ super().__init__(config={}, **kwargs)
1304
1357
  self.pipeline_name = pipeline_name
1305
1358
  self.display_name = display_name
1306
1359
  self.pipeline_params = pipeline_params
@@ -1358,7 +1411,14 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
1358
1411
  :ref:`howto/operator:SageMakerStopPipelineOperator`
1359
1412
 
1360
1413
  :param config: The configuration to start the pipeline execution.
1361
- :param aws_conn_id: The AWS connection ID to use.
1414
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1415
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1416
+ running Airflow in a distributed manner and aws_conn_id is None or
1417
+ empty, then default boto3 configuration would be used (and must be
1418
+ maintained on each worker node).
1419
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1420
+ :param verify: Whether or not to verify SSL certificates. See:
1421
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1362
1422
  :param pipeline_exec_arn: Amazon Resource Name of the pipeline execution to stop.
1363
1423
  :param wait_for_completion: If true, this operator will only complete once the pipeline is fully stopped.
1364
1424
  :param check_interval: How long to wait between checks for pipeline status when waiting for completion.
@@ -1370,15 +1430,13 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
1370
1430
  :return str: Returns the status of the pipeline execution after the operation has been done.
1371
1431
  """
1372
1432
 
1373
- template_fields: Sequence[str] = (
1374
- "aws_conn_id",
1433
+ template_fields: Sequence[str] = aws_template_fields(
1375
1434
  "pipeline_exec_arn",
1376
1435
  )
1377
1436
 
1378
1437
  def __init__(
1379
1438
  self,
1380
1439
  *,
1381
- aws_conn_id: str | None = DEFAULT_CONN_ID,
1382
1440
  pipeline_exec_arn: str,
1383
1441
  wait_for_completion: bool = False,
1384
1442
  check_interval: int = CHECK_INTERVAL_SECOND,
@@ -1388,7 +1446,7 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
1388
1446
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1389
1447
  **kwargs,
1390
1448
  ):
1391
- super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
1449
+ super().__init__(config={}, **kwargs)
1392
1450
  self.pipeline_exec_arn = pipeline_exec_arn
1393
1451
  self.wait_for_completion = wait_for_completion
1394
1452
  self.check_interval = check_interval
@@ -1469,11 +1527,18 @@ class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
1469
1527
  :param extras: Can contain extra parameters for the boto call to create_model_package, and/or overrides
1470
1528
  for any parameter defined above. For a complete list of available parameters, see
1471
1529
  https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model_package
1472
-
1530
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1531
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1532
+ running Airflow in a distributed manner and aws_conn_id is None or
1533
+ empty, then default boto3 configuration would be used (and must be
1534
+ maintained on each worker node).
1535
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1536
+ :param verify: Whether or not to verify SSL certificates. See:
1537
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1473
1538
  :return str: Returns the ARN of the model package created.
1474
1539
  """
1475
1540
 
1476
- template_fields: Sequence[str] = (
1541
+ template_fields: Sequence[str] = aws_template_fields(
1477
1542
  "image_uri",
1478
1543
  "model_url",
1479
1544
  "package_group_name",
@@ -1492,11 +1557,10 @@ class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
1492
1557
  package_desc: str = "",
1493
1558
  model_approval: ApprovalStatus = ApprovalStatus.PENDING_MANUAL_APPROVAL,
1494
1559
  extras: dict | None = None,
1495
- aws_conn_id: str | None = DEFAULT_CONN_ID,
1496
1560
  config: dict | None = None,
1497
1561
  **kwargs,
1498
1562
  ):
1499
- super().__init__(config=config or {}, aws_conn_id=aws_conn_id, **kwargs)
1563
+ super().__init__(config=config or {}, **kwargs)
1500
1564
  self.image_uri = image_uri
1501
1565
  self.model_url = model_url
1502
1566
  self.package_group_name = package_group_name
@@ -1563,13 +1627,20 @@ class SageMakerAutoMLOperator(SageMakerBaseOperator):
1563
1627
  https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_auto_ml_job
1564
1628
  :param wait_for_completion: Whether to wait for the job to finish before returning. Defaults to True.
1565
1629
  :param check_interval: Interval in seconds between 2 status checks when waiting for completion.
1566
-
1630
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1631
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1632
+ running Airflow in a distributed manner and aws_conn_id is None or
1633
+ empty, then default boto3 configuration would be used (and must be
1634
+ maintained on each worker node).
1635
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1636
+ :param verify: Whether or not to verify SSL certificates. See:
1637
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1567
1638
  :returns: Only if waiting for completion, a dictionary detailing the best model. The structure is that of
1568
1639
  the "BestCandidate" key in:
1569
1640
  https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
1570
1641
  """
1571
1642
 
1572
- template_fields: Sequence[str] = (
1643
+ template_fields: Sequence[str] = aws_template_fields(
1573
1644
  "job_name",
1574
1645
  "s3_input",
1575
1646
  "target_attribute",
@@ -1595,11 +1666,10 @@ class SageMakerAutoMLOperator(SageMakerBaseOperator):
1595
1666
  extras: dict | None = None,
1596
1667
  wait_for_completion: bool = True,
1597
1668
  check_interval: int = 30,
1598
- aws_conn_id: str | None = DEFAULT_CONN_ID,
1599
1669
  config: dict | None = None,
1600
1670
  **kwargs,
1601
1671
  ):
1602
- super().__init__(config=config or {}, aws_conn_id=aws_conn_id, **kwargs)
1672
+ super().__init__(config=config or {}, **kwargs)
1603
1673
  self.job_name = job_name
1604
1674
  self.s3_input = s3_input
1605
1675
  self.target_attribute = target_attribute
@@ -1640,12 +1710,19 @@ class SageMakerCreateExperimentOperator(SageMakerBaseOperator):
1640
1710
  :param name: name of the experiment, must be unique within the AWS account
1641
1711
  :param description: description of the experiment, optional
1642
1712
  :param tags: tags to attach to the experiment, optional
1643
- :param aws_conn_id: The AWS connection ID to use.
1713
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1714
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1715
+ running Airflow in a distributed manner and aws_conn_id is None or
1716
+ empty, then default boto3 configuration would be used (and must be
1717
+ maintained on each worker node).
1718
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1719
+ :param verify: Whether or not to verify SSL certificates. See:
1720
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1644
1721
 
1645
1722
  :returns: the ARN of the experiment created, though experiments are referred to by name
1646
1723
  """
1647
1724
 
1648
- template_fields: Sequence[str] = (
1725
+ template_fields: Sequence[str] = aws_template_fields(
1649
1726
  "name",
1650
1727
  "description",
1651
1728
  "tags",
@@ -1657,28 +1734,26 @@ class SageMakerCreateExperimentOperator(SageMakerBaseOperator):
1657
1734
  name: str,
1658
1735
  description: str | None = None,
1659
1736
  tags: dict | None = None,
1660
- aws_conn_id: str | None = DEFAULT_CONN_ID,
1661
1737
  **kwargs,
1662
1738
  ):
1663
- super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
1739
+ super().__init__(config={}, **kwargs)
1664
1740
  self.name = name
1665
1741
  self.description = description
1666
1742
  self.tags = tags or {}
1667
1743
 
1668
1744
  def execute(self, context: Context) -> str:
1669
- sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
1670
1745
  params = {
1671
1746
  "ExperimentName": self.name,
1672
1747
  "Description": self.description,
1673
1748
  "Tags": format_tags(self.tags),
1674
1749
  }
1675
- ans = sagemaker_hook.conn.create_experiment(**trim_none_values(params))
1750
+ ans = self.hook.conn.create_experiment(**trim_none_values(params))
1676
1751
  arn = ans["ExperimentArn"]
1677
1752
  self.log.info("Experiment %s created successfully with ARN %s.", self.name, arn)
1678
1753
  return arn
1679
1754
 
1680
1755
 
1681
- class SageMakerCreateNotebookOperator(BaseOperator):
1756
+ class SageMakerCreateNotebookOperator(AwsBaseOperator[SageMakerHook]):
1682
1757
  """
1683
1758
  Create a SageMaker notebook.
1684
1759
 
@@ -1699,12 +1774,19 @@ class SageMakerCreateNotebookOperator(BaseOperator):
1699
1774
  :param root_access: Whether to give the notebook instance root access to the Amazon S3 bucket.
1700
1775
  :param wait_for_completion: Whether or not to wait for the notebook to be InService before returning
1701
1776
  :param create_instance_kwargs: Additional configuration options for the create call.
1702
- :param aws_conn_id: The AWS connection ID to use.
1703
-
1777
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1778
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1779
+ running Airflow in a distributed manner and aws_conn_id is None or
1780
+ empty, then default boto3 configuration would be used (and must be
1781
+ maintained on each worker node).
1782
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1783
+ :param verify: Whether or not to verify SSL certificates. See:
1784
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1704
1785
  :return: The ARN of the created notebook.
1705
1786
  """
1706
1787
 
1707
- template_fields: Sequence[str] = (
1788
+ aws_hook_class = SageMakerHook
1789
+ template_fields: Sequence[str] = aws_template_fields(
1708
1790
  "instance_name",
1709
1791
  "instance_type",
1710
1792
  "role_arn",
@@ -1732,7 +1814,6 @@ class SageMakerCreateNotebookOperator(BaseOperator):
1732
1814
  root_access: str | None = None,
1733
1815
  create_instance_kwargs: dict[str, Any] | None = None,
1734
1816
  wait_for_completion: bool = True,
1735
- aws_conn_id: str | None = "aws_default",
1736
1817
  **kwargs,
1737
1818
  ):
1738
1819
  super().__init__(**kwargs)
@@ -1745,17 +1826,11 @@ class SageMakerCreateNotebookOperator(BaseOperator):
1745
1826
  self.direct_internet_access = direct_internet_access
1746
1827
  self.root_access = root_access
1747
1828
  self.wait_for_completion = wait_for_completion
1748
- self.aws_conn_id = aws_conn_id
1749
1829
  self.create_instance_kwargs = create_instance_kwargs or {}
1750
1830
 
1751
1831
  if self.create_instance_kwargs.get("tags") is not None:
1752
1832
  self.create_instance_kwargs["tags"] = format_tags(self.create_instance_kwargs["tags"])
1753
1833
 
1754
- @cached_property
1755
- def hook(self) -> SageMakerHook:
1756
- """Create and return SageMakerHook."""
1757
- return SageMakerHook(aws_conn_id=self.aws_conn_id)
1758
-
1759
1834
  def execute(self, context: Context):
1760
1835
  create_notebook_instance_kwargs = {
1761
1836
  "NotebookInstanceName": self.instance_name,
@@ -1783,7 +1858,7 @@ class SageMakerCreateNotebookOperator(BaseOperator):
1783
1858
  return response["NotebookInstanceArn"]
1784
1859
 
1785
1860
 
1786
- class SageMakerStopNotebookOperator(BaseOperator):
1861
+ class SageMakerStopNotebookOperator(AwsBaseOperator[SageMakerHook]):
1787
1862
  """
1788
1863
  Stop a notebook instance.
1789
1864
 
@@ -1793,10 +1868,18 @@ class SageMakerStopNotebookOperator(BaseOperator):
1793
1868
 
1794
1869
  :param instance_name: The name of the notebook instance to stop.
1795
1870
  :param wait_for_completion: Whether or not to wait for the notebook to be stopped before returning
1796
- :param aws_conn_id: The AWS connection ID to use.
1871
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1872
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1873
+ running Airflow in a distributed manner and aws_conn_id is None or
1874
+ empty, then default boto3 configuration would be used (and must be
1875
+ maintained on each worker node).
1876
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1877
+ :param verify: Whether or not to verify SSL certificates. See:
1878
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1797
1879
  """
1798
1880
 
1799
- template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
1881
+ aws_hook_class = SageMakerHook
1882
+ template_fields: Sequence[str] = aws_template_fields("instance_name", "wait_for_completion")
1800
1883
 
1801
1884
  ui_color = "#ff7300"
1802
1885
 
@@ -1804,18 +1887,11 @@ class SageMakerStopNotebookOperator(BaseOperator):
1804
1887
  self,
1805
1888
  instance_name: str,
1806
1889
  wait_for_completion: bool = True,
1807
- aws_conn_id: str | None = "aws_default",
1808
1890
  **kwargs,
1809
1891
  ):
1810
1892
  super().__init__(**kwargs)
1811
1893
  self.instance_name = instance_name
1812
1894
  self.wait_for_completion = wait_for_completion
1813
- self.aws_conn_id = aws_conn_id
1814
-
1815
- @cached_property
1816
- def hook(self) -> SageMakerHook:
1817
- """Create and return SageMakerHook."""
1818
- return SageMakerHook(aws_conn_id=self.aws_conn_id)
1819
1895
 
1820
1896
  def execute(self, context):
1821
1897
  self.log.info("Stopping SageMaker notebook %s.", self.instance_name)
@@ -1828,7 +1904,7 @@ class SageMakerStopNotebookOperator(BaseOperator):
1828
1904
  )
1829
1905
 
1830
1906
 
1831
- class SageMakerDeleteNotebookOperator(BaseOperator):
1907
+ class SageMakerDeleteNotebookOperator(AwsBaseOperator[SageMakerHook]):
1832
1908
  """
1833
1909
  Delete a notebook instance.
1834
1910
 
@@ -1838,30 +1914,30 @@ class SageMakerDeleteNotebookOperator(BaseOperator):
1838
1914
 
1839
1915
  :param instance_name: The name of the notebook instance to delete.
1840
1916
  :param wait_for_completion: Whether or not to wait for the notebook to delete before returning.
1841
- :param aws_conn_id: The AWS connection ID to use.
1917
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1918
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1919
+ running Airflow in a distributed manner and aws_conn_id is None or
1920
+ empty, then default boto3 configuration would be used (and must be
1921
+ maintained on each worker node).
1922
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1923
+ :param verify: Whether or not to verify SSL certificates. See:
1924
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1842
1925
  """
1843
1926
 
1844
- template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
1845
-
1927
+ template_fields: Sequence[str] = aws_template_fields("instance_name", "wait_for_completion")
1928
+ aws_hook_class = SageMakerHook
1846
1929
  ui_color = "#ff7300"
1847
1930
 
1848
1931
  def __init__(
1849
1932
  self,
1850
1933
  instance_name: str,
1851
1934
  wait_for_completion: bool = True,
1852
- aws_conn_id: str | None = "aws_default",
1853
1935
  **kwargs,
1854
1936
  ):
1855
1937
  super().__init__(**kwargs)
1856
1938
  self.instance_name = instance_name
1857
- self.aws_conn_id = aws_conn_id
1858
1939
  self.wait_for_completion = wait_for_completion
1859
1940
 
1860
- @cached_property
1861
- def hook(self) -> SageMakerHook:
1862
- """Create and return SageMakerHook."""
1863
- return SageMakerHook(aws_conn_id=self.aws_conn_id)
1864
-
1865
1941
  def execute(self, context):
1866
1942
  self.log.info("Deleting SageMaker notebook %s....", self.instance_name)
1867
1943
  self.hook.conn.delete_notebook_instance(NotebookInstanceName=self.instance_name)
@@ -1873,7 +1949,7 @@ class SageMakerDeleteNotebookOperator(BaseOperator):
1873
1949
  )
1874
1950
 
1875
1951
 
1876
- class SageMakerStartNoteBookOperator(BaseOperator):
1952
+ class SageMakerStartNoteBookOperator(AwsBaseOperator[SageMakerHook]):
1877
1953
  """
1878
1954
  Start a notebook instance.
1879
1955
 
@@ -1883,10 +1959,18 @@ class SageMakerStartNoteBookOperator(BaseOperator):
1883
1959
 
1884
1960
  :param instance_name: The name of the notebook instance to start.
1885
1961
  :param wait_for_completion: Whether or not to wait for notebook to be InService before returning
1886
- :param aws_conn_id: The AWS connection ID to use.
1962
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
1963
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
1964
+ running Airflow in a distributed manner and aws_conn_id is None or
1965
+ empty, then default boto3 configuration would be used (and must be
1966
+ maintained on each worker node).
1967
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
1968
+ :param verify: Whether or not to verify SSL certificates. See:
1969
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
1887
1970
  """
1888
1971
 
1889
- template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
1972
+ template_fields: Sequence[str] = aws_template_fields("instance_name", "wait_for_completion")
1973
+ aws_hook_class = SageMakerHook
1890
1974
 
1891
1975
  ui_color = "#ff7300"
1892
1976
 
@@ -1894,19 +1978,12 @@ class SageMakerStartNoteBookOperator(BaseOperator):
1894
1978
  self,
1895
1979
  instance_name: str,
1896
1980
  wait_for_completion: bool = True,
1897
- aws_conn_id: str | None = "aws_default",
1898
1981
  **kwargs,
1899
1982
  ):
1900
1983
  super().__init__(**kwargs)
1901
1984
  self.instance_name = instance_name
1902
- self.aws_conn_id = aws_conn_id
1903
1985
  self.wait_for_completion = wait_for_completion
1904
1986
 
1905
- @cached_property
1906
- def hook(self) -> SageMakerHook:
1907
- """Create and return SageMakerHook."""
1908
- return SageMakerHook(aws_conn_id=self.aws_conn_id)
1909
-
1910
1987
  def execute(self, context):
1911
1988
  self.log.info("Starting SageMaker notebook %s....", self.instance_name)
1912
1989
  self.hook.conn.start_notebook_instance(NotebookInstanceName=self.instance_name)
@@ -107,7 +107,7 @@ class S3KeySensor(AwsBaseSensor[S3Hook]):
107
107
  self.verify = verify
108
108
  self.deferrable = deferrable
109
109
  self.use_regex = use_regex
110
- self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
110
+ self.metadata_keys = metadata_keys if metadata_keys else ["Size", "Key"]
111
111
 
112
112
  def _check_key(self, key, context: Context):
113
113
  bucket_name, key = self.hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
@@ -116,7 +116,8 @@ class S3KeySensor(AwsBaseSensor[S3Hook]):
116
116
  """
117
117
  Set variable `files` which contains a list of dict which contains attributes defined by the user
118
118
  Format: [{
119
- 'Size': int
119
+ 'Size': int,
120
+ 'Key': str,
120
121
  }]
121
122
  """
122
123
  if self.wildcard_match:
@@ -18,18 +18,18 @@ from __future__ import annotations
18
18
 
19
19
  import time
20
20
  from collections.abc import Sequence
21
- from functools import cached_property
22
21
  from typing import TYPE_CHECKING
23
22
 
24
23
  from airflow.exceptions import AirflowException
25
24
  from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
26
- from airflow.sensors.base import BaseSensorOperator
25
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
26
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
27
27
 
28
28
  if TYPE_CHECKING:
29
29
  from airflow.utils.context import Context
30
30
 
31
31
 
32
- class SageMakerBaseSensor(BaseSensorOperator):
32
+ class SageMakerBaseSensor(AwsBaseSensor[SageMakerHook]):
33
33
  """
34
34
  Contains general sensor behavior for SageMaker.
35
35
 
@@ -37,17 +37,13 @@ class SageMakerBaseSensor(BaseSensorOperator):
37
37
  Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
38
38
  """
39
39
 
40
+ aws_hook_class = SageMakerHook
40
41
  ui_color = "#ededed"
41
42
 
42
- def __init__(self, *, aws_conn_id: str | None = "aws_default", resource_type: str = "job", **kwargs):
43
+ def __init__(self, *, resource_type: str = "job", **kwargs):
43
44
  super().__init__(**kwargs)
44
- self.aws_conn_id = aws_conn_id
45
45
  self.resource_type = resource_type # only used for logs, to say what kind of resource we are sensing
46
46
 
47
- @cached_property
48
- def hook(self) -> SageMakerHook:
49
- return SageMakerHook(aws_conn_id=self.aws_conn_id)
50
-
51
47
  def poke(self, context: Context):
52
48
  response = self.get_sagemaker_response()
53
49
  if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
@@ -96,7 +92,9 @@ class SageMakerEndpointSensor(SageMakerBaseSensor):
96
92
  :param endpoint_name: Name of the endpoint instance to watch.
97
93
  """
98
94
 
99
- template_fields: Sequence[str] = ("endpoint_name",)
95
+ template_fields: Sequence[str] = aws_template_fields(
96
+ "endpoint_name",
97
+ )
100
98
  template_ext: Sequence[str] = ()
101
99
 
102
100
  def __init__(self, *, endpoint_name, **kwargs):
@@ -131,7 +129,9 @@ class SageMakerTransformSensor(SageMakerBaseSensor):
131
129
  :param job_name: Name of the transform job to watch.
132
130
  """
133
131
 
134
- template_fields: Sequence[str] = ("job_name",)
132
+ template_fields: Sequence[str] = aws_template_fields(
133
+ "job_name",
134
+ )
135
135
  template_ext: Sequence[str] = ()
136
136
 
137
137
  def __init__(self, *, job_name: str, **kwargs):
@@ -166,7 +166,9 @@ class SageMakerTuningSensor(SageMakerBaseSensor):
166
166
  :param job_name: Name of the tuning instance to watch.
167
167
  """
168
168
 
169
- template_fields: Sequence[str] = ("job_name",)
169
+ template_fields: Sequence[str] = aws_template_fields(
170
+ "job_name",
171
+ )
170
172
  template_ext: Sequence[str] = ()
171
173
 
172
174
  def __init__(self, *, job_name: str, **kwargs):
@@ -202,7 +204,9 @@ class SageMakerTrainingSensor(SageMakerBaseSensor):
202
204
  :param print_log: Prints the cloudwatch log if True; Defaults to True.
203
205
  """
204
206
 
205
- template_fields: Sequence[str] = ("job_name",)
207
+ template_fields: Sequence[str] = aws_template_fields(
208
+ "job_name",
209
+ )
206
210
  template_ext: Sequence[str] = ()
207
211
 
208
212
  def __init__(self, *, job_name, print_log=True, **kwargs):
@@ -281,7 +285,9 @@ class SageMakerPipelineSensor(SageMakerBaseSensor):
281
285
  Defaults to true, consider turning off for pipelines that have thousands of steps.
282
286
  """
283
287
 
284
- template_fields: Sequence[str] = ("pipeline_exec_arn",)
288
+ template_fields: Sequence[str] = aws_template_fields(
289
+ "pipeline_exec_arn",
290
+ )
285
291
 
286
292
  def __init__(self, *, pipeline_exec_arn: str, verbose: bool = True, **kwargs):
287
293
  super().__init__(resource_type="pipeline", **kwargs)
@@ -313,7 +319,9 @@ class SageMakerAutoMLSensor(SageMakerBaseSensor):
313
319
  :param job_name: unique name of the AutoML job to watch.
314
320
  """
315
321
 
316
- template_fields: Sequence[str] = ("job_name",)
322
+ template_fields: Sequence[str] = aws_template_fields(
323
+ "job_name",
324
+ )
317
325
 
318
326
  def __init__(self, *, job_name: str, **kwargs):
319
327
  super().__init__(resource_type="autoML job", **kwargs)
@@ -344,7 +352,9 @@ class SageMakerProcessingSensor(SageMakerBaseSensor):
344
352
  :param job_name: Name of the processing job to watch.
345
353
  """
346
354
 
347
- template_fields: Sequence[str] = ("job_name",)
355
+ template_fields: Sequence[str] = aws_template_fields(
356
+ "job_name",
357
+ )
348
358
  template_ext: Sequence[str] = ()
349
359
 
350
360
  def __init__(self, *, job_name: str, **kwargs):
@@ -18,9 +18,10 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import enum
21
+ import gzip
22
+ import io
21
23
  from collections import namedtuple
22
24
  from collections.abc import Iterable, Mapping, Sequence
23
- from tempfile import NamedTemporaryFile
24
25
  from typing import TYPE_CHECKING, Any, cast
25
26
 
26
27
  from typing_extensions import Literal
@@ -191,16 +192,29 @@ class SqlToS3Operator(BaseOperator):
191
192
  self.log.info("Data from SQL obtained")
192
193
  self._fix_dtypes(data_df, self.file_format)
193
194
  file_options = FILE_OPTIONS_MAP[self.file_format]
195
+
194
196
  for group_name, df in self._partition_dataframe(df=data_df):
195
- with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:
196
- self.log.info("Writing data to temp file")
197
- getattr(df, file_options.function)(tmp_file.name, **self.pd_kwargs)
198
-
199
- self.log.info("Uploading data to S3")
200
- object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key
201
- s3_conn.load_file(
202
- filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
203
- )
197
+ buf = io.BytesIO()
198
+ self.log.info("Writing data to in-memory buffer")
199
+ object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key
200
+
201
+ if self.pd_kwargs.get("compression") == "gzip":
202
+ pd_kwargs = {k: v for k, v in self.pd_kwargs.items() if k != "compression"}
203
+ with gzip.GzipFile(fileobj=buf, mode="wb", filename=object_key) as gz:
204
+ getattr(df, file_options.function)(gz, **pd_kwargs)
205
+ else:
206
+ if self.file_format == FILE_FORMAT.PARQUET:
207
+ getattr(df, file_options.function)(buf, **self.pd_kwargs)
208
+ else:
209
+ text_buf = io.TextIOWrapper(buf, encoding="utf-8", write_through=True)
210
+ getattr(df, file_options.function)(text_buf, **self.pd_kwargs)
211
+ text_buf.flush()
212
+ buf.seek(0)
213
+
214
+ self.log.info("Uploading data to S3")
215
+ s3_conn.load_file_obj(
216
+ file_obj=buf, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
217
+ )
204
218
 
205
219
  def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
206
220
  """Partition dataframe using pandas groupby() method."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: apache-airflow-providers-amazon
3
- Version: 9.7.0rc2
3
+ Version: 9.8.0rc1
4
4
  Summary: Provider package apache-airflow-providers-amazon for Apache Airflow
5
5
  Keywords: airflow-provider,amazon,airflow,integration
6
6
  Author-email: Apache Software Foundation <dev@airflow.apache.org>
@@ -56,8 +56,8 @@ Requires-Dist: apache-airflow-providers-salesforce ; extra == "salesforce"
56
56
  Requires-Dist: apache-airflow-providers-ssh ; extra == "ssh"
57
57
  Requires-Dist: apache-airflow-providers-standard ; extra == "standard"
58
58
  Project-URL: Bug Tracker, https://github.com/apache/airflow/issues
59
- Project-URL: Changelog, https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.7.0/changelog.html
60
- Project-URL: Documentation, https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.7.0
59
+ Project-URL: Changelog, https://airflow.staged.apache.org/docs/apache-airflow-providers-amazon/9.8.0/changelog.html
60
+ Project-URL: Documentation, https://airflow.staged.apache.org/docs/apache-airflow-providers-amazon/9.8.0
61
61
  Project-URL: Mastodon, https://fosstodon.org/@airflow
62
62
  Project-URL: Slack Chat, https://s.apache.org/airflow-slack
63
63
  Project-URL: Source Code, https://github.com/apache/airflow
@@ -105,7 +105,7 @@ Provides-Extra: standard
105
105
 
106
106
  Package ``apache-airflow-providers-amazon``
107
107
 
108
- Release: ``9.7.0``
108
+ Release: ``9.8.0``
109
109
 
110
110
 
111
111
  Amazon integration (including `Amazon Web Services (AWS) <https://aws.amazon.com/>`__).
@@ -118,7 +118,7 @@ This is a provider package for ``amazon`` provider. All classes for this provide
118
118
  are in ``airflow.providers.amazon`` python package.
119
119
 
120
120
  You can find package information and changelog for the provider
121
- in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.7.0/>`_.
121
+ in the `documentation <https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.8.0/>`_.
122
122
 
123
123
  Installation
124
124
  ------------
@@ -189,5 +189,5 @@ Dependent package
189
189
  ======================================================================================================================== ====================
190
190
 
191
191
  The changelog for the provider package can be found in the
192
- `changelog <https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.7.0/changelog.html>`_.
192
+ `changelog <https://airflow.apache.org/docs/apache-airflow-providers-amazon/9.8.0/changelog.html>`_.
193
193
 
@@ -1,5 +1,5 @@
1
1
  airflow/providers/amazon/LICENSE,sha256=gXPVwptPlW1TJ4HSuG5OMPg-a3h43OGMkZRR1rpwfJA,10850
2
- airflow/providers/amazon/__init__.py,sha256=WURN_BIK-1PYtgyWYekRiDMEbfkAnPajJpKuhPT6gLQ,1495
2
+ airflow/providers/amazon/__init__.py,sha256=_CZ_wEd1ln7ijiK7yeMFFxlJFhkSQb_T-U1QhAjBiH8,1495
3
3
  airflow/providers/amazon/get_provider_info.py,sha256=iXOUQZQkWSX6JDGZnqaQp7B7RzYyyW0RoLgS8qXFRl0,68490
4
4
  airflow/providers/amazon/version_compat.py,sha256=j5PCtXvZ71aBjixu-EFTNtVDPsngzzs7os0ZQDgFVDk,1536
5
5
  airflow/providers/amazon/aws/__init__.py,sha256=9hdXHABrVpkbpjZgUft39kOFL2xSGeG4GEua0Hmelus,785
@@ -136,7 +136,7 @@ airflow/providers/amazon/aws/operators/rds.py,sha256=cib-k2aHpa-DeYL92HqRMa5x9O8
136
136
  airflow/providers/amazon/aws/operators/redshift_cluster.py,sha256=dXakMyZV5vvEh0-20FOolMU5xEDugnMrvwwbLvuYc3o,37168
137
137
  airflow/providers/amazon/aws/operators/redshift_data.py,sha256=dcPEYGgXn9M1zS2XP7szm_kYO2xWD2CiUqsmlygf0_k,10854
138
138
  airflow/providers/amazon/aws/operators/s3.py,sha256=Imd3siCtmtaPWRmmSd382dJHhr49WRd-_aP6Tx5T7ac,38389
139
- airflow/providers/amazon/aws/operators/sagemaker.py,sha256=gZ0y2aBkIq0IZo602q1cOHIi5ATVEitHCngKkcPHjYg,83551
139
+ airflow/providers/amazon/aws/operators/sagemaker.py,sha256=54RsCEJ735MsAmXh3Z0hWVQFOP4X98kSeeKKXdWijaU,91293
140
140
  airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py,sha256=J-huObn3pZ_fg0gEy-BLsX298CX_n7qWV2YwjfpFnrw,6867
141
141
  airflow/providers/amazon/aws/operators/sns.py,sha256=uVcSJBbqy7YCOeiCrMvFFn9F9xTzMRpfrEygqEIhWEM,3757
142
142
  airflow/providers/amazon/aws/operators/sqs.py,sha256=o9rH2Pm5DNmccLh5I2wr96hZiuxOPi6YGZ2QluOeVb0,4764
@@ -170,8 +170,8 @@ airflow/providers/amazon/aws/sensors/opensearch_serverless.py,sha256=cSaZvCvAC7z
170
170
  airflow/providers/amazon/aws/sensors/quicksight.py,sha256=lm1omzh01BKh0KHU3g2I1yH9LAXtddUDiuIS3uIeOrE,3575
171
171
  airflow/providers/amazon/aws/sensors/rds.py,sha256=HWYQOQ7n9s48Ci2WxBOtrAp17aB-at5werAljq3NDYE,7420
172
172
  airflow/providers/amazon/aws/sensors/redshift_cluster.py,sha256=8JxB23ifahmqrso6j8JPmiqYLcHZmFznE79aSeSHrJs,4086
173
- airflow/providers/amazon/aws/sensors/s3.py,sha256=kmlYlCHufVAr8kh_dXebAuLEIrDqymrPHdVWFpbU3mY,17312
174
- airflow/providers/amazon/aws/sensors/sagemaker.py,sha256=nR32E1qKl9X61W52fC5FVB6ZQKb4gZVgxoMDimvXYhQ,13661
173
+ airflow/providers/amazon/aws/sensors/s3.py,sha256=y118-NM72oMrUQhSO9shdemqIIgBv-qJJa8BDXqHzsA,17344
174
+ airflow/providers/amazon/aws/sensors/sagemaker.py,sha256=dVQntJNRyUYCLQ7cIkeHesgZxf-1yS_BBAiVBzCwaHI,13795
175
175
  airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py,sha256=REPtB8BXcGu5xX-n-y_IAZXYpZ9d1nKiJAVqnMW3ayY,2865
176
176
  airflow/providers/amazon/aws/sensors/sqs.py,sha256=V3d05xb2VuxdWimpDVJy_SOKX7N0ok9TBbEYO-9o3v4,10672
177
177
  airflow/providers/amazon/aws/sensors/step_function.py,sha256=gaklKHdfmE-9avKSmyuGYvv9CuSklpjPz4KXZI8wXnY,3607
@@ -197,7 +197,7 @@ airflow/providers/amazon/aws/transfers/s3_to_sftp.py,sha256=gYT0iG9pbL36ObnjzKm_
197
197
  airflow/providers/amazon/aws/transfers/s3_to_sql.py,sha256=shkNpAoNgKJ3fBEBg7ZQaU8zMa5WjZuZV4eihxh2uLQ,4976
198
198
  airflow/providers/amazon/aws/transfers/salesforce_to_s3.py,sha256=WbCZUa9gfQB1SjDfUfPw5QO8lZ8Q-vSLriTnpXLhvxs,5713
199
199
  airflow/providers/amazon/aws/transfers/sftp_to_s3.py,sha256=-D5AR306Q8710e4dHao75CMGS7InHernCH_aZsE6Je4,4209
200
- airflow/providers/amazon/aws/transfers/sql_to_s3.py,sha256=WI3ykSVZSGuS6iffP7d4WL3TvP-SthgwWUKnHoT8FNU,10587
200
+ airflow/providers/amazon/aws/transfers/sql_to_s3.py,sha256=KbmMAVkIgC1jcOul70fzQnrx6V6vF7uXdO9I7JYA9oI,11118
201
201
  airflow/providers/amazon/aws/triggers/README.md,sha256=ax2F0w2CuQSDN4ghJADozrrv5W4OeCDPA8Vzp00BXOU,10919
202
202
  airflow/providers/amazon/aws/triggers/__init__.py,sha256=mlJxuZLkd5x-iq2SBwD3mvRQpt3YR7wjz_nceyF1IaI,787
203
203
  airflow/providers/amazon/aws/triggers/athena.py,sha256=62ty40zejcm5Y0d1rTQZuYzSjq3hUkmAs0d_zxM_Kjw,2596
@@ -269,7 +269,7 @@ airflow/providers/amazon/aws/waiters/rds.json,sha256=HNmNQm5J-VaFHzjWb1pE5P7-Ix-
269
269
  airflow/providers/amazon/aws/waiters/redshift.json,sha256=jOBotCgbkko1b_CHcGEbhhRvusgt0YSzVuFiZrqVP30,1742
270
270
  airflow/providers/amazon/aws/waiters/sagemaker.json,sha256=JPHuQtUFZ1B7EMLfVmCRevNZ9jgpB71LM0dva8ZEO9A,5254
271
271
  airflow/providers/amazon/aws/waiters/stepfunctions.json,sha256=GsOH-emGerKGBAUFmI5lpMfNGH4c0ol_PSiea25DCEY,1033
272
- apache_airflow_providers_amazon-9.7.0rc2.dist-info/entry_points.txt,sha256=vlc0ZzhBkMrav1maTRofgksnAw4SwoQLFX9cmnTgktk,102
273
- apache_airflow_providers_amazon-9.7.0rc2.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
274
- apache_airflow_providers_amazon-9.7.0rc2.dist-info/METADATA,sha256=zErqGTl6canGc8K_h9vx-ueuZrRgCzbt8niXOUTislQ,10206
275
- apache_airflow_providers_amazon-9.7.0rc2.dist-info/RECORD,,
272
+ apache_airflow_providers_amazon-9.8.0rc1.dist-info/entry_points.txt,sha256=vlc0ZzhBkMrav1maTRofgksnAw4SwoQLFX9cmnTgktk,102
273
+ apache_airflow_providers_amazon-9.8.0rc1.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
274
+ apache_airflow_providers_amazon-9.8.0rc1.dist-info/METADATA,sha256=4tfW46Y4El9haP5LZt0H4_TJSc3CgwPS3jl854NFIcg,10220
275
+ apache_airflow_providers_amazon-9.8.0rc1.dist-info/RECORD,,