google-cloud-pipeline-components 2.13.1__py3-none-any.whl → 2.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of google-cloud-pipeline-components might be problematic. Click here for more details.

Files changed (82) hide show
  1. google_cloud_pipeline_components/__init__.py +5 -6
  2. google_cloud_pipeline_components/_implementation/llm/deployment_graph.py +12 -34
  3. google_cloud_pipeline_components/_implementation/llm/env.py +1 -1
  4. google_cloud_pipeline_components/_implementation/llm/function_based.py +14 -48
  5. google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py +1 -1
  6. google_cloud_pipeline_components/_implementation/llm/infer_preprocessor.py +109 -0
  7. google_cloud_pipeline_components/_implementation/llm/online_evaluation_pairwise.py +8 -0
  8. google_cloud_pipeline_components/_implementation/llm/reinforcement_learning_graph.py +27 -36
  9. google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py +31 -47
  10. google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py +84 -0
  11. google_cloud_pipeline_components/_implementation/llm/validate_pipeline.py +11 -0
  12. google_cloud_pipeline_components/_implementation/model_evaluation/__init__.py +0 -12
  13. google_cloud_pipeline_components/_implementation/model_evaluation/llm_embedding/evaluation_llm_embedding_pipeline.py +2 -1
  14. google_cloud_pipeline_components/_placeholders.py +30 -1
  15. google_cloud_pipeline_components/preview/automl/forecasting/forecasting_ensemble.py +1 -1
  16. google_cloud_pipeline_components/preview/automl/forecasting/forecasting_stage_1_tuner.py +2 -2
  17. google_cloud_pipeline_components/preview/automl/forecasting/forecasting_stage_2_tuner.py +2 -2
  18. google_cloud_pipeline_components/preview/automl/forecasting/learn_to_learn_forecasting_pipeline.yaml +34 -34
  19. google_cloud_pipeline_components/preview/automl/forecasting/sequence_to_sequence_forecasting_pipeline.yaml +34 -34
  20. google_cloud_pipeline_components/preview/automl/forecasting/temporal_fusion_transformer_forecasting_pipeline.yaml +34 -34
  21. google_cloud_pipeline_components/preview/automl/forecasting/time_series_dense_encoder_forecasting_pipeline.yaml +34 -34
  22. google_cloud_pipeline_components/preview/automl/tabular/auto_feature_engineering.py +1 -1
  23. google_cloud_pipeline_components/preview/automl/tabular/automl_tabular_feature_selection_pipeline.yaml +39 -39
  24. google_cloud_pipeline_components/preview/automl/tabular/automl_tabular_v2_pipeline.yaml +41 -41
  25. google_cloud_pipeline_components/preview/automl/tabular/distillation_stage_feature_transform_engine.py +2 -2
  26. google_cloud_pipeline_components/preview/automl/tabular/feature_selection.py +2 -2
  27. google_cloud_pipeline_components/preview/automl/tabular/feature_selection_pipeline.yaml +4 -4
  28. google_cloud_pipeline_components/preview/automl/tabular/feature_transform_engine.py +3 -3
  29. google_cloud_pipeline_components/preview/automl/tabular/tabnet_hyperparameter_tuning_job.py +2 -2
  30. google_cloud_pipeline_components/preview/automl/tabular/tabnet_hyperparameter_tuning_job_pipeline.yaml +17 -17
  31. google_cloud_pipeline_components/preview/automl/tabular/tabnet_trainer.py +2 -2
  32. google_cloud_pipeline_components/preview/automl/tabular/tabnet_trainer_pipeline.yaml +15 -15
  33. google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_hyperparameter_tuning_job.py +2 -2
  34. google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_hyperparameter_tuning_job_pipeline.yaml +16 -16
  35. google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_trainer.py +2 -2
  36. google_cloud_pipeline_components/preview/automl/tabular/wide_and_deep_trainer_pipeline.yaml +15 -15
  37. google_cloud_pipeline_components/preview/automl/tabular/xgboost_hyperparameter_tuning_job_pipeline.yaml +14 -14
  38. google_cloud_pipeline_components/preview/automl/tabular/xgboost_trainer_pipeline.yaml +13 -13
  39. google_cloud_pipeline_components/preview/automl/vision/data_converter.py +3 -1
  40. google_cloud_pipeline_components/preview/custom_job/component.py +2 -2
  41. google_cloud_pipeline_components/preview/custom_job/utils.py +3 -2
  42. google_cloud_pipeline_components/preview/llm/infer/component.py +22 -25
  43. google_cloud_pipeline_components/preview/llm/rlhf/component.py +72 -10
  44. google_cloud_pipeline_components/preview/model_evaluation/__init__.py +5 -2
  45. google_cloud_pipeline_components/preview/model_evaluation/model_evaluation_import_component.py +209 -0
  46. google_cloud_pipeline_components/proto/task_error_pb2.py +33 -0
  47. google_cloud_pipeline_components/proto/template_metadata_pb2.py +22 -15
  48. google_cloud_pipeline_components/v1/automl/forecasting/bqml_arima_predict_pipeline.yaml +10 -10
  49. google_cloud_pipeline_components/v1/automl/forecasting/bqml_arima_train_pipeline.yaml +31 -31
  50. google_cloud_pipeline_components/v1/automl/forecasting/prophet_predict_pipeline.yaml +13 -13
  51. google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer.py +13 -3
  52. google_cloud_pipeline_components/v1/automl/forecasting/prophet_trainer_pipeline.yaml +18 -15
  53. google_cloud_pipeline_components/v1/automl/tabular/automl_tabular_pipeline.yaml +37 -37
  54. google_cloud_pipeline_components/v1/automl/tabular/cv_trainer.py +2 -2
  55. google_cloud_pipeline_components/v1/automl/tabular/ensemble.py +2 -2
  56. google_cloud_pipeline_components/v1/automl/tabular/finalizer.py +1 -1
  57. google_cloud_pipeline_components/v1/automl/tabular/infra_validator.py +1 -1
  58. google_cloud_pipeline_components/v1/automl/tabular/split_materialized_data.py +1 -1
  59. google_cloud_pipeline_components/v1/automl/tabular/stage_1_tuner.py +2 -2
  60. google_cloud_pipeline_components/v1/automl/tabular/stats_and_example_gen.py +2 -2
  61. google_cloud_pipeline_components/v1/automl/tabular/training_configurator_and_validator.py +1 -1
  62. google_cloud_pipeline_components/v1/automl/tabular/transform.py +2 -2
  63. google_cloud_pipeline_components/v1/model_evaluation/__init__.py +3 -1
  64. google_cloud_pipeline_components/v1/model_evaluation/classification_component.py +2 -2
  65. google_cloud_pipeline_components/v1/model_evaluation/error_analysis_pipeline.py +8 -10
  66. google_cloud_pipeline_components/v1/model_evaluation/evaluated_annotation_pipeline.py +2 -2
  67. google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_tabular_feature_attribution_pipeline.py +2 -2
  68. google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_tabular_pipeline.py +2 -2
  69. google_cloud_pipeline_components/v1/model_evaluation/evaluation_automl_unstructure_data_pipeline.py +2 -2
  70. google_cloud_pipeline_components/v1/model_evaluation/evaluation_feature_attribution_pipeline.py +2 -2
  71. google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_classification_pipeline.py +4 -2
  72. google_cloud_pipeline_components/v1/model_evaluation/evaluation_llm_text_generation_pipeline.py +4 -2
  73. google_cloud_pipeline_components/{preview → v1}/model_evaluation/model_based_llm_evaluation/__init__.py +2 -2
  74. google_cloud_pipeline_components/{preview → v1}/model_evaluation/model_based_llm_evaluation/autosxs/autosxs_pipeline.py +1 -0
  75. google_cloud_pipeline_components/version.py +1 -1
  76. {google_cloud_pipeline_components-2.13.1.dist-info → google_cloud_pipeline_components-2.14.1.dist-info}/METADATA +18 -19
  77. {google_cloud_pipeline_components-2.13.1.dist-info → google_cloud_pipeline_components-2.14.1.dist-info}/RECORD +81 -79
  78. {google_cloud_pipeline_components-2.13.1.dist-info → google_cloud_pipeline_components-2.14.1.dist-info}/WHEEL +1 -1
  79. google_cloud_pipeline_components/proto/preflight_validations_pb2.py +0 -47
  80. /google_cloud_pipeline_components/{preview → v1}/model_evaluation/model_based_llm_evaluation/autosxs/__init__.py +0 -0
  81. {google_cloud_pipeline_components-2.13.1.dist-info → google_cloud_pipeline_components-2.14.1.dist-info}/LICENSE +0 -0
  82. {google_cloud_pipeline_components-2.13.1.dist-info → google_cloud_pipeline_components-2.14.1.dist-info}/top_level.txt +0 -0
@@ -17,14 +17,13 @@ import warnings
17
17
 
18
18
  from google_cloud_pipeline_components.version import __version__
19
19
 
20
- if sys.version_info < (3, 8):
20
+ if sys.version_info < (3, 9):
21
21
  warnings.warn(
22
22
  (
23
- 'Python 3.7 has reached end-of-life. Google Cloud Pipeline Components'
24
- ' will drop support for Python 3.7 on April 23, 2024. To use new'
25
- ' versions of the KFP SDK after that date, you will need to upgrade'
26
- ' to Python >= 3.8. See https://devguide.python.org/versions/ for'
27
- ' more details.'
23
+ ' Google Cloud Pipeline Components will drop support for Python 3.8'
24
+ ' on Oct 1, 2024. To use new versions of the GCPC SDK after that'
25
+ ' date, you will need to upgrade to Python >= 3.9. See'
26
+ ' https://devguide.python.org/versions/ for more details.'
28
27
  ),
29
28
  FutureWarning,
30
29
  stacklevel=2,
@@ -34,10 +34,13 @@ PipelineOutput = NamedTuple(
34
34
  def pipeline(
35
35
  output_adapter_path: str,
36
36
  large_model_reference: str,
37
+ policy_model_reference: str,
37
38
  model_display_name: Optional[str] = None,
38
39
  deploy_model: bool = True,
40
+ upload_model: bool = True,
39
41
  encryption_spec_key_name: str = '',
40
42
  upload_location: str = _placeholders.LOCATION_PLACEHOLDER,
43
+ regional_endpoint: str = '',
41
44
  ) -> PipelineOutput:
42
45
  # fmt: off
43
46
  """Uploads a tuned language model and (optionally) deploys it to an endpoint.
@@ -45,62 +48,37 @@ def pipeline(
45
48
  Args:
46
49
  output_adapter_path: Path to the trained model adapter if LoRA tuning was used.
47
50
  large_model_reference: Name of the base model. Supported values are `text-bison@001`, `t5-small`, `t5-large`, `t5-xl` and `t5-xxl`. `text-bison@001` and `t5-small` are supported in `us-central1` and `europe-west4`. `t5-large`, `t5-xl` and `t5-xxl` are only supported in `europe-west4`.
51
+ policy_model_reference: The name of the model for deployment. The name should be in capitalized snake case format.
48
52
  model_display_name: Name of the fine-tuned model shown in the Model Registry. If not provided, a default name will be created.
49
53
  deploy_model: Whether to deploy the model to an endpoint in `us-central1`. Default is True.
50
54
  encryption_spec_key_name: Customer-managed encryption key. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key. Note that this is not supported for TPU at the moment.
51
55
  upload_location: Region to upload and deploy the model to. Default is the location used to run the pipeline components.
56
+ regional_endpoint: Regional endpoint to upload the model.
52
57
 
53
58
  Returns:
54
59
  model_resource_name: Path to the model uploaded to the Model Registry. This will be an empty string if the model was not deployed.
55
60
  endpoint_resource_name: Path the Online Prediction Endpoint. This will be an empty string if the model was not deployed.
56
61
  """
57
62
  # fmt: on
58
- regional_endpoint = function_based.resolve_regional_endpoint(
59
- upload_location=upload_location
60
- ).set_display_name('Resolve Regional Endpoint')
61
-
62
- display_name = (
63
- function_based.resolve_model_display_name(
64
- large_model_reference=large_model_reference,
65
- model_display_name=model_display_name,
66
- )
67
- .set_caching_options(False)
68
- .set_display_name('Resolve Model Display Name')
69
- )
70
-
71
- reference_model_metadata = function_based.resolve_reference_model_metadata(
72
- large_model_reference=large_model_reference,
73
- ).set_display_name('Resolve Model Metadata')
74
-
75
- upload_model = function_based.resolve_upload_model(
76
- large_model_reference=reference_model_metadata.outputs[
77
- 'large_model_reference'
78
- ]
79
- ).set_display_name('Resolve Upload Model')
80
63
  upload_task = upload_llm_model.refined_upload_llm_model(
81
64
  project=_placeholders.PROJECT_ID_PLACEHOLDER,
82
65
  location=upload_location,
83
- regional_endpoint=regional_endpoint.output,
66
+ regional_endpoint=regional_endpoint,
84
67
  artifact_uri=output_adapter_path,
85
- model_display_name=display_name.output,
68
+ model_display_name=model_display_name,
86
69
  model_reference_name=large_model_reference,
87
- upload_model=upload_model.output,
70
+ upload_model=upload_model,
88
71
  encryption_spec_key_name=encryption_spec_key_name,
89
72
  tune_type='rlhf',
90
73
  ).set_display_name('Upload Model')
91
- deploy_model = function_based.resolve_deploy_model(
92
- deploy_model=deploy_model,
93
- large_model_reference=reference_model_metadata.outputs[
94
- 'large_model_reference'
95
- ],
96
- ).set_display_name('Resolve Deploy Model')
74
+
97
75
  deploy_task = deploy_llm_model.deploy_llm_model(
98
76
  project=_placeholders.PROJECT_ID_PLACEHOLDER,
99
77
  location=upload_location,
100
78
  model_resource_name=upload_task.outputs['model_resource_name'],
101
- display_name=display_name.output,
102
- regional_endpoint=regional_endpoint.output,
103
- deploy_model=deploy_model.output,
79
+ display_name=model_display_name,
80
+ regional_endpoint=regional_endpoint,
81
+ deploy_model=deploy_model,
104
82
  encryption_spec_key_name=encryption_spec_key_name,
105
83
  ).set_display_name('Deploy Model')
106
84
  return PipelineOutput(
@@ -19,7 +19,7 @@ from google_cloud_pipeline_components._implementation.llm.generated import refin
19
19
 
20
20
 
21
21
  def get_private_image_tag() -> str:
22
- return os.getenv('PRIVATE_IMAGE_TAG') or '20240330_0352_RC00'
22
+ return os.getenv('PRIVATE_IMAGE_TAG') or refined_image_versions.IMAGE_TAG
23
23
 
24
24
 
25
25
  def get_autosxs_image_tag() -> str:
@@ -231,8 +231,8 @@ def resolve_reference_model_metadata(
231
231
  'gs://vertex-llm-restricted/cloud-llm-restricted/checkpoints/'
232
232
  'safe_flan_t5/xxl/v1/checkpoint_1190000/'
233
233
  ),
234
- reward_model_reference='T5_XL',
235
- reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_xl',
234
+ reward_model_reference='T5_XXL',
235
+ reward_model_path='gs://t5-data/pretrained_models/t5x/t5_1_1_xxl',
236
236
  is_supported=True,
237
237
  ),
238
238
  'palm-tiny': reference_model_metadata(
@@ -265,8 +265,10 @@ def resolve_reference_model_metadata(
265
265
  reference_model_path=(
266
266
  'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
267
267
  ),
268
- reward_model_reference='OTTER',
269
- reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
268
+ reward_model_reference='BISON',
269
+ reward_model_path=(
270
+ 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
271
+ ),
270
272
  is_supported=False, # Deprecated: Use text-bision@001 instead.
271
273
  ),
272
274
  'text-bison@001': reference_model_metadata(
@@ -274,8 +276,10 @@ def resolve_reference_model_metadata(
274
276
  reference_model_path=(
275
277
  'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
276
278
  ),
277
- reward_model_reference='OTTER',
278
- reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
279
+ reward_model_reference='BISON',
280
+ reward_model_path=(
281
+ 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
282
+ ),
279
283
  is_supported=True,
280
284
  ),
281
285
  'text-bison@002': reference_model_metadata(
@@ -292,8 +296,10 @@ def resolve_reference_model_metadata(
292
296
  reference_model_path=(
293
297
  'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
294
298
  ),
295
- reward_model_reference='OTTER',
296
- reward_model_path='gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_otter_pretrain/',
299
+ reward_model_reference='BISON',
300
+ reward_model_path=(
301
+ 'gs://vertex-rlhf-restricted/pretrained_models/palm/t5x_bison/'
302
+ ),
297
303
  is_supported=True,
298
304
  ),
299
305
  'elephant': reference_model_metadata(
@@ -372,46 +378,6 @@ def convert_to_delimited_string(items: List[str], delimiter: str = ',') -> str:
372
378
  return delimiter.join(items)
373
379
 
374
380
 
375
- @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
376
- def generate_default_instruction(
377
- task: str,
378
- target_sequence_length: int,
379
- instruction_override: str = '',
380
- ) -> str:
381
- """Generates a default instruction if no override is provided."""
382
- if instruction_override:
383
- return instruction_override
384
- task = task.lower()
385
- if task == 'summarization':
386
- return f'Summarize in less than {target_sequence_length} words.'
387
-
388
- elif task == 'question_answer':
389
- return f'Answer the question in less than {target_sequence_length} words.'
390
-
391
- else:
392
- raise ValueError(
393
- f'Task not recognized: {task}. Supported tasks are: "summarization",'
394
- ' "question_answer".'
395
- )
396
-
397
-
398
- @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
399
- def resolve_upload_location(upload_location: Optional[str] = None) -> str:
400
- """Gets the region to upload the model.
401
-
402
- Args:
403
- upload_location: User-specified region to upload the model to.
404
-
405
- Returns:
406
- Where to upload the model. If no location is specified, the model will be
407
- uploaded to the region where the pipeline is running.
408
- """
409
- # pylint: disable=g-import-not-at-top
410
- import os
411
- # pylint: enable=g-import-not-at-top
412
- return upload_location or os.environ['CLOUD_ML_REGION']
413
-
414
-
415
381
  @dsl.component(base_image=_image.GCPC_IMAGE_TAG, install_kfp_package=False)
416
382
  def resolve_regional_endpoint(upload_location: str) -> str:
417
383
  """Gets the regional endpoint used to upload a model to the registry.
@@ -17,4 +17,4 @@
17
17
  DO NOT EDIT - This file is generated, manual changes will be overridden.
18
18
  """
19
19
 
20
- IMAGE_TAG = '20240327_1338'
20
+ IMAGE_TAG = '20240506_1707'
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Component that preprocesses inputs for infer pipeline."""
15
+
16
+ from google_cloud_pipeline_components import _placeholders
17
+ from google_cloud_pipeline_components import utils as gcpc_utils
18
+ from google_cloud_pipeline_components._implementation.llm import utils
19
+ from kfp import dsl
20
+
21
+
22
+ @dsl.container_component
23
+ def infer_preprocessor(
24
+ large_model_reference: str,
25
+ accelerator_type: str,
26
+ use_test_spec: bool,
27
+ project: str,
28
+ location: str,
29
+ artifact_registry: str,
30
+ tag: str,
31
+ gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
32
+ metadata_large_model_reference: dsl.OutputPath(str), # pytype: disable=invalid-annotation
33
+ metadata_reference_model_path: dsl.OutputPath(str), # pytype: disable=invalid-annotation
34
+ metadata_reward_model_reference: dsl.OutputPath(str), # pytype: disable=invalid-annotation
35
+ metadata_reward_model_path: dsl.OutputPath(str), # pytype: disable=invalid-annotation
36
+ metadata_machine_type: dsl.OutputPath(str), # pytype: disable=invalid-annotation
37
+ metadata_tuning_location: dsl.OutputPath(str), # pytype: disable=invalid-annotation
38
+ metadata_accelerator_type: dsl.OutputPath(str), # pytype: disable=invalid-annotation
39
+ metadata_accelerator_count: dsl.OutputPath(int), # pytype: disable=invalid-annotation
40
+ metadata_instruction: dsl.OutputPath(str), # pytype: disable=invalid-annotation
41
+ metadata_refined_image_uri: dsl.OutputPath(str), # pytype: disable=invalid-annotation
42
+ use_experimental_image: bool = False,
43
+ input_reference_model_path: str = '',
44
+ instruction: str = '',
45
+ image_uri: str = utils.get_default_image_uri('refined_cpu', ''),
46
+ ) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
47
+ # fmt: off
48
+ """Preprocess infer pipeline inputs.
49
+
50
+ Args:
51
+ large_model_reference: The model for fine tuning.
52
+ accelerator_type: Specific accelerator type for the job.
53
+ use_test_spec: Whether to use a lower resource machine for testing.
54
+ project: Project that contains the artifact registry.
55
+ location: Region that contains the artifact registry.
56
+ artifact_registry: Registry that contains Docker images.
57
+ tag: Image tag.
58
+ use_experimental_image: Whether to use refined experimental image.
59
+ input_reference_model_path: The model checkpoint path for the reference model
60
+ instruction: The instruction to let the model know what task it needs to perform.
61
+ image_uri: Docker image URI to use for the custom job.
62
+
63
+ Returns:
64
+ gcp_resources: GCP resources that can be used to track the custom job.
65
+ metadata_large_model_reference: The base model for fine tuning. The name should be in capitalized snake case format.
66
+ metadata_reference_model_path: The model checkpoint path for the reinforcer model
67
+ metadata_reward_model_reference: The base model for training reward model. The name should be in capitalized snake case format.
68
+ metadata_reward_model_path: The model checkpoint path for the reward model.
69
+ metadata_machine_type: The type of the machine to provision for the custom job.
70
+ metadata_tuning_location: The GCP region to run the custom job.
71
+ metadata_accelerator_type: Specific accelerator type for the custom job.
72
+ metadata_accelerator_count: The number of accelerator.
73
+ metadata_instruction: The instruction to let the model know what task it needs to perform.
74
+ metadata_refined_image_uri: Docker image URI to use for the custom job.
75
+ """
76
+ # fmt: on
77
+ return gcpc_utils.build_serverless_customjob_container_spec(
78
+ project=_placeholders.PROJECT_ID_PLACEHOLDER,
79
+ location=_placeholders.LOCATION_PLACEHOLDER,
80
+ custom_job_payload=utils.build_payload(
81
+ display_name='infer_preprocessor',
82
+ machine_type='n1-standard-4',
83
+ image_uri=image_uri,
84
+ args=[
85
+ '--app_name=infer_preprocessor',
86
+ f'--large_model_reference={large_model_reference}',
87
+ f'--input_reference_model_path={input_reference_model_path}',
88
+ f'--accelerator_type={accelerator_type}',
89
+ f'--use_test_spec={use_test_spec}',
90
+ f'--project={project}',
91
+ f'--location={location}',
92
+ f'--artifact_registry={artifact_registry}',
93
+ f'--tag={tag}',
94
+ f'--use_experimental_image={use_experimental_image}',
95
+ f'--instruction={instruction}',
96
+ f'--metadata_large_model_reference_path={metadata_large_model_reference}',
97
+ f'--metadata_reference_model_path_path={metadata_reference_model_path}',
98
+ f'--metadata_reward_model_reference_path={metadata_reward_model_reference}',
99
+ f'--metadata_reward_model_path_path={metadata_reward_model_path}',
100
+ f'--metadata_machine_type_path={metadata_machine_type}',
101
+ f'--metadata_tuning_location_path={metadata_tuning_location}',
102
+ f'--metadata_accelerator_type_path={metadata_accelerator_type}',
103
+ f'--metadata_accelerator_count_path={metadata_accelerator_count}',
104
+ f'--metadata_instruction_path={metadata_instruction}',
105
+ f'--metadata_refined_image_uri_path={metadata_refined_image_uri}',
106
+ ],
107
+ ),
108
+ gcp_resources=gcp_resources,
109
+ )
@@ -52,6 +52,7 @@ def online_evaluation_pairwise(
52
52
  project: str = _placeholders.PROJECT_ID_PLACEHOLDER,
53
53
  location: str = _placeholders.LOCATION_PLACEHOLDER,
54
54
  encryption_spec_key_name: str = '',
55
+ autorater_prompt_parameters: Dict[str, Dict[str, str]] = {},
55
56
  ) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
56
57
  """Evaluate two models using an autorater.
57
58
 
@@ -73,6 +74,8 @@ def online_evaluation_pairwise(
73
74
  encryption_spec_key_name: Customer-managed encryption key options. If this
74
75
  is set, then all resources created by the component will be encrypted with
75
76
  the provided encryption key.
77
+ autorater_prompt_parameters: Map of autorater prompt template parameters to
78
+ columns or templates.
76
79
 
77
80
  Returns:
78
81
  judgments: Individual judgments used to calculate the win rates.
@@ -112,6 +115,11 @@ def online_evaluation_pairwise(
112
115
  '--executor_input={{$.json_escape[1]}}',
113
116
  f'--kms_key_name={encryption_spec_key_name}',
114
117
  f'--metadata_path={metadata}',
118
+ (
119
+ '--autorater_prompt_parameters='
120
+ "{{$.inputs.parameters['autorater_prompt_parameters']"
121
+ '.json_escape[0]}}'
122
+ ),
115
123
  ],
116
124
  encryption_spec_key_name=encryption_spec_key_name,
117
125
  ),
@@ -41,6 +41,14 @@ def pipeline(
41
41
  input_reward_adapter_path: str,
42
42
  input_preference_dataset_path: str,
43
43
  large_model_reference: str,
44
+ reward_model_reference: str,
45
+ policy_model_reference: str,
46
+ policy_model_path: str,
47
+ machine_type: str,
48
+ tuning_location: str,
49
+ accelerator_type: str,
50
+ accelerator_count: int,
51
+ rl_image_uri: str,
44
52
  prompt_sequence_length: int = 512,
45
53
  target_sequence_length: int = 64,
46
54
  lora_dim: int = 1,
@@ -51,10 +59,10 @@ def pipeline(
51
59
  kl_coeff: float = 0.1,
52
60
  instruction: Optional[str] = None,
53
61
  project: str = _placeholders.PROJECT_ID_PLACEHOLDER,
54
- accelerator_type: str = 'GPU',
55
62
  location: str = _placeholders.LOCATION_PLACEHOLDER,
56
63
  tensorboard_resource_id: str = '',
57
64
  encryption_spec_key_name: str = '',
65
+ num_microbatches: int = 0,
58
66
  ) -> PipelineOutput:
59
67
  # fmt: off
60
68
  """Trains a reward model.
@@ -64,6 +72,14 @@ def pipeline(
64
72
  input_reward_adapter_path: Path to the reward LoRA adapter to use during reinforcement learning.
65
73
  input_preference_dataset_path: Path to preference dataset used by the reward model.
66
74
  large_model_reference: Name of the base model. Supported values are `text-bison@001`, `t5-small`, `t5-large`, `t5-xl` and `t5-xxl`. `text-bison@001` and `t5-small` are supported in `us-central1` and `europe-west4`. `t5-large`, `t5-xl` and `t5-xxl` are only supported in `europe-west4`.
75
+ reward_model_reference: Name of the reward model. The name should be in capitalized snake case format.
76
+ policy_model_reference: Name of the policy model. The name should be in capitalized snake case format.
77
+ policy_model_path: The model checkpoint path to the reinforcer model.
78
+ machine_type: The type of the machine to provision for the custom job. Must be a valid GCE instance type and compatible with the accelerator type.
79
+ tuning_location: The GCP region to run the custom job.
80
+ accelerator_type: Specific accelerator type for the custom job.
81
+ accelerator_count: The number of accelerator.
82
+ rl_image_uri: Docker image URI to use for the reinforcement learning training job.
67
83
  prompt_sequence_length: Maximum tokenized sequence length for input text. Higher values increase memory overhead. This value should be at most 8192. Default value is 512.
68
84
  target_sequence_length: Maximum tokenized sequence length for target text. Higher values increase memory overhead. This value should be at most 1024. Default value is 64.
69
85
  lora_dim: The rank of the LoRA adapter. If >0, then use LoRA-tuning. If =0, then use full-tuning. Default is 1.
@@ -74,7 +90,6 @@ def pipeline(
74
90
  kl_coeff: Coefficient for KL penalty. This regularizes the policy model and penalizes if it diverges from its initial distribution. If set to 0, the reference language model is not loaded into memory. Default value is 0.1.
75
91
  instruction: This field lets the model know what task it needs to perform. Base models have been trained over a large set of varied instructions. You can give a simple and intuitive description of the task and the model will follow it, e.g. "Classify this movie review as positive or negative" or "Translate this sentence to Danish". Do not specify this if your dataset already prepends the instruction to the inputs field.
76
92
  project: Project used to run custom jobs. If not specified the project used to run the pipeline will be used.
77
- accelerator_type: One of 'TPU' or 'GPU'. If 'TPU' is specified, tuning components run in europe-west4. Otherwise tuning components run in us-central1 on GPUs. Default is 'GPU'.
78
93
  location: Location used to run non-tuning components, i.e. components that do not require accelerators. If not specified the location used to run the pipeline will be used.
79
94
  tensorboard_resource_id: Optional tensorboard resource id in format `projects/{project_number}/locations/{location}/tensorboards/{tensorboard_id}`. If provided, tensorboard metrics will be uploaded to this location.
80
95
  encryption_spec_key_name: Customer-managed encryption key. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key. Note that this is not supported for TPU at the moment.
@@ -85,14 +100,6 @@ def pipeline(
85
100
  """
86
101
  # fmt: on
87
102
  prompt_column = 'input_text'
88
- machine_spec = function_based.resolve_machine_spec(
89
- accelerator_type=accelerator_type,
90
- use_test_spec=env.get_use_test_machine_spec(),
91
- ).set_display_name('Resolve Machine Spec')
92
-
93
- reference_model_metadata = function_based.resolve_reference_model_metadata(
94
- large_model_reference=large_model_reference,
95
- ).set_display_name('Resolve Model Metadata')
96
103
 
97
104
  processed_dataset = preprocess_chat_dataset.preprocess_chat_dataset(
98
105
  large_model_reference=large_model_reference,
@@ -109,30 +116,18 @@ def pipeline(
109
116
  # Target field name does not matter because this field is not used.
110
117
  targets_field_name='non_existent_targets_field_name',
111
118
  output_split_name=env.TRAIN_SPLIT,
112
- large_model_reference=reference_model_metadata.outputs[
113
- 'large_model_reference'
114
- ],
119
+ large_model_reference=policy_model_reference,
115
120
  instruction=instruction,
116
121
  encryption_spec_key_name=encryption_spec_key_name,
117
122
  )
118
123
  .set_display_name('Import Prompt Dataset')
119
124
  .set_caching_options(False)
120
125
  )
121
- rl_image_uri = function_based.resolve_private_refined_image_uri(
122
- accelerator_type=machine_spec.outputs['accelerator_type'],
123
- ).set_display_name('Resolve Reinforcer Image URI')
124
- num_microbatches = function_based.resolve_num_microbatches(
125
- large_model_reference=reference_model_metadata.outputs[
126
- 'large_model_reference'
127
- ]
128
- ).set_display_name('Resolve Number of Microbatches')
129
126
  rl_model = (
130
127
  reinforcer.reinforcer(
131
128
  project=project,
132
- location=machine_spec.outputs['tuning_location'],
133
- input_reference_model_path=reference_model_metadata.outputs[
134
- 'reference_model_path'
135
- ],
129
+ location=tuning_location,
130
+ input_reference_model_path=policy_model_path,
136
131
  input_reward_model_path=input_reward_model_path,
137
132
  input_reward_adapter_path=input_reward_adapter_path,
138
133
  input_dataset_path=prompt_dataset_importer.outputs[
@@ -140,16 +135,12 @@ def pipeline(
140
135
  ],
141
136
  input_preference_dataset_path=input_preference_dataset_path,
142
137
  train_steps=reinforcement_learning_train_steps,
143
- accelerator_type=machine_spec.outputs['accelerator_type'],
144
- accelerator_count=machine_spec.outputs['accelerator_count'],
145
- large_model_reference=reference_model_metadata.outputs[
146
- 'large_model_reference'
147
- ],
148
- reward_model_reference=reference_model_metadata.outputs[
149
- 'reward_model_reference'
150
- ],
151
- machine_type=machine_spec.outputs['machine_type'],
152
- image_uri=rl_image_uri.output,
138
+ accelerator_type=accelerator_type,
139
+ accelerator_count=accelerator_count,
140
+ large_model_reference=policy_model_reference,
141
+ reward_model_reference=reward_model_reference,
142
+ machine_type=machine_type,
143
+ image_uri=rl_image_uri,
153
144
  inputs_sequence_length=prompt_sequence_length,
154
145
  targets_sequence_length=target_sequence_length,
155
146
  batch_size=batch_size,
@@ -157,7 +148,7 @@ def pipeline(
157
148
  kl_coeff=kl_coeff,
158
149
  lora_dim=lora_dim,
159
150
  reward_lora_dim=reward_lora_dim,
160
- num_microbatches=num_microbatches.output,
151
+ num_microbatches=num_microbatches,
161
152
  encryption_spec_key_name=encryption_spec_key_name,
162
153
  tensorboard_resource_id=tensorboard_resource_id,
163
154
  )