google-genai 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +74 -82
- google/genai/_api_module.py +24 -0
- google/genai/_automatic_function_calling_util.py +43 -22
- google/genai/_common.py +11 -8
- google/genai/_extra_utils.py +22 -16
- google/genai/_operations.py +365 -0
- google/genai/_replay_api_client.py +7 -2
- google/genai/_test_api_client.py +1 -1
- google/genai/_transformers.py +218 -97
- google/genai/batches.py +194 -155
- google/genai/caches.py +117 -134
- google/genai/chats.py +22 -18
- google/genai/client.py +31 -37
- google/genai/files.py +154 -183
- google/genai/live.py +11 -5
- google/genai/models.py +506 -254
- google/genai/tunings.py +85 -422
- google/genai/types.py +647 -458
- google/genai/version.py +1 -1
- {google_genai-0.6.0.dist-info → google_genai-0.8.0.dist-info}/METADATA +119 -70
- google_genai-0.8.0.dist-info/RECORD +27 -0
- google_genai-0.6.0.dist-info/RECORD +0 -25
- {google_genai-0.6.0.dist-info → google_genai-0.8.0.dist-info}/LICENSE +0 -0
- {google_genai-0.6.0.dist-info → google_genai-0.8.0.dist-info}/WHEEL +0 -0
- {google_genai-0.6.0.dist-info → google_genai-0.8.0.dist-info}/top_level.txt +0 -0
google/genai/tunings.py
CHANGED
@@ -17,6 +17,7 @@
|
|
17
17
|
|
18
18
|
from typing import Optional, Union
|
19
19
|
from urllib.parse import urlencode
|
20
|
+
from . import _api_module
|
20
21
|
from . import _common
|
21
22
|
from . import _transformers as t
|
22
23
|
from . import types
|
@@ -26,30 +27,6 @@ from ._common import set_value_by_path as setv
|
|
26
27
|
from .pagers import AsyncPager, Pager
|
27
28
|
|
28
29
|
|
29
|
-
def _GetTuningJobConfig_to_mldev(
|
30
|
-
api_client: ApiClient,
|
31
|
-
from_object: Union[dict, object],
|
32
|
-
parent_object: dict = None,
|
33
|
-
) -> dict:
|
34
|
-
to_object = {}
|
35
|
-
if getv(from_object, ['http_options']) is not None:
|
36
|
-
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
37
|
-
|
38
|
-
return to_object
|
39
|
-
|
40
|
-
|
41
|
-
def _GetTuningJobConfig_to_vertex(
|
42
|
-
api_client: ApiClient,
|
43
|
-
from_object: Union[dict, object],
|
44
|
-
parent_object: dict = None,
|
45
|
-
) -> dict:
|
46
|
-
to_object = {}
|
47
|
-
if getv(from_object, ['http_options']) is not None:
|
48
|
-
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
49
|
-
|
50
|
-
return to_object
|
51
|
-
|
52
|
-
|
53
30
|
def _GetTuningJobParameters_to_mldev(
|
54
31
|
api_client: ApiClient,
|
55
32
|
from_object: Union[dict, object],
|
@@ -60,13 +37,7 @@ def _GetTuningJobParameters_to_mldev(
|
|
60
37
|
setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
|
61
38
|
|
62
39
|
if getv(from_object, ['config']) is not None:
|
63
|
-
setv(
|
64
|
-
to_object,
|
65
|
-
['config'],
|
66
|
-
_GetTuningJobConfig_to_mldev(
|
67
|
-
api_client, getv(from_object, ['config']), to_object
|
68
|
-
),
|
69
|
-
)
|
40
|
+
setv(to_object, ['config'], getv(from_object, ['config']))
|
70
41
|
|
71
42
|
return to_object
|
72
43
|
|
@@ -81,13 +52,7 @@ def _GetTuningJobParameters_to_vertex(
|
|
81
52
|
setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
|
82
53
|
|
83
54
|
if getv(from_object, ['config']) is not None:
|
84
|
-
setv(
|
85
|
-
to_object,
|
86
|
-
['config'],
|
87
|
-
_GetTuningJobConfig_to_vertex(
|
88
|
-
api_client, getv(from_object, ['config']), to_object
|
89
|
-
),
|
90
|
-
)
|
55
|
+
setv(to_object, ['config'], getv(from_object, ['config']))
|
91
56
|
|
92
57
|
return to_object
|
93
58
|
|
@@ -98,6 +63,7 @@ def _ListTuningJobsConfig_to_mldev(
|
|
98
63
|
parent_object: dict = None,
|
99
64
|
) -> dict:
|
100
65
|
to_object = {}
|
66
|
+
|
101
67
|
if getv(from_object, ['page_size']) is not None:
|
102
68
|
setv(
|
103
69
|
parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
|
@@ -122,6 +88,7 @@ def _ListTuningJobsConfig_to_vertex(
|
|
122
88
|
parent_object: dict = None,
|
123
89
|
) -> dict:
|
124
90
|
to_object = {}
|
91
|
+
|
125
92
|
if getv(from_object, ['page_size']) is not None:
|
126
93
|
setv(
|
127
94
|
parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
|
@@ -213,7 +180,7 @@ def _TuningDataset_to_mldev(
|
|
213
180
|
) -> dict:
|
214
181
|
to_object = {}
|
215
182
|
if getv(from_object, ['gcs_uri']) is not None:
|
216
|
-
raise ValueError('gcs_uri parameter is not supported in
|
183
|
+
raise ValueError('gcs_uri parameter is not supported in Gemini API.')
|
217
184
|
|
218
185
|
if getv(from_object, ['examples']) is not None:
|
219
186
|
setv(
|
@@ -254,7 +221,7 @@ def _TuningValidationDataset_to_mldev(
|
|
254
221
|
) -> dict:
|
255
222
|
to_object = {}
|
256
223
|
if getv(from_object, ['gcs_uri']) is not None:
|
257
|
-
raise ValueError('gcs_uri parameter is not supported in
|
224
|
+
raise ValueError('gcs_uri parameter is not supported in Gemini API.')
|
258
225
|
|
259
226
|
return to_object
|
260
227
|
|
@@ -277,12 +244,10 @@ def _CreateTuningJobConfig_to_mldev(
|
|
277
244
|
parent_object: dict = None,
|
278
245
|
) -> dict:
|
279
246
|
to_object = {}
|
280
|
-
if getv(from_object, ['http_options']) is not None:
|
281
|
-
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
282
247
|
|
283
248
|
if getv(from_object, ['validation_dataset']) is not None:
|
284
249
|
raise ValueError(
|
285
|
-
'validation_dataset parameter is not supported in
|
250
|
+
'validation_dataset parameter is not supported in Gemini API.'
|
286
251
|
)
|
287
252
|
|
288
253
|
if getv(from_object, ['tuned_model_display_name']) is not None:
|
@@ -293,7 +258,7 @@ def _CreateTuningJobConfig_to_mldev(
|
|
293
258
|
)
|
294
259
|
|
295
260
|
if getv(from_object, ['description']) is not None:
|
296
|
-
raise ValueError('description parameter is not supported in
|
261
|
+
raise ValueError('description parameter is not supported in Gemini API.')
|
297
262
|
|
298
263
|
if getv(from_object, ['epoch_count']) is not None:
|
299
264
|
setv(
|
@@ -310,7 +275,7 @@ def _CreateTuningJobConfig_to_mldev(
|
|
310
275
|
)
|
311
276
|
|
312
277
|
if getv(from_object, ['adapter_size']) is not None:
|
313
|
-
raise ValueError('adapter_size parameter is not supported in
|
278
|
+
raise ValueError('adapter_size parameter is not supported in Gemini API.')
|
314
279
|
|
315
280
|
if getv(from_object, ['batch_size']) is not None:
|
316
281
|
setv(
|
@@ -335,8 +300,6 @@ def _CreateTuningJobConfig_to_vertex(
|
|
335
300
|
parent_object: dict = None,
|
336
301
|
) -> dict:
|
337
302
|
to_object = {}
|
338
|
-
if getv(from_object, ['http_options']) is not None:
|
339
|
-
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
340
303
|
|
341
304
|
if getv(from_object, ['validation_dataset']) is not None:
|
342
305
|
setv(
|
@@ -447,234 +410,6 @@ def _CreateTuningJobParameters_to_vertex(
|
|
447
410
|
return to_object
|
448
411
|
|
449
412
|
|
450
|
-
def _DistillationDataset_to_mldev(
|
451
|
-
api_client: ApiClient,
|
452
|
-
from_object: Union[dict, object],
|
453
|
-
parent_object: dict = None,
|
454
|
-
) -> dict:
|
455
|
-
to_object = {}
|
456
|
-
if getv(from_object, ['gcs_uri']) is not None:
|
457
|
-
raise ValueError('gcs_uri parameter is not supported in Google AI.')
|
458
|
-
|
459
|
-
return to_object
|
460
|
-
|
461
|
-
|
462
|
-
def _DistillationDataset_to_vertex(
|
463
|
-
api_client: ApiClient,
|
464
|
-
from_object: Union[dict, object],
|
465
|
-
parent_object: dict = None,
|
466
|
-
) -> dict:
|
467
|
-
to_object = {}
|
468
|
-
if getv(from_object, ['gcs_uri']) is not None:
|
469
|
-
setv(
|
470
|
-
parent_object,
|
471
|
-
['distillationSpec', 'trainingDatasetUri'],
|
472
|
-
getv(from_object, ['gcs_uri']),
|
473
|
-
)
|
474
|
-
|
475
|
-
return to_object
|
476
|
-
|
477
|
-
|
478
|
-
def _DistillationValidationDataset_to_mldev(
|
479
|
-
api_client: ApiClient,
|
480
|
-
from_object: Union[dict, object],
|
481
|
-
parent_object: dict = None,
|
482
|
-
) -> dict:
|
483
|
-
to_object = {}
|
484
|
-
if getv(from_object, ['gcs_uri']) is not None:
|
485
|
-
raise ValueError('gcs_uri parameter is not supported in Google AI.')
|
486
|
-
|
487
|
-
return to_object
|
488
|
-
|
489
|
-
|
490
|
-
def _DistillationValidationDataset_to_vertex(
|
491
|
-
api_client: ApiClient,
|
492
|
-
from_object: Union[dict, object],
|
493
|
-
parent_object: dict = None,
|
494
|
-
) -> dict:
|
495
|
-
to_object = {}
|
496
|
-
if getv(from_object, ['gcs_uri']) is not None:
|
497
|
-
setv(to_object, ['validationDatasetUri'], getv(from_object, ['gcs_uri']))
|
498
|
-
|
499
|
-
return to_object
|
500
|
-
|
501
|
-
|
502
|
-
def _CreateDistillationJobConfig_to_mldev(
|
503
|
-
api_client: ApiClient,
|
504
|
-
from_object: Union[dict, object],
|
505
|
-
parent_object: dict = None,
|
506
|
-
) -> dict:
|
507
|
-
to_object = {}
|
508
|
-
if getv(from_object, ['http_options']) is not None:
|
509
|
-
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
510
|
-
|
511
|
-
if getv(from_object, ['validation_dataset']) is not None:
|
512
|
-
raise ValueError(
|
513
|
-
'validation_dataset parameter is not supported in Google AI.'
|
514
|
-
)
|
515
|
-
|
516
|
-
if getv(from_object, ['tuned_model_display_name']) is not None:
|
517
|
-
setv(
|
518
|
-
parent_object,
|
519
|
-
['displayName'],
|
520
|
-
getv(from_object, ['tuned_model_display_name']),
|
521
|
-
)
|
522
|
-
|
523
|
-
if getv(from_object, ['epoch_count']) is not None:
|
524
|
-
setv(
|
525
|
-
parent_object,
|
526
|
-
['tuningTask', 'hyperparameters', 'epochCount'],
|
527
|
-
getv(from_object, ['epoch_count']),
|
528
|
-
)
|
529
|
-
|
530
|
-
if getv(from_object, ['learning_rate_multiplier']) is not None:
|
531
|
-
setv(
|
532
|
-
parent_object,
|
533
|
-
['tuningTask', 'hyperparameters', 'learningRateMultiplier'],
|
534
|
-
getv(from_object, ['learning_rate_multiplier']),
|
535
|
-
)
|
536
|
-
|
537
|
-
if getv(from_object, ['adapter_size']) is not None:
|
538
|
-
raise ValueError('adapter_size parameter is not supported in Google AI.')
|
539
|
-
|
540
|
-
if getv(from_object, ['pipeline_root_directory']) is not None:
|
541
|
-
raise ValueError(
|
542
|
-
'pipeline_root_directory parameter is not supported in Google AI.'
|
543
|
-
)
|
544
|
-
|
545
|
-
return to_object
|
546
|
-
|
547
|
-
|
548
|
-
def _CreateDistillationJobConfig_to_vertex(
|
549
|
-
api_client: ApiClient,
|
550
|
-
from_object: Union[dict, object],
|
551
|
-
parent_object: dict = None,
|
552
|
-
) -> dict:
|
553
|
-
to_object = {}
|
554
|
-
if getv(from_object, ['http_options']) is not None:
|
555
|
-
setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
|
556
|
-
|
557
|
-
if getv(from_object, ['validation_dataset']) is not None:
|
558
|
-
setv(
|
559
|
-
parent_object,
|
560
|
-
['distillationSpec'],
|
561
|
-
_DistillationValidationDataset_to_vertex(
|
562
|
-
api_client, getv(from_object, ['validation_dataset']), to_object
|
563
|
-
),
|
564
|
-
)
|
565
|
-
|
566
|
-
if getv(from_object, ['tuned_model_display_name']) is not None:
|
567
|
-
setv(
|
568
|
-
parent_object,
|
569
|
-
['tunedModelDisplayName'],
|
570
|
-
getv(from_object, ['tuned_model_display_name']),
|
571
|
-
)
|
572
|
-
|
573
|
-
if getv(from_object, ['epoch_count']) is not None:
|
574
|
-
setv(
|
575
|
-
parent_object,
|
576
|
-
['distillationSpec', 'hyperParameters', 'epochCount'],
|
577
|
-
getv(from_object, ['epoch_count']),
|
578
|
-
)
|
579
|
-
|
580
|
-
if getv(from_object, ['learning_rate_multiplier']) is not None:
|
581
|
-
setv(
|
582
|
-
parent_object,
|
583
|
-
['distillationSpec', 'hyperParameters', 'learningRateMultiplier'],
|
584
|
-
getv(from_object, ['learning_rate_multiplier']),
|
585
|
-
)
|
586
|
-
|
587
|
-
if getv(from_object, ['adapter_size']) is not None:
|
588
|
-
setv(
|
589
|
-
parent_object,
|
590
|
-
['distillationSpec', 'hyperParameters', 'adapterSize'],
|
591
|
-
getv(from_object, ['adapter_size']),
|
592
|
-
)
|
593
|
-
|
594
|
-
if getv(from_object, ['pipeline_root_directory']) is not None:
|
595
|
-
setv(
|
596
|
-
parent_object,
|
597
|
-
['distillationSpec', 'pipelineRootDirectory'],
|
598
|
-
getv(from_object, ['pipeline_root_directory']),
|
599
|
-
)
|
600
|
-
|
601
|
-
return to_object
|
602
|
-
|
603
|
-
|
604
|
-
def _CreateDistillationJobParameters_to_mldev(
|
605
|
-
api_client: ApiClient,
|
606
|
-
from_object: Union[dict, object],
|
607
|
-
parent_object: dict = None,
|
608
|
-
) -> dict:
|
609
|
-
to_object = {}
|
610
|
-
if getv(from_object, ['student_model']) is not None:
|
611
|
-
raise ValueError('student_model parameter is not supported in Google AI.')
|
612
|
-
|
613
|
-
if getv(from_object, ['teacher_model']) is not None:
|
614
|
-
raise ValueError('teacher_model parameter is not supported in Google AI.')
|
615
|
-
|
616
|
-
if getv(from_object, ['training_dataset']) is not None:
|
617
|
-
setv(
|
618
|
-
to_object,
|
619
|
-
['tuningTask', 'trainingData'],
|
620
|
-
_DistillationDataset_to_mldev(
|
621
|
-
api_client, getv(from_object, ['training_dataset']), to_object
|
622
|
-
),
|
623
|
-
)
|
624
|
-
|
625
|
-
if getv(from_object, ['config']) is not None:
|
626
|
-
setv(
|
627
|
-
to_object,
|
628
|
-
['config'],
|
629
|
-
_CreateDistillationJobConfig_to_mldev(
|
630
|
-
api_client, getv(from_object, ['config']), to_object
|
631
|
-
),
|
632
|
-
)
|
633
|
-
|
634
|
-
return to_object
|
635
|
-
|
636
|
-
|
637
|
-
def _CreateDistillationJobParameters_to_vertex(
|
638
|
-
api_client: ApiClient,
|
639
|
-
from_object: Union[dict, object],
|
640
|
-
parent_object: dict = None,
|
641
|
-
) -> dict:
|
642
|
-
to_object = {}
|
643
|
-
if getv(from_object, ['student_model']) is not None:
|
644
|
-
setv(
|
645
|
-
to_object,
|
646
|
-
['distillationSpec', 'studentModel'],
|
647
|
-
getv(from_object, ['student_model']),
|
648
|
-
)
|
649
|
-
|
650
|
-
if getv(from_object, ['teacher_model']) is not None:
|
651
|
-
setv(
|
652
|
-
to_object,
|
653
|
-
['distillationSpec', 'baseTeacherModel'],
|
654
|
-
getv(from_object, ['teacher_model']),
|
655
|
-
)
|
656
|
-
|
657
|
-
if getv(from_object, ['training_dataset']) is not None:
|
658
|
-
setv(
|
659
|
-
to_object,
|
660
|
-
['distillationSpec', 'trainingDatasetUri'],
|
661
|
-
_DistillationDataset_to_vertex(
|
662
|
-
api_client, getv(from_object, ['training_dataset']), to_object
|
663
|
-
),
|
664
|
-
)
|
665
|
-
|
666
|
-
if getv(from_object, ['config']) is not None:
|
667
|
-
setv(
|
668
|
-
to_object,
|
669
|
-
['config'],
|
670
|
-
_CreateDistillationJobConfig_to_vertex(
|
671
|
-
api_client, getv(from_object, ['config']), to_object
|
672
|
-
),
|
673
|
-
)
|
674
|
-
|
675
|
-
return to_object
|
676
|
-
|
677
|
-
|
678
413
|
def _TunedModel_from_mldev(
|
679
414
|
api_client: ApiClient,
|
680
415
|
from_object: Union[dict, object],
|
@@ -756,12 +491,22 @@ def _TuningJob_from_mldev(
|
|
756
491
|
),
|
757
492
|
)
|
758
493
|
|
494
|
+
if getv(from_object, ['distillationSpec']) is not None:
|
495
|
+
setv(
|
496
|
+
to_object,
|
497
|
+
['distillation_spec'],
|
498
|
+
getv(from_object, ['distillationSpec']),
|
499
|
+
)
|
500
|
+
|
759
501
|
if getv(from_object, ['experiment']) is not None:
|
760
502
|
setv(to_object, ['experiment'], getv(from_object, ['experiment']))
|
761
503
|
|
762
504
|
if getv(from_object, ['labels']) is not None:
|
763
505
|
setv(to_object, ['labels'], getv(from_object, ['labels']))
|
764
506
|
|
507
|
+
if getv(from_object, ['pipelineJob']) is not None:
|
508
|
+
setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
|
509
|
+
|
765
510
|
if getv(from_object, ['tunedModelDisplayName']) is not None:
|
766
511
|
setv(
|
767
512
|
to_object,
|
@@ -833,13 +578,6 @@ def _TuningJob_from_vertex(
|
|
833
578
|
if getv(from_object, ['encryptionSpec']) is not None:
|
834
579
|
setv(to_object, ['encryption_spec'], getv(from_object, ['encryptionSpec']))
|
835
580
|
|
836
|
-
if getv(from_object, ['distillationSpec']) is not None:
|
837
|
-
setv(
|
838
|
-
to_object,
|
839
|
-
['distillation_spec'],
|
840
|
-
getv(from_object, ['distillationSpec']),
|
841
|
-
)
|
842
|
-
|
843
581
|
if getv(from_object, ['partnerModelTuningSpec']) is not None:
|
844
582
|
setv(
|
845
583
|
to_object,
|
@@ -847,8 +585,12 @@ def _TuningJob_from_vertex(
|
|
847
585
|
getv(from_object, ['partnerModelTuningSpec']),
|
848
586
|
)
|
849
587
|
|
850
|
-
if getv(from_object, ['
|
851
|
-
setv(
|
588
|
+
if getv(from_object, ['distillationSpec']) is not None:
|
589
|
+
setv(
|
590
|
+
to_object,
|
591
|
+
['distillation_spec'],
|
592
|
+
getv(from_object, ['distillationSpec']),
|
593
|
+
)
|
852
594
|
|
853
595
|
if getv(from_object, ['experiment']) is not None:
|
854
596
|
setv(to_object, ['experiment'], getv(from_object, ['experiment']))
|
@@ -856,6 +598,9 @@ def _TuningJob_from_vertex(
|
|
856
598
|
if getv(from_object, ['labels']) is not None:
|
857
599
|
setv(to_object, ['labels'], getv(from_object, ['labels']))
|
858
600
|
|
601
|
+
if getv(from_object, ['pipelineJob']) is not None:
|
602
|
+
setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
|
603
|
+
|
859
604
|
if getv(from_object, ['tunedModelDisplayName']) is not None:
|
860
605
|
setv(
|
861
606
|
to_object,
|
@@ -950,7 +695,7 @@ def _TuningJobOrOperation_from_vertex(
|
|
950
695
|
return to_object
|
951
696
|
|
952
697
|
|
953
|
-
class Tunings(
|
698
|
+
class Tunings(_api_module.BaseModule):
|
954
699
|
|
955
700
|
def _get(
|
956
701
|
self,
|
@@ -986,8 +731,14 @@ class Tunings(_common.BaseModule):
|
|
986
731
|
if query_params:
|
987
732
|
path = f'{path}?{urlencode(query_params)}'
|
988
733
|
# TODO: remove the hack that pops config.
|
989
|
-
|
990
|
-
|
734
|
+
request_dict.pop('config', None)
|
735
|
+
|
736
|
+
http_options = None
|
737
|
+
if isinstance(config, dict):
|
738
|
+
http_options = config.get('http_options', None)
|
739
|
+
elif hasattr(config, 'http_options'):
|
740
|
+
http_options = config.http_options
|
741
|
+
|
991
742
|
request_dict = _common.convert_to_dict(request_dict)
|
992
743
|
request_dict = _common.encode_unserializable_types(request_dict)
|
993
744
|
|
@@ -1001,7 +752,7 @@ class Tunings(_common.BaseModule):
|
|
1001
752
|
response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
|
1002
753
|
|
1003
754
|
return_value = types.TuningJob._from_response(
|
1004
|
-
response_dict, parameter_model
|
755
|
+
response=response_dict, kwargs=parameter_model
|
1005
756
|
)
|
1006
757
|
self._api_client._verify_response(return_value)
|
1007
758
|
return return_value
|
@@ -1036,8 +787,14 @@ class Tunings(_common.BaseModule):
|
|
1036
787
|
if query_params:
|
1037
788
|
path = f'{path}?{urlencode(query_params)}'
|
1038
789
|
# TODO: remove the hack that pops config.
|
1039
|
-
|
1040
|
-
|
790
|
+
request_dict.pop('config', None)
|
791
|
+
|
792
|
+
http_options = None
|
793
|
+
if isinstance(config, dict):
|
794
|
+
http_options = config.get('http_options', None)
|
795
|
+
elif hasattr(config, 'http_options'):
|
796
|
+
http_options = config.http_options
|
797
|
+
|
1041
798
|
request_dict = _common.convert_to_dict(request_dict)
|
1042
799
|
request_dict = _common.encode_unserializable_types(request_dict)
|
1043
800
|
|
@@ -1055,7 +812,7 @@ class Tunings(_common.BaseModule):
|
|
1055
812
|
)
|
1056
813
|
|
1057
814
|
return_value = types.ListTuningJobsResponse._from_response(
|
1058
|
-
response_dict, parameter_model
|
815
|
+
response=response_dict, kwargs=parameter_model
|
1059
816
|
)
|
1060
817
|
self._api_client._verify_response(return_value)
|
1061
818
|
return return_value
|
@@ -1098,8 +855,14 @@ class Tunings(_common.BaseModule):
|
|
1098
855
|
if query_params:
|
1099
856
|
path = f'{path}?{urlencode(query_params)}'
|
1100
857
|
# TODO: remove the hack that pops config.
|
1101
|
-
|
1102
|
-
|
858
|
+
request_dict.pop('config', None)
|
859
|
+
|
860
|
+
http_options = None
|
861
|
+
if isinstance(config, dict):
|
862
|
+
http_options = config.get('http_options', None)
|
863
|
+
elif hasattr(config, 'http_options'):
|
864
|
+
http_options = config.http_options
|
865
|
+
|
1103
866
|
request_dict = _common.convert_to_dict(request_dict)
|
1104
867
|
request_dict = _common.encode_unserializable_types(request_dict)
|
1105
868
|
|
@@ -1117,70 +880,11 @@ class Tunings(_common.BaseModule):
|
|
1117
880
|
)
|
1118
881
|
|
1119
882
|
return_value = types.TuningJobOrOperation._from_response(
|
1120
|
-
response_dict, parameter_model
|
883
|
+
response=response_dict, kwargs=parameter_model
|
1121
884
|
).tuning_job
|
1122
885
|
self._api_client._verify_response(return_value)
|
1123
886
|
return return_value
|
1124
887
|
|
1125
|
-
def distill(
|
1126
|
-
self,
|
1127
|
-
*,
|
1128
|
-
student_model: str,
|
1129
|
-
teacher_model: str,
|
1130
|
-
training_dataset: types.DistillationDatasetOrDict,
|
1131
|
-
config: Optional[types.CreateDistillationJobConfigOrDict] = None,
|
1132
|
-
) -> types.TuningJob:
|
1133
|
-
"""Creates a distillation job.
|
1134
|
-
|
1135
|
-
Args:
|
1136
|
-
student_model: The name of the model to tune.
|
1137
|
-
teacher_model: The name of the model to distill from.
|
1138
|
-
training_dataset: The training dataset to use.
|
1139
|
-
config: The configuration to use for the distillation job.
|
1140
|
-
|
1141
|
-
Returns:
|
1142
|
-
A TuningJob object.
|
1143
|
-
"""
|
1144
|
-
|
1145
|
-
parameter_model = types._CreateDistillationJobParameters(
|
1146
|
-
student_model=student_model,
|
1147
|
-
teacher_model=teacher_model,
|
1148
|
-
training_dataset=training_dataset,
|
1149
|
-
config=config,
|
1150
|
-
)
|
1151
|
-
|
1152
|
-
if not self._api_client.vertexai:
|
1153
|
-
raise ValueError('This method is only supported in the Vertex AI client.')
|
1154
|
-
else:
|
1155
|
-
request_dict = _CreateDistillationJobParameters_to_vertex(
|
1156
|
-
self._api_client, parameter_model
|
1157
|
-
)
|
1158
|
-
path = 'tuningJobs'.format_map(request_dict.get('_url'))
|
1159
|
-
|
1160
|
-
query_params = request_dict.get('_query')
|
1161
|
-
if query_params:
|
1162
|
-
path = f'{path}?{urlencode(query_params)}'
|
1163
|
-
# TODO: remove the hack that pops config.
|
1164
|
-
config = request_dict.pop('config', None)
|
1165
|
-
http_options = config.pop('httpOptions', None) if config else None
|
1166
|
-
request_dict = _common.convert_to_dict(request_dict)
|
1167
|
-
request_dict = _common.encode_unserializable_types(request_dict)
|
1168
|
-
|
1169
|
-
response_dict = self._api_client.request(
|
1170
|
-
'post', path, request_dict, http_options
|
1171
|
-
)
|
1172
|
-
|
1173
|
-
if self._api_client.vertexai:
|
1174
|
-
response_dict = _TuningJob_from_vertex(self._api_client, response_dict)
|
1175
|
-
else:
|
1176
|
-
response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
|
1177
|
-
|
1178
|
-
return_value = types.TuningJob._from_response(
|
1179
|
-
response_dict, parameter_model
|
1180
|
-
)
|
1181
|
-
self._api_client._verify_response(return_value)
|
1182
|
-
return return_value
|
1183
|
-
|
1184
888
|
def list(
|
1185
889
|
self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None
|
1186
890
|
) -> Pager[types.TuningJob]:
|
@@ -1222,7 +926,7 @@ class Tunings(_common.BaseModule):
|
|
1222
926
|
return result
|
1223
927
|
|
1224
928
|
|
1225
|
-
class AsyncTunings(
|
929
|
+
class AsyncTunings(_api_module.BaseModule):
|
1226
930
|
|
1227
931
|
async def _get(
|
1228
932
|
self,
|
@@ -1258,8 +962,14 @@ class AsyncTunings(_common.BaseModule):
|
|
1258
962
|
if query_params:
|
1259
963
|
path = f'{path}?{urlencode(query_params)}'
|
1260
964
|
# TODO: remove the hack that pops config.
|
1261
|
-
|
1262
|
-
|
965
|
+
request_dict.pop('config', None)
|
966
|
+
|
967
|
+
http_options = None
|
968
|
+
if isinstance(config, dict):
|
969
|
+
http_options = config.get('http_options', None)
|
970
|
+
elif hasattr(config, 'http_options'):
|
971
|
+
http_options = config.http_options
|
972
|
+
|
1263
973
|
request_dict = _common.convert_to_dict(request_dict)
|
1264
974
|
request_dict = _common.encode_unserializable_types(request_dict)
|
1265
975
|
|
@@ -1273,7 +983,7 @@ class AsyncTunings(_common.BaseModule):
|
|
1273
983
|
response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
|
1274
984
|
|
1275
985
|
return_value = types.TuningJob._from_response(
|
1276
|
-
response_dict, parameter_model
|
986
|
+
response=response_dict, kwargs=parameter_model
|
1277
987
|
)
|
1278
988
|
self._api_client._verify_response(return_value)
|
1279
989
|
return return_value
|
@@ -1308,8 +1018,14 @@ class AsyncTunings(_common.BaseModule):
|
|
1308
1018
|
if query_params:
|
1309
1019
|
path = f'{path}?{urlencode(query_params)}'
|
1310
1020
|
# TODO: remove the hack that pops config.
|
1311
|
-
|
1312
|
-
|
1021
|
+
request_dict.pop('config', None)
|
1022
|
+
|
1023
|
+
http_options = None
|
1024
|
+
if isinstance(config, dict):
|
1025
|
+
http_options = config.get('http_options', None)
|
1026
|
+
elif hasattr(config, 'http_options'):
|
1027
|
+
http_options = config.http_options
|
1028
|
+
|
1313
1029
|
request_dict = _common.convert_to_dict(request_dict)
|
1314
1030
|
request_dict = _common.encode_unserializable_types(request_dict)
|
1315
1031
|
|
@@ -1327,7 +1043,7 @@ class AsyncTunings(_common.BaseModule):
|
|
1327
1043
|
)
|
1328
1044
|
|
1329
1045
|
return_value = types.ListTuningJobsResponse._from_response(
|
1330
|
-
response_dict, parameter_model
|
1046
|
+
response=response_dict, kwargs=parameter_model
|
1331
1047
|
)
|
1332
1048
|
self._api_client._verify_response(return_value)
|
1333
1049
|
return return_value
|
@@ -1370,8 +1086,14 @@ class AsyncTunings(_common.BaseModule):
|
|
1370
1086
|
if query_params:
|
1371
1087
|
path = f'{path}?{urlencode(query_params)}'
|
1372
1088
|
# TODO: remove the hack that pops config.
|
1373
|
-
|
1374
|
-
|
1089
|
+
request_dict.pop('config', None)
|
1090
|
+
|
1091
|
+
http_options = None
|
1092
|
+
if isinstance(config, dict):
|
1093
|
+
http_options = config.get('http_options', None)
|
1094
|
+
elif hasattr(config, 'http_options'):
|
1095
|
+
http_options = config.http_options
|
1096
|
+
|
1375
1097
|
request_dict = _common.convert_to_dict(request_dict)
|
1376
1098
|
request_dict = _common.encode_unserializable_types(request_dict)
|
1377
1099
|
|
@@ -1389,70 +1111,11 @@ class AsyncTunings(_common.BaseModule):
|
|
1389
1111
|
)
|
1390
1112
|
|
1391
1113
|
return_value = types.TuningJobOrOperation._from_response(
|
1392
|
-
response_dict, parameter_model
|
1114
|
+
response=response_dict, kwargs=parameter_model
|
1393
1115
|
).tuning_job
|
1394
1116
|
self._api_client._verify_response(return_value)
|
1395
1117
|
return return_value
|
1396
1118
|
|
1397
|
-
async def distill(
|
1398
|
-
self,
|
1399
|
-
*,
|
1400
|
-
student_model: str,
|
1401
|
-
teacher_model: str,
|
1402
|
-
training_dataset: types.DistillationDatasetOrDict,
|
1403
|
-
config: Optional[types.CreateDistillationJobConfigOrDict] = None,
|
1404
|
-
) -> types.TuningJob:
|
1405
|
-
"""Creates a distillation job.
|
1406
|
-
|
1407
|
-
Args:
|
1408
|
-
student_model: The name of the model to tune.
|
1409
|
-
teacher_model: The name of the model to distill from.
|
1410
|
-
training_dataset: The training dataset to use.
|
1411
|
-
config: The configuration to use for the distillation job.
|
1412
|
-
|
1413
|
-
Returns:
|
1414
|
-
A TuningJob object.
|
1415
|
-
"""
|
1416
|
-
|
1417
|
-
parameter_model = types._CreateDistillationJobParameters(
|
1418
|
-
student_model=student_model,
|
1419
|
-
teacher_model=teacher_model,
|
1420
|
-
training_dataset=training_dataset,
|
1421
|
-
config=config,
|
1422
|
-
)
|
1423
|
-
|
1424
|
-
if not self._api_client.vertexai:
|
1425
|
-
raise ValueError('This method is only supported in the Vertex AI client.')
|
1426
|
-
else:
|
1427
|
-
request_dict = _CreateDistillationJobParameters_to_vertex(
|
1428
|
-
self._api_client, parameter_model
|
1429
|
-
)
|
1430
|
-
path = 'tuningJobs'.format_map(request_dict.get('_url'))
|
1431
|
-
|
1432
|
-
query_params = request_dict.get('_query')
|
1433
|
-
if query_params:
|
1434
|
-
path = f'{path}?{urlencode(query_params)}'
|
1435
|
-
# TODO: remove the hack that pops config.
|
1436
|
-
config = request_dict.pop('config', None)
|
1437
|
-
http_options = config.pop('httpOptions', None) if config else None
|
1438
|
-
request_dict = _common.convert_to_dict(request_dict)
|
1439
|
-
request_dict = _common.encode_unserializable_types(request_dict)
|
1440
|
-
|
1441
|
-
response_dict = await self._api_client.async_request(
|
1442
|
-
'post', path, request_dict, http_options
|
1443
|
-
)
|
1444
|
-
|
1445
|
-
if self._api_client.vertexai:
|
1446
|
-
response_dict = _TuningJob_from_vertex(self._api_client, response_dict)
|
1447
|
-
else:
|
1448
|
-
response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
|
1449
|
-
|
1450
|
-
return_value = types.TuningJob._from_response(
|
1451
|
-
response_dict, parameter_model
|
1452
|
-
)
|
1453
|
-
self._api_client._verify_response(return_value)
|
1454
|
-
return return_value
|
1455
|
-
|
1456
1119
|
async def list(
|
1457
1120
|
self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None
|
1458
1121
|
) -> AsyncPager[types.TuningJob]:
|