apache-airflow-providers-google 10.14.0rc2__py3-none-any.whl → 10.15.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/cloud/hooks/automl.py +13 -13
  3. airflow/providers/google/cloud/hooks/bigquery.py +193 -246
  4. airflow/providers/google/cloud/hooks/bigquery_dts.py +6 -6
  5. airflow/providers/google/cloud/hooks/bigtable.py +8 -8
  6. airflow/providers/google/cloud/hooks/cloud_batch.py +1 -1
  7. airflow/providers/google/cloud/hooks/cloud_build.py +19 -20
  8. airflow/providers/google/cloud/hooks/cloud_composer.py +4 -4
  9. airflow/providers/google/cloud/hooks/cloud_memorystore.py +10 -10
  10. airflow/providers/google/cloud/hooks/cloud_run.py +1 -1
  11. airflow/providers/google/cloud/hooks/cloud_sql.py +17 -17
  12. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +3 -3
  13. airflow/providers/google/cloud/hooks/compute.py +16 -16
  14. airflow/providers/google/cloud/hooks/compute_ssh.py +1 -1
  15. airflow/providers/google/cloud/hooks/datacatalog.py +22 -22
  16. airflow/providers/google/cloud/hooks/dataflow.py +48 -49
  17. airflow/providers/google/cloud/hooks/dataform.py +16 -16
  18. airflow/providers/google/cloud/hooks/datafusion.py +15 -15
  19. airflow/providers/google/cloud/hooks/datapipeline.py +3 -3
  20. airflow/providers/google/cloud/hooks/dataplex.py +19 -19
  21. airflow/providers/google/cloud/hooks/dataprep.py +8 -8
  22. airflow/providers/google/cloud/hooks/dataproc.py +88 -0
  23. airflow/providers/google/cloud/hooks/dataproc_metastore.py +13 -13
  24. airflow/providers/google/cloud/hooks/datastore.py +3 -3
  25. airflow/providers/google/cloud/hooks/dlp.py +25 -25
  26. airflow/providers/google/cloud/hooks/gcs.py +25 -23
  27. airflow/providers/google/cloud/hooks/gdm.py +3 -3
  28. airflow/providers/google/cloud/hooks/kms.py +3 -3
  29. airflow/providers/google/cloud/hooks/kubernetes_engine.py +63 -48
  30. airflow/providers/google/cloud/hooks/life_sciences.py +13 -12
  31. airflow/providers/google/cloud/hooks/looker.py +7 -7
  32. airflow/providers/google/cloud/hooks/mlengine.py +12 -12
  33. airflow/providers/google/cloud/hooks/natural_language.py +2 -2
  34. airflow/providers/google/cloud/hooks/os_login.py +1 -1
  35. airflow/providers/google/cloud/hooks/pubsub.py +9 -9
  36. airflow/providers/google/cloud/hooks/secret_manager.py +1 -1
  37. airflow/providers/google/cloud/hooks/spanner.py +11 -11
  38. airflow/providers/google/cloud/hooks/speech_to_text.py +1 -1
  39. airflow/providers/google/cloud/hooks/stackdriver.py +7 -7
  40. airflow/providers/google/cloud/hooks/tasks.py +11 -11
  41. airflow/providers/google/cloud/hooks/text_to_speech.py +1 -1
  42. airflow/providers/google/cloud/hooks/translate.py +1 -1
  43. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +13 -13
  44. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +6 -6
  45. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +45 -50
  46. airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +13 -13
  47. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +9 -9
  48. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +128 -11
  49. airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +10 -10
  50. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +8 -8
  51. airflow/providers/google/cloud/hooks/video_intelligence.py +2 -2
  52. airflow/providers/google/cloud/hooks/vision.py +1 -1
  53. airflow/providers/google/cloud/hooks/workflows.py +10 -10
  54. airflow/providers/google/cloud/links/datafusion.py +12 -5
  55. airflow/providers/google/cloud/operators/bigquery.py +9 -11
  56. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +3 -1
  57. airflow/providers/google/cloud/operators/dataflow.py +16 -16
  58. airflow/providers/google/cloud/operators/datafusion.py +9 -1
  59. airflow/providers/google/cloud/operators/dataproc.py +298 -65
  60. airflow/providers/google/cloud/operators/kubernetes_engine.py +6 -6
  61. airflow/providers/google/cloud/operators/life_sciences.py +10 -9
  62. airflow/providers/google/cloud/operators/mlengine.py +96 -96
  63. airflow/providers/google/cloud/operators/pubsub.py +2 -0
  64. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +33 -3
  65. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +59 -2
  66. airflow/providers/google/cloud/secrets/secret_manager.py +8 -7
  67. airflow/providers/google/cloud/sensors/bigquery.py +20 -16
  68. airflow/providers/google/cloud/sensors/cloud_composer.py +11 -8
  69. airflow/providers/google/cloud/sensors/gcs.py +8 -7
  70. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +4 -4
  71. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
  72. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
  73. airflow/providers/google/cloud/transfers/mysql_to_gcs.py +1 -1
  74. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +1 -1
  75. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +1 -1
  76. airflow/providers/google/cloud/transfers/presto_to_gcs.py +1 -1
  77. airflow/providers/google/cloud/transfers/s3_to_gcs.py +3 -3
  78. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
  79. airflow/providers/google/cloud/transfers/sql_to_gcs.py +3 -3
  80. airflow/providers/google/cloud/transfers/trino_to_gcs.py +1 -1
  81. airflow/providers/google/cloud/triggers/bigquery.py +12 -12
  82. airflow/providers/google/cloud/triggers/bigquery_dts.py +1 -1
  83. airflow/providers/google/cloud/triggers/cloud_batch.py +3 -1
  84. airflow/providers/google/cloud/triggers/cloud_build.py +2 -2
  85. airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
  86. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +6 -6
  87. airflow/providers/google/cloud/triggers/dataflow.py +3 -1
  88. airflow/providers/google/cloud/triggers/datafusion.py +2 -2
  89. airflow/providers/google/cloud/triggers/dataplex.py +2 -2
  90. airflow/providers/google/cloud/triggers/dataproc.py +2 -2
  91. airflow/providers/google/cloud/triggers/gcs.py +12 -8
  92. airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -2
  93. airflow/providers/google/cloud/triggers/mlengine.py +2 -2
  94. airflow/providers/google/cloud/triggers/pubsub.py +1 -1
  95. airflow/providers/google/cloud/triggers/vertex_ai.py +99 -0
  96. airflow/providers/google/cloud/utils/bigquery.py +2 -2
  97. airflow/providers/google/cloud/utils/credentials_provider.py +2 -2
  98. airflow/providers/google/cloud/utils/dataform.py +1 -1
  99. airflow/providers/google/cloud/utils/field_validator.py +2 -2
  100. airflow/providers/google/cloud/utils/helpers.py +2 -2
  101. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +1 -1
  102. airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +1 -1
  103. airflow/providers/google/common/auth_backend/google_openid.py +2 -2
  104. airflow/providers/google/common/hooks/base_google.py +29 -22
  105. airflow/providers/google/common/hooks/discovery_api.py +2 -2
  106. airflow/providers/google/common/utils/id_token_credentials.py +5 -5
  107. airflow/providers/google/firebase/hooks/firestore.py +3 -3
  108. airflow/providers/google/get_provider_info.py +7 -2
  109. airflow/providers/google/leveldb/hooks/leveldb.py +2 -2
  110. airflow/providers/google/marketing_platform/hooks/analytics.py +11 -14
  111. airflow/providers/google/marketing_platform/hooks/campaign_manager.py +11 -11
  112. airflow/providers/google/marketing_platform/hooks/display_video.py +13 -13
  113. airflow/providers/google/marketing_platform/hooks/search_ads.py +4 -4
  114. airflow/providers/google/marketing_platform/operators/analytics.py +37 -32
  115. airflow/providers/google/suite/hooks/calendar.py +2 -2
  116. airflow/providers/google/suite/hooks/drive.py +7 -7
  117. airflow/providers/google/suite/hooks/sheets.py +8 -8
  118. {apache_airflow_providers_google-10.14.0rc2.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/METADATA +11 -11
  119. {apache_airflow_providers_google-10.14.0rc2.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/RECORD +121 -120
  120. {apache_airflow_providers_google-10.14.0rc2.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/WHEEL +0 -0
  121. {apache_airflow_providers_google-10.14.0rc2.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/entry_points.txt +0 -0
@@ -52,7 +52,7 @@ class DatasetHook(GoogleBaseHook):
52
52
  super().__init__(**kwargs)
53
53
 
54
54
  def get_dataset_service_client(self, region: str | None = None) -> DatasetServiceClient:
55
- """Returns DatasetServiceClient."""
55
+ """Return DatasetServiceClient."""
56
56
  if region and region != "global":
57
57
  client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
58
58
  else:
@@ -63,7 +63,7 @@ class DatasetHook(GoogleBaseHook):
63
63
  )
64
64
 
65
65
  def wait_for_operation(self, operation: Operation, timeout: float | None = None):
66
- """Waits for long-lasting operation to complete."""
66
+ """Wait for long-lasting operation to complete."""
67
67
  try:
68
68
  return operation.result(timeout=timeout)
69
69
  except Exception:
@@ -72,7 +72,7 @@ class DatasetHook(GoogleBaseHook):
72
72
 
73
73
  @staticmethod
74
74
  def extract_dataset_id(obj: dict) -> str:
75
- """Returns unique id of the dataset."""
75
+ """Return unique id of the dataset."""
76
76
  return obj["name"].rpartition("/")[-1]
77
77
 
78
78
  @GoogleBaseHook.fallback_to_default_project_id
@@ -86,7 +86,7 @@ class DatasetHook(GoogleBaseHook):
86
86
  metadata: Sequence[tuple[str, str]] = (),
87
87
  ) -> Operation:
88
88
  """
89
- Creates a Dataset.
89
+ Create a Dataset.
90
90
 
91
91
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
92
92
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -120,7 +120,7 @@ class DatasetHook(GoogleBaseHook):
120
120
  metadata: Sequence[tuple[str, str]] = (),
121
121
  ) -> Operation:
122
122
  """
123
- Deletes a Dataset.
123
+ Delete a Dataset.
124
124
 
125
125
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
126
126
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -154,7 +154,7 @@ class DatasetHook(GoogleBaseHook):
154
154
  metadata: Sequence[tuple[str, str]] = (),
155
155
  ) -> Operation:
156
156
  """
157
- Exports data from a Dataset.
157
+ Export data from a Dataset.
158
158
 
159
159
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
160
160
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -191,7 +191,7 @@ class DatasetHook(GoogleBaseHook):
191
191
  metadata: Sequence[tuple[str, str]] = (),
192
192
  ) -> AnnotationSpec:
193
193
  """
194
- Gets an AnnotationSpec.
194
+ Get an AnnotationSpec.
195
195
 
196
196
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
197
197
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -228,7 +228,7 @@ class DatasetHook(GoogleBaseHook):
228
228
  metadata: Sequence[tuple[str, str]] = (),
229
229
  ) -> Dataset:
