google-genai 1.59.0__py3-none-any.whl → 1.61.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_interactions/resources/interactions.py +60 -38
- google/genai/_interactions/types/__init__.py +2 -1
- google/genai/_interactions/types/content_delta.py +1 -1
- google/genai/_interactions/types/function_result_content.py +2 -1
- google/genai/_interactions/types/function_result_content_param.py +4 -4
- google/genai/_interactions/types/{interaction_event.py → interaction_complete_event.py} +3 -3
- google/genai/_interactions/types/interaction_create_params.py +6 -6
- google/genai/_interactions/types/interaction_get_params.py +3 -0
- google/genai/_interactions/types/interaction_sse_event.py +11 -2
- google/genai/_interactions/types/interaction_start_event.py +36 -0
- google/genai/batches.py +8 -0
- google/genai/files.py +15 -15
- google/genai/models.py +12 -0
- google/genai/tests/batches/test_create_with_inlined_requests.py +31 -15
- google/genai/tests/batches/test_get.py +1 -1
- google/genai/tests/client/test_client_close.py +0 -1
- google/genai/tests/files/test_register_table.py +1 -1
- google/genai/tests/models/test_generate_content.py +16 -0
- google/genai/tests/transformers/test_schema.py +10 -1
- google/genai/tests/tunings/test_tune.py +87 -0
- google/genai/tunings.py +163 -4
- google/genai/types.py +221 -14
- google/genai/version.py +1 -1
- {google_genai-1.59.0.dist-info → google_genai-1.61.0.dist-info}/METADATA +1 -1
- {google_genai-1.59.0.dist-info → google_genai-1.61.0.dist-info}/RECORD +28 -27
- {google_genai-1.59.0.dist-info → google_genai-1.61.0.dist-info}/WHEEL +1 -1
- {google_genai-1.59.0.dist-info → google_genai-1.61.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.59.0.dist-info → google_genai-1.61.0.dist-info}/top_level.txt +0 -0
google/genai/batches.py
CHANGED
|
@@ -1013,6 +1013,11 @@ def _GenerateContentConfig_to_mldev(
|
|
|
1013
1013
|
getv(from_object, ['enable_enhanced_civic_answers']),
|
|
1014
1014
|
)
|
|
1015
1015
|
|
|
1016
|
+
if getv(from_object, ['model_armor_config']) is not None:
|
|
1017
|
+
raise ValueError(
|
|
1018
|
+
'model_armor_config parameter is not supported in Gemini API.'
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1016
1021
|
return to_object
|
|
1017
1022
|
|
|
1018
1023
|
|
|
@@ -1203,6 +1208,9 @@ def _InlinedResponse_from_mldev(
|
|
|
1203
1208
|
),
|
|
1204
1209
|
)
|
|
1205
1210
|
|
|
1211
|
+
if getv(from_object, ['metadata']) is not None:
|
|
1212
|
+
setv(to_object, ['metadata'], getv(from_object, ['metadata']))
|
|
1213
|
+
|
|
1206
1214
|
if getv(from_object, ['error']) is not None:
|
|
1207
1215
|
setv(to_object, ['error'], getv(from_object, ['error']))
|
|
1208
1216
|
|
google/genai/files.py
CHANGED
|
@@ -101,6 +101,17 @@ def _GetFileParameters_to_mldev(
|
|
|
101
101
|
return to_object
|
|
102
102
|
|
|
103
103
|
|
|
104
|
+
def _InternalRegisterFilesParameters_to_mldev(
|
|
105
|
+
from_object: Union[dict[str, Any], object],
|
|
106
|
+
parent_object: Optional[dict[str, Any]] = None,
|
|
107
|
+
) -> dict[str, Any]:
|
|
108
|
+
to_object: dict[str, Any] = {}
|
|
109
|
+
if getv(from_object, ['uris']) is not None:
|
|
110
|
+
setv(to_object, ['uris'], getv(from_object, ['uris']))
|
|
111
|
+
|
|
112
|
+
return to_object
|
|
113
|
+
|
|
114
|
+
|
|
104
115
|
def _ListFilesConfig_to_mldev(
|
|
105
116
|
from_object: Union[dict[str, Any], object],
|
|
106
117
|
parent_object: Optional[dict[str, Any]] = None,
|
|
@@ -152,17 +163,6 @@ def _ListFilesResponse_from_mldev(
|
|
|
152
163
|
return to_object
|
|
153
164
|
|
|
154
165
|
|
|
155
|
-
def _RegisterFilesParameters_to_mldev(
|
|
156
|
-
from_object: Union[dict[str, Any], object],
|
|
157
|
-
parent_object: Optional[dict[str, Any]] = None,
|
|
158
|
-
) -> dict[str, Any]:
|
|
159
|
-
to_object: dict[str, Any] = {}
|
|
160
|
-
if getv(from_object, ['uris']) is not None:
|
|
161
|
-
setv(to_object, ['uris'], getv(from_object, ['uris']))
|
|
162
|
-
|
|
163
|
-
return to_object
|
|
164
|
-
|
|
165
|
-
|
|
166
166
|
def _RegisterFilesResponse_from_mldev(
|
|
167
167
|
from_object: Union[dict[str, Any], object],
|
|
168
168
|
parent_object: Optional[dict[str, Any]] = None,
|
|
@@ -438,7 +438,7 @@ class Files(_api_module.BaseModule):
|
|
|
438
438
|
uris: list[str],
|
|
439
439
|
config: Optional[types.RegisterFilesConfigOrDict] = None,
|
|
440
440
|
) -> types.RegisterFilesResponse:
|
|
441
|
-
parameter_model = types.
|
|
441
|
+
parameter_model = types._InternalRegisterFilesParameters(
|
|
442
442
|
uris=uris,
|
|
443
443
|
config=config,
|
|
444
444
|
)
|
|
@@ -449,7 +449,7 @@ class Files(_api_module.BaseModule):
|
|
|
449
449
|
'This method is only supported in the Gemini Developer client.'
|
|
450
450
|
)
|
|
451
451
|
else:
|
|
452
|
-
request_dict =
|
|
452
|
+
request_dict = _InternalRegisterFilesParameters_to_mldev(parameter_model)
|
|
453
453
|
request_url_dict = request_dict.get('_url')
|
|
454
454
|
if request_url_dict:
|
|
455
455
|
path = 'files:register'.format_map(request_url_dict)
|
|
@@ -977,7 +977,7 @@ class AsyncFiles(_api_module.BaseModule):
|
|
|
977
977
|
uris: list[str],
|
|
978
978
|
config: Optional[types.RegisterFilesConfigOrDict] = None,
|
|
979
979
|
) -> types.RegisterFilesResponse:
|
|
980
|
-
parameter_model = types.
|
|
980
|
+
parameter_model = types._InternalRegisterFilesParameters(
|
|
981
981
|
uris=uris,
|
|
982
982
|
config=config,
|
|
983
983
|
)
|
|
@@ -988,7 +988,7 @@ class AsyncFiles(_api_module.BaseModule):
|
|
|
988
988
|
'This method is only supported in the Gemini Developer client.'
|
|
989
989
|
)
|
|
990
990
|
else:
|
|
991
|
-
request_dict =
|
|
991
|
+
request_dict = _InternalRegisterFilesParameters_to_mldev(parameter_model)
|
|
992
992
|
request_url_dict = request_dict.get('_url')
|
|
993
993
|
if request_url_dict:
|
|
994
994
|
path = 'files:register'.format_map(request_url_dict)
|
google/genai/models.py
CHANGED
|
@@ -1117,6 +1117,11 @@ def _GenerateContentConfig_to_mldev(
|
|
|
1117
1117
|
getv(from_object, ['enable_enhanced_civic_answers']),
|
|
1118
1118
|
)
|
|
1119
1119
|
|
|
1120
|
+
if getv(from_object, ['model_armor_config']) is not None:
|
|
1121
|
+
raise ValueError(
|
|
1122
|
+
'model_armor_config parameter is not supported in Gemini API.'
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1120
1125
|
return to_object
|
|
1121
1126
|
|
|
1122
1127
|
|
|
@@ -1279,6 +1284,13 @@ def _GenerateContentConfig_to_vertex(
|
|
|
1279
1284
|
'enable_enhanced_civic_answers parameter is not supported in Vertex AI.'
|
|
1280
1285
|
)
|
|
1281
1286
|
|
|
1287
|
+
if getv(from_object, ['model_armor_config']) is not None:
|
|
1288
|
+
setv(
|
|
1289
|
+
parent_object,
|
|
1290
|
+
['modelArmorConfig'],
|
|
1291
|
+
getv(from_object, ['model_armor_config']),
|
|
1292
|
+
)
|
|
1293
|
+
|
|
1282
1294
|
return to_object
|
|
1283
1295
|
|
|
1284
1296
|
|
|
@@ -42,20 +42,36 @@ _SAFETY_SETTINGS = [
|
|
|
42
42
|
},
|
|
43
43
|
]
|
|
44
44
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
'
|
|
48
|
-
'
|
|
45
|
+
_INLINED_REQUESTS = [
|
|
46
|
+
{
|
|
47
|
+
'contents': [{
|
|
48
|
+
'parts': [{
|
|
49
|
+
'text': 'what is the number after 1? return just the number.',
|
|
50
|
+
}],
|
|
51
|
+
'role': 'user',
|
|
49
52
|
}],
|
|
50
|
-
'
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
'
|
|
53
|
+
'metadata': {
|
|
54
|
+
'key': 'request-1',
|
|
55
|
+
},
|
|
56
|
+
'config': {
|
|
57
|
+
'safety_settings': _SAFETY_SETTINGS,
|
|
58
|
+
},
|
|
54
59
|
},
|
|
55
|
-
|
|
56
|
-
'
|
|
60
|
+
{
|
|
61
|
+
'contents': [{
|
|
62
|
+
'parts': [{
|
|
63
|
+
'text': 'what is the number after 2? return just the number.',
|
|
64
|
+
}],
|
|
65
|
+
'role': 'user',
|
|
66
|
+
}],
|
|
67
|
+
'metadata': {
|
|
68
|
+
'key': 'request-2',
|
|
69
|
+
},
|
|
70
|
+
'config': {
|
|
71
|
+
'safety_settings': _SAFETY_SETTINGS,
|
|
72
|
+
},
|
|
57
73
|
},
|
|
58
|
-
|
|
74
|
+
]
|
|
59
75
|
_INLINED_TEXT_REQUEST_UNION = {
|
|
60
76
|
'contents': [{
|
|
61
77
|
'parts': [{
|
|
@@ -154,7 +170,7 @@ test_table: list[pytest_helper.TestTableItem] = [
|
|
|
154
170
|
name='test_union_with_inlined_request',
|
|
155
171
|
parameters=types._CreateBatchJobParameters(
|
|
156
172
|
model=_MLDEV_GEMINI_MODEL,
|
|
157
|
-
src=
|
|
173
|
+
src=_INLINED_REQUESTS,
|
|
158
174
|
config={
|
|
159
175
|
'display_name': _DISPLAY_NAME,
|
|
160
176
|
},
|
|
@@ -166,7 +182,7 @@ test_table: list[pytest_helper.TestTableItem] = [
|
|
|
166
182
|
name='test_with_inlined_request',
|
|
167
183
|
parameters=types._CreateBatchJobParameters(
|
|
168
184
|
model=_MLDEV_GEMINI_MODEL,
|
|
169
|
-
src={'inlined_requests':
|
|
185
|
+
src={'inlined_requests': _INLINED_REQUESTS},
|
|
170
186
|
config={
|
|
171
187
|
'display_name': _DISPLAY_NAME,
|
|
172
188
|
},
|
|
@@ -177,7 +193,7 @@ test_table: list[pytest_helper.TestTableItem] = [
|
|
|
177
193
|
name='test_with_inlined_request_config',
|
|
178
194
|
parameters=types._CreateBatchJobParameters(
|
|
179
195
|
model=_MLDEV_GEMINI_MODEL,
|
|
180
|
-
src={'inlined_requests':
|
|
196
|
+
src={'inlined_requests': _INLINED_REQUESTS},
|
|
181
197
|
config={
|
|
182
198
|
'display_name': _DISPLAY_NAME,
|
|
183
199
|
},
|
|
@@ -247,7 +263,7 @@ async def test_async_create(client):
|
|
|
247
263
|
with pytest_helper.exception_if_vertex(client, ValueError):
|
|
248
264
|
batch_job = await client.aio.batches.create(
|
|
249
265
|
model=_GEMINI_MODEL,
|
|
250
|
-
src=
|
|
266
|
+
src=_INLINED_REQUESTS,
|
|
251
267
|
)
|
|
252
268
|
assert batch_job.name.startswith('batches/')
|
|
253
269
|
assert (
|
|
@@ -29,7 +29,7 @@ _BATCH_JOB_FULL_RESOURCE_NAME = (
|
|
|
29
29
|
f'batchPredictionJobs/{_BATCH_JOB_NAME}'
|
|
30
30
|
)
|
|
31
31
|
# MLDev batch operation name.
|
|
32
|
-
_MLDEV_BATCH_OPERATION_NAME = 'batches/
|
|
32
|
+
_MLDEV_BATCH_OPERATION_NAME = 'batches/z2p8ksus4lyxt25rntl3fpd67p2niw4hfij5'
|
|
33
33
|
_INVALID_BATCH_JOB_NAME = 'invalid_name'
|
|
34
34
|
|
|
35
35
|
|
|
@@ -42,7 +42,7 @@ def get_headers():
|
|
|
42
42
|
test_table: list[pytest_helper.TestTableItem] = [
|
|
43
43
|
pytest_helper.TestTableItem(
|
|
44
44
|
name='test_register',
|
|
45
|
-
parameters=types.
|
|
45
|
+
parameters=types._InternalRegisterFilesParameters(uris=['gs://unified-genai-dev/image.jpg']),
|
|
46
46
|
exception_if_vertex='only supported in the Gemini Developer client',
|
|
47
47
|
skip_in_api_mode=(
|
|
48
48
|
'The files have a TTL, they cannot be reliably retrieved for a long'
|
|
@@ -546,6 +546,22 @@ test_table: list[pytest_helper.TestTableItem] = [
|
|
|
546
546
|
),
|
|
547
547
|
exception_if_vertex='not supported',
|
|
548
548
|
),
|
|
549
|
+
pytest_helper.TestTableItem(
|
|
550
|
+
name='test_model_armor_config',
|
|
551
|
+
parameters=types._GenerateContentParameters(
|
|
552
|
+
model=GEMINI_FLASH_LATEST,
|
|
553
|
+
contents=t.t_contents('What is your name?'),
|
|
554
|
+
config={
|
|
555
|
+
'model_armor_config': {
|
|
556
|
+
'prompt_template_name': '',
|
|
557
|
+
'response_template_name': '',
|
|
558
|
+
# Intentionally left blank just to test that the SDK doesn't
|
|
559
|
+
# throw an exception.
|
|
560
|
+
},
|
|
561
|
+
},
|
|
562
|
+
),
|
|
563
|
+
exception_if_mldev='not supported',
|
|
564
|
+
),
|
|
549
565
|
]
|
|
550
566
|
|
|
551
567
|
pytestmark = pytest_helper.setup(
|
|
@@ -79,7 +79,6 @@ class CountryInfoWithAnyOf(pydantic.BaseModel):
|
|
|
79
79
|
|
|
80
80
|
|
|
81
81
|
@pytest.fixture
|
|
82
|
-
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
83
82
|
def client(use_vertex):
|
|
84
83
|
if use_vertex:
|
|
85
84
|
yield google_genai_client_module.Client(
|
|
@@ -91,6 +90,7 @@ def client(use_vertex):
|
|
|
91
90
|
)
|
|
92
91
|
|
|
93
92
|
|
|
93
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
94
94
|
def test_build_schema_for_list_of_pydantic_schema(client):
|
|
95
95
|
"""Tests _build_schema() when list[pydantic.BaseModel] is provided to response_schema."""
|
|
96
96
|
|
|
@@ -112,6 +112,7 @@ def test_build_schema_for_list_of_pydantic_schema(client):
|
|
|
112
112
|
assert list_schema['required'] == list(country_info_fields.keys())
|
|
113
113
|
|
|
114
114
|
|
|
115
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
115
116
|
def test_build_schema_for_list_of_nested_pydantic_schema(client):
|
|
116
117
|
"""Tests _build_schema() when list[pydantic.BaseModel] is provided to response_schema and the pydantic.BaseModel has nested pydantic fields."""
|
|
117
118
|
list_schema = _transformers.t_schema(
|
|
@@ -132,6 +133,7 @@ def test_build_schema_for_list_of_nested_pydantic_schema(client):
|
|
|
132
133
|
assert field_name in currency_info_fields
|
|
133
134
|
|
|
134
135
|
|
|
136
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
135
137
|
def test_t_schema_for_pydantic_schema(client):
|
|
136
138
|
"""Tests t_schema when pydantic.BaseModel is passed to response_schema."""
|
|
137
139
|
transformed_schema = _transformers.t_schema(client, CountryInfo)
|
|
@@ -143,6 +145,7 @@ def test_t_schema_for_pydantic_schema(client):
|
|
|
143
145
|
)
|
|
144
146
|
|
|
145
147
|
|
|
148
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
146
149
|
def test_t_schema_for_list_of_pydantic_schema(client):
|
|
147
150
|
"""Tests t_schema when list[pydantic.BaseModel] is passed to response_schema."""
|
|
148
151
|
transformed_schema = _transformers.t_schema(client, list[CountryInfo])
|
|
@@ -156,6 +159,7 @@ def test_t_schema_for_list_of_pydantic_schema(client):
|
|
|
156
159
|
)
|
|
157
160
|
|
|
158
161
|
|
|
162
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
159
163
|
def test_t_schema_for_null_fields(client):
|
|
160
164
|
"""Tests t_schema when null fields are present."""
|
|
161
165
|
transformed_schema = _transformers.t_schema(client, CountryInfoWithNullFields)
|
|
@@ -163,6 +167,7 @@ def test_t_schema_for_null_fields(client):
|
|
|
163
167
|
assert transformed_schema.properties['population'].nullable
|
|
164
168
|
|
|
165
169
|
|
|
170
|
+
#@pytest.mark.parametrize('use_vertex', [True, False])
|
|
166
171
|
def test_schema_with_no_null_fields_is_unchanged():
|
|
167
172
|
"""Tests handle_null_fields() doesn't change anything when no null fields are present."""
|
|
168
173
|
test_properties = {
|
|
@@ -207,6 +212,7 @@ def test_schema_with_default_value(client):
|
|
|
207
212
|
assert transformed_schema == expected_schema
|
|
208
213
|
|
|
209
214
|
|
|
215
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
210
216
|
def test_schema_with_any_of(client):
|
|
211
217
|
transformed_schema = _transformers.t_schema(client, CountryInfoWithAnyOf)
|
|
212
218
|
expected_schema = types.Schema(
|
|
@@ -601,6 +607,7 @@ def test_process_schema_order_properties_propagates_into_any_of(
|
|
|
601
607
|
assert schema == schema_without_property_ordering
|
|
602
608
|
|
|
603
609
|
|
|
610
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
604
611
|
def test_t_schema_does_not_change_property_ordering_if_set(client):
|
|
605
612
|
"""Tests t_schema doesn't overwrite the property_ordering field if already set."""
|
|
606
613
|
|
|
@@ -612,6 +619,7 @@ def test_t_schema_does_not_change_property_ordering_if_set(client):
|
|
|
612
619
|
assert transformed_schema.property_ordering == custom_property_ordering
|
|
613
620
|
|
|
614
621
|
|
|
622
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
615
623
|
def test_t_schema_sets_property_ordering_for_json_schema(client):
|
|
616
624
|
"""Tests t_schema sets the property_ordering field for json schemas."""
|
|
617
625
|
|
|
@@ -629,6 +637,7 @@ def test_t_schema_sets_property_ordering_for_json_schema(client):
|
|
|
629
637
|
]
|
|
630
638
|
|
|
631
639
|
|
|
640
|
+
@pytest.mark.parametrize('use_vertex', [True, False])
|
|
632
641
|
def test_t_schema_sets_property_ordering_for_schema_type(client):
|
|
633
642
|
"""Tests t_schema sets the property_ordering field for Schema types."""
|
|
634
643
|
|
|
@@ -20,6 +20,12 @@ from ... import types as genai_types
|
|
|
20
20
|
from .. import pytest_helper
|
|
21
21
|
import pytest
|
|
22
22
|
|
|
23
|
+
|
|
24
|
+
VERTEX_HTTP_OPTIONS = {
|
|
25
|
+
'api_version': 'v1beta1',
|
|
26
|
+
'base_url': 'https://us-central1-autopush-aiplatform.sandbox.googleapis.com/',
|
|
27
|
+
}
|
|
28
|
+
|
|
23
29
|
evaluation_config=genai_types.EvaluationConfig(
|
|
24
30
|
metrics=[
|
|
25
31
|
genai_types.Metric(name="bleu", prompt_template="test prompt template")
|
|
@@ -158,6 +164,87 @@ test_table: list[pytest_helper.TestTableItem] = [
|
|
|
158
164
|
),
|
|
159
165
|
exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.",
|
|
160
166
|
),
|
|
167
|
+
pytest_helper.TestTableItem(
|
|
168
|
+
name="test_tune_distillation",
|
|
169
|
+
parameters=genai_types.CreateTuningJobParameters(
|
|
170
|
+
base_model="meta/llama3_1@llama-3.1-8b-instruct",
|
|
171
|
+
training_dataset=genai_types.TuningDataset(
|
|
172
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
|
|
173
|
+
),
|
|
174
|
+
config=genai_types.CreateTuningJobConfig(
|
|
175
|
+
method="DISTILLATION",
|
|
176
|
+
base_teacher_model="deepseek-ai/deepseek-v3.1-maas",
|
|
177
|
+
epoch_count=20,
|
|
178
|
+
validation_dataset=genai_types.TuningValidationDataset(
|
|
179
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
|
|
180
|
+
),
|
|
181
|
+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder",
|
|
182
|
+
http_options=VERTEX_HTTP_OPTIONS,
|
|
183
|
+
),
|
|
184
|
+
),
|
|
185
|
+
exception_if_mldev="parameter is not supported in Gemini API.",
|
|
186
|
+
),
|
|
187
|
+
pytest_helper.TestTableItem(
|
|
188
|
+
name="test_tune_oss_sft",
|
|
189
|
+
parameters=genai_types.CreateTuningJobParameters(
|
|
190
|
+
base_model="meta/llama3_1@llama-3.1-8b-instruct",
|
|
191
|
+
training_dataset=genai_types.TuningDataset(
|
|
192
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
|
|
193
|
+
),
|
|
194
|
+
config=genai_types.CreateTuningJobConfig(
|
|
195
|
+
epoch_count=20,
|
|
196
|
+
validation_dataset=genai_types.TuningValidationDataset(
|
|
197
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
|
|
198
|
+
),
|
|
199
|
+
custom_base_model="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder/postprocess/node-0/checkpoints/final",
|
|
200
|
+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
|
|
201
|
+
http_options=VERTEX_HTTP_OPTIONS,
|
|
202
|
+
),
|
|
203
|
+
),
|
|
204
|
+
exception_if_mldev="not supported in Gemini API",
|
|
205
|
+
),
|
|
206
|
+
pytest_helper.TestTableItem(
|
|
207
|
+
name="test_tune_oss_sft_hyperparams",
|
|
208
|
+
parameters=genai_types.CreateTuningJobParameters(
|
|
209
|
+
base_model="meta/llama3_1@llama-3.1-8b-instruct",
|
|
210
|
+
training_dataset=genai_types.TuningDataset(
|
|
211
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
|
|
212
|
+
),
|
|
213
|
+
config=genai_types.CreateTuningJobConfig(
|
|
214
|
+
epoch_count=20,
|
|
215
|
+
validation_dataset=genai_types.TuningValidationDataset(
|
|
216
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
|
|
217
|
+
),
|
|
218
|
+
learning_rate=2.5e-4,
|
|
219
|
+
tuning_mode="TUNING_MODE_FULL",
|
|
220
|
+
custom_base_model="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder/postprocess/node-0/checkpoints/final",
|
|
221
|
+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
|
|
222
|
+
http_options=VERTEX_HTTP_OPTIONS,
|
|
223
|
+
),
|
|
224
|
+
),
|
|
225
|
+
exception_if_mldev="not supported in Gemini API",
|
|
226
|
+
),
|
|
227
|
+
pytest_helper.TestTableItem(
|
|
228
|
+
name="test_tune_oss_distillation",
|
|
229
|
+
parameters=genai_types.CreateTuningJobParameters(
|
|
230
|
+
base_model="meta/llama3_1@llama-3.1-8b-instruct",
|
|
231
|
+
training_dataset=genai_types.TuningDataset(
|
|
232
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
|
|
233
|
+
),
|
|
234
|
+
config=genai_types.CreateTuningJobConfig(
|
|
235
|
+
method="DISTILLATION",
|
|
236
|
+
base_teacher_model="deepseek-ai/deepseek-v3.1-maas",
|
|
237
|
+
epoch_count=20,
|
|
238
|
+
validation_dataset=genai_types.TuningValidationDataset(
|
|
239
|
+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
|
|
240
|
+
),
|
|
241
|
+
custom_base_model="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder/postprocess/node-0/checkpoints/final",
|
|
242
|
+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
|
|
243
|
+
http_options=VERTEX_HTTP_OPTIONS,
|
|
244
|
+
),
|
|
245
|
+
),
|
|
246
|
+
exception_if_mldev="not supported in Gemini API",
|
|
247
|
+
),
|
|
161
248
|
]
|
|
162
249
|
|
|
163
250
|
pytestmark = pytest_helper.setup(
|