apache-airflow-providers-google 10.16.0rc1__py3-none-any.whl → 10.17.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/google/__init__.py +1 -1
- airflow/providers/google/ads/hooks/ads.py +5 -4
- airflow/providers/google/ads/operators/ads.py +1 -0
- airflow/providers/google/cloud/example_dags/example_cloud_sql_query.py +1 -0
- airflow/providers/google/cloud/example_dags/example_cloud_task.py +1 -0
- airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py +1 -0
- airflow/providers/google/cloud/example_dags/example_looker.py +1 -0
- airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py +1 -0
- airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py +1 -0
- airflow/providers/google/cloud/fs/gcs.py +1 -2
- airflow/providers/google/cloud/hooks/automl.py +1 -0
- airflow/providers/google/cloud/hooks/bigquery.py +87 -24
- airflow/providers/google/cloud/hooks/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/hooks/bigtable.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_build.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_memorystore.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_sql.py +1 -0
- airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +9 -4
- airflow/providers/google/cloud/hooks/compute.py +1 -0
- airflow/providers/google/cloud/hooks/compute_ssh.py +2 -2
- airflow/providers/google/cloud/hooks/dataflow.py +6 -5
- airflow/providers/google/cloud/hooks/datafusion.py +1 -0
- airflow/providers/google/cloud/hooks/datapipeline.py +1 -0
- airflow/providers/google/cloud/hooks/dataplex.py +1 -0
- airflow/providers/google/cloud/hooks/dataprep.py +1 -0
- airflow/providers/google/cloud/hooks/dataproc.py +3 -2
- airflow/providers/google/cloud/hooks/dataproc_metastore.py +1 -0
- airflow/providers/google/cloud/hooks/datastore.py +1 -0
- airflow/providers/google/cloud/hooks/dlp.py +1 -0
- airflow/providers/google/cloud/hooks/functions.py +1 -0
- airflow/providers/google/cloud/hooks/gcs.py +12 -5
- airflow/providers/google/cloud/hooks/kms.py +1 -0
- airflow/providers/google/cloud/hooks/kubernetes_engine.py +178 -300
- airflow/providers/google/cloud/hooks/life_sciences.py +1 -0
- airflow/providers/google/cloud/hooks/looker.py +1 -0
- airflow/providers/google/cloud/hooks/mlengine.py +1 -0
- airflow/providers/google/cloud/hooks/natural_language.py +1 -0
- airflow/providers/google/cloud/hooks/os_login.py +1 -0
- airflow/providers/google/cloud/hooks/pubsub.py +1 -0
- airflow/providers/google/cloud/hooks/secret_manager.py +1 -0
- airflow/providers/google/cloud/hooks/spanner.py +1 -0
- airflow/providers/google/cloud/hooks/speech_to_text.py +1 -0
- airflow/providers/google/cloud/hooks/stackdriver.py +1 -0
- airflow/providers/google/cloud/hooks/text_to_speech.py +1 -0
- airflow/providers/google/cloud/hooks/translate.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/batch_prediction_job.py +255 -3
- airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/dataset.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/endpoint_service.py +1 -0
- airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +197 -0
- airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +9 -9
- airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +231 -12
- airflow/providers/google/cloud/hooks/video_intelligence.py +1 -0
- airflow/providers/google/cloud/hooks/vision.py +1 -0
- airflow/providers/google/cloud/links/automl.py +1 -0
- airflow/providers/google/cloud/links/bigquery.py +1 -0
- airflow/providers/google/cloud/links/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/links/cloud_memorystore.py +1 -0
- airflow/providers/google/cloud/links/cloud_sql.py +1 -0
- airflow/providers/google/cloud/links/cloud_tasks.py +1 -0
- airflow/providers/google/cloud/links/compute.py +1 -0
- airflow/providers/google/cloud/links/datacatalog.py +1 -0
- airflow/providers/google/cloud/links/dataflow.py +1 -0
- airflow/providers/google/cloud/links/dataform.py +1 -0
- airflow/providers/google/cloud/links/datafusion.py +1 -0
- airflow/providers/google/cloud/links/dataplex.py +1 -0
- airflow/providers/google/cloud/links/dataproc.py +1 -0
- airflow/providers/google/cloud/links/kubernetes_engine.py +28 -0
- airflow/providers/google/cloud/links/mlengine.py +1 -0
- airflow/providers/google/cloud/links/pubsub.py +1 -0
- airflow/providers/google/cloud/links/spanner.py +1 -0
- airflow/providers/google/cloud/links/stackdriver.py +1 -0
- airflow/providers/google/cloud/links/workflows.py +1 -0
- airflow/providers/google/cloud/log/stackdriver_task_handler.py +18 -4
- airflow/providers/google/cloud/operators/automl.py +1 -0
- airflow/providers/google/cloud/operators/bigquery.py +21 -0
- airflow/providers/google/cloud/operators/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/operators/bigtable.py +1 -0
- airflow/providers/google/cloud/operators/cloud_base.py +1 -0
- airflow/providers/google/cloud/operators/cloud_build.py +1 -0
- airflow/providers/google/cloud/operators/cloud_memorystore.py +1 -0
- airflow/providers/google/cloud/operators/cloud_sql.py +1 -0
- airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +11 -5
- airflow/providers/google/cloud/operators/compute.py +1 -0
- airflow/providers/google/cloud/operators/dataflow.py +1 -0
- airflow/providers/google/cloud/operators/datafusion.py +1 -0
- airflow/providers/google/cloud/operators/datapipeline.py +1 -0
- airflow/providers/google/cloud/operators/dataprep.py +1 -0
- airflow/providers/google/cloud/operators/dataproc.py +3 -2
- airflow/providers/google/cloud/operators/dataproc_metastore.py +1 -0
- airflow/providers/google/cloud/operators/datastore.py +1 -0
- airflow/providers/google/cloud/operators/functions.py +1 -0
- airflow/providers/google/cloud/operators/gcs.py +1 -0
- airflow/providers/google/cloud/operators/kubernetes_engine.py +600 -4
- airflow/providers/google/cloud/operators/life_sciences.py +1 -0
- airflow/providers/google/cloud/operators/looker.py +1 -0
- airflow/providers/google/cloud/operators/mlengine.py +283 -259
- airflow/providers/google/cloud/operators/natural_language.py +1 -0
- airflow/providers/google/cloud/operators/pubsub.py +1 -0
- airflow/providers/google/cloud/operators/spanner.py +1 -0
- airflow/providers/google/cloud/operators/speech_to_text.py +1 -0
- airflow/providers/google/cloud/operators/text_to_speech.py +1 -0
- airflow/providers/google/cloud/operators/translate.py +1 -0
- airflow/providers/google/cloud/operators/translate_speech.py +1 -0
- airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +14 -7
- airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py +67 -13
- airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +26 -8
- airflow/providers/google/cloud/operators/vertex_ai/dataset.py +1 -0
- airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +306 -0
- airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py +29 -48
- airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +52 -17
- airflow/providers/google/cloud/operators/video_intelligence.py +1 -0
- airflow/providers/google/cloud/operators/vision.py +1 -0
- airflow/providers/google/cloud/secrets/secret_manager.py +1 -0
- airflow/providers/google/cloud/sensors/bigquery.py +1 -0
- airflow/providers/google/cloud/sensors/bigquery_dts.py +1 -0
- airflow/providers/google/cloud/sensors/bigtable.py +1 -0
- airflow/providers/google/cloud/sensors/cloud_storage_transfer_service.py +1 -0
- airflow/providers/google/cloud/sensors/dataflow.py +1 -0
- airflow/providers/google/cloud/sensors/dataform.py +1 -0
- airflow/providers/google/cloud/sensors/datafusion.py +1 -0
- airflow/providers/google/cloud/sensors/dataplex.py +1 -0
- airflow/providers/google/cloud/sensors/dataprep.py +1 -0
- airflow/providers/google/cloud/sensors/dataproc.py +1 -0
- airflow/providers/google/cloud/sensors/gcs.py +1 -0
- airflow/providers/google/cloud/sensors/looker.py +1 -0
- airflow/providers/google/cloud/sensors/pubsub.py +1 -0
- airflow/providers/google/cloud/sensors/tasks.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_mysql.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_postgres.py +1 -0
- airflow/providers/google/cloud/transfers/bigquery_to_sql.py +1 -0
- airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +3 -2
- airflow/providers/google/cloud/transfers/gcs_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/gcs_to_sftp.py +1 -0
- airflow/providers/google/cloud/transfers/local_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/mssql_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/mysql_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/postgres_to_gcs.py +19 -1
- airflow/providers/google/cloud/transfers/s3_to_gcs.py +3 -5
- airflow/providers/google/cloud/transfers/sftp_to_gcs.py +1 -0
- airflow/providers/google/cloud/transfers/sql_to_gcs.py +4 -2
- airflow/providers/google/cloud/triggers/bigquery.py +4 -3
- airflow/providers/google/cloud/triggers/cloud_batch.py +1 -1
- airflow/providers/google/cloud/triggers/cloud_run.py +1 -0
- airflow/providers/google/cloud/triggers/cloud_sql.py +2 -0
- airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +14 -2
- airflow/providers/google/cloud/triggers/dataplex.py +1 -0
- airflow/providers/google/cloud/triggers/dataproc.py +1 -0
- airflow/providers/google/cloud/triggers/kubernetes_engine.py +72 -2
- airflow/providers/google/cloud/triggers/mlengine.py +2 -0
- airflow/providers/google/cloud/triggers/pubsub.py +3 -3
- airflow/providers/google/cloud/triggers/vertex_ai.py +107 -15
- airflow/providers/google/cloud/utils/field_sanitizer.py +2 -1
- airflow/providers/google/cloud/utils/field_validator.py +1 -0
- airflow/providers/google/cloud/utils/helpers.py +1 -0
- airflow/providers/google/cloud/utils/mlengine_operator_utils.py +1 -0
- airflow/providers/google/cloud/utils/mlengine_prediction_summary.py +1 -0
- airflow/providers/google/cloud/utils/openlineage.py +1 -0
- airflow/providers/google/common/auth_backend/google_openid.py +1 -0
- airflow/providers/google/common/hooks/base_google.py +2 -1
- airflow/providers/google/common/hooks/discovery_api.py +1 -0
- airflow/providers/google/common/links/storage.py +1 -0
- airflow/providers/google/common/utils/id_token_credentials.py +1 -0
- airflow/providers/google/firebase/hooks/firestore.py +1 -0
- airflow/providers/google/get_provider_info.py +9 -3
- airflow/providers/google/go_module_utils.py +1 -0
- airflow/providers/google/leveldb/hooks/leveldb.py +8 -7
- airflow/providers/google/marketing_platform/example_dags/example_display_video.py +1 -0
- airflow/providers/google/marketing_platform/hooks/analytics_admin.py +1 -0
- airflow/providers/google/marketing_platform/hooks/campaign_manager.py +1 -0
- airflow/providers/google/marketing_platform/hooks/display_video.py +1 -0
- airflow/providers/google/marketing_platform/hooks/search_ads.py +1 -0
- airflow/providers/google/marketing_platform/operators/analytics.py +1 -0
- airflow/providers/google/marketing_platform/operators/analytics_admin.py +4 -2
- airflow/providers/google/marketing_platform/operators/campaign_manager.py +1 -0
- airflow/providers/google/marketing_platform/operators/display_video.py +1 -0
- airflow/providers/google/marketing_platform/operators/search_ads.py +1 -0
- airflow/providers/google/marketing_platform/sensors/campaign_manager.py +1 -0
- airflow/providers/google/marketing_platform/sensors/display_video.py +2 -1
- airflow/providers/google/marketing_platform/sensors/search_ads.py +1 -0
- airflow/providers/google/suite/hooks/calendar.py +1 -0
- airflow/providers/google/suite/hooks/drive.py +1 -0
- airflow/providers/google/suite/hooks/sheets.py +1 -0
- airflow/providers/google/suite/sensors/drive.py +1 -0
- airflow/providers/google/suite/transfers/gcs_to_gdrive.py +7 -0
- airflow/providers/google/suite/transfers/gcs_to_sheets.py +4 -1
- airflow/providers/google/suite/transfers/local_to_drive.py +1 -0
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/METADATA +22 -17
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/RECORD +196 -194
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_google-10.16.0rc1.dist-info → apache_airflow_providers_google-10.17.0rc1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,306 @@
|
|
1
|
+
#
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
4
|
+
# distributed with this work for additional information
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
7
|
+
# "License"); you may not use this file except in compliance
|
8
|
+
# with the License. You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
13
|
+
# software distributed under the License is distributed on an
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
15
|
+
# KIND, either express or implied. See the License for the
|
16
|
+
# specific language governing permissions and limitations
|
17
|
+
# under the License.
|
18
|
+
"""This module contains Google Vertex AI Generative AI operators."""
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
from typing import TYPE_CHECKING, Sequence
|
23
|
+
|
24
|
+
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
|
25
|
+
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
26
|
+
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from airflow.utils.context import Context
|
29
|
+
|
30
|
+
|
31
|
+
class PromptLanguageModelOperator(GoogleCloudBaseOperator):
|
32
|
+
"""
|
33
|
+
Uses the Vertex AI PaLM API to generate natural language text.
|
34
|
+
|
35
|
+
:param project_id: Required. The ID of the Google Cloud project that the
|
36
|
+
service belongs to.
|
37
|
+
:param location: Required. The ID of the Google Cloud location that the
|
38
|
+
service belongs to.
|
39
|
+
:param prompt: Required. Inputs or queries that a user or a program gives
|
40
|
+
to the Vertex AI PaLM API, in order to elicit a specific response.
|
41
|
+
:param pretrained_model: By default uses the pre-trained model `text-bison`,
|
42
|
+
optimized for performing natural language tasks such as classification,
|
43
|
+
summarization, extraction, content creation, and ideation.
|
44
|
+
:param temperature: Temperature controls the degree of randomness in token
|
45
|
+
selection. Defaults to 0.0.
|
46
|
+
:param max_output_tokens: Token limit determines the maximum amount of text
|
47
|
+
output. Defaults to 256.
|
48
|
+
:param top_p: Tokens are selected from most probable to least until the sum
|
49
|
+
of their probabilities equals the top_p value. Defaults to 0.8.
|
50
|
+
:param top_k: A top_k of 1 means the selected token is the most probable
|
51
|
+
among all tokens. Defaults to 0.4.
|
52
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
53
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
54
|
+
credentials, or chained list of accounts required to get the access_token
|
55
|
+
of the last account in the list, which will be impersonated in the request.
|
56
|
+
If set as a string, the account must grant the originating account
|
57
|
+
the Service Account Token Creator IAM role.
|
58
|
+
If set as a sequence, the identities from the list must grant
|
59
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
60
|
+
account from the list granting this role to the originating account (templated).
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
*,
|
66
|
+
project_id: str,
|
67
|
+
location: str,
|
68
|
+
prompt: str,
|
69
|
+
pretrained_model: str = "text-bison",
|
70
|
+
temperature: float = 0.0,
|
71
|
+
max_output_tokens: int = 256,
|
72
|
+
top_p: float = 0.8,
|
73
|
+
top_k: int = 40,
|
74
|
+
gcp_conn_id: str = "google_cloud_default",
|
75
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
76
|
+
**kwargs,
|
77
|
+
) -> None:
|
78
|
+
super().__init__(**kwargs)
|
79
|
+
self.project_id = project_id
|
80
|
+
self.location = location
|
81
|
+
self.prompt = prompt
|
82
|
+
self.pretrained_model = pretrained_model
|
83
|
+
self.temperature = temperature
|
84
|
+
self.max_output_tokens = max_output_tokens
|
85
|
+
self.top_p = top_p
|
86
|
+
self.top_k = top_k
|
87
|
+
self.gcp_conn_id = gcp_conn_id
|
88
|
+
self.impersonation_chain = impersonation_chain
|
89
|
+
|
90
|
+
def execute(self, context: Context):
|
91
|
+
self.hook = GenerativeModelHook(
|
92
|
+
gcp_conn_id=self.gcp_conn_id,
|
93
|
+
impersonation_chain=self.impersonation_chain,
|
94
|
+
)
|
95
|
+
|
96
|
+
self.log.info("Submitting prompt")
|
97
|
+
response = self.hook.prompt_language_model(
|
98
|
+
project_id=self.project_id,
|
99
|
+
location=self.location,
|
100
|
+
prompt=self.prompt,
|
101
|
+
pretrained_model=self.pretrained_model,
|
102
|
+
temperature=self.temperature,
|
103
|
+
max_output_tokens=self.max_output_tokens,
|
104
|
+
top_p=self.top_p,
|
105
|
+
top_k=self.top_k,
|
106
|
+
)
|
107
|
+
|
108
|
+
self.log.info("Model response: %s", response)
|
109
|
+
self.xcom_push(context, key="prompt_response", value=response)
|
110
|
+
|
111
|
+
return response
|
112
|
+
|
113
|
+
|
114
|
+
class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator):
|
115
|
+
"""
|
116
|
+
Uses the Vertex AI PaLM API to generate natural language text.
|
117
|
+
|
118
|
+
:param project_id: Required. The ID of the Google Cloud project that the
|
119
|
+
service belongs to.
|
120
|
+
:param location: Required. The ID of the Google Cloud location that the
|
121
|
+
service belongs to.
|
122
|
+
:param prompt: Required. Inputs or queries that a user or a program gives
|
123
|
+
to the Vertex AI PaLM API, in order to elicit a specific response.
|
124
|
+
:param pretrained_model: By default uses the pre-trained model `textembedding-gecko`,
|
125
|
+
optimized for performing text embeddings.
|
126
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
127
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
128
|
+
credentials, or chained list of accounts required to get the access_token
|
129
|
+
of the last account in the list, which will be impersonated in the request.
|
130
|
+
If set as a string, the account must grant the originating account
|
131
|
+
the Service Account Token Creator IAM role.
|
132
|
+
If set as a sequence, the identities from the list must grant
|
133
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
134
|
+
account from the list granting this role to the originating account (templated).
|
135
|
+
"""
|
136
|
+
|
137
|
+
def __init__(
|
138
|
+
self,
|
139
|
+
*,
|
140
|
+
project_id: str,
|
141
|
+
location: str,
|
142
|
+
prompt: str,
|
143
|
+
pretrained_model: str = "textembedding-gecko",
|
144
|
+
gcp_conn_id: str = "google_cloud_default",
|
145
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
146
|
+
**kwargs,
|
147
|
+
) -> None:
|
148
|
+
super().__init__(**kwargs)
|
149
|
+
self.project_id = project_id
|
150
|
+
self.location = location
|
151
|
+
self.prompt = prompt
|
152
|
+
self.pretrained_model = pretrained_model
|
153
|
+
self.gcp_conn_id = gcp_conn_id
|
154
|
+
self.impersonation_chain = impersonation_chain
|
155
|
+
|
156
|
+
def execute(self, context: Context):
|
157
|
+
self.hook = GenerativeModelHook(
|
158
|
+
gcp_conn_id=self.gcp_conn_id,
|
159
|
+
impersonation_chain=self.impersonation_chain,
|
160
|
+
)
|
161
|
+
|
162
|
+
self.log.info("Generating text embeddings")
|
163
|
+
response = self.hook.generate_text_embeddings(
|
164
|
+
project_id=self.project_id,
|
165
|
+
location=self.location,
|
166
|
+
prompt=self.prompt,
|
167
|
+
pretrained_model=self.pretrained_model,
|
168
|
+
)
|
169
|
+
|
170
|
+
self.log.info("Model response: %s", response)
|
171
|
+
self.xcom_push(context, key="prompt_response", value=response)
|
172
|
+
|
173
|
+
return response
|
174
|
+
|
175
|
+
|
176
|
+
class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
|
177
|
+
"""
|
178
|
+
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
|
179
|
+
|
180
|
+
:param project_id: Required. The ID of the Google Cloud project that the
|
181
|
+
service belongs to.
|
182
|
+
:param location: Required. The ID of the Google Cloud location that the
|
183
|
+
service belongs to.
|
184
|
+
:param prompt: Required. Inputs or queries that a user or a program gives
|
185
|
+
to the Multi-modal model, in order to elicit a specific response.
|
186
|
+
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
|
187
|
+
supporting prompts with text-only input, including natural language
|
188
|
+
tasks, multi-turn text and code chat, and code generation. It can
|
189
|
+
output text and code.
|
190
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
191
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
192
|
+
credentials, or chained list of accounts required to get the access_token
|
193
|
+
of the last account in the list, which will be impersonated in the request.
|
194
|
+
If set as a string, the account must grant the originating account
|
195
|
+
the Service Account Token Creator IAM role.
|
196
|
+
If set as a sequence, the identities from the list must grant
|
197
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
198
|
+
account from the list granting this role to the originating account (templated).
|
199
|
+
"""
|
200
|
+
|
201
|
+
def __init__(
|
202
|
+
self,
|
203
|
+
*,
|
204
|
+
project_id: str,
|
205
|
+
location: str,
|
206
|
+
prompt: str,
|
207
|
+
pretrained_model: str = "gemini-pro",
|
208
|
+
gcp_conn_id: str = "google_cloud_default",
|
209
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
210
|
+
**kwargs,
|
211
|
+
) -> None:
|
212
|
+
super().__init__(**kwargs)
|
213
|
+
self.project_id = project_id
|
214
|
+
self.location = location
|
215
|
+
self.prompt = prompt
|
216
|
+
self.pretrained_model = pretrained_model
|
217
|
+
self.gcp_conn_id = gcp_conn_id
|
218
|
+
self.impersonation_chain = impersonation_chain
|
219
|
+
|
220
|
+
def execute(self, context: Context):
|
221
|
+
self.hook = GenerativeModelHook(
|
222
|
+
gcp_conn_id=self.gcp_conn_id,
|
223
|
+
impersonation_chain=self.impersonation_chain,
|
224
|
+
)
|
225
|
+
response = self.hook.prompt_multimodal_model(
|
226
|
+
project_id=self.project_id,
|
227
|
+
location=self.location,
|
228
|
+
prompt=self.prompt,
|
229
|
+
pretrained_model=self.pretrained_model,
|
230
|
+
)
|
231
|
+
|
232
|
+
self.log.info("Model response: %s", response)
|
233
|
+
self.xcom_push(context, key="prompt_response", value=response)
|
234
|
+
|
235
|
+
return response
|
236
|
+
|
237
|
+
|
238
|
+
class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
|
239
|
+
"""
|
240
|
+
Use the Vertex AI Gemini Pro foundation model to generate natural language text.
|
241
|
+
|
242
|
+
:param project_id: Required. The ID of the Google Cloud project that the
|
243
|
+
service belongs to.
|
244
|
+
:param location: Required. The ID of the Google Cloud location that the
|
245
|
+
service belongs to.
|
246
|
+
:param prompt: Required. Inputs or queries that a user or a program gives
|
247
|
+
to the Multi-modal model, in order to elicit a specific response.
|
248
|
+
:param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
|
249
|
+
supporting prompts with text-only input, including natural language
|
250
|
+
tasks, multi-turn text and code chat, and code generation. It can
|
251
|
+
output text and code.
|
252
|
+
:param media_gcs_path: A GCS path to a media file such as an image or a video.
|
253
|
+
Can be passed to the multi-modal model as part of the prompt. Used with vision models.
|
254
|
+
:param mime_type: Validates the media type presented by the file in the media_gcs_path.
|
255
|
+
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
256
|
+
:param impersonation_chain: Optional service account to impersonate using short-term
|
257
|
+
credentials, or chained list of accounts required to get the access_token
|
258
|
+
of the last account in the list, which will be impersonated in the request.
|
259
|
+
If set as a string, the account must grant the originating account
|
260
|
+
the Service Account Token Creator IAM role.
|
261
|
+
If set as a sequence, the identities from the list must grant
|
262
|
+
Service Account Token Creator IAM role to the directly preceding identity, with first
|
263
|
+
account from the list granting this role to the originating account (templated).
|
264
|
+
"""
|
265
|
+
|
266
|
+
def __init__(
|
267
|
+
self,
|
268
|
+
*,
|
269
|
+
project_id: str,
|
270
|
+
location: str,
|
271
|
+
prompt: str,
|
272
|
+
media_gcs_path: str,
|
273
|
+
mime_type: str,
|
274
|
+
pretrained_model: str = "gemini-pro-vision",
|
275
|
+
gcp_conn_id: str = "google_cloud_default",
|
276
|
+
impersonation_chain: str | Sequence[str] | None = None,
|
277
|
+
**kwargs,
|
278
|
+
) -> None:
|
279
|
+
super().__init__(**kwargs)
|
280
|
+
self.project_id = project_id
|
281
|
+
self.location = location
|
282
|
+
self.prompt = prompt
|
283
|
+
self.pretrained_model = pretrained_model
|
284
|
+
self.media_gcs_path = media_gcs_path
|
285
|
+
self.mime_type = mime_type
|
286
|
+
self.gcp_conn_id = gcp_conn_id
|
287
|
+
self.impersonation_chain = impersonation_chain
|
288
|
+
|
289
|
+
def execute(self, context: Context):
|
290
|
+
self.hook = GenerativeModelHook(
|
291
|
+
gcp_conn_id=self.gcp_conn_id,
|
292
|
+
impersonation_chain=self.impersonation_chain,
|
293
|
+
)
|
294
|
+
response = self.hook.prompt_multimodal_model_with_media(
|
295
|
+
project_id=self.project_id,
|
296
|
+
location=self.location,
|
297
|
+
prompt=self.prompt,
|
298
|
+
pretrained_model=self.pretrained_model,
|
299
|
+
media_gcs_path=self.media_gcs_path,
|
300
|
+
mime_type=self.mime_type,
|
301
|
+
)
|
302
|
+
|
303
|
+
self.log.info("Model response: %s", response)
|
304
|
+
self.xcom_push(context, key="prompt_response", value=response)
|
305
|
+
|
306
|
+
return response
|
@@ -20,14 +20,15 @@
|
|
20
20
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
|
+
import warnings
|
23
24
|
from typing import TYPE_CHECKING, Any, Sequence
|
24
25
|
|
25
26
|
from google.api_core.exceptions import NotFound
|
26
27
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
27
|
-
from google.cloud.aiplatform_v1
|
28
|
+
from google.cloud.aiplatform_v1 import types
|
28
29
|
|
29
30
|
from airflow.configuration import conf
|
30
|
-
from airflow.exceptions import AirflowException
|
31
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
31
32
|
from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import (
|
32
33
|
HyperparameterTuningJobHook,
|
33
34
|
)
|
@@ -40,7 +41,7 @@ from airflow.providers.google.cloud.triggers.vertex_ai import CreateHyperparamet
|
|
40
41
|
|
41
42
|
if TYPE_CHECKING:
|
42
43
|
from google.api_core.retry import Retry
|
43
|
-
from google.cloud.aiplatform import gapic, hyperparameter_tuning
|
44
|
+
from google.cloud.aiplatform import HyperparameterTuningJob, gapic, hyperparameter_tuning
|
44
45
|
|
45
46
|
from airflow.utils.context import Context
|
46
47
|
|
@@ -127,8 +128,8 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
127
128
|
`service_account` is required with provided `tensorboard`. For more information on configuring
|
128
129
|
your service account please visit:
|
129
130
|
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
|
130
|
-
:param sync: Whether to execute this method synchronously. If False, this method will
|
131
|
-
will be executed in a concurrent Future.
|
131
|
+
:param sync: (Deprecated) Whether to execute this method synchronously. If False, this method will
|
132
|
+
unblock, and it will be executed in a concurrent Future.
|
132
133
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
133
134
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
134
135
|
credentials, or chained list of accounts required to get the access_token
|
@@ -138,8 +139,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
138
139
|
If set as a sequence, the identities from the list must grant
|
139
140
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
140
141
|
account from the list granting this role to the originating account (templated).
|
141
|
-
:param deferrable: Run operator in the deferrable mode.
|
142
|
-
with `sync=False` parameter.
|
142
|
+
:param deferrable: Run operator in the deferrable mode.
|
143
143
|
:param poll_interval: Interval size which defines how often job status is checked in deferrable mode.
|
144
144
|
"""
|
145
145
|
|
@@ -221,19 +221,18 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
221
221
|
self.poll_interval = poll_interval
|
222
222
|
|
223
223
|
def execute(self, context: Context):
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
)
|
224
|
+
warnings.warn(
|
225
|
+
"The 'sync' parameter is deprecated and will be removed after 01.09.2024.",
|
226
|
+
AirflowProviderDeprecationWarning,
|
227
|
+
stacklevel=2,
|
228
|
+
)
|
230
229
|
|
231
230
|
self.log.info("Creating Hyperparameter Tuning job")
|
232
231
|
self.hook = HyperparameterTuningJobHook(
|
233
232
|
gcp_conn_id=self.gcp_conn_id,
|
234
233
|
impersonation_chain=self.impersonation_chain,
|
235
234
|
)
|
236
|
-
|
235
|
+
hyperparameter_tuning_job: HyperparameterTuningJob = self.hook.create_hyperparameter_tuning_job(
|
237
236
|
project_id=self.project_id,
|
238
237
|
region=self.region,
|
239
238
|
display_name=self.display_name,
|
@@ -259,14 +258,19 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
259
258
|
restart_job_on_worker_restart=self.restart_job_on_worker_restart,
|
260
259
|
enable_web_access=self.enable_web_access,
|
261
260
|
tensorboard=self.tensorboard,
|
262
|
-
sync=
|
263
|
-
wait_job_completed=
|
261
|
+
sync=False,
|
262
|
+
wait_job_completed=False,
|
264
263
|
)
|
265
264
|
|
266
|
-
hyperparameter_tuning_job
|
267
|
-
hyperparameter_tuning_job_id =
|
268
|
-
|
265
|
+
hyperparameter_tuning_job.wait_for_resource_creation()
|
266
|
+
hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
|
267
|
+
self.log.info("Hyperparameter Tuning job was created. Job id: %s", hyperparameter_tuning_job_id)
|
268
|
+
|
269
|
+
self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
|
270
|
+
VertexAITrainingLink.persist(
|
271
|
+
context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
|
269
272
|
)
|
273
|
+
|
270
274
|
if self.deferrable:
|
271
275
|
self.defer(
|
272
276
|
trigger=CreateHyperparameterTuningJobTrigger(
|
@@ -279,14 +283,10 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
279
283
|
),
|
280
284
|
method_name="execute_complete",
|
281
285
|
)
|
286
|
+
return
|
282
287
|
|
283
|
-
|
284
|
-
|
285
|
-
self.xcom_push(context, key="hyperparameter_tuning_job_id", value=hyperparameter_tuning_job_id)
|
286
|
-
VertexAITrainingLink.persist(
|
287
|
-
context=context, task_instance=self, training_id=hyperparameter_tuning_job_id
|
288
|
-
)
|
289
|
-
return hyperparameter_tuning_job
|
288
|
+
hyperparameter_tuning_job.wait_for_completion()
|
289
|
+
return hyperparameter_tuning_job.to_dict()
|
290
290
|
|
291
291
|
def on_kill(self) -> None:
|
292
292
|
"""Act as a callback called when the operator is killed; cancel any running job."""
|
@@ -298,26 +298,7 @@ class CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
298
298
|
raise AirflowException(event["message"])
|
299
299
|
job: dict[str, Any] = event["job"]
|
300
300
|
self.log.info("Hyperparameter tuning job %s created and completed successfully.", job["name"])
|
301
|
-
|
302
|
-
gcp_conn_id=self.gcp_conn_id,
|
303
|
-
impersonation_chain=self.impersonation_chain,
|
304
|
-
)
|
305
|
-
job_id = hook.extract_hyperparameter_tuning_job_id(job)
|
306
|
-
self.xcom_push(
|
307
|
-
context,
|
308
|
-
key="hyperparameter_tuning_job_id",
|
309
|
-
value=job_id,
|
310
|
-
)
|
311
|
-
self.xcom_push(
|
312
|
-
context,
|
313
|
-
key="training_conf",
|
314
|
-
value={
|
315
|
-
"training_conf_id": job_id,
|
316
|
-
"region": self.region,
|
317
|
-
"project_id": self.project_id,
|
318
|
-
},
|
319
|
-
)
|
320
|
-
return event["job"]
|
301
|
+
return job
|
321
302
|
|
322
303
|
|
323
304
|
class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
@@ -387,7 +368,7 @@ class GetHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
387
368
|
context=context, task_instance=self, training_id=self.hyperparameter_tuning_job_id
|
388
369
|
)
|
389
370
|
self.log.info("Hyperparameter tuning job was gotten.")
|
390
|
-
return HyperparameterTuningJob.to_dict(result)
|
371
|
+
return types.HyperparameterTuningJob.to_dict(result)
|
391
372
|
except NotFound:
|
392
373
|
self.log.info(
|
393
374
|
"The Hyperparameter tuning job %s does not exist.", self.hyperparameter_tuning_job_id
|
@@ -532,4 +513,4 @@ class ListHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
|
|
532
513
|
metadata=self.metadata,
|
533
514
|
)
|
534
515
|
VertexAIHyperparameterTuningJobListLink.persist(context=context, task_instance=self)
|
535
|
-
return [HyperparameterTuningJob.to_dict(result) for result in results]
|
516
|
+
return [types.HyperparameterTuningJob.to_dict(result) for result in results]
|
@@ -16,23 +16,29 @@
|
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
18
|
"""This module contains Google Vertex AI operators."""
|
19
|
+
|
19
20
|
from __future__ import annotations
|
20
21
|
|
22
|
+
from functools import cached_property
|
21
23
|
from typing import TYPE_CHECKING, Any, Sequence
|
22
24
|
|
23
25
|
from google.api_core.exceptions import NotFound
|
24
26
|
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
|
25
|
-
from google.cloud.aiplatform_v1
|
27
|
+
from google.cloud.aiplatform_v1 import types
|
26
28
|
|
29
|
+
from airflow.configuration import conf
|
30
|
+
from airflow.exceptions import AirflowException
|
27
31
|
from airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job import PipelineJobHook
|
28
32
|
from airflow.providers.google.cloud.links.vertex_ai import (
|
29
33
|
VertexAIPipelineJobLink,
|
30
34
|
VertexAIPipelineJobListLink,
|
31
35
|
)
|
32
36
|
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
|
37
|
+
from airflow.providers.google.cloud.triggers.vertex_ai import RunPipelineJobTrigger
|
33
38
|
|
34
39
|
if TYPE_CHECKING:
|
35
40
|
from google.api_core.retry import Retry
|
41
|
+
from google.cloud.aiplatform import PipelineJob
|
36
42
|
from google.cloud.aiplatform.metadata import experiment_resources
|
37
43
|
|
38
44
|
from airflow.utils.context import Context
|
@@ -40,7 +46,7 @@ if TYPE_CHECKING:
|
|
40
46
|
|
41
47
|
class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
42
48
|
"""
|
43
|
-
|
49
|
+
Create and run a Pipeline job.
|
44
50
|
|
45
51
|
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
|
46
52
|
:param region: Required. The ID of the Google Cloud region that the service belongs to.
|
@@ -82,9 +88,9 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
|
82
88
|
Private services access must already be configured for the network. If left unspecified, the
|
83
89
|
network set in aiplatform.init will be used. Otherwise, the job is not peered with any network.
|
84
90
|
:param create_request_timeout: Optional. The timeout for the create request in seconds.
|
85
|
-
:param experiment: Optional. The Vertex AI experiment name or instance to associate to this
|
86
|
-
|
87
|
-
|
91
|
+
:param experiment: Optional. The Vertex AI experiment name or instance to associate to this PipelineJob.
|
92
|
+
Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as metrics
|
93
|
+
to the current Experiment Run. Pipeline parameters will be associated as parameters to
|
88
94
|
the current Experiment Run.
|
89
95
|
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
|
90
96
|
:param impersonation_chain: Optional service account to impersonate using short-term
|
@@ -95,6 +101,10 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
|
95
101
|
If set as a sequence, the identities from the list must grant
|
96
102
|
Service Account Token Creator IAM role to the directly preceding identity, with first
|
97
103
|
account from the list granting this role to the originating account (templated).
|
104
|
+
:param deferrable: If True, run the task in the deferrable mode.
|
105
|
+
Note that it requires calling the operator with `sync=False` parameter.
|
106
|
+
:param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
|
107
|
+
The default is 300 seconds.
|
98
108
|
"""
|
99
109
|
|
100
110
|
template_fields = [
|
@@ -126,6 +136,8 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
|
126
136
|
experiment: str | experiment_resources.Experiment | None = None,
|
127
137
|
gcp_conn_id: str = "google_cloud_default",
|
128
138
|
impersonation_chain: str | Sequence[str] | None = None,
|
139
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
140
|
+
poll_interval: int = 5 * 60,
|
129
141
|
**kwargs,
|
130
142
|
) -> None:
|
131
143
|
super().__init__(**kwargs)
|
@@ -147,15 +159,12 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
|
147
159
|
self.experiment = experiment
|
148
160
|
self.gcp_conn_id = gcp_conn_id
|
149
161
|
self.impersonation_chain = impersonation_chain
|
150
|
-
self.
|
162
|
+
self.deferrable = deferrable
|
163
|
+
self.poll_interval = poll_interval
|
151
164
|
|
152
165
|
def execute(self, context: Context):
|
153
166
|
self.log.info("Running Pipeline job")
|
154
|
-
self.hook
|
155
|
-
gcp_conn_id=self.gcp_conn_id,
|
156
|
-
impersonation_chain=self.impersonation_chain,
|
157
|
-
)
|
158
|
-
result = self.hook.run_pipeline_job(
|
167
|
+
pipeline_job_obj: PipelineJob = self.hook.submit_pipeline_job(
|
159
168
|
project_id=self.project_id,
|
160
169
|
region=self.region,
|
161
170
|
display_name=self.display_name,
|
@@ -173,20 +182,46 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
|
|
173
182
|
create_request_timeout=self.create_request_timeout,
|
174
183
|
experiment=self.experiment,
|
175
184
|
)
|
176
|
-
|
177
|
-
pipeline_job = result.to_dict()
|
178
|
-
pipeline_job_id = self.hook.extract_pipeline_job_id(pipeline_job)
|
185
|
+
pipeline_job_id = pipeline_job_obj.job_id
|
179
186
|
self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id)
|
180
|
-
|
181
187
|
self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id)
|
182
188
|
VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=pipeline_job_id)
|
189
|
+
|
190
|
+
if self.deferrable:
|
191
|
+
pipeline_job_obj.wait_for_resource_creation()
|
192
|
+
self.defer(
|
193
|
+
trigger=RunPipelineJobTrigger(
|
194
|
+
conn_id=self.gcp_conn_id,
|
195
|
+
project_id=self.project_id,
|
196
|
+
location=pipeline_job_obj.location,
|
197
|
+
job_id=pipeline_job_id,
|
198
|
+
poll_interval=self.poll_interval,
|
199
|
+
impersonation_chain=self.impersonation_chain,
|
200
|
+
),
|
201
|
+
method_name="execute_complete",
|
202
|
+
)
|
203
|
+
|
204
|
+
pipeline_job_obj.wait()
|
205
|
+
pipeline_job = pipeline_job_obj.to_dict()
|
183
206
|
return pipeline_job
|
184
207
|
|
208
|
+
def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]:
|
209
|
+
if event["status"] == "error":
|
210
|
+
raise AirflowException(event["message"])
|
211
|
+
return event["job"]
|
212
|
+
|
185
213
|
def on_kill(self) -> None:
|
186
214
|
"""Act as a callback called when the operator is killed; cancel any running job."""
|
187
215
|
if self.hook:
|
188
216
|
self.hook.cancel_pipeline_job()
|
189
217
|
|
218
|
+
@cached_property
|
219
|
+
def hook(self) -> PipelineJobHook:
|
220
|
+
return PipelineJobHook(
|
221
|
+
gcp_conn_id=self.gcp_conn_id,
|
222
|
+
impersonation_chain=self.impersonation_chain,
|
223
|
+
)
|
224
|
+
|
190
225
|
|
191
226
|
class GetPipelineJobOperator(GoogleCloudBaseOperator):
|
192
227
|
"""
|
@@ -261,7 +296,7 @@ class GetPipelineJobOperator(GoogleCloudBaseOperator):
|
|
261
296
|
context=context, task_instance=self, pipeline_id=self.pipeline_job_id
|
262
297
|
)
|
263
298
|
self.log.info("Pipeline job was gotten.")
|
264
|
-
return PipelineJob.to_dict(result)
|
299
|
+
return types.PipelineJob.to_dict(result)
|
265
300
|
except NotFound:
|
266
301
|
self.log.info("The Pipeline job %s does not exist.", self.pipeline_job_id)
|
267
302
|
|
@@ -389,7 +424,7 @@ class ListPipelineJobOperator(GoogleCloudBaseOperator):
|
|
389
424
|
metadata=self.metadata,
|
390
425
|
)
|
391
426
|
VertexAIPipelineJobListLink.persist(context=context, task_instance=self)
|
392
|
-
return [PipelineJob.to_dict(result) for result in results]
|
427
|
+
return [types.PipelineJob.to_dict(result) for result in results]
|
393
428
|
|
394
429
|
|
395
430
|
class DeletePipelineJobOperator(GoogleCloudBaseOperator):
|