apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +39 -5
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/bigquery.py +166 -281
- airflow/providers/google/cloud/hooks/cloud_composer.py +287 -14
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_run.py +17 -9
- airflow/providers/google/cloud/hooks/cloud_sql.py +101 -22
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +27 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +5 -1
- airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
- airflow/providers/google/cloud/hooks/dataflow.py +71 -94
- airflow/providers/google/cloud/hooks/datafusion.py +1 -1
- airflow/providers/google/cloud/hooks/dataplex.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +72 -71
- airflow/providers/google/cloud/hooks/gcs.py +111 -14
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/mlengine.py +3 -2
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +73 -8
- airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -209
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +2 -2
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +27 -1
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/vision.py +2 -2
- airflow/providers/google/cloud/hooks/workflows.py +1 -1
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -13
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -95
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +17 -9
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +10 -8
- airflow/providers/google/cloud/openlineage/utils.py +15 -1
- airflow/providers/google/cloud/operators/alloy_db.py +70 -55
- airflow/providers/google/cloud/operators/bigquery.py +73 -636
- airflow/providers/google/cloud/operators/bigquery_dts.py +3 -5
- airflow/providers/google/cloud/operators/bigtable.py +36 -7
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +2 -2
- airflow/providers/google/cloud/operators/cloud_build.py +75 -32
- airflow/providers/google/cloud/operators/cloud_composer.py +128 -40
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
- airflow/providers/google/cloud/operators/cloud_run.py +23 -5
- airflow/providers/google/cloud/operators/cloud_sql.py +8 -16
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -11
- airflow/providers/google/cloud/operators/compute.py +8 -40
- airflow/providers/google/cloud/operators/datacatalog.py +157 -21
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +15 -5
- airflow/providers/google/cloud/operators/datafusion.py +41 -20
- airflow/providers/google/cloud/operators/dataplex.py +193 -109
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +78 -35
- airflow/providers/google/cloud/operators/dataproc_metastore.py +96 -88
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +6 -29
- airflow/providers/google/cloud/operators/functions.py +16 -7
- airflow/providers/google/cloud/operators/gcs.py +10 -8
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +60 -99
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +107 -52
- airflow/providers/google/cloud/operators/natural_language.py +1 -1
- airflow/providers/google/cloud/operators/pubsub.py +60 -14
- airflow/providers/google/cloud/operators/spanner.py +25 -12
- airflow/providers/google/cloud/operators/speech_to_text.py +1 -2
- airflow/providers/google/cloud/operators/stackdriver.py +1 -9
- airflow/providers/google/cloud/operators/tasks.py +1 -12
- airflow/providers/google/cloud/operators/text_to_speech.py +1 -2
- airflow/providers/google/cloud/operators/translate.py +40 -16
- airflow/providers/google/cloud/operators/translate_speech.py +1 -2
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +29 -9
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +54 -26
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +11 -9
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +30 -7
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
- airflow/providers/google/cloud/operators/vision.py +2 -2
- airflow/providers/google/cloud/operators/workflows.py +18 -15
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -2
- airflow/providers/google/cloud/sensors/bigtable.py +11 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +533 -29
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -2
- airflow/providers/google/cloud/sensors/dataflow.py +26 -9
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +2 -2
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -2
- airflow/providers/google/cloud/sensors/gcs.py +4 -4
- airflow/providers/google/cloud/sensors/looker.py +2 -2
- airflow/providers/google/cloud/sensors/pubsub.py +4 -4
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +2 -2
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +20 -12
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +13 -4
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_composer.py +302 -46
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +91 -1
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +122 -52
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +15 -19
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/common/auth_backend/google_openid.py +4 -4
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +27 -8
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +3 -4
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +56 -52
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +26 -1
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +1 -2
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +3 -63
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +92 -48
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/hooks/automl.py +0 -673
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1362
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -112
- apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
|
4
|
+
# distributed with this work for additional information
|
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
|
7
|
+
# "License"); you may not use this file except in compliance
|
|
8
|
+
# with the License. You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
|
13
|
+
# software distributed under the License is distributed on an
|
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
15
|
+
# KIND, either express or implied. See the License for the
|
|
16
|
+
# specific language governing permissions and limitations
|
|
17
|
+
# under the License.
|
|
18
|
+
"""This module contains a Google Cloud GenAI Generative Model hook."""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import time
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
|
+
|
|
25
|
+
from google import genai
|
|
26
|
+
|
|
27
|
+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from google.genai.types import (
|
|
31
|
+
ContentListUnion,
|
|
32
|
+
ContentListUnionDict,
|
|
33
|
+
CountTokensConfigOrDict,
|
|
34
|
+
CountTokensResponse,
|
|
35
|
+
CreateCachedContentConfigOrDict,
|
|
36
|
+
CreateTuningJobConfigOrDict,
|
|
37
|
+
EmbedContentConfigOrDict,
|
|
38
|
+
EmbedContentResponse,
|
|
39
|
+
GenerateContentConfig,
|
|
40
|
+
TuningDatasetOrDict,
|
|
41
|
+
TuningJob,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GenAIGenerativeModelHook(GoogleBaseHook):
|
|
46
|
+
"""Class for Google Cloud Generative AI Vertex AI hook."""
|
|
47
|
+
|
|
48
|
+
def get_genai_client(self, project_id: str, location: str):
|
|
49
|
+
return genai.Client(
|
|
50
|
+
vertexai=True,
|
|
51
|
+
project=project_id,
|
|
52
|
+
location=location,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
56
|
+
def embed_content(
|
|
57
|
+
self,
|
|
58
|
+
model: str,
|
|
59
|
+
location: str,
|
|
60
|
+
contents: ContentListUnion | ContentListUnionDict | list[str],
|
|
61
|
+
config: EmbedContentConfigOrDict | None = None,
|
|
62
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
63
|
+
) -> EmbedContentResponse:
|
|
64
|
+
"""
|
|
65
|
+
Generate embeddings for words, phrases, sentences, and code.
|
|
66
|
+
|
|
67
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
68
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
69
|
+
:param model: Required. The model to use.
|
|
70
|
+
:param contents: Optional. The contents to use for embedding.
|
|
71
|
+
:param config: Optional. Configuration for embeddings.
|
|
72
|
+
"""
|
|
73
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
74
|
+
|
|
75
|
+
resp = client.models.embed_content(model=model, contents=contents, config=config)
|
|
76
|
+
return resp
|
|
77
|
+
|
|
78
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
79
|
+
def generate_content(
|
|
80
|
+
self,
|
|
81
|
+
location: str,
|
|
82
|
+
model: str,
|
|
83
|
+
contents: ContentListUnionDict,
|
|
84
|
+
generation_config: GenerateContentConfig | None = None,
|
|
85
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
86
|
+
) -> str:
|
|
87
|
+
"""
|
|
88
|
+
Make an API request to generate content using a model.
|
|
89
|
+
|
|
90
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
91
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
92
|
+
:param model: Required. The model to use.
|
|
93
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
|
94
|
+
gives to the generative model, in order to elicit a specific response.
|
|
95
|
+
:param generation_config: Optional. Generation configuration settings.
|
|
96
|
+
"""
|
|
97
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
98
|
+
response = client.models.generate_content(
|
|
99
|
+
model=model,
|
|
100
|
+
contents=contents,
|
|
101
|
+
config=generation_config,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return response.text
|
|
105
|
+
|
|
106
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
107
|
+
def supervised_fine_tuning_train(
|
|
108
|
+
self,
|
|
109
|
+
source_model: str,
|
|
110
|
+
location: str,
|
|
111
|
+
training_dataset: TuningDatasetOrDict,
|
|
112
|
+
tuning_job_config: CreateTuningJobConfigOrDict | dict[str, Any] | None = None,
|
|
113
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
114
|
+
) -> TuningJob:
|
|
115
|
+
"""
|
|
116
|
+
Create a tuning job to adapt model behavior with a labeled dataset.
|
|
117
|
+
|
|
118
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
119
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
120
|
+
:param source_model: Required. A pre-trained model optimized for performing natural
|
|
121
|
+
language tasks such as classification, summarization, extraction, content
|
|
122
|
+
creation, and ideation.
|
|
123
|
+
:param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
|
|
124
|
+
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
|
|
125
|
+
:param tuning_job_config: Optional. Configuration of the Tuning job to be created.
|
|
126
|
+
"""
|
|
127
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
128
|
+
|
|
129
|
+
tuning_job = client.tunings.tune(
|
|
130
|
+
base_model=source_model,
|
|
131
|
+
training_dataset=training_dataset,
|
|
132
|
+
config=tuning_job_config,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Poll until completion
|
|
136
|
+
running = {"JOB_STATE_PENDING", "JOB_STATE_RUNNING"}
|
|
137
|
+
while tuning_job.state in running:
|
|
138
|
+
time.sleep(60)
|
|
139
|
+
tuning_job = client.tunings.get(name=tuning_job.name)
|
|
140
|
+
|
|
141
|
+
return tuning_job
|
|
142
|
+
|
|
143
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
144
|
+
def count_tokens(
|
|
145
|
+
self,
|
|
146
|
+
location: str,
|
|
147
|
+
model: str,
|
|
148
|
+
contents: ContentListUnion | ContentListUnionDict,
|
|
149
|
+
config: CountTokensConfigOrDict | None = None,
|
|
150
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
151
|
+
) -> CountTokensResponse:
|
|
152
|
+
"""
|
|
153
|
+
Use Count Tokens API to calculate the number of input tokens before sending a request to Gemini API.
|
|
154
|
+
|
|
155
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
156
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
157
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
|
158
|
+
gives to the generative model, in order to elicit a specific response.
|
|
159
|
+
:param model: Required. Model,
|
|
160
|
+
supporting prompts with text-only input, including natural language
|
|
161
|
+
tasks, multi-turn text and code chat, and code generation. It can
|
|
162
|
+
output text and code.
|
|
163
|
+
:param config: Optional. Configuration for Count Tokens.
|
|
164
|
+
"""
|
|
165
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
166
|
+
response = client.models.count_tokens(
|
|
167
|
+
model=model,
|
|
168
|
+
contents=contents,
|
|
169
|
+
config=config,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return response
|
|
173
|
+
|
|
174
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
|
175
|
+
def create_cached_content(
|
|
176
|
+
self,
|
|
177
|
+
model: str,
|
|
178
|
+
location: str,
|
|
179
|
+
cached_content_config: CreateCachedContentConfigOrDict | None = None,
|
|
180
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
|
181
|
+
) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Create CachedContent to reduce the cost of requests containing repeat content.
|
|
184
|
+
|
|
185
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
186
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
187
|
+
:param model: Required. The name of the publisher model to use for cached content.
|
|
188
|
+
:param cached_content_config: Optional. Configuration of the Cached Content.
|
|
189
|
+
"""
|
|
190
|
+
client = self.get_genai_client(project_id=project_id, location=location)
|
|
191
|
+
resp = client.caches.create(
|
|
192
|
+
model=model,
|
|
193
|
+
config=cached_content_config,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return resp.name
|
|
@@ -30,7 +30,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
|
30
30
|
from google.auth.transport import requests as google_requests
|
|
31
31
|
|
|
32
32
|
# not sure why but mypy complains on missing `container_v1` but it is clearly there and is importable
|
|
33
|
-
from google.cloud import exceptions
|
|
33
|
+
from google.cloud import exceptions
|
|
34
34
|
from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient
|
|
35
35
|
from google.cloud.container_v1.types import Cluster, Operation
|
|
36
36
|
from kubernetes import client
|
|
@@ -498,7 +498,7 @@ class GKEKubernetesAsyncHook(GoogleBaseAsyncHook, AsyncKubernetesHook):
|
|
|
498
498
|
)
|
|
499
499
|
|
|
500
500
|
@contextlib.asynccontextmanager
|
|
501
|
-
async def get_conn(self) -> async_client.ApiClient:
|
|
501
|
+
async def get_conn(self) -> async_client.ApiClient:
|
|
502
502
|
kube_client = None
|
|
503
503
|
try:
|
|
504
504
|
kube_client = await self._load_config()
|
|
@@ -29,7 +29,7 @@ from looker_sdk.sdk.api40 import methods as methods40
|
|
|
29
29
|
from packaging.version import parse as parse_version
|
|
30
30
|
|
|
31
31
|
from airflow.exceptions import AirflowException
|
|
32
|
-
from airflow.
|
|
32
|
+
from airflow.providers.common.compat.sdk import BaseHook
|
|
33
33
|
from airflow.version import version
|
|
34
34
|
|
|
35
35
|
if TYPE_CHECKING:
|
|
@@ -39,6 +39,11 @@ if TYPE_CHECKING:
|
|
|
39
39
|
class LookerHook(BaseHook):
|
|
40
40
|
"""Hook for Looker APIs."""
|
|
41
41
|
|
|
42
|
+
conn_name_attr = "looker_conn_id"
|
|
43
|
+
default_conn_name = "looker_default"
|
|
44
|
+
conn_type = "gcp_looker"
|
|
45
|
+
hook_name = "Google Looker"
|
|
46
|
+
|
|
42
47
|
def __init__(
|
|
43
48
|
self,
|
|
44
49
|
looker_conn_id: str,
|
|
@@ -23,7 +23,8 @@ import contextlib
|
|
|
23
23
|
import logging
|
|
24
24
|
import random
|
|
25
25
|
import time
|
|
26
|
-
from
|
|
26
|
+
from collections.abc import Callable
|
|
27
|
+
from typing import TYPE_CHECKING
|
|
27
28
|
|
|
28
29
|
from aiohttp import ClientSession
|
|
29
30
|
from gcloud.aio.auth import AioSession, Token
|
|
@@ -587,7 +588,7 @@ class MLEngineAsyncHook(GoogleBaseAsyncHook):
|
|
|
587
588
|
job = await self.get_job(
|
|
588
589
|
project_id=project_id,
|
|
589
590
|
job_id=job_id,
|
|
590
|
-
session=session, #
|
|
591
|
+
session=session, # type: ignore
|
|
591
592
|
)
|
|
592
593
|
job = await job.json(content_type=None)
|
|
593
594
|
self.log.info("Retrieving json_response: %s", job)
|
|
@@ -47,6 +47,10 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
47
47
|
See https://cloud.google.com/secret-manager
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
|
+
def __init__(self, location: str | None = None, **kwargs) -> None:
|
|
51
|
+
super().__init__(**kwargs)
|
|
52
|
+
self.location = location
|
|
53
|
+
|
|
50
54
|
@cached_property
|
|
51
55
|
def client(self):
|
|
52
56
|
"""
|
|
@@ -54,7 +58,16 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
54
58
|
|
|
55
59
|
:return: Secret Manager client.
|
|
56
60
|
"""
|
|
57
|
-
|
|
61
|
+
if self.location is not None:
|
|
62
|
+
return SecretManagerServiceClient(
|
|
63
|
+
credentials=self.get_credentials(),
|
|
64
|
+
client_info=CLIENT_INFO,
|
|
65
|
+
client_options={"api_endpoint": f"secretmanager.{self.location}.rep.googleapis.com"},
|
|
66
|
+
)
|
|
67
|
+
return SecretManagerServiceClient(
|
|
68
|
+
credentials=self.get_credentials(),
|
|
69
|
+
client_info=CLIENT_INFO,
|
|
70
|
+
)
|
|
58
71
|
|
|
59
72
|
def get_conn(self) -> SecretManagerServiceClient:
|
|
60
73
|
"""
|
|
@@ -64,6 +77,60 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
64
77
|
"""
|
|
65
78
|
return self.client
|
|
66
79
|
|
|
80
|
+
def _get_parent(self, project_id: str, location: str | None = None) -> str:
|
|
81
|
+
"""
|
|
82
|
+
Return parent path.
|
|
83
|
+
|
|
84
|
+
:param project_id: Required. ID of the GCP project that owns the job.
|
|
85
|
+
:param location: Optional. Target location. If set to ``None`` or missing, the location provided for GoogleCloudSecretHook instantiation is used
|
|
86
|
+
For more details : https://cloud.google.com/secret-manager/docs/locations
|
|
87
|
+
|
|
88
|
+
:return: Parent path.
|
|
89
|
+
"""
|
|
90
|
+
_location = location or self.location
|
|
91
|
+
if _location is not None:
|
|
92
|
+
return self.client.common_location_path(project_id, _location)
|
|
93
|
+
return self.client.common_project_path(project_id)
|
|
94
|
+
|
|
95
|
+
def _get_secret_path(self, project_id: str, secret_id: str, location: str | None = None) -> str:
|
|
96
|
+
"""
|
|
97
|
+
Return secret path.
|
|
98
|
+
|
|
99
|
+
:param project_id: Required. ID of the GCP project that owns the job.
|
|
100
|
+
:param secret_id: Required. Secret ID for which path is required.
|
|
101
|
+
:param location: Optional. Target location. If set to ``None`` or missing, the location provided for GoogleCloudSecretHook instantiation is used
|
|
102
|
+
For more details : https://cloud.google.com/secret-manager/docs/locations
|
|
103
|
+
|
|
104
|
+
:return: Parent path.
|
|
105
|
+
"""
|
|
106
|
+
_location = location or self.location
|
|
107
|
+
if _location is not None:
|
|
108
|
+
# Google's client library does not provide a method to construct regional secret paths, so constructing manually.
|
|
109
|
+
return f"projects/{project_id}/locations/{_location}/secrets/{secret_id}"
|
|
110
|
+
return self.client.secret_path(project_id, secret_id)
|
|
111
|
+
|
|
112
|
+
def _get_secret_version_path(
|
|
113
|
+
self, project_id: str, secret_id: str, secret_version: str, location: str | None = None
|
|
114
|
+
) -> str:
|
|
115
|
+
"""
|
|
116
|
+
Return secret version path.
|
|
117
|
+
|
|
118
|
+
:param project_id: Required. ID of the GCP project that owns the job.
|
|
119
|
+
:param secret_id: Required. Secret ID for which path is required.
|
|
120
|
+
:param secret_version: Required. Secret version for which path is required.
|
|
121
|
+
:param location: Optional. Target location. If set to ``None`` or missing, the location provided for GoogleCloudSecretHook instantiation is used
|
|
122
|
+
For more details : https://cloud.google.com/secret-manager/docs/locations
|
|
123
|
+
|
|
124
|
+
:return: Parent path.
|
|
125
|
+
"""
|
|
126
|
+
_location = location or self.location
|
|
127
|
+
if _location is not None:
|
|
128
|
+
# Google's client library does not provide a method to construct regional secret version paths, so constructing manually.
|
|
129
|
+
return (
|
|
130
|
+
f"projects/{project_id}/locations/{_location}/secrets/{secret_id}/versions/{secret_version}"
|
|
131
|
+
)
|
|
132
|
+
return self.client.secret_version_path(project_id, secret_id, secret_version)
|
|
133
|
+
|
|
67
134
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
68
135
|
def create_secret(
|
|
69
136
|
self,
|
|
@@ -73,6 +140,7 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
73
140
|
retry: Retry | _MethodDefault = DEFAULT,
|
|
74
141
|
timeout: float | None = None,
|
|
75
142
|
metadata: Sequence[tuple[str, str]] = (),
|
|
143
|
+
location: str | None = None,
|
|
76
144
|
) -> Secret:
|
|
77
145
|
"""
|
|
78
146
|
Create a secret.
|
|
@@ -88,12 +156,20 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
88
156
|
:param retry: Optional. Designation of what errors, if any, should be retried.
|
|
89
157
|
:param timeout: Optional. The timeout for this request.
|
|
90
158
|
:param metadata: Optional. Strings which should be sent along with the request as metadata.
|
|
159
|
+
:param location: Optional. Location where secret should be created. Used for creating regional secret. If set to ``None`` or missing, the location provided for GoogleCloudSecretHook instantiation is used
|
|
160
|
+
For more details : https://cloud.google.com/secret-manager/docs/locations
|
|
91
161
|
:return: Secret object.
|
|
92
162
|
"""
|
|
93
|
-
|
|
163
|
+
if not secret:
|
|
164
|
+
_secret: dict | Secret = {}
|
|
165
|
+
if (location or self.location) is None:
|
|
166
|
+
_secret["replication"] = {"automatic": {}}
|
|
167
|
+
else:
|
|
168
|
+
_secret = secret
|
|
169
|
+
|
|
94
170
|
response = self.client.create_secret(
|
|
95
171
|
request={
|
|
96
|
-
"parent":
|
|
172
|
+
"parent": self._get_parent(project_id, location),
|
|
97
173
|
"secret_id": secret_id,
|
|
98
174
|
"secret": _secret,
|
|
99
175
|
},
|
|
@@ -113,6 +189,7 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
113
189
|
retry: Retry | _MethodDefault = DEFAULT,
|
|
114
190
|
timeout: float | None = None,
|
|
115
191
|
metadata: Sequence[tuple[str, str]] = (),
|
|
192
|
+
location: str | None = None,
|
|
116
193
|
) -> SecretVersion:
|
|
117
194
|
"""
|
|
118
195
|
Add a version to the secret.
|
|
@@ -128,11 +205,13 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
128
205
|
:param retry: Optional. Designation of what errors, if any, should be retried.
|
|
129
206
|
:param timeout: Optional. The timeout for this request.
|
|
130
207
|
:param metadata: Optional. Strings which should be sent along with the request as metadata.
|
|
208
|
+
:param location: Optional. Location where secret is located. Used for adding version to regional secret. If set to ``None`` or missing, the location provided for GoogleCloudSecretHook instantiation is used
|
|
209
|
+
For more details : https://cloud.google.com/secret-manager/docs/locations
|
|
131
210
|
:return: Secret version object.
|
|
132
211
|
"""
|
|
133
212
|
response = self.client.add_secret_version(
|
|
134
213
|
request={
|
|
135
|
-
"parent":
|
|
214
|
+
"parent": self._get_secret_path(project_id, secret_id, location),
|
|
136
215
|
"payload": secret_payload,
|
|
137
216
|
},
|
|
138
217
|
retry=retry,
|
|
@@ -152,6 +231,7 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
152
231
|
retry: Retry | _MethodDefault = DEFAULT,
|
|
153
232
|
timeout: float | None = None,
|
|
154
233
|
metadata: Sequence[tuple[str, str]] = (),
|
|
234
|
+
location: str | None = None,
|
|
155
235
|
) -> ListSecretsPager:
|
|
156
236
|
"""
|
|
157
237
|
List secrets.
|
|
@@ -168,11 +248,13 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
168
248
|
:param retry: Optional. Designation of what errors, if any, should be retried.
|
|
169
249
|
:param timeout: Optional. The timeout for this request.
|
|
170
250
|
:param metadata: Optional. Strings which should be sent along with the request as metadata.
|
|
251
|
+
:param location: Optional. The regional secrets stored in the provided location will be listed. If set to ``None`` or missing, the location provided for GoogleCloudSecretHook instantiation is used
|
|
252
|
+
For more details : https://cloud.google.com/secret-manager/docs/locations
|
|
171
253
|
:return: Secret List object.
|
|
172
254
|
"""
|
|
173
255
|
response = self.client.list_secrets(
|
|
174
256
|
request={
|
|
175
|
-
"parent":
|
|
257
|
+
"parent": self._get_parent(project_id, location),
|
|
176
258
|
"page_size": page_size,
|
|
177
259
|
"page_token": page_token,
|
|
178
260
|
"filter": secret_filter,
|
|
@@ -185,18 +267,22 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
185
267
|
return response
|
|
186
268
|
|
|
187
269
|
@GoogleBaseHook.fallback_to_default_project_id
|
|
188
|
-
def secret_exists(self, project_id: str, secret_id: str) -> bool:
|
|
270
|
+
def secret_exists(self, project_id: str, secret_id: str, location: str | None = None) -> bool:
|
|
189
271
|
"""
|
|
190
272
|
Check whether secret exists.
|
|
191
273
|
|
|
192
274
|
:param project_id: Required. ID of the GCP project that owns the job.
|
|
193
275
|
If set to ``None`` or missing, the default project_id from the GCP connection is used.
|
|
194
276
|
:param secret_id: Required. ID of the secret to find.
|
|
277
|
+
:param location: Optional. Location where secret is expected to be stored regionally. If set to ``None`` or missing, the location provided for GoogleCloudSecretHook instantiation is used
|
|
278
|
+
For more details : https://cloud.google.com/secret-manager/docs/locations
|
|
195
279
|
:return: True if the secret exists, False otherwise.
|
|
196
280
|
"""
|
|
197
281
|
secret_filter = f"name:{secret_id}"
|
|
198
|
-
secret_name = self.
|
|
199
|
-
for secret in self.list_secrets(
|
|
282
|
+
secret_name = self._get_secret_path(project_id, secret_id, location)
|
|
283
|
+
for secret in self.list_secrets(
|
|
284
|
+
project_id=project_id, page_size=100, secret_filter=secret_filter, location=location
|
|
285
|
+
):
|
|
200
286
|
if secret.name.split("/")[-1] == secret_id:
|
|
201
287
|
self.log.info("Secret %s exists.", secret_name)
|
|
202
288
|
return True
|
|
@@ -212,6 +298,7 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
212
298
|
retry: Retry | _MethodDefault = DEFAULT,
|
|
213
299
|
timeout: float | None = None,
|
|
214
300
|
metadata: Sequence[tuple[str, str]] = (),
|
|
301
|
+
location: str | None = None,
|
|
215
302
|
) -> AccessSecretVersionResponse:
|
|
216
303
|
"""
|
|
217
304
|
Access a secret version.
|
|
@@ -227,11 +314,13 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
227
314
|
:param retry: Optional. Designation of what errors, if any, should be retried.
|
|
228
315
|
:param timeout: Optional. The timeout for this request.
|
|
229
316
|
:param metadata: Optional. Strings which should be sent along with the request as metadata.
|
|
317
|
+
:param location: Optional. Location where secret is stored regionally. If set to ``None`` or missing, the location provided for GoogleCloudSecretHook instantiation is used
|
|
318
|
+
For more details : https://cloud.google.com/secret-manager/docs/locations
|
|
230
319
|
:return: Access secret version response object.
|
|
231
320
|
"""
|
|
232
321
|
response = self.client.access_secret_version(
|
|
233
322
|
request={
|
|
234
|
-
"name": self.
|
|
323
|
+
"name": self._get_secret_version_path(project_id, secret_id, secret_version, location),
|
|
235
324
|
},
|
|
236
325
|
retry=retry,
|
|
237
326
|
timeout=timeout,
|
|
@@ -248,6 +337,7 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
248
337
|
retry: Retry | _MethodDefault = DEFAULT,
|
|
249
338
|
timeout: float | None = None,
|
|
250
339
|
metadata: Sequence[tuple[str, str]] = (),
|
|
340
|
+
location: str | None = None,
|
|
251
341
|
) -> None:
|
|
252
342
|
"""
|
|
253
343
|
Delete a secret.
|
|
@@ -262,9 +352,11 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
|
262
352
|
:param retry: Optional. Designation of what errors, if any, should be retried.
|
|
263
353
|
:param timeout: Optional. The timeout for this request.
|
|
264
354
|
:param metadata: Optional. Strings which should be sent along with the request as metadata.
|
|
355
|
+
:param location: Optional. Location where secret is stored regionally. If set to ``None`` or missing, the location provided for GoogleCloudSecretHook instantiation is used.
|
|
356
|
+
For more details : https://cloud.google.com/secret-manager/docs/locations
|
|
265
357
|
:return: Access secret version response object.
|
|
266
358
|
"""
|
|
267
|
-
name = self.
|
|
359
|
+
name = self._get_secret_path(project_id, secret_id, location)
|
|
268
360
|
self.client.delete_secret(
|
|
269
361
|
request={"name": name},
|
|
270
362
|
retry=retry,
|
|
@@ -19,8 +19,9 @@
|
|
|
19
19
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
|
-
from collections
|
|
23
|
-
from
|
|
22
|
+
from collections import OrderedDict
|
|
23
|
+
from collections.abc import Callable, Sequence
|
|
24
|
+
from typing import TYPE_CHECKING, NamedTuple
|
|
24
25
|
|
|
25
26
|
from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
|
|
26
27
|
from google.cloud.spanner_v1.client import Client
|
|
@@ -30,6 +31,7 @@ from airflow.exceptions import AirflowException
|
|
|
30
31
|
from airflow.providers.common.sql.hooks.sql import DbApiHook
|
|
31
32
|
from airflow.providers.google.common.consts import CLIENT_INFO
|
|
32
33
|
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
|
|
34
|
+
from airflow.providers.openlineage.sqlparser import DatabaseInfo
|
|
33
35
|
|
|
34
36
|
if TYPE_CHECKING:
|
|
35
37
|
from google.cloud.spanner_v1.database import Database
|
|
@@ -37,6 +39,8 @@ if TYPE_CHECKING:
|
|
|
37
39
|
from google.cloud.spanner_v1.transaction import Transaction
|
|
38
40
|
from google.longrunning.operations_grpc_pb2 import Operation
|
|
39
41
|
|
|
42
|
+
from airflow.models.connection import Connection
|
|
43
|
+
|
|
40
44
|
|
|
41
45
|
class SpannerConnectionParams(NamedTuple):
|
|
42
46
|
"""Information about Google Spanner connection parameters."""
|
|
@@ -388,7 +392,7 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
|
|
|
388
392
|
database_id: str,
|
|
389
393
|
queries: list[str],
|
|
390
394
|
project_id: str,
|
|
391
|
-
) ->
|
|
395
|
+
) -> list[int]:
|
|
392
396
|
"""
|
|
393
397
|
Execute an arbitrary DML query (INSERT, UPDATE, DELETE).
|
|
394
398
|
|
|
@@ -398,12 +402,73 @@ class SpannerHook(GoogleBaseHook, DbApiHook):
|
|
|
398
402
|
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
|
|
399
403
|
database. If set to None or missing, the default project_id from the Google Cloud connection
|
|
400
404
|
is used.
|
|
405
|
+
:return: list of numbers of affected rows by DML query
|
|
401
406
|
"""
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
407
|
+
db = (
|
|
408
|
+
self._get_client(project_id=project_id)
|
|
409
|
+
.instance(instance_id=instance_id)
|
|
410
|
+
.database(database_id=database_id)
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
def _tx_runner(tx: Transaction) -> dict[str, int]:
|
|
414
|
+
return self._execute_sql_in_transaction(tx, queries)
|
|
415
|
+
|
|
416
|
+
result = db.run_in_transaction(_tx_runner)
|
|
417
|
+
|
|
418
|
+
result_rows_count_per_query = []
|
|
419
|
+
for i, (sql, rc) in enumerate(result.items(), start=1):
|
|
420
|
+
if not sql.startswith("SELECT"):
|
|
421
|
+
preview = sql if len(sql) <= 300 else sql[:300] + "…"
|
|
422
|
+
self.log.info("[DML %d/%d] affected rows=%d | %s", i, len(result), rc, preview)
|
|
423
|
+
result_rows_count_per_query.append(rc)
|
|
424
|
+
return result_rows_count_per_query
|
|
405
425
|
|
|
406
426
|
@staticmethod
|
|
407
|
-
def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]):
|
|
427
|
+
def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]) -> dict[str, int]:
|
|
428
|
+
counts: OrderedDict[str, int] = OrderedDict()
|
|
408
429
|
for sql in queries:
|
|
409
|
-
transaction.execute_update(sql)
|
|
430
|
+
rc = transaction.execute_update(sql)
|
|
431
|
+
counts[sql] = rc
|
|
432
|
+
return counts
|
|
433
|
+
|
|
434
|
+
def _get_openlineage_authority_part(self, connection: Connection) -> str | None:
|
|
435
|
+
"""Build Spanner-specific authority part for OpenLineage. Returns {project}/{instance}."""
|
|
436
|
+
extras = connection.extra_dejson
|
|
437
|
+
project_id = extras.get("project_id")
|
|
438
|
+
instance_id = extras.get("instance_id")
|
|
439
|
+
|
|
440
|
+
if not project_id or not instance_id:
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
return f"{project_id}/{instance_id}"
|
|
444
|
+
|
|
445
|
+
def get_openlineage_database_dialect(self, connection: Connection) -> str:
|
|
446
|
+
"""Return database dialect for OpenLineage."""
|
|
447
|
+
return "spanner"
|
|
448
|
+
|
|
449
|
+
def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
|
|
450
|
+
"""Return Spanner specific information for OpenLineage."""
|
|
451
|
+
extras = connection.extra_dejson
|
|
452
|
+
database_id = extras.get("database_id")
|
|
453
|
+
|
|
454
|
+
return DatabaseInfo(
|
|
455
|
+
scheme=self.get_openlineage_database_dialect(connection),
|
|
456
|
+
authority=self._get_openlineage_authority_part(connection),
|
|
457
|
+
database=database_id,
|
|
458
|
+
information_schema_columns=[
|
|
459
|
+
"table_schema",
|
|
460
|
+
"table_name",
|
|
461
|
+
"column_name",
|
|
462
|
+
"ordinal_position",
|
|
463
|
+
"spanner_type",
|
|
464
|
+
],
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
def get_openlineage_default_schema(self) -> str | None:
|
|
468
|
+
"""
|
|
469
|
+
Spanner expose 'public' or '' schema depending on dialect(Postgres vs GoogleSQL).
|
|
470
|
+
|
|
471
|
+
SQLAlchemy dialect for Spanner does not expose default schema, so we return None
|
|
472
|
+
to follow the same approach.
|
|
473
|
+
"""
|
|
474
|
+
return None
|
|
@@ -261,8 +261,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
261
261
|
channel_name_map = {}
|
|
262
262
|
|
|
263
263
|
for channel in channels:
|
|
264
|
+
# This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
|
|
264
265
|
channel.verification_status = (
|
|
265
|
-
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
|
|
266
|
+
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
|
|
266
267
|
)
|
|
267
268
|
|
|
268
269
|
if channel.name in existing_channels:
|
|
@@ -274,7 +275,7 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
274
275
|
)
|
|
275
276
|
else:
|
|
276
277
|
old_name = channel.name
|
|
277
|
-
channel.name
|
|
278
|
+
del channel.name
|
|
278
279
|
new_channel = channel_client.create_notification_channel(
|
|
279
280
|
request={"name": f"projects/{project_id}", "notification_channel": channel},
|
|
280
281
|
retry=retry,
|
|
@@ -284,8 +285,8 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
284
285
|
channel_name_map[old_name] = new_channel.name
|
|
285
286
|
|
|
286
287
|
for policy in policies_:
|
|
287
|
-
policy.creation_record
|
|
288
|
-
policy.mutation_record
|
|
288
|
+
del policy.creation_record
|
|
289
|
+
del policy.mutation_record
|
|
289
290
|
|
|
290
291
|
for i, channel in enumerate(policy.notification_channels):
|
|
291
292
|
new_channel = channel_name_map.get(channel)
|
|
@@ -301,9 +302,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
301
302
|
metadata=metadata,
|
|
302
303
|
)
|
|
303
304
|
else:
|
|
304
|
-
policy.name
|
|
305
|
+
del policy.name
|
|
305
306
|
for condition in policy.conditions:
|
|
306
|
-
condition.name
|
|
307
|
+
del condition.name
|
|
307
308
|
policy_client.create_alert_policy(
|
|
308
309
|
request={"name": f"projects/{project_id}", "alert_policy": policy},
|
|
309
310
|
retry=retry,
|
|
@@ -531,8 +532,9 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
531
532
|
channels_list.append(NotificationChannel(**channel))
|
|
532
533
|
|
|
533
534
|
for channel in channels_list:
|
|
535
|
+
# This field is immutable, illegal to specifying non-default UNVERIFIED or VERIFIED, so setting default
|
|
534
536
|
channel.verification_status = (
|
|
535
|
-
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED
|
|
537
|
+
monitoring_v3.NotificationChannel.VerificationStatus.VERIFICATION_STATUS_UNSPECIFIED # type: ignore[assignment]
|
|
536
538
|
)
|
|
537
539
|
|
|
538
540
|
if channel.name in existing_channels:
|
|
@@ -544,7 +546,7 @@ class StackdriverHook(GoogleBaseHook):
|
|
|
544
546
|
)
|
|
545
547
|
else:
|
|
546
548
|
old_name = channel.name
|
|
547
|
-
channel.name
|
|
549
|
+
del channel.name
|
|
548
550
|
new_channel = channel_client.create_notification_channel(
|
|
549
551
|
request={"name": f"projects/{project_id}", "notification_channel": channel},
|
|
550
552
|
retry=retry,
|
|
@@ -429,7 +429,7 @@ class TranslateHook(GoogleBaseHook, OperationHelper):
|
|
|
429
429
|
project_id: str,
|
|
430
430
|
location: str,
|
|
431
431
|
retry: Retry | _MethodDefault = DEFAULT,
|
|
432
|
-
timeout: float | _MethodDefault = DEFAULT,
|
|
432
|
+
timeout: float | None | _MethodDefault = DEFAULT,
|
|
433
433
|
metadata: Sequence[tuple[str, str]] = (),
|
|
434
434
|
) -> automl_translation.Dataset:
|
|
435
435
|
"""
|