230
230
  """
231
- Gets a Dataset.
231
+ Get a Dataset.
232
232
 
233
233
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
234
234
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -264,7 +264,7 @@ class DatasetHook(GoogleBaseHook):
264
264
  metadata: Sequence[tuple[str, str]] = (),
265
265
  ) -> Operation:
266
266
  """
267
- Imports data into a Dataset.
267
+ Import data into a Dataset.
268
268
 
269
269
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
270
270
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -306,7 +306,7 @@ class DatasetHook(GoogleBaseHook):
306
306
  metadata: Sequence[tuple[str, str]] = (),
307
307
  ) -> ListAnnotationsPager:
308
308
  """
309
- Lists Annotations belongs to a data item.
309
+ List Annotations belongs to a data item.
310
310
 
311
311
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
312
312
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -356,7 +356,7 @@ class DatasetHook(GoogleBaseHook):
356
356
  metadata: Sequence[tuple[str, str]] = (),
357
357
  ) -> ListDataItemsPager:
358
358
  """
359
- Lists DataItems in a Dataset.
359
+ List DataItems in a Dataset.
360
360
 
361
361
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
362
362
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -404,7 +404,7 @@ class DatasetHook(GoogleBaseHook):
404
404
  metadata: Sequence[tuple[str, str]] = (),
