google-genai 0.7.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/tunings.py CHANGED
@@ -15,7 +15,8 @@
15
15
 
16
16
  # Code generated by the Google Gen AI SDK generator DO NOT EDIT.
17
17
 
18
- from typing import Optional, Union
18
+ import time
19
+ from typing import Any, Optional, Union
19
20
  from urllib.parse import urlencode
20
21
  from . import _api_module
21
22
  from . import _common
@@ -655,42 +656,50 @@ def _ListTuningJobsResponse_from_vertex(
655
656
  return to_object
656
657
 
657
658
 
658
- def _TuningJobOrOperation_from_mldev(
659
+ def _Operation_from_mldev(
659
660
  api_client: ApiClient,
660
661
  from_object: Union[dict, object],
661
662
  parent_object: dict = None,
662
663
  ) -> dict:
663
664
  to_object = {}
664
- if getv(from_object, ['_self']) is not None:
665
- setv(
666
- to_object,
667
- ['tuning_job'],
668
- _TuningJob_from_mldev(
669
- api_client,
670
- t.t_resolve_operation(api_client, getv(from_object, ['_self'])),
671
- to_object,
672
- ),
673
- )
665
+ if getv(from_object, ['name']) is not None:
666
+ setv(to_object, ['name'], getv(from_object, ['name']))
667
+
668
+ if getv(from_object, ['metadata']) is not None:
669
+ setv(to_object, ['metadata'], getv(from_object, ['metadata']))
670
+
671
+ if getv(from_object, ['done']) is not None:
672
+ setv(to_object, ['done'], getv(from_object, ['done']))
673
+
674
+ if getv(from_object, ['error']) is not None:
675
+ setv(to_object, ['error'], getv(from_object, ['error']))
676
+
677
+ if getv(from_object, ['response']) is not None:
678
+ setv(to_object, ['response'], getv(from_object, ['response']))
674
679
 
675
680
  return to_object
676
681
 
677
682
 
678
- def _TuningJobOrOperation_from_vertex(
683
+ def _Operation_from_vertex(
679
684
  api_client: ApiClient,
680
685
  from_object: Union[dict, object],
681
686
  parent_object: dict = None,
682
687
  ) -> dict:
683
688
  to_object = {}
684
- if getv(from_object, ['_self']) is not None:
685
- setv(
686
- to_object,
687
- ['tuning_job'],
688
- _TuningJob_from_vertex(
689
- api_client,
690
- t.t_resolve_operation(api_client, getv(from_object, ['_self'])),
691
- to_object,
692
- ),
693
- )
689
+ if getv(from_object, ['name']) is not None:
690
+ setv(to_object, ['name'], getv(from_object, ['name']))
691
+
692
+ if getv(from_object, ['metadata']) is not None:
693
+ setv(to_object, ['metadata'], getv(from_object, ['metadata']))
694
+
695
+ if getv(from_object, ['done']) is not None:
696
+ setv(to_object, ['done'], getv(from_object, ['done']))
697
+
698
+ if getv(from_object, ['error']) is not None:
699
+ setv(to_object, ['error'], getv(from_object, ['error']))
700
+
701
+ if getv(from_object, ['response']) is not None:
702
+ setv(to_object, ['response'], getv(from_object, ['response']))
694
703
 
695
704
  return to_object
696
705
 
@@ -823,7 +832,7 @@ class Tunings(_api_module.BaseModule):
823
832
  base_model: str,
824
833
  training_dataset: types.TuningDatasetOrDict,
825
834
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
826
- ) -> types.TuningJobOrOperation:
835
+ ) -> types.TuningJob:
827
836
  """Creates a supervised fine-tuning job.
828
837
 
829
838
  Args:
@@ -841,16 +850,76 @@ class Tunings(_api_module.BaseModule):
841
850
  config=config,
842
851
  )
843
852
 
844
- if self._api_client.vertexai:
853
+ if not self._api_client.vertexai:
854
+ raise ValueError('This method is only supported in the Vertex AI client.')
855
+ else:
845
856
  request_dict = _CreateTuningJobParameters_to_vertex(
846
857
  self._api_client, parameter_model
847
858
  )
848
859
  path = 'tuningJobs'.format_map(request_dict.get('_url'))
860
+
861
+ query_params = request_dict.get('_query')
862
+ if query_params:
863
+ path = f'{path}?{urlencode(query_params)}'
864
+ # TODO: remove the hack that pops config.
865
+ request_dict.pop('config', None)
866
+
867
+ http_options = None
868
+ if isinstance(config, dict):
869
+ http_options = config.get('http_options', None)
870
+ elif hasattr(config, 'http_options'):
871
+ http_options = config.http_options
872
+
873
+ request_dict = _common.convert_to_dict(request_dict)
874
+ request_dict = _common.encode_unserializable_types(request_dict)
875
+
876
+ response_dict = self._api_client.request(
877
+ 'post', path, request_dict, http_options
878
+ )
879
+
880
+ if self._api_client.vertexai:
881
+ response_dict = _TuningJob_from_vertex(self._api_client, response_dict)
882
+ else:
883
+ response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
884
+
885
+ return_value = types.TuningJob._from_response(
886
+ response=response_dict, kwargs=parameter_model
887
+ )
888
+ self._api_client._verify_response(return_value)
889
+ return return_value
890
+
891
+ def _tune_mldev(
892
+ self,
893
+ *,
894
+ base_model: str,
895
+ training_dataset: types.TuningDatasetOrDict,
896
+ config: Optional[types.CreateTuningJobConfigOrDict] = None,
897
+ ) -> types.Operation:
898
+ """Creates a supervised fine-tuning job.
899
+
900
+ Args:
901
+ base_model: The name of the model to tune.
902
+ training_dataset: The training dataset to use.
903
+ config: The configuration to use for the tuning job.
904
+
905
+ Returns:
906
+ A TuningJob operation.
907
+ """
908
+
909
+ parameter_model = types._CreateTuningJobParameters(
910
+ base_model=base_model,
911
+ training_dataset=training_dataset,
912
+ config=config,
913
+ )
914
+
915
+ if self._api_client.vertexai:
916
+ raise ValueError('This method is only supported in the default client.')
849
917
  else:
850
918
  request_dict = _CreateTuningJobParameters_to_mldev(
851
919
  self._api_client, parameter_model
852
920
  )
853
921
  path = 'tunedModels'.format_map(request_dict.get('_url'))
922
+
854
923
  query_params = request_dict.get('_query')
855
924
  if query_params:
856
925
  path = f'{path}?{urlencode(query_params)}'
@@ -871,17 +940,13 @@ class Tunings(_api_module.BaseModule):
871
940
  )
872
941
 
873
942
  if self._api_client.vertexai:
874
- response_dict = _TuningJobOrOperation_from_vertex(
875
- self._api_client, response_dict
876
- )
943
+ response_dict = _Operation_from_vertex(self._api_client, response_dict)
877
944
  else:
878
- response_dict = _TuningJobOrOperation_from_mldev(
879
- self._api_client, response_dict
880
- )
945
+ response_dict = _Operation_from_mldev(self._api_client, response_dict)
881
946
 
882
- return_value = types.TuningJobOrOperation._from_response(
947
+ return_value = types.Operation._from_response(
883
948
  response=response_dict, kwargs=parameter_model
884
- ).tuning_job
949
+ )
885
950
  self._api_client._verify_response(return_value)
886
951
  return return_value
887
952
 
@@ -909,21 +974,106 @@ class Tunings(_api_module.BaseModule):
909
974
  )
910
975
  return job
911
976
 
977
+ @_common.experimental_warning(
978
+ "The SDK's tuning implementation is experimental, "
979
+ 'and may change in future versions.',
980
+ )
912
981
  def tune(
913
982
  self,
914
983
  *,
915
984
  base_model: str,
916
985
  training_dataset: types.TuningDatasetOrDict,
917
986
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
918
- ) -> types.TuningJobOrOperation:
919
- result = self._tune(
920
- base_model=base_model,
921
- training_dataset=training_dataset,
922
- config=config,
923
- )
924
- if result.name and self._api_client.vertexai:
925
- _IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name)
926
- return result
987
+ ) -> types.TuningJob:
988
+ if self._api_client.vertexai:
989
+ tuning_job = self._tune(
990
+ base_model=base_model,
991
+ training_dataset=training_dataset,
992
+ config=config,
993
+ )
994
+ else:
995
+ operation = self._tune_mldev(
996
+ base_model=base_model,
997
+ training_dataset=training_dataset,
998
+ config=config,
999
+ )
1000
+ operation_dict = operation.to_json_dict()
1001
+ tuned_model_dict = _resolve_operation(self._api_client, operation_dict)
1002
+ tuning_job_dict = _TuningJob_from_mldev(
1003
+ self._api_client, tuned_model_dict
1004
+ )
1005
+ tuning_job = types.TuningJob._from_response(tuning_job_dict, None)
1006
+ if tuning_job.name and self._api_client.vertexai:
1007
+ _IpythonUtils.display_model_tuning_button(
1008
+ tuning_job_resource=tuning_job.name
1009
+ )
1010
+ return tuning_job
1011
+
1012
+
1013
+ _LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
1014
+ _LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
1015
+ _LRO_POLLING_TIMEOUT_SECONDS = 900.0
1016
+ _LRO_POLLING_MULTIPLIER = 1.5
1017
+
1018
+
1019
+ def _resolve_operation(api_client: ApiClient, struct: dict[str, Any]):
1020
+ if (name := struct.get('name')) and '/operations/' in name:
1021
+ operation: dict[str, Any] = struct
1022
+ total_seconds = 0.0
1023
+ delay_seconds = _LRO_POLLING_INITIAL_DELAY_SECONDS
1024
+ while not operation.get('done'):
1025
+ if total_seconds > _LRO_POLLING_TIMEOUT_SECONDS:
1026
+ raise RuntimeError(f'Operation {name} timed out.\n{operation}')
1027
+ # TODO(b/374433890): Replace with LRO module once it's available.
1028
+ operation: dict[str, Any] = api_client.request(
1029
+ http_method='GET', path=name, request_dict={}
1030
+ )
1031
+ if 'ReplayApiClient' not in type(api_client).__name__:
1032
+ time.sleep(delay_seconds)
1033
+ total_seconds += total_seconds
1034
+ # Exponential backoff
1035
+ delay_seconds = min(
1036
+ delay_seconds * _LRO_POLLING_MULTIPLIER,
1037
+ _LRO_POLLING_MAXIMUM_DELAY_SECONDS,
1038
+ )
1039
+ if error := operation.get('error'):
1040
+ raise RuntimeError(
1041
+ f'Operation {name} failed with error: {error}.\n{operation}'
1042
+ )
1043
+ return operation.get('response')
1044
+ else:
1045
+ return struct
1046
+
1047
+
1048
+ async def _resolve_operation_async(
1049
+ api_client: ApiClient, struct: dict[str, Any]
1050
+ ):
1051
+ if (name := struct.get('name')) and '/operations/' in name:
1052
+ operation: dict[str, Any] = struct
1053
+ total_seconds = 0.0
1054
+ delay_seconds = _LRO_POLLING_INITIAL_DELAY_SECONDS
1055
+ while not operation.get('done'):
1056
+ if total_seconds > _LRO_POLLING_TIMEOUT_SECONDS:
1057
+ raise RuntimeError(f'Operation {name} timed out.\n{operation}')
1058
+ # TODO(b/374433890): Replace with LRO module once it's available.
1059
+ operation: dict[str, Any] = await api_client.async_request(
1060
+ http_method='GET', path=name, request_dict={}
1061
+ )
1062
+ if 'ReplayApiClient' not in type(api_client).__name__:
1063
+ time.sleep(delay_seconds)
1064
+ total_seconds += total_seconds
1065
+ # Exponential backoff
1066
+ delay_seconds = min(
1067
+ delay_seconds * _LRO_POLLING_MULTIPLIER,
1068
+ _LRO_POLLING_MAXIMUM_DELAY_SECONDS,
1069
+ )
1070
+ if error := operation.get('error'):
1071
+ raise RuntimeError(
1072
+ f'Operation {name} failed with error: {error}.\n{operation}'
1073
+ )
1074
+ return operation.get('response')
1075
+ else:
1076
+ return struct
927
1077
 
928
1078
 
929
1079
  class AsyncTunings(_api_module.BaseModule):
@@ -1054,7 +1204,7 @@ class AsyncTunings(_api_module.BaseModule):
1054
1204
  base_model: str,
1055
1205
  training_dataset: types.TuningDatasetOrDict,
1056
1206
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1057
- ) -> types.TuningJobOrOperation:
1207
+ ) -> types.TuningJob:
1058
1208
  """Creates a supervised fine-tuning job.
1059
1209
 
1060
1210
  Args:
@@ -1072,16 +1222,76 @@ class AsyncTunings(_api_module.BaseModule):
1072
1222
  config=config,
1073
1223
  )
