apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (234) hide show
  1. airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
  2. airflow/providers/google/__init__.py +3 -3
  3. airflow/providers/google/ads/hooks/ads.py +39 -5
  4. airflow/providers/google/ads/operators/ads.py +2 -2
  5. airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
  6. airflow/providers/google/assets/gcs.py +1 -11
  7. airflow/providers/google/cloud/bundles/__init__.py +16 -0
  8. airflow/providers/google/cloud/bundles/gcs.py +161 -0
  9. airflow/providers/google/cloud/hooks/bigquery.py +166 -281
  10. airflow/providers/google/cloud/hooks/cloud_composer.py +287 -14
  11. airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
  12. airflow/providers/google/cloud/hooks/cloud_run.py +17 -9
  13. airflow/providers/google/cloud/hooks/cloud_sql.py +101 -22
  14. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +27 -6
  15. airflow/providers/google/cloud/hooks/compute_ssh.py +5 -1
  16. airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
  17. airflow/providers/google/cloud/hooks/dataflow.py +71 -94
  18. airflow/providers/google/cloud/hooks/datafusion.py +1 -1
  19. airflow/providers/google/cloud/hooks/dataplex.py +1 -1
  20. airflow/providers/google/cloud/hooks/dataprep.py +1 -1
  21. airflow/providers/google/cloud/hooks/dataproc.py +72 -71
  22. airflow/providers/google/cloud/hooks/gcs.py +111 -14
  23. airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
  24. airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
  25. airflow/providers/google/cloud/hooks/looker.py +6 -1
  26. airflow/providers/google/cloud/hooks/mlengine.py +3 -2
  27. airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
  28. airflow/providers/google/cloud/hooks/spanner.py +73 -8
  29. airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
  30. airflow/providers/google/cloud/hooks/translate.py +1 -1
  31. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -209
  32. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +2 -2
  33. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +27 -1
  34. airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
  35. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
  36. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
  37. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
  38. airflow/providers/google/cloud/hooks/vision.py +2 -2
  39. airflow/providers/google/cloud/hooks/workflows.py +1 -1
  40. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  41. airflow/providers/google/cloud/links/base.py +77 -13
  42. airflow/providers/google/cloud/links/bigquery.py +0 -47
  43. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  44. airflow/providers/google/cloud/links/bigtable.py +0 -48
  45. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  46. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  47. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  48. airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
  49. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  50. airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
  51. airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
  52. airflow/providers/google/cloud/links/compute.py +0 -58
  53. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  54. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  55. airflow/providers/google/cloud/links/dataflow.py +0 -34
  56. airflow/providers/google/cloud/links/dataform.py +0 -64
  57. airflow/providers/google/cloud/links/datafusion.py +1 -96
  58. airflow/providers/google/cloud/links/dataplex.py +0 -154
  59. airflow/providers/google/cloud/links/dataprep.py +0 -24
  60. airflow/providers/google/cloud/links/dataproc.py +11 -95
  61. airflow/providers/google/cloud/links/datastore.py +0 -31
  62. airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
  63. airflow/providers/google/cloud/links/managed_kafka.py +0 -70
  64. airflow/providers/google/cloud/links/mlengine.py +0 -70
  65. airflow/providers/google/cloud/links/pubsub.py +0 -32
  66. airflow/providers/google/cloud/links/spanner.py +0 -33
  67. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  68. airflow/providers/google/cloud/links/translate.py +17 -187
  69. airflow/providers/google/cloud/links/vertex_ai.py +28 -195
  70. airflow/providers/google/cloud/links/workflows.py +0 -52
  71. airflow/providers/google/cloud/log/gcs_task_handler.py +17 -9
  72. airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
  73. airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
  74. airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
  75. airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
  76. airflow/providers/google/cloud/openlineage/facets.py +102 -1
  77. airflow/providers/google/cloud/openlineage/mixins.py +10 -8
  78. airflow/providers/google/cloud/openlineage/utils.py +15 -1
  79. airflow/providers/google/cloud/operators/alloy_db.py +70 -55
  80. airflow/providers/google/cloud/operators/bigquery.py +73 -636
  81. airflow/providers/google/cloud/operators/bigquery_dts.py +3 -5
  82. airflow/providers/google/cloud/operators/bigtable.py +36 -7
  83. airflow/providers/google/cloud/operators/cloud_base.py +21 -1
  84. airflow/providers/google/cloud/operators/cloud_batch.py +2 -2
  85. airflow/providers/google/cloud/operators/cloud_build.py +75 -32
  86. airflow/providers/google/cloud/operators/cloud_composer.py +128 -40
  87. airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
  88. airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
  89. airflow/providers/google/cloud/operators/cloud_run.py +23 -5
  90. airflow/providers/google/cloud/operators/cloud_sql.py +8 -16
  91. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -11
  92. airflow/providers/google/cloud/operators/compute.py +8 -40
  93. airflow/providers/google/cloud/operators/datacatalog.py +157 -21
  94. airflow/providers/google/cloud/operators/dataflow.py +38 -15
  95. airflow/providers/google/cloud/operators/dataform.py +15 -5
  96. airflow/providers/google/cloud/operators/datafusion.py +41 -20
  97. airflow/providers/google/cloud/operators/dataplex.py +193 -109
  98. airflow/providers/google/cloud/operators/dataprep.py +1 -5
  99. airflow/providers/google/cloud/operators/dataproc.py +78 -35
  100. airflow/providers/google/cloud/operators/dataproc_metastore.py +96 -88
  101. airflow/providers/google/cloud/operators/datastore.py +22 -6
  102. airflow/providers/google/cloud/operators/dlp.py +6 -29
  103. airflow/providers/google/cloud/operators/functions.py +16 -7
  104. airflow/providers/google/cloud/operators/gcs.py +10 -8
  105. airflow/providers/google/cloud/operators/gen_ai.py +389 -0
  106. airflow/providers/google/cloud/operators/kubernetes_engine.py +60 -99
  107. airflow/providers/google/cloud/operators/looker.py +1 -1
  108. airflow/providers/google/cloud/operators/managed_kafka.py +107 -52
  109. airflow/providers/google/cloud/operators/natural_language.py +1 -1
  110. airflow/providers/google/cloud/operators/pubsub.py +60 -14
  111. airflow/providers/google/cloud/operators/spanner.py +25 -12
  112. airflow/providers/google/cloud/operators/speech_to_text.py +1 -2
  113. airflow/providers/google/cloud/operators/stackdriver.py +1 -9
  114. airflow/providers/google/cloud/operators/tasks.py +1 -12
  115. airflow/providers/google/cloud/operators/text_to_speech.py +1 -2
  116. airflow/providers/google/cloud/operators/translate.py +40 -16
  117. airflow/providers/google/cloud/operators/translate_speech.py +1 -2
  118. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
  119. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +29 -9
  120. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +54 -26
  121. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
  122. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
  123. airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
  124. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
  125. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
  126. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +11 -9
  127. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
  128. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +30 -7
  129. airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
  130. airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
  131. airflow/providers/google/cloud/operators/vision.py +2 -2
  132. airflow/providers/google/cloud/operators/workflows.py +18 -15
  133. airflow/providers/google/cloud/sensors/bigquery.py +2 -2
  134. airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -2
  135. airflow/providers/google/cloud/sensors/bigtable.py +11 -4
  136. airflow/providers/google/cloud/sensors/cloud_composer.py +533 -29
  137. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -2
  138. airflow/providers/google/cloud/sensors/dataflow.py +26 -9
  139. airflow/providers/google/cloud/sensors/dataform.py +2 -2
  140. airflow/providers/google/cloud/sensors/datafusion.py +4 -4
  141. airflow/providers/google/cloud/sensors/dataplex.py +2 -2
  142. airflow/providers/google/cloud/sensors/dataprep.py +2 -2
  143. airflow/providers/google/cloud/sensors/dataproc.py +2 -2
  144. airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -2
  145. airflow/providers/google/cloud/sensors/gcs.py +4 -4
  146. airflow/providers/google/cloud/sensors/looker.py +2 -2
  147. airflow/providers/google/cloud/sensors/pubsub.py +4 -4
  148. airflow/providers/google/cloud/sensors/tasks.py +2 -2
  149. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
  150. airflow/providers/google/cloud/sensors/workflows.py +2 -2
  151. airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
  152. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
  153. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
  154. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
  155. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +4 -4
  156. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
  157. airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
  158. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
  159. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
  160. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  161. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +2 -2
  162. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +3 -3
  163. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +20 -12
  164. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
  165. airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
  166. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
  167. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
  168. airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
  169. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  170. airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
  171. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
  172. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
  173. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
  174. airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
  175. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
  176. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +13 -4
  177. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
  178. airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
  179. airflow/providers/google/cloud/triggers/bigquery.py +75 -34
  180. airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
  181. airflow/providers/google/cloud/triggers/cloud_composer.py +302 -46
  182. airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
  183. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +91 -1
  184. airflow/providers/google/cloud/triggers/dataflow.py +122 -0
  185. airflow/providers/google/cloud/triggers/datafusion.py +1 -1
  186. airflow/providers/google/cloud/triggers/dataplex.py +14 -2
  187. airflow/providers/google/cloud/triggers/dataproc.py +122 -52
  188. airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
  189. airflow/providers/google/cloud/triggers/mlengine.py +1 -1
  190. airflow/providers/google/cloud/triggers/pubsub.py +15 -19
  191. airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
  192. airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
  193. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  194. airflow/providers/google/common/auth_backend/google_openid.py +4 -4
  195. airflow/providers/google/common/deprecated.py +2 -1
  196. airflow/providers/google/common/hooks/base_google.py +27 -8
  197. airflow/providers/google/common/links/storage.py +0 -22
  198. airflow/providers/google/common/utils/get_secret.py +31 -0
  199. airflow/providers/google/common/utils/id_token_credentials.py +3 -4
  200. airflow/providers/google/firebase/operators/firestore.py +2 -2
  201. airflow/providers/google/get_provider_info.py +56 -52
  202. airflow/providers/google/go_module_utils.py +35 -3
  203. airflow/providers/google/leveldb/hooks/leveldb.py +26 -1
  204. airflow/providers/google/leveldb/operators/leveldb.py +2 -2
  205. airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
  206. airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
  207. airflow/providers/google/marketing_platform/operators/analytics_admin.py +1 -2
  208. airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
  209. airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
  210. airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
  211. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
  212. airflow/providers/google/marketing_platform/sensors/display_video.py +3 -63
  213. airflow/providers/google/suite/hooks/calendar.py +1 -1
  214. airflow/providers/google/suite/hooks/sheets.py +15 -1
  215. airflow/providers/google/suite/operators/sheets.py +8 -3
  216. airflow/providers/google/suite/sensors/drive.py +2 -2
  217. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
  218. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  219. airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
  220. airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
  221. airflow/providers/google/version_compat.py +15 -1
  222. {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +92 -48
  223. apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
  224. apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
  225. airflow/providers/google/cloud/hooks/automl.py +0 -673
  226. airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
  227. airflow/providers/google/cloud/links/automl.py +0 -193
  228. airflow/providers/google/cloud/operators/automl.py +0 -1362
  229. airflow/providers/google/cloud/operators/life_sciences.py +0 -119
  230. airflow/providers/google/cloud/operators/mlengine.py +0 -112
  231. apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
  232. {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +0 -0
  233. {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
  234. {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
@@ -38,7 +38,6 @@ from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient
38
38
 
39
39
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
40
40
  from airflow.providers.google.common.consts import CLIENT_INFO
41
- from airflow.providers.google.common.deprecated import deprecated
42
41
  from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
43
42
  from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
44
43
 
@@ -185,42 +184,6 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
185
184
  model_encryption_spec_key_name=model_encryption_spec_key_name,
186
185
  )
187
186
 
188
- @deprecated(
189
- planned_removal_date="June 15, 2025",
190
- category=AirflowProviderDeprecationWarning,
191
- reason="Deprecation of AutoMLText API",
192
- )
193
- def get_auto_ml_text_training_job(
194
- self,
195
- display_name: str,
196
- prediction_type: str,
197
- multi_label: bool = False,
198
- sentiment_max: int = 10,
199
- project: str | None = None,
200
- location: str | None = None,
201
- labels: dict[str, str] | None = None,
202
- training_encryption_spec_key_name: str | None = None,
203
- model_encryption_spec_key_name: str | None = None,
204
- ) -> AutoMLTextTrainingJob:
205
- """
206
- Return AutoMLTextTrainingJob object.
207
-
208
- WARNING: Text creation API is deprecated since September 15, 2024
209
- (https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
210
- """
211
- return AutoMLTextTrainingJob(
212
- display_name=display_name,
213
- prediction_type=prediction_type,
214
- multi_label=multi_label,
215
- sentiment_max=sentiment_max,
216
- project=project,
217
- location=location,
218
- credentials=self.get_credentials(),
219
- labels=labels,
220
- training_encryption_spec_key_name=training_encryption_spec_key_name,
221
- model_encryption_spec_key_name=model_encryption_spec_key_name,
222
- )
223
-
224
187
  def get_auto_ml_video_training_job(
225
188
  self,
226
189
  display_name: str,
@@ -987,178 +950,6 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
987
950
  )
988
951
  return model, training_id
989
952
 
990
- @GoogleBaseHook.fallback_to_default_project_id
991
- @deprecated(
992
- planned_removal_date="September 15, 2025",
993
- category=AirflowProviderDeprecationWarning,
994
- reason="Deprecation of AutoMLText API",
995
- )
996
- def create_auto_ml_text_training_job(
997
- self,
998
- project_id: str,
999
- region: str,
1000
- display_name: str,
1001
- dataset: datasets.TextDataset,
1002
- prediction_type: str,
1003
- multi_label: bool = False,
1004
- sentiment_max: int = 10,
1005
- labels: dict[str, str] | None = None,
1006
- training_encryption_spec_key_name: str | None = None,
1007
- model_encryption_spec_key_name: str | None = None,
1008
- training_fraction_split: float | None = None,
1009
- validation_fraction_split: float | None = None,
1010
- test_fraction_split: float | None = None,
1011
- training_filter_split: str | None = None,
1012
- validation_filter_split: str | None = None,
1013
- test_filter_split: str | None = None,
1014
- model_display_name: str | None = None,
1015
- model_labels: dict[str, str] | None = None,
1016
- sync: bool = True,
1017
- parent_model: str | None = None,
1018
- is_default_version: bool | None = None,
1019
- model_version_aliases: list[str] | None = None,
1020
- model_version_description: str | None = None,
1021
- ) -> tuple[models.Model | None, str]:
1022
- """
1023
- Create an AutoML Text Training Job.
1024
-
1025
- WARNING: Text creation API is deprecated since September 15, 2024
1026
- (https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
1027
-
1028
- :param project_id: Required. Project to run training in.
1029
- :param region: Required. Location to run training in.
1030
- :param display_name: Required. The user-defined name of this TrainingPipeline.
1031
- :param dataset: Required. The dataset within the same Project from which data will be used to train
1032
- the Model. The Dataset must use schema compatible with Model being trained, and what is
1033
- compatible should be described in the used TrainingPipeline's [training_task_definition]
1034
- [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
1035
- :param prediction_type: The type of prediction the Model is to produce, one of:
1036
- "classification" - A classification model analyzes text data and returns a list of categories
1037
- that apply to the text found in the data. Vertex AI offers both single-label and multi-label text
1038
- classification models.
1039
- "extraction" - An entity extraction model inspects text data for known entities referenced in the
1040
- data and labels those entities in the text.
1041
- "sentiment" - A sentiment analysis model inspects text data and identifies the prevailing
1042
- emotional opinion within it, especially to determine a writer's attitude as positive, negative,
1043
- or neutral.
1044
- :param parent_model: Optional. The resource name or model ID of an existing model.
1045
- The new model uploaded by this job will be a version of `parent_model`.
1046
- Only set this field when training a new version of an existing model.
1047
- :param is_default_version: Optional. When set to True, the newly uploaded model version will
1048
- automatically have alias "default" included. Subsequent uses of
1049
- the model produced by this job without a version specified will
1050
- use this "default" version.
1051
- When set to False, the "default" alias will not be moved.
1052
- Actions targeting the model version produced by this job will need
1053
- to specifically reference this version by ID or alias.
1054
- New model uploads, i.e. version 1, will always be "default" aliased.
1055
- :param model_version_aliases: Optional. User provided version aliases so that the model version
1056
- uploaded by this job can be referenced via alias instead of
1057
- auto-generated version ID. A default version alias will be created
1058
- for the first version of the model.
1059
- The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9]
1060
- :param model_version_description: Optional. The description of the model version
1061
- being uploaded by this job.
1062
- :param multi_label: Required and only applicable for text classification task. If false, a
1063
- single-label (multi-class) Model will be trained (i.e. assuming that for each text snippet just
1064
- up to one annotation may be applicable). If true, a multi-label Model will be trained (i.e.
1065
- assuming that for each text snippet multiple annotations may be applicable).
1066
- :param sentiment_max: Required and only applicable for sentiment task. A sentiment is expressed as an
1067
- integer ordinal, where higher value means a more positive sentiment. The range of sentiments that
1068
- will be used is between 0 and sentimentMax (inclusive on both ends), and all the values in the
1069
- range must be represented in the dataset before a model can be created. Only the Annotations with
1070
- this sentimentMax will be used for training. sentimentMax value must be between 1 and 10
1071
- (inclusive).
1072
- :param labels: Optional. The labels with user-defined metadata to organize TrainingPipelines. Label
1073
- keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
1074
- lowercase letters, numeric characters, underscores and dashes. International characters are
1075
- allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
1076
- :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
1077
- managed encryption key used to protect the training pipeline. Has the form:
1078
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1079
- The key needs to be in the same region as where the compute resource is created.
1080
- If set, this TrainingPipeline will be secured by this key.
1081
- Note: Model trained by this TrainingPipeline is also secured by this key if ``model_to_upload``
1082
- is not set separately.
1083
- :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer
1084
- managed encryption key used to protect the model. Has the form:
1085
- ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
1086
- The key needs to be in the same region as where the compute resource is created.
1087
- If set, the trained Model will be secured by this key.
1088
- :param training_fraction_split: Optional. The fraction of the input data that is to be used to train
1089
- the Model. This is ignored if Dataset is not provided.
1090
- :param validation_fraction_split: Optional. The fraction of the input data that is to be used to
1091
- validate the Model. This is ignored if Dataset is not provided.
1092
- :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate
1093
- the Model. This is ignored if Dataset is not provided.
1094
- :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
1095
- this filter are used to train the Model. A filter with same syntax as the one used in
1096
- DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
1097
- FilterSplit filters, then it is assigned to the first set that applies to it in the training,
1098
- validation, test order. This is ignored if Dataset is not provided.
1099
- :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match
1100
- this filter are used to validate the Model. A filter with same syntax as the one used in
1101
- DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
1102
- FilterSplit filters, then it is assigned to the first set that applies to it in the training,
1103
- validation, test order. This is ignored if Dataset is not provided.
1104
- :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match this
1105
- filter are used to test the Model. A filter with same syntax as the one used in
1106
- DatasetService.ListDataItems may be used. If a single DataItem is matched by more than one of the
1107
- FilterSplit filters, then it is assigned to the first set that applies to it in the training,
1108
- validation, test order. This is ignored if Dataset is not provided.
1109
- :param model_display_name: Optional. The display name of the managed Vertex AI Model. The name can be
1110
- up to 128 characters long and can consist of any UTF-8 characters.
1111
- If not provided upon creation, the job's display_name is used.
1112
- :param model_labels: Optional. The labels with user-defined metadata to organize your Models. Label
1113
- keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
1114
- lowercase letters, numeric characters, underscores and dashes. International characters are
1115
- allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
1116
- :param sync: Whether to execute this method synchronously. If False, this method will be executed in
1117
- concurrent Future and any downstream object will be immediately returned and synced when the
1118
- Future has completed.
1119
- """
1120
- self._job = AutoMLTextTrainingJob(
1121
- display_name=display_name,
1122
- prediction_type=prediction_type,
1123
- multi_label=multi_label,
1124
- sentiment_max=sentiment_max,
1125
- project=project_id,
1126
- location=region,
1127
- credentials=self.get_credentials(),
1128
- labels=labels,
1129
- training_encryption_spec_key_name=training_encryption_spec_key_name,
1130
- model_encryption_spec_key_name=model_encryption_spec_key_name,
1131
- )
1132
-
1133
- if not self._job:
1134
- raise AirflowException("AutoMLTextTrainingJob was not created")
1135
-
1136
- model = self._job.run(
1137
- dataset=dataset, # type: ignore[arg-type]
1138
- training_fraction_split=training_fraction_split, # type: ignore[call-arg]
1139
- validation_fraction_split=validation_fraction_split, # type: ignore[call-arg]
1140
- test_fraction_split=test_fraction_split,
1141
- training_filter_split=training_filter_split,
1142
- validation_filter_split=validation_filter_split,
1143
- test_filter_split=test_filter_split, # type: ignore[call-arg]
1144
- model_display_name=model_display_name,
1145
- model_labels=model_labels,
1146
- sync=sync,
1147
- parent_model=parent_model,
1148
- is_default_version=is_default_version,
1149
- model_version_aliases=model_version_aliases,
1150
- model_version_description=model_version_description,
1151
- )
1152
- training_id = self.extract_training_id(self._job.resource_name)
1153
- if model:
1154
- model.wait()
1155
- else:
1156
- self.log.warning(
1157
- "Training did not produce a Managed Model returning None. AutoML Text Training "
1158
- "Pipeline is not configured to upload a Model."
1159
- )
1160
- return model, training_id
1161
-
1162
953
  @GoogleBaseHook.fallback_to_default_project_id
1163
954
  def create_auto_ml_video_training_job(
1164
955
  self,
@@ -110,7 +110,7 @@ class BatchPredictionJobHook(GoogleBaseHook, OperationHelper):
110
110
  :param project_id: Required. Project to run training in.
111
111
  :param region: Required. Location to run training in.
112
112
  :param job_display_name: Required. The user-defined name of the BatchPredictionJob. The name can be
113
- up to 128 characters long and can be consist of any UTF-8 characters.
113
+ up to 128 characters long and can consist of any UTF-8 characters.
114
114
  :param model_name: Required. A fully-qualified model resource name or model ID.
115
115
  :param instances_format: Required. The format in which instances are provided. Must be one of the
116
116
  formats listed in `Model.supported_input_storage_formats`. Default is "jsonl" when using
@@ -267,7 +267,7 @@ class BatchPredictionJobHook(GoogleBaseHook, OperationHelper):
267
267
  :param project_id: Required. Project to run training in.
268
268
  :param region: Required. Location to run training in.
269
269
  :param job_display_name: Required. The user-defined name of the BatchPredictionJob. The name can be
270
- up to 128 characters long and can be consist of any UTF-8 characters.
270
+ up to 128 characters long and can consist of any UTF-8 characters.
271
271
  :param model_name: Required. A fully-qualified model resource name or model ID.
272
272
  :param instances_format: Required. The format in which instances are provided. Must be one of the
273
273
  formats listed in `Model.supported_input_storage_formats`. Default is "jsonl" when using
@@ -55,7 +55,7 @@ if TYPE_CHECKING:
55
55
  from google.cloud.aiplatform_v1.services.pipeline_service.pagers import (
56
56
  ListTrainingPipelinesPager,
57
57
  )
58
- from google.cloud.aiplatform_v1.types import CustomJob, TrainingPipeline
58
+ from google.cloud.aiplatform_v1.types import CustomJob, PscInterfaceConfig, TrainingPipeline
59
59
 
60
60
 
61
61
  class CustomJobHook(GoogleBaseHook, OperationHelper):
@@ -317,6 +317,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
317
317
  is_default_version: bool | None = None,
318
318
  model_version_aliases: list[str] | None = None,
319
319
  model_version_description: str | None = None,
320
+ psc_interface_config: PscInterfaceConfig | None = None,
320
321
  ) -> tuple[models.Model | None, str, str]:
321
322
  """Run a training pipeline job and wait until its completion."""
322
323
  model = job.run(
@@ -350,6 +351,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
350
351
  is_default_version=is_default_version,
351
352
  model_version_aliases=model_version_aliases,
352
353
  model_version_description=model_version_description,
354
+ psc_interface_config=psc_interface_config,
353
355
  )
354
356
  training_id = self.extract_training_id(job.resource_name)
355
357
  custom_job_id = self.extract_custom_job_id(
@@ -574,6 +576,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
574
576
  timestamp_split_column_name: str | None = None,
575
577
  tensorboard: str | None = None,
576
578
  sync=True,
579
+ psc_interface_config: PscInterfaceConfig | None = None,
577
580
  ) -> tuple[models.Model | None, str, str]:
578
581
  """
579
582
  Create Custom Container Training Job.
@@ -837,6 +840,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
837
840
  :param sync: Whether to execute the AI Platform job synchronously. If False, this method
838
841
  will be executed in concurrent Future and any downstream object will
839
842
  be immediately returned and synced when the Future has completed.
843
+ :param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
844
+ training.
840
845
  """
841
846
  self._job = self.get_custom_container_training_job(
842
847
  project=project_id,
@@ -896,6 +901,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
896
901
  is_default_version=is_default_version,
897
902
  model_version_aliases=model_version_aliases,
898
903
  model_version_description=model_version_description,
904
+ psc_interface_config=psc_interface_config,
899
905
  )
900
906
 
901
907
  return model, training_id, custom_job_id
@@ -958,6 +964,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
958
964
  model_version_aliases: list[str] | None = None,
959
965
  model_version_description: str | None = None,
960
966
  sync=True,
967
+ psc_interface_config: PscInterfaceConfig | None = None,
961
968
  ) -> tuple[models.Model | None, str, str]:
962
969
  """
963
970
  Create Custom Python Package Training Job.
@@ -1220,6 +1227,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
1220
1227
  :param sync: Whether to execute the AI Platform job synchronously. If False, this method
1221
1228
  will be executed in concurrent Future and any downstream object will
1222
1229
  be immediately returned and synced when the Future has completed.
1230
+ :param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
1231
+ training.
1223
1232
  """
1224
1233
  self._job = self.get_custom_python_package_training_job(
1225
1234
  project=project_id,
@@ -1280,6 +1289,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
1280
1289
  is_default_version=is_default_version,
1281
1290
  model_version_aliases=model_version_aliases,
1282
1291
  model_version_description=model_version_description,
1292
+ psc_interface_config=psc_interface_config,
1283
1293
  )
1284
1294
 
1285
1295
  return model, training_id, custom_job_id
@@ -1342,6 +1352,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
1342
1352
  timestamp_split_column_name: str | None = None,
1343
1353
  tensorboard: str | None = None,
1344
1354
  sync=True,
1355
+ psc_interface_config: PscInterfaceConfig | None = None,
1345
1356
  ) -> tuple[models.Model | None, str, str]:
1346
1357
  """
1347
1358
  Create Custom Training Job.
@@ -1604,6 +1615,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
1604
1615
  :param sync: Whether to execute the AI Platform job synchronously. If False, this method
1605
1616
  will be executed in concurrent Future and any downstream object will
1606
1617
  be immediately returned and synced when the Future has completed.
1618
+ :param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
1619
+ training.
1607
1620
  """
1608
1621
  self._job = self.get_custom_training_job(
1609
1622
  project=project_id,
@@ -1664,6 +1677,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
1664
1677
  is_default_version=is_default_version,
1665
1678
  model_version_aliases=model_version_aliases,
1666
1679
  model_version_description=model_version_description,
1680
+ psc_interface_config=psc_interface_config,
1667
1681
  )
1668
1682
 
1669
1683
  return model, training_id, custom_job_id
@@ -1725,6 +1739,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
1725
1739
  predefined_split_column_name: str | None = None,
1726
1740
  timestamp_split_column_name: str | None = None,
1727
1741
  tensorboard: str | None = None,
1742
+ psc_interface_config: PscInterfaceConfig | None = None,
1728
1743
  ) -> CustomContainerTrainingJob:
