apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +34 -0
- airflow/providers/google/cloud/hooks/bigquery.py +63 -76
- airflow/providers/google/cloud/hooks/dataflow.py +67 -5
- airflow/providers/google/cloud/hooks/gcs.py +3 -3
- airflow/providers/google/cloud/hooks/looker.py +5 -0
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -36
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +1 -66
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/links/cloud_run.py +59 -0
- airflow/providers/google/cloud/links/vertex_ai.py +49 -0
- airflow/providers/google/cloud/log/gcs_task_handler.py +7 -5
- airflow/providers/google/cloud/operators/bigquery.py +49 -10
- airflow/providers/google/cloud/operators/cloud_run.py +20 -2
- airflow/providers/google/cloud/operators/gcs.py +1 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +4 -86
- airflow/providers/google/cloud/operators/pubsub.py +2 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +0 -92
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +4 -0
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +388 -0
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +9 -5
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +2 -0
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +11 -5
- airflow/providers/google/cloud/triggers/bigquery.py +32 -5
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/dataproc.py +62 -10
- airflow/providers/google/get_provider_info.py +18 -5
- airflow/providers/google/leveldb/hooks/leveldb.py +25 -0
- airflow/providers/google/version_compat.py +0 -1
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/METADATA +92 -85
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/RECORD +35 -32
- airflow/providers/google/cloud/links/automl.py +0 -193
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-16.0.0.dist-info}/entry_points.txt +0 -0
@@ -29,11 +29,11 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "
|
32
|
+
__version__ = "16.0.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
|
-
"2.
|
35
|
+
"2.10.0"
|
36
36
|
):
|
37
37
|
raise RuntimeError(
|
38
|
-
f"The package `apache-airflow-providers-google:{__version__}` needs Apache Airflow 2.
|
38
|
+
f"The package `apache-airflow-providers-google:{__version__}` needs Apache Airflow 2.10.0+"
|
39
39
|
)
|
@@ -101,6 +101,40 @@ class GoogleAdsHook(BaseHook):
|
|
101
101
|
:param api_version: The Google Ads API version to use.
|
102
102
|
"""
|
103
103
|
|
104
|
+
conn_name_attr = "google_ads_conn_id"
|
105
|
+
default_conn_name = "google_ads_default"
|
106
|
+
conn_type = "google_ads"
|
107
|
+
hook_name = "Google Ads"
|
108
|
+
|
109
|
+
@classmethod
|
110
|
+
def get_connection_form_widgets(cls) -> dict[str, Any]:
|
111
|
+
"""Return connection widgets to add to Google Ads connection form."""
|
112
|
+
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
|
113
|
+
from flask_babel import lazy_gettext
|
114
|
+
from wtforms import PasswordField, StringField
|
115
|
+
|
116
|
+
return {
|
117
|
+
"developer_token": StringField(lazy_gettext("Developer token"), widget=BS3TextFieldWidget()),
|
118
|
+
"client_id": StringField(lazy_gettext("OAuth2 Client ID"), widget=BS3TextFieldWidget()),
|
119
|
+
"client_secret": PasswordField(
|
120
|
+
lazy_gettext("OAuth2 Client Secret"), widget=BS3PasswordFieldWidget()
|
121
|
+
),
|
122
|
+
"refresh_token": PasswordField(
|
123
|
+
lazy_gettext("OAuth2 Refresh Token"), widget=BS3PasswordFieldWidget()
|
124
|
+
),
|
125
|
+
}
|
126
|
+
|
127
|
+
@classmethod
|
128
|
+
def get_ui_field_behaviour(cls) -> dict[str, Any]:
|
129
|
+
"""Return custom UI field behaviour for Google Ads connection."""
|
130
|
+
return {
|
131
|
+
"hidden_fields": ["host", "login", "schema", "port"],
|
132
|
+
"relabeling": {},
|
133
|
+
"placeholders": {
|
134
|
+
"password": "Leave blank (optional)",
|
135
|
+
},
|
136
|
+
}
|
137
|
+
|
104
138
|
def __init__(
|
105
139
|
self,
|
106
140
|
api_version: str | None = None,
|
@@ -29,7 +29,7 @@ import uuid
|
|
29
29
|
from collections.abc import Iterable, Mapping, Sequence
|
30
30
|
from copy import deepcopy
|
31
31
|
from datetime import datetime, timedelta
|
32
|
-
from typing import TYPE_CHECKING, Any, NoReturn, Union, cast
|
32
|
+
from typing import TYPE_CHECKING, Any, NoReturn, Union, cast, overload
|
33
33
|
|
34
34
|
from aiohttp import ClientSession as ClientSession
|
35
35
|
from gcloud.aio.bigquery import Job, Table as Table_async
|
@@ -57,8 +57,13 @@ from googleapiclient.discovery import build
|
|
57
57
|
from pandas_gbq import read_gbq
|
58
58
|
from pandas_gbq.gbq import GbqConnector # noqa: F401 used in ``airflow.contrib.hooks.bigquery``
|
59
59
|
from sqlalchemy import create_engine
|
60
|
+
from typing_extensions import Literal
|
60
61
|
|
61
|
-
from airflow.exceptions import
|
62
|
+
from airflow.exceptions import (
|
63
|
+
AirflowException,
|
64
|
+
AirflowOptionalProviderFeatureException,
|
65
|
+
AirflowProviderDeprecationWarning,
|
66
|
+
)
|
62
67
|
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
|
63
68
|
from airflow.providers.common.sql.hooks.sql import DbApiHook
|
64
69
|
from airflow.providers.google.cloud.utils.bigquery import bq_cast
|
@@ -77,6 +82,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
|
|
77
82
|
|
78
83
|
if TYPE_CHECKING:
|
79
84
|
import pandas as pd
|
85
|
+
import polars as pl
|
80
86
|
from google.api_core.page_iterator import HTTPIterator
|
81
87
|
from google.api_core.retry import Retry
|
82
88
|
from requests import Session
|
@@ -275,15 +281,57 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
|
|
275
281
|
"""
|
276
282
|
raise NotImplementedError()
|
277
283
|
|
278
|
-
def
|
284
|
+
def _get_pandas_df(
|
279
285
|
self,
|
280
286
|
sql: str,
|
281
287
|
parameters: Iterable | Mapping[str, Any] | None = None,
|
282
288
|
dialect: str | None = None,
|
283
289
|
**kwargs,
|
284
290
|
) -> pd.DataFrame:
|
291
|
+
if dialect is None:
|
292
|
+
dialect = "legacy" if self.use_legacy_sql else "standard"
|
293
|
+
|
294
|
+
credentials, project_id = self.get_credentials_and_project_id()
|
295
|
+
|
296
|
+
return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)
|
297
|
+
|
298
|
+
def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.DataFrame:
|
299
|
+
try:
|
300
|
+
import polars as pl
|
301
|
+
except ImportError:
|
302
|
+
raise AirflowOptionalProviderFeatureException(
|
303
|
+
"Polars is not installed. Please install it with `pip install polars`."
|
304
|
+
)
|
305
|
+
|
306
|
+
if dialect is None:
|
307
|
+
dialect = "legacy" if self.use_legacy_sql else "standard"
|
308
|
+
|
309
|
+
credentials, project_id = self.get_credentials_and_project_id()
|
310
|
+
|
311
|
+
pandas_df = read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs)
|
312
|
+
return pl.from_pandas(pandas_df)
|
313
|
+
|
314
|
+
@overload
|
315
|
+
def get_df(
|
316
|
+
self, sql, parameters=None, dialect=None, *, df_type: Literal["pandas"] = "pandas", **kwargs
|
317
|
+
) -> pd.DataFrame: ...
|
318
|
+
|
319
|
+
@overload
|
320
|
+
def get_df(
|
321
|
+
self, sql, parameters=None, dialect=None, *, df_type: Literal["polars"], **kwargs
|
322
|
+
) -> pl.DataFrame: ...
|
323
|
+
|
324
|
+
def get_df(
|
325
|
+
self,
|
326
|
+
sql,
|
327
|
+
parameters=None,
|
328
|
+
dialect=None,
|
329
|
+
*,
|
330
|
+
df_type: Literal["pandas", "polars"] = "pandas",
|
331
|
+
**kwargs,
|
332
|
+
) -> pd.DataFrame | pl.DataFrame:
|
285
333
|
"""
|
286
|
-
Get a
|
334
|
+
Get a DataFrame for the BigQuery results.
|
287
335
|
|
288
336
|
The DbApiHook method must be overridden because Pandas doesn't support
|
289
337
|
PEP 249 connections, except for SQLite.
|
@@ -299,12 +347,19 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
|
|
299
347
|
defaults to use `self.use_legacy_sql` if not specified
|
300
348
|
:param kwargs: (optional) passed into pandas_gbq.read_gbq method
|
301
349
|
"""
|
302
|
-
if
|
303
|
-
|
350
|
+
if df_type == "polars":
|
351
|
+
return self._get_polars_df(sql, parameters, dialect, **kwargs)
|
304
352
|
|
305
|
-
|
353
|
+
if df_type == "pandas":
|
354
|
+
return self._get_pandas_df(sql, parameters, dialect, **kwargs)
|
306
355
|
|
307
|
-
|
356
|
+
@deprecated(
|
357
|
+
planned_removal_date="November 30, 2025",
|
358
|
+
use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_df",
|
359
|
+
category=AirflowProviderDeprecationWarning,
|
360
|
+
)
|
361
|
+
def get_pandas_df(self, sql, parameters=None, dialect=None, **kwargs):
|
362
|
+
return self._get_pandas_df(sql, parameters, dialect, **kwargs)
|
308
363
|
|
309
364
|
@GoogleBaseHook.fallback_to_default_project_id
|
310
365
|
def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool:
|
@@ -1937,74 +1992,6 @@ def _escape(s: str) -> str:
|
|
1937
1992
|
return e
|
1938
1993
|
|
1939
1994
|
|
1940
|
-
@deprecated(
|
1941
|
-
planned_removal_date="April 01, 2025",
|
1942
|
-
use_instead="airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.split_tablename",
|
1943
|
-
category=AirflowProviderDeprecationWarning,
|
1944
|
-
)
|
1945
|
-
def split_tablename(
|
1946
|
-
table_input: str, default_project_id: str, var_name: str | None = None
|
1947
|
-
) -> tuple[str, str, str]:
|
1948
|
-
if "." not in table_input:
|
1949
|
-
raise ValueError(f"Expected table name in the format of <dataset>.<table>. Got: {table_input}")
|
1950
|
-
|
1951
|
-
if not default_project_id:
|
1952
|
-
raise ValueError("INTERNAL: No default project is specified")
|
1953
|
-
|
1954
|
-
def var_print(var_name):
|
1955
|
-
if var_name is None:
|
1956
|
-
return ""
|
1957
|
-
return f"Format exception for {var_name}: "
|
1958
|
-
|
1959
|
-
if table_input.count(".") + table_input.count(":") > 3:
|
1960
|
-
raise ValueError(f"{var_print(var_name)}Use either : or . to specify project got {table_input}")
|
1961
|
-
cmpt = table_input.rsplit(":", 1)
|
1962
|
-
project_id = None
|
1963
|
-
rest = table_input
|
1964
|
-
if len(cmpt) == 1:
|
1965
|
-
project_id = None
|
1966
|
-
rest = cmpt[0]
|
1967
|
-
elif len(cmpt) == 2 and cmpt[0].count(":") <= 1:
|
1968
|
-
if cmpt[-1].count(".") != 2:
|
1969
|
-
project_id = cmpt[0]
|
1970
|
-
rest = cmpt[1]
|
1971
|
-
else:
|
1972
|
-
raise ValueError(
|
1973
|
-
f"{var_print(var_name)}Expect format of (<project:)<dataset>.<table>, got {table_input}"
|
1974
|
-
)
|
1975
|
-
|
1976
|
-
cmpt = rest.split(".")
|
1977
|
-
if len(cmpt) == 3:
|
1978
|
-
if project_id:
|
1979
|
-
raise ValueError(f"{var_print(var_name)}Use either : or . to specify project")
|
1980
|
-
project_id = cmpt[0]
|
1981
|
-
dataset_id = cmpt[1]
|
1982
|
-
table_id = cmpt[2]
|
1983
|
-
|
1984
|
-
elif len(cmpt) == 2:
|
1985
|
-
dataset_id = cmpt[0]
|
1986
|
-
table_id = cmpt[1]
|
1987
|
-
else:
|
1988
|
-
raise ValueError(
|
1989
|
-
f"{var_print(var_name)}Expect format of (<project.|<project:)<dataset>.<table>, got {table_input}"
|
1990
|
-
)
|
1991
|
-
|
1992
|
-
# Exclude partition from the table name
|
1993
|
-
table_id = table_id.split("$")[0]
|
1994
|
-
|
1995
|
-
if project_id is None:
|
1996
|
-
if var_name is not None:
|
1997
|
-
log.info(
|
1998
|
-
'Project is not included in %s: %s; using project "%s"',
|
1999
|
-
var_name,
|
2000
|
-
table_input,
|
2001
|
-
default_project_id,
|
2002
|
-
)
|
2003
|
-
project_id = default_project_id
|
2004
|
-
|
2005
|
-
return project_id, dataset_id, table_id
|
2006
|
-
|
2007
|
-
|
2008
1995
|
def _cleanse_time_partitioning(
|
2009
1996
|
destination_dataset_table: str | None, time_partitioning_in: dict | None
|
2010
1997
|
) -> dict: # if it is a partitioned table ($ is in the table name) add partition load option
|
@@ -185,7 +185,67 @@ class DataflowJobType:
|
|
185
185
|
JOB_TYPE_STREAMING = "JOB_TYPE_STREAMING"
|
186
186
|
|
187
187
|
|
188
|
-
class
|
188
|
+
class DataflowJobTerminalStateHelper(LoggingMixin):
|
189
|
+
"""Helper to define and validate the dataflow job terminal state."""
|
190
|
+
|
191
|
+
@staticmethod
|
192
|
+
def expected_terminal_state_is_allowed(expected_terminal_state):
|
193
|
+
job_allowed_terminal_states = DataflowJobStatus.TERMINAL_STATES | {
|
194
|
+
DataflowJobStatus.JOB_STATE_RUNNING
|
195
|
+
}
|
196
|
+
if expected_terminal_state not in job_allowed_terminal_states:
|
197
|
+
raise AirflowException(
|
198
|
+
f"Google Cloud Dataflow job's expected terminal state "
|
199
|
+
f"'{expected_terminal_state}' is invalid."
|
200
|
+
f" The value should be any of the following: {job_allowed_terminal_states}"
|
201
|
+
)
|
202
|
+
return True
|
203
|
+
|
204
|
+
@staticmethod
|
205
|
+
def expected_terminal_state_is_valid_for_job_type(expected_terminal_state, is_streaming: bool):
|
206
|
+
if is_streaming:
|
207
|
+
invalid_terminal_state = DataflowJobStatus.JOB_STATE_DONE
|
208
|
+
job_type = "streaming"
|
209
|
+
else:
|
210
|
+
invalid_terminal_state = DataflowJobStatus.JOB_STATE_DRAINED
|
211
|
+
job_type = "batch"
|
212
|
+
|
213
|
+
if expected_terminal_state == invalid_terminal_state:
|
214
|
+
raise AirflowException(
|
215
|
+
f"Google Cloud Dataflow job's expected terminal state cannot be {invalid_terminal_state} while it is a {job_type} job"
|
216
|
+
)
|
217
|
+
return True
|
218
|
+
|
219
|
+
def job_reached_terminal_state(self, job, wait_until_finished=None, custom_terminal_state=None) -> bool:
|
220
|
+
"""
|
221
|
+
Check the job reached terminal state, if job failed raise exception.
|
222
|
+
|
223
|
+
:return: True if job is done.
|
224
|
+
:raise: Exception
|
225
|
+
"""
|
226
|
+
current_state = job["currentState"]
|
227
|
+
is_streaming = job.get("type") == DataflowJobType.JOB_TYPE_STREAMING
|
228
|
+
expected_terminal_state = (
|
229
|
+
DataflowJobStatus.JOB_STATE_RUNNING if is_streaming else DataflowJobStatus.JOB_STATE_DONE
|
230
|
+
)
|
231
|
+
if custom_terminal_state is not None:
|
232
|
+
expected_terminal_state = custom_terminal_state
|
233
|
+
self.expected_terminal_state_is_allowed(expected_terminal_state)
|
234
|
+
self.expected_terminal_state_is_valid_for_job_type(expected_terminal_state, is_streaming=is_streaming)
|
235
|
+
if current_state == expected_terminal_state:
|
236
|
+
if expected_terminal_state == DataflowJobStatus.JOB_STATE_RUNNING and wait_until_finished:
|
237
|
+
return False
|
238
|
+
return True
|
239
|
+
if current_state in DataflowJobStatus.AWAITING_STATES:
|
240
|
+
return wait_until_finished is False
|
241
|
+
self.log.debug("Current job: %s", job)
|
242
|
+
raise AirflowException(
|
243
|
+
f"Google Cloud Dataflow job {job['name']} is in an unexpected terminal state: {current_state}, "
|
244
|
+
f"expected terminal state: {expected_terminal_state}"
|
245
|
+
)
|
246
|
+
|
247
|
+
|
248
|
+
class _DataflowJobsController(DataflowJobTerminalStateHelper):
|
189
249
|
"""
|
190
250
|
Interface for communication with Google Cloud Dataflow API.
|
191
251
|
|
@@ -462,7 +522,10 @@ class _DataflowJobsController(LoggingMixin):
|
|
462
522
|
"""Wait for result of submitted job."""
|
463
523
|
self.log.info("Start waiting for done.")
|
464
524
|
self._refresh_jobs()
|
465
|
-
while self._jobs and not all(
|
525
|
+
while self._jobs and not all(
|
526
|
+
self.job_reached_terminal_state(job, self._wait_until_finished, self._expected_terminal_state)
|
527
|
+
for job in self._jobs
|
528
|
+
):
|
466
529
|
self.log.info("Waiting for done. Sleep %s s", self._poll_sleep)
|
467
530
|
time.sleep(self._poll_sleep)
|
468
531
|
self._refresh_jobs()
|
@@ -1295,8 +1358,7 @@ class DataflowHook(GoogleBaseHook):
|
|
1295
1358
|
location=location,
|
1296
1359
|
)
|
1297
1360
|
job = job_controller.fetch_job_by_id(job_id)
|
1298
|
-
|
1299
|
-
return job_controller._check_dataflow_job_state(job)
|
1361
|
+
return job_controller.job_reached_terminal_state(job)
|
1300
1362
|
|
1301
1363
|
@GoogleBaseHook.fallback_to_default_project_id
|
1302
1364
|
def create_data_pipeline(
|
@@ -1425,7 +1487,7 @@ class DataflowHook(GoogleBaseHook):
|
|
1425
1487
|
return f"projects/{project_id}/locations/{location}"
|
1426
1488
|
|
1427
1489
|
|
1428
|
-
class AsyncDataflowHook(GoogleBaseAsyncHook):
|
1490
|
+
class AsyncDataflowHook(GoogleBaseAsyncHook, DataflowJobTerminalStateHelper):
|
1429
1491
|
"""Async hook class for dataflow service."""
|
1430
1492
|
|
1431
1493
|
sync_hook_class = DataflowHook
|
@@ -549,13 +549,13 @@ class GCSHook(GoogleBaseHook):
|
|
549
549
|
if cache_control:
|
550
550
|
blob.cache_control = cache_control
|
551
551
|
|
552
|
-
if filename and data:
|
552
|
+
if filename is not None and data is not None:
|
553
553
|
raise ValueError(
|
554
554
|
"'filename' and 'data' parameter provided. Please "
|
555
555
|
"specify a single parameter, either 'filename' for "
|
556
556
|
"local file uploads or 'data' for file content uploads."
|
557
557
|
)
|
558
|
-
if filename:
|
558
|
+
if filename is not None:
|
559
559
|
if not mime_type:
|
560
560
|
mime_type = "application/octet-stream"
|
561
561
|
if gzip:
|
@@ -575,7 +575,7 @@ class GCSHook(GoogleBaseHook):
|
|
575
575
|
if gzip:
|
576
576
|
os.remove(filename)
|
577
577
|
self.log.info("File %s uploaded to %s in %s bucket", filename, object_name, bucket_name)
|
578
|
-
elif data:
|
578
|
+
elif data is not None:
|
579
579
|
if not mime_type:
|
580
580
|
mime_type = "text/plain"
|
581
581
|
if gzip:
|
@@ -39,6 +39,11 @@ if TYPE_CHECKING:
|
|
39
39
|
class LookerHook(BaseHook):
|
40
40
|
"""Hook for Looker APIs."""
|
41
41
|
|
42
|
+
conn_name_attr = "looker_conn_id"
|
43
|
+
default_conn_name = "looker_default"
|
44
|
+
conn_type = "gcp_looker"
|
45
|
+
hook_name = "Google Looker"
|
46
|
+
|
42
47
|
def __init__(
|
43
48
|
self,
|
44
49
|
looker_conn_id: str,
|
@@ -185,42 +185,6 @@ class AutoMLHook(GoogleBaseHook, OperationHelper):
|
|
185
185
|
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
186
186
|
)
|
187
187
|
|
188
|
-
@deprecated(
|
189
|
-
planned_removal_date="June 15, 2025",
|
190
|
-
category=AirflowProviderDeprecationWarning,
|
191
|
-
reason="Deprecation of AutoMLText API",
|
192
|
-
)
|
193
|
-
def get_auto_ml_text_training_job(
|
194
|
-
self,
|
195
|
-
display_name: str,
|
196
|
-
prediction_type: str,
|
197
|
-
multi_label: bool = False,
|
198
|
-
sentiment_max: int = 10,
|
199
|
-
project: str | None = None,
|
200
|
-
location: str | None = None,
|
201
|
-
labels: dict[str, str] | None = None,
|
202
|
-
training_encryption_spec_key_name: str | None = None,
|
203
|
-
model_encryption_spec_key_name: str | None = None,
|
204
|
-
) -> AutoMLTextTrainingJob:
|
205
|
-
"""
|
206
|
-
Return AutoMLTextTrainingJob object.
|
207
|
-
|
208
|
-
WARNING: Text creation API is deprecated since September 15, 2024
|
209
|
-
(https://cloud.google.com/vertex-ai/docs/tutorials/text-classification-automl/overview).
|
210
|
-
"""
|
211
|
-
return AutoMLTextTrainingJob(
|
212
|
-
display_name=display_name,
|
213
|
-
prediction_type=prediction_type,
|
214
|
-
multi_label=multi_label,
|
215
|
-
sentiment_max=sentiment_max,
|
216
|
-
project=project,
|
217
|
-
location=location,
|
218
|
-
credentials=self.get_credentials(),
|
219
|
-
labels=labels,
|
220
|
-
training_encryption_spec_key_name=training_encryption_spec_key_name,
|
221
|
-
model_encryption_spec_key_name=model_encryption_spec_key_name,
|
222
|
-
)
|
223
|
-
|
224
188
|
def get_auto_ml_video_training_job(
|
225
189
|
self,
|
226
190
|
display_name: str,
|
@@ -25,14 +25,12 @@ from typing import TYPE_CHECKING
|
|
25
25
|
|
26
26
|
import vertexai
|
27
27
|
from vertexai.generative_models import GenerativeModel
|
28
|
-
from vertexai.language_models import TextEmbeddingModel
|
28
|
+
from vertexai.language_models import TextEmbeddingModel
|
29
29
|
from vertexai.preview.caching import CachedContent
|
30
30
|
from vertexai.preview.evaluation import EvalResult, EvalTask
|
31
31
|
from vertexai.preview.generative_models import GenerativeModel as preview_generative_model
|
32
32
|
from vertexai.preview.tuning import sft
|
33
33
|
|
34
|
-
from airflow.exceptions import AirflowProviderDeprecationWarning
|
35
|
-
from airflow.providers.google.common.deprecated import deprecated
|
36
34
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
37
35
|
|
38
36
|
if TYPE_CHECKING:
|
@@ -43,16 +41,6 @@ if TYPE_CHECKING:
|
|
43
41
|
class GenerativeModelHook(GoogleBaseHook):
|
44
42
|
"""Hook for Google Cloud Vertex AI Generative Model APIs."""
|
45
43
|
|
46
|
-
@deprecated(
|
47
|
-
planned_removal_date="April 09, 2025",
|
48
|
-
use_instead="GenerativeModelHook.get_generative_model",
|
49
|
-
category=AirflowProviderDeprecationWarning,
|
50
|
-
)
|
51
|
-
def get_text_generation_model(self, pretrained_model: str):
|
52
|
-
"""Return a Model Garden Model object based on Text Generation."""
|
53
|
-
model = TextGenerationModel.from_pretrained(pretrained_model)
|
54
|
-
return model
|
55
|
-
|
56
44
|
def get_text_embedding_model(self, pretrained_model: str):
|
57
45
|
"""Return a Model Garden Model object based on Text Embedding."""
|
58
46
|
model = TextEmbeddingModel.from_pretrained(pretrained_model)
|
@@ -100,59 +88,6 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
100
88
|
cached_context_model = preview_generative_model.from_cached_content(cached_content)
|
101
89
|
return cached_context_model
|
102
90
|
|
103
|
-
@deprecated(
|
104
|
-
planned_removal_date="April 09, 2025",
|
105
|
-
use_instead="GenerativeModelHook.generative_model_generate_content",
|
106
|
-
category=AirflowProviderDeprecationWarning,
|
107
|
-
)
|
108
|
-
@GoogleBaseHook.fallback_to_default_project_id
|
109
|
-
def text_generation_model_predict(
|
110
|
-
self,
|
111
|
-
prompt: str,
|
112
|
-
pretrained_model: str,
|
113
|
-
temperature: float,
|
114
|
-
max_output_tokens: int,
|
115
|
-
top_p: float,
|
116
|
-
top_k: int,
|
117
|
-
location: str,
|
118
|
-
project_id: str = PROVIDE_PROJECT_ID,
|
119
|
-
) -> str:
|
120
|
-
"""
|
121
|
-
Use the Vertex AI PaLM API to generate natural language text.
|
122
|
-
|
123
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
124
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
125
|
-
:param prompt: Required. Inputs or queries that a user or a program gives
|
126
|
-
to the Vertex AI PaLM API, in order to elicit a specific response.
|
127
|
-
:param pretrained_model: A pre-trained model optimized for performing natural
|
128
|
-
language tasks such as classification, summarization, extraction, content
|
129
|
-
creation, and ideation.
|
130
|
-
:param temperature: Temperature controls the degree of randomness in token
|
131
|
-
selection.
|
132
|
-
:param max_output_tokens: Token limit determines the maximum amount of text
|
133
|
-
output.
|
134
|
-
:param top_p: Tokens are selected from most probable to least until the sum
|
135
|
-
of their probabilities equals the top_p value. Defaults to 0.8.
|
136
|
-
:param top_k: A top_k of 1 means the selected token is the most probable
|
137
|
-
among all tokens.
|
138
|
-
"""
|
139
|
-
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
140
|
-
|
141
|
-
parameters = {
|
142
|
-
"temperature": temperature,
|
143
|
-
"max_output_tokens": max_output_tokens,
|
144
|
-
"top_p": top_p,
|
145
|
-
"top_k": top_k,
|
146
|
-
}
|
147
|
-
|
148
|
-
model = self.get_text_generation_model(pretrained_model)
|
149
|
-
|
150
|
-
response = model.predict(
|
151
|
-
prompt=prompt,
|
152
|
-
**parameters,
|
153
|
-
)
|
154
|
-
return response.text
|
155
|
-
|
156
91
|
@GoogleBaseHook.fallback_to_default_project_id
|
157
92
|
def text_embedding_model_get_embeddings(
|
158
93
|
self,
|