apache-airflow-providers-google 16.0.0a1__py3-none-any.whl → 16.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +43 -5
- airflow/providers/google/ads/operators/ads.py +1 -1
- airflow/providers/google/ads/transfers/ads_to_gcs.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +63 -77
- airflow/providers/google/cloud/hooks/cloud_sql.py +8 -4
- airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
- airflow/providers/google/cloud/hooks/dataflow.py +2 -2
- airflow/providers/google/cloud/hooks/dataplex.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +4 -1
- airflow/providers/google/cloud/hooks/gcs.py +5 -5
- airflow/providers/google/cloud/hooks/looker.py +10 -1
- airflow/providers/google/cloud/hooks/mlengine.py +2 -1
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +2 -2
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -36
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +44 -80
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +11 -2
- airflow/providers/google/cloud/hooks/vision.py +2 -2
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +75 -11
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/cloud_run.py +27 -0
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +16 -43
- airflow/providers/google/cloud/links/cloud_tasks.py +6 -25
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +14 -90
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +5 -59
- airflow/providers/google/cloud/links/life_sciences.py +0 -19
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +16 -186
- airflow/providers/google/cloud/links/vertex_ai.py +8 -224
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +4 -4
- airflow/providers/google/cloud/operators/alloy_db.py +69 -54
- airflow/providers/google/cloud/operators/automl.py +16 -14
- airflow/providers/google/cloud/operators/bigquery.py +49 -25
- airflow/providers/google/cloud/operators/bigquery_dts.py +2 -4
- airflow/providers/google/cloud/operators/bigtable.py +35 -6
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_build.py +74 -31
- airflow/providers/google/cloud/operators/cloud_composer.py +34 -35
- airflow/providers/google/cloud/operators/cloud_memorystore.py +68 -42
- airflow/providers/google/cloud/operators/cloud_run.py +9 -1
- airflow/providers/google/cloud/operators/cloud_sql.py +11 -15
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +0 -2
- airflow/providers/google/cloud/operators/compute.py +7 -39
- airflow/providers/google/cloud/operators/datacatalog.py +156 -20
- airflow/providers/google/cloud/operators/dataflow.py +37 -14
- airflow/providers/google/cloud/operators/dataform.py +14 -4
- airflow/providers/google/cloud/operators/datafusion.py +4 -12
- airflow/providers/google/cloud/operators/dataplex.py +180 -96
- airflow/providers/google/cloud/operators/dataprep.py +0 -4
- airflow/providers/google/cloud/operators/dataproc.py +10 -16
- airflow/providers/google/cloud/operators/dataproc_metastore.py +95 -87
- airflow/providers/google/cloud/operators/datastore.py +21 -5
- airflow/providers/google/cloud/operators/dlp.py +3 -26
- airflow/providers/google/cloud/operators/functions.py +15 -6
- airflow/providers/google/cloud/operators/gcs.py +1 -7
- airflow/providers/google/cloud/operators/kubernetes_engine.py +53 -92
- airflow/providers/google/cloud/operators/life_sciences.py +0 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +106 -51
- airflow/providers/google/cloud/operators/mlengine.py +0 -1
- airflow/providers/google/cloud/operators/pubsub.py +4 -5
- airflow/providers/google/cloud/operators/spanner.py +0 -4
- airflow/providers/google/cloud/operators/speech_to_text.py +0 -1
- airflow/providers/google/cloud/operators/stackdriver.py +0 -8
- airflow/providers/google/cloud/operators/tasks.py +0 -11
- airflow/providers/google/cloud/operators/text_to_speech.py +0 -1
- airflow/providers/google/cloud/operators/translate.py +37 -13
- airflow/providers/google/cloud/operators/translate_speech.py +0 -1
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +31 -18
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +28 -8
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +38 -25
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +69 -7
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +42 -8
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +531 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +93 -117
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +10 -8
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +56 -10
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +29 -6
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +9 -6
- airflow/providers/google/cloud/operators/workflows.py +1 -9
- airflow/providers/google/cloud/sensors/bigquery.py +1 -1
- airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -1
- airflow/providers/google/cloud/sensors/bigtable.py +15 -3
- airflow/providers/google/cloud/sensors/cloud_composer.py +6 -1
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -1
- airflow/providers/google/cloud/sensors/dataflow.py +3 -3
- airflow/providers/google/cloud/sensors/dataform.py +6 -1
- airflow/providers/google/cloud/sensors/datafusion.py +6 -1
- airflow/providers/google/cloud/sensors/dataplex.py +6 -1
- airflow/providers/google/cloud/sensors/dataprep.py +6 -1
- airflow/providers/google/cloud/sensors/dataproc.py +6 -1
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +6 -1
- airflow/providers/google/cloud/sensors/gcs.py +9 -3
- airflow/providers/google/cloud/sensors/looker.py +6 -1
- airflow/providers/google/cloud/sensors/pubsub.py +8 -3
- airflow/providers/google/cloud/sensors/tasks.py +6 -1
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +6 -1
- airflow/providers/google/cloud/sensors/workflows.py +6 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +10 -7
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -2
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +0 -1
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -1
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +1 -2
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/gcs_to_local.py +1 -1
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -1
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +5 -1
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +1 -1
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +11 -5
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +1 -1
- airflow/providers/google/cloud/triggers/bigquery.py +32 -5
- airflow/providers/google/cloud/triggers/dataproc.py +62 -10
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/common/auth_backend/google_openid.py +2 -1
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +7 -3
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/firebase/operators/firestore.py +1 -1
- airflow/providers/google/get_provider_info.py +14 -16
- airflow/providers/google/leveldb/hooks/leveldb.py +30 -1
- airflow/providers/google/leveldb/operators/leveldb.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +3 -6
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +0 -1
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +4 -4
- airflow/providers/google/marketing_platform/operators/display_video.py +6 -6
- airflow/providers/google/marketing_platform/operators/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +6 -1
- airflow/providers/google/marketing_platform/sensors/display_video.py +6 -1
- airflow/providers/google/suite/operators/sheets.py +3 -3
- airflow/providers/google/suite/sensors/drive.py +6 -1
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +1 -1
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +1 -1
- airflow/providers/google/version_compat.py +28 -0
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0.dist-info}/METADATA +35 -35
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0.dist-info}/RECORD +171 -170
- airflow/providers/google/cloud/links/automl.py +0 -193
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-16.0.0a1.dist-info → apache_airflow_providers_google-16.1.0.dist-info}/entry_points.txt +0 -0
@@ -20,107 +20,21 @@
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
from collections.abc import Sequence
|
23
|
-
from typing import TYPE_CHECKING
|
23
|
+
from typing import TYPE_CHECKING, Any, Literal
|
24
24
|
|
25
|
-
from
|
26
|
-
|
25
|
+
from google.api_core import exceptions
|
26
|
+
|
27
|
+
from airflow.exceptions import AirflowException
|
28
|
+
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
|
29
|
+
ExperimentRunHook,
|
30
|
+
GenerativeModelHook,
|
31
|
+
)
|
27
32
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
28
|
-
from airflow.providers.google.common.deprecated import deprecated
|
29
33
|
|
30
34
|
if TYPE_CHECKING:
|
31
35
|
from airflow.utils.context import Context
|
32
36
|
|
33
37
|
|
34
|
-
@deprecated(
|
35
|
-
planned_removal_date="April 09, 2025",
|
36
|
-
use_instead="GenerativeModelGenerateContentOperator",
|
37
|
-
category=AirflowProviderDeprecationWarning,
|
38
|
-
)
|
39
|
-
class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
|
40
|
-
"""
|
41
|
-
Uses the Vertex AI PaLM API to generate natural language text.
|
42
|
-
|
43
|
-
:param project_id: Required. The ID of the Google Cloud project that the
|
44
|
-
service belongs to (templated).
|
45
|
-
:param location: Required. The ID of the Google Cloud location that the
|
46
|
-
service belongs to (templated).
|
47
|
-
:param prompt: Required. Inputs or queries that a user or a program gives
|
48
|
-
to the Vertex AI PaLM API, in order to elicit a specific response (templated).
|
49
|
-
:param pretrained_model: By default uses the pre-trained model `text-bison`,
|
50
|
-
optimized for performing natural language tasks such as classification,
|
51
|
-
summarization, extraction, content creation, and ideation.
|
52
|
-
:param temperature: Temperature controls the degree of randomness in token
|
53
|
-
selection. Defaults to 0.0.
|
54
|
-
:param max_output_tokens: Token limit determines the maximum amount of text
|
55
|
-
output. Defaults to 256.
|
56
|
-
:param top_p: Tokens are selected from most probable to least until the sum
|
57
|
-
of their probabilities equals the top_p value. Defaults to 0.8.
|
58
|
-
:param top_k: A top_k of 1 means the selected token is the most probable
|
59
|
-
among all tokens. Defaults to 0.4.
|
60
|
-
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
61
|
-
:param impersonation_chain: Optional service account to impersonate using short-term
|
62
|
-
credentials, or chained list of accounts required to get the access_token
|
63
|
-
of the last account in the list, which will be impersonated in the request.
|
64
|
-
If set as a string, the account must grant the originating account
|
65
|
-
the Service Account Token Creator IAM role.
|
66
|
-
If set as a sequence, the identities from the list must grant
|
67
|
-
Service Account Token Creator IAM role to the directly preceding identity, with first
|
68
|
-
account from the list granting this role to the originating account (templated).
|
69
|
-
"""
|
70
|
-
|
71
|
-
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
72
|
-
|
73
|
-
def __init__(
|
74
|
-
self,
|
75
|
-
*,
|
76
|
-
project_id: str,
|
77
|
-
location: str,
|
78
|
-
prompt: str,
|
79
|
-
pretrained_model: str = "text-bison",
|
80
|
-
temperature: float = 0.0,
|
81
|
-
max_output_tokens: int = 256,
|
82
|
-
top_p: float = 0.8,
|
83
|
-
top_k: int = 40,
|
84
|
-
gcp_conn_id: str = "google_cloud_default",
|
85
|
-
impersonation_chain: str | Sequence[str] | None = None,
|
86
|
-
**kwargs,
|
87
|
-
) -> None:
|
88
|
-
super().__init__(**kwargs)
|
89
|
-
self.project_id = project_id
|
90
|
-
self.location = location
|
91
|
-
self.prompt = prompt
|
92
|
-
self.pretrained_model = pretrained_model
|
93
|
-
self.temperature = temperature
|
94
|
-
self.max_output_tokens = max_output_tokens
|
95
|
-
self.top_p = top_p
|
96
|
-
self.top_k = top_k
|
97
|
-
self.gcp_conn_id = gcp_conn_id
|
98
|
-
self.impersonation_chain = impersonation_chain
|
99
|
-
|
100
|
-
def execute(self, context: Context):
|
101
|
-
self.hook = GenerativeModelHook(
|
102
|
-
gcp_conn_id=self.gcp_conn_id,
|
103
|
-
impersonation_chain=self.impersonation_chain,
|
104
|
-
)
|
105
|
-
|
106
|
-
self.log.info("Submitting prompt")
|
107
|
-
response = self.hook.text_generation_model_predict(
|
108
|
-
project_id=self.project_id,
|
109
|
-
location=self.location,
|
110
|
-
prompt=self.prompt,
|
111
|
-
pretrained_model=self.pretrained_model,
|
112
|
-
temperature=self.temperature,
|
113
|
-
max_output_tokens=self.max_output_tokens,
|
114
|
-
top_p=self.top_p,
|
115
|
-
top_k=self.top_k,
|
116
|
-
)
|
117
|
-
|
118
|
-
self.log.info("Model response: %s", response)
|
119
|
-
self.xcom_push(context, key="model_response", value=response)
|
120
|
-
|
121
|
-
return response
|
122
|
-
|
123
|
-
|
124
38
|
class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
125
39
|
"""
|
126
40
|
Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
|
@@ -130,9 +44,8 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
130
44
|
:param location: Required. The ID of the Google Cloud location that the
|
131
45
|
service belongs to (templated).
|
132
46
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
133
|
-
to the Vertex AI
|
134
|
-
:param pretrained_model:
|
135
|
-
optimized for performing text embeddings.
|
47
|
+
to the Vertex AI Generative Model API, in order to elicit a specific response (templated).
|
48
|
+
:param pretrained_model: Required. Model, optimized for performing text embeddings.
|
136
49
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
137
50
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
138
51
|
credentials, or chained list of accounts required to get the access_token
|
@@ -152,7 +65,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
152
65
|
project_id: str,
|
153
66
|
location: str,
|
154
67
|
prompt: str,
|
155
|
-
pretrained_model: str
|
68
|
+
pretrained_model: str,
|
156
69
|
gcp_conn_id: str = "google_cloud_default",
|
157
70
|
impersonation_chain: str | Sequence[str] | None = None,
|
158
71
|
**kwargs,
|
@@ -180,7 +93,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
180
93
|
)
|
181
94
|
|
182
95
|
self.log.info("Model response: %s", response)
|
183
|
-
|
96
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
184
97
|
|
185
98
|
return response
|
186
99
|
|
@@ -199,10 +112,9 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
199
112
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
200
113
|
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
|
201
114
|
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
|
202
|
-
:param pretrained_model:
|
203
|
-
|
204
|
-
|
205
|
-
output text and code.
|
115
|
+
:param pretrained_model: Required. The name of the model to use for content generation,
|
116
|
+
which can be a text-only or multimodal model. For example, `gemini-pro` or
|
117
|
+
`gemini-pro-vision`.
|
206
118
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
207
119
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
208
120
|
credentials, or chained list of accounts required to get the access_token
|
@@ -226,7 +138,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
226
138
|
generation_config: dict | None = None,
|
227
139
|
safety_settings: dict | None = None,
|
228
140
|
system_instruction: str | None = None,
|
229
|
-
pretrained_model: str
|
141
|
+
pretrained_model: str,
|
230
142
|
gcp_conn_id: str = "google_cloud_default",
|
231
143
|
impersonation_chain: str | Sequence[str] | None = None,
|
232
144
|
**kwargs,
|
@@ -260,7 +172,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
260
172
|
)
|
261
173
|
|
262
174
|
self.log.info("Model response: %s", response)
|
263
|
-
|
175
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
264
176
|
|
265
177
|
return response
|
266
178
|
|
@@ -310,7 +222,7 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
310
222
|
tuned_model_display_name: str | None = None,
|
311
223
|
validation_dataset: str | None = None,
|
312
224
|
epochs: int | None = None,
|
313
|
-
adapter_size:
|
225
|
+
adapter_size: Literal[1, 4, 8, 16] | None = None,
|
314
226
|
learning_rate_multiplier: float | None = None,
|
315
227
|
gcp_conn_id: str = "google_cloud_default",
|
316
228
|
impersonation_chain: str | Sequence[str] | None = None,
|
@@ -349,8 +261,8 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
349
261
|
self.log.info("Tuned Model Name: %s", response.tuned_model_name)
|
350
262
|
self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)
|
351
263
|
|
352
|
-
|
353
|
-
|
264
|
+
context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model_name)
|
265
|
+
context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
|
354
266
|
|
355
267
|
result = {
|
356
268
|
"tuned_model_name": response.tuned_model_name,
|
@@ -370,10 +282,9 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
370
282
|
service belongs to (templated).
|
371
283
|
:param contents: Required. The multi-part content of a message that a user or a program
|
372
284
|
gives to the generative model, in order to elicit a specific response.
|
373
|
-
:param pretrained_model:
|
374
|
-
|
375
|
-
|
376
|
-
output text and code.
|
285
|
+
:param pretrained_model: Required. Model, supporting prompts with text-only input,
|
286
|
+
including natural language tasks, multi-turn text and code chat,
|
287
|
+
and code generation. It can output text and code.
|
377
288
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
378
289
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
379
290
|
credentials, or chained list of accounts required to get the access_token
|
@@ -393,7 +304,7 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
393
304
|
project_id: str,
|
394
305
|
location: str,
|
395
306
|
contents: list,
|
396
|
-
pretrained_model: str
|
307
|
+
pretrained_model: str,
|
397
308
|
gcp_conn_id: str = "google_cloud_default",
|
398
309
|
impersonation_chain: str | Sequence[str] | None = None,
|
399
310
|
**kwargs,
|
@@ -421,8 +332,8 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
421
332
|
self.log.info("Total tokens: %s", response.total_tokens)
|
422
333
|
self.log.info("Total billable characters: %s", response.total_billable_characters)
|
423
334
|
|
424
|
-
|
425
|
-
|
335
|
+
context["ti"].xcom_push(key="total_tokens", value=response.total_tokens)
|
336
|
+
context["ti"].xcom_push(key="total_billable_characters", value=response.total_billable_characters)
|
426
337
|
|
427
338
|
|
428
339
|
class RunEvaluationOperator(GoogleCloudBaseOperator):
|
@@ -562,8 +473,8 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
562
473
|
project_id: str,
|
563
474
|
location: str,
|
564
475
|
model_name: str,
|
565
|
-
system_instruction:
|
566
|
-
contents: list | None = None,
|
476
|
+
system_instruction: Any | None = None,
|
477
|
+
contents: list[Any] | None = None,
|
567
478
|
ttl_hours: float = 1,
|
568
479
|
display_name: str | None = None,
|
569
480
|
gcp_conn_id: str = "google_cloud_default",
|
@@ -674,3 +585,68 @@ class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
|
|
674
585
|
self.log.info("Cached Content Response: %s", cached_content_text)
|
675
586
|
|
676
587
|
return cached_content_text
|
588
|
+
|
589
|
+
|
590
|
+
class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
|
591
|
+
"""
|
592
|
+
Use the Rapid Evaluation API to evaluate a model.
|
593
|
+
|
594
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
595
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
596
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
597
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
598
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
599
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
600
|
+
credentials, or chained list of accounts required to get the access_token
|
601
|
+
of the last account in the list, which will be impersonated in the request.
|
602
|
+
If set as a string, the account must grant the originating account
|
603
|
+
the Service Account Token Creator IAM role.
|
604
|
+
If set as a sequence, the identities from the list must grant
|
605
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
606
|
+
account from the list granting this role to the originating account (templated).
|
607
|
+
"""
|
608
|
+
|
609
|
+
template_fields = (
|
610
|
+
"location",
|
611
|
+
"project_id",
|
612
|
+
"impersonation_chain",
|
613
|
+
"experiment_name",
|
614
|
+
"experiment_run_name",
|
615
|
+
)
|
616
|
+
|
617
|
+
def __init__(
|
618
|
+
self,
|
619
|
+
*,
|
620
|
+
project_id: str,
|
621
|
+
location: str,
|
622
|
+
experiment_name: str,
|
623
|
+
experiment_run_name: str,
|
624
|
+
gcp_conn_id: str = "google_cloud_default",
|
625
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
626
|
+
**kwargs,
|
627
|
+
) -> None:
|
628
|
+
super().__init__(**kwargs)
|
629
|
+
self.project_id = project_id
|
630
|
+
self.location = location
|
631
|
+
self.experiment_name = experiment_name
|
632
|
+
self.experiment_run_name = experiment_run_name
|
633
|
+
self.gcp_conn_id = gcp_conn_id
|
634
|
+
self.impersonation_chain = impersonation_chain
|
635
|
+
|
636
|
+
def execute(self, context: Context) -> None:
|
637
|
+
self.hook = ExperimentRunHook(
|
638
|
+
gcp_conn_id=self.gcp_conn_id,
|
639
|
+
impersonation_chain=self.impersonation_chain,
|
640
|
+
)
|
641
|
+
|
642
|
+
try:
|
643
|
+
self.hook.delete_experiment_run(
|
644
|
+
project_id=self.project_id,
|
645
|
+
location=self.location,
|
646
|
+
experiment_name=self.experiment_name,
|
647
|
+
experiment_run_name=self.experiment_run_name,
|
648
|
+
)
|
649
|
+
except exceptions.NotFound:
|
650
|
+
raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found")
|
651
|
+
|
652
|
+
self.log.info("Deleted experiment run: %s", self.experiment_run_name)
|
@@ -257,10 +257,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
257
257
|
hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
|
258
258
|
self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
|
259
259
|
|
260
|
-
|
261
|
-
VertexAITrainingLink.persist(
|
262
|
-
context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
|
263
|
-
)
|
260
|
+
context["ti"].xcom_push(key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
|
261
|
+
VertexAITrainingLink.persist(context=context, training_id=hyperparameter_tuning_job_id)
|
264
262
|
|
265
263
|
if self.deferrable:
|
266
264
|
self.defer(
|
@@ -355,9 +353,7 @@ class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
355
353
|
timeout=self.timeout,
|
356
354
|
metadata=self.metadata,
|
357
355
|
)
|
358
|
-
VertexAITrainingLink.persist(
|
359
|
-
context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id
|
360
|
-
)
|
356
|
+
VertexAITrainingLink.persist(context=context, training_id=self.hyperparameter_tuning_job_id)
|
361
357
|
self.log.info("Hyperparameter tuning job was gotten.")
|
362
358
|
return types.HyperparameterTuningJob.to_dict(result)
|
363
359
|
except NotFound:
|
@@ -487,6 +483,12 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
487
483
|
self.gcp_conn_id = gcp_conn_id
|
488
484
|
self.impersonation_chain = impersonation_chain
|
489
485
|
|
486
|
+
@property
|
487
|
+
def extra_links_params(self) -> dict[str, Any]:
|
488
|
+
return {
|
489
|
+
"project_id": self.project_id,
|
490
|
+
}
|
491
|
+
|
490
492
|
def execute(self, context: Context):
|
491
493
|
hook = HyperparameterTuningJobHook(
|
492
494
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -503,5 +505,5 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
503
505
|
timeout=self.timeout,
|
504
506
|
metadata=self.metadata,
|
505
507
|
)
|
506
|
-
VertexAIHyperparameterTuningJobListLink.persist(context=context
|
508
|
+
VertexAIHyperparameterTuningJobListLink.persist(context=context)
|
507
509
|
return [types.HyperparameterTuningJob.to_dict(result) for result in results]
|
@@ -20,7 +20,7 @@
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
from collections.abc import Sequence
|
23
|
-
from typing import TYPE_CHECKING
|
23
|
+
from typing import TYPE_CHECKING, Any
|
24
24
|
|
25
25
|
from google.api_core.exceptions import NotFound
|
26
26
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
@@ -161,6 +161,13 @@ class GetModelOperator(GoogleCloudBaseOperator):
|
|
161
161
|
self.gcp_conn_id = gcp_conn_id
|
162
162
|
self.impersonation_chain = impersonation_chain
|
163
163
|
|
164
|
+
@property
|
165
|
+
def extra_links_params(self) -> dict[str, Any]:
|
166
|
+
return {
|
167
|
+
"region": self.region,
|
168
|
+
"project_id": self.project_id,
|
169
|
+
}
|
170
|
+
|
164
171
|
def execute(self, context: Context):
|
165
172
|
hook = ModelServiceHook(
|
166
173
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -179,8 +186,8 @@ class GetModelOperator(GoogleCloudBaseOperator):
|
|
179
186
|
)
|
180
187
|
self.log.info("Model found. Model ID: %s", self.model_id)
|
181
188
|
|
182
|
-
|
183
|
-
VertexAIModelLink.persist(context=context,
|
189
|
+
context["ti"].xcom_push(key="model_id", value=self.model_id)
|
190
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
184
191
|
return Model.to_dict(model)
|
185
192
|
except NotFound:
|
186
193
|
self.log.info("The Model ID %s does not exist.", self.model_id)
|
@@ -257,7 +264,12 @@ class ExportModelOperator(GoogleCloudBaseOperator):
|
|
257
264
|
metadata=self.metadata,
|
258
265
|
)
|
259
266
|
hook.wait_for_operation(timeout=self.timeout, operation=operation)
|
260
|
-
VertexAIModelExportLink.persist(
|
267
|
+
VertexAIModelExportLink.persist(
|
268
|
+
context=context,
|
269
|
+
output_config=self.output_config,
|
270
|
+
model_id=self.model_id,
|
271
|
+
project_id=self.project_id,
|
272
|
+
)
|
261
273
|
self.log.info("Model was exported.")
|
262
274
|
except NotFound:
|
263
275
|
self.log.info("The Model ID %s does not exist.", self.model_id)
|
@@ -335,6 +347,12 @@ class ListModelsOperator(GoogleCloudBaseOperator):
|
|
335
347
|
self.gcp_conn_id = gcp_conn_id
|
336
348
|
self.impersonation_chain = impersonation_chain
|
337
349
|
|
350
|
+
@property
|
351
|
+
def extra_links_params(self) -> dict[str, Any]:
|
352
|
+
return {
|
353
|
+
"project_id": self.project_id,
|
354
|
+
}
|
355
|
+
|
338
356
|
def execute(self, context: Context):
|
339
357
|
hook = ModelServiceHook(
|
340
358
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -352,7 +370,7 @@ class ListModelsOperator(GoogleCloudBaseOperator):
|
|
352
370
|
timeout=self.timeout,
|
353
371
|
metadata=self.metadata,
|
354
372
|
)
|
355
|
-
VertexAIModelListLink.persist(context=context
|
373
|
+
VertexAIModelListLink.persist(context=context)
|
356
374
|
return [Model.to_dict(result) for result in results]
|
357
375
|
|
358
376
|
|
@@ -407,6 +425,13 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
407
425
|
self.gcp_conn_id = gcp_conn_id
|
408
426
|
self.impersonation_chain = impersonation_chain
|
409
427
|
|
428
|
+
@property
|
429
|
+
def extra_links_params(self) -> dict[str, Any]:
|
430
|
+
return {
|
431
|
+
"region": self.region,
|
432
|
+
"project_id": self.project_id,
|
433
|
+
}
|
434
|
+
|
410
435
|
def execute(self, context: Context):
|
411
436
|
hook = ModelServiceHook(
|
412
437
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -428,8 +453,8 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
428
453
|
model_id = hook.extract_model_id(model_resp)
|
429
454
|
self.log.info("Model was uploaded. Model ID: %s", model_id)
|
430
455
|
|
431
|
-
|
432
|
-
VertexAIModelLink.persist(context=context,
|
456
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
457
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
433
458
|
return model_resp
|
434
459
|
|
435
460
|
|
@@ -553,6 +578,13 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
|
|
553
578
|
self.gcp_conn_id = gcp_conn_id
|
554
579
|
self.impersonation_chain = impersonation_chain
|
555
580
|
|
581
|
+
@property
|
582
|
+
def extra_links_params(self) -> dict[str, Any]:
|
583
|
+
return {
|
584
|
+
"region": self.region,
|
585
|
+
"project_id": self.project_id,
|
586
|
+
}
|
587
|
+
|
556
588
|
def execute(self, context: Context):
|
557
589
|
hook = ModelServiceHook(
|
558
590
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -571,7 +603,7 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
|
|
571
603
|
timeout=self.timeout,
|
572
604
|
metadata=self.metadata,
|
573
605
|
)
|
574
|
-
VertexAIModelLink.persist(context=context,
|
606
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
575
607
|
return Model.to_dict(updated_model)
|
576
608
|
|
577
609
|
|
@@ -627,6 +659,13 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
627
659
|
self.gcp_conn_id = gcp_conn_id
|
628
660
|
self.impersonation_chain = impersonation_chain
|
629
661
|
|
662
|
+
@property
|
663
|
+
def extra_links_params(self) -> dict[str, Any]:
|
664
|
+
return {
|
665
|
+
"region": self.region,
|
666
|
+
"project_id": self.project_id,
|
667
|
+
}
|
668
|
+
|
630
669
|
def execute(self, context: Context):
|
631
670
|
hook = ModelServiceHook(
|
632
671
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -645,7 +684,7 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
645
684
|
timeout=self.timeout,
|
646
685
|
metadata=self.metadata,
|
647
686
|
)
|
648
|
-
VertexAIModelLink.persist(context=context,
|
687
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
649
688
|
return Model.to_dict(updated_model)
|
650
689
|
|
651
690
|
|
@@ -701,6 +740,13 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
701
740
|
self.gcp_conn_id = gcp_conn_id
|
702
741
|
self.impersonation_chain = impersonation_chain
|
703
742
|
|
743
|
+
@property
|
744
|
+
def extra_links_params(self) -> dict[str, Any]:
|
745
|
+
return {
|
746
|
+
"region": self.region,
|
747
|
+
"project_id": self.project_id,
|
748
|
+
}
|
749
|
+
|
704
750
|
def execute(self, context: Context):
|
705
751
|
hook = ModelServiceHook(
|
706
752
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -721,7 +767,7 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
721
767
|
timeout=self.timeout,
|
722
768
|
metadata=self.metadata,
|
723
769
|
)
|
724
|
-
VertexAIModelLink.persist(context=context,
|
770
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
725
771
|
return Model.to_dict(updated_model)
|
726
772
|
|
727
773
|
|
@@ -112,6 +112,10 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
|
112
112
|
"project_id",
|
113
113
|
"input_artifacts",
|
114
114
|
"impersonation_chain",
|
115
|
+
"template_path",
|
116
|
+
"pipeline_root",
|
117
|
+
"parameter_values",
|
118
|
+
"service_account",
|
115
119
|
]
|
116
120
|
operator_extra_links = (VertexAIPipelineJobLink(),)
|
117
121
|
|
@@ -162,6 +166,13 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
|
162
166
|
self.deferrable = deferrable
|
163
167
|
self.poll_interval = poll_interval
|
164
168
|
|
169
|
+
@property
|
170
|
+
def extra_links_params(self) -> dict[str, Any]:
|
171
|
+
return {
|
172
|
+
"region": self.region,
|
173
|
+
"project_id": self.project_id,
|
174
|
+
}
|
175
|
+
|
165
176
|
def execute(self, context: Context):
|
166
177
|
self.log.info("Running Pipeline job")
|
167
178
|
pipeline_job_obj: PipelineJob = self.hook.submit_pipeline_job(
|
@@ -184,8 +195,8 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
|
184
195
|
)
|
185
196
|
pipeline_job_id = pipeline_job_obj.job_id
|
186
197
|
self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id)
|
187
|
-
|
188
|
-
VertexAIPipelineJobLink.persist(context=context,
|
198
|
+
context["ti"].xcom_push(key="pipeline_job_id", value=pipeline_job_id)
|
199
|
+
VertexAIPipelineJobLink.persist(context=context, pipeline_id=pipeline_job_id)
|
189
200
|
|
190
201
|
if self.deferrable:
|
191
202
|
pipeline_job_obj.wait_for_resource_creation()
|
@@ -276,6 +287,13 @@ class GetPipelineJobOperator(GoogleCloudBaseOperator):
|
|
276
287
|
self.gcp_conn_id = gcp_conn_id
|
277
288
|
self.impersonation_chain = impersonation_chain
|
278
289
|
|
290
|
+
@property
|
291
|
+
def extra_links_params(self) -> dict[str, Any]:
|
292
|
+
return {
|
293
|
+
"region": self.region,
|
294
|
+
"project_id": self.project_id,
|
295
|
+
}
|
296
|
+
|
279
297
|
def execute(self, context: Context):
|
280
298
|
hook = PipelineJobHook(
|
281
299
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -292,9 +310,7 @@ class GetPipelineJobOperator(GoogleCloudBaseOperator):
|
|
292
310
|
timeout=self.timeout,
|
293
311
|
metadata=self.metadata,
|
294
312
|
)
|
295
|
-
VertexAIPipelineJobLink.persist(
|
296
|
-
context=context, task_instance=self, pipeline_id=self.pipeline_job_id
|
297
|
-
)
|
313
|
+
VertexAIPipelineJobLink.persist(context=context, pipeline_id=self.pipeline_job_id)
|
298
314
|
self.log.info("Pipeline job was gotten.")
|
299
315
|
return types.PipelineJob.to_dict(result)
|
300
316
|
except NotFound:
|
@@ -408,6 +424,13 @@ class ListPipelineJobOperator(GoogleCloudBaseOperator):
|
|
408
424
|
self.gcp_conn_id = gcp_conn_id
|
409
425
|
self.impersonation_chain = impersonation_chain
|
410
426
|
|
427
|
+
@property
|
428
|
+
def extra_links_params(self) -> dict[str, Any]:
|
429
|
+
return {
|
430
|
+
"region": self.region,
|
431
|
+
"project_id": self.project_id,
|
432
|
+
}
|
433
|
+
|
411
434
|
def execute(self, context: Context):
|
412
435
|
hook = PipelineJobHook(
|
413
436
|
gcp_conn_id=self.gcp_conn_id,
|
@@ -424,7 +447,7 @@ class ListPipelineJobOperator(GoogleCloudBaseOperator):
|
|
424
447
|
timeout=self.timeout,
|
425
448
|
metadata=self.metadata,
|
426
449
|
)
|
427
|
-
VertexAIPipelineJobListLink.persist(context=context
|
450
|
+
VertexAIPipelineJobListLink.persist(context=context)
|
428
451
|
return [types.PipelineJob.to_dict(result) for result in results]
|
429
452
|
|
430
453
|
|
@@ -188,12 +188,13 @@ class CreateRayClusterOperator(RayBaseOperator):
|
|
188
188
|
labels=self.labels,
|
189
189
|
)
|
190
190
|
cluster_id = self.hook.extract_cluster_id(cluster_path)
|
191
|
-
|
192
|
-
context=context,
|
191
|
+
context["ti"].xcom_push(
|
193
192
|
key="cluster_id",
|
194
193
|
value=cluster_id,
|
195
194
|
)
|
196
|
-
VertexAIRayClusterLink.persist(
|
195
|
+
VertexAIRayClusterLink.persist(
|
196
|
+
context=context, location=self.location, cluster_id=cluster_id, project_id=self.project_id
|
197
|
+
)
|
197
198
|
self.log.info("Ray cluster was created.")
|
198
199
|
except Exception as error:
|
199
200
|
raise AirflowException(error)
|
@@ -220,7 +221,7 @@ class ListRayClustersOperator(RayBaseOperator):
|
|
220
221
|
operator_extra_links = (VertexAIRayClusterListLink(),)
|
221
222
|
|
222
223
|
def execute(self, context: Context):
|
223
|
-
VertexAIRayClusterListLink.persist(context=context,
|
224
|
+
VertexAIRayClusterListLink.persist(context=context, project_id=self.project_id)
|
224
225
|
self.log.info("Listing Clusters from location %s.", self.location)
|
225
226
|
try:
|
226
227
|
ray_cluster_list = self.hook.list_ray_clusters(
|
@@ -268,8 +269,9 @@ class GetRayClusterOperator(RayBaseOperator):
|
|
268
269
|
def execute(self, context: Context):
|
269
270
|
VertexAIRayClusterLink.persist(
|
270
271
|
context=context,
|
271
|
-
|
272
|
+
location=self.location,
|
272
273
|
cluster_id=self.cluster_id,
|
274
|
+
project_id=self.project_id,
|
273
275
|
)
|
274
276
|
self.log.info("Getting Cluster: %s", self.cluster_id)
|
275
277
|
try:
|
@@ -325,8 +327,9 @@ class UpdateRayClusterOperator(RayBaseOperator):
|
|
325
327
|
def execute(self, context: Context):
|
326
328
|
VertexAIRayClusterLink.persist(
|
327
329
|
context=context,
|
328
|
-
|
330
|
+
location=self.location,
|
329
331
|
cluster_id=self.cluster_id,
|
332
|
+
project_id=self.project_id,
|
330
333
|
)
|
331
334
|
self.log.info("Updating a Ray cluster.")
|
332
335
|
try:
|