google-genai 1.60.0__py3-none-any.whl → 1.61.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. google/genai/_interactions/resources/interactions.py +50 -28
  2. google/genai/_interactions/types/__init__.py +2 -1
  3. google/genai/_interactions/types/content_delta.py +1 -1
  4. google/genai/_interactions/types/function_result_content.py +2 -1
  5. google/genai/_interactions/types/function_result_content_param.py +4 -4
  6. google/genai/_interactions/types/{interaction_event.py → interaction_complete_event.py} +3 -3
  7. google/genai/_interactions/types/interaction_create_params.py +4 -4
  8. google/genai/_interactions/types/interaction_get_params.py +3 -0
  9. google/genai/_interactions/types/interaction_sse_event.py +11 -2
  10. google/genai/_interactions/types/interaction_start_event.py +36 -0
  11. google/genai/batches.py +3 -0
  12. google/genai/files.py +15 -15
  13. google/genai/tests/batches/test_create_with_inlined_requests.py +31 -15
  14. google/genai/tests/batches/test_get.py +1 -1
  15. google/genai/tests/client/test_client_close.py +0 -1
  16. google/genai/tests/files/test_register_table.py +1 -1
  17. google/genai/tests/transformers/test_schema.py +10 -1
  18. google/genai/tests/tunings/test_tune.py +87 -0
  19. google/genai/tunings.py +163 -4
  20. google/genai/types.py +178 -14
  21. google/genai/version.py +1 -1
  22. {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/METADATA +1 -1
  23. {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/RECORD +26 -25
  24. {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/WHEEL +1 -1
  25. {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/licenses/LICENSE +0 -0
  26. {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/top_level.txt +0 -0
@@ -42,20 +42,36 @@ _SAFETY_SETTINGS = [
42
42
  },
43
43
  ]
44
44
 
45
- _INLINED_REQUEST = {
46
- 'contents': [{
47
- 'parts': [{
48
- 'text': 'Hello!',
45
+ _INLINED_REQUESTS = [
46
+ {
47
+ 'contents': [{
48
+ 'parts': [{
49
+ 'text': 'what is the number after 1? return just the number.',
50
+ }],
51
+ 'role': 'user',
49
52
  }],
50
- 'role': 'user',
51
- }],
52
- 'metadata': {
53
- 'key': 'request-1',
53
+ 'metadata': {
54
+ 'key': 'request-1',
55
+ },
56
+ 'config': {
57
+ 'safety_settings': _SAFETY_SETTINGS,
58
+ },
54
59
  },
55
- 'config': {
56
- 'safety_settings': _SAFETY_SETTINGS,
60
+ {
61
+ 'contents': [{
62
+ 'parts': [{
63
+ 'text': 'what is the number after 2? return just the number.',
64
+ }],
65
+ 'role': 'user',
66
+ }],
67
+ 'metadata': {
68
+ 'key': 'request-2',
69
+ },
70
+ 'config': {
71
+ 'safety_settings': _SAFETY_SETTINGS,
72
+ },
57
73
  },
58
- }
74
+ ]
59
75
  _INLINED_TEXT_REQUEST_UNION = {
60
76
  'contents': [{
61
77
  'parts': [{
@@ -154,7 +170,7 @@ test_table: list[pytest_helper.TestTableItem] = [
154
170
  name='test_union_with_inlined_request',
155
171
  parameters=types._CreateBatchJobParameters(
156
172
  model=_MLDEV_GEMINI_MODEL,
157
- src=[_INLINED_REQUEST],
173
+ src=_INLINED_REQUESTS,
158
174
  config={
159
175
  'display_name': _DISPLAY_NAME,
160
176
  },
@@ -166,7 +182,7 @@ test_table: list[pytest_helper.TestTableItem] = [
166
182
  name='test_with_inlined_request',
167
183
  parameters=types._CreateBatchJobParameters(
168
184
  model=_MLDEV_GEMINI_MODEL,
169
- src={'inlined_requests': [_INLINED_REQUEST]},
185
+ src={'inlined_requests': _INLINED_REQUESTS},
170
186
  config={
171
187
  'display_name': _DISPLAY_NAME,
172
188
  },
@@ -177,7 +193,7 @@ test_table: list[pytest_helper.TestTableItem] = [
177
193
  name='test_with_inlined_request_config',
178
194
  parameters=types._CreateBatchJobParameters(
179
195
  model=_MLDEV_GEMINI_MODEL,
180
- src={'inlined_requests': [_INLINED_TEXT_REQUEST]},
196
+ src={'inlined_requests': _INLINED_REQUESTS},
181
197
  config={
182
198
  'display_name': _DISPLAY_NAME,
183
199
  },
@@ -247,7 +263,7 @@ async def test_async_create(client):
247
263
  with pytest_helper.exception_if_vertex(client, ValueError):
248
264
  batch_job = await client.aio.batches.create(
249
265
  model=_GEMINI_MODEL,
250
- src=[_INLINED_REQUEST],
266
+ src=_INLINED_REQUESTS,
251
267
  )
252
268
  assert batch_job.name.startswith('batches/')
253
269
  assert (
@@ -29,7 +29,7 @@ _BATCH_JOB_FULL_RESOURCE_NAME = (
29
29
  f'batchPredictionJobs/{_BATCH_JOB_NAME}'
30
30
  )
31
31
  # MLDev batch operation name.
32
- _MLDEV_BATCH_OPERATION_NAME = 'batches/0yew7plxupyybd7appsrq5vw7w0lp3l79lab'
32
+ _MLDEV_BATCH_OPERATION_NAME = 'batches/z2p8ksus4lyxt25rntl3fpd67p2niw4hfij5'
33
33
  _INVALID_BATCH_JOB_NAME = 'invalid_name'
34
34
 
35
35
 
@@ -90,7 +90,6 @@ async def test_async_httpx_client_context_manager():
90
90
  assert async_client._api_client._async_httpx_client.is_closed
91
91
 
92
92
 
93
- @requires_aiohttp
94
93
  @pytest.fixture
95
94
  def mock_request():
96
95
  mock_aiohttp_response = mock.Mock(spec=aiohttp.ClientSession.request)
@@ -42,7 +42,7 @@ def get_headers():
42
42
  test_table: list[pytest_helper.TestTableItem] = [
43
43
  pytest_helper.TestTableItem(
44
44
  name='test_register',
45
- parameters=types._RegisterFilesParameters(uris=['gs://unified-genai-dev/image.jpg']),
45
+ parameters=types._InternalRegisterFilesParameters(uris=['gs://unified-genai-dev/image.jpg']),
46
46
  exception_if_vertex='only supported in the Gemini Developer client',
47
47
  skip_in_api_mode=(
48
48
  'The files have a TTL, they cannot be reliably retrieved for a long'
@@ -79,7 +79,6 @@ class CountryInfoWithAnyOf(pydantic.BaseModel):
79
79
 
80
80
 
81
81
  @pytest.fixture
82
- @pytest.mark.parametrize('use_vertex', [True, False])
83
82
  def client(use_vertex):
84
83
  if use_vertex:
85
84
  yield google_genai_client_module.Client(
@@ -91,6 +90,7 @@ def client(use_vertex):
91
90
  )
92
91
 
93
92
 
93
+ @pytest.mark.parametrize('use_vertex', [True, False])
94
94
  def test_build_schema_for_list_of_pydantic_schema(client):
95
95
  """Tests _build_schema() when list[pydantic.BaseModel] is provided to response_schema."""
96
96
 
@@ -112,6 +112,7 @@ def test_build_schema_for_list_of_pydantic_schema(client):
112
112
  assert list_schema['required'] == list(country_info_fields.keys())
113
113
 
114
114
 
115
+ @pytest.mark.parametrize('use_vertex', [True, False])
115
116
  def test_build_schema_for_list_of_nested_pydantic_schema(client):
116
117
  """Tests _build_schema() when list[pydantic.BaseModel] is provided to response_schema and the pydantic.BaseModel has nested pydantic fields."""
117
118
  list_schema = _transformers.t_schema(
@@ -132,6 +133,7 @@ def test_build_schema_for_list_of_nested_pydantic_schema(client):
132
133
  assert field_name in currency_info_fields
133
134
 
134
135
 
136
+ @pytest.mark.parametrize('use_vertex', [True, False])
135
137
  def test_t_schema_for_pydantic_schema(client):
136
138
  """Tests t_schema when pydantic.BaseModel is passed to response_schema."""
137
139
  transformed_schema = _transformers.t_schema(client, CountryInfo)
@@ -143,6 +145,7 @@ def test_t_schema_for_pydantic_schema(client):
143
145
  )
144
146
 
145
147
 
148
+ @pytest.mark.parametrize('use_vertex', [True, False])
146
149
  def test_t_schema_for_list_of_pydantic_schema(client):
147
150
  """Tests t_schema when list[pydantic.BaseModel] is passed to response_schema."""
148
151
  transformed_schema = _transformers.t_schema(client, list[CountryInfo])
@@ -156,6 +159,7 @@ def test_t_schema_for_list_of_pydantic_schema(client):
156
159
  )
157
160
 
158
161
 
162
+ @pytest.mark.parametrize('use_vertex', [True, False])
159
163
  def test_t_schema_for_null_fields(client):
160
164
  """Tests t_schema when null fields are present."""
161
165
  transformed_schema = _transformers.t_schema(client, CountryInfoWithNullFields)
@@ -163,6 +167,7 @@ def test_t_schema_for_null_fields(client):
163
167
  assert transformed_schema.properties['population'].nullable
164
168
 
165
169
 
170
+ #@pytest.mark.parametrize('use_vertex', [True, False])
166
171
  def test_schema_with_no_null_fields_is_unchanged():
167
172
  """Tests handle_null_fields() doesn't change anything when no null fields are present."""
168
173
  test_properties = {
@@ -207,6 +212,7 @@ def test_schema_with_default_value(client):
207
212
  assert transformed_schema == expected_schema
208
213
 
209
214
 
215
+ @pytest.mark.parametrize('use_vertex', [True, False])
210
216
  def test_schema_with_any_of(client):
211
217
  transformed_schema = _transformers.t_schema(client, CountryInfoWithAnyOf)
212
218
  expected_schema = types.Schema(
@@ -601,6 +607,7 @@ def test_process_schema_order_properties_propagates_into_any_of(
601
607
  assert schema == schema_without_property_ordering
602
608
 
603
609
 
610
+ @pytest.mark.parametrize('use_vertex', [True, False])
604
611
  def test_t_schema_does_not_change_property_ordering_if_set(client):
605
612
  """Tests t_schema doesn't overwrite the property_ordering field if already set."""
606
613
 
@@ -612,6 +619,7 @@ def test_t_schema_does_not_change_property_ordering_if_set(client):
612
619
  assert transformed_schema.property_ordering == custom_property_ordering
613
620
 
614
621
 
622
+ @pytest.mark.parametrize('use_vertex', [True, False])
615
623
  def test_t_schema_sets_property_ordering_for_json_schema(client):
616
624
  """Tests t_schema sets the property_ordering field for json schemas."""
617
625
 
@@ -629,6 +637,7 @@ def test_t_schema_sets_property_ordering_for_json_schema(client):
629
637
  ]
630
638
 
631
639
 
640
+ @pytest.mark.parametrize('use_vertex', [True, False])
632
641
  def test_t_schema_sets_property_ordering_for_schema_type(client):
633
642
  """Tests t_schema sets the property_ordering field for Schema types."""
634
643
 
@@ -20,6 +20,12 @@ from ... import types as genai_types
20
20
  from .. import pytest_helper
21
21
  import pytest
22
22
 
23
+
24
+ VERTEX_HTTP_OPTIONS = {
25
+ 'api_version': 'v1beta1',
26
+ 'base_url': 'https://us-central1-autopush-aiplatform.sandbox.googleapis.com/',
27
+ }
28
+
23
29
  evaluation_config=genai_types.EvaluationConfig(
24
30
  metrics=[
25
31
  genai_types.Metric(name="bleu", prompt_template="test prompt template")
@@ -158,6 +164,87 @@ test_table: list[pytest_helper.TestTableItem] = [
158
164
  ),
159
165
  exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.",
160
166
  ),
167
+ pytest_helper.TestTableItem(
168
+ name="test_tune_distillation",
169
+ parameters=genai_types.CreateTuningJobParameters(
170
+ base_model="meta/llama3_1@llama-3.1-8b-instruct",
171
+ training_dataset=genai_types.TuningDataset(
172
+ gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
173
+ ),
174
+ config=genai_types.CreateTuningJobConfig(
175
+ method="DISTILLATION",
176
+ base_teacher_model="deepseek-ai/deepseek-v3.1-maas",
177
+ epoch_count=20,
178
+ validation_dataset=genai_types.TuningValidationDataset(
179
+ gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
180
+ ),
181
+ output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder",
182
+ http_options=VERTEX_HTTP_OPTIONS,
183
+ ),
184
+ ),
185
+ exception_if_mldev="parameter is not supported in Gemini API.",
186
+ ),
187
+ pytest_helper.TestTableItem(
188
+ name="test_tune_oss_sft",
189
+ parameters=genai_types.CreateTuningJobParameters(
190
+ base_model="meta/llama3_1@llama-3.1-8b-instruct",
191
+ training_dataset=genai_types.TuningDataset(
192
+ gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
193
+ ),
194
+ config=genai_types.CreateTuningJobConfig(
195
+ epoch_count=20,
196
+ validation_dataset=genai_types.TuningValidationDataset(
197
+ gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
198
+ ),
199
+ custom_base_model="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder/postprocess/node-0/checkpoints/final",
200
+ output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
201
+ http_options=VERTEX_HTTP_OPTIONS,
202
+ ),
203
+ ),
204
+ exception_if_mldev="not supported in Gemini API",
205
+ ),
206
+ pytest_helper.TestTableItem(
207
+ name="test_tune_oss_sft_hyperparams",
208
+ parameters=genai_types.CreateTuningJobParameters(
209
+ base_model="meta/llama3_1@llama-3.1-8b-instruct",
210
+ training_dataset=genai_types.TuningDataset(
211
+ gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
212
+ ),
213
+ config=genai_types.CreateTuningJobConfig(
214
+ epoch_count=20,
215
+ validation_dataset=genai_types.TuningValidationDataset(
216
+ gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
217
+ ),
218
+ learning_rate=2.5e-4,
219
+ tuning_mode="TUNING_MODE_FULL",
220
+ custom_base_model="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder/postprocess/node-0/checkpoints/final",
221
+ output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
222
+ http_options=VERTEX_HTTP_OPTIONS,
223
+ ),
224
+ ),
225
+ exception_if_mldev="not supported in Gemini API",
226
+ ),
227
+ pytest_helper.TestTableItem(
228
+ name="test_tune_oss_distillation",
229
+ parameters=genai_types.CreateTuningJobParameters(
230
+ base_model="meta/llama3_1@llama-3.1-8b-instruct",
231
+ training_dataset=genai_types.TuningDataset(
232
+ gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
233
+ ),
234
+ config=genai_types.CreateTuningJobConfig(
235
+ method="DISTILLATION",
236
+ base_teacher_model="deepseek-ai/deepseek-v3.1-maas",
237
+ epoch_count=20,
238
+ validation_dataset=genai_types.TuningValidationDataset(
239
+ gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
240
+ ),
241
+ custom_base_model="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder/postprocess/node-0/checkpoints/final",
242
+ output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
243
+ http_options=VERTEX_HTTP_OPTIONS,
244
+ ),
245
+ ),
246
+ exception_if_mldev="not supported in Gemini API",
247
+ ),
161
248
  ]
162
249
 
163
250
  pytestmark = pytest_helper.setup(
google/genai/tunings.py CHANGED
@@ -188,6 +188,14 @@ def _CreateTuningJobConfig_to_mldev(
188
188
  if getv(from_object, ['adapter_size']) is not None:
189
189
  raise ValueError('adapter_size parameter is not supported in Gemini API.')
190
190
 
191
+ if getv(from_object, ['tuning_mode']) is not None:
192
+ raise ValueError('tuning_mode parameter is not supported in Gemini API.')
193
+
194
+ if getv(from_object, ['custom_base_model']) is not None:
195
+ raise ValueError(
196
+ 'custom_base_model parameter is not supported in Gemini API.'
197
+ )
198
+
191
199
  if getv(from_object, ['batch_size']) is not None:
192
200
  setv(
193
201
  parent_object,
@@ -213,6 +221,24 @@ def _CreateTuningJobConfig_to_mldev(
213
221
  if getv(from_object, ['beta']) is not None:
214
222
  raise ValueError('beta parameter is not supported in Gemini API.')
215
223
 
224
+ if getv(from_object, ['base_teacher_model']) is not None:
225
+ raise ValueError(
226
+ 'base_teacher_model parameter is not supported in Gemini API.'
227
+ )
228
+
229
+ if getv(from_object, ['tuned_teacher_model_source']) is not None:
230
+ raise ValueError(
231
+ 'tuned_teacher_model_source parameter is not supported in Gemini API.'
232
+ )
233
+
234
+ if getv(from_object, ['sft_loss_weight_multiplier']) is not None:
235
+ raise ValueError(
236
+ 'sft_loss_weight_multiplier parameter is not supported in Gemini API.'
237
+ )
238
+
239
+ if getv(from_object, ['output_uri']) is not None:
240
+ raise ValueError('output_uri parameter is not supported in Gemini API.')
241
+
216
242
  return to_object
217
243
 
218
244
 
@@ -246,6 +272,16 @@ def _CreateTuningJobConfig_to_vertex(
246
272
  ),
247
273
  )
248
274
 
275
+ elif discriminator == 'DISTILLATION':
276
+ if getv(from_object, ['validation_dataset']) is not None:
277
+ setv(
278
+ parent_object,
279
+ ['distillationSpec'],
280
+ _TuningValidationDataset_to_vertex(
281
+ getv(from_object, ['validation_dataset']), to_object, root_object
282
+ ),
283
+ )
284
+
249
285
  if getv(from_object, ['tuned_model_display_name']) is not None:
250
286
  setv(
251
287
  parent_object,
@@ -275,6 +311,14 @@ def _CreateTuningJobConfig_to_vertex(
275
311
  getv(from_object, ['epoch_count']),
276
312
  )
277
313
 
314
+ elif discriminator == 'DISTILLATION':
315
+ if getv(from_object, ['epoch_count']) is not None:
316
+ setv(
317
+ parent_object,
318
+ ['distillationSpec', 'hyperParameters', 'epochCount'],
319
+ getv(from_object, ['epoch_count']),
320
+ )
321
+
278
322
  discriminator = getv(root_object, ['config', 'method'])
279
323
  if discriminator is None:
280
324
  discriminator = 'SUPERVISED_FINE_TUNING'
@@ -298,6 +342,14 @@ def _CreateTuningJobConfig_to_vertex(
298
342
  getv(from_object, ['learning_rate_multiplier']),
299
343
  )
300
344
 
345
+ elif discriminator == 'DISTILLATION':
346
+ if getv(from_object, ['learning_rate_multiplier']) is not None:
347
+ setv(
348
+ parent_object,
349
+ ['distillationSpec', 'hyperParameters', 'learningRateMultiplier'],
350
+ getv(from_object, ['learning_rate_multiplier']),
351
+ )
352
+
301
353
  discriminator = getv(root_object, ['config', 'method'])
302
354
  if discriminator is None:
303
355
  discriminator = 'SUPERVISED_FINE_TUNING'
@@ -317,6 +369,14 @@ def _CreateTuningJobConfig_to_vertex(
317
369
  getv(from_object, ['export_last_checkpoint_only']),
318
370
  )
319
371
 
372
+ elif discriminator == 'DISTILLATION':
373
+ if getv(from_object, ['export_last_checkpoint_only']) is not None:
374
+ setv(
375
+ parent_object,
376
+ ['distillationSpec', 'exportLastCheckpointOnly'],
377
+ getv(from_object, ['export_last_checkpoint_only']),
378
+ )
379
+
320
380
  discriminator = getv(root_object, ['config', 'method'])
321
381
  if discriminator is None:
322
382
  discriminator = 'SUPERVISED_FINE_TUNING'
@@ -336,11 +396,53 @@ def _CreateTuningJobConfig_to_vertex(
336
396
  getv(from_object, ['adapter_size']),
337
397
  )
338
398
 
339
- if getv(from_object, ['batch_size']) is not None:
340
- raise ValueError('batch_size parameter is not supported in Vertex AI.')
399
+ elif discriminator == 'DISTILLATION':
400
+ if getv(from_object, ['adapter_size']) is not None:
401
+ setv(
402
+ parent_object,
403
+ ['distillationSpec', 'hyperParameters', 'adapterSize'],
404
+ getv(from_object, ['adapter_size']),
405
+ )
341
406
 
342
- if getv(from_object, ['learning_rate']) is not None:
343
- raise ValueError('learning_rate parameter is not supported in Vertex AI.')
407
+ discriminator = getv(root_object, ['config', 'method'])
408
+ if discriminator is None:
409
+ discriminator = 'SUPERVISED_FINE_TUNING'
410
+ if discriminator == 'SUPERVISED_FINE_TUNING':
411
+ if getv(from_object, ['tuning_mode']) is not None:
412
+ setv(
413
+ parent_object,
414
+ ['supervisedTuningSpec', 'tuningMode'],
415
+ getv(from_object, ['tuning_mode']),
416
+ )
417
+
418
+ if getv(from_object, ['custom_base_model']) is not None:
419
+ setv(
420
+ parent_object,
421
+ ['customBaseModel'],
422
+ getv(from_object, ['custom_base_model']),
423
+ )
424
+
425
+ discriminator = getv(root_object, ['config', 'method'])
426
+ if discriminator is None:
427
+ discriminator = 'SUPERVISED_FINE_TUNING'
428
+ if discriminator == 'SUPERVISED_FINE_TUNING':
429
+ if getv(from_object, ['batch_size']) is not None:
430
+ setv(
431
+ parent_object,
432
+ ['supervisedTuningSpec', 'hyperParameters', 'batchSize'],
433
+ getv(from_object, ['batch_size']),
434
+ )
435
+
436
+ discriminator = getv(root_object, ['config', 'method'])
437
+ if discriminator is None:
438
+ discriminator = 'SUPERVISED_FINE_TUNING'
439
+ if discriminator == 'SUPERVISED_FINE_TUNING':
440
+ if getv(from_object, ['learning_rate']) is not None:
441
+ setv(
442
+ parent_object,
443
+ ['supervisedTuningSpec', 'hyperParameters', 'learningRate'],
444
+ getv(from_object, ['learning_rate']),
445
+ )
344
446
 
345
447
  discriminator = getv(root_object, ['config', 'method'])
346
448
  if discriminator is None:
@@ -365,6 +467,16 @@ def _CreateTuningJobConfig_to_vertex(
365
467
  ),
366
468
  )
367
469
 
470
+ elif discriminator == 'DISTILLATION':
471
+ if getv(from_object, ['evaluation_config']) is not None:
472
+ setv(
473
+ parent_object,
474
+ ['distillationSpec', 'evaluationConfig'],
475
+ _EvaluationConfig_to_vertex(
476
+ getv(from_object, ['evaluation_config']), to_object, root_object
477
+ ),
478
+ )
479
+
368
480
  if getv(from_object, ['labels']) is not None:
369
481
  setv(parent_object, ['labels'], getv(from_object, ['labels']))
370
482
 
@@ -375,6 +487,30 @@ def _CreateTuningJobConfig_to_vertex(
375
487
  getv(from_object, ['beta']),
376
488
  )
377
489
 
490
+ if getv(from_object, ['base_teacher_model']) is not None:
491
+ setv(
492
+ parent_object,
493
+ ['distillationSpec', 'baseTeacherModel'],
494
+ getv(from_object, ['base_teacher_model']),
495
+ )
496
+
497
+ if getv(from_object, ['tuned_teacher_model_source']) is not None:
498
+ setv(
499
+ parent_object,
500
+ ['distillationSpec', 'tunedTeacherModelSource'],
501
+ getv(from_object, ['tuned_teacher_model_source']),
502
+ )
503
+
504
+ if getv(from_object, ['sft_loss_weight_multiplier']) is not None:
505
+ setv(
506
+ parent_object,
507
+ ['distillationSpec', 'hyperParameters', 'sftLossWeightMultiplier'],
508
+ getv(from_object, ['sft_loss_weight_multiplier']),
509
+ )
510
+
511
+ if getv(from_object, ['output_uri']) is not None:
512
+ setv(parent_object, ['outputUri'], getv(from_object, ['output_uri']))
513
+
378
514
  return to_object
379
515
 
380
516
 
@@ -920,6 +1056,14 @@ def _TuningDataset_to_vertex(
920
1056
  getv(from_object, ['gcs_uri']),
921
1057
  )
922
1058
 
1059
+ elif discriminator == 'DISTILLATION':
1060
+ if getv(from_object, ['gcs_uri']) is not None:
1061
+ setv(
1062
+ parent_object,
1063
+ ['distillationSpec', 'promptDatasetUri'],
1064
+ getv(from_object, ['gcs_uri']),
1065
+ )
1066
+
923
1067
  discriminator = getv(root_object, ['config', 'method'])
924
1068
  if discriminator is None:
925
1069
  discriminator = 'SUPERVISED_FINE_TUNING'
@@ -939,6 +1083,14 @@ def _TuningDataset_to_vertex(
939
1083
  getv(from_object, ['vertex_dataset_resource']),
940
1084
  )
941
1085
 
1086
+ elif discriminator == 'DISTILLATION':
1087
+ if getv(from_object, ['vertex_dataset_resource']) is not None:
1088
+ setv(
1089
+ parent_object,
1090
+ ['distillationSpec', 'promptDatasetUri'],
1091
+ getv(from_object, ['vertex_dataset_resource']),
1092
+ )
1093
+
942
1094
  if getv(from_object, ['examples']) is not None:
943
1095
  raise ValueError('examples parameter is not supported in Vertex AI.')
944
1096
 
@@ -1066,6 +1218,13 @@ def _TuningJob_from_vertex(
1066
1218
  getv(from_object, ['preferenceOptimizationSpec']),
1067
1219
  )
1068
1220
 
1221
+ if getv(from_object, ['distillationSpec']) is not None:
1222
+ setv(
1223
+ to_object,
1224
+ ['distillation_spec'],
1225
+ getv(from_object, ['distillationSpec']),
1226
+ )
1227
+
1069
1228
  if getv(from_object, ['tuningDataStats']) is not None:
1070
1229
  setv(
1071
1230
  to_object, ['tuning_data_stats'], getv(from_object, ['tuningDataStats'])