apache-airflow-providers-amazon 9.6.0__py3-none-any.whl → 9.6.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +15 -18
- airflow/providers/amazon/aws/auth_manager/router/login.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +3 -4
- airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/utils.py +1 -1
- airflow/providers/amazon/aws/hooks/athena.py +1 -1
- airflow/providers/amazon/aws/hooks/base_aws.py +12 -15
- airflow/providers/amazon/aws/hooks/batch_client.py +11 -0
- airflow/providers/amazon/aws/hooks/cloud_formation.py +1 -2
- airflow/providers/amazon/aws/hooks/datasync.py +2 -2
- airflow/providers/amazon/aws/hooks/dms.py +2 -3
- airflow/providers/amazon/aws/hooks/dynamodb.py +1 -2
- airflow/providers/amazon/aws/hooks/emr.py +14 -17
- airflow/providers/amazon/aws/hooks/glue.py +9 -13
- airflow/providers/amazon/aws/hooks/mwaa.py +6 -7
- airflow/providers/amazon/aws/hooks/redshift_data.py +1 -1
- airflow/providers/amazon/aws/hooks/redshift_sql.py +5 -6
- airflow/providers/amazon/aws/hooks/s3.py +3 -6
- airflow/providers/amazon/aws/hooks/sagemaker.py +6 -9
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +5 -6
- airflow/providers/amazon/aws/links/base_aws.py +2 -2
- airflow/providers/amazon/aws/links/emr.py +2 -4
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +3 -5
- airflow/providers/amazon/aws/log/s3_task_handler.py +1 -2
- airflow/providers/amazon/aws/operators/athena.py +1 -1
- airflow/providers/amazon/aws/operators/batch.py +37 -42
- airflow/providers/amazon/aws/operators/bedrock.py +1 -1
- airflow/providers/amazon/aws/operators/ecs.py +4 -6
- airflow/providers/amazon/aws/operators/eks.py +146 -139
- airflow/providers/amazon/aws/operators/emr.py +4 -5
- airflow/providers/amazon/aws/operators/mwaa.py +1 -1
- airflow/providers/amazon/aws/operators/neptune.py +2 -2
- airflow/providers/amazon/aws/operators/redshift_data.py +1 -2
- airflow/providers/amazon/aws/operators/s3.py +9 -13
- airflow/providers/amazon/aws/operators/sagemaker.py +11 -19
- airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -2
- airflow/providers/amazon/aws/sensors/batch.py +33 -55
- airflow/providers/amazon/aws/sensors/eks.py +64 -54
- airflow/providers/amazon/aws/sensors/glacier.py +4 -5
- airflow/providers/amazon/aws/sensors/glue.py +6 -9
- airflow/providers/amazon/aws/sensors/glue_crawler.py +2 -4
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +1 -1
- airflow/providers/amazon/aws/sensors/s3.py +1 -2
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +4 -5
- airflow/providers/amazon/aws/sensors/sqs.py +1 -2
- airflow/providers/amazon/aws/utils/connection_wrapper.py +1 -1
- airflow/providers/amazon/aws/utils/sqs.py +1 -2
- airflow/providers/amazon/aws/utils/tags.py +2 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +1 -1
- {apache_airflow_providers_amazon-9.6.0.dist-info → apache_airflow_providers_amazon-9.6.1rc1.dist-info}/METADATA +11 -10
- {apache_airflow_providers_amazon-9.6.0.dist-info → apache_airflow_providers_amazon-9.6.1rc1.dist-info}/RECORD +54 -54
- {apache_airflow_providers_amazon-9.6.0.dist-info → apache_airflow_providers_amazon-9.6.1rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.6.0.dist-info → apache_airflow_providers_amazon-9.6.1rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "9.6.
|
32
|
+
__version__ = "9.6.1"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.9.0"
|
@@ -107,26 +107,23 @@ def _create_policy_store(client: BaseClient, args) -> tuple[str | None, bool]:
|
|
107
107
|
f"There is already a policy store with description '{args.policy_store_description}' in Amazon Verified Permissions: '{existing_policy_stores[0]['policyStoreId']}'."
|
108
108
|
)
|
109
109
|
return existing_policy_stores[0]["policyStoreId"], False
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
)
|
124
|
-
if args.verbose:
|
125
|
-
log.debug("Response from create_policy_store: %s", response)
|
110
|
+
print(f"No policy store with description '{args.policy_store_description}' found, creating one.")
|
111
|
+
if args.dry_run:
|
112
|
+
print(f"Dry run, not creating the policy store with description '{args.policy_store_description}'.")
|
113
|
+
return None, True
|
114
|
+
|
115
|
+
response = client.create_policy_store(
|
116
|
+
validationSettings={
|
117
|
+
"mode": "STRICT",
|
118
|
+
},
|
119
|
+
description=args.policy_store_description,
|
120
|
+
)
|
121
|
+
if args.verbose:
|
122
|
+
log.debug("Response from create_policy_store: %s", response)
|
126
123
|
|
127
|
-
|
124
|
+
print(f"Policy store created: '{response['policyStoreId']}'")
|
128
125
|
|
129
|
-
|
126
|
+
return response["policyStoreId"], True
|
130
127
|
|
131
128
|
|
132
129
|
def _set_schema(client: BaseClient, policy_store_id: str, args) -> None:
|
@@ -91,7 +91,7 @@ def login_callback(request: Request):
|
|
91
91
|
|
92
92
|
def _init_saml_auth(request: Request) -> OneLogin_Saml2_Auth:
|
93
93
|
request_data = _prepare_request(request)
|
94
|
-
base_url = conf.get(section="api", key="base_url")
|
94
|
+
base_url = conf.get(section="api", key="base_url", fallback="/")
|
95
95
|
settings = {
|
96
96
|
# We want to keep this flag on in case of errors.
|
97
97
|
# It provides an error reasons, if turned off, it does not
|
@@ -23,6 +23,7 @@ Each Airflow task gets delegated out to an Amazon ECS Task.
|
|
23
23
|
|
24
24
|
from __future__ import annotations
|
25
25
|
|
26
|
+
import contextlib
|
26
27
|
import time
|
27
28
|
from collections import defaultdict, deque
|
28
29
|
from collections.abc import Sequence
|
@@ -449,13 +450,11 @@ class AwsEcsExecutor(BaseExecutor):
|
|
449
450
|
else:
|
450
451
|
task = run_task_response["tasks"][0]
|
451
452
|
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
|
452
|
-
|
453
|
-
self.running_state(task_key, task.task_arn)
|
454
|
-
except AttributeError:
|
453
|
+
with contextlib.suppress(AttributeError):
|
455
454
|
# running_state is newly added, and only needed to support task adoption (an optional
|
456
455
|
# executor feature).
|
457
456
|
# TODO: remove when min airflow version >= 2.9.2
|
458
|
-
|
457
|
+
self.running_state(task_key, task.task_arn)
|
459
458
|
|
460
459
|
def _run_task(
|
461
460
|
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
|
@@ -72,7 +72,7 @@ def build_task_kwargs() -> dict:
|
|
72
72
|
raise ValueError(
|
73
73
|
"capacity_provider_strategy and launch_type are mutually exclusive, you can not provide both."
|
74
74
|
)
|
75
|
-
|
75
|
+
if "cluster" in task_kwargs and not (has_capacity_provider or has_launch_type):
|
76
76
|
# Default API behavior if neither is provided is to fall back on the default capacity
|
77
77
|
# provider if it exists. Since it is not a required value, check if there is one
|
78
78
|
# before using it, and if there is not then use the FARGATE launch_type as
|
@@ -138,7 +138,7 @@ class EcsExecutorTask:
|
|
138
138
|
"""
|
139
139
|
if self.last_status == "RUNNING":
|
140
140
|
return State.RUNNING
|
141
|
-
|
141
|
+
if self.desired_status == "RUNNING":
|
142
142
|
return State.QUEUED
|
143
143
|
is_finished = self.desired_status == "STOPPED"
|
144
144
|
has_exit_codes = all(["exit_code" in x for x in self.containers])
|
@@ -208,7 +208,7 @@ class AthenaHook(AwsBaseHook):
|
|
208
208
|
if query_state is None:
|
209
209
|
self.log.error("Invalid Query state. Query execution id: %s", query_execution_id)
|
210
210
|
return None
|
211
|
-
|
211
|
+
if query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
|
212
212
|
self.log.error(
|
213
213
|
'Query is in "%s" state. Cannot fetch results. Query execution id: %s',
|
214
214
|
query_state,
|
@@ -191,15 +191,13 @@ class BaseSessionFactory(LoggingMixin):
|
|
191
191
|
session = self.get_async_session()
|
192
192
|
self._apply_session_kwargs(session)
|
193
193
|
return session
|
194
|
-
|
195
|
-
|
196
|
-
elif not self.role_arn:
|
194
|
+
return boto3.session.Session(region_name=self.region_name)
|
195
|
+
if not self.role_arn:
|
197
196
|
if deferrable:
|
198
197
|
session = self.get_async_session()
|
199
198
|
self._apply_session_kwargs(session)
|
200
199
|
return session
|
201
|
-
|
202
|
-
return self.basic_session
|
200
|
+
return self.basic_session
|
203
201
|
|
204
202
|
# Values stored in ``AwsConnectionWrapper.session_kwargs`` are intended to be used only
|
205
203
|
# to create the initial boto3 session.
|
@@ -624,7 +622,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
624
622
|
if is_resource_type:
|
625
623
|
raise LookupError("Requested `resource_type`, but `client_type` was set instead.")
|
626
624
|
return self.client_type
|
627
|
-
|
625
|
+
if self.resource_type:
|
628
626
|
if not is_resource_type:
|
629
627
|
raise LookupError("Requested `client_type`, but `resource_type` was set instead.")
|
630
628
|
return self.resource_type
|
@@ -840,15 +838,14 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
840
838
|
"""
|
841
839
|
if "/" in role:
|
842
840
|
return role
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
return _client.get_role(RoleName=role)["Role"]["Arn"]
|
841
|
+
session = self.get_session(region_name=region_name)
|
842
|
+
_client = session.client(
|
843
|
+
service_name="iam",
|
844
|
+
endpoint_url=self.conn_config.get_service_endpoint_url("iam"),
|
845
|
+
config=self.config,
|
846
|
+
verify=self.verify,
|
847
|
+
)
|
848
|
+
return _client.get_role(RoleName=role)["Role"]["Arn"]
|
852
849
|
|
853
850
|
@staticmethod
|
854
851
|
def retry(should_retry: Callable[[Exception], bool]):
|
@@ -142,6 +142,17 @@ class BatchProtocol(Protocol):
|
|
142
142
|
"""
|
143
143
|
...
|
144
144
|
|
145
|
+
def create_compute_environment(self, **kwargs) -> dict:
|
146
|
+
"""
|
147
|
+
Create an AWS Batch compute environment.
|
148
|
+
|
149
|
+
:param kwargs: Arguments for boto3 create_compute_environment
|
150
|
+
|
151
|
+
.. seealso::
|
152
|
+
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch/client/create_compute_environment.html
|
153
|
+
"""
|
154
|
+
...
|
155
|
+
|
145
156
|
|
146
157
|
# Note that the use of invalid-name parameters should be restricted to the boto3 mappings only;
|
147
158
|
# all the Airflow wrappers of boto3 clients should not adopt invalid-names to match boto3.
|
@@ -312,9 +312,9 @@ class DataSyncHook(AwsBaseHook):
|
|
312
312
|
self.log.info("status=%s", status)
|
313
313
|
if status in self.TASK_EXECUTION_SUCCESS_STATES:
|
314
314
|
return True
|
315
|
-
|
315
|
+
if status in self.TASK_EXECUTION_FAILURE_STATES:
|
316
316
|
return False
|
317
|
-
|
317
|
+
if status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
|
318
318
|
time.sleep(self.wait_interval_seconds)
|
319
319
|
else:
|
320
320
|
raise AirflowException(f"Unknown status: {status}") # Should never happen
|
@@ -108,9 +108,8 @@ class DmsHook(AwsBaseHook):
|
|
108
108
|
status = replication_tasks[0]["Status"]
|
109
109
|
self.log.info('Replication task with ARN(%s) has status "%s".', replication_task_arn, status)
|
110
110
|
return status
|
111
|
-
|
112
|
-
|
113
|
-
return None
|
111
|
+
self.log.info("Replication task with ARN(%s) is not found.", replication_task_arn)
|
112
|
+
return None
|
114
113
|
|
115
114
|
def create_replication_task(
|
116
115
|
self,
|
@@ -83,11 +83,10 @@ class EmrHook(AwsBaseHook):
|
|
83
83
|
cluster_id = matching_clusters[0]["Id"]
|
84
84
|
self.log.info("Found cluster name = %s id = %s", emr_cluster_name, cluster_id)
|
85
85
|
return cluster_id
|
86
|
-
|
86
|
+
if len(matching_clusters) > 1:
|
87
87
|
raise AirflowException(f"More than one cluster found for name {emr_cluster_name}")
|
88
|
-
|
89
|
-
|
90
|
-
return None
|
88
|
+
self.log.info("No cluster found for name %s", emr_cluster_name)
|
89
|
+
return None
|
91
90
|
|
92
91
|
def create_job_flow(self, job_flow_overrides: dict[str, Any]) -> dict[str, Any]:
|
93
92
|
"""
|
@@ -387,12 +386,11 @@ class EmrContainerHook(AwsBaseHook):
|
|
387
386
|
|
388
387
|
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
389
388
|
raise AirflowException(f"Create EMR EKS Cluster failed: {response}")
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
return response["id"]
|
389
|
+
self.log.info(
|
390
|
+
"Create EMR EKS Cluster success - virtual cluster id %s",
|
391
|
+
response["id"],
|
392
|
+
)
|
393
|
+
return response["id"]
|
396
394
|
|
397
395
|
def submit_job(
|
398
396
|
self,
|
@@ -446,13 +444,12 @@ class EmrContainerHook(AwsBaseHook):
|
|
446
444
|
|
447
445
|
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
448
446
|
raise AirflowException(f"Start Job Run failed: {response}")
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
return response["id"]
|
447
|
+
self.log.info(
|
448
|
+
"Start Job Run success - Job Id %s and virtual cluster id %s",
|
449
|
+
response["id"],
|
450
|
+
response["virtualClusterId"],
|
451
|
+
)
|
452
|
+
return response["id"]
|
456
453
|
|
457
454
|
def get_job_failure_reason(self, job_id: str) -> str | None:
|
458
455
|
"""
|
@@ -320,8 +320,7 @@ class GlueJobHook(AwsBaseHook):
|
|
320
320
|
if ret:
|
321
321
|
time.sleep(sleep_before_return)
|
322
322
|
return ret
|
323
|
-
|
324
|
-
time.sleep(self.job_poll_interval)
|
323
|
+
time.sleep(self.job_poll_interval)
|
325
324
|
|
326
325
|
async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]:
|
327
326
|
"""
|
@@ -338,8 +337,7 @@ class GlueJobHook(AwsBaseHook):
|
|
338
337
|
ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens)
|
339
338
|
if ret:
|
340
339
|
return ret
|
341
|
-
|
342
|
-
await asyncio.sleep(self.job_poll_interval)
|
340
|
+
await asyncio.sleep(self.job_poll_interval)
|
343
341
|
|
344
342
|
def _handle_state(
|
345
343
|
self,
|
@@ -367,13 +365,12 @@ class GlueJobHook(AwsBaseHook):
|
|
367
365
|
job_error_message = f"Exiting Job {run_id} Run State: {state}"
|
368
366
|
self.log.info(job_error_message)
|
369
367
|
raise AirflowException(job_error_message)
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
return None
|
368
|
+
self.log.info(
|
369
|
+
"Polling for AWS Glue Job %s current run state with status %s",
|
370
|
+
job_name,
|
371
|
+
state,
|
372
|
+
)
|
373
|
+
return None
|
377
374
|
|
378
375
|
def has_job(self, job_name) -> bool:
|
379
376
|
"""
|
@@ -414,8 +411,7 @@ class GlueJobHook(AwsBaseHook):
|
|
414
411
|
self.conn.update_job(JobName=job_name, JobUpdate=job_kwargs)
|
415
412
|
self.log.info("Updated configurations: %s", update_config)
|
416
413
|
return True
|
417
|
-
|
418
|
-
return False
|
414
|
+
return False
|
419
415
|
|
420
416
|
def get_or_create_glue_job(self) -> str | None:
|
421
417
|
"""
|
@@ -101,13 +101,12 @@ class MwaaHook(AwsBaseHook):
|
|
101
101
|
"Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
|
102
102
|
)
|
103
103
|
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
raise
|
104
|
+
to_log = e.response
|
105
|
+
# ResponseMetadata is removed because it contains data that is either very unlikely to be
|
106
|
+
# useful in XComs and logs, or redundant given the data already included in the response
|
107
|
+
to_log.pop("ResponseMetadata", None)
|
108
|
+
self.log.error(to_log)
|
109
|
+
raise
|
111
110
|
|
112
111
|
def _invoke_rest_api_using_local_session_token(
|
113
112
|
self,
|
@@ -181,7 +181,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
181
181
|
if num_rows is not None:
|
182
182
|
self.log.info("Processed %s rows", num_rows)
|
183
183
|
return True
|
184
|
-
|
184
|
+
if status in FAILURE_STATES:
|
185
185
|
exception_cls = (
|
186
186
|
RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError
|
187
187
|
)
|
@@ -245,15 +245,14 @@ class RedshiftSQLHook(DbApiHook):
|
|
245
245
|
parts = hostname.split(".")
|
246
246
|
if hostname.endswith("amazonaws.com") and len(parts) == 6:
|
247
247
|
return f"{parts[0]}.{parts[2]}"
|
248
|
-
|
249
|
-
|
250
|
-
"""Could not parse identifier from hostname '%s'.
|
248
|
+
self.log.debug(
|
249
|
+
"""Could not parse identifier from hostname '%s'.
|
251
250
|
You are probably using IP to connect to Redshift cluster.
|
252
251
|
Expected format: 'cluster_identifier.id.region_name.redshift.amazonaws.com'
|
253
252
|
Falling back to whole hostname.""",
|
254
|
-
|
255
|
-
|
256
|
-
|
253
|
+
hostname,
|
254
|
+
)
|
255
|
+
return hostname
|
257
256
|
|
258
257
|
def get_openlineage_database_dialect(self, connection: Connection) -> str:
|
259
258
|
"""Return redshift dialect."""
|
@@ -442,8 +442,7 @@ class S3Hook(AwsBaseHook):
|
|
442
442
|
except ClientError as e:
|
443
443
|
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
|
444
444
|
return head_object_val
|
445
|
-
|
446
|
-
raise e
|
445
|
+
raise e
|
447
446
|
|
448
447
|
async def list_prefixes_async(
|
449
448
|
self,
|
@@ -936,8 +935,7 @@ class S3Hook(AwsBaseHook):
|
|
936
935
|
except ClientError as e:
|
937
936
|
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
|
938
937
|
return None
|
939
|
-
|
940
|
-
raise e
|
938
|
+
raise e
|
941
939
|
|
942
940
|
@unify_bucket_name_and_key
|
943
941
|
@provide_bucket_name
|
@@ -1469,8 +1467,7 @@ class S3Hook(AwsBaseHook):
|
|
1469
1467
|
raise AirflowNotFoundException(
|
1470
1468
|
f"The source file in Bucket {bucket_name} with path {key} does not exist"
|
1471
1469
|
)
|
1472
|
-
|
1473
|
-
raise e
|
1470
|
+
raise e
|
1474
1471
|
|
1475
1472
|
if preserve_file_name:
|
1476
1473
|
local_dir = local_path or gettempdir()
|
@@ -750,7 +750,7 @@ class SageMakerHook(AwsBaseHook):
|
|
750
750
|
|
751
751
|
if status in self.failed_states:
|
752
752
|
raise AirflowException(f"SageMaker resource failed because {response['FailureReason']}")
|
753
|
-
|
753
|
+
if status not in non_terminal_states:
|
754
754
|
break
|
755
755
|
|
756
756
|
if max_ingestion_time and sec > max_ingestion_time:
|
@@ -1010,8 +1010,7 @@ class SageMakerHook(AwsBaseHook):
|
|
1010
1010
|
if "NextToken" not in response or (max_results is not None and len(results) == max_results):
|
1011
1011
|
# Return when there are no results left (no NextToken) or when we've reached max_results.
|
1012
1012
|
return results
|
1013
|
-
|
1014
|
-
next_token = response["NextToken"]
|
1013
|
+
next_token = response["NextToken"]
|
1015
1014
|
|
1016
1015
|
@staticmethod
|
1017
1016
|
def _name_matches_pattern(
|
@@ -1172,9 +1171,8 @@ class SageMakerHook(AwsBaseHook):
|
|
1172
1171
|
):
|
1173
1172
|
self.log.warning("Cannot stop pipeline execution, as it was not running: %s", ce)
|
1174
1173
|
break
|
1175
|
-
|
1176
|
-
|
1177
|
-
raise
|
1174
|
+
self.log.error(ce)
|
1175
|
+
raise
|
1178
1176
|
else:
|
1179
1177
|
break
|
1180
1178
|
|
@@ -1214,9 +1212,8 @@ class SageMakerHook(AwsBaseHook):
|
|
1214
1212
|
# log msg only so it doesn't look like an error
|
1215
1213
|
self.log.info("%s", e.response["Error"]["Message"])
|
1216
1214
|
return False
|
1217
|
-
|
1218
|
-
|
1219
|
-
raise
|
1215
|
+
self.log.error("Error when trying to create Model Package Group: %s", e)
|
1216
|
+
raise
|
1220
1217
|
|
1221
1218
|
def _describe_auto_ml_job(self, job_name: str):
|
1222
1219
|
res = self.conn.describe_auto_ml_job(AutoMLJobName=job_name)
|
@@ -180,9 +180,8 @@ class SageMakerNotebookHook(BaseHook):
|
|
180
180
|
if status in finished_states:
|
181
181
|
self.log.info(execution_message)
|
182
182
|
return {"Status": status, "ExecutionId": execution_id}
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
raise AirflowException(error_message)
|
183
|
+
log_error_message = f"Execution {execution_id} failed with error: {error_message}"
|
184
|
+
self.log.error(log_error_message)
|
185
|
+
if error_message == "":
|
186
|
+
error_message = execution_message
|
187
|
+
raise AirflowException(error_message)
|
@@ -49,9 +49,9 @@ class BaseAwsLink(BaseOperatorLink):
|
|
49
49
|
def get_aws_domain(aws_partition) -> str | None:
|
50
50
|
if aws_partition == "aws":
|
51
51
|
return "aws.amazon.com"
|
52
|
-
|
52
|
+
if aws_partition == "aws-cn":
|
53
53
|
return "amazonaws.cn"
|
54
|
-
|
54
|
+
if aws_partition == "aws-us-gov":
|
55
55
|
return "amazonaws-us-gov.com"
|
56
56
|
|
57
57
|
return None
|
@@ -127,8 +127,7 @@ class EmrServerlessLogsLink(BaseAwsLink):
|
|
127
127
|
)
|
128
128
|
if url:
|
129
129
|
return url._replace(path="/logs/SPARK_DRIVER/stdout.gz").geturl()
|
130
|
-
|
131
|
-
return ""
|
130
|
+
return ""
|
132
131
|
|
133
132
|
|
134
133
|
class EmrServerlessDashboardLink(BaseAwsLink):
|
@@ -145,8 +144,7 @@ class EmrServerlessDashboardLink(BaseAwsLink):
|
|
145
144
|
)
|
146
145
|
if url:
|
147
146
|
return url.geturl()
|
148
|
-
|
149
|
-
return ""
|
147
|
+
return ""
|
150
148
|
|
151
149
|
|
152
150
|
class EmrServerlessS3LogsLink(BaseAwsLink):
|
@@ -17,6 +17,7 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
import contextlib
|
20
21
|
import copy
|
21
22
|
import json
|
22
23
|
import logging
|
@@ -56,8 +57,7 @@ def json_serialize_legacy(value: Any) -> str | None:
|
|
56
57
|
"""
|
57
58
|
if isinstance(value, (date, datetime)):
|
58
59
|
return value.isoformat()
|
59
|
-
|
60
|
-
return None
|
60
|
+
return None
|
61
61
|
|
62
62
|
|
63
63
|
def json_serialize(value: Any) -> str | None:
|
@@ -134,10 +134,8 @@ class CloudWatchRemoteLogIO(LoggingMixin): # noqa: D101
|
|
134
134
|
msg = copy.copy(event)
|
135
135
|
created = None
|
136
136
|
if ts := msg.pop("timestamp", None):
|
137
|
-
|
137
|
+
with contextlib.suppress(Exception):
|
138
138
|
created = datetime.fromisoformat(ts)
|
139
|
-
except Exception:
|
140
|
-
pass
|
141
139
|
record = logRecordFactory(
|
142
140
|
name, level, pathname="", lineno=0, msg=msg, args=(), exc_info=None, func=None, sinfo=None
|
143
141
|
)
|
@@ -162,8 +162,7 @@ class S3RemoteLogIO(LoggingMixin): # noqa: D101
|
|
162
162
|
for key in keys:
|
163
163
|
logs.append(self.s3_read(key, return_error=True))
|
164
164
|
return messages, logs
|
165
|
-
|
166
|
-
return messages, None
|
165
|
+
return messages, None
|
167
166
|
|
168
167
|
|
169
168
|
class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
@@ -168,7 +168,7 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
|
|
168
168
|
f"Final state of Athena job is {query_status}, query_execution_id is "
|
169
169
|
f"{self.query_execution_id}. Error: {error_message}"
|
170
170
|
)
|
171
|
-
|
171
|
+
if not query_status or query_status in AthenaHook.INTERMEDIATE_STATES:
|
172
172
|
raise AirflowException(
|
173
173
|
f"Final state of Athena job is {query_status}. Max tries of poll status exceeded, "
|
174
174
|
f"query_execution_id is {self.query_execution_id}."
|