apache-airflow-providers-google 18.0.0__py3-none-any.whl → 18.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of apache-airflow-providers-google might be problematic. Click here for more details.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +5 -5
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/bigquery.py +45 -42
- airflow/providers/google/cloud/hooks/cloud_composer.py +131 -1
- airflow/providers/google/cloud/hooks/cloud_sql.py +88 -13
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +16 -0
- airflow/providers/google/cloud/hooks/dataflow.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +3 -0
- airflow/providers/google/cloud/hooks/gcs.py +107 -3
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/looker.py +1 -1
- airflow/providers/google/cloud/hooks/spanner.py +45 -0
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +30 -0
- airflow/providers/google/cloud/links/base.py +11 -11
- airflow/providers/google/cloud/links/dataproc.py +2 -10
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +3 -1
- airflow/providers/google/cloud/operators/bigquery.py +2 -9
- airflow/providers/google/cloud/operators/cloud_run.py +2 -1
- airflow/providers/google/cloud/operators/cloud_sql.py +1 -1
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +89 -6
- airflow/providers/google/cloud/operators/datafusion.py +36 -7
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/spanner.py +22 -6
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +7 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +30 -0
- airflow/providers/google/cloud/operators/workflows.py +17 -6
- airflow/providers/google/cloud/sensors/bigquery.py +1 -1
- airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -6
- airflow/providers/google/cloud/sensors/bigtable.py +1 -6
- airflow/providers/google/cloud/sensors/cloud_composer.py +65 -31
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +1 -6
- airflow/providers/google/cloud/sensors/dataflow.py +1 -1
- airflow/providers/google/cloud/sensors/dataform.py +1 -6
- airflow/providers/google/cloud/sensors/datafusion.py +1 -6
- airflow/providers/google/cloud/sensors/dataplex.py +1 -6
- airflow/providers/google/cloud/sensors/dataprep.py +1 -6
- airflow/providers/google/cloud/sensors/dataproc.py +1 -6
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +1 -6
- airflow/providers/google/cloud/sensors/gcs.py +1 -7
- airflow/providers/google/cloud/sensors/looker.py +1 -6
- airflow/providers/google/cloud/sensors/pubsub.py +1 -6
- airflow/providers/google/cloud/sensors/tasks.py +1 -6
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +1 -6
- airflow/providers/google/cloud/sensors/workflows.py +1 -6
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +2 -1
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +11 -2
- airflow/providers/google/cloud/triggers/bigquery.py +15 -3
- airflow/providers/google/cloud/triggers/cloud_composer.py +51 -21
- airflow/providers/google/cloud/triggers/cloud_run.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +90 -0
- airflow/providers/google/cloud/triggers/pubsub.py +14 -18
- airflow/providers/google/common/hooks/base_google.py +1 -1
- airflow/providers/google/get_provider_info.py +15 -0
- airflow/providers/google/leveldb/hooks/leveldb.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +2 -8
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +1 -6
- airflow/providers/google/marketing_platform/sensors/display_video.py +1 -6
- airflow/providers/google/suite/sensors/drive.py +1 -6
- airflow/providers/google/version_compat.py +0 -20
- {apache_airflow_providers_google-18.0.0.dist-info → apache_airflow_providers_google-18.1.0.dist-info}/METADATA +8 -8
- {apache_airflow_providers_google-18.0.0.dist-info → apache_airflow_providers_google-18.1.0.dist-info}/RECORD +72 -65
- {apache_airflow_providers_google-18.0.0.dist-info → apache_airflow_providers_google-18.1.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-18.0.0.dist-info → apache_airflow_providers_google-18.1.0.dist-info}/entry_points.txt +0 -0
|
@@ -50,7 +50,14 @@ from googleapiclient.errors import HttpError
|
|
|
50
50
|
# Number of retries - used by googleapiclient method calls to perform retries
|
|
51
51
|
# For requests that are "retriable"
|
|
52
52
|
from airflow.exceptions import AirflowException
|
|
53
|
-
from airflow.
|
|
53
|
+
from airflow.providers.google.version_compat import AIRFLOW_V_3_1_PLUS
|
|
54
|
+
|
|
55
|
+
if AIRFLOW_V_3_1_PLUS:
|
|
56
|
+
from airflow.sdk import Connection
|
|
57
|
+
else:
|
|
58
|
+
from airflow.models import Connection # type: ignore[assignment,attr-defined,no-redef]
|
|
59
|
+
|
|
60
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
54
61
|
from airflow.providers.google.cloud.hooks.secret_manager import (
|
|
55
62
|
GoogleCloudSecretManagerHook,
|
|
56
63
|
)
|
|
@@ -60,7 +67,6 @@ from airflow.providers.google.common.hooks.base_google import (
|
|
|
60
67
|
GoogleBaseHook,
|
|
61
68
|
get_field,
|
|
62
69
|
)
|
|
63
|
-
from airflow.providers.google.version_compat import BaseHook
|
|
64
70
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
|
65
71
|
|
|
66
72
|
if TYPE_CHECKING:
|
|
@@ -1045,15 +1051,26 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
|
1045
1051
|
def _quote(value) -> str | None:
|
|
1046
1052
|
return quote_plus(value) if value else None
|
|
1047
1053
|
|
|
1048
|
-
def
|
|
1054
|
+
def _reserve_port(self):
|
|
1049
1055
|
if self.use_proxy:
|
|
1050
1056
|
if self.sql_proxy_use_tcp:
|
|
1051
1057
|
if not self.sql_proxy_tcp_port:
|
|
1052
1058
|
self.reserve_free_tcp_port()
|
|
1053
1059
|
if not self.sql_proxy_unique_path:
|
|
1054
1060
|
self.sql_proxy_unique_path = self._generate_unique_path()
|
|
1061
|
+
|
|
1062
|
+
def _generate_connection_uri(self) -> str:
|
|
1063
|
+
self._reserve_port()
|
|
1055
1064
|
if not self.database_type:
|
|
1056
1065
|
raise ValueError("The database_type should be set")
|
|
1066
|
+
if not self.user:
|
|
1067
|
+
raise AirflowException("The login parameter needs to be set in connection")
|
|
1068
|
+
if not self.public_ip:
|
|
1069
|
+
raise AirflowException("The location parameter needs to be set in connection")
|
|
1070
|
+
if not self.password:
|
|
1071
|
+
raise AirflowException("The password parameter needs to be set in connection")
|
|
1072
|
+
if not self.database:
|
|
1073
|
+
raise AirflowException("The database parameter needs to be set in connection")
|
|
1057
1074
|
|
|
1058
1075
|
database_uris = CONNECTION_URIS[self.database_type]
|
|
1059
1076
|
ssl_spec = None
|
|
@@ -1072,14 +1089,6 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
|
1072
1089
|
ssl_spec = {"cert": self.sslcert, "key": self.sslkey, "ca": self.sslrootcert}
|
|
1073
1090
|
else:
|
|
1074
1091
|
format_string = public_uris["non-ssl"]
|
|
1075
|
-
if not self.user:
|
|
1076
|
-
raise AirflowException("The login parameter needs to be set in connection")
|
|
1077
|
-
if not self.public_ip:
|
|
1078
|
-
raise AirflowException("The location parameter needs to be set in connection")
|
|
1079
|
-
if not self.password:
|
|
1080
|
-
raise AirflowException("The password parameter needs to be set in connection")
|
|
1081
|
-
if not self.database:
|
|
1082
|
-
raise AirflowException("The database parameter needs to be set in connection")
|
|
1083
1092
|
|
|
1084
1093
|
connection_uri = format_string.format(
|
|
1085
1094
|
user=quote_plus(self.user) if self.user else "",
|
|
@@ -1113,6 +1122,69 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
|
1113
1122
|
instance_specification += f"=tcp:{self.sql_proxy_tcp_port}"
|
|
1114
1123
|
return instance_specification
|
|
1115
1124
|
|
|
1125
|
+
def _generate_connection_parameters(self) -> dict:
|
|
1126
|
+
self._reserve_port()
|
|
1127
|
+
if not self.database_type:
|
|
1128
|
+
raise ValueError("The database_type should be set")
|
|
1129
|
+
if not self.user:
|
|
1130
|
+
raise AirflowException("The login parameter needs to be set in connection")
|
|
1131
|
+
if not self.public_ip:
|
|
1132
|
+
raise AirflowException("The location parameter needs to be set in connection")
|
|
1133
|
+
if not self.password:
|
|
1134
|
+
raise AirflowException("The password parameter needs to be set in connection")
|
|
1135
|
+
if not self.database:
|
|
1136
|
+
raise AirflowException("The database parameter needs to be set in connection")
|
|
1137
|
+
|
|
1138
|
+
connection_parameters = {}
|
|
1139
|
+
|
|
1140
|
+
connection_parameters["conn_type"] = self.database_type
|
|
1141
|
+
connection_parameters["login"] = self.user
|
|
1142
|
+
connection_parameters["password"] = self.password
|
|
1143
|
+
connection_parameters["schema"] = self.database
|
|
1144
|
+
connection_parameters["extra"] = {}
|
|
1145
|
+
|
|
1146
|
+
database_uris = CONNECTION_URIS[self.database_type]
|
|
1147
|
+
if self.use_proxy:
|
|
1148
|
+
proxy_uris = database_uris["proxy"]
|
|
1149
|
+
if self.sql_proxy_use_tcp:
|
|
1150
|
+
connection_parameters["host"] = "127.0.0.1"
|
|
1151
|
+
connection_parameters["port"] = self.sql_proxy_tcp_port
|
|
1152
|
+
else:
|
|
1153
|
+
socket_path = f"{self.sql_proxy_unique_path}/{self._get_instance_socket_name()}"
|
|
1154
|
+
if "localhost" in proxy_uris["socket"]:
|
|
1155
|
+
connection_parameters["host"] = "localhost"
|
|
1156
|
+
connection_parameters["extra"].update({"unix_socket": socket_path})
|
|
1157
|
+
else:
|
|
1158
|
+
connection_parameters["host"] = socket_path
|
|
1159
|
+
else:
|
|
1160
|
+
public_uris = database_uris["public"]
|
|
1161
|
+
if self.use_ssl:
|
|
1162
|
+
connection_parameters["host"] = self.public_ip
|
|
1163
|
+
connection_parameters["port"] = self.public_port
|
|
1164
|
+
if "ssl_spec" in public_uris["ssl"]:
|
|
1165
|
+
connection_parameters["extra"].update(
|
|
1166
|
+
{
|
|
1167
|
+
"ssl": json.dumps(
|
|
1168
|
+
{"cert": self.sslcert, "key": self.sslkey, "ca": self.sslrootcert}
|
|
1169
|
+
)
|
|
1170
|
+
}
|
|
1171
|
+
)
|
|
1172
|
+
else:
|
|
1173
|
+
connection_parameters["extra"].update(
|
|
1174
|
+
{
|
|
1175
|
+
"sslmode": "verify-ca",
|
|
1176
|
+
"sslcert": self.sslcert,
|
|
1177
|
+
"sslkey": self.sslkey,
|
|
1178
|
+
"sslrootcert": self.sslrootcert,
|
|
1179
|
+
}
|
|
1180
|
+
)
|
|
1181
|
+
else:
|
|
1182
|
+
connection_parameters["host"] = self.public_ip
|
|
1183
|
+
connection_parameters["port"] = self.public_port
|
|
1184
|
+
if connection_parameters.get("extra"):
|
|
1185
|
+
connection_parameters["extra"] = json.dumps(connection_parameters["extra"])
|
|
1186
|
+
return connection_parameters
|
|
1187
|
+
|
|
1116
1188
|
def create_connection(self) -> Connection:
|
|
1117
1189
|
"""
|
|
1118
1190
|
Create a connection.
|
|
@@ -1120,8 +1192,11 @@ class CloudSQLDatabaseHook(BaseHook):
|
|
|
1120
1192
|
Connection ID will be randomly generated according to whether it uses
|
|
1121
1193
|
proxy, TCP, UNIX sockets, SSL.
|
|
1122
1194
|
"""
|
|
1123
|
-
|
|
1124
|
-
|
|
1195
|
+
if AIRFLOW_V_3_1_PLUS:
|
|
1196
|
+
kwargs = self._generate_connection_parameters()
|
|
1197
|
+
else:
|
|
1198
|
+
kwargs = {"uri": self._generate_connection_uri()}
|
|
1199
|
+
connection = Connection(conn_id=self.db_conn_id, **kwargs)
|
|
1125
1200
|
self.log.info("Creating connection %s", self.db_conn_id)
|
|
1126
1201
|
return connection
|
|
1127
1202
|
|
|
@@ -38,6 +38,7 @@ from typing import TYPE_CHECKING, Any
|
|
|
38
38
|
|
|
39
39
|
from google.cloud.storage_transfer_v1 import (
|
|
40
40
|
ListTransferJobsRequest,
|
|
41
|
+
RunTransferJobRequest,
|
|
41
42
|
StorageTransferServiceAsyncClient,
|
|
42
43
|
TransferJob,
|
|
43
44
|
TransferOperation,
|
|
@@ -55,6 +56,7 @@ from airflow.providers.google.common.hooks.base_google import (
|
|
|
55
56
|
)
|
|
56
57
|
|
|
57
58
|
if TYPE_CHECKING:
|
|
59
|
+
from google.api_core import operation_async
|
|
58
60
|
from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
|
|
59
61
|
ListTransferJobsAsyncPager,
|
|
60
62
|
)
|
|
@@ -712,3 +714,17 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
|
|
|
712
714
|
f"Expected: {', '.join(expected_statuses_set)}"
|
|
713
715
|
)
|
|
714
716
|
return False
|
|
717
|
+
|
|
718
|
+
async def run_transfer_job(self, job_name: str) -> operation_async.AsyncOperation:
|
|
719
|
+
"""
|
|
720
|
+
Run Google Storage Transfer Service job.
|
|
721
|
+
|
|
722
|
+
:param job_name: (Required) Name of the job to run.
|
|
723
|
+
"""
|
|
724
|
+
client = await self.get_conn()
|
|
725
|
+
request = RunTransferJobRequest(
|
|
726
|
+
job_name=job_name,
|
|
727
|
+
project_id=self.project_id,
|
|
728
|
+
)
|
|
729
|
+
operation = await client.run_transfer_job(request=request)
|
|
730
|
+
return operation
|
|
@@ -51,12 +51,12 @@ from googleapiclient.discovery import Resource, build
|
|
|
51
51
|
|
|
52
52
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
53
53
|
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType, beam_options_to_args
|
|
54
|
+
from airflow.providers.common.compat.sdk import timeout
|
|
54
55
|
from airflow.providers.google.common.hooks.base_google import (
|
|
55
56
|
PROVIDE_PROJECT_ID,
|
|
56
57
|
GoogleBaseAsyncHook,
|
|
57
58
|
GoogleBaseHook,
|
|
58
59
|
)
|
|
59
|
-
from airflow.providers.google.version_compat import timeout
|
|
60
60
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
|
61
61
|
|
|
62
62
|
if TYPE_CHECKING:
|
|
@@ -28,7 +28,7 @@ import requests
|
|
|
28
28
|
from requests import HTTPError
|
|
29
29
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
30
30
|
|
|
31
|
-
from airflow.providers.
|
|
31
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def _get_field(extras: dict, field_name: str) -> str | None:
|
|
@@ -912,12 +912,15 @@ class DataprocHook(GoogleBaseHook):
|
|
|
912
912
|
state = None
|
|
913
913
|
start = time.monotonic()
|
|
914
914
|
while state not in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED):
|
|
915
|
+
self.log.debug("Waiting for job %s to complete", job_id)
|
|
915
916
|
if timeout and start + timeout < time.monotonic():
|
|
916
917
|
raise AirflowException(f"Timeout: dataproc job {job_id} is not ready after {timeout}s")
|
|
918
|
+
self.log.debug("Sleeping for %s seconds", wait_time)
|
|
917
919
|
time.sleep(wait_time)
|
|
918
920
|
try:
|
|
919
921
|
job = self.get_job(project_id=project_id, region=region, job_id=job_id)
|
|
920
922
|
state = job.status.state
|
|
923
|
+
self.log.debug("Job %s is in state %s", job_id, state)
|
|
921
924
|
except ServerError as err:
|
|
922
925
|
self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err)
|
|
923
926
|
|
|
@@ -28,8 +28,10 @@ import time
|
|
|
28
28
|
import warnings
|
|
29
29
|
from collections.abc import Callable, Generator, Sequence
|
|
30
30
|
from contextlib import contextmanager
|
|
31
|
+
from datetime import datetime
|
|
31
32
|
from functools import partial
|
|
32
33
|
from io import BytesIO
|
|
34
|
+
from pathlib import Path
|
|
33
35
|
from tempfile import NamedTemporaryFile
|
|
34
36
|
from typing import IO, TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload
|
|
35
37
|
from urllib.parse import urlsplit
|
|
@@ -50,12 +52,14 @@ from airflow.providers.google.common.hooks.base_google import (
|
|
|
50
52
|
GoogleBaseAsyncHook,
|
|
51
53
|
GoogleBaseHook,
|
|
52
54
|
)
|
|
53
|
-
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
from airflow.sdk import timezone
|
|
58
|
+
except ImportError:
|
|
59
|
+
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
|
|
54
60
|
from airflow.version import version
|
|
55
61
|
|
|
56
62
|
if TYPE_CHECKING:
|
|
57
|
-
from datetime import datetime
|
|
58
|
-
|
|
59
63
|
from aiohttp import ClientSession
|
|
60
64
|
from google.api_core.retry import Retry
|
|
61
65
|
from google.cloud.storage.blob import Blob
|
|
@@ -1249,6 +1253,106 @@ class GCSHook(GoogleBaseHook):
|
|
|
1249
1253
|
|
|
1250
1254
|
self.log.info("Completed successfully.")
|
|
1251
1255
|
|
|
1256
|
+
def _sync_to_local_dir_delete_stale_local_files(self, current_gcs_objects: List[Path], local_dir: Path):
|
|
1257
|
+
current_gcs_keys = {key.resolve() for key in current_gcs_objects}
|
|
1258
|
+
|
|
1259
|
+
for item in local_dir.rglob("*"):
|
|
1260
|
+
if item.is_file():
|
|
1261
|
+
if item.resolve() not in current_gcs_keys:
|
|
1262
|
+
self.log.debug("Deleting stale local file: %s", item)
|
|
1263
|
+
item.unlink()
|
|
1264
|
+
# Clean up empty directories
|
|
1265
|
+
for root, dirs, _ in os.walk(local_dir, topdown=False):
|
|
1266
|
+
for d in dirs:
|
|
1267
|
+
dir_path = os.path.join(root, d)
|
|
1268
|
+
if not os.listdir(dir_path):
|
|
1269
|
+
self.log.debug("Deleting stale empty directory: %s", dir_path)
|
|
1270
|
+
os.rmdir(dir_path)
|
|
1271
|
+
|
|
1272
|
+
def _sync_to_local_dir_if_changed(self, blob: Blob, local_target_path: Path):
|
|
1273
|
+
should_download = False
|
|
1274
|
+
download_msg = ""
|
|
1275
|
+
if not local_target_path.exists():
|
|
1276
|
+
should_download = True
|
|
1277
|
+
download_msg = f"Local file {local_target_path} does not exist."
|
|
1278
|
+
else:
|
|
1279
|
+
local_stats = local_target_path.stat()
|
|
1280
|
+
# Reload blob to get fresh metadata, including size and updated time
|
|
1281
|
+
blob.reload()
|
|
1282
|
+
|
|
1283
|
+
if blob.size != local_stats.st_size:
|
|
1284
|
+
should_download = True
|
|
1285
|
+
download_msg = (
|
|
1286
|
+
f"GCS object size ({blob.size}) and local file size ({local_stats.st_size}) differ."
|
|
1287
|
+
)
|
|
1288
|
+
|
|
1289
|
+
gcs_last_modified = blob.updated
|
|
1290
|
+
if (
|
|
1291
|
+
not should_download
|
|
1292
|
+
and gcs_last_modified
|
|
1293
|
+
and local_stats.st_mtime < gcs_last_modified.timestamp()
|
|
1294
|
+
):
|
|
1295
|
+
should_download = True
|
|
1296
|
+
download_msg = f"GCS object last modified ({gcs_last_modified}) is newer than local file last modified ({datetime.fromtimestamp(local_stats.st_mtime, tz=timezone.utc)})."
|
|
1297
|
+
|
|
1298
|
+
if should_download:
|
|
1299
|
+
self.log.debug("%s Downloading %s to %s", download_msg, blob.name, local_target_path.as_posix())
|
|
1300
|
+
self.download(
|
|
1301
|
+
bucket_name=blob.bucket.name, object_name=blob.name, filename=str(local_target_path)
|
|
1302
|
+
)
|
|
1303
|
+
else:
|
|
1304
|
+
self.log.debug(
|
|
1305
|
+
"Local file %s is up-to-date with GCS object %s. Skipping download.",
|
|
1306
|
+
local_target_path.as_posix(),
|
|
1307
|
+
blob.name,
|
|
1308
|
+
)
|
|
1309
|
+
|
|
1310
|
+
def sync_to_local_dir(
|
|
1311
|
+
self,
|
|
1312
|
+
bucket_name: str,
|
|
1313
|
+
local_dir: str | Path,
|
|
1314
|
+
prefix: str | None = None,
|
|
1315
|
+
delete_stale: bool = False,
|
|
1316
|
+
) -> None:
|
|
1317
|
+
"""
|
|
1318
|
+
Download files from a GCS bucket to a local directory.
|
|
1319
|
+
|
|
1320
|
+
It will download all files from the given ``prefix`` and create the corresponding
|
|
1321
|
+
directory structure in the ``local_dir``.
|
|
1322
|
+
|
|
1323
|
+
If ``delete_stale`` is ``True``, it will delete all local files that do not exist in the GCS bucket.
|
|
1324
|
+
|
|
1325
|
+
:param bucket_name: The name of the GCS bucket.
|
|
1326
|
+
:param local_dir: The local directory to which the files will be downloaded.
|
|
1327
|
+
:param prefix: The prefix of the files to be downloaded.
|
|
1328
|
+
:param delete_stale: If ``True``, deletes local files that don't exist in the bucket.
|
|
1329
|
+
"""
|
|
1330
|
+
prefix = prefix or ""
|
|
1331
|
+
local_dir_path = Path(local_dir)
|
|
1332
|
+
self.log.debug("Downloading data from gs://%s/%s to %s", bucket_name, prefix, local_dir_path)
|
|
1333
|
+
|
|
1334
|
+
gcs_bucket = self.get_bucket(bucket_name)
|
|
1335
|
+
local_gcs_objects = []
|
|
1336
|
+
|
|
1337
|
+
for blob in gcs_bucket.list_blobs(prefix=prefix):
|
|
1338
|
+
# GCS lists "directories" as objects ending with a slash. We should skip them.
|
|
1339
|
+
if blob.name.endswith("/"):
|
|
1340
|
+
continue
|
|
1341
|
+
|
|
1342
|
+
blob_path = Path(blob.name)
|
|
1343
|
+
local_target_path = local_dir_path.joinpath(blob_path.relative_to(prefix))
|
|
1344
|
+
if not local_target_path.parent.exists():
|
|
1345
|
+
local_target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1346
|
+
self.log.debug("Created local directory: %s", local_target_path.parent)
|
|
1347
|
+
|
|
1348
|
+
self._sync_to_local_dir_if_changed(blob=blob, local_target_path=local_target_path)
|
|
1349
|
+
local_gcs_objects.append(local_target_path)
|
|
1350
|
+
|
|
1351
|
+
if delete_stale:
|
|
1352
|
+
self._sync_to_local_dir_delete_stale_local_files(
|
|
1353
|
+
current_gcs_objects=local_gcs_objects, local_dir=local_dir_path
|
|
1354
|
+
)
|
|
1355
|
+
|
|
1252
1356
|
def sync(
|
|
1253
1357
|
self,
|
|
1254
1358
|
source_bucket: str,
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
|
4
|
+
# distributed with this work for additional information
|
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
|
7
|
+
# "License"); you may not use this file except in compliance
|
|
8
|
+
# with the License. You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
|
13
|
+
# software distributed under the License is distributed on an
|
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
15
|
+
# KIND, either express or implied. See the License for the
|
|
16
|
+
# specific language governing permissions and limitations
|
|
17
|
+
# under the License.
|
|
18
|
+
"""This module contains a Google Cloud GenAI Generative Model hook."""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import time
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
|
+
|
|
25
|
+
from google import genai
|
|
26
|
+
|
|
27
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from google.genai.types import (
|
|
31
|
+
ContentListUnion,
|
|
32
|
+
ContentListUnionDict,
|
|
33
|
+
CountTokensConfigOrDict,
|
|
34
|
+
CountTokensResponse,
|
|
35
|
+
CreateCachedContentConfigOrDict,
|
|
36
|
+
CreateTuningJobConfigOrDict,
|
|
37
|
+
EmbedContentConfigOrDict,
|
|
38
|
+
EmbedContentResponse,
|
|
39
|
+
GenerateContentConfig,
|
|
40
|
+
TuningDatasetOrDict,
|
|
41
|
+
TuningJob,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GenAIGenerativeModelHook(GoogleBaseHook):
|
|
46
|
+
"""Class for Google Cloud Generative AI Vertex AI hook."""
|
|
47
|
+
|
|
48
|
+
def get_genai_client(self, project_id: str, location: str):
|
|
49
|
+
return genai.Client(
|
|
50
|
+
vertexai=True,
|
|
51
|
+
project=project_id,
|
|
52
|
+
location=location,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
56
|
+
def embed_content(
|
|
57
|
+
self,
|
|
58
|
+
model: str,
|
|
59
|
+
location: str,
|
|
60
|
+
contents: ContentListUnion | ContentListUnionDict | list[str],
|
|
61
|
+
config: EmbedContentConfigOrDict | None = None,
|
|
62
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
63
|
+
) -> EmbedContentResponse:
|
|
64
|
+
"""
|
|
65
|
+
Generate embeddings for words, phrases, sentences, and code.
|
|
66
|
+
|
|
67
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
68
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
69
|
+
:param model: Required. The model to use.
|
|
70
|
+
:param contents: Optional. The contents to use for embedding.
|
|
71
|
+
:param config: Optional. Configuration for embeddings.
|
|
72
|
+
"""
|
|
73
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
74
|
+
|
|
75
|
+
resp = client.models.embed_content(model=model, contents=contents, config=config)
|
|
76
|
+
return resp
|
|
77
|
+
|
|
78
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
79
|
+
def generate_content(
|
|
80
|
+
self,
|
|
81
|
+
location: str,
|
|
82
|
+
model: str,
|
|
83
|
+
contents: ContentListUnionDict,
|
|
84
|
+
generation_config: GenerateContentConfig | None = None,
|
|
85
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
86
|
+
) -> str:
|
|
87
|
+
"""
|
|
88
|
+
Make an API request to generate content using a model.
|
|
89
|
+
|
|
90
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
91
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
92
|
+
:param model: Required. The model to use.
|
|
93
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
|
94
|
+
gives to the generative model, in order to elicit a specific response.
|
|
95
|
+
:param generation_config: Optional. Generation configuration settings.
|
|
96
|
+
"""
|
|
97
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
98
|
+
response = client.models.generate_content(
|
|
99
|
+
model=model,
|
|
100
|
+
contents=contents,
|
|
101
|
+
config=generation_config,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return response.text
|
|
105
|
+
|
|
106
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
107
|
+
def supervised_fine_tuning_train(
|
|
108
|
+
self,
|
|
109
|
+
source_model: str,
|
|
110
|
+
location: str,
|
|
111
|
+
training_dataset: TuningDatasetOrDict,
|
|
112
|
+
tuning_job_config: CreateTuningJobConfigOrDict | dict[str, Any] | None = None,
|
|
113
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
114
|
+
) -> TuningJob:
|
|
115
|
+
"""
|
|
116
|
+
Create a tuning job to adapt model behavior with a labeled dataset.
|
|
117
|
+
|
|
118
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
119
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
120
|
+
:param source_model: Required. A pre-trained model optimized for performing natural
|
|
121
|
+
language tasks such as classification, summarization, extraction, content
|
|
122
|
+
creation, and ideation.
|
|
123
|
+
:param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
|
|
124
|
+
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
|
|
125
|
+
:param tuning_job_config: Optional. Configuration of the Tuning job to be created.
|
|
126
|
+
"""
|
|
127
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
128
|
+
|
|
129
|
+
tuning_job = client.tunings.tune(
|
|
130
|
+
base_model=source_model,
|
|
131
|
+
training_dataset=training_dataset,
|
|
132
|
+
config=tuning_job_config,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Poll until completion
|
|
136
|
+
running = {"JOB_STATE_PENDING", "JOB_STATE_RUNNING"}
|
|
137
|
+
while tuning_job.state in running:
|
|
138
|
+
time.sleep(60)
|
|
139
|
+
tuning_job = client.tunings.get(name=tuning_job.name)
|
|
140
|
+
|
|
141
|
+
return tuning_job
|
|
142
|
+
|
|
143
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
144
|
+
def count_tokens(
|
|
145
|
+
self,
|
|
146
|
+
location: str,
|
|
147
|
+
model: str,
|
|
148
|
+
contents: ContentListUnion | ContentListUnionDict,
|
|
149
|
+
config: CountTokensConfigOrDict | None = None,
|
|
150
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
151
|
+
) -> CountTokensResponse:
|
|
152
|
+
"""
|
|
153
|
+
Use Count Tokens API to calculate the number of input tokens before sending a request to Gemini API.
|
|
154
|
+
|
|
155
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
156
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
157
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
|
158
|
+
gives to the generative model, in order to elicit a specific response.
|
|
159
|
+
:param model: Required. Model,
|
|
160
|
+
supporting prompts with text-only input, including natural language
|
|
161
|
+
tasks, multi-turn text and code chat, and code generation. It can
|
|
162
|
+
output text and code.
|
|
163
|
+
:param config: Optional. Configuration for Count Tokens.
|
|
164
|
+
"""
|
|
165
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
166
|
+
response = client.models.count_tokens(
|
|
167
|
+
model=model,
|
|
168
|
+
contents=contents,
|
|
169
|
+
config=config,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return response
|
|
173
|
+
|
|
174
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
175
|
+
def create_cached_content(
|
|
176
|
+
self,
|
|
177
|
+
model: str,
|
|
178
|
+
location: str,
|
|
179
|
+
cached_content_config: CreateCachedContentConfigOrDict | None = None,
|
|
180
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
181
|
+
) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Create CachedContent to reduce the cost of requests containing repeat content.
|
|
184
|
+
|
|
185
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
186
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
187
|
+
:param model: Required. The name of the publisher model to use for cached content.
|
|
188
|
+
:param cached_content_config: Optional. Configuration of the Cached Content.
|
|
189
|
+
"""
|
|
190
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
191
|
+
resp = client.caches.create(
|
|
192
|
+
model=model,
|
|
193
|
+
config=cached_content_config,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return resp.name
|
|
@@ -29,7 +29,7 @@ from looker_sdk.sdk.api40 import methods as methods40
|
|
|
29
29
|
from packaging.version import parse as parse_version
|
|
30
30
|
|
|
31
31
|
from airflow.exceptions import AirflowException
|
|
32
|
-
from airflow.providers.
|
|
32
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
33
33
|
from airflow.version import version
|
|
34
34
|
|
|
35
35
|
if TYPE_CHECKING:
|
|
@@ -31,6 +31,7 @@ from airflow.exceptions import AirflowException
|
|
|
31
31
|
from airflow.providers.common.sql.hooks.sql import DbApiHook
|
|
32
32
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
33
33
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
|
|
34
|
+
from airflow.providers.openlineage.sqlparser import DatabaseInfo
|
|
34
35
|
|
|
35
36
|
if TYPE_CHECKING:
|
|
36
37
|
from google.cloud.spanner_v1.database import Database
|
|
@@ -38,6 +39,8 @@ if TYPE_CHECKING:
|
|
|
38
39
|
from google.cloud.spanner_v1.transaction import Transaction
|
|
39
40
|
from google.longrunning.operations_grpc_pb2 import Operation
|
|
40
41
|
|
|
42
|
+
from airflow.models.connection import Connection
|
|
43
|
+
|
|
41
44
|
|
|
42
45
|
class SpannerConnectionParams(NamedTuple):
|
|
43
46
|
"""Information about Google Spanner connection parameters."""
|
|
@@ -427,3 +430,45 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
|
|
|
427
430
|
rc = transaction.execute_update(sql)
|
|
428
431
|
counts[sql] = rc
|
|
429
432
|
return counts
|
|
433
|
+
|
|
434
|
+
def _get_openlineage_authority_part(self, connection: Connection) -> str | None:
|
|
435
|
+
"""Build Spanner-specific authority part for OpenLineage. Returns {project}/{instance}."""
|
|
436
|
+
extras = connection.extra_dejson
|
|
437
|
+
project_id = extras.get("project_id")
|
|
438
|
+
instance_id = extras.get("instance_id")
|
|
439
|
+
|
|
440
|
+
if not project_id or not instance_id:
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
return f"{project_id}/{instance_id}"
|
|
444
|
+
|
|
445
|
+
def get_openlineage_database_dialect(self, connection: Connection) -> str:
|
|
446
|
+
"""Return database dialect for OpenLineage."""
|
|
447
|
+
return "spanner"
|
|
448
|
+
|
|
449
|
+
def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
|
|
450
|
+
"""Return Spanner specific information for OpenLineage."""
|
|
451
|
+
extras = connection.extra_dejson
|
|
452
|
+
database_id = extras.get("database_id")
|
|
453
|
+
|
|
454
|
+
return DatabaseInfo(
|
|
455
|
+
scheme=self.get_openlineage_database_dialect(connection),
|
|
456
|
+
authority=self._get_openlineage_authority_part(connection),
|
|
457
|
+
database=database_id,
|
|
458
|
+
information_schema_columns=[
|
|
459
|
+
"table_schema",
|
|
460
|
+
"table_name",
|
|
461
|
+
"column_name",
|
|
462
|
+
"ordinal_position",
|
|
463
|
+
"spanner_type",
|
|
464
|
+
],
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
def get_openlineage_default_schema(self) -> str | None:
|
|
468
|
+
"""
|
|
469
|
+
Spanner expose 'public' or '' schema depending on dialect(Postgres vs GoogleSQL).
|
|
470
|
+
|
|
471
|
+
SQLAlchemy dialect for Spanner does not expose default schema, so we return None
|
|
472
|
+
to follow the same approach.
|
|
473
|
+
"""
|
|
474
|
+
return None
|