apache-airflow-providers-google 10.14.0rc1__py3-none-any.whl → 10.15.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +1 -2
- airflow/providers/google/cloud/hooks/automl.py +13 -13
- airflow/providers/google/cloud/hooks/bigquery.py +208 -256
- airflow/providers/google/cloud/hooks/bigquery_dts.py +6 -6
- airflow/providers/google/cloud/hooks/bigtable.py +8 -8
- airflow/providers/google/cloud/hooks/cloud_batch.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_build.py +19 -20
- airflow/providers/google/cloud/hooks/cloud_composer.py +4 -4
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +10 -10
- airflow/providers/google/cloud/hooks/cloud_run.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_sql.py +18 -19
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +3 -3
- airflow/providers/google/cloud/hooks/compute.py +16 -16
- airflow/providers/google/cloud/hooks/compute_ssh.py +1 -1
- airflow/providers/google/cloud/hooks/datacatalog.py +22 -22
- airflow/providers/google/cloud/hooks/dataflow.py +48 -49
- airflow/providers/google/cloud/hooks/dataform.py +16 -16
- airflow/providers/google/cloud/hooks/datafusion.py +15 -15
- airflow/providers/google/cloud/hooks/datapipeline.py +3 -3
- airflow/providers/google/cloud/hooks/dataplex.py +19 -19
- airflow/providers/google/cloud/hooks/dataprep.py +10 -10
- airflow/providers/google/cloud/hooks/dataproc.py +132 -14
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +13 -13
- airflow/providers/google/cloud/hooks/datastore.py +3 -3
- airflow/providers/google/cloud/hooks/dlp.py +25 -25
- airflow/providers/google/cloud/hooks/gcs.py +39 -27
- airflow/providers/google/cloud/hooks/gdm.py +3 -3
- airflow/providers/google/cloud/hooks/kms.py +3 -3
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +63 -48
- airflow/providers/google/cloud/hooks/life_sciences.py +13 -12
- airflow/providers/google/cloud/hooks/looker.py +8 -9
- airflow/providers/google/cloud/hooks/mlengine.py +12 -12
- airflow/providers/google/cloud/hooks/natural_language.py +2 -2
- airflow/providers/google/cloud/hooks/os_login.py +1 -1
- airflow/providers/google/cloud/hooks/pubsub.py +9 -9
- airflow/providers/google/cloud/hooks/secret_manager.py +1 -1
- airflow/providers/google/cloud/hooks/spanner.py +11 -11
- airflow/providers/google/cloud/hooks/speech_to_text.py +1 -1
- airflow/providers/google/cloud/hooks/stackdriver.py +7 -7
- airflow/providers/google/cloud/hooks/tasks.py +11 -11
- airflow/providers/google/cloud/hooks/text_to_speech.py +1 -1
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +13 -13
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +6 -6
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +45 -50
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +13 -13
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +9 -9
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +128 -11
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +10 -10
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +8 -8
- airflow/providers/google/cloud/hooks/video_intelligence.py +2 -2
- airflow/providers/google/cloud/hooks/vision.py +1 -1
- airflow/providers/google/cloud/hooks/workflows.py +10 -10
- airflow/providers/google/cloud/links/datafusion.py +12 -5
- airflow/providers/google/cloud/operators/bigquery.py +11 -11
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +3 -1
- airflow/providers/google/cloud/operators/dataflow.py +16 -16
- airflow/providers/google/cloud/operators/datafusion.py +9 -1
- airflow/providers/google/cloud/operators/dataproc.py +444 -69
- airflow/providers/google/cloud/operators/kubernetes_engine.py +6 -6
- airflow/providers/google/cloud/operators/life_sciences.py +10 -9
- airflow/providers/google/cloud/operators/mlengine.py +96 -96
- airflow/providers/google/cloud/operators/pubsub.py +2 -0
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +33 -3
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +59 -2
- airflow/providers/google/cloud/secrets/secret_manager.py +8 -7
- airflow/providers/google/cloud/sensors/bigquery.py +20 -16
- airflow/providers/google/cloud/sensors/cloud_composer.py +11 -8
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +12 -2
- airflow/providers/google/cloud/sensors/gcs.py +8 -7
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/mysql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/presto_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/trino_to_gcs.py +1 -1
- airflow/providers/google/cloud/triggers/bigquery.py +12 -12
- airflow/providers/google/cloud/triggers/bigquery_dts.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_batch.py +3 -1
- airflow/providers/google/cloud/triggers/cloud_build.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +6 -6
- airflow/providers/google/cloud/triggers/dataflow.py +3 -1
- airflow/providers/google/cloud/triggers/datafusion.py +2 -2
- airflow/providers/google/cloud/triggers/dataplex.py +2 -2
- airflow/providers/google/cloud/triggers/dataproc.py +34 -14
- airflow/providers/google/cloud/triggers/gcs.py +12 -8
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/triggers/mlengine.py +2 -2
- airflow/providers/google/cloud/triggers/pubsub.py +1 -1
- airflow/providers/google/cloud/triggers/vertex_ai.py +99 -0
- airflow/providers/google/cloud/utils/bigquery.py +2 -2
- airflow/providers/google/cloud/utils/credentials_provider.py +2 -2
- airflow/providers/google/cloud/utils/dataform.py +1 -1
- airflow/providers/google/cloud/utils/dataproc.py +25 -0
- airflow/providers/google/cloud/utils/field_validator.py +2 -2
- airflow/providers/google/cloud/utils/helpers.py +2 -2
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +1 -1
- airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +1 -1
- airflow/providers/google/common/auth_backend/google_openid.py +2 -2
- airflow/providers/google/common/hooks/base_google.py +87 -23
- airflow/providers/google/common/hooks/discovery_api.py +2 -2
- airflow/providers/google/common/utils/id_token_credentials.py +5 -5
- airflow/providers/google/firebase/hooks/firestore.py +3 -3
- airflow/providers/google/get_provider_info.py +7 -2
- airflow/providers/google/leveldb/hooks/leveldb.py +4 -4
- airflow/providers/google/marketing_platform/hooks/analytics.py +11 -14
- airflow/providers/google/marketing_platform/hooks/campaign_manager.py +11 -11
- airflow/providers/google/marketing_platform/hooks/display_video.py +13 -13
- airflow/providers/google/marketing_platform/hooks/search_ads.py +4 -4
- airflow/providers/google/marketing_platform/operators/analytics.py +37 -32
- airflow/providers/google/suite/hooks/calendar.py +2 -2
- airflow/providers/google/suite/hooks/drive.py +7 -7
- airflow/providers/google/suite/hooks/sheets.py +8 -8
- {apache_airflow_providers_google-10.14.0rc1.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/METADATA +11 -11
- {apache_airflow_providers_google-10.14.0rc1.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/RECORD +126 -124
- {apache_airflow_providers_google-10.14.0rc1.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.14.0rc1.dist-info → apache_airflow_providers_google-10.15.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,6 +19,7 @@
|
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
21
|
import asyncio
|
22
|
+
import re
|
22
23
|
import time
|
23
24
|
from typing import Any, AsyncIterator, Sequence
|
24
25
|
|
@@ -26,6 +27,7 @@ from google.api_core.exceptions import NotFound
|
|
26
27
|
from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
|
27
28
|
|
28
29
|
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
|
30
|
+
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
|
29
31
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
30
32
|
|
31
33
|
|
@@ -180,7 +182,7 @@ class DataprocBatchTrigger(DataprocBaseTrigger):
|
|
180
182
|
self.batch_id = batch_id
|
181
183
|
|
182
184
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
183
|
-
"""
|
185
|
+
"""Serialize DataprocBatchTrigger arguments and classpath."""
|
184
186
|
return (
|
185
187
|
"airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger",
|
186
188
|
{
|
@@ -242,7 +244,7 @@ class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
|
|
242
244
|
self.metadata = metadata
|
243
245
|
|
244
246
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
245
|
-
"""
|
247
|
+
"""Serialize DataprocDeleteClusterTrigger arguments and classpath."""
|
246
248
|
return (
|
247
249
|
"airflow.providers.google.cloud.triggers.dataproc.DataprocDeleteClusterTrigger",
|
248
250
|
{
|
@@ -281,22 +283,24 @@ class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
|
|
281
283
|
yield TriggerEvent({"status": "error", "message": "Timeout"})
|
282
284
|
|
283
285
|
|
284
|
-
class
|
286
|
+
class DataprocOperationTrigger(DataprocBaseTrigger):
|
285
287
|
"""
|
286
|
-
Trigger that periodically polls information from Dataproc API to verify status.
|
288
|
+
Trigger that periodically polls information on a long running operation from Dataproc API to verify status.
|
287
289
|
|
288
290
|
Implementation leverages asynchronous transport.
|
289
291
|
"""
|
290
292
|
|
291
|
-
def __init__(self, name: str, **kwargs: Any):
|
293
|
+
def __init__(self, name: str, operation_type: str | None = None, **kwargs: Any):
|
292
294
|
super().__init__(**kwargs)
|
293
295
|
self.name = name
|
296
|
+
self.operation_type = operation_type
|
294
297
|
|
295
298
|
def serialize(self):
|
296
299
|
return (
|
297
|
-
"airflow.providers.google.cloud.triggers.dataproc.
|
300
|
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger",
|
298
301
|
{
|
299
302
|
"name": self.name,
|
303
|
+
"operation_type": self.operation_type,
|
300
304
|
"project_id": self.project_id,
|
301
305
|
"region": self.region,
|
302
306
|
"gcp_conn_id": self.gcp_conn_id,
|
@@ -317,14 +321,30 @@ class DataprocWorkflowTrigger(DataprocBaseTrigger):
|
|
317
321
|
else:
|
318
322
|
status = "success"
|
319
323
|
message = "Operation is successfully ended."
|
320
|
-
|
321
|
-
{
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
"
|
326
|
-
|
327
|
-
|
324
|
+
if self.operation_type == DataprocOperationType.DIAGNOSE.value:
|
325
|
+
gcs_regex = rb"gs:\/\/[a-z0-9][a-z0-9_-]{1,61}[a-z0-9_\-\/]*"
|
326
|
+
gcs_uri_value = operation.response.value
|
327
|
+
match = re.search(gcs_regex, gcs_uri_value)
|
328
|
+
if match:
|
329
|
+
output_uri = match.group(0).decode("utf-8", "ignore")
|
330
|
+
else:
|
331
|
+
output_uri = gcs_uri_value
|
332
|
+
yield TriggerEvent(
|
333
|
+
{
|
334
|
+
"status": status,
|
335
|
+
"message": message,
|
336
|
+
"output_uri": output_uri,
|
337
|
+
}
|
338
|
+
)
|
339
|
+
else:
|
340
|
+
yield TriggerEvent(
|
341
|
+
{
|
342
|
+
"operation_name": operation.name,
|
343
|
+
"operation_done": operation.done,
|
344
|
+
"status": status,
|
345
|
+
"message": message,
|
346
|
+
}
|
347
|
+
)
|
328
348
|
return
|
329
349
|
else:
|
330
350
|
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
|
@@ -60,7 +60,7 @@ class GCSBlobTrigger(BaseTrigger):
|
|
60
60
|
self.hook_params = hook_params
|
61
61
|
|
62
62
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
63
|
-
"""
|
63
|
+
"""Serialize GCSBlobTrigger arguments and classpath."""
|
64
64
|
return (
|
65
65
|
"airflow.providers.google.cloud.triggers.gcs.GCSBlobTrigger",
|
66
66
|
{
|
@@ -93,7 +93,7 @@ class GCSBlobTrigger(BaseTrigger):
|
|
93
93
|
|
94
94
|
async def _object_exists(self, hook: GCSAsyncHook, bucket_name: str, object_name: str) -> str:
|
95
95
|
"""
|
96
|
-
|
96
|
+
Check for the existence of a file in Google Cloud Storage.
|
97
97
|
|
98
98
|
:param bucket_name: The Google Cloud Storage bucket where the object is.
|
99
99
|
:param object_name: The name of the blob_name to check in the Google cloud
|
@@ -143,7 +143,7 @@ class GCSCheckBlobUpdateTimeTrigger(BaseTrigger):
|
|
143
143
|
self.hook_params = hook_params
|
144
144
|
|
145
145
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
146
|
-
"""
|
146
|
+
"""Serialize GCSCheckBlobUpdateTimeTrigger arguments and classpath."""
|
147
147
|
return (
|
148
148
|
"airflow.providers.google.cloud.triggers.gcs.GCSCheckBlobUpdateTimeTrigger",
|
149
149
|
{
|
@@ -181,7 +181,7 @@ class GCSCheckBlobUpdateTimeTrigger(BaseTrigger):
|
|
181
181
|
self, hook: GCSAsyncHook, bucket_name: str, object_name: str, target_date: datetime
|
182
182
|
) -> tuple[bool, dict[str, Any]]:
|
183
183
|
"""
|
184
|
-
|
184
|
+
Check if the object in the bucket is updated.
|
185
185
|
|
186
186
|
:param hook: GCSAsyncHook Hook class
|
187
187
|
:param bucket_name: The Google Cloud Storage bucket where the object is.
|
@@ -248,7 +248,7 @@ class GCSPrefixBlobTrigger(GCSBlobTrigger):
|
|
248
248
|
self.prefix = prefix
|
249
249
|
|
250
250
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
251
|
-
"""
|
251
|
+
"""Serialize GCSPrefixBlobTrigger arguments and classpath."""
|
252
252
|
return (
|
253
253
|
"airflow.providers.google.cloud.triggers.gcs.GCSPrefixBlobTrigger",
|
254
254
|
{
|
@@ -282,7 +282,7 @@ class GCSPrefixBlobTrigger(GCSBlobTrigger):
|
|
282
282
|
|
283
283
|
async def _list_blobs_with_prefix(self, hook: GCSAsyncHook, bucket_name: str, prefix: str) -> list[str]:
|
284
284
|
"""
|
285
|
-
|
285
|
+
Return names of blobs which match the given prefix for a given bucket.
|
286
286
|
|
287
287
|
:param hook: The async hook to use for listing the blobs
|
288
288
|
:param bucket_name: The Google Cloud Storage bucket where the object is.
|
@@ -344,7 +344,7 @@ class GCSUploadSessionTrigger(GCSPrefixBlobTrigger):
|
|
344
344
|
self.last_activity_time: datetime | None = None
|
345
345
|
|
346
346
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
347
|
-
"""
|
347
|
+
"""Serialize GCSUploadSessionTrigger arguments and classpath."""
|
348
348
|
return (
|
349
349
|
"airflow.providers.google.cloud.triggers.gcs.GCSUploadSessionTrigger",
|
350
350
|
{
|
@@ -377,7 +377,11 @@ class GCSUploadSessionTrigger(GCSPrefixBlobTrigger):
|
|
377
377
|
yield TriggerEvent({"status": "error", "message": str(e)})
|
378
378
|
|
379
379
|
def _get_time(self) -> datetime:
|
380
|
-
"""
|
380
|
+
"""
|
381
|
+
Get current local date and time.
|
382
|
+
|
383
|
+
This is just a wrapper of datetime.datetime.now to simplify mocking in the unittests.
|
384
|
+
"""
|
381
385
|
return datetime.now()
|
382
386
|
|
383
387
|
def _is_bucket_updated(self, current_objects: set[str]) -> dict[str, str]:
|
@@ -174,7 +174,7 @@ class GKEOperationTrigger(BaseTrigger):
|
|
174
174
|
self._hook: GKEAsyncHook | None = None
|
175
175
|
|
176
176
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
177
|
-
"""
|
177
|
+
"""Serialize GKEOperationTrigger arguments and classpath."""
|
178
178
|
return (
|
179
179
|
"airflow.providers.google.cloud.triggers.kubernetes_engine.GKEOperationTrigger",
|
180
180
|
{
|
@@ -188,7 +188,7 @@ class GKEOperationTrigger(BaseTrigger):
|
|
188
188
|
)
|
189
189
|
|
190
190
|
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
|
191
|
-
"""
|
191
|
+
"""Get operation status and yields corresponding event."""
|
192
192
|
hook = self._get_hook()
|
193
193
|
try:
|
194
194
|
while True:
|
@@ -69,7 +69,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
|
|
69
69
|
self.impersonation_chain = impersonation_chain
|
70
70
|
|
71
71
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
72
|
-
"""
|
72
|
+
"""Serialize MLEngineStartTrainingJobTrigger arguments and classpath."""
|
73
73
|
return (
|
74
74
|
"airflow.providers.google.cloud.triggers.mlengine.MLEngineStartTrainingJobTrigger",
|
75
75
|
{
|
@@ -89,7 +89,7 @@ class MLEngineStartTrainingJobTrigger(BaseTrigger):
|
|
89
89
|
)
|
90
90
|
|
91
91
|
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
|
92
|
-
"""
|
92
|
+
"""Get current job execution status and yields a TriggerEvent."""
|
93
93
|
hook = self._get_async_hook()
|
94
94
|
try:
|
95
95
|
while True:
|
@@ -79,7 +79,7 @@ class PubsubPullTrigger(BaseTrigger):
|
|
79
79
|
self.hook = PubSubAsyncHook()
|
80
80
|
|
81
81
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
82
|
-
"""
|
82
|
+
"""Serialize PubsubPullTrigger arguments and classpath."""
|
83
83
|
return (
|
84
84
|
"airflow.providers.google.cloud.triggers.pubsub.PubsubPullTrigger",
|
85
85
|
{
|
@@ -0,0 +1,99 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import Any, AsyncIterator, Sequence
|
20
|
+
|
21
|
+
from google.cloud.aiplatform_v1 import HyperparameterTuningJob, JobState
|
22
|
+
|
23
|
+
from airflow.exceptions import AirflowException
|
24
|
+
from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
|
25
|
+
HyperparameterTuningJobAsyncHook,
|
26
|
+
)
|
27
|
+
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
28
|
+
|
29
|
+
|
30
|
+
class CreateHyperparameterTuningJobTrigger(BaseTrigger):
|
31
|
+
"""CreateHyperparameterTuningJobTrigger run on the trigger worker to perform create operation."""
|
32
|
+
|
33
|
+
statuses_success = {
|
34
|
+
JobState.JOB_STATE_PAUSED,
|
35
|
+
JobState.JOB_STATE_SUCCEEDED,
|
36
|
+
}
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
conn_id: str,
|
41
|
+
project_id: str,
|
42
|
+
location: str,
|
43
|
+
job_id: str,
|
44
|
+
poll_interval: int,
|
45
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
46
|
+
):
|
47
|
+
super().__init__()
|
48
|
+
self.conn_id = conn_id
|
49
|
+
self.project_id = project_id
|
50
|
+
self.location = location
|
51
|
+
self.job_id = job_id
|
52
|
+
self.poll_interval = poll_interval
|
53
|
+
self.impersonation_chain = impersonation_chain
|
54
|
+
|
55
|
+
def serialize(self) -> tuple[str, dict[str, Any]]:
|
56
|
+
return (
|
57
|
+
"airflow.providers.google.cloud.triggers.vertex_ai.CreateHyperparameterTuningJobTrigger",
|
58
|
+
{
|
59
|
+
"conn_id": self.conn_id,
|
60
|
+
"project_id": self.project_id,
|
61
|
+
"location": self.location,
|
62
|
+
"job_id": self.job_id,
|
63
|
+
"poll_interval": self.poll_interval,
|
64
|
+
"impersonation_chain": self.impersonation_chain,
|
65
|
+
},
|
66
|
+
)
|
67
|
+
|
68
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
69
|
+
hook = self._get_async_hook()
|
70
|
+
try:
|
71
|
+
job = await hook.wait_hyperparameter_tuning_job(
|
72
|
+
project_id=self.project_id,
|
73
|
+
location=self.location,
|
74
|
+
job_id=self.job_id,
|
75
|
+
poll_interval=self.poll_interval,
|
76
|
+
)
|
77
|
+
except AirflowException as ex:
|
78
|
+
yield TriggerEvent(
|
79
|
+
{
|
80
|
+
"status": "error",
|
81
|
+
"message": str(ex),
|
82
|
+
}
|
83
|
+
)
|
84
|
+
return
|
85
|
+
|
86
|
+
status = "success" if job.state in self.statuses_success else "error"
|
87
|
+
message = f"Hyperparameter tuning job {job.name} completed with status {job.state.name}"
|
88
|
+
yield TriggerEvent(
|
89
|
+
{
|
90
|
+
"status": status,
|
91
|
+
"message": message,
|
92
|
+
"job": HyperparameterTuningJob.to_dict(job),
|
93
|
+
}
|
94
|
+
)
|
95
|
+
|
96
|
+
def _get_async_hook(self) -> HyperparameterTuningJobAsyncHook:
|
97
|
+
return HyperparameterTuningJobAsyncHook(
|
98
|
+
gcp_conn_id=self.conn_id, impersonation_chain=self.impersonation_chain
|
99
|
+
)
|
@@ -21,7 +21,7 @@ from typing import Any
|
|
21
21
|
|
22
22
|
def bq_cast(string_field: str, bq_type: str) -> None | int | float | bool | str:
|
23
23
|
"""
|
24
|
-
|
24
|
+
Cast a BigQuery row to the appropriate data types.
|
25
25
|
|
26
26
|
This is useful because BigQuery returns all fields as strings.
|
27
27
|
"""
|
@@ -41,7 +41,7 @@ def bq_cast(string_field: str, bq_type: str) -> None | int | float | bool | str:
|
|
41
41
|
|
42
42
|
def convert_job_id(job_id: str | list[str], project_id: str, location: str | None) -> Any:
|
43
43
|
"""
|
44
|
-
|
44
|
+
Convert job_id to path: project_id:location:job_id.
|
45
45
|
|
46
46
|
:param project_id: Required. The ID of the Google Cloud project where workspace located.
|
47
47
|
:param location: Optional. The ID of the Google Cloud region where workspace located.
|
@@ -358,7 +358,7 @@ class _CredentialProvider(LoggingMixin):
|
|
358
358
|
|
359
359
|
|
360
360
|
def get_credentials_and_project_id(*args, **kwargs) -> tuple[google.auth.credentials.Credentials, str]:
|
361
|
-
"""
|
361
|
+
"""Return the Credentials object for Google API and the associated project_id."""
|
362
362
|
return _CredentialProvider(*args, **kwargs).get_credentials_and_project()
|
363
363
|
|
364
364
|
|
@@ -398,7 +398,7 @@ def _get_target_principal_and_delegates(
|
|
398
398
|
|
399
399
|
def _get_project_id_from_service_account_email(service_account_email: str) -> str:
|
400
400
|
"""
|
401
|
-
|
401
|
+
Extract project_id from service account's email address.
|
402
402
|
|
403
403
|
:param service_account_email: email of the service account.
|
404
404
|
|
@@ -45,7 +45,7 @@ def make_initialization_workspace_flow(
|
|
45
45
|
without_installation: bool = False,
|
46
46
|
) -> tuple:
|
47
47
|
"""
|
48
|
-
|
48
|
+
Create flow which simulates the initialization of the default project.
|
49
49
|
|
50
50
|
:param project_id: Required. The ID of the Google Cloud project where workspace located.
|
51
51
|
:param region: Required. The ID of the Google Cloud region where workspace located.
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from enum import Enum
|
20
|
+
|
21
|
+
|
22
|
+
class DataprocOperationType(Enum):
|
23
|
+
"""Contains types of long running operations."""
|
24
|
+
|
25
|
+
DIAGNOSE = "DIAGNOSE"
|
@@ -309,7 +309,7 @@ class GcpBodyFieldValidator(LoggingMixin):
|
|
309
309
|
|
310
310
|
def _validate_field(self, validation_spec, dictionary_to_validate, parent=None, force_optional=False):
|
311
311
|
"""
|
312
|
-
|
312
|
+
Validate if field is OK.
|
313
313
|
|
314
314
|
:param validation_spec: specification of the field
|
315
315
|
:param dictionary_to_validate: dictionary where the field should be present
|
@@ -413,7 +413,7 @@ class GcpBodyFieldValidator(LoggingMixin):
|
|
413
413
|
|
414
414
|
def validate(self, body_to_validate: dict) -> None:
|
415
415
|
"""
|
416
|
-
|
416
|
+
Validate if the body (dictionary) follows specification that the validator was instantiated with.
|
417
417
|
|
418
418
|
Raises ValidationSpecificationException or ValidationFieldException in case of problems
|
419
419
|
with specification or the body not conforming to the specification respectively.
|
@@ -19,12 +19,12 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
|
21
21
|
def normalize_directory_path(source_object: str | None) -> str | None:
|
22
|
-
"""
|
22
|
+
"""Make sure dir path ends with a slash."""
|
23
23
|
return source_object + "/" if source_object and not source_object.endswith("/") else source_object
|
24
24
|
|
25
25
|
|
26
26
|
def resource_path_to_dict(resource_name: str) -> dict[str, str]:
|
27
|
-
"""
|
27
|
+
"""Convert a path-like GCP resource name into a dictionary.
|
28
28
|
|
29
29
|
For example, the path `projects/my-project/locations/my-location/instances/my-instance` will be converted
|
30
30
|
to a dict:
|
@@ -57,7 +57,7 @@ def create_evaluate_ops(
|
|
57
57
|
py_interpreter="python3",
|
58
58
|
) -> tuple[MLEngineStartBatchPredictionJobOperator, BeamRunPythonPipelineOperator, PythonOperator]:
|
59
59
|
r"""
|
60
|
-
|
60
|
+
Create Operators needed for model evaluation and returns.
|
61
61
|
|
62
62
|
This function is deprecated. All the functionality of legacy MLEngine and new features are available
|
63
63
|
on the Vertex AI platform.
|
@@ -151,7 +151,7 @@ def MakeSummary(pcoll, metric_fn, metric_keys):
|
|
151
151
|
|
152
152
|
|
153
153
|
def run(argv=None):
|
154
|
-
"""
|
154
|
+
"""Obtain prediction summary."""
|
155
155
|
parser = argparse.ArgumentParser()
|
156
156
|
parser.add_argument(
|
157
157
|
"--prediction_path",
|
@@ -52,7 +52,7 @@ def create_client_session():
|
|
52
52
|
|
53
53
|
|
54
54
|
def init_app(_):
|
55
|
-
"""
|
55
|
+
"""Initialize authentication."""
|
56
56
|
|
57
57
|
|
58
58
|
def _get_id_token_from_request(request) -> str | None:
|
@@ -110,7 +110,7 @@ T = TypeVar("T", bound=Callable)
|
|
110
110
|
|
111
111
|
|
112
112
|
def requires_authentication(function: T):
|
113
|
-
"""Decorator for
|
113
|
+
"""Decorator for function that require authentication."""
|
114
114
|
|
115
115
|
@wraps(function)
|
116
116
|
def decorated(*args, **kwargs):
|