google-genai 1.60.0__py3-none-any.whl → 1.61.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_interactions/resources/interactions.py +50 -28
- google/genai/_interactions/types/__init__.py +2 -1
- google/genai/_interactions/types/content_delta.py +1 -1
- google/genai/_interactions/types/function_result_content.py +2 -1
- google/genai/_interactions/types/function_result_content_param.py +4 -4
- google/genai/_interactions/types/{interaction_event.py → interaction_complete_event.py} +3 -3
- google/genai/_interactions/types/interaction_create_params.py +4 -4
- google/genai/_interactions/types/interaction_get_params.py +3 -0
- google/genai/_interactions/types/interaction_sse_event.py +11 -2
- google/genai/_interactions/types/interaction_start_event.py +36 -0
- google/genai/batches.py +3 -0
- google/genai/files.py +15 -15
- google/genai/tests/batches/test_create_with_inlined_requests.py +31 -15
- google/genai/tests/batches/test_get.py +1 -1
- google/genai/tests/client/test_client_close.py +0 -1
- google/genai/tests/files/test_register_table.py +1 -1
- google/genai/tests/transformers/test_schema.py +10 -1
- google/genai/tests/tunings/test_tune.py +87 -0
- google/genai/tunings.py +163 -4
- google/genai/types.py +178 -14
- google/genai/version.py +1 -1
- {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/METADATA +1 -1
- {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/RECORD +26 -25
- {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/WHEEL +1 -1
- {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.60.0.dist-info → google_genai-1.61.0.dist-info}/top_level.txt +0 -0
|
@@ -42,20 +42,36 @@ _SAFETY_SETTINGS = [
|
|
|
42
42
|
},
|
|
43
43
|
]
|
|
44
44
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
'
|
|
48
|
-
'
|
|
45
|
+
_INLINED_REQUESTS = [
|
|
46
|
+
{
|
|
47
|
+
'contents': [{
|
|
48
|
+
'parts': [{
|
|
49
|
+
'text': 'what is the number after 1? return just the number.',
|
|
50
|
+
}],
|
|
51
|
+
'role': 'user',
|
|
49
52
|
}],
|
|
50
|
-
'
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
'
|
|
53
|
+
'metadata': {
|
|
54
|
+
'key': 'request-1',
|
|
55
|
+
},
|
|
56
|
+
'config': {
|
|
57
|
+
'safety_settings': _SAFETY_SETTINGS,
|
|
58
|
+
},
|
|
54
59
|
},
|
|
55
|
-
|
|
56
|
-
'
|
|
60
|
+
{
|
|
61
|
+
'contents': [{
|
|
62
|
+
'parts': [{
|
|
63
|
+
'text': 'what is the number after 2? return just the number.',
|
|
64
|
+
}],
|
|
65
|
+
'role': 'user',
|
|
66
|
+
}],
|
|
67
|
+
'metadata': {
|
|
68
|
+
'key': 'request-2',
|
|
69
|
+
},
|
|
70
|
+
'config': {
|
|
71
|
+
'safety_settings': _SAFETY_SETTINGS,
|
|
72
|
+
},
|
|
57
73
|
},
|
|
58
|
-
|
|
74
|
+
]
|
|
59
75
|
_INLINED_TEXT_REQUEST_UNION = {
|
|
60
76
|
'contents': [{
|
|
61
77
|
'parts': [{
|
|
@@ -154,7 +170,7 @@ test_table: list[pytest_helper.TestTableItem] = [
|
|
|
154
170
|
name='test_union_with_inlined_request',
|
|
155
171
|
parameters=types._CreateBatchJobParameters(
|
|
156
172
|
model=_MLDEV_GEMINI_MODEL,
|
|
157
|
-
src=
|
|
173
|
+
src=_INLINED_REQUESTS,
|
|
158
174
|
config={
|
|
159
175
|
'display_name': _DISPLAY_NAME,
|
|
160
176
|
},
|
|
@@ -166,7 +182,7 @@ test_table: list[pytest_helper.TestTableItem] = [
|
|
|
166
182
|
name='test_with_inlined_request',
|
|
167
183
|
parameters=types._CreateBatchJobParameters(
|
|
168
184
|
model=_MLDEV_GEMINI_MODEL,
|
|
169
|
-
src={'inlined_requests':
|
|
185
|
+
src={'inlined_requests': _INLINED_REQUESTS},
|
|
170
186
|
config={
|
|
171
187
|
'display_name': _DISPLAY_NAME,
|
|
172
188
|
},
|
|
@@ -177,7 +193,7 @@ test_table: list[pytest_helper.TestTableItem] = [
|
|
|
177
193
|
name='test_with_inlined_request_config',
|
|
178
194
|
parameters=types._CreateBatchJobParameters(
|
|
179
195
|
model=_MLDEV_GEMINI_MODEL,
|
|
180
|
-
src={'inlined_requests':
|
|
196
|
+
src={'inlined_requests': _INLINED_REQUESTS},
|
|
181
197
|
config={
|
|
182
198
|
'display_name': _DISPLAY_NAME,
|
|
183
199
|
},
|
|
@@ -247,7 +263,7 @@ async def test_async_create(client):
|
|
|
247
263
|
with pytest_helper.exception_if_vertex(client, ValueError):
|
|
248
264
|
batch_job = await client.aio.batches.create(
|
|
249
265
|
model=_GEMINI_MODEL,
|
|
250
|
-
src=
|
|
266
|
+
src=_INLINED_REQUESTS,
|
|
251
267
|
)
|
|
252
268
|
assert batch_job.name.startswith('batches/')
|
|
253
269
|
assert (
|
|
@@ -29,7 +29,7 @@ _BATCH_JOB_FULL_RESOURCE_NAME = (
|
|
|
29
29
|
f'batchPredictionJobs/{_BATCH_JOB_NAME}'
|
|
30
30
|
)
|
|
31
31
|
# MLDev batch operation name.
|
|
32
|
-
_MLDEV_BATCH_OPERATION_NAME = 'batches/
|
|
32
|
+
_MLDEV_BATCH_OPERATION_NAME = 'batches/z2p8ksus4lyxt25rntl3fpd67p2niw4hfij5'
|
|
33
33
|
_INVALID_BATCH_JOB_NAME = 'invalid_name'
|
|
34
34
|
|
|
35
35
|
|
|
@@ -42,7 +42,7 @@ def get_headers():
|
|
|
42
42
|
test_table: list[pytest_helper.TestTableItem] = [
|
|
43
43
|
pytest_helper.TestTableItem(
|
|
44
44
|
name='test_register',
|
|
45
|
-
parameters=types.
|
|
45
|
+
parameters=types._InternalRegisterFilesParameters(uris=['gs://unified-genai-dev/image.jpg']),
|
|
46
46
|
exception_if_vertex='only supported in the Gemini Developer client',
|
|
47
47
|
skip_in_api_mode=(
|
|
48
48
|
'The files have a TTL, they cannot be reliably retrieved for a long'
|
|
@@ -79,7 +79,6 @@ class CountryInfoWithAnyOf(pydantic.BaseModel):
|
|
|
79
79
|
|
|
80
80
|
|
|
81
81
|
@pytest.fixture
|
|
82
|
-
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
83
82
|
def client(use_vertex):
|
|
84
83
|
if use_vertex:
|
|
85
84
|
yield google_genai_client_module.Client(
|
|
@@ -91,6 +90,7 @@ def client(use_vertex):
|
|
|
91
90
|
)
|
|
92
91
|
|
|
93
92
|
|
|
93
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
94
94
|
def test_build_schema_for_list_of_pydantic_schema(client):
|
|
95
95
|
"""Tests _build_schema() when list[pydantic.BaseModel] is provided to response_schema."""
|
|
96
96
|
|
|
@@ -112,6 +112,7 @@ def test_build_schema_for_list_of_pydantic_schema(client):
|
|
|
112
112
|
assert list_schema['required'] == list(country_info_fields.keys())
|
|
113
113
|
|
|
114
114
|
|
|
115
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
115
116
|
def test_build_schema_for_list_of_nested_pydantic_schema(client):
|
|
116
117
|
"""Tests _build_schema() when list[pydantic.BaseModel] is provided to response_schema and the pydantic.BaseModel has nested pydantic fields."""
|
|
117
118
|
list_schema = _transformers.t_schema(
|
|
@@ -132,6 +133,7 @@ def test_build_schema_for_list_of_nested_pydantic_schema(client):
|
|
|
132
133
|
assert field_name in currency_info_fields
|
|
133
134
|
|
|
134
135
|
|
|
136
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
135
137
|
def test_t_schema_for_pydantic_schema(client):
|
|
136
138
|
"""Tests t_schema when pydantic.BaseModel is passed to response_schema."""
|
|
137
139
|
transformed_schema = _transformers.t_schema(client, CountryInfo)
|
|
@@ -143,6 +145,7 @@ def test_t_schema_for_pydantic_schema(client):
|
|
|
143
145
|
)
|
|
144
146
|
|
|
145
147
|
|
|
148
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
146
149
|
def test_t_schema_for_list_of_pydantic_schema(client):
|
|
147
150
|
"""Tests t_schema when list[pydantic.BaseModel] is passed to response_schema."""
|
|
148
151
|
transformed_schema = _transformers.t_schema(client, list[CountryInfo])
|
|
@@ -156,6 +159,7 @@ def test_t_schema_for_list_of_pydantic_schema(client):
|
|
|
156
159
|
)
|
|
157
160
|
|
|
158
161
|
|
|
162
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
159
163
|
def test_t_schema_for_null_fields(client):
|
|
160
164
|
"""Tests t_schema when null fields are present."""
|
|
161
165
|
transformed_schema = _transformers.t_schema(client, CountryInfoWithNullFields)
|
|
@@ -163,6 +167,7 @@ def test_t_schema_for_null_fields(client):
|
|
|
163
167
|
assert transformed_schema.properties['population'].nullable
|
|
164
168
|
|
|
165
169
|
|
|
170
|
+
#@pytest.mark.parametrize('use_vertex', [True, False])
|
|
166
171
|
def test_schema_with_no_null_fields_is_unchanged():
|
|
167
172
|
"""Tests handle_null_fields() doesn't change anything when no null fields are present."""
|
|
168
173
|
test_properties = {
|
|
@@ -207,6 +212,7 @@ def test_schema_with_default_value(client):
|
|
|
207
212
|
assert transformed_schema == expected_schema
|
|
208
213
|
|
|
209
214
|
|
|
215
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
210
216
|
def test_schema_with_any_of(client):
|
|
211
217
|
transformed_schema = _transformers.t_schema(client, CountryInfoWithAnyOf)
|
|
212
218
|
expected_schema = types.Schema(
|
|
@@ -601,6 +607,7 @@ def test_process_schema_order_properties_propagates_into_any_of(
|
|
|
601
607
|
assert schema == schema_without_property_ordering
|
|
602
608
|
|
|
603
609
|
|
|
610
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
604
611
|
def test_t_schema_does_not_change_property_ordering_if_set(client):
|
|
605
612
|
"""Tests t_schema doesn't overwrite the property_ordering field if already set."""
|
|
606
613
|
|
|
@@ -612,6 +619,7 @@ def test_t_schema_does_not_change_property_ordering_if_set(client):
|
|
|
612
619
|
assert transformed_schema.property_ordering == custom_property_ordering
|
|
613
620
|
|
|
614
621
|
|
|
622
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
615
623
|
def test_t_schema_sets_property_ordering_for_json_schema(client):
|
|
616
624
|
"""Tests t_schema sets the property_ordering field for json schemas."""
|
|
617
625
|
|
|
@@ -629,6 +637,7 @@ def test_t_schema_sets_property_ordering_for_json_schema(client):
|
|
|
629
637
|
]
|
|
630
638
|
|
|
631
639
|
|
|
640
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
632
641
|
def test_t_schema_sets_property_ordering_for_schema_type(client):
|
|
633
642
|
"""Tests t_schema sets the property_ordering field for Schema types."""
|
|
634
643
|
|
|
@@ -20,6 +20,12 @@ from ... import types as genai_types
|
|
|
20
20
|
from .. import pytest_helper
|
|
21
21
|
import pytest
|
|
22
22
|
|
|
23
|
+
|
|
24
|
+
VERTEX_HTTP_OPTIONS = {
|
|
25
|
+
'api_version': 'v1beta1',
|
|
26
|
+
'base_url': 'https://us-central1-autopush-aiplatform.sandbox.googleapis.com/',
|
|
27
|
+
}
|
|
28
|
+
|
|
23
29
|
evaluation_config=genai_types.EvaluationConfig(
|
|
24
30
|
metrics=[
|
|
25
31
|
genai_types.Metric(name="bleu", prompt_template="test prompt template")
|
|
@@ -158,6 +164,87 @@ test_table: list[pytest_helper.TestTableItem] = [
|
|
|
158
164
|
),
|
|
159
165
|
exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.",
|
|
160
166
|
),
|
|
167
|
+
pytest_helper.TestTableItem(
|
|
168
|
+
name="test_tune_distillation",
|
|
169
|
+
parameters=genai_types.CreateTuningJobParameters(
|
|
170
|
+
base_model="meta/llama3_1@llama-3.1-8b-instruct",
|
|
171
|
+
training_dataset=genai_types.TuningDataset(
|
|
172
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
|
|
173
|
+
),
|
|
174
|
+
config=genai_types.CreateTuningJobConfig(
|
|
175
|
+
method="DISTILLATION",
|
|
176
|
+
base_teacher_model="deepseek-ai/deepseek-v3.1-maas",
|
|
177
|
+
epoch_count=20,
|
|
178
|
+
validation_dataset=genai_types.TuningValidationDataset(
|
|
179
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
|
|
180
|
+
),
|
|
181
|
+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder",
|
|
182
|
+
http_options=VERTEX_HTTP_OPTIONS,
|
|
183
|
+
),
|
|
184
|
+
),
|
|
185
|
+
exception_if_mldev="parameter is not supported in Gemini API.",
|
|
186
|
+
),
|
|
187
|
+
pytest_helper.TestTableItem(
|
|
188
|
+
name="test_tune_oss_sft",
|
|
189
|
+
parameters=genai_types.CreateTuningJobParameters(
|
|
190
|
+
base_model="meta/llama3_1@llama-3.1-8b-instruct",
|
|
191
|
+
training_dataset=genai_types.TuningDataset(
|
|
192
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
|
|
193
|
+
),
|
|
194
|
+
config=genai_types.CreateTuningJobConfig(
|
|
195
|
+
epoch_count=20,
|
|
196
|
+
validation_dataset=genai_types.TuningValidationDataset(
|
|
197
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
|
|
198
|
+
),
|
|
199
|
+
custom_base_model="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder/postprocess/node-0/checkpoints/final",
|
|
200
|
+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
|
|
201
|
+
http_options=VERTEX_HTTP_OPTIONS,
|
|
202
|
+
),
|
|
203
|
+
),
|
|
204
|
+
exception_if_mldev="not supported in Gemini API",
|
|
205
|
+
),
|
|
206
|
+
pytest_helper.TestTableItem(
|
|
207
|
+
name="test_tune_oss_sft_hyperparams",
|
|
208
|
+
parameters=genai_types.CreateTuningJobParameters(
|
|
209
|
+
base_model="meta/llama3_1@llama-3.1-8b-instruct",
|
|
210
|
+
training_dataset=genai_types.TuningDataset(
|
|
211
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
|
|
212
|
+
),
|
|
213
|
+
config=genai_types.CreateTuningJobConfig(
|
|
214
|
+
epoch_count=20,
|
|
215
|
+
validation_dataset=genai_types.TuningValidationDataset(
|
|
216
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
|
|
217
|
+
),
|
|
218
|
+
learning_rate=2.5e-4,
|
|
219
|
+
tuning_mode="TUNING_MODE_FULL",
|
|
220
|
+
custom_base_model="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder/postprocess/node-0/checkpoints/final",
|
|
221
|
+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
|
|
222
|
+
http_options=VERTEX_HTTP_OPTIONS,
|
|
223
|
+
),
|
|
224
|
+
),
|
|
225
|
+
exception_if_mldev="not supported in Gemini API",
|
|
226
|
+
),
|
|
227
|
+
pytest_helper.TestTableItem(
|
|
228
|
+
name="test_tune_oss_distillation",
|
|
229
|
+
parameters=genai_types.CreateTuningJobParameters(
|
|
230
|
+
base_model="meta/llama3_1@llama-3.1-8b-instruct",
|
|
231
|
+
training_dataset=genai_types.TuningDataset(
|
|
232
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
|
|
233
|
+
),
|
|
234
|
+
config=genai_types.CreateTuningJobConfig(
|
|
235
|
+
method="DISTILLATION",
|
|
236
|
+
base_teacher_model="deepseek-ai/deepseek-v3.1-maas",
|
|
237
|
+
epoch_count=20,
|
|
238
|
+
validation_dataset=genai_types.TuningValidationDataset(
|
|
239
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
|
|
240
|
+
),
|
|
241
|
+
custom_base_model="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder/postprocess/node-0/checkpoints/final",
|
|
242
|
+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
|
|
243
|
+
http_options=VERTEX_HTTP_OPTIONS,
|
|
244
|
+
),
|
|
245
|
+
),
|
|
246
|
+
exception_if_mldev="not supported in Gemini API",
|
|
247
|
+
),
|
|
161
248
|
]
|
|
162
249
|
|
|
163
250
|
pytestmark = pytest_helper.setup(
|
google/genai/tunings.py
CHANGED
|
@@ -188,6 +188,14 @@ def _CreateTuningJobConfig_to_mldev(
|
|
|
188
188
|
if getv(from_object, ['adapter_size']) is not None:
|
|
189
189
|
raise ValueError('adapter_size parameter is not supported in Gemini API.')
|
|
190
190
|
|
|
191
|
+
if getv(from_object, ['tuning_mode']) is not None:
|
|
192
|
+
raise ValueError('tuning_mode parameter is not supported in Gemini API.')
|
|
193
|
+
|
|
194
|
+
if getv(from_object, ['custom_base_model']) is not None:
|
|
195
|
+
raise ValueError(
|
|
196
|
+
'custom_base_model parameter is not supported in Gemini API.'
|
|
197
|
+
)
|
|
198
|
+
|
|
191
199
|
if getv(from_object, ['batch_size']) is not None:
|
|
192
200
|
setv(
|
|
193
201
|
parent_object,
|
|
@@ -213,6 +221,24 @@ def _CreateTuningJobConfig_to_mldev(
|
|
|
213
221
|
if getv(from_object, ['beta']) is not None:
|
|
214
222
|
raise ValueError('beta parameter is not supported in Gemini API.')
|
|
215
223
|
|
|
224
|
+
if getv(from_object, ['base_teacher_model']) is not None:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
'base_teacher_model parameter is not supported in Gemini API.'
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if getv(from_object, ['tuned_teacher_model_source']) is not None:
|
|
230
|
+
raise ValueError(
|
|
231
|
+
'tuned_teacher_model_source parameter is not supported in Gemini API.'
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if getv(from_object, ['sft_loss_weight_multiplier']) is not None:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
'sft_loss_weight_multiplier parameter is not supported in Gemini API.'
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if getv(from_object, ['output_uri']) is not None:
|
|
240
|
+
raise ValueError('output_uri parameter is not supported in Gemini API.')
|
|
241
|
+
|
|
216
242
|
return to_object
|
|
217
243
|
|
|
218
244
|
|
|
@@ -246,6 +272,16 @@ def _CreateTuningJobConfig_to_vertex(
|
|
|
246
272
|
),
|
|
247
273
|
)
|
|
248
274
|
|
|
275
|
+
elif discriminator == 'DISTILLATION':
|
|
276
|
+
if getv(from_object, ['validation_dataset']) is not None:
|
|
277
|
+
setv(
|
|
278
|
+
parent_object,
|
|
279
|
+
['distillationSpec'],
|
|
280
|
+
_TuningValidationDataset_to_vertex(
|
|
281
|
+
getv(from_object, ['validation_dataset']), to_object, root_object
|
|
282
|
+
),
|
|
283
|
+
)
|
|
284
|
+
|
|
249
285
|
if getv(from_object, ['tuned_model_display_name']) is not None:
|
|
250
286
|
setv(
|
|
251
287
|
parent_object,
|
|
@@ -275,6 +311,14 @@ def _CreateTuningJobConfig_to_vertex(
|
|
|
275
311
|
getv(from_object, ['epoch_count']),
|
|
276
312
|
)
|
|
277
313
|
|
|
314
|
+
elif discriminator == 'DISTILLATION':
|
|
315
|
+
if getv(from_object, ['epoch_count']) is not None:
|
|
316
|
+
setv(
|
|
317
|
+
parent_object,
|
|
318
|
+
['distillationSpec', 'hyperParameters', 'epochCount'],
|
|
319
|
+
getv(from_object, ['epoch_count']),
|
|
320
|
+
)
|
|
321
|
+
|
|
278
322
|
discriminator = getv(root_object, ['config', 'method'])
|
|
279
323
|
if discriminator is None:
|
|
280
324
|
discriminator = 'SUPERVISED_FINE_TUNING'
|
|
@@ -298,6 +342,14 @@ def _CreateTuningJobConfig_to_vertex(
|
|
|
298
342
|
getv(from_object, ['learning_rate_multiplier']),
|
|
299
343
|
)
|
|
300
344
|
|
|
345
|
+
elif discriminator == 'DISTILLATION':
|
|
346
|
+
if getv(from_object, ['learning_rate_multiplier']) is not None:
|
|
347
|
+
setv(
|
|
348
|
+
parent_object,
|
|
349
|
+
['distillationSpec', 'hyperParameters', 'learningRateMultiplier'],
|
|
350
|
+
getv(from_object, ['learning_rate_multiplier']),
|
|
351
|
+
)
|
|
352
|
+
|
|
301
353
|
discriminator = getv(root_object, ['config', 'method'])
|
|
302
354
|
if discriminator is None:
|
|
303
355
|
discriminator = 'SUPERVISED_FINE_TUNING'
|
|
@@ -317,6 +369,14 @@ def _CreateTuningJobConfig_to_vertex(
|
|
|
317
369
|
getv(from_object, ['export_last_checkpoint_only']),
|
|
318
370
|
)
|
|
319
371
|
|
|
372
|
+
elif discriminator == 'DISTILLATION':
|
|
373
|
+
if getv(from_object, ['export_last_checkpoint_only']) is not None:
|
|
374
|
+
setv(
|
|
375
|
+
parent_object,
|
|
376
|
+
['distillationSpec', 'exportLastCheckpointOnly'],
|
|
377
|
+
getv(from_object, ['export_last_checkpoint_only']),
|
|
378
|
+
)
|
|
379
|
+
|
|
320
380
|
discriminator = getv(root_object, ['config', 'method'])
|
|
321
381
|
if discriminator is None:
|
|
322
382
|
discriminator = 'SUPERVISED_FINE_TUNING'
|
|
@@ -336,11 +396,53 @@ def _CreateTuningJobConfig_to_vertex(
|
|
|
336
396
|
getv(from_object, ['adapter_size']),
|
|
337
397
|
)
|
|
338
398
|
|
|
339
|
-
|
|
340
|
-
|
|
399
|
+
elif discriminator == 'DISTILLATION':
|
|
400
|
+
if getv(from_object, ['adapter_size']) is not None:
|
|
401
|
+
setv(
|
|
402
|
+
parent_object,
|
|
403
|
+
['distillationSpec', 'hyperParameters', 'adapterSize'],
|
|
404
|
+
getv(from_object, ['adapter_size']),
|
|
405
|
+
)
|
|
341
406
|
|
|
342
|
-
|
|
343
|
-
|
|
407
|
+
discriminator = getv(root_object, ['config', 'method'])
|
|
408
|
+
if discriminator is None:
|
|
409
|
+
discriminator = 'SUPERVISED_FINE_TUNING'
|
|
410
|
+
if discriminator == 'SUPERVISED_FINE_TUNING':
|
|
411
|
+
if getv(from_object, ['tuning_mode']) is not None:
|
|
412
|
+
setv(
|
|
413
|
+
parent_object,
|
|
414
|
+
['supervisedTuningSpec', 'tuningMode'],
|
|
415
|
+
getv(from_object, ['tuning_mode']),
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
if getv(from_object, ['custom_base_model']) is not None:
|
|
419
|
+
setv(
|
|
420
|
+
parent_object,
|
|
421
|
+
['customBaseModel'],
|
|
422
|
+
getv(from_object, ['custom_base_model']),
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
discriminator = getv(root_object, ['config', 'method'])
|
|
426
|
+
if discriminator is None:
|
|
427
|
+
discriminator = 'SUPERVISED_FINE_TUNING'
|
|
428
|
+
if discriminator == 'SUPERVISED_FINE_TUNING':
|
|
429
|
+
if getv(from_object, ['batch_size']) is not None:
|
|
430
|
+
setv(
|
|
431
|
+
parent_object,
|
|
432
|
+
['supervisedTuningSpec', 'hyperParameters', 'batchSize'],
|
|
433
|
+
getv(from_object, ['batch_size']),
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
discriminator = getv(root_object, ['config', 'method'])
|
|
437
|
+
if discriminator is None:
|
|
438
|
+
discriminator = 'SUPERVISED_FINE_TUNING'
|
|
439
|
+
if discriminator == 'SUPERVISED_FINE_TUNING':
|
|
440
|
+
if getv(from_object, ['learning_rate']) is not None:
|
|
441
|
+
setv(
|
|
442
|
+
parent_object,
|
|
443
|
+
['supervisedTuningSpec', 'hyperParameters', 'learningRate'],
|
|
444
|
+
getv(from_object, ['learning_rate']),
|
|
445
|
+
)
|
|
344
446
|
|
|
345
447
|
discriminator = getv(root_object, ['config', 'method'])
|
|
346
448
|
if discriminator is None:
|
|
@@ -365,6 +467,16 @@ def _CreateTuningJobConfig_to_vertex(
|
|
|
365
467
|
),
|
|
366
468
|
)
|
|
367
469
|
|
|
470
|
+
elif discriminator == 'DISTILLATION':
|
|
471
|
+
if getv(from_object, ['evaluation_config']) is not None:
|
|
472
|
+
setv(
|
|
473
|
+
parent_object,
|
|
474
|
+
['distillationSpec', 'evaluationConfig'],
|
|
475
|
+
_EvaluationConfig_to_vertex(
|
|
476
|
+
getv(from_object, ['evaluation_config']), to_object, root_object
|
|
477
|
+
),
|
|
478
|
+
)
|
|
479
|
+
|
|
368
480
|
if getv(from_object, ['labels']) is not None:
|
|
369
481
|
setv(parent_object, ['labels'], getv(from_object, ['labels']))
|
|
370
482
|
|
|
@@ -375,6 +487,30 @@ def _CreateTuningJobConfig_to_vertex(
|
|
|
375
487
|
getv(from_object, ['beta']),
|
|
376
488
|
)
|
|
377
489
|
|
|
490
|
+
if getv(from_object, ['base_teacher_model']) is not None:
|
|
491
|
+
setv(
|
|
492
|
+
parent_object,
|
|
493
|
+
['distillationSpec', 'baseTeacherModel'],
|
|
494
|
+
getv(from_object, ['base_teacher_model']),
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
if getv(from_object, ['tuned_teacher_model_source']) is not None:
|
|
498
|
+
setv(
|
|
499
|
+
parent_object,
|
|
500
|
+
['distillationSpec', 'tunedTeacherModelSource'],
|
|
501
|
+
getv(from_object, ['tuned_teacher_model_source']),
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
if getv(from_object, ['sft_loss_weight_multiplier']) is not None:
|
|
505
|
+
setv(
|
|
506
|
+
parent_object,
|
|
507
|
+
['distillationSpec', 'hyperParameters', 'sftLossWeightMultiplier'],
|
|
508
|
+
getv(from_object, ['sft_loss_weight_multiplier']),
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
if getv(from_object, ['output_uri']) is not None:
|
|
512
|
+
setv(parent_object, ['outputUri'], getv(from_object, ['output_uri']))
|
|
513
|
+
|
|
378
514
|
return to_object
|
|
379
515
|
|
|
380
516
|
|
|
@@ -920,6 +1056,14 @@ def _TuningDataset_to_vertex(
|
|
|
920
1056
|
getv(from_object, ['gcs_uri']),
|
|
921
1057
|
)
|
|
922
1058
|
|
|
1059
|
+
elif discriminator == 'DISTILLATION':
|
|
1060
|
+
if getv(from_object, ['gcs_uri']) is not None:
|
|
1061
|
+
setv(
|
|
1062
|
+
parent_object,
|
|
1063
|
+
['distillationSpec', 'promptDatasetUri'],
|
|
1064
|
+
getv(from_object, ['gcs_uri']),
|
|
1065
|
+
)
|
|
1066
|
+
|
|
923
1067
|
discriminator = getv(root_object, ['config', 'method'])
|
|
924
1068
|
if discriminator is None:
|
|
925
1069
|
discriminator = 'SUPERVISED_FINE_TUNING'
|
|
@@ -939,6 +1083,14 @@ def _TuningDataset_to_vertex(
|
|
|
939
1083
|
getv(from_object, ['vertex_dataset_resource']),
|
|
940
1084
|
)
|
|
941
1085
|
|
|
1086
|
+
elif discriminator == 'DISTILLATION':
|
|
1087
|
+
if getv(from_object, ['vertex_dataset_resource']) is not None:
|
|
1088
|
+
setv(
|
|
1089
|
+
parent_object,
|
|
1090
|
+
['distillationSpec', 'promptDatasetUri'],
|
|
1091
|
+
getv(from_object, ['vertex_dataset_resource']),
|
|
1092
|
+
)
|
|
1093
|
+
|
|
942
1094
|
if getv(from_object, ['examples']) is not None:
|
|
943
1095
|
raise ValueError('examples parameter is not supported in Vertex AI.')
|
|
944
1096
|
|
|
@@ -1066,6 +1218,13 @@ def _TuningJob_from_vertex(
|
|
|
1066
1218
|
getv(from_object, ['preferenceOptimizationSpec']),
|
|
1067
1219
|
)
|
|
1068
1220
|
|
|
1221
|
+
if getv(from_object, ['distillationSpec']) is not None:
|
|
1222
|
+
setv(
|
|
1223
|
+
to_object,
|
|
1224
|
+
['distillation_spec'],
|
|
1225
|
+
getv(from_object, ['distillationSpec']),
|
|
1226
|
+
)
|
|
1227
|
+
|
|
1069
1228
|
if getv(from_object, ['tuningDataStats']) is not None:
|
|
1070
1229
|
setv(
|
|
1071
1230
|
to_object, ['tuning_data_stats'], getv(from_object, ['tuningDataStats'])
|