apache-airflow-providers-google 10.22.0rc1__py3-none-any.whl → 10.23.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +91 -54
- airflow/providers/google/cloud/hooks/cloud_build.py +3 -2
- airflow/providers/google/cloud/hooks/dataflow.py +112 -47
- airflow/providers/google/cloud/hooks/datapipeline.py +3 -3
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +15 -26
- airflow/providers/google/cloud/hooks/life_sciences.py +5 -7
- airflow/providers/google/cloud/hooks/secret_manager.py +3 -3
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +28 -8
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +11 -6
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +214 -34
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +11 -4
- airflow/providers/google/cloud/links/automl.py +13 -22
- airflow/providers/google/cloud/log/gcs_task_handler.py +1 -2
- airflow/providers/google/cloud/operators/bigquery.py +6 -4
- airflow/providers/google/cloud/operators/dataflow.py +186 -4
- airflow/providers/google/cloud/operators/datafusion.py +3 -2
- airflow/providers/google/cloud/operators/datapipeline.py +5 -6
- airflow/providers/google/cloud/operators/dataproc.py +30 -33
- airflow/providers/google/cloud/operators/gcs.py +4 -4
- airflow/providers/google/cloud/operators/kubernetes_engine.py +16 -2
- airflow/providers/google/cloud/operators/life_sciences.py +5 -7
- airflow/providers/google/cloud/operators/mlengine.py +42 -65
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +18 -4
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +5 -5
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +280 -9
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +4 -0
- airflow/providers/google/cloud/secrets/secret_manager.py +3 -5
- airflow/providers/google/cloud/sensors/bigquery.py +8 -27
- airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +9 -14
- airflow/providers/google/cloud/sensors/dataflow.py +1 -25
- airflow/providers/google/cloud/sensors/dataform.py +1 -4
- airflow/providers/google/cloud/sensors/datafusion.py +1 -7
- airflow/providers/google/cloud/sensors/dataplex.py +1 -31
- airflow/providers/google/cloud/sensors/dataproc.py +1 -16
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +1 -7
- airflow/providers/google/cloud/sensors/gcs.py +5 -27
- airflow/providers/google/cloud/sensors/looker.py +1 -13
- airflow/providers/google/cloud/sensors/pubsub.py +11 -5
- airflow/providers/google/cloud/sensors/workflows.py +1 -4
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +6 -0
- airflow/providers/google/cloud/triggers/dataflow.py +145 -1
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +66 -3
- airflow/providers/google/common/deprecated.py +176 -0
- airflow/providers/google/common/hooks/base_google.py +3 -2
- airflow/providers/google/get_provider_info.py +8 -10
- airflow/providers/google/marketing_platform/hooks/analytics.py +4 -2
- airflow/providers/google/marketing_platform/hooks/search_ads.py +169 -30
- airflow/providers/google/marketing_platform/operators/analytics.py +16 -33
- airflow/providers/google/marketing_platform/operators/search_ads.py +217 -156
- airflow/providers/google/marketing_platform/sensors/display_video.py +1 -4
- {apache_airflow_providers_google-10.22.0rc1.dist-info → apache_airflow_providers_google-10.23.0rc1.dist-info}/METADATA +18 -16
- {apache_airflow_providers_google-10.22.0rc1.dist-info → apache_airflow_providers_google-10.23.0rc1.dist-info}/RECORD +56 -56
- airflow/providers/google/marketing_platform/sensors/search_ads.py +0 -92
- {apache_airflow_providers_google-10.22.0rc1.dist-info → apache_airflow_providers_google-10.23.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.22.0rc1.dist-info → apache_airflow_providers_google-10.23.0rc1.dist-info}/entry_points.txt +0 -0
@@ -19,16 +19,23 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
-
|
22
|
+
import time
|
23
|
+
from typing import TYPE_CHECKING, Sequence
|
23
24
|
|
24
25
|
import vertexai
|
25
|
-
from deprecated import deprecated
|
26
26
|
from vertexai.generative_models import GenerativeModel, Part
|
27
27
|
from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
|
28
|
+
from vertexai.preview.evaluation import EvalResult, EvalTask
|
29
|
+
from vertexai.preview.tuning import sft
|
28
30
|
|
29
31
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
32
|
+
from airflow.providers.google.common.deprecated import deprecated
|
30
33
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
|
31
34
|
|
35
|
+
if TYPE_CHECKING:
|
36
|
+
from google.cloud.aiplatform_v1 import types as types_v1
|
37
|
+
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
|
38
|
+
|
32
39
|
|
33
40
|
class GenerativeModelHook(GoogleBaseHook):
|
34
41
|
"""Hook for Google Cloud Vertex AI Generative Model APIs."""
|
@@ -56,15 +63,43 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
56
63
|
model = TextEmbeddingModel.from_pretrained(pretrained_model)
|
57
64
|
return model
|
58
65
|
|
59
|
-
def get_generative_model(
|
66
|
+
def get_generative_model(
|
67
|
+
self,
|
68
|
+
pretrained_model: str,
|
69
|
+
system_instruction: str | None = None,
|
70
|
+
generation_config: dict | None = None,
|
71
|
+
safety_settings: dict | None = None,
|
72
|
+
tools: list | None = None,
|
73
|
+
) -> GenerativeModel:
|
60
74
|
"""Return a Generative Model object."""
|
61
|
-
model = GenerativeModel(
|
75
|
+
model = GenerativeModel(
|
76
|
+
model_name=pretrained_model,
|
77
|
+
system_instruction=system_instruction,
|
78
|
+
generation_config=generation_config,
|
79
|
+
safety_settings=safety_settings,
|
80
|
+
tools=tools,
|
81
|
+
)
|
62
82
|
return model
|
63
83
|
|
84
|
+
def get_eval_task(
|
85
|
+
self,
|
86
|
+
dataset: dict,
|
87
|
+
metrics: list,
|
88
|
+
experiment: str,
|
89
|
+
) -> EvalTask:
|
90
|
+
"""Return an EvalTask object."""
|
91
|
+
eval_task = EvalTask(
|
92
|
+
dataset=dataset,
|
93
|
+
metrics=metrics,
|
94
|
+
experiment=experiment,
|
95
|
+
)
|
96
|
+
return eval_task
|
97
|
+
|
64
98
|
@deprecated(
|
65
|
-
|
66
|
-
|
67
|
-
|
99
|
+
planned_removal_date="January 01, 2025",
|
100
|
+
use_instead="Part objects included in contents parameter of "
|
101
|
+
"airflow.providers.google.cloud.hooks.generative_model."
|
102
|
+
"GenerativeModelHook.generative_model_generate_content",
|
68
103
|
category=AirflowProviderDeprecationWarning,
|
69
104
|
)
|
70
105
|
def get_generative_model_part(self, content_gcs_path: str, content_mime_type: str | None = None) -> Part:
|
@@ -73,9 +108,9 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
73
108
|
return part
|
74
109
|
|
75
110
|
@deprecated(
|
76
|
-
|
77
|
-
|
78
|
-
|
111
|
+
planned_removal_date="January 01, 2025",
|
112
|
+
use_instead="airflow.providers.google.cloud.hooks.generative_model."
|
113
|
+
"GenerativeModelHook.text_generation_model_predict",
|
79
114
|
category=AirflowProviderDeprecationWarning,
|
80
115
|
)
|
81
116
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -93,6 +128,8 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
93
128
|
"""
|
94
129
|
Use the Vertex AI PaLM API to generate natural language text.
|
95
130
|
|
131
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
132
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
96
133
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
97
134
|
to the Vertex AI PaLM API, in order to elicit a specific response.
|
98
135
|
:param pretrained_model: A pre-trained model optimized for performing natural
|
@@ -106,8 +143,6 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
106
143
|
of their probabilities equals the top_p value. Defaults to 0.8.
|
107
144
|
:param top_k: A top_k of 1 means the selected token is the most probable
|
108
145
|
among all tokens.
|
109
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
110
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
111
146
|
"""
|
112
147
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
113
148
|
|
@@ -127,9 +162,9 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
127
162
|
return response.text
|
128
163
|
|
129
164
|
@deprecated(
|
130
|
-
|
131
|
-
|
132
|
-
|
165
|
+
planned_removal_date="January 01, 2025",
|
166
|
+
use_instead="airflow.providers.google.cloud.hooks.generative_model."
|
167
|
+
"GenerativeModelHook.text_embedding_model_get_embeddings",
|
133
168
|
category=AirflowProviderDeprecationWarning,
|
134
169
|
)
|
135
170
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -143,11 +178,11 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
143
178
|
"""
|
144
179
|
Use the Vertex AI PaLM API to generate text embeddings.
|
145
180
|
|
181
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
182
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
146
183
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
147
184
|
to the Vertex AI PaLM API, in order to elicit a specific response.
|
148
185
|
:param pretrained_model: A pre-trained model optimized for generating text embeddings.
|
149
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
150
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
151
186
|
"""
|
152
187
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
153
188
|
model = self.get_text_embedding_model(pretrained_model)
|
@@ -157,9 +192,9 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
157
192
|
return response.values
|
158
193
|
|
159
194
|
@deprecated(
|
160
|
-
|
161
|
-
|
162
|
-
|
195
|
+
planned_removal_date="January 01, 2025",
|
196
|
+
use_instead="airflow.providers.google.cloud.hooks.generative_model."
|
197
|
+
"GenerativeModelHook.generative_model_generate_content",
|
163
198
|
category=AirflowProviderDeprecationWarning,
|
164
199
|
)
|
165
200
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -175,16 +210,16 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
175
210
|
"""
|
176
211
|
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
|
177
212
|
|
213
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
214
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
178
215
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
179
216
|
to the Multi-modal model, in order to elicit a specific response.
|
180
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
181
217
|
:param generation_config: Optional. Generation configuration settings.
|
182
218
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
183
219
|
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
184
220
|
supporting prompts with text-only input, including natural language
|
185
221
|
tasks, multi-turn text and code chat, and code generation. It can
|
186
222
|
output text and code.
|
187
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
188
223
|
"""
|
189
224
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
190
225
|
|
@@ -196,9 +231,9 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
196
231
|
return response.text
|
197
232
|
|
198
233
|
@deprecated(
|
199
|
-
|
200
|
-
|
201
|
-
|
234
|
+
planned_removal_date="January 01, 2025",
|
235
|
+
use_instead="airflow.providers.google.cloud.hooks.generative_model."
|
236
|
+
"GenerativeModelHook.generative_model_generate_content",
|
202
237
|
category=AirflowProviderDeprecationWarning,
|
203
238
|
)
|
204
239
|
@GoogleBaseHook.fallback_to_default_project_id
|
@@ -216,6 +251,8 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
216
251
|
"""
|
217
252
|
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
|
218
253
|
|
254
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
255
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
219
256
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
220
257
|
to the Multi-modal model, in order to elicit a specific response.
|
221
258
|
:param generation_config: Optional. Generation configuration settings.
|
@@ -227,8 +264,6 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
227
264
|
:param media_gcs_path: A GCS path to a content file such as an image or a video.
|
228
265
|
Can be passed to the multi-modal model as part of the prompt. Used with vision models.
|
229
266
|
:param mime_type: Validates the media type presented by the file in the media_gcs_path.
|
230
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
231
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
232
267
|
"""
|
233
268
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
234
269
|
|
@@ -255,6 +290,8 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
255
290
|
"""
|
256
291
|
Use the Vertex AI PaLM API to generate natural language text.
|
257
292
|
|
293
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
294
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
258
295
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
259
296
|
to the Vertex AI PaLM API, in order to elicit a specific response.
|
260
297
|
:param pretrained_model: A pre-trained model optimized for performing natural
|
@@ -268,8 +305,6 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
268
305
|
of their probabilities equals the top_p value. Defaults to 0.8.
|
269
306
|
:param top_k: A top_k of 1 means the selected token is the most probable
|
270
307
|
among all tokens.
|
271
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
272
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
273
308
|
"""
|
274
309
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
275
310
|
|
@@ -299,11 +334,11 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
299
334
|
"""
|
300
335
|
Use the Vertex AI PaLM API to generate text embeddings.
|
301
336
|
|
337
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
338
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
302
339
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
303
340
|
to the Vertex AI PaLM API, in order to elicit a specific response.
|
304
341
|
:param pretrained_model: A pre-trained model optimized for generating text embeddings.
|
305
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
306
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
307
342
|
"""
|
308
343
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
309
344
|
model = self.get_text_embedding_model(pretrained_model)
|
@@ -320,26 +355,31 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
320
355
|
tools: list | None = None,
|
321
356
|
generation_config: dict | None = None,
|
322
357
|
safety_settings: dict | None = None,
|
358
|
+
system_instruction: str | None = None,
|
323
359
|
pretrained_model: str = "gemini-pro",
|
324
360
|
project_id: str = PROVIDE_PROJECT_ID,
|
325
361
|
) -> str:
|
326
362
|
"""
|
327
363
|
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
|
328
364
|
|
365
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
366
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
329
367
|
:param contents: Required. The multi-part content of a message that a user or a program
|
330
368
|
gives to the generative model, in order to elicit a specific response.
|
331
|
-
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
332
369
|
:param generation_config: Optional. Generation configuration settings.
|
333
370
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
371
|
+
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
|
372
|
+
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
|
334
373
|
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
335
374
|
supporting prompts with text-only input, including natural language
|
336
375
|
tasks, multi-turn text and code chat, and code generation. It can
|
337
376
|
output text and code.
|
338
|
-
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
339
377
|
"""
|
340
378
|
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
341
379
|
|
342
|
-
model = self.get_generative_model(
|
380
|
+
model = self.get_generative_model(
|
381
|
+
pretrained_model=pretrained_model, system_instruction=system_instruction
|
382
|
+
)
|
343
383
|
response = model.generate_content(
|
344
384
|
contents=contents,
|
345
385
|
tools=tools,
|
@@ -348,3 +388,143 @@ class GenerativeModelHook(GoogleBaseHook):
|
|
348
388
|
)
|
349
389
|
|
350
390
|
return response.text
|
391
|
+
|
392
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
393
|
+
def supervised_fine_tuning_train(
|
394
|
+
self,
|
395
|
+
source_model: str,
|
396
|
+
train_dataset: str,
|
397
|
+
location: str,
|
398
|
+
tuned_model_display_name: str | None = None,
|
399
|
+
validation_dataset: str | None = None,
|
400
|
+
epochs: int | None = None,
|
401
|
+
adapter_size: int | None = None,
|
402
|
+
learning_rate_multiplier: float | None = None,
|
403
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
404
|
+
) -> types_v1.TuningJob:
|
405
|
+
"""
|
406
|
+
Use the Supervised Fine Tuning API to create a tuning job.
|
407
|
+
|
408
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
409
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
410
|
+
:param source_model: Required. A pre-trained model optimized for performing natural
|
411
|
+
language tasks such as classification, summarization, extraction, content
|
412
|
+
creation, and ideation.
|
413
|
+
:param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
|
414
|
+
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
|
415
|
+
:param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up
|
416
|
+
to 128 characters long and can consist of any UTF-8 characters.
|
417
|
+
:param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be
|
418
|
+
formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
|
419
|
+
:param epochs: Optional. To optimize performance on a specific dataset, try using a higher
|
420
|
+
epoch value. Increasing the number of epochs might improve results. However, be cautious
|
421
|
+
about over-fitting, especially when dealing with small datasets. If over-fitting occurs,
|
422
|
+
consider lowering the epoch number.
|
423
|
+
:param adapter_size: Optional. Adapter size for tuning.
|
424
|
+
:param learning_rate_multiplier: Optional. Multiplier for adjusting the default learning rate.
|
425
|
+
"""
|
426
|
+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
427
|
+
|
428
|
+
sft_tuning_job = sft.train(
|
429
|
+
source_model=source_model,
|
430
|
+
train_dataset=train_dataset,
|
431
|
+
validation_dataset=validation_dataset,
|
432
|
+
epochs=epochs,
|
433
|
+
adapter_size=adapter_size,
|
434
|
+
learning_rate_multiplier=learning_rate_multiplier,
|
435
|
+
tuned_model_display_name=tuned_model_display_name,
|
436
|
+
)
|
437
|
+
|
438
|
+
# Polling for job completion
|
439
|
+
while not sft_tuning_job.has_ended:
|
440
|
+
time.sleep(60)
|
441
|
+
sft_tuning_job.refresh()
|
442
|
+
|
443
|
+
return sft_tuning_job
|
444
|
+
|
445
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
446
|
+
def count_tokens(
|
447
|
+
self,
|
448
|
+
contents: list,
|
449
|
+
location: str,
|
450
|
+
pretrained_model: str = "gemini-pro",
|
451
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
452
|
+
) -> types_v1beta1.CountTokensResponse:
|
453
|
+
"""
|
454
|
+
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
|
455
|
+
|
456
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
457
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
458
|
+
:param contents: Required. The multi-part content of a message that a user or a program
|
459
|
+
gives to the generative model, in order to elicit a specific response.
|
460
|
+
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
461
|
+
supporting prompts with text-only input, including natural language
|
462
|
+
tasks, multi-turn text and code chat, and code generation. It can
|
463
|
+
output text and code.
|
464
|
+
"""
|
465
|
+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
466
|
+
|
467
|
+
model = self.get_generative_model(pretrained_model=pretrained_model)
|
468
|
+
response = model.count_tokens(
|
469
|
+
contents=contents,
|
470
|
+
)
|
471
|
+
|
472
|
+
return response
|
473
|
+
|
474
|
+
@GoogleBaseHook.fallback_to_default_project_id
|
475
|
+
def run_evaluation(
|
476
|
+
self,
|
477
|
+
pretrained_model: str,
|
478
|
+
eval_dataset: dict,
|
479
|
+
metrics: list,
|
480
|
+
experiment_name: str,
|
481
|
+
experiment_run_name: str,
|
482
|
+
prompt_template: str,
|
483
|
+
location: str,
|
484
|
+
generation_config: dict | None = None,
|
485
|
+
safety_settings: dict | None = None,
|
486
|
+
system_instruction: str | None = None,
|
487
|
+
tools: list | None = None,
|
488
|
+
project_id: str = PROVIDE_PROJECT_ID,
|
489
|
+
) -> EvalResult:
|
490
|
+
"""
|
491
|
+
Use the Rapid Evaluation API to evaluate a model.
|
492
|
+
|
493
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
494
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
495
|
+
:param pretrained_model: Required. A pre-trained model optimized for performing natural
|
496
|
+
language tasks such as classification, summarization, extraction, content
|
497
|
+
creation, and ideation.
|
498
|
+
:param eval_dataset: Required. A fixed dataset for evaluating a model against. Adheres to Rapid Evaluation API.
|
499
|
+
:param metrics: Required. A list of evaluation metrics to be used in the experiment. Adheres to Rapid Evaluation API.
|
500
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
501
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
502
|
+
:param prompt_template: Required. The template used to format the model's prompts during evaluation. Adheres to Rapid Evaluation API.
|
503
|
+
:param generation_config: Optional. A dictionary containing generation parameters for the model.
|
504
|
+
:param safety_settings: Optional. A dictionary specifying harm category thresholds for blocking model outputs.
|
505
|
+
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
|
506
|
+
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
|
507
|
+
"""
|
508
|
+
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())
|
509
|
+
|
510
|
+
model = self.get_generative_model(
|
511
|
+
pretrained_model=pretrained_model,
|
512
|
+
system_instruction=system_instruction,
|
513
|
+
generation_config=generation_config,
|
514
|
+
safety_settings=safety_settings,
|
515
|
+
tools=tools,
|
516
|
+
)
|
517
|
+
|
518
|
+
eval_task = self.get_eval_task(
|
519
|
+
dataset=eval_dataset,
|
520
|
+
metrics=metrics,
|
521
|
+
experiment=experiment_name,
|
522
|
+
)
|
523
|
+
|
524
|
+
eval_result = eval_task.evaluate(
|
525
|
+
model=model,
|
526
|
+
prompt_template=prompt_template,
|
527
|
+
experiment_run_name=experiment_run_name,
|
528
|
+
)
|
529
|
+
|
530
|
+
return eval_result
|
@@ -208,6 +208,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
208
208
|
project_id: str,
|
209
209
|
region: str,
|
210
210
|
model: Model | dict,
|
211
|
+
parent_model: str | None = None,
|
211
212
|
retry: Retry | _MethodDefault = DEFAULT,
|
212
213
|
timeout: float | None = None,
|
213
214
|
metadata: Sequence[tuple[str, str]] = (),
|
@@ -218,6 +219,7 @@ class ModelServiceHook(GoogleBaseHook):
|
|
218
219
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
219
220
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
220
221
|
:param model: Required. The Model to create.
|
222
|
+
:param parent_model: The name of the parent model to create a new version under.
|
221
223
|
:param retry: Designation of what errors, if any, should be retried.
|
222
224
|
:param timeout: The timeout for this request.
|
223
225
|
:param metadata: Strings which should be sent along with the request as metadata.
|
@@ -225,11 +227,16 @@ class ModelServiceHook(GoogleBaseHook):
|
|
225
227
|
client = self.get_model_service_client(region)
|
226
228
|
parent = client.common_location_path(project_id, region)
|
227
229
|
|
230
|
+
request = {
|
231
|
+
"parent": parent,
|
232
|
+
"model": model,
|
233
|
+
}
|
234
|
+
|
235
|
+
if parent_model:
|
236
|
+
request["parent_model"] = parent_model
|
237
|
+
|
228
238
|
result = client.upload_model(
|
229
|
-
request=
|
230
|
-
"parent": parent,
|
231
|
-
"model": model,
|
232
|
-
},
|
239
|
+
request=request,
|
233
240
|
retry=retry,
|
234
241
|
timeout=timeout,
|
235
242
|
metadata=metadata,
|
@@ -21,10 +21,9 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
from typing import TYPE_CHECKING
|
23
23
|
|
24
|
-
from deprecated import deprecated
|
25
|
-
|
26
24
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
27
25
|
from airflow.providers.google.cloud.links.base import BaseGoogleLink
|
26
|
+
from airflow.providers.google.common.deprecated import deprecated
|
28
27
|
|
29
28
|
if TYPE_CHECKING:
|
30
29
|
from airflow.utils.context import Context
|
@@ -48,10 +47,8 @@ AUTOML_MODEL_PREDICT_LINK = (
|
|
48
47
|
|
49
48
|
|
50
49
|
@deprecated(
|
51
|
-
|
52
|
-
|
53
|
-
"Please use `TranslationLegacyDatasetLink` from `airflow/providers/google/cloud/links/translate.py` instead."
|
54
|
-
),
|
50
|
+
planned_removal_date="December 31, 2024",
|
51
|
+
use_instead="TranslationLegacyDatasetLink class from airflow/providers/google/cloud/links/translate.py",
|
55
52
|
category=AirflowProviderDeprecationWarning,
|
56
53
|
)
|
57
54
|
class AutoMLDatasetLink(BaseGoogleLink):
|
@@ -76,10 +73,8 @@ class AutoMLDatasetLink(BaseGoogleLink):
|
|
76
73
|
|
77
74
|
|
78
75
|
@deprecated(
|
79
|
-
|
80
|
-
|
81
|
-
"Please use `TranslationDatasetListLink` from `airflow/providers/google/cloud/links/translate.py` instead."
|
82
|
-
),
|
76
|
+
planned_removal_date="December 31, 2024",
|
77
|
+
use_instead="TranslationDatasetListLink class from airflow/providers/google/cloud/links/translate.py",
|
83
78
|
category=AirflowProviderDeprecationWarning,
|
84
79
|
)
|
85
80
|
class AutoMLDatasetListLink(BaseGoogleLink):
|
@@ -105,10 +100,8 @@ class AutoMLDatasetListLink(BaseGoogleLink):
|
|
105
100
|
|
106
101
|
|
107
102
|
@deprecated(
|
108
|
-
|
109
|
-
|
110
|
-
"Please use `TranslationLegacyModelLink` from `airflow/providers/google/cloud/links/translate.py` instead."
|
111
|
-
),
|
103
|
+
planned_removal_date="December 31, 2024",
|
104
|
+
use_instead="TranslationLegacyModelLink class from airflow/providers/google/cloud/links/translate.py",
|
112
105
|
category=AirflowProviderDeprecationWarning,
|
113
106
|
)
|
114
107
|
class AutoMLModelLink(BaseGoogleLink):
|
@@ -139,10 +132,9 @@ class AutoMLModelLink(BaseGoogleLink):
|
|
139
132
|
|
140
133
|
|
141
134
|
@deprecated(
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
),
|
135
|
+
planned_removal_date="December 31, 2024",
|
136
|
+
use_instead="TranslationLegacyModelTrainLink class from "
|
137
|
+
"airflow/providers/google/cloud/links/translate.py",
|
146
138
|
category=AirflowProviderDeprecationWarning,
|
147
139
|
)
|
148
140
|
class AutoMLModelTrainLink(BaseGoogleLink):
|
@@ -170,10 +162,9 @@ class AutoMLModelTrainLink(BaseGoogleLink):
|
|
170
162
|
|
171
163
|
|
172
164
|
@deprecated(
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
),
|
165
|
+
planned_removal_date="December 31, 2024",
|
166
|
+
use_instead="TranslationLegacyModelPredictLink class from "
|
167
|
+
"airflow/providers/google/cloud/links/translate.py",
|
177
168
|
category=AirflowProviderDeprecationWarning,
|
178
169
|
)
|
179
170
|
class AutoMLModelPredictLink(BaseGoogleLink):
|
@@ -79,14 +79,13 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
|
|
79
79
|
*,
|
80
80
|
base_log_folder: str,
|
81
81
|
gcs_log_folder: str,
|
82
|
-
filename_template: str | None = None,
|
83
82
|
gcp_key_path: str | None = None,
|
84
83
|
gcp_keyfile_dict: dict | None = None,
|
85
84
|
gcp_scopes: Collection[str] | None = _DEFAULT_SCOPESS,
|
86
85
|
project_id: str = PROVIDE_PROJECT_ID,
|
87
86
|
**kwargs,
|
88
87
|
):
|
89
|
-
super().__init__(base_log_folder
|
88
|
+
super().__init__(base_log_folder)
|
90
89
|
self.handler: logging.FileHandler | None = None
|
91
90
|
self.remote_base = gcs_log_folder
|
92
91
|
self.log_relative_path = ""
|
@@ -27,7 +27,6 @@ from functools import cached_property
|
|
27
27
|
from typing import TYPE_CHECKING, Any, Iterable, Sequence, SupportsAbs
|
28
28
|
|
29
29
|
import attr
|
30
|
-
from deprecated import deprecated
|
31
30
|
from google.api_core.exceptions import Conflict
|
32
31
|
from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob, Row
|
33
32
|
from google.cloud.bigquery.table import RowIterator
|
@@ -57,6 +56,7 @@ from airflow.providers.google.cloud.triggers.bigquery import (
|
|
57
56
|
BigQueryValueCheckTrigger,
|
58
57
|
)
|
59
58
|
from airflow.providers.google.cloud.utils.bigquery import convert_job_id
|
59
|
+
from airflow.providers.google.common.deprecated import deprecated
|
60
60
|
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
|
61
61
|
from airflow.utils.helpers import exactly_one
|
62
62
|
|
@@ -1203,7 +1203,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator, _BigQueryOperatorsEncrypt
|
|
1203
1203
|
|
1204
1204
|
|
1205
1205
|
@deprecated(
|
1206
|
-
|
1206
|
+
planned_removal_date="November 01, 2024",
|
1207
|
+
use_instead="BigQueryInsertJobOperator",
|
1207
1208
|
category=AirflowProviderDeprecationWarning,
|
1208
1209
|
)
|
1209
1210
|
class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
|
@@ -1415,7 +1416,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
|
|
1415
1416
|
raise AirflowException(f"argument 'sql' of type {type(str)} is neither a string nor an iterable")
|
1416
1417
|
project_id = self.hook.project_id
|
1417
1418
|
if project_id:
|
1418
|
-
job_id_path = convert_job_id(job_id=self.job_id, project_id=project_id, location=self.location)
|
1419
|
+
job_id_path = convert_job_id(job_id=self.job_id, project_id=project_id, location=self.location) # type: ignore[arg-type]
|
1419
1420
|
context["task_instance"].xcom_push(key="job_id_path", value=job_id_path)
|
1420
1421
|
return self.job_id
|
1421
1422
|
|
@@ -2298,7 +2299,8 @@ class BigQueryGetDatasetTablesOperator(GoogleCloudBaseOperator):
|
|
2298
2299
|
|
2299
2300
|
|
2300
2301
|
@deprecated(
|
2301
|
-
|
2302
|
+
planned_removal_date="November 01, 2024",
|
2303
|
+
use_instead="BigQueryUpdateDatasetOperator",
|
2302
2304
|
category=AirflowProviderDeprecationWarning,
|
2303
2305
|
)
|
2304
2306
|
class BigQueryPatchDatasetOperator(GoogleCloudBaseOperator):
|