apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +39 -5
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/bigquery.py +166 -281
- airflow/providers/google/cloud/hooks/cloud_composer.py +287 -14
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_run.py +17 -9
- airflow/providers/google/cloud/hooks/cloud_sql.py +101 -22
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +27 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +5 -1
- airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
- airflow/providers/google/cloud/hooks/dataflow.py +71 -94
- airflow/providers/google/cloud/hooks/datafusion.py +1 -1
- airflow/providers/google/cloud/hooks/dataplex.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +72 -71
- airflow/providers/google/cloud/hooks/gcs.py +111 -14
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/mlengine.py +3 -2
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +73 -8
- airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -209
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +2 -2
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +27 -1
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/vision.py +2 -2
- airflow/providers/google/cloud/hooks/workflows.py +1 -1
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -13
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -95
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +17 -9
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +10 -8
- airflow/providers/google/cloud/openlineage/utils.py +15 -1
- airflow/providers/google/cloud/operators/alloy_db.py +70 -55
- airflow/providers/google/cloud/operators/bigquery.py +73 -636
- airflow/providers/google/cloud/operators/bigquery_dts.py +3 -5
- airflow/providers/google/cloud/operators/bigtable.py +36 -7
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +2 -2
- airflow/providers/google/cloud/operators/cloud_build.py +75 -32
- airflow/providers/google/cloud/operators/cloud_composer.py +128 -40
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
- airflow/providers/google/cloud/operators/cloud_run.py +23 -5
- airflow/providers/google/cloud/operators/cloud_sql.py +8 -16
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -11
- airflow/providers/google/cloud/operators/compute.py +8 -40
- airflow/providers/google/cloud/operators/datacatalog.py +157 -21
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +15 -5
- airflow/providers/google/cloud/operators/datafusion.py +41 -20
- airflow/providers/google/cloud/operators/dataplex.py +193 -109
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +78 -35
- airflow/providers/google/cloud/operators/dataproc_metastore.py +96 -88
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +6 -29
- airflow/providers/google/cloud/operators/functions.py +16 -7
- airflow/providers/google/cloud/operators/gcs.py +10 -8
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +60 -99
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +107 -52
- airflow/providers/google/cloud/operators/natural_language.py +1 -1
- airflow/providers/google/cloud/operators/pubsub.py +60 -14
- airflow/providers/google/cloud/operators/spanner.py +25 -12
- airflow/providers/google/cloud/operators/speech_to_text.py +1 -2
- airflow/providers/google/cloud/operators/stackdriver.py +1 -9
- airflow/providers/google/cloud/operators/tasks.py +1 -12
- airflow/providers/google/cloud/operators/text_to_speech.py +1 -2
- airflow/providers/google/cloud/operators/translate.py +40 -16
- airflow/providers/google/cloud/operators/translate_speech.py +1 -2
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +29 -9
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +54 -26
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +11 -9
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +30 -7
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
- airflow/providers/google/cloud/operators/vision.py +2 -2
- airflow/providers/google/cloud/operators/workflows.py +18 -15
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -2
- airflow/providers/google/cloud/sensors/bigtable.py +11 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +533 -29
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -2
- airflow/providers/google/cloud/sensors/dataflow.py +26 -9
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +2 -2
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -2
- airflow/providers/google/cloud/sensors/gcs.py +4 -4
- airflow/providers/google/cloud/sensors/looker.py +2 -2
- airflow/providers/google/cloud/sensors/pubsub.py +4 -4
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +2 -2
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +20 -12
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +13 -4
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_composer.py +302 -46
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +91 -1
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +122 -52
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +15 -19
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/common/auth_backend/google_openid.py +4 -4
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +27 -8
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +3 -4
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +56 -52
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +26 -1
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +1 -2
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +3 -63
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +92 -48
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/hooks/automl.py +0 -673
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1362
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -112
- apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -38,7 +38,6 @@ from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
|
|
|
38
38
|
|
|
39
39
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
40
40
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
41
|
-
from airflow.providers.google.common.deprecated import deprecated
|
|
42
41
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|
43
42
|
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
|
|
44
43
|
|
|
@@ -185,42 +184,6 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
|
|
|
185
184
|
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
|
186
185
|
)
|
|
187
186
|
|
|
188
|
-
@deprecated(
|
|
189
|
-
planned_removal_date="June 15, 2025",
|
|
190
|
-
category=AirflowProviderDeprecationWarning,
|
|
191
|
-
reason="Deprecation of AutoMLText API",
|
|
192
|
-
)
|
|
193
|
-
def get_auto_ml_text_training_job(
|
|
194
|
-
self,
|
|
195
|
-
display_name: str,
|
|
196
|
-
prediction_type: str,
|
|
197
|
-
multi_label: bool = False,
|
|
198
|
-
sentiment_max: int = 10,
|
|
199
|
-
project: str | None = None,
|
|
200
|
-
location: str | None = None,
|
|
201
|
-
labels: dict[str, str] | None = None,
|
|
202
|
-
training_encryption_spec_key_name: str | None = None,
|
|
203
|
-
model_encryption_spec_key_name: str | None = None,
|
|
204
|
-
) -> AutoMLTextTrainingJob:
|
|
205
|
-
"""
|
|
206
|
-
Return AutoMLTextTrainingJob object.
|
|
207
|
-
|
|
208
|
-
WARNING: Text creation API is deprecated since September 15, 2024
|
|
209
|
-
(https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
|
|
210
|
-
"""
|
|
211
|
-
return AutoMLTextTrainingJob(
|
|
212
|
-
display_name=display_name,
|
|
213
|
-
prediction_type=prediction_type,
|
|
214
|
-
multi_label=multi_label,
|
|
215
|
-
sentiment_max=sentiment_max,
|
|
216
|
-
project=project,
|
|
217
|
-
location=location,
|
|
218
|
-
credentials=self.get_credentials(),
|
|
219
|
-
labels=labels,
|
|
220
|
-
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
|
221
|
-
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
|
222
|
-
)
|
|
223
|
-
|
|
224
187
|
def get_auto_ml_video_training_job(
|
|
225
188
|
self,
|
|
226
189
|
display_name: str,
|
|
@@ -987,178 +950,6 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
|
|
|
987
950
|
)
|
|
988
951
|
return model, training_id
|
|
989
952
|
|
|
990
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
|
991
|
-
@deprecated(
|
|
992
|
-
planned_removal_date="September 15, 2025",
|
|
993
|
-
category=AirflowProviderDeprecationWarning,
|
|
994
|
-
reason="Deprecation of AutoMLText API",
|
|
995
|
-
)
|
|
996
|
-
def create_auto_ml_text_training_job(
|
|
997
|
-
self,
|
|
998
|
-
project_id: str,
|
|
999
|
-
region: str,
|
|
1000
|
-
display_name: str,
|
|
1001
|
-
dataset: datasets.TextDataset,
|
|
1002
|
-
prediction_type: str,
|
|
1003
|
-
multi_label: bool = False,
|
|
1004
|
-
sentiment_max: int = 10,
|
|
1005
|
-
labels: dict[str, str] | None = None,
|
|
1006
|
-
training_encryption_spec_key_name: str | None = None,
|
|
1007
|
-
model_encryption_spec_key_name: str | None = None,
|
|
1008
|
-
training_fraction_split: float | None = None,
|
|
1009
|
-
validation_fraction_split: float | None = None,
|
|
1010
|
-
test_fraction_split: float | None = None,
|
|
1011
|
-
training_filter_split: str | None = None,
|
|
1012
|
-
validation_filter_split: str | None = None,
|
|
1013
|
-
test_filter_split: str | None = None,
|
|
1014
|
-
model_display_name: str | None = None,
|
|
1015
|
-
model_labels: dict[str, str] | None = None,
|
|
1016
|
-
sync: bool = True,
|
|
1017
|
-
parent_model: str | None = None,
|
|
1018
|
-
is_default_version: bool | None = None,
|
|
1019
|
-
model_version_aliases: list[str] | None = None,
|
|
1020
|
-
model_version_description: str | None = None,
|
|
1021
|
-
) -> tuple[models.Model | None, str]:
|
|
1022
|
-
"""
|
|
1023
|
-
Create an AutoML Text Training Job.
|
|
1024
|
-
|
|
1025
|
-
WARNING: Text creation API is deprecated since September 15, 2024
|
|
1026
|
-
(https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
|
|
1027
|
-
|
|
1028
|
-
:param project_id: Required. Project to run training in.
|
|
1029
|
-
:param region: Required. Location to run training in.
|
|
1030
|
-
:param display_name: Required. The user-defined name of this TrainingPipeline.
|
|
1031
|
-
:param dataset: Required. The dataset within the same Project from which data will be used to train
|
|
1032
|
-
the Model. The Dataset must use schema compatible with Model being trained, and what is
|
|
1033
|
-
compatible should be described in the used TrainingPipeline's [training_task_definition]
|
|
1034
|
-
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
|
|
1035
|
-
:param prediction_type: The type of prediction the Model is to produce, one of:
|
|
1036
|
-
"classification" - A classification model analyzes text data and returns a list of categories
|
|
1037
|
-
that apply to the text found in the data. Vertex AI offers both single-label and multi-label text
|
|
1038
|
-
classification models.
|
|
1039
|
-
"extraction" - An entity extraction model inspects text data for known entities referenced in the
|
|
1040
|
-
data and labels those entities in the text.
|
|
1041
|
-
"sentiment" - A sentiment analysis model inspects text data and identifies the prevailing
|
|
1042
|
-
emotional opinion within it, especially to determine a writer's attitude as positive, negative,
|
|
1043
|
-
or neutral.
|
|
1044
|
-
:param parent_model: Optional. The resource name or model ID of an existing model.
|
|
1045
|
-
The new model uploaded by this job will be a version of `parent_model`.
|
|
1046
|
-
Only set this field when training a new version of an existing model.
|
|
1047
|
-
:param is_default_version: Optional. When set to True, the newly uploaded model version will
|
|
1048
|
-
automatically have alias "default" included. Subsequent uses of
|
|
1049
|
-
the model produced by this job without a version specified will
|
|
1050
|
-
use this "default" version.
|
|
1051
|
-
When set to False, the "default" alias will not be moved.
|
|
1052
|
-
Actions targeting the model version produced by this job will need
|
|
1053
|
-
to specifically reference this version by ID or alias.
|
|
1054
|
-
New model uploads, i.e. version 1, will always be "default" aliased.
|
|
1055
|
-
:param model_version_aliases: Optional. User provided version aliases so that the model version
|
|
1056
|
-
uploaded by this job can be referenced via alias instead of
|
|
1057
|
-
auto-generated version ID. A default version alias will be created
|
|
1058
|
-
for the first version of the model.
|
|
1059
|
-
The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
|
|
1060
|
-
:param model_version_description: Optional. The description of the model version
|
|
1061
|
-
being uploaded by this job.
|
|
1062
|
-
:param multi_label: Required and only applicable for text classification task. If false, a
|
|
1063
|
-
single-label (multi-class) Model will be trained (i.e. assuming that for each text snippet just
|
|
1064
|
-
up to one annotation may be applicable). If true, a multi-label Model will be trained (i.e.
|
|
1065
|
-
assuming that for each text snippet multiple annotations may be applicable).
|
|
1066
|
-
:param sentiment_max: Required and only applicable for sentiment task. A sentiment is expressed as an
|
|
1067
|
-
integer ordinal, where higher value means a more positive sentiment. The range of sentiments that
|
|
1068
|
-
will be used is between 0 and sentimentMax (inclusive on both ends), and all the values in the
|
|
1069
|
-
range must be represented in the dataset before a model can be created. Only the Annotations with
|
|
1070
|
-
this sentimentMax will be used for training. sentimentMax value must be between 1 and 10
|
|
1071
|
-
(inclusive).
|
|
1072
|
-
:param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
|
|
1073
|
-
keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
|
|
1074
|
-
lowercase letters, numeric characters, underscores and dashes. International characters are
|
|
1075
|
-
allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
|
|
1076
|
-
:param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
|
1077
|
-
managed encryption key used to protect the training pipeline. Has the form:
|
|
1078
|
-
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
|
1079
|
-
The key needs to be in the same region as where the compute resource is created.
|
|
1080
|
-
If set, this TrainingPipeline will be secured by this key.
|
|
1081
|
-
Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
|
|
1082
|
-
is not set separately.
|
|
1083
|
-
:param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
|
|
1084
|
-
managed encryption key used to protect the model. Has the form:
|
|
1085
|
-
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
|
|
1086
|
-
The key needs to be in the same region as where the compute resource is created.
|
|
1087
|
-
If set, the trained Model will be secured by this key.
|
|
1088
|
-
:param training_fraction_split: Optional. The fraction of the input data that is to be used to train
|
|
1089
|
-
the Model. This is ignored if Dataset is not provided.
|
|
1090
|
-
:param validation_fraction_split: Optional. The fraction of the input data that is to be used to
|
|
1091
|
-
validate the Model. This is ignored if Dataset is not provided.
|
|
1092
|
-
:param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
|
|
1093
|
-
the Model. This is ignored if Dataset is not provided.
|
|
1094
|
-
:param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
|
1095
|
-
this filter are used to train the Model. A filter with same syntax as the one used in
|
|
1096
|
-
DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
|
|
1097
|
-
FilterSplit filters, then it is assigned to the first set that applies to it in the training,
|
|
1098
|
-
validation, test order. This is ignored if Dataset is not provided.
|
|
1099
|
-
:param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
|
|
1100
|
-
this filter are used to validate the Model. A filter with same syntax as the one used in
|
|
1101
|
-
DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
|
|
1102
|
-
FilterSplit filters, then it is assigned to the first set that applies to it in the training,
|
|
1103
|
-
validation, test order. This is ignored if Dataset is not provided.
|
|
1104
|
-
:param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this
|
|
1105
|
-
filter are used to test the Model. A filter with same syntax as the one used in
|
|
1106
|
-
DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
|
|
1107
|
-
FilterSplit filters, then it is assigned to the first set that applies to it in the training,
|
|
1108
|
-
validation, test order. This is ignored if Dataset is not provided.
|
|
1109
|
-
:param model_display_name: Optional. The display name of the managed Vertex AI Model. The name can be
|
|
1110
|
-
up to 128 characters long and can consist of any UTF-8 characters.
|
|
1111
|
-
If not provided upon creation, the job's display_name is used.
|
|
1112
|
-
:param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
|
|
1113
|
-
keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
|
|
1114
|
-
lowercase letters, numeric characters, underscores and dashes. International characters are
|
|
1115
|
-
allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
|
|
1116
|
-
:param sync: Whether to execute this method synchronously. If False, this method will be executed in
|
|
1117
|
-
concurrent Future and any downstream object will be immediately returned and synced when the
|
|
1118
|
-
Future has completed.
|
|
1119
|
-
"""
|
|
1120
|
-
self._job = AutoMLTextTrainingJob(
|
|
1121
|
-
display_name=display_name,
|
|
1122
|
-
prediction_type=prediction_type,
|
|
1123
|
-
multi_label=multi_label,
|
|
1124
|
-
sentiment_max=sentiment_max,
|
|
1125
|
-
project=project_id,
|
|
1126
|
-
location=region,
|
|
1127
|
-
credentials=self.get_credentials(),
|
|
1128
|
-
labels=labels,
|
|
1129
|
-
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
|
1130
|
-
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
|
1131
|
-
)
|
|
1132
|
-
|
|
1133
|
-
if not self._job:
|
|
1134
|
-
raise AirflowException("AutoMLTextTrainingJob was not created")
|
|
1135
|
-
|
|
1136
|
-
model = self._job.run(
|
|
1137
|
-
dataset=dataset, # type: ignore[arg-type]
|
|
1138
|
-
training_fraction_split=training_fraction_split, # type: ignore[call-arg]
|
|
1139
|
-
validation_fraction_split=validation_fraction_split, # type: ignore[call-arg]
|
|
1140
|
-
test_fraction_split=test_fraction_split,
|
|
1141
|
-
training_filter_split=training_filter_split,
|
|
1142
|
-
validation_filter_split=validation_filter_split,
|
|
1143
|
-
test_filter_split=test_filter_split, # type: ignore[call-arg]
|
|
1144
|
-
model_display_name=model_display_name,
|
|
1145
|
-
model_labels=model_labels,
|
|
1146
|
-
sync=sync,
|
|
1147
|
-
parent_model=parent_model,
|
|
1148
|
-
is_default_version=is_default_version,
|
|
1149
|
-
model_version_aliases=model_version_aliases,
|
|
1150
|
-
model_version_description=model_version_description,
|
|
1151
|
-
)
|
|
1152
|
-
training_id = self.extract_training_id(self._job.resource_name)
|
|
1153
|
-
if model:
|
|
1154
|
-
model.wait()
|
|
1155
|
-
else:
|
|
1156
|
-
self.log.warning(
|
|
1157
|
-
"Training did not produce a Managed Model returning None. AutoML Text Training "
|
|
1158
|
-
"Pipeline is not configured to upload a Model."
|
|
1159
|
-
)
|
|
1160
|
-
return model, training_id
|
|
1161
|
-
|
|
1162
953
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
1163
954
|
def create_auto_ml_video_training_job(
|
|
1164
955
|
self,
|
|
@@ -110,7 +110,7 @@ class BatchPredictionJobHook(GoogleBaseHook, OperationHelper):
|
|
|
110
110
|
:param project_id: Required. Project to run training in.
|
|
111
111
|
:param region: Required. Location to run training in.
|
|
112
112
|
:param job_display_name: Required. The user-defined name of the BatchPredictionJob. The name can be
|
|
113
|
-
up to 128 characters long and can
|
|
113
|
+
up to 128 characters long and can consist of any UTF-8 characters.
|
|
114
114
|
:param model_name: Required. A fully-qualified model resource name or model ID.
|
|
115
115
|
:param instances_format: Required. The format in which instances are provided. Must be one of the
|
|
116
116
|
formats listed in `Model.supported_input_storage_formats`. Default is "jsonl" when using
|
|
@@ -267,7 +267,7 @@ class BatchPredictionJobHook(GoogleBaseHook, OperationHelper):
|
|
|
267
267
|
:param project_id: Required. Project to run training in.
|
|
268
268
|
:param region: Required. Location to run training in.
|
|
269
269
|
:param job_display_name: Required. The user-defined name of the BatchPredictionJob. The name can be
|
|
270
|
-
up to 128 characters long and can
|
|
270
|
+
up to 128 characters long and can consist of any UTF-8 characters.
|
|
271
271
|
:param model_name: Required. A fully-qualified model resource name or model ID.
|
|
272
272
|
:param instances_format: Required. The format in which instances are provided. Must be one of the
|
|
273
273
|
formats listed in `Model.supported_input_storage_formats`. Default is "jsonl" when using
|
|
@@ -55,7 +55,7 @@ if TYPE_CHECKING:
|
|
|
55
55
|
from google.cloud.aiplatform_v1.services.pipeline_service.pagers import (
|
|
56
56
|
ListTrainingPipelinesPager,
|
|
57
57
|
)
|
|
58
|
-
from google.cloud.aiplatform_v1.types import CustomJob, TrainingPipeline
|
|
58
|
+
from google.cloud.aiplatform_v1.types import CustomJob, PscInterfaceConfig, TrainingPipeline
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
@@ -317,6 +317,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
317
317
|
is_default_version: bool | None = None,
|
|
318
318
|
model_version_aliases: list[str] | None = None,
|
|
319
319
|
model_version_description: str | None = None,
|
|
320
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
320
321
|
) -> tuple[models.Model | None, str, str]:
|
|
321
322
|
"""Run a training pipeline job and wait until its completion."""
|
|
322
323
|
model = job.run(
|
|
@@ -350,6 +351,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
350
351
|
is_default_version=is_default_version,
|
|
351
352
|
model_version_aliases=model_version_aliases,
|
|
352
353
|
model_version_description=model_version_description,
|
|
354
|
+
psc_interface_config=psc_interface_config,
|
|
353
355
|
)
|
|
354
356
|
training_id = self.extract_training_id(job.resource_name)
|
|
355
357
|
custom_job_id = self.extract_custom_job_id(
|
|
@@ -574,6 +576,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
574
576
|
timestamp_split_column_name: str | None = None,
|
|
575
577
|
tensorboard: str | None = None,
|
|
576
578
|
sync=True,
|
|
579
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
577
580
|
) -> tuple[models.Model | None, str, str]:
|
|
578
581
|
"""
|
|
579
582
|
Create Custom Container Training Job.
|
|
@@ -837,6 +840,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
837
840
|
:param sync: Whether to execute the AI Platform job synchronously. If False, this method
|
|
838
841
|
will be executed in concurrent Future and any downstream object will
|
|
839
842
|
be immediately returned and synced when the Future has completed.
|
|
843
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
844
|
+
training.
|
|
840
845
|
"""
|
|
841
846
|
self._job = self.get_custom_container_training_job(
|
|
842
847
|
project=project_id,
|
|
@@ -896,6 +901,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
896
901
|
is_default_version=is_default_version,
|
|
897
902
|
model_version_aliases=model_version_aliases,
|
|
898
903
|
model_version_description=model_version_description,
|
|
904
|
+
psc_interface_config=psc_interface_config,
|
|
899
905
|
)
|
|
900
906
|
|
|
901
907
|
return model, training_id, custom_job_id
|
|
@@ -958,6 +964,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
958
964
|
model_version_aliases: list[str] | None = None,
|
|
959
965
|
model_version_description: str | None = None,
|
|
960
966
|
sync=True,
|
|
967
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
961
968
|
) -> tuple[models.Model | None, str, str]:
|
|
962
969
|
"""
|
|
963
970
|
Create Custom Python Package Training Job.
|
|
@@ -1220,6 +1227,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
1220
1227
|
:param sync: Whether to execute the AI Platform job synchronously. If False, this method
|
|
1221
1228
|
will be executed in concurrent Future and any downstream object will
|
|
1222
1229
|
be immediately returned and synced when the Future has completed.
|
|
1230
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
1231
|
+
training.
|
|
1223
1232
|
"""
|
|
1224
1233
|
self._job = self.get_custom_python_package_training_job(
|
|
1225
1234
|
project=project_id,
|
|
@@ -1280,6 +1289,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
1280
1289
|
is_default_version=is_default_version,
|
|
1281
1290
|
model_version_aliases=model_version_aliases,
|
|
1282
1291
|
model_version_description=model_version_description,
|
|
1292
|
+
psc_interface_config=psc_interface_config,
|
|
1283
1293
|
)
|
|
1284
1294
|
|
|
1285
1295
|
return model, training_id, custom_job_id
|
|
@@ -1342,6 +1352,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
1342
1352
|
timestamp_split_column_name: str | None = None,
|
|
1343
1353
|
tensorboard: str | None = None,
|
|
1344
1354
|
sync=True,
|
|
1355
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
1345
1356
|
) -> tuple[models.Model | None, str, str]:
|
|
1346
1357
|
"""
|
|
1347
1358
|
Create Custom Training Job.
|
|
@@ -1604,6 +1615,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
1604
1615
|
:param sync: Whether to execute the AI Platform job synchronously. If False, this method
|
|
1605
1616
|
will be executed in concurrent Future and any downstream object will
|
|
1606
1617
|
be immediately returned and synced when the Future has completed.
|
|
1618
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
1619
|
+
training.
|
|
1607
1620
|
"""
|
|
1608
1621
|
self._job = self.get_custom_training_job(
|
|
1609
1622
|
project=project_id,
|
|
@@ -1664,6 +1677,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
1664
1677
|
is_default_version=is_default_version,
|
|
1665
1678
|
model_version_aliases=model_version_aliases,
|
|
1666
1679
|
model_version_description=model_version_description,
|
|
1680
|
+
psc_interface_config=psc_interface_config,
|
|
1667
1681
|
)
|
|
1668
1682
|
|
|
1669
1683
|
return model, training_id, custom_job_id
|
|
@@ -1725,6 +1739,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
1725
1739
|
predefined_split_column_name: str | None = None,
|
|
1726
1740
|
timestamp_split_column_name: str | None = None,
|
|
1727
1741
|
tensorboard: str | None = None,
|
|
1742
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
1728
1743
|
) -> CustomContainerTrainingJob:
|
|
1729
1744
|
"""
|
|
1730
1745
|
Create and submit a Custom Container Training Job pipeline, then exit without waiting for it to complete.
|
|
@@ -1985,6 +2000,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
1985
2000
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
1986
2001
|
For more information on configuring your service account please visit:
|
|
1987
2002
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
2003
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
2004
|
+
training.
|
|
1988
2005
|
"""
|
|
1989
2006
|
self._job = self.get_custom_container_training_job(
|
|
1990
2007
|
project=project_id,
|
|
@@ -2043,6 +2060,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
2043
2060
|
model_version_aliases=model_version_aliases,
|
|
2044
2061
|
model_version_description=model_version_description,
|
|
2045
2062
|
sync=False,
|
|
2063
|
+
psc_interface_config=psc_interface_config,
|
|
2046
2064
|
)
|
|
2047
2065
|
return self._job
|
|
2048
2066
|
|
|
@@ -2104,6 +2122,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
2104
2122
|
is_default_version: bool | None = None,
|
|
2105
2123
|
model_version_aliases: list[str] | None = None,
|
|
2106
2124
|
model_version_description: str | None = None,
|
|
2125
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
2107
2126
|
) -> CustomPythonPackageTrainingJob:
|
|
2108
2127
|
"""
|
|
2109
2128
|
Create and submit a Custom Python Package Training Job pipeline, then exit without waiting for it to complete.
|
|
@@ -2363,6 +2382,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
2363
2382
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
2364
2383
|
For more information on configuring your service account please visit:
|
|
2365
2384
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
2385
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
2386
|
+
training.
|
|
2366
2387
|
"""
|
|
2367
2388
|
self._job = self.get_custom_python_package_training_job(
|
|
2368
2389
|
project=project_id,
|
|
@@ -2422,6 +2443,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
2422
2443
|
model_version_aliases=model_version_aliases,
|
|
2423
2444
|
model_version_description=model_version_description,
|
|
2424
2445
|
sync=False,
|
|
2446
|
+
psc_interface_config=psc_interface_config,
|
|
2425
2447
|
)
|
|
2426
2448
|
|
|
2427
2449
|
return self._job
|
|
@@ -2484,6 +2506,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
2484
2506
|
predefined_split_column_name: str | None = None,
|
|
2485
2507
|
timestamp_split_column_name: str | None = None,
|
|
2486
2508
|
tensorboard: str | None = None,
|
|
2509
|
+
psc_interface_config: PscInterfaceConfig | None = None,
|
|
2487
2510
|
) -> CustomTrainingJob:
|
|
2488
2511
|
"""
|
|
2489
2512
|
Create and submit a Custom Training Job pipeline, then exit without waiting for it to complete.
|
|
@@ -2747,6 +2770,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
2747
2770
|
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
|
|
2748
2771
|
For more information on configuring your service account please visit:
|
|
2749
2772
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
|
2773
|
+
:param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
|
|
2774
|
+
training.
|
|
2750
2775
|
"""
|
|
2751
2776
|
self._job = self.get_custom_training_job(
|
|
2752
2777
|
project=project_id,
|
|
@@ -2806,6 +2831,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
|
|
|
2806
2831
|
model_version_aliases=model_version_aliases,
|
|
2807
2832
|
model_version_description=model_version_description,
|
|
2808
2833
|
sync=False,
|
|
2834
|
+
psc_interface_config=psc_interface_config,
|
|
2809
2835
|
)
|
|
2810
2836
|
return self._job
|
|
2811
2837
|
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
|
3
|
+
# distributed with this work for additional information
|
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
|
6
|
+
# "License"); you may not use this file except in compliance
|
|
7
|
+
# with the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
|
12
|
+
# software distributed under the License is distributed on an
|
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
14
|
+
# KIND, either express or implied. See the License for the
|
|
15
|
+
# specific language governing permissions and limitations
|
|
16
|
+
# under the License.
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from google.cloud import aiplatform
|
|
21
|
+
from google.cloud.aiplatform.compat.types import execution_v1 as gca_execution
|
|
22
|
+
|
|
23
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ExperimentHook(GoogleBaseHook):
|
|
27
|
+
"""Use the Vertex AI SDK for Python to manage your experiments."""
|
|
28
|
+
|
|
29
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
30
|
+
def create_experiment(
|
|
31
|
+
self,
|
|
32
|
+
experiment_name: str,
|
|
33
|
+
location: str,
|
|
34
|
+
experiment_description: str = "",
|
|
35
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
36
|
+
experiment_tensorboard: str | None = None,
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Create an experiment and, optionally, associate a Vertex AI TensorBoard instance using the Vertex AI SDK for Python.
|
|
40
|
+
|
|
41
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
42
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
43
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
44
|
+
:param experiment_description: Optional. Description of the evaluation experiment.
|
|
45
|
+
:param experiment_tensorboard: Optional. The Vertex TensorBoard instance to use as a backing
|
|
46
|
+
TensorBoard for the provided experiment. If no TensorBoard is provided, a default Tensorboard
|
|
47
|
+
instance is created and used by this experiment.
|
|
48
|
+
"""
|
|
49
|
+
aiplatform.init(
|
|
50
|
+
experiment=experiment_name,
|
|
51
|
+
experiment_description=experiment_description,
|
|
52
|
+
experiment_tensorboard=experiment_tensorboard if experiment_tensorboard else False,
|
|
53
|
+
project=project_id,
|
|
54
|
+
location=location,
|
|
55
|
+
)
|
|
56
|
+
self.log.info("Created experiment with name: %s", experiment_name)
|
|
57
|
+
|
|
58
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
59
|
+
def delete_experiment(
|
|
60
|
+
self,
|
|
61
|
+
experiment_name: str,
|
|
62
|
+
location: str,
|
|
63
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
64
|
+
delete_backing_tensorboard_runs: bool = False,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""
|
|
67
|
+
Delete an experiment.
|
|
68
|
+
|
|
69
|
+
Deleting an experiment deletes that experiment and all experiment runs associated with the experiment.
|
|
70
|
+
The Vertex AI TensorBoard experiment associated with the experiment is not deleted.
|
|
71
|
+
|
|
72
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
73
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
74
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
75
|
+
:param delete_backing_tensorboard_runs: Optional. If True will also delete the Vertex AI TensorBoard
|
|
76
|
+
runs associated with the experiment runs under this experiment that we used to store time series
|
|
77
|
+
metrics.
|
|
78
|
+
"""
|
|
79
|
+
experiment = aiplatform.Experiment(
|
|
80
|
+
experiment_name=experiment_name, project=project_id, location=location
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
experiment.delete(delete_backing_tensorboard_runs=delete_backing_tensorboard_runs)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ExperimentRunHook(GoogleBaseHook):
|
|
87
|
+
"""Use the Vertex AI SDK for Python to create and manage your experiment runs."""
|
|
88
|
+
|
|
89
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
90
|
+
def create_experiment_run(
|
|
91
|
+
self,
|
|
92
|
+
experiment_run_name: str,
|
|
93
|
+
experiment_name: str,
|
|
94
|
+
location: str,
|
|
95
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
96
|
+
experiment_run_tensorboard: str | None = None,
|
|
97
|
+
run_after_creation: bool = False,
|
|
98
|
+
) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Create experiment run for the experiment.
|
|
101
|
+
|
|
102
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
103
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
104
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
105
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
|
106
|
+
:param experiment_run_tensorboard: Optional. A backing TensorBoard resource to enable and store time
|
|
107
|
+
series metrics logged to this experiment run.
|
|
108
|
+
:param run_after_creation: Optional. Responsible for state after creation of experiment run.
|
|
109
|
+
If true experiment run will be created with state RUNNING.
|
|
110
|
+
"""
|
|
111
|
+
experiment_run_state = (
|
|
112
|
+
gca_execution.Execution.State.NEW
|
|
113
|
+
if not run_after_creation
|
|
114
|
+
else gca_execution.Execution.State.RUNNING
|
|
115
|
+
)
|
|
116
|
+
experiment_run = aiplatform.ExperimentRun.create(
|
|
117
|
+
run_name=experiment_run_name,
|
|
118
|
+
experiment=experiment_name,
|
|
119
|
+
project=project_id,
|
|
120
|
+
location=location,
|
|
121
|
+
state=experiment_run_state,
|
|
122
|
+
tensorboard=experiment_run_tensorboard,
|
|
123
|
+
)
|
|
124
|
+
self.log.info(
|
|
125
|
+
"Created experiment run with name: %s and status: %s",
|
|
126
|
+
experiment_run.name,
|
|
127
|
+
experiment_run.state,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
131
|
+
def list_experiment_runs(
|
|
132
|
+
self,
|
|
133
|
+
experiment_name: str,
|
|
134
|
+
location: str,
|
|
135
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
136
|
+
) -> list[aiplatform.ExperimentRun]:
|
|
137
|
+
"""
|
|
138
|
+
List experiment run for the experiment.
|
|
139
|
+
|
|
140
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
141
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
142
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
143
|
+
"""
|
|
144
|
+
experiment_runs = aiplatform.ExperimentRun.list(
|
|
145
|
+
experiment=experiment_name,
|
|
146
|
+
project=project_id,
|
|
147
|
+
location=location,
|
|
148
|
+
)
|
|
149
|
+
return experiment_runs
|
|
150
|
+
|
|
151
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
152
|
+
def update_experiment_run_state(
|
|
153
|
+
self,
|
|
154
|
+
experiment_run_name: str,
|
|
155
|
+
experiment_name: str,
|
|
156
|
+
location: str,
|
|
157
|
+
new_state: gca_execution.Execution.State,
|
|
158
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
159
|
+
) -> None:
|
|
160
|
+
"""
|
|
161
|
+
Update state of the experiment run.
|
|
162
|
+
|
|
163
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
164
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
165
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
166
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
|
167
|
+
:param new_state: Required. New state of the experiment run.
|
|
168
|
+
"""
|
|
169
|
+
experiment_run = aiplatform.ExperimentRun(
|
|
170
|
+
run_name=experiment_run_name,
|
|
171
|
+
experiment=experiment_name,
|
|
172
|
+
project=project_id,
|
|
173
|
+
location=location,
|
|
174
|
+
)
|
|
175
|
+
self.log.info("State of the %s before update is: %s", experiment_run.name, experiment_run.state)
|
|
176
|
+
|
|
177
|
+
experiment_run.update_state(new_state)
|
|
178
|
+
|
|
179
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
180
|
+
def delete_experiment_run(
|
|
181
|
+
self,
|
|
182
|
+
experiment_run_name: str,
|
|
183
|
+
experiment_name: str,
|
|
184
|
+
location: str,
|
|
185
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
186
|
+
delete_backing_tensorboard_run: bool = False,
|
|
187
|
+
) -> None:
|
|
188
|
+
"""
|
|
189
|
+
Delete experiment run from the experiment.
|
|
190
|
+
|
|
191
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
192
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
193
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
194
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
|
195
|
+
:param delete_backing_tensorboard_run: Whether to delete the backing Vertex AI TensorBoard run
|
|
196
|
+
that stores time series metrics for this run.
|
|
197
|
+
"""
|
|
198
|
+
self.log.info("Next experiment run will be deleted: %s", experiment_run_name)
|
|
199
|
+
experiment_run = aiplatform.ExperimentRun(
|
|
200
|
+
run_name=experiment_run_name, experiment=experiment_name, project=project_id, location=location
|
|
201
|
+
)
|
|
202
|
+
experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)
|