apache-airflow-providers-google 18.0.0rc1__py3-none-any.whl → 18.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-google might be problematic. Click here for more details.

Files changed (72) hide show
  1. airflow/providers/google/__init__.py +1 -1
  2. airflow/providers/google/ads/hooks/ads.py +5 -5
  3. airflow/providers/google/assets/gcs.py +1 -11
  4. airflow/providers/google/cloud/bundles/__init__.py +16 -0
  5. airflow/providers/google/cloud/bundles/gcs.py +161 -0
  6. airflow/providers/google/cloud/hooks/bigquery.py +45 -42
  7. airflow/providers/google/cloud/hooks/cloud_composer.py +131 -1
  8. airflow/providers/google/cloud/hooks/cloud_sql.py +88 -13
  9. airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +16 -0
  10. airflow/providers/google/cloud/hooks/dataflow.py +1 -1
  11. airflow/providers/google/cloud/hooks/dataprep.py +1 -1
  12. airflow/providers/google/cloud/hooks/dataproc.py +3 -0
  13. airflow/providers/google/cloud/hooks/gcs.py +107 -3
  14. airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
  15. airflow/providers/google/cloud/hooks/looker.py +1 -1
  16. airflow/providers/google/cloud/hooks/spanner.py +45 -0
  17. airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +30 -0
  18. airflow/providers/google/cloud/links/base.py +11 -11
  19. airflow/providers/google/cloud/links/dataproc.py +2 -10
  20. airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
  21. airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
  22. airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
  23. airflow/providers/google/cloud/openlineage/facets.py +102 -1
  24. airflow/providers/google/cloud/openlineage/mixins.py +3 -1
  25. airflow/providers/google/cloud/operators/bigquery.py +2 -9
  26. airflow/providers/google/cloud/operators/cloud_run.py +2 -1
  27. airflow/providers/google/cloud/operators/cloud_sql.py +1 -1
  28. airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +89 -6
  29. airflow/providers/google/cloud/operators/datafusion.py +36 -7
  30. airflow/providers/google/cloud/operators/gen_ai.py +389 -0
  31. airflow/providers/google/cloud/operators/spanner.py +22 -6
  32. airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +7 -0
  33. airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +30 -0
  34. airflow/providers/google/cloud/operators/workflows.py +17 -6
  35. airflow/providers/google/cloud/sensors/bigquery.py +1 -1
  36. airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -6
  37. airflow/providers/google/cloud/sensors/bigtable.py +1 -6
  38. airflow/providers/google/cloud/sensors/cloud_composer.py +65 -31
  39. airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +1 -6
  40. airflow/providers/google/cloud/sensors/dataflow.py +1 -1
  41. airflow/providers/google/cloud/sensors/dataform.py +1 -6
  42. airflow/providers/google/cloud/sensors/datafusion.py +1 -6
  43. airflow/providers/google/cloud/sensors/dataplex.py +1 -6
  44. airflow/providers/google/cloud/sensors/dataprep.py +1 -6
  45. airflow/providers/google/cloud/sensors/dataproc.py +1 -6
  46. airflow/providers/google/cloud/sensors/dataproc_metastore.py +1 -6
  47. airflow/providers/google/cloud/sensors/gcs.py +1 -7
  48. airflow/providers/google/cloud/sensors/looker.py +1 -6
  49. airflow/providers/google/cloud/sensors/pubsub.py +1 -6
  50. airflow/providers/google/cloud/sensors/tasks.py +1 -6
  51. airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +1 -6
  52. airflow/providers/google/cloud/sensors/workflows.py +1 -6
  53. airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
  54. airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +2 -1
  55. airflow/providers/google/cloud/transfers/sftp_to_gcs.py +11 -2
  56. airflow/providers/google/cloud/triggers/bigquery.py +15 -3
  57. airflow/providers/google/cloud/triggers/cloud_composer.py +51 -21
  58. airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
  59. airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +90 -0
  60. airflow/providers/google/cloud/triggers/pubsub.py +14 -18
  61. airflow/providers/google/common/hooks/base_google.py +1 -1
  62. airflow/providers/google/get_provider_info.py +15 -0
  63. airflow/providers/google/leveldb/hooks/leveldb.py +1 -1
  64. airflow/providers/google/marketing_platform/links/analytics_admin.py +2 -8
  65. airflow/providers/google/marketing_platform/sensors/campaign_manager.py +1 -6
  66. airflow/providers/google/marketing_platform/sensors/display_video.py +1 -6
  67. airflow/providers/google/suite/sensors/drive.py +1 -6
  68. airflow/providers/google/version_compat.py +0 -20
  69. {apache_airflow_providers_google-18.0.0rc1.dist-info → apache_airflow_providers_google-18.1.0rc1.dist-info}/METADATA +8 -8
  70. {apache_airflow_providers_google-18.0.0rc1.dist-info → apache_airflow_providers_google-18.1.0rc1.dist-info}/RECORD +72 -65
  71. {apache_airflow_providers_google-18.0.0rc1.dist-info → apache_airflow_providers_google-18.1.0rc1.dist-info}/WHEEL +0 -0
  72. {apache_airflow_providers_google-18.0.0rc1.dist-info → apache_airflow_providers_google-18.1.0rc1.dist-info}/entry_points.txt +0 -0
@@ -50,7 +50,14 @@ from googleapiclient.errors import HttpError
50
50
  # Number of retries - used by googleapiclient method calls to perform retries
51
51
  # For requests that are "retriable"
52
52
  from airflow.exceptions import AirflowException
53
- from airflow.models import Connection
53
+ from airflow.providers.google.version_compat import AIRFLOW_V_3_1_PLUS
54
+
55
+ if AIRFLOW_V_3_1_PLUS:
56
+ from airflow.sdk import Connection
57
+ else:
58
+ from airflow.models import Connection # type: ignore[assignment,attr-defined,no-redef]
59
+
60
+ from airflow.providers.common.compat.sdk import BaseHook
54
61
  from airflow.providers.google.cloud.hooks.secret_manager import (
55
62
  GoogleCloudSecretManagerHook,
56
63
  )
@@ -60,7 +67,6 @@ from airflow.providers.google.common.hooks.base_google import (
60
67
  GoogleBaseHook,
61
68
  get_field,
62
69
  )
63
- from airflow.providers.google.version_compat import BaseHook
64
70
  from airflow.utils.log.logging_mixin import LoggingMixin
65
71
 
66
72
  if TYPE_CHECKING:
@@ -1045,15 +1051,26 @@ class CloudSQLDatabaseHook(BaseHook):
1045
1051
  def _quote(value) -> str | None:
1046
1052
  return quote_plus(value) if value else None
1047
1053
 
1048
- def _generate_connection_uri(self) -> str:
1054
+ def _reserve_port(self):
1049
1055
  if self.use_proxy:
1050
1056
  if self.sql_proxy_use_tcp:
1051
1057
  if not self.sql_proxy_tcp_port:
1052
1058
  self.reserve_free_tcp_port()
1053
1059
  if not self.sql_proxy_unique_path:
1054
1060
  self.sql_proxy_unique_path = self._generate_unique_path()
1061
+
1062
+ def _generate_connection_uri(self) -> str:
1063
+ self._reserve_port()
1055
1064
  if not self.database_type:
1056
1065
  raise ValueError("The database_type should be set")
1066
+ if not self.user:
1067
+ raise AirflowException("The login parameter needs to be set in connection")
1068
+ if not self.public_ip:
1069
+ raise AirflowException("The location parameter needs to be set in connection")
1070
+ if not self.password:
1071
+ raise AirflowException("The password parameter needs to be set in connection")
1072
+ if not self.database:
1073
+ raise AirflowException("The database parameter needs to be set in connection")
1057
1074
 
1058
1075
  database_uris = CONNECTION_URIS[self.database_type]
1059
1076
  ssl_spec = None
@@ -1072,14 +1089,6 @@ class CloudSQLDatabaseHook(BaseHook):
1072
1089
  ssl_spec = {"cert": self.sslcert, "key": self.sslkey, "ca": self.sslrootcert}
1073
1090
  else:
1074
1091
  format_string = public_uris["non-ssl"]
1075
- if not self.user:
1076
- raise AirflowException("The login parameter needs to be set in connection")
1077
- if not self.public_ip:
1078
- raise AirflowException("The location parameter needs to be set in connection")
1079
- if not self.password:
1080
- raise AirflowException("The password parameter needs to be set in connection")
1081
- if not self.database:
1082
- raise AirflowException("The database parameter needs to be set in connection")
1083
1092
 
1084
1093
  connection_uri = format_string.format(
1085
1094
  user=quote_plus(self.user) if self.user else "",
@@ -1113,6 +1122,69 @@ class CloudSQLDatabaseHook(BaseHook):
1113
1122
  instance_specification += f"=tcp:{self.sql_proxy_tcp_port}"
1114
1123
  return instance_specification
1115
1124
 
1125
+ def _generate_connection_parameters(self) -> dict:
1126
+ self._reserve_port()
1127
+ if not self.database_type:
1128
+ raise ValueError("The database_type should be set")
1129
+ if not self.user:
1130
+ raise AirflowException("The login parameter needs to be set in connection")
1131
+ if not self.public_ip:
1132
+ raise AirflowException("The location parameter needs to be set in connection")
1133
+ if not self.password:
1134
+ raise AirflowException("The password parameter needs to be set in connection")
1135
+ if not self.database:
1136
+ raise AirflowException("The database parameter needs to be set in connection")
1137
+
1138
+ connection_parameters = {}
1139
+
1140
+ connection_parameters["conn_type"] = self.database_type
1141
+ connection_parameters["login"] = self.user
1142
+ connection_parameters["password"] = self.password
1143
+ connection_parameters["schema"] = self.database
1144
+ connection_parameters["extra"] = {}
1145
+
1146
+ database_uris = CONNECTION_URIS[self.database_type]
1147
+ if self.use_proxy:
1148
+ proxy_uris = database_uris["proxy"]
1149
+ if self.sql_proxy_use_tcp:
1150
+ connection_parameters["host"] = "127.0.0.1"
1151
+ connection_parameters["port"] = self.sql_proxy_tcp_port
1152
+ else:
1153
+ socket_path = f"{self.sql_proxy_unique_path}/{self._get_instance_socket_name()}"
1154
+ if "localhost" in proxy_uris["socket"]:
1155
+ connection_parameters["host"] = "localhost"
1156
+ connection_parameters["extra"].update({"unix_socket": socket_path})
1157
+ else:
1158
+ connection_parameters["host"] = socket_path
1159
+ else:
1160
+ public_uris = database_uris["public"]
1161
+ if self.use_ssl:
1162
+ connection_parameters["host"] = self.public_ip
1163
+ connection_parameters["port"] = self.public_port
1164
+ if "ssl_spec" in public_uris["ssl"]:
1165
+ connection_parameters["extra"].update(
1166
+ {
1167
+ "ssl": json.dumps(
1168
+ {"cert": self.sslcert, "key": self.sslkey, "ca": self.sslrootcert}
1169
+ )
1170
+ }
1171
+ )
1172
+ else:
1173
+ connection_parameters["extra"].update(
1174
+ {
1175
+ "sslmode": "verify-ca",
1176
+ "sslcert": self.sslcert,
1177
+ "sslkey": self.sslkey,
1178
+ "sslrootcert": self.sslrootcert,
1179
+ }
1180
+ )
1181
+ else:
1182
+ connection_parameters["host"] = self.public_ip
1183
+ connection_parameters["port"] = self.public_port
1184
+ if connection_parameters.get("extra"):
1185
+ connection_parameters["extra"] = json.dumps(connection_parameters["extra"])
1186
+ return connection_parameters
1187
+
1116
1188
  def create_connection(self) -> Connection:
1117
1189
  """
1118
1190
  Create a connection.
@@ -1120,8 +1192,11 @@ class CloudSQLDatabaseHook(BaseHook):
1120
1192
  Connection ID will be randomly generated according to whether it uses
1121
1193
  proxy, TCP, UNIX sockets, SSL.
1122
1194
  """
1123
- uri = self._generate_connection_uri()
1124
- connection = Connection(conn_id=self.db_conn_id, uri=uri)
1195
+ if AIRFLOW_V_3_1_PLUS:
1196
+ kwargs = self._generate_connection_parameters()
1197
+ else:
1198
+ kwargs = {"uri": self._generate_connection_uri()}
1199
+ connection = Connection(conn_id=self.db_conn_id, **kwargs)
1125
1200
  self.log.info("Creating connection %s", self.db_conn_id)
1126
1201
  return connection
1127
1202
 
@@ -38,6 +38,7 @@ from typing import TYPE_CHECKING, Any
38
38
 
39
39
  from google.cloud.storage_transfer_v1 import (
40
40
  ListTransferJobsRequest,
41
+ RunTransferJobRequest,
41
42
  StorageTransferServiceAsyncClient,
42
43
  TransferJob,
43
44
  TransferOperation,
@@ -55,6 +56,7 @@ from airflow.providers.google.common.hooks.base_google import (
55
56
  )
56
57
 
57
58
  if TYPE_CHECKING:
59
+ from google.api_core import operation_async
58
60
  from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
59
61
  ListTransferJobsAsyncPager,
60
62
  )
@@ -712,3 +714,17 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
712
714
  f"Expected: {', '.join(expected_statuses_set)}"
713
715
  )
