apache-airflow-providers-google 14.0.0__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. airflow/providers/google/3rd-party-licenses/LICENSES.txt +14 -0
  2. airflow/providers/google/3rd-party-licenses/NOTICE +5 -0
  3. airflow/providers/google/__init__.py +3 -3
  4. airflow/providers/google/_vendor/__init__.py +0 -0
  5. airflow/providers/google/_vendor/json_merge_patch.py +91 -0
  6. airflow/providers/google/ads/hooks/ads.py +52 -43
  7. airflow/providers/google/ads/operators/ads.py +2 -2
  8. airflow/providers/google/ads/transfers/ads_to_gcs.py +3 -19
  9. airflow/providers/google/assets/gcs.py +1 -11
  10. airflow/providers/google/cloud/_internal_client/secret_manager_client.py +3 -2
  11. airflow/providers/google/cloud/bundles/gcs.py +161 -0
  12. airflow/providers/google/cloud/hooks/alloy_db.py +2 -3
  13. airflow/providers/google/cloud/hooks/bigquery.py +195 -318
  14. airflow/providers/google/cloud/hooks/bigquery_dts.py +8 -8
  15. airflow/providers/google/cloud/hooks/bigtable.py +3 -2
  16. airflow/providers/google/cloud/hooks/cloud_batch.py +8 -9
  17. airflow/providers/google/cloud/hooks/cloud_build.py +6 -65
  18. airflow/providers/google/cloud/hooks/cloud_composer.py +292 -24
  19. airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
  20. airflow/providers/google/cloud/hooks/cloud_memorystore.py +4 -3
  21. airflow/providers/google/cloud/hooks/cloud_run.py +20 -11
  22. airflow/providers/google/cloud/hooks/cloud_sql.py +136 -64
  23. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +35 -15
  24. airflow/providers/google/cloud/hooks/compute.py +7 -6
  25. airflow/providers/google/cloud/hooks/compute_ssh.py +7 -4
  26. airflow/providers/google/cloud/hooks/datacatalog.py +12 -3
  27. airflow/providers/google/cloud/hooks/dataflow.py +87 -242
  28. airflow/providers/google/cloud/hooks/dataform.py +9 -14
  29. airflow/providers/google/cloud/hooks/datafusion.py +7 -9
  30. airflow/providers/google/cloud/hooks/dataplex.py +13 -12
  31. airflow/providers/google/cloud/hooks/dataprep.py +2 -2
  32. airflow/providers/google/cloud/hooks/dataproc.py +76 -74
  33. airflow/providers/google/cloud/hooks/dataproc_metastore.py +4 -3
  34. airflow/providers/google/cloud/hooks/dlp.py +5 -4
  35. airflow/providers/google/cloud/hooks/gcs.py +144 -33
  36. airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
  37. airflow/providers/google/cloud/hooks/kms.py +3 -2
  38. airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -17
  39. airflow/providers/google/cloud/hooks/looker.py +6 -1
  40. airflow/providers/google/cloud/hooks/managed_kafka.py +227 -3
  41. airflow/providers/google/cloud/hooks/mlengine.py +7 -8
  42. airflow/providers/google/cloud/hooks/natural_language.py +3 -2
  43. airflow/providers/google/cloud/hooks/os_login.py +3 -2
  44. airflow/providers/google/cloud/hooks/pubsub.py +6 -6
  45. airflow/providers/google/cloud/hooks/secret_manager.py +105 -12
  46. airflow/providers/google/cloud/hooks/spanner.py +75 -10
  47. airflow/providers/google/cloud/hooks/speech_to_text.py +3 -2
  48. airflow/providers/google/cloud/hooks/stackdriver.py +18 -18
  49. airflow/providers/google/cloud/hooks/tasks.py +4 -3
  50. airflow/providers/google/cloud/hooks/text_to_speech.py +3 -2
  51. airflow/providers/google/cloud/hooks/translate.py +8 -17
  52. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -222
  53. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +9 -15
  54. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +33 -283
  55. airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +5 -12
  56. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +6 -12
  57. airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
  58. airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +311 -10
  59. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
  60. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +7 -13
  61. airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +8 -12
  62. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +6 -12
  63. airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +3 -2
  64. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
  65. airflow/providers/google/cloud/hooks/video_intelligence.py +3 -2
  66. airflow/providers/google/cloud/hooks/vision.py +7 -7
  67. airflow/providers/google/cloud/hooks/workflows.py +4 -3
  68. airflow/providers/google/cloud/links/alloy_db.py +0 -46
  69. airflow/providers/google/cloud/links/base.py +77 -7
  70. airflow/providers/google/cloud/links/bigquery.py +0 -47
  71. airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
  72. airflow/providers/google/cloud/links/bigtable.py +0 -48
  73. airflow/providers/google/cloud/links/cloud_build.py +0 -73
  74. airflow/providers/google/cloud/links/cloud_functions.py +0 -33
  75. airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
  76. airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
  77. airflow/providers/google/cloud/links/cloud_sql.py +0 -33
  78. airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -46
  79. airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
  80. airflow/providers/google/cloud/links/compute.py +0 -58
  81. airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
  82. airflow/providers/google/cloud/links/datacatalog.py +23 -54
  83. airflow/providers/google/cloud/links/dataflow.py +0 -34
  84. airflow/providers/google/cloud/links/dataform.py +0 -64
  85. airflow/providers/google/cloud/links/datafusion.py +1 -90
  86. airflow/providers/google/cloud/links/dataplex.py +0 -154
  87. airflow/providers/google/cloud/links/dataprep.py +0 -24
  88. airflow/providers/google/cloud/links/dataproc.py +11 -89
  89. airflow/providers/google/cloud/links/datastore.py +0 -31
  90. airflow/providers/google/cloud/links/kubernetes_engine.py +11 -61
  91. airflow/providers/google/cloud/links/managed_kafka.py +11 -51
  92. airflow/providers/google/cloud/links/mlengine.py +0 -70
  93. airflow/providers/google/cloud/links/pubsub.py +0 -32
  94. airflow/providers/google/cloud/links/spanner.py +0 -33
  95. airflow/providers/google/cloud/links/stackdriver.py +0 -30
  96. airflow/providers/google/cloud/links/translate.py +17 -187
  97. airflow/providers/google/cloud/links/vertex_ai.py +28 -195
  98. airflow/providers/google/cloud/links/workflows.py +0 -52
  99. airflow/providers/google/cloud/log/gcs_task_handler.py +166 -118
  100. airflow/providers/google/cloud/log/stackdriver_task_handler.py +14 -9
  101. airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
  102. airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
  103. airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
  104. airflow/providers/google/cloud/openlineage/facets.py +141 -40
  105. airflow/providers/google/cloud/openlineage/mixins.py +14 -13
  106. airflow/providers/google/cloud/openlineage/utils.py +19 -3
  107. airflow/providers/google/cloud/operators/alloy_db.py +76 -61
  108. airflow/providers/google/cloud/operators/bigquery.py +104 -667
  109. airflow/providers/google/cloud/operators/bigquery_dts.py +12 -12
  110. airflow/providers/google/cloud/operators/bigtable.py +38 -7
  111. airflow/providers/google/cloud/operators/cloud_base.py +22 -1
  112. airflow/providers/google/cloud/operators/cloud_batch.py +18 -18
  113. airflow/providers/google/cloud/operators/cloud_build.py +80 -36
  114. airflow/providers/google/cloud/operators/cloud_composer.py +157 -71
  115. airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
  116. airflow/providers/google/cloud/operators/cloud_memorystore.py +74 -46
  117. airflow/providers/google/cloud/operators/cloud_run.py +39 -20
  118. airflow/providers/google/cloud/operators/cloud_sql.py +46 -61
  119. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -14
  120. airflow/providers/google/cloud/operators/compute.py +18 -50
  121. airflow/providers/google/cloud/operators/datacatalog.py +167 -29
  122. airflow/providers/google/cloud/operators/dataflow.py +38 -15
  123. airflow/providers/google/cloud/operators/dataform.py +19 -7
  124. airflow/providers/google/cloud/operators/datafusion.py +43 -43
  125. airflow/providers/google/cloud/operators/dataplex.py +212 -126
  126. airflow/providers/google/cloud/operators/dataprep.py +1 -5
  127. airflow/providers/google/cloud/operators/dataproc.py +134 -207
  128. airflow/providers/google/cloud/operators/dataproc_metastore.py +102 -84
  129. airflow/providers/google/cloud/operators/datastore.py +22 -6
  130. airflow/providers/google/cloud/operators/dlp.py +24 -45
  131. airflow/providers/google/cloud/operators/functions.py +21 -14
  132. airflow/providers/google/cloud/operators/gcs.py +15 -12
  133. airflow/providers/google/cloud/operators/gen_ai.py +389 -0
  134. airflow/providers/google/cloud/operators/kubernetes_engine.py +115 -106
  135. airflow/providers/google/cloud/operators/looker.py +1 -1
  136. airflow/providers/google/cloud/operators/managed_kafka.py +362 -40
  137. airflow/providers/google/cloud/operators/natural_language.py +5 -3
  138. airflow/providers/google/cloud/operators/pubsub.py +69 -21
  139. airflow/providers/google/cloud/operators/spanner.py +53 -45
  140. airflow/providers/google/cloud/operators/speech_to_text.py +5 -4
  141. airflow/providers/google/cloud/operators/stackdriver.py +5 -11
  142. airflow/providers/google/cloud/operators/tasks.py +6 -15
  143. airflow/providers/google/cloud/operators/text_to_speech.py +4 -3
  144. airflow/providers/google/cloud/operators/translate.py +46 -20
  145. airflow/providers/google/cloud/operators/translate_speech.py +4 -3
  146. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +44 -34
  147. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +34 -12
  148. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +62 -53
  149. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +75 -11
  150. airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +48 -12
  151. airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
  152. airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
  153. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
  154. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +16 -12
  155. airflow/providers/google/cloud/operators/vertex_ai/model_service.py +62 -14
  156. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +35 -10
  157. airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
  158. airflow/providers/google/cloud/operators/video_intelligence.py +5 -3
  159. airflow/providers/google/cloud/operators/vision.py +7 -5
  160. airflow/providers/google/cloud/operators/workflows.py +24 -19
  161. airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
  162. airflow/providers/google/cloud/sensors/bigquery.py +2 -2
  163. airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -4
  164. airflow/providers/google/cloud/sensors/bigtable.py +14 -6
  165. airflow/providers/google/cloud/sensors/cloud_composer.py +535 -33
  166. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -5
  167. airflow/providers/google/cloud/sensors/dataflow.py +27 -10
  168. airflow/providers/google/cloud/sensors/dataform.py +2 -2
  169. airflow/providers/google/cloud/sensors/datafusion.py +4 -4
  170. airflow/providers/google/cloud/sensors/dataplex.py +7 -5
  171. airflow/providers/google/cloud/sensors/dataprep.py +2 -2
  172. airflow/providers/google/cloud/sensors/dataproc.py +10 -9
  173. airflow/providers/google/cloud/sensors/dataproc_metastore.py +4 -3
  174. airflow/providers/google/cloud/sensors/gcs.py +22 -21
  175. airflow/providers/google/cloud/sensors/looker.py +5 -5
  176. airflow/providers/google/cloud/sensors/pubsub.py +20 -20
  177. airflow/providers/google/cloud/sensors/tasks.py +2 -2
  178. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
  179. airflow/providers/google/cloud/sensors/workflows.py +6 -4
  180. airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
  181. airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
  182. airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
  183. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
  184. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +14 -13
  185. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
  186. airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
  187. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
  188. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
  189. airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
  190. airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +18 -22
  191. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -5
  192. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +45 -38
  193. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
  194. airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
  195. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
  196. airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
  197. airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
  198. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  199. airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
  200. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
  201. airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
  202. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +44 -12
  203. airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
  204. airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
  205. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +36 -14
  206. airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
  207. airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
  208. airflow/providers/google/cloud/triggers/bigquery.py +75 -34
  209. airflow/providers/google/cloud/triggers/bigquery_dts.py +2 -1
  210. airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
  211. airflow/providers/google/cloud/triggers/cloud_build.py +3 -2
  212. airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
  213. airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
  214. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +96 -5
  215. airflow/providers/google/cloud/triggers/dataflow.py +125 -2
  216. airflow/providers/google/cloud/triggers/datafusion.py +1 -1
  217. airflow/providers/google/cloud/triggers/dataplex.py +16 -3
  218. airflow/providers/google/cloud/triggers/dataproc.py +124 -53
  219. airflow/providers/google/cloud/triggers/kubernetes_engine.py +46 -28
  220. airflow/providers/google/cloud/triggers/mlengine.py +1 -1
  221. airflow/providers/google/cloud/triggers/pubsub.py +17 -20
  222. airflow/providers/google/cloud/triggers/vertex_ai.py +8 -7
  223. airflow/providers/google/cloud/utils/bigquery.py +5 -7
  224. airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
  225. airflow/providers/google/cloud/utils/credentials_provider.py +4 -3
  226. airflow/providers/google/cloud/utils/dataform.py +1 -1
  227. airflow/providers/google/cloud/utils/external_token_supplier.py +0 -1
  228. airflow/providers/google/cloud/utils/field_validator.py +1 -2
  229. airflow/providers/google/cloud/utils/validators.py +43 -0
  230. airflow/providers/google/common/auth_backend/google_openid.py +26 -9
  231. airflow/providers/google/common/consts.py +2 -1
  232. airflow/providers/google/common/deprecated.py +2 -1
  233. airflow/providers/google/common/hooks/base_google.py +40 -43
  234. airflow/providers/google/common/hooks/operation_helpers.py +78 -0
  235. airflow/providers/google/common/links/storage.py +0 -22
  236. airflow/providers/google/common/utils/get_secret.py +31 -0
  237. airflow/providers/google/common/utils/id_token_credentials.py +4 -5
  238. airflow/providers/google/firebase/operators/firestore.py +2 -2
  239. airflow/providers/google/get_provider_info.py +61 -216
  240. airflow/providers/google/go_module_utils.py +35 -3
  241. airflow/providers/google/leveldb/hooks/leveldb.py +30 -6
  242. airflow/providers/google/leveldb/operators/leveldb.py +2 -2
  243. airflow/providers/google/marketing_platform/hooks/analytics_admin.py +3 -2
  244. airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
  245. airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
  246. airflow/providers/google/marketing_platform/links/analytics_admin.py +4 -5
  247. airflow/providers/google/marketing_platform/operators/analytics_admin.py +7 -6
  248. airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
  249. airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
  250. airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
  251. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
  252. airflow/providers/google/marketing_platform/sensors/display_video.py +4 -64
  253. airflow/providers/google/suite/hooks/calendar.py +1 -1
  254. airflow/providers/google/suite/hooks/drive.py +2 -2
  255. airflow/providers/google/suite/hooks/sheets.py +15 -1
  256. airflow/providers/google/suite/operators/sheets.py +8 -3
  257. airflow/providers/google/suite/sensors/drive.py +2 -2
  258. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
  259. airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
  260. airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
  261. airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
  262. airflow/providers/google/version_compat.py +15 -1
  263. {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +117 -72
  264. apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
  265. {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +1 -1
  266. apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
  267. airflow/providers/google/cloud/example_dags/example_cloud_task.py +0 -54
  268. airflow/providers/google/cloud/hooks/automl.py +0 -679
  269. airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
  270. airflow/providers/google/cloud/links/automl.py +0 -193
  271. airflow/providers/google/cloud/operators/automl.py +0 -1360
  272. airflow/providers/google/cloud/operators/life_sciences.py +0 -119
  273. airflow/providers/google/cloud/operators/mlengine.py +0 -1515
  274. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
  275. apache_airflow_providers_google-14.0.0.dist-info/RECORD +0 -318
  276. /airflow/providers/google/cloud/{example_dags → bundles}/__init__.py +0 -0
  277. {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
  278. {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
@@ -20,107 +20,27 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any, Literal
24
24
 
25
- from airflow.exceptions import AirflowProviderDeprecationWarning
26
- from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
25
+ from google.api_core import exceptions
26
+
27
+ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
28
+ from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
29
+ ExperimentRunHook,
30
+ GenerativeModelHook,
31
+ )
27
32
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
28
33
  from airflow.providers.google.common.deprecated import deprecated
29
34
 
30
35
  if TYPE_CHECKING:
31
- from airflow.utils.context import Context
36
+ from airflow.providers.common.compat.sdk import Context
32
37
 
33
38
 
34
39
  @deprecated(
35
- planned_removal_date="April 09, 2025",
36
- use_instead="GenerativeModelGenerateContentOperator",
40
+ planned_removal_date="January 3, 2026",
41
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateEmbeddingsOperator",
37
42
  category=AirflowProviderDeprecationWarning,
38
43
  )
39
- class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
40
- """
41
- Uses the Vertex AI PaLM API to generate natural language text.
42
-
43
- :param project_id: Required. The ID of the Google Cloud project that the
44
- service belongs to (templated).
45
- :param location: Required. The ID of the Google Cloud location that the
46
- service belongs to (templated).
47
- :param prompt: Required. Inputs or queries that a user or a program gives
48
- to the Vertex AI PaLM API, in order to elicit a specific response (templated).
49
- :param pretrained_model: By default uses the pre-trained model `text-bison`,
50
- optimized for performing natural language tasks such as classification,
51
- summarization, extraction, content creation, and ideation.
52
- :param temperature: Temperature controls the degree of randomness in token
53
- selection. Defaults to 0.0.
54
- :param max_output_tokens: Token limit determines the maximum amount of text
55
- output. Defaults to 256.
56
- :param top_p: Tokens are selected from most probable to least until the sum
57
- of their probabilities equals the top_p value. Defaults to 0.8.
58
- :param top_k: A top_k of 1 means the selected token is the most probable
59
- among all tokens. Defaults to 0.4.
60
- :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
61
- :param impersonation_chain: Optional service account to impersonate using short-term
62
- credentials, or chained list of accounts required to get the access_token
63
- of the last account in the list, which will be impersonated in the request.
64
- If set as a string, the account must grant the originating account
65
- the Service Account Token Creator IAM role.
66
- If set as a sequence, the identities from the list must grant
67
- Service Account Token Creator IAM role to the directly preceding identity, with first
68
- account from the list granting this role to the originating account (templated).
69
- """
70
-
71
- template_fields = ("location", "project_id", "impersonation_chain", "prompt")
72
-
73
- def __init__(
74
- self,
75
- *,
76
- project_id: str,
77
- location: str,
78
- prompt: str,
79
- pretrained_model: str = "text-bison",
80
- temperature: float = 0.0,
81
- max_output_tokens: int = 256,
82
- top_p: float = 0.8,
83
- top_k: int = 40,
84
- gcp_conn_id: str = "google_cloud_default",
85
- impersonation_chain: str | Sequence[str] | None = None,
86
- **kwargs,
87
- ) -> None:
88
- super().__init__(**kwargs)
89
- self.project_id = project_id
90
- self.location = location
91
- self.prompt = prompt
92
- self.pretrained_model = pretrained_model
93
- self.temperature = temperature
94
- self.max_output_tokens = max_output_tokens
95
- self.top_p = top_p
96
- self.top_k = top_k
97
- self.gcp_conn_id = gcp_conn_id
98
- self.impersonation_chain = impersonation_chain
99
-
100
- def execute(self, context: Context):
101
- self.hook = GenerativeModelHook(
102
- gcp_conn_id=self.gcp_conn_id,
103
- impersonation_chain=self.impersonation_chain,
104
- )
105
-
106
- self.log.info("Submitting prompt")
107
- response = self.hook.text_generation_model_predict(
108
- project_id=self.project_id,
109
- location=self.location,
110
- prompt=self.prompt,
111
- pretrained_model=self.pretrained_model,
112
- temperature=self.temperature,
113
- max_output_tokens=self.max_output_tokens,
114
- top_p=self.top_p,
115
- top_k=self.top_k,
116
- )
117
-
118
- self.log.info("Model response: %s", response)
119
- self.xcom_push(context, key="model_response", value=response)
120
-
121
- return response
122
-
123
-
124
44
  class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
125
45
  """
126
46
  Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
@@ -130,9 +50,8 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
130
50
  :param location: Required. The ID of the Google Cloud location that the
131
51
  service belongs to (templated).
132
52
  :param prompt: Required. Inputs or queries that a user or a program gives
133
- to the Vertex AI PaLM API, in order to elicit a specific response (templated).
134
- :param pretrained_model: By default uses the pre-trained model `textembedding-gecko`,
135
- optimized for performing text embeddings.
53
+ to the Vertex AI Generative Model API, in order to elicit a specific response (templated).
54
+ :param pretrained_model: Required. Model, optimized for performing text embeddings.
136
55
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
137
56
  :param impersonation_chain: Optional service account to impersonate using short-term
138
57
  credentials, or chained list of accounts required to get the access_token
@@ -144,7 +63,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
144
63
  account from the list granting this role to the originating account (templated).
145
64
  """
146
65
 
147
- template_fields = ("location", "project_id", "impersonation_chain", "prompt")
66
+ template_fields = ("location", "project_id", "impersonation_chain", "prompt", "pretrained_model")
148
67
 
149
68
  def __init__(
150
69
  self,
@@ -152,7 +71,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
152
71
  project_id: str,
153
72
  location: str,
154
73
  prompt: str,
155
- pretrained_model: str = "textembedding-gecko",
74
+ pretrained_model: str,
156
75
  gcp_conn_id: str = "google_cloud_default",
157
76
  impersonation_chain: str | Sequence[str] | None = None,
158
77
  **kwargs,
@@ -180,11 +99,16 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
180
99
  )
181
100
 
182
101
  self.log.info("Model response: %s", response)
183
- self.xcom_push(context, key="model_response", value=response)
102
+ context["ti"].xcom_push(key="model_response", value=response)
184
103
 
185
104
  return response
186
105
 
187
106
 
107
+ @deprecated(
108
+ planned_removal_date="January 3, 2026",
109
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
110
+ category=AirflowProviderDeprecationWarning,
111
+ )
188
112
  class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
189
113
  """
190
114
  Use the Vertex AI Gemini Pro foundation model to generate content.
@@ -199,10 +123,9 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
199
123
  :param safety_settings: Optional. Per request settings for blocking unsafe content.
200
124
  :param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
201
125
  :param system_instruction: Optional. An instruction given to the model to guide its behavior.
202
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
203
- supporting prompts with text-only input, including natural language
204
- tasks, multi-turn text and code chat, and code generation. It can
205
- output text and code.
126
+ :param pretrained_model: Required. The name of the model to use for content generation,
127
+ which can be a text-only or multimodal model. For example, `gemini-pro` or
128
+ `gemini-pro-vision`.
206
129
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
207
130
  :param impersonation_chain: Optional service account to impersonate using short-term
208
131
  credentials, or chained list of accounts required to get the access_token
@@ -226,7 +149,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
226
149
  generation_config: dict | None = None,
227
150
  safety_settings: dict | None = None,
228
151
  system_instruction: str | None = None,
229
- pretrained_model: str = "gemini-pro",
152
+ pretrained_model: str,
230
153
  gcp_conn_id: str = "google_cloud_default",
231
154
  impersonation_chain: str | Sequence[str] | None = None,
232
155
  **kwargs,
@@ -260,11 +183,16 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
260
183
  )
261
184
 
262
185
  self.log.info("Model response: %s", response)
263
- self.xcom_push(context, key="model_response", value=response)
186
+ context["ti"].xcom_push(key="model_response", value=response)
264
187
 
265
188
  return response
266
189
 
267
190
 
191
+ @deprecated(
192
+ planned_removal_date="January 3, 2026",
193
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAISupervisedFineTuningTrainOperator",
194
+ category=AirflowProviderDeprecationWarning,
195
+ )
268
196
  class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
269
197
  """
270
198
  Use the Supervised Fine Tuning API to create a tuning job.
@@ -298,7 +226,14 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
298
226
  account from the list granting this role to the originating account (templated).
299
227
  """
300
228
 
301
- template_fields = ("location", "project_id", "impersonation_chain", "train_dataset", "validation_dataset")
229
+ template_fields = (
230
+ "location",
231
+ "project_id",
232
+ "impersonation_chain",
233
+ "train_dataset",
234
+ "validation_dataset",
235
+ "source_model",
236
+ )
302
237
 
303
238
  def __init__(
304
239
  self,
@@ -310,7 +245,7 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
310
245
  tuned_model_display_name: str | None = None,
311
246
  validation_dataset: str | None = None,
312
247
  epochs: int | None = None,
313
- adapter_size: int | None = None,
248
+ adapter_size: Literal[1, 4, 8, 16] | None = None,
314
249
  learning_rate_multiplier: float | None = None,
315
250
  gcp_conn_id: str = "google_cloud_default",
316
251
  impersonation_chain: str | Sequence[str] | None = None,
@@ -349,8 +284,8 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
349
284
  self.log.info("Tuned Model Name: %s", response.tuned_model_name)
350
285
  self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)
351
286
 
352
- self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name)
353
- self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
287
+ context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model_name)
288
+ context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
354
289
 
355
290
  result = {
356
291
  "tuned_model_name": response.tuned_model_name,
@@ -360,6 +295,11 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
360
295
  return result
361
296
 
362
297
 
298
+ @deprecated(
299
+ planned_removal_date="January 3, 2026",
300
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICountTokensOperator",
301
+ category=AirflowProviderDeprecationWarning,
302
+ )
363
303
  class CountTokensOperator(GoogleCloudBaseOperator):
364
304
  """
365
305
  Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
@@ -370,10 +310,9 @@ class CountTokensOperator(GoogleCloudBaseOperator):
370
310
  service belongs to (templated).
371
311
  :param contents: Required. The multi-part content of a message that a user or a program
372
312
  gives to the generative model, in order to elicit a specific response.
373
- :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
374
- supporting prompts with text-only input, including natural language
375
- tasks, multi-turn text and code chat, and code generation. It can
376
- output text and code.
313
+ :param pretrained_model: Required. Model, supporting prompts with text-only input,
314
+ including natural language tasks, multi-turn text and code chat,
315
+ and code generation. It can output text and code.
377
316
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
378
317
  :param impersonation_chain: Optional service account to impersonate using short-term
379
318
  credentials, or chained list of accounts required to get the access_token
@@ -393,7 +332,7 @@ class CountTokensOperator(GoogleCloudBaseOperator):
393
332
  project_id: str,
394
333
  location: str,
395
334
  contents: list,
396
- pretrained_model: str = "gemini-pro",
335
+ pretrained_model: str,
397
336
  gcp_conn_id: str = "google_cloud_default",
398
337
  impersonation_chain: str | Sequence[str] | None = None,
399
338
  **kwargs,
@@ -421,8 +360,8 @@ class CountTokensOperator(GoogleCloudBaseOperator):
421
360
  self.log.info("Total tokens: %s", response.total_tokens)
422
361
  self.log.info("Total billable characters: %s", response.total_billable_characters)
423
362
 
424
- self.xcom_push(context, key="total_tokens", value=response.total_tokens)
425
- self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters)
363
+ context["ti"].xcom_push(key="total_tokens", value=response.total_tokens)
364
+ context["ti"].xcom_push(key="total_billable_characters", value=response.total_billable_characters)
426
365
 
427
366
 
428
367
  class RunEvaluationOperator(GoogleCloudBaseOperator):
@@ -524,6 +463,11 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
524
463
  return response.summary_metrics
525
464
 
526
465
 
466
+ @deprecated(
467
+ planned_removal_date="January 3, 2026",
468
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICreateCachedContentOperator",
469
+ category=AirflowProviderDeprecationWarning,
470
+ )
527
471
  class CreateCachedContentOperator(GoogleCloudBaseOperator):
528
472
  """
529
473
  Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
@@ -562,8 +506,8 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
562
506
  project_id: str,
563
507
  location: str,
564
508
  model_name: str,
565
- system_instruction: str | None = None,
566
- contents: list | None = None,
509
+ system_instruction: Any | None = None,
510
+ contents: list[Any] | None = None,
567
511
  ttl_hours: float = 1,
568
512
  display_name: str | None = None,
569
513
  gcp_conn_id: str = "google_cloud_default",
@@ -603,6 +547,11 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
603
547
  return cached_content_name
604
548
 
605
549
 
550
+ @deprecated(
551
+ planned_removal_date="January 3, 2026",
552
+ use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
553
+ category=AirflowProviderDeprecationWarning,
554
+ )
606
555
  class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
607
556
  """
608
557
  Generate a response from CachedContent.
@@ -674,3 +623,73 @@ class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
674
623
  self.log.info("Cached Content Response: %s", cached_content_text)
675
624
 
676
625
  return cached_content_text
626
+
627
+
628
+ @deprecated(
629
+ planned_removal_date="January 3, 2026",
630
+ use_instead="airflow.providers.google.cloud.operators.vertex_ai.experiment_service.DeleteExperimentRunOperator",
631
+ category=AirflowProviderDeprecationWarning,
632
+ )
633
+ class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
634
+ """
635
+ Use the Rapid Evaluation API to evaluate a model.
636
+
637
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
638
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
639
+ :param experiment_name: Required. The name of the evaluation experiment.
640
+ :param experiment_run_name: Required. The specific run name or ID for this experiment.
641
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
642
+ :param impersonation_chain: Optional service account to impersonate using short-term
643
+ credentials, or chained list of accounts required to get the access_token
644
+ of the last account in the list, which will be impersonated in the request.
645
+ If set as a string, the account must grant the originating account
646
+ the Service Account Token Creator IAM role.
647
+ If set as a sequence, the identities from the list must grant
648
+ Service Account Token Creator IAM role to the directly preceding identity, with first
649
+ account from the list granting this role to the originating account (templated).
650
+ """
651
+
652
+ template_fields = (
653
+ "location",
654
+ "project_id",
655
+ "impersonation_chain",
656
+ "experiment_name",
657
+ "experiment_run_name",
658
+ )
659
+
660
+ def __init__(
661
+ self,
662
+ *,
663
+ project_id: str,
664
+ location: str,
665
+ experiment_name: str,
666
+ experiment_run_name: str,
667
+ gcp_conn_id: str = "google_cloud_default",
668
+ impersonation_chain: str | Sequence[str] | None = None,
669
+ **kwargs,
670
+ ) -> None:
671
+ super().__init__(**kwargs)
672
+ self.project_id = project_id
673
+ self.location = location
674
+ self.experiment_name = experiment_name
675
+ self.experiment_run_name = experiment_run_name
676
+ self.gcp_conn_id = gcp_conn_id
677
+ self.impersonation_chain = impersonation_chain
678
+
679
+ def execute(self, context: Context) -> None:
680
+ self.hook = ExperimentRunHook(
681
+ gcp_conn_id=self.gcp_conn_id,
682
+ impersonation_chain=self.impersonation_chain,
683
+ )
684
+
685
+ try:
686
+ self.hook.delete_experiment_run(
687
+ project_id=self.project_id,
688
+ location=self.location,
689
+ experiment_name=self.experiment_name,
690
+ experiment_run_name=self.experiment_run_name,
691
+ )
692
+ except exceptions.NotFound:
693
+ raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found")
694
+
695
+ self.log.info("Deleted experiment run: %s", self.experiment_run_name)
@@ -23,6 +23,10 @@ from __future__ import annotations
23
23
  from collections.abc import Sequence
24
24
  from typing import TYPE_CHECKING, Any
25
25
 
26
+ from google.api_core.exceptions import NotFound
27
+ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
28
+ from google.cloud.aiplatform_v1 import types
29
+
26
30
  from airflow.configuration import conf
27
31
  from airflow.exceptions import AirflowException
28
32
  from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
@@ -34,15 +38,13 @@ from airflow.providers.google.cloud.links.vertex_ai import (
34
38
  )
35
39
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
36
40
  from airflow.providers.google.cloud.triggers.vertex_ai import CreateHyperparameterTuningJobTrigger
37
- from google.api_core.exceptions import NotFound
38
- from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
39
- from google.cloud.aiplatform_v1 import types
40
41
 
41
42
  if TYPE_CHECKING:
42
- from airflow.utils.context import Context
43
43
  from google.api_core.retry import Retry
44
44
  from google.cloud.aiplatform import HyperparameterTuningJob, gapic, hyperparameter_tuning
45
45
 
46
+ from airflow.providers.common.compat.sdk import Context
47
+
46
48
 
47
49
  class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
48
50
  """
@@ -255,10 +257,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
255
257
  hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
256
258
  self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
257
259
 
258
- self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
259
- VertexAITrainingLink.persist(
260
- context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
261
- )
260
+ context["ti"].xcom_push(key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
261
+ VertexAITrainingLink.persist(context=context, training_id=hyperparameter_tuning_job_id)
262
262
 
263
263
  if self.deferrable:
264
264
  self.defer(
@@ -353,9 +353,7 @@ class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
353
353
  timeout=self.timeout,
354
354
  metadata=self.metadata,
355
355
  )
356
- VertexAITrainingLink.persist(
357
- context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id
358
- )
356
+ VertexAITrainingLink.persist(context=context, training_id=self.hyperparameter_tuning_job_id)
359
357
  self.log.info("Hyperparameter tuning job was gotten.")
360
358
  return types.HyperparameterTuningJob.to_dict(result)
361
359
  except NotFound:
@@ -485,6 +483,12 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
485
483
  self.gcp_conn_id = gcp_conn_id
486
484
  self.impersonation_chain = impersonation_chain
487
485
 
486
+ @property
487
+ def extra_links_params(self) -> dict[str, Any]:
488
+ return {
489
+ "project_id": self.project_id,
490
+ }
491
+
488
492
  def execute(self, context: Context):
489
493
  hook = HyperparameterTuningJobHook(
490
494
  gcp_conn_id=self.gcp_conn_id,
@@ -501,5 +505,5 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
501
505
  timeout=self.timeout,
502
506
  metadata=self.metadata,
503
507
  )
504
- VertexAIHyperparameterTuningJobListLink.persist(context=context, task_instance=self)
508
+ VertexAIHyperparameterTuningJobListLink.persist(context=context)
505
509
  return [types.HyperparameterTuningJob.to_dict(result) for result in results]
@@ -20,7 +20,11 @@
20
20
  from __future__ import annotations
21
21
 
22
22
  from collections.abc import Sequence
23
- from typing import TYPE_CHECKING
23
+ from typing import TYPE_CHECKING, Any
24
+
25
+ from google.api_core.exceptions import NotFound
26
+ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
27
+ from google.cloud.aiplatform_v1.types import Model, model_service
24
28
 
25
29
  from airflow.providers.google.cloud.hooks.vertex_ai.model_service import ModelServiceHook
26
30
  from airflow.providers.google.cloud.links.vertex_ai import (
@@ -29,14 +33,12 @@ from airflow.providers.google.cloud.links.vertex_ai import (
29
33
  VertexAIModelListLink,
30
34
  )
31
35
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
32
- from google.api_core.exceptions import NotFound
33
- from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
34
- from google.cloud.aiplatform_v1.types import Model, model_service
35
36
 
36
37
  if TYPE_CHECKING:
37
- from airflow.utils.context import Context
38
38
  from google.api_core.retry import Retry
39
39
 
40
+ from airflow.providers.common.compat.sdk import Context
41
+
40
42
 
41
43
  class DeleteModelOperator(GoogleCloudBaseOperator):
42
44
  """
@@ -159,6 +161,13 @@ class GetModelOperator(GoogleCloudBaseOperator):
159
161
  self.gcp_conn_id = gcp_conn_id
160
162
  self.impersonation_chain = impersonation_chain
161
163
 
164
+ @property
165
+ def extra_links_params(self) -> dict[str, Any]:
166
+ return {
167
+ "region": self.region,
168
+ "project_id": self.project_id,
169
+ }
170
+
162
171
  def execute(self, context: Context):
163
172
  hook = ModelServiceHook(
164
173
  gcp_conn_id=self.gcp_conn_id,
@@ -177,8 +186,8 @@ class GetModelOperator(GoogleCloudBaseOperator):
177
186
  )
178
187
  self.log.info("Model found. Model ID: %s", self.model_id)
179
188
 
180
- self.xcom_push(context, key="model_id", value=self.model_id)
181
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
189
+ context["ti"].xcom_push(key="model_id", value=self.model_id)
190
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
182
191
  return Model.to_dict(model)
183
192
  except NotFound:
184
193
  self.log.info("The Model ID %s does not exist.", self.model_id)
@@ -255,7 +264,12 @@ class ExportModelOperator(GoogleCloudBaseOperator):
255
264
  metadata=self.metadata,
256
265
  )
257
266
  hook.wait_for_operation(timeout=self.timeout, operation=operation)
258
- VertexAIModelExportLink.persist(context=context, task_instance=self)
267
+ VertexAIModelExportLink.persist(
268
+ context=context,
269
+ output_config=self.output_config,
270
+ model_id=self.model_id,
271
+ project_id=self.project_id,
272
+ )
259
273
  self.log.info("Model was exported.")
260
274
  except NotFound:
261
275
  self.log.info("The Model ID %s does not exist.", self.model_id)
@@ -333,6 +347,12 @@ class ListModelsOperator(GoogleCloudBaseOperator):
333
347
  self.gcp_conn_id = gcp_conn_id
334
348
  self.impersonation_chain = impersonation_chain
335
349
 
350
+ @property
351
+ def extra_links_params(self) -> dict[str, Any]:
352
+ return {
353
+ "project_id": self.project_id,
354
+ }
355
+
336
356
  def execute(self, context: Context):
337
357
  hook = ModelServiceHook(
338
358
  gcp_conn_id=self.gcp_conn_id,
@@ -350,7 +370,7 @@ class ListModelsOperator(GoogleCloudBaseOperator):
350
370
  timeout=self.timeout,
351
371
  metadata=self.metadata,
352
372
  )
353
- VertexAIModelListLink.persist(context=context, task_instance=self)
373
+ VertexAIModelListLink.persist(context=context)
354
374
  return [Model.to_dict(result) for result in results]
355
375
 
356
376
 
@@ -405,6 +425,13 @@ class UploadModelOperator(GoogleCloudBaseOperator):
405
425
  self.gcp_conn_id = gcp_conn_id
406
426
  self.impersonation_chain = impersonation_chain
407
427
 
428
+ @property
429
+ def extra_links_params(self) -> dict[str, Any]:
430
+ return {
431
+ "region": self.region,
432
+ "project_id": self.project_id,
433
+ }
434
+
408
435
  def execute(self, context: Context):
409
436
  hook = ModelServiceHook(
410
437
  gcp_conn_id=self.gcp_conn_id,
@@ -426,8 +453,8 @@ class UploadModelOperator(GoogleCloudBaseOperator):
426
453
  model_id = hook.extract_model_id(model_resp)
427
454
  self.log.info("Model was uploaded. Model ID: %s", model_id)
428
455
 
429
- self.xcom_push(context, key="model_id", value=model_id)
430
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
456
+ context["ti"].xcom_push(key="model_id", value=model_id)
457
+ VertexAIModelLink.persist(context=context, model_id=model_id)
431
458
  return model_resp
432
459
 
433
460
 
@@ -551,6 +578,13 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
551
578
  self.gcp_conn_id = gcp_conn_id
552
579
  self.impersonation_chain = impersonation_chain
553
580
 
581
+ @property
582
+ def extra_links_params(self) -> dict[str, Any]:
583
+ return {
584
+ "region": self.region,
585
+ "project_id": self.project_id,
586
+ }
587
+
554
588
  def execute(self, context: Context):
555
589
  hook = ModelServiceHook(
556
590
  gcp_conn_id=self.gcp_conn_id,
@@ -569,7 +603,7 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
569
603
  timeout=self.timeout,
570
604
  metadata=self.metadata,
571
605
  )
572
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
606
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
573
607
  return Model.to_dict(updated_model)
574
608
 
575
609
 
@@ -625,6 +659,13 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
625
659
  self.gcp_conn_id = gcp_conn_id
626
660
  self.impersonation_chain = impersonation_chain
627
661
 
662
+ @property
663
+ def extra_links_params(self) -> dict[str, Any]:
664
+ return {
665
+ "region": self.region,
666
+ "project_id": self.project_id,
667
+ }
668
+
628
669
  def execute(self, context: Context):
629
670
  hook = ModelServiceHook(
630
671
  gcp_conn_id=self.gcp_conn_id,
@@ -643,7 +684,7 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
643
684
  timeout=self.timeout,
644
685
  metadata=self.metadata,
645
686
  )
646
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
687
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
647
688
  return Model.to_dict(updated_model)
648
689
 
649
690
 
@@ -699,6 +740,13 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
699
740
  self.gcp_conn_id = gcp_conn_id
700
741
  self.impersonation_chain = impersonation_chain
701
742
 
743
+ @property
744
+ def extra_links_params(self) -> dict[str, Any]:
745
+ return {
746
+ "region": self.region,
747
+ "project_id": self.project_id,
748
+ }
749
+
702
750
  def execute(self, context: Context):
703
751
  hook = ModelServiceHook(
704
752
  gcp_conn_id=self.gcp_conn_id,
@@ -719,7 +767,7 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
719
767
  timeout=self.timeout,
720
768
  metadata=self.metadata,
721
769
  )
722
- VertexAIModelLink.persist(context=context, task_instance=self, model_id=self.model_id)
770
+ VertexAIModelLink.persist(context=context, model_id=self.model_id)
723
771
  return Model.to_dict(updated_model)
724
772
 
725
773