google-genai 1.0.0rc0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/tunings.py CHANGED
@@ -655,42 +655,50 @@ def _ListTuningJobsResponse_from_vertex(
655
655
  return to_object
656
656
 
657
657
 
658
- def _TuningJobOrOperation_from_mldev(
658
+ def _Operation_from_mldev(
659
659
  api_client: ApiClient,
660
660
  from_object: Union[dict, object],
661
661
  parent_object: dict = None,
662
662
  ) -> dict:
663
663
  to_object = {}
664
- if getv(from_object, ['_self']) is not None:
665
- setv(
666
- to_object,
667
- ['tuning_job'],
668
- _TuningJob_from_mldev(
669
- api_client,
670
- t.t_resolve_operation(api_client, getv(from_object, ['_self'])),
671
- to_object,
672
- ),
673
- )
664
+ if getv(from_object, ['name']) is not None:
665
+ setv(to_object, ['name'], getv(from_object, ['name']))
666
+
667
+ if getv(from_object, ['metadata']) is not None:
668
+ setv(to_object, ['metadata'], getv(from_object, ['metadata']))
669
+
670
+ if getv(from_object, ['done']) is not None:
671
+ setv(to_object, ['done'], getv(from_object, ['done']))
672
+
673
+ if getv(from_object, ['error']) is not None:
674
+ setv(to_object, ['error'], getv(from_object, ['error']))
675
+
676
+ if getv(from_object, ['response']) is not None:
677
+ setv(to_object, ['response'], getv(from_object, ['response']))
674
678
 
675
679
  return to_object
676
680
 
677
681
 
678
- def _TuningJobOrOperation_from_vertex(
682
+ def _Operation_from_vertex(
679
683
  api_client: ApiClient,
680
684
  from_object: Union[dict, object],
681
685
  parent_object: dict = None,
682
686
  ) -> dict:
683
687
  to_object = {}
684
- if getv(from_object, ['_self']) is not None:
685
- setv(
686
- to_object,
687
- ['tuning_job'],
688
- _TuningJob_from_vertex(
689
- api_client,
690
- t.t_resolve_operation(api_client, getv(from_object, ['_self'])),
691
- to_object,
692
- ),
693
- )
688
+ if getv(from_object, ['name']) is not None:
689
+ setv(to_object, ['name'], getv(from_object, ['name']))
690
+
691
+ if getv(from_object, ['metadata']) is not None:
692
+ setv(to_object, ['metadata'], getv(from_object, ['metadata']))
693
+
694
+ if getv(from_object, ['done']) is not None:
695
+ setv(to_object, ['done'], getv(from_object, ['done']))
696
+
697
+ if getv(from_object, ['error']) is not None:
698
+ setv(to_object, ['error'], getv(from_object, ['error']))
699
+
700
+ if getv(from_object, ['response']) is not None:
701
+ setv(to_object, ['response'], getv(from_object, ['response']))
694
702
 
695
703
  return to_object
696
704
 
@@ -823,7 +831,7 @@ class Tunings(_api_module.BaseModule):
823
831
  base_model: str,
824
832
  training_dataset: types.TuningDatasetOrDict,
825
833
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
826
- ) -> types.TuningJobOrOperation:
834
+ ) -> types.TuningJob:
827
835
  """Creates a supervised fine-tuning job.
828
836
 
829
837
  Args:
@@ -841,16 +849,76 @@ class Tunings(_api_module.BaseModule):
841
849
  config=config,
842
850
  )
843
851
 
844
- if self._api_client.vertexai:
852
+ if not self._api_client.vertexai:
853
+ raise ValueError('This method is only supported in the Vertex AI client.')
854
+ else:
845
855
  request_dict = _CreateTuningJobParameters_to_vertex(
846
856
  self._api_client, parameter_model
847
857
  )
848
858
  path = 'tuningJobs'.format_map(request_dict.get('_url'))
859
+
860
+ query_params = request_dict.get('_query')
861
+ if query_params:
862
+ path = f'{path}?{urlencode(query_params)}'
863
+ # TODO: remove the hack that pops config.
864
+ request_dict.pop('config', None)
865
+
866
+ http_options = None
867
+ if isinstance(config, dict):
868
+ http_options = config.get('http_options', None)
869
+ elif hasattr(config, 'http_options'):
870
+ http_options = config.http_options
871
+
872
+ request_dict = _common.convert_to_dict(request_dict)
873
+ request_dict = _common.encode_unserializable_types(request_dict)
874
+
875
+ response_dict = self._api_client.request(
876
+ 'post', path, request_dict, http_options
877
+ )
878
+
879
+ if self._api_client.vertexai:
880
+ response_dict = _TuningJob_from_vertex(self._api_client, response_dict)
881
+ else:
882
+ response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
883
+
884
+ return_value = types.TuningJob._from_response(
885
+ response=response_dict, kwargs=parameter_model
886
+ )
887
+ self._api_client._verify_response(return_value)
888
+ return return_value
889
+
890
+ def _tune_mldev(
891
+ self,
892
+ *,
893
+ base_model: str,
894
+ training_dataset: types.TuningDatasetOrDict,
895
+ config: Optional[types.CreateTuningJobConfigOrDict] = None,
896
+ ) -> types.Operation:
897
+ """Creates a supervised fine-tuning job.
898
+
899
+ Args:
900
+ base_model: The name of the model to tune.
901
+ training_dataset: The training dataset to use.
902
+ config: The configuration to use for the tuning job.
903
+
904
+ Returns:
905
+ A TuningJob operation.
906
+ """
907
+
908
+ parameter_model = types._CreateTuningJobParameters(
909
+ base_model=base_model,
910
+ training_dataset=training_dataset,
911
+ config=config,
912
+ )
913
+
914
+ if self._api_client.vertexai:
915
+ raise ValueError('This method is only supported in the default client.')
849
916
  else:
850
917
  request_dict = _CreateTuningJobParameters_to_mldev(
851
918
  self._api_client, parameter_model
852
919
  )
853
920
  path = 'tunedModels'.format_map(request_dict.get('_url'))
921
+
854
922
  query_params = request_dict.get('_query')
855
923
  if query_params:
856
924
  path = f'{path}?{urlencode(query_params)}'
@@ -871,17 +939,13 @@ class Tunings(_api_module.BaseModule):
871
939
  )
872
940
 
873
941
  if self._api_client.vertexai:
874
- response_dict = _TuningJobOrOperation_from_vertex(
875
- self._api_client, response_dict
876
- )
942
+ response_dict = _Operation_from_vertex(self._api_client, response_dict)
877
943
  else:
878
- response_dict = _TuningJobOrOperation_from_mldev(
879
- self._api_client, response_dict
880
- )
944
+ response_dict = _Operation_from_mldev(self._api_client, response_dict)
881
945
 
882
- return_value = types.TuningJobOrOperation._from_response(
946
+ return_value = types.Operation._from_response(
883
947
  response=response_dict, kwargs=parameter_model
884
- ).tuning_job
948
+ )
885
949
  self._api_client._verify_response(return_value)
886
950
  return return_value
887
951
 
@@ -909,21 +973,43 @@ class Tunings(_api_module.BaseModule):
909
973
  )
910
974
  return job
911
975
 
976
+ @_common.experimental_warning(
977
+ "The SDK's tuning implementation is experimental, "
978
+ 'and may change in future versions.',
979
+ )
912
980
  def tune(
913
981
  self,
914
982
  *,
915
983
  base_model: str,
916
984
  training_dataset: types.TuningDatasetOrDict,
917
985
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
918
- ) -> types.TuningJobOrOperation:
919
- result = self._tune(
920
- base_model=base_model,
921
- training_dataset=training_dataset,
922
- config=config,
923
- )
924
- if result.name and self._api_client.vertexai:
925
- _IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name)
926
- return result
986
+ ) -> types.TuningJob:
987
+ if self._api_client.vertexai:
988
+ tuning_job = self._tune(
989
+ base_model=base_model,
990
+ training_dataset=training_dataset,
991
+ config=config,
992
+ )
993
+ else:
994
+ operation = self._tune_mldev(
995
+ base_model=base_model,
996
+ training_dataset=training_dataset,
997
+ config=config,
998
+ )
999
+ operation_dict = operation.to_json_dict()
1000
+ try:
1001
+ tuned_model_name = operation_dict['metadata']['tunedModel']
1002
+ except KeyError:
1003
+ tuned_model_name = operation_dict['name'].partition('/operations/')[0]
1004
+ tuning_job = types.TuningJob(
1005
+ name=tuned_model_name,
1006
+ state=types.JobState.JOB_STATE_QUEUED,
1007
+ )
1008
+ if tuning_job.name and self._api_client.vertexai:
1009
+ _IpythonUtils.display_model_tuning_button(
1010
+ tuning_job_resource=tuning_job.name
1011
+ )
1012
+ return tuning_job
927
1013
 
928
1014
 
929
1015
  class AsyncTunings(_api_module.BaseModule):
@@ -1054,7 +1140,7 @@ class AsyncTunings(_api_module.BaseModule):
1054
1140
  base_model: str,
1055
1141
  training_dataset: types.TuningDatasetOrDict,
1056
1142
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1057
- ) -> types.TuningJobOrOperation:
1143
+ ) -> types.TuningJob:
1058
1144
  """Creates a supervised fine-tuning job.
1059
1145
 
1060
1146
  Args:
@@ -1072,16 +1158,76 @@ class AsyncTunings(_api_module.BaseModule):
1072
1158
  config=config,
1073
1159
  )
1074
1160
 
1075
- if self._api_client.vertexai:
1161
+ if not self._api_client.vertexai:
1162
+ raise ValueError('This method is only supported in the Vertex AI client.')
1163
+ else:
1076
1164
  request_dict = _CreateTuningJobParameters_to_vertex(
1077
1165
  self._api_client, parameter_model
1078
1166
  )
1079
1167
  path = 'tuningJobs'.format_map(request_dict.get('_url'))
1168
+
1169
+ query_params = request_dict.get('_query')
1170
+ if query_params:
1171
+ path = f'{path}?{urlencode(query_params)}'
1172
+ # TODO: remove the hack that pops config.
1173
+ request_dict.pop('config', None)
1174
+
1175
+ http_options = None
1176
+ if isinstance(config, dict):
1177
+ http_options = config.get('http_options', None)
1178
+ elif hasattr(config, 'http_options'):
1179
+ http_options = config.http_options
1180
+
1181
+ request_dict = _common.convert_to_dict(request_dict)
1182
+ request_dict = _common.encode_unserializable_types(request_dict)
1183
+
1184
+ response_dict = await self._api_client.async_request(
1185
+ 'post', path, request_dict, http_options
1186
+ )
1187
+
1188
+ if self._api_client.vertexai:
1189
+ response_dict = _TuningJob_from_vertex(self._api_client, response_dict)
1190
+ else:
1191
+ response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
1192
+
1193
+ return_value = types.TuningJob._from_response(
1194
+ response=response_dict, kwargs=parameter_model
1195
+ )
1196
+ self._api_client._verify_response(return_value)
1197
+ return return_value
1198
+
1199
+ async def _tune_mldev(
1200
+ self,
1201
+ *,
1202
+ base_model: str,
1203
+ training_dataset: types.TuningDatasetOrDict,
1204
+ config: Optional[types.CreateTuningJobConfigOrDict] = None,
1205
+ ) -> types.Operation:
1206
+ """Creates a supervised fine-tuning job.
1207
+
1208
+ Args:
1209
+ base_model: The name of the model to tune.
1210
+ training_dataset: The training dataset to use.
1211
+ config: The configuration to use for the tuning job.
1212
+
1213
+ Returns:
1214
+ A TuningJob operation.
1215
+ """
1216
+
1217
+ parameter_model = types._CreateTuningJobParameters(
1218
+ base_model=base_model,
1219
+ training_dataset=training_dataset,
1220
+ config=config,
1221
+ )
1222
+
1223
+ if self._api_client.vertexai:
1224
+ raise ValueError('This method is only supported in the default client.')
1080
1225
  else:
1081
1226
  request_dict = _CreateTuningJobParameters_to_mldev(
1082
1227
  self._api_client, parameter_model
1083
1228
  )
1084
1229
  path = 'tunedModels'.format_map(request_dict.get('_url'))
1230
+
1085
1231
  query_params = request_dict.get('_query')
1086
1232
  if query_params:
1087
1233
  path = f'{path}?{urlencode(query_params)}'
@@ -1102,17 +1248,13 @@ class AsyncTunings(_api_module.BaseModule):
1102
1248
  )
1103
1249
 
1104
1250
  if self._api_client.vertexai:
1105
- response_dict = _TuningJobOrOperation_from_vertex(
1106
- self._api_client, response_dict
1107
- )
1251
+ response_dict = _Operation_from_vertex(self._api_client, response_dict)
1108
1252
  else:
1109
- response_dict = _TuningJobOrOperation_from_mldev(
1110
- self._api_client, response_dict
1111
- )
1253
+ response_dict = _Operation_from_mldev(self._api_client, response_dict)
1112
1254
 
1113
- return_value = types.TuningJobOrOperation._from_response(
1255
+ return_value = types.Operation._from_response(
1114
1256
  response=response_dict, kwargs=parameter_model
1115
- ).tuning_job
1257
+ )
1116
1258
  self._api_client._verify_response(return_value)
