apache-airflow-providers-google 10.17.0rc1__py3-none-any.whl → 10.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. airflow/providers/google/__init__.py +5 -8
  2. airflow/providers/google/cloud/hooks/automl.py +35 -1
  3. airflow/providers/google/cloud/hooks/bigquery.py +126 -41
  4. airflow/providers/google/cloud/hooks/cloud_composer.py +250 -2
  5. airflow/providers/google/cloud/hooks/cloud_sql.py +154 -7
  6. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +7 -2
  7. airflow/providers/google/cloud/hooks/compute_ssh.py +2 -1
  8. airflow/providers/google/cloud/hooks/dataflow.py +246 -32
  9. airflow/providers/google/cloud/hooks/dataplex.py +6 -2
  10. airflow/providers/google/cloud/hooks/dlp.py +14 -14
  11. airflow/providers/google/cloud/hooks/gcs.py +6 -2
  12. airflow/providers/google/cloud/hooks/gdm.py +2 -2
  13. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  14. airflow/providers/google/cloud/hooks/mlengine.py +8 -4
  15. airflow/providers/google/cloud/hooks/pubsub.py +1 -1
  16. airflow/providers/google/cloud/hooks/secret_manager.py +252 -4
  17. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1431 -74
  18. airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +91 -0
  19. airflow/providers/google/cloud/links/vertex_ai.py +2 -1
  20. airflow/providers/google/cloud/log/gcs_task_handler.py +2 -1
  21. airflow/providers/google/cloud/operators/automl.py +243 -37
  22. airflow/providers/google/cloud/operators/bigquery.py +164 -62
  23. airflow/providers/google/cloud/operators/bigquery_dts.py +4 -3
  24. airflow/providers/google/cloud/operators/bigtable.py +7 -6
  25. airflow/providers/google/cloud/operators/cloud_build.py +12 -11
  26. airflow/providers/google/cloud/operators/cloud_composer.py +147 -2
  27. airflow/providers/google/cloud/operators/cloud_memorystore.py +17 -16
  28. airflow/providers/google/cloud/operators/cloud_sql.py +60 -17
  29. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +35 -16
  30. airflow/providers/google/cloud/operators/compute.py +12 -11
  31. airflow/providers/google/cloud/operators/datacatalog.py +21 -20
  32. airflow/providers/google/cloud/operators/dataflow.py +59 -42
  33. airflow/providers/google/cloud/operators/datafusion.py +11 -10
  34. airflow/providers/google/cloud/operators/datapipeline.py +3 -2
  35. airflow/providers/google/cloud/operators/dataprep.py +5 -4
  36. airflow/providers/google/cloud/operators/dataproc.py +20 -17
  37. airflow/providers/google/cloud/operators/datastore.py +8 -7
  38. airflow/providers/google/cloud/operators/dlp.py +31 -30
  39. airflow/providers/google/cloud/operators/functions.py +4 -3
  40. airflow/providers/google/cloud/operators/gcs.py +66 -41
  41. airflow/providers/google/cloud/operators/kubernetes_engine.py +256 -49
  42. airflow/providers/google/cloud/operators/life_sciences.py +2 -1
  43. airflow/providers/google/cloud/operators/mlengine.py +11 -10
  44. airflow/providers/google/cloud/operators/pubsub.py +6 -5
  45. airflow/providers/google/cloud/operators/spanner.py +7 -6
  46. airflow/providers/google/cloud/operators/speech_to_text.py +2 -1
  47. airflow/providers/google/cloud/operators/stackdriver.py +11 -10
  48. airflow/providers/google/cloud/operators/tasks.py +14 -13
  49. airflow/providers/google/cloud/operators/text_to_speech.py +2 -1
  50. airflow/providers/google/cloud/operators/translate_speech.py +2 -1
  51. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +333 -26
  52. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +20 -12
  53. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +0 -1
  54. airflow/providers/google/cloud/operators/vision.py +13 -12
  55. airflow/providers/google/cloud/operators/workflows.py +12 -14
  56. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  57. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -1
  58. airflow/providers/google/cloud/sensors/bigtable.py +2 -1
  59. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -1
  60. airflow/providers/google/cloud/sensors/dataflow.py +239 -52
  61. airflow/providers/google/cloud/sensors/datafusion.py +2 -1
  62. airflow/providers/google/cloud/sensors/dataproc.py +3 -2
  63. airflow/providers/google/cloud/sensors/gcs.py +14 -12
  64. airflow/providers/google/cloud/sensors/tasks.py +2 -1
  65. airflow/providers/google/cloud/sensors/workflows.py +2 -1
  66. airflow/providers/google/cloud/transfers/adls_to_gcs.py +8 -2
  67. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +7 -1
  68. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +7 -1
  69. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
  70. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -1
  71. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
  72. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +5 -6
  73. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +22 -12
  74. airflow/providers/google/cloud/triggers/bigquery.py +75 -6
  75. airflow/providers/google/cloud/triggers/cloud_composer.py +68 -0
  76. airflow/providers/google/cloud/triggers/cloud_sql.py +2 -1
  77. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +2 -1
  78. airflow/providers/google/cloud/triggers/dataflow.py +504 -4
  79. airflow/providers/google/cloud/triggers/dataproc.py +190 -27
  80. airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -3
  81. airflow/providers/google/cloud/triggers/mlengine.py +2 -1
  82. airflow/providers/google/cloud/triggers/vertex_ai.py +94 -0
  83. airflow/providers/google/common/hooks/base_google.py +45 -7
  84. airflow/providers/google/firebase/hooks/firestore.py +2 -2
  85. airflow/providers/google/firebase/operators/firestore.py +2 -1
  86. airflow/providers/google/get_provider_info.py +5 -3
  87. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0.dist-info}/METADATA +18 -18
  88. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0.dist-info}/RECORD +90 -90
  89. airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +0 -289
  90. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0.dist-info}/WHEEL +0 -0
  91. {apache_airflow_providers_google-10.17.0rc1.dist-info → apache_airflow_providers_google-10.18.0.dist-info}/entry_points.txt +0 -0
@@ -22,14 +22,22 @@ from __future__ import annotations
22
22
  import asyncio
23
23
  import re
24
24
  import time
25
- from typing import Any, AsyncIterator, Sequence
25
+ from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence
26
26
 
27
27
  from google.api_core.exceptions import NotFound
28
- from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
28
+ from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
29
29
 
30
- from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
30
+ from airflow.exceptions import AirflowException
31
+ from airflow.models.taskinstance import TaskInstance
32
+ from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
31
33
  from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
34
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
32
35
  from airflow.triggers.base import BaseTrigger, TriggerEvent
36
+ from airflow.utils.session import provide_session
37
+ from airflow.utils.state import TaskInstanceState
38
+
39
+ if TYPE_CHECKING:
40
+ from sqlalchemy.orm.session import Session
33
41
 
34
42
 
35
43
  class DataprocBaseTrigger(BaseTrigger):
@@ -38,10 +46,12 @@ class DataprocBaseTrigger(BaseTrigger):
38
46
  def __init__(
39
47
  self,
40
48
  region: str,
41
- project_id: str | None = None,
49
+ project_id: str = PROVIDE_PROJECT_ID,
42
50
  gcp_conn_id: str = "google_cloud_default",
43
51
  impersonation_chain: str | Sequence[str] | None = None,
44
52
  polling_interval_seconds: int = 30,
53
+ cancel_on_kill: bool = True,
54
+ delete_on_error: bool = True,
45
55
  ):
46
56
  super().__init__()
47
57
  self.region = region
@@ -49,6 +59,8 @@ class DataprocBaseTrigger(BaseTrigger):
49
59
  self.gcp_conn_id = gcp_conn_id
50
60
  self.impersonation_chain = impersonation_chain
51
61
  self.polling_interval_seconds = polling_interval_seconds
62
+ self.cancel_on_kill = cancel_on_kill
63
+ self.delete_on_error = delete_on_error
52
64
 