405
405
  ) -> ListDatasetsPager:
406
406
  """
407
- Lists Datasets in a Location.
407
+ List Datasets in a Location.
408
408
 
409
409
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
410
410
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -448,7 +448,7 @@ class DatasetHook(GoogleBaseHook):
448
448
  metadata: Sequence[tuple[str, str]] = (),
449
449
  ) -> Dataset:
450
450
  """
451
- Updates a Dataset.
451
+ Update a Dataset.
452
452
 
453
453
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
454
454
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -47,7 +47,7 @@ class EndpointServiceHook(GoogleBaseHook):
47
47
  super().__init__(**kwargs)
48
48
 
49
49
  def get_endpoint_service_client(self, region: str | None = None) -> EndpointServiceClient:
50
- """Returns EndpointServiceClient."""
50
+ """Return EndpointServiceClient."""
51
51
  if region and region != "global":
52
52
  client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
53
53
  else:
@@ -58,7 +58,7 @@ class EndpointServiceHook(GoogleBaseHook):
58
58
  )
59
59
 
60
60
  def wait_for_operation(self, operation: Operation, timeout: float | None = None):
61
- """Waits for long-lasting operation to complete."""
61
+ """Wait for long-lasting operation to complete."""
62
62
  try:
63
63
  return operation.result(timeout=timeout)