1074
1224
 
1075
- if self._api_client.vertexai:
1225
+ if not self._api_client.vertexai:
1226
+ raise ValueError('This method is only supported in the Vertex AI client.')
1227
+ else:
1076
1228
  request_dict = _CreateTuningJobParameters_to_vertex(
1077
1229
  self._api_client, parameter_model
1078
1230
  )
1079
1231
  path = 'tuningJobs'.format_map(request_dict.get('_url'))
1232
+
1233
+ query_params = request_dict.get('_query')
1234
+ if query_params:
1235
+ path = f'{path}?{urlencode(query_params)}'
1236
+ # TODO: remove the hack that pops config.
1237
+ request_dict.pop('config', None)
1238
+
1239
+ http_options = None
1240
+ if isinstance(config, dict):
1241
+ http_options = config.get('http_options', None)
1242
+ elif hasattr(config, 'http_options'):
1243
+ http_options = config.http_options
1244
+
1245
+ request_dict = _common.convert_to_dict(request_dict)
1246
+ request_dict = _common.encode_unserializable_types(request_dict)
1247
+
1248
+ response_dict = await self._api_client.async_request(
1249
+ 'post', path, request_dict, http_options
1250
+ )
1251
+
1252
+ if self._api_client.vertexai:
1253
+ response_dict = _TuningJob_from_vertex(self._api_client, response_dict)
1254
+ else:
1255
+ response_dict = _TuningJob_from_mldev(self._api_client, response_dict)
1256
+
1257
+ return_value = types.TuningJob._from_response(
1258
+ response=response_dict, kwargs=parameter_model
1259
+ )
1260
+ self._api_client._verify_response(return_value)
1261
+ return return_value
1262
+
1263
+ async def _tune_mldev(
1264
+ self,
1265
+ *,
1266
+ base_model: str,
1267
+ training_dataset: types.TuningDatasetOrDict,
1268
+ config: Optional[types.CreateTuningJobConfigOrDict] = None,
1269
+ ) -> types.Operation:
1270
+ """Creates a supervised fine-tuning job.
1271
+
1272
+ Args:
1273
+ base_model: The name of the model to tune.
1274
+ training_dataset: The training dataset to use.
1275
+ config: The configuration to use for the tuning job.
1276
+
1277
+ Returns:
1278
+ A TuningJob operation.
1279
+ """
1280
+
1281
+ parameter_model = types._CreateTuningJobParameters(
1282
+ base_model=base_model,
1283
+ training_dataset=training_dataset,
1284
+ config=config,
1285
+ )
1286
+
1287
+ if self._api_client.vertexai:
1288
+ raise ValueError('This method is only supported in the default client.')
1080
1289
  else:
1081
1290
  request_dict = _CreateTuningJobParameters_to_mldev(
1082
1291
  self._api_client, parameter_model
1083
1292
  )
1084
1293
  path = 'tunedModels'.format_map(request_dict.get('_url'))
1294
+
1085
1295
  query_params = request_dict.get('_query')
1086
1296
  if query_params:
1087
1297
  path = f'{path}?{urlencode(query_params)}'
@@ -1102,17 +1312,13 @@ class AsyncTunings(_api_module.BaseModule):
1102
1312
  )
1103
1313
 
1104
1314
  if self._api_client.vertexai:
1105
- response_dict = _TuningJobOrOperation_from_vertex(
1106
- self._api_client, response_dict
1107
- )
1315
+ response_dict = _Operation_from_vertex(self._api_client, response_dict)
1108
1316
  else:
1109
- response_dict = _TuningJobOrOperation_from_mldev(
1110
- self._api_client, response_dict
1111
- )
1317
+ response_dict = _Operation_from_mldev(self._api_client, response_dict)
1112
1318
 
1113
- return_value = types.TuningJobOrOperation._from_response(
1319
+ return_value = types.Operation._from_response(
1114
1320
  response=response_dict, kwargs=parameter_model
1115
- ).tuning_job
1321
+ )
1116
1322
  self._api_client._verify_response(return_value)
