apache-airflow-providers-google 10.14.0rc2__py3-none-any.whl → 10.15.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/cloud/hooks/automl.py +13 -13
- airflow/providers/google/cloud/hooks/bigquery.py +193 -246
- airflow/providers/google/cloud/hooks/bigquery_dts.py +6 -6
- airflow/providers/google/cloud/hooks/bigtable.py +8 -8
- airflow/providers/google/cloud/hooks/cloud_batch.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_build.py +19 -20
- airflow/providers/google/cloud/hooks/cloud_composer.py +4 -4
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +10 -10
- airflow/providers/google/cloud/hooks/cloud_run.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_sql.py +17 -17
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +3 -3
- airflow/providers/google/cloud/hooks/compute.py +16 -16
- airflow/providers/google/cloud/hooks/compute_ssh.py +1 -1
- airflow/providers/google/cloud/hooks/datacatalog.py +22 -22
- airflow/providers/google/cloud/hooks/dataflow.py +48 -49
- airflow/providers/google/cloud/hooks/dataform.py +16 -16
- airflow/providers/google/cloud/hooks/datafusion.py +15 -15
- airflow/providers/google/cloud/hooks/datapipeline.py +3 -3
- airflow/providers/google/cloud/hooks/dataplex.py +19 -19
- airflow/providers/google/cloud/hooks/dataprep.py +8 -8
- airflow/providers/google/cloud/hooks/dataproc.py +88 -0
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +13 -13
- airflow/providers/google/cloud/hooks/datastore.py +3 -3
- airflow/providers/google/cloud/hooks/dlp.py +25 -25
- airflow/providers/google/cloud/hooks/gcs.py +25 -23
- airflow/providers/google/cloud/hooks/gdm.py +3 -3
- airflow/providers/google/cloud/hooks/kms.py +3 -3
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +63 -48
- airflow/providers/google/cloud/hooks/life_sciences.py +13 -12
- airflow/providers/google/cloud/hooks/looker.py +7 -7
- airflow/providers/google/cloud/hooks/mlengine.py +12 -12
- airflow/providers/google/cloud/hooks/natural_language.py +2 -2
- airflow/providers/google/cloud/hooks/os_login.py +1 -1
- airflow/providers/google/cloud/hooks/pubsub.py +9 -9
- airflow/providers/google/cloud/hooks/secret_manager.py +1 -1
- airflow/providers/google/cloud/hooks/spanner.py +11 -11
- airflow/providers/google/cloud/hooks/speech_to_text.py +1 -1
- airflow/providers/google/cloud/hooks/stackdriver.py +7 -7
- airflow/providers/google/cloud/hooks/tasks.py +11 -11
- airflow/providers/google/cloud/hooks/text_to_speech.py +1 -1
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +13 -13
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +6 -6
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +45 -50
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +13 -13
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +9 -9
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +128 -11
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +10 -10
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +8 -8
- airflow/providers/google/cloud/hooks/video_intelligence.py +2 -2
- airflow/providers/google/cloud/hooks/vision.py +1 -1
- airflow/providers/google/cloud/hooks/workflows.py +10 -10
- airflow/providers/google/cloud/links/datafusion.py +12 -5
- airflow/providers/google/cloud/operators/bigquery.py +9 -11
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +3 -1
- airflow/providers/google/cloud/operators/dataflow.py +16 -16
- airflow/providers/google/cloud/operators/datafusion.py +9 -1
- airflow/providers/google/cloud/operators/dataproc.py +298 -65
- airflow/providers/google/cloud/operators/kubernetes_engine.py +6 -6
- airflow/providers/google/cloud/operators/life_sciences.py +10 -9
- airflow/providers/google/cloud/operators/mlengine.py +96 -96
- airflow/providers/google/cloud/operators/pubsub.py +2 -0
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +33 -3
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +59 -2
- airflow/providers/google/cloud/secrets/secret_manager.py +8 -7
- airflow/providers/google/cloud/sensors/bigquery.py +20 -16
- airflow/providers/google/cloud/sensors/cloud_composer.py +11 -8
- airflow/providers/google/cloud/sensors/gcs.py +8 -7
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/mysql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/presto_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/trino_to_gcs.py +1 -1
- airflow/providers/google/cloud/triggers/bigquery.py +12 -12
- airflow/providers/google/cloud/triggers/bigquery_dts.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_batch.py +3 -1
- airflow/providers/google/cloud/triggers/cloud_build.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +6 -6
- airflow/providers/google/cloud/triggers/dataflow.py +3 -1
- airflow/providers/google/cloud/triggers/datafusion.py +2 -2
- airflow/providers/google/cloud/triggers/dataplex.py +2 -2
- airflow/providers/google/cloud/triggers/dataproc.py +2 -2
- airflow/providers/google/cloud/triggers/gcs.py +12 -8
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/triggers/mlengine.py +2 -2
- airflow/providers/google/cloud/triggers/pubsub.py +1 -1
- airflow/providers/google/cloud/triggers/vertex_ai.py +99 -0
- airflow/providers/google/cloud/utils/bigquery.py +2 -2
- airflow/providers/google/cloud/utils/credentials_provider.py +2 -2
- airflow/providers/google/cloud/utils/dataform.py +1 -1
- airflow/providers/google/cloud/utils/field_validator.py +2 -2
- airflow/providers/google/cloud/utils/helpers.py +2 -2
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +1 -1
- airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +1 -1
- airflow/providers/google/common/auth_backend/google_openid.py +2 -2
- airflow/providers/google/common/hooks/base_google.py +29 -22
- airflow/providers/google/common/hooks/discovery_api.py +2 -2
- airflow/providers/google/common/utils/id_token_credentials.py +5 -5
- airflow/providers/google/firebase/hooks/firestore.py +3 -3
- airflow/providers/google/get_provider_info.py +7 -2
- airflow/providers/google/leveldb/hooks/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/analytics.py +11 -14
- airflow/providers/google/marketing_platform/hooks/campaign_manager.py +11 -11
- airflow/providers/google/marketing_platform/hooks/display_video.py +13 -13
- airflow/providers/google/marketing_platform/hooks/search_ads.py +4 -4
- airflow/providers/google/marketing_platform/operators/analytics.py +37 -32
- airflow/providers/google/suite/hooks/calendar.py +2 -2
- airflow/providers/google/suite/hooks/drive.py +7 -7
- airflow/providers/google/suite/hooks/sheets.py +8 -8
- {apache_airflow_providers_google-10.14.0rc2.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/METADATA +11 -11
- {apache_airflow_providers_google-10.14.0rc2.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/RECORD +121 -120
- {apache_airflow_providers_google-10.14.0rc2.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.14.0rc2.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/entry_points.txt +0 -0
@@ -52,7 +52,7 @@ class DatasetHook(GoogleBaseHook):
|
|
52
52
|
super().__init__(**kwargs)
|
53
53
|
|
54
54
|
def get_dataset_service_client(self, region: str | None = None) -> DatasetServiceClient:
|
55
|
-
"""
|
55
|
+
"""Return DatasetServiceClient."""
|
56
56
|
if region and region != "global":
|
57
57
|
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
|
58
58
|
else:
|
@@ -63,7 +63,7 @@ class DatasetHook(GoogleBaseHook):
|
|
63
63
|
)
|
64
64
|
|
65
65
|
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
66
|
-
"""
|
66
|
+
"""Wait for long-lasting operation to complete."""
|
67
67
|
try:
|
68
68
|
return operation.result(timeout=timeout)
|
69
69
|
except Exception:
|
@@ -72,7 +72,7 @@ class DatasetHook(GoogleBaseHook):
|
|
72
72
|
|
73
73
|
@staticmethod
|
74
74
|
def extract_dataset_id(obj: dict) -> str:
|
75
|
-
"""
|
75
|
+
"""Return unique id of the dataset."""
|
76
76
|
return obj["name"].rpartition("/")[-1]
|
77
77
|
|
78
78
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -86,7 +86,7 @@ class DatasetHook(GoogleBaseHook):
|
|
86
86
|
metadata: Sequence[tuple[str, str]] = (),
|
87
87
|
) -> Operation:
|
88
88
|
"""
|
89
|
-
|
89
|
+
Create a Dataset.
|
90
90
|
|
91
91
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
92
92
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -120,7 +120,7 @@ class DatasetHook(GoogleBaseHook):
|
|
120
120
|
metadata: Sequence[tuple[str, str]] = (),
|
121
121
|
) -> Operation:
|
122
122
|
"""
|
123
|
-
|
123
|
+
Delete a Dataset.
|
124
124
|
|
125
125
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
126
126
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -154,7 +154,7 @@ class DatasetHook(GoogleBaseHook):
|
|
154
154
|
metadata: Sequence[tuple[str, str]] = (),
|
155
155
|
) -> Operation:
|
156
156
|
"""
|
157
|
-
|
157
|
+
Export data from a Dataset.
|
158
158
|
|
159
159
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
160
160
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -191,7 +191,7 @@ class DatasetHook(GoogleBaseHook):
|
|
191
191
|
metadata: Sequence[tuple[str, str]] = (),
|
192
192
|
) -> AnnotationSpec:
|
193
193
|
"""
|
194
|
-
|
194
|
+
Get an AnnotationSpec.
|
195
195
|
|
196
196
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
197
197
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -228,7 +228,7 @@ class DatasetHook(GoogleBaseHook):
|
|
228
228
|
metadata: Sequence[tuple[str, str]] = (),
|
229
229
|
) -> Dataset:
|
230
230
|
"""
|
231
|
-
|
231
|
+
Get a Dataset.
|
232
232
|
|
233
233
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
234
234
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -264,7 +264,7 @@ class DatasetHook(GoogleBaseHook):
|
|
264
264
|
metadata: Sequence[tuple[str, str]] = (),
|
265
265
|
) -> Operation:
|
266
266
|
"""
|
267
|
-
|
267
|
+
Import data into a Dataset.
|
268
268
|
|
269
269
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
270
270
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -306,7 +306,7 @@ class DatasetHook(GoogleBaseHook):
|
|
306
306
|
metadata: Sequence[tuple[str, str]] = (),
|
307
307
|
) -> ListAnnotationsPager:
|
308
308
|
"""
|
309
|
-
|
309
|
+
List Annotations belongs to a data item.
|
310
310
|
|
311
311
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
312
312
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -356,7 +356,7 @@ class DatasetHook(GoogleBaseHook):
|
|
356
356
|
metadata: Sequence[tuple[str, str]] = (),
|
357
357
|
) -> ListDataItemsPager:
|
358
358
|
"""
|
359
|
-
|
359
|
+
List DataItems in a Dataset.
|
360
360
|
|
361
361
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
362
362
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -404,7 +404,7 @@ class DatasetHook(GoogleBaseHook):
|
|
404
404
|
metadata: Sequence[tuple[str, str]] = (),
|
405
405
|
) -> ListDatasetsPager:
|
406
406
|
"""
|
407
|
-
|
407
|
+
List Datasets in a Location.
|
408
408
|
|
409
409
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
410
410
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -448,7 +448,7 @@ class DatasetHook(GoogleBaseHook):
|
|
448
448
|
metadata: Sequence[tuple[str, str]] = (),
|
449
449
|
) -> Dataset:
|
450
450
|
"""
|
451
|
-
|
451
|
+
Update a Dataset.
|
452
452
|
|
453
453
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
454
454
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -47,7 +47,7 @@ class EndpointServiceHook(GoogleBaseHook):
|
|
47
47
|
super().__init__(**kwargs)
|
48
48
|
|
49
49
|
def get_endpoint_service_client(self, region: str | None = None) -> EndpointServiceClient:
|
50
|
-
"""
|
50
|
+
"""Return EndpointServiceClient."""
|
51
51
|
if region and region != "global":
|
52
52
|
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
|
53
53
|
else:
|
@@ -58,7 +58,7 @@ class EndpointServiceHook(GoogleBaseHook):
|
|
58
58
|
)
|
59
59
|
|
60
60
|
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
61
|
-
"""
|
61
|
+
"""Wait for long-lasting operation to complete."""
|
62
62
|
try:
|
63
63
|
return operation.result(timeout=timeout)
|
64
64
|
except Exception:
|
@@ -67,12 +67,12 @@ class EndpointServiceHook(GoogleBaseHook):
|
|
67
67
|
|
68
68
|
@staticmethod
|
69
69
|
def extract_endpoint_id(obj: dict) -> str:
|
70
|
-
"""
|
70
|
+
"""Return unique id of the endpoint."""
|
71
71
|
return obj["name"].rpartition("/")[-1]
|
72
72
|
|
73
73
|
@staticmethod
|
74
74
|
def extract_deployed_model_id(obj: dict) -> str:
|
75
|
-
"""
|
75
|
+
"""Return unique id of the deploy model."""
|
76
76
|
return obj["deployed_model"]["id"]
|
77
77
|
|
78
78
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -87,7 +87,7 @@ class EndpointServiceHook(GoogleBaseHook):
|
|
87
87
|
metadata: Sequence[tuple[str, str]] = (),
|
88
88
|
) -> Operation:
|
89
89
|
"""
|
90
|
-
|
90
|
+
Create an Endpoint.
|
91
91
|
|
92
92
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
93
93
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -124,7 +124,7 @@ class EndpointServiceHook(GoogleBaseHook):
|
|
124
124
|
metadata: Sequence[tuple[str, str]] = (),
|
125
125
|
) -> Operation:
|
126
126
|
"""
|
127
|
-
|
127
|
+
Delete an Endpoint.
|
128
128
|
|
129
129
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
130
130
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -210,7 +210,7 @@ class EndpointServiceHook(GoogleBaseHook):
|
|
210
210
|
metadata: Sequence[tuple[str, str]] = (),
|
211
211
|
) -> Endpoint:
|
212
212
|
"""
|
213
|
-
|
213
|
+
Get an Endpoint.
|
214
214
|
|
215
215
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
216
216
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -247,7 +247,7 @@ class EndpointServiceHook(GoogleBaseHook):
|
|
247
247
|
metadata: Sequence[tuple[str, str]] = (),
|
248
248
|
) -> ListEndpointsPager:
|
249
249
|
"""
|
250
|
-
|
250
|
+
List Endpoints in a Location.
|
251
251
|
|
252
252
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
253
253
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -352,7 +352,7 @@ class EndpointServiceHook(GoogleBaseHook):
|
|
352
352
|
metadata: Sequence[tuple[str, str]] = (),
|
353
353
|
) -> Endpoint:
|
354
354
|
"""
|
355
|
-
|
355
|
+
Update an Endpoint.
|
356
356
|
|
357
357
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
358
358
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -15,22 +15,32 @@
|
|
15
15
|
# KIND, either express or implied. See the License for the
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
|
-
"""
|
18
|
+
"""
|
19
|
+
This module contains a Google Cloud Vertex AI hook.
|
20
|
+
|
21
|
+
.. spelling:word-list::
|
22
|
+
|
23
|
+
JobServiceAsyncClient
|
24
|
+
"""
|
19
25
|
from __future__ import annotations
|
20
26
|
|
27
|
+
import asyncio
|
28
|
+
from functools import lru_cache
|
21
29
|
from typing import TYPE_CHECKING, Sequence
|
22
30
|
|
23
31
|
from google.api_core.client_options import ClientOptions
|
24
32
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
25
33
|
from google.cloud.aiplatform import CustomJob, HyperparameterTuningJob, gapic, hyperparameter_tuning
|
26
|
-
from google.cloud.aiplatform_v1 import JobServiceClient, types
|
34
|
+
from google.cloud.aiplatform_v1 import JobServiceAsyncClient, JobServiceClient, JobState, types
|
27
35
|
|
28
36
|
from airflow.exceptions import AirflowException
|
37
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
29
38
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
30
39
|
|
31
40
|
if TYPE_CHECKING:
|
32
41
|
from google.api_core.operation import Operation
|
33
42
|
from google.api_core.retry import Retry
|
43
|
+
from google.api_core.retry_async import AsyncRetry
|
34
44
|
from google.cloud.aiplatform_v1.services.job_service.pagers import ListHyperparameterTuningJobsPager
|
35
45
|
|
36
46
|
|
@@ -55,7 +65,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
55
65
|
self._hyperparameter_tuning_job: HyperparameterTuningJob | None = None
|
56
66
|
|
57
67
|
def get_job_service_client(self, region: str | None = None) -> JobServiceClient:
|
58
|
-
"""
|
68
|
+
"""Return JobServiceClient."""
|
59
69
|
if region and region != "global":
|
60
70
|
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
|
61
71
|
else:
|
@@ -81,7 +91,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
81
91
|
labels: dict[str, str] | None = None,
|
82
92
|
encryption_spec_key_name: str | None = None,
|
83
93
|
) -> HyperparameterTuningJob:
|
84
|
-
"""
|
94
|
+
"""Return HyperparameterTuningJob object."""
|
85
95
|
return HyperparameterTuningJob(
|
86
96
|
display_name=display_name,
|
87
97
|
custom_job=custom_job,
|
@@ -110,7 +120,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
110
120
|
encryption_spec_key_name: str | None = None,
|
111
121
|
staging_bucket: str | None = None,
|
112
122
|
) -> CustomJob:
|
113
|
-
"""
|
123
|
+
"""Return CustomJob object."""
|
114
124
|
return CustomJob(
|
115
125
|
display_name=display_name,
|
116
126
|
worker_pool_specs=worker_pool_specs,
|
@@ -125,11 +135,11 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
125
135
|
|
126
136
|
@staticmethod
|
127
137
|
def extract_hyperparameter_tuning_job_id(obj: dict) -> str:
|
128
|
-
"""
|
138
|
+
"""Return unique id of the hyperparameter_tuning_job."""
|
129
139
|
return obj["name"].rpartition("/")[-1]
|
130
140
|
|
131
141
|
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
132
|
-
"""
|
142
|
+
"""Wait for long-lasting operation to complete."""
|
133
143
|
try:
|
134
144
|
return operation.result(timeout=timeout)
|
135
145
|
except Exception:
|
@@ -172,6 +182,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
172
182
|
tensorboard: str | None = None,
|
173
183
|
sync: bool = True,
|
174
184
|
# END: run param
|
185
|
+
wait_job_completed: bool = True,
|
175
186
|
) -> HyperparameterTuningJob:
|
176
187
|
"""
|
177
188
|
Create a HyperparameterTuningJob.
|
@@ -256,6 +267,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
256
267
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
257
268
|
:param sync: Whether to execute this method synchronously. If False, this method will unblock and it
|
258
269
|
will be executed in a concurrent Future.
|
270
|
+
:param wait_job_completed: Whether to wait for the job completed.
|
259
271
|
"""
|
260
272
|
custom_job = self.get_custom_job_object(
|
261
273
|
project=project_id,
|
@@ -292,7 +304,11 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
292
304
|
tensorboard=tensorboard,
|
293
305
|
sync=sync,
|
294
306
|
)
|
295
|
-
|
307
|
+
|
308
|
+
if wait_job_completed:
|
309
|
+
self._hyperparameter_tuning_job.wait()
|
310
|
+
else:
|
311
|
+
self._hyperparameter_tuning_job._wait_for_resource_creation()
|
296
312
|
return self._hyperparameter_tuning_job
|
297
313
|
|
298
314
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -306,7 +322,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
306
322
|
metadata: Sequence[tuple[str, str]] = (),
|
307
323
|
) -> types.HyperparameterTuningJob:
|
308
324
|
"""
|
309
|
-
|
325
|
+
Get a HyperparameterTuningJob.
|
310
326
|
|
311
327
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
312
328
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -342,7 +358,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
342
358
|
metadata: Sequence[tuple[str, str]] = (),
|
343
359
|
) -> ListHyperparameterTuningJobsPager:
|
344
360
|
"""
|
345
|
-
|
361
|
+
List HyperparameterTuningJobs in a Location.
|
346
362
|
|
347
363
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
348
364
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -391,7 +407,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
391
407
|
metadata: Sequence[tuple[str, str]] = (),
|
392
408
|
) -> Operation:
|
393
409
|
"""
|
394
|
-
|
410
|
+
Delete a HyperparameterTuningJob.
|
395
411
|
|
396
412
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
397
413
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -413,3 +429,104 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
|
|
413
429
|
metadata=metadata,
|
414
430
|
)
|
415
431
|
return result
|
432
|
+
|
433
|
+
|
434
|
+
class HyperparameterTuningJobAsyncHook(GoogleBaseHook):
|
435
|
+
"""Async hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs."""
|
436
|
+
|
437
|
+
def __init__(
|
438
|
+
self,
|
439
|
+
gcp_conn_id: str = "google_cloud_default",
|
440
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
441
|
+
**kwargs,
|
442
|
+
):
|
443
|
+
super().__init__(
|
444
|
+
gcp_conn_id=gcp_conn_id,
|
445
|
+
impersonation_chain=impersonation_chain,
|
446
|
+
**kwargs,
|
447
|
+
)
|
448
|
+
|
449
|
+
@lru_cache
|
450
|
+
def get_job_service_client(self, region: str | None = None) -> JobServiceAsyncClient:
|
451
|
+
"""
|
452
|
+
Retrieve Vertex AI async client.
|
453
|
+
|
454
|
+
:return: Google Cloud Vertex AI client object.
|
455
|
+
"""
|
456
|
+
endpoint = f"{region}-aiplatform.googleapis.com:443" if region and region != "global" else None
|
457
|
+
return JobServiceAsyncClient(
|
458
|
+
credentials=self.get_credentials(),
|
459
|
+
client_info=CLIENT_INFO,
|
460
|
+
client_options=ClientOptions(api_endpoint=endpoint),
|
461
|
+
)
|
462
|
+
|
463
|
+
async def get_hyperparameter_tuning_job(
|
464
|
+
self,
|
465
|
+
project_id: str,
|
466
|
+
location: str,
|
467
|
+
job_id: str,
|
468
|
+
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
469
|
+
timeout: float | None = None,
|
470
|
+
metadata: Sequence[tuple[str, str]] = (),
|
471
|
+
) -> types.HyperparameterTuningJob:
|
472
|
+
"""
|
473
|
+
Retrieve a hyperparameter tuning job.
|
474
|
+
|
475
|
+
:param project_id: Required. The ID of the Google Cloud project that the job belongs to.
|
476
|
+
:param location: Required. The ID of the Google Cloud region that the job belongs to.
|
477
|
+
:param job_id: Required. The hyperparameter tuning job id.
|
478
|
+
:param retry: Designation of what errors, if any, should be retried.
|
479
|
+
:param timeout: The timeout for this request.
|
480
|
+
:param metadata: Strings which should be sent along with the request as metadata.
|
481
|
+
"""
|
482
|
+
client: JobServiceAsyncClient = self.get_job_service_client(region=location)
|
483
|
+
job_name = client.hyperparameter_tuning_job_path(project_id, location, job_id)
|
484
|
+
|
485
|
+
result = await client.get_hyperparameter_tuning_job(
|
486
|
+
request={
|
487
|
+
"name": job_name,
|
488
|
+
},
|
489
|
+
retry=retry,
|
490
|
+
timeout=timeout,
|
491
|
+
metadata=metadata,
|
492
|
+
)
|
493
|
+
|
494
|
+
return result
|
495
|
+
|
496
|
+
async def wait_hyperparameter_tuning_job(
|
497
|
+
self,
|
498
|
+
project_id: str,
|
499
|
+
location: str,
|
500
|
+
job_id: str,
|
501
|
+
retry: AsyncRetry | _MethodDefault = DEFAULT,
|
502
|
+
timeout: float | None = None,
|
503
|
+
metadata: Sequence[tuple[str, str]] = (),
|
504
|
+
poll_interval: int = 10,
|
505
|
+
) -> types.HyperparameterTuningJob:
|
506
|
+
statuses_complete = {
|
507
|
+
JobState.JOB_STATE_CANCELLED,
|
508
|
+
JobState.JOB_STATE_FAILED,
|
509
|
+
JobState.JOB_STATE_PAUSED,
|
510
|
+
JobState.JOB_STATE_SUCCEEDED,
|
511
|
+
}
|
512
|
+
while True:
|
513
|
+
try:
|
514
|
+
self.log.info("Requesting hyperparameter tuning job with id %s", job_id)
|
515
|
+
job: types.HyperparameterTuningJob = await self.get_hyperparameter_tuning_job(
|
516
|
+
project_id=project_id,
|
517
|
+
location=location,
|
518
|
+
job_id=job_id,
|
519
|
+
retry=retry,
|
520
|
+
timeout=timeout,
|
521
|
+
metadata=metadata,
|
522
|
+
)
|
523
|
+
except Exception as ex:
|
524
|
+
self.log.exception("Exception occurred while requesting job %s", job_id)
|
525
|
+
raise AirflowException(ex)
|
526
|
+
|
527
|
+
self.log.info("Status of the hyperparameter tuning job %s is %s", job.name, job.state.name)
|
528
|
+
if job.state in statuses_complete:
|
529
|
+
return job
|
530
|
+
|
531
|
+
self.log.info("Sleeping for %s seconds.", poll_interval)
|
532
|
+
await asyncio.sleep(poll_interval)
|
@@ -51,7 +51,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
51
51
|
super().__init__(**kwargs)
|
52
52
|
|
53
53
|
def get_model_service_client(self, region: str | None = None) -> ModelServiceClient:
|
54
|
-
"""
|
54
|
+
"""Return ModelServiceClient object."""
|
55
55
|
if region and region != "global":
|
56
56
|
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
|
57
57
|
else:
|
@@ -63,11 +63,11 @@ class ModelServiceHook(GoogleBaseHook):
|
|
63
63
|
|
64
64
|
@staticmethod
|
65
65
|
def extract_model_id(obj: dict) -> str:
|
66
|
-
"""
|
66
|
+
"""Return unique id of the model."""
|
67
67
|
return obj["model"].rpartition("/")[-1]
|
68
68
|
|
69
69
|
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
70
|
-
"""
|
70
|
+
"""Wait for long-lasting operation to complete."""
|
71
71
|
try:
|
72
72
|
return operation.result(timeout=timeout)
|
73
73
|
except Exception:
|
@@ -85,7 +85,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
85
85
|
metadata: Sequence[tuple[str, str]] = (),
|
86
86
|
) -> Operation:
|
87
87
|
"""
|
88
|
-
|
88
|
+
Delete a Model.
|
89
89
|
|
90
90
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
91
91
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -119,7 +119,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
119
119
|
metadata: Sequence[tuple[str, str]] = (),
|
120
120
|
) -> Operation:
|
121
121
|
"""
|
122
|
-
|
122
|
+
Export a trained, exportable Model to a location specified by the user.
|
123
123
|
|
124
124
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
125
125
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -158,7 +158,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
158
158
|
metadata: Sequence[tuple[str, str]] = (),
|
159
159
|
) -> ListModelsPager:
|
160
160
|
r"""
|
161
|
-
|
161
|
+
List Models in a Location.
|
162
162
|
|
163
163
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
164
164
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -213,7 +213,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
213
213
|
metadata: Sequence[tuple[str, str]] = (),
|
214
214
|
) -> Operation:
|
215
215
|
"""
|
216
|
-
|
216
|
+
Upload a Model artifact into Vertex AI.
|
217
217
|
|
218
218
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
219
219
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -247,7 +247,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
247
247
|
metadata: Sequence[tuple[str, str]] = (),
|
248
248
|
) -> ListModelVersionsPager:
|
249
249
|
"""
|
250
|
-
|
250
|
+
List all versions of the existing Model.
|
251
251
|
|
252
252
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
253
253
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -280,7 +280,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
280
280
|
metadata: Sequence[tuple[str, str]] = (),
|
281
281
|
) -> Operation:
|
282
282
|
"""
|
283
|
-
|
283
|
+
Delete version of the Model. The version could not be deleted if this version is default.
|
284
284
|
|
285
285
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
286
286
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -313,7 +313,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
313
313
|
metadata: Sequence[tuple[str, str]] = (),
|
314
314
|
) -> Model:
|
315
315
|
"""
|
316
|
-
|
316
|
+
Retrieve Model of specific name and version. If version is not specified, the default is retrieved.
|
317
317
|
|
318
318
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
319
319
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -60,7 +60,7 @@ class PipelineJobHook(GoogleBaseHook):
|
|
60
60
|
self,
|
61
61
|
region: str | None = None,
|
62
62
|
) -> PipelineServiceClient:
|
63
|
-
"""
|
63
|
+
"""Return PipelineServiceClient object."""
|
64
64
|
if region and region != "global":
|
65
65
|
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
|
66
66
|
else:
|
@@ -84,7 +84,7 @@ class PipelineJobHook(GoogleBaseHook):
|
|
84
84
|
location: str | None = None,
|
85
85
|
failure_policy: str | None = None,
|
86
86
|
) -> PipelineJob:
|
87
|
-
"""
|
87
|
+
"""Return PipelineJob object."""
|
88
88
|
return PipelineJob(
|
89
89
|
display_name=display_name,
|
90
90
|
template_path=template_path,
|
@@ -103,11 +103,11 @@ class PipelineJobHook(GoogleBaseHook):
|
|
103
103
|
|
104
104
|
@staticmethod
|
105
105
|
def extract_pipeline_job_id(obj: dict) -> str:
|
106
|
-
"""
|
106
|
+
"""Return unique id of the pipeline_job."""
|
107
107
|
return obj["name"].rpartition("/")[-1]
|
108
108
|
|
109
109
|
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
110
|
-
"""
|
110
|
+
"""Wait for long-lasting operation to complete."""
|
111
111
|
try:
|
112
112
|
return operation.result(timeout=timeout)
|
113
113
|
except Exception:
|
@@ -131,7 +131,7 @@ class PipelineJobHook(GoogleBaseHook):
|
|
131
131
|
metadata: Sequence[tuple[str, str]] = (),
|
132
132
|
) -> PipelineJob:
|
133
133
|
"""
|
134
|
-
|
134
|
+
Create a PipelineJob. A PipelineJob will run immediately when created.
|
135
135
|
|
136
136
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
137
137
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -265,7 +265,7 @@ class PipelineJobHook(GoogleBaseHook):
|
|
265
265
|
metadata: Sequence[tuple[str, str]] = (),
|
266
266
|
) -> PipelineJob:
|
267
267
|
"""
|
268
|
-
|
268
|
+
Get a PipelineJob.
|
269
269
|
|
270
270
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
271
271
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -301,7 +301,7 @@ class PipelineJobHook(GoogleBaseHook):
|
|
301
301
|
metadata: Sequence[tuple[str, str]] = (),
|
302
302
|
) -> ListPipelineJobsPager:
|
303
303
|
"""
|
304
|
-
|
304
|
+
List PipelineJobs in a Location.
|
305
305
|
|
306
306
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
307
307
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -386,7 +386,7 @@ class PipelineJobHook(GoogleBaseHook):
|
|
386
386
|
metadata: Sequence[tuple[str, str]] = (),
|
387
387
|
) -> Operation:
|
388
388
|
"""
|
389
|
-
|
389
|
+
Delete a PipelineJob.
|
390
390
|
|
391
391
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
392
392
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -71,7 +71,7 @@ class CloudVideoIntelligenceHook(GoogleBaseHook):
|
|
71
71
|
self._conn: VideoIntelligenceServiceClient | None = None
|
72
72
|
|
73
73
|
def get_conn(self) -> VideoIntelligenceServiceClient:
|
74
|
-
"""
|
74
|
+
"""Return Gcp Video Intelligence Service client."""
|
75
75
|
if not self._conn:
|
76
76
|
self._conn = VideoIntelligenceServiceClient(
|
77
77
|
credentials=self.get_credentials(), client_info=CLIENT_INFO
|
@@ -92,7 +92,7 @@ class CloudVideoIntelligenceHook(GoogleBaseHook):
|
|
92
92
|
metadata: Sequence[tuple[str, str]] = (),
|
93
93
|
) -> Operation:
|
94
94
|
"""
|
95
|
-
|
95
|
+
Perform video annotation.
|
96
96
|
|
97
97
|
:param input_uri: Input video location. Currently, only Google Cloud Storage URIs are supported,
|
98
98
|
which must be specified in the following format: ``gs://bucket-id/object-id``.
|