google-genai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/tunings.py CHANGED
@@ -17,6 +17,7 @@
17
17
 
18
18
  from typing import Optional, Union
19
19
  from urllib.parse import urlencode
20
+ from . import _api_module
20
21
  from . import _common
21
22
  from . import _transformers as t
22
23
  from . import types
@@ -26,30 +27,6 @@ from ._common import set_value_by_path as setv
26
27
  from .pagers import AsyncPager, Pager
27
28
 
28
29
 
29
- def _GetTuningJobConfig_to_mldev(
30
- api_client: ApiClient,
31
- from_object: Union[dict, object],
32
- parent_object: dict = None,
33
- ) -> dict:
34
- to_object = {}
35
- if getv(from_object, ['http_options']) is not None:
36
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
37
-
38
- return to_object
39
-
40
-
41
- def _GetTuningJobConfig_to_vertex(
42
- api_client: ApiClient,
43
- from_object: Union[dict, object],
44
- parent_object: dict = None,
45
- ) -> dict:
46
- to_object = {}
47
- if getv(from_object, ['http_options']) is not None:
48
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
49
-
50
- return to_object
51
-
52
-
53
30
  def _GetTuningJobParameters_to_mldev(
54
31
  api_client: ApiClient,
55
32
  from_object: Union[dict, object],
@@ -60,13 +37,7 @@ def _GetTuningJobParameters_to_mldev(
60
37
  setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
61
38
 
62
39
  if getv(from_object, ['config']) is not None:
63
- setv(
64
- to_object,
65
- ['config'],
66
- _GetTuningJobConfig_to_mldev(
67
- api_client, getv(from_object, ['config']), to_object
68
- ),
69
- )
40
+ setv(to_object, ['config'], getv(from_object, ['config']))
70
41
 
71
42
  return to_object
72
43
 
@@ -81,13 +52,7 @@ def _GetTuningJobParameters_to_vertex(
81
52
  setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
82
53
 
83
54
  if getv(from_object, ['config']) is not None:
84
- setv(
85
- to_object,
86
- ['config'],
87
- _GetTuningJobConfig_to_vertex(
88
- api_client, getv(from_object, ['config']), to_object
89
- ),
90
- )
55
+ setv(to_object, ['config'], getv(from_object, ['config']))
91
56
 
92
57
  return to_object
93
58
 
@@ -98,6 +63,7 @@ def _ListTuningJobsConfig_to_mldev(
98
63
  parent_object: dict = None,
99
64
  ) -> dict:
100
65
  to_object = {}
66
+
101
67
  if getv(from_object, ['page_size']) is not None:
102
68
  setv(
103
69
  parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
@@ -122,6 +88,7 @@ def _ListTuningJobsConfig_to_vertex(
122
88
  parent_object: dict = None,
123
89
  ) -> dict:
124
90
  to_object = {}
91
+
125
92
  if getv(from_object, ['page_size']) is not None:
126
93
  setv(
127
94
  parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
@@ -213,7 +180,7 @@ def _TuningDataset_to_mldev(
213
180
  ) -> dict:
214
181
  to_object = {}
215
182
  if getv(from_object, ['gcs_uri']) is not None:
216
- raise ValueError('gcs_uri parameter is not supported in Google AI.')
183
+ raise ValueError('gcs_uri parameter is not supported in Gemini API.')
217
184
 
218
185
  if getv(from_object, ['examples']) is not None:
219
186
  setv(
@@ -254,7 +221,7 @@ def _TuningValidationDataset_to_mldev(
254
221
  ) -> dict:
255
222
  to_object = {}
256
223
  if getv(from_object, ['gcs_uri']) is not None:
257
- raise ValueError('gcs_uri parameter is not supported in Google AI.')
224
+ raise ValueError('gcs_uri parameter is not supported in Gemini API.')
258
225
 
259
226
  return to_object
260
227
 
@@ -277,12 +244,10 @@ def _CreateTuningJobConfig_to_mldev(
277
244
  parent_object: dict = None,
278
245
  ) -> dict:
279
246
  to_object = {}
280
- if getv(from_object, ['http_options']) is not None:
281
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
282
247
 
283
248
  if getv(from_object, ['validation_dataset']) is not None:
284
249
  raise ValueError(
285
- 'validation_dataset parameter is not supported in Google AI.'
250
+ 'validation_dataset parameter is not supported in Gemini API.'
286
251
  )
287
252
 
288
253
  if getv(from_object, ['tuned_model_display_name']) is not None:
@@ -293,7 +258,7 @@ def _CreateTuningJobConfig_to_mldev(
293
258
  )
294
259
 
295
260
  if getv(from_object, ['description']) is not None:
296
- raise ValueError('description parameter is not supported in Google AI.')
261
+ raise ValueError('description parameter is not supported in Gemini API.')
297
262
 
298
263
  if getv(from_object, ['epoch_count']) is not None:
299
264
  setv(
@@ -310,7 +275,7 @@ def _CreateTuningJobConfig_to_mldev(
310
275
  )
311
276
 
312
277
  if getv(from_object, ['adapter_size']) is not None:
313
- raise ValueError('adapter_size parameter is not supported in Google AI.')
278
+ raise ValueError('adapter_size parameter is not supported in Gemini API.')
314
279
 
315
280
  if getv(from_object, ['batch_size']) is not None:
316
281
  setv(
@@ -335,8 +300,6 @@ def _CreateTuningJobConfig_to_vertex(
335
300
  parent_object: dict = None,
336
301
  ) -> dict:
337
302
  to_object = {}
338
- if getv(from_object, ['http_options']) is not None:
339
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
340
303
 
341
304
  if getv(from_object, ['validation_dataset']) is not None:
342
305
  setv(
@@ -447,234 +410,6 @@ def _CreateTuningJobParameters_to_vertex(
447
410
  return to_object
448
411
 
449
412
 
450
- def _DistillationDataset_to_mldev(
451
- api_client: ApiClient,
452
- from_object: Union[dict, object],
453
- parent_object: dict = None,
454
- ) -> dict:
455
- to_object = {}
456
- if getv(from_object, ['gcs_uri']) is not None:
457
- raise ValueError('gcs_uri parameter is not supported in Google AI.')
458
-
459
- return to_object
460
-
461
-
462
- def _DistillationDataset_to_vertex(
463
- api_client: ApiClient,
464
- from_object: Union[dict, object],
465
- parent_object: dict = None,
466
- ) -> dict:
467
- to_object = {}
468
- if getv(from_object, ['gcs_uri']) is not None:
469
- setv(
470
- parent_object,
471
- ['distillationSpec', 'trainingDatasetUri'],
472
- getv(from_object, ['gcs_uri']),
473
- )
474
-
475
- return to_object
476
-
477
-
478
- def _DistillationValidationDataset_to_mldev(
479
- api_client: ApiClient,
480
- from_object: Union[dict, object],
481
- parent_object: dict = None,
482
- ) -> dict:
483
- to_object = {}
484
- if getv(from_object, ['gcs_uri']) is not None:
485
- raise ValueError('gcs_uri parameter is not supported in Google AI.')
486
-
487
- return to_object
488
-
489
-
490
- def _DistillationValidationDataset_to_vertex(
491
- api_client: ApiClient,
492
- from_object: Union[dict, object],
493
- parent_object: dict = None,
494
- ) -> dict:
495
- to_object = {}
496
- if getv(from_object, ['gcs_uri']) is not None:
497
- setv(to_object, ['validationDatasetUri'], getv(from_object, ['gcs_uri']))
498
-
499
- return to_object
500
-
501
-
502
- def _CreateDistillationJobConfig_to_mldev(
503
- api_client: ApiClient,
504
- from_object: Union[dict, object],
505
- parent_object: dict = None,
506
- ) -> dict:
507
- to_object = {}
508
- if getv(from_object, ['http_options']) is not None:
509
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
510
-
511
- if getv(from_object, ['validation_dataset']) is not None:
512
- raise ValueError(
513
- 'validation_dataset parameter is not supported in Google AI.'
514
- )
515
-
516
- if getv(from_object, ['tuned_model_display_name']) is not None:
517
- setv(
518
- parent_object,
519
- ['displayName'],
520
- getv(from_object, ['tuned_model_display_name']),
521
- )
522
-
523
- if getv(from_object, ['epoch_count']) is not None:
524
- setv(
525
- parent_object,
526
- ['tuningTask', 'hyperparameters', 'epochCount'],
527
- getv(from_object, ['epoch_count']),
528
- )
529
-
530
- if getv(from_object, ['learning_rate_multiplier']) is not None:
531
- setv(
532
- parent_object,
533
- ['tuningTask', 'hyperparameters', 'learningRateMultiplier'],
534
- getv(from_object, ['learning_rate_multiplier']),
535
- )
536
-
537
- if getv(from_object, ['adapter_size']) is not None:
538
- raise ValueError('adapter_size parameter is not supported in Google AI.')
539
-
540
- if getv(from_object, ['pipeline_root_directory']) is not None:
541
- raise ValueError(
542
- 'pipeline_root_directory parameter is not supported in Google AI.'
543
- )
544
-
545
- return to_object
546
-
547
-
548
- def _CreateDistillationJobConfig_to_vertex(
549
- api_client: ApiClient,
550
- from_object: Union[dict, object],
551
- parent_object: dict = None,
552
- ) -> dict:
553
- to_object = {}
554
- if getv(from_object, ['http_options']) is not None:
555
- setv(to_object, ['httpOptions'], getv(from_object, ['http_options']))
556
-
557
- if getv(from_object, ['validation_dataset']) is not None:
558
- setv(
559
- parent_object,
560
- ['distillationSpec'],
561
- _DistillationValidationDataset_to_vertex(
562
- api_client, getv(from_object, ['validation_dataset']), to_object
563
- ),
564
- )
565
-
566
- if getv(from_object, ['tuned_model_display_name']) is not None:
567
- setv(
568
- parent_object,
569
- ['tunedModelDisplayName'],
570
- getv(from_object, ['tuned_model_display_name']),
571
- )
572
-
573
- if getv(from_object, ['epoch_count']) is not None:
574
- setv(
575
- parent_object,
576
- ['distillationSpec', 'hyperParameters', 'epochCount'],
577
- getv(from_object, ['epoch_count']),
578
- )
579
-
580
- if getv(from_object, ['learning_rate_multiplier']) is not None:
581
- setv(
582
- parent_object,
583
- ['distillationSpec', 'hyperParameters', 'learningRateMultiplier'],
584
- getv(from_object, ['learning_rate_multiplier']),
585
- )
586
-
587
- if getv(from_object, ['adapter_size']) is not None:
588
- setv(
589
- parent_object,
590
- ['distillationSpec', 'hyperParameters', 'adapterSize'],
591
- getv(from_object, ['adapter_size']),
592
- )
593
-
594
- if getv(from_object, ['pipeline_root_directory']) is not None:
595
- setv(
596
- parent_object,
597
- ['distillationSpec', 'pipelineRootDirectory'],
598
- getv(from_object, ['pipeline_root_directory']),
599
- )
600
-
601
- return to_object
602
-
603
-
604
- def _CreateDistillationJobParameters_to_mldev(
605
- api_client: ApiClient,
606
- from_object: Union[dict, object],
607
- parent_object: dict = None,
608
- ) -> dict:
609
- to_object = {}
610
- if getv(from_object, ['student_model']) is not None:
611
- raise ValueError('student_model parameter is not supported in Google AI.')
612
-
613
- if getv(from_object, ['teacher_model']) is not None:
614
- raise ValueError('teacher_model parameter is not supported in Google AI.')
615
-
616
- if getv(from_object, ['training_dataset']) is not None:
617
- setv(
618
- to_object,
619
- ['tuningTask', 'trainingData'],
620
- _DistillationDataset_to_mldev(
621
- api_client, getv(from_object, ['training_dataset']), to_object
622
- ),
623
- )
624
-
625
- if getv(from_object, ['config']) is not None:
626
- setv(
627
- to_object,
628
- ['config'],
629
- _CreateDistillationJobConfig_to_mldev(
630
- api_client, getv(from_object, ['config']), to_object
631
- ),
632
- )
633
-
634
- return to_object
635
-
636
-
637
- def _CreateDistillationJobParameters_to_vertex(
638
- api_client: ApiClient,
639
- from_object: Union[dict, object],
640
- parent_object: dict = None,
641
- ) -> dict:
642
- to_object = {}
643
- if getv(from_object, ['student_model']) is not None:
644
- setv(
645
- to_object,
646
- ['distillationSpec', 'studentModel'],
647
- getv(from_object, ['student_model']),
648
- )
649
-
650
- if getv(from_object, ['teacher_model']) is not None:
651
- setv(
652
- to_object,
653
- ['distillationSpec', 'baseTeacherModel'],
654
- getv(from_object, ['teacher_model']),
655
- )
656
-
657
- if getv(from_object, ['training_dataset']) is not None:
658
- setv(
659
- to_object,
660
- ['distillationSpec', 'trainingDatasetUri'],
661
- _DistillationDataset_to_vertex(
662
- api_client, getv(from_object, ['training_dataset']), to_object
663
- ),
664
- )
665
-
666
- if getv(from_object, ['config']) is not None:
667
- setv(
668
- to_object,
669
- ['config'],
670
- _CreateDistillationJobConfig_to_vertex(
671
- api_client, getv(from_object, ['config']), to_object
672
- ),
673
- )
674
-
675
- return to_object
676
-
677
-
678
413
  def _TunedModel_from_mldev(
679
414
  api_client: ApiClient,
680
415
  from_object: Union[dict, object],
@@ -756,12 +491,22 @@ def _TuningJob_from_mldev(
756
491
  ),
757
492
  )
758
493
 
494
+ if getv(from_object, ['distillationSpec']) is not None:
495
+ setv(
496
+ to_object,
497
+ ['distillation_spec'],
498
+ getv(from_object, ['distillationSpec']),
499
+ )
500
+
759
501
  if getv(from_object, ['experiment']) is not None:
760
502
  setv(to_object, ['experiment'], getv(from_object, ['experiment']))
761
503
 
762
504
  if getv(from_object, ['labels']) is not None:
763
505
  setv(to_object, ['labels'], getv(from_object, ['labels']))
764
506
 
507
+ if getv(from_object, ['pipelineJob']) is not None:
508
+ setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
509
+
765
510
  if getv(from_object, ['tunedModelDisplayName']) is not None:
766
511
  setv(
767
512
  to_object,
@@ -833,13 +578,6 @@ def _TuningJob_from_vertex(
833
578
  if getv(from_object, ['encryptionSpec']) is not None:
834
579
  setv(to_object, ['encryption_spec'], getv(from_object, ['encryptionSpec']))
835
580
 
836
- if getv(from_object, ['distillationSpec']) is not None:
837
- setv(
838
- to_object,
839
- ['distillation_spec'],
840
- getv(from_object, ['distillationSpec']),
841
- )
842
-
843
581
  if getv(from_object, ['partnerModelTuningSpec']) is not None:
844
582
  setv(
845
583
  to_object,
@@ -847,8 +585,12 @@ def _TuningJob_from_vertex(
847
585
  getv(from_object, ['partnerModelTuningSpec']),
848
586
  )
849
587
 
850
- if getv(from_object, ['pipelineJob']) is not None:
851
- setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
588
+ if getv(from_object, ['distillationSpec']) is not None:
589
+ setv(
590
+ to_object,
591
+ ['distillation_spec'],
592
+ getv(from_object, ['distillationSpec']),
593
+ )
852
594
 
853
595
  if getv(from_object, ['experiment']) is not None:
854
596
  setv(to_object, ['experiment'], getv(from_object, ['experiment']))
@@ -856,6 +598,9 @@ def _TuningJob_from_vertex(
856
598
  if getv(from_object, ['labels']) is not None:
857
599
  setv(to_object, ['labels'], getv(from_object, ['labels']))
858
600
 
601
+ if getv(from_object, ['pipelineJob']) is not None:
602
+ setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
603
+
859
604
  if getv(from_object, ['tunedModelDisplayName']) is not None:
860
605
  setv(
861
606
  to_object,
@@ -950,7 +695,7 @@ def _TuningJobOrOperation_from_vertex(
950
695
  return to_object
951
696
 
952
697
 
953
- class Tunings(_common.BaseModule):
698
+ class Tunings(_api_module.BaseModule):
954
699
 
955
700
  def _get(
956
701
  self,
@@ -986,8 +731,14 @@ class Tunings(_common.BaseModule):
986
731
  if query_params:
987
732
  path = f'{path}?{urlencode(query_params)}'
988
733
  # TODO: remove the hack that pops config.
989
- config = request_dict.pop('config', None)
990
- http_options = config.pop('httpOptions', None) if config else None
734
+ request_dict.pop('config', None)
735
+
736
+ http_options = None
737
+ if isinstance(config, dict):
738
+ http_options = config.get('http_options', None)
739
+ elif hasattr(config, 'http_options'):
740
+ http_options = config.http_options
741
+
991
742
  request_dict = _common.convert_to_dict(request_dict)
992
743
  request_dict = _common.encode_unserializable_types(request_dict)
993
744
 
@@ -1001,7 +752,7 @@ class Tunings(_common.BaseModule):
1001
752
  response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
1002
753
 
1003
754
  return_value = types.TuningJob._from_response(
1004
- response_dict, parameter_model
755
+ response=response_dict, kwargs=parameter_model
1005
756
  )
1006
757
  self._api_client._verify_response(return_value)
1007
758
  return return_value
@@ -1036,8 +787,14 @@ class Tunings(_common.BaseModule):
1036
787
  if query_params:
1037
788
  path = f'{path}?{urlencode(query_params)}'
1038
789
  # TODO: remove the hack that pops config.
1039
- config = request_dict.pop('config', None)
1040
- http_options = config.pop('httpOptions', None) if config else None
790
+ request_dict.pop('config', None)
791
+
792
+ http_options = None
793
+ if isinstance(config, dict):
794
+ http_options = config.get('http_options', None)
795
+ elif hasattr(config, 'http_options'):
796
+ http_options = config.http_options
797
+
1041
798
  request_dict = _common.convert_to_dict(request_dict)
1042
799
  request_dict = _common.encode_unserializable_types(request_dict)
1043
800
 
@@ -1055,7 +812,7 @@ class Tunings(_common.BaseModule):
1055
812
  )
1056
813
 
1057
814
  return_value = types.ListTuningJobsResponse._from_response(
1058
- response_dict, parameter_model
815
+ response=response_dict, kwargs=parameter_model
1059
816
  )
1060
817
  self._api_client._verify_response(return_value)
1061
818
  return return_value
@@ -1098,8 +855,14 @@ class Tunings(_common.BaseModule):
1098
855
  if query_params:
1099
856
  path = f'{path}?{urlencode(query_params)}'
1100
857
  # TODO: remove the hack that pops config.
1101
- config = request_dict.pop('config', None)
1102
- http_options = config.pop('httpOptions', None) if config else None
858
+ request_dict.pop('config', None)
859
+
860
+ http_options = None
861
+ if isinstance(config, dict):
862
+ http_options = config.get('http_options', None)
863
+ elif hasattr(config, 'http_options'):
864
+ http_options = config.http_options
865
+
1103
866
  request_dict = _common.convert_to_dict(request_dict)
1104
867
  request_dict = _common.encode_unserializable_types(request_dict)
1105
868
 
@@ -1117,70 +880,11 @@ class Tunings(_common.BaseModule):
1117
880
  )
1118
881
 
1119
882
  return_value = types.TuningJobOrOperation._from_response(
1120
- response_dict, parameter_model
883
+ response=response_dict, kwargs=parameter_model
1121
884
  ).tuning_job
1122
885
  self._api_client._verify_response(return_value)
1123
886
  return return_value
1124
887
 
1125
- def distill(
1126
- self,
1127
- *,
1128
- student_model: str,
1129
- teacher_model: str,
1130
- training_dataset: types.DistillationDatasetOrDict,
1131
- config: Optional[types.CreateDistillationJobConfigOrDict] = None,
1132
- ) -> types.TuningJob:
1133
- """Creates a distillation job.
1134
-
1135
- Args:
1136
- student_model: The name of the model to tune.
1137
- teacher_model: The name of the model to distill from.
1138
- training_dataset: The training dataset to use.
1139
- config: The configuration to use for the distillation job.
1140
-
1141
- Returns:
1142
- A TuningJob object.
1143
- """
1144
-
1145
- parameter_model = types._CreateDistillationJobParameters(
1146
- student_model=student_model,
1147
- teacher_model=teacher_model,
1148
- training_dataset=training_dataset,
1149
- config=config,
1150
- )
1151
-
1152
- if not self._api_client.vertexai:
1153
- raise ValueError('This method is only supported in the Vertex AI client.')
1154
- else:
1155
- request_dict = _CreateDistillationJobParameters_to_vertex(
1156
- self._api_client, parameter_model
1157
- )
1158
- path = 'tuningJobs'.format_map(request_dict.get('_url'))
1159
-
1160
- query_params = request_dict.get('_query')
1161
- if query_params:
1162
- path = f'{path}?{urlencode(query_params)}'
1163
- # TODO: remove the hack that pops config.
1164
- config = request_dict.pop('config', None)
1165
- http_options = config.pop('httpOptions', None) if config else None
1166
- request_dict = _common.convert_to_dict(request_dict)
1167
- request_dict = _common.encode_unserializable_types(request_dict)
1168
-
1169
- response_dict = self._api_client.request(
1170
- 'post', path, request_dict, http_options
1171
- )
1172
-
1173
- if self._api_client.vertexai:
1174
- response_dict = _TuningJob_from_vertex(self._api_client, response_dict)
1175
- else:
1176
- response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
1177
-
1178
- return_value = types.TuningJob._from_response(
1179
- response_dict, parameter_model
1180
- )
1181
- self._api_client._verify_response(return_value)
1182
- return return_value
1183
-
1184
888
  def list(
1185
889
  self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None
1186
890
  ) -> Pager[types.TuningJob]:
@@ -1222,7 +926,7 @@ class Tunings(_common.BaseModule):
1222
926
  return result
1223
927
 
1224
928
 
1225
- class AsyncTunings(_common.BaseModule):
929
+ class AsyncTunings(_api_module.BaseModule):
1226
930
 
1227
931
  async def _get(
1228
932
  self,
@@ -1258,8 +962,14 @@ class AsyncTunings(_common.BaseModule):
1258
962
  if query_params:
1259
963
  path = f'{path}?{urlencode(query_params)}'
1260
964
  # TODO: remove the hack that pops config.
1261
- config = request_dict.pop('config', None)
1262
- http_options = config.pop('httpOptions', None) if config else None
965
+ request_dict.pop('config', None)
966
+
967
+ http_options = None
968
+ if isinstance(config, dict):
969
+ http_options = config.get('http_options', None)
970
+ elif hasattr(config, 'http_options'):
971
+ http_options = config.http_options
972
+
1263
973
  request_dict = _common.convert_to_dict(request_dict)
1264
974
  request_dict = _common.encode_unserializable_types(request_dict)
1265
975
 
@@ -1273,7 +983,7 @@ class AsyncTunings(_common.BaseModule):
1273
983
  response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
1274
984
 
1275
985
  return_value = types.TuningJob._from_response(
1276
- response_dict, parameter_model
986
+ response=response_dict, kwargs=parameter_model
1277
987
  )
1278
988
  self._api_client._verify_response(return_value)
1279
989
  return return_value
@@ -1308,8 +1018,14 @@ class AsyncTunings(_common.BaseModule):
1308
1018
  if query_params:
1309
1019
  path = f'{path}?{urlencode(query_params)}'
1310
1020
  # TODO: remove the hack that pops config.
1311
- config = request_dict.pop('config', None)
1312
- http_options = config.pop('httpOptions', None) if config else None
1021
+ request_dict.pop('config', None)
1022
+
1023
+ http_options = None
1024
+ if isinstance(config, dict):
1025
+ http_options = config.get('http_options', None)
1026
+ elif hasattr(config, 'http_options'):
1027
+ http_options = config.http_options
1028
+
1313
1029
  request_dict = _common.convert_to_dict(request_dict)
1314
1030
  request_dict = _common.encode_unserializable_types(request_dict)
1315
1031
 
@@ -1327,7 +1043,7 @@ class AsyncTunings(_common.BaseModule):
1327
1043
  )
1328
1044
 
1329
1045
  return_value = types.ListTuningJobsResponse._from_response(
1330
- response_dict, parameter_model
1046
+ response=response_dict, kwargs=parameter_model
1331
1047
  )
1332
1048
  self._api_client._verify_response(return_value)
1333
1049
  return return_value
@@ -1370,8 +1086,14 @@ class AsyncTunings(_common.BaseModule):
1370
1086
  if query_params:
1371
1087
  path = f'{path}?{urlencode(query_params)}'
1372
1088
  # TODO: remove the hack that pops config.
1373
- config = request_dict.pop('config', None)
1374
- http_options = config.pop('httpOptions', None) if config else None
1089
+ request_dict.pop('config', None)
1090
+
1091
+ http_options = None
1092
+ if isinstance(config, dict):
1093
+ http_options = config.get('http_options', None)
1094
+ elif hasattr(config, 'http_options'):
1095
+ http_options = config.http_options
1096
+
1375
1097
  request_dict = _common.convert_to_dict(request_dict)
1376
1098
  request_dict = _common.encode_unserializable_types(request_dict)
1377
1099
 
@@ -1389,70 +1111,11 @@ class AsyncTunings(_common.BaseModule):
1389
1111
  )
1390
1112
 
1391
1113
  return_value = types.TuningJobOrOperation._from_response(
1392
- response_dict, parameter_model
1114
+ response=response_dict, kwargs=parameter_model
1393
1115
  ).tuning_job
1394
1116
  self._api_client._verify_response(return_value)
1395
1117
  return return_value
1396
1118
 
1397
- async def distill(
1398
- self,
1399
- *,
1400
- student_model: str,
1401
- teacher_model: str,
1402
- training_dataset: types.DistillationDatasetOrDict,
1403
- config: Optional[types.CreateDistillationJobConfigOrDict] = None,
1404
- ) -> types.TuningJob:
1405
- """Creates a distillation job.
1406
-
1407
- Args:
1408
- student_model: The name of the model to tune.
1409
- teacher_model: The name of the model to distill from.
1410
- training_dataset: The training dataset to use.
1411
- config: The configuration to use for the distillation job.
1412
-
1413
- Returns:
1414
- A TuningJob object.
1415
- """
1416
-
1417
- parameter_model = types._CreateDistillationJobParameters(
1418
- student_model=student_model,
1419
- teacher_model=teacher_model,
1420
- training_dataset=training_dataset,
1421
- config=config,
1422
- )
1423
-
1424
- if not self._api_client.vertexai:
1425
- raise ValueError('This method is only supported in the Vertex AI client.')
1426
- else:
1427
- request_dict = _CreateDistillationJobParameters_to_vertex(
1428
- self._api_client, parameter_model
1429
- )
1430
- path = 'tuningJobs'.format_map(request_dict.get('_url'))
1431
-
1432
- query_params = request_dict.get('_query')
1433
- if query_params:
1434
- path = f'{path}?{urlencode(query_params)}'
1435
- # TODO: remove the hack that pops config.
1436
- config = request_dict.pop('config', None)
1437
- http_options = config.pop('httpOptions', None) if config else None
1438
- request_dict = _common.convert_to_dict(request_dict)
1439
- request_dict = _common.encode_unserializable_types(request_dict)
1440
-
1441
- response_dict = await self._api_client.async_request(
1442
- 'post', path, request_dict, http_options
1443
- )
1444
-
1445
- if self._api_client.vertexai:
1446
- response_dict = _TuningJob_from_vertex(self._api_client, response_dict)
1447
- else:
1448
- response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
1449
-
1450
- return_value = types.TuningJob._from_response(
1451
- response_dict, parameter_model
1452
- )
1453
- self._api_client._verify_response(return_value)
1454
- return return_value
1455
-
1456
1119
  async def list(
1457
1120
  self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None
1458
1121
  ) -> AsyncPager[types.TuningJob]: