google-cloud-pipeline-components 2.6.0__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_cloud_pipeline_components/_implementation/llm/arbiter_preprocess.py +137 -0
- google_cloud_pipeline_components/_implementation/llm/autosxs_arbiter.py +105 -0
- google_cloud_pipeline_components/_implementation/llm/autosxs_metrics_computer.py +66 -0
- google_cloud_pipeline_components/_implementation/llm/deployment_graph.py +10 -16
- google_cloud_pipeline_components/_implementation/llm/env.py +1 -1
- google_cloud_pipeline_components/_implementation/llm/function_based.py +82 -5
- google_cloud_pipeline_components/_implementation/llm/reinforcement_learning_graph.py +6 -0
- google_cloud_pipeline_components/_implementation/llm/reinforcer.py +7 -2
- google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py +6 -0
- google_cloud_pipeline_components/_implementation/llm/reward_model_trainer.py +7 -2
- google_cloud_pipeline_components/_implementation/llm/supervised_fine_tuner.py +5 -0
- google_cloud_pipeline_components/_implementation/llm/task_preprocess.py +97 -0
- google_cloud_pipeline_components/_implementation/llm/upload_llm_model.py +5 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/__init__.py +4 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/endpoint_batch_predict/component.py +1 -1
- google_cloud_pipeline_components/_implementation/model_evaluation/import_evaluation/component.py +10 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/llm_embedding/evaluation_llm_embedding_pipeline.py +64 -15
- google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation/component.py +9 -2
- google_cloud_pipeline_components/_implementation/model_evaluation/model_inference/__init__.py +14 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/model_inference/component.py +324 -0
- google_cloud_pipeline_components/_implementation/model_evaluation/version.py +2 -2
- google_cloud_pipeline_components/container/_implementation/model_evaluation/import_model_evaluation.py +8 -0
- google_cloud_pipeline_components/container/v1/automl_training_job/__init__.py +14 -0
- google_cloud_pipeline_components/container/v1/automl_training_job/image/__init__.py +14 -0
- google_cloud_pipeline_components/container/v1/automl_training_job/image/launcher.py +236 -0
- google_cloud_pipeline_components/container/v1/automl_training_job/image/remote_runner.py +250 -0
- google_cloud_pipeline_components/preview/model_evaluation/evaluation_llm_text_generation_pipeline.py +6 -1
- google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/__init__.py +20 -0
- google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/autosxs/__init__.py +13 -0
- google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/autosxs/autosxs_pipeline.py +234 -0
- google_cloud_pipeline_components/version.py +1 -1
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/METADATA +1 -1
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/RECORD +36 -23
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/LICENSE +0 -0
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/WHEEL +0 -0
- {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# Copyright 2023 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""KFP Container component for preprocessing predictions for the Arbiter."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from typing import Dict, List
|
|
18
|
+
|
|
19
|
+
from google_cloud_pipeline_components import _placeholders
|
|
20
|
+
from google_cloud_pipeline_components import utils as gcpc_utils
|
|
21
|
+
from google_cloud_pipeline_components._implementation.llm import utils
|
|
22
|
+
from kfp import dsl
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _resolve_image() -> str:
|
|
26
|
+
"""Determines the image URI to create a container from."""
|
|
27
|
+
return (
|
|
28
|
+
os.environ.get('AUTOSXS_IMAGE_OVERRIDE')
|
|
29
|
+
or utils.get_default_image_uri('autosxs'))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# pylint: disable=unused-argument,dangerous-default-value
|
|
33
|
+
@dsl.container_component
|
|
34
|
+
def arbiter_preprocess(
|
|
35
|
+
evaluation_dataset: str,
|
|
36
|
+
id_columns: List[str],
|
|
37
|
+
response_column_a: str,
|
|
38
|
+
response_column_b: str,
|
|
39
|
+
task: str,
|
|
40
|
+
is_bp_output_a: bool,
|
|
41
|
+
is_bp_output_b: bool,
|
|
42
|
+
autorater_prompt_parameters: Dict[str, Dict[str, str]],
|
|
43
|
+
preprocessed_evaluation_dataset: dsl.Output[dsl.Dataset], # pylint: disable=unused-argument # pytype: disable=unsupported-operands
|
|
44
|
+
preprocessed_evaluation_dataset_uri: dsl.OutputPath(str), # pylint: disable=unused-argument # pytype: disable=invalid-annotation
|
|
45
|
+
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
|
|
46
|
+
prediction_uris_a: str = '',
|
|
47
|
+
prediction_uris_b: str = '',
|
|
48
|
+
model_a_prompt_parameters: Dict[str, Dict[str, str]] = {},
|
|
49
|
+
model_b_prompt_parameters: Dict[str, Dict[str, str]] = {},
|
|
50
|
+
human_preference_column: str = '',
|
|
51
|
+
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
|
|
52
|
+
"""Preprocesses predictions tables for the AutoSxS Arbiter.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
evaluation_dataset: GCS or BigQuery URIs representing a dataset of prompts
|
|
56
|
+
and responses.
|
|
57
|
+
id_columns: The columns which distinguish unique evaluation examples.
|
|
58
|
+
response_column_a: The column containing responses for model a.
|
|
59
|
+
response_column_b: The column containing responses for model a.
|
|
60
|
+
task: Task to evaluate.
|
|
61
|
+
output_path: Path to write the path where preprocessed predictions are
|
|
62
|
+
stored.
|
|
63
|
+
is_bp_output_a: If True, the prediction URIs will be parsed as if they came
|
|
64
|
+
from Vertex Batch Prediction, where response_column_a represents a field
|
|
65
|
+
in the model output containing the response. If False, the expected format
|
|
66
|
+
will be a table containing all model_prompt_parameters and the
|
|
67
|
+
response_column.
|
|
68
|
+
is_bp_output_b: If True, the prediction URIs will be parsed as if they came
|
|
69
|
+
from Vertex Batch Prediction, where response_column_b represents a field
|
|
70
|
+
in the model output containing the response. If False, the expected format
|
|
71
|
+
will be a table containing all model_prompt_parameters and the
|
|
72
|
+
response_column.
|
|
73
|
+
prediction_uris: A list of GCS or BigQuery URIs representing a dataset of
|
|
74
|
+
prompts and responses for model a.
|
|
75
|
+
prediction_uris: A list of GCS or BigQuery URIs representing a dataset of
|
|
76
|
+
prompts and responses for model b.
|
|
77
|
+
model_a_prompt_parameters: Map of model A prompt template parameters to
|
|
78
|
+
columns or templates.
|
|
79
|
+
model_b_prompt_parameters: Map of model B prompt template parameters to
|
|
80
|
+
columns or templates.
|
|
81
|
+
autorater_prompt_parameters: Map of autorater prompt template parameters to
|
|
82
|
+
columns or templates.
|
|
83
|
+
human_preference_column: The column containing ground truths. The default
|
|
84
|
+
value is an empty string if not be provided by users.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
preprocessed_evaluation_dataset: Dataset of the table containing the inputs
|
|
88
|
+
expected by the Arbiter.
|
|
89
|
+
preprocessed_evaluation_dataset_uri: URI of the table containing the inputs
|
|
90
|
+
expected by the Arbiter.
|
|
91
|
+
gcp_resources: Tracker for GCP resources created by this component.
|
|
92
|
+
"""
|
|
93
|
+
return gcpc_utils.build_serverless_customjob_container_spec(
|
|
94
|
+
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
|
95
|
+
location=_placeholders.LOCATION_PLACEHOLDER,
|
|
96
|
+
custom_job_payload=utils.build_payload(
|
|
97
|
+
display_name='arbiter_preprocess',
|
|
98
|
+
machine_type='n1-standard-4',
|
|
99
|
+
image_uri=_resolve_image(),
|
|
100
|
+
args=[
|
|
101
|
+
'--', # Used to mark the start of component flags.
|
|
102
|
+
'arbiter_preprocess',
|
|
103
|
+
f'--evaluation_dataset={evaluation_dataset}',
|
|
104
|
+
f'--prediction_uris_a={prediction_uris_a}',
|
|
105
|
+
f'--prediction_uris_b={prediction_uris_b}',
|
|
106
|
+
(
|
|
107
|
+
'--id_columns='
|
|
108
|
+
"{{$.inputs.parameters['id_columns'].json_escape[0]}}"
|
|
109
|
+
),
|
|
110
|
+
(
|
|
111
|
+
'--autorater_prompt_parameters='
|
|
112
|
+
"{{$.inputs.parameters['autorater_prompt_parameters']"
|
|
113
|
+
'.json_escape[0]}}'
|
|
114
|
+
),
|
|
115
|
+
(
|
|
116
|
+
'--model_a_prompt_parameters='
|
|
117
|
+
"{{$.inputs.parameters['model_a_prompt_parameters']"
|
|
118
|
+
'.json_escape[0]}}'
|
|
119
|
+
),
|
|
120
|
+
(
|
|
121
|
+
'--model_b_prompt_parameters='
|
|
122
|
+
"{{$.inputs.parameters['model_b_prompt_parameters']"
|
|
123
|
+
'.json_escape[0]}}'
|
|
124
|
+
),
|
|
125
|
+
f'--response_column_a={response_column_a}',
|
|
126
|
+
f'--response_column_b={response_column_b}',
|
|
127
|
+
f'--human_preference_column={human_preference_column}',
|
|
128
|
+
f'--task={task}',
|
|
129
|
+
f'--is_batch_prediction_output_a={is_bp_output_a}',
|
|
130
|
+
f'--is_batch_prediction_output_b={is_bp_output_b}',
|
|
131
|
+
f'--output_dir={dsl.PIPELINE_ROOT_PLACEHOLDER}',
|
|
132
|
+
f'--preprocessed_evaluation_dataset_uri={preprocessed_evaluation_dataset_uri}',
|
|
133
|
+
'--executor_input={{$.json_escape[1]}}',
|
|
134
|
+
],
|
|
135
|
+
),
|
|
136
|
+
gcp_resources=gcp_resources,
|
|
137
|
+
)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright 2023 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""KFP Container component that performs AutoSxS."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from typing import Any, Dict, List
|
|
18
|
+
|
|
19
|
+
from google_cloud_pipeline_components import _placeholders
|
|
20
|
+
from google_cloud_pipeline_components import utils as gcpc_utils
|
|
21
|
+
from google_cloud_pipeline_components._implementation.llm import utils
|
|
22
|
+
from kfp import dsl
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _resolve_image() -> str:
|
|
26
|
+
"""Determines the image URI to create a container from."""
|
|
27
|
+
return (
|
|
28
|
+
os.environ.get('AUTOSXS_IMAGE_OVERRIDE')
|
|
29
|
+
or utils.get_default_image_uri('autosxs'))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_prediction_endpoint_overrides() -> str:
|
|
33
|
+
"""Used for integration tests to override the prediction endpoint."""
|
|
34
|
+
return os.environ.get('PREDICTION_ENDPOINT_OVERRIDES', '')
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dsl.container_component
|
|
38
|
+
def autosxs_arbiter(
|
|
39
|
+
inference_output_uri: str,
|
|
40
|
+
id_columns: List[str],
|
|
41
|
+
task: str,
|
|
42
|
+
judgments: dsl.Output[dsl.Dataset], # pylint: disable=unused-argument # pytype: disable=unsupported-operands
|
|
43
|
+
judgments_uri: dsl.OutputPath(str), # pytype: disable=invalid-annotation
|
|
44
|
+
gcp_resources: dsl.OutputPath(str),
|
|
45
|
+
metadata: dsl.OutputPath(str),
|
|
46
|
+
human_preference_column: str = '',
|
|
47
|
+
judgments_format: str = 'jsonl',
|
|
48
|
+
bigquery_destination_prefix: str = '',
|
|
49
|
+
experimental_args: Dict[str, Any] = {},
|
|
50
|
+
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
|
|
51
|
+
"""Evaluate two models using an autorater.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
inference_output_uri: Directory of model A's inference output.
|
|
55
|
+
id_columns: The columns which distinguish unique evaluation examples.
|
|
56
|
+
human_preference_column: Human preference column included in our inference
|
|
57
|
+
output.
|
|
58
|
+
task: Evaluation task in the form {task}@{version}. task can be one of
|
|
59
|
+
"summarization", "question_answer". Version is an integer with 3 digits or
|
|
60
|
+
"latest". Ex: summarization@001 or question_answer@latest.
|
|
61
|
+
judgments_format: The format to write judgments to. Can be either 'json' or
|
|
62
|
+
'bigquery'.
|
|
63
|
+
bigquery_destination_prefix: BigQuery table to write judgments to if the
|
|
64
|
+
specified format is 'bigquery'.
|
|
65
|
+
experimental_args: Experimentally released arguments. Subject to change.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
judgments: Individual judgments used to calculate the win rates.
|
|
69
|
+
judgments_uri: URI of the Judgments Artifact.
|
|
70
|
+
gcp_resources: Tracker for GCP resources created by this component.
|
|
71
|
+
metadata: Computed runtime metrics metadata from this component.
|
|
72
|
+
"""
|
|
73
|
+
return gcpc_utils.build_serverless_customjob_container_spec(
|
|
74
|
+
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
|
75
|
+
# Hardcode location to us-central1 for text-bison availability.
|
|
76
|
+
location='us-central1',
|
|
77
|
+
custom_job_payload=utils.build_payload(
|
|
78
|
+
display_name='autosxs_arbiter',
|
|
79
|
+
machine_type='n1-standard-4',
|
|
80
|
+
image_uri=_resolve_image(),
|
|
81
|
+
args=[
|
|
82
|
+
'--', # Used to mark the start of component flags.
|
|
83
|
+
'arbiter',
|
|
84
|
+
f'--inference_output_uri={inference_output_uri}',
|
|
85
|
+
f'--human_preference_column={human_preference_column}',
|
|
86
|
+
f'--task={task}',
|
|
87
|
+
f'--prediction_endpoint_overrides={_get_prediction_endpoint_overrides()}',
|
|
88
|
+
f'--output_dir={dsl.PIPELINE_ROOT_PLACEHOLDER}',
|
|
89
|
+
f'--judgments_uri={judgments_uri}',
|
|
90
|
+
f'--judgments_format={judgments_format}',
|
|
91
|
+
f'--bigquery_destination_prefix={bigquery_destination_prefix}',
|
|
92
|
+
(
|
|
93
|
+
'--id_columns='
|
|
94
|
+
"{{$.inputs.parameters['id_columns'].json_escape[0]}}"
|
|
95
|
+
),
|
|
96
|
+
(
|
|
97
|
+
'--experimental_args='
|
|
98
|
+
"{{$.inputs.parameters['experimental_args'].json_escape[0]}}"
|
|
99
|
+
),
|
|
100
|
+
'--executor_input={{$.json_escape[1]}}',
|
|
101
|
+
f'--metadata_path={metadata}',
|
|
102
|
+
],
|
|
103
|
+
),
|
|
104
|
+
gcp_resources=gcp_resources,
|
|
105
|
+
)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2023 The Kubeflow Authors. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""KFP Container component for computing AutoSXS metrics."""
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from google_cloud_pipeline_components import _placeholders
|
|
19
|
+
from google_cloud_pipeline_components import utils as gcpc_utils
|
|
20
|
+
from google_cloud_pipeline_components._implementation.llm import utils
|
|
21
|
+
from kfp import dsl
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _resolve_image() -> str:
|
|
25
|
+
"""Determines the image URI to create a container from."""
|
|
26
|
+
return os.environ.get(
|
|
27
|
+
'AUTOSXS_IMAGE_OVERRIDE'
|
|
28
|
+
) or utils.get_default_image_uri('autosxs')
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dsl.container_component
|
|
32
|
+
def autosxs_metrics_computer(
|
|
33
|
+
judgments_dir: str,
|
|
34
|
+
has_human_preference: bool,
|
|
35
|
+
autosxs_metrics: dsl.Output[dsl.Metrics], # pylint: disable=unused-argument # pytype: disable=unsupported-operands
|
|
36
|
+
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
|
|
37
|
+
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
|
|
38
|
+
"""Compute AutoSXS metrics using judgments outputs from Arbiter.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
judgments_dir: Path where store the Judgments.
|
|
42
|
+
has_human_preference: Boolean value. True if users provided human preference
|
|
43
|
+
data, otherwise false.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
autosxs_metrics: Autosxs win rate metrics and human alignment metrics.
|
|
47
|
+
gcp_resources: Tracker for GCP resources created by this component.
|
|
48
|
+
"""
|
|
49
|
+
return gcpc_utils.build_serverless_customjob_container_spec(
|
|
50
|
+
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
|
51
|
+
# Hardcode location to us-central1 for text-bison availability.
|
|
52
|
+
location='us-central1',
|
|
53
|
+
custom_job_payload=utils.build_payload(
|
|
54
|
+
display_name='autosxs_metrics_computer',
|
|
55
|
+
machine_type='n1-standard-4',
|
|
56
|
+
image_uri=_resolve_image(),
|
|
57
|
+
args=[
|
|
58
|
+
'--', # Used to mark the start of component flags.
|
|
59
|
+
'autosxs_metrics',
|
|
60
|
+
f'--judgments_dir={judgments_dir}',
|
|
61
|
+
f'--has_human_preference={has_human_preference}',
|
|
62
|
+
'--executor_input={{$.json_escape[1]}}',
|
|
63
|
+
],
|
|
64
|
+
),
|
|
65
|
+
gcp_resources=gcp_resources,
|
|
66
|
+
)
|
|
@@ -75,22 +75,16 @@ def pipeline(
|
|
|
75
75
|
'large_model_reference'
|
|
76
76
|
]
|
|
77
77
|
).set_display_name('Resolve Upload Model')
|
|
78
|
-
upload_task = (
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
.set_env_variable(
|
|
89
|
-
name='VERTEX_AI_PIPELINES_RUN_LABELS',
|
|
90
|
-
value=json.dumps({'tune-type': 'rlhf'}),
|
|
91
|
-
)
|
|
92
|
-
.set_display_name('Upload Model')
|
|
93
|
-
)
|
|
78
|
+
upload_task = upload_llm_model.upload_llm_model(
|
|
79
|
+
project=_placeholders.PROJECT_ID_PLACEHOLDER,
|
|
80
|
+
location=upload_location,
|
|
81
|
+
regional_endpoint=regional_endpoint.output,
|
|
82
|
+
artifact_uri=adapter_artifact.output,
|
|
83
|
+
model_display_name=display_name.output,
|
|
84
|
+
model_reference_name='text-bison@001',
|
|
85
|
+
upload_model=upload_model.output,
|
|
86
|
+
tune_type='rlhf',
|
|
87
|
+
).set_display_name('Upload Model')
|
|
94
88
|
deploy_model = function_based.resolve_deploy_model(
|
|
95
89
|
deploy_model=deploy_model,
|
|
96
90
|
large_model_reference=reference_model_metadata.outputs[
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Python function-based components used in KFP pipelies."""
|
|
15
15
|
import functools
|
|
16
|
-
from typing import List, NamedTuple, Optional
|
|
16
|
+
from typing import Any, Dict, List, NamedTuple, Optional
|
|
17
17
|
|
|
18
18
|
from google_cloud_pipeline_components import _image
|
|
19
19
|
from google_cloud_pipeline_components._implementation.llm import env
|
|
@@ -302,8 +302,8 @@ def resolve_reference_model_metadata(
|
|
|
302
302
|
'llama-2-13b': reference_model_metadata(
|
|
303
303
|
large_model_reference='LLAMA_2_13B',
|
|
304
304
|
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b/',
|
|
305
|
-
reward_model_reference='
|
|
306
|
-
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/
|
|
305
|
+
reward_model_reference='LLAMA_2_7B',
|
|
306
|
+
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b/',
|
|
307
307
|
is_supported=True,
|
|
308
308
|
),
|
|
309
309
|
'llama-2-7b-chat': reference_model_metadata(
|
|
@@ -316,8 +316,8 @@ def resolve_reference_model_metadata(
|
|
|
316
316
|
'llama-2-13b-chat': reference_model_metadata(
|
|
317
317
|
large_model_reference='LLAMA_2_13B_CHAT',
|
|
318
318
|
reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b_chat/',
|
|
319
|
-
reward_model_reference='
|
|
320
|
-
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/
|
|
319
|
+
reward_model_reference='LLAMA_2_7B',
|
|
320
|
+
reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b/',
|
|
321
321
|
is_supported=True,
|
|
322
322
|
),
|
|
323
323
|
}
|
|
@@ -495,3 +495,80 @@ def resolve_instruction(
|
|
|
495
495
|
"""
|
|
496
496
|
instruction = instruction or ''
|
|
497
497
|
return instruction if 'chat' not in large_model_reference.lower() else ''
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
|
|
501
|
+
def resolve_num_microbatches(large_model_reference: str) -> int:
|
|
502
|
+
"""Resolves the number of microbatches to use during training.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
large_model_reference: Base model tuned by the pipeline.
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
Number of microbatches to break the total batch size into during training.
|
|
509
|
+
"""
|
|
510
|
+
if 'llama' in large_model_reference.lower():
|
|
511
|
+
return 2
|
|
512
|
+
return 0
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
|
|
516
|
+
def read_file(path: str) -> str:
|
|
517
|
+
"""Reads the contents of the given file."""
|
|
518
|
+
# pylint: disable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
|
|
519
|
+
import re
|
|
520
|
+
# pylint: enable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
|
|
521
|
+
|
|
522
|
+
path = re.sub('^gs://', '/gcs/', path)
|
|
523
|
+
with open(path, 'r') as f:
|
|
524
|
+
return f.read()
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
|
|
528
|
+
def get_usage_metric(metadata: Dict[str, Any], key: str) -> bool: # pytype: disable=unsupported-operands
|
|
529
|
+
"""Extracts a single usage metric from metadata."""
|
|
530
|
+
return metadata[key]
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
|
|
534
|
+
def dump_dict(value: Dict[Any, Any]) -> str:
|
|
535
|
+
"""Dumps the given dict to a JSON string."""
|
|
536
|
+
# pylint: disable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
|
|
537
|
+
import json
|
|
538
|
+
# pylint: enable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
|
|
539
|
+
|
|
540
|
+
return json.dumps(value).replace('"', '\\"')
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
|
|
544
|
+
def dump_list(value: List[Any]) -> str:
|
|
545
|
+
"""Dumps the given dict to a JSON string."""
|
|
546
|
+
# pylint: disable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
|
|
547
|
+
import json
|
|
548
|
+
# pylint: enable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
|
|
549
|
+
|
|
550
|
+
return json.dumps(value).replace('"', '\\"')
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
|
|
554
|
+
def identity(
|
|
555
|
+
x: str,
|
|
556
|
+
) -> str:
|
|
557
|
+
return x
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
|
|
561
|
+
def get_uri(artifact: dsl.Input[dsl.Artifact], is_dir: bool = False) -> str: # pytype: disable=unsupported-operands
|
|
562
|
+
"""Extracts the URI from an artifact."""
|
|
563
|
+
# pylint: disable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
|
|
564
|
+
import os
|
|
565
|
+
# pylint: enable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
|
|
566
|
+
|
|
567
|
+
if is_dir:
|
|
568
|
+
return os.path.join(artifact.uri, '*')
|
|
569
|
+
return artifact.uri
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
@dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
|
|
573
|
+
def get_empty_string() -> str:
|
|
574
|
+
return ''
|
|
@@ -117,6 +117,11 @@ def pipeline(
|
|
|
117
117
|
accelerator_type=machine_spec.outputs['accelerator_type'],
|
|
118
118
|
accelerator_count=machine_spec.outputs['accelerator_count'],
|
|
119
119
|
).set_display_name('Resolve Reinforcer Image URI')
|
|
120
|
+
num_microbatches = function_based.resolve_num_microbatches(
|
|
121
|
+
large_model_reference=reference_model_metadata.outputs[
|
|
122
|
+
'large_model_reference'
|
|
123
|
+
]
|
|
124
|
+
).set_display_name('Resolve Number of Microbatches')
|
|
120
125
|
rl_model = (
|
|
121
126
|
reinforcer.Reinforcer(
|
|
122
127
|
project=project,
|
|
@@ -145,6 +150,7 @@ def pipeline(
|
|
|
145
150
|
learning_rate_multiplier=reinforcement_learning_rate_multiplier,
|
|
146
151
|
kl_coeff=kl_coeff,
|
|
147
152
|
lora_dim=lora_dim,
|
|
153
|
+
num_microbatches=num_microbatches.output,
|
|
148
154
|
)
|
|
149
155
|
.set_display_name('Reinforcer')
|
|
150
156
|
.set_caching_options(False)
|
|
@@ -43,6 +43,7 @@ def Reinforcer( # pylint: disable=invalid-name
|
|
|
43
43
|
learning_rate_multiplier: float = 1.0,
|
|
44
44
|
kl_coeff: float = 0.1,
|
|
45
45
|
lora_dim: int = 0,
|
|
46
|
+
num_microbatches: int = 0,
|
|
46
47
|
) -> kfp.dsl.ContainerSpec: # pylint: disable=g-doc-args
|
|
47
48
|
"""Trains a model using reinforcement learning.
|
|
48
49
|
|
|
@@ -53,8 +54,8 @@ def Reinforcer( # pylint: disable=invalid-name
|
|
|
53
54
|
input_reward_model_path: Path to the reward model to use during
|
|
54
55
|
reinforcement learning.
|
|
55
56
|
input_dataset_path: Path to training dataset.
|
|
56
|
-
train_steps: Number of training steps. These are the number of steps
|
|
57
|
-
|
|
57
|
+
train_steps: Number of training steps. These are the number of steps on top
|
|
58
|
+
of any steps used to train the base model.
|
|
58
59
|
targets_length: Maximum decoder steps. Outputs will be at most this length.
|
|
59
60
|
accelerator_type: Type of TPU accelerator. Can be either TPU_V2 or TPU_V3.
|
|
60
61
|
accelerator_count: Number of TPU accelerators.
|
|
@@ -75,6 +76,9 @@ def Reinforcer( # pylint: disable=invalid-name
|
|
|
75
76
|
then use full-tuning.
|
|
76
77
|
learning_rate_multiplier: Constant multiplied by the base learning rate used
|
|
77
78
|
to adjust the learning rate during reinforcement learning.
|
|
79
|
+
num_microbatches: Number of microbatches to break the total batch size into
|
|
80
|
+
during training. If <= 1, the model is trained on the full batch size
|
|
81
|
+
directly.
|
|
78
82
|
|
|
79
83
|
Returns:
|
|
80
84
|
output_model_path: Path to the trained model checkpoint.
|
|
@@ -110,6 +114,7 @@ def Reinforcer( # pylint: disable=invalid-name
|
|
|
110
114
|
f'--learning_rate_multiplier={learning_rate_multiplier}',
|
|
111
115
|
f'--kl_coeff={kl_coeff}',
|
|
112
116
|
f'--lora_dim={lora_dim}',
|
|
117
|
+
f'--num_microbatches={num_microbatches}',
|
|
113
118
|
],
|
|
114
119
|
),
|
|
115
120
|
gcp_resources=gcp_resources,
|
|
@@ -118,6 +118,11 @@ def pipeline(
|
|
|
118
118
|
accelerator_type=machine_spec.outputs['accelerator_type'],
|
|
119
119
|
accelerator_count=machine_spec.outputs['accelerator_count'],
|
|
120
120
|
).set_display_name('Resolve Reward Model Image URI')
|
|
121
|
+
num_microbatches = function_based.resolve_num_microbatches(
|
|
122
|
+
large_model_reference=reference_model_metadata.outputs[
|
|
123
|
+
'reward_model_reference'
|
|
124
|
+
]
|
|
125
|
+
).set_display_name('Resolve Number of Microbatches')
|
|
121
126
|
reward_model = (
|
|
122
127
|
reward_model_trainer.RewardModelTrainer(
|
|
123
128
|
project=project,
|
|
@@ -141,6 +146,7 @@ def pipeline(
|
|
|
141
146
|
batch_size=batch_size,
|
|
142
147
|
learning_rate_multiplier=reward_model_learning_rate_multiplier,
|
|
143
148
|
lora_dim=lora_dim,
|
|
149
|
+
num_microbatches=num_microbatches.output,
|
|
144
150
|
)
|
|
145
151
|
.set_display_name('Reward Model Trainer')
|
|
146
152
|
.set_caching_options(False)
|
|
@@ -39,6 +39,7 @@ def RewardModelTrainer( # pylint: disable=invalid-name
|
|
|
39
39
|
batch_size: int = 64,
|
|
40
40
|
learning_rate_multiplier: float = 1.0,
|
|
41
41
|
lora_dim: int = 0,
|
|
42
|
+
num_microbatches: int = 0,
|
|
42
43
|
) -> kfp.dsl.ContainerSpec: # pylint: disable=g-doc-args
|
|
43
44
|
"""Trains a reward model.
|
|
44
45
|
|
|
@@ -47,8 +48,8 @@ def RewardModelTrainer( # pylint: disable=invalid-name
|
|
|
47
48
|
location: Location used to run the job.
|
|
48
49
|
input_model_path: Path to the base model to fine tune.
|
|
49
50
|
input_dataset_path: Path to dataset to use to train a reward model.
|
|
50
|
-
train_steps: Number of training steps. These are the number of steps
|
|
51
|
-
|
|
51
|
+
train_steps: Number of training steps. These are the number of steps on top
|
|
52
|
+
of any steps used to train the base model.
|
|
52
53
|
accelerator_type: Type of TPU accelerator. Can be either TPU_V2 or TPU_V3.
|
|
53
54
|
accelerator_count: Number of TPU accelerators.
|
|
54
55
|
large_model_reference: Predefined model used to create the ``input_model``.
|
|
@@ -64,6 +65,9 @@ def RewardModelTrainer( # pylint: disable=invalid-name
|
|
|
64
65
|
then use full-tuning.
|
|
65
66
|
learning_rate_multiplier: Constant multiplied by the base learning rate used
|
|
66
67
|
to adjust the learning rate when training a reward model.
|
|
68
|
+
num_microbatches: Number of microbatches to break the total batch size into
|
|
69
|
+
during training. If <= 1, the model is trained on the full batch size
|
|
70
|
+
directly.
|
|
67
71
|
|
|
68
72
|
Returns:
|
|
69
73
|
output_model: Trained reward model.
|
|
@@ -98,6 +102,7 @@ def RewardModelTrainer( # pylint: disable=invalid-name
|
|
|
98
102
|
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
|
|
99
103
|
),
|
|
100
104
|
f'--lora_dim={lora_dim}',
|
|
105
|
+
f'--num_microbatches={num_microbatches}',
|
|
101
106
|
],
|
|
102
107
|
),
|
|
103
108
|
gcp_resources=gcp_resources,
|
|
@@ -39,6 +39,7 @@ def SupervisedFineTuner( # pylint: disable=invalid-name
|
|
|
39
39
|
batch_size: int = 64,
|
|
40
40
|
learning_rate_multiplier: float = 1.0,
|
|
41
41
|
lora_dim: int = 0,
|
|
42
|
+
num_microbatches: int = 0,
|
|
42
43
|
) -> kfp.dsl.ContainerSpec: # pylint: disable=g-doc-args
|
|
43
44
|
"""Performs supervised fine tuning.
|
|
44
45
|
|
|
@@ -65,6 +66,9 @@ def SupervisedFineTuner( # pylint: disable=invalid-name
|
|
|
65
66
|
then use full-tuning.
|
|
66
67
|
learning_rate_multiplier: Constant multiplied by the base learning rate used
|
|
67
68
|
to adjust the learning rate during supervised fine tuning.
|
|
69
|
+
num_microbatches: Number of microbatches to break the total batch size into
|
|
70
|
+
during training. If <= 1, the model is trained on the full batch size
|
|
71
|
+
directly.
|
|
68
72
|
|
|
69
73
|
Returns:
|
|
70
74
|
output_model_path: Fine-tuned model path.
|
|
@@ -99,6 +103,7 @@ def SupervisedFineTuner( # pylint: disable=invalid-name
|
|
|
99
103
|
f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
|
|
100
104
|
),
|
|
101
105
|
f'--lora_dim={lora_dim}',
|
|
106
|
+
f'--num_microbatches={num_microbatches}',
|
|
102
107
|
],
|
|
103
108
|
),
|
|
104
109
|
gcp_resources=gcp_resources,
|