google-genai 1.28.0__py3-none-any.whl → 1.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/tunings.py CHANGED
@@ -166,6 +166,12 @@ def _CreateTuningJobConfig_to_mldev(
166
166
  'export_last_checkpoint_only parameter is not supported in Gemini API.'
167
167
  )
168
168
 
169
+ if getv(from_object, ['pre_tuned_model_checkpoint_id']) is not None:
170
+ raise ValueError(
171
+ 'pre_tuned_model_checkpoint_id parameter is not supported in Gemini'
172
+ ' API.'
173
+ )
174
+
169
175
  if getv(from_object, ['adapter_size']) is not None:
170
176
  raise ValueError('adapter_size parameter is not supported in Gemini API.')
171
177
 
@@ -183,10 +189,15 @@ def _CreateTuningJobConfig_to_mldev(
183
189
  getv(from_object, ['learning_rate']),
184
190
  )
185
191
 
192
+ if getv(from_object, ['evaluation_config']) is not None:
193
+ raise ValueError(
194
+ 'evaluation_config parameter is not supported in Gemini API.'
195
+ )
196
+
186
197
  return to_object
187
198
 
188
199
 
189
- def _CreateTuningJobParameters_to_mldev(
200
+ def _CreateTuningJobParametersPrivate_to_mldev(
190
201
  from_object: Union[dict[str, Any], object],
191
202
  parent_object: Optional[dict[str, Any]] = None,
192
203
  ) -> dict[str, Any]:
@@ -194,6 +205,9 @@ def _CreateTuningJobParameters_to_mldev(
194
205
  if getv(from_object, ['base_model']) is not None:
195
206
  setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
196
207
 
208
+ if getv(from_object, ['pre_tuned_model']) is not None:
209
+ setv(to_object, ['preTunedModel'], getv(from_object, ['pre_tuned_model']))
210
+
197
211
  if getv(from_object, ['training_dataset']) is not None:
198
212
  setv(
199
213
  to_object,
@@ -313,6 +327,82 @@ def _TuningValidationDataset_to_vertex(
313
327
  return to_object
314
328
 
315
329
 
330
+ def _GcsDestination_to_vertex(
331
+ from_object: Union[dict[str, Any], object],
332
+ parent_object: Optional[dict[str, Any]] = None,
333
+ ) -> dict[str, Any]:
334
+ to_object: dict[str, Any] = {}
335
+ if getv(from_object, ['output_uri_prefix']) is not None:
336
+ setv(
337
+ to_object, ['outputUriPrefix'], getv(from_object, ['output_uri_prefix'])
338
+ )
339
+
340
+ return to_object
341
+
342
+
343
+ def _OutputConfig_to_vertex(
344
+ from_object: Union[dict[str, Any], object],
345
+ parent_object: Optional[dict[str, Any]] = None,
346
+ ) -> dict[str, Any]:
347
+ to_object: dict[str, Any] = {}
348
+ if getv(from_object, ['gcs_destination']) is not None:
349
+ setv(
350
+ to_object,
351
+ ['gcsDestination'],
352
+ _GcsDestination_to_vertex(
353
+ getv(from_object, ['gcs_destination']), to_object
354
+ ),
355
+ )
356
+
357
+ return to_object
358
+
359
+
360
+ def _AutoraterConfig_to_vertex(
361
+ from_object: Union[dict[str, Any], object],
362
+ parent_object: Optional[dict[str, Any]] = None,
363
+ ) -> dict[str, Any]:
364
+ to_object: dict[str, Any] = {}
365
+ if getv(from_object, ['sampling_count']) is not None:
366
+ setv(to_object, ['samplingCount'], getv(from_object, ['sampling_count']))
367
+
368
+ if getv(from_object, ['flip_enabled']) is not None:
369
+ setv(to_object, ['flipEnabled'], getv(from_object, ['flip_enabled']))
370
+
371
+ if getv(from_object, ['autorater_model']) is not None:
372
+ setv(to_object, ['autoraterModel'], getv(from_object, ['autorater_model']))
373
+
374
+ return to_object
375
+
376
+
377
+ def _EvaluationConfig_to_vertex(
378
+ from_object: Union[dict[str, Any], object],
379
+ parent_object: Optional[dict[str, Any]] = None,
380
+ ) -> dict[str, Any]:
381
+ to_object: dict[str, Any] = {}
382
+ if getv(from_object, ['metrics']) is not None:
383
+ setv(to_object, ['metrics'], t.t_metrics(getv(from_object, ['metrics'])))
384
+
385
+ if getv(from_object, ['output_config']) is not None:
386
+ setv(
387
+ to_object,
388
+ ['outputConfig'],
389
+ _OutputConfig_to_vertex(
390
+ getv(from_object, ['output_config']), to_object
391
+ ),
392
+ )
393
+
394
+ if getv(from_object, ['autorater_config']) is not None:
395
+ setv(
396
+ to_object,
397
+ ['autoraterConfig'],
398
+ _AutoraterConfig_to_vertex(
399
+ getv(from_object, ['autorater_config']), to_object
400
+ ),
401
+ )
402
+
403
+ return to_object
404
+
405
+
316
406
  def _CreateTuningJobConfig_to_vertex(
317
407
  from_object: Union[dict[str, Any], object],
318
408
  parent_object: Optional[dict[str, Any]] = None,
@@ -359,6 +449,13 @@ def _CreateTuningJobConfig_to_vertex(
359
449
  getv(from_object, ['export_last_checkpoint_only']),
360
450
  )
361
451
 
452
+ if getv(from_object, ['pre_tuned_model_checkpoint_id']) is not None:
453
+ setv(
454
+ to_object,
455
+ ['preTunedModel', 'checkpointId'],
456
+ getv(from_object, ['pre_tuned_model_checkpoint_id']),
457
+ )
458
+
362
459
  if getv(from_object, ['adapter_size']) is not None:
363
460
  setv(
364
461
  parent_object,
@@ -372,10 +469,19 @@ def _CreateTuningJobConfig_to_vertex(
372
469
  if getv(from_object, ['learning_rate']) is not None:
373
470
  raise ValueError('learning_rate parameter is not supported in Vertex AI.')
374
471
 
472
+ if getv(from_object, ['evaluation_config']) is not None:
473
+ setv(
474
+ parent_object,
475
+ ['supervisedTuningSpec', 'evaluationConfig'],
476
+ _EvaluationConfig_to_vertex(
477
+ getv(from_object, ['evaluation_config']), to_object
478
+ ),
479
+ )
480
+
375
481
  return to_object
376
482
 
377
483
 
378
- def _CreateTuningJobParameters_to_vertex(
484
+ def _CreateTuningJobParametersPrivate_to_vertex(
379
485
  from_object: Union[dict[str, Any], object],
380
486
  parent_object: Optional[dict[str, Any]] = None,
381
487
  ) -> dict[str, Any]:
@@ -383,6 +489,9 @@ def _CreateTuningJobParameters_to_vertex(
383
489
  if getv(from_object, ['base_model']) is not None:
384
490
  setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
385
491
 
492
+ if getv(from_object, ['pre_tuned_model']) is not None:
493
+ setv(to_object, ['preTunedModel'], getv(from_object, ['pre_tuned_model']))
494
+
386
495
  if getv(from_object, ['training_dataset']) is not None:
387
496
  setv(
388
497
  to_object,
@@ -471,11 +580,9 @@ def _TuningJob_from_mldev(
471
580
  _TunedModel_from_mldev(getv(from_object, ['_self']), to_object),
472
581
  )
473
582
 
474
- if getv(from_object, ['distillationSpec']) is not None:
583
+ if getv(from_object, ['customBaseModel']) is not None:
475
584
  setv(
476
- to_object,
477
- ['distillation_spec'],
478
- getv(from_object, ['distillationSpec']),
585
+ to_object, ['custom_base_model'], getv(from_object, ['customBaseModel'])
479
586
  )
480
587
 
481
588
  if getv(from_object, ['experiment']) is not None:
@@ -484,15 +591,12 @@ def _TuningJob_from_mldev(
484
591
  if getv(from_object, ['labels']) is not None:
485
592
  setv(to_object, ['labels'], getv(from_object, ['labels']))
486
593
 
594
+ if getv(from_object, ['outputUri']) is not None:
595
+ setv(to_object, ['output_uri'], getv(from_object, ['outputUri']))
596
+
487
597
  if getv(from_object, ['pipelineJob']) is not None:
488
598
  setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
489
599
 
490
- if getv(from_object, ['satisfiesPzi']) is not None:
491
- setv(to_object, ['satisfies_pzi'], getv(from_object, ['satisfiesPzi']))
492
-
493
- if getv(from_object, ['satisfiesPzs']) is not None:
494
- setv(to_object, ['satisfies_pzs'], getv(from_object, ['satisfiesPzs']))
495
-
496
600
  if getv(from_object, ['serviceAccount']) is not None:
497
601
  setv(to_object, ['service_account'], getv(from_object, ['serviceAccount']))
498
602
 
@@ -601,6 +705,82 @@ def _TunedModel_from_vertex(
601
705
  return to_object
602
706
 
603
707
 
708
+ def _GcsDestination_from_vertex(
709
+ from_object: Union[dict[str, Any], object],
710
+ parent_object: Optional[dict[str, Any]] = None,
711
+ ) -> dict[str, Any]:
712
+ to_object: dict[str, Any] = {}
713
+ if getv(from_object, ['outputUriPrefix']) is not None:
714
+ setv(
715
+ to_object, ['output_uri_prefix'], getv(from_object, ['outputUriPrefix'])
716
+ )
717
+
718
+ return to_object
719
+
720
+
721
+ def _OutputConfig_from_vertex(
722
+ from_object: Union[dict[str, Any], object],
723
+ parent_object: Optional[dict[str, Any]] = None,
724
+ ) -> dict[str, Any]:
725
+ to_object: dict[str, Any] = {}
726
+ if getv(from_object, ['gcsDestination']) is not None:
727
+ setv(
728
+ to_object,
729
+ ['gcs_destination'],
730
+ _GcsDestination_from_vertex(
731
+ getv(from_object, ['gcsDestination']), to_object
732
+ ),
733
+ )
734
+
735
+ return to_object
736
+
737
+
738
+ def _AutoraterConfig_from_vertex(
739
+ from_object: Union[dict[str, Any], object],
740
+ parent_object: Optional[dict[str, Any]] = None,
741
+ ) -> dict[str, Any]:
742
+ to_object: dict[str, Any] = {}
743
+ if getv(from_object, ['samplingCount']) is not None:
744
+ setv(to_object, ['sampling_count'], getv(from_object, ['samplingCount']))
745
+
746
+ if getv(from_object, ['flipEnabled']) is not None:
747
+ setv(to_object, ['flip_enabled'], getv(from_object, ['flipEnabled']))
748
+
749
+ if getv(from_object, ['autoraterModel']) is not None:
750
+ setv(to_object, ['autorater_model'], getv(from_object, ['autoraterModel']))
751
+
752
+ return to_object
753
+
754
+
755
+ def _EvaluationConfig_from_vertex(
756
+ from_object: Union[dict[str, Any], object],
757
+ parent_object: Optional[dict[str, Any]] = None,
758
+ ) -> dict[str, Any]:
759
+ to_object: dict[str, Any] = {}
760
+ if getv(from_object, ['metrics']) is not None:
761
+ setv(to_object, ['metrics'], t.t_metrics(getv(from_object, ['metrics'])))
762
+
763
+ if getv(from_object, ['outputConfig']) is not None:
764
+ setv(
765
+ to_object,
766
+ ['output_config'],
767
+ _OutputConfig_from_vertex(
768
+ getv(from_object, ['outputConfig']), to_object
769
+ ),
770
+ )
771
+
772
+ if getv(from_object, ['autoraterConfig']) is not None:
773
+ setv(
774
+ to_object,
775
+ ['autorater_config'],
776
+ _AutoraterConfig_from_vertex(
777
+ getv(from_object, ['autoraterConfig']), to_object
778
+ ),
779
+ )
780
+
781
+ return to_object
782
+
783
+
604
784
  def _TuningJob_from_vertex(
605
785
  from_object: Union[dict[str, Any], object],
606
786
  parent_object: Optional[dict[str, Any]] = None,
@@ -649,6 +829,9 @@ def _TuningJob_from_vertex(
649
829
  _TunedModel_from_vertex(getv(from_object, ['tunedModel']), to_object),
650
830
  )
651
831
 
832
+ if getv(from_object, ['preTunedModel']) is not None:
833
+ setv(to_object, ['pre_tuned_model'], getv(from_object, ['preTunedModel']))
834
+
652
835
  if getv(from_object, ['supervisedTuningSpec']) is not None:
653
836
  setv(
654
837
  to_object,
@@ -671,11 +854,18 @@ def _TuningJob_from_vertex(
671
854
  getv(from_object, ['partnerModelTuningSpec']),
672
855
  )
673
856
 
674
- if getv(from_object, ['distillationSpec']) is not None:
857
+ if getv(from_object, ['evaluationConfig']) is not None:
675
858
  setv(
676
859
  to_object,
677
- ['distillation_spec'],
678
- getv(from_object, ['distillationSpec']),
860
+ ['evaluation_config'],
861
+ _EvaluationConfig_from_vertex(
862
+ getv(from_object, ['evaluationConfig']), to_object
863
+ ),
864
+ )
865
+
866
+ if getv(from_object, ['customBaseModel']) is not None:
867
+ setv(
868
+ to_object, ['custom_base_model'], getv(from_object, ['customBaseModel'])
679
869
  )
680
870
 
681
871
  if getv(from_object, ['experiment']) is not None:
@@ -684,15 +874,12 @@ def _TuningJob_from_vertex(
684
874
  if getv(from_object, ['labels']) is not None:
685
875
  setv(to_object, ['labels'], getv(from_object, ['labels']))
686
876
 
877
+ if getv(from_object, ['outputUri']) is not None:
878
+ setv(to_object, ['output_uri'], getv(from_object, ['outputUri']))
879
+
687
880
  if getv(from_object, ['pipelineJob']) is not None:
688
881
  setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
689
882
 
690
- if getv(from_object, ['satisfiesPzi']) is not None:
691
- setv(to_object, ['satisfies_pzi'], getv(from_object, ['satisfiesPzi']))
692
-
693
- if getv(from_object, ['satisfiesPzs']) is not None:
694
- setv(to_object, ['satisfies_pzs'], getv(from_object, ['satisfiesPzs']))
695
-
696
883
  if getv(from_object, ['serviceAccount']) is not None:
697
884
  setv(to_object, ['service_account'], getv(from_object, ['serviceAccount']))
698
885
 
@@ -875,7 +1062,8 @@ class Tunings(_api_module.BaseModule):
875
1062
  def _tune(
876
1063
  self,
877
1064
  *,
878
- base_model: str,
1065
+ base_model: Optional[str] = None,
1066
+ pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
879
1067
  training_dataset: types.TuningDatasetOrDict,
880
1068
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
881
1069
  ) -> types.TuningJob:
@@ -890,8 +1078,9 @@ class Tunings(_api_module.BaseModule):
890
1078
  A TuningJob object.
891
1079
  """
892
1080
 
893
- parameter_model = types._CreateTuningJobParameters(
1081
+ parameter_model = types._CreateTuningJobParametersPrivate(
894
1082
  base_model=base_model,
1083
+ pre_tuned_model=pre_tuned_model,
895
1084
  training_dataset=training_dataset,
896
1085
  config=config,
897
1086
  )
@@ -900,7 +1089,9 @@ class Tunings(_api_module.BaseModule):
900
1089
  if not self._api_client.vertexai:
901
1090
  raise ValueError('This method is only supported in the Vertex AI client.')
902
1091
  else:
903
- request_dict = _CreateTuningJobParameters_to_vertex(parameter_model)
1092
+ request_dict = _CreateTuningJobParametersPrivate_to_vertex(
1093
+ parameter_model
1094
+ )
904
1095
  request_url_dict = request_dict.get('_url')
905
1096
  if request_url_dict:
906
1097
  path = 'tuningJobs'.format_map(request_url_dict)
@@ -944,7 +1135,8 @@ class Tunings(_api_module.BaseModule):
944
1135
  def _tune_mldev(
945
1136
  self,
946
1137
  *,
947
- base_model: str,
1138
+ base_model: Optional[str] = None,
1139
+ pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
948
1140
  training_dataset: types.TuningDatasetOrDict,
949
1141
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
950
1142
  ) -> types.TuningOperation:
@@ -959,8 +1151,9 @@ class Tunings(_api_module.BaseModule):
959
1151
  A TuningJob operation.
960
1152
  """
961
1153
 
962
- parameter_model = types._CreateTuningJobParameters(
1154
+ parameter_model = types._CreateTuningJobParametersPrivate(
963
1155
  base_model=base_model,
1156
+ pre_tuned_model=pre_tuned_model,
964
1157
  training_dataset=training_dataset,
965
1158
  config=config,
966
1159
  )
@@ -971,7 +1164,7 @@ class Tunings(_api_module.BaseModule):
971
1164
  'This method is only supported in the Gemini Developer client.'
972
1165
  )
973
1166
  else:
974
- request_dict = _CreateTuningJobParameters_to_mldev(parameter_model)
1167
+ request_dict = _CreateTuningJobParametersPrivate_to_mldev(parameter_model)
975
1168
  request_url_dict = request_dict.get('_url')
976
1169
  if request_url_dict:
977
1170
  path = 'tunedModels'.format_map(request_url_dict)
@@ -1052,11 +1245,50 @@ class Tunings(_api_module.BaseModule):
1052
1245
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1053
1246
  ) -> types.TuningJob:
1054
1247
  if self._api_client.vertexai:
1055
- tuning_job = self._tune(
1056
- base_model=base_model,
1057
- training_dataset=training_dataset,
1058
- config=config,
1059
- )
1248
+ if base_model.startswith('projects/'): # Pre-tuned model
1249
+ pre_tuned_model = types.PreTunedModel(tuned_model_name=base_model)
1250
+ tuning_job = self._tune(
1251
+ pre_tuned_model=pre_tuned_model,
1252
+ training_dataset=training_dataset,
1253
+ config=config,
1254
+ )
1255
+ else:
1256
+ validated_evaluation_config: Optional[types.EvaluationConfig] = None
1257
+ if (
1258
+ config is not None
1259
+ and getattr(config, 'evaluation_config', None) is not None
1260
+ ):
1261
+ evaluation_config = getattr(config, 'evaluation_config')
1262
+ if isinstance(evaluation_config, dict):
1263
+ evaluation_config = types.EvaluationConfig(**evaluation_config)
1264
+ if (
1265
+ not evaluation_config.metrics
1266
+ or not evaluation_config.output_config
1267
+ ):
1268
+ raise ValueError(
1269
+ 'Evaluation config must have at least one metric and an output'
1270
+ ' config.'
1271
+ )
1272
+ for i in range(len(evaluation_config.metrics)):
1273
+ if isinstance(evaluation_config.metrics[i], dict):
1274
+ evaluation_config.metrics[i] = types.Metric.model_validate(
1275
+ evaluation_config.metrics[i]
1276
+ )
1277
+ if isinstance(config, dict):
1278
+ config['evaluation_config'] = evaluation_config
1279
+ else:
1280
+ config.evaluation_config = evaluation_config
1281
+ validated_evaluation_config = evaluation_config
1282
+ tuning_job = self._tune(
1283
+ base_model=base_model,
1284
+ training_dataset=training_dataset,
1285
+ config=config,
1286
+ )
1287
+ if (
1288
+ config is not None
1289
+ and getattr(config, 'evaluation_config', None) is not None
1290
+ ):
1291
+ tuning_job.evaluation_config = validated_evaluation_config
1060
1292
  else:
1061
1293
  operation = self._tune_mldev(
1062
1294
  base_model=base_model,
@@ -1227,7 +1459,8 @@ class AsyncTunings(_api_module.BaseModule):
1227
1459
  async def _tune(
1228
1460
  self,
1229
1461
  *,
1230
- base_model: str,
1462
+ base_model: Optional[str] = None,
1463
+ pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
1231
1464
  training_dataset: types.TuningDatasetOrDict,
1232
1465
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1233
1466
  ) -> types.TuningJob:
@@ -1242,8 +1475,9 @@ class AsyncTunings(_api_module.BaseModule):
1242
1475
  A TuningJob object.
1243
1476
  """
1244
1477
 
1245
- parameter_model = types._CreateTuningJobParameters(
1478
+ parameter_model = types._CreateTuningJobParametersPrivate(
1246
1479
  base_model=base_model,
1480
+ pre_tuned_model=pre_tuned_model,
1247
1481
  training_dataset=training_dataset,
1248
1482
  config=config,
1249
1483
  )
@@ -1252,7 +1486,9 @@ class AsyncTunings(_api_module.BaseModule):
1252
1486
  if not self._api_client.vertexai:
1253
1487
  raise ValueError('This method is only supported in the Vertex AI client.')
1254
1488
  else:
1255
- request_dict = _CreateTuningJobParameters_to_vertex(parameter_model)
1489
+ request_dict = _CreateTuningJobParametersPrivate_to_vertex(
1490
+ parameter_model
1491
+ )
1256
1492
  request_url_dict = request_dict.get('_url')
1257
1493
  if request_url_dict:
1258
1494
  path = 'tuningJobs'.format_map(request_url_dict)
@@ -1296,7 +1532,8 @@ class AsyncTunings(_api_module.BaseModule):
1296
1532
  async def _tune_mldev(
1297
1533
  self,
1298
1534
  *,
1299
- base_model: str,
1535
+ base_model: Optional[str] = None,
1536
+ pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
1300
1537
  training_dataset: types.TuningDatasetOrDict,
1301
1538
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1302
1539
  ) -> types.TuningOperation:
@@ -1311,8 +1548,9 @@ class AsyncTunings(_api_module.BaseModule):
1311
1548
  A TuningJob operation.
1312
1549
  """
1313
1550
 
1314
- parameter_model = types._CreateTuningJobParameters(
1551
+ parameter_model = types._CreateTuningJobParametersPrivate(
1315
1552
  base_model=base_model,
1553
+ pre_tuned_model=pre_tuned_model,
1316
1554
  training_dataset=training_dataset,
1317
1555
  config=config,
1318
1556
  )
@@ -1323,7 +1561,7 @@ class AsyncTunings(_api_module.BaseModule):
1323
1561
  'This method is only supported in the Gemini Developer client.'
1324
1562
  )
1325
1563
  else:
1326
- request_dict = _CreateTuningJobParameters_to_mldev(parameter_model)
1564
+ request_dict = _CreateTuningJobParametersPrivate_to_mldev(parameter_model)
1327
1565
  request_url_dict = request_dict.get('_url')
1328
1566
  if request_url_dict:
1329
1567
  path = 'tunedModels'.format_map(request_url_dict)
@@ -1404,11 +1642,44 @@ class AsyncTunings(_api_module.BaseModule):
1404
1642
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1405
1643
  ) -> types.TuningJob:
1406
1644
  if self._api_client.vertexai:
1407
- tuning_job = await self._tune(
1408
- base_model=base_model,
1409
- training_dataset=training_dataset,
1410
- config=config,
1411
- )
1645
+ if base_model.startswith('projects/'): # Pre-tuned model
1646
+ pre_tuned_model = types.PreTunedModel(tuned_model_name=base_model)
1647
+
1648
+ tuning_job = await self._tune(
1649
+ pre_tuned_model=pre_tuned_model,
1650
+ training_dataset=training_dataset,
1651
+ config=config,
1652
+ )
1653
+ else:
1654
+ if (
1655
+ config is not None
1656
+ and getattr(config, 'evaluation_config', None) is not None
1657
+ ):
1658
+ evaluation_config = getattr(config, 'evaluation_config')
1659
+ if isinstance(evaluation_config, dict):
1660
+ evaluation_config = types.EvaluationConfig(**evaluation_config)
1661
+ if (
1662
+ not evaluation_config.metrics
1663
+ or not evaluation_config.output_config
1664
+ ):
1665
+ raise ValueError(
1666
+ 'Evaluation config must have at least one metric and an output'
1667
+ ' config.'
1668
+ )
1669
+ for i in range(len(evaluation_config.metrics)):
1670
+ if isinstance(evaluation_config.metrics[i], dict):
1671
+ evaluation_config.metrics[i] = types.Metric.model_validate(
1672
+ evaluation_config.metrics[i]
1673
+ )
1674
+ if isinstance(config, dict):
1675
+ config['evaluation_config'] = evaluation_config
1676
+ else:
1677
+ config.evaluation_config = evaluation_config
1678
+ tuning_job = await self._tune(
1679
+ base_model=base_model,
1680
+ training_dataset=training_dataset,
1681
+ config=config,
1682
+ )
1412
1683
  else:
1413
1684
  operation = await self._tune_mldev(
1414
1685
  base_model=base_model,