53
65
  def get_async_hook(self):
54
66
  return DataprocAsyncHook(
@@ -56,6 +68,16 @@ class DataprocBaseTrigger(BaseTrigger):
56
68
  impersonation_chain=self.impersonation_chain,
57
69
  )
58
70
 
71
+ def get_sync_hook(self):
72
+ # The synchronous hook is utilized to delete the cluster when a task is cancelled.
73
+ # This is because the asynchronous hook deletion is not awaited when the trigger task
74
+ # is cancelled. The call for deleting the cluster or job through the sync hook is not a blocking
75
+ # call, which means it does not wait until the cluster or job is deleted.
76
+ return DataprocHook(
77
+ gcp_conn_id=self.gcp_conn_id,
78
+ impersonation_chain=self.impersonation_chain,
79
+ )
80
+
59
81
 
60
82
  class DataprocSubmitTrigger(DataprocBaseTrigger):
61
83
  """
@@ -90,20 +112,78 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
90
112
  "gcp_conn_id": self.gcp_conn_id,
91
113
  "impersonation_chain": self.impersonation_chain,
92
114
  "polling_interval_seconds": self.polling_interval_seconds,
115
+ "cancel_on_kill": self.cancel_on_kill,
93
116
  },
94
117
  )
95
118
 
96
- async def run(self):
97
- while True:
98
- job = await self.get_async_hook().get_job(
99
- project_id=self.project_id, region=self.region, job_id=self.job_id
119
+ @provide_session
120
+ def get_task_instance(self, session: Session) -> TaskInstance:
121
+ """
122
+ Get the task instance for the current task.
123
+
124
+ :param session: Sqlalchemy session
125
+ """
126
+ query = session.query(TaskInstance).filter(
127
+ TaskInstance.dag_id == self.task_instance.dag_id,
128
+ TaskInstance.task_id == self.task_instance.task_id,
129
+ TaskInstance.run_id == self.task_instance.run_id,
130
+ TaskInstance.map_index == self.task_instance.map_index,
131
+ )
132
+ task_instance = query.one_or_none()
133
+ if task_instance is None:
134
+ raise AirflowException(
135
+ "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
136
+ self.task_instance.dag_id,
137
+ self.task_instance.task_id,
138
+ self.task_instance.run_id,
139
+ self.task_instance.map_index,
100
140
  )
101
- state = job.status.state
102
- self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
103
- if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
104
- break
105
- await asyncio.sleep(self.polling_interval_seconds)
106
- yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job})
141
+ return task_instance
142
+
143
+ def safe_to_cancel(self) -> bool:
144
+ """
145
+ Whether it is safe to cancel the external job which is being executed by this trigger.
146
+
147
+ This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
148
+ Because in those cases, we should NOT cancel the external job.
149
+ """
150
+ # Database query is needed to get the latest state of the task instance.
151
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
152
+ return task_instance.state != TaskInstanceState.DEFERRED
153
+
154
+ async def run(self):
155
+ try:
156
+ while True:
157
+ job = await self.get_async_hook().get_job(
158
+ project_id=self.project_id, region=self.region, job_id=self.job_id
159
+ )
160
+ state = job.status.state
161
+ self.log.info("Dataproc job: %s is in state: %s", self.job_id, state)
162
+ if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
163
+ break
164
+ await asyncio.sleep(self.polling_interval_seconds)
165
+ yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job})
166
+ except asyncio.CancelledError:
167
+ self.log.info("Task got cancelled.")
168
+ try:
169
+ if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
170
+ self.log.info(
171
+ "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
172
+ " in deferred state."
173
+ )
174
+ self.log.info("Cancelling the job: %s", self.job_id)
175
+ # The synchronous hook is utilized to delete the cluster when a task is cancelled. This
176
+ # is because the asynchronous hook deletion is not awaited when the trigger task is
177
+ # cancelled. The call for deleting the cluster or job through the sync hook is not a
178
+ # blocking call, which means it does not wait until the cluster or job is deleted.
179
+ self.get_sync_hook().cancel_job(
180
+ job_id=self.job_id, project_id=self.project_id, region=self.region
181
+ )
182
+ self.log.info("Job: %s is cancelled", self.job_id)
183
+ yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING})
184
+ except Exception as e:
185
+ self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
186
+ raise e
107
187
 
108
188
 
109
189
  class DataprocClusterTrigger(DataprocBaseTrigger):
@@ -139,24 +219,107 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
139
219
  "gcp_conn_id": self.gcp_conn_id,
140
220
  "impersonation_chain": self.impersonation_chain,
141
221
  "polling_interval_seconds": self.polling_interval_seconds,
222
+ "delete_on_error": self.delete_on_error,
142
223
  },
143
224
  )
144
225
 
226
+ @provide_session
227
+ def get_task_instance(self, session: Session) -> TaskInstance:
228
+ query = session.query(TaskInstance).filter(
229
+ TaskInstance.dag_id == self.task_instance.dag_id,
230
+ TaskInstance.task_id == self.task_instance.task_id,
231
+ TaskInstance.run_id == self.task_instance.run_id,
232
+ TaskInstance.map_index == self.task_instance.map_index,
233
+ )
234
+ task_instance = query.one_or_none()
235
+ if task_instance is None:
236
+ raise AirflowException(
237
+ "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
238
+ self.task_instance.dag_id,
239
+ self.task_instance.task_id,
240
+ self.task_instance.run_id,
241
+ self.task_instance.map_index,
242
+ )
243
+ return task_instance
244
+
245
+ def safe_to_cancel(self) -> bool:
246
+ """
247
+ Whether it is safe to cancel the external job which is being executed by this trigger.
248
+
249
+ This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
250
+ Because in those cases, we should NOT cancel the external job.
251
+ """
252
+ # Database query is needed to get the latest state of the task instance.
253
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
254
+ return task_instance.state != TaskInstanceState.DEFERRED
255
+
145
256
  async def run(self) -> AsyncIterator[TriggerEvent]:
146
- while True:
147
- cluster = await self.get_async_hook().get_cluster(
148
- project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
257
+ try:
258
+ while True:
259
+ cluster = await self.fetch_cluster()
260
+ state = cluster.status.state
261
+ if state == ClusterStatus.State.ERROR:
262
+ await self.delete_when_error_occurred(cluster)
263
+ yield TriggerEvent(
264
+ {
265
+ "cluster_name": self.cluster_name,
266
+ "cluster_state": ClusterStatus.State.DELETING,
267
+ "cluster": cluster,
268
+ }
269
+ )
270
+ return
271
+ elif state == ClusterStatus.State.RUNNING:
272
+ yield TriggerEvent(
273
+ {
274
+ "cluster_name": self.cluster_name,
275
+ "cluster_state": state,
276
+ "cluster": cluster,
277
+ }
278
+ )
279
+ return
280
+ self.log.info("Current state is %s", state)
281
+ self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
282
+ await asyncio.sleep(self.polling_interval_seconds)
283
+ except asyncio.CancelledError:
284
+ try:
285
+ if self.delete_on_error and self.safe_to_cancel():
286
+ self.log.info(
287
+ "Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in "
288
+ "deferred state."
289
+ )
290
+ self.log.info("Deleting cluster %s.", self.cluster_name)
291
+ # The synchronous hook is utilized to delete the cluster when a task is cancelled.
292
+ # This is because the asynchronous hook deletion is not awaited when the trigger task
293
+ # is cancelled. The call for deleting the cluster through the sync hook is not a blocking
294
+ # call, which means it does not wait until the cluster is deleted.
295
+ self.get_sync_hook().delete_cluster(
296
+ region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
297
+ )
298
+ self.log.info("Deleted cluster %s during cancellation.", self.cluster_name)
299
+ except Exception as e:
300
+ self.log.error("Error during cancellation handling: %s", e)
301
+ raise AirflowException("Error during cancellation handling: %s", e)
302
+
303
+ async def fetch_cluster(self) -> Cluster:
304
+ """Fetch the cluster status."""
305
+ return await self.get_async_hook().get_cluster(
306
+ project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
307
+ )
308
+
309
+ async def delete_when_error_occurred(self, cluster: Cluster) -> None:
310
+ """
311
+ Delete the cluster on error.
312
+
313
+ :param cluster: The cluster to delete.
314
+ """
315
+ if self.delete_on_error:
316
+ self.log.info("Deleting cluster %s.", self.cluster_name)
317
+ await self.get_async_hook().delete_cluster(
318
+ region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
149
319
  )
150
- state = cluster.status.state
151
- self.log.info("Dataproc cluster: %s is in state: %s", self.cluster_name, state)
152
- if state in (
153
- ClusterStatus.State.ERROR,
154
- ClusterStatus.State.RUNNING,
155
- ):
156
- break
157
- self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
158
- await asyncio.sleep(self.polling_interval_seconds)
159
- yield TriggerEvent({"cluster_name": self.cluster_name, "cluster_state": state, "cluster": cluster})
320
+ self.log.info("Cluster %s has been deleted.", self.cluster_name)
321
+ else:
322
+ self.log.info("Cluster %s is not deleted as delete_on_error is set to False.", self.cluster_name)
160
323
 
161
324
 
162
325
  class DataprocBatchTrigger(DataprocBaseTrigger):
@@ -30,7 +30,6 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
30
30
  from airflow.providers.google.cloud.hooks.kubernetes_engine import (
31
31
  GKEAsyncHook,
32
32
  GKEKubernetesAsyncHook,
33
- GKEPodAsyncHook,
34
33
  )
35
34
  from airflow.triggers.base import BaseTrigger, TriggerEvent
36
35
 
@@ -147,8 +146,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
147
146
  )
148
147
 
149
148
  @cached_property
150
- def hook(self) -> GKEPodAsyncHook: # type: ignore[override]
151
- return GKEPodAsyncHook(
149
+ def hook(self) -> GKEKubernetesAsyncHook: # type: ignore[override]
150
+ return GKEKubernetesAsyncHook(
152
151
  cluster_url=self._cluster_url,
153
152
  ssl_ca_cert=self._ssl_ca_cert,
154
153
  gcp_conn_id=self.gcp_conn_id,
@@ -20,6 +20,7 @@ import asyncio
20
20
  from typing import Any, AsyncIterator, Sequence
21
21
 
22
22
  from airflow.providers.google.cloud.hooks.mlengine import MLEngineAsyncHook
23
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
23
24
  from airflow.triggers.base import BaseTrigger, TriggerEvent
24
25
 
25
26
 
@@ -45,7 +46,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
45
46
  runtime_version: str | None = None,
46
47
  python_version: str | None = None,
47
48
  job_dir: str | None = None,
48
- project_id: str | None = None,
49
+ project_id: str = PROVIDE_PROJECT_ID,
49
50
  labels: dict[str, str] | None = None,
50
51
  gcp_conn_id: str = "google_cloud_default",
51
52
  impersonation_chain: str | Sequence[str] | None = None,
@@ -29,6 +29,7 @@ from google.cloud.aiplatform_v1 import (
29
29
 
30
30
  from airflow.exceptions import AirflowException
31
31
  from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook
32
+ from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook
32
33
  from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
33
34
  HyperparameterTuningJobAsyncHook,
34
35
  )
@@ -189,3 +190,96 @@ class RunPipelineJobTrigger(BaseVertexAIJobTrigger):
189
190
  poll_interval=self.poll_interval,
190
191
  )
191
192
  return job
193
+
194
+
195
+ class CustomTrainingJobTrigger(BaseVertexAIJobTrigger):
196
+ """
197
+ Make async calls to Vertex AI to check the state of a running custom training job.
198
+
199
+ Return the job when it enters a completed state.
200
+ """
201
+
202
+ job_type_verbose_name = "Custom Training Job"
203
+ job_serializer_class = types.TrainingPipeline
204
+ statuses_success = {
205
+ PipelineState.PIPELINE_STATE_PAUSED,
206
+ PipelineState.PIPELINE_STATE_SUCCEEDED,
207
+ }
208
+
209
+ @cached_property
210
+ def async_hook(self) -> CustomJobAsyncHook:
211
+ return CustomJobAsyncHook(
212
+ gcp_conn_id=self.conn_id,
213
+ impersonation_chain=self.impersonation_chain,
214
+ )
215
+
216
+ async def _wait_job(self) -> types.TrainingPipeline:
217
+ pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline(
218
+ project_id=self.project_id,
219
+ location=self.location,
220
+ pipeline_id=self.job_id,
221
+ poll_interval=self.poll_interval,
222
+ )
223
+ return pipeline
224
+
225
+
226
+ class CustomContainerTrainingJobTrigger(BaseVertexAIJobTrigger):
227
+ """
228
+ Make async calls to Vertex AI to check the state of a running custom container training job.
229
+
230
+ Return the job when it enters a completed state.
231
+ """
232
+
233
+ job_type_verbose_name = "Custom Container Training Job"
234
+ job_serializer_class = types.TrainingPipeline
235
+ statuses_success = {
236
+ PipelineState.PIPELINE_STATE_PAUSED,
237
+ PipelineState.PIPELINE_STATE_SUCCEEDED,
238
+ }
239
+
240
+ @cached_property
241
+ def async_hook(self) -> CustomJobAsyncHook:
242
+ return CustomJobAsyncHook(
243
+ gcp_conn_id=self.conn_id,
244
+ impersonation_chain=self.impersonation_chain,
245
+ )
246
+
247
+ async def _wait_job(self) -> types.TrainingPipeline:
248
+ pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline(
249
+ project_id=self.project_id,
250
+ location=self.location,
251
+ pipeline_id=self.job_id,
252
+ poll_interval=self.poll_interval,
253
+ )
254
+ return pipeline
255
+
256
+
257
+ class CustomPythonPackageTrainingJobTrigger(BaseVertexAIJobTrigger):
258
+ """
259
+ Make async calls to Vertex AI to check the state of a running custom python package training job.
260
+
261
+ Return the job when it enters a completed state.
262
+ """
263
+
264
+ job_type_verbose_name = "Custom Python Package Training Job"
265
+ job_serializer_class = types.TrainingPipeline
266
+ statuses_success = {
267
+ PipelineState.PIPELINE_STATE_PAUSED,
268
+ PipelineState.PIPELINE_STATE_SUCCEEDED,
269
+ }
270
+
271
+ @cached_property
272
+ def async_hook(self) -> CustomJobAsyncHook:
273
+ return CustomJobAsyncHook(
274
+ gcp_conn_id=self.conn_id,
275
+ impersonation_chain=self.impersonation_chain,
276
+ )
277
+
278
+ async def _wait_job(self) -> types.TrainingPipeline:
279
+ pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline(
280
+ project_id=self.project_id,
281
+ location=self.location,
282
+ pipeline_id=self.job_id,
283
+ poll_interval=self.poll_interval,
284
+ )
285
+ return pipeline
@@ -114,6 +114,19 @@ def is_operation_in_progress_exception(exception: Exception) -> bool:
114
114
  return False
115
115
 
116
116
 
117
+ def is_refresh_credentials_exception(exception: Exception) -> bool:
118
+ """
119
+ Handle refresh credentials exceptions.
120
+
121
+ Some calls return 502 (server error) in case a new token cannot be obtained.
122
+
123
+ * Google BigQuery
124
+ """
125
+ if isinstance(exception, RefreshError):
126
+ return "Unable to acquire impersonated credentials" in str(exception)
127
+ return False
128
+
129
+
117
130
  class retry_if_temporary_quota(tenacity.retry_if_exception):
118
131
  """Retries if there was an exception for exceeding the temporary quote limit."""
119
132
 
@@ -122,12 +135,19 @@ class retry_if_temporary_quota(tenacity.retry_if_exception):
122
135
 
123
136
 
124
137
  class retry_if_operation_in_progress(tenacity.retry_if_exception):
125
- """Retries if there was an exception for exceeding the temporary quote limit."""
138
+ """Retries if there was an exception in case of operation in progress."""
126
139
 
127
140
  def __init__(self):
128
141
  super().__init__(is_operation_in_progress_exception)
129
142
 
130
143
 
144
+ class retry_if_temporary_refresh_credentials(tenacity.retry_if_exception):
145
+ """Retries if there was an exception for refreshing credentials."""
146
+
147
+ def __init__(self):
148
+ super().__init__(is_refresh_credentials_exception)
149
+
150
+
131
151
  # A fake project_id to use in functions decorated by fallback_to_default_project_id
132
152
  # This allows the 'project_id' argument to be of type str instead of str | None,
133
153
  # making it easier to type hint the function body without dealing with the None
@@ -364,14 +384,14 @@ class GoogleBaseHook(BaseHook):
364
384
  return hasattr(self, "extras") and get_field(self.extras, f) or default
365
385
 
366
386
  @property
367
- def project_id(self) -> str | None:
387
+ def project_id(self) -> str:
368
388
  """
