apache-airflow-providers-google 15.1.0__py3-none-any.whl → 16.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. airflow/providers/google/__init__.py +3 -3
  2. airflow/providers/google/ads/hooks/ads.py +34 -0
  3. airflow/providers/google/cloud/hooks/bigquery.py +63 -76
  4. airflow/providers/google/cloud/hooks/dataflow.py +67 -5
  5. airflow/providers/google/cloud/hooks/gcs.py +3 -3
  6. airflow/providers/google/cloud/hooks/looker.py +5 -0
  7. airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -36
  8. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +1 -66
  9. airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
  10. airflow/providers/google/cloud/links/cloud_run.py +59 -0
  11. airflow/providers/google/cloud/links/vertex_ai.py +49 -0
  12. airflow/providers/google/cloud/log/gcs_task_handler.py +7 -5
  13. airflow/providers/google/cloud/operators/bigquery.py +49 -10
  14. airflow/providers/google/cloud/operators/cloud_run.py +20 -2
  15. airflow/providers/google/cloud/operators/gcs.py +1 -0
  16. airflow/providers/google/cloud/operators/kubernetes_engine.py +4 -86
  17. airflow/providers/google/cloud/operators/pubsub.py +2 -1
  18. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +0 -92
  19. airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +4 -0
  20. airflow/providers/google/cloud/operators/vertex_ai/ray.py +388 -0
  21. airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +9 -5
  22. airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -1
  23. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +2 -0
  24. airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
  25. airflow/providers/google/cloud/transfers/s3_to_gcs.py +11 -5
  26. airflow/providers/google/cloud/triggers/bigquery.py +32 -5
  27. airflow/providers/google/cloud/triggers/dataflow.py +122 -0
  28. airflow/providers/google/cloud/triggers/dataproc.py +62 -10
  29. airflow/providers/google/get_provider_info.py +18 -5
  30. airflow/providers/google/leveldb/hooks/leveldb.py +25 -0
  31. airflow/providers/google/version_compat.py +0 -1
  32. {apache_airflow_providers_google-15.1.0.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/METADATA +91 -84
  33. {apache_airflow_providers_google-15.1.0.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/RECORD +35 -32
  34. airflow/providers/google/cloud/links/automl.py +0 -193
  35. {apache_airflow_providers_google-15.1.0.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/WHEEL +0 -0
  36. {apache_airflow_providers_google-15.1.0.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/entry_points.txt +0 -0
@@ -29,11 +29,11 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "15.1.0"
32
+ __version__ = "16.0.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
- "2.9.0"
35
+ "2.10.0"
36
36
  ):
37
37
  raise RuntimeError(
38
- f"The package `apache-airflow-providers-google:{__version__}` needs Apache Airflow 2.9.0+"
38
+ f"The package `apache-airflow-providers-google:{__version__}` needs Apache Airflow 2.10.0+"
39
39
  )
@@ -101,6 +101,40 @@ class GoogleAdsHook(BaseHook):
101
101
  :param api_version: The Google Ads API version to use.
102
102
  """
103
103
 
104
+ conn_name_attr = "google_ads_conn_id"
105
+ default_conn_name = "google_ads_default"
106
+ conn_type = "google_ads"
107
+ hook_name = "Google Ads"
108
+
109
+ @classmethod
110
+ def get_connection_form_widgets(cls) -> dict[str, Any]:
111
+ """Return connection widgets to add to Google Ads connection form."""
112
+ from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
113
+ from flask_babel import lazy_gettext
114
+ from wtforms import PasswordField, StringField
115
+
116
+ return {
117
+ "developer_token": StringField(lazy_gettext("Developer token"), widget=BS3TextFieldWidget()),
118
+ "client_id": StringField(lazy_gettext("OAuth2 Client ID"), widget=BS3TextFieldWidget()),
119
+ "client_secret": PasswordField(
120
+ lazy_gettext("OAuth2 Client Secret"), widget=BS3PasswordFieldWidget()
121
+ ),
122
+ "refresh_token": PasswordField(
123
+ lazy_gettext("OAuth2 Refresh Token"), widget=BS3PasswordFieldWidget()
124
+ ),
125
+ }
126
+
127
+ @classmethod
128
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
129
+ """Return custom UI field behaviour for Google Ads connection."""
130
+ return {
131
+ "hidden_fields": ["host", "login", "schema", "port"],
132
+ "relabeling": {},
133
+ "placeholders": {
134
+ "password": "Leave blank (optional)",
135
+ },
136
+ }
137
+
104
138
  def __init__(
105
139
  self,
106
140
  api_version: str | None = None,
@@ -29,7 +29,7 @@ import uuid
29
29
  from collections.abc import Iterable, Mapping, Sequence
30
30
  from copy import deepcopy
31
31
  from datetime import datetime, timedelta
32
- from typing import TYPE_CHECKING, Any, NoReturn, Union, cast
32
+ from typing import TYPE_CHECKING, Any, NoReturn, Union, cast, overload
33
33
 
34
34
  from aiohttp import ClientSession as ClientSession
35
35
  from gcloud.aio.bigquery import Job, Table as Table_async
@@ -57,8 +57,13 @@ from googleapiclient.discovery import build
57
57
  from pandas_gbq import read_gbq
58
58
  from pandas_gbq.gbq import GbqConnector # noqa: F401 used in ``airflow.contrib.hooks.bigquery``
59
59
  from sqlalchemy import create_engine
60
+ from typing_extensions import Literal
60
61
 
61
- from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
62
+ from airflow.exceptions import (
63
+ AirflowException,
64
+ AirflowOptionalProviderFeatureException,
65
+ AirflowProviderDeprecationWarning,
66
+ )
62
67
  from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
63
68
  from airflow.providers.common.sql.hooks.sql import DbApiHook
64
69
  from airflow.providers.google.cloud.utils.bigquery import bq_cast
@@ -77,6 +82,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
77
82
 
78
83
  if TYPE_CHECKING:
79
84
  import pandas as pd
85
+ import polars as pl
80
86
  from google.api_core.page_iterator import HTTPIterator
81
87
  from google.api_core.retry import Retry
82
88
  from requests import Session
@@ -275,15 +281,57 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
275
281
  """
276
282
  raise NotImplementedError()
277
283
 
278
- def get_pandas_df(
284
+ def _get_pandas_df(
279
285
  self,
280
286
  sql: str,
281
287
  parameters: Iterable | Mapping[str, Any] | None = None,
282
288
  dialect: str | None = None,
283
289
  **kwargs,
284
290
  ) -> pd.DataFrame:
291
+ if dialect is None:
292
+ dialect = "legacy" if self.use_legacy_sql else "standard"
293
+
294
+ credentials, project_id = self.get_credentials_and_project_id()
295
+
296
+ return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)
297
+
298
+ def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.DataFrame:
299
+ try:
300
+ import polars as pl
301
+ except ImportError:
302
+ raise AirflowOptionalProviderFeatureException(
303
+ "Polars is not installed. Please install it with `pip install polars`."
304
+ )
305
+
306
+ if dialect is None:
307
+ dialect = "legacy" if self.use_legacy_sql else "standard"
308
+
309
+ credentials, project_id = self.get_credentials_and_project_id()
310
+
311
+ pandas_df = read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)
312
+ return pl.from_pandas(pandas_df)
313
+
314
+ @overload
315
+ def get_df(
316
+ self, sql, parameters=None, dialect=None, *, df_type: Literal["pandas"] = "pandas", **kwargs
317
+ ) -> pd.DataFrame: ...
318
+
319
+ @overload
320
+ def get_df(
321
+ self, sql, parameters=None, dialect=None, *, df_type: Literal["polars"], **kwargs
322
+ ) -> pl.DataFrame: ...
323
+
324
+ def get_df(
325
+ self,
326
+ sql,
327
+ parameters=None,
328
+ dialect=None,
329
+ *,
330
+ df_type: Literal["pandas", "polars"] = "pandas",
331
+ **kwargs,
332
+ ) -> pd.DataFrame | pl.DataFrame:
285
333
  """
286
- Get a Pandas DataFrame for the BigQuery results.
334
+ Get a DataFrame for the BigQuery results.
287
335
 
288
336
  The DbApiHook method must be overridden because Pandas doesn't support
289
337
  PEP 249 connections, except for SQLite.
@@ -299,12 +347,19 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
299
347
  defaults to use `self.use_legacy_sql` if not specified
300
348
  :param kwargs: (optional) passed into pandas_gbq.read_gbq method
301
349
  """
302
- if dialect is None:
303
- dialect = "legacy" if self.use_legacy_sql else "standard"
350
+ if df_type == "polars":
351
+ return self._get_polars_df(sql, parameters, dialect, **kwargs)
304
352
 
305
- credentials, project_id = self.get_credentials_and_project_id()
353
+ if df_type == "pandas":
354
+ return self._get_pandas_df(sql, parameters, dialect, **kwargs)
306
355
 
307
- return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)
356
+ @deprecated(
357
+ planned_removal_date="November 30, 2025",
358
+ use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_df",
359
+ category=AirflowProviderDeprecationWarning,
360
+ )
361
+ def get_pandas_df(self, sql, parameters=None, dialect=None, **kwargs):
362
+ return self._get_pandas_df(sql, parameters, dialect, **kwargs)
308
363
 
309
364
  @GoogleBaseHook.fallback_to_default_project_id
310
365
  def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool:
@@ -1937,74 +1992,6 @@ def _escape(s: str) -> str:
1937
1992
  return e
1938
1993
 
1939
1994
 
1940
- @deprecated(
1941
- planned_removal_date="April 01, 2025",
1942
- use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.split_tablename",
1943
- category=AirflowProviderDeprecationWarning,
1944
- )
1945
- def split_tablename(
1946
- table_input: str, default_project_id: str, var_name: str | None = None
1947
- ) -> tuple[str, str, str]:
1948
- if "." not in table_input:
1949
- raise ValueError(f"Expected table name in the format of <dataset>.<table>. Got: {table_input}")
1950
-
1951
- if not default_project_id:
1952
- raise ValueError("INTERNAL: No default project is specified")
1953
-
1954
- def var_print(var_name):
1955
- if var_name is None:
1956
- return ""
1957
- return f"Format exception for {var_name}: "
1958
-
1959
- if table_input.count(".") + table_input.count(":") > 3:
1960
- raise ValueError(f"{var_print(var_name)}Use either : or . to specify project got {table_input}")
1961
- cmpt = table_input.rsplit(":", 1)
1962
- project_id = None
1963
- rest = table_input
1964
- if len(cmpt) == 1:
1965
- project_id = None
1966
- rest = cmpt[0]
1967
- elif len(cmpt) == 2 and cmpt[0].count(":") <= 1:
1968
- if cmpt[-1].count(".") != 2:
1969
- project_id = cmpt[0]
1970
- rest = cmpt[1]
1971
- else:
1972
- raise ValueError(
1973
- f"{var_print(var_name)}Expect format of (<project:)<dataset>.<table>, got {table_input}"
1974
- )
1975
-
1976
- cmpt = rest.split(".")
1977
- if len(cmpt) == 3:
1978
- if project_id:
1979
- raise ValueError(f"{var_print(var_name)}Use either : or . to specify project")
1980
- project_id = cmpt[0]
1981
- dataset_id = cmpt[1]
1982
- table_id = cmpt[2]
1983
-
1984
- elif len(cmpt) == 2:
1985
- dataset_id = cmpt[0]
1986
- table_id = cmpt[1]
1987
- else:
1988
- raise ValueError(
1989
- f"{var_print(var_name)}Expect format of (<project.|<project:)<dataset>.<table>, got {table_input}"
1990
- )
1991
-
1992
- # Exclude partition from the table name
1993
- table_id = table_id.split("$")[0]
1994
-
1995
- if project_id is None:
1996
- if var_name is not None:
1997
- log.info(
1998
- 'Project is not included in %s: %s; using project "%s"',
1999
- var_name,
2000
- table_input,
2001
- default_project_id,
2002
- )
2003
- project_id = default_project_id
2004
-
2005
- return project_id, dataset_id, table_id
2006
-
2007
-
2008
1995
  def _cleanse_time_partitioning(
2009
1996
  destination_dataset_table: str | None, time_partitioning_in: dict | None
2010
1997
  ) -> dict: # if it is a partitioned table ($ is in the table name) add partition load option
@@ -185,7 +185,67 @@ class DataflowJobType:
185
185
  JOB_TYPE_STREAMING = "JOB_TYPE_STREAMING"
186
186
 
187
187
 
188
- class _DataflowJobsController(LoggingMixin):
188
+ class DataflowJobTerminalStateHelper(LoggingMixin):
189
+ """Helper to define and validate the dataflow job terminal state."""
190
+
191
+ @staticmethod
192
+ def expected_terminal_state_is_allowed(expected_terminal_state):
193
+ job_allowed_terminal_states = DataflowJobStatus.TERMINAL_STATES | {
194
+ DataflowJobStatus.JOB_STATE_RUNNING
195
+ }
196
+ if expected_terminal_state not in job_allowed_terminal_states:
197
+ raise AirflowException(
198
+ f"Google Cloud Dataflow job's expected terminal state "
199
+ f"'{expected_terminal_state}' is invalid."
200
+ f" The value should be any of the following: {job_allowed_terminal_states}"
201
+ )
202
+ return True
203
+
204
+ @staticmethod
205
+ def expected_terminal_state_is_valid_for_job_type(expected_terminal_state, is_streaming: bool):
206
+ if is_streaming:
207
+ invalid_terminal_state = DataflowJobStatus.JOB_STATE_DONE
208
+ job_type = "streaming"
209
+ else:
210
+ invalid_terminal_state = DataflowJobStatus.JOB_STATE_DRAINED
211
+ job_type = "batch"
212
+
213
+ if expected_terminal_state == invalid_terminal_state:
214
+ raise AirflowException(
215
+ f"Google Cloud Dataflow job's expected terminal state cannot be {invalid_terminal_state} while it is a {job_type} job"
216
+ )
217
+ return True
218
+
219
+ def job_reached_terminal_state(self, job, wait_until_finished=None, custom_terminal_state=None) -> bool:
220
+ """
221
+ Check the job reached terminal state, if job failed raise exception.
222
+
223
+ :return: True if job is done.
224
+ :raise: Exception
225
+ """
226
+ current_state = job["currentState"]
227
+ is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING
228
+ expected_terminal_state = (
229
+ DataflowJobStatus.JOB_STATE_RUNNING if is_streaming else DataflowJobStatus.JOB_STATE_DONE
230
+ )
231
+ if custom_terminal_state is not None:
232
+ expected_terminal_state = custom_terminal_state
233
+ self.expected_terminal_state_is_allowed(expected_terminal_state)
234
+ self.expected_terminal_state_is_valid_for_job_type(expected_terminal_state, is_streaming=is_streaming)
235
+ if current_state == expected_terminal_state:
236
+ if expected_terminal_state == DataflowJobStatus.JOB_STATE_RUNNING and wait_until_finished:
237
+ return False
238
+ return True
239
+ if current_state in DataflowJobStatus.AWAITING_STATES:
240
+ return wait_until_finished is False
241
+ self.log.debug("Current job: %s", job)
242
+ raise AirflowException(
243
+ f"Google Cloud Dataflow job {job['name']} is in an unexpected terminal state: {current_state}, "
244
+ f"expected terminal state: {expected_terminal_state}"
245
+ )
246
+
247
+
248
+ class _DataflowJobsController(DataflowJobTerminalStateHelper):
189
249
  """
190
250
  Interface for communication with Google Cloud Dataflow API.
191
251
 
@@ -462,7 +522,10 @@ class _DataflowJobsController(LoggingMixin):
462
522
  """Wait for result of submitted job."""
463
523
  self.log.info("Start waiting for done.")
464
524
  self._refresh_jobs()
465
- while self._jobs and not all(self._check_dataflow_job_state(job) for job in self._jobs):
525
+ while self._jobs and not all(
526
+ self.job_reached_terminal_state(job, self._wait_until_finished, self._expected_terminal_state)
527
+ for job in self._jobs
528
+ ):
466
529
  self.log.info("Waiting for done. Sleep %s s", self._poll_sleep)
467
530
  time.sleep(self._poll_sleep)
468
531
  self._refresh_jobs()
@@ -1295,8 +1358,7 @@ class DataflowHook(GoogleBaseHook):
1295
1358
  location=location,
1296
1359
  )
1297
1360
  job = job_controller.fetch_job_by_id(job_id)
1298
-
1299
- return job_controller._check_dataflow_job_state(job)
1361
+ return job_controller.job_reached_terminal_state(job)
1300
1362
 
1301
1363
  @GoogleBaseHook.fallback_to_default_project_id
1302
1364
  def create_data_pipeline(
@@ -1425,7 +1487,7 @@ class DataflowHook(GoogleBaseHook):
1425
1487
  return f"projects/{project_id}/locations/{location}"
1426
1488
 
1427
1489
 
1428
- class AsyncDataflowHook(GoogleBaseAsyncHook):
1490
+ class AsyncDataflowHook(GoogleBaseAsyncHook, DataflowJobTerminalStateHelper):
1429
1491
  """Async hook class for dataflow service."""
1430
1492
 
1431
1493
  sync_hook_class = DataflowHook
@@ -549,13 +549,13 @@ class GCSHook(GoogleBaseHook):
549
549
  if cache_control:
550
550
  blob.cache_control = cache_control
551
551
 
552
- if filename and data:
552
+ if filename is not None and data is not None:
553
553
  raise ValueError(
554
554
  "'filename' and 'data' parameter provided. Please "
555
555
  "specify a single parameter, either 'filename' for "
556
556
  "local file uploads or 'data' for file content uploads."
557
557
  )
558
- if filename:
558
+ if filename is not None:
559
559
  if not mime_type:
560
560
  mime_type = "application/octet-stream"
561
561
  if gzip:
@@ -575,7 +575,7 @@ class GCSHook(GoogleBaseHook):
575
575
  if gzip:
576
576
  os.remove(filename)
577
577
  self.log.info("File %s uploaded to %s in %s bucket", filename, object_name, bucket_name)
578
- elif data:
578
+ elif data is not None:
579
579
  if not mime_type:
580
580
  mime_type = "text/plain"
581
581
  if gzip:
@@ -39,6 +39,11 @@ if TYPE_CHECKING:
39
39
  class LookerHook(BaseHook):
40
40
  """Hook for Looker APIs."""
41
41
 
42
+ conn_name_attr = "looker_conn_id"
43
+ default_conn_name = "looker_default"
44
+ conn_type = "gcp_looker"
45
+ hook_name = "Google Looker"
46
+
42
47
  def __init__(
43
48
  self,
44
49
  looker_conn_id: str,
@@ -185,42 +185,6 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
185
185
  model_encryption_spec_key_name=model_encryption_spec_key_name,
186
186
  )
187
187
 
188
- @deprecated(
189
- planned_removal_date="June 15, 2025",
190
- category=AirflowProviderDeprecationWarning,
191
- reason="Deprecation of AutoMLText API",
192
- )
193
- def get_auto_ml_text_training_job(
194
- self,
195
- display_name: str,
196
- prediction_type: str,
197
- multi_label: bool = False,
198
- sentiment_max: int = 10,
199
- project: str | None = None,
200
- location: str | None = None,
201
- labels: dict[str, str] | None = None,
202
- training_encryption_spec_key_name: str | None = None,
203
- model_encryption_spec_key_name: str | None = None,
204
- ) -> AutoMLTextTrainingJob:
205
- """
206
- Return AutoMLTextTrainingJob object.
207
-
208
- WARNING: Text creation API is deprecated since September 15, 2024
209
- (https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
210
- """
211
- return AutoMLTextTrainingJob(
212
- display_name=display_name,
213
- prediction_type=prediction_type,
214
- multi_label=multi_label,
215
- sentiment_max=sentiment_max,
216
- project=project,
217
- location=location,
218
- credentials=self.get_credentials(),
219
- labels=labels,
220
- training_encryption_spec_key_name=training_encryption_spec_key_name,
221
- model_encryption_spec_key_name=model_encryption_spec_key_name,
222
- )
223
-
224
188
  def get_auto_ml_video_training_job(
225
189
  self,
226
190
  display_name: str,
@@ -25,14 +25,12 @@ from typing import TYPE_CHECKING
25
25
 
26
26
  import vertexai
27
27
  from vertexai.generative_models import GenerativeModel
28
- from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
28
+ from vertexai.language_models import TextEmbeddingModel
29
29
  from vertexai.preview.caching import CachedContent
30
30
  from vertexai.preview.evaluation import EvalResult, EvalTask
31
31
  from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
32
32
  from vertexai.preview.tuning import sft
33
33
 
34
- from airflow.exceptions import AirflowProviderDeprecationWarning
35
- from airflow.providers.google.common.deprecated import deprecated
36
34
  from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
37
35
 
38
36
  if TYPE_CHECKING:
@@ -43,16 +41,6 @@ if TYPE_CHECKING:
43
41
  class GenerativeModelHook(GoogleBaseHook):
44
42
  """Hook for Google Cloud Vertex AI Generative Model APIs."""
45
43
 
46
- @deprecated(
47
- planned_removal_date="April 09, 2025",
48
- use_instead="GenerativeModelHook.get_generative_model",
49
- category=AirflowProviderDeprecationWarning,
50
- )
51
- def get_text_generation_model(self, pretrained_model: str):
52
- """Return a Model Garden Model object based on Text Generation."""
53
- model = TextGenerationModel.from_pretrained(pretrained_model)
54
- return model
55
-
56
44
  def get_text_embedding_model(self, pretrained_model: str):
57
45
  """Return a Model Garden Model object based on Text Embedding."""
58
46
  model = TextEmbeddingModel.from_pretrained(pretrained_model)
@@ -100,59 +88,6 @@ class GenerativeModelHook(GoogleBaseHook):
100
88
  cached_context_model = preview_generative_model.from_cached_content(cached_content)
101
89
  return cached_context_model
102
90
 
103
- @deprecated(
104
- planned_removal_date="April 09, 2025",
105
- use_instead="GenerativeModelHook.generative_model_generate_content",
106
- category=AirflowProviderDeprecationWarning,
107
- )
108
- @GoogleBaseHook.fallback_to_default_project_id
109
- def text_generation_model_predict(
110
- self,
111
- prompt: str,
112
- pretrained_model: str,
113
- temperature: float,
114
- max_output_tokens: int,
115
- top_p: float,
116
- top_k: int,
117
- location: str,
118
- project_id: str = PROVIDE_PROJECT_ID,
119
- ) -> str:
120
- """
121
- Use the Vertex AI PaLM API to generate natural language text.
122
-
123
- :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
124
- :param location: Required. The ID of the Google Cloud location that the service belongs to.
125
- :param prompt: Required. Inputs or queries that a user or a program gives
126
- to the Vertex AI PaLM API, in order to elicit a specific response.
127
- :param pretrained_model: A pre-trained model optimized for performing natural
128
- language tasks such as classification, summarization, extraction, content
129
- creation, and ideation.
130
- :param temperature: Temperature controls the degree of randomness in token
131
- selection.
132
- :param max_output_tokens: Token limit determines the maximum amount of text
133
- output.
134
- :param top_p: Tokens are selected from most probable to least until the sum
135
- of their probabilities equals the top_p value. Defaults to 0.8.
136
- :param top_k: A top_k of 1 means the selected token is the most probable
137
- among all tokens.
138
- """
139
- vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
140
-
141
- parameters = {
142
- "temperature": temperature,
143
- "max_output_tokens": max_output_tokens,
144
- "top_p": top_p,
145
- "top_k": top_k,
146
- }
147
-
148
- model = self.get_text_generation_model(pretrained_model)
149
-
150
- response = model.predict(
151
- prompt=prompt,
152
- **parameters,
153
- )
154
- return response.text
155
-
156
91
  @GoogleBaseHook.fallback_to_default_project_id
157
92
  def text_embedding_model_get_embeddings(
158
93
  self,