google-cloud-pipeline-components 2.6.0__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_cloud_pipeline_components/_implementation/llm/arbiter_preprocess.py +137 -0
- google_cloud_pipeline_components/_implementation/llm/autosxs_arbiter.py +105 -0
- google_cloud_pipeline_components/_implementation/llm/autosxs_metrics_computer.py +66 -0
- google_cloud_pipeline_components/_implementation/llm/deployment_graph.py +10 -16
- google_cloud_pipeline_components/_implementation/llm/env.py +1 -1
- google_cloud_pipeline_components/_implementation/llm/function_based.py +82 -5
- google_cloud_pipeline_components/_implementation/llm/reinforcement_learning_graph.py +6 -0
- google_cloud_pipeline_components/_implementation/llm/reinforcer.py +7 -2
- google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py +6 -0
- google_cloud_pipeline_components/_implementation/llm/reward_model_trainer.py +7 -2
- google_cloud_pipeline_components/_implementation/llm/supervised_fine_tuner.py +5 -0
- google_cloud_pipeline_components/_implementation/llm/task_preprocess.py +97 -0
- google_cloud_pipeline_components/_implementation/llm/upload_llm_model.py +5 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/__init__.py +4 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/endpoint_batch_predict/component.py +1 -1
- google_cloud_pipeline_components/_implementation/model_evaluation/import_evaluation/component.py +10 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/llm_embedding/evaluation_llm_embedding_pipeline.py +64 -15
- google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation/component.py +9 -2
- google_cloud_pipeline_components/_implementation/model_evaluation/model_inference/__init__.py +14 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/model_inference/component.py +324 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/version.py +2 -2
- google_cloud_pipeline_components/container/_implementation/model_evaluation/import_model_evaluation.py +8 -0
- google_cloud_pipeline_components/container/v1/automl_training_job/__init__.py +14 -0
- google_cloud_pipeline_components/container/v1/automl_training_job/image/__init__.py +14 -0
- google_cloud_pipeline_components/container/v1/automl_training_job/image/launcher.py +236 -0
- google_cloud_pipeline_components/container/v1/automl_training_job/image/remote_runner.py +250 -0
- google_cloud_pipeline_components/preview/model_evaluation/evaluation_llm_text_generation_pipeline.py +6 -1
- google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/__init__.py +20 -0
- google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/autosxs/__init__.py +13 -0
- google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/autosxs/autosxs_pipeline.py +234 -0
- google_cloud_pipeline_components/version.py +1 -1
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/METADATA +1 -1
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/RECORD +36 -23
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/LICENSE +0 -0
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/WHEEL +0 -0
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright 2023 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Component for preprocessing the evaluation dataset into prediction inputs."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from typing import Any, Dict, List
|
|
18
|
+
|
|
19
|
+
from google_cloud_pipeline_components import _placeholders
|
|
20
|
+
from google_cloud_pipeline_components import utils as gcpc_utils
|
|
21
|
+
from google_cloud_pipeline_components._implementation.llm import utils
|
|
22
|
+
from kfp import dsl
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _resolve_image() -> str:
|
|
26
|
+
"""Determines the image URI to create a container from."""
|
|
27
|
+
return (
|
|
28
|
+
os.environ.get('AUTOSXS_IMAGE_OVERRIDE')
|
|
29
|
+
or utils.get_default_image_uri('autosxs'))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# pylint: disable=dangerous-default-value,g-bare-generic,unused-argument
|
|
33
|
+
@dsl.container_component
|
|
34
|
+
def task_preprocess(
|
|
35
|
+
evaluation_dataset: str,
|
|
36
|
+
id_columns: List[str],
|
|
37
|
+
task: str,
|
|
38
|
+
model_prompt_parameters: Dict[str, Dict[str, str]],
|
|
39
|
+
prediction_inputs: dsl.OutputPath(List[str]), # pytype: disable=invalid-annotation
|
|
40
|
+
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
|
|
41
|
+
metadata: dsl.OutputPath(Dict[str, Any]), # pytype: disable=invalid-annotation
|
|
42
|
+
response_column: str,
|
|
43
|
+
human_preference_column: str = '',
|
|
44
|
+
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
|
|
45
|
+
"""Preprocesses evaluation dataset into prediction inputs.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
evaluation_dataset: GCS or BigQuery URIs representing a dataset of prompts
|
|
49
|
+
and responses.
|
|
50
|
+
id_columns: The columns which distinguish unique evaluation examples.
|
|
51
|
+
task: Evaluation task in the form {task}@{version}. task can be one of
|
|
52
|
+
"summarization", "question_answer". Version is an integer with 3 digits or
|
|
53
|
+
"latest". Ex: summarization@001 or question_answer@latest.
|
|
54
|
+
model_prompt_parameters: Map of model prompt template parameters to columns
|
|
55
|
+
or templates.
|
|
56
|
+
response_column: Either an existing column containing predefined responses,
|
|
57
|
+
or the name of the model output column containing responses.
|
|
58
|
+
human_preference_column: The column containing ground truths. Only required
|
|
59
|
+
when users want to check the autorater alignment against human preference.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
prediction_inputs_path: Path to write the path where preprocessed
|
|
63
|
+
predictions are stored.
|
|
64
|
+
gcp_resources: Tracker for GCP resources created by this component.
|
|
65
|
+
metadata_path: Path to write the object that stores computed metrics
|
|
66
|
+
metadata for the task preprocess component.
|
|
67
|
+
"""
|
|
68
|
+
return gcpc_utils.build_serverless_customjob_container_spec(
|
|
69
|
+
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
|
70
|
+
location=_placeholders.LOCATION_PLACEHOLDER,
|
|
71
|
+
custom_job_payload=utils.build_payload(
|
|
72
|
+
display_name='task_preprocess',
|
|
73
|
+
machine_type='n1-standard-4',
|
|
74
|
+
image_uri=_resolve_image(),
|
|
75
|
+
args=[
|
|
76
|
+
'--', # Used to mark the start of component flags.
|
|
77
|
+
'task_preprocess',
|
|
78
|
+
f'--evaluation_dataset={evaluation_dataset}',
|
|
79
|
+
f'--staging_dir={dsl.PIPELINE_ROOT_PLACEHOLDER}',
|
|
80
|
+
f'--task={task}',
|
|
81
|
+
f'--prediction_inputs_path={prediction_inputs}',
|
|
82
|
+
(
|
|
83
|
+
'--id_columns='
|
|
84
|
+
"{{$.inputs.parameters['id_columns'].json_escape[0]}}"
|
|
85
|
+
),
|
|
86
|
+
(
|
|
87
|
+
'--model_prompt_parameters='
|
|
88
|
+
"{{$.inputs.parameters['model_prompt_parameters']"
|
|
89
|
+
'.json_escape[0]}}'
|
|
90
|
+
),
|
|
91
|
+
f'--metadata_path={metadata}',
|
|
92
|
+
f'--response_column={response_column}',
|
|
93
|
+
f'--human_preference_column={human_preference_column}',
|
|
94
|
+
],
|
|
95
|
+
),
|
|
96
|
+
gcp_resources=gcp_resources,
|
|
97
|
+
)
|
|
@@ -34,6 +34,7 @@ def upload_llm_model(
|
|
|
34
34
|
gcp_resources: dsl.OutputPath(str),
|
|
35
35
|
encryption_spec_key_name: str = '',
|
|
36
36
|
upload_model: bool = True,
|
|
37
|
+
tune_type: str = '',
|
|
37
38
|
):
|
|
38
39
|
"""Uploads LLM model.
|
|
39
40
|
|
|
@@ -48,6 +49,8 @@ def upload_llm_model(
|
|
|
48
49
|
upload_model: Whether to upload the model to the Model Registry. Default
|
|
49
50
|
is ``True``. If ``False``, the model will not be uploaded and output
|
|
50
51
|
artifacts will contain empty strings.
|
|
52
|
+
tune_type: Method used to tune the model, e.g. ``rlhf``. If present, this
|
|
53
|
+
value is used to set the ``tune-type`` run label during model upload.
|
|
51
54
|
|
|
52
55
|
Returns:
|
|
53
56
|
model_resource_name: Path to the created Model on Model Registry.
|
|
@@ -76,6 +79,8 @@ def upload_llm_model(
|
|
|
76
79
|
labels['google-vertex-llm-tuning-base-model-id'] = (
|
|
77
80
|
model_reference_name.replace('@', '-')
|
|
78
81
|
)
|
|
82
|
+
if tune_type:
|
|
83
|
+
labels['tune-type'] = tune_type
|
|
79
84
|
|
|
80
85
|
model_upload_payload = {
|
|
81
86
|
'model': {
|
|
@@ -35,6 +35,8 @@ from google_cloud_pipeline_components._implementation.model_evaluation.llm_infor
|
|
|
35
35
|
from google_cloud_pipeline_components._implementation.model_evaluation.llm_retrieval_metrics.component import llm_retrieval_metrics as LLMRetrievalMetricsOp
|
|
36
36
|
from google_cloud_pipeline_components._implementation.model_evaluation.llm_safety_bias.component import llm_safety_bias_metrics as LLMSafetyBiasMetricsOp
|
|
37
37
|
from google_cloud_pipeline_components._implementation.model_evaluation.llm_safety_bias.evaluation_llm_safety_bias_pipeline import evaluation_llm_safety_bias_pipeline
|
|
38
|
+
from google_cloud_pipeline_components._implementation.model_evaluation.model_inference.component import model_inference_and_evaluation_component
|
|
39
|
+
from google_cloud_pipeline_components._implementation.model_evaluation.model_inference.component import model_inference_component
|
|
38
40
|
from google_cloud_pipeline_components._implementation.model_evaluation.target_field_data_remover.component import target_field_data_remover as TargetFieldDataRemoverOp
|
|
39
41
|
from google_cloud_pipeline_components._implementation.model_evaluation.text2sql.evaluation_llm_text2sql_pipeline import evaluation_llm_text2sql_pipeline
|
|
40
42
|
|
|
@@ -62,4 +64,6 @@ __all__ = [
|
|
|
62
64
|
'ModelImportEvaluatedAnnotationOp',
|
|
63
65
|
'ModelImportEvaluationOp',
|
|
64
66
|
'TargetFieldDataRemoverOp',
|
|
67
|
+
'model_inference_component',
|
|
68
|
+
'model_inference_and_evaluation_component',
|
|
65
69
|
]
|
|
@@ -169,7 +169,7 @@ def evaluation_llm_endpoint_batch_predict_pipeline_graph_component(
|
|
|
169
169
|
network: str = '',
|
|
170
170
|
encryption_spec_key_name: str = '',
|
|
171
171
|
) -> NamedTuple('outputs', gcs_output_directory=Artifact):
|
|
172
|
-
"""The
|
|
172
|
+
"""The First Party Model Endpoint Batch Predict Pipeline.
|
|
173
173
|
|
|
174
174
|
Args:
|
|
175
175
|
project: Required. The GCP project that runs the pipeline components.
|
google_cloud_pipeline_components/_implementation/model_evaluation/import_evaluation/component.py
CHANGED
|
@@ -31,6 +31,7 @@ def model_evaluation_import(
|
|
|
31
31
|
gcp_resources: dsl.OutputPath(str),
|
|
32
32
|
evaluation_resource_name: dsl.OutputPath(str),
|
|
33
33
|
metrics: Optional[Input[Metrics]] = None,
|
|
34
|
+
row_based_metrics: Optional[Input[Metrics]] = None,
|
|
34
35
|
problem_type: Optional[str] = None,
|
|
35
36
|
classification_metrics: Optional[Input[ClassificationMetrics]] = None,
|
|
36
37
|
forecasting_metrics: Optional[Input[ForecastingMetrics]] = None,
|
|
@@ -59,6 +60,8 @@ def model_evaluation_import(
|
|
|
59
60
|
model: Vertex model resource that will be the parent resource of the
|
|
60
61
|
uploaded evaluation.
|
|
61
62
|
metrics: Path of metrics generated from an evaluation component.
|
|
63
|
+
row_based_metrics:
|
|
64
|
+
Path of row_based_metrics generated from an evaluation component.
|
|
62
65
|
problem_type: The problem type of the metrics being imported to the
|
|
63
66
|
VertexModel. `classification`, `regression`, `forecasting`,
|
|
64
67
|
`text-generation`, `question-answering`, and `summarization` are the
|
|
@@ -106,6 +109,13 @@ def model_evaluation_import(
|
|
|
106
109
|
metrics.metadata["explanation_gcs_path"],
|
|
107
110
|
],
|
|
108
111
|
),
|
|
112
|
+
dsl.IfPresentPlaceholder(
|
|
113
|
+
input_name="row_based_metrics",
|
|
114
|
+
then=[
|
|
115
|
+
"--row_based_metrics",
|
|
116
|
+
row_based_metrics.uri,
|
|
117
|
+
],
|
|
118
|
+
),
|
|
109
119
|
dsl.IfPresentPlaceholder(
|
|
110
120
|
input_name="explanation",
|
|
111
121
|
then=[
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""LLM embedding evaluation pipeline based on information retrieval (IR) task."""
|
|
15
15
|
|
|
16
|
+
from typing import Dict, Optional, Union
|
|
17
|
+
from google_cloud_pipeline_components._implementation.model_evaluation.endpoint_batch_predict.component import evaluation_llm_endpoint_batch_predict_pipeline_graph_component as LLMEndpointBatchPredictOp
|
|
16
18
|
from google_cloud_pipeline_components._implementation.model_evaluation.import_evaluation.component import model_evaluation_import as ModelImportEvaluationOp
|
|
17
19
|
from google_cloud_pipeline_components._implementation.model_evaluation.llm_embedding_retrieval.component import llm_embedding_retrieval as LLMEmbeddingRetrievalOp
|
|
18
20
|
from google_cloud_pipeline_components._implementation.model_evaluation.llm_information_retrieval_preprocessor.component import llm_information_retrieval_preprocessor as LLMInformationRetrievalPreprocessorOp
|
|
@@ -34,6 +36,8 @@ def evaluation_llm_embedding_pipeline(
|
|
|
34
36
|
query_gcs_source: str,
|
|
35
37
|
golden_docs_gcs_source: str,
|
|
36
38
|
model_name: str,
|
|
39
|
+
qms_override: Optional[Dict[str, Union[int, float]]] = {},
|
|
40
|
+
model_parameters: Optional[Dict[str, Union[int, float]]] = {},
|
|
37
41
|
embedding_chunking_function: str = 'langchain-RecursiveCharacterTextSplitter',
|
|
38
42
|
embedding_chunk_size: int = 0,
|
|
39
43
|
embedding_chunk_overlap: int = 0,
|
|
@@ -65,7 +69,14 @@ def evaluation_llm_embedding_pipeline(
|
|
|
65
69
|
query_gcs_source: The gcs location for json file containing query documents.
|
|
66
70
|
golden_docs_gcs_source: The gcs location for csv file containing mapping of
|
|
67
71
|
each query to the golden docs.
|
|
68
|
-
model_name: The path for model to generate embeddings
|
|
72
|
+
model_name: The path for model to generate embeddings, example,
|
|
73
|
+
'publishers/google/models/textembedding-gecko-multilingual@latest'
|
|
74
|
+
qms_override: Manual control of a large language model's qms. Write up when
|
|
75
|
+
there's an approved quota increase for a LLM. Write down when limiting qms
|
|
76
|
+
of a LLM for this pipeline. Should be provided as a dictionary, for
|
|
77
|
+
example {'text-bison': 20}. For deployed model which doesn't have
|
|
78
|
+
google-vertex-llm-tuning-base-model-id label, override the default here.
|
|
79
|
+
model_parameters: The parameters that govern the prediction.
|
|
69
80
|
embedding_chunking_function: function used to split a document into chunks.
|
|
70
81
|
Supported values are `langchain-RecursiveCharacterTextSplitter` and
|
|
71
82
|
`sentence-splitter`. langchain-RecursiveCharacterTextSplitter:
|
|
@@ -157,38 +168,76 @@ def evaluation_llm_embedding_pipeline(
|
|
|
157
168
|
)
|
|
158
169
|
get_vertex_model_task.set_display_name('get-vertex-model')
|
|
159
170
|
|
|
160
|
-
batch_predict_corpus =
|
|
161
|
-
|
|
171
|
+
batch_predict_corpus = LLMEndpointBatchPredictOp(
|
|
172
|
+
display_name=(
|
|
173
|
+
'batch-prediction-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}'
|
|
174
|
+
),
|
|
175
|
+
publisher_model=model_name,
|
|
176
|
+
qms_override=qms_override,
|
|
162
177
|
location=location,
|
|
163
|
-
|
|
164
|
-
job_display_name='evaluation-batch-predict-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
|
|
165
|
-
gcs_source_uris=preprocessing_task.outputs[
|
|
178
|
+
source_gcs_uris=preprocessing_task.outputs[
|
|
166
179
|
'predictions_corpus_gcs_source'
|
|
167
180
|
],
|
|
168
|
-
|
|
169
|
-
predictions_format=batch_predict_predictions_format,
|
|
181
|
+
model_parameters=model_parameters,
|
|
170
182
|
gcs_destination_output_uri_prefix=(
|
|
171
183
|
f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_output'
|
|
172
184
|
),
|
|
185
|
+
service_account=service_account,
|
|
173
186
|
encryption_spec_key_name=encryption_spec_key_name,
|
|
187
|
+
project=project,
|
|
174
188
|
)
|
|
175
189
|
|
|
176
|
-
batch_predict_query =
|
|
177
|
-
|
|
190
|
+
batch_predict_query = LLMEndpointBatchPredictOp(
|
|
191
|
+
display_name=(
|
|
192
|
+
'batch-prediction-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}'
|
|
193
|
+
),
|
|
194
|
+
publisher_model=model_name,
|
|
195
|
+
qms_override=qms_override,
|
|
178
196
|
location=location,
|
|
179
|
-
|
|
180
|
-
job_display_name='evaluation-batch-predict-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
|
|
181
|
-
gcs_source_uris=preprocessing_task.outputs[
|
|
197
|
+
source_gcs_uris=preprocessing_task.outputs[
|
|
182
198
|
'predictions_query_gcs_source'
|
|
183
199
|
],
|
|
184
|
-
|
|
185
|
-
predictions_format=batch_predict_predictions_format,
|
|
200
|
+
model_parameters=model_parameters,
|
|
186
201
|
gcs_destination_output_uri_prefix=(
|
|
187
202
|
f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_output'
|
|
188
203
|
),
|
|
204
|
+
service_account=service_account,
|
|
189
205
|
encryption_spec_key_name=encryption_spec_key_name,
|
|
206
|
+
project=project,
|
|
190
207
|
)
|
|
191
208
|
|
|
209
|
+
# batch_predict_corpus = ModelBatchPredictOp(
|
|
210
|
+
# project=project,
|
|
211
|
+
# location=location,
|
|
212
|
+
# model=get_vertex_model_task.outputs['artifact'],
|
|
213
|
+
# job_display_name='evaluation-batch-predict-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
|
|
214
|
+
# gcs_source_uris=preprocessing_task.outputs[
|
|
215
|
+
# 'predictions_corpus_gcs_source'
|
|
216
|
+
# ],
|
|
217
|
+
# instances_format=batch_predict_instances_format,
|
|
218
|
+
# predictions_format=batch_predict_predictions_format,
|
|
219
|
+
# gcs_destination_output_uri_prefix=(
|
|
220
|
+
# f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_output'
|
|
221
|
+
# ),
|
|
222
|
+
# encryption_spec_key_name=encryption_spec_key_name,
|
|
223
|
+
# )
|
|
224
|
+
|
|
225
|
+
# batch_predict_query = ModelBatchPredictOp(
|
|
226
|
+
# project=project,
|
|
227
|
+
# location=location,
|
|
228
|
+
# model=get_vertex_model_task.outputs['artifact'],
|
|
229
|
+
# job_display_name='evaluation-batch-predict-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
|
|
230
|
+
# gcs_source_uris=preprocessing_task.outputs[
|
|
231
|
+
# 'predictions_query_gcs_source'
|
|
232
|
+
# ],
|
|
233
|
+
# instances_format=batch_predict_instances_format,
|
|
234
|
+
# predictions_format=batch_predict_predictions_format,
|
|
235
|
+
# gcs_destination_output_uri_prefix=(
|
|
236
|
+
# f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_output'
|
|
237
|
+
# ),
|
|
238
|
+
# encryption_spec_key_name=encryption_spec_key_name,
|
|
239
|
+
# )
|
|
240
|
+
|
|
192
241
|
# TODO(b/290838262): Revisit if/when the concurrent jobs limit is increased/removed.
|
|
193
242
|
batch_predict_query.after(batch_predict_corpus)
|
|
194
243
|
|
google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation/component.py
CHANGED
|
@@ -29,6 +29,7 @@ from kfp.dsl import OutputPath
|
|
|
29
29
|
def model_evaluation_text_generation(
|
|
30
30
|
gcp_resources: OutputPath(str),
|
|
31
31
|
evaluation_metrics: Output[Metrics],
|
|
32
|
+
row_based_metrics: Output[Metrics],
|
|
32
33
|
project: str,
|
|
33
34
|
location: str,
|
|
34
35
|
evaluation_task: str = 'text-generation',
|
|
@@ -38,6 +39,7 @@ def model_evaluation_text_generation(
|
|
|
38
39
|
joined_predictions_gcs_source: dsl.Input[Artifact] = None,
|
|
39
40
|
predictions_gcs_source: dsl.Input[Artifact] = None,
|
|
40
41
|
ground_truth_gcs_source: str = '',
|
|
42
|
+
enable_row_based_metrics: bool = False,
|
|
41
43
|
display_name: str = 'model-evaluation-text-generation',
|
|
42
44
|
machine_type: str = 'e2-highmem-16',
|
|
43
45
|
service_account: str = '',
|
|
@@ -106,11 +108,14 @@ def model_evaluation_text_generation(
|
|
|
106
108
|
created.
|
|
107
109
|
|
|
108
110
|
Returns:
|
|
109
|
-
evaluation_metrics: `Metrics` artifact representing the language model
|
|
110
|
-
evaluation metrics.
|
|
111
111
|
gcp_resources: Serialized gcp_resources proto tracking the custom job.
|
|
112
112
|
For more details, see
|
|
113
113
|
https://github.com/kubeflow/pipelines/blob/master/components/google-cloud/google_cloud_pipeline_components/proto/README.md.
|
|
114
|
+
evaluation_metrics: `Metrics` artifact representing the language model
|
|
115
|
+
evaluation metrics.
|
|
116
|
+
row_based_metrics: `Metrics` artifact representing the language model
|
|
117
|
+
evaluation metrics of each instance. This is only available if
|
|
118
|
+
enable_row_based_metrics is set to True.
|
|
114
119
|
"""
|
|
115
120
|
return gcpc_utils.build_serverless_customjob_container_spec(
|
|
116
121
|
project=project,
|
|
@@ -128,6 +133,8 @@ def model_evaluation_text_generation(
|
|
|
128
133
|
f'--predictions_gcs_source={predictions_gcs_source.uri}',
|
|
129
134
|
f'--ground_truth_gcs_source={ground_truth_gcs_source}',
|
|
130
135
|
f'--evaluation_metrics_output_path={evaluation_metrics.path}',
|
|
136
|
+
f'--enable_row_based_metrics={enable_row_based_metrics}',
|
|
137
|
+
f'--row_based_metrics_output_path={row_based_metrics.path}',
|
|
131
138
|
'--executor_input={{$.json_escape[1]}}',
|
|
132
139
|
],
|
|
133
140
|
service_account=service_account,
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2023 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Google Cloud Pipeline Evaluation Model Inference Component."""
|