apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +39 -5
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/bigquery.py +166 -281
- airflow/providers/google/cloud/hooks/cloud_composer.py +287 -14
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_run.py +17 -9
- airflow/providers/google/cloud/hooks/cloud_sql.py +101 -22
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +27 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +5 -1
- airflow/providers/google/cloud/hooks/datacatalog.py +9 -1
- airflow/providers/google/cloud/hooks/dataflow.py +71 -94
- airflow/providers/google/cloud/hooks/datafusion.py +1 -1
- airflow/providers/google/cloud/hooks/dataplex.py +1 -1
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +72 -71
- airflow/providers/google/cloud/hooks/gcs.py +111 -14
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +2 -2
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/mlengine.py +3 -2
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +73 -8
- airflow/providers/google/cloud/hooks/stackdriver.py +10 -8
- airflow/providers/google/cloud/hooks/translate.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +0 -209
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +2 -2
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +27 -1
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +307 -7
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/vision.py +2 -2
- airflow/providers/google/cloud/hooks/workflows.py +1 -1
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -13
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -95
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +17 -9
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +10 -8
- airflow/providers/google/cloud/openlineage/utils.py +15 -1
- airflow/providers/google/cloud/operators/alloy_db.py +70 -55
- airflow/providers/google/cloud/operators/bigquery.py +73 -636
- airflow/providers/google/cloud/operators/bigquery_dts.py +3 -5
- airflow/providers/google/cloud/operators/bigtable.py +36 -7
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +2 -2
- airflow/providers/google/cloud/operators/cloud_build.py +75 -32
- airflow/providers/google/cloud/operators/cloud_composer.py +128 -40
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
- airflow/providers/google/cloud/operators/cloud_run.py +23 -5
- airflow/providers/google/cloud/operators/cloud_sql.py +8 -16
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -11
- airflow/providers/google/cloud/operators/compute.py +8 -40
- airflow/providers/google/cloud/operators/datacatalog.py +157 -21
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +15 -5
- airflow/providers/google/cloud/operators/datafusion.py +41 -20
- airflow/providers/google/cloud/operators/dataplex.py +193 -109
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +78 -35
- airflow/providers/google/cloud/operators/dataproc_metastore.py +96 -88
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +6 -29
- airflow/providers/google/cloud/operators/functions.py +16 -7
- airflow/providers/google/cloud/operators/gcs.py +10 -8
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +60 -99
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +107 -52
- airflow/providers/google/cloud/operators/natural_language.py +1 -1
- airflow/providers/google/cloud/operators/pubsub.py +60 -14
- airflow/providers/google/cloud/operators/spanner.py +25 -12
- airflow/providers/google/cloud/operators/speech_to_text.py +1 -2
- airflow/providers/google/cloud/operators/stackdriver.py +1 -9
- airflow/providers/google/cloud/operators/tasks.py +1 -12
- airflow/providers/google/cloud/operators/text_to_speech.py +1 -2
- airflow/providers/google/cloud/operators/translate.py +40 -16
- airflow/providers/google/cloud/operators/translate_speech.py +1 -2
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +29 -9
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +54 -26
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +11 -9
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +30 -7
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
- airflow/providers/google/cloud/operators/vision.py +2 -2
- airflow/providers/google/cloud/operators/workflows.py +18 -15
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -2
- airflow/providers/google/cloud/sensors/bigtable.py +11 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +533 -29
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -2
- airflow/providers/google/cloud/sensors/dataflow.py +26 -9
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +2 -2
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -2
- airflow/providers/google/cloud/sensors/gcs.py +4 -4
- airflow/providers/google/cloud/sensors/looker.py +2 -2
- airflow/providers/google/cloud/sensors/pubsub.py +4 -4
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +2 -2
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +20 -12
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +13 -4
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_composer.py +302 -46
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +91 -1
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +122 -52
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +45 -27
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +15 -19
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +1 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/common/auth_backend/google_openid.py +4 -4
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +27 -8
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +3 -4
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +56 -52
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +26 -1
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +1 -2
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +3 -63
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +92 -48
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/hooks/automl.py +0 -673
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1362
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -112
- apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -20,107 +20,27 @@
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
|
-
from typing import TYPE_CHECKING
|
|
23
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
24
24
|
|
|
25
|
-
from
|
|
26
|
-
|
|
25
|
+
from google.api_core import exceptions
|
|
26
|
+
|
|
27
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
28
|
+
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
|
|
29
|
+
ExperimentRunHook,
|
|
30
|
+
GenerativeModelHook,
|
|
31
|
+
)
|
|
27
32
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
28
33
|
from airflow.providers.google.common.deprecated import deprecated
|
|
29
34
|
|
|
30
35
|
if TYPE_CHECKING:
|
|
31
|
-
from airflow.
|
|
36
|
+
from airflow.providers.common.compat.sdk import Context
|
|
32
37
|
|
|
33
38
|
|
|
34
39
|
@deprecated(
|
|
35
|
-
planned_removal_date="
|
|
36
|
-
use_instead="
|
|
40
|
+
planned_removal_date="January 3, 2026",
|
|
41
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateEmbeddingsOperator",
|
|
37
42
|
category=AirflowProviderDeprecationWarning,
|
|
38
43
|
)
|
|
39
|
-
class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
|
|
40
|
-
"""
|
|
41
|
-
Uses the Vertex AI PaLM API to generate natural language text.
|
|
42
|
-
|
|
43
|
-
:param project_id: Required. The ID of the Google Cloud project that the
|
|
44
|
-
service belongs to (templated).
|
|
45
|
-
:param location: Required. The ID of the Google Cloud location that the
|
|
46
|
-
service belongs to (templated).
|
|
47
|
-
:param prompt: Required. Inputs or queries that a user or a program gives
|
|
48
|
-
to the Vertex AI PaLM API, in order to elicit a specific response (templated).
|
|
49
|
-
:param pretrained_model: By default uses the pre-trained model `text-bison`,
|
|
50
|
-
optimized for performing natural language tasks such as classification,
|
|
51
|
-
summarization, extraction, content creation, and ideation.
|
|
52
|
-
:param temperature: Temperature controls the degree of randomness in token
|
|
53
|
-
selection. Defaults to 0.0.
|
|
54
|
-
:param max_output_tokens: Token limit determines the maximum amount of text
|
|
55
|
-
output. Defaults to 256.
|
|
56
|
-
:param top_p: Tokens are selected from most probable to least until the sum
|
|
57
|
-
of their probabilities equals the top_p value. Defaults to 0.8.
|
|
58
|
-
:param top_k: A top_k of 1 means the selected token is the most probable
|
|
59
|
-
among all tokens. Defaults to 0.4.
|
|
60
|
-
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
61
|
-
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
62
|
-
credentials, or chained list of accounts required to get the access_token
|
|
63
|
-
of the last account in the list, which will be impersonated in the request.
|
|
64
|
-
If set as a string, the account must grant the originating account
|
|
65
|
-
the Service Account Token Creator IAM role.
|
|
66
|
-
If set as a sequence, the identities from the list must grant
|
|
67
|
-
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
68
|
-
account from the list granting this role to the originating account (templated).
|
|
69
|
-
"""
|
|
70
|
-
|
|
71
|
-
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
|
72
|
-
|
|
73
|
-
def __init__(
|
|
74
|
-
self,
|
|
75
|
-
*,
|
|
76
|
-
project_id: str,
|
|
77
|
-
location: str,
|
|
78
|
-
prompt: str,
|
|
79
|
-
pretrained_model: str = "text-bison",
|
|
80
|
-
temperature: float = 0.0,
|
|
81
|
-
max_output_tokens: int = 256,
|
|
82
|
-
top_p: float = 0.8,
|
|
83
|
-
top_k: int = 40,
|
|
84
|
-
gcp_conn_id: str = "google_cloud_default",
|
|
85
|
-
impersonation_chain: str | Sequence[str] | None = None,
|
|
86
|
-
**kwargs,
|
|
87
|
-
) -> None:
|
|
88
|
-
super().__init__(**kwargs)
|
|
89
|
-
self.project_id = project_id
|
|
90
|
-
self.location = location
|
|
91
|
-
self.prompt = prompt
|
|
92
|
-
self.pretrained_model = pretrained_model
|
|
93
|
-
self.temperature = temperature
|
|
94
|
-
self.max_output_tokens = max_output_tokens
|
|
95
|
-
self.top_p = top_p
|
|
96
|
-
self.top_k = top_k
|
|
97
|
-
self.gcp_conn_id = gcp_conn_id
|
|
98
|
-
self.impersonation_chain = impersonation_chain
|
|
99
|
-
|
|
100
|
-
def execute(self, context: Context):
|
|
101
|
-
self.hook = GenerativeModelHook(
|
|
102
|
-
gcp_conn_id=self.gcp_conn_id,
|
|
103
|
-
impersonation_chain=self.impersonation_chain,
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
self.log.info("Submitting prompt")
|
|
107
|
-
response = self.hook.text_generation_model_predict(
|
|
108
|
-
project_id=self.project_id,
|
|
109
|
-
location=self.location,
|
|
110
|
-
prompt=self.prompt,
|
|
111
|
-
pretrained_model=self.pretrained_model,
|
|
112
|
-
temperature=self.temperature,
|
|
113
|
-
max_output_tokens=self.max_output_tokens,
|
|
114
|
-
top_p=self.top_p,
|
|
115
|
-
top_k=self.top_k,
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
self.log.info("Model response: %s", response)
|
|
119
|
-
self.xcom_push(context, key="model_response", value=response)
|
|
120
|
-
|
|
121
|
-
return response
|
|
122
|
-
|
|
123
|
-
|
|
124
44
|
class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
125
45
|
"""
|
|
126
46
|
Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
|
|
@@ -130,9 +50,8 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
130
50
|
:param location: Required. The ID of the Google Cloud location that the
|
|
131
51
|
service belongs to (templated).
|
|
132
52
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
|
133
|
-
to the Vertex AI
|
|
134
|
-
:param pretrained_model:
|
|
135
|
-
optimized for performing text embeddings.
|
|
53
|
+
to the Vertex AI Generative Model API, in order to elicit a specific response (templated).
|
|
54
|
+
:param pretrained_model: Required. Model, optimized for performing text embeddings.
|
|
136
55
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
137
56
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
138
57
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -144,7 +63,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
144
63
|
account from the list granting this role to the originating account (templated).
|
|
145
64
|
"""
|
|
146
65
|
|
|
147
|
-
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
|
66
|
+
template_fields = ("location", "project_id", "impersonation_chain", "prompt", "pretrained_model")
|
|
148
67
|
|
|
149
68
|
def __init__(
|
|
150
69
|
self,
|
|
@@ -152,7 +71,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
152
71
|
project_id: str,
|
|
153
72
|
location: str,
|
|
154
73
|
prompt: str,
|
|
155
|
-
pretrained_model: str
|
|
74
|
+
pretrained_model: str,
|
|
156
75
|
gcp_conn_id: str = "google_cloud_default",
|
|
157
76
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
158
77
|
**kwargs,
|
|
@@ -180,11 +99,16 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
180
99
|
)
|
|
181
100
|
|
|
182
101
|
self.log.info("Model response: %s", response)
|
|
183
|
-
|
|
102
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
|
184
103
|
|
|
185
104
|
return response
|
|
186
105
|
|
|
187
106
|
|
|
107
|
+
@deprecated(
|
|
108
|
+
planned_removal_date="January 3, 2026",
|
|
109
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
|
|
110
|
+
category=AirflowProviderDeprecationWarning,
|
|
111
|
+
)
|
|
188
112
|
class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
189
113
|
"""
|
|
190
114
|
Use the Vertex AI Gemini Pro foundation model to generate content.
|
|
@@ -199,10 +123,9 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
199
123
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
|
200
124
|
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
|
|
201
125
|
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
|
|
202
|
-
:param pretrained_model:
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
output text and code.
|
|
126
|
+
:param pretrained_model: Required. The name of the model to use for content generation,
|
|
127
|
+
which can be a text-only or multimodal model. For example, `gemini-pro` or
|
|
128
|
+
`gemini-pro-vision`.
|
|
206
129
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
207
130
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
208
131
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -226,7 +149,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
226
149
|
generation_config: dict | None = None,
|
|
227
150
|
safety_settings: dict | None = None,
|
|
228
151
|
system_instruction: str | None = None,
|
|
229
|
-
pretrained_model: str
|
|
152
|
+
pretrained_model: str,
|
|
230
153
|
gcp_conn_id: str = "google_cloud_default",
|
|
231
154
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
232
155
|
**kwargs,
|
|
@@ -260,11 +183,16 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
260
183
|
)
|
|
261
184
|
|
|
262
185
|
self.log.info("Model response: %s", response)
|
|
263
|
-
|
|
186
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
|
264
187
|
|
|
265
188
|
return response
|
|
266
189
|
|
|
267
190
|
|
|
191
|
+
@deprecated(
|
|
192
|
+
planned_removal_date="January 3, 2026",
|
|
193
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAISupervisedFineTuningTrainOperator",
|
|
194
|
+
category=AirflowProviderDeprecationWarning,
|
|
195
|
+
)
|
|
268
196
|
class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
269
197
|
"""
|
|
270
198
|
Use the Supervised Fine Tuning API to create a tuning job.
|
|
@@ -298,7 +226,14 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
298
226
|
account from the list granting this role to the originating account (templated).
|
|
299
227
|
"""
|
|
300
228
|
|
|
301
|
-
template_fields = (
|
|
229
|
+
template_fields = (
|
|
230
|
+
"location",
|
|
231
|
+
"project_id",
|
|
232
|
+
"impersonation_chain",
|
|
233
|
+
"train_dataset",
|
|
234
|
+
"validation_dataset",
|
|
235
|
+
"source_model",
|
|
236
|
+
)
|
|
302
237
|
|
|
303
238
|
def __init__(
|
|
304
239
|
self,
|
|
@@ -310,7 +245,7 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
310
245
|
tuned_model_display_name: str | None = None,
|
|
311
246
|
validation_dataset: str | None = None,
|
|
312
247
|
epochs: int | None = None,
|
|
313
|
-
adapter_size:
|
|
248
|
+
adapter_size: Literal[1, 4, 8, 16] | None = None,
|
|
314
249
|
learning_rate_multiplier: float | None = None,
|
|
315
250
|
gcp_conn_id: str = "google_cloud_default",
|
|
316
251
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
@@ -349,8 +284,8 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
349
284
|
self.log.info("Tuned Model Name: %s", response.tuned_model_name)
|
|
350
285
|
self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)
|
|
351
286
|
|
|
352
|
-
|
|
353
|
-
|
|
287
|
+
context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model_name)
|
|
288
|
+
context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
|
|
354
289
|
|
|
355
290
|
result = {
|
|
356
291
|
"tuned_model_name": response.tuned_model_name,
|
|
@@ -360,6 +295,11 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
360
295
|
return result
|
|
361
296
|
|
|
362
297
|
|
|
298
|
+
@deprecated(
|
|
299
|
+
planned_removal_date="January 3, 2026",
|
|
300
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICountTokensOperator",
|
|
301
|
+
category=AirflowProviderDeprecationWarning,
|
|
302
|
+
)
|
|
363
303
|
class CountTokensOperator(GoogleCloudBaseOperator):
|
|
364
304
|
"""
|
|
365
305
|
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
|
|
@@ -370,10 +310,9 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
|
370
310
|
service belongs to (templated).
|
|
371
311
|
:param contents: Required. The multi-part content of a message that a user or a program
|
|
372
312
|
gives to the generative model, in order to elicit a specific response.
|
|
373
|
-
:param pretrained_model:
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
output text and code.
|
|
313
|
+
:param pretrained_model: Required. Model, supporting prompts with text-only input,
|
|
314
|
+
including natural language tasks, multi-turn text and code chat,
|
|
315
|
+
and code generation. It can output text and code.
|
|
377
316
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
378
317
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
379
318
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -393,7 +332,7 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
|
393
332
|
project_id: str,
|
|
394
333
|
location: str,
|
|
395
334
|
contents: list,
|
|
396
|
-
pretrained_model: str
|
|
335
|
+
pretrained_model: str,
|
|
397
336
|
gcp_conn_id: str = "google_cloud_default",
|
|
398
337
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
399
338
|
**kwargs,
|
|
@@ -421,8 +360,8 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
|
421
360
|
self.log.info("Total tokens: %s", response.total_tokens)
|
|
422
361
|
self.log.info("Total billable characters: %s", response.total_billable_characters)
|
|
423
362
|
|
|
424
|
-
|
|
425
|
-
|
|
363
|
+
context["ti"].xcom_push(key="total_tokens", value=response.total_tokens)
|
|
364
|
+
context["ti"].xcom_push(key="total_billable_characters", value=response.total_billable_characters)
|
|
426
365
|
|
|
427
366
|
|
|
428
367
|
class RunEvaluationOperator(GoogleCloudBaseOperator):
|
|
@@ -524,6 +463,11 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
|
|
|
524
463
|
return response.summary_metrics
|
|
525
464
|
|
|
526
465
|
|
|
466
|
+
@deprecated(
|
|
467
|
+
planned_removal_date="January 3, 2026",
|
|
468
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICreateCachedContentOperator",
|
|
469
|
+
category=AirflowProviderDeprecationWarning,
|
|
470
|
+
)
|
|
527
471
|
class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
528
472
|
"""
|
|
529
473
|
Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
|
|
@@ -562,8 +506,8 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
562
506
|
project_id: str,
|
|
563
507
|
location: str,
|
|
564
508
|
model_name: str,
|
|
565
|
-
system_instruction:
|
|
566
|
-
contents: list | None = None,
|
|
509
|
+
system_instruction: Any | None = None,
|
|
510
|
+
contents: list[Any] | None = None,
|
|
567
511
|
ttl_hours: float = 1,
|
|
568
512
|
display_name: str | None = None,
|
|
569
513
|
gcp_conn_id: str = "google_cloud_default",
|
|
@@ -603,6 +547,11 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
603
547
|
return cached_content_name
|
|
604
548
|
|
|
605
549
|
|
|
550
|
+
@deprecated(
|
|
551
|
+
planned_removal_date="January 3, 2026",
|
|
552
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
|
|
553
|
+
category=AirflowProviderDeprecationWarning,
|
|
554
|
+
)
|
|
606
555
|
class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
|
|
607
556
|
"""
|
|
608
557
|
Generate a response from CachedContent.
|
|
@@ -674,3 +623,73 @@ class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
674
623
|
self.log.info("Cached Content Response: %s", cached_content_text)
|
|
675
624
|
|
|
676
625
|
return cached_content_text
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
@deprecated(
|
|
629
|
+
planned_removal_date="January 3, 2026",
|
|
630
|
+
use_instead="airflow.providers.google.cloud.operators.vertex_ai.experiment_service.DeleteExperimentRunOperator",
|
|
631
|
+
category=AirflowProviderDeprecationWarning,
|
|
632
|
+
)
|
|
633
|
+
class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
|
|
634
|
+
"""
|
|
635
|
+
Use the Rapid Evaluation API to evaluate a model.
|
|
636
|
+
|
|
637
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
638
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
639
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
640
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
|
641
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
642
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
643
|
+
credentials, or chained list of accounts required to get the access_token
|
|
644
|
+
of the last account in the list, which will be impersonated in the request.
|
|
645
|
+
If set as a string, the account must grant the originating account
|
|
646
|
+
the Service Account Token Creator IAM role.
|
|
647
|
+
If set as a sequence, the identities from the list must grant
|
|
648
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
649
|
+
account from the list granting this role to the originating account (templated).
|
|
650
|
+
"""
|
|
651
|
+
|
|
652
|
+
template_fields = (
|
|
653
|
+
"location",
|
|
654
|
+
"project_id",
|
|
655
|
+
"impersonation_chain",
|
|
656
|
+
"experiment_name",
|
|
657
|
+
"experiment_run_name",
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
def __init__(
|
|
661
|
+
self,
|
|
662
|
+
*,
|
|
663
|
+
project_id: str,
|
|
664
|
+
location: str,
|
|
665
|
+
experiment_name: str,
|
|
666
|
+
experiment_run_name: str,
|
|
667
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
668
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
669
|
+
**kwargs,
|
|
670
|
+
) -> None:
|
|
671
|
+
super().__init__(**kwargs)
|
|
672
|
+
self.project_id = project_id
|
|
673
|
+
self.location = location
|
|
674
|
+
self.experiment_name = experiment_name
|
|
675
|
+
self.experiment_run_name = experiment_run_name
|
|
676
|
+
self.gcp_conn_id = gcp_conn_id
|
|
677
|
+
self.impersonation_chain = impersonation_chain
|
|
678
|
+
|
|
679
|
+
def execute(self, context: Context) -> None:
|
|
680
|
+
self.hook = ExperimentRunHook(
|
|
681
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
682
|
+
impersonation_chain=self.impersonation_chain,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
try:
|
|
686
|
+
self.hook.delete_experiment_run(
|
|
687
|
+
project_id=self.project_id,
|
|
688
|
+
location=self.location,
|
|
689
|
+
experiment_name=self.experiment_name,
|
|
690
|
+
experiment_run_name=self.experiment_run_name,
|
|
691
|
+
)
|
|
692
|
+
except exceptions.NotFound:
|
|
693
|
+
raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found")
|
|
694
|
+
|
|
695
|
+
self.log.info("Deleted experiment run: %s", self.experiment_run_name)
|
|
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
|
|
43
43
|
from google.api_core.retry import Retry
|
|
44
44
|
from google.cloud.aiplatform import HyperparameterTuningJob, gapic, hyperparameter_tuning
|
|
45
45
|
|
|
46
|
-
from airflow.
|
|
46
|
+
from airflow.providers.common.compat.sdk import Context
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
@@ -257,10 +257,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
257
257
|
hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
|
|
258
258
|
self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
|
|
259
259
|
|
|
260
|
-
|
|
261
|
-
VertexAITrainingLink.persist(
|
|
262
|
-
context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
|
|
263
|
-
)
|
|
260
|
+
context["ti"].xcom_push(key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
|
|
261
|
+
VertexAITrainingLink.persist(context=context, training_id=hyperparameter_tuning_job_id)
|
|
264
262
|
|
|
265
263
|
if self.deferrable:
|
|
266
264
|
self.defer(
|
|
@@ -355,9 +353,7 @@ class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
355
353
|
timeout=self.timeout,
|
|
356
354
|
metadata=self.metadata,
|
|
357
355
|
)
|
|
358
|
-
VertexAITrainingLink.persist(
|
|
359
|
-
context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id
|
|
360
|
-
)
|
|
356
|
+
VertexAITrainingLink.persist(context=context, training_id=self.hyperparameter_tuning_job_id)
|
|
361
357
|
self.log.info("Hyperparameter tuning job was gotten.")
|
|
362
358
|
return types.HyperparameterTuningJob.to_dict(result)
|
|
363
359
|
except NotFound:
|
|
@@ -487,6 +483,12 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
487
483
|
self.gcp_conn_id = gcp_conn_id
|
|
488
484
|
self.impersonation_chain = impersonation_chain
|
|
489
485
|
|
|
486
|
+
@property
|
|
487
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
488
|
+
return {
|
|
489
|
+
"project_id": self.project_id,
|
|
490
|
+
}
|
|
491
|
+
|
|
490
492
|
def execute(self, context: Context):
|
|
491
493
|
hook = HyperparameterTuningJobHook(
|
|
492
494
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -503,5 +505,5 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
503
505
|
timeout=self.timeout,
|
|
504
506
|
metadata=self.metadata,
|
|
505
507
|
)
|
|
506
|
-
VertexAIHyperparameterTuningJobListLink.persist(context=context
|
|
508
|
+
VertexAIHyperparameterTuningJobListLink.persist(context=context)
|
|
507
509
|
return [types.HyperparameterTuningJob.to_dict(result) for result in results]
|
|
@@ -20,7 +20,7 @@
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
|
-
from typing import TYPE_CHECKING
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
24
|
|
|
25
25
|
from google.api_core.exceptions import NotFound
|
|
26
26
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
@@ -37,7 +37,7 @@ from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseO
|
|
|
37
37
|
if TYPE_CHECKING:
|
|
38
38
|
from google.api_core.retry import Retry
|
|
39
39
|
|
|
40
|
-
from airflow.
|
|
40
|
+
from airflow.providers.common.compat.sdk import Context
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class DeleteModelOperator(GoogleCloudBaseOperator):
|
|
@@ -161,6 +161,13 @@ class GetModelOperator(GoogleCloudBaseOperator):
|
|
|
161
161
|
self.gcp_conn_id = gcp_conn_id
|
|
162
162
|
self.impersonation_chain = impersonation_chain
|
|
163
163
|
|
|
164
|
+
@property
|
|
165
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
166
|
+
return {
|
|
167
|
+
"region": self.region,
|
|
168
|
+
"project_id": self.project_id,
|
|
169
|
+
}
|
|
170
|
+
|
|
164
171
|
def execute(self, context: Context):
|
|
165
172
|
hook = ModelServiceHook(
|
|
166
173
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -179,8 +186,8 @@ class GetModelOperator(GoogleCloudBaseOperator):
|
|
|
179
186
|
)
|
|
180
187
|
self.log.info("Model found. Model ID: %s", self.model_id)
|
|
181
188
|
|
|
182
|
-
|
|
183
|
-
VertexAIModelLink.persist(context=context,
|
|
189
|
+
context["ti"].xcom_push(key="model_id", value=self.model_id)
|
|
190
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
184
191
|
return Model.to_dict(model)
|
|
185
192
|
except NotFound:
|
|
186
193
|
self.log.info("The Model ID %s does not exist.", self.model_id)
|
|
@@ -257,7 +264,12 @@ class ExportModelOperator(GoogleCloudBaseOperator):
|
|
|
257
264
|
metadata=self.metadata,
|
|
258
265
|
)
|
|
259
266
|
hook.wait_for_operation(timeout=self.timeout, operation=operation)
|
|
260
|
-
VertexAIModelExportLink.persist(
|
|
267
|
+
VertexAIModelExportLink.persist(
|
|
268
|
+
context=context,
|
|
269
|
+
output_config=self.output_config,
|
|
270
|
+
model_id=self.model_id,
|
|
271
|
+
project_id=self.project_id,
|
|
272
|
+
)
|
|
261
273
|
self.log.info("Model was exported.")
|
|
262
274
|
except NotFound:
|
|
263
275
|
self.log.info("The Model ID %s does not exist.", self.model_id)
|
|
@@ -335,6 +347,12 @@ class ListModelsOperator(GoogleCloudBaseOperator):
|
|
|
335
347
|
self.gcp_conn_id = gcp_conn_id
|
|
336
348
|
self.impersonation_chain = impersonation_chain
|
|
337
349
|
|
|
350
|
+
@property
|
|
351
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
352
|
+
return {
|
|
353
|
+
"project_id": self.project_id,
|
|
354
|
+
}
|
|
355
|
+
|
|
338
356
|
def execute(self, context: Context):
|
|
339
357
|
hook = ModelServiceHook(
|
|
340
358
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -352,7 +370,7 @@ class ListModelsOperator(GoogleCloudBaseOperator):
|
|
|
352
370
|
timeout=self.timeout,
|
|
353
371
|
metadata=self.metadata,
|
|
354
372
|
)
|
|
355
|
-
VertexAIModelListLink.persist(context=context
|
|
373
|
+
VertexAIModelListLink.persist(context=context)
|
|
356
374
|
return [Model.to_dict(result) for result in results]
|
|
357
375
|
|
|
358
376
|
|
|
@@ -407,6 +425,13 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
|
407
425
|
self.gcp_conn_id = gcp_conn_id
|
|
408
426
|
self.impersonation_chain = impersonation_chain
|
|
409
427
|
|
|
428
|
+
@property
|
|
429
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
430
|
+
return {
|
|
431
|
+
"region": self.region,
|
|
432
|
+
"project_id": self.project_id,
|
|
433
|
+
}
|
|
434
|
+
|
|
410
435
|
def execute(self, context: Context):
|
|
411
436
|
hook = ModelServiceHook(
|
|
412
437
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -428,8 +453,8 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
|
428
453
|
model_id = hook.extract_model_id(model_resp)
|
|
429
454
|
self.log.info("Model was uploaded. Model ID: %s", model_id)
|
|
430
455
|
|
|
431
|
-
|
|
432
|
-
VertexAIModelLink.persist(context=context,
|
|
456
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
457
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
433
458
|
return model_resp
|
|
434
459
|
|
|
435
460
|
|
|
@@ -553,6 +578,13 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
|
|
|
553
578
|
self.gcp_conn_id = gcp_conn_id
|
|
554
579
|
self.impersonation_chain = impersonation_chain
|
|
555
580
|
|
|
581
|
+
@property
|
|
582
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
583
|
+
return {
|
|
584
|
+
"region": self.region,
|
|
585
|
+
"project_id": self.project_id,
|
|
586
|
+
}
|
|
587
|
+
|
|
556
588
|
def execute(self, context: Context):
|
|
557
589
|
hook = ModelServiceHook(
|
|
558
590
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -571,7 +603,7 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
|
|
|
571
603
|
timeout=self.timeout,
|
|
572
604
|
metadata=self.metadata,
|
|
573
605
|
)
|
|
574
|
-
VertexAIModelLink.persist(context=context,
|
|
606
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
575
607
|
return Model.to_dict(updated_model)
|
|
576
608
|
|
|
577
609
|
|
|
@@ -627,6 +659,13 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
627
659
|
self.gcp_conn_id = gcp_conn_id
|
|
628
660
|
self.impersonation_chain = impersonation_chain
|
|
629
661
|
|
|
662
|
+
@property
|
|
663
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
664
|
+
return {
|
|
665
|
+
"region": self.region,
|
|
666
|
+
"project_id": self.project_id,
|
|
667
|
+
}
|
|
668
|
+
|
|
630
669
|
def execute(self, context: Context):
|
|
631
670
|
hook = ModelServiceHook(
|
|
632
671
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -645,7 +684,7 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
645
684
|
timeout=self.timeout,
|
|
646
685
|
metadata=self.metadata,
|
|
647
686
|
)
|
|
648
|
-
VertexAIModelLink.persist(context=context,
|
|
687
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
649
688
|
return Model.to_dict(updated_model)
|
|
650
689
|
|
|
651
690
|
|
|
@@ -701,6 +740,13 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
701
740
|
self.gcp_conn_id = gcp_conn_id
|
|
702
741
|
self.impersonation_chain = impersonation_chain
|
|
703
742
|
|
|
743
|
+
@property
|
|
744
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
745
|
+
return {
|
|
746
|
+
"region": self.region,
|
|
747
|
+
"project_id": self.project_id,
|
|
748
|
+
}
|
|
749
|
+
|
|
704
750
|
def execute(self, context: Context):
|
|
705
751
|
hook = ModelServiceHook(
|
|
706
752
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -721,7 +767,7 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
721
767
|
timeout=self.timeout,
|
|
722
768
|
metadata=self.metadata,
|
|
723
769
|
)
|
|
724
|
-
VertexAIModelLink.persist(context=context,
|
|
770
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
725
771
|
return Model.to_dict(updated_model)
|
|
726
772
|
|
|
727
773
|
|