apache-airflow-providers-google 15.1.0rc1__py3-none-any.whl → 19.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/NOTICE +2 -12
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/ads/hooks/ads.py +39 -6
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +2 -2
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/bundles/__init__.py +16 -0
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/alloy_db.py +1 -1
- airflow/providers/google/cloud/hooks/bigquery.py +176 -293
- airflow/providers/google/cloud/hooks/cloud_batch.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_build.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_composer.py +288 -15
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +1 -1
- airflow/providers/google/cloud/hooks/cloud_run.py +18 -10
- airflow/providers/google/cloud/hooks/cloud_sql.py +102 -23
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +29 -7
- airflow/providers/google/cloud/hooks/compute.py +1 -1
- airflow/providers/google/cloud/hooks/compute_ssh.py +6 -2
- airflow/providers/google/cloud/hooks/datacatalog.py +10 -1
- airflow/providers/google/cloud/hooks/dataflow.py +72 -95
- airflow/providers/google/cloud/hooks/dataform.py +1 -1
- airflow/providers/google/cloud/hooks/datafusion.py +21 -19
- airflow/providers/google/cloud/hooks/dataplex.py +2 -2
- airflow/providers/google/cloud/hooks/dataprep.py +1 -1
- airflow/providers/google/cloud/hooks/dataproc.py +73 -72
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +1 -1
- airflow/providers/google/cloud/hooks/dlp.py +1 -1
- airflow/providers/google/cloud/hooks/functions.py +1 -1
- airflow/providers/google/cloud/hooks/gcs.py +112 -15
- airflow/providers/google/cloud/hooks/gdm.py +1 -1
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +3 -3
- airflow/providers/google/cloud/hooks/looker.py +6 -2
- airflow/providers/google/cloud/hooks/managed_kafka.py +1 -1
- airflow/providers/google/cloud/hooks/mlengine.py +4 -3
- airflow/providers/google/cloud/hooks/pubsub.py +3 -0
- airflow/providers/google/cloud/hooks/secret_manager.py +102 -10
- airflow/providers/google/cloud/hooks/spanner.py +74 -9
- airflow/providers/google/cloud/hooks/stackdriver.py +11 -9
- airflow/providers/google/cloud/hooks/tasks.py +1 -1
- airflow/providers/google/cloud/hooks/translate.py +2 -2
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +2 -210
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +3 -3
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +28 -2
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +308 -8
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +1 -1
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/vision.py +3 -3
- airflow/providers/google/cloud/hooks/workflows.py +1 -1
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -13
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -44
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -96
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -95
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +9 -60
- airflow/providers/google/cloud/links/managed_kafka.py +0 -70
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +58 -22
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +9 -6
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +102 -1
- airflow/providers/google/cloud/openlineage/mixins.py +10 -8
- airflow/providers/google/cloud/openlineage/utils.py +15 -1
- airflow/providers/google/cloud/operators/alloy_db.py +71 -56
- airflow/providers/google/cloud/operators/bigquery.py +73 -636
- airflow/providers/google/cloud/operators/bigquery_dts.py +4 -6
- airflow/providers/google/cloud/operators/bigtable.py +37 -8
- airflow/providers/google/cloud/operators/cloud_base.py +21 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +3 -3
- airflow/providers/google/cloud/operators/cloud_build.py +76 -33
- airflow/providers/google/cloud/operators/cloud_composer.py +129 -41
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +69 -43
- airflow/providers/google/cloud/operators/cloud_run.py +24 -6
- airflow/providers/google/cloud/operators/cloud_sql.py +8 -17
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +93 -12
- airflow/providers/google/cloud/operators/compute.py +9 -41
- airflow/providers/google/cloud/operators/datacatalog.py +157 -21
- airflow/providers/google/cloud/operators/dataflow.py +40 -16
- airflow/providers/google/cloud/operators/dataform.py +15 -5
- airflow/providers/google/cloud/operators/datafusion.py +42 -21
- airflow/providers/google/cloud/operators/dataplex.py +194 -110
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +80 -36
- airflow/providers/google/cloud/operators/dataproc_metastore.py +97 -89
- airflow/providers/google/cloud/operators/datastore.py +23 -7
- airflow/providers/google/cloud/operators/dlp.py +6 -29
- airflow/providers/google/cloud/operators/functions.py +17 -8
- airflow/providers/google/cloud/operators/gcs.py +12 -9
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +62 -100
- airflow/providers/google/cloud/operators/looker.py +2 -2
- airflow/providers/google/cloud/operators/managed_kafka.py +108 -53
- airflow/providers/google/cloud/operators/natural_language.py +1 -1
- airflow/providers/google/cloud/operators/pubsub.py +68 -15
- airflow/providers/google/cloud/operators/spanner.py +26 -13
- airflow/providers/google/cloud/operators/speech_to_text.py +2 -3
- airflow/providers/google/cloud/operators/stackdriver.py +1 -9
- airflow/providers/google/cloud/operators/tasks.py +1 -12
- airflow/providers/google/cloud/operators/text_to_speech.py +2 -3
- airflow/providers/google/cloud/operators/translate.py +41 -17
- airflow/providers/google/cloud/operators/translate_speech.py +2 -3
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +39 -19
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +30 -10
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +55 -27
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +70 -8
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +43 -9
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -115
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +12 -10
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +57 -11
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +31 -8
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -1
- airflow/providers/google/cloud/operators/vision.py +2 -2
- airflow/providers/google/cloud/operators/workflows.py +18 -15
- airflow/providers/google/cloud/secrets/secret_manager.py +3 -2
- airflow/providers/google/cloud/sensors/bigquery.py +3 -3
- airflow/providers/google/cloud/sensors/bigquery_dts.py +2 -3
- airflow/providers/google/cloud/sensors/bigtable.py +11 -4
- airflow/providers/google/cloud/sensors/cloud_composer.py +533 -30
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +2 -3
- airflow/providers/google/cloud/sensors/dataflow.py +26 -10
- airflow/providers/google/cloud/sensors/dataform.py +2 -3
- airflow/providers/google/cloud/sensors/datafusion.py +4 -5
- airflow/providers/google/cloud/sensors/dataplex.py +2 -3
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +2 -3
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +2 -3
- airflow/providers/google/cloud/sensors/gcs.py +4 -5
- airflow/providers/google/cloud/sensors/looker.py +2 -3
- airflow/providers/google/cloud/sensors/pubsub.py +4 -5
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -3
- airflow/providers/google/cloud/sensors/workflows.py +2 -3
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +4 -3
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +10 -5
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -4
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +21 -13
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +4 -3
- airflow/providers/google/cloud/transfers/gcs_to_local.py +6 -4
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +11 -5
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +42 -9
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +13 -7
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +14 -5
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +76 -35
- airflow/providers/google/cloud/triggers/cloud_build.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
- airflow/providers/google/cloud/triggers/cloud_run.py +3 -3
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +92 -2
- airflow/providers/google/cloud/triggers/dataflow.py +122 -0
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +14 -2
- airflow/providers/google/cloud/triggers/dataproc.py +123 -53
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +47 -28
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +15 -19
- airflow/providers/google/cloud/triggers/vertex_ai.py +1 -1
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +2 -2
- airflow/providers/google/cloud/utils/field_sanitizer.py +1 -1
- airflow/providers/google/cloud/utils/field_validator.py +2 -3
- airflow/providers/google/common/auth_backend/google_openid.py +4 -4
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +27 -9
- airflow/providers/google/common/hooks/operation_helpers.py +1 -1
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +3 -4
- airflow/providers/google/firebase/hooks/firestore.py +1 -1
- airflow/providers/google/firebase/operators/firestore.py +3 -3
- airflow/providers/google/get_provider_info.py +56 -52
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +27 -2
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/campaign_manager.py +1 -1
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +5 -14
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +2 -3
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +6 -6
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +3 -64
- airflow/providers/google/suite/hooks/calendar.py +2 -2
- airflow/providers/google/suite/hooks/sheets.py +16 -2
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +3 -3
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/METADATA +90 -46
- apache_airflow_providers_google-19.3.0.dist-info/RECORD +331 -0
- apache_airflow_providers_google-19.3.0.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/hooks/automl.py +0 -673
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1362
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -112
- apache_airflow_providers_google-15.1.0rc1.dist-info/RECORD +0 -321
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-15.1.0rc1.dist-info → apache_airflow_providers_google-19.3.0.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.3.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -20,107 +20,28 @@
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
|
-
from typing import TYPE_CHECKING
|
|
23
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
24
|
+
|
|
25
|
+
from google.api_core import exceptions
|
|
24
26
|
|
|
25
27
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
|
26
|
-
from airflow.providers.
|
|
28
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
29
|
+
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
|
|
30
|
+
ExperimentRunHook,
|
|
31
|
+
GenerativeModelHook,
|
|
32
|
+
)
|
|
27
33
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
28
34
|
from airflow.providers.google.common.deprecated import deprecated
|
|
29
35
|
|
|
30
36
|
if TYPE_CHECKING:
|
|
31
|
-
from airflow.
|
|
37
|
+
from airflow.providers.common.compat.sdk import Context
|
|
32
38
|
|
|
33
39
|
|
|
34
40
|
@deprecated(
|
|
35
|
-
planned_removal_date="
|
|
36
|
-
use_instead="
|
|
41
|
+
planned_removal_date="January 3, 2026",
|
|
42
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateEmbeddingsOperator",
|
|
37
43
|
category=AirflowProviderDeprecationWarning,
|
|
38
44
|
)
|
|
39
|
-
class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
|
|
40
|
-
"""
|
|
41
|
-
Uses the Vertex AI PaLM API to generate natural language text.
|
|
42
|
-
|
|
43
|
-
:param project_id: Required. The ID of the Google Cloud project that the
|
|
44
|
-
service belongs to (templated).
|
|
45
|
-
:param location: Required. The ID of the Google Cloud location that the
|
|
46
|
-
service belongs to (templated).
|
|
47
|
-
:param prompt: Required. Inputs or queries that a user or a program gives
|
|
48
|
-
to the Vertex AI PaLM API, in order to elicit a specific response (templated).
|
|
49
|
-
:param pretrained_model: By default uses the pre-trained model `text-bison`,
|
|
50
|
-
optimized for performing natural language tasks such as classification,
|
|
51
|
-
summarization, extraction, content creation, and ideation.
|
|
52
|
-
:param temperature: Temperature controls the degree of randomness in token
|
|
53
|
-
selection. Defaults to 0.0.
|
|
54
|
-
:param max_output_tokens: Token limit determines the maximum amount of text
|
|
55
|
-
output. Defaults to 256.
|
|
56
|
-
:param top_p: Tokens are selected from most probable to least until the sum
|
|
57
|
-
of their probabilities equals the top_p value. Defaults to 0.8.
|
|
58
|
-
:param top_k: A top_k of 1 means the selected token is the most probable
|
|
59
|
-
among all tokens. Defaults to 0.4.
|
|
60
|
-
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
61
|
-
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
62
|
-
credentials, or chained list of accounts required to get the access_token
|
|
63
|
-
of the last account in the list, which will be impersonated in the request.
|
|
64
|
-
If set as a string, the account must grant the originating account
|
|
65
|
-
the Service Account Token Creator IAM role.
|
|
66
|
-
If set as a sequence, the identities from the list must grant
|
|
67
|
-
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
68
|
-
account from the list granting this role to the originating account (templated).
|
|
69
|
-
"""
|
|
70
|
-
|
|
71
|
-
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
|
72
|
-
|
|
73
|
-
def __init__(
|
|
74
|
-
self,
|
|
75
|
-
*,
|
|
76
|
-
project_id: str,
|
|
77
|
-
location: str,
|
|
78
|
-
prompt: str,
|
|
79
|
-
pretrained_model: str = "text-bison",
|
|
80
|
-
temperature: float = 0.0,
|
|
81
|
-
max_output_tokens: int = 256,
|
|
82
|
-
top_p: float = 0.8,
|
|
83
|
-
top_k: int = 40,
|
|
84
|
-
gcp_conn_id: str = "google_cloud_default",
|
|
85
|
-
impersonation_chain: str | Sequence[str] | None = None,
|
|
86
|
-
**kwargs,
|
|
87
|
-
) -> None:
|
|
88
|
-
super().__init__(**kwargs)
|
|
89
|
-
self.project_id = project_id
|
|
90
|
-
self.location = location
|
|
91
|
-
self.prompt = prompt
|
|
92
|
-
self.pretrained_model = pretrained_model
|
|
93
|
-
self.temperature = temperature
|
|
94
|
-
self.max_output_tokens = max_output_tokens
|
|
95
|
-
self.top_p = top_p
|
|
96
|
-
self.top_k = top_k
|
|
97
|
-
self.gcp_conn_id = gcp_conn_id
|
|
98
|
-
self.impersonation_chain = impersonation_chain
|
|
99
|
-
|
|
100
|
-
def execute(self, context: Context):
|
|
101
|
-
self.hook = GenerativeModelHook(
|
|
102
|
-
gcp_conn_id=self.gcp_conn_id,
|
|
103
|
-
impersonation_chain=self.impersonation_chain,
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
self.log.info("Submitting prompt")
|
|
107
|
-
response = self.hook.text_generation_model_predict(
|
|
108
|
-
project_id=self.project_id,
|
|
109
|
-
location=self.location,
|
|
110
|
-
prompt=self.prompt,
|
|
111
|
-
pretrained_model=self.pretrained_model,
|
|
112
|
-
temperature=self.temperature,
|
|
113
|
-
max_output_tokens=self.max_output_tokens,
|
|
114
|
-
top_p=self.top_p,
|
|
115
|
-
top_k=self.top_k,
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
self.log.info("Model response: %s", response)
|
|
119
|
-
self.xcom_push(context, key="model_response", value=response)
|
|
120
|
-
|
|
121
|
-
return response
|
|
122
|
-
|
|
123
|
-
|
|
124
45
|
class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
125
46
|
"""
|
|
126
47
|
Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
|
|
@@ -130,9 +51,8 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
130
51
|
:param location: Required. The ID of the Google Cloud location that the
|
|
131
52
|
service belongs to (templated).
|
|
132
53
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
|
133
|
-
to the Vertex AI
|
|
134
|
-
:param pretrained_model:
|
|
135
|
-
optimized for performing text embeddings.
|
|
54
|
+
to the Vertex AI Generative Model API, in order to elicit a specific response (templated).
|
|
55
|
+
:param pretrained_model: Required. Model, optimized for performing text embeddings.
|
|
136
56
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
137
57
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
138
58
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -144,7 +64,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
144
64
|
account from the list granting this role to the originating account (templated).
|
|
145
65
|
"""
|
|
146
66
|
|
|
147
|
-
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
|
67
|
+
template_fields = ("location", "project_id", "impersonation_chain", "prompt", "pretrained_model")
|
|
148
68
|
|
|
149
69
|
def __init__(
|
|
150
70
|
self,
|
|
@@ -152,7 +72,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
152
72
|
project_id: str,
|
|
153
73
|
location: str,
|
|
154
74
|
prompt: str,
|
|
155
|
-
pretrained_model: str
|
|
75
|
+
pretrained_model: str,
|
|
156
76
|
gcp_conn_id: str = "google_cloud_default",
|
|
157
77
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
158
78
|
**kwargs,
|
|
@@ -180,11 +100,16 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
180
100
|
)
|
|
181
101
|
|
|
182
102
|
self.log.info("Model response: %s", response)
|
|
183
|
-
|
|
103
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
|
184
104
|
|
|
185
105
|
return response
|
|
186
106
|
|
|
187
107
|
|
|
108
|
+
@deprecated(
|
|
109
|
+
planned_removal_date="January 3, 2026",
|
|
110
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
|
|
111
|
+
category=AirflowProviderDeprecationWarning,
|
|
112
|
+
)
|
|
188
113
|
class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
189
114
|
"""
|
|
190
115
|
Use the Vertex AI Gemini Pro foundation model to generate content.
|
|
@@ -199,10 +124,9 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
199
124
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
|
200
125
|
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
|
|
201
126
|
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
|
|
202
|
-
:param pretrained_model:
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
output text and code.
|
|
127
|
+
:param pretrained_model: Required. The name of the model to use for content generation,
|
|
128
|
+
which can be a text-only or multimodal model. For example, `gemini-pro` or
|
|
129
|
+
`gemini-pro-vision`.
|
|
206
130
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
207
131
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
208
132
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -226,7 +150,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
226
150
|
generation_config: dict | None = None,
|
|
227
151
|
safety_settings: dict | None = None,
|
|
228
152
|
system_instruction: str | None = None,
|
|
229
|
-
pretrained_model: str
|
|
153
|
+
pretrained_model: str,
|
|
230
154
|
gcp_conn_id: str = "google_cloud_default",
|
|
231
155
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
232
156
|
**kwargs,
|
|
@@ -260,11 +184,16 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
260
184
|
)
|
|
261
185
|
|
|
262
186
|
self.log.info("Model response: %s", response)
|
|
263
|
-
|
|
187
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
|
264
188
|
|
|
265
189
|
return response
|
|
266
190
|
|
|
267
191
|
|
|
192
|
+
@deprecated(
|
|
193
|
+
planned_removal_date="January 3, 2026",
|
|
194
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAISupervisedFineTuningTrainOperator",
|
|
195
|
+
category=AirflowProviderDeprecationWarning,
|
|
196
|
+
)
|
|
268
197
|
class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
269
198
|
"""
|
|
270
199
|
Use the Supervised Fine Tuning API to create a tuning job.
|
|
@@ -298,7 +227,14 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
298
227
|
account from the list granting this role to the originating account (templated).
|
|
299
228
|
"""
|
|
300
229
|
|
|
301
|
-
template_fields = (
|
|
230
|
+
template_fields = (
|
|
231
|
+
"location",
|
|
232
|
+
"project_id",
|
|
233
|
+
"impersonation_chain",
|
|
234
|
+
"train_dataset",
|
|
235
|
+
"validation_dataset",
|
|
236
|
+
"source_model",
|
|
237
|
+
)
|
|
302
238
|
|
|
303
239
|
def __init__(
|
|
304
240
|
self,
|
|
@@ -310,7 +246,7 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
310
246
|
tuned_model_display_name: str | None = None,
|
|
311
247
|
validation_dataset: str | None = None,
|
|
312
248
|
epochs: int | None = None,
|
|
313
|
-
adapter_size:
|
|
249
|
+
adapter_size: Literal[1, 4, 8, 16] | None = None,
|
|
314
250
|
learning_rate_multiplier: float | None = None,
|
|
315
251
|
gcp_conn_id: str = "google_cloud_default",
|
|
316
252
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
@@ -349,8 +285,8 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
349
285
|
self.log.info("Tuned Model Name: %s", response.tuned_model_name)
|
|
350
286
|
self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)
|
|
351
287
|
|
|
352
|
-
|
|
353
|
-
|
|
288
|
+
context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model_name)
|
|
289
|
+
context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
|
|
354
290
|
|
|
355
291
|
result = {
|
|
356
292
|
"tuned_model_name": response.tuned_model_name,
|
|
@@ -360,6 +296,11 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
360
296
|
return result
|
|
361
297
|
|
|
362
298
|
|
|
299
|
+
@deprecated(
|
|
300
|
+
planned_removal_date="January 3, 2026",
|
|
301
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICountTokensOperator",
|
|
302
|
+
category=AirflowProviderDeprecationWarning,
|
|
303
|
+
)
|
|
363
304
|
class CountTokensOperator(GoogleCloudBaseOperator):
|
|
364
305
|
"""
|
|
365
306
|
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
|
|
@@ -370,10 +311,9 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
|
370
311
|
service belongs to (templated).
|
|
371
312
|
:param contents: Required. The multi-part content of a message that a user or a program
|
|
372
313
|
gives to the generative model, in order to elicit a specific response.
|
|
373
|
-
:param pretrained_model:
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
output text and code.
|
|
314
|
+
:param pretrained_model: Required. Model, supporting prompts with text-only input,
|
|
315
|
+
including natural language tasks, multi-turn text and code chat,
|
|
316
|
+
and code generation. It can output text and code.
|
|
377
317
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
378
318
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
379
319
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -393,7 +333,7 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
|
393
333
|
project_id: str,
|
|
394
334
|
location: str,
|
|
395
335
|
contents: list,
|
|
396
|
-
pretrained_model: str
|
|
336
|
+
pretrained_model: str,
|
|
397
337
|
gcp_conn_id: str = "google_cloud_default",
|
|
398
338
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
399
339
|
**kwargs,
|
|
@@ -421,8 +361,8 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
|
421
361
|
self.log.info("Total tokens: %s", response.total_tokens)
|
|
422
362
|
self.log.info("Total billable characters: %s", response.total_billable_characters)
|
|
423
363
|
|
|
424
|
-
|
|
425
|
-
|
|
364
|
+
context["ti"].xcom_push(key="total_tokens", value=response.total_tokens)
|
|
365
|
+
context["ti"].xcom_push(key="total_billable_characters", value=response.total_billable_characters)
|
|
426
366
|
|
|
427
367
|
|
|
428
368
|
class RunEvaluationOperator(GoogleCloudBaseOperator):
|
|
@@ -524,6 +464,11 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
|
|
|
524
464
|
return response.summary_metrics
|
|
525
465
|
|
|
526
466
|
|
|
467
|
+
@deprecated(
|
|
468
|
+
planned_removal_date="January 3, 2026",
|
|
469
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICreateCachedContentOperator",
|
|
470
|
+
category=AirflowProviderDeprecationWarning,
|
|
471
|
+
)
|
|
527
472
|
class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
528
473
|
"""
|
|
529
474
|
Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
|
|
@@ -562,8 +507,8 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
562
507
|
project_id: str,
|
|
563
508
|
location: str,
|
|
564
509
|
model_name: str,
|
|
565
|
-
system_instruction:
|
|
566
|
-
contents: list | None = None,
|
|
510
|
+
system_instruction: Any | None = None,
|
|
511
|
+
contents: list[Any] | None = None,
|
|
567
512
|
ttl_hours: float = 1,
|
|
568
513
|
display_name: str | None = None,
|
|
569
514
|
gcp_conn_id: str = "google_cloud_default",
|
|
@@ -603,6 +548,11 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
603
548
|
return cached_content_name
|
|
604
549
|
|
|
605
550
|
|
|
551
|
+
@deprecated(
|
|
552
|
+
planned_removal_date="January 3, 2026",
|
|
553
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
|
|
554
|
+
category=AirflowProviderDeprecationWarning,
|
|
555
|
+
)
|
|
606
556
|
class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
|
|
607
557
|
"""
|
|
608
558
|
Generate a response from CachedContent.
|
|
@@ -674,3 +624,73 @@ class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
674
624
|
self.log.info("Cached Content Response: %s", cached_content_text)
|
|
675
625
|
|
|
676
626
|
return cached_content_text
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
@deprecated(
|
|
630
|
+
planned_removal_date="January 3, 2026",
|
|
631
|
+
use_instead="airflow.providers.google.cloud.operators.vertex_ai.experiment_service.DeleteExperimentRunOperator",
|
|
632
|
+
category=AirflowProviderDeprecationWarning,
|
|
633
|
+
)
|
|
634
|
+
class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
|
|
635
|
+
"""
|
|
636
|
+
Use the Rapid Evaluation API to evaluate a model.
|
|
637
|
+
|
|
638
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
639
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
640
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
641
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
|
642
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
643
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
644
|
+
credentials, or chained list of accounts required to get the access_token
|
|
645
|
+
of the last account in the list, which will be impersonated in the request.
|
|
646
|
+
If set as a string, the account must grant the originating account
|
|
647
|
+
the Service Account Token Creator IAM role.
|
|
648
|
+
If set as a sequence, the identities from the list must grant
|
|
649
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
650
|
+
account from the list granting this role to the originating account (templated).
|
|
651
|
+
"""
|
|
652
|
+
|
|
653
|
+
template_fields = (
|
|
654
|
+
"location",
|
|
655
|
+
"project_id",
|
|
656
|
+
"impersonation_chain",
|
|
657
|
+
"experiment_name",
|
|
658
|
+
"experiment_run_name",
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
def __init__(
|
|
662
|
+
self,
|
|
663
|
+
*,
|
|
664
|
+
project_id: str,
|
|
665
|
+
location: str,
|
|
666
|
+
experiment_name: str,
|
|
667
|
+
experiment_run_name: str,
|
|
668
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
669
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
670
|
+
**kwargs,
|
|
671
|
+
) -> None:
|
|
672
|
+
super().__init__(**kwargs)
|
|
673
|
+
self.project_id = project_id
|
|
674
|
+
self.location = location
|
|
675
|
+
self.experiment_name = experiment_name
|
|
676
|
+
self.experiment_run_name = experiment_run_name
|
|
677
|
+
self.gcp_conn_id = gcp_conn_id
|
|
678
|
+
self.impersonation_chain = impersonation_chain
|
|
679
|
+
|
|
680
|
+
def execute(self, context: Context) -> None:
|
|
681
|
+
self.hook = ExperimentRunHook(
|
|
682
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
683
|
+
impersonation_chain=self.impersonation_chain,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
try:
|
|
687
|
+
self.hook.delete_experiment_run(
|
|
688
|
+
project_id=self.project_id,
|
|
689
|
+
location=self.location,
|
|
690
|
+
experiment_name=self.experiment_name,
|
|
691
|
+
experiment_run_name=self.experiment_run_name,
|
|
692
|
+
)
|
|
693
|
+
except exceptions.NotFound:
|
|
694
|
+
raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found")
|
|
695
|
+
|
|
696
|
+
self.log.info("Deleted experiment run: %s", self.experiment_run_name)
|
|
@@ -28,7 +28,7 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
|
28
28
|
from google.cloud.aiplatform_v1 import types
|
|
29
29
|
|
|
30
30
|
from airflow.configuration import conf
|
|
31
|
-
from airflow.
|
|
31
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
32
32
|
from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
|
|
33
33
|
HyperparameterTuningJobHook,
|
|
34
34
|
)
|
|
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
|
|
43
43
|
from google.api_core.retry import Retry
|
|
44
44
|
from google.cloud.aiplatform import HyperparameterTuningJob, gapic, hyperparameter_tuning
|
|
45
45
|
|
|
46
|
-
from airflow.
|
|
46
|
+
from airflow.providers.common.compat.sdk import Context
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
@@ -257,10 +257,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
257
257
|
hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
|
|
258
258
|
self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
|
|
259
259
|
|
|
260
|
-
|
|
261
|
-
VertexAITrainingLink.persist(
|
|
262
|
-
context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
|
|
263
|
-
)
|
|
260
|
+
context["ti"].xcom_push(key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
|
|
261
|
+
VertexAITrainingLink.persist(context=context, training_id=hyperparameter_tuning_job_id)
|
|
264
262
|
|
|
265
263
|
if self.deferrable:
|
|
266
264
|
self.defer(
|
|
@@ -355,9 +353,7 @@ class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
355
353
|
timeout=self.timeout,
|
|
356
354
|
metadata=self.metadata,
|
|
357
355
|
)
|
|
358
|
-
VertexAITrainingLink.persist(
|
|
359
|
-
context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id
|
|
360
|
-
)
|
|
356
|
+
VertexAITrainingLink.persist(context=context, training_id=self.hyperparameter_tuning_job_id)
|
|
361
357
|
self.log.info("Hyperparameter tuning job was gotten.")
|
|
362
358
|
return types.HyperparameterTuningJob.to_dict(result)
|
|
363
359
|
except NotFound:
|
|
@@ -487,6 +483,12 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
487
483
|
self.gcp_conn_id = gcp_conn_id
|
|
488
484
|
self.impersonation_chain = impersonation_chain
|
|
489
485
|
|
|
486
|
+
@property
|
|
487
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
488
|
+
return {
|
|
489
|
+
"project_id": self.project_id,
|
|
490
|
+
}
|
|
491
|
+
|
|
490
492
|
def execute(self, context: Context):
|
|
491
493
|
hook = HyperparameterTuningJobHook(
|
|
492
494
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -503,5 +505,5 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
503
505
|
timeout=self.timeout,
|
|
504
506
|
metadata=self.metadata,
|
|
505
507
|
)
|
|
506
|
-
VertexAIHyperparameterTuningJobListLink.persist(context=context
|
|
508
|
+
VertexAIHyperparameterTuningJobListLink.persist(context=context)
|
|
507
509
|
return [types.HyperparameterTuningJob.to_dict(result) for result in results]
|
|
@@ -20,7 +20,7 @@
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
|
-
from typing import TYPE_CHECKING
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
24
|
|
|
25
25
|
from google.api_core.exceptions import NotFound
|
|
26
26
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
@@ -37,7 +37,7 @@ from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseO
|
|
|
37
37
|
if TYPE_CHECKING:
|
|
38
38
|
from google.api_core.retry import Retry
|
|
39
39
|
|
|
40
|
-
from airflow.
|
|
40
|
+
from airflow.providers.common.compat.sdk import Context
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
class DeleteModelOperator(GoogleCloudBaseOperator):
|
|
@@ -161,6 +161,13 @@ class GetModelOperator(GoogleCloudBaseOperator):
|
|
|
161
161
|
self.gcp_conn_id = gcp_conn_id
|
|
162
162
|
self.impersonation_chain = impersonation_chain
|
|
163
163
|
|
|
164
|
+
@property
|
|
165
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
166
|
+
return {
|
|
167
|
+
"region": self.region,
|
|
168
|
+
"project_id": self.project_id,
|
|
169
|
+
}
|
|
170
|
+
|
|
164
171
|
def execute(self, context: Context):
|
|
165
172
|
hook = ModelServiceHook(
|
|
166
173
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -179,8 +186,8 @@ class GetModelOperator(GoogleCloudBaseOperator):
|
|
|
179
186
|
)
|
|
180
187
|
self.log.info("Model found. Model ID: %s", self.model_id)
|
|
181
188
|
|
|
182
|
-
|
|
183
|
-
VertexAIModelLink.persist(context=context,
|
|
189
|
+
context["ti"].xcom_push(key="model_id", value=self.model_id)
|
|
190
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
184
191
|
return Model.to_dict(model)
|
|
185
192
|
except NotFound:
|
|
186
193
|
self.log.info("The Model ID %s does not exist.", self.model_id)
|
|
@@ -257,7 +264,12 @@ class ExportModelOperator(GoogleCloudBaseOperator):
|
|
|
257
264
|
metadata=self.metadata,
|
|
258
265
|
)
|
|
259
266
|
hook.wait_for_operation(timeout=self.timeout, operation=operation)
|
|
260
|
-
VertexAIModelExportLink.persist(
|
|
267
|
+
VertexAIModelExportLink.persist(
|
|
268
|
+
context=context,
|
|
269
|
+
output_config=self.output_config,
|
|
270
|
+
model_id=self.model_id,
|
|
271
|
+
project_id=self.project_id,
|
|
272
|
+
)
|
|
261
273
|
self.log.info("Model was exported.")
|
|
262
274
|
except NotFound:
|
|
263
275
|
self.log.info("The Model ID %s does not exist.", self.model_id)
|
|
@@ -335,6 +347,12 @@ class ListModelsOperator(GoogleCloudBaseOperator):
|
|
|
335
347
|
self.gcp_conn_id = gcp_conn_id
|
|
336
348
|
self.impersonation_chain = impersonation_chain
|
|
337
349
|
|
|
350
|
+
@property
|
|
351
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
352
|
+
return {
|
|
353
|
+
"project_id": self.project_id,
|
|
354
|
+
}
|
|
355
|
+
|
|
338
356
|
def execute(self, context: Context):
|
|
339
357
|
hook = ModelServiceHook(
|
|
340
358
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -352,7 +370,7 @@ class ListModelsOperator(GoogleCloudBaseOperator):
|
|
|
352
370
|
timeout=self.timeout,
|
|
353
371
|
metadata=self.metadata,
|
|
354
372
|
)
|
|
355
|
-
VertexAIModelListLink.persist(context=context
|
|
373
|
+
VertexAIModelListLink.persist(context=context)
|
|
356
374
|
return [Model.to_dict(result) for result in results]
|
|
357
375
|
|
|
358
376
|
|
|
@@ -407,6 +425,13 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
|
407
425
|
self.gcp_conn_id = gcp_conn_id
|
|
408
426
|
self.impersonation_chain = impersonation_chain
|
|
409
427
|
|
|
428
|
+
@property
|
|
429
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
430
|
+
return {
|
|
431
|
+
"region": self.region,
|
|
432
|
+
"project_id": self.project_id,
|
|
433
|
+
}
|
|
434
|
+
|
|
410
435
|
def execute(self, context: Context):
|
|
411
436
|
hook = ModelServiceHook(
|
|
412
437
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -428,8 +453,8 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
|
428
453
|
model_id = hook.extract_model_id(model_resp)
|
|
429
454
|
self.log.info("Model was uploaded. Model ID: %s", model_id)
|
|
430
455
|
|
|
431
|
-
|
|
432
|
-
VertexAIModelLink.persist(context=context,
|
|
456
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
457
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
433
458
|
return model_resp
|
|
434
459
|
|
|
435
460
|
|
|
@@ -553,6 +578,13 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
|
|
|
553
578
|
self.gcp_conn_id = gcp_conn_id
|
|
554
579
|
self.impersonation_chain = impersonation_chain
|
|
555
580
|
|
|
581
|
+
@property
|
|
582
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
583
|
+
return {
|
|
584
|
+
"region": self.region,
|
|
585
|
+
"project_id": self.project_id,
|
|
586
|
+
}
|
|
587
|
+
|
|
556
588
|
def execute(self, context: Context):
|
|
557
589
|
hook = ModelServiceHook(
|
|
558
590
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -571,7 +603,7 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
|
|
|
571
603
|
timeout=self.timeout,
|
|
572
604
|
metadata=self.metadata,
|
|
573
605
|
)
|
|
574
|
-
VertexAIModelLink.persist(context=context,
|
|
606
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
575
607
|
return Model.to_dict(updated_model)
|
|
576
608
|
|
|
577
609
|
|
|
@@ -627,6 +659,13 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
627
659
|
self.gcp_conn_id = gcp_conn_id
|
|
628
660
|
self.impersonation_chain = impersonation_chain
|
|
629
661
|
|
|
662
|
+
@property
|
|
663
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
664
|
+
return {
|
|
665
|
+
"region": self.region,
|
|
666
|
+
"project_id": self.project_id,
|
|
667
|
+
}
|
|
668
|
+
|
|
630
669
|
def execute(self, context: Context):
|
|
631
670
|
hook = ModelServiceHook(
|
|
632
671
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -645,7 +684,7 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
645
684
|
timeout=self.timeout,
|
|
646
685
|
metadata=self.metadata,
|
|
647
686
|
)
|
|
648
|
-
VertexAIModelLink.persist(context=context,
|
|
687
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
649
688
|
return Model.to_dict(updated_model)
|
|
650
689
|
|
|
651
690
|
|
|
@@ -701,6 +740,13 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
701
740
|
self.gcp_conn_id = gcp_conn_id
|
|
702
741
|
self.impersonation_chain = impersonation_chain
|
|
703
742
|
|
|
743
|
+
@property
|
|
744
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
745
|
+
return {
|
|
746
|
+
"region": self.region,
|
|
747
|
+
"project_id": self.project_id,
|
|
748
|
+
}
|
|
749
|
+
|
|
704
750
|
def execute(self, context: Context):
|
|
705
751
|
hook = ModelServiceHook(
|
|
706
752
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -721,7 +767,7 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
721
767
|
timeout=self.timeout,
|
|
722
768
|
metadata=self.metadata,
|
|
723
769
|
)
|
|
724
|
-
VertexAIModelLink.persist(context=context,
|
|
770
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
725
771
|
return Model.to_dict(updated_model)
|
|
726
772
|
|
|
727
773
|
|