1117
1323
  return return_value
1118
1324
 
@@ -1140,21 +1346,42 @@ class AsyncTunings(_api_module.BaseModule):
1140
1346
  )
1141
1347
  return job
1142
1348
 
1349
+ @_common.experimental_warning(
1350
+ "The SDK's tuning implementation is experimental, "
1351
+ 'and may change in future versions.'
1352
+ )
1143
1353
  async def tune(
1144
1354
  self,
1145
1355
  *,
1146
1356
  base_model: str,
1147
1357
  training_dataset: types.TuningDatasetOrDict,
1148
1358
  config: Optional[types.CreateTuningJobConfigOrDict] = None,
1149
- ) -> types.TuningJobOrOperation:
1150
- result = await self._tune(
1151
- base_model=base_model,
1152
- training_dataset=training_dataset,
1153
- config=config,
1154
- )
1155
- if result.name and self._api_client.vertexai:
1156
- _IpythonUtils.display_model_tuning_button(tuning_job_resource=result.name)
1157
- return result
1359
+ ) -> types.TuningJob:
1360
+ if self._api_client.vertexai:
1361
+ tuning_job = await self._tune(
1362
+ base_model=base_model,
1363
+ training_dataset=training_dataset,
1364
+ config=config,
1365
+ )
1366
+ else:
1367
+ operation = await self._tune_mldev(
1368
+ base_model=base_model,
1369
+ training_dataset=training_dataset,
1370
+ config=config,
1371
+ )
1372
+ operation_dict = operation.to_json_dict()
1373
+ tuned_model_dict = await _resolve_operation_async(
1374
+ self._api_client, operation_dict
1375
+ )
1376
+ tuning_job_dict = _TuningJob_from_mldev(
1377
+ self._api_client, tuned_model_dict
1378
+ )
1379
+ tuning_job = types.TuningJob._from_response(tuning_job_dict, None)
1380
+ if tuning_job.name and self._api_client.vertexai:
1381
+ _IpythonUtils.display_model_tuning_button(
1382
+ tuning_job_resource=tuning_job.name
1383
+ )
1384
+ return tuning_job
1158
1385
 
1159
1386
 
1160
1387
  class _IpythonUtils: