google-cloud-pipeline-components 2.6.0__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. google_cloud_pipeline_components/_implementation/llm/arbiter_preprocess.py +137 -0
  2. google_cloud_pipeline_components/_implementation/llm/autosxs_arbiter.py +105 -0
  3. google_cloud_pipeline_components/_implementation/llm/autosxs_metrics_computer.py +66 -0
  4. google_cloud_pipeline_components/_implementation/llm/deployment_graph.py +10 -16
  5. google_cloud_pipeline_components/_implementation/llm/env.py +1 -1
  6. google_cloud_pipeline_components/_implementation/llm/function_based.py +82 -5
  7. google_cloud_pipeline_components/_implementation/llm/reinforcement_learning_graph.py +6 -0
  8. google_cloud_pipeline_components/_implementation/llm/reinforcer.py +7 -2
  9. google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py +6 -0
  10. google_cloud_pipeline_components/_implementation/llm/reward_model_trainer.py +7 -2
  11. google_cloud_pipeline_components/_implementation/llm/supervised_fine_tuner.py +5 -0
  12. google_cloud_pipeline_components/_implementation/llm/task_preprocess.py +97 -0
  13. google_cloud_pipeline_components/_implementation/llm/upload_llm_model.py +5 -0
  14. google_cloud_pipeline_components/_implementation/model_evaluation/__init__.py +4 -0
  15. google_cloud_pipeline_components/_implementation/model_evaluation/endpoint_batch_predict/component.py +1 -1
  16. google_cloud_pipeline_components/_implementation/model_evaluation/import_evaluation/component.py +10 -0
  17. google_cloud_pipeline_components/_implementation/model_evaluation/llm_embedding/evaluation_llm_embedding_pipeline.py +64 -15
  18. google_cloud_pipeline_components/_implementation/model_evaluation/llm_evaluation/component.py +9 -2
  19. google_cloud_pipeline_components/_implementation/model_evaluation/model_inference/__init__.py +14 -0
  20. google_cloud_pipeline_components/_implementation/model_evaluation/model_inference/component.py +324 -0
  21. google_cloud_pipeline_components/_implementation/model_evaluation/version.py +2 -2
  22. google_cloud_pipeline_components/container/_implementation/model_evaluation/import_model_evaluation.py +8 -0
  23. google_cloud_pipeline_components/container/v1/automl_training_job/__init__.py +14 -0
  24. google_cloud_pipeline_components/container/v1/automl_training_job/image/__init__.py +14 -0
  25. google_cloud_pipeline_components/container/v1/automl_training_job/image/launcher.py +236 -0
  26. google_cloud_pipeline_components/container/v1/automl_training_job/image/remote_runner.py +250 -0
  27. google_cloud_pipeline_components/preview/model_evaluation/evaluation_llm_text_generation_pipeline.py +6 -1
  28. google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/__init__.py +20 -0
  29. google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/autosxs/__init__.py +13 -0
  30. google_cloud_pipeline_components/preview/model_evaluation/model_based_llm_evaluation/autosxs/autosxs_pipeline.py +234 -0
  31. google_cloud_pipeline_components/version.py +1 -1
  32. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/METADATA +1 -1
  33. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/RECORD +36 -23
  34. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/LICENSE +0 -0
  35. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/WHEEL +0 -0
  36. {google_cloud_pipeline_components-2.6.0.dist-info → google_cloud_pipeline_components-2.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,324 @@
1
+ # Copyright 2023 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Third party inference component."""
15
+ from typing import Any, Dict, List, NamedTuple
16
+
17
+ from google_cloud_pipeline_components import utils as gcpc_utils
18
+ from google_cloud_pipeline_components._implementation.model_evaluation import LLMEvaluationTextGenerationOp
19
+ from google_cloud_pipeline_components._implementation.model_evaluation import utils
20
+ from kfp.dsl import Artifact
21
+ from kfp.dsl import container_component
22
+ from kfp.dsl import Metrics
23
+ from kfp.dsl import Output
24
+ from kfp.dsl import OutputPath
25
+ from kfp.dsl import pipeline
26
+
27
+
28
+ _IMAGE_URI = 'gcr.io/model-evaluation-dev/llm_eval:clyu-test'
29
+
30
+
31
+ @container_component
32
+ def model_inference_component_internal(
33
+ gcp_resources: OutputPath(str),
34
+ gcs_output_directory: Output[Artifact],
35
+ project: str,
36
+ location: str,
37
+ client_api_key_path: str,
38
+ prediction_instances_source_uri: str,
39
+ output_inference_gcs_prefix: str,
40
+ inference_platform: str = 'openai_chat_completions',
41
+ model_id: str = 'gpt-3.5-turbo',
42
+ request_params: Dict[str, Any] = {},
43
+ max_request_per_second: float = 3,
44
+ max_tokens_per_minute: float = 100,
45
+ display_name: str = 'third-party-inference',
46
+ machine_type: str = 'e2-highmem-16',
47
+ service_account: str = '',
48
+ network: str = '',
49
+ reserved_ip_ranges: List[str] = [],
50
+ encryption_spec_key_name: str = '',
51
+ ):
52
+ """Internal component to run Third Party Model Inference.
53
+
54
+ Args:
55
+ gcp_resources (str): Serialized gcp_resources proto tracking the custom
56
+ job.
57
+ model_inference_output_gcs_uri: The storage URI pointing toward a GCS
58
+ location to store CSV for third party inference.
59
+ project: Required. The GCP project that runs the pipeline component.
60
+ location: Required. The GCP region that runs the pipeline component.
61
+ client_api_key_path: The GCS URI where client API key.
62
+ output_inference_gcs_prefix: GCS file prefix for writing output.
63
+ display_name: display name of the pipeline.
64
+ machine_type: The machine type of this custom job. If not set, defaulted
65
+ to `e2-highmem-16`. More details:
66
+ https://cloud.google.com/compute/docs/machine-resource
67
+ service_account: Sets the default service account for workload run-as
68
+ account. The service account running the pipeline
69
+ (https://cloud.google.com/vertex-ai/docs/pipelines/configure-project#service-account)
70
+ submitting jobs must have act-as permission on this run-as account. If
71
+ unspecified, the Vertex AI Custom Code Service
72
+ Agent(https://cloud.google.com/vertex-ai/docs/general/access-control#service-agents)
73
+ for the CustomJob's project.
74
+ network: The full name of the Compute Engine network to which the job
75
+ should be peered. For example, projects/12345/global/networks/myVPC.
76
+ Format is of the form projects/{project}/global/networks/{network}.
77
+ Where {project} is a project number, as in 12345, and {network} is a
78
+ network name. Private services access must already be configured for the
79
+ network. If left unspecified, the job is not peered with any network.
80
+ reserved_ip_ranges: A list of names for the reserved ip ranges under the
81
+ VPC network that can be used for this job. If set, we will deploy the
82
+ job within the provided ip ranges. Otherwise, the job will be deployed
83
+ to any ip ranges under the provided VPC network.
84
+ encryption_spec_key_name: Customer-managed encryption key options for the
85
+ CustomJob. If this is set, then all resources created by the CustomJob
86
+ will be encrypted with the provided encryption key.
87
+
88
+ Returns:
89
+ gcp_resources (str): Serialized gcp_resources proto tracking the custom
90
+ job.
91
+ model_inference_output_gcs_uri: The storage URI pointing toward a
92
+ GCS location to store CSV for third party inference.
93
+ """
94
+ return gcpc_utils.build_serverless_customjob_container_spec(
95
+ project=project,
96
+ location=location,
97
+ custom_job_payload=utils.build_custom_job_payload(
98
+ display_name=display_name,
99
+ machine_type=machine_type,
100
+ image_uri=_IMAGE_URI,
101
+ args=[
102
+ f'--3p_model_inference={True}',
103
+ f'--project={project}',
104
+ f'--location={location}',
105
+ f'--prediction_instances_source_uri={prediction_instances_source_uri}',
106
+ f'--inference_platform={inference_platform}',
107
+ f'--output_inference_gcs_prefix={output_inference_gcs_prefix}',
108
+ f'--model_id={model_id}',
109
+ f'--request_params={request_params}',
110
+ f'--client_api_key_path={client_api_key_path}',
111
+ f'--max_request_per_second={max_request_per_second}',
112
+ f'--max_tokens_per_minute={max_tokens_per_minute}',
113
+ # f'--gcs_output_directory={gcs_output_directory}',
114
+ f'--gcs_output_directory={gcs_output_directory.path}',
115
+ '--executor_input={{$.json_escape[1]}}',
116
+ ],
117
+ service_account=service_account,
118
+ network=network,
119
+ reserved_ip_ranges=reserved_ip_ranges,
120
+ encryption_spec_key_name=encryption_spec_key_name,
121
+ ),
122
+ gcp_resources=gcp_resources,
123
+ )
124
+
125
+
126
+ @pipeline(name='ModelEvaluationModelInferenceOp')
127
+ def model_inference_component(
128
+ project: str,
129
+ location: str,
130
+ client_api_key_path: str,
131
+ prediction_instances_source_uri: str,
132
+ output_inference_gcs_prefix: str,
133
+ inference_platform: str = 'openai_chat_completions',
134
+ model_id: str = 'gpt-3.5-turbo',
135
+ request_params: Dict[str, Any] = {},
136
+ max_request_per_second: float = 3,
137
+ max_tokens_per_minute: float = 100,
138
+ display_name: str = 'third-party-inference',
139
+ machine_type: str = 'e2-highmem-16',
140
+ service_account: str = '',
141
+ network: str = '',
142
+ reserved_ip_ranges: List[str] = [],
143
+ encryption_spec_key_name: str = '',
144
+ ) -> NamedTuple(
145
+ 'outputs',
146
+ gcs_output_directory=Artifact,
147
+ ):
148
+ """Component to run Third Party Model Inference.
149
+
150
+ Args:
151
+ project: Required. The GCP project that runs the pipeline component.
152
+ location: Required. The GCP region that runs the pipeline component.
153
+ client_api_key_path: The GCS URI where client API key.
154
+ output_inference_gcs_prefix: GCS file prefix for writing output.
155
+ display_name: display name of the pipeline.
156
+ machine_type: The machine type of this custom job. If not set, defaulted
157
+ to `e2-highmem-16`. More details:
158
+ https://cloud.google.com/compute/docs/machine-resource
159
+ service_account: Sets the default service account for workload run-as
160
+ account. The service account running the pipeline
161
+ (https://cloud.google.com/vertex-ai/docs/pipelines/configure-project#service-account)
162
+ submitting jobs must have act-as permission on this run-as account. If
163
+ unspecified, the Vertex AI Custom Code Service
164
+ Agent(https://cloud.google.com/vertex-ai/docs/general/access-control#service-agents)
165
+ for the CustomJob's project.
166
+ network: The full name of the Compute Engine network to which the job
167
+ should be peered. For example, projects/12345/global/networks/myVPC.
168
+ Format is of the form projects/{project}/global/networks/{network}.
169
+ Where {project} is a project number, as in 12345, and {network} is a
170
+ network name. Private services access must already be configured for the
171
+ network. If left unspecified, the job is not peered with any network.
172
+ reserved_ip_ranges: A list of names for the reserved ip ranges under the
173
+ VPC network that can be used for this job. If set, we will deploy the
174
+ job within the provided ip ranges. Otherwise, the job will be deployed
175
+ to any ip ranges under the provided VPC network.
176
+ encryption_spec_key_name: Customer-managed encryption key options for the
177
+ CustomJob. If this is set, then all resources created by the CustomJob
178
+ will be encrypted with the provided encryption key.
179
+
180
+ Returns:
181
+ NamedTuple:
182
+ model_inference_output_gcs_uri: CSV file output containing third
183
+ party prediction results.
184
+ """
185
+ outputs = NamedTuple(
186
+ 'outputs',
187
+ gcs_output_directory=Artifact,
188
+ )
189
+
190
+ inference_task = model_inference_component_internal(
191
+ project=project,
192
+ location=location,
193
+ client_api_key_path=client_api_key_path,
194
+ prediction_instances_source_uri=prediction_instances_source_uri,
195
+ inference_platform=inference_platform,
196
+ model_id=model_id,
197
+ request_params=request_params,
198
+ max_request_per_second=max_request_per_second,
199
+ max_tokens_per_minute=max_tokens_per_minute,
200
+ output_inference_gcs_prefix=output_inference_gcs_prefix,
201
+ display_name=display_name,
202
+ machine_type=machine_type,
203
+ service_account=service_account,
204
+ network=network,
205
+ reserved_ip_ranges=reserved_ip_ranges,
206
+ encryption_spec_key_name=encryption_spec_key_name,
207
+ )
208
+
209
+ return outputs(
210
+ gcs_output_directory=inference_task.outputs['gcs_output_directory'],
211
+ )
212
+
213
+
214
+ @pipeline(name='ModelEvaluationModelInferenceAndEvaluationPipeline')
215
+ def model_inference_and_evaluation_component(
216
+ project: str,
217
+ location: str,
218
+ client_api_key_path: str,
219
+ prediction_instances_source_uri: str,
220
+ output_inference_gcs_prefix: str,
221
+ target_field_name: str = '',
222
+ inference_platform: str = 'openai_chat_completions',
223
+ model_id: str = 'gpt-3.5-turbo',
224
+ request_params: Dict[str, Any] = {},
225
+ max_request_per_second: float = 3,
226
+ max_tokens_per_minute: float = 100,
227
+ display_name: str = 'third-party-inference',
228
+ machine_type: str = 'e2-highmem-16',
229
+ service_account: str = '',
230
+ network: str = '',
231
+ reserved_ip_ranges: List[str] = [],
232
+ encryption_spec_key_name: str = '',
233
+ ) -> NamedTuple(
234
+ 'outputs',
235
+ gcs_output_directory=Artifact,
236
+ evaluation_metrics=Metrics,
237
+ ):
238
+ """Component tun Third Party Model Inference and evaluation.
239
+
240
+ Args:
241
+ project: Required. The GCP project that runs the pipeline component.
242
+ location: Required. The GCP region that runs the pipeline component.
243
+ client_api_key_path: The GCS URI where client API key.
244
+ output_inference_gcs_prefix: GCS file prefix for writing output.
245
+ display_name: display name of the pipeline.
246
+ machine_type: The machine type of this custom job. If not set, defaulted
247
+ to `e2-highmem-16`. More details:
248
+ https://cloud.google.com/compute/docs/machine-resource
249
+ service_account: Sets the default service account for workload run-as
250
+ account. The service account running the pipeline
251
+ (https://cloud.google.com/vertex-ai/docs/pipelines/configure-project#service-account)
252
+ submitting jobs must have act-as permission on this run-as account. If
253
+ unspecified, the Vertex AI Custom Code Service
254
+ Agent(https://cloud.google.com/vertex-ai/docs/general/access-control#service-agents)
255
+ for the CustomJob's project.
256
+ network: The full name of the Compute Engine network to which the job
257
+ should be peered. For example, projects/12345/global/networks/myVPC.
258
+ Format is of the form projects/{project}/global/networks/{network}.
259
+ Where {project} is a project number, as in 12345, and {network} is a
260
+ network name. Private services access must already be configured for the
261
+ network. If left unspecified, the job is not peered with any network.
262
+ reserved_ip_ranges: A list of names for the reserved ip ranges under the
263
+ VPC network that can be used for this job. If set, we will deploy the
264
+ job within the provided ip ranges. Otherwise, the job will be deployed
265
+ to any ip ranges under the provided VPC network.
266
+ encryption_spec_key_name: Customer-managed encryption key options for the
267
+ CustomJob. If this is set, then all resources created by the CustomJob
268
+ will be encrypted with the provided encryption key.
269
+
270
+ Returns:
271
+ NamedTuple:
272
+ model_inference_output_gcs_uri: CSV file output containing third
273
+ party prediction results.
274
+ """
275
+ outputs = NamedTuple(
276
+ 'outputs',
277
+ gcs_output_directory=Artifact,
278
+ evaluation_metrics=Metrics,
279
+ )
280
+
281
+ inference_task = model_inference_component_internal(
282
+ project=project,
283
+ location=location,
284
+ client_api_key_path=client_api_key_path,
285
+ prediction_instances_source_uri=prediction_instances_source_uri,
286
+ inference_platform=inference_platform,
287
+ model_id=model_id,
288
+ request_params=request_params,
289
+ max_request_per_second=max_request_per_second,
290
+ max_tokens_per_minute=max_tokens_per_minute,
291
+ output_inference_gcs_prefix=output_inference_gcs_prefix,
292
+ display_name=display_name,
293
+ machine_type=machine_type,
294
+ service_account=service_account,
295
+ network=network,
296
+ reserved_ip_ranges=reserved_ip_ranges,
297
+ encryption_spec_key_name=encryption_spec_key_name,
298
+ )
299
+
300
+ if inference_platform == 'openai_chat_completions':
301
+ prediction_field_name = 'predictions.0.message.content'
302
+ elif inference_platform == 'anthropic_predictions':
303
+ prediction_field_name = 'predictions'
304
+ else:
305
+ prediction_field_name = ''
306
+
307
+ eval_task = LLMEvaluationTextGenerationOp(
308
+ project=project,
309
+ location=location,
310
+ evaluation_task='text-generation',
311
+ target_field_name=target_field_name,
312
+ prediction_field_name=prediction_field_name,
313
+ predictions_format='jsonl',
314
+ joined_predictions_gcs_source=inference_task.outputs[
315
+ 'gcs_output_directory'
316
+ ],
317
+ machine_type=machine_type,
318
+ encryption_spec_key_name=encryption_spec_key_name,
319
+ )
320
+
321
+ return outputs(
322
+ gcs_output_directory=inference_task.outputs['gcs_output_directory'],
323
+ evaluation_metrics=eval_task.outputs['evaluation_metrics'],
324
+ )
@@ -13,8 +13,8 @@
13
13
  # limitations under the License.
14
14
  """Version constants for model evaluation components."""
15
15
 
16
- _EVAL_VERSION = 'v0.9.3'
17
- _LLM_EVAL_VERSION = 'v0.3'
16
+ _EVAL_VERSION = 'v0.9.4'
17
+ _LLM_EVAL_VERSION = 'v0.5'
18
18
 
19
19
  _EVAL_IMAGE_NAME = 'gcr.io/ml-pipeline/model-evaluation'
20
20
  _LLM_EVAL_IMAGE_NAME = 'gcr.io/ml-pipeline/llm-model-evaluation'
@@ -56,6 +56,9 @@ parser = argparse.ArgumentParser(
56
56
  prog='Vertex Model Service evaluation importer', description=''
57
57
  )
58
58
  parser.add_argument('--metrics', dest='metrics', type=str, default='')
59
+ parser.add_argument(
60
+ '--row_based_metrics', dest='row_based_metrics', type=str, default=''
61
+ )
59
62
  parser.add_argument(
60
63
  '--classification_metrics',
61
64
  dest='classification_metrics',
@@ -274,6 +277,11 @@ def main(argv):
274
277
  'pipeline_job_resource_name': parsed_args.pipeline_job_resource_name,
275
278
  'evaluation_dataset_type': parsed_args.dataset_type,
276
279
  'evaluation_dataset_path': dataset_paths or None,
280
+ 'row_based_metrics_path': (
281
+ parsed_args.row_based_metrics
282
+ if parsed_args.row_based_metrics
283
+ else None
284
+ ),
277
285
  }.items()
278
286
  if value
279
287
  }
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Google Cloud Pipeline Components - AutoML Image Training Job container code."""
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Google Cloud Pipeline Components - AutoML Image Training Job Launcher and Remote Runner."""
@@ -0,0 +1,236 @@
1
+ # Copyright 2023 The Kubeflow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """GCP launcher for AutoML image training jobs based on the AI Platform SDK."""
15
+
16
+ import argparse
17
+ import json
18
+ import logging
19
+ import sys
20
+ from typing import List
21
+
22
+ from google_cloud_pipeline_components.container.v1.automl_training_job.image import remote_runner
23
+ from google_cloud_pipeline_components.container.v1.gcp_launcher.utils import parser_util
24
+
25
+
26
+ def _parse_args(args: List[str]):
27
+ """Parse command line arguments."""
28
+ args.append('--payload')
29
+ args.append('"{}"') # Unused but required by parser_util.
30
+ parser, _ = parser_util.parse_default_args(args)
31
+ # Parse the conditionally required arguments
32
+ parser.add_argument(
33
+ '--display_name',
34
+ dest='display_name',
35
+ type=str,
36
+ required=False,
37
+ default=argparse.SUPPRESS,
38
+ )
39
+ parser.add_argument(
40
+ '--prediction_type',
41
+ dest='prediction_type',
42
+ type=str,
43
+ required=False,
44
+ default=argparse.SUPPRESS,
45
+ )
46
+ parser.add_argument(
47
+ '--multi_label',
48
+ dest='multi_label',
49
+ type=parser_util.parse_bool,
50
+ required=False,
51
+ default=argparse.SUPPRESS,
52
+ )
53
+ parser.add_argument(
54
+ '--model_type',
55
+ dest='model_type',
56
+ type=str,
57
+ required=False,
58
+ default=argparse.SUPPRESS,
59
+ )
60
+ parser.add_argument(
61
+ '--labels',
62
+ dest='labels',
63
+ type=json.loads,
64
+ required=False,
65
+ default=argparse.SUPPRESS,
66
+ )
67
+ parser.add_argument(
68
+ '--dataset',
69
+ dest='dataset',
70
+ type=str,
71
+ required=False,
72
+ default=argparse.SUPPRESS,
73
+ )
74
+ parser.add_argument(
75
+ '--disable_early_stopping',
76
+ dest='disable_early_stopping',
77
+ type=parser_util.parse_bool,
78
+ required=False,
79
+ default=argparse.SUPPRESS,
80
+ )
81
+ parser.add_argument(
82
+ '--training_encryption_spec_key_name',
83
+ dest='training_encryption_spec_key_name',
84
+ type=str,
85
+ required=False,
86
+ default=argparse.SUPPRESS,
87
+ )
88
+ parser.add_argument(
89
+ '--model_encryption_spec_key_name',
90
+ dest='model_encryption_spec_key_name',
91
+ type=str,
92
+ required=False,
93
+ default=argparse.SUPPRESS,
94
+ )
95
+ parser.add_argument(
96
+ '--model_display_name',
97
+ dest='model_display_name',
98
+ type=str,
99
+ required=False,
100
+ default=argparse.SUPPRESS,
101
+ )
102
+ parser.add_argument(
103
+ '--training_fraction_split',
104
+ dest='training_fraction_split',
105
+ type=float,
106
+ required=False,
107
+ default=argparse.SUPPRESS,
108
+ )
109
+ parser.add_argument(
110
+ '--validation_fraction_split',
111
+ dest='validation_fraction_split',
112
+ type=float,
113
+ required=False,
114
+ default=argparse.SUPPRESS,
115
+ )
116
+ parser.add_argument(
117
+ '--test_fraction_split',
118
+ dest='test_fraction_split',
119
+ type=float,
120
+ required=False,
121
+ default=argparse.SUPPRESS,
122
+ )
123
+ parser.add_argument(
124
+ '--budget_milli_node_hours',
125
+ dest='budget_milli_node_hours',
126
+ type=int,
127
+ required=False,
128
+ default=argparse.SUPPRESS,
129
+ )
130
+ parser.add_argument(
131
+ '--training_filter_split',
132
+ dest='training_filter_split',
133
+ type=str,
134
+ required=False,
135
+ default=argparse.SUPPRESS,
136
+ )
137
+ parser.add_argument(
138
+ '--validation_filter_split',
139
+ dest='validation_filter_split',
140
+ type=str,
141
+ required=False,
142
+ default=argparse.SUPPRESS,
143
+ )
144
+ parser.add_argument(
145
+ '--test_filter_split',
146
+ dest='test_filter_split',
147
+ type=str,
148
+ required=False,
149
+ default=argparse.SUPPRESS,
150
+ )
151
+ parser.add_argument(
152
+ '--base_model',
153
+ dest='base_model',
154
+ type=str,
155
+ required=False,
156
+ default=argparse.SUPPRESS,
157
+ )
158
+ parser.add_argument(
159
+ '--incremental_train_base_model',
160
+ dest='incremental_train_base_model',
161
+ type=str,
162
+ required=False,
163
+ default=argparse.SUPPRESS,
164
+ )
165
+ parser.add_argument(
166
+ '--parent_model',
167
+ dest='parent_model',
168
+ type=str,
169
+ required=False,
170
+ default=argparse.SUPPRESS,
171
+ )
172
+ parser.add_argument(
173
+ '--is_default_version',
174
+ dest='is_default_version',
175
+ type=parser_util.parse_bool,
176
+ required=False,
177
+ default=argparse.SUPPRESS,
178
+ )
179
+ parser.add_argument(
180
+ '--model_version_aliases',
181
+ dest='model_version_aliases',
182
+ type=json.loads,
183
+ required=False,
184
+ default=argparse.SUPPRESS,
185
+ )
186
+ parser.add_argument(
187
+ '--model_version_description',
188
+ dest='model_version_description',
189
+ type=str,
190
+ required=False,
191
+ default=argparse.SUPPRESS,
192
+ )
193
+ parser.add_argument(
194
+ '--model_labels',
195
+ dest='model_labels',
196
+ type=json.loads,
197
+ required=False,
198
+ default=argparse.SUPPRESS,
199
+ )
200
+ parsed_args, _ = parser.parse_known_args(args)
201
+ args_dict = vars(parsed_args)
202
+ del args_dict['payload']
203
+ return args_dict
204
+
205
+
206
+ def main(argv: List[str]):
207
+ """Main entry.
208
+
209
+ Expected input args are as follows:
210
+ Project - Required. The project of which the resource will be launched.
211
+ Region - Required. The region of which the resource will be launched.
212
+ Type - Required. GCP launcher is a single container. This Enum will
213
+ specify which resource to be launched.
214
+ gcp_resources - placeholder output for returning job_id.
215
+ Extra arguments - For constructing request payload. See remote_runner.py for
216
+ more information.
217
+
218
+ Args:
219
+ argv: A list of system arguments.
220
+ """
221
+ parsed_args = _parse_args(argv)
222
+ job_type = parsed_args['type']
223
+
224
+ if job_type != 'AutoMLImageTrainingJob':
225
+ raise ValueError('Incorrect job type: ' + job_type)
226
+
227
+ logging.info(
228
+ 'Starting AutoMLImageTrainingJob using the following arguments: %s',
229
+ parsed_args,
230
+ )
231
+
232
+ remote_runner.create_pipeline(**parsed_args)
233
+
234
+
235
+ if __name__ == '__main__':
236
+ main(sys.argv[1:])