714
716
  return False
717
+
718
+ async def run_transfer_job(self, job_name: str) -> operation_async.AsyncOperation:
719
+ """
720
+ Run Google Storage Transfer Service job.
721
+
722
+ :param job_name: (Required) Name of the job to run.
723
+ """
724
+ client = await self.get_conn()
725
+ request = RunTransferJobRequest(
726
+ job_name=job_name,
727
+ project_id=self.project_id,
728
+ )
729
+ operation = await client.run_transfer_job(request=request)
730
+ return operation
@@ -51,12 +51,12 @@ from googleapiclient.discovery import Resource, build
51
51
 
52
52
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
53
53
  from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
54
+ from airflow.providers.common.compat.sdk import timeout
54
55
  from airflow.providers.google.common.hooks.base_google import (
55
56
  PROVIDE_PROJECT_ID,
56
57
  GoogleBaseAsyncHook,
57
58
  GoogleBaseHook,
58
59
  )
59
- from airflow.providers.google.version_compat import timeout
60
60
  from airflow.utils.log.logging_mixin import LoggingMixin
61
61
 
62
62
  if TYPE_CHECKING:
@@ -28,7 +28,7 @@ import requests
28
28
  from requests import HTTPError
29
29
  from tenacity import retry, stop_after_attempt, wait_exponential
30
30
 
31
- from airflow.providers.google.version_compat import BaseHook
31
+ from airflow.providers.common.compat.sdk import BaseHook
32
32
 
33
33
 
34
34
  def _get_field(extras: dict, field_name: str) -> str | None:
@@ -912,12 +912,15 @@ class DataprocHook(GoogleBaseHook):
912
912
  state = None
913
913
  start = time.monotonic()
914
914
  while state not in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
915
+ self.log.debug("Waiting for job %s to complete", job_id)
915
916
  if timeout and start + timeout < time.monotonic():
916
917
  raise AirflowException(f"Timeout: dataproc job {job_id} is not ready after {timeout}s")
918
+ self.log.debug("Sleeping for %s seconds", wait_time)
917
919
  time.sleep(wait_time)
918
920
  try:
919
921
  job = self.get_job(project_id=project_id, region=region, job_id=job_id)
920
922
  state = job.status.state
923
+ self.log.debug("Job %s is in state %s", job_id, state)
921
924
  except ServerError as err:
