apache-airflow-providers-amazon 9.9.0__py3-none-any.whl → 9.9.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/executors/batch/batch_executor.py +51 -0
- airflow/providers/amazon/aws/executors/ecs/utils.py +2 -2
- airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +1 -1
- airflow/providers/amazon/aws/fs/s3.py +2 -1
- airflow/providers/amazon/aws/hooks/athena_sql.py +12 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +24 -5
- airflow/providers/amazon/aws/hooks/batch_client.py +2 -1
- airflow/providers/amazon/aws/hooks/batch_waiters.py +2 -1
- airflow/providers/amazon/aws/hooks/chime.py +5 -1
- airflow/providers/amazon/aws/hooks/ec2.py +2 -1
- airflow/providers/amazon/aws/hooks/eks.py +1 -2
- airflow/providers/amazon/aws/hooks/glue.py +82 -7
- airflow/providers/amazon/aws/hooks/rds.py +2 -1
- airflow/providers/amazon/aws/hooks/s3.py +2 -2
- airflow/providers/amazon/aws/hooks/sagemaker.py +2 -2
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +5 -1
- airflow/providers/amazon/aws/links/base_aws.py +2 -10
- airflow/providers/amazon/aws/operators/base_aws.py +1 -1
- airflow/providers/amazon/aws/operators/batch.py +6 -22
- airflow/providers/amazon/aws/operators/ecs.py +1 -1
- airflow/providers/amazon/aws/operators/glue.py +22 -8
- airflow/providers/amazon/aws/operators/redshift_data.py +1 -1
- airflow/providers/amazon/aws/operators/sagemaker.py +2 -2
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +1 -1
- airflow/providers/amazon/aws/sensors/base_aws.py +1 -1
- airflow/providers/amazon/aws/sensors/glue.py +56 -12
- airflow/providers/amazon/aws/sensors/s3.py +2 -2
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +1 -1
- airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/base.py +1 -1
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/ftp_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -1
- airflow/providers/amazon/aws/transfers/google_api_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +3 -3
- airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/local_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_ftp.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_sftp.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_sql.py +8 -4
- airflow/providers/amazon/aws/transfers/salesforce_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/sftp_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +7 -5
- airflow/providers/amazon/aws/triggers/base.py +0 -1
- airflow/providers/amazon/aws/triggers/glue.py +37 -24
- airflow/providers/amazon/aws/utils/connection_wrapper.py +4 -1
- airflow/providers/amazon/aws/utils/suppress.py +2 -1
- airflow/providers/amazon/aws/utils/waiter.py +1 -1
- airflow/providers/amazon/aws/waiters/glue.json +55 -0
- airflow/providers/amazon/version_compat.py +10 -0
- {apache_airflow_providers_amazon-9.9.0.dist-info → apache_airflow_providers_amazon-9.9.1rc1.dist-info}/METADATA +14 -15
- {apache_airflow_providers_amazon-9.9.0.dist-info → apache_airflow_providers_amazon-9.9.1rc1.dist-info}/RECORD +62 -62
- {apache_airflow_providers_amazon-9.9.0.dist-info → apache_airflow_providers_amazon-9.9.1rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.9.0.dist-info → apache_airflow_providers_amazon-9.9.1rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "9.9.
|
32
|
+
__version__ = "9.9.1"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.10.0"
|
@@ -36,11 +36,15 @@ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
|
|
36
36
|
exponential_backoff_retry,
|
37
37
|
)
|
38
38
|
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
|
39
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
39
40
|
from airflow.stats import Stats
|
40
41
|
from airflow.utils import timezone
|
41
42
|
from airflow.utils.helpers import merge_dicts
|
42
43
|
|
43
44
|
if TYPE_CHECKING:
|
45
|
+
from sqlalchemy.orm import Session
|
46
|
+
|
47
|
+
from airflow.executors import workloads
|
44
48
|
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
|
45
49
|
from airflow.providers.amazon.aws.executors.batch.boto_schema import (
|
46
50
|
BatchDescribeJobsResponseSchema,
|
@@ -97,6 +101,11 @@ class AwsBatchExecutor(BaseExecutor):
|
|
97
101
|
# AWS only allows a maximum number of JOBs in the describe_jobs function
|
98
102
|
DESCRIBE_JOBS_BATCH_SIZE = 99
|
99
103
|
|
104
|
+
if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
|
105
|
+
# In the v3 path, we store workloads, not commands as strings.
|
106
|
+
# TODO: TaskSDK: move this type change into BaseExecutor
|
107
|
+
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
|
108
|
+
|
100
109
|
def __init__(self, *args, **kwargs):
|
101
110
|
super().__init__(*args, **kwargs)
|
102
111
|
self.active_workers = BatchJobCollection()
|
@@ -106,6 +115,30 @@ class AwsBatchExecutor(BaseExecutor):
|
|
106
115
|
self.IS_BOTO_CONNECTION_HEALTHY = False
|
107
116
|
self.submit_job_kwargs = self._load_submit_kwargs()
|
108
117
|
|
118
|
+
def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
|
119
|
+
from airflow.executors import workloads
|
120
|
+
|
121
|
+
if not isinstance(workload, workloads.ExecuteTask):
|
122
|
+
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
|
123
|
+
ti = workload.ti
|
124
|
+
self.queued_tasks[ti.key] = workload
|
125
|
+
|
126
|
+
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
|
127
|
+
from airflow.executors.workloads import ExecuteTask
|
128
|
+
|
129
|
+
# Airflow V3 version
|
130
|
+
for w in workloads:
|
131
|
+
if not isinstance(w, ExecuteTask):
|
132
|
+
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
|
133
|
+
command = [w]
|
134
|
+
key = w.ti.key
|
135
|
+
queue = w.ti.queue
|
136
|
+
executor_config = w.ti.executor_config or {}
|
137
|
+
|
138
|
+
del self.queued_tasks[key]
|
139
|
+
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
|
140
|
+
self.running.add(key)
|
141
|
+
|
109
142
|
def check_health(self):
|
110
143
|
"""Make a test API call to check the health of the Batch Executor."""
|
111
144
|
success_status = "succeeded."
|
@@ -343,6 +376,24 @@ class AwsBatchExecutor(BaseExecutor):
|
|
343
376
|
if executor_config and "command" in executor_config:
|
344
377
|
raise ValueError('Executor Config should never override "command"')
|
345
378
|
|
379
|
+
if len(command) == 1:
|
380
|
+
from airflow.executors.workloads import ExecuteTask
|
381
|
+
|
382
|
+
if isinstance(command[0], ExecuteTask):
|
383
|
+
workload = command[0]
|
384
|
+
ser_input = workload.model_dump_json()
|
385
|
+
command = [
|
386
|
+
"python",
|
387
|
+
"-m",
|
388
|
+
"airflow.sdk.execution_time.execute_workload",
|
389
|
+
"--json-string",
|
390
|
+
ser_input,
|
391
|
+
]
|
392
|
+
else:
|
393
|
+
raise ValueError(
|
394
|
+
f"BatchExecutor doesn't know how to handle workload of type: {type(command[0])}"
|
395
|
+
)
|
396
|
+
|
346
397
|
self.pending_jobs.append(
|
347
398
|
BatchQueuedJob(
|
348
399
|
key=key,
|
@@ -25,9 +25,9 @@ from __future__ import annotations
|
|
25
25
|
|
26
26
|
import datetime
|
27
27
|
from collections import defaultdict
|
28
|
-
from collections.abc import Sequence
|
28
|
+
from collections.abc import Callable, Sequence
|
29
29
|
from dataclasses import dataclass
|
30
|
-
from typing import TYPE_CHECKING, Any
|
30
|
+
from typing import TYPE_CHECKING, Any
|
31
31
|
|
32
32
|
from inflection import camelize
|
33
33
|
|
@@ -18,8 +18,9 @@ from __future__ import annotations
|
|
18
18
|
|
19
19
|
import asyncio
|
20
20
|
import logging
|
21
|
+
from collections.abc import Callable
|
21
22
|
from functools import partial
|
22
|
-
from typing import TYPE_CHECKING, Any
|
23
|
+
from typing import TYPE_CHECKING, Any
|
23
24
|
|
24
25
|
import requests
|
25
26
|
from botocore import UNSIGNED
|
@@ -111,7 +111,14 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
|
|
111
111
|
connection.login = athena_conn.login
|
112
112
|
connection.password = athena_conn.password
|
113
113
|
connection.schema = athena_conn.schema
|
114
|
-
|
114
|
+
merged_extra = {**athena_conn.extra_dejson, **connection.extra_dejson}
|
115
|
+
try:
|
116
|
+
extra_json = json.dumps(merged_extra)
|
117
|
+
connection.extra = extra_json
|
118
|
+
except (TypeError, ValueError):
|
119
|
+
raise ValueError(
|
120
|
+
f"Encountered non-JSON in `extra` field for connection {self.aws_conn_id!r}."
|
121
|
+
)
|
115
122
|
except AirflowNotFoundException:
|
116
123
|
connection = athena_conn
|
117
124
|
connection.conn_type = "aws"
|
@@ -120,7 +127,10 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
|
|
120
127
|
)
|
121
128
|
|
122
129
|
return AwsConnectionWrapper(
|
123
|
-
conn=connection,
|
130
|
+
conn=connection,
|
131
|
+
region_name=self._region_name,
|
132
|
+
botocore_config=self._config,
|
133
|
+
verify=self._verify,
|
124
134
|
)
|
125
135
|
|
126
136
|
@property
|
@@ -31,10 +31,11 @@ import json
|
|
31
31
|
import logging
|
32
32
|
import os
|
33
33
|
import warnings
|
34
|
+
from collections.abc import Callable
|
34
35
|
from copy import deepcopy
|
35
36
|
from functools import cached_property, wraps
|
36
37
|
from pathlib import Path
|
37
|
-
from typing import TYPE_CHECKING, Any,
|
38
|
+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union
|
38
39
|
|
39
40
|
import boto3
|
40
41
|
import botocore
|
@@ -43,6 +44,8 @@ import jinja2
|
|
43
44
|
import requests
|
44
45
|
import tenacity
|
45
46
|
from asgiref.sync import sync_to_async
|
47
|
+
from boto3.resources.base import ServiceResource
|
48
|
+
from botocore.client import BaseClient
|
46
49
|
from botocore.config import Config
|
47
50
|
from botocore.waiter import Waiter, WaiterModel
|
48
51
|
from dateutil.tz import tzlocal
|
@@ -54,16 +57,29 @@ from airflow.exceptions import (
|
|
54
57
|
AirflowNotFoundException,
|
55
58
|
AirflowProviderDeprecationWarning,
|
56
59
|
)
|
57
|
-
from airflow.hooks.base import BaseHook
|
58
60
|
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
|
59
61
|
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
|
60
62
|
from airflow.providers.amazon.aws.utils.suppress import return_on_error
|
61
63
|
from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS
|
62
64
|
from airflow.providers_manager import ProvidersManager
|
65
|
+
|
66
|
+
try:
|
67
|
+
from airflow.sdk import BaseHook
|
68
|
+
except ImportError:
|
69
|
+
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
|
63
70
|
from airflow.utils.helpers import exactly_one
|
64
71
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
65
72
|
|
66
|
-
|
73
|
+
# We need to set typeignore, sadly without it Sphinx build and mypy don't agree.
|
74
|
+
# ideally the code should be:
|
75
|
+
# BaseAwsConnection = TypeVar("BaseAwsConnection", bound=BaseClient | ServiceResource)
|
76
|
+
# but if we do that Sphinx complains about:
|
77
|
+
# TypeError: unsupported operand type(s) for |: 'BaseClient' and 'ServiceResource'
|
78
|
+
# If we change to Union syntax then mypy is not happy with UP007 Use `X | Y` for type annotations
|
79
|
+
# The only way to workaround it for now is to keep the union syntax with ignore for mypy
|
80
|
+
# We should try to resolve this later.
|
81
|
+
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[BaseClient, ServiceResource]) # type: ignore[operator] # noqa: UP007
|
82
|
+
|
67
83
|
|
68
84
|
if AIRFLOW_V_3_0_PLUS:
|
69
85
|
from airflow.sdk.exceptions import AirflowRuntimeError
|
@@ -627,7 +643,10 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
627
643
|
raise
|
628
644
|
|
629
645
|
return AwsConnectionWrapper(
|
630
|
-
conn=connection,
|
646
|
+
conn=connection, # type: ignore[arg-type]
|
647
|
+
region_name=self._region_name,
|
648
|
+
botocore_config=self._config,
|
649
|
+
verify=self._verify,
|
631
650
|
)
|
632
651
|
|
633
652
|
def _resolve_service_name(self, is_resource_type: bool = False) -> str:
|
@@ -1038,7 +1057,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
1038
1057
|
return WaiterModel(model_config).waiter_names
|
1039
1058
|
|
1040
1059
|
|
1041
|
-
class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
|
1060
|
+
class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]): # type: ignore[operator] # noqa: UP007
|
1042
1061
|
"""
|
1043
1062
|
Base class for interact with AWS.
|
1044
1063
|
|
@@ -30,7 +30,8 @@ from __future__ import annotations
|
|
30
30
|
import itertools
|
31
31
|
import random
|
32
32
|
import time
|
33
|
-
from
|
33
|
+
from collections.abc import Callable
|
34
|
+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
34
35
|
|
35
36
|
import botocore.client
|
36
37
|
import botocore.exceptions
|
@@ -28,9 +28,10 @@ from __future__ import annotations
|
|
28
28
|
|
29
29
|
import json
|
30
30
|
import sys
|
31
|
+
from collections.abc import Callable
|
31
32
|
from copy import deepcopy
|
32
33
|
from pathlib import Path
|
33
|
-
from typing import TYPE_CHECKING, Any
|
34
|
+
from typing import TYPE_CHECKING, Any
|
34
35
|
|
35
36
|
import botocore.client
|
36
37
|
import botocore.exceptions
|
@@ -66,9 +66,13 @@ class ChimeWebhookHook(HttpHook):
|
|
66
66
|
:return: Endpoint(str) for chime webhook.
|
67
67
|
"""
|
68
68
|
conn = self.get_connection(conn_id)
|
69
|
-
token = conn.
|
69
|
+
token = conn.password
|
70
70
|
if token is None:
|
71
71
|
raise AirflowException("Webhook token field is missing and is required.")
|
72
|
+
if not conn.schema:
|
73
|
+
raise AirflowException("Webook schema field is missing and is required")
|
74
|
+
if not conn.host:
|
75
|
+
raise AirflowException("Webhook host field is missing and is required.")
|
72
76
|
url = conn.schema + "://" + conn.host
|
73
77
|
endpoint = url + token
|
74
78
|
# Check to make sure the endpoint matches what Chime expects
|
@@ -19,7 +19,8 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
import functools
|
21
21
|
import time
|
22
|
-
from
|
22
|
+
from collections.abc import Callable
|
23
|
+
from typing import TypeVar
|
23
24
|
|
24
25
|
from airflow.exceptions import AirflowException
|
25
26
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
@@ -23,11 +23,10 @@ import json
|
|
23
23
|
import os
|
24
24
|
import sys
|
25
25
|
import tempfile
|
26
|
-
from collections.abc import Generator
|
26
|
+
from collections.abc import Callable, Generator
|
27
27
|
from contextlib import contextmanager
|
28
28
|
from enum import Enum
|
29
29
|
from functools import partial
|
30
|
-
from typing import Callable
|
31
30
|
|
32
31
|
from botocore.exceptions import ClientError
|
33
32
|
from botocore.signers import RequestSigner
|
@@ -24,6 +24,14 @@ from functools import cached_property
|
|
24
24
|
from typing import Any
|
25
25
|
|
26
26
|
from botocore.exceptions import ClientError
|
27
|
+
from tenacity import (
|
28
|
+
AsyncRetrying,
|
29
|
+
Retrying,
|
30
|
+
before_sleep_log,
|
31
|
+
retry_if_exception,
|
32
|
+
stop_after_attempt,
|
33
|
+
wait_exponential,
|
34
|
+
)
|
27
35
|
|
28
36
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
29
37
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
@@ -46,11 +54,11 @@ class GlueJobHook(AwsBaseHook):
|
|
46
54
|
:param script_location: path to etl script on s3
|
47
55
|
:param retry_limit: Maximum number of times to retry this job if it fails
|
48
56
|
:param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job
|
49
|
-
:param region_name: aws region name (example: us-east-1)
|
50
57
|
:param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None.
|
51
58
|
:param iam_role_arn: AWS IAM Role ARN for Glue Job Execution, If set `iam_role_name` must equal None.
|
52
59
|
:param create_job_kwargs: Extra arguments for Glue Job Creation
|
53
60
|
:param update_config: Update job configuration on Glue (default: False)
|
61
|
+
:param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
|
54
62
|
|
55
63
|
Additional arguments (such as ``aws_conn_id``) may be specified and
|
56
64
|
are passed down to the underlying AwsBaseHook.
|
@@ -80,6 +88,7 @@ class GlueJobHook(AwsBaseHook):
|
|
80
88
|
create_job_kwargs: dict | None = None,
|
81
89
|
update_config: bool = False,
|
82
90
|
job_poll_interval: int | float = 6,
|
91
|
+
api_retry_args: dict[Any, Any] | None = None,
|
83
92
|
*args,
|
84
93
|
**kwargs,
|
85
94
|
):
|
@@ -96,6 +105,17 @@ class GlueJobHook(AwsBaseHook):
|
|
96
105
|
self.update_config = update_config
|
97
106
|
self.job_poll_interval = job_poll_interval
|
98
107
|
|
108
|
+
self.retry_config: dict[str, Any] = {
|
109
|
+
"retry": retry_if_exception(self._should_retry_on_error),
|
110
|
+
"wait": wait_exponential(multiplier=1, min=1, max=60),
|
111
|
+
"stop": stop_after_attempt(5),
|
112
|
+
"before_sleep": before_sleep_log(self.log, log_level=20),
|
113
|
+
"reraise": True,
|
114
|
+
}
|
115
|
+
|
116
|
+
if api_retry_args:
|
117
|
+
self.retry_config.update(api_retry_args)
|
118
|
+
|
99
119
|
worker_type_exists = "WorkerType" in self.create_job_kwargs
|
100
120
|
num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs
|
101
121
|
|
@@ -116,6 +136,29 @@ class GlueJobHook(AwsBaseHook):
|
|
116
136
|
kwargs["client_type"] = "glue"
|
117
137
|
super().__init__(*args, **kwargs)
|
118
138
|
|
139
|
+
def _should_retry_on_error(self, exception: BaseException) -> bool:
|
140
|
+
"""
|
141
|
+
Determine if an exception should trigger a retry.
|
142
|
+
|
143
|
+
:param exception: The exception that occurred
|
144
|
+
:return: True if the exception should trigger a retry, False otherwise
|
145
|
+
"""
|
146
|
+
if isinstance(exception, ClientError):
|
147
|
+
error_code = exception.response.get("Error", {}).get("Code", "")
|
148
|
+
retryable_errors = {
|
149
|
+
"ThrottlingException",
|
150
|
+
"RequestLimitExceeded",
|
151
|
+
"ServiceUnavailable",
|
152
|
+
"InternalFailure",
|
153
|
+
"InternalServerError",
|
154
|
+
"TooManyRequestsException",
|
155
|
+
"RequestTimeout",
|
156
|
+
"RequestTimeoutException",
|
157
|
+
"HttpTimeoutException",
|
158
|
+
}
|
159
|
+
return error_code in retryable_errors
|
160
|
+
return False
|
161
|
+
|
119
162
|
def create_glue_job_config(self) -> dict:
|
120
163
|
default_command = {
|
121
164
|
"Name": "glueetl",
|
@@ -217,8 +260,21 @@ class GlueJobHook(AwsBaseHook):
|
|
217
260
|
:param run_id: The job-run ID of the predecessor job run
|
218
261
|
:return: State of the Glue job
|
219
262
|
"""
|
220
|
-
|
221
|
-
|
263
|
+
for attempt in Retrying(**self.retry_config):
|
264
|
+
with attempt:
|
265
|
+
try:
|
266
|
+
job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
|
267
|
+
return job_run["JobRun"]["JobRunState"]
|
268
|
+
except ClientError as e:
|
269
|
+
self.log.error("Failed to get job state for job %s run %s: %s", job_name, run_id, e)
|
270
|
+
raise
|
271
|
+
except Exception as e:
|
272
|
+
self.log.error(
|
273
|
+
"Unexpected error getting job state for job %s run %s: %s", job_name, run_id, e
|
274
|
+
)
|
275
|
+
raise
|
276
|
+
# This should never be reached due to reraise=True, but mypy needs it
|
277
|
+
raise RuntimeError("Unexpected end of retry loop")
|
222
278
|
|
223
279
|
async def async_get_job_state(self, job_name: str, run_id: str) -> str:
|
224
280
|
"""
|
@@ -226,9 +282,22 @@ class GlueJobHook(AwsBaseHook):
|
|
226
282
|
|
227
283
|
The async version of get_job_state.
|
228
284
|
"""
|
229
|
-
async
|
230
|
-
|
231
|
-
|
285
|
+
async for attempt in AsyncRetrying(**self.retry_config):
|
286
|
+
with attempt:
|
287
|
+
try:
|
288
|
+
async with await self.get_async_conn() as client:
|
289
|
+
job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
|
290
|
+
return job_run["JobRun"]["JobRunState"]
|
291
|
+
except ClientError as e:
|
292
|
+
self.log.error("Failed to get job state for job %s run %s: %s", job_name, run_id, e)
|
293
|
+
raise
|
294
|
+
except Exception as e:
|
295
|
+
self.log.error(
|
296
|
+
"Unexpected error getting job state for job %s run %s: %s", job_name, run_id, e
|
297
|
+
)
|
298
|
+
raise
|
299
|
+
# This should never be reached due to reraise=True, but mypy needs it
|
300
|
+
raise RuntimeError("Unexpected end of retry loop")
|
232
301
|
|
233
302
|
@cached_property
|
234
303
|
def logs_hook(self):
|
@@ -372,7 +441,7 @@ class GlueJobHook(AwsBaseHook):
|
|
372
441
|
)
|
373
442
|
return None
|
374
443
|
|
375
|
-
def has_job(self, job_name) -> bool:
|
444
|
+
def has_job(self, job_name: str) -> bool:
|
376
445
|
"""
|
377
446
|
Check if the job already exists.
|
378
447
|
|
@@ -422,6 +491,9 @@ class GlueJobHook(AwsBaseHook):
|
|
422
491
|
|
423
492
|
:return:Name of the Job
|
424
493
|
"""
|
494
|
+
if self.job_name is None:
|
495
|
+
raise ValueError("job_name must be set to get or create a Glue job")
|
496
|
+
|
425
497
|
if self.has_job(self.job_name):
|
426
498
|
return self.job_name
|
427
499
|
|
@@ -441,6 +513,9 @@ class GlueJobHook(AwsBaseHook):
|
|
441
513
|
|
442
514
|
:return:Name of the Job
|
443
515
|
"""
|
516
|
+
if self.job_name is None:
|
517
|
+
raise ValueError("job_name must be set to create or update a Glue job")
|
518
|
+
|
444
519
|
config = self.create_glue_job_config()
|
445
520
|
|
446
521
|
if self.has_job(self.job_name):
|
@@ -20,7 +20,8 @@
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import time
|
23
|
-
from
|
23
|
+
from collections.abc import Callable
|
24
|
+
from typing import TYPE_CHECKING
|
24
25
|
|
25
26
|
from airflow.exceptions import AirflowException, AirflowNotFoundException
|
26
27
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
@@ -28,7 +28,7 @@ import os
|
|
28
28
|
import re
|
29
29
|
import shutil
|
30
30
|
import time
|
31
|
-
from collections.abc import AsyncIterator
|
31
|
+
from collections.abc import AsyncIterator, Callable
|
32
32
|
from contextlib import suppress
|
33
33
|
from copy import deepcopy
|
34
34
|
from datetime import datetime
|
@@ -37,7 +37,7 @@ from inspect import signature
|
|
37
37
|
from io import BytesIO
|
38
38
|
from pathlib import Path
|
39
39
|
from tempfile import NamedTemporaryFile, gettempdir
|
40
|
-
from typing import TYPE_CHECKING, Any
|
40
|
+
from typing import TYPE_CHECKING, Any
|
41
41
|
from urllib.parse import urlsplit
|
42
42
|
from uuid import uuid4
|
43
43
|
|
@@ -23,10 +23,10 @@ import tarfile
|
|
23
23
|
import tempfile
|
24
24
|
import time
|
25
25
|
from collections import Counter, namedtuple
|
26
|
-
from collections.abc import AsyncGenerator, Generator
|
26
|
+
from collections.abc import AsyncGenerator, Callable, Generator
|
27
27
|
from datetime import datetime
|
28
28
|
from functools import partial
|
29
|
-
from typing import Any,
|
29
|
+
from typing import Any, cast
|
30
30
|
|
31
31
|
from asgiref.sync import sync_to_async
|
32
32
|
from botocore.exceptions import ClientError
|
@@ -25,9 +25,13 @@ from sagemaker_studio import ClientConfig
|
|
25
25
|
from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI
|
26
26
|
|
27
27
|
from airflow.exceptions import AirflowException
|
28
|
-
from airflow.hooks.base import BaseHook
|
29
28
|
from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner
|
30
29
|
|
30
|
+
try:
|
31
|
+
from airflow.sdk import BaseHook
|
32
|
+
except ImportError:
|
33
|
+
from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
|
34
|
+
|
31
35
|
|
32
36
|
class SageMakerNotebookHook(BaseHook):
|
33
37
|
"""
|
@@ -20,20 +20,13 @@ from __future__ import annotations
|
|
20
20
|
from typing import TYPE_CHECKING, ClassVar
|
21
21
|
|
22
22
|
from airflow.providers.amazon.aws.utils.suppress import return_on_error
|
23
|
-
from airflow.providers.amazon.version_compat import
|
23
|
+
from airflow.providers.amazon.version_compat import BaseOperatorLink, XCom
|
24
24
|
|
25
25
|
if TYPE_CHECKING:
|
26
26
|
from airflow.models import BaseOperator
|
27
27
|
from airflow.models.taskinstancekey import TaskInstanceKey
|
28
28
|
from airflow.utils.context import Context
|
29
29
|
|
30
|
-
if AIRFLOW_V_3_0_PLUS:
|
31
|
-
from airflow.sdk import BaseOperatorLink
|
32
|
-
from airflow.sdk.execution_time.xcom import XCom
|
33
|
-
else:
|
34
|
-
from airflow.models import XCom # type: ignore[no-redef]
|
35
|
-
from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
|
36
|
-
|
37
30
|
|
38
31
|
BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}"
|
39
32
|
|
@@ -94,8 +87,7 @@ class BaseAwsLink(BaseOperatorLink):
|
|
94
87
|
if not operator.do_xcom_push:
|
95
88
|
return
|
96
89
|
|
97
|
-
|
98
|
-
context,
|
90
|
+
context["ti"].xcom_push(
|
99
91
|
key=cls.key,
|
100
92
|
value={
|
101
93
|
"region_name": region_name,
|
@@ -19,13 +19,13 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
from collections.abc import Sequence
|
21
21
|
|
22
|
-
from airflow.models import BaseOperator
|
23
22
|
from airflow.providers.amazon.aws.utils.mixins import (
|
24
23
|
AwsBaseHookMixin,
|
25
24
|
AwsHookParams,
|
26
25
|
AwsHookType,
|
27
26
|
aws_template_fields,
|
28
27
|
)
|
28
|
+
from airflow.providers.amazon.version_compat import BaseOperator
|
29
29
|
from airflow.utils.types import NOTSET, ArgNotSet
|
30
30
|
|
31
31
|
|
@@ -140,28 +140,12 @@ class BatchOperator(AwsBaseOperator[BatchClientHook]):
|
|
140
140
|
"retry_strategy": "json",
|
141
141
|
}
|
142
142
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
"wait_for_completion"
|
150
|
-
) or self.expand_input.value.get("wait_for_completion")
|
151
|
-
array_properties = self.partial_kwargs.get("array_properties") or self.expand_input.value.get(
|
152
|
-
"array_properties"
|
153
|
-
)
|
154
|
-
else:
|
155
|
-
wait_for_completion = self.wait_for_completion
|
156
|
-
array_properties = self.array_properties
|
157
|
-
|
158
|
-
if wait_for_completion:
|
159
|
-
op_extra_links.extend([BatchJobDefinitionLink(), BatchJobQueueLink()])
|
160
|
-
if not array_properties:
|
161
|
-
# There is no CloudWatch Link to the parent Batch Job available.
|
162
|
-
op_extra_links.append(CloudWatchEventsLink())
|
163
|
-
|
164
|
-
return tuple(op_extra_links)
|
143
|
+
operator_extra_links = (
|
144
|
+
BatchJobDetailsLink(),
|
145
|
+
BatchJobDefinitionLink(),
|
146
|
+
BatchJobQueueLink(),
|
147
|
+
CloudWatchEventsLink(),
|
148
|
+
)
|
165
149
|
|
166
150
|
def __init__(
|
167
151
|
self,
|
@@ -526,7 +526,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
|
|
526
526
|
self._start_task()
|
527
527
|
|
528
528
|
if self.do_xcom_push:
|
529
|
-
|
529
|
+
context["ti"].xcom_push(key="ecs_task_arn", value=self.arn)
|
530
530
|
|
531
531
|
if self.deferrable:
|
532
532
|
self.defer(
|