apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +5 -8
- airflow/providers/google/cloud/hooks/automl.py +35 -1
- airflow/providers/google/cloud/hooks/bigquery.py +126 -41
- airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
- airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
- airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
- airflow/providers/google/cloud/hooks/dataflow.py +246 -32
- airflow/providers/google/cloud/hooks/dataplex.py +6 -2
- airflow/providers/google/cloud/hooks/dlp.py +14 -14
- airflow/providers/google/cloud/hooks/gcs.py +6 -2
- airflow/providers/google/cloud/hooks/gdm.py +2 -2
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/mlengine.py +8 -4
- airflow/providers/google/cloud/hooks/pubsub.py +1 -1
- airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
- airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +91 -0
- airflow/providers/google/cloud/links/vertex_ai.py +2 -1
- airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
- airflow/providers/google/cloud/operators/automl.py +243 -37
- airflow/providers/google/cloud/operators/bigquery.py +164 -62
- airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
- airflow/providers/google/cloud/operators/bigtable.py +7 -6
- airflow/providers/google/cloud/operators/cloud_build.py +12 -11
- airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
- airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
- airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
- airflow/providers/google/cloud/operators/compute.py +12 -11
- airflow/providers/google/cloud/operators/datacatalog.py +21 -20
- airflow/providers/google/cloud/operators/dataflow.py +59 -42
- airflow/providers/google/cloud/operators/datafusion.py +11 -10
- airflow/providers/google/cloud/operators/datapipeline.py +3 -2
- airflow/providers/google/cloud/operators/dataprep.py +5 -4
- airflow/providers/google/cloud/operators/dataproc.py +20 -17
- airflow/providers/google/cloud/operators/datastore.py +8 -7
- airflow/providers/google/cloud/operators/dlp.py +31 -30
- airflow/providers/google/cloud/operators/functions.py +4 -3
- airflow/providers/google/cloud/operators/gcs.py +66 -41
- airflow/providers/google/cloud/operators/kubernetes_engine.py +256 -49
- airflow/providers/google/cloud/operators/life_sciences.py +2 -1
- airflow/providers/google/cloud/operators/mlengine.py +11 -10
- airflow/providers/google/cloud/operators/pubsub.py +6 -5
- airflow/providers/google/cloud/operators/spanner.py +7 -6
- airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
- airflow/providers/google/cloud/operators/stackdriver.py +11 -10
- airflow/providers/google/cloud/operators/tasks.py +14 -13
- airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
- airflow/providers/google/cloud/operators/translate_speech.py +2 -1
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
- airflow/providers/google/cloud/operators/vision.py +13 -12
- airflow/providers/google/cloud/operators/workflows.py +12 -14
- airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
- airflow/providers/google/cloud/sensors/bigtable.py +2 -1
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
- airflow/providers/google/cloud/sensors/dataflow.py +239 -52
- airflow/providers/google/cloud/sensors/datafusion.py +2 -1
- airflow/providers/google/cloud/sensors/dataproc.py +3 -2
- airflow/providers/google/cloud/sensors/gcs.py +14 -12
- airflow/providers/google/cloud/sensors/tasks.py +2 -1
- airflow/providers/google/cloud/sensors/workflows.py +2 -1
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
- airflow/providers/google/cloud/triggers/bigquery.py +75 -6
- airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
- airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
- airflow/providers/google/cloud/triggers/dataflow.py +504 -4
- airflow/providers/google/cloud/triggers/dataproc.py +190 -27
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -3
- airflow/providers/google/cloud/triggers/mlengine.py +2 -1
- airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
- airflow/providers/google/common/hooks/base_google.py +45 -7
- airflow/providers/google/firebase/hooks/firestore.py +2 -2
- airflow/providers/google/firebase/operators/firestore.py +2 -1
- airflow/providers/google/get_provider_info.py +5 -3
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0.dist-info}/METADATA +18 -18
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0.dist-info}/RECORD +90 -90
- airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0.dist-info}/entry_points.txt +0 -0
@@ -22,14 +22,22 @@ from __future__ import annotations
|
|
22
22
|
import asyncio
|
23
23
|
import re
|
24
24
|
import time
|
25
|
-
from typing import Any, AsyncIterator, Sequence
|
25
|
+
from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence
|
26
26
|
|
27
27
|
from google.api_core.exceptions import NotFound
|
28
|
-
from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
|
28
|
+
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
|
29
29
|
|
30
|
-
from airflow.
|
30
|
+
from airflow.exceptions import AirflowException
|
31
|
+
from airflow.models.taskinstance import TaskInstance
|
32
|
+
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
|
31
33
|
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
|
34
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
32
35
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
36
|
+
from airflow.utils.session import provide_session
|
37
|
+
from airflow.utils.state import TaskInstanceState
|
38
|
+
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from sqlalchemy.orm.session import Session
|
33
41
|
|
34
42
|
|
35
43
|
class DataprocBaseTrigger(BaseTrigger):
|
@@ -38,10 +46,12 @@ class DataprocBaseTrigger(BaseTrigger):
|
|
38
46
|
def __init__(
|
39
47
|
self,
|
40
48
|
region: str,
|
41
|
-
project_id: str
|
49
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
42
50
|
gcp_conn_id: str = "google_cloud_default",
|
43
51
|
impersonation_chain: str | Sequence[str] | None = None,
|
44
52
|
polling_interval_seconds: int = 30,
|
53
|
+
cancel_on_kill: bool = True,
|
54
|
+
delete_on_error: bool = True,
|
45
55
|
):
|
46
56
|
super().__init__()
|
47
57
|
self.region = region
|
@@ -49,6 +59,8 @@ class DataprocBaseTrigger(BaseTrigger):
|
|
49
59
|
self.gcp_conn_id = gcp_conn_id
|
50
60
|
self.impersonation_chain = impersonation_chain
|
51
61
|
self.polling_interval_seconds = polling_interval_seconds
|
62
|
+
self.cancel_on_kill = cancel_on_kill
|
63
|
+
self.delete_on_error = delete_on_error
|
52
64
|
|
53
65
|
def get_async_hook(self):
|
54
66
|
return DataprocAsyncHook(
|
@@ -56,6 +68,16 @@ class DataprocBaseTrigger(BaseTrigger):
|
|
56
68
|
impersonation_chain=self.impersonation_chain,
|
57
69
|
)
|
58
70
|
|
71
|
+
def get_sync_hook(self):
|
72
|
+
# The synchronous hook is utilized to delete the cluster when a task is cancelled.
|
73
|
+
# This is because the asynchronous hook deletion is not awaited when the trigger task
|
74
|
+
# is cancelled. The call for deleting the cluster or job through the sync hook is not a blocking
|
75
|
+
# call, which means it does not wait until the cluster or job is deleted.
|
76
|
+
return DataprocHook(
|
77
|
+
gcp_conn_id=self.gcp_conn_id,
|
78
|
+
impersonation_chain=self.impersonation_chain,
|
79
|
+
)
|
80
|
+
|
59
81
|
|
60
82
|
class DataprocSubmitTrigger(DataprocBaseTrigger):
|
61
83
|
"""
|
@@ -90,20 +112,78 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
|
|
90
112
|
"gcp_conn_id": self.gcp_conn_id,
|
91
113
|
"impersonation_chain": self.impersonation_chain,
|
92
114
|
"polling_interval_seconds": self.polling_interval_seconds,
|
115
|
+
"cancel_on_kill": self.cancel_on_kill,
|
93
116
|
},
|
94
117
|
)
|
95
118
|
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
119
|
+
@provide_session
|
120
|
+
def get_task_instance(self, session: Session) -> TaskInstance:
|
121
|
+
"""
|
122
|
+
Get the task instance for the current task.
|
123
|
+
|
124
|
+
:param session: Sqlalchemy session
|
125
|
+
"""
|
126
|
+
query = session.query(TaskInstance).filter(
|
127
|
+
TaskInstance.dag_id == self.task_instance.dag_id,
|
128
|
+
TaskInstance.task_id == self.task_instance.task_id,
|
129
|
+
TaskInstance.run_id == self.task_instance.run_id,
|
130
|
+
TaskInstance.map_index == self.task_instance.map_index,
|
131
|
+
)
|
132
|
+
task_instance = query.one_or_none()
|
133
|
+
if task_instance is None:
|
134
|
+
raise AirflowException(
|
135
|
+
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
|
136
|
+
self.task_instance.dag_id,
|
137
|
+
self.task_instance.task_id,
|
138
|
+
self.task_instance.run_id,
|
139
|
+
self.task_instance.map_index,
|
100
140
|
)
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
141
|
+
return task_instance
|
142
|
+
|
143
|
+
def safe_to_cancel(self) -> bool:
|
144
|
+
"""
|
145
|
+
Whether it is safe to cancel the external job which is being executed by this trigger.
|
146
|
+
|
147
|
+
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
148
|
+
Because in those cases, we should NOT cancel the external job.
|
149
|
+
"""
|
150
|
+
# Database query is needed to get the latest state of the task instance.
|
151
|
+
task_instance = self.get_task_instance() # type: ignore[call-arg]
|
152
|
+
return task_instance.state != TaskInstanceState.DEFERRED
|
153
|
+
|
154
|
+
async def run(self):
|
155
|
+
try:
|
156
|
+
while True:
|
157
|
+
job = await self.get_async_hook().get_job(
|
158
|
+
project_id=self.project_id, region=self.region, job_id=self.job_id
|
159
|
+
)
|
160
|
+
state = job.status.state
|
161
|
+
self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
|
162
|
+
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
|
163
|
+
break
|
164
|
+
await asyncio.sleep(self.polling_interval_seconds)
|
165
|
+
yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job})
|
166
|
+
except asyncio.CancelledError:
|
167
|
+
self.log.info("Task got cancelled.")
|
168
|
+
try:
|
169
|
+
if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
|
170
|
+
self.log.info(
|
171
|
+
"Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
|
172
|
+
" in deferred state."
|
173
|
+
)
|
174
|
+
self.log.info("Cancelling the job: %s", self.job_id)
|
175
|
+
# The synchronous hook is utilized to delete the cluster when a task is cancelled. This
|
176
|
+
# is because the asynchronous hook deletion is not awaited when the trigger task is
|
177
|
+
# cancelled. The call for deleting the cluster or job through the sync hook is not a
|
178
|
+
# blocking call, which means it does not wait until the cluster or job is deleted.
|
179
|
+
self.get_sync_hook().cancel_job(
|
180
|
+
job_id=self.job_id, project_id=self.project_id, region=self.region
|
181
|
+
)
|
182
|
+
self.log.info("Job: %s is cancelled", self.job_id)
|
183
|
+
yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING})
|
184
|
+
except Exception as e:
|
185
|
+
self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
|
186
|
+
raise e
|
107
187
|
|
108
188
|
|
109
189
|
class DataprocClusterTrigger(DataprocBaseTrigger):
|
@@ -139,24 +219,107 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
|
|
139
219
|
"gcp_conn_id": self.gcp_conn_id,
|
140
220
|
"impersonation_chain": self.impersonation_chain,
|
141
221
|
"polling_interval_seconds": self.polling_interval_seconds,
|
222
|
+
"delete_on_error": self.delete_on_error,
|
142
223
|
},
|
143
224
|
)
|
144
225
|
|
226
|
+
@provide_session
|
227
|
+
def get_task_instance(self, session: Session) -> TaskInstance:
|
228
|
+
query = session.query(TaskInstance).filter(
|
229
|
+
TaskInstance.dag_id == self.task_instance.dag_id,
|
230
|
+
TaskInstance.task_id == self.task_instance.task_id,
|
231
|
+
TaskInstance.run_id == self.task_instance.run_id,
|
232
|
+
TaskInstance.map_index == self.task_instance.map_index,
|
233
|
+
)
|
234
|
+
task_instance = query.one_or_none()
|
235
|
+
if task_instance is None:
|
236
|
+
raise AirflowException(
|
237
|
+
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
|
238
|
+
self.task_instance.dag_id,
|
239
|
+
self.task_instance.task_id,
|
240
|
+
self.task_instance.run_id,
|
241
|
+
self.task_instance.map_index,
|
242
|
+
)
|
243
|
+
return task_instance
|
244
|
+
|
245
|
+
def safe_to_cancel(self) -> bool:
|
246
|
+
"""
|
247
|
+
Whether it is safe to cancel the external job which is being executed by this trigger.
|
248
|
+
|
249
|
+
This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
|
250
|
+
Because in those cases, we should NOT cancel the external job.
|
251
|
+
"""
|
252
|
+
# Database query is needed to get the latest state of the task instance.
|
253
|
+
task_instance = self.get_task_instance() # type: ignore[call-arg]
|
254
|
+
return task_instance.state != TaskInstanceState.DEFERRED
|
255
|
+
|
145
256
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
146
|
-
|
147
|
-
|
148
|
-
|
257
|
+
try:
|
258
|
+
while True:
|
259
|
+
cluster = await self.fetch_cluster()
|
260
|
+
state = cluster.status.state
|
261
|
+
if state == ClusterStatus.State.ERROR:
|
262
|
+
await self.delete_when_error_occurred(cluster)
|
263
|
+
yield TriggerEvent(
|
264
|
+
{
|
265
|
+
"cluster_name": self.cluster_name,
|
266
|
+
"cluster_state": ClusterStatus.State.DELETING,
|
267
|
+
"cluster": cluster,
|
268
|
+
}
|
269
|
+
)
|
270
|
+
return
|
271
|
+
elif state == ClusterStatus.State.RUNNING:
|
272
|
+
yield TriggerEvent(
|
273
|
+
{
|
274
|
+
"cluster_name": self.cluster_name,
|
275
|
+
"cluster_state": state,
|
276
|
+
"cluster": cluster,
|
277
|
+
}
|
278
|
+
)
|
279
|
+
return
|
280
|
+
self.log.info("Current state is %s", state)
|
281
|
+
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
|
282
|
+
await asyncio.sleep(self.polling_interval_seconds)
|
283
|
+
except asyncio.CancelledError:
|
284
|
+
try:
|
285
|
+
if self.delete_on_error and self.safe_to_cancel():
|
286
|
+
self.log.info(
|
287
|
+
"Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in "
|
288
|
+
"deferred state."
|
289
|
+
)
|
290
|
+
self.log.info("Deleting cluster %s.", self.cluster_name)
|
291
|
+
# The synchronous hook is utilized to delete the cluster when a task is cancelled.
|
292
|
+
# This is because the asynchronous hook deletion is not awaited when the trigger task
|
293
|
+
# is cancelled. The call for deleting the cluster through the sync hook is not a blocking
|
294
|
+
# call, which means it does not wait until the cluster is deleted.
|
295
|
+
self.get_sync_hook().delete_cluster(
|
296
|
+
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
|
297
|
+
)
|
298
|
+
self.log.info("Deleted cluster %s during cancellation.", self.cluster_name)
|
299
|
+
except Exception as e:
|
300
|
+
self.log.error("Error during cancellation handling: %s", e)
|
301
|
+
raise AirflowException("Error during cancellation handling: %s", e)
|
302
|
+
|
303
|
+
async def fetch_cluster(self) -> Cluster:
|
304
|
+
"""Fetch the cluster status."""
|
305
|
+
return await self.get_async_hook().get_cluster(
|
306
|
+
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
|
307
|
+
)
|
308
|
+
|
309
|
+
async def delete_when_error_occurred(self, cluster: Cluster) -> None:
|
310
|
+
"""
|
311
|
+
Delete the cluster on error.
|
312
|
+
|
313
|
+
:param cluster: The cluster to delete.
|
314
|
+
"""
|
315
|
+
if self.delete_on_error:
|
316
|
+
self.log.info("Deleting cluster %s.", self.cluster_name)
|
317
|
+
await self.get_async_hook().delete_cluster(
|
318
|
+
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
|
149
319
|
)
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
ClusterStatus.State.ERROR,
|
154
|
-
ClusterStatus.State.RUNNING,
|
155
|
-
):
|
156
|
-
break
|
157
|
-
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
|
158
|
-
await asyncio.sleep(self.polling_interval_seconds)
|
159
|
-
yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster})
|
320
|
+
self.log.info("Cluster %s has been deleted.", self.cluster_name)
|
321
|
+
else:
|
322
|
+
self.log.info("Cluster %s is not deleted as delete_on_error is set to False.", self.cluster_name)
|
160
323
|
|
161
324
|
|
162
325
|
class DataprocBatchTrigger(DataprocBaseTrigger):
|
@@ -30,7 +30,6 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
|
|
30
30
|
from airflow.providers.google.cloud.hooks.kubernetes_engine import (
|
31
31
|
GKEAsyncHook,
|
32
32
|
GKEKubernetesAsyncHook,
|
33
|
-
GKEPodAsyncHook,
|
34
33
|
)
|
35
34
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
36
35
|
|
@@ -147,8 +146,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
|
|
147
146
|
)
|
148
147
|
|
149
148
|
@cached_property
|
150
|
-
def hook(self) ->
|
151
|
-
return
|
149
|
+
def hook(self) -> GKEKubernetesAsyncHook: # type: ignore[override]
|
150
|
+
return GKEKubernetesAsyncHook(
|
152
151
|
cluster_url=self._cluster_url,
|
153
152
|
ssl_ca_cert=self._ssl_ca_cert,
|
154
153
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -20,6 +20,7 @@ import asyncio
|
|
20
20
|
from typing import Any, AsyncIterator, Sequence
|
21
21
|
|
22
22
|
from airflow.providers.google.cloud.hooks.mlengine import MLEngineAsyncHook
|
23
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
23
24
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
24
25
|
|
25
26
|
|
@@ -45,7 +46,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
|
|
45
46
|
runtime_version: str | None = None,
|
46
47
|
python_version: str | None = None,
|
47
48
|
job_dir: str | None = None,
|
48
|
-
project_id: str
|
49
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
49
50
|
labels: dict[str, str] | None = None,
|
50
51
|
gcp_conn_id: str = "google_cloud_default",
|
51
52
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -29,6 +29,7 @@ from google.cloud.aiplatform_v1 import (
|
|
29
29
|
|
30
30
|
from airflow.exceptions import AirflowException
|
31
31
|
from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook
|
32
|
+
from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook
|
32
33
|
from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
|
33
34
|
HyperparameterTuningJobAsyncHook,
|
34
35
|
)
|
@@ -189,3 +190,96 @@ class RunPipelineJobTrigger(BaseVertexAIJobTrigger):
|
|
189
190
|
poll_interval=self.poll_interval,
|
190
191
|
)
|
191
192
|
return job
|
193
|
+
|
194
|
+
|
195
|
+
class CustomTrainingJobTrigger(BaseVertexAIJobTrigger):
|
196
|
+
"""
|
197
|
+
Make async calls to Vertex AI to check the state of a running custom training job.
|
198
|
+
|
199
|
+
Return the job when it enters a completed state.
|
200
|
+
"""
|
201
|
+
|
202
|
+
job_type_verbose_name = "Custom Training Job"
|
203
|
+
job_serializer_class = types.TrainingPipeline
|
204
|
+
statuses_success = {
|
205
|
+
PipelineState.PIPELINE_STATE_PAUSED,
|
206
|
+
PipelineState.PIPELINE_STATE_SUCCEEDED,
|
207
|
+
}
|
208
|
+
|
209
|
+
@cached_property
|
210
|
+
def async_hook(self) -> CustomJobAsyncHook:
|
211
|
+
return CustomJobAsyncHook(
|
212
|
+
gcp_conn_id=self.conn_id,
|
213
|
+
impersonation_chain=self.impersonation_chain,
|
214
|
+
)
|
215
|
+
|
216
|
+
async def _wait_job(self) -> types.TrainingPipeline:
|
217
|
+
pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline(
|
218
|
+
project_id=self.project_id,
|
219
|
+
location=self.location,
|
220
|
+
pipeline_id=self.job_id,
|
221
|
+
poll_interval=self.poll_interval,
|
222
|
+
)
|
223
|
+
return pipeline
|
224
|
+
|
225
|
+
|
226
|
+
class CustomContainerTrainingJobTrigger(BaseVertexAIJobTrigger):
|
227
|
+
"""
|
228
|
+
Make async calls to Vertex AI to check the state of a running custom container training job.
|
229
|
+
|
230
|
+
Return the job when it enters a completed state.
|
231
|
+
"""
|
232
|
+
|
233
|
+
job_type_verbose_name = "Custom Container Training Job"
|
234
|
+
job_serializer_class = types.TrainingPipeline
|
235
|
+
statuses_success = {
|
236
|
+
PipelineState.PIPELINE_STATE_PAUSED,
|
237
|
+
PipelineState.PIPELINE_STATE_SUCCEEDED,
|
238
|
+
}
|
239
|
+
|
240
|
+
@cached_property
|
241
|
+
def async_hook(self) -> CustomJobAsyncHook:
|
242
|
+
return CustomJobAsyncHook(
|
243
|
+
gcp_conn_id=self.conn_id,
|
244
|
+
impersonation_chain=self.impersonation_chain,
|
245
|
+
)
|
246
|
+
|
247
|
+
async def _wait_job(self) -> types.TrainingPipeline:
|
248
|
+
pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline(
|
249
|
+
project_id=self.project_id,
|
250
|
+
location=self.location,
|
251
|
+
pipeline_id=self.job_id,
|
252
|
+
poll_interval=self.poll_interval,
|
253
|
+
)
|
254
|
+
return pipeline
|
255
|
+
|
256
|
+
|
257
|
+
class CustomPythonPackageTrainingJobTrigger(BaseVertexAIJobTrigger):
|
258
|
+
"""
|
259
|
+
Make async calls to Vertex AI to check the state of a running custom python package training job.
|
260
|
+
|
261
|
+
Return the job when it enters a completed state.
|
262
|
+
"""
|
263
|
+
|
264
|
+
job_type_verbose_name = "Custom Python Package Training Job"
|
265
|
+
job_serializer_class = types.TrainingPipeline
|
266
|
+
statuses_success = {
|
267
|
+
PipelineState.PIPELINE_STATE_PAUSED,
|
268
|
+
PipelineState.PIPELINE_STATE_SUCCEEDED,
|
269
|
+
}
|
270
|
+
|
271
|
+
@cached_property
|
272
|
+
def async_hook(self) -> CustomJobAsyncHook:
|
273
|
+
return CustomJobAsyncHook(
|
274
|
+
gcp_conn_id=self.conn_id,
|
275
|
+
impersonation_chain=self.impersonation_chain,
|
276
|
+
)
|
277
|
+
|
278
|
+
async def _wait_job(self) -> types.TrainingPipeline:
|
279
|
+
pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline(
|
280
|
+
project_id=self.project_id,
|
281
|
+
location=self.location,
|
282
|
+
pipeline_id=self.job_id,
|
283
|
+
poll_interval=self.poll_interval,
|
284
|
+
)
|
285
|
+
return pipeline
|
@@ -114,6 +114,19 @@ def is_operation_in_progress_exception(exception: Exception) -> bool:
|
|
114
114
|
return False
|
115
115
|
|
116
116
|
|
117
|
+
def is_refresh_credentials_exception(exception: Exception) -> bool:
|
118
|
+
"""
|
119
|
+
Handle refresh credentials exceptions.
|
120
|
+
|
121
|
+
Some calls return 502 (server error) in case a new token cannot be obtained.
|
122
|
+
|
123
|
+
* Google BigQuery
|
124
|
+
"""
|
125
|
+
if isinstance(exception, RefreshError):
|
126
|
+
return "Unable to acquire impersonated credentials" in str(exception)
|
127
|
+
return False
|
128
|
+
|
129
|
+
|
117
130
|
class retry_if_temporary_quota(tenacity.retry_if_exception):
|
118
131
|
"""Retries if there was an exception for exceeding the temporary quote limit."""
|
119
132
|
|
@@ -122,12 +135,19 @@ class retry_if_temporary_quota(tenacity.retry_if_exception):
|
|
122
135
|
|
123
136
|
|
124
137
|
class retry_if_operation_in_progress(tenacity.retry_if_exception):
|
125
|
-
"""Retries if there was an exception
|
138
|
+
"""Retries if there was an exception in case of operation in progress."""
|
126
139
|
|
127
140
|
def __init__(self):
|
128
141
|
super().__init__(is_operation_in_progress_exception)
|
129
142
|
|
130
143
|
|
144
|
+
class retry_if_temporary_refresh_credentials(tenacity.retry_if_exception):
|
145
|
+
"""Retries if there was an exception for refreshing credentials."""
|
146
|
+
|
147
|
+
def __init__(self):
|
148
|
+
super().__init__(is_refresh_credentials_exception)
|
149
|
+
|
150
|
+
|
131
151
|
# A fake project_id to use in functions decorated by fallback_to_default_project_id
|
132
152
|
# This allows the 'project_id' argument to be of type str instead of str | None,
|
133
153
|
# making it easier to type hint the function body without dealing with the None
|
@@ -364,14 +384,14 @@ class GoogleBaseHook(BaseHook):
|
|
364
384
|
return hasattr(self, "extras") and get_field(self.extras, f) or default
|
365
385
|
|
366
386
|
@property
|
367
|
-
def project_id(self) -> str
|
387
|
+
def project_id(self) -> str:
|
368
388
|
"""
|
369
389
|
Returns project id.
|
370
390
|
|
371
391
|
:return: id of the project
|
372
392
|
"""
|
373
393
|
_, project_id = self.get_credentials_and_project_id()
|
374
|
-
return project_id
|
394
|
+
return project_id or PROVIDE_PROJECT_ID
|
375
395
|
|
376
396
|
@property
|
377
397
|
def num_retries(self) -> int:
|
@@ -426,7 +446,7 @@ class GoogleBaseHook(BaseHook):
|
|
426
446
|
def quota_retry(*args, **kwargs) -> Callable:
|
427
447
|
"""Provide a mechanism to repeat requests in response to exceeding a temporary quota limit."""
|
428
448
|
|
429
|
-
def decorator(
|
449
|
+
def decorator(func: Callable):
|
430
450
|
default_kwargs = {
|
431
451
|
"wait": tenacity.wait_exponential(multiplier=1, max=100),
|
432
452
|
"retry": retry_if_temporary_quota(),
|
@@ -434,7 +454,7 @@ class GoogleBaseHook(BaseHook):
|
|
434
454
|
"after": tenacity.after_log(log, logging.DEBUG),
|
435
455
|
}
|
436
456
|
default_kwargs.update(**kwargs)
|
437
|
-
return tenacity.retry(*args, **default_kwargs)(
|
457
|
+
return tenacity.retry(*args, **default_kwargs)(func)
|
438
458
|
|
439
459
|
return decorator
|
440
460
|
|
@@ -442,7 +462,7 @@ class GoogleBaseHook(BaseHook):
|
|
442
462
|
def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]:
|
443
463
|
"""Provide a mechanism to repeat requests in response to operation in progress (HTTP 409) limit."""
|
444
464
|
|
445
|
-
def decorator(
|
465
|
+
def decorator(func: T):
|
446
466
|
default_kwargs = {
|
447
467
|
"wait": tenacity.wait_exponential(multiplier=1, max=300),
|
448
468
|
"retry": retry_if_operation_in_progress(),
|
@@ -450,7 +470,25 @@ class GoogleBaseHook(BaseHook):
|
|
450
470
|
"after": tenacity.after_log(log, logging.DEBUG),
|
451
471
|
}
|
452
472
|
default_kwargs.update(**kwargs)
|
453
|
-
return cast(T, tenacity.retry(*args, **default_kwargs)(
|
473
|
+
return cast(T, tenacity.retry(*args, **default_kwargs)(func))
|
474
|
+
|
475
|
+
return decorator
|
476
|
+
|
477
|
+
@staticmethod
|
478
|
+
def refresh_credentials_retry(*args, **kwargs) -> Callable[[T], T]:
|
479
|
+
"""Provide a mechanism to repeat requests in response to a temporary refresh credential issue."""
|
480
|
+
|
481
|
+
def decorator(func: T):
|
482
|
+
default_kwargs = {
|
483
|
+
"wait": tenacity.wait_exponential(multiplier=1, max=5),
|
484
|
+
"stop": tenacity.stop_after_attempt(3),
|
485
|
+
"retry": retry_if_temporary_refresh_credentials(),
|
486
|
+
"reraise": True,
|
487
|
+
"before": tenacity.before_log(log, logging.DEBUG),
|
488
|
+
"after": tenacity.after_log(log, logging.DEBUG),
|
489
|
+
}
|
490
|
+
default_kwargs.update(**kwargs)
|
491
|
+
return cast(T, tenacity.retry(*args, **default_kwargs)(func))
|
454
492
|
|
455
493
|
return decorator
|
456
494
|
|
@@ -25,7 +25,7 @@ from typing import Sequence
|
|
25
25
|
from googleapiclient.discovery import build, build_from_document
|
26
26
|
|
27
27
|
from airflow.exceptions import AirflowException
|
28
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
28
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
29
29
|
|
30
30
|
# Time to sleep between active checks of the operation results
|
31
31
|
TIME_TO_SLEEP_IN_SECONDS = 5
|
@@ -84,7 +84,7 @@ class CloudFirestoreHook(GoogleBaseHook):
|
|
84
84
|
|
85
85
|
@GoogleBaseHook.fallback_to_default_project_id
|
86
86
|
def export_documents(
|
87
|
-
self, body: dict, database_id: str = "(default)", project_id: str
|
87
|
+
self, body: dict, database_id: str = "(default)", project_id: str = PROVIDE_PROJECT_ID
|
88
88
|
) -> None:
|
89
89
|
"""
|
90
90
|
Start a export with the specified configuration.
|
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Sequence
|
|
20
20
|
|
21
21
|
from airflow.exceptions import AirflowException
|
22
22
|
from airflow.models import BaseOperator
|
23
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
23
24
|
from airflow.providers.google.firebase.hooks.firestore import CloudFirestoreHook
|
24
25
|
|
25
26
|
if TYPE_CHECKING:
|
@@ -64,7 +65,7 @@ class CloudFirestoreExportDatabaseOperator(BaseOperator):
|
|
64
65
|
*,
|
65
66
|
body: dict,
|
66
67
|
database_id: str = "(default)",
|
67
|
-
project_id: str
|
68
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
68
69
|
gcp_conn_id: str = "google_cloud_default",
|
69
70
|
api_version: str = "v1",
|
70
71
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -28,8 +28,9 @@ def get_provider_info():
|
|
28
28
|
"name": "Google",
|
29
29
|
"description": "Google services including:\n\n - `Google Ads <https://ads.google.com/>`__\n - `Google Cloud (GCP) <https://cloud.google.com/>`__\n - `Google Firebase <https://firebase.google.com/>`__\n - `Google LevelDB <https://github.com/google/leveldb/>`__\n - `Google Marketing Platform <https://marketingplatform.google.com/>`__\n - `Google Workspace <https://workspace.google.com/>`__ (formerly Google Suite)\n",
|
30
30
|
"state": "ready",
|
31
|
-
"source-date-epoch":
|
31
|
+
"source-date-epoch": 1715384437,
|
32
32
|
"versions": [
|
33
|
+
"10.18.0",
|
33
34
|
"10.17.0",
|
34
35
|
"10.16.0",
|
35
36
|
"10.15.0",
|
@@ -86,7 +87,7 @@ def get_provider_info():
|
|
86
87
|
"1.0.0",
|
87
88
|
],
|
88
89
|
"dependencies": [
|
89
|
-
"apache-airflow>=2.
|
90
|
+
"apache-airflow>=2.7.0",
|
90
91
|
"apache-airflow-providers-common-sql>=1.7.2",
|
91
92
|
"asgiref>=3.5.2",
|
92
93
|
"gcloud-aio-auth>=4.0.0,<5.0.0",
|
@@ -95,7 +96,7 @@ def get_provider_info():
|
|
95
96
|
"gcsfs>=2023.10.0",
|
96
97
|
"google-ads>=23.1.0",
|
97
98
|
"google-analytics-admin",
|
98
|
-
"google-api-core>=2.11.0,!=2.16.0",
|
99
|
+
"google-api-core>=2.11.0,!=2.16.0,!=2.18.0",
|
99
100
|
"google-api-python-client>=1.6.0",
|
100
101
|
"google-auth>=1.0.0",
|
101
102
|
"google-auth-httplib2>=0.0.1",
|
@@ -1146,6 +1147,7 @@ def get_provider_info():
|
|
1146
1147
|
"airflow.providers.google.cloud.hooks.vertex_ai.model_service",
|
1147
1148
|
"airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job",
|
1148
1149
|
"airflow.providers.google.cloud.hooks.vertex_ai.generative_model",
|
1150
|
+
"airflow.providers.google.cloud.hooks.vertex_ai.prediction_service",
|
1149
1151
|
],
|
1150
1152
|
},
|
1151
1153
|
{
|