google-genai 1.29.0__py3-none-any.whl → 1.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/tunings.py CHANGED
@@ -28,6 +28,7 @@ from ._common import get_value_by_path as getv
28
28
  from ._common import set_value_by_path as setv
29
29
  from .pagers import AsyncPager, Pager
30
30
 
31
+
31
32
  logger = logging.getLogger('google_genai.tunings')
32
33
 
33
34
 
@@ -166,6 +167,12 @@ def _CreateTuningJobConfig_to_mldev(
166
167
  'export_last_checkpoint_only parameter is not supported in Gemini API.'
167
168
  )
168
169
 
170
+ if getv(from_object, ['pre_tuned_model_checkpoint_id']) is not None:
171
+ raise ValueError(
172
+ 'pre_tuned_model_checkpoint_id parameter is not supported in Gemini'
173
+ ' API.'
174
+ )
175
+
169
176
  if getv(from_object, ['adapter_size']) is not None:
170
177
  raise ValueError('adapter_size parameter is not supported in Gemini API.')
171
178
 
@@ -183,10 +190,15 @@ def _CreateTuningJobConfig_to_mldev(
183
190
  getv(from_object, ['learning_rate']),
184
191
  )
185
192
 
193
+ if getv(from_object, ['evaluation_config']) is not None:
194
+ raise ValueError(
195
+ 'evaluation_config parameter is not supported in Gemini API.'
196
+ )
197
+
186
198
  return to_object
187
199
 
188
200
 
189
- def _CreateTuningJobParameters_to_mldev(
201
+ def _CreateTuningJobParametersPrivate_to_mldev(
190
202
  from_object: Union[dict[str, Any], object],
191
203
  parent_object: Optional[dict[str, Any]] = None,
192
204
  ) -> dict[str, Any]:
@@ -194,6 +206,9 @@ def _CreateTuningJobParameters_to_mldev(
194
206
  if getv(from_object, ['base_model']) is not None:
195
207
  setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
196
208
 
209
+ if getv(from_object, ['pre_tuned_model']) is not None:
210
+ setv(to_object, ['preTunedModel'], getv(from_object, ['pre_tuned_model']))
211
+
197
212
  if getv(from_object, ['training_dataset']) is not None:
198
213
  setv(
199
214
  to_object,
@@ -313,6 +328,82 @@ def _TuningValidationDataset_to_vertex(
313
328
  return to_object
314
329
 
315
330
 
331
+ def _GcsDestination_to_vertex(
332
+ from_object: Union[dict[str, Any], object],
333
+ parent_object: Optional[dict[str, Any]] = None,
334
+ ) -> dict[str, Any]:
335
+ to_object: dict[str, Any] = {}
336
+ if getv(from_object, ['output_uri_prefix']) is not None:
337
+ setv(
338
+ to_object, ['outputUriPrefix'], getv(from_object, ['output_uri_prefix'])
339
+ )
340
+
341
+ return to_object
342
+
343
+
344
+ def _OutputConfig_to_vertex(
345
+ from_object: Union[dict[str, Any], object],
346
+ parent_object: Optional[dict[str, Any]] = None,
347
+ ) -> dict[str, Any]:
348
+ to_object: dict[str, Any] = {}
349
+ if getv(from_object, ['gcs_destination']) is not None:
350
+ setv(
351
+ to_object,
352
+ ['gcsDestination'],
353
+ _GcsDestination_to_vertex(
354
+ getv(from_object, ['gcs_destination']), to_object
355
+ ),
356
+ )
357
+
358
+ return to_object
359
+
360
+
361
+ def _AutoraterConfig_to_vertex(
362
+ from_object: Union[dict[str, Any], object],
363
+ parent_object: Optional[dict[str, Any]] = None,
364
+ ) -> dict[str, Any]:
365
+ to_object: dict[str, Any] = {}
366
+ if getv(from_object, ['sampling_count']) is not None:
367
+ setv(to_object, ['samplingCount'], getv(from_object, ['sampling_count']))
368
+
369
+ if getv(from_object, ['flip_enabled']) is not None:
370
+ setv(to_object, ['flipEnabled'], getv(from_object, ['flip_enabled']))
371
+
372
+ if getv(from_object, ['autorater_model']) is not None:
373
+ setv(to_object, ['autoraterModel'], getv(from_object, ['autorater_model']))
374
+
375
+ return to_object
376
+
377
+
378
+ def _EvaluationConfig_to_vertex(
379
+ from_object: Union[dict[str, Any], object],
380
+ parent_object: Optional[dict[str, Any]] = None,
381
+ ) -> dict[str, Any]:
382
+ to_object: dict[str, Any] = {}
383
+ if getv(from_object, ['metrics']) is not None:
384
+ setv(to_object, ['metrics'], t.t_metrics(getv(from_object, ['metrics'])))
385
+
386
+ if getv(from_object, ['output_config']) is not None:
387
+ setv(
388
+ to_object,
389
+ ['outputConfig'],
390
+ _OutputConfig_to_vertex(
391
+ getv(from_object, ['output_config']), to_object
392
+ ),
393
+ )
394
+
395
+ if getv(from_object, ['autorater_config']) is not None:
396
+ setv(
397
+ to_object,
398
+ ['autoraterConfig'],
399
+ _AutoraterConfig_to_vertex(
400
+ getv(from_object, ['autorater_config']), to_object
401
+ ),
402
+ )
403
+
404
+ return to_object
405
+
406
+
316
407
  def _CreateTuningJobConfig_to_vertex(
317
408
  from_object: Union[dict[str, Any], object],
318
409
  parent_object: Optional[dict[str, Any]] = None,
@@ -359,6 +450,13 @@ def _CreateTuningJobConfig_to_vertex(
359
450
  getv(from_object, ['export_last_checkpoint_only']),
360
451
  )
361
452
 
453
+ if getv(from_object, ['pre_tuned_model_checkpoint_id']) is not None:
454
+ setv(
455
+ to_object,
456
+ ['preTunedModel', 'checkpointId'],
457
+ getv(from_object, ['pre_tuned_model_checkpoint_id']),
458
+ )
459
+
362
460
  if getv(from_object, ['adapter_size']) is not None:
363
461
  setv(
364
462
  parent_object,
@@ -372,10 +470,19 @@ def _CreateTuningJobConfig_to_vertex(
372
470
  if getv(from_object, ['learning_rate']) is not None:
373
471
  raise ValueError('learning_rate parameter is not supported in Vertex AI.')
374
472
 
473
+ if getv(from_object, ['evaluation_config']) is not None:
474
+ setv(
475
+ parent_object,
476
+ ['supervisedTuningSpec', 'evaluationConfig'],
477
+ _EvaluationConfig_to_vertex(
478
+ getv(from_object, ['evaluation_config']), to_object
479
+ ),
480
+ )
481
+
375
482
  return to_object
376
483
 
377
484
 
378
- def _CreateTuningJobParameters_to_vertex(
485
+ def _CreateTuningJobParametersPrivate_to_vertex(
379
486
  from_object: Union[dict[str, Any], object],
380
487
  parent_object: Optional[dict[str, Any]] = None,
381
488
  ) -> dict[str, Any]:
@@ -383,6 +490,9 @@ def _CreateTuningJobParameters_to_vertex(
383
490
  if getv(from_object, ['base_model']) is not None:
384
491
  setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
385
492
 
493
+ if getv(from_object, ['pre_tuned_model']) is not None:
494
+ setv(to_object, ['preTunedModel'], getv(from_object, ['pre_tuned_model']))
495
+
386
496
  if getv(from_object, ['training_dataset']) is not None:
387
497
  setv(
388
498
  to_object,
@@ -471,11 +581,9 @@ def _TuningJob_from_mldev(
471
581
  _TunedModel_from_mldev(getv(from_object, ['_self']), to_object),
472
582
  )
473
583
 
474
- if getv(from_object, ['distillationSpec']) is not None:
584
+ if getv(from_object, ['customBaseModel']) is not None:
475
585
  setv(
476
- to_object,
477
- ['distillation_spec'],
478
- getv(from_object, ['distillationSpec']),
586
+ to_object, ['custom_base_model'], getv(from_object, ['customBaseModel'])
479
587
  )
480
588
 
481
589
  if getv(from_object, ['experiment']) is not None:
@@ -484,15 +592,12 @@ def _TuningJob_from_mldev(
484
592
  if getv(from_object, ['labels']) is not None:
485
593
  setv(to_object, ['labels'], getv(from_object, ['labels']))
486
594
 
595
+ if getv(from_object, ['outputUri']) is not None:
596
+ setv(to_object, ['output_uri'], getv(from_object, ['outputUri']))
597
+
487
598
  if getv(from_object, ['pipelineJob']) is not None:
488
599
  setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
489
600
 
490
- if getv(from_object, ['satisfiesPzi']) is not None:
491
- setv(to_object, ['satisfies_pzi'], getv(from_object, ['satisfiesPzi']))
492
-
493
- if getv(from_object, ['satisfiesPzs']) is not None:
494
- setv(to_object, ['satisfies_pzs'], getv(from_object, ['satisfiesPzs']))
495
-
496
601
  if getv(from_object, ['serviceAccount']) is not None:
497
602
  setv(to_object, ['service_account'], getv(from_object, ['serviceAccount']))
498
603
 
@@ -601,6 +706,82 @@ def _TunedModel_from_vertex(
601
706
  return to_object
602
707
 
603
708
 
709
+ def _GcsDestination_from_vertex(
710
+ from_object: Union[dict[str, Any], object],
711
+ parent_object: Optional[dict[str, Any]] = None,
712
+ ) -> dict[str, Any]:
713
+ to_object: dict[str, Any] = {}
714
+ if getv(from_object, ['outputUriPrefix']) is not None:
715
+ setv(
716
+ to_object, ['output_uri_prefix'], getv(from_object, ['outputUriPrefix'])
717
+ )
718
+
719
+ return to_object
720
+
721
+
722
+ def _OutputConfig_from_vertex(
723
+ from_object: Union[dict[str, Any], object],
724
+ parent_object: Optional[dict[str, Any]] = None,
725
+ ) -> dict[str, Any]:
726
+ to_object: dict[str, Any] = {}
727
+ if getv(from_object, ['gcsDestination']) is not None:
728
+ setv(
729
+ to_object,
730
+ ['gcs_destination'],
731
+ _GcsDestination_from_vertex(
732
+ getv(from_object, ['gcsDestination']), to_object
733
+ ),
734
+ )
735
+
736
+ return to_object
737
+
738
+
739
+ def _AutoraterConfig_from_vertex(
740
+ from_object: Union[dict[str, Any], object],
741
+ parent_object: Optional[dict[str, Any]] = None,
742
+ ) -> dict[str, Any]:
743
+ to_object: dict[str, Any] = {}
744
+ if getv(from_object, ['samplingCount']) is not None:
745
+ setv(to_object, ['sampling_count'], getv(from_object, ['samplingCount']))
746
+
747
+ if getv(from_object, ['flipEnabled']) is not None:
748
+ setv(to_object, ['flip_enabled'], getv(from_object, ['flipEnabled']))
749
+
750
+ if getv(from_object, ['autoraterModel']) is not None:
751
+ setv(to_object, ['autorater_model'], getv(from_object, ['autoraterModel']))
752
+
753
+ return to_object
754
+
755
+
756
+ def _EvaluationConfig_from_vertex(
757
+ from_object: Union[dict[str, Any], object],
758
+ parent_object: Optional[dict[str, Any]] = None,
759
+ ) -> dict[str, Any]:
760
+ to_object: dict[str, Any] = {}
761
+ if getv(from_object, ['metrics']) is not None:
762
+ setv(to_object, ['metrics'], t.t_metrics(getv(from_object, ['metrics'])))
763
+
764
+ if getv(from_object, ['outputConfig']) is not None:
765
+ setv(
766
+ to_object,
767
+ ['output_config'],
768
+ _OutputConfig_from_vertex(
769
+ getv(from_object, ['outputConfig']), to_object
770
+ ),
771
+ )
772
+
773
+ if getv(from_object, ['autoraterConfig']) is not None:
774
+ setv(
775
+ to_object,
776
+ ['autorater_config'],
777
+ _AutoraterConfig_from_vertex(
778
+ getv(from_object, ['autoraterConfig']), to_object
779
+ ),
780
+ )
781
+
782
+ return to_object
783
+
784
+
604
785
  def _TuningJob_from_vertex(
605
786
  from_object: Union[dict[str, Any], object],
606
787
  parent_object: Optional[dict[str, Any]] = None,
@@ -649,6 +830,9 @@ def _TuningJob_from_vertex(
649
830
  _TunedModel_from_vertex(getv(from_object, ['tunedModel']), to_object),
650
831
  )
651
832
 
833
+ if getv(from_object, ['preTunedModel']) is not None:
834
+ setv(to_object, ['pre_tuned_model'], getv(from_object, ['preTunedModel']))
835
+
652
836
  if getv(from_object, ['supervisedTuningSpec']) is not None:
653
837
  setv(
654
838
  to_object,
@@ -671,11 +855,18 @@ def _TuningJob_from_vertex(
671
855
  getv(from_object, ['partnerModelTuningSpec']),
672
856
  )
673
857
 
674
- if getv(from_object, ['distillationSpec']) is not None:
858
+ if getv(from_object, ['evaluationConfig']) is not None:
675
859
  setv(
676
860
  to_object,
677
- ['distillation_spec'],
678
- getv(from_object, ['distillationSpec']),
861
+ ['evaluation_config'],
862
+ _EvaluationConfig_from_vertex(
863
+ getv(from_object, ['evaluationConfig']), to_object
864
+ ),
865
+ )
866
+
867
+ if getv(from_object, ['customBaseModel']) is not None:
868
+ setv(
869
+ to_object, ['custom_base_model'], getv(from_object, ['customBaseModel'])
679
870
  )
680
871
 
681
872
  if getv(from_object, ['experiment']) is not None:
@@ -684,15 +875,12 @@ def _TuningJob_from_vertex(
684
875
  if getv(from_object, ['labels']) is not None:
685
876
  setv(to_object, ['labels'], getv(from_object, ['labels']))
686
877
 
878
+ if getv(from_object, ['outputUri']) is not None:
879
+ setv(to_object, ['output_uri'], getv(from_object, ['outputUri']))
880
+
687
881
  if getv(from_object, ['pipelineJob']) is not None:
688
882
  setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
689
883
 
690
- if getv(from_object, ['satisfiesPzi']) is not None:
691
- setv(to_object, ['satisfies_pzi'], getv(from_object, ['satisfiesPzi']))
692
-
693
- if getv(from_object, ['satisfiesPzs']) is not None:
694
- setv(to_object, ['satisfies_pzs'], getv(from_object, ['satisfiesPzs']))
695
-
696
884
  if getv(from_object, ['serviceAccount']) is not None:
697
885
  setv(to_object, ['service_account'], getv(from_object, ['serviceAccount']))
698
886
 
@@ -875,7 +1063,8 @@ class Tunings(_api_module.BaseModule):
875
1063
  def _tune(
876
1064
  self,
877
1065
  *,
878
- base_model: str,
1066
+ base_model: Optional[str] = None,
1067
+ pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
879
1068
  training_dataset: types.TuningDatasetOrDict,
880
1069
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
881
1070
  ) -> types.TuningJob:
@@ -890,8 +1079,9 @@ class Tunings(_api_module.BaseModule):
890
1079
  A TuningJob object.
891
1080
  """
892
1081
 
893
- parameter_model = types._CreateTuningJobParameters(
1082
+ parameter_model = types._CreateTuningJobParametersPrivate(
894
1083
  base_model=base_model,
1084
+ pre_tuned_model=pre_tuned_model,
895
1085
  training_dataset=training_dataset,
896
1086
  config=config,
897
1087
  )
@@ -900,7 +1090,9 @@ class Tunings(_api_module.BaseModule):
900
1090
  if not self._api_client.vertexai:
901
1091
  raise ValueError('This method is only supported in the Vertex AI client.')
902
1092
  else:
903
- request_dict = _CreateTuningJobParameters_to_vertex(parameter_model)
1093
+ request_dict = _CreateTuningJobParametersPrivate_to_vertex(
1094
+ parameter_model
1095
+ )
904
1096
  request_url_dict = request_dict.get('_url')
905
1097
  if request_url_dict:
906
1098
  path = 'tuningJobs'.format_map(request_url_dict)
@@ -944,7 +1136,8 @@ class Tunings(_api_module.BaseModule):
944
1136
  def _tune_mldev(
945
1137
  self,
946
1138
  *,
947
- base_model: str,
1139
+ base_model: Optional[str] = None,
1140
+ pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
948
1141
  training_dataset: types.TuningDatasetOrDict,
949
1142
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
950
1143
  ) -> types.TuningOperation:
@@ -959,8 +1152,9 @@ class Tunings(_api_module.BaseModule):
959
1152
  A TuningJob operation.
960
1153
  """
961
1154
 
962
- parameter_model = types._CreateTuningJobParameters(
1155
+ parameter_model = types._CreateTuningJobParametersPrivate(
963
1156
  base_model=base_model,
1157
+ pre_tuned_model=pre_tuned_model,
964
1158
  training_dataset=training_dataset,
965
1159
  config=config,
966
1160
  )
@@ -971,7 +1165,7 @@ class Tunings(_api_module.BaseModule):
971
1165
  'This method is only supported in the Gemini Developer client.'
972
1166
  )
973
1167
  else:
974
- request_dict = _CreateTuningJobParameters_to_mldev(parameter_model)
1168
+ request_dict = _CreateTuningJobParametersPrivate_to_mldev(parameter_model)
975
1169
  request_url_dict = request_dict.get('_url')
976
1170
  if request_url_dict:
977
1171
  path = 'tunedModels'.format_map(request_url_dict)
@@ -1052,11 +1246,50 @@ class Tunings(_api_module.BaseModule):
1052
1246
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1053
1247
  ) -> types.TuningJob:
1054
1248
  if self._api_client.vertexai:
1055
- tuning_job = self._tune(
1056
- base_model=base_model,
1057
- training_dataset=training_dataset,
1058
- config=config,
1059
- )
1249
+ if base_model.startswith('projects/'): # Pre-tuned model
1250
+ pre_tuned_model = types.PreTunedModel(tuned_model_name=base_model)
1251
+ tuning_job = self._tune(
1252
+ pre_tuned_model=pre_tuned_model,
1253
+ training_dataset=training_dataset,
1254
+ config=config,
1255
+ )
1256
+ else:
1257
+ validated_evaluation_config: Optional[types.EvaluationConfig] = None
1258
+ if (
1259
+ config is not None
1260
+ and getattr(config, 'evaluation_config', None) is not None
1261
+ ):
1262
+ evaluation_config = getattr(config, 'evaluation_config')
1263
+ if isinstance(evaluation_config, dict):
1264
+ evaluation_config = types.EvaluationConfig(**evaluation_config)
1265
+ if (
1266
+ not evaluation_config.metrics
1267
+ or not evaluation_config.output_config
1268
+ ):
1269
+ raise ValueError(
1270
+ 'Evaluation config must have at least one metric and an output'
1271
+ ' config.'
1272
+ )
1273
+ for i in range(len(evaluation_config.metrics)):
1274
+ if isinstance(evaluation_config.metrics[i], dict):
1275
+ evaluation_config.metrics[i] = types.Metric.model_validate(
1276
+ evaluation_config.metrics[i]
1277
+ )
1278
+ if isinstance(config, dict):
1279
+ config['evaluation_config'] = evaluation_config
1280
+ else:
1281
+ config.evaluation_config = evaluation_config
1282
+ validated_evaluation_config = evaluation_config
1283
+ tuning_job = self._tune(
1284
+ base_model=base_model,
1285
+ training_dataset=training_dataset,
1286
+ config=config,
1287
+ )
1288
+ if (
1289
+ config is not None
1290
+ and getattr(config, 'evaluation_config', None) is not None
1291
+ ):
1292
+ tuning_job.evaluation_config = validated_evaluation_config
1060
1293
  else:
1061
1294
  operation = self._tune_mldev(
1062
1295
  base_model=base_model,
@@ -1227,7 +1460,8 @@ class AsyncTunings(_api_module.BaseModule):
1227
1460
  async def _tune(
1228
1461
  self,
1229
1462
  *,
1230
- base_model: str,
1463
+ base_model: Optional[str] = None,
1464
+ pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
1231
1465
  training_dataset: types.TuningDatasetOrDict,
1232
1466
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1233
1467
  ) -> types.TuningJob:
@@ -1242,8 +1476,9 @@ class AsyncTunings(_api_module.BaseModule):
1242
1476
  A TuningJob object.
1243
1477
  """
1244
1478
 
1245
- parameter_model = types._CreateTuningJobParameters(
1479
+ parameter_model = types._CreateTuningJobParametersPrivate(
1246
1480
  base_model=base_model,
1481
+ pre_tuned_model=pre_tuned_model,
1247
1482
  training_dataset=training_dataset,
1248
1483
  config=config,
1249
1484
  )
@@ -1252,7 +1487,9 @@ class AsyncTunings(_api_module.BaseModule):
1252
1487
  if not self._api_client.vertexai:
1253
1488
  raise ValueError('This method is only supported in the Vertex AI client.')
1254
1489
  else:
1255
- request_dict = _CreateTuningJobParameters_to_vertex(parameter_model)
1490
+ request_dict = _CreateTuningJobParametersPrivate_to_vertex(
1491
+ parameter_model
1492
+ )
1256
1493
  request_url_dict = request_dict.get('_url')
1257
1494
  if request_url_dict:
1258
1495
  path = 'tuningJobs'.format_map(request_url_dict)
@@ -1296,7 +1533,8 @@ class AsyncTunings(_api_module.BaseModule):
1296
1533
  async def _tune_mldev(
1297
1534
  self,
1298
1535
  *,
1299
- base_model: str,
1536
+ base_model: Optional[str] = None,
1537
+ pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
1300
1538
  training_dataset: types.TuningDatasetOrDict,
1301
1539
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1302
1540
  ) -> types.TuningOperation:
@@ -1311,8 +1549,9 @@ class AsyncTunings(_api_module.BaseModule):
1311
1549
  A TuningJob operation.
1312
1550
  """
1313
1551
 
1314
- parameter_model = types._CreateTuningJobParameters(
1552
+ parameter_model = types._CreateTuningJobParametersPrivate(
1315
1553
  base_model=base_model,
1554
+ pre_tuned_model=pre_tuned_model,
1316
1555
  training_dataset=training_dataset,
1317
1556
  config=config,
1318
1557
  )
@@ -1323,7 +1562,7 @@ class AsyncTunings(_api_module.BaseModule):
1323
1562
  'This method is only supported in the Gemini Developer client.'
1324
1563
  )
1325
1564
  else:
1326
- request_dict = _CreateTuningJobParameters_to_mldev(parameter_model)
1565
+ request_dict = _CreateTuningJobParametersPrivate_to_mldev(parameter_model)
1327
1566
  request_url_dict = request_dict.get('_url')
1328
1567
  if request_url_dict:
1329
1568
  path = 'tunedModels'.format_map(request_url_dict)
@@ -1404,11 +1643,44 @@ class AsyncTunings(_api_module.BaseModule):
1404
1643
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1405
1644
  ) -> types.TuningJob:
1406
1645
  if self._api_client.vertexai:
1407
- tuning_job = await self._tune(
1408
- base_model=base_model,
1409
- training_dataset=training_dataset,
1410
- config=config,
1411
- )
1646
+ if base_model.startswith('projects/'): # Pre-tuned model
1647
+ pre_tuned_model = types.PreTunedModel(tuned_model_name=base_model)
1648
+
1649
+ tuning_job = await self._tune(
1650
+ pre_tuned_model=pre_tuned_model,
1651
+ training_dataset=training_dataset,
1652
+ config=config,
1653
+ )
1654
+ else:
1655
+ if (
1656
+ config is not None
1657
+ and getattr(config, 'evaluation_config', None) is not None
1658
+ ):
1659
+ evaluation_config = getattr(config, 'evaluation_config')
1660
+ if isinstance(evaluation_config, dict):
1661
+ evaluation_config = types.EvaluationConfig(**evaluation_config)
1662
+ if (
1663
+ not evaluation_config.metrics
1664
+ or not evaluation_config.output_config
1665
+ ):
1666
+ raise ValueError(
1667
+ 'Evaluation config must have at least one metric and an output'
1668
+ ' config.'
1669
+ )
1670
+ for i in range(len(evaluation_config.metrics)):
1671
+ if isinstance(evaluation_config.metrics[i], dict):
1672
+ evaluation_config.metrics[i] = types.Metric.model_validate(
1673
+ evaluation_config.metrics[i]
1674
+ )
1675
+ if isinstance(config, dict):
1676
+ config['evaluation_config'] = evaluation_config
1677
+ else:
1678
+ config.evaluation_config = evaluation_config
1679
+ tuning_job = await self._tune(
1680
+ base_model=base_model,
1681
+ training_dataset=training_dataset,
1682
+ config=config,
1683
+ )
1412
1684
  else:
1413
1685
  operation = await self._tune_mldev(
1414
1686
  base_model=base_model,