apache-airflow-providers-google 14.0.0__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/LICENSES.txt +14 -0
- airflow/providers/google/3rd-party-licenses/NOTICE +5 -0
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/_vendor/__init__.py +0 -0
- airflow/providers/google/_vendor/json_merge_patch.py +91 -0
- airflow/providers/google/ads/hooks/ads.py +52 -43
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +3 -19
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/_internal_client/secret_manager_client.py +3 -2
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/alloy_db.py +2 -3
- airflow/providers/google/cloud/hooks/bigquery.py +195 -318
- airflow/providers/google/cloud/hooks/bigquery_dts.py +8 -8
- airflow/providers/google/cloud/hooks/bigtable.py +3 -2
- airflow/providers/google/cloud/hooks/cloud_batch.py +8 -9
- airflow/providers/google/cloud/hooks/cloud_build.py +6 -65
- airflow/providers/google/cloud/hooks/cloud_composer.py +292 -24
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +4 -3
- airflow/providers/google/cloud/hooks/cloud_run.py +20 -11
- airflow/providers/google/cloud/hooks/cloud_sql.py +136 -64
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +35 -15
- airflow/providers/google/cloud/hooks/compute.py +7 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +7 -4
- airflow/providers/google/cloud/hooks/datacatalog.py +12 -3
- airflow/providers/google/cloud/hooks/dataflow.py +87 -242
- airflow/providers/google/cloud/hooks/dataform.py +9 -14
- airflow/providers/google/cloud/hooks/datafusion.py +7 -9
- airflow/providers/google/cloud/hooks/dataplex.py +13 -12
- airflow/providers/google/cloud/hooks/dataprep.py +2 -2
- airflow/providers/google/cloud/hooks/dataproc.py +76 -74
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/hooks/dlp.py +5 -4
- airflow/providers/google/cloud/hooks/gcs.py +144 -33
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kms.py +3 -2
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -17
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/managed_kafka.py +227 -3
- airflow/providers/google/cloud/hooks/mlengine.py +7 -8
- airflow/providers/google/cloud/hooks/natural_language.py +3 -2
- airflow/providers/google/cloud/hooks/os_login.py +3 -2
- airflow/providers/google/cloud/hooks/pubsub.py +6 -6
- airflow/providers/google/cloud/hooks/secret_manager.py +105 -12
- airflow/providers/google/cloud/hooks/spanner.py +75 -10
- airflow/providers/google/cloud/hooks/speech_to_text.py +3 -2
- airflow/providers/google/cloud/hooks/stackdriver.py +18 -18
- airflow/providers/google/cloud/hooks/tasks.py +4 -3
- airflow/providers/google/cloud/hooks/text_to_speech.py +3 -2
- airflow/providers/google/cloud/hooks/translate.py +8 -17
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -222
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +9 -15
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +33 -283
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +5 -12
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +311 -10
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +7 -13
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +8 -12
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +3 -2
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/video_intelligence.py +3 -2
- airflow/providers/google/cloud/hooks/vision.py +7 -7
- airflow/providers/google/cloud/hooks/workflows.py +4 -3
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -7
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -46
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -90
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -89
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +11 -61
- airflow/providers/google/cloud/links/managed_kafka.py +11 -51
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +166 -118
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +14 -9
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +141 -40
- airflow/providers/google/cloud/openlineage/mixins.py +14 -13
- airflow/providers/google/cloud/openlineage/utils.py +19 -3
- airflow/providers/google/cloud/operators/alloy_db.py +76 -61
- airflow/providers/google/cloud/operators/bigquery.py +104 -667
- airflow/providers/google/cloud/operators/bigquery_dts.py +12 -12
- airflow/providers/google/cloud/operators/bigtable.py +38 -7
- airflow/providers/google/cloud/operators/cloud_base.py +22 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +18 -18
- airflow/providers/google/cloud/operators/cloud_build.py +80 -36
- airflow/providers/google/cloud/operators/cloud_composer.py +157 -71
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +74 -46
- airflow/providers/google/cloud/operators/cloud_run.py +39 -20
- airflow/providers/google/cloud/operators/cloud_sql.py +46 -61
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -14
- airflow/providers/google/cloud/operators/compute.py +18 -50
- airflow/providers/google/cloud/operators/datacatalog.py +167 -29
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +19 -7
- airflow/providers/google/cloud/operators/datafusion.py +43 -43
- airflow/providers/google/cloud/operators/dataplex.py +212 -126
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +134 -207
- airflow/providers/google/cloud/operators/dataproc_metastore.py +102 -84
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +24 -45
- airflow/providers/google/cloud/operators/functions.py +21 -14
- airflow/providers/google/cloud/operators/gcs.py +15 -12
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +115 -106
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +362 -40
- airflow/providers/google/cloud/operators/natural_language.py +5 -3
- airflow/providers/google/cloud/operators/pubsub.py +69 -21
- airflow/providers/google/cloud/operators/spanner.py +53 -45
- airflow/providers/google/cloud/operators/speech_to_text.py +5 -4
- airflow/providers/google/cloud/operators/stackdriver.py +5 -11
- airflow/providers/google/cloud/operators/tasks.py +6 -15
- airflow/providers/google/cloud/operators/text_to_speech.py +4 -3
- airflow/providers/google/cloud/operators/translate.py +46 -20
- airflow/providers/google/cloud/operators/translate_speech.py +4 -3
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +44 -34
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +34 -12
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +62 -53
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +75 -11
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +48 -12
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +16 -12
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +62 -14
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +35 -10
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +5 -3
- airflow/providers/google/cloud/operators/vision.py +7 -5
- airflow/providers/google/cloud/operators/workflows.py +24 -19
- airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -4
- airflow/providers/google/cloud/sensors/bigtable.py +14 -6
- airflow/providers/google/cloud/sensors/cloud_composer.py +535 -33
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -5
- airflow/providers/google/cloud/sensors/dataflow.py +27 -10
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +7 -5
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +10 -9
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/sensors/gcs.py +22 -21
- airflow/providers/google/cloud/sensors/looker.py +5 -5
- airflow/providers/google/cloud/sensors/pubsub.py +20 -20
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +6 -4
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +14 -13
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +18 -22
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -5
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +45 -38
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +44 -12
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +36 -14
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/bigquery_dts.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_build.py +3 -2
- airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +96 -5
- airflow/providers/google/cloud/triggers/dataflow.py +125 -2
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +16 -3
- airflow/providers/google/cloud/triggers/dataproc.py +124 -53
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +46 -28
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +17 -20
- airflow/providers/google/cloud/triggers/vertex_ai.py +8 -7
- airflow/providers/google/cloud/utils/bigquery.py +5 -7
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +4 -3
- airflow/providers/google/cloud/utils/dataform.py +1 -1
- airflow/providers/google/cloud/utils/external_token_supplier.py +0 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/cloud/utils/validators.py +43 -0
- airflow/providers/google/common/auth_backend/google_openid.py +26 -9
- airflow/providers/google/common/consts.py +2 -1
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +40 -43
- airflow/providers/google/common/hooks/operation_helpers.py +78 -0
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +4 -5
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +61 -216
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +30 -6
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/analytics_admin.py +3 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +4 -5
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +7 -6
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +4 -64
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/drive.py +2 -2
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +117 -72
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +1 -1
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/example_dags/example_cloud_task.py +0 -54
- airflow/providers/google/cloud/hooks/automl.py +0 -679
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1360
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -1515
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
- apache_airflow_providers_google-14.0.0.dist-info/RECORD +0 -318
- /airflow/providers/google/cloud/{example_dags → bundles}/__init__.py +0 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -23,21 +23,23 @@ import asyncio
|
|
|
23
23
|
from collections.abc import Sequence
|
|
24
24
|
from typing import TYPE_CHECKING
|
|
25
25
|
|
|
26
|
-
from airflow.exceptions import AirflowException
|
|
27
|
-
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
28
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
|
29
26
|
from google.api_core.client_options import ClientOptions
|
|
30
27
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
31
28
|
from google.cloud.aiplatform import BatchPredictionJob, Model, explain
|
|
32
29
|
from google.cloud.aiplatform_v1 import JobServiceAsyncClient, JobServiceClient, JobState, types
|
|
33
30
|
|
|
31
|
+
from airflow.exceptions import AirflowException
|
|
32
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
33
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
|
34
|
+
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
35
|
+
|
|
34
36
|
if TYPE_CHECKING:
|
|
35
37
|
from google.api_core.operation import Operation
|
|
36
38
|
from google.api_core.retry import AsyncRetry, Retry
|
|
37
39
|
from google.cloud.aiplatform_v1.services.job_service.pagers import ListBatchPredictionJobsPager
|
|
38
40
|
|
|
39
41
|
|
|
40
|
-
class BatchPredictionJobHook(GoogleBaseHook):
|
|
42
|
+
class BatchPredictionJobHook(GoogleBaseHook, OperationHelper):
|
|
41
43
|
"""Hook for Google Cloud Vertex AI Batch Prediction Job APIs."""
|
|
42
44
|
|
|
43
45
|
def __init__(
|
|
@@ -61,17 +63,9 @@ class BatchPredictionJobHook(GoogleBaseHook):
|
|
|
61
63
|
client_options = ClientOptions()
|
|
62
64
|
|
|
63
65
|
return JobServiceClient(
|
|
64
|
-
credentials=self.get_credentials(), client_info=
|
|
66
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
|
65
67
|
)
|
|
66
68
|
|
|
67
|
-
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
|
68
|
-
"""Wait for long-lasting operation to complete."""
|
|
69
|
-
try:
|
|
70
|
-
return operation.result(timeout=timeout)
|
|
71
|
-
except Exception:
|
|
72
|
-
error = operation.exception(timeout=timeout)
|
|
73
|
-
raise AirflowException(error)
|
|
74
|
-
|
|
75
69
|
@staticmethod
|
|
76
70
|
def extract_batch_prediction_job_id(obj: dict) -> str:
|
|
77
71
|
"""Return unique id of the batch_prediction_job."""
|
|
@@ -116,7 +110,7 @@ class BatchPredictionJobHook(GoogleBaseHook):
|
|
|
116
110
|
:param project_id: Required. Project to run training in.
|
|
117
111
|
:param region: Required. Location to run training in.
|
|
118
112
|
:param job_display_name: Required. The user-defined name of the BatchPredictionJob. The name can be
|
|
119
|
-
up to 128 characters long and can
|
|
113
|
+
up to 128 characters long and can consist of any UTF-8 characters.
|
|
120
114
|
:param model_name: Required. A fully-qualified model resource name or model ID.
|
|
121
115
|
:param instances_format: Required. The format in which instances are provided. Must be one of the
|
|
122
116
|
formats listed in `Model.supported_input_storage_formats`. Default is "jsonl" when using
|
|
@@ -273,7 +267,7 @@ class BatchPredictionJobHook(GoogleBaseHook):
|
|
|
273
267
|
:param project_id: Required. Project to run training in.
|
|
274
268
|
:param region: Required. Location to run training in.
|
|
275
269
|
:param job_display_name: Required. The user-defined name of the BatchPredictionJob. The name can be
|
|
276
|
-
up to 128 characters long and can
|
|
270
|
+
up to 128 characters long and can consist of any UTF-8 characters.
|
|
277
271
|
:param model_name: Required. A fully-qualified model resource name or model ID.
|
|
278
272
|
:param instances_format: Required. The format in which instances are provided. Must be one of the
|
|
279
273
|
formats listed in `Model.supported_input_storage_formats`. Default is "jsonl" when using
|
|
@@ -23,10 +23,6 @@ import asyncio
|
|
|
23
23
|
from collections.abc import Sequence
|
|
24
24
|
from typing import TYPE_CHECKING, Any
|
|
25
25
|
|
|
26
|
-
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
27
|
-
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
28
|
-
from airflow.providers.google.common.deprecated import deprecated
|
|
29
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
|
30
26
|
from google.api_core.client_options import ClientOptions
|
|
31
27
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
32
28
|
from google.cloud.aiplatform import (
|
|
@@ -46,19 +42,23 @@ from google.cloud.aiplatform_v1 import (
|
|
|
46
42
|
types,
|
|
47
43
|
)
|
|
48
44
|
|
|
45
|
+
from airflow.exceptions import AirflowException
|
|
46
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
47
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
|
|
48
|
+
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
49
|
+
|
|
49
50
|
if TYPE_CHECKING:
|
|
50
51
|
from google.api_core.operation import Operation
|
|
51
52
|
from google.api_core.retry import AsyncRetry, Retry
|
|
52
53
|
from google.auth.credentials import Credentials
|
|
53
54
|
from google.cloud.aiplatform_v1.services.job_service.pagers import ListCustomJobsPager
|
|
54
55
|
from google.cloud.aiplatform_v1.services.pipeline_service.pagers import (
|
|
55
|
-
ListPipelineJobsPager,
|
|
56
56
|
ListTrainingPipelinesPager,
|
|
57
57
|
)
|
|
58
|
-
from google.cloud.aiplatform_v1.types import CustomJob,
|
|
58
|
+
from google.cloud.aiplatform_v1.types import CustomJob, PscInterfaceConfig, TrainingPipeline
|
|
59
59
|
|
|
60
60
|
|
|
61
|
-
class CustomJobHook(GoogleBaseHook):
|
|
61
|
+
class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
62
62
|
"""Hook for Google Cloud Vertex AI Custom Job APIs."""
|
|
63
63
|
|
|
64
64
|
def __init__(
|
|
@@ -276,14 +276,6 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
276
276
|
"""Return a unique Custom Job id from a serialized TrainingPipeline proto."""
|
|
277
277
|
return training_pipeline["training_task_metadata"]["backingCustomJob"].rpartition("/")[-1]
|
|
278
278
|
|
|
279
|
-
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
|
280
|
-
"""Wait for long-lasting operation to complete."""
|
|
281
|
-
try:
|
|
282
|
-
return operation.result(timeout=timeout)
|
|
283
|
-
except Exception:
|
|
284
|
-
error = operation.exception(timeout=timeout)
|
|
285
|
-
raise AirflowException(error)
|
|
286
|
-
|
|
287
279
|
def cancel_job(self) -> None:
|
|
288
280
|
"""Cancel Job for training pipeline."""
|
|
289
281
|
if self._job:
|
|
@@ -325,6 +317,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
325
317
|
is_default_version: bool | None = None,
|
|
326
318
|
model_version_aliases: list[str] | None = None,
|
|
327
319
|
model_version_description: str | None = None,
|
|
320
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
328
321
|
) -> tuple[models.Model | None, str, str]:
|
|
329
322
|
"""Run a training pipeline job and wait until its completion."""
|
|
330
323
|
model = job.run(
|
|
@@ -358,6 +351,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
358
351
|
is_default_version=is_default_version,
|
|
359
352
|
model_version_aliases=model_version_aliases,
|
|
360
353
|
model_version_description=model_version_description,
|
|
354
|
+
psc_interface_config=psc_interface_config,
|
|
361
355
|
)
|
|
362
356
|
training_id = self.extract_training_id(job.resource_name)
|
|
363
357
|
custom_job_id = self.extract_custom_job_id(
|
|
@@ -374,54 +368,6 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
374
368
|
)
|
|
375
369
|
return model, training_id, custom_job_id
|
|
376
370
|
|
|
377
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
|
378
|
-
@deprecated(
|
|
379
|
-
planned_removal_date="March 01, 2025",
|
|
380
|
-
use_instead="PipelineJobHook.cancel_pipeline_job",
|
|
381
|
-
category=AirflowProviderDeprecationWarning,
|
|
382
|
-
)
|
|
383
|
-
def cancel_pipeline_job(
|
|
384
|
-
self,
|
|
385
|
-
project_id: str,
|
|
386
|
-
region: str,
|
|
387
|
-
pipeline_job: str,
|
|
388
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
|
389
|
-
timeout: float | None = None,
|
|
390
|
-
metadata: Sequence[tuple[str, str]] = (),
|
|
391
|
-
) -> None:
|
|
392
|
-
"""
|
|
393
|
-
Cancel a PipelineJob.
|
|
394
|
-
|
|
395
|
-
Starts asynchronous cancellation on the PipelineJob. The server makes the best
|
|
396
|
-
effort to cancel the pipeline, but success is not guaranteed. Clients can use
|
|
397
|
-
[PipelineService.GetPipelineJob][google.cloud.aiplatform.v1.PipelineService.GetPipelineJob] or other
|
|
398
|
-
methods to check whether the cancellation succeeded or whether the pipeline completed despite
|
|
399
|
-
cancellation. On successful cancellation, the PipelineJob is not deleted; instead it becomes a
|
|
400
|
-
pipeline with a [PipelineJob.error][google.cloud.aiplatform.v1.PipelineJob.error] value with a
|
|
401
|
-
[google.rpc.Status.code][google.rpc.Status.code] of 1, corresponding to ``Code.CANCELLED``, and
|
|
402
|
-
[PipelineJob.state][google.cloud.aiplatform.v1.PipelineJob.state] is set to ``CANCELLED``.
|
|
403
|
-
|
|
404
|
-
This method is deprecated, please use `PipelineJobHook.cancel_pipeline_job` method.
|
|
405
|
-
|
|
406
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
407
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
|
408
|
-
:param pipeline_job: The name of the PipelineJob to cancel.
|
|
409
|
-
:param retry: Designation of what errors, if any, should be retried.
|
|
410
|
-
:param timeout: The timeout for this request.
|
|
411
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
|
412
|
-
"""
|
|
413
|
-
client = self.get_pipeline_service_client(region)
|
|
414
|
-
name = client.pipeline_job_path(project_id, region, pipeline_job)
|
|
415
|
-
|
|
416
|
-
client.cancel_pipeline_job(
|
|
417
|
-
request={
|
|
418
|
-
"name": name,
|
|
419
|
-
},
|
|
420
|
-
retry=retry,
|
|
421
|
-
timeout=timeout,
|
|
422
|
-
metadata=metadata,
|
|
423
|
-
)
|
|
424
|
-
|
|
425
371
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
426
372
|
def cancel_training_pipeline(
|
|
427
373
|
self,
|
|
@@ -504,53 +450,6 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
504
450
|
metadata=metadata,
|
|
505
451
|
)
|
|
506
452
|
|
|
507
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
|
508
|
-
@deprecated(
|
|
509
|
-
planned_removal_date="March 01, 2025",
|
|
510
|
-
use_instead="PipelineJobHook.create_pipeline_job",
|
|
511
|
-
category=AirflowProviderDeprecationWarning,
|
|
512
|
-
)
|
|
513
|
-
def create_pipeline_job(
|
|
514
|
-
self,
|
|
515
|
-
project_id: str,
|
|
516
|
-
region: str,
|
|
517
|
-
pipeline_job: PipelineJob,
|
|
518
|
-
pipeline_job_id: str,
|
|
519
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
|
520
|
-
timeout: float | None = None,
|
|
521
|
-
metadata: Sequence[tuple[str, str]] = (),
|
|
522
|
-
) -> PipelineJob:
|
|
523
|
-
"""
|
|
524
|
-
Create a PipelineJob. A PipelineJob will run immediately when created.
|
|
525
|
-
|
|
526
|
-
This method is deprecated, please use `PipelineJobHook.create_pipeline_job` method.
|
|
527
|
-
|
|
528
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
529
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
|
530
|
-
:param pipeline_job: Required. The PipelineJob to create.
|
|
531
|
-
:param pipeline_job_id: The ID to use for the PipelineJob, which will become the final component of
|
|
532
|
-
the PipelineJob name. If not provided, an ID will be automatically generated.
|
|
533
|
-
|
|
534
|
-
This value should be less than 128 characters, and valid characters are /[a-z][0-9]-/.
|
|
535
|
-
:param retry: Designation of what errors, if any, should be retried.
|
|
536
|
-
:param timeout: The timeout for this request.
|
|
537
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
|
538
|
-
"""
|
|
539
|
-
client = self.get_pipeline_service_client(region)
|
|
540
|
-
parent = client.common_location_path(project_id, region)
|
|
541
|
-
|
|
542
|
-
result = client.create_pipeline_job(
|
|
543
|
-
request={
|
|
544
|
-
"parent": parent,
|
|
545
|
-
"pipeline_job": pipeline_job,
|
|
546
|
-
"pipeline_job_id": pipeline_job_id,
|
|
547
|
-
},
|
|
548
|
-
retry=retry,
|
|
549
|
-
timeout=timeout,
|
|
550
|
-
metadata=metadata,
|
|
551
|
-
)
|
|
552
|
-
return result
|
|
553
|
-
|
|
554
453
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
555
454
|
def create_training_pipeline(
|
|
556
455
|
self,
|
|
@@ -677,6 +576,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
677
576
|
timestamp_split_column_name: str | None = None,
|
|
678
577
|
tensorboard: str | None = None,
|
|
679
578
|
sync=True,
|
|
579
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
680
580
|
) -> tuple[models.Model | None, str, str]:
|
|
681
581
|
"""
|
|
682
582
|
Create Custom Container Training Job.
|
|
@@ -940,6 +840,8 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
940
840
|
:param sync: Whether to execute the AI Platform job synchronously. If False, this method
|
|
941
841
|
will be executed in concurrent Future and any downstream object will
|
|
942
842
|
be immediately returned and synced when the Future has completed.
|
|
843
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
844
|
+
training.
|
|
943
845
|
"""
|
|
944
846
|
self._job = self.get_custom_container_training_job(
|
|
945
847
|
project=project_id,
|
|
@@ -999,6 +901,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
999
901
|
is_default_version=is_default_version,
|
|
1000
902
|
model_version_aliases=model_version_aliases,
|
|
1001
903
|
model_version_description=model_version_description,
|
|
904
|
+
psc_interface_config=psc_interface_config,
|
|
1002
905
|
)
|
|
1003
906
|
|
|
1004
907
|
return model, training_id, custom_job_id
|
|
@@ -1061,6 +964,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
1061
964
|
model_version_aliases: list[str] | None = None,
|
|
1062
965
|
model_version_description: str | None = None,
|
|
1063
966
|
sync=True,
|
|
967
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
1064
968
|
) -> tuple[models.Model | None, str, str]:
|
|
1065
969
|
"""
|
|
1066
970
|
Create Custom Python Package Training Job.
|
|
@@ -1323,6 +1227,8 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
1323
1227
|
:param sync: Whether to execute the AI Platform job synchronously. If False, this method
|
|
1324
1228
|
will be executed in concurrent Future and any downstream object will
|
|
1325
1229
|
be immediately returned and synced when the Future has completed.
|
|
1230
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
1231
|
+
training.
|
|
1326
1232
|
"""
|
|
1327
1233
|
self._job = self.get_custom_python_package_training_job(
|
|
1328
1234
|
project=project_id,
|
|
@@ -1383,6 +1289,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
1383
1289
|
is_default_version=is_default_version,
|
|
1384
1290
|
model_version_aliases=model_version_aliases,
|
|
1385
1291
|
model_version_description=model_version_description,
|
|
1292
|
+
psc_interface_config=psc_interface_config,
|
|
1386
1293
|
)
|
|
1387
1294
|
|
|
1388
1295
|
return model, training_id, custom_job_id
|
|
@@ -1445,6 +1352,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
1445
1352
|
timestamp_split_column_name: str | None = None,
|
|
1446
1353
|
tensorboard: str | None = None,
|
|
1447
1354
|
sync=True,
|
|
1355
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
1448
1356
|
) -> tuple[models.Model | None, str, str]:
|
|
1449
1357
|
"""
|
|
1450
1358
|
Create Custom Training Job.
|
|
@@ -1707,6 +1615,8 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
1707
1615
|
:param sync: Whether to execute the AI Platform job synchronously. If False, this method
|
|
1708
1616
|
will be executed in concurrent Future and any downstream object will
|
|
1709
1617
|
be immediately returned and synced when the Future has completed.
|
|
1618
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
1619
|
+
training.
|
|
1710
1620
|
"""
|
|
1711
1621
|
self._job = self.get_custom_training_job(
|
|
1712
1622
|
project=project_id,
|
|
@@ -1767,6 +1677,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
1767
1677
|
is_default_version=is_default_version,
|
|
1768
1678
|
model_version_aliases=model_version_aliases,
|
|
1769
1679
|
model_version_description=model_version_description,
|
|
1680
|
+
psc_interface_config=psc_interface_config,
|
|
1770
1681
|
)
|
|
1771
1682
|
|
|
1772
1683
|
return model, training_id, custom_job_id
|
|
@@ -1828,6 +1739,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
1828
1739
|
predefined_split_column_name: str | None = None,
|
|
1829
1740
|
timestamp_split_column_name: str | None = None,
|
|
1830
1741
|
tensorboard: str | None = None,
|
|
1742
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
1831
1743
|
) -> CustomContainerTrainingJob:
|
|
1832
1744
|
"""
|
|
1833
1745
|
Create and submit a Custom Container Training Job pipeline, then exit without waiting for it to complete.
|
|
@@ -2088,6 +2000,8 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
2088
2000
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
2089
2001
|
For more information on configuring your service account please visit:
|
|
2090
2002
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
2003
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
2004
|
+
training.
|
|
2091
2005
|
"""
|
|
2092
2006
|
self._job = self.get_custom_container_training_job(
|
|
2093
2007
|
project=project_id,
|
|
@@ -2146,6 +2060,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
2146
2060
|
model_version_aliases=model_version_aliases,
|
|
2147
2061
|
model_version_description=model_version_description,
|
|
2148
2062
|
sync=False,
|
|
2063
|
+
psc_interface_config=psc_interface_config,
|
|
2149
2064
|
)
|
|
2150
2065
|
return self._job
|
|
2151
2066
|
|
|
@@ -2207,6 +2122,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
2207
2122
|
is_default_version: bool | None = None,
|
|
2208
2123
|
model_version_aliases: list[str] | None = None,
|
|
2209
2124
|
model_version_description: str | None = None,
|
|
2125
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
2210
2126
|
) -> CustomPythonPackageTrainingJob:
|
|
2211
2127
|
"""
|
|
2212
2128
|
Create and submit a Custom Python Package Training Job pipeline, then exit without waiting for it to complete.
|
|
@@ -2466,6 +2382,8 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
2466
2382
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
2467
2383
|
For more information on configuring your service account please visit:
|
|
2468
2384
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
2385
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
2386
|
+
training.
|
|
2469
2387
|
"""
|
|
2470
2388
|
self._job = self.get_custom_python_package_training_job(
|
|
2471
2389
|
project=project_id,
|
|
@@ -2525,6 +2443,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
2525
2443
|
model_version_aliases=model_version_aliases,
|
|
2526
2444
|
model_version_description=model_version_description,
|
|
2527
2445
|
sync=False,
|
|
2446
|
+
psc_interface_config=psc_interface_config,
|
|
2528
2447
|
)
|
|
2529
2448
|
|
|
2530
2449
|
return self._job
|
|
@@ -2587,6 +2506,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
2587
2506
|
predefined_split_column_name: str | None = None,
|
|
2588
2507
|
timestamp_split_column_name: str | None = None,
|
|
2589
2508
|
tensorboard: str | None = None,
|
|
2509
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
2590
2510
|
) -> CustomTrainingJob:
|
|
2591
2511
|
"""
|
|
2592
2512
|
Create and submit a Custom Training Job pipeline, then exit without waiting for it to complete.
|
|
@@ -2850,6 +2770,8 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
2850
2770
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
2851
2771
|
For more information on configuring your service account please visit:
|
|
2852
2772
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
2773
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
2774
|
+
training.
|
|
2853
2775
|
"""
|
|
2854
2776
|
self._job = self.get_custom_training_job(
|
|
2855
2777
|
project=project_id,
|
|
@@ -2909,6 +2831,7 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
2909
2831
|
model_version_aliases=model_version_aliases,
|
|
2910
2832
|
model_version_description=model_version_description,
|
|
2911
2833
|
sync=False,
|
|
2834
|
+
psc_interface_config=psc_interface_config,
|
|
2912
2835
|
)
|
|
2913
2836
|
return self._job
|
|
2914
2837
|
|
|
@@ -2976,46 +2899,6 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
2976
2899
|
)
|
|
2977
2900
|
return result
|
|
2978
2901
|
|
|
2979
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
|
2980
|
-
@deprecated(
|
|
2981
|
-
planned_removal_date="March 01, 2025",
|
|
2982
|
-
use_instead="PipelineJobHook.get_pipeline_job",
|
|
2983
|
-
category=AirflowProviderDeprecationWarning,
|
|
2984
|
-
)
|
|
2985
|
-
def get_pipeline_job(
|
|
2986
|
-
self,
|
|
2987
|
-
project_id: str,
|
|
2988
|
-
region: str,
|
|
2989
|
-
pipeline_job: str,
|
|
2990
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
|
2991
|
-
timeout: float | None = None,
|
|
2992
|
-
metadata: Sequence[tuple[str, str]] = (),
|
|
2993
|
-
) -> PipelineJob:
|
|
2994
|
-
"""
|
|
2995
|
-
Get a PipelineJob.
|
|
2996
|
-
|
|
2997
|
-
This method is deprecated, please use `PipelineJobHook.get_pipeline_job` method.
|
|
2998
|
-
|
|
2999
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
3000
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
|
3001
|
-
:param pipeline_job: Required. The name of the PipelineJob resource.
|
|
3002
|
-
:param retry: Designation of what errors, if any, should be retried.
|
|
3003
|
-
:param timeout: The timeout for this request.
|
|
3004
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
|
3005
|
-
"""
|
|
3006
|
-
client = self.get_pipeline_service_client(region)
|
|
3007
|
-
name = client.pipeline_job_path(project_id, region, pipeline_job)
|
|
3008
|
-
|
|
3009
|
-
result = client.get_pipeline_job(
|
|
3010
|
-
request={
|
|
3011
|
-
"name": name,
|
|
3012
|
-
},
|
|
3013
|
-
retry=retry,
|
|
3014
|
-
timeout=timeout,
|
|
3015
|
-
metadata=metadata,
|
|
3016
|
-
)
|
|
3017
|
-
return result
|
|
3018
|
-
|
|
3019
2902
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
3020
2903
|
def get_training_pipeline(
|
|
3021
2904
|
self,
|
|
@@ -3082,101 +2965,6 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
3082
2965
|
)
|
|
3083
2966
|
return result
|
|
3084
2967
|
|
|
3085
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
|
3086
|
-
@deprecated(
|
|
3087
|
-
planned_removal_date="March 01, 2025",
|
|
3088
|
-
use_instead="PipelineJobHook.list_pipeline_jobs",
|
|
3089
|
-
category=AirflowProviderDeprecationWarning,
|
|
3090
|
-
)
|
|
3091
|
-
def list_pipeline_jobs(
|
|
3092
|
-
self,
|
|
3093
|
-
project_id: str,
|
|
3094
|
-
region: str,
|
|
3095
|
-
page_size: int | None = None,
|
|
3096
|
-
page_token: str | None = None,
|
|
3097
|
-
filter: str | None = None,
|
|
3098
|
-
order_by: str | None = None,
|
|
3099
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
|
3100
|
-
timeout: float | None = None,
|
|
3101
|
-
metadata: Sequence[tuple[str, str]] = (),
|
|
3102
|
-
) -> ListPipelineJobsPager:
|
|
3103
|
-
"""
|
|
3104
|
-
List PipelineJobs in a Location.
|
|
3105
|
-
|
|
3106
|
-
This method is deprecated, please use `PipelineJobHook.list_pipeline_jobs` method.
|
|
3107
|
-
|
|
3108
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
3109
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
|
3110
|
-
:param filter: Optional. Lists the PipelineJobs that match the filter expression. The
|
|
3111
|
-
following fields are supported:
|
|
3112
|
-
|
|
3113
|
-
- ``pipeline_name``: Supports ``=`` and ``!=`` comparisons.
|
|
3114
|
-
- ``display_name``: Supports ``=``, ``!=`` comparisons, and
|
|
3115
|
-
``:`` wildcard.
|
|
3116
|
-
- ``pipeline_job_user_id``: Supports ``=``, ``!=``
|
|
3117
|
-
comparisons, and ``:`` wildcard. for example, can check
|
|
3118
|
-
if pipeline's display_name contains *step* by doing
|
|
3119
|
-
display_name:"*step*"
|
|
3120
|
-
- ``create_time``: Supports ``=``, ``!=``, ``<``, ``>``,
|
|
3121
|
-
``<=``, and ``>=`` comparisons. Values must be in RFC
|
|
3122
|
-
3339 format.
|
|
3123
|
-
- ``update_time``: Supports ``=``, ``!=``, ``<``, ``>``,
|
|
3124
|
-
``<=``, and ``>=`` comparisons. Values must be in RFC
|
|
3125
|
-
3339 format.
|
|
3126
|
-
- ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``,
|
|
3127
|
-
``<=``, and ``>=`` comparisons. Values must be in RFC
|
|
3128
|
-
3339 format.
|
|
3129
|
-
- ``labels``: Supports key-value equality and key presence.
|
|
3130
|
-
|
|
3131
|
-
Filter expressions can be combined together using logical
|
|
3132
|
-
operators (``AND`` & ``OR``). For example:
|
|
3133
|
-
``pipeline_name="test" AND create_time>"2020-05-18T13:30:00Z"``.
|
|
3134
|
-
|
|
3135
|
-
The syntax to define filter expression is based on
|
|
3136
|
-
https://google.aip.dev/160.
|
|
3137
|
-
:param page_size: Optional. The standard list page size.
|
|
3138
|
-
:param page_token: Optional. The standard list page token. Typically obtained via
|
|
3139
|
-
[ListPipelineJobsResponse.next_page_token][google.cloud.aiplatform.v1.ListPipelineJobsResponse.next_page_token]
|
|
3140
|
-
of the previous
|
|
3141
|
-
[PipelineService.ListPipelineJobs][google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs]
|
|
3142
|
-
call.
|
|
3143
|
-
:param order_by: Optional. A comma-separated list of fields to order by. The default
|
|
3144
|
-
sort order is in ascending order. Use "desc" after a field
|
|
3145
|
-
name for descending. You can have multiple order_by fields
|
|
3146
|
-
provided e.g. "create_time desc, end_time", "end_time,
|
|
3147
|
-
start_time, update_time" For example, using "create_time
|
|
3148
|
-
desc, end_time" will order results by create time in
|
|
3149
|
-
descending order, and if there are multiple jobs having the
|
|
3150
|
-
same create time, order them by the end time in ascending
|
|
3151
|
-
order. if order_by is not specified, it will order by
|
|
3152
|
-
default order is create time in descending order. Supported
|
|
3153
|
-
fields:
|
|
3154
|
-
|
|
3155
|
-
- ``create_time``
|
|
3156
|
-
- ``update_time``
|
|
3157
|
-
- ``end_time``
|
|
3158
|
-
- ``start_time``
|
|
3159
|
-
:param retry: Designation of what errors, if any, should be retried.
|
|
3160
|
-
:param timeout: The timeout for this request.
|
|
3161
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
|
3162
|
-
"""
|
|
3163
|
-
client = self.get_pipeline_service_client(region)
|
|
3164
|
-
parent = client.common_location_path(project_id, region)
|
|
3165
|
-
|
|
3166
|
-
result = client.list_pipeline_jobs(
|
|
3167
|
-
request={
|
|
3168
|
-
"parent": parent,
|
|
3169
|
-
"page_size": page_size,
|
|
3170
|
-
"page_token": page_token,
|
|
3171
|
-
"filter": filter,
|
|
3172
|
-
"order_by": order_by,
|
|
3173
|
-
},
|
|
3174
|
-
retry=retry,
|
|
3175
|
-
timeout=timeout,
|
|
3176
|
-
metadata=metadata,
|
|
3177
|
-
)
|
|
3178
|
-
return result
|
|
3179
|
-
|
|
3180
2968
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
3181
2969
|
def list_training_pipelines(
|
|
3182
2970
|
self,
|
|
@@ -3299,44 +3087,6 @@ class CustomJobHook(GoogleBaseHook):
|
|
|
3299
3087
|
)
|
|
3300
3088
|
return result
|
|
3301
3089
|
|
|
3302
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
|
3303
|
-
@deprecated(
|
|
3304
|
-
planned_removal_date="March 01, 2025",
|
|
3305
|
-
use_instead="PipelineJobHook.delete_pipeline_job",
|
|
3306
|
-
category=AirflowProviderDeprecationWarning,
|
|
3307
|
-
)
|
|
3308
|
-
def delete_pipeline_job(
|
|
3309
|
-
self,
|
|
3310
|
-
project_id: str,
|
|
3311
|
-
region: str,
|
|
3312
|
-
pipeline_job: str,
|
|
3313
|
-
retry: Retry | _MethodDefault = DEFAULT,
|
|
3314
|
-
timeout: float | None = None,
|
|
3315
|
-
metadata: Sequence[tuple[str, str]] = (),
|
|
3316
|
-
) -> Operation:
|
|
3317
|
-
"""
|
|
3318
|
-
Delete a PipelineJob.
|
|
3319
|
-
|
|
3320
|
-
This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method.
|
|
3321
|
-
|
|
3322
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
3323
|
-
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
|
3324
|
-
:param pipeline_job: Required. The name of the PipelineJob resource to be deleted.
|
|
3325
|
-
:param retry: Designation of what errors, if any, should be retried.
|
|
3326
|
-
:param timeout: The timeout for this request.
|
|
3327
|
-
:param metadata: Strings which should be sent along with the request as metadata.
|
|
3328
|
-
"""
|
|
3329
|
-
client = self.get_pipeline_service_client(region)
|
|
3330
|
-
name = client.pipeline_job_path(project_id, region, pipeline_job)
|
|
3331
|
-
|
|
3332
|
-
result = client.delete_pipeline_job(
|
|
3333
|
-
request={"name": name},
|
|
3334
|
-
retry=retry,
|
|
3335
|
-
timeout=timeout,
|
|
3336
|
-
metadata=metadata,
|
|
3337
|
-
)
|
|
3338
|
-
return result
|
|
3339
|
-
|
|
3340
3090
|
|
|
3341
3091
|
class CustomJobAsyncHook(GoogleBaseAsyncHook):
|
|
3342
3092
|
"""Async hook for Custom Job Service Client."""
|
|
@@ -22,13 +22,14 @@ from __future__ import annotations
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
23
|
from typing import TYPE_CHECKING
|
|
24
24
|
|
|
25
|
-
from airflow.exceptions import AirflowException
|
|
26
|
-
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
27
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
28
25
|
from google.api_core.client_options import ClientOptions
|
|
29
26
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
30
27
|
from google.cloud.aiplatform_v1 import DatasetServiceClient
|
|
31
28
|
|
|
29
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
30
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
31
|
+
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
32
|
+
|
|
32
33
|
if TYPE_CHECKING:
|
|
33
34
|
from google.api_core.operation import Operation
|
|
34
35
|
from google.api_core.retry import Retry
|
|
@@ -41,7 +42,7 @@ if TYPE_CHECKING:
|
|
|
41
42
|
from google.protobuf.field_mask_pb2 import FieldMask
|
|
42
43
|
|
|
43
44
|
|
|
44
|
-
class DatasetHook(GoogleBaseHook):
|
|
45
|
+
class DatasetHook(GoogleBaseHook, OperationHelper):
|
|
45
46
|
"""Hook for Google Cloud Vertex AI Dataset APIs."""
|
|
46
47
|
|
|
47
48
|
def get_dataset_service_client(self, region: str | None = None) -> DatasetServiceClient:
|
|
@@ -55,14 +56,6 @@ class DatasetHook(GoogleBaseHook):
|
|
|
55
56
|
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
|
56
57
|
)
|
|
57
58
|
|
|
58
|
-
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
|
59
|
-
"""Wait for long-lasting operation to complete."""
|
|
60
|
-
try:
|
|
61
|
-
return operation.result(timeout=timeout)
|
|
62
|
-
except Exception:
|
|
63
|
-
error = operation.exception(timeout=timeout)
|
|
64
|
-
raise AirflowException(error)
|
|
65
|
-
|
|
66
59
|
@staticmethod
|
|
67
60
|
def extract_dataset_id(obj: dict) -> str:
|
|
68
61
|
"""Return unique id of the dataset."""
|
|
@@ -22,12 +22,14 @@ from __future__ import annotations
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
23
|
from typing import TYPE_CHECKING
|
|
24
24
|
|
|
25
|
-
from airflow.exceptions import AirflowException
|
|
26
|
-
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
27
25
|
from google.api_core.client_options import ClientOptions
|
|
28
26
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
29
27
|
from google.cloud.aiplatform_v1 import EndpointServiceClient
|
|
30
28
|
|
|
29
|
+
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
30
|
+
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
31
|
+
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
32
|
+
|
|
31
33
|
if TYPE_CHECKING:
|
|
32
34
|
from google.api_core.operation import Operation
|
|
33
35
|
from google.api_core.retry import Retry
|
|
@@ -36,7 +38,7 @@ if TYPE_CHECKING:
|
|
|
36
38
|
from google.protobuf.field_mask_pb2 import FieldMask
|
|
37
39
|
|
|
38
40
|
|
|
39
|
-
class EndpointServiceHook(GoogleBaseHook):
|
|
41
|
+
class EndpointServiceHook(GoogleBaseHook, OperationHelper):
|
|
40
42
|
"""Hook for Google Cloud Vertex AI Endpoint Service APIs."""
|
|
41
43
|
|
|
42
44
|
def get_endpoint_service_client(self, region: str | None = None) -> EndpointServiceClient:
|
|
@@ -47,17 +49,9 @@ class EndpointServiceHook(GoogleBaseHook):
|
|
|
47
49
|
client_options = ClientOptions()
|
|
48
50
|
|
|
49
51
|
return EndpointServiceClient(
|
|
50
|
-
credentials=self.get_credentials(), client_info=
|
|
52
|
+
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
|
|
51
53
|
)
|
|
52
54
|
|
|
53
|
-
def wait_for_operation(self, operation: Operation, timeout: float | None = None):
|
|
54
|
-
"""Wait for long-lasting operation to complete."""
|
|
55
|
-
try:
|
|
56
|
-
return operation.result(timeout=timeout)
|
|
57
|
-
except Exception:
|
|
58
|
-
error = operation.exception(timeout=timeout)
|
|
59
|
-
raise AirflowException(error)
|
|
60
|
-
|
|
61
55
|
@staticmethod
|
|
62
56
|
def extract_endpoint_id(obj: dict) -> str:
|
|
63
57
|
"""Return unique id of the endpoint."""
|