apache-airflow-providers-google 10.14.0rc1__py3-none-any.whl → 10.15.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +1 -2
  3. airflow/providers/google/cloud/hooks/automl.py +13 -13
  4. airflow/providers/google/cloud/hooks/bigquery.py +208 -256
  5. airflow/providers/google/cloud/hooks/bigquery_dts.py +6 -6
  6. airflow/providers/google/cloud/hooks/bigtable.py +8 -8
  7. airflow/providers/google/cloud/hooks/cloud_batch.py +1 -1
  8. airflow/providers/google/cloud/hooks/cloud_build.py +19 -20
  9. airflow/providers/google/cloud/hooks/cloud_composer.py +4 -4
  10. airflow/providers/google/cloud/hooks/cloud_memorystore.py +10 -10
  11. airflow/providers/google/cloud/hooks/cloud_run.py +1 -1
  12. airflow/providers/google/cloud/hooks/cloud_sql.py +18 -19
  13. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +3 -3
  14. airflow/providers/google/cloud/hooks/compute.py +16 -16
  15. airflow/providers/google/cloud/hooks/compute_ssh.py +1 -1
  16. airflow/providers/google/cloud/hooks/datacatalog.py +22 -22
  17. airflow/providers/google/cloud/hooks/dataflow.py +48 -49
  18. airflow/providers/google/cloud/hooks/dataform.py +16 -16
  19. airflow/providers/google/cloud/hooks/datafusion.py +15 -15
  20. airflow/providers/google/cloud/hooks/datapipeline.py +3 -3
  21. airflow/providers/google/cloud/hooks/dataplex.py +19 -19
  22. airflow/providers/google/cloud/hooks/dataprep.py +10 -10
  23. airflow/providers/google/cloud/hooks/dataproc.py +132 -14
  24. airflow/providers/google/cloud/hooks/dataproc_metastore.py +13 -13
  25. airflow/providers/google/cloud/hooks/datastore.py +3 -3
  26. airflow/providers/google/cloud/hooks/dlp.py +25 -25
  27. airflow/providers/google/cloud/hooks/gcs.py +39 -27
  28. airflow/providers/google/cloud/hooks/gdm.py +3 -3
  29. airflow/providers/google/cloud/hooks/kms.py +3 -3
  30. airflow/providers/google/cloud/hooks/kubernetes_engine.py +63 -48
  31. airflow/providers/google/cloud/hooks/life_sciences.py +13 -12
  32. airflow/providers/google/cloud/hooks/looker.py +8 -9
  33. airflow/providers/google/cloud/hooks/mlengine.py +12 -12
  34. airflow/providers/google/cloud/hooks/natural_language.py +2 -2
  35. airflow/providers/google/cloud/hooks/os_login.py +1 -1
  36. airflow/providers/google/cloud/hooks/pubsub.py +9 -9
  37. airflow/providers/google/cloud/hooks/secret_manager.py +1 -1
  38. airflow/providers/google/cloud/hooks/spanner.py +11 -11
  39. airflow/providers/google/cloud/hooks/speech_to_text.py +1 -1
  40. airflow/providers/google/cloud/hooks/stackdriver.py +7 -7
  41. airflow/providers/google/cloud/hooks/tasks.py +11 -11
  42. airflow/providers/google/cloud/hooks/text_to_speech.py +1 -1
  43. airflow/providers/google/cloud/hooks/translate.py +1 -1
  44. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +13 -13
  45. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +6 -6
  46. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +45 -50
  47. airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +13 -13
  48. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +9 -9
  49. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +128 -11
  50. airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +10 -10
  51. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +8 -8
  52. airflow/providers/google/cloud/hooks/video_intelligence.py +2 -2
  53. airflow/providers/google/cloud/hooks/vision.py +1 -1
  54. airflow/providers/google/cloud/hooks/workflows.py +10 -10
  55. airflow/providers/google/cloud/links/datafusion.py +12 -5
  56. airflow/providers/google/cloud/operators/bigquery.py +11 -11
  57. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +3 -1
  58. airflow/providers/google/cloud/operators/dataflow.py +16 -16
  59. airflow/providers/google/cloud/operators/datafusion.py +9 -1
  60. airflow/providers/google/cloud/operators/dataproc.py +444 -69
  61. airflow/providers/google/cloud/operators/kubernetes_engine.py +6 -6
  62. airflow/providers/google/cloud/operators/life_sciences.py +10 -9
  63. airflow/providers/google/cloud/operators/mlengine.py +96 -96
  64. airflow/providers/google/cloud/operators/pubsub.py +2 -0
  65. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +33 -3
  66. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +59 -2
  67. airflow/providers/google/cloud/secrets/secret_manager.py +8 -7
  68. airflow/providers/google/cloud/sensors/bigquery.py +20 -16
  69. airflow/providers/google/cloud/sensors/cloud_composer.py +11 -8
  70. airflow/providers/google/cloud/sensors/dataproc_metastore.py +12 -2
  71. airflow/providers/google/cloud/sensors/gcs.py +8 -7
  72. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -0
  73. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +4 -4
  74. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -0
  75. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
  76. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
  77. airflow/providers/google/cloud/transfers/mysql_to_gcs.py +1 -1
  78. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +1 -1
  79. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +1 -1
  80. airflow/providers/google/cloud/transfers/presto_to_gcs.py +1 -1
  81. airflow/providers/google/cloud/transfers/s3_to_gcs.py +3 -3
  82. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
  83. airflow/providers/google/cloud/transfers/sql_to_gcs.py +3 -3
  84. airflow/providers/google/cloud/transfers/trino_to_gcs.py +1 -1
  85. airflow/providers/google/cloud/triggers/bigquery.py +12 -12
  86. airflow/providers/google/cloud/triggers/bigquery_dts.py +1 -1
  87. airflow/providers/google/cloud/triggers/cloud_batch.py +3 -1
  88. airflow/providers/google/cloud/triggers/cloud_build.py +2 -2
  89. airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
  90. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +6 -6
  91. airflow/providers/google/cloud/triggers/dataflow.py +3 -1
  92. airflow/providers/google/cloud/triggers/datafusion.py +2 -2
  93. airflow/providers/google/cloud/triggers/dataplex.py +2 -2
  94. airflow/providers/google/cloud/triggers/dataproc.py +34 -14
  95. airflow/providers/google/cloud/triggers/gcs.py +12 -8
  96. airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -2
  97. airflow/providers/google/cloud/triggers/mlengine.py +2 -2
  98. airflow/providers/google/cloud/triggers/pubsub.py +1 -1
  99. airflow/providers/google/cloud/triggers/vertex_ai.py +99 -0
  100. airflow/providers/google/cloud/utils/bigquery.py +2 -2
  101. airflow/providers/google/cloud/utils/credentials_provider.py +2 -2
  102. airflow/providers/google/cloud/utils/dataform.py +1 -1
  103. airflow/providers/google/cloud/utils/dataproc.py +25 -0
  104. airflow/providers/google/cloud/utils/field_validator.py +2 -2
  105. airflow/providers/google/cloud/utils/helpers.py +2 -2
  106. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +1 -1
  107. airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +1 -1
  108. airflow/providers/google/common/auth_backend/google_openid.py +2 -2
  109. airflow/providers/google/common/hooks/base_google.py +87 -23
  110. airflow/providers/google/common/hooks/discovery_api.py +2 -2
  111. airflow/providers/google/common/utils/id_token_credentials.py +5 -5
  112. airflow/providers/google/firebase/hooks/firestore.py +3 -3
  113. airflow/providers/google/get_provider_info.py +7 -2
  114. airflow/providers/google/leveldb/hooks/leveldb.py +4 -4
  115. airflow/providers/google/marketing_platform/hooks/analytics.py +11 -14
  116. airflow/providers/google/marketing_platform/hooks/campaign_manager.py +11 -11
  117. airflow/providers/google/marketing_platform/hooks/display_video.py +13 -13
  118. airflow/providers/google/marketing_platform/hooks/search_ads.py +4 -4
  119. airflow/providers/google/marketing_platform/operators/analytics.py +37 -32
  120. airflow/providers/google/suite/hooks/calendar.py +2 -2
  121. airflow/providers/google/suite/hooks/drive.py +7 -7
  122. airflow/providers/google/suite/hooks/sheets.py +8 -8
  123. {apache_airflow_providers_google-10.14.0rc1.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/METADATA +11 -11
  124. {apache_airflow_providers_google-10.14.0rc1.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/RECORD +126 -124
  125. {apache_airflow_providers_google-10.14.0rc1.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/WHEEL +0 -0
  126. {apache_airflow_providers_google-10.14.0rc1.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@
19
19
  from __future__ import annotations
20
20
 
21
21
  import asyncio
22
+ import re
22
23
  import time
23
24
  from typing import Any, AsyncIterator, Sequence
24
25
 
@@ -26,6 +27,7 @@ from google.api_core.exceptions import NotFound
26
27
  from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
27
28
 
28
29
  from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
30
+ from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
29
31
  from airflow.triggers.base import BaseTrigger, TriggerEvent
30
32
 
31
33
 
@@ -180,7 +182,7 @@ class DataprocBatchTrigger(DataprocBaseTrigger):
180
182
  self.batch_id = batch_id
181
183
 
182
184
  def serialize(self) -> tuple[str, dict[str, Any]]:
183
- """Serializes DataprocBatchTrigger arguments and classpath."""
185
+ """Serialize DataprocBatchTrigger arguments and classpath."""
184
186
  return (
185
187
  "airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger",
186
188
  {
@@ -242,7 +244,7 @@ class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
242
244
  self.metadata = metadata
243
245
 
244
246
  def serialize(self) -> tuple[str, dict[str, Any]]:
245
- """Serializes DataprocDeleteClusterTrigger arguments and classpath."""
247
+ """Serialize DataprocDeleteClusterTrigger arguments and classpath."""
246
248
  return (
247
249
  "airflow.providers.google.cloud.triggers.dataproc.DataprocDeleteClusterTrigger",
248
250
  {
@@ -281,22 +283,24 @@ class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
281
283
  yield TriggerEvent({"status": "error", "message": "Timeout"})
282
284
 
283
285
 
284
- class DataprocWorkflowTrigger(DataprocBaseTrigger):
286
+ class DataprocOperationTrigger(DataprocBaseTrigger):
285
287
  """
286
- Trigger that periodically polls information from Dataproc API to verify status.
288
+ Trigger that periodically polls information on a long running operation from Dataproc API to verify status.
287
289
 
288
290
  Implementation leverages asynchronous transport.
289
291
  """
290
292
 
291
- def __init__(self, name: str, **kwargs: Any):
293
+ def __init__(self, name: str, operation_type: str | None = None, **kwargs: Any):
292
294
  super().__init__(**kwargs)
293
295
  self.name = name
296
+ self.operation_type = operation_type
294
297
 
295
298
  def serialize(self):
296
299
  return (
297
- "airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger",
300
+ "airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger",
298
301
  {
299
302
  "name": self.name,
303
+ "operation_type": self.operation_type,
300
304
  "project_id": self.project_id,
301
305
  "region": self.region,
302
306
  "gcp_conn_id": self.gcp_conn_id,
@@ -317,14 +321,30 @@ class DataprocWorkflowTrigger(DataprocBaseTrigger):
317
321
  else:
318
322
  status = "success"
319
323
  message = "Operation is successfully ended."
320
- yield TriggerEvent(
321
- {
322
- "operation_name": operation.name,
323
- "operation_done": operation.done,
324
- "status": status,
325
- "message": message,
326
- }
327
- )
324
+ if self.operation_type == DataprocOperationType.DIAGNOSE.value:
325
+ gcs_regex = rb"gs:\/\/[a-z0-9][a-z0-9_-]{1,61}[a-z0-9_\-\/]*"
326
+ gcs_uri_value = operation.response.value
327
+ match = re.search(gcs_regex, gcs_uri_value)
328
+ if match:
329
+ output_uri = match.group(0).decode("utf-8", "ignore")
330
+ else:
331
+ output_uri = gcs_uri_value
332
+ yield TriggerEvent(
333
+ {
334
+ "status": status,
335
+ "message": message,
336
+ "output_uri": output_uri,
337
+ }
338
+ )
339
+ else:
340
+ yield TriggerEvent(
341
+ {
342
+ "operation_name": operation.name,
343
+ "operation_done": operation.done,
344
+ "status": status,
345
+ "message": message,
346
+ }
347
+ )
328
348
  return
329
349
  else:
330
350
  self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
@@ -60,7 +60,7 @@ class GCSBlobTrigger(BaseTrigger):
60
60
  self.hook_params = hook_params
61
61
 
62
62
  def serialize(self) -> tuple[str, dict[str, Any]]:
63
- """Serializes GCSBlobTrigger arguments and classpath."""
63
+ """Serialize GCSBlobTrigger arguments and classpath."""
64
64
  return (
65
65
  "airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger",
66
66
  {
@@ -93,7 +93,7 @@ class GCSBlobTrigger(BaseTrigger):
93
93
 
94
94
  async def _object_exists(self, hook: GCSAsyncHook, bucket_name: str, object_name: str) -> str:
95
95
  """
96
- Checks for the existence of a file in Google Cloud Storage.
96
+ Check for the existence of a file in Google Cloud Storage.
97
97
 
98
98
  :param bucket_name: The Google Cloud Storage bucket where the object is.
99
99
  :param object_name: The name of the blob_name to check in the Google cloud
@@ -143,7 +143,7 @@ class GCSCheckBlobUpdateTimeTrigger(BaseTrigger):
143
143
  self.hook_params = hook_params
144
144
 
145
145
  def serialize(self) -> tuple[str, dict[str, Any]]:
146
- """Serializes GCSCheckBlobUpdateTimeTrigger arguments and classpath."""
146
+ """Serialize GCSCheckBlobUpdateTimeTrigger arguments and classpath."""
147
147
  return (
148
148
  "airflow.providers.google.cloud.triggers.gcs.GCSCheckBlobUpdateTimeTrigger",
149
149
  {
@@ -181,7 +181,7 @@ class GCSCheckBlobUpdateTimeTrigger(BaseTrigger):
181
181
  self, hook: GCSAsyncHook, bucket_name: str, object_name: str, target_date: datetime
182
182
  ) -> tuple[bool, dict[str, Any]]:
183
183
  """
184
- Checks if the object in the bucket is updated.
184
+ Check if the object in the bucket is updated.
185
185
 
186
186
  :param hook: GCSAsyncHook Hook class
187
187
  :param bucket_name: The Google Cloud Storage bucket where the object is.
@@ -248,7 +248,7 @@ class GCSPrefixBlobTrigger(GCSBlobTrigger):
248
248
  self.prefix = prefix
249
249
 
250
250
  def serialize(self) -> tuple[str, dict[str, Any]]:
251
- """Serializes GCSPrefixBlobTrigger arguments and classpath."""
251
+ """Serialize GCSPrefixBlobTrigger arguments and classpath."""
252
252
  return (
253
253
  "airflow.providers.google.cloud.triggers.gcs.GCSPrefixBlobTrigger",
254
254
  {
@@ -282,7 +282,7 @@ class GCSPrefixBlobTrigger(GCSBlobTrigger):
282
282
 
283
283
  async def _list_blobs_with_prefix(self, hook: GCSAsyncHook, bucket_name: str, prefix: str) -> list[str]:
284
284
  """
285
- Returns names of blobs which match the given prefix for a given bucket.
285
+ Return names of blobs which match the given prefix for a given bucket.
286
286
 
287
287
  :param hook: The async hook to use for listing the blobs
288
288
  :param bucket_name: The Google Cloud Storage bucket where the object is.
@@ -344,7 +344,7 @@ class GCSUploadSessionTrigger(GCSPrefixBlobTrigger):
344
344
  self.last_activity_time: datetime | None = None
345
345
 
346
346
  def serialize(self) -> tuple[str, dict[str, Any]]:
347
- """Serializes GCSUploadSessionTrigger arguments and classpath."""
347
+ """Serialize GCSUploadSessionTrigger arguments and classpath."""
348
348
  return (
349
349
  "airflow.providers.google.cloud.triggers.gcs.GCSUploadSessionTrigger",
350
350
  {
@@ -377,7 +377,11 @@ class GCSUploadSessionTrigger(GCSPrefixBlobTrigger):
377
377
  yield TriggerEvent({"status": "error", "message": str(e)})
378
378
 
379
379
  def _get_time(self) -> datetime:
380
- """This is just a wrapper of datetime.datetime.now to simplify mocking in the unittests."""
380
+ """
381
+ Get current local date and time.
382
+
383
+ This is just a wrapper of datetime.datetime.now to simplify mocking in the unittests.
384
+ """
381
385
  return datetime.now()
382
386
 
383
387
  def _is_bucket_updated(self, current_objects: set[str]) -> dict[str, str]:
@@ -174,7 +174,7 @@ class GKEOperationTrigger(BaseTrigger):
174
174
  self._hook: GKEAsyncHook | None = None
175
175
 
176
176
  def serialize(self) -> tuple[str, dict[str, Any]]:
177
- """Serializes GKEOperationTrigger arguments and classpath."""
177
+ """Serialize GKEOperationTrigger arguments and classpath."""
178
178
  return (
179
179
  "airflow.providers.google.cloud.triggers.kubernetes_engine.GKEOperationTrigger",
180
180
  {
@@ -188,7 +188,7 @@ class GKEOperationTrigger(BaseTrigger):
188
188
  )
189
189
 
190
190
  async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
191
- """Gets operation status and yields corresponding event."""
191
+ """Get operation status and yields corresponding event."""
192
192
  hook = self._get_hook()
193
193
  try:
194
194
  while True:
@@ -69,7 +69,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
69
69
  self.impersonation_chain = impersonation_chain
70
70
 
71
71
  def serialize(self) -> tuple[str, dict[str, Any]]:
72
- """Serializes MLEngineStartTrainingJobTrigger arguments and classpath."""
72
+ """Serialize MLEngineStartTrainingJobTrigger arguments and classpath."""
73
73
  return (
74
74
  "airflow.providers.google.cloud.triggers.mlengine.MLEngineStartTrainingJobTrigger",
75
75
  {
@@ -89,7 +89,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
89
89
  )
90
90
 
91
91
  async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
92
- """Gets current job execution status and yields a TriggerEvent."""
92
+ """Get current job execution status and yields a TriggerEvent."""
93
93
  hook = self._get_async_hook()
94
94
  try:
95
95
  while True:
@@ -79,7 +79,7 @@ class PubsubPullTrigger(BaseTrigger):
79
79
  self.hook = PubSubAsyncHook()
80
80
 
81
81
  def serialize(self) -> tuple[str, dict[str, Any]]:
82
- """Serializes PubsubPullTrigger arguments and classpath."""
82
+ """Serialize PubsubPullTrigger arguments and classpath."""
83
83
  return (
84
84
  "airflow.providers.google.cloud.triggers.pubsub.PubsubPullTrigger",
85
85
  {
@@ -0,0 +1,99 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import Any, AsyncIterator, Sequence
20
+
21
+ from google.cloud.aiplatform_v1 import HyperparameterTuningJob, JobState
22
+
23
+ from airflow.exceptions import AirflowException
24
+ from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
25
+ HyperparameterTuningJobAsyncHook,
26
+ )
27
+ from airflow.triggers.base import BaseTrigger, TriggerEvent
28
+
29
+
30
+ class CreateHyperparameterTuningJobTrigger(BaseTrigger):
31
+ """CreateHyperparameterTuningJobTrigger run on the trigger worker to perform create operation."""
32
+
33
+ statuses_success = {
34
+ JobState.JOB_STATE_PAUSED,
35
+ JobState.JOB_STATE_SUCCEEDED,
36
+ }
37
+
38
+ def __init__(
39
+ self,
40
+ conn_id: str,
41
+ project_id: str,
42
+ location: str,
43
+ job_id: str,
44
+ poll_interval: int,
45
+ impersonation_chain: str | Sequence[str] | None = None,
46
+ ):
47
+ super().__init__()
48
+ self.conn_id = conn_id
49
+ self.project_id = project_id
50
+ self.location = location
51
+ self.job_id = job_id
52
+ self.poll_interval = poll_interval
53
+ self.impersonation_chain = impersonation_chain
54
+
55
+ def serialize(self) -> tuple[str, dict[str, Any]]:
56
+ return (
57
+ "airflow.providers.google.cloud.triggers.vertex_ai.CreateHyperparameterTuningJobTrigger",
58
+ {
59
+ "conn_id": self.conn_id,
60
+ "project_id": self.project_id,
61
+ "location": self.location,
62
+ "job_id": self.job_id,
63
+ "poll_interval": self.poll_interval,
64
+ "impersonation_chain": self.impersonation_chain,
65
+ },
66
+ )
67
+
68
+ async def run(self) -> AsyncIterator[TriggerEvent]:
69
+ hook = self._get_async_hook()
70
+ try:
71
+ job = await hook.wait_hyperparameter_tuning_job(
72
+ project_id=self.project_id,
73
+ location=self.location,
74
+ job_id=self.job_id,
75
+ poll_interval=self.poll_interval,
76
+ )
77
+ except AirflowException as ex:
78
+ yield TriggerEvent(
79
+ {
80
+ "status": "error",
81
+ "message": str(ex),
82
+ }
83
+ )
84
+ return
85
+
86
+ status = "success" if job.state in self.statuses_success else "error"
87
+ message = f"Hyperparameter tuning job {job.name} completed with status {job.state.name}"
88
+ yield TriggerEvent(
89
+ {
90
+ "status": status,
91
+ "message": message,
92
+ "job": HyperparameterTuningJob.to_dict(job),
93
+ }
94
+ )
95
+
96
+ def _get_async_hook(self) -> HyperparameterTuningJobAsyncHook:
97
+ return HyperparameterTuningJobAsyncHook(
98
+ gcp_conn_id=self.conn_id, impersonation_chain=self.impersonation_chain
99
+ )
@@ -21,7 +21,7 @@ from typing import Any
21
21
 
22
22
  def bq_cast(string_field: str, bq_type: str) -> None | int | float | bool | str:
23
23
  """
24
- Helper method that casts a BigQuery row to the appropriate data types.
24
+ Cast a BigQuery row to the appropriate data types.
25
25
 
26
26
  This is useful because BigQuery returns all fields as strings.
27
27
  """
@@ -41,7 +41,7 @@ def bq_cast(string_field: str, bq_type: str) -> None | int | float | bool | str:
41
41
 
42
42
  def convert_job_id(job_id: str | list[str], project_id: str, location: str | None) -> Any:
43
43
  """
44
- Helper method that converts to path: project_id:location:job_id.
44
+ Convert job_id to path: project_id:location:job_id.
45
45
 
46
46
  :param project_id: Required. The ID of the Google Cloud project where workspace located.
47
47
  :param location: Optional. The ID of the Google Cloud region where workspace located.
@@ -358,7 +358,7 @@ class _CredentialProvider(LoggingMixin):
358
358
 
359
359
 
360
360
  def get_credentials_and_project_id(*args, **kwargs) -> tuple[google.auth.credentials.Credentials, str]:
361
- """Returns the Credentials object for Google API and the associated project_id."""
361
+ """Return the Credentials object for Google API and the associated project_id."""
362
362
  return _CredentialProvider(*args, **kwargs).get_credentials_and_project()
363
363
 
364
364
 
@@ -398,7 +398,7 @@ def _get_target_principal_and_delegates(
398
398
 
399
399
  def _get_project_id_from_service_account_email(service_account_email: str) -> str:
400
400
  """
401
- Extracts project_id from service account's email address.
401
+ Extract project_id from service account's email address.
402
402
 
403
403
  :param service_account_email: email of the service account.
404
404
 
@@ -45,7 +45,7 @@ def make_initialization_workspace_flow(
45
45
  without_installation: bool = False,
46
46
  ) -> tuple:
47
47
  """
48
- Creates flow which simulates the initialization of the default project.
48
+ Create flow which simulates the initialization of the default project.
49
49
 
50
50
  :param project_id: Required. The ID of the Google Cloud project where workspace located.
51
51
  :param region: Required. The ID of the Google Cloud region where workspace located.
@@ -0,0 +1,25 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from enum import Enum
20
+
21
+
22
+ class DataprocOperationType(Enum):
23
+ """Contains types of long running operations."""
24
+
25
+ DIAGNOSE = "DIAGNOSE"
@@ -309,7 +309,7 @@ class GcpBodyFieldValidator(LoggingMixin):
309
309
 
310
310
  def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, force_optional=False):
311
311
  """
312
- Validates if field is OK.
312
+ Validate if field is OK.
313
313
 
314
314
  :param validation_spec: specification of the field
315
315
  :param dictionary_to_validate: dictionary where the field should be present
@@ -413,7 +413,7 @@ class GcpBodyFieldValidator(LoggingMixin):
413
413
 
414
414
  def validate(self, body_to_validate: dict) -> None:
415
415
  """
416
- Validates if the body (dictionary) follows specification that the validator was instantiated with.
416
+ Validate if the body (dictionary) follows specification that the validator was instantiated with.
417
417
 
418
418
  Raises ValidationSpecificationException or ValidationFieldException in case of problems
419
419
  with specification or the body not conforming to the specification respectively.
@@ -19,12 +19,12 @@ from __future__ import annotations
19
19
 
20
20
 
21
21
  def normalize_directory_path(source_object: str | None) -> str | None:
22
- """Makes sure dir path ends with a slash."""
22
+ """Make sure dir path ends with a slash."""
23
23
  return source_object + "/" if source_object and not source_object.endswith("/") else source_object
24
24
 
25
25
 
26
26
  def resource_path_to_dict(resource_name: str) -> dict[str, str]:
27
- """Converts a path-like GCP resource name into a dictionary.
27
+ """Convert a path-like GCP resource name into a dictionary.
28
28
 
29
29
  For example, the path `projects/my-project/locations/my-location/instances/my-instance` will be converted
30
30
  to a dict:
@@ -57,7 +57,7 @@ def create_evaluate_ops(
57
57
  py_interpreter="python3",
58
58
  ) -> tuple[MLEngineStartBatchPredictionJobOperator, BeamRunPythonPipelineOperator, PythonOperator]:
59
59
  r"""
60
- Creates Operators needed for model evaluation and returns.
60
+ Create Operators needed for model evaluation and returns.
61
61
 
62
62
  This function is deprecated. All the functionality of legacy MLEngine and new features are available
63
63
  on the Vertex AI platform.
@@ -151,7 +151,7 @@ def MakeSummary(pcoll, metric_fn, metric_keys):
151
151
 
152
152
 
153
153
  def run(argv=None):
154
- """Helper for obtaining prediction summary."""
154
+ """Obtain prediction summary."""
155
155
  parser = argparse.ArgumentParser()
156
156
  parser.add_argument(
157
157
  "--prediction_path",
@@ -52,7 +52,7 @@ def create_client_session():
52
52
 
53
53
 
54
54
  def init_app(_):
55
- """Initializes authentication."""
55
+ """Initialize authentication."""
56
56
 
57
57
 
58
58
  def _get_id_token_from_request(request) -> str | None:
@@ -110,7 +110,7 @@ T = TypeVar("T", bound=Callable)
110
110
 
111
111
 
112
112
  def requires_authentication(function: T):
113
- """Decorator for functions that require authentication."""
113
+ """Decorator for function that require authentication."""
114
114
 
115
115
  @wraps(function)
116
116
  def decorated(*args, **kwargs):