369
389
  Returns project id.
370
390
 
371
391
  :return: id of the project
372
392
  """
373
393
  _, project_id = self.get_credentials_and_project_id()
374
- return project_id
394
+ return project_id or PROVIDE_PROJECT_ID
375
395
 
376
396
  @property
377
397
  def num_retries(self) -> int:
@@ -426,7 +446,7 @@ class GoogleBaseHook(BaseHook):
426
446
  def quota_retry(*args, **kwargs) -> Callable:
427
447
  """Provide a mechanism to repeat requests in response to exceeding a temporary quota limit."""
428
448
 
429
- def decorator(fun: Callable):
449
+ def decorator(func: Callable):
430
450
  default_kwargs = {
431
451
  "wait": tenacity.wait_exponential(multiplier=1, max=100),
432
452
  "retry": retry_if_temporary_quota(),
@@ -434,7 +454,7 @@ class GoogleBaseHook(BaseHook):
434
454
  "after": tenacity.after_log(log, logging.DEBUG),
435
455
  }
436
456
  default_kwargs.update(**kwargs)
437
- return tenacity.retry(*args, **default_kwargs)(fun)
457
+ return tenacity.retry(*args, **default_kwargs)(func)
438
458
 
439
459
  return decorator
440
460
 
@@ -442,7 +462,7 @@ class GoogleBaseHook(BaseHook):
442
462
  def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]:
443
463
  """Provide a mechanism to repeat requests in response to operation in progress (HTTP 409) limit."""
444
464
 
445
- def decorator(fun: T):
465
+ def decorator(func: T):
446
466
  default_kwargs = {
447
467
  "wait": tenacity.wait_exponential(multiplier=1, max=300),
448
468
  "retry": retry_if_operation_in_progress(),
@@ -450,7 +470,25 @@ class GoogleBaseHook(BaseHook):
450
470
  "after": tenacity.after_log(log, logging.DEBUG),
451
471
  }
452
472
  default_kwargs.update(**kwargs)
453
- return cast(T, tenacity.retry(*args, **default_kwargs)(fun))
473
+ return cast(T, tenacity.retry(*args, **default_kwargs)(func))
474
+
475
+ return decorator
476
+
477
+ @staticmethod
478
+ def refresh_credentials_retry(*args, **kwargs) -> Callable[[T], T]:
479
+ """Provide a mechanism to repeat requests in response to a temporary refresh credential issue."""
480
+
481
+ def decorator(func: T):
482
+ default_kwargs = {
483
+ "wait": tenacity.wait_exponential(multiplier=1, max=5),
484
+ "stop": tenacity.stop_after_attempt(3),
485
+ "retry": retry_if_temporary_refresh_credentials(),
486
+ "reraise": True,
487
+ "before": tenacity.before_log(log, logging.DEBUG),
488
+ "after": tenacity.after_log(log, logging.DEBUG),
489
+ }
490
+ default_kwargs.update(**kwargs)
491
+ return cast(T, tenacity.retry(*args, **default_kwargs)(func))
454
492
 
455
493
  return decorator
456
494
 
@@ -25,7 +25,7 @@ from typing import Sequence
25
25
  from googleapiclient.discovery import build, build_from_document
26
26
 
27
27
  from airflow.exceptions import AirflowException
28
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
28
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
29
29
 
30
30
  # Time to sleep between active checks of the operation results
31
31
  TIME_TO_SLEEP_IN_SECONDS = 5
@@ -84,7 +84,7 @@ class CloudFirestoreHook(GoogleBaseHook):
84
84
 
85
85
  @GoogleBaseHook.fallback_to_default_project_id
86
86
  def export_documents(
87
- self, body: dict, database_id: str = "(default)", project_id: str | None = None
87
+ self, body: dict, database_id: str = "(default)", project_id: str = PROVIDE_PROJECT_ID
88
88
  ) -> None:
89
89
  """
90
90
  Start a export with the specified configuration.
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Sequence
20
20
 
21
21
  from airflow.exceptions import AirflowException
22
22
  from airflow.models import BaseOperator
23
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
23
24
  from airflow.providers.google.firebase.hooks.firestore import CloudFirestoreHook
24
25
 
25
26
  if TYPE_CHECKING:
@@ -64,7 +65,7 @@ class CloudFirestoreExportDatabaseOperator(BaseOperator):
64
65
  *,
65
66
  body: dict,
66
67
  database_id: str = "(default)",
67
- project_id: str | None = None,
68
+ project_id: str = PROVIDE_PROJECT_ID,
68
69
  gcp_conn_id: str = "google_cloud_default",
69
70
  api_version: str = "v1",
70
71
  impersonation_chain: str | Sequence[str] | None = None,
@@ -28,8 +28,9 @@ def get_provider_info():
28
28
  "name": "Google",
29
29
  "description": "Google services including:\n\n - `Google Ads <https://ads.google.com/>`__\n - `Google Cloud (GCP) <https://cloud.google.com/>`__\n - `Google Firebase <https://firebase.google.com/>`__\n - `Google LevelDB <https://github.com/google/leveldb/>`__\n - `Google Marketing Platform <https://marketingplatform.google.com/>`__\n - `Google Workspace <https://workspace.google.com/>`__ (formerly Google Suite)\n",
