google-cloud-pipeline-components 2.6.0__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. google_cloud_pipeline_components/_implementation/llm/arbiter_preprocess.py +137 -0
  2. google_cloud_pipeline_components/_implementation/llm/autosxs_arbiter.py +105 -0
  3. google_cloud_pipeline_components/_implementation/llm/autosxs_metrics_computer.py +66 -0
  4. google_cloud_pipeline_components/_implementation/llm/deployment_graph.py +10 -16
  5. google_cloud_pipeline_components/_implementation/llm/env.py +1 -1
  6. google_cloud_pipeline_components/_implementation/llm/function_based.py +82 -5
  7. google_cloud_pipeline_components/_implementation/llm/reinforcement_learning_graph.py +6 -0
  8. google_cloud_pipeline_components/_implementation/llm/reinforcer.py +7 -2
  9. google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py +6 -0
  10. google_cloud_pipeline_components/_implementation/llm/reward_model_trainer.py +7 -2
  11. google_cloud_pipeline_components/_implementation/llm/supervised_fine_tuner.py +5 -0
  12. google_cloud_pipeline_components/_implementation/llm/task_preprocess.py +97 -0
  13. google_cloud_pipeline_components/_implementation/llm/upload_llm_model.py +5 -0
  14. google_cloud_pipeline_components/_implementation/model_evaluation/__init__.py +4 -0
  15. google_cloud_pipeline_components/_implementation/model_evaluation/endpoint_batch_predict/component.py +1 -1
  16. google_cloud_pipeline_components/_implementation/model_evaluation/import_evaluation/component.py +10 -0
  17. google_cloud_pipeline_components/_implementation/model_evaluation/llm_embedding/evaluation_llm_embedding_pipeline.py +64 -15
  18. google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation/component.py +9 -2
  19. google_cloud_pipeline_components/_implementation/model_evaluation/model_inference/__init__.py +14 -0
  20. google_cloud_pipeline_components/_implementation/model_evaluation/model_inference/component.py +324 -0
  21. google_cloud_pipeline_components/_implementation/model_evaluation/version.py +2 -2
  22. google_cloud_pipeline_components/container/_implementation/model_evaluation/import_model_evaluation.py +8 -0
  23. google_cloud_pipeline_components/container/v1/automl_training_job/__init__.py +14 -0
  24. google_cloud_pipeline_components/container/v1/automl_training_job/image/__init__.py +14 -0
  25. google_cloud_pipeline_components/container/v1/automl_training_job/image/launcher.py +236 -0
  26. google_cloud_pipeline_components/container/v1/automl_training_job/image/remote_runner.py +250 -0
  27. google_cloud_pipeline_components/preview/model_evaluation/evaluation_llm_text_generation_pipeline.py +6 -1
  28. google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/__init__.py +20 -0
  29. google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/autosxs/__init__.py +13 -0
  30. google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/autosxs/autosxs_pipeline.py +234 -0
  31. google_cloud_pipeline_components/version.py +1 -1
  32. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/METADATA +1 -1
  33. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/RECORD +36 -23
  34. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/LICENSE +0 -0
  35. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/WHEEL +0 -0
  36. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,137 @@
1
+ # Copyright 2023 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """KFP Container component for preprocessing predictions for the Arbiter."""
15
+
16
+ import os
17
+ from typing import Dict, List
18
+
19
+ from google_cloud_pipeline_components import _placeholders
20
+ from google_cloud_pipeline_components import utils as gcpc_utils
21
+ from google_cloud_pipeline_components._implementation.llm import utils
22
+ from kfp import dsl
23
+
24
+
25
+ def _resolve_image() -> str:
26
+ """Determines the image URI to create a container from."""
27
+ return (
28
+ os.environ.get('AUTOSXS_IMAGE_OVERRIDE')
29
+ or utils.get_default_image_uri('autosxs'))
30
+
31
+
32
+ # pylint: disable=unused-argument,dangerous-default-value
33
+ @dsl.container_component
34
+ def arbiter_preprocess(
35
+ evaluation_dataset: str,
36
+ id_columns: List[str],
37
+ response_column_a: str,
38
+ response_column_b: str,
39
+ task: str,
40
+ is_bp_output_a: bool,
41
+ is_bp_output_b: bool,
42
+ autorater_prompt_parameters: Dict[str, Dict[str, str]],
43
+ preprocessed_evaluation_dataset: dsl.Output[dsl.Dataset], # pylint: disable=unused-argument # pytype: disable=unsupported-operands
44
+ preprocessed_evaluation_dataset_uri: dsl.OutputPath(str), # pylint: disable=unused-argument # pytype: disable=invalid-annotation
45
+ gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
46
+ prediction_uris_a: str = '',
47
+ prediction_uris_b: str = '',
48
+ model_a_prompt_parameters: Dict[str, Dict[str, str]] = {},
49
+ model_b_prompt_parameters: Dict[str, Dict[str, str]] = {},
50
+ human_preference_column: str = '',
51
+ ) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
52
+ """Preprocesses predictions tables for the AutoSxS Arbiter.
53
+
54
+ Args:
55
+ evaluation_dataset: GCS or BigQuery URIs representing a dataset of prompts
56
+ and responses.
57
+ id_columns: The columns which distinguish unique evaluation examples.
58
+ response_column_a: The column containing responses for model a.
59
+ response_column_b: The column containing responses for model a.
60
+ task: Task to evaluate.
61
+ output_path: Path to write the path where preprocessed predictions are
62
+ stored.
63
+ is_bp_output_a: If True, the prediction URIs will be parsed as if they came
64
+ from Vertex Batch Prediction, where response_column_a represents a field
65
+ in the model output containing the response. If False, the expected format
66
+ will be a table containing all model_prompt_parameters and the
67
+ response_column.
68
+ is_bp_output_b: If True, the prediction URIs will be parsed as if they came
69
+ from Vertex Batch Prediction, where response_column_b represents a field
70
+ in the model output containing the response. If False, the expected format
71
+ will be a table containing all model_prompt_parameters and the
72
+ response_column.
73
+ prediction_uris: A list of GCS or BigQuery URIs representing a dataset of
74
+ prompts and responses for model a.
75
+ prediction_uris: A list of GCS or BigQuery URIs representing a dataset of
76
+ prompts and responses for model b.
77
+ model_a_prompt_parameters: Map of model A prompt template parameters to
78
+ columns or templates.
79
+ model_b_prompt_parameters: Map of model B prompt template parameters to
80
+ columns or templates.
81
+ autorater_prompt_parameters: Map of autorater prompt template parameters to
82
+ columns or templates.
83
+ human_preference_column: The column containing ground truths. The default
84
+ value is an empty string if not be provided by users.
85
+
86
+ Returns:
87
+ preprocessed_evaluation_dataset: Dataset of the table containing the inputs
88
+ expected by the Arbiter.
89
+ preprocessed_evaluation_dataset_uri: URI of the table containing the inputs
90
+ expected by the Arbiter.
91
+ gcp_resources: Tracker for GCP resources created by this component.
92
+ """
93
+ return gcpc_utils.build_serverless_customjob_container_spec(
94
+ project=_placeholders.PROJECT_ID_PLACEHOLDER,
95
+ location=_placeholders.LOCATION_PLACEHOLDER,
96
+ custom_job_payload=utils.build_payload(
97
+ display_name='arbiter_preprocess',
98
+ machine_type='n1-standard-4',
99
+ image_uri=_resolve_image(),
100
+ args=[
101
+ '--', # Used to mark the start of component flags.
102
+ 'arbiter_preprocess',
103
+ f'--evaluation_dataset={evaluation_dataset}',
104
+ f'--prediction_uris_a={prediction_uris_a}',
105
+ f'--prediction_uris_b={prediction_uris_b}',
106
+ (
107
+ '--id_columns='
108
+ "{{$.inputs.parameters['id_columns'].json_escape[0]}}"
109
+ ),
110
+ (
111
+ '--autorater_prompt_parameters='
112
+ "{{$.inputs.parameters['autorater_prompt_parameters']"
113
+ '.json_escape[0]}}'
114
+ ),
115
+ (
116
+ '--model_a_prompt_parameters='
117
+ "{{$.inputs.parameters['model_a_prompt_parameters']"
118
+ '.json_escape[0]}}'
119
+ ),
120
+ (
121
+ '--model_b_prompt_parameters='
122
+ "{{$.inputs.parameters['model_b_prompt_parameters']"
123
+ '.json_escape[0]}}'
124
+ ),
125
+ f'--response_column_a={response_column_a}',
126
+ f'--response_column_b={response_column_b}',
127
+ f'--human_preference_column={human_preference_column}',
128
+ f'--task={task}',
129
+ f'--is_batch_prediction_output_a={is_bp_output_a}',
130
+ f'--is_batch_prediction_output_b={is_bp_output_b}',
131
+ f'--output_dir={dsl.PIPELINE_ROOT_PLACEHOLDER}',
132
+ f'--preprocessed_evaluation_dataset_uri={preprocessed_evaluation_dataset_uri}',
133
+ '--executor_input={{$.json_escape[1]}}',
134
+ ],
135
+ ),
136
+ gcp_resources=gcp_resources,
137
+ )
@@ -0,0 +1,105 @@
1
+ # Copyright 2023 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """KFP Container component that performs AutoSxS."""
15
+
16
+ import os
17
+ from typing import Any, Dict, List
18
+
19
+ from google_cloud_pipeline_components import _placeholders
20
+ from google_cloud_pipeline_components import utils as gcpc_utils
21
+ from google_cloud_pipeline_components._implementation.llm import utils
22
+ from kfp import dsl
23
+
24
+
25
+ def _resolve_image() -> str:
26
+ """Determines the image URI to create a container from."""
27
+ return (
28
+ os.environ.get('AUTOSXS_IMAGE_OVERRIDE')
29
+ or utils.get_default_image_uri('autosxs'))
30
+
31
+
32
+ def _get_prediction_endpoint_overrides() -> str:
33
+ """Used for integration tests to override the prediction endpoint."""
34
+ return os.environ.get('PREDICTION_ENDPOINT_OVERRIDES', '')
35
+
36
+
37
+ @dsl.container_component
38
+ def autosxs_arbiter(
39
+ inference_output_uri: str,
40
+ id_columns: List[str],
41
+ task: str,
42
+ judgments: dsl.Output[dsl.Dataset], # pylint: disable=unused-argument # pytype: disable=unsupported-operands
43
+ judgments_uri: dsl.OutputPath(str), # pytype: disable=invalid-annotation
44
+ gcp_resources: dsl.OutputPath(str),
45
+ metadata: dsl.OutputPath(str),
46
+ human_preference_column: str = '',
47
+ judgments_format: str = 'jsonl',
48
+ bigquery_destination_prefix: str = '',
49
+ experimental_args: Dict[str, Any] = {},
50
+ ) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
51
+ """Evaluate two models using an autorater.
52
+
53
+ Args:
54
+ inference_output_uri: Directory of model A's inference output.
55
+ id_columns: The columns which distinguish unique evaluation examples.
56
+ human_preference_column: Human preference column included in our inference
57
+ output.
58
+ task: Evaluation task in the form {task}@{version}. task can be one of
59
+ "summarization", "question_answer". Version is an integer with 3 digits or
60
+ "latest". Ex: summarization@001 or question_answer@latest.
61
+ judgments_format: The format to write judgments to. Can be either 'json' or
62
+ 'bigquery'.
63
+ bigquery_destination_prefix: BigQuery table to write judgments to if the
64
+ specified format is 'bigquery'.
65
+ experimental_args: Experimentally released arguments. Subject to change.
66
+
67
+ Returns:
68
+ judgments: Individual judgments used to calculate the win rates.
69
+ judgments_uri: URI of the Judgments Artifact.
70
+ gcp_resources: Tracker for GCP resources created by this component.
71
+ metadata: Computed runtime metrics metadata from this component.
72
+ """
73
+ return gcpc_utils.build_serverless_customjob_container_spec(
74
+ project=_placeholders.PROJECT_ID_PLACEHOLDER,
75
+ # Hardcode location to us-central1 for text-bison availability.
76
+ location='us-central1',
77
+ custom_job_payload=utils.build_payload(
78
+ display_name='autosxs_arbiter',
79
+ machine_type='n1-standard-4',
80
+ image_uri=_resolve_image(),
81
+ args=[
82
+ '--', # Used to mark the start of component flags.
83
+ 'arbiter',
84
+ f'--inference_output_uri={inference_output_uri}',
85
+ f'--human_preference_column={human_preference_column}',
86
+ f'--task={task}',
87
+ f'--prediction_endpoint_overrides={_get_prediction_endpoint_overrides()}',
88
+ f'--output_dir={dsl.PIPELINE_ROOT_PLACEHOLDER}',
89
+ f'--judgments_uri={judgments_uri}',
90
+ f'--judgments_format={judgments_format}',
91
+ f'--bigquery_destination_prefix={bigquery_destination_prefix}',
92
+ (
93
+ '--id_columns='
94
+ "{{$.inputs.parameters['id_columns'].json_escape[0]}}"
95
+ ),
96
+ (
97
+ '--experimental_args='
98
+ "{{$.inputs.parameters['experimental_args'].json_escape[0]}}"
99
+ ),
100
+ '--executor_input={{$.json_escape[1]}}',
101
+ f'--metadata_path={metadata}',
102
+ ],
103
+ ),
104
+ gcp_resources=gcp_resources,
105
+ )
@@ -0,0 +1,66 @@
1
+ # Copyright 2023 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """KFP Container component for computing AutoSXS metrics."""
15
+
16
+ import os
17
+
18
+ from google_cloud_pipeline_components import _placeholders
19
+ from google_cloud_pipeline_components import utils as gcpc_utils
20
+ from google_cloud_pipeline_components._implementation.llm import utils
21
+ from kfp import dsl
22
+
23
+
24
+ def _resolve_image() -> str:
25
+ """Determines the image URI to create a container from."""
26
+ return os.environ.get(
27
+ 'AUTOSXS_IMAGE_OVERRIDE'
28
+ ) or utils.get_default_image_uri('autosxs')
29
+
30
+
31
+ @dsl.container_component
32
+ def autosxs_metrics_computer(
33
+ judgments_dir: str,
34
+ has_human_preference: bool,
35
+ autosxs_metrics: dsl.Output[dsl.Metrics], # pylint: disable=unused-argument # pytype: disable=unsupported-operands
36
+ gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
37
+ ) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
38
+ """Compute AutoSXS metrics using judgments outputs from Arbiter.
39
+
40
+ Args:
41
+ judgments_dir: Path where store the Judgments.
42
+ has_human_preference: Boolean value. True if users provided human preference
43
+ data, otherwise false.
44
+
45
+ Returns:
46
+ autosxs_metrics: Autosxs win rate metrics and human alignment metrics.
47
+ gcp_resources: Tracker for GCP resources created by this component.
48
+ """
49
+ return gcpc_utils.build_serverless_customjob_container_spec(
50
+ project=_placeholders.PROJECT_ID_PLACEHOLDER,
51
+ # Hardcode location to us-central1 for text-bison availability.
52
+ location='us-central1',
53
+ custom_job_payload=utils.build_payload(
54
+ display_name='autosxs_metrics_computer',
55
+ machine_type='n1-standard-4',
56
+ image_uri=_resolve_image(),
57
+ args=[
58
+ '--', # Used to mark the start of component flags.
59
+ 'autosxs_metrics',
60
+ f'--judgments_dir={judgments_dir}',
61
+ f'--has_human_preference={has_human_preference}',
62
+ '--executor_input={{$.json_escape[1]}}',
63
+ ],
64
+ ),
65
+ gcp_resources=gcp_resources,
66
+ )
@@ -75,22 +75,16 @@ def pipeline(
75
75
  'large_model_reference'
76
76
  ]
77
77
  ).set_display_name('Resolve Upload Model')
78
- upload_task = (
79
- upload_llm_model.upload_llm_model(
80
- project=_placeholders.PROJECT_ID_PLACEHOLDER,
81
- location=upload_location,
82
- regional_endpoint=regional_endpoint.output,
83
- artifact_uri=adapter_artifact.output,
84
- model_display_name=display_name.output,
85
- model_reference_name='text-bison@001',
86
- upload_model=upload_model.output,
87
- )
88
- .set_env_variable(
89
- name='VERTEX_AI_PIPELINES_RUN_LABELS',
90
- value=json.dumps({'tune-type': 'rlhf'}),
91
- )
92
- .set_display_name('Upload Model')
93
- )
78
+ upload_task = upload_llm_model.upload_llm_model(
79
+ project=_placeholders.PROJECT_ID_PLACEHOLDER,
80
+ location=upload_location,
81
+ regional_endpoint=regional_endpoint.output,
82
+ artifact_uri=adapter_artifact.output,
83
+ model_display_name=display_name.output,
84
+ model_reference_name='text-bison@001',
85
+ upload_model=upload_model.output,
86
+ tune_type='rlhf',
87
+ ).set_display_name('Upload Model')
94
88
  deploy_model = function_based.resolve_deploy_model(
95
89
  deploy_model=deploy_model,
96
90
  large_model_reference=reference_model_metadata.outputs[
@@ -16,7 +16,7 @@ import os
16
16
 
17
17
 
18
18
  def get_private_image_tag() -> str:
19
- return os.getenv('PRIVATE_IMAGE_TAG', '20231031_0507_RC00')
19
+ return os.getenv('PRIVATE_IMAGE_TAG', '20231213_0507_RC00')
20
20
 
21
21
 
22
22
  def get_use_test_machine_spec() -> bool:
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Python function-based components used in KFP pipelies."""
15
15
  import functools
16
- from typing import List, NamedTuple, Optional
16
+ from typing import Any, Dict, List, NamedTuple, Optional
17
17
 
18
18
  from google_cloud_pipeline_components import _image
19
19
  from google_cloud_pipeline_components._implementation.llm import env
@@ -302,8 +302,8 @@ def resolve_reference_model_metadata(
302
302
  'llama-2-13b': reference_model_metadata(
303
303
  large_model_reference='LLAMA_2_13B',
304
304
  reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b/',
305
- reward_model_reference='LLAMA_2_13B',
306
- reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b/',
305
+ reward_model_reference='LLAMA_2_7B',
306
+ reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b/',
307
307
  is_supported=True,
308
308
  ),
309
309
  'llama-2-7b-chat': reference_model_metadata(
@@ -316,8 +316,8 @@ def resolve_reference_model_metadata(
316
316
  'llama-2-13b-chat': reference_model_metadata(
317
317
  large_model_reference='LLAMA_2_13B_CHAT',
318
318
  reference_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b_chat/',
319
- reward_model_reference='LLAMA_2_13B',
320
- reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_13b/',
319
+ reward_model_reference='LLAMA_2_7B',
320
+ reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/llama/t5x_llama_2_7b/',
321
321
  is_supported=True,
322
322
  ),
323
323
  }
@@ -495,3 +495,80 @@ def resolve_instruction(
495
495
  """
496
496
  instruction = instruction or ''
497
497
  return instruction if 'chat' not in large_model_reference.lower() else ''
498
+
499
+
500
+ @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
501
+ def resolve_num_microbatches(large_model_reference: str) -> int:
502
+ """Resolves the number of microbatches to use during training.
503
+
504
+ Args:
505
+ large_model_reference: Base model tuned by the pipeline.
506
+
507
+ Returns:
508
+ Number of microbatches to break the total batch size into during training.
509
+ """
510
+ if 'llama' in large_model_reference.lower():
511
+ return 2
512
+ return 0
513
+
514
+
515
+ @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
516
+ def read_file(path: str) -> str:
517
+ """Reads the contents of the given file."""
518
+ # pylint: disable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
519
+ import re
520
+ # pylint: enable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
521
+
522
+ path = re.sub('^gs://', '/gcs/', path)
523
+ with open(path, 'r') as f:
524
+ return f.read()
525
+
526
+
527
+ @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
528
+ def get_usage_metric(metadata: Dict[str, Any], key: str) -> bool: # pytype: disable=unsupported-operands
529
+ """Extracts a single usage metric from metadata."""
530
+ return metadata[key]
531
+
532
+
533
+ @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
534
+ def dump_dict(value: Dict[Any, Any]) -> str:
535
+ """Dumps the given dict to a JSON string."""
536
+ # pylint: disable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
537
+ import json
538
+ # pylint: enable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
539
+
540
+ return json.dumps(value).replace('"', '\\"')
541
+
542
+
543
+ @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
544
+ def dump_list(value: List[Any]) -> str:
545
+ """Dumps the given dict to a JSON string."""
546
+ # pylint: disable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
547
+ import json
548
+ # pylint: enable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
549
+
550
+ return json.dumps(value).replace('"', '\\"')
551
+
552
+
553
+ @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
554
+ def identity(
555
+ x: str,
556
+ ) -> str:
557
+ return x
558
+
559
+
560
+ @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
561
+ def get_uri(artifact: dsl.Input[dsl.Artifact], is_dir: bool = False) -> str: # pytype: disable=unsupported-operands
562
+ """Extracts the URI from an artifact."""
563
+ # pylint: disable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
564
+ import os
565
+ # pylint: enable=g-import-not-at-top,import-outside-toplevel,redefined-outer-name,reimported
566
+
567
+ if is_dir:
568
+ return os.path.join(artifact.uri, '*')
569
+ return artifact.uri
570
+
571
+
572
+ @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
573
+ def get_empty_string() -> str:
574
+ return ''
@@ -117,6 +117,11 @@ def pipeline(
117
117
  accelerator_type=machine_spec.outputs['accelerator_type'],
118
118
  accelerator_count=machine_spec.outputs['accelerator_count'],
119
119
  ).set_display_name('Resolve Reinforcer Image URI')
120
+ num_microbatches = function_based.resolve_num_microbatches(
121
+ large_model_reference=reference_model_metadata.outputs[
122
+ 'large_model_reference'
123
+ ]
124
+ ).set_display_name('Resolve Number of Microbatches')
120
125
  rl_model = (
121
126
  reinforcer.Reinforcer(
122
127
  project=project,
@@ -145,6 +150,7 @@ def pipeline(
145
150
  learning_rate_multiplier=reinforcement_learning_rate_multiplier,
146
151
  kl_coeff=kl_coeff,
147
152
  lora_dim=lora_dim,
153
+ num_microbatches=num_microbatches.output,
148
154
  )
149
155
  .set_display_name('Reinforcer')
150
156
  .set_caching_options(False)
@@ -43,6 +43,7 @@ def Reinforcer( # pylint: disable=invalid-name
43
43
  learning_rate_multiplier: float = 1.0,
44
44
  kl_coeff: float = 0.1,
45
45
  lora_dim: int = 0,
46
+ num_microbatches: int = 0,
46
47
  ) -> kfp.dsl.ContainerSpec: # pylint: disable=g-doc-args
47
48
  """Trains a model using reinforcement learning.
48
49
 
@@ -53,8 +54,8 @@ def Reinforcer( # pylint: disable=invalid-name
53
54
  input_reward_model_path: Path to the reward model to use during
54
55
  reinforcement learning.
55
56
  input_dataset_path: Path to training dataset.
56
- train_steps: Number of training steps. These are the number of steps
57
- on top of any steps used to train the base model.
57
+ train_steps: Number of training steps. These are the number of steps on top
58
+ of any steps used to train the base model.
58
59
  targets_length: Maximum decoder steps. Outputs will be at most this length.
59
60
  accelerator_type: Type of TPU accelerator. Can be either TPU_V2 or TPU_V3.
60
61
  accelerator_count: Number of TPU accelerators.
@@ -75,6 +76,9 @@ def Reinforcer( # pylint: disable=invalid-name
75
76
  then use full-tuning.
76
77
  learning_rate_multiplier: Constant multiplied by the base learning rate used
77
78
  to adjust the learning rate during reinforcement learning.
79
+ num_microbatches: Number of microbatches to break the total batch size into
80
+ during training. If <= 1, the model is trained on the full batch size
81
+ directly.
78
82
 
79
83
  Returns:
80
84
  output_model_path: Path to the trained model checkpoint.
@@ -110,6 +114,7 @@ def Reinforcer( # pylint: disable=invalid-name
110
114
  f'--learning_rate_multiplier={learning_rate_multiplier}',
111
115
  f'--kl_coeff={kl_coeff}',
112
116
  f'--lora_dim={lora_dim}',
117
+ f'--num_microbatches={num_microbatches}',
113
118
  ],
114
119
  ),
115
120
  gcp_resources=gcp_resources,
@@ -118,6 +118,11 @@ def pipeline(
118
118
  accelerator_type=machine_spec.outputs['accelerator_type'],
119
119
  accelerator_count=machine_spec.outputs['accelerator_count'],
120
120
  ).set_display_name('Resolve Reward Model Image URI')
121
+ num_microbatches = function_based.resolve_num_microbatches(
122
+ large_model_reference=reference_model_metadata.outputs[
123
+ 'reward_model_reference'
124
+ ]
125
+ ).set_display_name('Resolve Number of Microbatches')
121
126
  reward_model = (
122
127
  reward_model_trainer.RewardModelTrainer(
123
128
  project=project,
@@ -141,6 +146,7 @@ def pipeline(
141
146
  batch_size=batch_size,
142
147
  learning_rate_multiplier=reward_model_learning_rate_multiplier,
143
148
  lora_dim=lora_dim,
149
+ num_microbatches=num_microbatches.output,
144
150
  )
145
151
  .set_display_name('Reward Model Trainer')
146
152
  .set_caching_options(False)
@@ -39,6 +39,7 @@ def RewardModelTrainer( # pylint: disable=invalid-name
39
39
  batch_size: int = 64,
40
40
  learning_rate_multiplier: float = 1.0,
41
41
  lora_dim: int = 0,
42
+ num_microbatches: int = 0,
42
43
  ) -> kfp.dsl.ContainerSpec: # pylint: disable=g-doc-args
43
44
  """Trains a reward model.
44
45
 
@@ -47,8 +48,8 @@ def RewardModelTrainer( # pylint: disable=invalid-name
47
48
  location: Location used to run the job.
48
49
  input_model_path: Path to the base model to fine tune.
49
50
  input_dataset_path: Path to dataset to use to train a reward model.
50
- train_steps: Number of training steps. These are the number of steps
51
- on top of any steps used to train the base model.
51
+ train_steps: Number of training steps. These are the number of steps on top
52
+ of any steps used to train the base model.
52
53
  accelerator_type: Type of TPU accelerator. Can be either TPU_V2 or TPU_V3.
53
54
  accelerator_count: Number of TPU accelerators.
54
55
  large_model_reference: Predefined model used to create the ``input_model``.
@@ -64,6 +65,9 @@ def RewardModelTrainer( # pylint: disable=invalid-name
64
65
  then use full-tuning.
65
66
  learning_rate_multiplier: Constant multiplied by the base learning rate used
66
67
  to adjust the learning rate when training a reward model.
68
+ num_microbatches: Number of microbatches to break the total batch size into
69
+ during training. If <= 1, the model is trained on the full batch size
70
+ directly.
67
71
 
68
72
  Returns:
69
73
  output_model: Trained reward model.
@@ -98,6 +102,7 @@ def RewardModelTrainer( # pylint: disable=invalid-name
98
102
  f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
99
103
  ),
100
104
  f'--lora_dim={lora_dim}',
105
+ f'--num_microbatches={num_microbatches}',
101
106
  ],
102
107
  ),
103
108
  gcp_resources=gcp_resources,
@@ -39,6 +39,7 @@ def SupervisedFineTuner( # pylint: disable=invalid-name
39
39
  batch_size: int = 64,
40
40
  learning_rate_multiplier: float = 1.0,
41
41
  lora_dim: int = 0,
42
+ num_microbatches: int = 0,
42
43
  ) -> kfp.dsl.ContainerSpec: # pylint: disable=g-doc-args
43
44
  """Performs supervised fine tuning.
44
45
 
@@ -65,6 +66,9 @@ def SupervisedFineTuner( # pylint: disable=invalid-name
65
66
  then use full-tuning.
66
67
  learning_rate_multiplier: Constant multiplied by the base learning rate used
67
68
  to adjust the learning rate during supervised fine tuning.
69
+ num_microbatches: Number of microbatches to break the total batch size into
70
+ during training. If <= 1, the model is trained on the full batch size
71
+ directly.
68
72
 
69
73
  Returns:
70
74
  output_model_path: Fine-tuned model path.
@@ -99,6 +103,7 @@ def SupervisedFineTuner( # pylint: disable=invalid-name
99
103
  f'{kfp.dsl.PIPELINE_TASK_ID_PLACEHOLDER}'
100
104
  ),
101
105
  f'--lora_dim={lora_dim}',
106
+ f'--num_microbatches={num_microbatches}',
102
107
  ],
103
108
  ),
104
109
  gcp_resources=gcp_resources,