1729
1744
  """
1730
1745
  Create and submit a Custom Container Training Job pipeline, then exit without waiting for it to complete.
@@ -1985,6 +2000,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
1985
2000
  ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
1986
2001
  For more information on configuring your service account please visit:
1987
2002
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
2003
+ :param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
2004
+ training.
1988
2005
  """
1989
2006
  self._job = self.get_custom_container_training_job(
1990
2007
  project=project_id,
@@ -2043,6 +2060,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
2043
2060
  model_version_aliases=model_version_aliases,
2044
2061
  model_version_description=model_version_description,
2045
2062
  sync=False,
2063
+ psc_interface_config=psc_interface_config,
2046
2064
  )
2047
2065
  return self._job
2048
2066
 
@@ -2104,6 +2122,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
2104
2122
  is_default_version: bool | None = None,
2105
2123
  model_version_aliases: list[str] | None = None,
2106
2124
  model_version_description: str | None = None,
2125
+ psc_interface_config: PscInterfaceConfig | None = None,
2107
2126
  ) -> CustomPythonPackageTrainingJob:
2108
2127
  """
2109
2128
  Create and submit a Custom Python Package Training Job pipeline, then exit without waiting for it to complete.
@@ -2363,6 +2382,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
2363
2382
  ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
2364
2383
  For more information on configuring your service account please visit:
2365
2384
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
2385
+ :param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
2386
+ training.
2366
2387
  """
2367
2388
  self._job = self.get_custom_python_package_training_job(
2368
2389
  project=project_id,
@@ -2422,6 +2443,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
2422
2443
  model_version_aliases=model_version_aliases,
2423
2444
  model_version_description=model_version_description,
2424
2445
  sync=False,
2446
+ psc_interface_config=psc_interface_config,
2425
2447
  )
2426
2448
 
2427
2449
  return self._job
@@ -2484,6 +2506,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
2484
2506
  predefined_split_column_name: str | None = None,
2485
2507
  timestamp_split_column_name: str | None = None,
2486
2508
  tensorboard: str | None = None,
2509
+ psc_interface_config: PscInterfaceConfig | None = None,
2487
2510
  ) -> CustomTrainingJob:
2488
2511
  """
2489
2512
  Create and submit a Custom Training Job pipeline, then exit without waiting for it to complete.
@@ -2747,6 +2770,8 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
2747
2770
  ``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
2748
2771
  For more information on configuring your service account please visit:
2749
2772
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
2773
+ :param psc_interface_config: Optional. Configuration for Private Service Connect interface used for
2774
+ training.
2750
2775
  """
2751
2776
  self._job = self.get_custom_training_job(
2752
2777
  project=project_id,
@@ -2806,6 +2831,7 @@ class CustomJobHook(GoogleBaseHook, OperationHelper):
2806
2831
  model_version_aliases=model_version_aliases,
2807
2832
  model_version_description=model_version_description,
2808
2833
  sync=False,
2834
+ psc_interface_config=psc_interface_config,
2809
2835
  )
2810
2836
  return self._job
2811
2837
 
@@ -0,0 +1,202 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ from google.cloud import aiplatform
21
+ from google.cloud.aiplatform.compat.types import execution_v1 as gca_execution
22
+
23
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
24
+
25
+
26
+ class ExperimentHook(GoogleBaseHook):
27
+ """Use the Vertex AI SDK for Python to manage your experiments."""
28
+
29
+ @GoogleBaseHook.fallback_to_default_project_id
30
+ def create_experiment(
31
+ self,
32
+ experiment_name: str,
33
+ location: str,
34
+ experiment_description: str = "",
35
+ project_id: str = PROVIDE_PROJECT_ID,
36
+ experiment_tensorboard: str | None = None,
37
+ ):
38
+ """
39
+ Create an experiment and, optionally, associate a Vertex AI TensorBoard instance using the Vertex AI SDK for Python.
40
+
41
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
42
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
43
+ :param experiment_name: Required. The name of the evaluation experiment.
44
+ :param experiment_description: Optional. Description of the evaluation experiment.
45
+ :param experiment_tensorboard: Optional. The Vertex TensorBoard instance to use as a backing
46
+ TensorBoard for the provided experiment. If no TensorBoard is provided, a default Tensorboard
47
+ instance is created and used by this experiment.
48
+ """
49
+ aiplatform.init(
50
+ experiment=experiment_name,
51
+ experiment_description=experiment_description,
52
+ experiment_tensorboard=experiment_tensorboard if experiment_tensorboard else False,
53
+ project=project_id,
54
+ location=location,
55
+ )
56
+ self.log.info("Created experiment with name: %s", experiment_name)
57
+
58
+ @GoogleBaseHook.fallback_to_default_project_id
59
+ def delete_experiment(
60
+ self,
61
+ experiment_name: str,
62
+ location: str,
63
+ project_id: str = PROVIDE_PROJECT_ID,
64
+ delete_backing_tensorboard_runs: bool = False,
65
+ ) -> None:
66
+ """
67
+ Delete an experiment.
68
+
69
+ Deleting an experiment deletes that experiment and all experiment runs associated with the experiment.
70
+ The Vertex AI TensorBoard experiment associated with the experiment is not deleted.
71
+
72
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
73
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
74
+ :param experiment_name: Required. The name of the evaluation experiment.
75
+ :param delete_backing_tensorboard_runs: Optional. If True will also delete the Vertex AI TensorBoard
76
+ runs associated with the experiment runs under this experiment that we used to store time series
77
+ metrics.
78
+ """
79
+ experiment = aiplatform.Experiment(
80
+ experiment_name=experiment_name, project=project_id, location=location
81
+ )
82
+
83
+ experiment.delete(delete_backing_tensorboard_runs=delete_backing_tensorboard_runs)
84
+
85
+
86
+ class ExperimentRunHook(GoogleBaseHook):
87
+ """Use the Vertex AI SDK for Python to create and manage your experiment runs."""
88
+
89
+ @GoogleBaseHook.fallback_to_default_project_id
90
+ def create_experiment_run(
91
+ self,
92
+ experiment_run_name: str,
93
+ experiment_name: str,
94
+ location: str,
95
+ project_id: str = PROVIDE_PROJECT_ID,
96
+ experiment_run_tensorboard: str | None = None,
97
+ run_after_creation: bool = False,
98
+ ) -> None:
99
+ """
100
+ Create experiment run for the experiment.
101
+
102
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
103
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
104
+ :param experiment_name: Required. The name of the evaluation experiment.
105
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
106
+ :param experiment_run_tensorboard: Optional. A backing TensorBoard resource to enable and store time
107
+ series metrics logged to this experiment run.
108
+ :param run_after_creation: Optional. Responsible for state after creation of experiment run.
109
+ If true experiment run will be created with state RUNNING.
110
+ """
111
+ experiment_run_state = (
112
+ gca_execution.Execution.State.NEW
113
+ if not run_after_creation
114
+ else gca_execution.Execution.State.RUNNING
115
+ )
116
+ experiment_run = aiplatform.ExperimentRun.create(
117
+ run_name=experiment_run_name,
118
+ experiment=experiment_name,
119
+ project=project_id,
120
+ location=location,
121
+ state=experiment_run_state,
122
+ tensorboard=experiment_run_tensorboard,
123
+ )
124
+ self.log.info(
125
+ "Created experiment run with name: %s and status: %s",
126
+ experiment_run.name,
127
+ experiment_run.state,
128
+ )
129
+
130
+ @GoogleBaseHook.fallback_to_default_project_id
131
+ def list_experiment_runs(
132
+ self,
133
+ experiment_name: str,
134
+ location: str,
135
+ project_id: str = PROVIDE_PROJECT_ID,
136
+ ) -> list[aiplatform.ExperimentRun]:
137
+ """
138
+ List experiment run for the experiment.
139
+
140
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
141
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
142
+ :param experiment_name: Required. The name of the evaluation experiment.
143
+ """
144
+ experiment_runs = aiplatform.ExperimentRun.list(
145
+ experiment=experiment_name,
146
+ project=project_id,
147
+ location=location,
148
+ )
149
+ return experiment_runs
150
+
151
+ @GoogleBaseHook.fallback_to_default_project_id
152
+ def update_experiment_run_state(
153
+ self,
154
+ experiment_run_name: str,
155
+ experiment_name: str,
156
+ location: str,
157
+ new_state: gca_execution.Execution.State,
158
+ project_id: str = PROVIDE_PROJECT_ID,
159
+ ) -> None:
160
+ """
161
+ Update state of the experiment run.
162
+
163
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
164
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
165
+ :param experiment_name: Required. The name of the evaluation experiment.
166
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
167
+ :param new_state: Required. New state of the experiment run.
168
+ """
169
+ experiment_run = aiplatform.ExperimentRun(
170
+ run_name=experiment_run_name,
171
+ experiment=experiment_name,
172
+ project=project_id,
173
+ location=location,
174
+ )
175
+ self.log.info("State of the %s before update is: %s", experiment_run.name, experiment_run.state)
176
+
177
+ experiment_run.update_state(new_state)
178
+
179
+ @GoogleBaseHook.fallback_to_default_project_id
180
+ def delete_experiment_run(
181
+ self,
182
+ experiment_run_name: str,
183
+ experiment_name: str,
184
+ location: str,
185
+ project_id: str = PROVIDE_PROJECT_ID,
186
+ delete_backing_tensorboard_run: bool = False,
187
+ ) -> None:
188
+ """
189
+ Delete experiment run from the experiment.
190
+
191
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
192
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
193
+ :param experiment_name: Required. The name of the evaluation experiment.
194
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
195
+ :param delete_backing_tensorboard_run: Whether to delete the backing Vertex AI TensorBoard run
196
+ that stores time series metrics for this run.
197
+ """
198
+ self.log.info("Next experiment run will be deleted: %s", experiment_run_name)
199
+ experiment_run = aiplatform.ExperimentRun(
200
+ run_name=experiment_run_name, experiment=experiment_name, project=project_id, location=location
201
+ )
202
+ experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)