apache-airflow-providers-google 10.20.0rc1__py3-none-any.whl → 10.21.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +16 -8
- airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -1
- airflow/providers/google/cloud/_internal_client/secret_manager_client.py +6 -3
- airflow/providers/google/cloud/hooks/bigquery.py +158 -79
- airflow/providers/google/cloud/hooks/cloud_sql.py +12 -6
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +34 -17
- airflow/providers/google/cloud/hooks/dataflow.py +30 -26
- airflow/providers/google/cloud/hooks/dataform.py +2 -1
- airflow/providers/google/cloud/hooks/datafusion.py +4 -2
- airflow/providers/google/cloud/hooks/dataproc.py +102 -51
- airflow/providers/google/cloud/hooks/functions.py +20 -10
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -11
- airflow/providers/google/cloud/hooks/os_login.py +2 -1
- airflow/providers/google/cloud/hooks/secret_manager.py +18 -9
- airflow/providers/google/cloud/hooks/translate.py +2 -1
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +2 -1
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +141 -0
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +2 -1
- airflow/providers/google/cloud/links/base.py +2 -1
- airflow/providers/google/cloud/links/datafusion.py +2 -1
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +4 -2
- airflow/providers/google/cloud/openlineage/mixins.py +10 -0
- airflow/providers/google/cloud/openlineage/utils.py +4 -2
- airflow/providers/google/cloud/operators/bigquery.py +55 -21
- airflow/providers/google/cloud/operators/cloud_batch.py +3 -1
- airflow/providers/google/cloud/operators/cloud_sql.py +22 -11
- airflow/providers/google/cloud/operators/dataform.py +2 -1
- airflow/providers/google/cloud/operators/dataproc.py +75 -34
- airflow/providers/google/cloud/operators/dataproc_metastore.py +24 -12
- airflow/providers/google/cloud/operators/gcs.py +2 -1
- airflow/providers/google/cloud/operators/pubsub.py +10 -5
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +3 -3
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +12 -9
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +243 -0
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +2 -1
- airflow/providers/google/cloud/operators/vision.py +36 -18
- airflow/providers/google/cloud/sensors/gcs.py +11 -2
- airflow/providers/google/cloud/sensors/pubsub.py +2 -1
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +21 -12
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +1 -1
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +17 -5
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/local_to_gcs.py +5 -1
- airflow/providers/google/cloud/transfers/mysql_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/presto_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +2 -1
- airflow/providers/google/cloud/transfers/trino_to_gcs.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -1
- airflow/providers/google/cloud/triggers/dataflow.py +2 -1
- airflow/providers/google/cloud/triggers/vertex_ai.py +2 -1
- airflow/providers/google/cloud/utils/external_token_supplier.py +4 -2
- airflow/providers/google/cloud/utils/field_sanitizer.py +4 -2
- airflow/providers/google/cloud/utils/field_validator.py +6 -3
- airflow/providers/google/cloud/utils/helpers.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +2 -1
- airflow/providers/google/common/utils/id_token_credentials.py +2 -1
- airflow/providers/google/get_provider_info.py +3 -2
- airflow/providers/google/go_module_utils.py +4 -2
- airflow/providers/google/marketing_platform/hooks/analytics_admin.py +12 -6
- airflow/providers/google/marketing_platform/links/analytics_admin.py +2 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +2 -1
- {apache_airflow_providers_google-10.20.0rc1.dist-info → apache_airflow_providers_google-10.21.0rc1.dist-info}/METADATA +8 -8
- {apache_airflow_providers_google-10.20.0rc1.dist-info → apache_airflow_providers_google-10.21.0rc1.dist-info}/RECORD +69 -69
- {apache_airflow_providers_google-10.20.0rc1.dist-info → apache_airflow_providers_google-10.21.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.20.0rc1.dist-info → apache_airflow_providers_google-10.21.0rc1.dist-info}/entry_points.txt +0 -0
@@ -104,7 +104,8 @@ class GKEClusterConnection:
|
|
104
104
|
|
105
105
|
|
106
106
|
class GKEHook(GoogleBaseHook):
|
107
|
-
"""
|
107
|
+
"""
|
108
|
+
Google Kubernetes Engine cluster APIs.
|
108
109
|
|
109
110
|
All the methods in the hook where project_id is used must be called with
|
110
111
|
keyword arguments rather than positional.
|
@@ -157,7 +158,8 @@ class GKEHook(GoogleBaseHook):
|
|
157
158
|
return self.get_conn()
|
158
159
|
|
159
160
|
def wait_for_operation(self, operation: Operation, project_id: str = PROVIDE_PROJECT_ID) -> Operation:
|
160
|
-
"""
|
161
|
+
"""
|
162
|
+
Continuously fetch the status from Google Cloud.
|
161
163
|
|
162
164
|
This is done until the given operation completes, or raises an error.
|
163
165
|
|
@@ -177,7 +179,8 @@ class GKEHook(GoogleBaseHook):
|
|
177
179
|
return operation
|
178
180
|
|
179
181
|
def get_operation(self, operation_name: str, project_id: str = PROVIDE_PROJECT_ID) -> Operation:
|
180
|
-
"""
|
182
|
+
"""
|
183
|
+
Get an operation from Google Cloud.
|
181
184
|
|
182
185
|
:param operation_name: Name of operation to fetch
|
183
186
|
:param project_id: Google Cloud project ID
|
@@ -192,7 +195,8 @@ class GKEHook(GoogleBaseHook):
|
|
192
195
|
|
193
196
|
@staticmethod
|
194
197
|
def _append_label(cluster_proto: Cluster, key: str, val: str) -> Cluster:
|
195
|
-
"""
|
198
|
+
"""
|
199
|
+
Append labels to provided Cluster Protobuf.
|
196
200
|
|
197
201
|
Labels must fit the regex ``[a-z]([-a-z0-9]*[a-z0-9])?`` (current
|
198
202
|
airflow version string follows semantic versioning spec: x.y.z).
|
@@ -216,7 +220,8 @@ class GKEHook(GoogleBaseHook):
|
|
216
220
|
retry: Retry | _MethodDefault = DEFAULT,
|
217
221
|
timeout: float | None = None,
|
218
222
|
) -> Operation | None:
|
219
|
-
"""
|
223
|
+
"""
|
224
|
+
Delete the cluster, the Kubernetes endpoint, and all worker nodes.
|
220
225
|
|
221
226
|
Firewalls and routes that were configured during cluster creation are
|
222
227
|
also deleted. Other Google Compute Engine resources that might be in use
|
@@ -259,7 +264,8 @@ class GKEHook(GoogleBaseHook):
|
|
259
264
|
retry: Retry | _MethodDefault = DEFAULT,
|
260
265
|
timeout: float | None = None,
|
261
266
|
) -> Operation | Cluster:
|
262
|
-
"""
|
267
|
+
"""
|
268
|
+
Create a cluster.
|
263
269
|
|
264
270
|
This should consist of the specified number, and the type of Google
|
265
271
|
Compute Engine instances.
|
@@ -314,7 +320,8 @@ class GKEHook(GoogleBaseHook):
|
|
314
320
|
retry: Retry | _MethodDefault = DEFAULT,
|
315
321
|
timeout: float | None = None,
|
316
322
|
) -> Cluster:
|
317
|
-
"""
|
323
|
+
"""
|
324
|
+
Get details of specified cluster.
|
318
325
|
|
319
326
|
:param name: The name of the cluster to retrieve.
|
320
327
|
:param project_id: Google Cloud project ID.
|
@@ -404,7 +411,8 @@ class GKEAsyncHook(GoogleBaseAsyncHook):
|
|
404
411
|
operation_name: str,
|
405
412
|
project_id: str = PROVIDE_PROJECT_ID,
|
406
413
|
) -> Operation:
|
407
|
-
"""
|
414
|
+
"""
|
415
|
+
Fetch an operation from Google Cloud.
|
408
416
|
|
409
417
|
:param operation_name: Name of operation to fetch.
|
410
418
|
:param project_id: Google Cloud project ID.
|
@@ -420,7 +428,8 @@ class GKEAsyncHook(GoogleBaseAsyncHook):
|
|
420
428
|
|
421
429
|
|
422
430
|
class GKEKubernetesHook(GoogleBaseHook, KubernetesHook):
|
423
|
-
"""
|
431
|
+
"""
|
432
|
+
GKE authenticated hook for standard Kubernetes API.
|
424
433
|
|
425
434
|
This hook provides full set of the standard Kubernetes API provided by the KubernetesHook,
|
426
435
|
and at the same time it provides a GKE authentication, so it makes it possible to KubernetesHook
|
@@ -506,7 +515,8 @@ class GKEKubernetesHook(GoogleBaseHook, KubernetesHook):
|
|
506
515
|
|
507
516
|
|
508
517
|
class GKEKubernetesAsyncHook(GoogleBaseAsyncHook, AsyncKubernetesHook):
|
509
|
-
"""
|
518
|
+
"""
|
519
|
+
Async GKE authenticated hook for standard Kubernetes API.
|
510
520
|
|
511
521
|
This hook provides full set of the standard Kubernetes API provided by the AsyncKubernetesHook,
|
512
522
|
and at the same time it provides a GKE authentication, so it makes it possible to KubernetesHook
|
@@ -639,7 +649,8 @@ class GKEJobHook(GKEKubernetesHook):
|
|
639
649
|
category=AirflowProviderDeprecationWarning,
|
640
650
|
)
|
641
651
|
class GKEPodAsyncHook(GKEKubernetesAsyncHook):
|
642
|
-
"""
|
652
|
+
"""
|
653
|
+
Google Kubernetes Engine pods APIs asynchronously.
|
643
654
|
|
644
655
|
:param cluster_url: The URL pointed to the cluster.
|
645
656
|
:param ssl_ca_cert: SSL certificate used for authentication to the pod.
|
@@ -111,21 +111,24 @@ class SecretsManagerHook(GoogleBaseHook):
|
|
111
111
|
|
112
112
|
|
113
113
|
class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
114
|
-
"""
|
114
|
+
"""
|
115
|
+
Hook for the Google Cloud Secret Manager API.
|
115
116
|
|
116
117
|
See https://cloud.google.com/secret-manager
|
117
118
|
"""
|
118
119
|
|
119
120
|
@cached_property
|
120
121
|
def client(self):
|
121
|
-
"""
|
122
|
+
"""
|
123
|
+
Create a Secret Manager Client.
|
122
124
|
|
123
125
|
:return: Secret Manager client.
|
124
126
|
"""
|
125
127
|
return SecretManagerServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
|
126
128
|
|
127
129
|
def get_conn(self) -> SecretManagerServiceClient:
|
128
|
-
"""
|
130
|
+
"""
|
131
|
+
Retrieve the connection to Secret Manager.
|
129
132
|
|
130
133
|
:return: Secret Manager client.
|
131
134
|
"""
|
@@ -141,7 +144,8 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
141
144
|
timeout: float | None = None,
|
142
145
|
metadata: Sequence[tuple[str, str]] = (),
|
143
146
|
) -> Secret:
|
144
|
-
"""
|
147
|
+
"""
|
148
|
+
Create a secret.
|
145
149
|
|
146
150
|
.. seealso::
|
147
151
|
For more details see API documentation:
|
@@ -180,7 +184,8 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
180
184
|
timeout: float | None = None,
|
181
185
|
metadata: Sequence[tuple[str, str]] = (),
|
182
186
|
) -> SecretVersion:
|
183
|
-
"""
|
187
|
+
"""
|
188
|
+
Add a version to the secret.
|
184
189
|
|
185
190
|
.. seealso::
|
186
191
|
For more details see API documentation:
|
@@ -218,7 +223,8 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
218
223
|
timeout: float | None = None,
|
219
224
|
metadata: Sequence[tuple[str, str]] = (),
|
220
225
|
) -> ListSecretsPager:
|
221
|
-
"""
|
226
|
+
"""
|
227
|
+
List secrets.
|
222
228
|
|
223
229
|
.. seealso::
|
224
230
|
For more details see API documentation:
|
@@ -250,7 +256,8 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
250
256
|
|
251
257
|
@GoogleBaseHook.fallback_to_default_project_id
|
252
258
|
def secret_exists(self, project_id: str, secret_id: str) -> bool:
|
253
|
-
"""
|
259
|
+
"""
|
260
|
+
Check whether secret exists.
|
254
261
|
|
255
262
|
:param project_id: Required. ID of the GCP project that owns the job.
|
256
263
|
If set to ``None`` or missing, the default project_id from the GCP connection is used.
|
@@ -276,7 +283,8 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
276
283
|
timeout: float | None = None,
|
277
284
|
metadata: Sequence[tuple[str, str]] = (),
|
278
285
|
) -> AccessSecretVersionResponse:
|
279
|
-
"""
|
286
|
+
"""
|
287
|
+
Access a secret version.
|
280
288
|
|
281
289
|
.. seealso::
|
282
290
|
For more details see API documentation:
|
@@ -311,7 +319,8 @@ class GoogleCloudSecretManagerHook(GoogleBaseHook):
|
|
311
319
|
timeout: float | None = None,
|
312
320
|
metadata: Sequence[tuple[str, str]] = (),
|
313
321
|
) -> None:
|
314
|
-
"""
|
322
|
+
"""
|
323
|
+
Delete a secret.
|
315
324
|
|
316
325
|
.. seealso::
|
317
326
|
For more details see API documentation:
|
@@ -71,7 +71,8 @@ class CloudTranslateHook(GoogleBaseHook):
|
|
71
71
|
source_language: str | None = None,
|
72
72
|
model: str | list[str] | None = None,
|
73
73
|
) -> dict:
|
74
|
-
"""
|
74
|
+
"""
|
75
|
+
Translate a string or list of strings.
|
75
76
|
|
76
77
|
See https://cloud.google.com/translate/docs/translating-text
|
77
78
|
|
@@ -551,7 +551,8 @@ class BatchPredictionJobAsyncHook(GoogleBaseAsyncHook):
|
|
551
551
|
timeout: float | None = None,
|
552
552
|
metadata: Sequence[tuple[str, str]] = (),
|
553
553
|
) -> types.BatchPredictionJob:
|
554
|
-
"""
|
554
|
+
"""
|
555
|
+
Retrieve a batch prediction tuning job.
|
555
556
|
|
556
557
|
:param project_id: Required. The ID of the Google Cloud project that the job belongs to.
|
557
558
|
:param location: Required. The ID of the Google Cloud region that the job belongs to.
|
@@ -22,9 +22,11 @@ from __future__ import annotations
|
|
22
22
|
from typing import Sequence
|
23
23
|
|
24
24
|
import vertexai
|
25
|
+
from deprecated import deprecated
|
25
26
|
from vertexai.generative_models import GenerativeModel, Part
|
26
27
|
from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
|
27
28
|
|
29
|
+
from airflow.exceptions import AirflowProviderDeprecationWarning
|
28
30
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
29
31
|
|
30
32
|
|
@@ -59,11 +61,23 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
59
61
|
model = GenerativeModel(pretrained_model)
|
60
62
|
return model
|
61
63
|
|
64
|
+
@deprecated(
|
65
|
+
reason=(
|
66
|
+
"The `get_generative_model_part` method is deprecated and will be removed after 01.01.2025, please include `Part` objects in `contents` parameter of `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.generative_model_generate_content`"
|
67
|
+
),
|
68
|
+
category=AirflowProviderDeprecationWarning,
|
69
|
+
)
|
62
70
|
def get_generative_model_part(self, content_gcs_path: str, content_mime_type: str | None = None) -> Part:
|
63
71
|
"""Return a Generative Model Part object."""
|
64
72
|
part = Part.from_uri(content_gcs_path, mime_type=content_mime_type)
|
65
73
|
return part
|
66
74
|
|
75
|
+
@deprecated(
|
76
|
+
reason=(
|
77
|
+
"The `prompt_language_model` method is deprecated and will be removed after 01.01.2025, please use `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.text_generation_model_predict` method."
|
78
|
+
),
|
79
|
+
category=AirflowProviderDeprecationWarning,
|
80
|
+
)
|
67
81
|
@GoogleBaseHook.fallback_to_default_project_id
|
68
82
|
def prompt_language_model(
|
69
83
|
self,
|
@@ -112,6 +126,12 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
112
126
|
)
|
113
127
|
return response.text
|
114
128
|
|
129
|
+
@deprecated(
|
130
|
+
reason=(
|
131
|
+
"The `generate_text_embeddings` method is deprecated and will be removed after 01.01.2025, please use `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.text_embedding_model_get_embeddings` method."
|
132
|
+
),
|
133
|
+
category=AirflowProviderDeprecationWarning,
|
134
|
+
)
|
115
135
|
@GoogleBaseHook.fallback_to_default_project_id
|
116
136
|
def generate_text_embeddings(
|
117
137
|
self,
|
@@ -136,6 +156,12 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
136
156
|
|
137
157
|
return response.values
|
138
158
|
|
159
|
+
@deprecated(
|
160
|
+
reason=(
|
161
|
+
"The `prompt_multimodal_model` method is deprecated and will be removed after 01.01.2025, please use `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.generative_model_generate_content` method."
|
162
|
+
),
|
163
|
+
category=AirflowProviderDeprecationWarning,
|
164
|
+
)
|
139
165
|
@GoogleBaseHook.fallback_to_default_project_id
|
140
166
|
def prompt_multimodal_model(
|
141
167
|
self,
|
@@ -169,6 +195,12 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
169
195
|
|
170
196
|
return response.text
|
171
197
|
|
198
|
+
@deprecated(
|
199
|
+
reason=(
|
200
|
+
"The `prompt_multimodal_model_with_media` method is deprecated and will be removed after 01.01.2025, please use `airflow.providers.google.cloud.hooks.generative_model.GenerativeModelHook.generative_model_generate_content` method."
|
201
|
+
),
|
202
|
+
category=AirflowProviderDeprecationWarning,
|
203
|
+
)
|
172
204
|
@GoogleBaseHook.fallback_to_default_project_id
|
173
205
|
def prompt_multimodal_model_with_media(
|
174
206
|
self,
|
@@ -207,3 +239,112 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
207
239
|
)
|
208
240
|
|
209
241
|
return response.text
|
242
|
+
|
243
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
244
|
+
def text_generation_model_predict(
|
245
|
+
self,
|
246
|
+
prompt: str,
|
247
|
+
pretrained_model: str,
|
248
|
+
temperature: float,
|
249
|
+
max_output_tokens: int,
|
250
|
+
top_p: float,
|
251
|
+
top_k: int,
|
252
|
+
location: str,
|
253
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
254
|
+
) -> str:
|
255
|
+
"""
|
256
|
+
Use the Vertex AI PaLM API to generate natural language text.
|
257
|
+
|
258
|
+
:param prompt: Required. Inputs or queries that a user or a program gives
|
259
|
+
to the Vertex AI PaLM API, in order to elicit a specific response.
|
260
|
+
:param pretrained_model: A pre-trained model optimized for performing natural
|
261
|
+
language tasks such as classification, summarization, extraction, content
|
262
|
+
creation, and ideation.
|
263
|
+
:param temperature: Temperature controls the degree of randomness in token
|
264
|
+
selection.
|
265
|
+
:param max_output_tokens: Token limit determines the maximum amount of text
|
266
|
+
output.
|
267
|
+
:param top_p: Tokens are selected from most probable to least until the sum
|
268
|
+
of their probabilities equals the top_p value. Defaults to 0.8.
|
269
|
+
:param top_k: A top_k of 1 means the selected token is the most probable
|
270
|
+
among all tokens.
|
271
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
272
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
273
|
+
"""
|
274
|
+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
275
|
+
|
276
|
+
parameters = {
|
277
|
+
"temperature": temperature,
|
278
|
+
"max_output_tokens": max_output_tokens,
|
279
|
+
"top_p": top_p,
|
280
|
+
"top_k": top_k,
|
281
|
+
}
|
282
|
+
|
283
|
+
model = self.get_text_generation_model(pretrained_model)
|
284
|
+
|
285
|
+
response = model.predict(
|
286
|
+
prompt=prompt,
|
287
|
+
**parameters,
|
288
|
+
)
|
289
|
+
return response.text
|
290
|
+
|
291
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
292
|
+
def text_embedding_model_get_embeddings(
|
293
|
+
self,
|
294
|
+
prompt: str,
|
295
|
+
pretrained_model: str,
|
296
|
+
location: str,
|
297
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
298
|
+
) -> list:
|
299
|
+
"""
|
300
|
+
Use the Vertex AI PaLM API to generate text embeddings.
|
301
|
+
|
302
|
+
:param prompt: Required. Inputs or queries that a user or a program gives
|
303
|
+
to the Vertex AI PaLM API, in order to elicit a specific response.
|
304
|
+
:param pretrained_model: A pre-trained model optimized for generating text embeddings.
|
305
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
306
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
307
|
+
"""
|
308
|
+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
309
|
+
model = self.get_text_embedding_model(pretrained_model)
|
310
|
+
|
311
|
+
response = model.get_embeddings([prompt])[0] # single prompt
|
312
|
+
|
313
|
+
return response.values
|
314
|
+
|
315
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
316
|
+
def generative_model_generate_content(
|
317
|
+
self,
|
318
|
+
contents: list,
|
319
|
+
location: str,
|
320
|
+
tools: list | None = None,
|
321
|
+
generation_config: dict | None = None,
|
322
|
+
safety_settings: dict | None = None,
|
323
|
+
pretrained_model: str = "gemini-pro",
|
324
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
325
|
+
) -> str:
|
326
|
+
"""
|
327
|
+
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
|
328
|
+
|
329
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
330
|
+
gives to the generative model, in order to elicit a specific response.
|
331
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
332
|
+
:param generation_config: Optional. Generation configuration settings.
|
333
|
+
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
334
|
+
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
335
|
+
supporting prompts with text-only input, including natural language
|
336
|
+
tasks, multi-turn text and code chat, and code generation. It can
|
337
|
+
output text and code.
|
338
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
339
|
+
"""
|
340
|
+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
341
|
+
|
342
|
+
model = self.get_generative_model(pretrained_model)
|
343
|
+
response = model.generate_content(
|
344
|
+
contents=contents,
|
345
|
+
tools=tools,
|
346
|
+
generation_config=generation_config,
|
347
|
+
safety_settings=safety_settings,
|
348
|
+
)
|
349
|
+
|
350
|
+
return response.text
|
@@ -15,7 +15,8 @@
|
|
15
15
|
# KIND, either express or implied. See the License for the
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
|
-
"""
|
18
|
+
"""
|
19
|
+
This module contains a Google Cloud Vertex AI hook.
|
19
20
|
|
20
21
|
.. spelling:word-list::
|
21
22
|
|
@@ -36,7 +36,8 @@ DATAFUSION_PIPELINE_LINK = "{uri}/pipelines/ns/{namespace}/view/{pipeline_name}"
|
|
36
36
|
|
37
37
|
|
38
38
|
class BaseGoogleLink(BaseOperatorLink):
|
39
|
-
"""
|
39
|
+
"""
|
40
|
+
Link for Google operators.
|
40
41
|
|
41
42
|
Prevent adding ``https://console.cloud.google.com`` in front of every link
|
42
43
|
where URI is used.
|
@@ -50,7 +50,8 @@ _DEFAULT_SCOPESS = frozenset(
|
|
50
50
|
|
51
51
|
|
52
52
|
class StackdriverTaskHandler(logging.Handler):
|
53
|
-
"""
|
53
|
+
"""
|
54
|
+
Handler that directly makes Stackdriver logging API calls.
|
54
55
|
|
55
56
|
This is a Python standard ``logging`` handler using that can be used to
|
56
57
|
route Python standard logging messages directly to the Stackdriver
|
@@ -174,7 +175,8 @@ class StackdriverTaskHandler(logging.Handler):
|
|
174
175
|
return labels or {}
|
175
176
|
|
176
177
|
def emit(self, record: logging.LogRecord) -> None:
|
177
|
-
"""
|
178
|
+
"""
|
179
|
+
Actually log the specified logging record.
|
178
180
|
|
179
181
|
:param record: The record to be logged.
|
180
182
|
"""
|
@@ -67,8 +67,18 @@ class _BigQueryOpenLineageMixin:
|
|
67
67
|
from airflow.providers.openlineage.sqlparser import SQLParser
|
68
68
|
|
69
69
|
if not self.job_id:
|
70
|
+
if hasattr(self, "log"):
|
71
|
+
self.log.warning("No BigQuery job_id was found by OpenLineage.")
|
70
72
|
return OperatorLineage()
|
71
73
|
|
74
|
+
if not self.hook:
|
75
|
+
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
|
76
|
+
|
77
|
+
self.hook = BigQueryHook(
|
78
|
+
gcp_conn_id=self.gcp_conn_id,
|
79
|
+
impersonation_chain=self.impersonation_chain,
|
80
|
+
)
|
81
|
+
|
72
82
|
run_facets: dict[str, BaseFacet] = {
|
73
83
|
"externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery")
|
74
84
|
}
|
@@ -89,7 +89,8 @@ def get_identity_column_lineage_facet(
|
|
89
89
|
|
90
90
|
@define
|
91
91
|
class BigQueryJobRunFacet(BaseFacet):
|
92
|
-
"""
|
92
|
+
"""
|
93
|
+
Facet that represents relevant statistics of bigquery run.
|
93
94
|
|
94
95
|
This facet is used to provide statistics about bigquery run.
|
95
96
|
|
@@ -134,7 +135,8 @@ class BigQueryErrorRunFacet(BaseFacet):
|
|
134
135
|
|
135
136
|
|
136
137
|
def get_from_nullable_chain(source: Any, chain: list[str]) -> Any | None:
|
137
|
-
"""
|
138
|
+
"""
|
139
|
+
Get object from nested structure of objects, where it's not guaranteed that all keys in the nested structure exist.
|
138
140
|
|
139
141
|
Intended to replace chain of `dict.get()` statements.
|
140
142
|
|