922
925
  self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)
923
926
 
@@ -28,8 +28,10 @@ import time
28
28
  import warnings
29
29
  from collections.abc import Callable, Generator, Sequence
30
30
  from contextlib import contextmanager
31
+ from datetime import datetime
31
32
  from functools import partial
32
33
  from io import BytesIO
34
+ from pathlib import Path
33
35
  from tempfile import NamedTemporaryFile
34
36
  from typing import IO, TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload
35
37
  from urllib.parse import urlsplit
@@ -50,12 +52,14 @@ from airflow.providers.google.common.hooks.base_google import (
50
52
  GoogleBaseAsyncHook,
51
53
  GoogleBaseHook,
52
54
  )
53
- from airflow.utils import timezone
55
+
56
+ try:
57
+ from airflow.sdk import timezone
58
+ except ImportError:
59
+ from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
54
60
  from airflow.version import version
55
61
 
56
62
  if TYPE_CHECKING:
57
- from datetime import datetime
58
-
59
63
  from aiohttp import ClientSession
60
64
  from google.api_core.retry import Retry
61
65
  from google.cloud.storage.blob import Blob
@@ -1249,6 +1253,106 @@ class GCSHook(GoogleBaseHook):
1249
1253
 
1250
1254
  self.log.info("Completed successfully.")
1251
1255
 
1256
+ def _sync_to_local_dir_delete_stale_local_files(self, current_gcs_objects: List[Path], local_dir: Path):
1257
+ current_gcs_keys = {key.resolve() for key in current_gcs_objects}
1258
+
1259
+ for item in local_dir.rglob("*"):
1260
+ if item.is_file():
1261
+ if item.resolve() not in current_gcs_keys:
1262
+ self.log.debug("Deleting stale local file: %s", item)
1263
+ item.unlink()
1264
+ # Clean up empty directories
1265
+ for root, dirs, _ in os.walk(local_dir, topdown=False):
1266
+ for d in dirs:
1267
+ dir_path = os.path.join(root, d)
1268
+ if not os.listdir(dir_path):
1269
+ self.log.debug("Deleting stale empty directory: %s", dir_path)
1270
+ os.rmdir(dir_path)
1271
+
1272
+ def _sync_to_local_dir_if_changed(self, blob: Blob, local_target_path: Path):
1273
+ should_download = False
1274
+ download_msg = ""
1275
+ if not local_target_path.exists():
1276
+ should_download = True
1277
+ download_msg = f"Local file {local_target_path} does not exist."
1278
+ else:
1279
+ local_stats = local_target_path.stat()
1280
+ # Reload blob to get fresh metadata, including size and updated time
1281
+ blob.reload()
1282
+
1283
+ if blob.size != local_stats.st_size:
1284
+ should_download = True
1285
+ download_msg = (
1286
+ f"GCS object size ({blob.size}) and local file size ({local_stats.st_size}) differ."
1287
+ )
1288
+
1289
+ gcs_last_modified = blob.updated
1290
+ if (
1291
+ not should_download
1292
+ and gcs_last_modified
1293
+ and local_stats.st_mtime < gcs_last_modified.timestamp()
1294
+ ):
1295
+ should_download = True
1296
+ download_msg = f"GCS object last modified ({gcs_last_modified}) is newer than local file last modified ({datetime.fromtimestamp(local_stats.st_mtime, tz=timezone.utc)})."
1297
+
1298
+ if should_download:
1299
+ self.log.debug("%s Downloading %s to %s", download_msg, blob.name, local_target_path.as_posix())
1300
+ self.download(
1301
+ bucket_name=blob.bucket.name, object_name=blob.name, filename=str(local_target_path)
1302
+ )
1303
+ else:
1304
+ self.log.debug(
1305
+ "Local file %s is up-to-date with GCS object %s. Skipping download.",
1306
+ local_target_path.as_posix(),
1307
+ blob.name,
1308
+ )
1309
+
1310
+ def sync_to_local_dir(
1311
+ self,
1312
+ bucket_name: str,
1313
+ local_dir: str | Path,
1314
+ prefix: str | None = None,
1315
+ delete_stale: bool = False,
1316
+ ) -> None:
1317
+ """
1318
+ Download files from a GCS bucket to a local directory.
1319
+
1320
+ It will download all files from the given ``prefix`` and create the corresponding
1321
+ directory structure in the ``local_dir``.
1322
+
1323
+ If ``delete_stale`` is ``True``, it will delete all local files that do not exist in the GCS bucket.
1324
+
1325
+ :param bucket_name: The name of the GCS bucket.
1326
+ :param local_dir: The local directory to which the files will be downloaded.
1327
+ :param prefix: The prefix of the files to be downloaded.
1328
+ :param delete_stale: If ``True``, deletes local files that don't exist in the bucket.
1329
+ """
1330
+ prefix = prefix or ""
1331
+ local_dir_path = Path(local_dir)
1332
+ self.log.debug("Downloading data from gs://%s/%s to %s", bucket_name, prefix, local_dir_path)
1333
+
1334
+ gcs_bucket = self.get_bucket(bucket_name)
1335
+ local_gcs_objects = []
1336
+
1337
+ for blob in gcs_bucket.list_blobs(prefix=prefix):
1338
+ # GCS lists "directories" as objects ending with a slash. We should skip them.
1339
+ if blob.name.endswith("/"):
1340
+ continue
1341
+
1342
+ blob_path = Path(blob.name)
1343
+ local_target_path = local_dir_path.joinpath(blob_path.relative_to(prefix))
1344
+ if not local_target_path.parent.exists():
1345
+ local_target_path.parent.mkdir(parents=True, exist_ok=True)
1346
+ self.log.debug("Created local directory: %s", local_target_path.parent)
1347
+
1348
+ self._sync_to_local_dir_if_changed(blob=blob, local_target_path=local_target_path)
1349
+ local_gcs_objects.append(local_target_path)
1350
+
1351
+ if delete_stale:
1352
+ self._sync_to_local_dir_delete_stale_local_files(
1353
+ current_gcs_objects=local_gcs_objects, local_dir=local_dir_path
1354
+ )
1355
+
1252
1356
  def sync(
1253
1357
  self,
1254
1358
  source_bucket: str,
@@ -0,0 +1,196 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ """This module contains a Google Cloud GenAI Generative Model hook."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import time
23
+ from typing import TYPE_CHECKING, Any
24
+
25
+ from google import genai
26
+
27
+ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
28
+
29
+ if TYPE_CHECKING:
30
+ from google.genai.types import (
31
+ ContentListUnion,
32
+ ContentListUnionDict,
33
+ CountTokensConfigOrDict,
34
+ CountTokensResponse,
35
+ CreateCachedContentConfigOrDict,
36
+ CreateTuningJobConfigOrDict,
37
+ EmbedContentConfigOrDict,
38
+ EmbedContentResponse,
39
+ GenerateContentConfig,
40
+ TuningDatasetOrDict,
41
+ TuningJob,
42
+ )
43
+
44
+
45
+ class GenAIGenerativeModelHook(GoogleBaseHook):
46
+ """Class for Google Cloud Generative AI Vertex AI hook."""
47
+
48
+ def get_genai_client(self, project_id: str, location: str):
49
+ return genai.Client(
50
+ vertexai=True,
51
+ project=project_id,
52
+ location=location,
53
+ )
54
+
55
+ @GoogleBaseHook.fallback_to_default_project_id
56
+ def embed_content(
57
+ self,
58
+ model: str,
59
+ location: str,
60
+ contents: ContentListUnion | ContentListUnionDict | list[str],
61
+ config: EmbedContentConfigOrDict | None = None,
62
+ project_id: str = PROVIDE_PROJECT_ID,
63
+ ) -> EmbedContentResponse:
64
+ """
65
+ Generate embeddings for words, phrases, sentences, and code.
66
+
67
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
68
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
69
+ :param model: Required. The model to use.
70
+ :param contents: Optional. The contents to use for embedding.
71
+ :param config: Optional. Configuration for embeddings.
72
+ """
73
+ client = self.get_genai_client(project_id=project_id, location=location)
74
+
75
+ resp = client.models.embed_content(model=model, contents=contents, config=config)
76
+ return resp
77
+
78
+ @GoogleBaseHook.fallback_to_default_project_id
79
+ def generate_content(
80
+ self,
81
+ location: str,
82
+ model: str,
83
+ contents: ContentListUnionDict,
84
+ generation_config: GenerateContentConfig | None = None,
85
+ project_id: str = PROVIDE_PROJECT_ID,
86
+ ) -> str:
87
+ """
88
+ Make an API request to generate content using a model.
89
+
90
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
91
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
92
+ :param model: Required. The model to use.
93
+ :param contents: Required. The multi-part content of a message that a user or a program
94
+ gives to the generative model, in order to elicit a specific response.
95
+ :param generation_config: Optional. Generation configuration settings.
96
+ """
97
+ client = self.get_genai_client(project_id=project_id, location=location)
98
+ response = client.models.generate_content(
99
+ model=model,
100
+ contents=contents,
101
+ config=generation_config,
102
+ )
103
+
104
+ return response.text
105
+
106
+ @GoogleBaseHook.fallback_to_default_project_id
107
+ def supervised_fine_tuning_train(
108
+ self,
109
+ source_model: str,
110
+ location: str,
111
+ training_dataset: TuningDatasetOrDict,
112
+ tuning_job_config: CreateTuningJobConfigOrDict | dict[str, Any] | None = None,
113
+ project_id: str = PROVIDE_PROJECT_ID,
114
+ ) -> TuningJob:
115
+ """
116
+ Create a tuning job to adapt model behavior with a labeled dataset.
117
+
118
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
119
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
120
+ :param source_model: Required. A pre-trained model optimized for performing natural
121
+ language tasks such as classification, summarization, extraction, content
122
+ creation, and ideation.
123
+ :param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
124
+ must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
125
+ :param tuning_job_config: Optional. Configuration of the Tuning job to be created.
126
+ """
127
+ client = self.get_genai_client(project_id=project_id, location=location)
128
+
129
+ tuning_job = client.tunings.tune(
130
+ base_model=source_model,
131
+ training_dataset=training_dataset,
132
+ config=tuning_job_config,
133
+ )
134
+
135
+ # Poll until completion
136
+ running = {"JOB_STATE_PENDING", "JOB_STATE_RUNNING"}
137
+ while tuning_job.state in running:
138
+ time.sleep(60)
139
+ tuning_job = client.tunings.get(name=tuning_job.name)
140
+
141
+ return tuning_job
142
+
143
+ @GoogleBaseHook.fallback_to_default_project_id
144
+ def count_tokens(
145
+ self,
146
+ location: str,
147
+ model: str,
148
+ contents: ContentListUnion | ContentListUnionDict,
149
+ config: CountTokensConfigOrDict | None = None,
150
+ project_id: str = PROVIDE_PROJECT_ID,
151
+ ) -> CountTokensResponse:
152
+ """
153
+ Use Count Tokens API to calculate the number of input tokens before sending a request to Gemini API.
154
+
155
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
156
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
157
+ :param contents: Required. The multi-part content of a message that a user or a program
158
+ gives to the generative model, in order to elicit a specific response.
159
+ :param model: Required. Model,
160
+ supporting prompts with text-only input, including natural language
161
+ tasks, multi-turn text and code chat, and code generation. It can
162
+ output text and code.
163
+ :param config: Optional. Configuration for Count Tokens.
164
+ """
165
+ client = self.get_genai_client(project_id=project_id, location=location)
166
+ response = client.models.count_tokens(
167
+ model=model,
168
+ contents=contents,
169
+ config=config,
170
+ )
171
+
172
+ return response
173
+
174
+ @GoogleBaseHook.fallback_to_default_project_id
175
+ def create_cached_content(
176
+ self,
177
+ model: str,
178
+ location: str,
179
+ cached_content_config: CreateCachedContentConfigOrDict | None = None,
180
+ project_id: str = PROVIDE_PROJECT_ID,
181
+ ) -> str:
182
+ """
183
+ Create CachedContent to reduce the cost of requests containing repeat content.
184
+
185
+ :param project_id: Required. The ID of the Google Cloud project that the service belongs to.
186
+ :param location: Required. The ID of the Google Cloud location that the service belongs to.
187
+ :param model: Required. The name of the publisher model to use for cached content.
188
+ :param cached_content_config: Optional. Configuration of the Cached Content.
189
+ """
190
+ client = self.get_genai_client(project_id=project_id, location=location)
191
+ resp = client.caches.create(
192
+ model=model,
193
+ config=cached_content_config,
194
+ )
195
+
196
+ return resp.name
@@ -29,7 +29,7 @@ from looker_sdk.sdk.api40 import methods as methods40
29
29
  from packaging.version import parse as parse_version
30
30
 
31
31
  from airflow.exceptions import AirflowException
32
- from airflow.providers.google.version_compat import BaseHook
32
+ from airflow.providers.common.compat.sdk import BaseHook
33
33
  from airflow.version import version
34
34
 
35
35
  if TYPE_CHECKING:
@@ -31,6 +31,7 @@ from airflow.exceptions import AirflowException
31
31
  from airflow.providers.common.sql.hooks.sql import DbApiHook
32
32
  from airflow.providers.google.common.consts import CLIENT_INFO
33
33
  from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
34
+ from airflow.providers.openlineage.sqlparser import DatabaseInfo
34
35
 
35
36
  if TYPE_CHECKING:
36
37
  from google.cloud.spanner_v1.database import Database
@@ -38,6 +39,8 @@ if TYPE_CHECKING:
38
39
  from google.cloud.spanner_v1.transaction import Transaction
39
40
  from google.longrunning.operations_grpc_pb2 import Operation
40
41
 
42
+ from airflow.models.connection import Connection
43
+
41
44
 
42
45
  class SpannerConnectionParams(NamedTuple):
43
46
  """Information about Google Spanner connection parameters."""
@@ -427,3 +430,45 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
427
430
  rc = transaction.execute_update(sql)
428
431
  counts[sql] = rc
429
432
  return counts
433
+
434
+ def _get_openlineage_authority_part(self, connection: Connection) -> str | None:
435
+ """Build Spanner-specific authority part for OpenLineage. Returns {project}/{instance}."""
436
+ extras = connection.extra_dejson
437
+ project_id = extras.get("project_id")
438
+ instance_id = extras.get("instance_id")
439
+
440
+ if not project_id or not instance_id:
441
+ return None
442
+
443
+ return f"{project_id}/{instance_id}"
444
+
445
+ def get_openlineage_database_dialect(self, connection: Connection) -> str:
446
+ """Return database dialect for OpenLineage."""
447
+ return "spanner"
448
+
449
+ def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
450
+ """Return Spanner specific information for OpenLineage."""
451
+ extras = connection.extra_dejson
452
+ database_id = extras.get("database_id")
453
+
454
+ return DatabaseInfo(
455
+ scheme=self.get_openlineage_database_dialect(connection),
456
+ authority=self._get_openlineage_authority_part(connection),
457
+ database=database_id,
458
+ information_schema_columns=[
459
+ "table_schema",
460
+ "table_name",
461
+ "column_name",
462
+ "ordinal_position",
463
+ "spanner_type",
464
+ ],
465
+ )
466
+
467
+ def get_openlineage_default_schema(self) -> str | None:
468
+ """
469
+ Spanner expose 'public' or '' schema depending on dialect(Postgres vs GoogleSQL).
470
+
471
+ SQLAlchemy dialect for Spanner does not expose default schema, so we return None
472
+ to follow the same approach.
473
+ """
474
+ return None