30
30
  "state": "ready",
31
- "source-date-epoch": 1712665855,
31
+ "source-date-epoch": 1715384437,
32
32
  "versions": [
33
+ "10.18.0",
33
34
  "10.17.0",
34
35
  "10.16.0",
35
36
  "10.15.0",
@@ -86,7 +87,7 @@ def get_provider_info():
86
87
  "1.0.0",
87
88
  ],
88
89
  "dependencies": [
89
- "apache-airflow>=2.6.0",
90
+ "apache-airflow>=2.7.0",
90
91
  "apache-airflow-providers-common-sql>=1.7.2",
91
92
  "asgiref>=3.5.2",
92
93
  "gcloud-aio-auth>=4.0.0,<5.0.0",
@@ -95,7 +96,7 @@ def get_provider_info():
95
96
  "gcsfs>=2023.10.0",
96
97
  "google-ads>=23.1.0",
97
98
  "google-analytics-admin",
98
- "google-api-core>=2.11.0,!=2.16.0",
99
+ "google-api-core>=2.11.0,!=2.16.0,!=2.18.0",
99
100
  "google-api-python-client>=1.6.0",
100
101
  "google-auth>=1.0.0",
101
102
  "google-auth-httplib2>=0.0.1",
@@ -1146,6 +1147,7 @@ def get_provider_info():
1146
1147
  "airflow.providers.google.cloud.hooks.vertex_ai.model_service",
1147
1148
  "airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job",
1148
1149
  "airflow.providers.google.cloud.hooks.vertex_ai.generative_model",
1150
+ "airflow.providers.google.cloud.hooks.vertex_ai.prediction_service",
1149
1151
  ],
1150
1152
  },
1151
1153
  {