1117
1259
  return return_value
1118
1260
 
@@ -1140,21 +1282,43 @@ class AsyncTunings(_api_module.BaseModule):
1140
1282
  )
1141
1283
  return job
1142
1284
 
1285
+ @_common.experimental_warning(
1286
+ "The SDK's tuning implementation is experimental, "
1287
+ 'and may change in future versions.'
1288
+ )
1143
1289
  async def tune(
1144
1290
  self,
1145
1291
  *,
1146
1292
  base_model: str,
1147
1293
  training_dataset: types.TuningDatasetOrDict,
1148
1294
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1149
- ) -> types.TuningJobOrOperation:
1150
- result = await self._tune(
1151
- base_model=base_model,
1152
- training_dataset=training_dataset,
1153
- config=config,
1154
- )
1155
- if result.name and self._api_client.vertexai:
1156
- _IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name)
1157
- return result
1295
+ ) -> types.TuningJob:
1296
+ if self._api_client.vertexai:
1297
+ tuning_job = await self._tune(
1298
+ base_model=base_model,
1299
+ training_dataset=training_dataset,
1300
+ config=config,
1301
+ )
1302
+ else:
1303
+ operation = await self._tune_mldev(
1304
+ base_model=base_model,
1305
+ training_dataset=training_dataset,
1306
+ config=config,
1307
+ )
1308
+ operation_dict = operation.to_json_dict()
1309
+ try:
1310
+ tuned_model_name = operation_dict['metadata']['tunedModel']
1311
+ except KeyError:
1312
+ tuned_model_name = operation_dict['name'].partition('/operations/')[0]
1313
+ tuning_job = types.TuningJob(
1314
+ name=tuned_model_name,
1315
+ state=types.JobState.JOB_STATE_QUEUED,
1316
+ )
1317
+ if tuning_job.name and self._api_client.vertexai:
1318
+ _IpythonUtils.display_model_tuning_button(
1319
+ tuning_job_resource=tuning_job.name
1320
+ )
1321
+ return tuning_job
1158
1322
 
1159
1323
 
1160
1324
  class _IpythonUtils: