apache-airflow-providers-amazon 9.9.0__py3-none-any.whl → 9.9.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/executors/batch/batch_executor.py +51 -0
  3. airflow/providers/amazon/aws/executors/ecs/utils.py +2 -2
  4. airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +1 -1
  5. airflow/providers/amazon/aws/fs/s3.py +2 -1
  6. airflow/providers/amazon/aws/hooks/athena_sql.py +12 -2
  7. airflow/providers/amazon/aws/hooks/base_aws.py +24 -5
  8. airflow/providers/amazon/aws/hooks/batch_client.py +2 -1
  9. airflow/providers/amazon/aws/hooks/batch_waiters.py +2 -1
  10. airflow/providers/amazon/aws/hooks/chime.py +5 -1
  11. airflow/providers/amazon/aws/hooks/ec2.py +2 -1
  12. airflow/providers/amazon/aws/hooks/eks.py +1 -2
  13. airflow/providers/amazon/aws/hooks/glue.py +82 -7
  14. airflow/providers/amazon/aws/hooks/rds.py +2 -1
  15. airflow/providers/amazon/aws/hooks/s3.py +2 -2
  16. airflow/providers/amazon/aws/hooks/sagemaker.py +2 -2
  17. airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +5 -1
  18. airflow/providers/amazon/aws/links/base_aws.py +2 -10
  19. airflow/providers/amazon/aws/operators/base_aws.py +1 -1
  20. airflow/providers/amazon/aws/operators/batch.py +6 -22
  21. airflow/providers/amazon/aws/operators/ecs.py +1 -1
  22. airflow/providers/amazon/aws/operators/glue.py +22 -8
  23. airflow/providers/amazon/aws/operators/redshift_data.py +1 -1
  24. airflow/providers/amazon/aws/operators/sagemaker.py +2 -2
  25. airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +1 -1
  26. airflow/providers/amazon/aws/sensors/base_aws.py +1 -1
  27. airflow/providers/amazon/aws/sensors/glue.py +56 -12
  28. airflow/providers/amazon/aws/sensors/s3.py +2 -2
  29. airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +1 -1
  30. airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +1 -1
  31. airflow/providers/amazon/aws/transfers/base.py +1 -1
  32. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +2 -2
  33. airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -1
  34. airflow/providers/amazon/aws/transfers/ftp_to_s3.py +1 -1
  35. airflow/providers/amazon/aws/transfers/gcs_to_s3.py +1 -1
  36. airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -1
  37. airflow/providers/amazon/aws/transfers/google_api_to_s3.py +1 -1
  38. airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +3 -3
  39. airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -1
  40. airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -1
  41. airflow/providers/amazon/aws/transfers/local_to_s3.py +1 -1
  42. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +1 -1
  43. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -1
  44. airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +1 -1
  45. airflow/providers/amazon/aws/transfers/s3_to_ftp.py +1 -1
  46. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -1
  47. airflow/providers/amazon/aws/transfers/s3_to_sftp.py +1 -1
  48. airflow/providers/amazon/aws/transfers/s3_to_sql.py +8 -4
  49. airflow/providers/amazon/aws/transfers/salesforce_to_s3.py +1 -1
  50. airflow/providers/amazon/aws/transfers/sftp_to_s3.py +1 -1
  51. airflow/providers/amazon/aws/transfers/sql_to_s3.py +7 -5
  52. airflow/providers/amazon/aws/triggers/base.py +0 -1
  53. airflow/providers/amazon/aws/triggers/glue.py +37 -24
  54. airflow/providers/amazon/aws/utils/connection_wrapper.py +4 -1
  55. airflow/providers/amazon/aws/utils/suppress.py +2 -1
  56. airflow/providers/amazon/aws/utils/waiter.py +1 -1
  57. airflow/providers/amazon/aws/waiters/glue.json +55 -0
  58. airflow/providers/amazon/version_compat.py +10 -0
  59. {apache_airflow_providers_amazon-9.9.0.dist-info → apache_airflow_providers_amazon-9.9.1rc1.dist-info}/METADATA +14 -15
  60. {apache_airflow_providers_amazon-9.9.0.dist-info → apache_airflow_providers_amazon-9.9.1rc1.dist-info}/RECORD +62 -62
  61. {apache_airflow_providers_amazon-9.9.0.dist-info → apache_airflow_providers_amazon-9.9.1rc1.dist-info}/WHEEL +0 -0
  62. {apache_airflow_providers_amazon-9.9.0.dist-info → apache_airflow_providers_amazon-9.9.1rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "9.9.0"
32
+ __version__ = "9.9.1"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.10.0"
@@ -36,11 +36,15 @@ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
36
36
  exponential_backoff_retry,
37
37
  )
38
38
  from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
39
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
39
40
  from airflow.stats import Stats
40
41
  from airflow.utils import timezone
41
42
  from airflow.utils.helpers import merge_dicts
42
43
 
43
44
  if TYPE_CHECKING:
45
+ from sqlalchemy.orm import Session
46
+
47
+ from airflow.executors import workloads
44
48
  from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
45
49
  from airflow.providers.amazon.aws.executors.batch.boto_schema import (
46
50
  BatchDescribeJobsResponseSchema,
@@ -97,6 +101,11 @@ class AwsBatchExecutor(BaseExecutor):
97
101
  # AWS only allows a maximum number of JOBs in the describe_jobs function
98
102
  DESCRIBE_JOBS_BATCH_SIZE = 99
99
103
 
104
+ if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
105
+ # In the v3 path, we store workloads, not commands as strings.
106
+ # TODO: TaskSDK: move this type change into BaseExecutor
107
+ queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
108
+
100
109
  def __init__(self, *args, **kwargs):
101
110
  super().__init__(*args, **kwargs)
102
111
  self.active_workers = BatchJobCollection()
@@ -106,6 +115,30 @@ class AwsBatchExecutor(BaseExecutor):
106
115
  self.IS_BOTO_CONNECTION_HEALTHY = False
107
116
  self.submit_job_kwargs = self._load_submit_kwargs()
108
117
 
118
+ def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
119
+ from airflow.executors import workloads
120
+
121
+ if not isinstance(workload, workloads.ExecuteTask):
122
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
123
+ ti = workload.ti
124
+ self.queued_tasks[ti.key] = workload
125
+
126
+ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
127
+ from airflow.executors.workloads import ExecuteTask
128
+
129
+ # Airflow V3 version
130
+ for w in workloads:
131
+ if not isinstance(w, ExecuteTask):
132
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
133
+ command = [w]
134
+ key = w.ti.key
135
+ queue = w.ti.queue
136
+ executor_config = w.ti.executor_config or {}
137
+
138
+ del self.queued_tasks[key]
139
+ self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
140
+ self.running.add(key)
141
+
109
142
  def check_health(self):
110
143
  """Make a test API call to check the health of the Batch Executor."""
111
144
  success_status = "succeeded."
@@ -343,6 +376,24 @@ class AwsBatchExecutor(BaseExecutor):
343
376
  if executor_config and "command" in executor_config:
344
377
  raise ValueError('Executor Config should never override "command"')
345
378
 
379
+ if len(command) == 1:
380
+ from airflow.executors.workloads import ExecuteTask
381
+
382
+ if isinstance(command[0], ExecuteTask):
383
+ workload = command[0]
384
+ ser_input = workload.model_dump_json()
385
+ command = [
386
+ "python",
387
+ "-m",
388
+ "airflow.sdk.execution_time.execute_workload",
389
+ "--json-string",
390
+ ser_input,
391
+ ]
392
+ else:
393
+ raise ValueError(
394
+ f"BatchExecutor doesn't know how to handle workload of type: {type(command[0])}"
395
+ )
396
+
346
397
  self.pending_jobs.append(
347
398
  BatchQueuedJob(
348
399
  key=key,
@@ -25,9 +25,9 @@ from __future__ import annotations
25
25
 
26
26
  import datetime
27
27
  from collections import defaultdict
28
- from collections.abc import Sequence
28
+ from collections.abc import Callable, Sequence
29
29
  from dataclasses import dataclass
30
- from typing import TYPE_CHECKING, Any, Callable
30
+ from typing import TYPE_CHECKING, Any
31
31
 
32
32
  from inflection import camelize
33
33
 
@@ -17,8 +17,8 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import logging
20
+ from collections.abc import Callable
20
21
  from datetime import datetime, timedelta
21
- from typing import Callable
22
22
 
23
23
  from airflow.utils import timezone
24
24
 
@@ -18,8 +18,9 @@ from __future__ import annotations
18
18
 
19
19
  import asyncio
20
20
  import logging
21
+ from collections.abc import Callable
21
22
  from functools import partial
22
- from typing import TYPE_CHECKING, Any, Callable
23
+ from typing import TYPE_CHECKING, Any
23
24
 
24
25
  import requests
25
26
  from botocore import UNSIGNED
@@ -111,7 +111,14 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
111
111
  connection.login = athena_conn.login
112
112
  connection.password = athena_conn.password
113
113
  connection.schema = athena_conn.schema
114
- connection.set_extra(json.dumps({**athena_conn.extra_dejson, **connection.extra_dejson}))
114
+ merged_extra = {**athena_conn.extra_dejson, **connection.extra_dejson}
115
+ try:
116
+ extra_json = json.dumps(merged_extra)
117
+ connection.extra = extra_json
118
+ except (TypeError, ValueError):
119
+ raise ValueError(
120
+ f"Encountered non-JSON in `extra` field for connection {self.aws_conn_id!r}."
121
+ )
115
122
  except AirflowNotFoundException:
116
123
  connection = athena_conn
117
124
  connection.conn_type = "aws"
@@ -120,7 +127,10 @@ class AthenaSQLHook(AwsBaseHook, DbApiHook):
120
127
  )
121
128
 
122
129
  return AwsConnectionWrapper(
123
- conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
130
+ conn=connection,
131
+ region_name=self._region_name,
132
+ botocore_config=self._config,
133
+ verify=self._verify,
124
134
  )
125
135
 
126
136
  @property
@@ -31,10 +31,11 @@ import json
31
31
  import logging
32
32
  import os
33
33
  import warnings
34
+ from collections.abc import Callable
34
35
  from copy import deepcopy
35
36
  from functools import cached_property, wraps
36
37
  from pathlib import Path
37
- from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
38
+ from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union
38
39
 
39
40
  import boto3
40
41
  import botocore
@@ -43,6 +44,8 @@ import jinja2
43
44
  import requests
44
45
  import tenacity
45
46
  from asgiref.sync import sync_to_async
47
+ from boto3.resources.base import ServiceResource
48
+ from botocore.client import BaseClient
46
49
  from botocore.config import Config
47
50
  from botocore.waiter import Waiter, WaiterModel
48
51
  from dateutil.tz import tzlocal
@@ -54,16 +57,29 @@ from airflow.exceptions import (
54
57
  AirflowNotFoundException,
55
58
  AirflowProviderDeprecationWarning,
56
59
  )
57
- from airflow.hooks.base import BaseHook
58
60
  from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
59
61
  from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
60
62
  from airflow.providers.amazon.aws.utils.suppress import return_on_error
61
63
  from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS
62
64
  from airflow.providers_manager import ProvidersManager
65
+
66
+ try:
67
+ from airflow.sdk import BaseHook
68
+ except ImportError:
69
+ from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
63
70
  from airflow.utils.helpers import exactly_one
64
71
  from airflow.utils.log.logging_mixin import LoggingMixin
65
72
 
66
- BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
73
+ # We need to set typeignore, sadly without it Sphinx build and mypy don't agree.
74
+ # ideally the code should be:
75
+ # BaseAwsConnection = TypeVar("BaseAwsConnection", bound=BaseClient | ServiceResource)
76
+ # but if we do that Sphinx complains about:
77
+ # TypeError: unsupported operand type(s) for |: 'BaseClient' and 'ServiceResource'
78
+ # If we change to Union syntax then mypy is not happy with UP007 Use `X | Y` for type annotations
79
+ # The only way to workaround it for now is to keep the union syntax with ignore for mypy
80
+ # We should try to resolve this later.
81
+ BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[BaseClient, ServiceResource]) # type: ignore[operator] # noqa: UP007
82
+
67
83
 
68
84
  if AIRFLOW_V_3_0_PLUS:
69
85
  from airflow.sdk.exceptions import AirflowRuntimeError
@@ -627,7 +643,10 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
627
643
  raise
628
644
 
629
645
  return AwsConnectionWrapper(
630
- conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
646
+ conn=connection, # type: ignore[arg-type]
647
+ region_name=self._region_name,
648
+ botocore_config=self._config,
649
+ verify=self._verify,
631
650
  )
632
651
 
633
652
  def _resolve_service_name(self, is_resource_type: bool = False) -> str:
@@ -1038,7 +1057,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
1038
1057
  return WaiterModel(model_config).waiter_names
1039
1058
 
1040
1059
 
1041
- class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
1060
+ class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]): # type: ignore[operator] # noqa: UP007
1042
1061
  """
1043
1062
  Base class for interact with AWS.
1044
1063
 
@@ -30,7 +30,8 @@ from __future__ import annotations
30
30
  import itertools
31
31
  import random
32
32
  import time
33
- from typing import TYPE_CHECKING, Callable, Protocol, runtime_checkable
33
+ from collections.abc import Callable
34
+ from typing import TYPE_CHECKING, Protocol, runtime_checkable
34
35
 
35
36
  import botocore.client
36
37
  import botocore.exceptions
@@ -28,9 +28,10 @@ from __future__ import annotations
28
28
 
29
29
  import json
30
30
  import sys
31
+ from collections.abc import Callable
31
32
  from copy import deepcopy
32
33
  from pathlib import Path
33
- from typing import TYPE_CHECKING, Any, Callable
34
+ from typing import TYPE_CHECKING, Any
34
35
 
35
36
  import botocore.client
36
37
  import botocore.exceptions
@@ -66,9 +66,13 @@ class ChimeWebhookHook(HttpHook):
66
66
  :return: Endpoint(str) for chime webhook.
67
67
  """
68
68
  conn = self.get_connection(conn_id)
69
- token = conn.get_password()
69
+ token = conn.password
70
70
  if token is None:
71
71
  raise AirflowException("Webhook token field is missing and is required.")
72
+ if not conn.schema:
73
+ raise AirflowException("Webook schema field is missing and is required")
74
+ if not conn.host:
75
+ raise AirflowException("Webhook host field is missing and is required.")
72
76
  url = conn.schema + "://" + conn.host
73
77
  endpoint = url + token
74
78
  # Check to make sure the endpoint matches what Chime expects
@@ -19,7 +19,8 @@ from __future__ import annotations
19
19
 
20
20
  import functools
21
21
  import time
22
- from typing import Callable, TypeVar
22
+ from collections.abc import Callable
23
+ from typing import TypeVar
23
24
 
24
25
  from airflow.exceptions import AirflowException
25
26
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -23,11 +23,10 @@ import json
23
23
  import os
24
24
  import sys
25
25
  import tempfile
26
- from collections.abc import Generator
26
+ from collections.abc import Callable, Generator
27
27
  from contextlib import contextmanager
28
28
  from enum import Enum
29
29
  from functools import partial
30
- from typing import Callable
31
30
 
32
31
  from botocore.exceptions import ClientError
33
32
  from botocore.signers import RequestSigner
@@ -24,6 +24,14 @@ from functools import cached_property
24
24
  from typing import Any
25
25
 
26
26
  from botocore.exceptions import ClientError
27
+ from tenacity import (
28
+ AsyncRetrying,
29
+ Retrying,
30
+ before_sleep_log,
31
+ retry_if_exception,
32
+ stop_after_attempt,
33
+ wait_exponential,
34
+ )
27
35
 
28
36
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
29
37
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -46,11 +54,11 @@ class GlueJobHook(AwsBaseHook):
46
54
  :param script_location: path to etl script on s3
47
55
  :param retry_limit: Maximum number of times to retry this job if it fails
48
56
  :param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job
49
- :param region_name: aws region name (example: us-east-1)
50
57
  :param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None.
51
58
  :param iam_role_arn: AWS IAM Role ARN for Glue Job Execution, If set `iam_role_name` must equal None.
52
59
  :param create_job_kwargs: Extra arguments for Glue Job Creation
53
60
  :param update_config: Update job configuration on Glue (default: False)
61
+ :param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.
54
62
 
55
63
  Additional arguments (such as ``aws_conn_id``) may be specified and
56
64
  are passed down to the underlying AwsBaseHook.
@@ -80,6 +88,7 @@ class GlueJobHook(AwsBaseHook):
80
88
  create_job_kwargs: dict | None = None,
81
89
  update_config: bool = False,
82
90
  job_poll_interval: int | float = 6,
91
+ api_retry_args: dict[Any, Any] | None = None,
83
92
  *args,
84
93
  **kwargs,
85
94
  ):
@@ -96,6 +105,17 @@ class GlueJobHook(AwsBaseHook):
96
105
  self.update_config = update_config
97
106
  self.job_poll_interval = job_poll_interval
98
107
 
108
+ self.retry_config: dict[str, Any] = {
109
+ "retry": retry_if_exception(self._should_retry_on_error),
110
+ "wait": wait_exponential(multiplier=1, min=1, max=60),
111
+ "stop": stop_after_attempt(5),
112
+ "before_sleep": before_sleep_log(self.log, log_level=20),
113
+ "reraise": True,
114
+ }
115
+
116
+ if api_retry_args:
117
+ self.retry_config.update(api_retry_args)
118
+
99
119
  worker_type_exists = "WorkerType" in self.create_job_kwargs
100
120
  num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs
101
121
 
@@ -116,6 +136,29 @@ class GlueJobHook(AwsBaseHook):
116
136
  kwargs["client_type"] = "glue"
117
137
  super().__init__(*args, **kwargs)
118
138
 
139
+ def _should_retry_on_error(self, exception: BaseException) -> bool:
140
+ """
141
+ Determine if an exception should trigger a retry.
142
+
143
+ :param exception: The exception that occurred
144
+ :return: True if the exception should trigger a retry, False otherwise
145
+ """
146
+ if isinstance(exception, ClientError):
147
+ error_code = exception.response.get("Error", {}).get("Code", "")
148
+ retryable_errors = {
149
+ "ThrottlingException",
150
+ "RequestLimitExceeded",
151
+ "ServiceUnavailable",
152
+ "InternalFailure",
153
+ "InternalServerError",
154
+ "TooManyRequestsException",
155
+ "RequestTimeout",
156
+ "RequestTimeoutException",
157
+ "HttpTimeoutException",
158
+ }
159
+ return error_code in retryable_errors
160
+ return False
161
+
119
162
  def create_glue_job_config(self) -> dict:
120
163
  default_command = {
121
164
  "Name": "glueetl",
@@ -217,8 +260,21 @@ class GlueJobHook(AwsBaseHook):
217
260
  :param run_id: The job-run ID of the predecessor job run
218
261
  :return: State of the Glue job
219
262
  """
220
- job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
221
- return job_run["JobRun"]["JobRunState"]
263
+ for attempt in Retrying(**self.retry_config):
264
+ with attempt:
265
+ try:
266
+ job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
267
+ return job_run["JobRun"]["JobRunState"]
268
+ except ClientError as e:
269
+ self.log.error("Failed to get job state for job %s run %s: %s", job_name, run_id, e)
270
+ raise
271
+ except Exception as e:
272
+ self.log.error(
273
+ "Unexpected error getting job state for job %s run %s: %s", job_name, run_id, e
274
+ )
275
+ raise
276
+ # This should never be reached due to reraise=True, but mypy needs it
277
+ raise RuntimeError("Unexpected end of retry loop")
222
278
 
223
279
  async def async_get_job_state(self, job_name: str, run_id: str) -> str:
224
280
  """
@@ -226,9 +282,22 @@ class GlueJobHook(AwsBaseHook):
226
282
 
227
283
  The async version of get_job_state.
228
284
  """
229
- async with await self.get_async_conn() as client:
230
- job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
231
- return job_run["JobRun"]["JobRunState"]
285
+ async for attempt in AsyncRetrying(**self.retry_config):
286
+ with attempt:
287
+ try:
288
+ async with await self.get_async_conn() as client:
289
+ job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
290
+ return job_run["JobRun"]["JobRunState"]
291
+ except ClientError as e:
292
+ self.log.error("Failed to get job state for job %s run %s: %s", job_name, run_id, e)
293
+ raise
294
+ except Exception as e:
295
+ self.log.error(
296
+ "Unexpected error getting job state for job %s run %s: %s", job_name, run_id, e
297
+ )
298
+ raise
299
+ # This should never be reached due to reraise=True, but mypy needs it
300
+ raise RuntimeError("Unexpected end of retry loop")
232
301
 
233
302
  @cached_property
234
303
  def logs_hook(self):
@@ -372,7 +441,7 @@ class GlueJobHook(AwsBaseHook):
372
441
  )
373
442
  return None
374
443
 
375
- def has_job(self, job_name) -> bool:
444
+ def has_job(self, job_name: str) -> bool:
376
445
  """
377
446
  Check if the job already exists.
378
447
 
@@ -422,6 +491,9 @@ class GlueJobHook(AwsBaseHook):
422
491
 
423
492
  :return:Name of the Job
424
493
  """
494
+ if self.job_name is None:
495
+ raise ValueError("job_name must be set to get or create a Glue job")
496
+
425
497
  if self.has_job(self.job_name):
426
498
  return self.job_name
427
499
 
@@ -441,6 +513,9 @@ class GlueJobHook(AwsBaseHook):
441
513
 
442
514
  :return:Name of the Job
443
515
  """
516
+ if self.job_name is None:
517
+ raise ValueError("job_name must be set to create or update a Glue job")
518
+
444
519
  config = self.create_glue_job_config()
445
520
 
446
521
  if self.has_job(self.job_name):
@@ -20,7 +20,8 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  import time
23
- from typing import TYPE_CHECKING, Callable
23
+ from collections.abc import Callable
24
+ from typing import TYPE_CHECKING
24
25
 
25
26
  from airflow.exceptions import AirflowException, AirflowNotFoundException
26
27
  from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
@@ -28,7 +28,7 @@ import os
28
28
  import re
29
29
  import shutil
30
30
  import time
31
- from collections.abc import AsyncIterator
31
+ from collections.abc import AsyncIterator, Callable
32
32
  from contextlib import suppress
33
33
  from copy import deepcopy
34
34
  from datetime import datetime
@@ -37,7 +37,7 @@ from inspect import signature
37
37
  from io import BytesIO
38
38
  from pathlib import Path
39
39
  from tempfile import NamedTemporaryFile, gettempdir
40
- from typing import TYPE_CHECKING, Any, Callable
40
+ from typing import TYPE_CHECKING, Any
41
41
  from urllib.parse import urlsplit
42
42
  from uuid import uuid4
43
43
 
@@ -23,10 +23,10 @@ import tarfile
23
23
  import tempfile
24
24
  import time
25
25
  from collections import Counter, namedtuple
26
- from collections.abc import AsyncGenerator, Generator
26
+ from collections.abc import AsyncGenerator, Callable, Generator
27
27
  from datetime import datetime
28
28
  from functools import partial
29
- from typing import Any, Callable, cast
29
+ from typing import Any, cast
30
30
 
31
31
  from asgiref.sync import sync_to_async
32
32
  from botocore.exceptions import ClientError
@@ -25,9 +25,13 @@ from sagemaker_studio import ClientConfig
25
25
  from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI
26
26
 
27
27
  from airflow.exceptions import AirflowException
28
- from airflow.hooks.base import BaseHook
29
28
  from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner
30
29
 
30
+ try:
31
+ from airflow.sdk import BaseHook
32
+ except ImportError:
33
+ from airflow.hooks.base import BaseHook # type: ignore[attr-defined,no-redef]
34
+
31
35
 
32
36
  class SageMakerNotebookHook(BaseHook):
33
37
  """
@@ -20,20 +20,13 @@ from __future__ import annotations
20
20
  from typing import TYPE_CHECKING, ClassVar
21
21
 
22
22
  from airflow.providers.amazon.aws.utils.suppress import return_on_error
23
- from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
23
+ from airflow.providers.amazon.version_compat import BaseOperatorLink, XCom
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from airflow.models import BaseOperator
27
27
  from airflow.models.taskinstancekey import TaskInstanceKey
28
28
  from airflow.utils.context import Context
29
29
 
30
- if AIRFLOW_V_3_0_PLUS:
31
- from airflow.sdk import BaseOperatorLink
32
- from airflow.sdk.execution_time.xcom import XCom
33
- else:
34
- from airflow.models import XCom # type: ignore[no-redef]
35
- from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef]
36
-
37
30
 
38
31
  BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}"
39
32
 
@@ -94,8 +87,7 @@ class BaseAwsLink(BaseOperatorLink):
94
87
  if not operator.do_xcom_push:
95
88
  return
96
89
 
97
- operator.xcom_push(
98
- context,
90
+ context["ti"].xcom_push(
99
91
  key=cls.key,
100
92
  value={
101
93
  "region_name": region_name,
@@ -19,13 +19,13 @@ from __future__ import annotations
19
19
 
20
20
  from collections.abc import Sequence
21
21
 
22
- from airflow.models import BaseOperator
23
22
  from airflow.providers.amazon.aws.utils.mixins import (
24
23
  AwsBaseHookMixin,
25
24
  AwsHookParams,
26
25
  AwsHookType,
27
26
  aws_template_fields,
28
27
  )
28
+ from airflow.providers.amazon.version_compat import BaseOperator
29
29
  from airflow.utils.types import NOTSET, ArgNotSet
30
30
 
31
31
 
@@ -140,28 +140,12 @@ class BatchOperator(AwsBaseOperator[BatchClientHook]):
140
140
  "retry_strategy": "json",
141
141
  }
142
142
 
143
- @property
144
- def operator_extra_links(self):
145
- op_extra_links = [BatchJobDetailsLink()]
146
-
147
- if self.is_mapped:
148
- wait_for_completion = self.partial_kwargs.get(
149
- "wait_for_completion"
150
- ) or self.expand_input.value.get("wait_for_completion")
151
- array_properties = self.partial_kwargs.get("array_properties") or self.expand_input.value.get(
152
- "array_properties"
153
- )
154
- else:
155
- wait_for_completion = self.wait_for_completion
156
- array_properties = self.array_properties
157
-
158
- if wait_for_completion:
159
- op_extra_links.extend([BatchJobDefinitionLink(), BatchJobQueueLink()])
160
- if not array_properties:
161
- # There is no CloudWatch Link to the parent Batch Job available.
162
- op_extra_links.append(CloudWatchEventsLink())
163
-
164
- return tuple(op_extra_links)
143
+ operator_extra_links = (
144
+ BatchJobDetailsLink(),
145
+ BatchJobDefinitionLink(),
146
+ BatchJobQueueLink(),
147
+ CloudWatchEventsLink(),
148
+ )
165
149
 
166
150
  def __init__(
167
151
  self,
@@ -526,7 +526,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
526
526
  self._start_task()
527
527
 
528
528
  if self.do_xcom_push:
529
- self.xcom_push(context, key="ecs_task_arn", value=self.arn)
529
+ context["ti"].xcom_push(key="ecs_task_arn", value=self.arn)
530
530
 
531
531
  if self.deferrable:
532
532
  self.defer(