apache-airflow-providers-google 14.0.0__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. airflow/providers/google/3rd-party-licenses/LICENSES.txt +14 -0
  2. airflow/providers/google/3rd-party-licenses/NOTICE +5 -0
  3. airflow/providers/google/__init__.py +3 -3
  4. airflow/providers/google/_vendor/__init__.py +0 -0
  5. airflow/providers/google/_vendor/json_merge_patch.py +91 -0
  6. airflow/providers/google/ads/hooks/ads.py +52 -43
  7. airflow/providers/google/ads/operators/ads.py +2 -2
  8. airflow/providers/google/ads/transfers/ads_to_gcs.py +3 -19
  9. airflow/providers/google/assets/gcs.py +1 -11
  10. airflow/providers/google/cloud/_internal_client/secret_manager_client.py +3 -2
  11. airflow/providers/google/cloud/bundles/gcs.py +161 -0
  12. airflow/providers/google/cloud/hooks/alloy_db.py +2 -3
  13. airflow/providers/google/cloud/hooks/bigquery.py +195 -318
  14. airflow/providers/google/cloud/hooks/bigquery_dts.py +8 -8
  15. airflow/providers/google/cloud/hooks/bigtable.py +3 -2
  16. airflow/providers/google/cloud/hooks/cloud_batch.py +8 -9
  17. airflow/providers/google/cloud/hooks/cloud_build.py +6 -65
  18. airflow/providers/google/cloud/hooks/cloud_composer.py +292 -24
  19. airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
  20. airflow/providers/google/cloud/hooks/cloud_memorystore.py +4 -3
  21. airflow/providers/google/cloud/hooks/cloud_run.py +20 -11
  22. airflow/providers/google/cloud/hooks/cloud_sql.py +136 -64
  23. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +35 -15
  24. airflow/providers/google/cloud/hooks/compute.py +7 -6
  25. airflow/providers/google/cloud/hooks/compute_ssh.py +7 -4
  26. airflow/providers/google/cloud/hooks/datacatalog.py +12 -3
  27. airflow/providers/google/cloud/hooks/dataflow.py +87 -242
  28. airflow/providers/google/cloud/hooks/dataform.py +9 -14
  29. airflow/providers/google/cloud/hooks/datafusion.py +7 -9
  30. airflow/providers/google/cloud/hooks/dataplex.py +13 -12
  31. airflow/providers/google/cloud/hooks/dataprep.py +2 -2
  32. airflow/providers/google/cloud/hooks/dataproc.py +76 -74
  33. airflow/providers/google/cloud/hooks/dataproc_metastore.py +4 -3
  34. airflow/providers/google/cloud/hooks/dlp.py +5 -4
  35. airflow/providers/google/cloud/hooks/gcs.py +144 -33
  36. airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
  37. airflow/providers/google/cloud/hooks/kms.py +3 -2
  38. airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -17
  39. airflow/providers/google/cloud/hooks/looker.py +6 -1
  40. airflow/providers/google/cloud/hooks/managed_kafka.py +227 -3
  41. airflow/providers/google/cloud/hooks/mlengine.py +7 -8
  42. airflow/providers/google/cloud/hooks/natural_language.py +3 -2
  43. airflow/providers/google/cloud/hooks/os_login.py +3 -2
  44. airflow/providers/google/cloud/hooks/pubsub.py +6 -6
  45. airflow/providers/google/cloud/hooks/secret_manager.py +105 -12
  46. airflow/providers/google/cloud/hooks/spanner.py +75 -10
  47. airflow/providers/google/cloud/hooks/speech_to_text.py +3 -2
  48. airflow/providers/google/cloud/hooks/stackdriver.py +18 -18
  49. airflow/providers/google/cloud/hooks/tasks.py +4 -3
  50. airflow/providers/google/cloud/hooks/text_to_speech.py +3 -2
  51. airflow/providers/google/cloud/hooks/translate.py +8 -17
  52. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -222
  53. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +9 -15
  54. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +33 -283
  55. airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +5 -12
  56. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +6 -12
  57. airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
  58. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +311 -10
  59. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
  60. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +7 -13
  61. airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +8 -12
  62. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +6 -12
  63. airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +3 -2
  64. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
  65. airflow/providers/google/cloud/hooks/video_intelligence.py +3 -2
  66. airflow/providers/google/cloud/hooks/vision.py +7 -7
  67. airflow/providers/google/cloud/hooks/workflows.py +4 -3
  68. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  69. airflow/providers/google/cloud/links/base.py +77 -7
  70. airflow/providers/google/cloud/links/bigquery.py +0 -47
  71. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  72. airflow/providers/google/cloud/links/bigtable.py +0 -48
  73. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  74. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  75. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  76. airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
  77. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  78. airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -46
  79. airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
  80. airflow/providers/google/cloud/links/compute.py +0 -58
  81. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  82. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  83. airflow/providers/google/cloud/links/dataflow.py +0 -34
  84. airflow/providers/google/cloud/links/dataform.py +0 -64
  85. airflow/providers/google/cloud/links/datafusion.py +1 -90
  86. airflow/providers/google/cloud/links/dataplex.py +0 -154
  87. airflow/providers/google/cloud/links/dataprep.py +0 -24
  88. airflow/providers/google/cloud/links/dataproc.py +11 -89
  89. airflow/providers/google/cloud/links/datastore.py +0 -31
  90. airflow/providers/google/cloud/links/kubernetes_engine.py +11 -61
  91. airflow/providers/google/cloud/links/managed_kafka.py +11 -51
  92. airflow/providers/google/cloud/links/mlengine.py +0 -70
  93. airflow/providers/google/cloud/links/pubsub.py +0 -32
  94. airflow/providers/google/cloud/links/spanner.py +0 -33
  95. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  96. airflow/providers/google/cloud/links/translate.py +17 -187
  97. airflow/providers/google/cloud/links/vertex_ai.py +28 -195
  98. airflow/providers/google/cloud/links/workflows.py +0 -52
  99. airflow/providers/google/cloud/log/gcs_task_handler.py +166 -118
  100. airflow/providers/google/cloud/log/stackdriver_task_handler.py +14 -9
  101. airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
  102. airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
  103. airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
  104. airflow/providers/google/cloud/openlineage/facets.py +141 -40
  105. airflow/providers/google/cloud/openlineage/mixins.py +14 -13
  106. airflow/providers/google/cloud/openlineage/utils.py +19 -3
  107. airflow/providers/google/cloud/operators/alloy_db.py +76 -61
  108. airflow/providers/google/cloud/operators/bigquery.py +104 -667
  109. airflow/providers/google/cloud/operators/bigquery_dts.py +12 -12
  110. airflow/providers/google/cloud/operators/bigtable.py +38 -7
  111. airflow/providers/google/cloud/operators/cloud_base.py +22 -1
  112. airflow/providers/google/cloud/operators/cloud_batch.py +18 -18
  113. airflow/providers/google/cloud/operators/cloud_build.py +80 -36
  114. airflow/providers/google/cloud/operators/cloud_composer.py +157 -71
  115. airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
  116. airflow/providers/google/cloud/operators/cloud_memorystore.py +74 -46
  117. airflow/providers/google/cloud/operators/cloud_run.py +39 -20
  118. airflow/providers/google/cloud/operators/cloud_sql.py +46 -61
  119. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -14
  120. airflow/providers/google/cloud/operators/compute.py +18 -50
  121. airflow/providers/google/cloud/operators/datacatalog.py +167 -29
  122. airflow/providers/google/cloud/operators/dataflow.py +38 -15
  123. airflow/providers/google/cloud/operators/dataform.py +19 -7
  124. airflow/providers/google/cloud/operators/datafusion.py +43 -43
  125. airflow/providers/google/cloud/operators/dataplex.py +212 -126
  126. airflow/providers/google/cloud/operators/dataprep.py +1 -5
  127. airflow/providers/google/cloud/operators/dataproc.py +134 -207
  128. airflow/providers/google/cloud/operators/dataproc_metastore.py +102 -84
  129. airflow/providers/google/cloud/operators/datastore.py +22 -6
  130. airflow/providers/google/cloud/operators/dlp.py +24 -45
  131. airflow/providers/google/cloud/operators/functions.py +21 -14
  132. airflow/providers/google/cloud/operators/gcs.py +15 -12
  133. airflow/providers/google/cloud/operators/gen_ai.py +389 -0
  134. airflow/providers/google/cloud/operators/kubernetes_engine.py +115 -106
  135. airflow/providers/google/cloud/operators/looker.py +1 -1
  136. airflow/providers/google/cloud/operators/managed_kafka.py +362 -40
  137. airflow/providers/google/cloud/operators/natural_language.py +5 -3
  138. airflow/providers/google/cloud/operators/pubsub.py +69 -21
  139. airflow/providers/google/cloud/operators/spanner.py +53 -45
  140. airflow/providers/google/cloud/operators/speech_to_text.py +5 -4
  141. airflow/providers/google/cloud/operators/stackdriver.py +5 -11
  142. airflow/providers/google/cloud/operators/tasks.py +6 -15
  143. airflow/providers/google/cloud/operators/text_to_speech.py +4 -3
  144. airflow/providers/google/cloud/operators/translate.py +46 -20
  145. airflow/providers/google/cloud/operators/translate_speech.py +4 -3
  146. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +44 -34
  147. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +34 -12
  148. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +62 -53
  149. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +75 -11
  150. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +48 -12
  151. airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
  152. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
  153. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
  154. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +16 -12
  155. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +62 -14
  156. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +35 -10
  157. airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
  158. airflow/providers/google/cloud/operators/video_intelligence.py +5 -3
  159. airflow/providers/google/cloud/operators/vision.py +7 -5
  160. airflow/providers/google/cloud/operators/workflows.py +24 -19
  161. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  162. airflow/providers/google/cloud/sensors/bigquery.py +2 -2
  163. airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -4
  164. airflow/providers/google/cloud/sensors/bigtable.py +14 -6
  165. airflow/providers/google/cloud/sensors/cloud_composer.py +535 -33
  166. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -5
  167. airflow/providers/google/cloud/sensors/dataflow.py +27 -10
  168. airflow/providers/google/cloud/sensors/dataform.py +2 -2
  169. airflow/providers/google/cloud/sensors/datafusion.py +4 -4
  170. airflow/providers/google/cloud/sensors/dataplex.py +7 -5
  171. airflow/providers/google/cloud/sensors/dataprep.py +2 -2
  172. airflow/providers/google/cloud/sensors/dataproc.py +10 -9
  173. airflow/providers/google/cloud/sensors/dataproc_metastore.py +4 -3
  174. airflow/providers/google/cloud/sensors/gcs.py +22 -21
  175. airflow/providers/google/cloud/sensors/looker.py +5 -5
  176. airflow/providers/google/cloud/sensors/pubsub.py +20 -20
  177. airflow/providers/google/cloud/sensors/tasks.py +2 -2
  178. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
  179. airflow/providers/google/cloud/sensors/workflows.py +6 -4
  180. airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
  181. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
  182. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
  183. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
  184. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +14 -13
  185. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
  186. airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
  187. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
  188. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
  189. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  190. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +18 -22
  191. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -5
  192. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +45 -38
  193. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
  194. airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
  195. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
  196. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
  197. airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
  198. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  199. airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
  200. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
  201. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
  202. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +44 -12
  203. airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
  204. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
  205. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +36 -14
  206. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
  207. airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
  208. airflow/providers/google/cloud/triggers/bigquery.py +75 -34
  209. airflow/providers/google/cloud/triggers/bigquery_dts.py +2 -1
  210. airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
  211. airflow/providers/google/cloud/triggers/cloud_build.py +3 -2
  212. airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
  213. airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
  214. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +96 -5
  215. airflow/providers/google/cloud/triggers/dataflow.py +125 -2
  216. airflow/providers/google/cloud/triggers/datafusion.py +1 -1
  217. airflow/providers/google/cloud/triggers/dataplex.py +16 -3
  218. airflow/providers/google/cloud/triggers/dataproc.py +124 -53
  219. airflow/providers/google/cloud/triggers/kubernetes_engine.py +46 -28
  220. airflow/providers/google/cloud/triggers/mlengine.py +1 -1
  221. airflow/providers/google/cloud/triggers/pubsub.py +17 -20
  222. airflow/providers/google/cloud/triggers/vertex_ai.py +8 -7
  223. airflow/providers/google/cloud/utils/bigquery.py +5 -7
  224. airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
  225. airflow/providers/google/cloud/utils/credentials_provider.py +4 -3
  226. airflow/providers/google/cloud/utils/dataform.py +1 -1
  227. airflow/providers/google/cloud/utils/external_token_supplier.py +0 -1
  228. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  229. airflow/providers/google/cloud/utils/validators.py +43 -0
  230. airflow/providers/google/common/auth_backend/google_openid.py +26 -9
  231. airflow/providers/google/common/consts.py +2 -1
  232. airflow/providers/google/common/deprecated.py +2 -1
  233. airflow/providers/google/common/hooks/base_google.py +40 -43
  234. airflow/providers/google/common/hooks/operation_helpers.py +78 -0
  235. airflow/providers/google/common/links/storage.py +0 -22
  236. airflow/providers/google/common/utils/get_secret.py +31 -0
  237. airflow/providers/google/common/utils/id_token_credentials.py +4 -5
  238. airflow/providers/google/firebase/operators/firestore.py +2 -2
  239. airflow/providers/google/get_provider_info.py +61 -216
  240. airflow/providers/google/go_module_utils.py +35 -3
  241. airflow/providers/google/leveldb/hooks/leveldb.py +30 -6
  242. airflow/providers/google/leveldb/operators/leveldb.py +2 -2
  243. airflow/providers/google/marketing_platform/hooks/analytics_admin.py +3 -2
  244. airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
  245. airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
  246. airflow/providers/google/marketing_platform/links/analytics_admin.py +4 -5
  247. airflow/providers/google/marketing_platform/operators/analytics_admin.py +7 -6
  248. airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
  249. airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
  250. airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
  251. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
  252. airflow/providers/google/marketing_platform/sensors/display_video.py +4 -64
  253. airflow/providers/google/suite/hooks/calendar.py +1 -1
  254. airflow/providers/google/suite/hooks/drive.py +2 -2
  255. airflow/providers/google/suite/hooks/sheets.py +15 -1
  256. airflow/providers/google/suite/operators/sheets.py +8 -3
  257. airflow/providers/google/suite/sensors/drive.py +2 -2
  258. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
  259. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  260. airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
  261. airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
  262. airflow/providers/google/version_compat.py +15 -1
  263. {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +117 -72
  264. apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
  265. {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +1 -1
  266. apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
  267. airflow/providers/google/cloud/example_dags/example_cloud_task.py +0 -54
  268. airflow/providers/google/cloud/hooks/automl.py +0 -679
  269. airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
  270. airflow/providers/google/cloud/links/automl.py +0 -193
  271. airflow/providers/google/cloud/operators/automl.py +0 -1360
  272. airflow/providers/google/cloud/operators/life_sciences.py +0 -119
  273. airflow/providers/google/cloud/operators/mlengine.py +0 -1515
  274. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
  275. apache_airflow_providers_google-14.0.0.dist-info/RECORD +0 -318
  276. /airflow/providers/google/cloud/{example_dags → bundles}/__init__.py +0 -0
  277. {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
  278. {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
@@ -21,14 +21,15 @@ from __future__ import annotations
21
21
 
22
22
  import time
23
23
  from datetime import timedelta
24
- from typing import TYPE_CHECKING
24
+ from typing import TYPE_CHECKING, Any, Literal
25
25
 
26
26
  import vertexai
27
+ from google.cloud import aiplatform
27
28
  from vertexai.generative_models import GenerativeModel
28
- from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
29
+ from vertexai.language_models import TextEmbeddingModel
30
+ from vertexai.preview import generative_models as preview_generative_model
29
31
  from vertexai.preview.caching import CachedContent
30
32
  from vertexai.preview.evaluation import EvalResult, EvalTask
31
- from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
32
33
  from vertexai.preview.tuning import sft
33
34
 
34
35
  from airflow.exceptions import AirflowProviderDeprecationWarning
@@ -36,23 +37,12 @@ from airflow.providers.google.common.deprecated import deprecated
36
37
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
37
38
 
38
39
  if TYPE_CHECKING:
39
- from google.cloud.aiplatform_v1 import types as types_v1
40
40
  from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
41
41
 
42
42
 
43
43
  class GenerativeModelHook(GoogleBaseHook):
44
44
  """Hook for Google Cloud Vertex AI Generative Model APIs."""
45
45
 
46
- @deprecated(
47
- planned_removal_date="April 09, 2025",
48
- use_instead="GenerativeModelHook.get_generative_model",
49
- category=AirflowProviderDeprecationWarning,
50
- )
51
- def get_text_generation_model(self, pretrained_model: str):
52
- """Return a Model Garden Model object based on Text Generation."""
53
- model = TextGenerationModel.from_pretrained(pretrained_model)
54
- return model
55
-
56
46
  def get_text_embedding_model(self, pretrained_model: str):
57
47
  """Return a Model Garden Model object based on Text Embedding."""
58
48
  model = TextEmbeddingModel.from_pretrained(pretrained_model)
@@ -61,7 +51,7 @@ class GenerativeModelHook(GoogleBaseHook):
61
51
  def get_generative_model(
62
52
  self,
63
53
  pretrained_model: str,
64
- system_instruction: str | None = None,
54
+ system_instruction: Any | None = None,
65
55
  generation_config: dict | None = None,
66
56
  safety_settings: dict | None = None,
67
57
  tools: list | None = None,
@@ -93,66 +83,18 @@ class GenerativeModelHook(GoogleBaseHook):
93
83
  def get_cached_context_model(
94
84
  self,
95
85
  cached_content_name: str,
96
- ) -> preview_generative_model:
86
+ ) -> Any:
97
87
  """Return a Generative Model with Cached Context."""
98
88
  cached_content = CachedContent(cached_content_name=cached_content_name)
99
89
 
100
- cached_context_model = preview_generative_model.from_cached_content(cached_content)
90
+ cached_context_model = preview_generative_model.GenerativeModel.from_cached_content(cached_content)
101
91
  return cached_context_model
102
92
 
103
93
  @deprecated(
104
- planned_removal_date="April 09, 2025",
105
- use_instead="GenerativeModelHook.generative_model_generate_content",
94
+ planned_removal_date="January 3, 2026",
95
+ use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.embed_content",
106
96
  category=AirflowProviderDeprecationWarning,
107
97
  )
108
- @GoogleBaseHook.fallback_to_default_project_id
109
- def text_generation_model_predict(
110
- self,
111
- prompt: str,
112
- pretrained_model: str,
113
- temperature: float,
114
- max_output_tokens: int,
115
- top_p: float,
116
- top_k: int,
117
- location: str,
118
- project_id: str = PROVIDE_PROJECT_ID,
119
- ) -> str:
120
- """
121
- Use the Vertex AI PaLM API to generate natural language text.
122
-
123
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
124
- :param location: Required. The ID of the Google Cloud location that the service belongs to.
125
- :param prompt: Required. Inputs or queries that a user or a program gives
126
- to the Vertex AI PaLM API, in order to elicit a specific response.
127
- :param pretrained_model: A pre-trained model optimized for performing natural
128
- language tasks such as classification, summarization, extraction, content
129
- creation, and ideation.
130
- :param temperature: Temperature controls the degree of randomness in token
131
- selection.
132
- :param max_output_tokens: Token limit determines the maximum amount of text
133
- output.
134
- :param top_p: Tokens are selected from most probable to least until the sum
135
- of their probabilities equals the top_p value. Defaults to 0.8.
136
- :param top_k: A top_k of 1 means the selected token is the most probable
137
- among all tokens.
138
- """
139
- vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
140
-
141
- parameters = {
142
- "temperature": temperature,
143
- "max_output_tokens": max_output_tokens,
144
- "top_p": top_p,
145
- "top_k": top_k,
146
- }
147
-
148
- model = self.get_text_generation_model(pretrained_model)
149
-
150
- response = model.predict(
151
- prompt=prompt,
152
- **parameters,
153
- )
154
- return response.text
155
-
156
98
  @GoogleBaseHook.fallback_to_default_project_id
157
99
  def text_embedding_model_get_embeddings(
158
100
  self,
@@ -177,16 +119,21 @@ class GenerativeModelHook(GoogleBaseHook):
177
119
 
178
120
  return response.values
179
121
 
122
+ @deprecated(
123
+ planned_removal_date="January 3, 2026",
124
+ use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.generate_content",
125
+ category=AirflowProviderDeprecationWarning,
126
+ )
180
127
  @GoogleBaseHook.fallback_to_default_project_id
181
128
  def generative_model_generate_content(
182
129
  self,
183
130
  contents: list,
184
131
  location: str,
132
+ pretrained_model: str,
185
133
  tools: list | None = None,
186
134
  generation_config: dict | None = None,
187
135
  safety_settings: dict | None = None,
188
136
  system_instruction: str | None = None,
189
- pretrained_model: str = "gemini-pro",
190
137
  project_id: str = PROVIDE_PROJECT_ID,
191
138
  ) -> str:
192
139
  """
@@ -200,7 +147,7 @@ class GenerativeModelHook(GoogleBaseHook):
200
147
  :param safety_settings: Optional. Per request settings for blocking unsafe content.
201
148
  :param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
202
149
  :param system_instruction: Optional. An instruction given to the model to guide its behavior.
203
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
150
+ :param pretrained_model: Required. Model,
204
151
  supporting prompts with text-only input, including natural language
205
152
  tasks, multi-turn text and code chat, and code generation. It can
206
153
  output text and code.
@@ -219,6 +166,11 @@ class GenerativeModelHook(GoogleBaseHook):
219
166
 
220
167
  return response.text
221
168
 
169
+ @deprecated(
170
+ planned_removal_date="January 3, 2026",
171
+ use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.supervised_fine_tuning_train",
172
+ category=AirflowProviderDeprecationWarning,
173
+ )
222
174
  @GoogleBaseHook.fallback_to_default_project_id
223
175
  def supervised_fine_tuning_train(
224
176
  self,
@@ -228,10 +180,10 @@ class GenerativeModelHook(GoogleBaseHook):
228
180
  tuned_model_display_name: str | None = None,
229
181
  validation_dataset: str | None = None,
230
182
  epochs: int | None = None,
231
- adapter_size: int | None = None,
183
+ adapter_size: Literal[1, 4, 8, 16] | None = None,
232
184
  learning_rate_multiplier: float | None = None,
233
185
  project_id: str = PROVIDE_PROJECT_ID,
234
- ) -> types_v1.TuningJob:
186
+ ) -> Any:
235
187
  """
236
188
  Use the Supervised Fine Tuning API to create a tuning job.
237
189
 
@@ -272,12 +224,17 @@ class GenerativeModelHook(GoogleBaseHook):
272
224
 
273
225
  return sft_tuning_job
274
226
 
227
+ @deprecated(
228
+ planned_removal_date="January 3, 2026",
229
+ use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.count_tokens",
230
+ category=AirflowProviderDeprecationWarning,
231
+ )
275
232
  @GoogleBaseHook.fallback_to_default_project_id
276
233
  def count_tokens(
277
234
  self,
278
235
  contents: list,
279
236
  location: str,
280
- pretrained_model: str = "gemini-pro",
237
+ pretrained_model: str,
281
238
  project_id: str = PROVIDE_PROJECT_ID,
282
239
  ) -> types_v1beta1.CountTokensResponse:
283
240
  """
@@ -287,7 +244,7 @@ class GenerativeModelHook(GoogleBaseHook):
287
244
  :param location: Required. The ID of the Google Cloud location that the service belongs to.
288
245
  :param contents: Required. The multi-part content of a message that a user or a program
289
246
  gives to the generative model, in order to elicit a specific response.
290
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
247
+ :param pretrained_model: Required. Model,
291
248
  supporting prompts with text-only input, including natural language
292
249
  tasks, multi-turn text and code chat, and code generation. It can
293
250
  output text and code.
@@ -359,13 +316,18 @@ class GenerativeModelHook(GoogleBaseHook):
359
316
 
360
317
  return eval_result
361
318
 
319
+ @deprecated(
320
+ planned_removal_date="January 3, 2026",
321
+ use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.create_cached_content",
322
+ category=AirflowProviderDeprecationWarning,
323
+ )
362
324
  def create_cached_content(
363
325
  self,
364
326
  model_name: str,
365
327
  location: str,
366
328
  ttl_hours: float = 1,
367
- system_instruction: str | None = None,
368
- contents: list | None = None,
329
+ system_instruction: Any | None = None,
330
+ contents: list[Any] | None = None,
369
331
  display_name: str | None = None,
370
332
  project_id: str = PROVIDE_PROJECT_ID,
371
333
  ) -> str:
@@ -393,6 +355,11 @@ class GenerativeModelHook(GoogleBaseHook):
393
355
 
394
356
  return response.name
395
357
 
358
+ @deprecated(
359
+ planned_removal_date="January 3, 2026",
360
+ use_instead="airflow.providers.google.cloud.hooks.gen_ai.generative_model.GenAIGenerativeModelHook.generate_content",
361
+ category=AirflowProviderDeprecationWarning,
362
+ )
396
363
  def generate_from_cached_content(
397
364
  self,
398
365
  location: str,
@@ -413,6 +380,9 @@ class GenerativeModelHook(GoogleBaseHook):
413
380
  :param generation_config: Optional. Generation configuration settings.
414
381
  :param safety_settings: Optional. Per request settings for blocking unsafe content.
415
382
  """
383
+ # During run of the system test it was found out that names from xcom, e.g. 3402922389 can be
384
+ # treated as int and throw an error TypeError: expected string or bytes-like object, got 'int'
385
+ cached_content_name = str(cached_content_name)
416
386
  vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
417
387
 
418
388
  cached_context_model = self.get_cached_context_model(cached_content_name=cached_content_name)
@@ -424,3 +394,37 @@ class GenerativeModelHook(GoogleBaseHook):
424
394
  )
425
395
 
426
396
  return response.text
397
+
398
+
399
+ @deprecated(
400
+ planned_removal_date="January 3, 2026",
401
+ use_instead="airflow.providers.google.cloud.hooks.vertex_ai.experiment_service.ExperimentRunHook",
402
+ category=AirflowProviderDeprecationWarning,
403
+ )
404
+ class ExperimentRunHook(GoogleBaseHook):
405
+ """Use the Vertex AI SDK for Python to create and manage your experiment runs."""
406
+
407
+ @GoogleBaseHook.fallback_to_default_project_id
408
+ def delete_experiment_run(
409
+ self,
410
+ experiment_run_name: str,
411
+ experiment_name: str,
412
+ location: str,
413
+ project_id: str = PROVIDE_PROJECT_ID,
414
+ delete_backing_tensorboard_run: bool = False,
415
+ ) -> None:
416
+ """
417
+ Delete experiment run from the experiment.
418
+
419
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
420
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
421
+ :param experiment_name: Required. The name of the evaluation experiment.
422
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
423
+ :param delete_backing_tensorboard_run: Whether to delete the backing Vertex AI TensorBoard run
424
+ that stores time series metrics for this run.
425
+ """
426
+ self.log.info("Next experiment run will be deleted: %s", experiment_run_name)
427
+ experiment_run = aiplatform.ExperimentRun(
428
+ run_name=experiment_run_name, experiment=experiment_name, project=project_id, location=location
429
+ )
430
+ experiment_run.delete(delete_backing_tensorboard_run=delete_backing_tensorboard_run)
@@ -29,21 +29,23 @@ import asyncio
29
29
  from collections.abc import Sequence
30
30
  from typing import TYPE_CHECKING
31
31
 
32
- from airflow.exceptions import AirflowException
33
- from airflow.providers.google.common.consts import CLIENT_INFO
34
- from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
35
32
  from google.api_core.client_options import ClientOptions
36
33
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
37
34
  from google.cloud.aiplatform import CustomJob, HyperparameterTuningJob, gapic, hyperparameter_tuning
38
35
  from google.cloud.aiplatform_v1 import JobServiceAsyncClient, JobServiceClient, JobState, types
39
36
 
37
+ from airflow.exceptions import AirflowException
38
+ from airflow.providers.google.common.consts import CLIENT_INFO
39
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
40
+ from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
41
+
40
42
  if TYPE_CHECKING:
41
43
  from google.api_core.operation import Operation
42
44
  from google.api_core.retry import AsyncRetry, Retry
43
45
  from google.cloud.aiplatform_v1.services.job_service.pagers import ListHyperparameterTuningJobsPager
44
46
 
45
47
 
46
- class HyperparameterTuningJobHook(GoogleBaseHook):
48
+ class HyperparameterTuningJobHook(GoogleBaseHook, OperationHelper):
47
49
  """Hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs."""
48
50
 
49
51
  def __init__(
@@ -67,7 +69,7 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
67
69
  client_options = ClientOptions()
68
70
 
69
71
  return JobServiceClient(
70
- credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
72
+ credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
71
73
  )
72
74
 
73
75
  def get_hyperparameter_tuning_job_object(
@@ -133,14 +135,6 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
133
135
  """Return unique id of the hyperparameter_tuning_job."""
134
136
  return obj["name"].rpartition("/")[-1]
135
137
 
136
- def wait_for_operation(self, operation: Operation, timeout: float | None = None):
137
- """Wait for long-lasting operation to complete."""
138
- try:
139
- return operation.result(timeout=timeout)
140
- except Exception:
141
- error = operation.exception(timeout=timeout)
142
- raise AirflowException(error)
143
-
144
138
  def cancel_hyperparameter_tuning_job(self) -> None:
145
139
  """Cancel HyperparameterTuningJob."""
146
140
  if self._hyperparameter_tuning_job:
@@ -23,12 +23,14 @@ from __future__ import annotations
23
23
  from collections.abc import Sequence
24
24
  from typing import TYPE_CHECKING
25
25
 
26
- from airflow.exceptions import AirflowException
27
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
28
26
  from google.api_core.client_options import ClientOptions
29
27
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
30
28
  from google.cloud.aiplatform_v1 import ModelServiceClient
31
29
 
30
+ from airflow.exceptions import AirflowException
31
+ from airflow.providers.google.common.consts import CLIENT_INFO
32
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
33
+
32
34
  if TYPE_CHECKING:
33
35
  from google.api_core.operation import Operation
34
36
  from google.api_core.retry import Retry
@@ -38,8 +40,10 @@ if TYPE_CHECKING:
38
40
  )
39
41
  from google.cloud.aiplatform_v1.types import Model, model_service
40
42
 
43
+ from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
44
+
41
45
 
42
- class ModelServiceHook(GoogleBaseHook):
46
+ class ModelServiceHook(GoogleBaseHook, OperationHelper):
43
47
  """Hook for Google Cloud Vertex AI Endpoint Service APIs."""
44
48
 
45
49
  def get_model_service_client(self, region: str | None = None) -> ModelServiceClient:
@@ -50,7 +54,7 @@ class ModelServiceHook(GoogleBaseHook):
50
54
  client_options = ClientOptions()
51
55
 
52
56
  return ModelServiceClient(
53
- credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
57
+ credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
54
58
  )
55
59
 
56
60
  @staticmethod
@@ -58,14 +62,6 @@ class ModelServiceHook(GoogleBaseHook):
58
62
  """Return unique id of the model."""
59
63
  return obj["model"].rpartition("/")[-1]
60
64
 
61
- def wait_for_operation(self, operation: Operation, timeout: float | None = None):
62
- """Wait for long-lasting operation to complete."""
63
- try:
64
- return operation.result(timeout=timeout)
65
- except Exception:
66
- error = operation.exception(timeout=timeout)
67
- raise AirflowException(error)
68
-
69
65
  @GoogleBaseHook.fallback_to_default_project_id
70
66
  def delete_model(
71
67
  self,
@@ -29,9 +29,6 @@ import asyncio
29
29
  from collections.abc import Sequence
30
30
  from typing import TYPE_CHECKING, Any
31
31
 
32
- from airflow.exceptions import AirflowException
33
- from airflow.providers.google.common.consts import CLIENT_INFO
34
- from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
35
32
  from google.api_core.client_options import ClientOptions
36
33
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
37
34
  from google.cloud.aiplatform import PipelineJob
@@ -42,6 +39,11 @@ from google.cloud.aiplatform_v1 import (
42
39
  types,
43
40
  )
44
41
 
42
+ from airflow.exceptions import AirflowException
43
+ from airflow.providers.google.common.consts import CLIENT_INFO
44
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
45
+ from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
46
+
45
47
  if TYPE_CHECKING:
46
48
  from google.api_core.operation import Operation
47
49
  from google.api_core.retry import AsyncRetry, Retry
@@ -50,7 +52,7 @@ if TYPE_CHECKING:
50
52
  from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ListPipelineJobsPager
51
53
 
52
54
 
53
- class PipelineJobHook(GoogleBaseHook):
55
+ class PipelineJobHook(GoogleBaseHook, OperationHelper):
54
56
  """Hook for Google Cloud Vertex AI Pipeline Job APIs."""
55
57
 
56
58
  def __init__(
@@ -111,14 +113,6 @@ class PipelineJobHook(GoogleBaseHook):
111
113
  failure_policy=failure_policy,
112
114
  )
113
115
 
114
- def wait_for_operation(self, operation: Operation, timeout: float | None = None):
115
- """Wait for long-lasting operation to complete."""
116
- try:
117
- return operation.result(timeout=timeout)
118
- except Exception:
119
- error = operation.exception(timeout=timeout)
120
- raise AirflowException(error)
121
-
122
116
  def cancel_pipeline_job(self) -> None:
123
117
  """Cancel PipelineJob."""
124
118
  if self._pipeline_job:
@@ -20,12 +20,13 @@ from __future__ import annotations
20
20
  from collections.abc import Sequence
21
21
  from typing import TYPE_CHECKING
22
22
 
23
- from airflow.providers.google.common.consts import CLIENT_INFO
24
- from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
25
23
  from google.api_core.client_options import ClientOptions
26
24
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
27
25
  from google.cloud.aiplatform_v1 import PredictionServiceClient
28
26
 
27
+ from airflow.providers.google.common.consts import CLIENT_INFO
28
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
29
+
29
30
  if TYPE_CHECKING:
30
31
  from google.api_core.retry import Retry
31
32
  from google.cloud.aiplatform_v1.types import PredictResponse
@@ -0,0 +1,223 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """This module contains a Google Cloud Vertex AI hook."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import dataclasses
23
+ from collections.abc import MutableMapping
24
+ from typing import Any
25
+
26
+ import vertex_ray
27
+ from google.cloud import aiplatform
28
+ from google.cloud.aiplatform.vertex_ray.util import resources
29
+ from google.cloud.aiplatform_v1 import (
30
+ PersistentResourceServiceClient,
31
+ )
32
+ from proto.marshal.collections.repeated import Repeated
33
+
34
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
35
+
36
+
37
+ class RayHook(GoogleBaseHook):
38
+ """Hook for Google Cloud Vertex AI Ray APIs."""
39
+
40
+ def extract_cluster_id(self, cluster_path) -> str:
41
+ """Extract cluster_id from cluster_path."""
42
+ cluster_id = PersistentResourceServiceClient.parse_persistent_resource_path(cluster_path)[
43
+ "persistent_resource"
44
+ ]
45
+ return cluster_id
46
+
47
+ def serialize_cluster_obj(self, cluster_obj: resources.Cluster) -> dict:
48
+ """Serialize Cluster dataclass to dict."""
49
+
50
+ def __encode_value(value: Any) -> Any:
51
+ if isinstance(value, (list, Repeated)):
52
+ return [__encode_value(nested_value) for nested_value in value]
53
+ if not isinstance(value, dict) and isinstance(value, MutableMapping):
54
+ return {key: __encode_value(nested_value) for key, nested_value in dict(value).items()}
55
+ if dataclasses.is_dataclass(value):
56
+ return dataclasses.asdict(value)
57
+ return value
58
+
59
+ return {
60
+ field.name: __encode_value(getattr(cluster_obj, field.name))
61
+ for field in dataclasses.fields(cluster_obj)
62
+ }
63
+
64
+ @GoogleBaseHook.fallback_to_default_project_id
65
+ def create_ray_cluster(
66
+ self,
67
+ project_id: str,
68
+ location: str,
69
+ head_node_type: resources.Resources = resources.Resources(),
70
+ python_version: str = "3.10",
71
+ ray_version: str = "2.33",
72
+ network: str | None = None,
73
+ service_account: str | None = None,
74
+ cluster_name: str | None = None,
75
+ worker_node_types: list[resources.Resources] | None = None,
76
+ custom_images: resources.NodeImages | None = None,
77
+ enable_metrics_collection: bool = True,
78
+ enable_logging: bool = True,
79
+ psc_interface_config: resources.PscIConfig | None = None,
80
+ reserved_ip_ranges: list[str] | None = None,
81
+ labels: dict[str, str] | None = None,
82
+ ) -> str:
83
+ """
84
+ Create a Ray cluster on the Vertex AI.
85
+
86
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
87
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
88
+ :param head_node_type: The head node resource. Resources.node_count must be 1. If not set, default
89
+ value of Resources() class will be used.
90
+ :param python_version: Python version for the ray cluster.
91
+ :param ray_version: Ray version for the ray cluster. Default is 2.33.0.
92
+ :param network: Virtual private cloud (VPC) network. For Ray Client, VPC peering is required to
93
+ connect to the Ray Cluster managed in the Vertex API service. For Ray Job API, VPC network is not
94
+ required because Ray Cluster connection can be accessed through dashboard address.
95
+ :param service_account: Service account to be used for running Ray programs on the cluster.
96
+ :param cluster_name: This value may be up to 63 characters, and valid characters are `[a-z0-9_-]`.
97
+ The first character cannot be a number or hyphen.
98
+ :param worker_node_types: The list of Resources of the worker nodes. The same Resources object should
99
+ not appear multiple times in the list.
100
+ :param custom_images: The NodeImages which specifies head node and worker nodes images. All the
101
+ workers will share the same image. If each Resource has a specific custom image, use
102
+ `Resources.custom_image` for head/worker_node_type(s). Note that configuring
103
+ `Resources.custom_image` will override `custom_images` here. Allowlist only.
104
+ :param enable_metrics_collection: Enable Ray metrics collection for visualization.
105
+ :param enable_logging: Enable exporting Ray logs to Cloud Logging.
106
+ :param psc_interface_config: PSC-I config.
107
+ :param reserved_ip_ranges: A list of names for the reserved IP ranges under the VPC network that can
108
+ be used for this cluster. If set, we will deploy the cluster within the provided IP ranges.
109
+ Otherwise, the cluster is deployed to any IP ranges under the provided VPC network.
110
+ Example: ["vertex-ai-ip-range"].
111
+ :param labels: The labels with user-defined metadata to organize Ray cluster.
112
+ Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
113
+ lowercase letters, numeric characters, underscores and dashes. International characters are allowed.
114
+ See https://goo.gl/xmQnxf for more information and examples of labels.
115
+ """
116
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
117
+ cluster_path = vertex_ray.create_ray_cluster(
118
+ head_node_type=head_node_type,
119
+ python_version=python_version,
120
+ ray_version=ray_version,
121
+ network=network,
122
+ service_account=service_account,
123
+ cluster_name=cluster_name,
124
+ worker_node_types=worker_node_types,
125
+ custom_images=custom_images,
126
+ enable_metrics_collection=enable_metrics_collection,
127
+ enable_logging=enable_logging,
128
+ psc_interface_config=psc_interface_config,
129
+ reserved_ip_ranges=reserved_ip_ranges,
130
+ labels=labels,
131
+ )
132
+ return cluster_path
133
+
134
+ @GoogleBaseHook.fallback_to_default_project_id
135
+ def list_ray_clusters(
136
+ self,
137
+ project_id: str,
138
+ location: str,
139
+ ) -> list[resources.Cluster]:
140
+ """
141
+ List Ray clusters under the currently authenticated project.
142
+
143
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
144
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
145
+ """
146
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
147
+ ray_clusters = vertex_ray.list_ray_clusters()
148
+ return ray_clusters
149
+
150
+ @GoogleBaseHook.fallback_to_default_project_id
151
+ def get_ray_cluster(
152
+ self,
153
+ project_id: str,
154
+ location: str,
155
+ cluster_id: str,
156
+ ) -> resources.Cluster:
157
+ """
158
+ Get Ray cluster.
159
+
160
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
161
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
162
+ :param cluster_id: Cluster resource ID.
163
+ """
164
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
165
+ ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
166
+ project=project_id,
167
+ location=location,
168
+ persistent_resource=cluster_id,
169
+ )
170
+ ray_cluster = vertex_ray.get_ray_cluster(
171
+ cluster_resource_name=ray_cluster_name,
172
+ )
173
+ return ray_cluster
174
+
175
+ @GoogleBaseHook.fallback_to_default_project_id
176
+ def update_ray_cluster(
177
+ self,
178
+ project_id: str,
179
+ location: str,
180
+ cluster_id: str,
181
+ worker_node_types: list[resources.Resources],
182
+ ) -> str:
183
+ """
184
+ Update Ray cluster (currently support resizing node counts for worker nodes).
185
+
186
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
187
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
188
+ :param cluster_id: Cluster resource ID.
189
+ :param worker_node_types: The list of Resources of the resized worker nodes. The same Resources
190
+ object should not appear multiple times in the list.
191
+ """
192
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
193
+ ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
194
+ project=project_id,
195
+ location=location,
196
+ persistent_resource=cluster_id,
197
+ )
198
+ updated_ray_cluster_name = vertex_ray.update_ray_cluster(
199
+ cluster_resource_name=ray_cluster_name, worker_node_types=worker_node_types
200
+ )
201
+ return updated_ray_cluster_name
202
+
203
+ @GoogleBaseHook.fallback_to_default_project_id
204
+ def delete_ray_cluster(
205
+ self,
206
+ project_id: str,
207
+ location: str,
208
+ cluster_id: str,
209
+ ) -> None:
210
+ """
211
+ Delete Ray cluster.
212
+
213
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
214
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
215
+ :param cluster_id: Cluster resource ID.
216
+ """
217
+ aiplatform.init(project=project_id, location=location, credentials=self.get_credentials())
218
+ ray_cluster_name = PersistentResourceServiceClient.persistent_resource_path(
219
+ project=project_id,
220
+ location=location,
221
+ persistent_resource=cluster_id,
222
+ )
223
+ vertex_ray.delete_ray_cluster(cluster_resource_name=ray_cluster_name)