apache-airflow-providers-google 10.16.0rc1__py3-none-any.whl → 10.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +5 -4
  3. airflow/providers/google/ads/operators/ads.py +1 -0
  4. airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +1 -0
  5. airflow/providers/google/cloud/example_dags/example_cloud_task.py +1 -0
  6. airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py +1 -0
  7. airflow/providers/google/cloud/example_dags/example_looker.py +1 -0
  8. airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py +1 -0
  9. airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py +1 -0
  10. airflow/providers/google/cloud/fs/gcs.py +1 -2
  11. airflow/providers/google/cloud/hooks/automl.py +1 -0
  12. airflow/providers/google/cloud/hooks/bigquery.py +87 -24
  13. airflow/providers/google/cloud/hooks/bigquery_dts.py +1 -0
  14. airflow/providers/google/cloud/hooks/bigtable.py +1 -0
  15. airflow/providers/google/cloud/hooks/cloud_build.py +1 -0
  16. airflow/providers/google/cloud/hooks/cloud_memorystore.py +1 -0
  17. airflow/providers/google/cloud/hooks/cloud_sql.py +1 -0
  18. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +9 -4
  19. airflow/providers/google/cloud/hooks/compute.py +1 -0
  20. airflow/providers/google/cloud/hooks/compute_ssh.py +2 -2
  21. airflow/providers/google/cloud/hooks/dataflow.py +6 -5
  22. airflow/providers/google/cloud/hooks/datafusion.py +1 -0
  23. airflow/providers/google/cloud/hooks/datapipeline.py +1 -0
  24. airflow/providers/google/cloud/hooks/dataplex.py +1 -0
  25. airflow/providers/google/cloud/hooks/dataprep.py +1 -0
  26. airflow/providers/google/cloud/hooks/dataproc.py +3 -2
  27. airflow/providers/google/cloud/hooks/dataproc_metastore.py +1 -0
  28. airflow/providers/google/cloud/hooks/datastore.py +1 -0
  29. airflow/providers/google/cloud/hooks/dlp.py +1 -0
  30. airflow/providers/google/cloud/hooks/functions.py +1 -0
  31. airflow/providers/google/cloud/hooks/gcs.py +12 -5
  32. airflow/providers/google/cloud/hooks/kms.py +1 -0
  33. airflow/providers/google/cloud/hooks/kubernetes_engine.py +178 -300
  34. airflow/providers/google/cloud/hooks/life_sciences.py +1 -0
  35. airflow/providers/google/cloud/hooks/looker.py +1 -0
  36. airflow/providers/google/cloud/hooks/mlengine.py +1 -0
  37. airflow/providers/google/cloud/hooks/natural_language.py +1 -0
  38. airflow/providers/google/cloud/hooks/os_login.py +1 -0
  39. airflow/providers/google/cloud/hooks/pubsub.py +1 -0
  40. airflow/providers/google/cloud/hooks/secret_manager.py +1 -0
  41. airflow/providers/google/cloud/hooks/spanner.py +1 -0
  42. airflow/providers/google/cloud/hooks/speech_to_text.py +1 -0
  43. airflow/providers/google/cloud/hooks/stackdriver.py +1 -0
  44. airflow/providers/google/cloud/hooks/text_to_speech.py +1 -0
  45. airflow/providers/google/cloud/hooks/translate.py +1 -0
  46. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +1 -0
  47. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +255 -3
  48. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1 -0
  49. airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +1 -0
  50. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +1 -0
  51. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +197 -0
  52. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +9 -9
  53. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +231 -12
  54. airflow/providers/google/cloud/hooks/video_intelligence.py +1 -0
  55. airflow/providers/google/cloud/hooks/vision.py +1 -0
  56. airflow/providers/google/cloud/links/automl.py +1 -0
  57. airflow/providers/google/cloud/links/bigquery.py +1 -0
  58. airflow/providers/google/cloud/links/bigquery_dts.py +1 -0
  59. airflow/providers/google/cloud/links/cloud_memorystore.py +1 -0
  60. airflow/providers/google/cloud/links/cloud_sql.py +1 -0
  61. airflow/providers/google/cloud/links/cloud_tasks.py +1 -0
  62. airflow/providers/google/cloud/links/compute.py +1 -0
  63. airflow/providers/google/cloud/links/datacatalog.py +1 -0
  64. airflow/providers/google/cloud/links/dataflow.py +1 -0
  65. airflow/providers/google/cloud/links/dataform.py +1 -0
  66. airflow/providers/google/cloud/links/datafusion.py +1 -0
  67. airflow/providers/google/cloud/links/dataplex.py +1 -0
  68. airflow/providers/google/cloud/links/dataproc.py +1 -0
  69. airflow/providers/google/cloud/links/kubernetes_engine.py +28 -0
  70. airflow/providers/google/cloud/links/mlengine.py +1 -0
  71. airflow/providers/google/cloud/links/pubsub.py +1 -0
  72. airflow/providers/google/cloud/links/spanner.py +1 -0
  73. airflow/providers/google/cloud/links/stackdriver.py +1 -0
  74. airflow/providers/google/cloud/links/workflows.py +1 -0
  75. airflow/providers/google/cloud/log/stackdriver_task_handler.py +18 -4
  76. airflow/providers/google/cloud/operators/automl.py +1 -0
  77. airflow/providers/google/cloud/operators/bigquery.py +21 -0
  78. airflow/providers/google/cloud/operators/bigquery_dts.py +1 -0
  79. airflow/providers/google/cloud/operators/bigtable.py +1 -0
  80. airflow/providers/google/cloud/operators/cloud_base.py +1 -0
  81. airflow/providers/google/cloud/operators/cloud_build.py +1 -0
  82. airflow/providers/google/cloud/operators/cloud_memorystore.py +1 -0
  83. airflow/providers/google/cloud/operators/cloud_sql.py +1 -0
  84. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +11 -5
  85. airflow/providers/google/cloud/operators/compute.py +1 -0
  86. airflow/providers/google/cloud/operators/dataflow.py +1 -0
  87. airflow/providers/google/cloud/operators/datafusion.py +1 -0
  88. airflow/providers/google/cloud/operators/datapipeline.py +1 -0
  89. airflow/providers/google/cloud/operators/dataprep.py +1 -0
  90. airflow/providers/google/cloud/operators/dataproc.py +3 -2
  91. airflow/providers/google/cloud/operators/dataproc_metastore.py +1 -0
  92. airflow/providers/google/cloud/operators/datastore.py +1 -0
  93. airflow/providers/google/cloud/operators/functions.py +1 -0
  94. airflow/providers/google/cloud/operators/gcs.py +1 -0
  95. airflow/providers/google/cloud/operators/kubernetes_engine.py +600 -4
  96. airflow/providers/google/cloud/operators/life_sciences.py +1 -0
  97. airflow/providers/google/cloud/operators/looker.py +1 -0
  98. airflow/providers/google/cloud/operators/mlengine.py +283 -259
  99. airflow/providers/google/cloud/operators/natural_language.py +1 -0
  100. airflow/providers/google/cloud/operators/pubsub.py +1 -0
  101. airflow/providers/google/cloud/operators/spanner.py +1 -0
  102. airflow/providers/google/cloud/operators/speech_to_text.py +1 -0
  103. airflow/providers/google/cloud/operators/text_to_speech.py +1 -0
  104. airflow/providers/google/cloud/operators/translate.py +1 -0
  105. airflow/providers/google/cloud/operators/translate_speech.py +1 -0
  106. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +14 -7
  107. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +67 -13
  108. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +26 -8
  109. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +1 -0
  110. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +306 -0
  111. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +29 -48
  112. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +52 -17
  113. airflow/providers/google/cloud/operators/video_intelligence.py +1 -0
  114. airflow/providers/google/cloud/operators/vision.py +1 -0
  115. airflow/providers/google/cloud/secrets/secret_manager.py +1 -0
  116. airflow/providers/google/cloud/sensors/bigquery.py +1 -0
  117. airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -0
  118. airflow/providers/google/cloud/sensors/bigtable.py +1 -0
  119. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +1 -0
  120. airflow/providers/google/cloud/sensors/dataflow.py +1 -0
  121. airflow/providers/google/cloud/sensors/dataform.py +1 -0
  122. airflow/providers/google/cloud/sensors/datafusion.py +1 -0
  123. airflow/providers/google/cloud/sensors/dataplex.py +1 -0
  124. airflow/providers/google/cloud/sensors/dataprep.py +1 -0
  125. airflow/providers/google/cloud/sensors/dataproc.py +1 -0
  126. airflow/providers/google/cloud/sensors/gcs.py +1 -0
  127. airflow/providers/google/cloud/sensors/looker.py +1 -0
  128. airflow/providers/google/cloud/sensors/pubsub.py +1 -0
  129. airflow/providers/google/cloud/sensors/tasks.py +1 -0
  130. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +1 -0
  131. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -0
  132. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -0
  133. airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +1 -0
  134. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +1 -0
  135. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
  136. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -0
  137. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +3 -2
  138. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -0
  139. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -0
  140. airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -0
  141. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -0
  142. airflow/providers/google/cloud/transfers/mysql_to_gcs.py +1 -0
  143. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +19 -1
  144. airflow/providers/google/cloud/transfers/s3_to_gcs.py +3 -5
  145. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -0
  146. airflow/providers/google/cloud/transfers/sql_to_gcs.py +4 -2
  147. airflow/providers/google/cloud/triggers/bigquery.py +4 -3
  148. airflow/providers/google/cloud/triggers/cloud_batch.py +1 -1
  149. airflow/providers/google/cloud/triggers/cloud_run.py +1 -0
  150. airflow/providers/google/cloud/triggers/cloud_sql.py +2 -0
  151. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +14 -2
  152. airflow/providers/google/cloud/triggers/dataplex.py +1 -0
  153. airflow/providers/google/cloud/triggers/dataproc.py +1 -0
  154. airflow/providers/google/cloud/triggers/kubernetes_engine.py +72 -2
  155. airflow/providers/google/cloud/triggers/mlengine.py +2 -0
  156. airflow/providers/google/cloud/triggers/pubsub.py +3 -3
  157. airflow/providers/google/cloud/triggers/vertex_ai.py +107 -15
  158. airflow/providers/google/cloud/utils/field_sanitizer.py +2 -1
  159. airflow/providers/google/cloud/utils/field_validator.py +1 -0
  160. airflow/providers/google/cloud/utils/helpers.py +1 -0
  161. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +1 -0
  162. airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +1 -0
  163. airflow/providers/google/cloud/utils/openlineage.py +1 -0
  164. airflow/providers/google/common/auth_backend/google_openid.py +1 -0
  165. airflow/providers/google/common/hooks/base_google.py +2 -1
  166. airflow/providers/google/common/hooks/discovery_api.py +1 -0
  167. airflow/providers/google/common/links/storage.py +1 -0
  168. airflow/providers/google/common/utils/id_token_credentials.py +1 -0
  169. airflow/providers/google/firebase/hooks/firestore.py +1 -0
  170. airflow/providers/google/get_provider_info.py +9 -3
  171. airflow/providers/google/go_module_utils.py +1 -0
  172. airflow/providers/google/leveldb/hooks/leveldb.py +8 -7
  173. airflow/providers/google/marketing_platform/example_dags/example_display_video.py +1 -0
  174. airflow/providers/google/marketing_platform/hooks/analytics_admin.py +1 -0
  175. airflow/providers/google/marketing_platform/hooks/campaign_manager.py +1 -0
  176. airflow/providers/google/marketing_platform/hooks/display_video.py +1 -0
  177. airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -0
  178. airflow/providers/google/marketing_platform/operators/analytics.py +1 -0
  179. airflow/providers/google/marketing_platform/operators/analytics_admin.py +4 -2
  180. airflow/providers/google/marketing_platform/operators/campaign_manager.py +1 -0
  181. airflow/providers/google/marketing_platform/operators/display_video.py +1 -0
  182. airflow/providers/google/marketing_platform/operators/search_ads.py +1 -0
  183. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +1 -0
  184. airflow/providers/google/marketing_platform/sensors/display_video.py +2 -1
  185. airflow/providers/google/marketing_platform/sensors/search_ads.py +1 -0
  186. airflow/providers/google/suite/hooks/calendar.py +1 -0
  187. airflow/providers/google/suite/hooks/drive.py +1 -0
  188. airflow/providers/google/suite/hooks/sheets.py +1 -0
  189. airflow/providers/google/suite/sensors/drive.py +1 -0
  190. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +7 -0
  191. airflow/providers/google/suite/transfers/gcs_to_sheets.py +4 -1
  192. airflow/providers/google/suite/transfers/local_to_drive.py +1 -0
  193. {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0.dist-info}/METADATA +18 -13
  194. {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0.dist-info}/RECORD +196 -194
  195. {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0.dist-info}/WHEEL +0 -0
  196. {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,306 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """This module contains Google Vertex AI Generative AI operators."""
19
+
20
+ from __future__ import annotations
21
+
22
+ from typing import TYPE_CHECKING, Sequence
23
+
24
+ from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
25
+ from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
26
+
27
+ if TYPE_CHECKING:
28
+ from airflow.utils.context import Context
29
+
30
+
31
+ class PromptLanguageModelOperator(GoogleCloudBaseOperator):
32
+ """
33
+ Uses the Vertex AI PaLM API to generate natural language text.
34
+
35
+ :param project_id: Required. The ID of the Google Cloud project that the
36
+ service belongs to.
37
+ :param location: Required. The ID of the Google Cloud location that the
38
+ service belongs to.
39
+ :param prompt: Required. Inputs or queries that a user or a program gives
40
+ to the Vertex AI PaLM API, in order to elicit a specific response.
41
+ :param pretrained_model: By default uses the pre-trained model `text-bison`,
42
+ optimized for performing natural language tasks such as classification,
43
+ summarization, extraction, content creation, and ideation.
44
+ :param temperature: Temperature controls the degree of randomness in token
45
+ selection. Defaults to 0.0.
46
+ :param max_output_tokens: Token limit determines the maximum amount of text
47
+ output. Defaults to 256.
48
+ :param top_p: Tokens are selected from most probable to least until the sum
49
+ of their probabilities equals the top_p value. Defaults to 0.8.
50
+ :param top_k: A top_k of 1 means the selected token is the most probable
51
+ among all tokens. Defaults to 0.4.
52
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
53
+ :param impersonation_chain: Optional service account to impersonate using short-term
54
+ credentials, or chained list of accounts required to get the access_token
55
+ of the last account in the list, which will be impersonated in the request.
56
+ If set as a string, the account must grant the originating account
57
+ the Service Account Token Creator IAM role.
58
+ If set as a sequence, the identities from the list must grant
59
+ Service Account Token Creator IAM role to the directly preceding identity, with first
60
+ account from the list granting this role to the originating account (templated).
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ *,
66
+ project_id: str,
67
+ location: str,
68
+ prompt: str,
69
+ pretrained_model: str = "text-bison",
70
+ temperature: float = 0.0,
71
+ max_output_tokens: int = 256,
72
+ top_p: float = 0.8,
73
+ top_k: int = 40,
74
+ gcp_conn_id: str = "google_cloud_default",
75
+ impersonation_chain: str | Sequence[str] | None = None,
76
+ **kwargs,
77
+ ) -> None:
78
+ super().__init__(**kwargs)
79
+ self.project_id = project_id
80
+ self.location = location
81
+ self.prompt = prompt
82
+ self.pretrained_model = pretrained_model
83
+ self.temperature = temperature
84
+ self.max_output_tokens = max_output_tokens
85
+ self.top_p = top_p
86
+ self.top_k = top_k
87
+ self.gcp_conn_id = gcp_conn_id
88
+ self.impersonation_chain = impersonation_chain
89
+
90
+ def execute(self, context: Context):
91
+ self.hook = GenerativeModelHook(
92
+ gcp_conn_id=self.gcp_conn_id,
93
+ impersonation_chain=self.impersonation_chain,
94
+ )
95
+
96
+ self.log.info("Submitting prompt")
97
+ response = self.hook.prompt_language_model(
98
+ project_id=self.project_id,
99
+ location=self.location,
100
+ prompt=self.prompt,
101
+ pretrained_model=self.pretrained_model,
102
+ temperature=self.temperature,
103
+ max_output_tokens=self.max_output_tokens,
104
+ top_p=self.top_p,
105
+ top_k=self.top_k,
106
+ )
107
+
108
+ self.log.info("Model response: %s", response)
109
+ self.xcom_push(context, key="prompt_response", value=response)
110
+
111
+ return response
112
+
113
+
114
+ class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
115
+ """
116
+ Uses the Vertex AI PaLM API to generate natural language text.
117
+
118
+ :param project_id: Required. The ID of the Google Cloud project that the
119
+ service belongs to.
120
+ :param location: Required. The ID of the Google Cloud location that the
121
+ service belongs to.
122
+ :param prompt: Required. Inputs or queries that a user or a program gives
123
+ to the Vertex AI PaLM API, in order to elicit a specific response.
124
+ :param pretrained_model: By default uses the pre-trained model `textembedding-gecko`,
125
+ optimized for performing text embeddings.
126
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
127
+ :param impersonation_chain: Optional service account to impersonate using short-term
128
+ credentials, or chained list of accounts required to get the access_token
129
+ of the last account in the list, which will be impersonated in the request.
130
+ If set as a string, the account must grant the originating account
131
+ the Service Account Token Creator IAM role.
132
+ If set as a sequence, the identities from the list must grant
133
+ Service Account Token Creator IAM role to the directly preceding identity, with first
134
+ account from the list granting this role to the originating account (templated).
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ *,
140
+ project_id: str,
141
+ location: str,
142
+ prompt: str,
143
+ pretrained_model: str = "textembedding-gecko",
144
+ gcp_conn_id: str = "google_cloud_default",
145
+ impersonation_chain: str | Sequence[str] | None = None,
146
+ **kwargs,
147
+ ) -> None:
148
+ super().__init__(**kwargs)
149
+ self.project_id = project_id
150
+ self.location = location
151
+ self.prompt = prompt
152
+ self.pretrained_model = pretrained_model
153
+ self.gcp_conn_id = gcp_conn_id
154
+ self.impersonation_chain = impersonation_chain
155
+
156
+ def execute(self, context: Context):
157
+ self.hook = GenerativeModelHook(
158
+ gcp_conn_id=self.gcp_conn_id,
159
+ impersonation_chain=self.impersonation_chain,
160
+ )
161
+
162
+ self.log.info("Generating text embeddings")
163
+ response = self.hook.generate_text_embeddings(
164
+ project_id=self.project_id,
165
+ location=self.location,
166
+ prompt=self.prompt,
167
+ pretrained_model=self.pretrained_model,
168
+ )
169
+
170
+ self.log.info("Model response: %s", response)
171
+ self.xcom_push(context, key="prompt_response", value=response)
172
+
173
+ return response
174
+
175
+
176
+ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
177
+ """
178
+ Use the Vertex AI Gemini Pro foundation model to generate natural language text.
179
+
180
+ :param project_id: Required. The ID of the Google Cloud project that the
181
+ service belongs to.
182
+ :param location: Required. The ID of the Google Cloud location that the
183
+ service belongs to.
184
+ :param prompt: Required. Inputs or queries that a user or a program gives
185
+ to the Multi-modal model, in order to elicit a specific response.
186
+ :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
187
+ supporting prompts with text-only input, including natural language
188
+ tasks, multi-turn text and code chat, and code generation. It can
189
+ output text and code.
190
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
191
+ :param impersonation_chain: Optional service account to impersonate using short-term
192
+ credentials, or chained list of accounts required to get the access_token
193
+ of the last account in the list, which will be impersonated in the request.
194
+ If set as a string, the account must grant the originating account
195
+ the Service Account Token Creator IAM role.
196
+ If set as a sequence, the identities from the list must grant
197
+ Service Account Token Creator IAM role to the directly preceding identity, with first
198
+ account from the list granting this role to the originating account (templated).
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ *,
204
+ project_id: str,
205
+ location: str,
206
+ prompt: str,
207
+ pretrained_model: str = "gemini-pro",
208
+ gcp_conn_id: str = "google_cloud_default",
209
+ impersonation_chain: str | Sequence[str] | None = None,
210
+ **kwargs,
211
+ ) -> None:
212
+ super().__init__(**kwargs)
213
+ self.project_id = project_id
214
+ self.location = location
215
+ self.prompt = prompt
216
+ self.pretrained_model = pretrained_model
217
+ self.gcp_conn_id = gcp_conn_id
218
+ self.impersonation_chain = impersonation_chain
219
+
220
+ def execute(self, context: Context):
221
+ self.hook = GenerativeModelHook(
222
+ gcp_conn_id=self.gcp_conn_id,
223
+ impersonation_chain=self.impersonation_chain,
224
+ )
225
+ response = self.hook.prompt_multimodal_model(
226
+ project_id=self.project_id,
227
+ location=self.location,
228
+ prompt=self.prompt,
229
+ pretrained_model=self.pretrained_model,
230
+ )
231
+
232
+ self.log.info("Model response: %s", response)
233
+ self.xcom_push(context, key="prompt_response", value=response)
234
+
235
+ return response
236
+
237
+
238
+ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
239
+ """
240
+ Use the Vertex AI Gemini Pro foundation model to generate natural language text.
241
+
242
+ :param project_id: Required. The ID of the Google Cloud project that the
243
+ service belongs to.
244
+ :param location: Required. The ID of the Google Cloud location that the
245
+ service belongs to.
246
+ :param prompt: Required. Inputs or queries that a user or a program gives
247
+ to the Multi-modal model, in order to elicit a specific response.
248
+ :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
249
+ supporting prompts with text-only input, including natural language
250
+ tasks, multi-turn text and code chat, and code generation. It can
251
+ output text and code.
252
+ :param media_gcs_path: A GCS path to a media file such as an image or a video.
253
+ Can be passed to the multi-modal model as part of the prompt. Used with vision models.
254
+ :param mime_type: Validates the media type presented by the file in the media_gcs_path.
255
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
256
+ :param impersonation_chain: Optional service account to impersonate using short-term
257
+ credentials, or chained list of accounts required to get the access_token
258
+ of the last account in the list, which will be impersonated in the request.
259
+ If set as a string, the account must grant the originating account
260
+ the Service Account Token Creator IAM role.
261
+ If set as a sequence, the identities from the list must grant
262
+ Service Account Token Creator IAM role to the directly preceding identity, with first
263
+ account from the list granting this role to the originating account (templated).
264
+ """
265
+
266
+ def __init__(
267
+ self,
268
+ *,
269
+ project_id: str,
270
+ location: str,
271
+ prompt: str,
272
+ media_gcs_path: str,
273
+ mime_type: str,
274
+ pretrained_model: str = "gemini-pro-vision",
275
+ gcp_conn_id: str = "google_cloud_default",
276
+ impersonation_chain: str | Sequence[str] | None = None,
277
+ **kwargs,
278
+ ) -> None:
279
+ super().__init__(**kwargs)
280
+ self.project_id = project_id
281
+ self.location = location
282
+ self.prompt = prompt
283
+ self.pretrained_model = pretrained_model
284
+ self.media_gcs_path = media_gcs_path
285
+ self.mime_type = mime_type
286
+ self.gcp_conn_id = gcp_conn_id
287
+ self.impersonation_chain = impersonation_chain
288
+
289
+ def execute(self, context: Context):
290
+ self.hook = GenerativeModelHook(
291
+ gcp_conn_id=self.gcp_conn_id,
292
+ impersonation_chain=self.impersonation_chain,
293
+ )
294
+ response = self.hook.prompt_multimodal_model_with_media(
295
+ project_id=self.project_id,
296
+ location=self.location,
297
+ prompt=self.prompt,
298
+ pretrained_model=self.pretrained_model,
299
+ media_gcs_path=self.media_gcs_path,
300
+ mime_type=self.mime_type,
301
+ )
302
+
303
+ self.log.info("Model response: %s", response)
304
+ self.xcom_push(context, key="prompt_response", value=response)
305
+
306
+ return response
@@ -20,14 +20,15 @@
20
20
 
21
21
  from __future__ import annotations
22
22
 
23
+ import warnings
23
24
  from typing import TYPE_CHECKING, Any, Sequence
24
25
 
25
26
  from google.api_core.exceptions import NotFound
26
27
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
27
- from google.cloud.aiplatform_v1.types import HyperparameterTuningJob
28
+ from google.cloud.aiplatform_v1 import types
28
29
 
29
30
  from airflow.configuration import conf
30
- from airflow.exceptions import AirflowException
31
+ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
31
32
  from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
32
33
  HyperparameterTuningJobHook,
33
34
  )
@@ -40,7 +41,7 @@ from airflow.providers.google.cloud.triggers.vertex_ai import CreateHyperparamet
40
41
 
41
42
  if TYPE_CHECKING:
42
43
  from google.api_core.retry import Retry
43
- from google.cloud.aiplatform import gapic, hyperparameter_tuning
44
+ from google.cloud.aiplatform import HyperparameterTuningJob, gapic, hyperparameter_tuning
44
45
 
45
46
  from airflow.utils.context import Context
46
47
 
@@ -127,8 +128,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
127
128
  `service_account` is required with provided `tensorboard`. For more information on configuring
128
129
  your service account please visit:
129
130
  https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
130
- :param sync: Whether to execute this method synchronously. If False, this method will unblock, and it
131
- will be executed in a concurrent Future.
131
+ :param sync: (Deprecated) Whether to execute this method synchronously. If False, this method will
132
+ unblock, and it will be executed in a concurrent Future.
132
133
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
133
134
  :param impersonation_chain: Optional service account to impersonate using short-term
134
135
  credentials, or chained list of accounts required to get the access_token
@@ -138,8 +139,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
138
139
  If set as a sequence, the identities from the list must grant
139
140
  Service Account Token Creator IAM role to the directly preceding identity, with first
140
141
  account from the list granting this role to the originating account (templated).
141
- :param deferrable: Run operator in the deferrable mode. Note that it requires calling the operator
142
- with `sync=False` parameter.
142
+ :param deferrable: Run operator in the deferrable mode.
143
143
  :param poll_interval: Interval size which defines how often job status is checked in deferrable mode.
144
144
  """
145
145
 
@@ -221,19 +221,18 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
221
221
  self.poll_interval = poll_interval
222
222
 
223
223
  def execute(self, context: Context):
224
- if self.deferrable and self.sync:
225
- raise AirflowException(
226
- "Deferrable mode can be used only with sync=False option. "
227
- "If you are willing to run the operator in deferrable mode, please, set sync=False. "
228
- "Otherwise, disable deferrable mode `deferrable=False`."
229
- )
224
+ warnings.warn(
225
+ "The 'sync' parameter is deprecated and will be removed after 01.09.2024.",
226
+ AirflowProviderDeprecationWarning,
227
+ stacklevel=2,
228
+ )
230
229
 
231
230
  self.log.info("Creating Hyperparameter Tuning job")
232
231
  self.hook = HyperparameterTuningJobHook(
233
232
  gcp_conn_id=self.gcp_conn_id,
234
233
  impersonation_chain=self.impersonation_chain,
235
234
  )
236
- result = self.hook.create_hyperparameter_tuning_job(
235
+ hyperparameter_tuning_job: HyperparameterTuningJob = self.hook.create_hyperparameter_tuning_job(
237
236
  project_id=self.project_id,
238
237
  region=self.region,
239
238
  display_name=self.display_name,
@@ -259,14 +258,19 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
259
258
  restart_job_on_worker_restart=self.restart_job_on_worker_restart,
260
259
  enable_web_access=self.enable_web_access,
261
260
  tensorboard=self.tensorboard,
262
- sync=self.sync,
263
- wait_job_completed=not self.deferrable,
261
+ sync=False,
262
+ wait_job_completed=False,
264
263
  )
265
264
 
266
- hyperparameter_tuning_job = result.to_dict()
267
- hyperparameter_tuning_job_id = self.hook.extract_hyperparameter_tuning_job_id(
268
- hyperparameter_tuning_job
265
+ hyperparameter_tuning_job.wait_for_resource_creation()
266
+ hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
267
+ self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
268
+
269
+ self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
270
+ VertexAITrainingLink.persist(
271
+ context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
269
272
  )
273
+
270
274
  if self.deferrable:
271
275
  self.defer(
272
276
  trigger=CreateHyperparameterTuningJobTrigger(
@@ -279,14 +283,10 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
279
283
  ),
280
284
  method_name="execute_complete",
281
285
  )
286
+ return
282
287
 
283
- self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
284
-
285
- self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
286
- VertexAITrainingLink.persist(
287
- context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
288
- )
289
- return hyperparameter_tuning_job
288
+ hyperparameter_tuning_job.wait_for_completion()
289
+ return hyperparameter_tuning_job.to_dict()
290
290
 
291
291
  def on_kill(self) -> None:
292
292
  """Act as a callback called when the operator is killed; cancel any running job."""
@@ -298,26 +298,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
298
298
  raise AirflowException(event["message"])
299
299
  job: dict[str, Any] = event["job"]
300
300
  self.log.info("Hyperparameter tuning job %s created and completed successfully.", job["name"])
301
- hook = HyperparameterTuningJobHook(
302
- gcp_conn_id=self.gcp_conn_id,
303
- impersonation_chain=self.impersonation_chain,
304
- )
305
- job_id = hook.extract_hyperparameter_tuning_job_id(job)
306
- self.xcom_push(
307
- context,
308
- key="hyperparameter_tuning_job_id",
309
- value=job_id,
310
- )
311
- self.xcom_push(
312
- context,
313
- key="training_conf",
314
- value={
315
- "training_conf_id": job_id,
316
- "region": self.region,
317
- "project_id": self.project_id,
318
- },
319
- )
320
- return event["job"]
301
+ return job
321
302
 
322
303
 
323
304
  class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
@@ -387,7 +368,7 @@ class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
387
368
  context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id
388
369
  )
389
370
  self.log.info("Hyperparameter tuning job was gotten.")
390
- return HyperparameterTuningJob.to_dict(result)
371
+ return types.HyperparameterTuningJob.to_dict(result)
391
372
  except NotFound:
392
373
  self.log.info(
393
374
  "The Hyperparameter tuning job %s does not exist.", self.hyperparameter_tuning_job_id
@@ -532,4 +513,4 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
532
513
  metadata=self.metadata,
533
514
  )
534
515
  VertexAIHyperparameterTuningJobListLink.persist(context=context, task_instance=self)
535
- return [HyperparameterTuningJob.to_dict(result) for result in results]
516
+ return [types.HyperparameterTuningJob.to_dict(result) for result in results]
@@ -16,23 +16,29 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains Google Vertex AI operators."""
19
+
19
20
  from __future__ import annotations
20
21
 
22
+ from functools import cached_property
21
23
  from typing import TYPE_CHECKING, Any, Sequence
22
24
 
23
25
  from google.api_core.exceptions import NotFound
24
26
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
25
- from google.cloud.aiplatform_v1.types import PipelineJob
27
+ from google.cloud.aiplatform_v1 import types
26
28
 
29
+ from airflow.configuration import conf
30
+ from airflow.exceptions import AirflowException
27
31
  from airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job import PipelineJobHook
28
32
  from airflow.providers.google.cloud.links.vertex_ai import (
29
33
  VertexAIPipelineJobLink,
30
34
  VertexAIPipelineJobListLink,
31
35
  )
32
36
  from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
37
+ from airflow.providers.google.cloud.triggers.vertex_ai import RunPipelineJobTrigger
33
38
 
34
39
  if TYPE_CHECKING:
35
40
  from google.api_core.retry import Retry
41
+ from google.cloud.aiplatform import PipelineJob
36
42
  from google.cloud.aiplatform.metadata import experiment_resources
37
43
 
38
44
  from airflow.utils.context import Context
@@ -40,7 +46,7 @@ if TYPE_CHECKING:
40
46
 
41
47
  class RunPipelineJobOperator(GoogleCloudBaseOperator):
42
48
  """
43
- Run Pipeline job.
49
+ Create and run a Pipeline job.
44
50
 
45
51
  :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
46
52
  :param region: Required. The ID of the Google Cloud region that the service belongs to.
@@ -82,9 +88,9 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
82
88
  Private services access must already be configured for the network. If left unspecified, the
83
89
  network set in aiplatform.init will be used. Otherwise, the job is not peered with any network.
84
90
  :param create_request_timeout: Optional. The timeout for the create request in seconds.
85
- :param experiment: Optional. The Vertex AI experiment name or instance to associate to this
86
- PipelineJob. Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as
87
- metrics to the current Experiment Run. Pipeline parameters will be associated as parameters to
91
+ :param experiment: Optional. The Vertex AI experiment name or instance to associate to this PipelineJob.
92
+ Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as metrics
93
+ to the current Experiment Run. Pipeline parameters will be associated as parameters to
88
94
  the current Experiment Run.
89
95
  :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
90
96
  :param impersonation_chain: Optional service account to impersonate using short-term
@@ -95,6 +101,10 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
95
101
  If set as a sequence, the identities from the list must grant
96
102
  Service Account Token Creator IAM role to the directly preceding identity, with first
97
103
  account from the list granting this role to the originating account (templated).
104
+ :param deferrable: If True, run the task in the deferrable mode.
105
+ Note that it requires calling the operator with `sync=False` parameter.
106
+ :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
107
+ The default is 300 seconds.
98
108
  """
99
109
 
100
110
  template_fields = [
@@ -126,6 +136,8 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
126
136
  experiment: str | experiment_resources.Experiment | None = None,
127
137
  gcp_conn_id: str = "google_cloud_default",
128
138
  impersonation_chain: str | Sequence[str] | None = None,
139
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
140
+ poll_interval: int = 5 * 60,
129
141
  **kwargs,
130
142
  ) -> None:
131
143
  super().__init__(**kwargs)
@@ -147,15 +159,12 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
147
159
  self.experiment = experiment
148
160
  self.gcp_conn_id = gcp_conn_id
149
161
  self.impersonation_chain = impersonation_chain
150
- self.hook: PipelineJobHook | None = None
162
+ self.deferrable = deferrable
163
+ self.poll_interval = poll_interval
151
164
 
152
165
  def execute(self, context: Context):
153
166
  self.log.info("Running Pipeline job")
154
- self.hook = PipelineJobHook(
155
- gcp_conn_id=self.gcp_conn_id,
156
- impersonation_chain=self.impersonation_chain,
157
- )
158
- result = self.hook.run_pipeline_job(
167
+ pipeline_job_obj: PipelineJob = self.hook.submit_pipeline_job(
159
168
  project_id=self.project_id,
160
169
  region=self.region,
161
170
  display_name=self.display_name,
@@ -173,20 +182,46 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
173
182
  create_request_timeout=self.create_request_timeout,
174
183
  experiment=self.experiment,
175
184
  )
176
-
177
- pipeline_job = result.to_dict()
178
- pipeline_job_id = self.hook.extract_pipeline_job_id(pipeline_job)
185
+ pipeline_job_id = pipeline_job_obj.job_id
179
186
  self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id)
180
-
181
187
  self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id)
182
188
  VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=pipeline_job_id)
189
+
190
+ if self.deferrable:
191
+ pipeline_job_obj.wait_for_resource_creation()
192
+ self.defer(
193
+ trigger=RunPipelineJobTrigger(
194
+ conn_id=self.gcp_conn_id,
195
+ project_id=self.project_id,
196
+ location=pipeline_job_obj.location,
197
+ job_id=pipeline_job_id,
198
+ poll_interval=self.poll_interval,
199
+ impersonation_chain=self.impersonation_chain,
200
+ ),
201
+ method_name="execute_complete",
202
+ )
203
+
204
+ pipeline_job_obj.wait()
205
+ pipeline_job = pipeline_job_obj.to_dict()
183
206
  return pipeline_job
184
207
 
208
+ def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]:
209
+ if event["status"] == "error":
210
+ raise AirflowException(event["message"])
211
+ return event["job"]
212
+
185
213
  def on_kill(self) -> None:
186
214
  """Act as a callback called when the operator is killed; cancel any running job."""
187
215
  if self.hook:
188
216
  self.hook.cancel_pipeline_job()
189
217
 
218
+ @cached_property
219
+ def hook(self) -> PipelineJobHook:
220
+ return PipelineJobHook(
221
+ gcp_conn_id=self.gcp_conn_id,
222
+ impersonation_chain=self.impersonation_chain,
223
+ )
224
+
190
225
 
191
226
  class GetPipelineJobOperator(GoogleCloudBaseOperator):
192
227
  """
@@ -261,7 +296,7 @@ class GetPipelineJobOperator(GoogleCloudBaseOperator):
261
296
  context=context, task_instance=self, pipeline_id=self.pipeline_job_id
262
297
  )
