apache-airflow-providers-google 10.16.0rc1__py3-none-any.whl → 10.17.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +5 -4
  3. airflow/providers/google/ads/operators/ads.py +1 -0
  4. airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +1 -0
  5. airflow/providers/google/cloud/example_dags/example_cloud_task.py +1 -0
  6. airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py +1 -0
  7. airflow/providers/google/cloud/example_dags/example_looker.py +1 -0
  8. airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py +1 -0
  9. airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py +1 -0
  10. airflow/providers/google/cloud/fs/gcs.py +1 -2
  11. airflow/providers/google/cloud/hooks/automl.py +1 -0
  12. airflow/providers/google/cloud/hooks/bigquery.py +87 -24
  13. airflow/providers/google/cloud/hooks/bigquery_dts.py +1 -0
  14. airflow/providers/google/cloud/hooks/bigtable.py +1 -0
  15. airflow/providers/google/cloud/hooks/cloud_build.py +1 -0
  16. airflow/providers/google/cloud/hooks/cloud_memorystore.py +1 -0
  17. airflow/providers/google/cloud/hooks/cloud_sql.py +1 -0
  18. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +9 -4
  19. airflow/providers/google/cloud/hooks/compute.py +1 -0
  20. airflow/providers/google/cloud/hooks/compute_ssh.py +2 -2
  21. airflow/providers/google/cloud/hooks/dataflow.py +6 -5
  22. airflow/providers/google/cloud/hooks/datafusion.py +1 -0
  23. airflow/providers/google/cloud/hooks/datapipeline.py +1 -0
  24. airflow/providers/google/cloud/hooks/dataplex.py +1 -0
  25. airflow/providers/google/cloud/hooks/dataprep.py +1 -0
  26. airflow/providers/google/cloud/hooks/dataproc.py +3 -2
  27. airflow/providers/google/cloud/hooks/dataproc_metastore.py +1 -0
  28. airflow/providers/google/cloud/hooks/datastore.py +1 -0
  29. airflow/providers/google/cloud/hooks/dlp.py +1 -0
  30. airflow/providers/google/cloud/hooks/functions.py +1 -0
  31. airflow/providers/google/cloud/hooks/gcs.py +12 -5
  32. airflow/providers/google/cloud/hooks/kms.py +1 -0
  33. airflow/providers/google/cloud/hooks/kubernetes_engine.py +178 -300
  34. airflow/providers/google/cloud/hooks/life_sciences.py +1 -0
  35. airflow/providers/google/cloud/hooks/looker.py +1 -0
  36. airflow/providers/google/cloud/hooks/mlengine.py +1 -0
  37. airflow/providers/google/cloud/hooks/natural_language.py +1 -0
  38. airflow/providers/google/cloud/hooks/os_login.py +1 -0
  39. airflow/providers/google/cloud/hooks/pubsub.py +1 -0
  40. airflow/providers/google/cloud/hooks/secret_manager.py +1 -0
  41. airflow/providers/google/cloud/hooks/spanner.py +1 -0
  42. airflow/providers/google/cloud/hooks/speech_to_text.py +1 -0
  43. airflow/providers/google/cloud/hooks/stackdriver.py +1 -0
  44. airflow/providers/google/cloud/hooks/text_to_speech.py +1 -0
  45. airflow/providers/google/cloud/hooks/translate.py +1 -0
  46. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +1 -0
  47. airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +255 -3
  48. airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1 -0
  49. airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +1 -0
  50. airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +1 -0
  51. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +197 -0
  52. airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +9 -9
  53. airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +231 -12
  54. airflow/providers/google/cloud/hooks/video_intelligence.py +1 -0
  55. airflow/providers/google/cloud/hooks/vision.py +1 -0
  56. airflow/providers/google/cloud/links/automl.py +1 -0
  57. airflow/providers/google/cloud/links/bigquery.py +1 -0
  58. airflow/providers/google/cloud/links/bigquery_dts.py +1 -0
  59. airflow/providers/google/cloud/links/cloud_memorystore.py +1 -0
  60. airflow/providers/google/cloud/links/cloud_sql.py +1 -0
  61. airflow/providers/google/cloud/links/cloud_tasks.py +1 -0
  62. airflow/providers/google/cloud/links/compute.py +1 -0
  63. airflow/providers/google/cloud/links/datacatalog.py +1 -0
  64. airflow/providers/google/cloud/links/dataflow.py +1 -0
  65. airflow/providers/google/cloud/links/dataform.py +1 -0
  66. airflow/providers/google/cloud/links/datafusion.py +1 -0
  67. airflow/providers/google/cloud/links/dataplex.py +1 -0
  68. airflow/providers/google/cloud/links/dataproc.py +1 -0
  69. airflow/providers/google/cloud/links/kubernetes_engine.py +28 -0
  70. airflow/providers/google/cloud/links/mlengine.py +1 -0
  71. airflow/providers/google/cloud/links/pubsub.py +1 -0
  72. airflow/providers/google/cloud/links/spanner.py +1 -0
  73. airflow/providers/google/cloud/links/stackdriver.py +1 -0
  74. airflow/providers/google/cloud/links/workflows.py +1 -0
  75. airflow/providers/google/cloud/log/stackdriver_task_handler.py +18 -4
  76. airflow/providers/google/cloud/operators/automl.py +1 -0
  77. airflow/providers/google/cloud/operators/bigquery.py +21 -0
  78. airflow/providers/google/cloud/operators/bigquery_dts.py +1 -0
  79. airflow/providers/google/cloud/operators/bigtable.py +1 -0
  80. airflow/providers/google/cloud/operators/cloud_base.py +1 -0
  81. airflow/providers/google/cloud/operators/cloud_build.py +1 -0
  82. airflow/providers/google/cloud/operators/cloud_memorystore.py +1 -0
  83. airflow/providers/google/cloud/operators/cloud_sql.py +1 -0
  84. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +11 -5
  85. airflow/providers/google/cloud/operators/compute.py +1 -0
  86. airflow/providers/google/cloud/operators/dataflow.py +1 -0
  87. airflow/providers/google/cloud/operators/datafusion.py +1 -0
  88. airflow/providers/google/cloud/operators/datapipeline.py +1 -0
  89. airflow/providers/google/cloud/operators/dataprep.py +1 -0
  90. airflow/providers/google/cloud/operators/dataproc.py +3 -2
  91. airflow/providers/google/cloud/operators/dataproc_metastore.py +1 -0
  92. airflow/providers/google/cloud/operators/datastore.py +1 -0
  93. airflow/providers/google/cloud/operators/functions.py +1 -0
  94. airflow/providers/google/cloud/operators/gcs.py +1 -0
  95. airflow/providers/google/cloud/operators/kubernetes_engine.py +600 -4
  96. airflow/providers/google/cloud/operators/life_sciences.py +1 -0
  97. airflow/providers/google/cloud/operators/looker.py +1 -0
  98. airflow/providers/google/cloud/operators/mlengine.py +283 -259
  99. airflow/providers/google/cloud/operators/natural_language.py +1 -0
  100. airflow/providers/google/cloud/operators/pubsub.py +1 -0
  101. airflow/providers/google/cloud/operators/spanner.py +1 -0
  102. airflow/providers/google/cloud/operators/speech_to_text.py +1 -0
  103. airflow/providers/google/cloud/operators/text_to_speech.py +1 -0
  104. airflow/providers/google/cloud/operators/translate.py +1 -0
  105. airflow/providers/google/cloud/operators/translate_speech.py +1 -0
  106. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +14 -7
  107. airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +67 -13
  108. airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +26 -8
  109. airflow/providers/google/cloud/operators/vertex_ai/dataset.py +1 -0
  110. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +306 -0
  111. airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +29 -48
  112. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +52 -17
  113. airflow/providers/google/cloud/operators/video_intelligence.py +1 -0
  114. airflow/providers/google/cloud/operators/vision.py +1 -0
  115. airflow/providers/google/cloud/secrets/secret_manager.py +1 -0
  116. airflow/providers/google/cloud/sensors/bigquery.py +1 -0
  117. airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -0
  118. airflow/providers/google/cloud/sensors/bigtable.py +1 -0
  119. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +1 -0
  120. airflow/providers/google/cloud/sensors/dataflow.py +1 -0
  121. airflow/providers/google/cloud/sensors/dataform.py +1 -0
  122. airflow/providers/google/cloud/sensors/datafusion.py +1 -0
  123. airflow/providers/google/cloud/sensors/dataplex.py +1 -0
  124. airflow/providers/google/cloud/sensors/dataprep.py +1 -0
  125. airflow/providers/google/cloud/sensors/dataproc.py +1 -0
  126. airflow/providers/google/cloud/sensors/gcs.py +1 -0
  127. airflow/providers/google/cloud/sensors/looker.py +1 -0
  128. airflow/providers/google/cloud/sensors/pubsub.py +1 -0
  129. airflow/providers/google/cloud/sensors/tasks.py +1 -0
  130. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +1 -0
  131. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -0
  132. airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -0
  133. airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +1 -0
  134. airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +1 -0
  135. airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
  136. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -0
  137. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +3 -2
  138. airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -0
  139. airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -0
  140. airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -0
  141. airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -0
  142. airflow/providers/google/cloud/transfers/mysql_to_gcs.py +1 -0
  143. airflow/providers/google/cloud/transfers/postgres_to_gcs.py +19 -1
  144. airflow/providers/google/cloud/transfers/s3_to_gcs.py +3 -5
  145. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -0
  146. airflow/providers/google/cloud/transfers/sql_to_gcs.py +4 -2
  147. airflow/providers/google/cloud/triggers/bigquery.py +4 -3
  148. airflow/providers/google/cloud/triggers/cloud_batch.py +1 -1
  149. airflow/providers/google/cloud/triggers/cloud_run.py +1 -0
  150. airflow/providers/google/cloud/triggers/cloud_sql.py +2 -0
  151. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +14 -2
  152. airflow/providers/google/cloud/triggers/dataplex.py +1 -0
  153. airflow/providers/google/cloud/triggers/dataproc.py +1 -0
  154. airflow/providers/google/cloud/triggers/kubernetes_engine.py +72 -2
  155. airflow/providers/google/cloud/triggers/mlengine.py +2 -0
  156. airflow/providers/google/cloud/triggers/pubsub.py +3 -3
  157. airflow/providers/google/cloud/triggers/vertex_ai.py +107 -15
  158. airflow/providers/google/cloud/utils/field_sanitizer.py +2 -1
  159. airflow/providers/google/cloud/utils/field_validator.py +1 -0
  160. airflow/providers/google/cloud/utils/helpers.py +1 -0
  161. airflow/providers/google/cloud/utils/mlengine_operator_utils.py +1 -0
  162. airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +1 -0
  163. airflow/providers/google/cloud/utils/openlineage.py +1 -0
  164. airflow/providers/google/common/auth_backend/google_openid.py +1 -0
  165. airflow/providers/google/common/hooks/base_google.py +2 -1
  166. airflow/providers/google/common/hooks/discovery_api.py +1 -0
  167. airflow/providers/google/common/links/storage.py +1 -0
  168. airflow/providers/google/common/utils/id_token_credentials.py +1 -0
  169. airflow/providers/google/firebase/hooks/firestore.py +1 -0
  170. airflow/providers/google/get_provider_info.py +9 -3
  171. airflow/providers/google/go_module_utils.py +1 -0
  172. airflow/providers/google/leveldb/hooks/leveldb.py +8 -7
  173. airflow/providers/google/marketing_platform/example_dags/example_display_video.py +1 -0
  174. airflow/providers/google/marketing_platform/hooks/analytics_admin.py +1 -0
  175. airflow/providers/google/marketing_platform/hooks/campaign_manager.py +1 -0
  176. airflow/providers/google/marketing_platform/hooks/display_video.py +1 -0
  177. airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -0
  178. airflow/providers/google/marketing_platform/operators/analytics.py +1 -0
  179. airflow/providers/google/marketing_platform/operators/analytics_admin.py +4 -2
  180. airflow/providers/google/marketing_platform/operators/campaign_manager.py +1 -0
  181. airflow/providers/google/marketing_platform/operators/display_video.py +1 -0
  182. airflow/providers/google/marketing_platform/operators/search_ads.py +1 -0
  183. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +1 -0
  184. airflow/providers/google/marketing_platform/sensors/display_video.py +2 -1
  185. airflow/providers/google/marketing_platform/sensors/search_ads.py +1 -0
  186. airflow/providers/google/suite/hooks/calendar.py +1 -0
  187. airflow/providers/google/suite/hooks/drive.py +1 -0
  188. airflow/providers/google/suite/hooks/sheets.py +1 -0
  189. airflow/providers/google/suite/sensors/drive.py +1 -0
  190. airflow/providers/google/suite/transfers/gcs_to_gdrive.py +7 -0
  191. airflow/providers/google/suite/transfers/gcs_to_sheets.py +4 -1
  192. airflow/providers/google/suite/transfers/local_to_drive.py +1 -0
  193. {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/METADATA +22 -17
  194. {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/RECORD +196 -194
  195. {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/WHEEL +0 -0
  196. {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/entry_points.txt +0 -0
@@ -20,6 +20,7 @@
20
20
  ImportSshPublicKeyResponse
21
21
  oslogin
22
22
  """
23
+
23
24
  from __future__ import annotations
24
25
 
25
26
  from typing import TYPE_CHECKING, Sequence
@@ -23,6 +23,7 @@ This module contains a Google Pub/Sub Hook.
23
23
  MessageStoragePolicy
24
24
  ReceivedMessage
25
25
  """
26
+
26
27
  from __future__ import annotations
27
28
 
28
29
  import warnings
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """Hook for Secrets Manager service."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Spanner Hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Callable, NamedTuple, Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Speech Hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains Google Cloud Stackdriver operators."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import contextlib
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Text to Speech Hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Translate Hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Vertex AI hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  import warnings
@@ -16,21 +16,24 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Vertex AI hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
22
+ import asyncio
21
23
  from typing import TYPE_CHECKING, Sequence
22
24
 
23
25
  from google.api_core.client_options import ClientOptions
24
26
  from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
25
27
  from google.cloud.aiplatform import BatchPredictionJob, Model, explain
26
- from google.cloud.aiplatform_v1 import JobServiceClient
28
+ from google.cloud.aiplatform_v1 import JobServiceAsyncClient, JobServiceClient, JobState, types
27
29
 
28
30
  from airflow.exceptions import AirflowException
29
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
31
+ from airflow.providers.google.common.consts import CLIENT_INFO
32
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
30
33
 
31
34
  if TYPE_CHECKING:
32
35
  from google.api_core.operation import Operation
33
- from google.api_core.retry import Retry
36
+ from google.api_core.retry import AsyncRetry, Retry
34
37
  from google.cloud.aiplatform_v1.services.job_service.pagers import ListBatchPredictionJobsPager
35
38
 
36
39
 
@@ -241,6 +244,159 @@ class BatchPredictionJobHook(GoogleBaseHook):
241
244
  )
242
245
  return self._batch_prediction_job
243
246
 
247
+ @GoogleBaseHook.fallback_to_default_project_id
248
+ def submit_batch_prediction_job(
249
+ self,
250
+ project_id: str,
251
+ region: str,
252
+ job_display_name: str,
253
+ model_name: str | Model,
254
+ instances_format: str = "jsonl",
255
+ predictions_format: str = "jsonl",
256
+ gcs_source: str | Sequence[str] | None = None,
257
+ bigquery_source: str | None = None,
258
+ gcs_destination_prefix: str | None = None,
259
+ bigquery_destination_prefix: str | None = None,
260
+ model_parameters: dict | None = None,
261
+ machine_type: str | None = None,
262
+ accelerator_type: str | None = None,
263
+ accelerator_count: int | None = None,
264
+ starting_replica_count: int | None = None,
265
+ max_replica_count: int | None = None,
266
+ generate_explanation: bool | None = False,
267
+ explanation_metadata: explain.ExplanationMetadata | None = None,
268
+ explanation_parameters: explain.ExplanationParameters | None = None,
269
+ labels: dict[str, str] | None = None,
270
+ encryption_spec_key_name: str | None = None,
271
+ create_request_timeout: float | None = None,
272
+ batch_size: int | None = None,
273
+ ) -> BatchPredictionJob:
274
+ """
275
+ Create a batch prediction job.
276
+
277
+ :param project_id: Required. Project to run training in.
278
+ :param region: Required. Location to run training in.
279
+ :param job_display_name: Required. The user-defined name of the BatchPredictionJob. The name can be
280
+ up to 128 characters long and can be consist of any UTF-8 characters.
281
+ :param model_name: Required. A fully-qualified model resource name or model ID.
282
+ :param instances_format: Required. The format in which instances are provided. Must be one of the
283
+ formats listed in `Model.supported_input_storage_formats`. Default is "jsonl" when using
284
+ `gcs_source`. If a `bigquery_source` is provided, this is overridden to "bigquery".
285
+ :param predictions_format: Required. The format in which Vertex AI outputs the predictions, must be
286
+ one of the formats specified in `Model.supported_output_storage_formats`. Default is "jsonl" when
287
+ using `gcs_destination_prefix`. If a `bigquery_destination_prefix` is provided, this is
288
+ overridden to "bigquery".
289
+ :param gcs_source: Google Cloud Storage URI(-s) to your instances to run batch prediction on. They
290
+ must match `instances_format`. May contain wildcards. For more information on wildcards, see
291
+ https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
292
+ :param bigquery_source: BigQuery URI to a table, up to 2000 characters long.
293
+ For example: `bq://projectId.bqDatasetId.bqTableId`
294
+ :param gcs_destination_prefix: The Google Cloud Storage location of the directory where the output is
295
+ to be written to. In the given directory a new directory is created. Its name is
296
+ ``prediction-<model-display-name>-<job-create-time>``, where timestamp is in
297
+ YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. Inside of it files ``predictions_0001.<extension>``,
298
+ ``predictions_0002.<extension>``, ..., ``predictions_N.<extension>`` are created where
299
+ ``<extension>`` depends on chosen ``predictions_format``, and N may equal 0001 and depends on the
300
+ total number of successfully predicted instances. If the Model has both ``instance`` and
301
+ ``prediction`` schemata defined then each such file contains predictions as per the
302
+ ``predictions_format``. If prediction for any instance failed (partially or completely), then an
303
+ additional ``errors_0001.<extension>``, ``errors_0002.<extension>``,..., ``errors_N.<extension>``
304
+ files are created (N depends on total number of failed predictions). These files contain the
305
+ failed instances, as per their schema, followed by an additional ``error`` field which as value
306
+ has ```google.rpc.Status`` <Status>`__ containing only ``code`` and ``message`` fields.
307
+ :param bigquery_destination_prefix: The BigQuery project location where the output is to be written
308
+ to. In the given project a new dataset is created with name
309
+ ``prediction_<model-display-name>_<job-create-time>`` where is made BigQuery-dataset-name
310
+ compatible (for example, most special characters become underscores), and timestamp is in
311
+ YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the dataset two tables will be created,
312
+ ``predictions``, and ``errors``. If the Model has both ``instance`` and ``prediction`` schemata
313
+ defined then the tables have columns as follows: The ``predictions`` table contains instances for
314
+ which the prediction succeeded, it has columns as per a concatenation of the Model's instance and
315
+ prediction schemata. The ``errors`` table contains rows for which the prediction has failed, it
316
+ has instance columns, as per the instance schema, followed by a single "errors" column, which as
317
+ values has ```google.rpc.Status`` <Status>`__ represented as a STRUCT, and containing only
318
+ ``code`` and ``message``.
319
+ :param model_parameters: The parameters that govern the predictions. The schema of the parameters may
320
+ be specified via the Model's `parameters_schema_uri`.
321
+ :param machine_type: The type of machine for running batch prediction on dedicated resources. Not
322
+ specifying machine type will result in batch prediction job being run with automatic resources.
323
+ :param accelerator_type: The type of accelerator(s) that may be attached to the machine as per
324
+ `accelerator_count`. Only used if `machine_type` is set.
325
+ :param accelerator_count: The number of accelerators to attach to the `machine_type`. Only used if
326
+ `machine_type` is set.
327
+ :param starting_replica_count: The number of machine replicas used at the start of the batch
328
+ operation. If not set, Vertex AI decides starting number, not greater than `max_replica_count`.
329
+ Only used if `machine_type` is set.
330
+ :param max_replica_count: The maximum number of machine replicas the batch operation may be scaled
331
+ to. Only used if `machine_type` is set. Default is 10.
332
+ :param generate_explanation: Optional. Generate explanation along with the batch prediction results.
333
+ This will cause the batch prediction output to include explanations based on the
334
+ `prediction_format`:
335
+ - `bigquery`: output includes a column named `explanation`. The value is a struct that conforms
336
+ to the [aiplatform.gapic.Explanation] object.
337
+ - `jsonl`: The JSON objects on each line include an additional entry keyed `explanation`. The
338
+ value of the entry is a JSON object that conforms to the [aiplatform.gapic.Explanation] object.
339
+ - `csv`: Generating explanations for CSV format is not supported.
340
+ :param explanation_metadata: Optional. Explanation metadata configuration for this
341
+ BatchPredictionJob. Can be specified only if `generate_explanation` is set to `True`.
342
+ This value overrides the value of `Model.explanation_metadata`. All fields of
343
+ `explanation_metadata` are optional in the request. If a field of the `explanation_metadata`
344
+ object is not populated, the corresponding field of the `Model.explanation_metadata` object is
345
+ inherited. For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
346
+ :param explanation_parameters: Optional. Parameters to configure explaining for Model's predictions.
347
+ Can be specified only if `generate_explanation` is set to `True`.
348
+ This value overrides the value of `Model.explanation_parameters`. All fields of
349
+ `explanation_parameters` are optional in the request. If a field of the `explanation_parameters`
350
+ object is not populated, the corresponding field of the `Model.explanation_parameters` object is
351
+ inherited. For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
352
+ :param labels: Optional. The labels with user-defined metadata to organize your BatchPredictionJobs.
353
+ Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain
354
+ lowercase letters, numeric characters, underscores and dashes. International characters are
355
+ allowed. See https://goo.gl/xmQnxf for more information and examples of labels.
356
+ :param encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed
357
+ encryption key used to protect the job. Has the form:
358
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. The key needs to be
359
+ in the same region as where the compute resource is created.
360
+ If this is set, then all resources created by the BatchPredictionJob will be encrypted with the
361
+ provided encryption key.
362
+ Overrides encryption_spec_key_name set in aiplatform.init.
363
+ :param create_request_timeout: Optional. The timeout for the create request in seconds.
364
+ :param batch_size: Optional. The number of the records (e.g. instances)
365
+ of the operation given in each batch
366
+ to a machine replica. Machine type, and size of a single record should be considered
367
+ when setting this parameter, higher value speeds up the batch operation's execution,
368
+ but too high value will result in a whole batch not fitting in a machine's memory,
369
+ and the whole operation will fail.
370
+ The default value is same as in the aiplatform's BatchPredictionJob.
371
+ """
372
+ self._batch_prediction_job = BatchPredictionJob.submit(
373
+ job_display_name=job_display_name,
374
+ model_name=model_name,
375
+ instances_format=instances_format,
376
+ predictions_format=predictions_format,
377
+ gcs_source=gcs_source,
378
+ bigquery_source=bigquery_source,
379
+ gcs_destination_prefix=gcs_destination_prefix,
380
+ bigquery_destination_prefix=bigquery_destination_prefix,
381
+ model_parameters=model_parameters,
382
+ machine_type=machine_type,
383
+ accelerator_type=accelerator_type,
384
+ accelerator_count=accelerator_count,
385
+ starting_replica_count=starting_replica_count,
386
+ max_replica_count=max_replica_count,
387
+ generate_explanation=generate_explanation,
388
+ explanation_metadata=explanation_metadata,
389
+ explanation_parameters=explanation_parameters,
390
+ labels=labels,
391
+ project=project_id,
392
+ location=region,
393
+ credentials=self.get_credentials(),
394
+ encryption_spec_key_name=encryption_spec_key_name,
395
+ create_request_timeout=create_request_timeout,
396
+ batch_size=batch_size,
397
+ )
398
+ return self._batch_prediction_job
399
+
244
400
  @GoogleBaseHook.fallback_to_default_project_id
245
401
  def delete_batch_prediction_job(
246
402
  self,
@@ -358,3 +514,99 @@ class BatchPredictionJobHook(GoogleBaseHook):
358
514
  metadata=metadata,
359
515
  )
360
516
  return result
517
+
518
+
519
+ class BatchPredictionJobAsyncHook(GoogleBaseAsyncHook):
520
+ """Hook for Google Cloud Vertex AI Batch Prediction Job Async APIs."""
521
+
522
+ sync_hook_class = BatchPredictionJobHook
523
+
524
+ def __init__(
525
+ self,
526
+ gcp_conn_id: str = "google_cloud_default",
527
+ impersonation_chain: str | Sequence[str] | None = None,
528
+ **kwargs,
529
+ ):
530
+ super().__init__(
531
+ gcp_conn_id=gcp_conn_id,
532
+ impersonation_chain=impersonation_chain,
533
+ **kwargs,
534
+ )
535
+
536
+ async def get_job_service_client(self, region: str | None = None) -> JobServiceAsyncClient:
537
+ """Return JobServiceAsyncClient object."""
538
+ endpoint = f"{region}-aiplatform.googleapis.com:443" if region and region != "global" else None
539
+ return JobServiceAsyncClient(
540
+ credentials=(await self.get_sync_hook()).get_credentials(),
541
+ client_info=CLIENT_INFO,
542
+ client_options=ClientOptions(api_endpoint=endpoint),
543
+ )
544
+
545
+ async def get_batch_prediction_job(
546
+ self,
547
+ project_id: str,
548
+ location: str,
549
+ job_id: str,
550
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
551
+ timeout: float | None = None,
552
+ metadata: Sequence[tuple[str, str]] = (),
553
+ ) -> types.BatchPredictionJob:
554
+ """Retrieve a batch prediction tuning job.
555
+
556
+ :param project_id: Required. The ID of the Google Cloud project that the job belongs to.
557
+ :param location: Required. The ID of the Google Cloud region that the job belongs to.
558
+ :param job_id: Required. The hyperparameter tuning job id.
559
+ :param retry: Designation of what errors, if any, should be retried.
560
+ :param timeout: The timeout for this request.
561
+ :param metadata: Strings which should be sent along with the request as metadata.
562
+ """
563
+ client: JobServiceAsyncClient = await self.get_job_service_client(region=location)
564
+ job_name = client.batch_prediction_job_path(project_id, location, job_id)
565
+
566
+ result = await client.get_batch_prediction_job(
567
+ request={
568
+ "name": job_name,
569
+ },
570
+ retry=retry,
571
+ timeout=timeout,
572
+ metadata=metadata,
573
+ )
574
+ return result
575
+
576
+ async def wait_batch_prediction_job(
577
+ self,
578
+ project_id: str,
579
+ location: str,
580
+ job_id: str,
581
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
582
+ timeout: float | None = None,
583
+ metadata: Sequence[tuple[str, str]] = (),
584
+ poll_interval: int = 10,
585
+ ) -> types.BatchPredictionJob:
586
+ statuses_complete = {
587
+ JobState.JOB_STATE_CANCELLED,
588
+ JobState.JOB_STATE_FAILED,
589
+ JobState.JOB_STATE_PAUSED,
590
+ JobState.JOB_STATE_SUCCEEDED,
591
+ }
592
+ while True:
593
+ try:
594
+ self.log.info("Requesting batch prediction tuning job with id %s", job_id)
595
+ job: types.BatchPredictionJob = await self.get_batch_prediction_job(
596
+ project_id=project_id,
597
+ location=location,
598
+ job_id=job_id,
599
+ retry=retry,
600
+ timeout=timeout,
601
+ metadata=metadata,
602
+ )
603
+ except Exception as ex:
604
+ self.log.exception("Exception occurred while requesting job %s", job_id)
605
+ raise AirflowException(ex)
606
+
607
+ self.log.info("Status of the batch prediction job %s is %s", job.name, job.state.name)
608
+ if job.state in statuses_complete:
609
+ return job
610
+
611
+ self.log.info("Sleeping for %s seconds.", poll_interval)
612
+ await asyncio.sleep(poll_interval)
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Vertex AI hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Vertex AI hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains a Google Cloud Vertex AI hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING, Sequence
@@ -0,0 +1,197 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """This module contains a Google Cloud Vertex AI Generative Model hook."""
19
+
20
+ from __future__ import annotations
21
+
22
+ from typing import Sequence
23
+
24
+ import vertexai
25
+ from vertexai.generative_models import GenerativeModel, Part
26
+ from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
27
+
28
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
29
+
30
+
31
+ class GenerativeModelHook(GoogleBaseHook):
32
+ """Hook for Google Cloud Vertex AI Generative Model APIs."""
33
+
34
+ def __init__(
35
+ self,
36
+ gcp_conn_id: str = "google_cloud_default",
37
+ impersonation_chain: str | Sequence[str] | None = None,
38
+ **kwargs,
39
+ ):
40
+ if kwargs.get("delegate_to") is not None:
41
+ raise RuntimeError(
42
+ "The `delegate_to` parameter has been deprecated before and finally removed in this version"
43
+ " of Google Provider. You MUST convert it to `impersonate_chain`"
44
+ )
45
+ super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)
46
+
47
+ def get_text_generation_model(self, pretrained_model: str):
48
+ """Return a Model Garden Model object based on Text Generation."""
49
+ model = TextGenerationModel.from_pretrained(pretrained_model)
50
+ return model
51
+
52
+ def get_text_embedding_model(self, pretrained_model: str):
53
+ """Return a Model Garden Model object based on Text Embedding."""
54
+ model = TextEmbeddingModel.from_pretrained(pretrained_model)
55
+ return model
56
+
57
+ def get_generative_model(self, pretrained_model: str) -> GenerativeModel:
58
+ """Return a Generative Model object."""
59
+ model = GenerativeModel(pretrained_model)
60
+ return model
61
+
62
+ def get_generative_model_part(self, content_gcs_path: str, content_mime_type: str | None = None) -> Part:
63
+ """Return a Generative Model Part object."""
64
+ part = Part.from_uri(content_gcs_path, mime_type=content_mime_type)
65
+ return part
66
+
67
+ @GoogleBaseHook.fallback_to_default_project_id
68
+ def prompt_language_model(
69
+ self,
70
+ prompt: str,
71
+ pretrained_model: str,
72
+ temperature: float,
73
+ max_output_tokens: int,
74
+ top_p: float,
75
+ top_k: int,
76
+ location: str,
77
+ project_id: str = PROVIDE_PROJECT_ID,
78
+ ) -> str:
79
+ """
80
+ Use the Vertex AI PaLM API to generate natural language text.
81
+
82
+ :param prompt: Required. Inputs or queries that a user or a program gives
83
+ to the Vertex AI PaLM API, in order to elicit a specific response.
84
+ :param pretrained_model: A pre-trained model optimized for performing natural
85
+ language tasks such as classification, summarization, extraction, content
86
+ creation, and ideation.
87
+ :param temperature: Temperature controls the degree of randomness in token
88
+ selection.
89
+ :param max_output_tokens: Token limit determines the maximum amount of text
90
+ output.
91
+ :param top_p: Tokens are selected from most probable to least until the sum
92
+ of their probabilities equals the top_p value. Defaults to 0.8.
93
+ :param top_k: A top_k of 1 means the selected token is the most probable
94
+ among all tokens.
95
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
96
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
97
+ """
98
+ vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
99
+
100
+ parameters = {
101
+ "temperature": temperature,
102
+ "max_output_tokens": max_output_tokens,
103
+ "top_p": top_p,
104
+ "top_k": top_k,
105
+ }
106
+
107
+ model = self.get_text_generation_model(pretrained_model)
108
+
109
+ response = model.predict(
110
+ prompt=prompt,
111
+ **parameters,
112
+ )
113
+ return response.text
114
+
115
+ @GoogleBaseHook.fallback_to_default_project_id
116
+ def generate_text_embeddings(
117
+ self,
118
+ prompt: str,
119
+ pretrained_model: str,
120
+ location: str,
121
+ project_id: str = PROVIDE_PROJECT_ID,
122
+ ) -> list:
123
+ """
124
+ Use the Vertex AI PaLM API to generate text embeddings.
125
+
126
+ :param prompt: Required. Inputs or queries that a user or a program gives
127
+ to the Vertex AI PaLM API, in order to elicit a specific response.
128
+ :param pretrained_model: A pre-trained model optimized for generating text embeddings.
129
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
130
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
131
+ """
132
+ vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
133
+ model = self.get_text_embedding_model(pretrained_model)
134
+
135
+ response = model.get_embeddings([prompt])[0] # single prompt
136
+
137
+ return response.values
138
+
139
+ @GoogleBaseHook.fallback_to_default_project_id
140
+ def prompt_multimodal_model(
141
+ self,
142
+ prompt: str,
143
+ location: str,
144
+ pretrained_model: str = "gemini-pro",
145
+ project_id: str = PROVIDE_PROJECT_ID,
146
+ ) -> str:
147
+ """
148
+ Use the Vertex AI Gemini Pro foundation model to generate natural language text.
149
+
150
+ :param prompt: Required. Inputs or queries that a user or a program gives
151
+ to the Multi-modal model, in order to elicit a specific response.
152
+ :param pretrained_model: By default uses the pre-trained model `gemini-pro`,
153
+ supporting prompts with text-only input, including natural language
154
+ tasks, multi-turn text and code chat, and code generation. It can
155
+ output text and code.
156
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
157
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
158
+ """
159
+ vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
160
+
161
+ model = self.get_generative_model(pretrained_model)
162
+ response = model.generate_content(prompt)
163
+
164
+ return response.text
165
+
166
+ @GoogleBaseHook.fallback_to_default_project_id
167
+ def prompt_multimodal_model_with_media(
168
+ self,
169
+ prompt: str,
170
+ location: str,
171
+ media_gcs_path: str,
172
+ mime_type: str,
173
+ pretrained_model: str = "gemini-pro-vision",
174
+ project_id: str = PROVIDE_PROJECT_ID,
175
+ ) -> str:
176
+ """
177
+ Use the Vertex AI Gemini Pro foundation model to generate natural language text.
178
+
179
+ :param prompt: Required. Inputs or queries that a user or a program gives
180
+ to the Multi-modal model, in order to elicit a specific response.
181
+ :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
182
+ supporting prompts with text-only input, including natural language
183
+ tasks, multi-turn text and code chat, and code generation. It can
184
+ output text and code.
185
+ :param media_gcs_path: A GCS path to a content file such as an image or a video.
186
+ Can be passed to the multi-modal model as part of the prompt. Used with vision models.
187
+ :param mime_type: Validates the media type presented by the file in the media_gcs_path.
188
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
189
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
190
+ """
191
+ vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
192
+
193
+ model = self.get_generative_model(pretrained_model)
194
+ part = self.get_generative_model_part(media_gcs_path, mime_type)
195
+ response = model.generate_content([prompt, part])
196
+
197
+ return response.text
@@ -22,10 +22,10 @@ This module contains a Google Cloud Vertex AI hook.
22
22
 
23
23
  JobServiceAsyncClient
24
24
  """
25
+
25
26
  from __future__ import annotations
26
27
 
27
28
  import asyncio
28
- from functools import lru_cache
29
29
  from typing import TYPE_CHECKING, Sequence
30
30
 
31
31
  from google.api_core.client_options import ClientOptions
@@ -35,12 +35,11 @@ from google.cloud.aiplatform_v1 import JobServiceAsyncClient, JobServiceClient,
35
35
 
36
36
  from airflow.exceptions import AirflowException
37
37
  from airflow.providers.google.common.consts import CLIENT_INFO
38
- from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
38
+ from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
39
39
 
40
40
  if TYPE_CHECKING:
41
41
  from google.api_core.operation import Operation
42
- from google.api_core.retry import Retry
43
- from google.api_core.retry_async import AsyncRetry
42
+ from google.api_core.retry import AsyncRetry, Retry
44
43
  from google.cloud.aiplatform_v1.services.job_service.pagers import ListHyperparameterTuningJobsPager
45
44
 
46
45
 
@@ -431,9 +430,11 @@ class HyperparameterTuningJobHook(GoogleBaseHook):
431
430
  return result
432
431
 
433
432
 
434
- class HyperparameterTuningJobAsyncHook(GoogleBaseHook):
433
+ class HyperparameterTuningJobAsyncHook(GoogleBaseAsyncHook):
435
434
  """Async hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs."""
436
435
 
436
+ sync_hook_class = HyperparameterTuningJobHook
437
+
437
438
  def __init__(
438
439
  self,
439
440
  gcp_conn_id: str = "google_cloud_default",
@@ -446,8 +447,7 @@ class HyperparameterTuningJobAsyncHook(GoogleBaseHook):
446
447
  **kwargs,
447
448
  )
448
449
 
449
- @lru_cache
450
- def get_job_service_client(self, region: str | None = None) -> JobServiceAsyncClient:
450
+ async def get_job_service_client(self, region: str | None = None) -> JobServiceAsyncClient:
451
451
  """
452
452
  Retrieve Vertex AI async client.
453
453
 
@@ -455,7 +455,7 @@ class HyperparameterTuningJobAsyncHook(GoogleBaseHook):
455
455
  """
456
456
  endpoint = f"{region}-aiplatform.googleapis.com:443" if region and region != "global" else None
457
457
  return JobServiceAsyncClient(
458
- credentials=self.get_credentials(),
458
+ credentials=(await self.get_sync_hook()).get_credentials(),
459
459
  client_info=CLIENT_INFO,
460
460
  client_options=ClientOptions(api_endpoint=endpoint),
461
461
  )
@@ -479,7 +479,7 @@ class HyperparameterTuningJobAsyncHook(GoogleBaseHook):
479
479
  :param timeout: The timeout for this request.
480
480
  :param metadata: Strings which should be sent along with the request as metadata.
481
481
  """
482
- client: JobServiceAsyncClient = self.get_job_service_client(region=location)
482
+ client: JobServiceAsyncClient = await self.get_job_service_client(region=location)
483
483
  job_name = client.hyperparameter_tuning_job_path(project_id, location, job_id)
484
484
 
485
485
  result = await client.get_hyperparameter_tuning_job(