64
64
  except Exception:
@@ -67,12 +67,12 @@ class EndpointServiceHook(GoogleBaseHook):
67
67
 
68
68
  @staticmethod
69
69
  def extract_endpoint_id(obj: dict) -> str:
70
- """Returns unique id of the endpoint."""
70
+ """Return unique id of the endpoint."""
71
71
  return obj["name"].rpartition("/")[-1]
72
72
 
73
73
  @staticmethod
74
74
  def extract_deployed_model_id(obj: dict) -> str:
75
- """Returns unique id of the deploy model."""
75
+ """Return unique id of the deploy model."""
76
76
  return obj["deployed_model"]["id"]
77
77
 
78
78
  @GoogleBaseHook.fallback_to_default_project_id
@@ -87,7 +87,7 @@ class EndpointServiceHook(GoogleBaseHook):
87
87
  metadata: Sequence[tuple[str, str]] = (),
88
88
  ) -> Operation:
89
89
  """
90
- Creates an Endpoint.
90
+ Create an Endpoint.
91
91
 
92
92
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
93
93
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -124,7 +124,7 @@ class EndpointServiceHook(GoogleBaseHook):
124
124
  metadata: Sequence[tuple[str, str]] = (),
125
125
  ) -> Operation:
126
126
  """
127
- Deletes an Endpoint.
127
+ Delete an Endpoint.
128
128
 
129
129
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
130
130
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -210,7 +210,7 @@ class EndpointServiceHook(GoogleBaseHook):
210
210
  metadata: Sequence[tuple[str, str]] = (),
211
211
  ) -> Endpoint:
212
212
  """
213
- Gets an Endpoint.
213
+ Get an Endpoint.
214
214
 
215
215
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
216
216
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -247,7 +247,7 @@ class EndpointServiceHook(GoogleBaseHook):
247
247
  metadata: Sequence[tuple[str, str]] = (),
248
248
  ) -> ListEndpointsPager:
249
249
  """
250
- Lists Endpoints in a Location.
250
+ List Endpoints in a Location.
251
251
 
252
252
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
253
253
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -352,7 +352,7 @@ class EndpointServiceHook(GoogleBaseHook):
352
352
  metadata: Sequence[tuple[str, str]] = (),
353
353
  ) -> Endpoint:
354
354
  """
355
- Updates an Endpoint.
355
+ Update an Endpoint.
356
356
 
357
357
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
358
358
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -15,22 +15,32 @@
15
15
  # KIND, either express or implied. See the License for the
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
- """This module contains a Google Cloud Vertex AI hook."""
18
+ """
19
+ This module contains a Google Cloud Vertex AI hook.
20
+
21
+ .. spelling:word-list::
22
+
23
+ JobServiceAsyncClient
24
+ """
19
25
  from __future__ import annotations
20
26
 
27
+ import asyncio
28
+ from functools import lru_cache
21
29
  from typing import TYPE_CHECKING, Sequence
22
30
 
23
31
  from google.api_core.client_options import ClientOptions
24
32
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
25
33
  from google.cloud.aiplatform import CustomJob, HyperparameterTuningJob, gapic, hyperparameter_tuning
26
- from google.cloud.aiplatform_v1 import JobServiceClient, types
34
+ from google.cloud.aiplatform_v1 import JobServiceAsyncClient, JobServiceClient, JobState, types
27
35
 
28
36
  from airflow.exceptions import AirflowException
37
+ from airflow.providers.google.common.consts import CLIENT_INFO
29
38
  from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
30
39
 
31
40
  if TYPE_CHECKING:
32
41
  from google.api_core.operation import Operation
33
42
  from google.api_core.retry import Retry
43
+ from google.api_core.retry_async import AsyncRetry
34
44
  from google.cloud.aiplatform_v1.services.job_service.pagers import ListHyperparameterTuningJobsPager
35
45
 
36
46
 
@@ -55,7 +65,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
55
65
  self._hyperparameter_tuning_job: HyperparameterTuningJob | None = None
56
66
 
57
67
  def get_job_service_client(self, region: str | None = None) -> JobServiceClient:
58
- """Returns JobServiceClient."""
68
+ """Return JobServiceClient."""
59
69
  if region and region != "global":
60
70
  client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
61
71
  else:
@@ -81,7 +91,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
81
91
  labels: dict[str, str] | None = None,
82
92
  encryption_spec_key_name: str | None = None,
83
93
  ) -> HyperparameterTuningJob:
84
- """Returns HyperparameterTuningJob object."""
94
+ """Return HyperparameterTuningJob object."""
85
95
  return HyperparameterTuningJob(
86
96
  display_name=display_name,
87
97
  custom_job=custom_job,
@@ -110,7 +120,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
110
120
  encryption_spec_key_name: str | None = None,
111
121
  staging_bucket: str | None = None,
112
122
  ) -> CustomJob:
113
- """Returns CustomJob object."""
123
+ """Return CustomJob object."""
114
124
  return CustomJob(
115
125
  display_name=display_name,
116
126
  worker_pool_specs=worker_pool_specs,
@@ -125,11 +135,11 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
125
135
 
126
136
  @staticmethod
127
137
  def extract_hyperparameter_tuning_job_id(obj: dict) -> str:
128
- """Returns unique id of the hyperparameter_tuning_job."""
138
+ """Return unique id of the hyperparameter_tuning_job."""
129
139
  return obj["name"].rpartition("/")[-1]
130
140
 
131
141
  def wait_for_operation(self, operation: Operation, timeout: float | None = None):
132
- """Waits for long-lasting operation to complete."""
142
+ """Wait for long-lasting operation to complete."""
133
143
  try:
134
144
  return operation.result(timeout=timeout)
135
145
  except Exception:
@@ -172,6 +182,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
172
182
  tensorboard: str | None = None,
173
183
  sync: bool = True,
174
184
  # END: run param
185
+ wait_job_completed: bool = True,
175
186
  ) -> HyperparameterTuningJob:
176
187
  """
177
188
  Create a HyperparameterTuningJob.
@@ -256,6 +267,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
256
267
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
257
268
  :param sync: Whether to execute this method synchronously. If False, this method will unblock and it
258
269
  will be executed in a concurrent Future.
270
+ :param wait_job_completed: Whether to wait for the job completed.
259
271
  """
260
272
  custom_job = self.get_custom_job_object(
261
273
  project=project_id,
@@ -292,7 +304,11 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
292
304
  tensorboard=tensorboard,
293
305
  sync=sync,
294
306
  )
295
- self._hyperparameter_tuning_job.wait()
307
+
308
+ if wait_job_completed:
309
+ self._hyperparameter_tuning_job.wait()
310
+ else:
311
+ self._hyperparameter_tuning_job._wait_for_resource_creation()
296
312
  return self._hyperparameter_tuning_job
297
313
 
298
314
  @GoogleBaseHook.fallback_to_default_project_id
@@ -306,7 +322,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
306
322
  metadata: Sequence[tuple[str, str]] = (),
307
323
  ) -> types.HyperparameterTuningJob:
308
324
  """
309
- Gets a HyperparameterTuningJob.
325
+ Get a HyperparameterTuningJob.
310
326
 
311
327
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
312
328
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -342,7 +358,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
342
358
  metadata: Sequence[tuple[str, str]] = (),
343
359
  ) -> ListHyperparameterTuningJobsPager:
344
360
  """
345
- Lists HyperparameterTuningJobs in a Location.
361
+ List HyperparameterTuningJobs in a Location.
346
362
 
347
363
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
348
364
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -391,7 +407,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
391
407
  metadata: Sequence[tuple[str, str]] = (),
392
408
  ) -> Operation:
393
409
  """
394
- Deletes a HyperparameterTuningJob.
410
+ Delete a HyperparameterTuningJob.
395
411
 
396
412
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
397
413
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -413,3 +429,104 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
413
429
  metadata=metadata,
414
430
  )
415
431
  return result
432
+
433
+
434
+ class HyperparameterTuningJobAsyncHook(GoogleBaseHook):
435
+ """Async hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs."""
436
+
437
+ def __init__(
438
+ self,
439
+ gcp_conn_id: str = "google_cloud_default",
440
+ impersonation_chain: str | Sequence[str] | None = None,
441
+ **kwargs,
442
+ ):
443
+ super().__init__(
444
+ gcp_conn_id=gcp_conn_id,
445
+ impersonation_chain=impersonation_chain,
446
+ **kwargs,
447
+ )
448
+
449
+ @lru_cache
450
+ def get_job_service_client(self, region: str | None = None) -> JobServiceAsyncClient:
451
+ """
452
+ Retrieve Vertex AI async client.
453
+
454
+ :return: Google Cloud Vertex AI client object.
455
+ """
456
+ endpoint = f"{region}-aiplatform.googleapis.com:443" if region and region != "global" else None
457
+ return JobServiceAsyncClient(
458
+ credentials=self.get_credentials(),
459
+ client_info=CLIENT_INFO,
460
+ client_options=ClientOptions(api_endpoint=endpoint),
461
+ )
462
+
463
+ async def get_hyperparameter_tuning_job(
464
+ self,
465
+ project_id: str,
466
+ location: str,
467
+ job_id: str,
468
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
469
+ timeout: float | None = None,
470
+ metadata: Sequence[tuple[str, str]] = (),
471
+ ) -> types.HyperparameterTuningJob:
472
+ """
473
+ Retrieve a hyperparameter tuning job.
474
+
475
+ :param project_id: Required. The ID of the Google Cloud project that the job belongs to.
476
+ :param location: Required. The ID of the Google Cloud region that the job belongs to.
477
+ :param job_id: Required. The hyperparameter tuning job id.
478
+ :param retry: Designation of what errors, if any, should be retried.
479
+ :param timeout: The timeout for this request.
480
+ :param metadata: Strings which should be sent along with the request as metadata.
481
+ """
482
+ client: JobServiceAsyncClient = self.get_job_service_client(region=location)
483
+ job_name = client.hyperparameter_tuning_job_path(project_id, location, job_id)
484
+
485
+ result = await client.get_hyperparameter_tuning_job(
486
+ request={
487
+ "name": job_name,
488
+ },
489
+ retry=retry,
490
+ timeout=timeout,
491
+ metadata=metadata,
492
+ )
493
+
494
+ return result
495
+
496
+ async def wait_hyperparameter_tuning_job(
497
+ self,
498
+ project_id: str,
499
+ location: str,
500
+ job_id: str,
501
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
502
+ timeout: float | None = None,
503
+ metadata: Sequence[tuple[str, str]] = (),
504
+ poll_interval: int = 10,
505
+ ) -> types.HyperparameterTuningJob:
506
+ statuses_complete = {
507
+ JobState.JOB_STATE_CANCELLED,
508
+ JobState.JOB_STATE_FAILED,
509
+ JobState.JOB_STATE_PAUSED,
510
+ JobState.JOB_STATE_SUCCEEDED,
511
+ }
512
+ while True:
513
+ try:
514
+ self.log.info("Requesting hyperparameter tuning job with id %s", job_id)
515
+ job: types.HyperparameterTuningJob = await self.get_hyperparameter_tuning_job(
516
+ project_id=project_id,
517
+ location=location,
518
+ job_id=job_id,
519
+ retry=retry,
520
+ timeout=timeout,
521
+ metadata=metadata,
522
+ )
523
+ except Exception as ex:
524
+ self.log.exception("Exception occurred while requesting job %s", job_id)
525
+ raise AirflowException(ex)
526
+
527
+ self.log.info("Status of the hyperparameter tuning job %s is %s", job.name, job.state.name)
528
+ if job.state in statuses_complete:
529
+ return job
530
+
531
+ self.log.info("Sleeping for %s seconds.", poll_interval)
532
+ await asyncio.sleep(poll_interval)
@@ -51,7 +51,7 @@ class ModelServiceHook(GoogleBaseHook):
51
51
  super().__init__(**kwargs)
52
52
 
53
53
  def get_model_service_client(self, region: str | None = None) -> ModelServiceClient:
54
- """Returns ModelServiceClient."""
54
+ """Return ModelServiceClient object."""
55
55
  if region and region != "global":
56
56
  client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
57
57
  else:
@@ -63,11 +63,11 @@ class ModelServiceHook(GoogleBaseHook):
63
63
 
64
64
  @staticmethod
65
65
  def extract_model_id(obj: dict) -> str:
66
- """Returns unique id of the model."""
66
+ """Return unique id of the model."""
67
67
  return obj["model"].rpartition("/")[-1]
68
68
 
69
69
  def wait_for_operation(self, operation: Operation, timeout: float | None = None):
70
- """Waits for long-lasting operation to complete."""
70
+ """Wait for long-lasting operation to complete."""
71
71
  try:
72
72
  return operation.result(timeout=timeout)
73
73
  except Exception:
@@ -85,7 +85,7 @@ class ModelServiceHook(GoogleBaseHook):
85
85
  metadata: Sequence[tuple[str, str]] = (),
86
86
  ) -> Operation:
87
87
  """
88
- Deletes a Model.
88
+ Delete a Model.
89
89
 
90
90
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
91
91
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -119,7 +119,7 @@ class ModelServiceHook(GoogleBaseHook):
119
119
  metadata: Sequence[tuple[str, str]] = (),
120
120
  ) -> Operation:
121
121
  """
122
- Exports a trained, exportable Model to a location specified by the user.
122
+ Export a trained, exportable Model to a location specified by the user.
123
123
 
124
124
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
125
125
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -158,7 +158,7 @@ class ModelServiceHook(GoogleBaseHook):
158
158
  metadata: Sequence[tuple[str, str]] = (),
159
159
  ) -> ListModelsPager:
160
160
  r"""
161
- Lists Models in a Location.
161
+ List Models in a Location.
162
162
 
163
163
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
164
164
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -213,7 +213,7 @@ class ModelServiceHook(GoogleBaseHook):
213
213
  metadata: Sequence[tuple[str, str]] = (),
214
214
  ) -> Operation:
215
215
  """
216
- Uploads a Model artifact into Vertex AI.
216
+ Upload a Model artifact into Vertex AI.
217
217
 
218
218
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
219
219
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -247,7 +247,7 @@ class ModelServiceHook(GoogleBaseHook):
247
247
  metadata: Sequence[tuple[str, str]] = (),
248
248
  ) -> ListModelVersionsPager:
249
249
  """
250
- Lists all versions of the existing Model.
250
+ List all versions of the existing Model.
251
251
 
252
252
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
253
253
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -280,7 +280,7 @@ class ModelServiceHook(GoogleBaseHook):
280
280
  metadata: Sequence[tuple[str, str]] = (),
281
281
  ) -> Operation:
282
282
  """
283
- Deletes version of the Model. The version could not be deleted if this version is default.
283
+ Delete version of the Model. The version could not be deleted if this version is default.
284
284
 
285
285
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
286
286
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -313,7 +313,7 @@ class ModelServiceHook(GoogleBaseHook):
313
313
  metadata: Sequence[tuple[str, str]] = (),
314
314
  ) -> Model:
315
315
  """
316
- Retrieves Model of specific name and version. If version is not specified, the default is retrieved.
316
+ Retrieve Model of specific name and version. If version is not specified, the default is retrieved.
317
317
 
318
318
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
319
319
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -60,7 +60,7 @@ class PipelineJobHook(GoogleBaseHook):
60
60
  self,
61
61
  region: str | None = None,
62
62
  ) -> PipelineServiceClient:
63
- """Returns PipelineServiceClient."""
63
+ """Return PipelineServiceClient object."""
64
64
  if region and region != "global":
65
65
  client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
66
66
  else:
@@ -84,7 +84,7 @@ class PipelineJobHook(GoogleBaseHook):
84
84
  location: str | None = None,
85
85
  failure_policy: str | None = None,
86
86
  ) -> PipelineJob:
87
- """Returns PipelineJob object."""
87
+ """Return PipelineJob object."""
88
88
  return PipelineJob(
89
89
  display_name=display_name,
90
90
  template_path=template_path,
@@ -103,11 +103,11 @@ class PipelineJobHook(GoogleBaseHook):
103
103
 
104
104
  @staticmethod
105
105
  def extract_pipeline_job_id(obj: dict) -> str:
106
- """Returns unique id of the pipeline_job."""
106
+ """Return unique id of the pipeline_job."""
107
107
  return obj["name"].rpartition("/")[-1]
108
108
 
109
109
  def wait_for_operation(self, operation: Operation, timeout: float | None = None):
110
- """Waits for long-lasting operation to complete."""
110
+ """Wait for long-lasting operation to complete."""
111
111
  try:
112
112
  return operation.result(timeout=timeout)
113
113
  except Exception:
@@ -131,7 +131,7 @@ class PipelineJobHook(GoogleBaseHook):
131
131
  metadata: Sequence[tuple[str, str]] = (),
132
132
  ) -> PipelineJob:
133
133
  """
134
- Creates a PipelineJob. A PipelineJob will run immediately when created.
134
+ Create a PipelineJob. A PipelineJob will run immediately when created.
135
135
 
136
136
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
137
137
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -265,7 +265,7 @@ class PipelineJobHook(GoogleBaseHook):
265
265
  metadata: Sequence[tuple[str, str]] = (),
266
266
  ) -> PipelineJob:
267
267
  """
268
- Gets a PipelineJob.
268
+ Get a PipelineJob.
269
269
 
270
270
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
271
271
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -301,7 +301,7 @@ class PipelineJobHook(GoogleBaseHook):
301
301
  metadata: Sequence[tuple[str, str]] = (),
302
302
  ) -> ListPipelineJobsPager:
303
303
  """
304
- Lists PipelineJobs in a Location.
304
+ List PipelineJobs in a Location.
305
305
 
306
306
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
307
307
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -386,7 +386,7 @@ class PipelineJobHook(GoogleBaseHook):
386
386
  metadata: Sequence[tuple[str, str]] = (),
387
387
  ) -> Operation:
388
388
  """
389
- Deletes a PipelineJob.
389
+ Delete a PipelineJob.
390
390
 
391
391
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
392
392
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -71,7 +71,7 @@ class CloudVideoIntelligenceHook(GoogleBaseHook):
71
71
  self._conn: VideoIntelligenceServiceClient | None = None
72
72
 
73
73
  def get_conn(self) -> VideoIntelligenceServiceClient:
74
- """Returns Gcp Video Intelligence Service client."""
74
+ """Return Gcp Video Intelligence Service client."""
75
75
  if not self._conn:
76
76
  self._conn = VideoIntelligenceServiceClient(
77
77
  credentials=self.get_credentials(), client_info=CLIENT_INFO
@@ -92,7 +92,7 @@ class CloudVideoIntelligenceHook(GoogleBaseHook):
92
92
  metadata: Sequence[tuple[str, str]] = (),
93
93
  ) -> Operation:
94
94
  """
95
- Performs video annotation.
95
+ Perform video annotation.
96
96
 
97
97
  :param input_uri: Input video location. Currently, only Google Cloud Storage URIs are supported,
98
98
  which must be specified in the following format: ``gs://bucket-id/object-id``.
@@ -142,7 +142,7 @@ class CloudVisionHook(GoogleBaseHook):
142
142
 
143
143
  def get_conn(self) -> ProductSearchClient:
144
144
  """
145
- Retrieves connection to Cloud Vision.
145
+ Retrieve a connection to Cloud Vision.
146
146
 
147
147
  :return: Google Cloud Vision client object.
148
148
  """