apache-airflow-providers-google 14.0.0__py3-none-any.whl → 19.1.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/3rd-party-licenses/LICENSES.txt +14 -0
- airflow/providers/google/3rd-party-licenses/NOTICE +5 -0
- airflow/providers/google/__init__.py +3 -3
- airflow/providers/google/_vendor/__init__.py +0 -0
- airflow/providers/google/_vendor/json_merge_patch.py +91 -0
- airflow/providers/google/ads/hooks/ads.py +52 -43
- airflow/providers/google/ads/operators/ads.py +2 -2
- airflow/providers/google/ads/transfers/ads_to_gcs.py +3 -19
- airflow/providers/google/assets/gcs.py +1 -11
- airflow/providers/google/cloud/_internal_client/secret_manager_client.py +3 -2
- airflow/providers/google/cloud/bundles/gcs.py +161 -0
- airflow/providers/google/cloud/hooks/alloy_db.py +2 -3
- airflow/providers/google/cloud/hooks/bigquery.py +195 -318
- airflow/providers/google/cloud/hooks/bigquery_dts.py +8 -8
- airflow/providers/google/cloud/hooks/bigtable.py +3 -2
- airflow/providers/google/cloud/hooks/cloud_batch.py +8 -9
- airflow/providers/google/cloud/hooks/cloud_build.py +6 -65
- airflow/providers/google/cloud/hooks/cloud_composer.py +292 -24
- airflow/providers/google/cloud/hooks/cloud_logging.py +109 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +4 -3
- airflow/providers/google/cloud/hooks/cloud_run.py +20 -11
- airflow/providers/google/cloud/hooks/cloud_sql.py +136 -64
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +35 -15
- airflow/providers/google/cloud/hooks/compute.py +7 -6
- airflow/providers/google/cloud/hooks/compute_ssh.py +7 -4
- airflow/providers/google/cloud/hooks/datacatalog.py +12 -3
- airflow/providers/google/cloud/hooks/dataflow.py +87 -242
- airflow/providers/google/cloud/hooks/dataform.py +9 -14
- airflow/providers/google/cloud/hooks/datafusion.py +7 -9
- airflow/providers/google/cloud/hooks/dataplex.py +13 -12
- airflow/providers/google/cloud/hooks/dataprep.py +2 -2
- airflow/providers/google/cloud/hooks/dataproc.py +76 -74
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/hooks/dlp.py +5 -4
- airflow/providers/google/cloud/hooks/gcs.py +144 -33
- airflow/providers/google/cloud/hooks/gen_ai.py +196 -0
- airflow/providers/google/cloud/hooks/kms.py +3 -2
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +22 -17
- airflow/providers/google/cloud/hooks/looker.py +6 -1
- airflow/providers/google/cloud/hooks/managed_kafka.py +227 -3
- airflow/providers/google/cloud/hooks/mlengine.py +7 -8
- airflow/providers/google/cloud/hooks/natural_language.py +3 -2
- airflow/providers/google/cloud/hooks/os_login.py +3 -2
- airflow/providers/google/cloud/hooks/pubsub.py +6 -6
- airflow/providers/google/cloud/hooks/secret_manager.py +105 -12
- airflow/providers/google/cloud/hooks/spanner.py +75 -10
- airflow/providers/google/cloud/hooks/speech_to_text.py +3 -2
- airflow/providers/google/cloud/hooks/stackdriver.py +18 -18
- airflow/providers/google/cloud/hooks/tasks.py +4 -3
- airflow/providers/google/cloud/hooks/text_to_speech.py +3 -2
- airflow/providers/google/cloud/hooks/translate.py +8 -17
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +8 -222
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +9 -15
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +33 -283
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +5 -12
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/experiment_service.py +202 -0
- airflow/providers/google/cloud/hooks/vertex_ai/feature_store.py +311 -10
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +79 -75
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +7 -13
- airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +8 -12
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +6 -12
- airflow/providers/google/cloud/hooks/vertex_ai/prediction_service.py +3 -2
- airflow/providers/google/cloud/hooks/vertex_ai/ray.py +223 -0
- airflow/providers/google/cloud/hooks/video_intelligence.py +3 -2
- airflow/providers/google/cloud/hooks/vision.py +7 -7
- airflow/providers/google/cloud/hooks/workflows.py +4 -3
- airflow/providers/google/cloud/links/alloy_db.py +0 -46
- airflow/providers/google/cloud/links/base.py +77 -7
- airflow/providers/google/cloud/links/bigquery.py +0 -47
- airflow/providers/google/cloud/links/bigquery_dts.py +0 -20
- airflow/providers/google/cloud/links/bigtable.py +0 -48
- airflow/providers/google/cloud/links/cloud_build.py +0 -73
- airflow/providers/google/cloud/links/cloud_functions.py +0 -33
- airflow/providers/google/cloud/links/cloud_memorystore.py +0 -58
- airflow/providers/google/cloud/links/{life_sciences.py → cloud_run.py} +5 -27
- airflow/providers/google/cloud/links/cloud_sql.py +0 -33
- airflow/providers/google/cloud/links/cloud_storage_transfer.py +17 -46
- airflow/providers/google/cloud/links/cloud_tasks.py +7 -26
- airflow/providers/google/cloud/links/compute.py +0 -58
- airflow/providers/google/cloud/links/data_loss_prevention.py +0 -169
- airflow/providers/google/cloud/links/datacatalog.py +23 -54
- airflow/providers/google/cloud/links/dataflow.py +0 -34
- airflow/providers/google/cloud/links/dataform.py +0 -64
- airflow/providers/google/cloud/links/datafusion.py +1 -90
- airflow/providers/google/cloud/links/dataplex.py +0 -154
- airflow/providers/google/cloud/links/dataprep.py +0 -24
- airflow/providers/google/cloud/links/dataproc.py +11 -89
- airflow/providers/google/cloud/links/datastore.py +0 -31
- airflow/providers/google/cloud/links/kubernetes_engine.py +11 -61
- airflow/providers/google/cloud/links/managed_kafka.py +11 -51
- airflow/providers/google/cloud/links/mlengine.py +0 -70
- airflow/providers/google/cloud/links/pubsub.py +0 -32
- airflow/providers/google/cloud/links/spanner.py +0 -33
- airflow/providers/google/cloud/links/stackdriver.py +0 -30
- airflow/providers/google/cloud/links/translate.py +17 -187
- airflow/providers/google/cloud/links/vertex_ai.py +28 -195
- airflow/providers/google/cloud/links/workflows.py +0 -52
- airflow/providers/google/cloud/log/gcs_task_handler.py +166 -118
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +14 -9
- airflow/providers/google/cloud/openlineage/CloudStorageTransferJobFacet.json +68 -0
- airflow/providers/google/cloud/openlineage/CloudStorageTransferRunFacet.json +60 -0
- airflow/providers/google/cloud/openlineage/DataFusionRunFacet.json +32 -0
- airflow/providers/google/cloud/openlineage/facets.py +141 -40
- airflow/providers/google/cloud/openlineage/mixins.py +14 -13
- airflow/providers/google/cloud/openlineage/utils.py +19 -3
- airflow/providers/google/cloud/operators/alloy_db.py +76 -61
- airflow/providers/google/cloud/operators/bigquery.py +104 -667
- airflow/providers/google/cloud/operators/bigquery_dts.py +12 -12
- airflow/providers/google/cloud/operators/bigtable.py +38 -7
- airflow/providers/google/cloud/operators/cloud_base.py +22 -1
- airflow/providers/google/cloud/operators/cloud_batch.py +18 -18
- airflow/providers/google/cloud/operators/cloud_build.py +80 -36
- airflow/providers/google/cloud/operators/cloud_composer.py +157 -71
- airflow/providers/google/cloud/operators/cloud_logging_sink.py +341 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +74 -46
- airflow/providers/google/cloud/operators/cloud_run.py +39 -20
- airflow/providers/google/cloud/operators/cloud_sql.py +46 -61
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +92 -14
- airflow/providers/google/cloud/operators/compute.py +18 -50
- airflow/providers/google/cloud/operators/datacatalog.py +167 -29
- airflow/providers/google/cloud/operators/dataflow.py +38 -15
- airflow/providers/google/cloud/operators/dataform.py +19 -7
- airflow/providers/google/cloud/operators/datafusion.py +43 -43
- airflow/providers/google/cloud/operators/dataplex.py +212 -126
- airflow/providers/google/cloud/operators/dataprep.py +1 -5
- airflow/providers/google/cloud/operators/dataproc.py +134 -207
- airflow/providers/google/cloud/operators/dataproc_metastore.py +102 -84
- airflow/providers/google/cloud/operators/datastore.py +22 -6
- airflow/providers/google/cloud/operators/dlp.py +24 -45
- airflow/providers/google/cloud/operators/functions.py +21 -14
- airflow/providers/google/cloud/operators/gcs.py +15 -12
- airflow/providers/google/cloud/operators/gen_ai.py +389 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +115 -106
- airflow/providers/google/cloud/operators/looker.py +1 -1
- airflow/providers/google/cloud/operators/managed_kafka.py +362 -40
- airflow/providers/google/cloud/operators/natural_language.py +5 -3
- airflow/providers/google/cloud/operators/pubsub.py +69 -21
- airflow/providers/google/cloud/operators/spanner.py +53 -45
- airflow/providers/google/cloud/operators/speech_to_text.py +5 -4
- airflow/providers/google/cloud/operators/stackdriver.py +5 -11
- airflow/providers/google/cloud/operators/tasks.py +6 -15
- airflow/providers/google/cloud/operators/text_to_speech.py +4 -3
- airflow/providers/google/cloud/operators/translate.py +46 -20
- airflow/providers/google/cloud/operators/translate_speech.py +4 -3
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +44 -34
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +34 -12
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +62 -53
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +75 -11
- airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py +48 -12
- airflow/providers/google/cloud/operators/vertex_ai/experiment_service.py +435 -0
- airflow/providers/google/cloud/operators/vertex_ai/feature_store.py +532 -1
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +135 -116
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +16 -12
- airflow/providers/google/cloud/operators/vertex_ai/model_service.py +62 -14
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +35 -10
- airflow/providers/google/cloud/operators/vertex_ai/ray.py +393 -0
- airflow/providers/google/cloud/operators/video_intelligence.py +5 -3
- airflow/providers/google/cloud/operators/vision.py +7 -5
- airflow/providers/google/cloud/operators/workflows.py +24 -19
- airflow/providers/google/cloud/secrets/secret_manager.py +2 -1
- airflow/providers/google/cloud/sensors/bigquery.py +2 -2
- airflow/providers/google/cloud/sensors/bigquery_dts.py +6 -4
- airflow/providers/google/cloud/sensors/bigtable.py +14 -6
- airflow/providers/google/cloud/sensors/cloud_composer.py +535 -33
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +6 -5
- airflow/providers/google/cloud/sensors/dataflow.py +27 -10
- airflow/providers/google/cloud/sensors/dataform.py +2 -2
- airflow/providers/google/cloud/sensors/datafusion.py +4 -4
- airflow/providers/google/cloud/sensors/dataplex.py +7 -5
- airflow/providers/google/cloud/sensors/dataprep.py +2 -2
- airflow/providers/google/cloud/sensors/dataproc.py +10 -9
- airflow/providers/google/cloud/sensors/dataproc_metastore.py +4 -3
- airflow/providers/google/cloud/sensors/gcs.py +22 -21
- airflow/providers/google/cloud/sensors/looker.py +5 -5
- airflow/providers/google/cloud/sensors/pubsub.py +20 -20
- airflow/providers/google/cloud/sensors/tasks.py +2 -2
- airflow/providers/google/cloud/sensors/vertex_ai/feature_store.py +2 -2
- airflow/providers/google/cloud/sensors/workflows.py +6 -4
- airflow/providers/google/cloud/transfers/adls_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +11 -8
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +14 -13
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +7 -3
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +12 -1
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +24 -10
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +104 -5
- airflow/providers/google/cloud/transfers/calendar_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +18 -22
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +4 -5
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +45 -38
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/gcs_to_local.py +5 -3
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +10 -4
- airflow/providers/google/cloud/transfers/gdrive_to_gcs.py +6 -2
- airflow/providers/google/cloud/transfers/gdrive_to_local.py +2 -2
- airflow/providers/google/cloud/transfers/http_to_gcs.py +193 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -1
- airflow/providers/google/cloud/transfers/oracle_to_gcs.py +36 -11
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +44 -12
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +12 -6
- airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +2 -2
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +36 -14
- airflow/providers/google/cloud/transfers/sheets_to_gcs.py +3 -3
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +10 -10
- airflow/providers/google/cloud/triggers/bigquery.py +75 -34
- airflow/providers/google/cloud/triggers/bigquery_dts.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_batch.py +2 -1
- airflow/providers/google/cloud/triggers/cloud_build.py +3 -2
- airflow/providers/google/cloud/triggers/cloud_composer.py +303 -47
- airflow/providers/google/cloud/triggers/cloud_run.py +2 -2
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +96 -5
- airflow/providers/google/cloud/triggers/dataflow.py +125 -2
- airflow/providers/google/cloud/triggers/datafusion.py +1 -1
- airflow/providers/google/cloud/triggers/dataplex.py +16 -3
- airflow/providers/google/cloud/triggers/dataproc.py +124 -53
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +46 -28
- airflow/providers/google/cloud/triggers/mlengine.py +1 -1
- airflow/providers/google/cloud/triggers/pubsub.py +17 -20
- airflow/providers/google/cloud/triggers/vertex_ai.py +8 -7
- airflow/providers/google/cloud/utils/bigquery.py +5 -7
- airflow/providers/google/cloud/utils/bigquery_get_data.py +1 -1
- airflow/providers/google/cloud/utils/credentials_provider.py +4 -3
- airflow/providers/google/cloud/utils/dataform.py +1 -1
- airflow/providers/google/cloud/utils/external_token_supplier.py +0 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -2
- airflow/providers/google/cloud/utils/validators.py +43 -0
- airflow/providers/google/common/auth_backend/google_openid.py +26 -9
- airflow/providers/google/common/consts.py +2 -1
- airflow/providers/google/common/deprecated.py +2 -1
- airflow/providers/google/common/hooks/base_google.py +40 -43
- airflow/providers/google/common/hooks/operation_helpers.py +78 -0
- airflow/providers/google/common/links/storage.py +0 -22
- airflow/providers/google/common/utils/get_secret.py +31 -0
- airflow/providers/google/common/utils/id_token_credentials.py +4 -5
- airflow/providers/google/firebase/operators/firestore.py +2 -2
- airflow/providers/google/get_provider_info.py +61 -216
- airflow/providers/google/go_module_utils.py +35 -3
- airflow/providers/google/leveldb/hooks/leveldb.py +30 -6
- airflow/providers/google/leveldb/operators/leveldb.py +2 -2
- airflow/providers/google/marketing_platform/hooks/analytics_admin.py +3 -2
- airflow/providers/google/marketing_platform/hooks/display_video.py +3 -109
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -1
- airflow/providers/google/marketing_platform/links/analytics_admin.py +4 -5
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +7 -6
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +5 -5
- airflow/providers/google/marketing_platform/operators/display_video.py +28 -489
- airflow/providers/google/marketing_platform/operators/search_ads.py +2 -2
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +2 -2
- airflow/providers/google/marketing_platform/sensors/display_video.py +4 -64
- airflow/providers/google/suite/hooks/calendar.py +1 -1
- airflow/providers/google/suite/hooks/drive.py +2 -2
- airflow/providers/google/suite/hooks/sheets.py +15 -1
- airflow/providers/google/suite/operators/sheets.py +8 -3
- airflow/providers/google/suite/sensors/drive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +2 -2
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +1 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +3 -3
- airflow/providers/google/suite/transfers/sql_to_sheets.py +5 -4
- airflow/providers/google/version_compat.py +15 -1
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/METADATA +117 -72
- apache_airflow_providers_google-19.1.0rc1.dist-info/RECORD +331 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/WHEEL +1 -1
- apache_airflow_providers_google-19.1.0rc1.dist-info/licenses/NOTICE +5 -0
- airflow/providers/google/cloud/example_dags/example_cloud_task.py +0 -54
- airflow/providers/google/cloud/hooks/automl.py +0 -679
- airflow/providers/google/cloud/hooks/life_sciences.py +0 -159
- airflow/providers/google/cloud/links/automl.py +0 -193
- airflow/providers/google/cloud/operators/automl.py +0 -1360
- airflow/providers/google/cloud/operators/life_sciences.py +0 -119
- airflow/providers/google/cloud/operators/mlengine.py +0 -1515
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +0 -273
- apache_airflow_providers_google-14.0.0.dist-info/RECORD +0 -318
- /airflow/providers/google/cloud/{example_dags → bundles}/__init__.py +0 -0
- {apache_airflow_providers_google-14.0.0.dist-info → apache_airflow_providers_google-19.1.0rc1.dist-info}/entry_points.txt +0 -0
- {airflow/providers/google → apache_airflow_providers_google-19.1.0rc1.dist-info/licenses}/LICENSE +0 -0
|
@@ -20,107 +20,27 @@
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
|
-
from typing import TYPE_CHECKING
|
|
23
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
24
24
|
|
|
25
|
-
from
|
|
26
|
-
|
|
25
|
+
from google.api_core import exceptions
|
|
26
|
+
|
|
27
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
|
28
|
+
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
|
|
29
|
+
ExperimentRunHook,
|
|
30
|
+
GenerativeModelHook,
|
|
31
|
+
)
|
|
27
32
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
28
33
|
from airflow.providers.google.common.deprecated import deprecated
|
|
29
34
|
|
|
30
35
|
if TYPE_CHECKING:
|
|
31
|
-
from airflow.
|
|
36
|
+
from airflow.providers.common.compat.sdk import Context
|
|
32
37
|
|
|
33
38
|
|
|
34
39
|
@deprecated(
|
|
35
|
-
planned_removal_date="
|
|
36
|
-
use_instead="
|
|
40
|
+
planned_removal_date="January 3, 2026",
|
|
41
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateEmbeddingsOperator",
|
|
37
42
|
category=AirflowProviderDeprecationWarning,
|
|
38
43
|
)
|
|
39
|
-
class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
|
|
40
|
-
"""
|
|
41
|
-
Uses the Vertex AI PaLM API to generate natural language text.
|
|
42
|
-
|
|
43
|
-
:param project_id: Required. The ID of the Google Cloud project that the
|
|
44
|
-
service belongs to (templated).
|
|
45
|
-
:param location: Required. The ID of the Google Cloud location that the
|
|
46
|
-
service belongs to (templated).
|
|
47
|
-
:param prompt: Required. Inputs or queries that a user or a program gives
|
|
48
|
-
to the Vertex AI PaLM API, in order to elicit a specific response (templated).
|
|
49
|
-
:param pretrained_model: By default uses the pre-trained model `text-bison`,
|
|
50
|
-
optimized for performing natural language tasks such as classification,
|
|
51
|
-
summarization, extraction, content creation, and ideation.
|
|
52
|
-
:param temperature: Temperature controls the degree of randomness in token
|
|
53
|
-
selection. Defaults to 0.0.
|
|
54
|
-
:param max_output_tokens: Token limit determines the maximum amount of text
|
|
55
|
-
output. Defaults to 256.
|
|
56
|
-
:param top_p: Tokens are selected from most probable to least until the sum
|
|
57
|
-
of their probabilities equals the top_p value. Defaults to 0.8.
|
|
58
|
-
:param top_k: A top_k of 1 means the selected token is the most probable
|
|
59
|
-
among all tokens. Defaults to 0.4.
|
|
60
|
-
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
61
|
-
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
62
|
-
credentials, or chained list of accounts required to get the access_token
|
|
63
|
-
of the last account in the list, which will be impersonated in the request.
|
|
64
|
-
If set as a string, the account must grant the originating account
|
|
65
|
-
the Service Account Token Creator IAM role.
|
|
66
|
-
If set as a sequence, the identities from the list must grant
|
|
67
|
-
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
68
|
-
account from the list granting this role to the originating account (templated).
|
|
69
|
-
"""
|
|
70
|
-
|
|
71
|
-
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
|
72
|
-
|
|
73
|
-
def __init__(
|
|
74
|
-
self,
|
|
75
|
-
*,
|
|
76
|
-
project_id: str,
|
|
77
|
-
location: str,
|
|
78
|
-
prompt: str,
|
|
79
|
-
pretrained_model: str = "text-bison",
|
|
80
|
-
temperature: float = 0.0,
|
|
81
|
-
max_output_tokens: int = 256,
|
|
82
|
-
top_p: float = 0.8,
|
|
83
|
-
top_k: int = 40,
|
|
84
|
-
gcp_conn_id: str = "google_cloud_default",
|
|
85
|
-
impersonation_chain: str | Sequence[str] | None = None,
|
|
86
|
-
**kwargs,
|
|
87
|
-
) -> None:
|
|
88
|
-
super().__init__(**kwargs)
|
|
89
|
-
self.project_id = project_id
|
|
90
|
-
self.location = location
|
|
91
|
-
self.prompt = prompt
|
|
92
|
-
self.pretrained_model = pretrained_model
|
|
93
|
-
self.temperature = temperature
|
|
94
|
-
self.max_output_tokens = max_output_tokens
|
|
95
|
-
self.top_p = top_p
|
|
96
|
-
self.top_k = top_k
|
|
97
|
-
self.gcp_conn_id = gcp_conn_id
|
|
98
|
-
self.impersonation_chain = impersonation_chain
|
|
99
|
-
|
|
100
|
-
def execute(self, context: Context):
|
|
101
|
-
self.hook = GenerativeModelHook(
|
|
102
|
-
gcp_conn_id=self.gcp_conn_id,
|
|
103
|
-
impersonation_chain=self.impersonation_chain,
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
self.log.info("Submitting prompt")
|
|
107
|
-
response = self.hook.text_generation_model_predict(
|
|
108
|
-
project_id=self.project_id,
|
|
109
|
-
location=self.location,
|
|
110
|
-
prompt=self.prompt,
|
|
111
|
-
pretrained_model=self.pretrained_model,
|
|
112
|
-
temperature=self.temperature,
|
|
113
|
-
max_output_tokens=self.max_output_tokens,
|
|
114
|
-
top_p=self.top_p,
|
|
115
|
-
top_k=self.top_k,
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
self.log.info("Model response: %s", response)
|
|
119
|
-
self.xcom_push(context, key="model_response", value=response)
|
|
120
|
-
|
|
121
|
-
return response
|
|
122
|
-
|
|
123
|
-
|
|
124
44
|
class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
125
45
|
"""
|
|
126
46
|
Uses the Vertex AI Embeddings API to generate embeddings based on prompt.
|
|
@@ -130,9 +50,8 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
130
50
|
:param location: Required. The ID of the Google Cloud location that the
|
|
131
51
|
service belongs to (templated).
|
|
132
52
|
:param prompt: Required. Inputs or queries that a user or a program gives
|
|
133
|
-
to the Vertex AI
|
|
134
|
-
:param pretrained_model:
|
|
135
|
-
optimized for performing text embeddings.
|
|
53
|
+
to the Vertex AI Generative Model API, in order to elicit a specific response (templated).
|
|
54
|
+
:param pretrained_model: Required. Model, optimized for performing text embeddings.
|
|
136
55
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
137
56
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
138
57
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -144,7 +63,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
144
63
|
account from the list granting this role to the originating account (templated).
|
|
145
64
|
"""
|
|
146
65
|
|
|
147
|
-
template_fields = ("location", "project_id", "impersonation_chain", "prompt")
|
|
66
|
+
template_fields = ("location", "project_id", "impersonation_chain", "prompt", "pretrained_model")
|
|
148
67
|
|
|
149
68
|
def __init__(
|
|
150
69
|
self,
|
|
@@ -152,7 +71,7 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
152
71
|
project_id: str,
|
|
153
72
|
location: str,
|
|
154
73
|
prompt: str,
|
|
155
|
-
pretrained_model: str
|
|
74
|
+
pretrained_model: str,
|
|
156
75
|
gcp_conn_id: str = "google_cloud_default",
|
|
157
76
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
158
77
|
**kwargs,
|
|
@@ -180,11 +99,16 @@ class TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
|
|
|
180
99
|
)
|
|
181
100
|
|
|
182
101
|
self.log.info("Model response: %s", response)
|
|
183
|
-
|
|
102
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
|
184
103
|
|
|
185
104
|
return response
|
|
186
105
|
|
|
187
106
|
|
|
107
|
+
@deprecated(
|
|
108
|
+
planned_removal_date="January 3, 2026",
|
|
109
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
|
|
110
|
+
category=AirflowProviderDeprecationWarning,
|
|
111
|
+
)
|
|
188
112
|
class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
189
113
|
"""
|
|
190
114
|
Use the Vertex AI Gemini Pro foundation model to generate content.
|
|
@@ -199,10 +123,9 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
199
123
|
:param safety_settings: Optional. Per request settings for blocking unsafe content.
|
|
200
124
|
:param tools: Optional. A list of tools available to the model during evaluation, such as a data store.
|
|
201
125
|
:param system_instruction: Optional. An instruction given to the model to guide its behavior.
|
|
202
|
-
:param pretrained_model:
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
output text and code.
|
|
126
|
+
:param pretrained_model: Required. The name of the model to use for content generation,
|
|
127
|
+
which can be a text-only or multimodal model. For example, `gemini-pro` or
|
|
128
|
+
`gemini-pro-vision`.
|
|
206
129
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
207
130
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
208
131
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -226,7 +149,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
226
149
|
generation_config: dict | None = None,
|
|
227
150
|
safety_settings: dict | None = None,
|
|
228
151
|
system_instruction: str | None = None,
|
|
229
|
-
pretrained_model: str
|
|
152
|
+
pretrained_model: str,
|
|
230
153
|
gcp_conn_id: str = "google_cloud_default",
|
|
231
154
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
232
155
|
**kwargs,
|
|
@@ -260,11 +183,16 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
|
|
|
260
183
|
)
|
|
261
184
|
|
|
262
185
|
self.log.info("Model response: %s", response)
|
|
263
|
-
|
|
186
|
+
context["ti"].xcom_push(key="model_response", value=response)
|
|
264
187
|
|
|
265
188
|
return response
|
|
266
189
|
|
|
267
190
|
|
|
191
|
+
@deprecated(
|
|
192
|
+
planned_removal_date="January 3, 2026",
|
|
193
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAISupervisedFineTuningTrainOperator",
|
|
194
|
+
category=AirflowProviderDeprecationWarning,
|
|
195
|
+
)
|
|
268
196
|
class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
269
197
|
"""
|
|
270
198
|
Use the Supervised Fine Tuning API to create a tuning job.
|
|
@@ -298,7 +226,14 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
298
226
|
account from the list granting this role to the originating account (templated).
|
|
299
227
|
"""
|
|
300
228
|
|
|
301
|
-
template_fields = (
|
|
229
|
+
template_fields = (
|
|
230
|
+
"location",
|
|
231
|
+
"project_id",
|
|
232
|
+
"impersonation_chain",
|
|
233
|
+
"train_dataset",
|
|
234
|
+
"validation_dataset",
|
|
235
|
+
"source_model",
|
|
236
|
+
)
|
|
302
237
|
|
|
303
238
|
def __init__(
|
|
304
239
|
self,
|
|
@@ -310,7 +245,7 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
310
245
|
tuned_model_display_name: str | None = None,
|
|
311
246
|
validation_dataset: str | None = None,
|
|
312
247
|
epochs: int | None = None,
|
|
313
|
-
adapter_size:
|
|
248
|
+
adapter_size: Literal[1, 4, 8, 16] | None = None,
|
|
314
249
|
learning_rate_multiplier: float | None = None,
|
|
315
250
|
gcp_conn_id: str = "google_cloud_default",
|
|
316
251
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
@@ -349,8 +284,8 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
349
284
|
self.log.info("Tuned Model Name: %s", response.tuned_model_name)
|
|
350
285
|
self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)
|
|
351
286
|
|
|
352
|
-
|
|
353
|
-
|
|
287
|
+
context["ti"].xcom_push(key="tuned_model_name", value=response.tuned_model_name)
|
|
288
|
+
context["ti"].xcom_push(key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)
|
|
354
289
|
|
|
355
290
|
result = {
|
|
356
291
|
"tuned_model_name": response.tuned_model_name,
|
|
@@ -360,6 +295,11 @@ class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
|
|
|
360
295
|
return result
|
|
361
296
|
|
|
362
297
|
|
|
298
|
+
@deprecated(
|
|
299
|
+
planned_removal_date="January 3, 2026",
|
|
300
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICountTokensOperator",
|
|
301
|
+
category=AirflowProviderDeprecationWarning,
|
|
302
|
+
)
|
|
363
303
|
class CountTokensOperator(GoogleCloudBaseOperator):
|
|
364
304
|
"""
|
|
365
305
|
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
|
|
@@ -370,10 +310,9 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
|
370
310
|
service belongs to (templated).
|
|
371
311
|
:param contents: Required. The multi-part content of a message that a user or a program
|
|
372
312
|
gives to the generative model, in order to elicit a specific response.
|
|
373
|
-
:param pretrained_model:
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
output text and code.
|
|
313
|
+
:param pretrained_model: Required. Model, supporting prompts with text-only input,
|
|
314
|
+
including natural language tasks, multi-turn text and code chat,
|
|
315
|
+
and code generation. It can output text and code.
|
|
377
316
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
378
317
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
379
318
|
credentials, or chained list of accounts required to get the access_token
|
|
@@ -393,7 +332,7 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
|
393
332
|
project_id: str,
|
|
394
333
|
location: str,
|
|
395
334
|
contents: list,
|
|
396
|
-
pretrained_model: str
|
|
335
|
+
pretrained_model: str,
|
|
397
336
|
gcp_conn_id: str = "google_cloud_default",
|
|
398
337
|
impersonation_chain: str | Sequence[str] | None = None,
|
|
399
338
|
**kwargs,
|
|
@@ -421,8 +360,8 @@ class CountTokensOperator(GoogleCloudBaseOperator):
|
|
|
421
360
|
self.log.info("Total tokens: %s", response.total_tokens)
|
|
422
361
|
self.log.info("Total billable characters: %s", response.total_billable_characters)
|
|
423
362
|
|
|
424
|
-
|
|
425
|
-
|
|
363
|
+
context["ti"].xcom_push(key="total_tokens", value=response.total_tokens)
|
|
364
|
+
context["ti"].xcom_push(key="total_billable_characters", value=response.total_billable_characters)
|
|
426
365
|
|
|
427
366
|
|
|
428
367
|
class RunEvaluationOperator(GoogleCloudBaseOperator):
|
|
@@ -524,6 +463,11 @@ class RunEvaluationOperator(GoogleCloudBaseOperator):
|
|
|
524
463
|
return response.summary_metrics
|
|
525
464
|
|
|
526
465
|
|
|
466
|
+
@deprecated(
|
|
467
|
+
planned_removal_date="January 3, 2026",
|
|
468
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAICreateCachedContentOperator",
|
|
469
|
+
category=AirflowProviderDeprecationWarning,
|
|
470
|
+
)
|
|
527
471
|
class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
528
472
|
"""
|
|
529
473
|
Create CachedContent to reduce the cost of requests that contain repeat content with high input token counts.
|
|
@@ -562,8 +506,8 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
562
506
|
project_id: str,
|
|
563
507
|
location: str,
|
|
564
508
|
model_name: str,
|
|
565
|
-
system_instruction:
|
|
566
|
-
contents: list | None = None,
|
|
509
|
+
system_instruction: Any | None = None,
|
|
510
|
+
contents: list[Any] | None = None,
|
|
567
511
|
ttl_hours: float = 1,
|
|
568
512
|
display_name: str | None = None,
|
|
569
513
|
gcp_conn_id: str = "google_cloud_default",
|
|
@@ -603,6 +547,11 @@ class CreateCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
603
547
|
return cached_content_name
|
|
604
548
|
|
|
605
549
|
|
|
550
|
+
@deprecated(
|
|
551
|
+
planned_removal_date="January 3, 2026",
|
|
552
|
+
use_instead="airflow.providers.google.cloud.operators.gen_ai.generative_model.GenAIGenerateContentOperator",
|
|
553
|
+
category=AirflowProviderDeprecationWarning,
|
|
554
|
+
)
|
|
606
555
|
class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
|
|
607
556
|
"""
|
|
608
557
|
Generate a response from CachedContent.
|
|
@@ -674,3 +623,73 @@ class GenerateFromCachedContentOperator(GoogleCloudBaseOperator):
|
|
|
674
623
|
self.log.info("Cached Content Response: %s", cached_content_text)
|
|
675
624
|
|
|
676
625
|
return cached_content_text
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
@deprecated(
|
|
629
|
+
planned_removal_date="January 3, 2026",
|
|
630
|
+
use_instead="airflow.providers.google.cloud.operators.vertex_ai.experiment_service.DeleteExperimentRunOperator",
|
|
631
|
+
category=AirflowProviderDeprecationWarning,
|
|
632
|
+
)
|
|
633
|
+
class DeleteExperimentRunOperator(GoogleCloudBaseOperator):
|
|
634
|
+
"""
|
|
635
|
+
Use the Rapid Evaluation API to evaluate a model.
|
|
636
|
+
|
|
637
|
+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
|
638
|
+
:param location: Required. The ID of the Google Cloud location that the service belongs to.
|
|
639
|
+
:param experiment_name: Required. The name of the evaluation experiment.
|
|
640
|
+
:param experiment_run_name: Required. The specific run name or ID for this experiment.
|
|
641
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
|
642
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
|
643
|
+
credentials, or chained list of accounts required to get the access_token
|
|
644
|
+
of the last account in the list, which will be impersonated in the request.
|
|
645
|
+
If set as a string, the account must grant the originating account
|
|
646
|
+
the Service Account Token Creator IAM role.
|
|
647
|
+
If set as a sequence, the identities from the list must grant
|
|
648
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
|
649
|
+
account from the list granting this role to the originating account (templated).
|
|
650
|
+
"""
|
|
651
|
+
|
|
652
|
+
template_fields = (
|
|
653
|
+
"location",
|
|
654
|
+
"project_id",
|
|
655
|
+
"impersonation_chain",
|
|
656
|
+
"experiment_name",
|
|
657
|
+
"experiment_run_name",
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
def __init__(
|
|
661
|
+
self,
|
|
662
|
+
*,
|
|
663
|
+
project_id: str,
|
|
664
|
+
location: str,
|
|
665
|
+
experiment_name: str,
|
|
666
|
+
experiment_run_name: str,
|
|
667
|
+
gcp_conn_id: str = "google_cloud_default",
|
|
668
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
|
669
|
+
**kwargs,
|
|
670
|
+
) -> None:
|
|
671
|
+
super().__init__(**kwargs)
|
|
672
|
+
self.project_id = project_id
|
|
673
|
+
self.location = location
|
|
674
|
+
self.experiment_name = experiment_name
|
|
675
|
+
self.experiment_run_name = experiment_run_name
|
|
676
|
+
self.gcp_conn_id = gcp_conn_id
|
|
677
|
+
self.impersonation_chain = impersonation_chain
|
|
678
|
+
|
|
679
|
+
def execute(self, context: Context) -> None:
|
|
680
|
+
self.hook = ExperimentRunHook(
|
|
681
|
+
gcp_conn_id=self.gcp_conn_id,
|
|
682
|
+
impersonation_chain=self.impersonation_chain,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
try:
|
|
686
|
+
self.hook.delete_experiment_run(
|
|
687
|
+
project_id=self.project_id,
|
|
688
|
+
location=self.location,
|
|
689
|
+
experiment_name=self.experiment_name,
|
|
690
|
+
experiment_run_name=self.experiment_run_name,
|
|
691
|
+
)
|
|
692
|
+
except exceptions.NotFound:
|
|
693
|
+
raise AirflowException(f"Experiment Run with name {self.experiment_run_name} not found")
|
|
694
|
+
|
|
695
|
+
self.log.info("Deleted experiment run: %s", self.experiment_run_name)
|
|
@@ -23,6 +23,10 @@ from __future__ import annotations
|
|
|
23
23
|
from collections.abc import Sequence
|
|
24
24
|
from typing import TYPE_CHECKING, Any
|
|
25
25
|
|
|
26
|
+
from google.api_core.exceptions import NotFound
|
|
27
|
+
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
28
|
+
from google.cloud.aiplatform_v1 import types
|
|
29
|
+
|
|
26
30
|
from airflow.configuration import conf
|
|
27
31
|
from airflow.exceptions import AirflowException
|
|
28
32
|
from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
|
|
@@ -34,15 +38,13 @@ from airflow.providers.google.cloud.links.vertex_ai import (
|
|
|
34
38
|
)
|
|
35
39
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
36
40
|
from airflow.providers.google.cloud.triggers.vertex_ai import CreateHyperparameterTuningJobTrigger
|
|
37
|
-
from google.api_core.exceptions import NotFound
|
|
38
|
-
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
39
|
-
from google.cloud.aiplatform_v1 import types
|
|
40
41
|
|
|
41
42
|
if TYPE_CHECKING:
|
|
42
|
-
from airflow.utils.context import Context
|
|
43
43
|
from google.api_core.retry import Retry
|
|
44
44
|
from google.cloud.aiplatform import HyperparameterTuningJob, gapic, hyperparameter_tuning
|
|
45
45
|
|
|
46
|
+
from airflow.providers.common.compat.sdk import Context
|
|
47
|
+
|
|
46
48
|
|
|
47
49
|
class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
48
50
|
"""
|
|
@@ -255,10 +257,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
255
257
|
hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
|
|
256
258
|
self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
|
|
257
259
|
|
|
258
|
-
|
|
259
|
-
VertexAITrainingLink.persist(
|
|
260
|
-
context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
|
|
261
|
-
)
|
|
260
|
+
context["ti"].xcom_push(key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
|
|
261
|
+
VertexAITrainingLink.persist(context=context, training_id=hyperparameter_tuning_job_id)
|
|
262
262
|
|
|
263
263
|
if self.deferrable:
|
|
264
264
|
self.defer(
|
|
@@ -353,9 +353,7 @@ class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
353
353
|
timeout=self.timeout,
|
|
354
354
|
metadata=self.metadata,
|
|
355
355
|
)
|
|
356
|
-
VertexAITrainingLink.persist(
|
|
357
|
-
context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id
|
|
358
|
-
)
|
|
356
|
+
VertexAITrainingLink.persist(context=context, training_id=self.hyperparameter_tuning_job_id)
|
|
359
357
|
self.log.info("Hyperparameter tuning job was gotten.")
|
|
360
358
|
return types.HyperparameterTuningJob.to_dict(result)
|
|
361
359
|
except NotFound:
|
|
@@ -485,6 +483,12 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
485
483
|
self.gcp_conn_id = gcp_conn_id
|
|
486
484
|
self.impersonation_chain = impersonation_chain
|
|
487
485
|
|
|
486
|
+
@property
|
|
487
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
488
|
+
return {
|
|
489
|
+
"project_id": self.project_id,
|
|
490
|
+
}
|
|
491
|
+
|
|
488
492
|
def execute(self, context: Context):
|
|
489
493
|
hook = HyperparameterTuningJobHook(
|
|
490
494
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -501,5 +505,5 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
|
501
505
|
timeout=self.timeout,
|
|
502
506
|
metadata=self.metadata,
|
|
503
507
|
)
|
|
504
|
-
VertexAIHyperparameterTuningJobListLink.persist(context=context
|
|
508
|
+
VertexAIHyperparameterTuningJobListLink.persist(context=context)
|
|
505
509
|
return [types.HyperparameterTuningJob.to_dict(result) for result in results]
|
|
@@ -20,7 +20,11 @@
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
from collections.abc import Sequence
|
|
23
|
-
from typing import TYPE_CHECKING
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
|
+
|
|
25
|
+
from google.api_core.exceptions import NotFound
|
|
26
|
+
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
27
|
+
from google.cloud.aiplatform_v1.types import Model, model_service
|
|
24
28
|
|
|
25
29
|
from airflow.providers.google.cloud.hooks.vertex_ai.model_service import ModelServiceHook
|
|
26
30
|
from airflow.providers.google.cloud.links.vertex_ai import (
|
|
@@ -29,14 +33,12 @@ from airflow.providers.google.cloud.links.vertex_ai import (
|
|
|
29
33
|
VertexAIModelListLink,
|
|
30
34
|
)
|
|
31
35
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
|
32
|
-
from google.api_core.exceptions import NotFound
|
|
33
|
-
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
|
34
|
-
from google.cloud.aiplatform_v1.types import Model, model_service
|
|
35
36
|
|
|
36
37
|
if TYPE_CHECKING:
|
|
37
|
-
from airflow.utils.context import Context
|
|
38
38
|
from google.api_core.retry import Retry
|
|
39
39
|
|
|
40
|
+
from airflow.providers.common.compat.sdk import Context
|
|
41
|
+
|
|
40
42
|
|
|
41
43
|
class DeleteModelOperator(GoogleCloudBaseOperator):
|
|
42
44
|
"""
|
|
@@ -159,6 +161,13 @@ class GetModelOperator(GoogleCloudBaseOperator):
|
|
|
159
161
|
self.gcp_conn_id = gcp_conn_id
|
|
160
162
|
self.impersonation_chain = impersonation_chain
|
|
161
163
|
|
|
164
|
+
@property
|
|
165
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
166
|
+
return {
|
|
167
|
+
"region": self.region,
|
|
168
|
+
"project_id": self.project_id,
|
|
169
|
+
}
|
|
170
|
+
|
|
162
171
|
def execute(self, context: Context):
|
|
163
172
|
hook = ModelServiceHook(
|
|
164
173
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -177,8 +186,8 @@ class GetModelOperator(GoogleCloudBaseOperator):
|
|
|
177
186
|
)
|
|
178
187
|
self.log.info("Model found. Model ID: %s", self.model_id)
|
|
179
188
|
|
|
180
|
-
|
|
181
|
-
VertexAIModelLink.persist(context=context,
|
|
189
|
+
context["ti"].xcom_push(key="model_id", value=self.model_id)
|
|
190
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
182
191
|
return Model.to_dict(model)
|
|
183
192
|
except NotFound:
|
|
184
193
|
self.log.info("The Model ID %s does not exist.", self.model_id)
|
|
@@ -255,7 +264,12 @@ class ExportModelOperator(GoogleCloudBaseOperator):
|
|
|
255
264
|
metadata=self.metadata,
|
|
256
265
|
)
|
|
257
266
|
hook.wait_for_operation(timeout=self.timeout, operation=operation)
|
|
258
|
-
VertexAIModelExportLink.persist(
|
|
267
|
+
VertexAIModelExportLink.persist(
|
|
268
|
+
context=context,
|
|
269
|
+
output_config=self.output_config,
|
|
270
|
+
model_id=self.model_id,
|
|
271
|
+
project_id=self.project_id,
|
|
272
|
+
)
|
|
259
273
|
self.log.info("Model was exported.")
|
|
260
274
|
except NotFound:
|
|
261
275
|
self.log.info("The Model ID %s does not exist.", self.model_id)
|
|
@@ -333,6 +347,12 @@ class ListModelsOperator(GoogleCloudBaseOperator):
|
|
|
333
347
|
self.gcp_conn_id = gcp_conn_id
|
|
334
348
|
self.impersonation_chain = impersonation_chain
|
|
335
349
|
|
|
350
|
+
@property
|
|
351
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
352
|
+
return {
|
|
353
|
+
"project_id": self.project_id,
|
|
354
|
+
}
|
|
355
|
+
|
|
336
356
|
def execute(self, context: Context):
|
|
337
357
|
hook = ModelServiceHook(
|
|
338
358
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -350,7 +370,7 @@ class ListModelsOperator(GoogleCloudBaseOperator):
|
|
|
350
370
|
timeout=self.timeout,
|
|
351
371
|
metadata=self.metadata,
|
|
352
372
|
)
|
|
353
|
-
VertexAIModelListLink.persist(context=context
|
|
373
|
+
VertexAIModelListLink.persist(context=context)
|
|
354
374
|
return [Model.to_dict(result) for result in results]
|
|
355
375
|
|
|
356
376
|
|
|
@@ -405,6 +425,13 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
|
405
425
|
self.gcp_conn_id = gcp_conn_id
|
|
406
426
|
self.impersonation_chain = impersonation_chain
|
|
407
427
|
|
|
428
|
+
@property
|
|
429
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
430
|
+
return {
|
|
431
|
+
"region": self.region,
|
|
432
|
+
"project_id": self.project_id,
|
|
433
|
+
}
|
|
434
|
+
|
|
408
435
|
def execute(self, context: Context):
|
|
409
436
|
hook = ModelServiceHook(
|
|
410
437
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -426,8 +453,8 @@ class UploadModelOperator(GoogleCloudBaseOperator):
|
|
|
426
453
|
model_id = hook.extract_model_id(model_resp)
|
|
427
454
|
self.log.info("Model was uploaded. Model ID: %s", model_id)
|
|
428
455
|
|
|
429
|
-
|
|
430
|
-
VertexAIModelLink.persist(context=context,
|
|
456
|
+
context["ti"].xcom_push(key="model_id", value=model_id)
|
|
457
|
+
VertexAIModelLink.persist(context=context, model_id=model_id)
|
|
431
458
|
return model_resp
|
|
432
459
|
|
|
433
460
|
|
|
@@ -551,6 +578,13 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
|
|
|
551
578
|
self.gcp_conn_id = gcp_conn_id
|
|
552
579
|
self.impersonation_chain = impersonation_chain
|
|
553
580
|
|
|
581
|
+
@property
|
|
582
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
583
|
+
return {
|
|
584
|
+
"region": self.region,
|
|
585
|
+
"project_id": self.project_id,
|
|
586
|
+
}
|
|
587
|
+
|
|
554
588
|
def execute(self, context: Context):
|
|
555
589
|
hook = ModelServiceHook(
|
|
556
590
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -569,7 +603,7 @@ class SetDefaultVersionOnModelOperator(GoogleCloudBaseOperator):
|
|
|
569
603
|
timeout=self.timeout,
|
|
570
604
|
metadata=self.metadata,
|
|
571
605
|
)
|
|
572
|
-
VertexAIModelLink.persist(context=context,
|
|
606
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
573
607
|
return Model.to_dict(updated_model)
|
|
574
608
|
|
|
575
609
|
|
|
@@ -625,6 +659,13 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
625
659
|
self.gcp_conn_id = gcp_conn_id
|
|
626
660
|
self.impersonation_chain = impersonation_chain
|
|
627
661
|
|
|
662
|
+
@property
|
|
663
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
664
|
+
return {
|
|
665
|
+
"region": self.region,
|
|
666
|
+
"project_id": self.project_id,
|
|
667
|
+
}
|
|
668
|
+
|
|
628
669
|
def execute(self, context: Context):
|
|
629
670
|
hook = ModelServiceHook(
|
|
630
671
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -643,7 +684,7 @@ class AddVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
643
684
|
timeout=self.timeout,
|
|
644
685
|
metadata=self.metadata,
|
|
645
686
|
)
|
|
646
|
-
VertexAIModelLink.persist(context=context,
|
|
687
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
647
688
|
return Model.to_dict(updated_model)
|
|
648
689
|
|
|
649
690
|
|
|
@@ -699,6 +740,13 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
699
740
|
self.gcp_conn_id = gcp_conn_id
|
|
700
741
|
self.impersonation_chain = impersonation_chain
|
|
701
742
|
|
|
743
|
+
@property
|
|
744
|
+
def extra_links_params(self) -> dict[str, Any]:
|
|
745
|
+
return {
|
|
746
|
+
"region": self.region,
|
|
747
|
+
"project_id": self.project_id,
|
|
748
|
+
}
|
|
749
|
+
|
|
702
750
|
def execute(self, context: Context):
|
|
703
751
|
hook = ModelServiceHook(
|
|
704
752
|
gcp_conn_id=self.gcp_conn_id,
|
|
@@ -719,7 +767,7 @@ class DeleteVersionAliasesOnModelOperator(GoogleCloudBaseOperator):
|
|
|
719
767
|
timeout=self.timeout,
|
|
720
768
|
metadata=self.metadata,
|
|
721
769
|
)
|
|
722
|
-
VertexAIModelLink.persist(context=context,
|
|
770
|
+
VertexAIModelLink.persist(context=context, model_id=self.model_id)
|
|
723
771
|
return Model.to_dict(updated_model)
|
|
724
772
|
|
|
725
773
|
|