263
298
  self.log.info("Pipeline job was gotten.")
264
- return PipelineJob.to_dict(result)
299
+ return types.PipelineJob.to_dict(result)
265
300
  except NotFound:
266
301
  self.log.info("The Pipeline job %s does not exist.", self.pipeline_job_id)
267
302
 
@@ -389,7 +424,7 @@ class ListPipelineJobOperator(GoogleCloudBaseOperator):
389
424
  metadata=self.metadata,
390
425
  )
391
426
  VertexAIPipelineJobListLink.persist(context=context, task_instance=self)
392
- return [PipelineJob.to_dict(result) for result in results]
427
+ return [types.PipelineJob.to_dict(result) for result in results]
393
428
 
394
429
 
395
430
  class DeletePipelineJobOperator(GoogleCloudBaseOperator):
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains Google Cloud Vision operators."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Vision operator."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from copy import deepcopy
@@ -15,6 +15,7 @@
15
15
  # specific language governing permissions and limitations
16
16
  # under the License.
17
17
  """Objects relating to sourcing connections from Google Cloud Secrets Manager."""
18
+
18
19
  from __future__ import annotations
19
20
 
20
21
  import logging
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains Google BigQuery sensors."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import warnings
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google BigQuery Data Transfer Service sensor."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains Google Cloud Bigtable sensor."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence