google-genai 1.29.0__py3-none-any.whl → 1.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +41 -37
- google/genai/_automatic_function_calling_util.py +12 -0
- google/genai/_live_converters.py +51 -6
- google/genai/_tokens_converters.py +26 -3
- google/genai/_transformers.py +51 -0
- google/genai/batches.py +166 -3
- google/genai/caches.py +51 -6
- google/genai/chats.py +1 -0
- google/genai/files.py +1 -0
- google/genai/live.py +92 -88
- google/genai/models.py +416 -16
- google/genai/operations.py +1 -0
- google/genai/tokens.py +1 -0
- google/genai/tunings.py +315 -43
- google/genai/types.py +1518 -421
- google/genai/version.py +1 -1
- {google_genai-1.29.0.dist-info → google_genai-1.31.0.dist-info}/METADATA +1 -1
- google_genai-1.31.0.dist-info/RECORD +35 -0
- google_genai-1.29.0.dist-info/RECORD +0 -35
- {google_genai-1.29.0.dist-info → google_genai-1.31.0.dist-info}/WHEEL +0 -0
- {google_genai-1.29.0.dist-info → google_genai-1.31.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.29.0.dist-info → google_genai-1.31.0.dist-info}/top_level.txt +0 -0
google/genai/tunings.py
CHANGED
@@ -28,6 +28,7 @@ from ._common import get_value_by_path as getv
|
|
28
28
|
from ._common import set_value_by_path as setv
|
29
29
|
from .pagers import AsyncPager, Pager
|
30
30
|
|
31
|
+
|
31
32
|
logger = logging.getLogger('google_genai.tunings')
|
32
33
|
|
33
34
|
|
@@ -166,6 +167,12 @@ def _CreateTuningJobConfig_to_mldev(
|
|
166
167
|
'export_last_checkpoint_only parameter is not supported in Gemini API.'
|
167
168
|
)
|
168
169
|
|
170
|
+
if getv(from_object, ['pre_tuned_model_checkpoint_id']) is not None:
|
171
|
+
raise ValueError(
|
172
|
+
'pre_tuned_model_checkpoint_id parameter is not supported in Gemini'
|
173
|
+
' API.'
|
174
|
+
)
|
175
|
+
|
169
176
|
if getv(from_object, ['adapter_size']) is not None:
|
170
177
|
raise ValueError('adapter_size parameter is not supported in Gemini API.')
|
171
178
|
|
@@ -183,10 +190,15 @@ def _CreateTuningJobConfig_to_mldev(
|
|
183
190
|
getv(from_object, ['learning_rate']),
|
184
191
|
)
|
185
192
|
|
193
|
+
if getv(from_object, ['evaluation_config']) is not None:
|
194
|
+
raise ValueError(
|
195
|
+
'evaluation_config parameter is not supported in Gemini API.'
|
196
|
+
)
|
197
|
+
|
186
198
|
return to_object
|
187
199
|
|
188
200
|
|
189
|
-
def
|
201
|
+
def _CreateTuningJobParametersPrivate_to_mldev(
|
190
202
|
from_object: Union[dict[str, Any], object],
|
191
203
|
parent_object: Optional[dict[str, Any]] = None,
|
192
204
|
) -> dict[str, Any]:
|
@@ -194,6 +206,9 @@ def _CreateTuningJobParameters_to_mldev(
|
|
194
206
|
if getv(from_object, ['base_model']) is not None:
|
195
207
|
setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
|
196
208
|
|
209
|
+
if getv(from_object, ['pre_tuned_model']) is not None:
|
210
|
+
setv(to_object, ['preTunedModel'], getv(from_object, ['pre_tuned_model']))
|
211
|
+
|
197
212
|
if getv(from_object, ['training_dataset']) is not None:
|
198
213
|
setv(
|
199
214
|
to_object,
|
@@ -313,6 +328,82 @@ def _TuningValidationDataset_to_vertex(
|
|
313
328
|
return to_object
|
314
329
|
|
315
330
|
|
331
|
+
def _GcsDestination_to_vertex(
|
332
|
+
from_object: Union[dict[str, Any], object],
|
333
|
+
parent_object: Optional[dict[str, Any]] = None,
|
334
|
+
) -> dict[str, Any]:
|
335
|
+
to_object: dict[str, Any] = {}
|
336
|
+
if getv(from_object, ['output_uri_prefix']) is not None:
|
337
|
+
setv(
|
338
|
+
to_object, ['outputUriPrefix'], getv(from_object, ['output_uri_prefix'])
|
339
|
+
)
|
340
|
+
|
341
|
+
return to_object
|
342
|
+
|
343
|
+
|
344
|
+
def _OutputConfig_to_vertex(
|
345
|
+
from_object: Union[dict[str, Any], object],
|
346
|
+
parent_object: Optional[dict[str, Any]] = None,
|
347
|
+
) -> dict[str, Any]:
|
348
|
+
to_object: dict[str, Any] = {}
|
349
|
+
if getv(from_object, ['gcs_destination']) is not None:
|
350
|
+
setv(
|
351
|
+
to_object,
|
352
|
+
['gcsDestination'],
|
353
|
+
_GcsDestination_to_vertex(
|
354
|
+
getv(from_object, ['gcs_destination']), to_object
|
355
|
+
),
|
356
|
+
)
|
357
|
+
|
358
|
+
return to_object
|
359
|
+
|
360
|
+
|
361
|
+
def _AutoraterConfig_to_vertex(
|
362
|
+
from_object: Union[dict[str, Any], object],
|
363
|
+
parent_object: Optional[dict[str, Any]] = None,
|
364
|
+
) -> dict[str, Any]:
|
365
|
+
to_object: dict[str, Any] = {}
|
366
|
+
if getv(from_object, ['sampling_count']) is not None:
|
367
|
+
setv(to_object, ['samplingCount'], getv(from_object, ['sampling_count']))
|
368
|
+
|
369
|
+
if getv(from_object, ['flip_enabled']) is not None:
|
370
|
+
setv(to_object, ['flipEnabled'], getv(from_object, ['flip_enabled']))
|
371
|
+
|
372
|
+
if getv(from_object, ['autorater_model']) is not None:
|
373
|
+
setv(to_object, ['autoraterModel'], getv(from_object, ['autorater_model']))
|
374
|
+
|
375
|
+
return to_object
|
376
|
+
|
377
|
+
|
378
|
+
def _EvaluationConfig_to_vertex(
|
379
|
+
from_object: Union[dict[str, Any], object],
|
380
|
+
parent_object: Optional[dict[str, Any]] = None,
|
381
|
+
) -> dict[str, Any]:
|
382
|
+
to_object: dict[str, Any] = {}
|
383
|
+
if getv(from_object, ['metrics']) is not None:
|
384
|
+
setv(to_object, ['metrics'], t.t_metrics(getv(from_object, ['metrics'])))
|
385
|
+
|
386
|
+
if getv(from_object, ['output_config']) is not None:
|
387
|
+
setv(
|
388
|
+
to_object,
|
389
|
+
['outputConfig'],
|
390
|
+
_OutputConfig_to_vertex(
|
391
|
+
getv(from_object, ['output_config']), to_object
|
392
|
+
),
|
393
|
+
)
|
394
|
+
|
395
|
+
if getv(from_object, ['autorater_config']) is not None:
|
396
|
+
setv(
|
397
|
+
to_object,
|
398
|
+
['autoraterConfig'],
|
399
|
+
_AutoraterConfig_to_vertex(
|
400
|
+
getv(from_object, ['autorater_config']), to_object
|
401
|
+
),
|
402
|
+
)
|
403
|
+
|
404
|
+
return to_object
|
405
|
+
|
406
|
+
|
316
407
|
def _CreateTuningJobConfig_to_vertex(
|
317
408
|
from_object: Union[dict[str, Any], object],
|
318
409
|
parent_object: Optional[dict[str, Any]] = None,
|
@@ -359,6 +450,13 @@ def _CreateTuningJobConfig_to_vertex(
|
|
359
450
|
getv(from_object, ['export_last_checkpoint_only']),
|
360
451
|
)
|
361
452
|
|
453
|
+
if getv(from_object, ['pre_tuned_model_checkpoint_id']) is not None:
|
454
|
+
setv(
|
455
|
+
to_object,
|
456
|
+
['preTunedModel', 'checkpointId'],
|
457
|
+
getv(from_object, ['pre_tuned_model_checkpoint_id']),
|
458
|
+
)
|
459
|
+
|
362
460
|
if getv(from_object, ['adapter_size']) is not None:
|
363
461
|
setv(
|
364
462
|
parent_object,
|
@@ -372,10 +470,19 @@ def _CreateTuningJobConfig_to_vertex(
|
|
372
470
|
if getv(from_object, ['learning_rate']) is not None:
|
373
471
|
raise ValueError('learning_rate parameter is not supported in Vertex AI.')
|
374
472
|
|
473
|
+
if getv(from_object, ['evaluation_config']) is not None:
|
474
|
+
setv(
|
475
|
+
parent_object,
|
476
|
+
['supervisedTuningSpec', 'evaluationConfig'],
|
477
|
+
_EvaluationConfig_to_vertex(
|
478
|
+
getv(from_object, ['evaluation_config']), to_object
|
479
|
+
),
|
480
|
+
)
|
481
|
+
|
375
482
|
return to_object
|
376
483
|
|
377
484
|
|
378
|
-
def
|
485
|
+
def _CreateTuningJobParametersPrivate_to_vertex(
|
379
486
|
from_object: Union[dict[str, Any], object],
|
380
487
|
parent_object: Optional[dict[str, Any]] = None,
|
381
488
|
) -> dict[str, Any]:
|
@@ -383,6 +490,9 @@ def _CreateTuningJobParameters_to_vertex(
|
|
383
490
|
if getv(from_object, ['base_model']) is not None:
|
384
491
|
setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
|
385
492
|
|
493
|
+
if getv(from_object, ['pre_tuned_model']) is not None:
|
494
|
+
setv(to_object, ['preTunedModel'], getv(from_object, ['pre_tuned_model']))
|
495
|
+
|
386
496
|
if getv(from_object, ['training_dataset']) is not None:
|
387
497
|
setv(
|
388
498
|
to_object,
|
@@ -471,11 +581,9 @@ def _TuningJob_from_mldev(
|
|
471
581
|
_TunedModel_from_mldev(getv(from_object, ['_self']), to_object),
|
472
582
|
)
|
473
583
|
|
474
|
-
if getv(from_object, ['
|
584
|
+
if getv(from_object, ['customBaseModel']) is not None:
|
475
585
|
setv(
|
476
|
-
to_object,
|
477
|
-
['distillation_spec'],
|
478
|
-
getv(from_object, ['distillationSpec']),
|
586
|
+
to_object, ['custom_base_model'], getv(from_object, ['customBaseModel'])
|
479
587
|
)
|
480
588
|
|
481
589
|
if getv(from_object, ['experiment']) is not None:
|
@@ -484,15 +592,12 @@ def _TuningJob_from_mldev(
|
|
484
592
|
if getv(from_object, ['labels']) is not None:
|
485
593
|
setv(to_object, ['labels'], getv(from_object, ['labels']))
|
486
594
|
|
595
|
+
if getv(from_object, ['outputUri']) is not None:
|
596
|
+
setv(to_object, ['output_uri'], getv(from_object, ['outputUri']))
|
597
|
+
|
487
598
|
if getv(from_object, ['pipelineJob']) is not None:
|
488
599
|
setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
|
489
600
|
|
490
|
-
if getv(from_object, ['satisfiesPzi']) is not None:
|
491
|
-
setv(to_object, ['satisfies_pzi'], getv(from_object, ['satisfiesPzi']))
|
492
|
-
|
493
|
-
if getv(from_object, ['satisfiesPzs']) is not None:
|
494
|
-
setv(to_object, ['satisfies_pzs'], getv(from_object, ['satisfiesPzs']))
|
495
|
-
|
496
601
|
if getv(from_object, ['serviceAccount']) is not None:
|
497
602
|
setv(to_object, ['service_account'], getv(from_object, ['serviceAccount']))
|
498
603
|
|
@@ -601,6 +706,82 @@ def _TunedModel_from_vertex(
|
|
601
706
|
return to_object
|
602
707
|
|
603
708
|
|
709
|
+
def _GcsDestination_from_vertex(
|
710
|
+
from_object: Union[dict[str, Any], object],
|
711
|
+
parent_object: Optional[dict[str, Any]] = None,
|
712
|
+
) -> dict[str, Any]:
|
713
|
+
to_object: dict[str, Any] = {}
|
714
|
+
if getv(from_object, ['outputUriPrefix']) is not None:
|
715
|
+
setv(
|
716
|
+
to_object, ['output_uri_prefix'], getv(from_object, ['outputUriPrefix'])
|
717
|
+
)
|
718
|
+
|
719
|
+
return to_object
|
720
|
+
|
721
|
+
|
722
|
+
def _OutputConfig_from_vertex(
|
723
|
+
from_object: Union[dict[str, Any], object],
|
724
|
+
parent_object: Optional[dict[str, Any]] = None,
|
725
|
+
) -> dict[str, Any]:
|
726
|
+
to_object: dict[str, Any] = {}
|
727
|
+
if getv(from_object, ['gcsDestination']) is not None:
|
728
|
+
setv(
|
729
|
+
to_object,
|
730
|
+
['gcs_destination'],
|
731
|
+
_GcsDestination_from_vertex(
|
732
|
+
getv(from_object, ['gcsDestination']), to_object
|
733
|
+
),
|
734
|
+
)
|
735
|
+
|
736
|
+
return to_object
|
737
|
+
|
738
|
+
|
739
|
+
def _AutoraterConfig_from_vertex(
|
740
|
+
from_object: Union[dict[str, Any], object],
|
741
|
+
parent_object: Optional[dict[str, Any]] = None,
|
742
|
+
) -> dict[str, Any]:
|
743
|
+
to_object: dict[str, Any] = {}
|
744
|
+
if getv(from_object, ['samplingCount']) is not None:
|
745
|
+
setv(to_object, ['sampling_count'], getv(from_object, ['samplingCount']))
|
746
|
+
|
747
|
+
if getv(from_object, ['flipEnabled']) is not None:
|
748
|
+
setv(to_object, ['flip_enabled'], getv(from_object, ['flipEnabled']))
|
749
|
+
|
750
|
+
if getv(from_object, ['autoraterModel']) is not None:
|
751
|
+
setv(to_object, ['autorater_model'], getv(from_object, ['autoraterModel']))
|
752
|
+
|
753
|
+
return to_object
|
754
|
+
|
755
|
+
|
756
|
+
def _EvaluationConfig_from_vertex(
|
757
|
+
from_object: Union[dict[str, Any], object],
|
758
|
+
parent_object: Optional[dict[str, Any]] = None,
|
759
|
+
) -> dict[str, Any]:
|
760
|
+
to_object: dict[str, Any] = {}
|
761
|
+
if getv(from_object, ['metrics']) is not None:
|
762
|
+
setv(to_object, ['metrics'], t.t_metrics(getv(from_object, ['metrics'])))
|
763
|
+
|
764
|
+
if getv(from_object, ['outputConfig']) is not None:
|
765
|
+
setv(
|
766
|
+
to_object,
|
767
|
+
['output_config'],
|
768
|
+
_OutputConfig_from_vertex(
|
769
|
+
getv(from_object, ['outputConfig']), to_object
|
770
|
+
),
|
771
|
+
)
|
772
|
+
|
773
|
+
if getv(from_object, ['autoraterConfig']) is not None:
|
774
|
+
setv(
|
775
|
+
to_object,
|
776
|
+
['autorater_config'],
|
777
|
+
_AutoraterConfig_from_vertex(
|
778
|
+
getv(from_object, ['autoraterConfig']), to_object
|
779
|
+
),
|
780
|
+
)
|
781
|
+
|
782
|
+
return to_object
|
783
|
+
|
784
|
+
|
604
785
|
def _TuningJob_from_vertex(
|
605
786
|
from_object: Union[dict[str, Any], object],
|
606
787
|
parent_object: Optional[dict[str, Any]] = None,
|
@@ -649,6 +830,9 @@ def _TuningJob_from_vertex(
|
|
649
830
|
_TunedModel_from_vertex(getv(from_object, ['tunedModel']), to_object),
|
650
831
|
)
|
651
832
|
|
833
|
+
if getv(from_object, ['preTunedModel']) is not None:
|
834
|
+
setv(to_object, ['pre_tuned_model'], getv(from_object, ['preTunedModel']))
|
835
|
+
|
652
836
|
if getv(from_object, ['supervisedTuningSpec']) is not None:
|
653
837
|
setv(
|
654
838
|
to_object,
|
@@ -671,11 +855,18 @@ def _TuningJob_from_vertex(
|
|
671
855
|
getv(from_object, ['partnerModelTuningSpec']),
|
672
856
|
)
|
673
857
|
|
674
|
-
if getv(from_object, ['
|
858
|
+
if getv(from_object, ['evaluationConfig']) is not None:
|
675
859
|
setv(
|
676
860
|
to_object,
|
677
|
-
['
|
678
|
-
|
861
|
+
['evaluation_config'],
|
862
|
+
_EvaluationConfig_from_vertex(
|
863
|
+
getv(from_object, ['evaluationConfig']), to_object
|
864
|
+
),
|
865
|
+
)
|
866
|
+
|
867
|
+
if getv(from_object, ['customBaseModel']) is not None:
|
868
|
+
setv(
|
869
|
+
to_object, ['custom_base_model'], getv(from_object, ['customBaseModel'])
|
679
870
|
)
|
680
871
|
|
681
872
|
if getv(from_object, ['experiment']) is not None:
|
@@ -684,15 +875,12 @@ def _TuningJob_from_vertex(
|
|
684
875
|
if getv(from_object, ['labels']) is not None:
|
685
876
|
setv(to_object, ['labels'], getv(from_object, ['labels']))
|
686
877
|
|
878
|
+
if getv(from_object, ['outputUri']) is not None:
|
879
|
+
setv(to_object, ['output_uri'], getv(from_object, ['outputUri']))
|
880
|
+
|
687
881
|
if getv(from_object, ['pipelineJob']) is not None:
|
688
882
|
setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
|
689
883
|
|
690
|
-
if getv(from_object, ['satisfiesPzi']) is not None:
|
691
|
-
setv(to_object, ['satisfies_pzi'], getv(from_object, ['satisfiesPzi']))
|
692
|
-
|
693
|
-
if getv(from_object, ['satisfiesPzs']) is not None:
|
694
|
-
setv(to_object, ['satisfies_pzs'], getv(from_object, ['satisfiesPzs']))
|
695
|
-
|
696
884
|
if getv(from_object, ['serviceAccount']) is not None:
|
697
885
|
setv(to_object, ['service_account'], getv(from_object, ['serviceAccount']))
|
698
886
|
|
@@ -875,7 +1063,8 @@ class Tunings(_api_module.BaseModule):
|
|
875
1063
|
def _tune(
|
876
1064
|
self,
|
877
1065
|
*,
|
878
|
-
base_model: str,
|
1066
|
+
base_model: Optional[str] = None,
|
1067
|
+
pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
|
879
1068
|
training_dataset: types.TuningDatasetOrDict,
|
880
1069
|
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
881
1070
|
) -> types.TuningJob:
|
@@ -890,8 +1079,9 @@ class Tunings(_api_module.BaseModule):
|
|
890
1079
|
A TuningJob object.
|
891
1080
|
"""
|
892
1081
|
|
893
|
-
parameter_model = types.
|
1082
|
+
parameter_model = types._CreateTuningJobParametersPrivate(
|
894
1083
|
base_model=base_model,
|
1084
|
+
pre_tuned_model=pre_tuned_model,
|
895
1085
|
training_dataset=training_dataset,
|
896
1086
|
config=config,
|
897
1087
|
)
|
@@ -900,7 +1090,9 @@ class Tunings(_api_module.BaseModule):
|
|
900
1090
|
if not self._api_client.vertexai:
|
901
1091
|
raise ValueError('This method is only supported in the Vertex AI client.')
|
902
1092
|
else:
|
903
|
-
request_dict =
|
1093
|
+
request_dict = _CreateTuningJobParametersPrivate_to_vertex(
|
1094
|
+
parameter_model
|
1095
|
+
)
|
904
1096
|
request_url_dict = request_dict.get('_url')
|
905
1097
|
if request_url_dict:
|
906
1098
|
path = 'tuningJobs'.format_map(request_url_dict)
|
@@ -944,7 +1136,8 @@ class Tunings(_api_module.BaseModule):
|
|
944
1136
|
def _tune_mldev(
|
945
1137
|
self,
|
946
1138
|
*,
|
947
|
-
base_model: str,
|
1139
|
+
base_model: Optional[str] = None,
|
1140
|
+
pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
|
948
1141
|
training_dataset: types.TuningDatasetOrDict,
|
949
1142
|
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
950
1143
|
) -> types.TuningOperation:
|
@@ -959,8 +1152,9 @@ class Tunings(_api_module.BaseModule):
|
|
959
1152
|
A TuningJob operation.
|
960
1153
|
"""
|
961
1154
|
|
962
|
-
parameter_model = types.
|
1155
|
+
parameter_model = types._CreateTuningJobParametersPrivate(
|
963
1156
|
base_model=base_model,
|
1157
|
+
pre_tuned_model=pre_tuned_model,
|
964
1158
|
training_dataset=training_dataset,
|
965
1159
|
config=config,
|
966
1160
|
)
|
@@ -971,7 +1165,7 @@ class Tunings(_api_module.BaseModule):
|
|
971
1165
|
'This method is only supported in the Gemini Developer client.'
|
972
1166
|
)
|
973
1167
|
else:
|
974
|
-
request_dict =
|
1168
|
+
request_dict = _CreateTuningJobParametersPrivate_to_mldev(parameter_model)
|
975
1169
|
request_url_dict = request_dict.get('_url')
|
976
1170
|
if request_url_dict:
|
977
1171
|
path = 'tunedModels'.format_map(request_url_dict)
|
@@ -1052,11 +1246,50 @@ class Tunings(_api_module.BaseModule):
|
|
1052
1246
|
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
1053
1247
|
) -> types.TuningJob:
|
1054
1248
|
if self._api_client.vertexai:
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1249
|
+
if base_model.startswith('projects/'): # Pre-tuned model
|
1250
|
+
pre_tuned_model = types.PreTunedModel(tuned_model_name=base_model)
|
1251
|
+
tuning_job = self._tune(
|
1252
|
+
pre_tuned_model=pre_tuned_model,
|
1253
|
+
training_dataset=training_dataset,
|
1254
|
+
config=config,
|
1255
|
+
)
|
1256
|
+
else:
|
1257
|
+
validated_evaluation_config: Optional[types.EvaluationConfig] = None
|
1258
|
+
if (
|
1259
|
+
config is not None
|
1260
|
+
and getattr(config, 'evaluation_config', None) is not None
|
1261
|
+
):
|
1262
|
+
evaluation_config = getattr(config, 'evaluation_config')
|
1263
|
+
if isinstance(evaluation_config, dict):
|
1264
|
+
evaluation_config = types.EvaluationConfig(**evaluation_config)
|
1265
|
+
if (
|
1266
|
+
not evaluation_config.metrics
|
1267
|
+
or not evaluation_config.output_config
|
1268
|
+
):
|
1269
|
+
raise ValueError(
|
1270
|
+
'Evaluation config must have at least one metric and an output'
|
1271
|
+
' config.'
|
1272
|
+
)
|
1273
|
+
for i in range(len(evaluation_config.metrics)):
|
1274
|
+
if isinstance(evaluation_config.metrics[i], dict):
|
1275
|
+
evaluation_config.metrics[i] = types.Metric.model_validate(
|
1276
|
+
evaluation_config.metrics[i]
|
1277
|
+
)
|
1278
|
+
if isinstance(config, dict):
|
1279
|
+
config['evaluation_config'] = evaluation_config
|
1280
|
+
else:
|
1281
|
+
config.evaluation_config = evaluation_config
|
1282
|
+
validated_evaluation_config = evaluation_config
|
1283
|
+
tuning_job = self._tune(
|
1284
|
+
base_model=base_model,
|
1285
|
+
training_dataset=training_dataset,
|
1286
|
+
config=config,
|
1287
|
+
)
|
1288
|
+
if (
|
1289
|
+
config is not None
|
1290
|
+
and getattr(config, 'evaluation_config', None) is not None
|
1291
|
+
):
|
1292
|
+
tuning_job.evaluation_config = validated_evaluation_config
|
1060
1293
|
else:
|
1061
1294
|
operation = self._tune_mldev(
|
1062
1295
|
base_model=base_model,
|
@@ -1227,7 +1460,8 @@ class AsyncTunings(_api_module.BaseModule):
|
|
1227
1460
|
async def _tune(
|
1228
1461
|
self,
|
1229
1462
|
*,
|
1230
|
-
base_model: str,
|
1463
|
+
base_model: Optional[str] = None,
|
1464
|
+
pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
|
1231
1465
|
training_dataset: types.TuningDatasetOrDict,
|
1232
1466
|
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
1233
1467
|
) -> types.TuningJob:
|
@@ -1242,8 +1476,9 @@ class AsyncTunings(_api_module.BaseModule):
|
|
1242
1476
|
A TuningJob object.
|
1243
1477
|
"""
|
1244
1478
|
|
1245
|
-
parameter_model = types.
|
1479
|
+
parameter_model = types._CreateTuningJobParametersPrivate(
|
1246
1480
|
base_model=base_model,
|
1481
|
+
pre_tuned_model=pre_tuned_model,
|
1247
1482
|
training_dataset=training_dataset,
|
1248
1483
|
config=config,
|
1249
1484
|
)
|
@@ -1252,7 +1487,9 @@ class AsyncTunings(_api_module.BaseModule):
|
|
1252
1487
|
if not self._api_client.vertexai:
|
1253
1488
|
raise ValueError('This method is only supported in the Vertex AI client.')
|
1254
1489
|
else:
|
1255
|
-
request_dict =
|
1490
|
+
request_dict = _CreateTuningJobParametersPrivate_to_vertex(
|
1491
|
+
parameter_model
|
1492
|
+
)
|
1256
1493
|
request_url_dict = request_dict.get('_url')
|
1257
1494
|
if request_url_dict:
|
1258
1495
|
path = 'tuningJobs'.format_map(request_url_dict)
|
@@ -1296,7 +1533,8 @@ class AsyncTunings(_api_module.BaseModule):
|
|
1296
1533
|
async def _tune_mldev(
|
1297
1534
|
self,
|
1298
1535
|
*,
|
1299
|
-
base_model: str,
|
1536
|
+
base_model: Optional[str] = None,
|
1537
|
+
pre_tuned_model: Optional[types.PreTunedModelOrDict] = None,
|
1300
1538
|
training_dataset: types.TuningDatasetOrDict,
|
1301
1539
|
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
1302
1540
|
) -> types.TuningOperation:
|
@@ -1311,8 +1549,9 @@ class AsyncTunings(_api_module.BaseModule):
|
|
1311
1549
|
A TuningJob operation.
|
1312
1550
|
"""
|
1313
1551
|
|
1314
|
-
parameter_model = types.
|
1552
|
+
parameter_model = types._CreateTuningJobParametersPrivate(
|
1315
1553
|
base_model=base_model,
|
1554
|
+
pre_tuned_model=pre_tuned_model,
|
1316
1555
|
training_dataset=training_dataset,
|
1317
1556
|
config=config,
|
1318
1557
|
)
|
@@ -1323,7 +1562,7 @@ class AsyncTunings(_api_module.BaseModule):
|
|
1323
1562
|
'This method is only supported in the Gemini Developer client.'
|
1324
1563
|
)
|
1325
1564
|
else:
|
1326
|
-
request_dict =
|
1565
|
+
request_dict = _CreateTuningJobParametersPrivate_to_mldev(parameter_model)
|
1327
1566
|
request_url_dict = request_dict.get('_url')
|
1328
1567
|
if request_url_dict:
|
1329
1568
|
path = 'tunedModels'.format_map(request_url_dict)
|
@@ -1404,11 +1643,44 @@ class AsyncTunings(_api_module.BaseModule):
|
|
1404
1643
|
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
1405
1644
|
) -> types.TuningJob:
|
1406
1645
|
if self._api_client.vertexai:
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1646
|
+
if base_model.startswith('projects/'): # Pre-tuned model
|
1647
|
+
pre_tuned_model = types.PreTunedModel(tuned_model_name=base_model)
|
1648
|
+
|
1649
|
+
tuning_job = await self._tune(
|
1650
|
+
pre_tuned_model=pre_tuned_model,
|
1651
|
+
training_dataset=training_dataset,
|
1652
|
+
config=config,
|
1653
|
+
)
|
1654
|
+
else:
|
1655
|
+
if (
|
1656
|
+
config is not None
|
1657
|
+
and getattr(config, 'evaluation_config', None) is not None
|
1658
|
+
):
|
1659
|
+
evaluation_config = getattr(config, 'evaluation_config')
|
1660
|
+
if isinstance(evaluation_config, dict):
|
1661
|
+
evaluation_config = types.EvaluationConfig(**evaluation_config)
|
1662
|
+
if (
|
1663
|
+
not evaluation_config.metrics
|
1664
|
+
or not evaluation_config.output_config
|
1665
|
+
):
|
1666
|
+
raise ValueError(
|
1667
|
+
'Evaluation config must have at least one metric and an output'
|
1668
|
+
' config.'
|
1669
|
+
)
|
1670
|
+
for i in range(len(evaluation_config.metrics)):
|
1671
|
+
if isinstance(evaluation_config.metrics[i], dict):
|
1672
|
+
evaluation_config.metrics[i] = types.Metric.model_validate(
|
1673
|
+
evaluation_config.metrics[i]
|
1674
|
+
)
|
1675
|
+
if isinstance(config, dict):
|
1676
|
+
config['evaluation_config'] = evaluation_config
|
1677
|
+
else:
|
1678
|
+
config.evaluation_config = evaluation_config
|
1679
|
+
tuning_job = await self._tune(
|
1680
|
+
base_model=base_model,
|
1681
|
+
training_dataset=training_dataset,
|
1682
|
+
config=config,
|
1683
|
+
)
|
1412
1684
|
else:
|
1413
1685
|
operation = await self._tune_mldev(
|
1414
1686
|
base_model=base_model,
|