google-genai 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +20 -0
- google/genai/_api_client.py +467 -0
- google/genai/_automatic_function_calling_util.py +341 -0
- google/genai/_common.py +256 -0
- google/genai/_extra_utils.py +295 -0
- google/genai/_replay_api_client.py +478 -0
- google/genai/_test_api_client.py +149 -0
- google/genai/_transformers.py +438 -0
- google/genai/batches.py +1041 -0
- google/genai/caches.py +1830 -0
- google/genai/chats.py +184 -0
- google/genai/client.py +277 -0
- google/genai/errors.py +110 -0
- google/genai/files.py +1211 -0
- google/genai/live.py +629 -0
- google/genai/models.py +5307 -0
- google/genai/pagers.py +245 -0
- google/genai/tunings.py +1366 -0
- google/genai/types.py +7639 -0
- google_genai-0.0.1.dist-info/LICENSE +202 -0
- google_genai-0.0.1.dist-info/METADATA +763 -0
- google_genai-0.0.1.dist-info/RECORD +24 -0
- google_genai-0.0.1.dist-info/WHEEL +5 -0
- google_genai-0.0.1.dist-info/top_level.txt +1 -0
google/genai/tunings.py
ADDED
@@ -0,0 +1,1366 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
from typing import Optional, Union
|
17
|
+
from urllib.parse import urlencode
|
18
|
+
from . import _common
|
19
|
+
from . import _transformers as t
|
20
|
+
from . import types
|
21
|
+
from ._api_client import ApiClient
|
22
|
+
from ._common import get_value_by_path as getv
|
23
|
+
from ._common import set_value_by_path as setv
|
24
|
+
from .pagers import AsyncPager, Pager
|
25
|
+
|
26
|
+
|
27
|
+
def _GetTuningJobParameters_to_mldev(
|
28
|
+
api_client: ApiClient,
|
29
|
+
from_object: Union[dict, object],
|
30
|
+
parent_object: dict = None,
|
31
|
+
) -> dict:
|
32
|
+
to_object = {}
|
33
|
+
if getv(from_object, ['name']) is not None:
|
34
|
+
setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
|
35
|
+
|
36
|
+
return to_object
|
37
|
+
|
38
|
+
|
39
|
+
def _GetTuningJobParameters_to_vertex(
|
40
|
+
api_client: ApiClient,
|
41
|
+
from_object: Union[dict, object],
|
42
|
+
parent_object: dict = None,
|
43
|
+
) -> dict:
|
44
|
+
to_object = {}
|
45
|
+
if getv(from_object, ['name']) is not None:
|
46
|
+
setv(to_object, ['_url', 'name'], getv(from_object, ['name']))
|
47
|
+
|
48
|
+
return to_object
|
49
|
+
|
50
|
+
|
51
|
+
def _ListTuningJobsConfig_to_mldev(
|
52
|
+
api_client: ApiClient,
|
53
|
+
from_object: Union[dict, object],
|
54
|
+
parent_object: dict = None,
|
55
|
+
) -> dict:
|
56
|
+
to_object = {}
|
57
|
+
if getv(from_object, ['page_size']) is not None:
|
58
|
+
setv(
|
59
|
+
parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
|
60
|
+
)
|
61
|
+
|
62
|
+
if getv(from_object, ['page_token']) is not None:
|
63
|
+
setv(
|
64
|
+
parent_object,
|
65
|
+
['_query', 'pageToken'],
|
66
|
+
getv(from_object, ['page_token']),
|
67
|
+
)
|
68
|
+
|
69
|
+
if getv(from_object, ['filter']) is not None:
|
70
|
+
setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter']))
|
71
|
+
|
72
|
+
return to_object
|
73
|
+
|
74
|
+
|
75
|
+
def _ListTuningJobsConfig_to_vertex(
|
76
|
+
api_client: ApiClient,
|
77
|
+
from_object: Union[dict, object],
|
78
|
+
parent_object: dict = None,
|
79
|
+
) -> dict:
|
80
|
+
to_object = {}
|
81
|
+
if getv(from_object, ['page_size']) is not None:
|
82
|
+
setv(
|
83
|
+
parent_object, ['_query', 'pageSize'], getv(from_object, ['page_size'])
|
84
|
+
)
|
85
|
+
|
86
|
+
if getv(from_object, ['page_token']) is not None:
|
87
|
+
setv(
|
88
|
+
parent_object,
|
89
|
+
['_query', 'pageToken'],
|
90
|
+
getv(from_object, ['page_token']),
|
91
|
+
)
|
92
|
+
|
93
|
+
if getv(from_object, ['filter']) is not None:
|
94
|
+
setv(parent_object, ['_query', 'filter'], getv(from_object, ['filter']))
|
95
|
+
|
96
|
+
return to_object
|
97
|
+
|
98
|
+
|
99
|
+
def _ListTuningJobsParameters_to_mldev(
|
100
|
+
api_client: ApiClient,
|
101
|
+
from_object: Union[dict, object],
|
102
|
+
parent_object: dict = None,
|
103
|
+
) -> dict:
|
104
|
+
to_object = {}
|
105
|
+
if getv(from_object, ['config']) is not None:
|
106
|
+
setv(
|
107
|
+
to_object,
|
108
|
+
['config'],
|
109
|
+
_ListTuningJobsConfig_to_mldev(
|
110
|
+
api_client, getv(from_object, ['config']), to_object
|
111
|
+
),
|
112
|
+
)
|
113
|
+
|
114
|
+
return to_object
|
115
|
+
|
116
|
+
|
117
|
+
def _ListTuningJobsParameters_to_vertex(
|
118
|
+
api_client: ApiClient,
|
119
|
+
from_object: Union[dict, object],
|
120
|
+
parent_object: dict = None,
|
121
|
+
) -> dict:
|
122
|
+
to_object = {}
|
123
|
+
if getv(from_object, ['config']) is not None:
|
124
|
+
setv(
|
125
|
+
to_object,
|
126
|
+
['config'],
|
127
|
+
_ListTuningJobsConfig_to_vertex(
|
128
|
+
api_client, getv(from_object, ['config']), to_object
|
129
|
+
),
|
130
|
+
)
|
131
|
+
|
132
|
+
return to_object
|
133
|
+
|
134
|
+
|
135
|
+
def _TuningExample_to_mldev(
|
136
|
+
api_client: ApiClient,
|
137
|
+
from_object: Union[dict, object],
|
138
|
+
parent_object: dict = None,
|
139
|
+
) -> dict:
|
140
|
+
to_object = {}
|
141
|
+
if getv(from_object, ['text_input']) is not None:
|
142
|
+
setv(to_object, ['textInput'], getv(from_object, ['text_input']))
|
143
|
+
|
144
|
+
if getv(from_object, ['output']) is not None:
|
145
|
+
setv(to_object, ['output'], getv(from_object, ['output']))
|
146
|
+
|
147
|
+
return to_object
|
148
|
+
|
149
|
+
|
150
|
+
def _TuningExample_to_vertex(
|
151
|
+
api_client: ApiClient,
|
152
|
+
from_object: Union[dict, object],
|
153
|
+
parent_object: dict = None,
|
154
|
+
) -> dict:
|
155
|
+
to_object = {}
|
156
|
+
if getv(from_object, ['text_input']):
|
157
|
+
raise ValueError('text_input parameter is not supported in Vertex AI.')
|
158
|
+
|
159
|
+
if getv(from_object, ['output']):
|
160
|
+
raise ValueError('output parameter is not supported in Vertex AI.')
|
161
|
+
|
162
|
+
return to_object
|
163
|
+
|
164
|
+
|
165
|
+
def _TuningDataset_to_mldev(
|
166
|
+
api_client: ApiClient,
|
167
|
+
from_object: Union[dict, object],
|
168
|
+
parent_object: dict = None,
|
169
|
+
) -> dict:
|
170
|
+
to_object = {}
|
171
|
+
if getv(from_object, ['gcs_uri']):
|
172
|
+
raise ValueError('gcs_uri parameter is not supported in Google AI.')
|
173
|
+
|
174
|
+
if getv(from_object, ['examples']) is not None:
|
175
|
+
setv(
|
176
|
+
to_object,
|
177
|
+
['examples', 'examples'],
|
178
|
+
[
|
179
|
+
_TuningExample_to_mldev(api_client, item, to_object)
|
180
|
+
for item in getv(from_object, ['examples'])
|
181
|
+
],
|
182
|
+
)
|
183
|
+
|
184
|
+
return to_object
|
185
|
+
|
186
|
+
|
187
|
+
def _TuningDataset_to_vertex(
|
188
|
+
api_client: ApiClient,
|
189
|
+
from_object: Union[dict, object],
|
190
|
+
parent_object: dict = None,
|
191
|
+
) -> dict:
|
192
|
+
to_object = {}
|
193
|
+
if getv(from_object, ['gcs_uri']) is not None:
|
194
|
+
setv(
|
195
|
+
parent_object,
|
196
|
+
['supervisedTuningSpec', 'trainingDatasetUri'],
|
197
|
+
getv(from_object, ['gcs_uri']),
|
198
|
+
)
|
199
|
+
|
200
|
+
if getv(from_object, ['examples']):
|
201
|
+
raise ValueError('examples parameter is not supported in Vertex AI.')
|
202
|
+
|
203
|
+
return to_object
|
204
|
+
|
205
|
+
|
206
|
+
def _TuningValidationDataset_to_mldev(
|
207
|
+
api_client: ApiClient,
|
208
|
+
from_object: Union[dict, object],
|
209
|
+
parent_object: dict = None,
|
210
|
+
) -> dict:
|
211
|
+
to_object = {}
|
212
|
+
if getv(from_object, ['gcs_uri']):
|
213
|
+
raise ValueError('gcs_uri parameter is not supported in Google AI.')
|
214
|
+
|
215
|
+
return to_object
|
216
|
+
|
217
|
+
|
218
|
+
def _TuningValidationDataset_to_vertex(
|
219
|
+
api_client: ApiClient,
|
220
|
+
from_object: Union[dict, object],
|
221
|
+
parent_object: dict = None,
|
222
|
+
) -> dict:
|
223
|
+
to_object = {}
|
224
|
+
if getv(from_object, ['gcs_uri']) is not None:
|
225
|
+
setv(to_object, ['validationDatasetUri'], getv(from_object, ['gcs_uri']))
|
226
|
+
|
227
|
+
return to_object
|
228
|
+
|
229
|
+
|
230
|
+
def _CreateTuningJobConfig_to_mldev(
|
231
|
+
api_client: ApiClient,
|
232
|
+
from_object: Union[dict, object],
|
233
|
+
parent_object: dict = None,
|
234
|
+
) -> dict:
|
235
|
+
to_object = {}
|
236
|
+
if getv(from_object, ['validation_dataset']):
|
237
|
+
raise ValueError(
|
238
|
+
'validation_dataset parameter is not supported in Google AI.'
|
239
|
+
)
|
240
|
+
|
241
|
+
if getv(from_object, ['tuned_model_display_name']) is not None:
|
242
|
+
setv(
|
243
|
+
parent_object,
|
244
|
+
['displayName'],
|
245
|
+
getv(from_object, ['tuned_model_display_name']),
|
246
|
+
)
|
247
|
+
|
248
|
+
if getv(from_object, ['description']):
|
249
|
+
raise ValueError('description parameter is not supported in Google AI.')
|
250
|
+
|
251
|
+
if getv(from_object, ['epoch_count']) is not None:
|
252
|
+
setv(
|
253
|
+
parent_object,
|
254
|
+
['tuningTask', 'hyperparameters', 'epochCount'],
|
255
|
+
getv(from_object, ['epoch_count']),
|
256
|
+
)
|
257
|
+
|
258
|
+
if getv(from_object, ['learning_rate_multiplier']) is not None:
|
259
|
+
setv(
|
260
|
+
to_object,
|
261
|
+
['tuningTask', 'hyperparameters', 'learningRateMultiplier'],
|
262
|
+
getv(from_object, ['learning_rate_multiplier']),
|
263
|
+
)
|
264
|
+
|
265
|
+
if getv(from_object, ['adapter_size']):
|
266
|
+
raise ValueError('adapter_size parameter is not supported in Google AI.')
|
267
|
+
|
268
|
+
if getv(from_object, ['batch_size']) is not None:
|
269
|
+
setv(
|
270
|
+
parent_object,
|
271
|
+
['tuningTask', 'hyperparameters', 'batchSize'],
|
272
|
+
getv(from_object, ['batch_size']),
|
273
|
+
)
|
274
|
+
|
275
|
+
if getv(from_object, ['learning_rate']) is not None:
|
276
|
+
setv(
|
277
|
+
parent_object,
|
278
|
+
['tuningTask', 'hyperparameters', 'learningRate'],
|
279
|
+
getv(from_object, ['learning_rate']),
|
280
|
+
)
|
281
|
+
|
282
|
+
return to_object
|
283
|
+
|
284
|
+
|
285
|
+
def _CreateTuningJobConfig_to_vertex(
|
286
|
+
api_client: ApiClient,
|
287
|
+
from_object: Union[dict, object],
|
288
|
+
parent_object: dict = None,
|
289
|
+
) -> dict:
|
290
|
+
to_object = {}
|
291
|
+
if getv(from_object, ['validation_dataset']) is not None:
|
292
|
+
setv(
|
293
|
+
parent_object,
|
294
|
+
['supervisedTuningSpec'],
|
295
|
+
_TuningValidationDataset_to_vertex(
|
296
|
+
api_client, getv(from_object, ['validation_dataset']), to_object
|
297
|
+
),
|
298
|
+
)
|
299
|
+
|
300
|
+
if getv(from_object, ['tuned_model_display_name']) is not None:
|
301
|
+
setv(
|
302
|
+
parent_object,
|
303
|
+
['tunedModelDisplayName'],
|
304
|
+
getv(from_object, ['tuned_model_display_name']),
|
305
|
+
)
|
306
|
+
|
307
|
+
if getv(from_object, ['description']) is not None:
|
308
|
+
setv(parent_object, ['description'], getv(from_object, ['description']))
|
309
|
+
|
310
|
+
if getv(from_object, ['epoch_count']) is not None:
|
311
|
+
setv(
|
312
|
+
parent_object,
|
313
|
+
['supervisedTuningSpec', 'hyperParameters', 'epochCount'],
|
314
|
+
getv(from_object, ['epoch_count']),
|
315
|
+
)
|
316
|
+
|
317
|
+
if getv(from_object, ['learning_rate_multiplier']) is not None:
|
318
|
+
setv(
|
319
|
+
to_object,
|
320
|
+
['supervisedTuningSpec', 'hyperParameters', 'learningRateMultiplier'],
|
321
|
+
getv(from_object, ['learning_rate_multiplier']),
|
322
|
+
)
|
323
|
+
|
324
|
+
if getv(from_object, ['adapter_size']) is not None:
|
325
|
+
setv(
|
326
|
+
parent_object,
|
327
|
+
['supervisedTuningSpec', 'hyperParameters', 'adapterSize'],
|
328
|
+
getv(from_object, ['adapter_size']),
|
329
|
+
)
|
330
|
+
|
331
|
+
if getv(from_object, ['batch_size']):
|
332
|
+
raise ValueError('batch_size parameter is not supported in Vertex AI.')
|
333
|
+
|
334
|
+
if getv(from_object, ['learning_rate']):
|
335
|
+
raise ValueError('learning_rate parameter is not supported in Vertex AI.')
|
336
|
+
|
337
|
+
return to_object
|
338
|
+
|
339
|
+
|
340
|
+
def _CreateTuningJobParameters_to_mldev(
|
341
|
+
api_client: ApiClient,
|
342
|
+
from_object: Union[dict, object],
|
343
|
+
parent_object: dict = None,
|
344
|
+
) -> dict:
|
345
|
+
to_object = {}
|
346
|
+
if getv(from_object, ['base_model']) is not None:
|
347
|
+
setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
|
348
|
+
|
349
|
+
if getv(from_object, ['training_dataset']) is not None:
|
350
|
+
setv(
|
351
|
+
to_object,
|
352
|
+
['tuningTask', 'trainingData'],
|
353
|
+
_TuningDataset_to_mldev(
|
354
|
+
api_client, getv(from_object, ['training_dataset']), to_object
|
355
|
+
),
|
356
|
+
)
|
357
|
+
|
358
|
+
if getv(from_object, ['config']) is not None:
|
359
|
+
setv(
|
360
|
+
to_object,
|
361
|
+
['config'],
|
362
|
+
_CreateTuningJobConfig_to_mldev(
|
363
|
+
api_client, getv(from_object, ['config']), to_object
|
364
|
+
),
|
365
|
+
)
|
366
|
+
|
367
|
+
return to_object
|
368
|
+
|
369
|
+
|
370
|
+
def _CreateTuningJobParameters_to_vertex(
|
371
|
+
api_client: ApiClient,
|
372
|
+
from_object: Union[dict, object],
|
373
|
+
parent_object: dict = None,
|
374
|
+
) -> dict:
|
375
|
+
to_object = {}
|
376
|
+
if getv(from_object, ['base_model']) is not None:
|
377
|
+
setv(to_object, ['baseModel'], getv(from_object, ['base_model']))
|
378
|
+
|
379
|
+
if getv(from_object, ['training_dataset']) is not None:
|
380
|
+
setv(
|
381
|
+
to_object,
|
382
|
+
['supervisedTuningSpec', 'trainingDatasetUri'],
|
383
|
+
_TuningDataset_to_vertex(
|
384
|
+
api_client, getv(from_object, ['training_dataset']), to_object
|
385
|
+
),
|
386
|
+
)
|
387
|
+
|
388
|
+
if getv(from_object, ['config']) is not None:
|
389
|
+
setv(
|
390
|
+
to_object,
|
391
|
+
['config'],
|
392
|
+
_CreateTuningJobConfig_to_vertex(
|
393
|
+
api_client, getv(from_object, ['config']), to_object
|
394
|
+
),
|
395
|
+
)
|
396
|
+
|
397
|
+
return to_object
|
398
|
+
|
399
|
+
|
400
|
+
def _DistillationDataset_to_mldev(
|
401
|
+
api_client: ApiClient,
|
402
|
+
from_object: Union[dict, object],
|
403
|
+
parent_object: dict = None,
|
404
|
+
) -> dict:
|
405
|
+
to_object = {}
|
406
|
+
if getv(from_object, ['gcs_uri']):
|
407
|
+
raise ValueError('gcs_uri parameter is not supported in Google AI.')
|
408
|
+
|
409
|
+
return to_object
|
410
|
+
|
411
|
+
|
412
|
+
def _DistillationDataset_to_vertex(
|
413
|
+
api_client: ApiClient,
|
414
|
+
from_object: Union[dict, object],
|
415
|
+
parent_object: dict = None,
|
416
|
+
) -> dict:
|
417
|
+
to_object = {}
|
418
|
+
if getv(from_object, ['gcs_uri']) is not None:
|
419
|
+
setv(
|
420
|
+
parent_object,
|
421
|
+
['distillationSpec', 'trainingDatasetUri'],
|
422
|
+
getv(from_object, ['gcs_uri']),
|
423
|
+
)
|
424
|
+
|
425
|
+
return to_object
|
426
|
+
|
427
|
+
|
428
|
+
def _DistillationValidationDataset_to_mldev(
|
429
|
+
api_client: ApiClient,
|
430
|
+
from_object: Union[dict, object],
|
431
|
+
parent_object: dict = None,
|
432
|
+
) -> dict:
|
433
|
+
to_object = {}
|
434
|
+
if getv(from_object, ['gcs_uri']):
|
435
|
+
raise ValueError('gcs_uri parameter is not supported in Google AI.')
|
436
|
+
|
437
|
+
return to_object
|
438
|
+
|
439
|
+
|
440
|
+
def _DistillationValidationDataset_to_vertex(
|
441
|
+
api_client: ApiClient,
|
442
|
+
from_object: Union[dict, object],
|
443
|
+
parent_object: dict = None,
|
444
|
+
) -> dict:
|
445
|
+
to_object = {}
|
446
|
+
if getv(from_object, ['gcs_uri']) is not None:
|
447
|
+
setv(to_object, ['validationDatasetUri'], getv(from_object, ['gcs_uri']))
|
448
|
+
|
449
|
+
return to_object
|
450
|
+
|
451
|
+
|
452
|
+
def _CreateDistillationJobConfig_to_mldev(
|
453
|
+
api_client: ApiClient,
|
454
|
+
from_object: Union[dict, object],
|
455
|
+
parent_object: dict = None,
|
456
|
+
) -> dict:
|
457
|
+
to_object = {}
|
458
|
+
if getv(from_object, ['validation_dataset']):
|
459
|
+
raise ValueError(
|
460
|
+
'validation_dataset parameter is not supported in Google AI.'
|
461
|
+
)
|
462
|
+
|
463
|
+
if getv(from_object, ['tuned_model_display_name']) is not None:
|
464
|
+
setv(
|
465
|
+
parent_object,
|
466
|
+
['displayName'],
|
467
|
+
getv(from_object, ['tuned_model_display_name']),
|
468
|
+
)
|
469
|
+
|
470
|
+
if getv(from_object, ['epoch_count']) is not None:
|
471
|
+
setv(
|
472
|
+
parent_object,
|
473
|
+
['tuningTask', 'hyperparameters', 'epochCount'],
|
474
|
+
getv(from_object, ['epoch_count']),
|
475
|
+
)
|
476
|
+
|
477
|
+
if getv(from_object, ['learning_rate_multiplier']) is not None:
|
478
|
+
setv(
|
479
|
+
parent_object,
|
480
|
+
['tuningTask', 'hyperparameters', 'learningRateMultiplier'],
|
481
|
+
getv(from_object, ['learning_rate_multiplier']),
|
482
|
+
)
|
483
|
+
|
484
|
+
if getv(from_object, ['adapter_size']):
|
485
|
+
raise ValueError('adapter_size parameter is not supported in Google AI.')
|
486
|
+
|
487
|
+
if getv(from_object, ['pipeline_root_directory']):
|
488
|
+
raise ValueError(
|
489
|
+
'pipeline_root_directory parameter is not supported in Google AI.'
|
490
|
+
)
|
491
|
+
|
492
|
+
return to_object
|
493
|
+
|
494
|
+
|
495
|
+
def _CreateDistillationJobConfig_to_vertex(
|
496
|
+
api_client: ApiClient,
|
497
|
+
from_object: Union[dict, object],
|
498
|
+
parent_object: dict = None,
|
499
|
+
) -> dict:
|
500
|
+
to_object = {}
|
501
|
+
if getv(from_object, ['validation_dataset']) is not None:
|
502
|
+
setv(
|
503
|
+
parent_object,
|
504
|
+
['distillationSpec'],
|
505
|
+
_DistillationValidationDataset_to_vertex(
|
506
|
+
api_client, getv(from_object, ['validation_dataset']), to_object
|
507
|
+
),
|
508
|
+
)
|
509
|
+
|
510
|
+
if getv(from_object, ['tuned_model_display_name']) is not None:
|
511
|
+
setv(
|
512
|
+
parent_object,
|
513
|
+
['tunedModelDisplayName'],
|
514
|
+
getv(from_object, ['tuned_model_display_name']),
|
515
|
+
)
|
516
|
+
|
517
|
+
if getv(from_object, ['epoch_count']) is not None:
|
518
|
+
setv(
|
519
|
+
parent_object,
|
520
|
+
['distillationSpec', 'hyperParameters', 'epochCount'],
|
521
|
+
getv(from_object, ['epoch_count']),
|
522
|
+
)
|
523
|
+
|
524
|
+
if getv(from_object, ['learning_rate_multiplier']) is not None:
|
525
|
+
setv(
|
526
|
+
parent_object,
|
527
|
+
['distillationSpec', 'hyperParameters', 'learningRateMultiplier'],
|
528
|
+
getv(from_object, ['learning_rate_multiplier']),
|
529
|
+
)
|
530
|
+
|
531
|
+
if getv(from_object, ['adapter_size']) is not None:
|
532
|
+
setv(
|
533
|
+
parent_object,
|
534
|
+
['distillationSpec', 'hyperParameters', 'adapterSize'],
|
535
|
+
getv(from_object, ['adapter_size']),
|
536
|
+
)
|
537
|
+
|
538
|
+
if getv(from_object, ['pipeline_root_directory']) is not None:
|
539
|
+
setv(
|
540
|
+
parent_object,
|
541
|
+
['distillationSpec', 'pipelineRootDirectory'],
|
542
|
+
getv(from_object, ['pipeline_root_directory']),
|
543
|
+
)
|
544
|
+
|
545
|
+
return to_object
|
546
|
+
|
547
|
+
|
548
|
+
def _CreateDistillationJobParameters_to_mldev(
|
549
|
+
api_client: ApiClient,
|
550
|
+
from_object: Union[dict, object],
|
551
|
+
parent_object: dict = None,
|
552
|
+
) -> dict:
|
553
|
+
to_object = {}
|
554
|
+
if getv(from_object, ['student_model']):
|
555
|
+
raise ValueError('student_model parameter is not supported in Google AI.')
|
556
|
+
|
557
|
+
if getv(from_object, ['teacher_model']):
|
558
|
+
raise ValueError('teacher_model parameter is not supported in Google AI.')
|
559
|
+
|
560
|
+
if getv(from_object, ['training_dataset']) is not None:
|
561
|
+
setv(
|
562
|
+
to_object,
|
563
|
+
['tuningTask', 'trainingData'],
|
564
|
+
_DistillationDataset_to_mldev(
|
565
|
+
api_client, getv(from_object, ['training_dataset']), to_object
|
566
|
+
),
|
567
|
+
)
|
568
|
+
|
569
|
+
if getv(from_object, ['config']) is not None:
|
570
|
+
setv(
|
571
|
+
to_object,
|
572
|
+
['config'],
|
573
|
+
_CreateDistillationJobConfig_to_mldev(
|
574
|
+
api_client, getv(from_object, ['config']), to_object
|
575
|
+
),
|
576
|
+
)
|
577
|
+
|
578
|
+
return to_object
|
579
|
+
|
580
|
+
|
581
|
+
def _CreateDistillationJobParameters_to_vertex(
|
582
|
+
api_client: ApiClient,
|
583
|
+
from_object: Union[dict, object],
|
584
|
+
parent_object: dict = None,
|
585
|
+
) -> dict:
|
586
|
+
to_object = {}
|
587
|
+
if getv(from_object, ['student_model']) is not None:
|
588
|
+
setv(
|
589
|
+
to_object,
|
590
|
+
['distillationSpec', 'studentModel'],
|
591
|
+
getv(from_object, ['student_model']),
|
592
|
+
)
|
593
|
+
|
594
|
+
if getv(from_object, ['teacher_model']) is not None:
|
595
|
+
setv(
|
596
|
+
to_object,
|
597
|
+
['distillationSpec', 'baseTeacherModel'],
|
598
|
+
getv(from_object, ['teacher_model']),
|
599
|
+
)
|
600
|
+
|
601
|
+
if getv(from_object, ['training_dataset']) is not None:
|
602
|
+
setv(
|
603
|
+
to_object,
|
604
|
+
['distillationSpec', 'trainingDatasetUri'],
|
605
|
+
_DistillationDataset_to_vertex(
|
606
|
+
api_client, getv(from_object, ['training_dataset']), to_object
|
607
|
+
),
|
608
|
+
)
|
609
|
+
|
610
|
+
if getv(from_object, ['config']) is not None:
|
611
|
+
setv(
|
612
|
+
to_object,
|
613
|
+
['config'],
|
614
|
+
_CreateDistillationJobConfig_to_vertex(
|
615
|
+
api_client, getv(from_object, ['config']), to_object
|
616
|
+
),
|
617
|
+
)
|
618
|
+
|
619
|
+
return to_object
|
620
|
+
|
621
|
+
|
622
|
+
def _TunedModel_from_mldev(
|
623
|
+
api_client: ApiClient,
|
624
|
+
from_object: Union[dict, object],
|
625
|
+
parent_object: dict = None,
|
626
|
+
) -> dict:
|
627
|
+
to_object = {}
|
628
|
+
if getv(from_object, ['name']) is not None:
|
629
|
+
setv(to_object, ['model'], getv(from_object, ['name']))
|
630
|
+
|
631
|
+
if getv(from_object, ['name']) is not None:
|
632
|
+
setv(to_object, ['endpoint'], getv(from_object, ['name']))
|
633
|
+
|
634
|
+
return to_object
|
635
|
+
|
636
|
+
|
637
|
+
def _TunedModel_from_vertex(
|
638
|
+
api_client: ApiClient,
|
639
|
+
from_object: Union[dict, object],
|
640
|
+
parent_object: dict = None,
|
641
|
+
) -> dict:
|
642
|
+
to_object = {}
|
643
|
+
if getv(from_object, ['model']) is not None:
|
644
|
+
setv(to_object, ['model'], getv(from_object, ['model']))
|
645
|
+
|
646
|
+
if getv(from_object, ['endpoint']) is not None:
|
647
|
+
setv(to_object, ['endpoint'], getv(from_object, ['endpoint']))
|
648
|
+
|
649
|
+
return to_object
|
650
|
+
|
651
|
+
|
652
|
+
def _TuningJob_from_mldev(
|
653
|
+
api_client: ApiClient,
|
654
|
+
from_object: Union[dict, object],
|
655
|
+
parent_object: dict = None,
|
656
|
+
) -> dict:
|
657
|
+
to_object = {}
|
658
|
+
if getv(from_object, ['name']) is not None:
|
659
|
+
setv(to_object, ['name'], getv(from_object, ['name']))
|
660
|
+
|
661
|
+
if getv(from_object, ['state']) is not None:
|
662
|
+
setv(
|
663
|
+
to_object,
|
664
|
+
['state'],
|
665
|
+
t.t_tuning_job_status(api_client, getv(from_object, ['state'])),
|
666
|
+
)
|
667
|
+
|
668
|
+
if getv(from_object, ['createTime']) is not None:
|
669
|
+
setv(to_object, ['create_time'], getv(from_object, ['createTime']))
|
670
|
+
|
671
|
+
if getv(from_object, ['tuningTask', 'startTime']) is not None:
|
672
|
+
setv(
|
673
|
+
to_object,
|
674
|
+
['start_time'],
|
675
|
+
getv(from_object, ['tuningTask', 'startTime']),
|
676
|
+
)
|
677
|
+
|
678
|
+
if getv(from_object, ['tuningTask', 'completeTime']) is not None:
|
679
|
+
setv(
|
680
|
+
to_object,
|
681
|
+
['end_time'],
|
682
|
+
getv(from_object, ['tuningTask', 'completeTime']),
|
683
|
+
)
|
684
|
+
|
685
|
+
if getv(from_object, ['updateTime']) is not None:
|
686
|
+
setv(to_object, ['update_time'], getv(from_object, ['updateTime']))
|
687
|
+
|
688
|
+
if getv(from_object, ['description']) is not None:
|
689
|
+
setv(to_object, ['description'], getv(from_object, ['description']))
|
690
|
+
|
691
|
+
if getv(from_object, ['baseModel']) is not None:
|
692
|
+
setv(to_object, ['base_model'], getv(from_object, ['baseModel']))
|
693
|
+
|
694
|
+
if getv(from_object, ['_self']) is not None:
|
695
|
+
setv(
|
696
|
+
to_object,
|
697
|
+
['tuned_model'],
|
698
|
+
_TunedModel_from_mldev(
|
699
|
+
api_client, getv(from_object, ['_self']), to_object
|
700
|
+
),
|
701
|
+
)
|
702
|
+
|
703
|
+
if getv(from_object, ['experiment']) is not None:
|
704
|
+
setv(to_object, ['experiment'], getv(from_object, ['experiment']))
|
705
|
+
|
706
|
+
if getv(from_object, ['labels']) is not None:
|
707
|
+
setv(to_object, ['labels'], getv(from_object, ['labels']))
|
708
|
+
|
709
|
+
if getv(from_object, ['tunedModelDisplayName']) is not None:
|
710
|
+
setv(
|
711
|
+
to_object,
|
712
|
+
['tuned_model_display_name'],
|
713
|
+
getv(from_object, ['tunedModelDisplayName']),
|
714
|
+
)
|
715
|
+
|
716
|
+
return to_object
|
717
|
+
|
718
|
+
|
719
|
+
def _TuningJob_from_vertex(
|
720
|
+
api_client: ApiClient,
|
721
|
+
from_object: Union[dict, object],
|
722
|
+
parent_object: dict = None,
|
723
|
+
) -> dict:
|
724
|
+
to_object = {}
|
725
|
+
if getv(from_object, ['name']) is not None:
|
726
|
+
setv(to_object, ['name'], getv(from_object, ['name']))
|
727
|
+
|
728
|
+
if getv(from_object, ['state']) is not None:
|
729
|
+
setv(
|
730
|
+
to_object,
|
731
|
+
['state'],
|
732
|
+
t.t_tuning_job_status(api_client, getv(from_object, ['state'])),
|
733
|
+
)
|
734
|
+
|
735
|
+
if getv(from_object, ['createTime']) is not None:
|
736
|
+
setv(to_object, ['create_time'], getv(from_object, ['createTime']))
|
737
|
+
|
738
|
+
if getv(from_object, ['startTime']) is not None:
|
739
|
+
setv(to_object, ['start_time'], getv(from_object, ['startTime']))
|
740
|
+
|
741
|
+
if getv(from_object, ['endTime']) is not None:
|
742
|
+
setv(to_object, ['end_time'], getv(from_object, ['endTime']))
|
743
|
+
|
744
|
+
if getv(from_object, ['updateTime']) is not None:
|
745
|
+
setv(to_object, ['update_time'], getv(from_object, ['updateTime']))
|
746
|
+
|
747
|
+
if getv(from_object, ['error']) is not None:
|
748
|
+
setv(to_object, ['error'], getv(from_object, ['error']))
|
749
|
+
|
750
|
+
if getv(from_object, ['description']) is not None:
|
751
|
+
setv(to_object, ['description'], getv(from_object, ['description']))
|
752
|
+
|
753
|
+
if getv(from_object, ['baseModel']) is not None:
|
754
|
+
setv(to_object, ['base_model'], getv(from_object, ['baseModel']))
|
755
|
+
|
756
|
+
if getv(from_object, ['tunedModel']) is not None:
|
757
|
+
setv(
|
758
|
+
to_object,
|
759
|
+
['tuned_model'],
|
760
|
+
_TunedModel_from_vertex(
|
761
|
+
api_client, getv(from_object, ['tunedModel']), to_object
|
762
|
+
),
|
763
|
+
)
|
764
|
+
|
765
|
+
if getv(from_object, ['supervisedTuningSpec']) is not None:
|
766
|
+
setv(
|
767
|
+
to_object,
|
768
|
+
['supervised_tuning_spec'],
|
769
|
+
getv(from_object, ['supervisedTuningSpec']),
|
770
|
+
)
|
771
|
+
|
772
|
+
if getv(from_object, ['tuningDataStats']) is not None:
|
773
|
+
setv(
|
774
|
+
to_object, ['tuning_data_stats'], getv(from_object, ['tuningDataStats'])
|
775
|
+
)
|
776
|
+
|
777
|
+
if getv(from_object, ['encryptionSpec']) is not None:
|
778
|
+
setv(to_object, ['encryption_spec'], getv(from_object, ['encryptionSpec']))
|
779
|
+
|
780
|
+
if getv(from_object, ['distillationSpec']) is not None:
|
781
|
+
setv(
|
782
|
+
to_object,
|
783
|
+
['distillation_spec'],
|
784
|
+
getv(from_object, ['distillationSpec']),
|
785
|
+
)
|
786
|
+
|
787
|
+
if getv(from_object, ['partnerModelTuningSpec']) is not None:
|
788
|
+
setv(
|
789
|
+
to_object,
|
790
|
+
['partner_model_tuning_spec'],
|
791
|
+
getv(from_object, ['partnerModelTuningSpec']),
|
792
|
+
)
|
793
|
+
|
794
|
+
if getv(from_object, ['pipelineJob']) is not None:
|
795
|
+
setv(to_object, ['pipeline_job'], getv(from_object, ['pipelineJob']))
|
796
|
+
|
797
|
+
if getv(from_object, ['experiment']) is not None:
|
798
|
+
setv(to_object, ['experiment'], getv(from_object, ['experiment']))
|
799
|
+
|
800
|
+
if getv(from_object, ['labels']) is not None:
|
801
|
+
setv(to_object, ['labels'], getv(from_object, ['labels']))
|
802
|
+
|
803
|
+
if getv(from_object, ['tunedModelDisplayName']) is not None:
|
804
|
+
setv(
|
805
|
+
to_object,
|
806
|
+
['tuned_model_display_name'],
|
807
|
+
getv(from_object, ['tunedModelDisplayName']),
|
808
|
+
)
|
809
|
+
|
810
|
+
return to_object
|
811
|
+
|
812
|
+
|
813
|
+
def _ListTuningJobsResponse_from_mldev(
|
814
|
+
api_client: ApiClient,
|
815
|
+
from_object: Union[dict, object],
|
816
|
+
parent_object: dict = None,
|
817
|
+
) -> dict:
|
818
|
+
to_object = {}
|
819
|
+
if getv(from_object, ['nextPageToken']) is not None:
|
820
|
+
setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken']))
|
821
|
+
|
822
|
+
if getv(from_object, ['tunedModels']) is not None:
|
823
|
+
setv(
|
824
|
+
to_object,
|
825
|
+
['tuning_jobs'],
|
826
|
+
[
|
827
|
+
_TuningJob_from_mldev(api_client, item, to_object)
|
828
|
+
for item in getv(from_object, ['tunedModels'])
|
829
|
+
],
|
830
|
+
)
|
831
|
+
|
832
|
+
return to_object
|
833
|
+
|
834
|
+
|
835
|
+
def _ListTuningJobsResponse_from_vertex(
|
836
|
+
api_client: ApiClient,
|
837
|
+
from_object: Union[dict, object],
|
838
|
+
parent_object: dict = None,
|
839
|
+
) -> dict:
|
840
|
+
to_object = {}
|
841
|
+
if getv(from_object, ['nextPageToken']) is not None:
|
842
|
+
setv(to_object, ['next_page_token'], getv(from_object, ['nextPageToken']))
|
843
|
+
|
844
|
+
if getv(from_object, ['tuningJobs']) is not None:
|
845
|
+
setv(
|
846
|
+
to_object,
|
847
|
+
['tuning_jobs'],
|
848
|
+
[
|
849
|
+
_TuningJob_from_vertex(api_client, item, to_object)
|
850
|
+
for item in getv(from_object, ['tuningJobs'])
|
851
|
+
],
|
852
|
+
)
|
853
|
+
|
854
|
+
return to_object
|
855
|
+
|
856
|
+
|
857
|
+
def _TuningJobOrOperation_from_mldev(
|
858
|
+
api_client: ApiClient,
|
859
|
+
from_object: Union[dict, object],
|
860
|
+
parent_object: dict = None,
|
861
|
+
) -> dict:
|
862
|
+
to_object = {}
|
863
|
+
if getv(from_object, ['_self']) is not None:
|
864
|
+
setv(
|
865
|
+
to_object,
|
866
|
+
['tuning_job'],
|
867
|
+
_TuningJob_from_mldev(
|
868
|
+
api_client,
|
869
|
+
t.t_resolve_operation(api_client, getv(from_object, ['_self'])),
|
870
|
+
to_object,
|
871
|
+
),
|
872
|
+
)
|
873
|
+
|
874
|
+
return to_object
|
875
|
+
|
876
|
+
|
877
|
+
def _TuningJobOrOperation_from_vertex(
|
878
|
+
api_client: ApiClient,
|
879
|
+
from_object: Union[dict, object],
|
880
|
+
parent_object: dict = None,
|
881
|
+
) -> dict:
|
882
|
+
to_object = {}
|
883
|
+
if getv(from_object, ['_self']) is not None:
|
884
|
+
setv(
|
885
|
+
to_object,
|
886
|
+
['tuning_job'],
|
887
|
+
_TuningJob_from_vertex(
|
888
|
+
api_client,
|
889
|
+
t.t_resolve_operation(api_client, getv(from_object, ['_self'])),
|
890
|
+
to_object,
|
891
|
+
),
|
892
|
+
)
|
893
|
+
|
894
|
+
return to_object
|
895
|
+
|
896
|
+
|
897
|
+
class Tunings(_common.BaseModule):
|
898
|
+
|
899
|
+
def get(self, *, name: str) -> types.TuningJob:
|
900
|
+
"""Gets a TuningJob.
|
901
|
+
|
902
|
+
Args:
|
903
|
+
name: The resource name of the tuning job.
|
904
|
+
|
905
|
+
Returns:
|
906
|
+
A TuningJob object.
|
907
|
+
"""
|
908
|
+
|
909
|
+
parameter_model = types._GetTuningJobParameters(
|
910
|
+
name=name,
|
911
|
+
)
|
912
|
+
|
913
|
+
if self.api_client.vertexai:
|
914
|
+
request_dict = _GetTuningJobParameters_to_vertex(
|
915
|
+
self.api_client, parameter_model
|
916
|
+
)
|
917
|
+
path = '{name}'.format_map(request_dict.get('_url'))
|
918
|
+
else:
|
919
|
+
request_dict = _GetTuningJobParameters_to_mldev(
|
920
|
+
self.api_client, parameter_model
|
921
|
+
)
|
922
|
+
path = '{name}'.format_map(request_dict.get('_url'))
|
923
|
+
query_params = request_dict.get('_query')
|
924
|
+
if query_params:
|
925
|
+
path = f'{path}?{urlencode(query_params)}'
|
926
|
+
# TODO: remove the hack that pops config.
|
927
|
+
config = request_dict.pop('config', None)
|
928
|
+
http_options = config.pop('httpOptions', None) if config else None
|
929
|
+
request_dict = _common.convert_to_dict(request_dict)
|
930
|
+
request_dict = _common.apply_base64_encoding(request_dict)
|
931
|
+
|
932
|
+
response_dict = self.api_client.request(
|
933
|
+
'get', path, request_dict, http_options
|
934
|
+
)
|
935
|
+
|
936
|
+
if self.api_client.vertexai:
|
937
|
+
response_dict = _TuningJob_from_vertex(self.api_client, response_dict)
|
938
|
+
else:
|
939
|
+
response_dict = _TuningJob_from_mldev(self.api_client, response_dict)
|
940
|
+
|
941
|
+
return_value = types.TuningJob._from_response(
|
942
|
+
response_dict, parameter_model
|
943
|
+
)
|
944
|
+
self.api_client._verify_response(return_value)
|
945
|
+
return return_value
|
946
|
+
|
947
|
+
def _list(
|
948
|
+
self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None
|
949
|
+
) -> types.ListTuningJobsResponse:
|
950
|
+
"""Lists tuning jobs.
|
951
|
+
|
952
|
+
Args:
|
953
|
+
config: The configuration for the list request.
|
954
|
+
|
955
|
+
Returns:
|
956
|
+
A list of tuning jobs.
|
957
|
+
"""
|
958
|
+
|
959
|
+
parameter_model = types._ListTuningJobsParameters(
|
960
|
+
config=config,
|
961
|
+
)
|
962
|
+
|
963
|
+
if self.api_client.vertexai:
|
964
|
+
request_dict = _ListTuningJobsParameters_to_vertex(
|
965
|
+
self.api_client, parameter_model
|
966
|
+
)
|
967
|
+
path = 'tuningJobs'.format_map(request_dict.get('_url'))
|
968
|
+
else:
|
969
|
+
request_dict = _ListTuningJobsParameters_to_mldev(
|
970
|
+
self.api_client, parameter_model
|
971
|
+
)
|
972
|
+
path = 'tunedModels'.format_map(request_dict.get('_url'))
|
973
|
+
query_params = request_dict.get('_query')
|
974
|
+
if query_params:
|
975
|
+
path = f'{path}?{urlencode(query_params)}'
|
976
|
+
# TODO: remove the hack that pops config.
|
977
|
+
config = request_dict.pop('config', None)
|
978
|
+
http_options = config.pop('httpOptions', None) if config else None
|
979
|
+
request_dict = _common.convert_to_dict(request_dict)
|
980
|
+
request_dict = _common.apply_base64_encoding(request_dict)
|
981
|
+
|
982
|
+
response_dict = self.api_client.request(
|
983
|
+
'get', path, request_dict, http_options
|
984
|
+
)
|
985
|
+
|
986
|
+
if self.api_client.vertexai:
|
987
|
+
response_dict = _ListTuningJobsResponse_from_vertex(
|
988
|
+
self.api_client, response_dict
|
989
|
+
)
|
990
|
+
else:
|
991
|
+
response_dict = _ListTuningJobsResponse_from_mldev(
|
992
|
+
self.api_client, response_dict
|
993
|
+
)
|
994
|
+
|
995
|
+
return_value = types.ListTuningJobsResponse._from_response(
|
996
|
+
response_dict, parameter_model
|
997
|
+
)
|
998
|
+
self.api_client._verify_response(return_value)
|
999
|
+
return return_value
|
1000
|
+
|
1001
|
+
def tune(
|
1002
|
+
self,
|
1003
|
+
*,
|
1004
|
+
base_model: str,
|
1005
|
+
training_dataset: types.TuningDatasetOrDict,
|
1006
|
+
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
1007
|
+
) -> types.TuningJobOrOperation:
|
1008
|
+
"""Creates a supervised fine-tuning job.
|
1009
|
+
|
1010
|
+
Args:
|
1011
|
+
base_model: The name of the model to tune.
|
1012
|
+
training_dataset: The training dataset to use.
|
1013
|
+
config: The configuration to use for the tuning job.
|
1014
|
+
|
1015
|
+
Returns:
|
1016
|
+
A TuningJob object.
|
1017
|
+
"""
|
1018
|
+
|
1019
|
+
parameter_model = types._CreateTuningJobParameters(
|
1020
|
+
base_model=base_model,
|
1021
|
+
training_dataset=training_dataset,
|
1022
|
+
config=config,
|
1023
|
+
)
|
1024
|
+
|
1025
|
+
if self.api_client.vertexai:
|
1026
|
+
request_dict = _CreateTuningJobParameters_to_vertex(
|
1027
|
+
self.api_client, parameter_model
|
1028
|
+
)
|
1029
|
+
path = 'tuningJobs'.format_map(request_dict.get('_url'))
|
1030
|
+
else:
|
1031
|
+
request_dict = _CreateTuningJobParameters_to_mldev(
|
1032
|
+
self.api_client, parameter_model
|
1033
|
+
)
|
1034
|
+
path = 'tunedModels'.format_map(request_dict.get('_url'))
|
1035
|
+
query_params = request_dict.get('_query')
|
1036
|
+
if query_params:
|
1037
|
+
path = f'{path}?{urlencode(query_params)}'
|
1038
|
+
# TODO: remove the hack that pops config.
|
1039
|
+
config = request_dict.pop('config', None)
|
1040
|
+
http_options = config.pop('httpOptions', None) if config else None
|
1041
|
+
request_dict = _common.convert_to_dict(request_dict)
|
1042
|
+
request_dict = _common.apply_base64_encoding(request_dict)
|
1043
|
+
|
1044
|
+
response_dict = self.api_client.request(
|
1045
|
+
'post', path, request_dict, http_options
|
1046
|
+
)
|
1047
|
+
|
1048
|
+
if self.api_client.vertexai:
|
1049
|
+
response_dict = _TuningJobOrOperation_from_vertex(
|
1050
|
+
self.api_client, response_dict
|
1051
|
+
)
|
1052
|
+
else:
|
1053
|
+
response_dict = _TuningJobOrOperation_from_mldev(
|
1054
|
+
self.api_client, response_dict
|
1055
|
+
)
|
1056
|
+
|
1057
|
+
return_value = types.TuningJobOrOperation._from_response(
|
1058
|
+
response_dict, parameter_model
|
1059
|
+
).tuning_job
|
1060
|
+
self.api_client._verify_response(return_value)
|
1061
|
+
return return_value
|
1062
|
+
|
1063
|
+
def distill(
|
1064
|
+
self,
|
1065
|
+
*,
|
1066
|
+
student_model: str,
|
1067
|
+
teacher_model: str,
|
1068
|
+
training_dataset: types.DistillationDatasetOrDict,
|
1069
|
+
config: Optional[types.CreateDistillationJobConfigOrDict] = None,
|
1070
|
+
) -> types.TuningJob:
|
1071
|
+
"""Creates a distillation job.
|
1072
|
+
|
1073
|
+
Args:
|
1074
|
+
student_model: The name of the model to tune.
|
1075
|
+
teacher_model: The name of the model to distill from.
|
1076
|
+
training_dataset: The training dataset to use.
|
1077
|
+
config: The configuration to use for the distillation job.
|
1078
|
+
|
1079
|
+
Returns:
|
1080
|
+
A TuningJob object.
|
1081
|
+
"""
|
1082
|
+
|
1083
|
+
parameter_model = types._CreateDistillationJobParameters(
|
1084
|
+
student_model=student_model,
|
1085
|
+
teacher_model=teacher_model,
|
1086
|
+
training_dataset=training_dataset,
|
1087
|
+
config=config,
|
1088
|
+
)
|
1089
|
+
|
1090
|
+
if not self.api_client.vertexai:
|
1091
|
+
raise ValueError('This method is only supported in the Vertex AI client.')
|
1092
|
+
else:
|
1093
|
+
request_dict = _CreateDistillationJobParameters_to_vertex(
|
1094
|
+
self.api_client, parameter_model
|
1095
|
+
)
|
1096
|
+
path = 'tuningJobs'.format_map(request_dict.get('_url'))
|
1097
|
+
|
1098
|
+
query_params = request_dict.get('_query')
|
1099
|
+
if query_params:
|
1100
|
+
path = f'{path}?{urlencode(query_params)}'
|
1101
|
+
# TODO: remove the hack that pops config.
|
1102
|
+
config = request_dict.pop('config', None)
|
1103
|
+
http_options = config.pop('httpOptions', None) if config else None
|
1104
|
+
request_dict = _common.convert_to_dict(request_dict)
|
1105
|
+
request_dict = _common.apply_base64_encoding(request_dict)
|
1106
|
+
|
1107
|
+
response_dict = self.api_client.request(
|
1108
|
+
'post', path, request_dict, http_options
|
1109
|
+
)
|
1110
|
+
|
1111
|
+
if self.api_client.vertexai:
|
1112
|
+
response_dict = _TuningJob_from_vertex(self.api_client, response_dict)
|
1113
|
+
else:
|
1114
|
+
response_dict = _TuningJob_from_mldev(self.api_client, response_dict)
|
1115
|
+
|
1116
|
+
return_value = types.TuningJob._from_response(
|
1117
|
+
response_dict, parameter_model
|
1118
|
+
)
|
1119
|
+
self.api_client._verify_response(return_value)
|
1120
|
+
return return_value
|
1121
|
+
|
1122
|
+
def list(
|
1123
|
+
self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None
|
1124
|
+
) -> Pager[types.TuningJob]:
|
1125
|
+
return Pager(
|
1126
|
+
'tuning_jobs',
|
1127
|
+
self._list,
|
1128
|
+
self._list(config=config),
|
1129
|
+
config,
|
1130
|
+
)
|
1131
|
+
|
1132
|
+
|
1133
|
+
class AsyncTunings(_common.BaseModule):
|
1134
|
+
|
1135
|
+
async def get(self, *, name: str) -> types.TuningJob:
|
1136
|
+
"""Gets a TuningJob.
|
1137
|
+
|
1138
|
+
Args:
|
1139
|
+
name: The resource name of the tuning job.
|
1140
|
+
|
1141
|
+
Returns:
|
1142
|
+
A TuningJob object.
|
1143
|
+
"""
|
1144
|
+
|
1145
|
+
parameter_model = types._GetTuningJobParameters(
|
1146
|
+
name=name,
|
1147
|
+
)
|
1148
|
+
|
1149
|
+
if self.api_client.vertexai:
|
1150
|
+
request_dict = _GetTuningJobParameters_to_vertex(
|
1151
|
+
self.api_client, parameter_model
|
1152
|
+
)
|
1153
|
+
path = '{name}'.format_map(request_dict.get('_url'))
|
1154
|
+
else:
|
1155
|
+
request_dict = _GetTuningJobParameters_to_mldev(
|
1156
|
+
self.api_client, parameter_model
|
1157
|
+
)
|
1158
|
+
path = '{name}'.format_map(request_dict.get('_url'))
|
1159
|
+
query_params = request_dict.get('_query')
|
1160
|
+
if query_params:
|
1161
|
+
path = f'{path}?{urlencode(query_params)}'
|
1162
|
+
# TODO: remove the hack that pops config.
|
1163
|
+
config = request_dict.pop('config', None)
|
1164
|
+
http_options = config.pop('httpOptions', None) if config else None
|
1165
|
+
request_dict = _common.convert_to_dict(request_dict)
|
1166
|
+
request_dict = _common.apply_base64_encoding(request_dict)
|
1167
|
+
|
1168
|
+
response_dict = await self.api_client.async_request(
|
1169
|
+
'get', path, request_dict, http_options
|
1170
|
+
)
|
1171
|
+
|
1172
|
+
if self.api_client.vertexai:
|
1173
|
+
response_dict = _TuningJob_from_vertex(self.api_client, response_dict)
|
1174
|
+
else:
|
1175
|
+
response_dict = _TuningJob_from_mldev(self.api_client, response_dict)
|
1176
|
+
|
1177
|
+
return_value = types.TuningJob._from_response(
|
1178
|
+
response_dict, parameter_model
|
1179
|
+
)
|
1180
|
+
self.api_client._verify_response(return_value)
|
1181
|
+
return return_value
|
1182
|
+
|
1183
|
+
async def _list(
|
1184
|
+
self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None
|
1185
|
+
) -> types.ListTuningJobsResponse:
|
1186
|
+
"""Lists tuning jobs.
|
1187
|
+
|
1188
|
+
Args:
|
1189
|
+
config: The configuration for the list request.
|
1190
|
+
|
1191
|
+
Returns:
|
1192
|
+
A list of tuning jobs.
|
1193
|
+
"""
|
1194
|
+
|
1195
|
+
parameter_model = types._ListTuningJobsParameters(
|
1196
|
+
config=config,
|
1197
|
+
)
|
1198
|
+
|
1199
|
+
if self.api_client.vertexai:
|
1200
|
+
request_dict = _ListTuningJobsParameters_to_vertex(
|
1201
|
+
self.api_client, parameter_model
|
1202
|
+
)
|
1203
|
+
path = 'tuningJobs'.format_map(request_dict.get('_url'))
|
1204
|
+
else:
|
1205
|
+
request_dict = _ListTuningJobsParameters_to_mldev(
|
1206
|
+
self.api_client, parameter_model
|
1207
|
+
)
|
1208
|
+
path = 'tunedModels'.format_map(request_dict.get('_url'))
|
1209
|
+
query_params = request_dict.get('_query')
|
1210
|
+
if query_params:
|
1211
|
+
path = f'{path}?{urlencode(query_params)}'
|
1212
|
+
# TODO: remove the hack that pops config.
|
1213
|
+
config = request_dict.pop('config', None)
|
1214
|
+
http_options = config.pop('httpOptions', None) if config else None
|
1215
|
+
request_dict = _common.convert_to_dict(request_dict)
|
1216
|
+
request_dict = _common.apply_base64_encoding(request_dict)
|
1217
|
+
|
1218
|
+
response_dict = await self.api_client.async_request(
|
1219
|
+
'get', path, request_dict, http_options
|
1220
|
+
)
|
1221
|
+
|
1222
|
+
if self.api_client.vertexai:
|
1223
|
+
response_dict = _ListTuningJobsResponse_from_vertex(
|
1224
|
+
self.api_client, response_dict
|
1225
|
+
)
|
1226
|
+
else:
|
1227
|
+
response_dict = _ListTuningJobsResponse_from_mldev(
|
1228
|
+
self.api_client, response_dict
|
1229
|
+
)
|
1230
|
+
|
1231
|
+
return_value = types.ListTuningJobsResponse._from_response(
|
1232
|
+
response_dict, parameter_model
|
1233
|
+
)
|
1234
|
+
self.api_client._verify_response(return_value)
|
1235
|
+
return return_value
|
1236
|
+
|
1237
|
+
async def tune(
|
1238
|
+
self,
|
1239
|
+
*,
|
1240
|
+
base_model: str,
|
1241
|
+
training_dataset: types.TuningDatasetOrDict,
|
1242
|
+
config: Optional[types.CreateTuningJobConfigOrDict] = None,
|
1243
|
+
) -> types.TuningJobOrOperation:
|
1244
|
+
"""Creates a supervised fine-tuning job.
|
1245
|
+
|
1246
|
+
Args:
|
1247
|
+
base_model: The name of the model to tune.
|
1248
|
+
training_dataset: The training dataset to use.
|
1249
|
+
config: The configuration to use for the tuning job.
|
1250
|
+
|
1251
|
+
Returns:
|
1252
|
+
A TuningJob object.
|
1253
|
+
"""
|
1254
|
+
|
1255
|
+
parameter_model = types._CreateTuningJobParameters(
|
1256
|
+
base_model=base_model,
|
1257
|
+
training_dataset=training_dataset,
|
1258
|
+
config=config,
|
1259
|
+
)
|
1260
|
+
|
1261
|
+
if self.api_client.vertexai:
|
1262
|
+
request_dict = _CreateTuningJobParameters_to_vertex(
|
1263
|
+
self.api_client, parameter_model
|
1264
|
+
)
|
1265
|
+
path = 'tuningJobs'.format_map(request_dict.get('_url'))
|
1266
|
+
else:
|
1267
|
+
request_dict = _CreateTuningJobParameters_to_mldev(
|
1268
|
+
self.api_client, parameter_model
|
1269
|
+
)
|
1270
|
+
path = 'tunedModels'.format_map(request_dict.get('_url'))
|
1271
|
+
query_params = request_dict.get('_query')
|
1272
|
+
if query_params:
|
1273
|
+
path = f'{path}?{urlencode(query_params)}'
|
1274
|
+
# TODO: remove the hack that pops config.
|
1275
|
+
config = request_dict.pop('config', None)
|
1276
|
+
http_options = config.pop('httpOptions', None) if config else None
|
1277
|
+
request_dict = _common.convert_to_dict(request_dict)
|
1278
|
+
request_dict = _common.apply_base64_encoding(request_dict)
|
1279
|
+
|
1280
|
+
response_dict = await self.api_client.async_request(
|
1281
|
+
'post', path, request_dict, http_options
|
1282
|
+
)
|
1283
|
+
|
1284
|
+
if self.api_client.vertexai:
|
1285
|
+
response_dict = _TuningJobOrOperation_from_vertex(
|
1286
|
+
self.api_client, response_dict
|
1287
|
+
)
|
1288
|
+
else:
|
1289
|
+
response_dict = _TuningJobOrOperation_from_mldev(
|
1290
|
+
self.api_client, response_dict
|
1291
|
+
)
|
1292
|
+
|
1293
|
+
return_value = types.TuningJobOrOperation._from_response(
|
1294
|
+
response_dict, parameter_model
|
1295
|
+
).tuning_job
|
1296
|
+
self.api_client._verify_response(return_value)
|
1297
|
+
return return_value
|
1298
|
+
|
1299
|
+
async def distill(
|
1300
|
+
self,
|
1301
|
+
*,
|
1302
|
+
student_model: str,
|
1303
|
+
teacher_model: str,
|
1304
|
+
training_dataset: types.DistillationDatasetOrDict,
|
1305
|
+
config: Optional[types.CreateDistillationJobConfigOrDict] = None,
|
1306
|
+
) -> types.TuningJob:
|
1307
|
+
"""Creates a distillation job.
|
1308
|
+
|
1309
|
+
Args:
|
1310
|
+
student_model: The name of the model to tune.
|
1311
|
+
teacher_model: The name of the model to distill from.
|
1312
|
+
training_dataset: The training dataset to use.
|
1313
|
+
config: The configuration to use for the distillation job.
|
1314
|
+
|
1315
|
+
Returns:
|
1316
|
+
A TuningJob object.
|
1317
|
+
"""
|
1318
|
+
|
1319
|
+
parameter_model = types._CreateDistillationJobParameters(
|
1320
|
+
student_model=student_model,
|
1321
|
+
teacher_model=teacher_model,
|
1322
|
+
training_dataset=training_dataset,
|
1323
|
+
config=config,
|
1324
|
+
)
|
1325
|
+
|
1326
|
+
if not self.api_client.vertexai:
|
1327
|
+
raise ValueError('This method is only supported in the Vertex AI client.')
|
1328
|
+
else:
|
1329
|
+
request_dict = _CreateDistillationJobParameters_to_vertex(
|
1330
|
+
self.api_client, parameter_model
|
1331
|
+
)
|
1332
|
+
path = 'tuningJobs'.format_map(request_dict.get('_url'))
|
1333
|
+
|
1334
|
+
query_params = request_dict.get('_query')
|
1335
|
+
if query_params:
|
1336
|
+
path = f'{path}?{urlencode(query_params)}'
|
1337
|
+
# TODO: remove the hack that pops config.
|
1338
|
+
config = request_dict.pop('config', None)
|
1339
|
+
http_options = config.pop('httpOptions', None) if config else None
|
1340
|
+
request_dict = _common.convert_to_dict(request_dict)
|
1341
|
+
request_dict = _common.apply_base64_encoding(request_dict)
|
1342
|
+
|
1343
|
+
response_dict = await self.api_client.async_request(
|
1344
|
+
'post', path, request_dict, http_options
|
1345
|
+
)
|
1346
|
+
|
1347
|
+
if self.api_client.vertexai:
|
1348
|
+
response_dict = _TuningJob_from_vertex(self.api_client, response_dict)
|
1349
|
+
else:
|
1350
|
+
response_dict = _TuningJob_from_mldev(self.api_client, response_dict)
|
1351
|
+
|
1352
|
+
return_value = types.TuningJob._from_response(
|
1353
|
+
response_dict, parameter_model
|
1354
|
+
)
|
1355
|
+
self.api_client._verify_response(return_value)
|
1356
|
+
return return_value
|
1357
|
+
|
1358
|
+
async def list(
|
1359
|
+
self, *, config: Optional[types.ListTuningJobsConfigOrDict] = None
|
1360
|
+
) -> AsyncPager[types.TuningJob]:
|
1361
|
+
return AsyncPager(
|
1362
|
+
'tuning_jobs',
|
1363
|
+
self._list,
|
1364
|
+
await self._list(config=config),
|
1365
|
+
config,
|
1366
|
+
)
|