google-genai 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_automatic_function_calling_util.py +15 -15
- google/genai/_common.py +2 -1
- google/genai/_extra_utils.py +2 -2
- google/genai/_transformers.py +31 -3
- google/genai/chats.py +24 -8
- google/genai/models.py +19 -9
- google/genai/tunings.py +15 -78
- google/genai/types.py +55 -8
- google/genai/version.py +1 -1
- {google_genai-1.0.0.dist-info → google_genai-1.1.0.dist-info}/METADATA +182 -150
- {google_genai-1.0.0.dist-info → google_genai-1.1.0.dist-info}/RECORD +14 -14
- {google_genai-1.0.0.dist-info → google_genai-1.1.0.dist-info}/LICENSE +0 -0
- {google_genai-1.0.0.dist-info → google_genai-1.1.0.dist-info}/WHEEL +0 -0
- {google_genai-1.0.0.dist-info → google_genai-1.1.0.dist-info}/top_level.txt +0 -0
@@ -62,8 +62,8 @@ def _raise_for_default_if_mldev(schema: types.Schema):
|
|
62
62
|
)
|
63
63
|
|
64
64
|
|
65
|
-
def _raise_if_schema_unsupported(
|
66
|
-
if
|
65
|
+
def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
|
66
|
+
if api_option == 'GEMINI_API':
|
67
67
|
_raise_for_any_of_if_mldev(schema)
|
68
68
|
_raise_for_default_if_mldev(schema)
|
69
69
|
|
@@ -114,7 +114,7 @@ def _is_default_value_compatible(
|
|
114
114
|
|
115
115
|
|
116
116
|
def _parse_schema_from_parameter(
|
117
|
-
|
117
|
+
api_option: Literal['VERTEX_AI', 'GEMINI_API'],
|
118
118
|
param: inspect.Parameter,
|
119
119
|
func_name: str,
|
120
120
|
) -> types.Schema:
|
@@ -134,7 +134,7 @@ def _parse_schema_from_parameter(
|
|
134
134
|
raise ValueError(default_value_error_msg)
|
135
135
|
schema.default = param.default
|
136
136
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
137
|
-
_raise_if_schema_unsupported(
|
137
|
+
_raise_if_schema_unsupported(api_option, schema)
|
138
138
|
return schema
|
139
139
|
if (
|
140
140
|
isinstance(param.annotation, VersionedUnionType)
|
@@ -153,7 +153,7 @@ def _parse_schema_from_parameter(
|
|
153
153
|
schema.nullable = True
|
154
154
|
continue
|
155
155
|
schema_in_any_of = _parse_schema_from_parameter(
|
156
|
-
|
156
|
+
api_option,
|
157
157
|
inspect.Parameter(
|
158
158
|
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
159
159
|
),
|
@@ -175,7 +175,7 @@ def _parse_schema_from_parameter(
|
|
175
175
|
if not _is_default_value_compatible(param.default, param.annotation):
|
176
176
|
raise ValueError(default_value_error_msg)
|
177
177
|
schema.default = param.default
|
178
|
-
_raise_if_schema_unsupported(
|
178
|
+
_raise_if_schema_unsupported(api_option, schema)
|
179
179
|
return schema
|
180
180
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
181
181
|
param.annotation, builtin_types.GenericAlias
|
@@ -188,7 +188,7 @@ def _parse_schema_from_parameter(
|
|
188
188
|
if not _is_default_value_compatible(param.default, param.annotation):
|
189
189
|
raise ValueError(default_value_error_msg)
|
190
190
|
schema.default = param.default
|
191
|
-
_raise_if_schema_unsupported(
|
191
|
+
_raise_if_schema_unsupported(api_option, schema)
|
192
192
|
return schema
|
193
193
|
if origin is Literal:
|
194
194
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -201,12 +201,12 @@ def _parse_schema_from_parameter(
|
|
201
201
|
if not _is_default_value_compatible(param.default, param.annotation):
|
202
202
|
raise ValueError(default_value_error_msg)
|
203
203
|
schema.default = param.default
|
204
|
-
_raise_if_schema_unsupported(
|
204
|
+
_raise_if_schema_unsupported(api_option, schema)
|
205
205
|
return schema
|
206
206
|
if origin is list:
|
207
207
|
schema.type = 'ARRAY'
|
208
208
|
schema.items = _parse_schema_from_parameter(
|
209
|
-
|
209
|
+
api_option,
|
210
210
|
inspect.Parameter(
|
211
211
|
'item',
|
212
212
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -218,7 +218,7 @@ def _parse_schema_from_parameter(
|
|
218
218
|
if not _is_default_value_compatible(param.default, param.annotation):
|
219
219
|
raise ValueError(default_value_error_msg)
|
220
220
|
schema.default = param.default
|
221
|
-
_raise_if_schema_unsupported(
|
221
|
+
_raise_if_schema_unsupported(api_option, schema)
|
222
222
|
return schema
|
223
223
|
if origin is Union:
|
224
224
|
schema.any_of = []
|
@@ -233,7 +233,7 @@ def _parse_schema_from_parameter(
|
|
233
233
|
schema.nullable = True
|
234
234
|
continue
|
235
235
|
schema_in_any_of = _parse_schema_from_parameter(
|
236
|
-
|
236
|
+
api_option,
|
237
237
|
inspect.Parameter(
|
238
238
|
'item',
|
239
239
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -268,7 +268,7 @@ def _parse_schema_from_parameter(
|
|
268
268
|
if not _is_default_value_compatible(param.default, param.annotation):
|
269
269
|
raise ValueError(default_value_error_msg)
|
270
270
|
schema.default = param.default
|
271
|
-
_raise_if_schema_unsupported(
|
271
|
+
_raise_if_schema_unsupported(api_option, schema)
|
272
272
|
return schema
|
273
273
|
# all other generic alias will be invoked in raise branch
|
274
274
|
if (
|
@@ -284,7 +284,7 @@ def _parse_schema_from_parameter(
|
|
284
284
|
schema.properties = {}
|
285
285
|
for field_name, field_info in param.annotation.model_fields.items():
|
286
286
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
287
|
-
|
287
|
+
api_option,
|
288
288
|
inspect.Parameter(
|
289
289
|
field_name,
|
290
290
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -292,9 +292,9 @@ def _parse_schema_from_parameter(
|
|
292
292
|
),
|
293
293
|
func_name,
|
294
294
|
)
|
295
|
-
if
|
295
|
+
if api_option == 'VERTEX_AI':
|
296
296
|
schema.required = _get_required_fields(schema)
|
297
|
-
_raise_if_schema_unsupported(
|
297
|
+
_raise_if_schema_unsupported(api_option, schema)
|
298
298
|
return schema
|
299
299
|
raise ValueError(
|
300
300
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
google/genai/_common.py
CHANGED
@@ -220,7 +220,8 @@ class CaseInSensitiveEnum(str, enum.Enum):
|
|
220
220
|
warnings.warn(f"{value} is not a valid {cls.__name__}")
|
221
221
|
try:
|
222
222
|
# Creating a enum instance based on the value
|
223
|
-
|
223
|
+
# We need to use super() to avoid infinite recursion.
|
224
|
+
unknown_enum_val = super().__new__(cls, value)
|
224
225
|
unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
|
225
226
|
unknown_enum_val._value_ = value # pylint: disable=protected-access
|
226
227
|
return unknown_enum_val
|
google/genai/_extra_utils.py
CHANGED
@@ -271,8 +271,8 @@ def should_disable_afc(
|
|
271
271
|
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
|
272
272
|
):
|
273
273
|
logging.warning(
|
274
|
-
'`automatic_function_calling.disable` is set to `True`.
|
275
|
-
' `automatic_function_calling.maximum_remote_calls` is
|
274
|
+
'`automatic_function_calling.disable` is set to `True`. And'
|
275
|
+
' `automatic_function_calling.maximum_remote_calls` is a'
|
276
276
|
' positive number'
|
277
277
|
f' {config_model.automatic_function_calling.maximum_remote_calls}.'
|
278
278
|
' Disabling automatic function calling. If you want to enable'
|
google/genai/_transformers.py
CHANGED
@@ -374,8 +374,8 @@ def handle_null_fields(schema: dict[str, Any]):
|
|
374
374
|
schema['anyOf'].remove({'type': 'null'})
|
375
375
|
if len(schema['anyOf']) == 1:
|
376
376
|
# If there is only one type left after removing null, remove the anyOf field.
|
377
|
-
|
378
|
-
|
377
|
+
for key,val in schema['anyOf'][0].items():
|
378
|
+
schema[key] = val
|
379
379
|
del schema['anyOf']
|
380
380
|
|
381
381
|
|
@@ -446,9 +446,17 @@ def process_schema(
|
|
446
446
|
|
447
447
|
if schema.get('default') is not None:
|
448
448
|
raise ValueError(
|
449
|
-
'Default value is not supported in the response schema for the
|
449
|
+
'Default value is not supported in the response schema for the Gemini API.'
|
450
450
|
)
|
451
451
|
|
452
|
+
if schema.get('title') == 'PlaceholderLiteralEnum':
|
453
|
+
schema.pop('title', None)
|
454
|
+
|
455
|
+
# If a dict is provided directly to response_schema, it may use `any_of`
|
456
|
+
# instead of `anyOf`. Otherwise model_json_schema() uses `anyOf`
|
457
|
+
if schema.get('any_of', None) is not None:
|
458
|
+
schema['anyOf'] = schema.pop('any_of')
|
459
|
+
|
452
460
|
if defs is None:
|
453
461
|
defs = schema.pop('$defs', {})
|
454
462
|
for _, sub_schema in defs.items():
|
@@ -456,6 +464,15 @@ def process_schema(
|
|
456
464
|
|
457
465
|
handle_null_fields(schema)
|
458
466
|
|
467
|
+
# After removing null fields, Optional fields with only one possible type
|
468
|
+
# will have a $ref key that needs to be flattened
|
469
|
+
# For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
|
470
|
+
if schema.get('$ref', None):
|
471
|
+
ref = defs[schema.get('$ref').split('defs/')[-1]]
|
472
|
+
for schema_key in list(ref.keys()):
|
473
|
+
schema[schema_key] = ref[schema_key]
|
474
|
+
del schema['$ref']
|
475
|
+
|
459
476
|
any_of = schema.get('anyOf', None)
|
460
477
|
if any_of is not None:
|
461
478
|
if not client.vertexai:
|
@@ -478,6 +495,16 @@ def process_schema(
|
|
478
495
|
schema_type = schema_type.value
|
479
496
|
schema_type = schema_type.upper()
|
480
497
|
|
498
|
+
# model_json_schema() returns a schema with a 'const' field when a Literal with one value is provided as a pydantic field
|
499
|
+
# For example `genre: Literal['action']` becomes: {'const': 'action', 'title': 'Genre', 'type': 'string'}
|
500
|
+
const = schema.get('const', None)
|
501
|
+
if const is not None:
|
502
|
+
if schema_type == 'STRING':
|
503
|
+
schema['enum'] = [const]
|
504
|
+
del schema['const']
|
505
|
+
else:
|
506
|
+
raise ValueError('Literal values must be strings.')
|
507
|
+
|
481
508
|
if schema_type == 'OBJECT':
|
482
509
|
properties = schema.get('properties', None)
|
483
510
|
if properties is None:
|
@@ -502,6 +529,7 @@ def process_schema(
|
|
502
529
|
process_schema(ref, client, defs)
|
503
530
|
schema['items'] = ref
|
504
531
|
|
532
|
+
|
505
533
|
def _process_enum(
|
506
534
|
enum: EnumMeta, client: Optional[_api_client.ApiClient] = None
|
507
535
|
) -> types.Schema:
|
google/genai/chats.py
CHANGED
@@ -57,12 +57,16 @@ class Chat(_BaseChat):
|
|
57
57
|
"""Chat session."""
|
58
58
|
|
59
59
|
def send_message(
|
60
|
-
self,
|
60
|
+
self,
|
61
|
+
message: Union[list[PartUnionDict], PartUnionDict],
|
62
|
+
config: Optional[GenerateContentConfigOrDict] = None,
|
61
63
|
) -> GenerateContentResponse:
|
62
64
|
"""Sends the conversation history with the additional message and returns the model's response.
|
63
65
|
|
64
66
|
Args:
|
65
67
|
message: The message to send to the model.
|
68
|
+
config: Optional config to override the default Chat config for this
|
69
|
+
request.
|
66
70
|
|
67
71
|
Returns:
|
68
72
|
The model's response.
|
@@ -79,7 +83,7 @@ class Chat(_BaseChat):
|
|
79
83
|
response = self._modules.generate_content(
|
80
84
|
model=self._model,
|
81
85
|
contents=self._curated_history + [input_content],
|
82
|
-
config=self._config,
|
86
|
+
config=config if config else self._config,
|
83
87
|
)
|
84
88
|
if _validate_response(response):
|
85
89
|
if response.automatic_function_calling_history:
|
@@ -92,12 +96,16 @@ class Chat(_BaseChat):
|
|
92
96
|
return response
|
93
97
|
|
94
98
|
def send_message_stream(
|
95
|
-
self,
|
99
|
+
self,
|
100
|
+
message: Union[list[PartUnionDict], PartUnionDict],
|
101
|
+
config: Optional[GenerateContentConfigOrDict] = None,
|
96
102
|
):
|
97
103
|
"""Sends the conversation history with the additional message and yields the model's response in chunks.
|
98
104
|
|
99
105
|
Args:
|
100
106
|
message: The message to send to the model.
|
107
|
+
config: Optional config to override the default Chat config for this
|
108
|
+
request.
|
101
109
|
|
102
110
|
Yields:
|
103
111
|
The model's response in chunks.
|
@@ -117,7 +125,7 @@ class Chat(_BaseChat):
|
|
117
125
|
for chunk in self._modules.generate_content_stream(
|
118
126
|
model=self._model,
|
119
127
|
contents=self._curated_history + [input_content],
|
120
|
-
config=self._config,
|
128
|
+
config=config if config else self._config,
|
121
129
|
):
|
122
130
|
if _validate_response(chunk):
|
123
131
|
output_contents.append(chunk.candidates[0].content)
|
@@ -164,12 +172,16 @@ class AsyncChat(_BaseChat):
|
|
164
172
|
"""Async chat session."""
|
165
173
|
|
166
174
|
async def send_message(
|
167
|
-
self,
|
175
|
+
self,
|
176
|
+
message: Union[list[PartUnionDict], PartUnionDict],
|
177
|
+
config: Optional[GenerateContentConfigOrDict] = None,
|
168
178
|
) -> GenerateContentResponse:
|
169
179
|
"""Sends the conversation history with the additional message and returns model's response.
|
170
180
|
|
171
181
|
Args:
|
172
182
|
message: The message to send to the model.
|
183
|
+
config: Optional config to override the default Chat config for this
|
184
|
+
request.
|
173
185
|
|
174
186
|
Returns:
|
175
187
|
The model's response.
|
@@ -186,7 +198,7 @@ class AsyncChat(_BaseChat):
|
|
186
198
|
response = await self._modules.generate_content(
|
187
199
|
model=self._model,
|
188
200
|
contents=self._curated_history + [input_content],
|
189
|
-
config=self._config,
|
201
|
+
config=config if config else self._config,
|
190
202
|
)
|
191
203
|
if _validate_response(response):
|
192
204
|
if response.automatic_function_calling_history:
|
@@ -199,12 +211,16 @@ class AsyncChat(_BaseChat):
|
|
199
211
|
return response
|
200
212
|
|
201
213
|
async def send_message_stream(
|
202
|
-
self,
|
214
|
+
self,
|
215
|
+
message: Union[list[PartUnionDict], PartUnionDict],
|
216
|
+
config: Optional[GenerateContentConfigOrDict] = None,
|
203
217
|
) -> Awaitable[AsyncIterator[GenerateContentResponse]]:
|
204
218
|
"""Sends the conversation history with the additional message and yields the model's response in chunks.
|
205
219
|
|
206
220
|
Args:
|
207
221
|
message: The message to send to the model.
|
222
|
+
config: Optional config to override the default Chat config for this
|
223
|
+
request.
|
208
224
|
|
209
225
|
Yields:
|
210
226
|
The model's response in chunks.
|
@@ -225,7 +241,7 @@ class AsyncChat(_BaseChat):
|
|
225
241
|
async for chunk in await self._modules.generate_content_stream(
|
226
242
|
model=self._model,
|
227
243
|
contents=self._curated_history + [input_content],
|
228
|
-
config=self._config,
|
244
|
+
config=config if config else self._config,
|
229
245
|
):
|
230
246
|
if _validate_response(chunk):
|
231
247
|
output_contents.append(chunk.candidates[0].content)
|
google/genai/models.py
CHANGED
@@ -4784,15 +4784,18 @@ class Models(_api_module.BaseModule):
|
|
4784
4784
|
"""
|
4785
4785
|
|
4786
4786
|
if _extra_utils.should_disable_afc(config):
|
4787
|
-
|
4787
|
+
yield from self._generate_content_stream(
|
4788
4788
|
model=model, contents=contents, config=config
|
4789
4789
|
)
|
4790
|
+
return
|
4791
|
+
|
4790
4792
|
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config)
|
4791
4793
|
logging.info(
|
4792
4794
|
f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.'
|
4793
4795
|
)
|
4794
4796
|
automatic_function_calling_history = []
|
4795
4797
|
chunk = None
|
4798
|
+
func_response_parts = None
|
4796
4799
|
i = 0
|
4797
4800
|
while remaining_remote_calls_afc > 0:
|
4798
4801
|
i += 1
|
@@ -4828,6 +4831,10 @@ class Models(_api_module.BaseModule):
|
|
4828
4831
|
automatic_function_calling_history
|
4829
4832
|
)
|
4830
4833
|
yield chunk
|
4834
|
+
func_response_parts = _extra_utils.get_function_response_parts(
|
4835
|
+
chunk, function_map
|
4836
|
+
)
|
4837
|
+
|
4831
4838
|
if not chunk:
|
4832
4839
|
break
|
4833
4840
|
if (
|
@@ -4840,9 +4847,6 @@ class Models(_api_module.BaseModule):
|
|
4840
4847
|
|
4841
4848
|
if not function_map:
|
4842
4849
|
break
|
4843
|
-
func_response_parts = _extra_utils.get_function_response_parts(
|
4844
|
-
chunk, function_map
|
4845
|
-
)
|
4846
4850
|
if not func_response_parts:
|
4847
4851
|
break
|
4848
4852
|
|
@@ -5930,10 +5934,16 @@ class AsyncModels(_api_module.BaseModule):
|
|
5930
5934
|
"""
|
5931
5935
|
|
5932
5936
|
if _extra_utils.should_disable_afc(config):
|
5933
|
-
|
5937
|
+
response = await self._generate_content_stream(
|
5934
5938
|
model=model, contents=contents, config=config
|
5935
5939
|
)
|
5936
5940
|
|
5941
|
+
async def base_async_generator(model, contents, config):
|
5942
|
+
async for chunk in response:
|
5943
|
+
yield chunk
|
5944
|
+
|
5945
|
+
return base_async_generator(model, contents, config)
|
5946
|
+
|
5937
5947
|
async def async_generator(model, contents, config):
|
5938
5948
|
remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config)
|
5939
5949
|
logging.info(
|
@@ -5979,7 +5989,9 @@ class AsyncModels(_api_module.BaseModule):
|
|
5979
5989
|
automatic_function_calling_history
|
5980
5990
|
)
|
5981
5991
|
yield chunk
|
5982
|
-
|
5992
|
+
func_response_parts = _extra_utils.get_function_response_parts(
|
5993
|
+
chunk, function_map
|
5994
|
+
)
|
5983
5995
|
if not chunk:
|
5984
5996
|
break
|
5985
5997
|
if (
|
@@ -5991,9 +6003,7 @@ class AsyncModels(_api_module.BaseModule):
|
|
5991
6003
|
break
|
5992
6004
|
if not function_map:
|
5993
6005
|
break
|
5994
|
-
|
5995
|
-
chunk, function_map
|
5996
|
-
)
|
6006
|
+
|
5997
6007
|
if not func_response_parts:
|
5998
6008
|
break
|
5999
6009
|
|
google/genai/tunings.py
CHANGED
@@ -15,8 +15,7 @@
|
|
15
15
|
|
16
16
|
# Code generated by the Google Gen AI SDK generator DO NOT EDIT.
|
17
17
|
|
18
|
-
import
|
19
|
-
from typing import Any, Optional, Union
|
18
|
+
from typing import Optional, Union
|
20
19
|
from urllib.parse import urlencode
|
21
20
|
from . import _api_module
|
22
21
|
from . import _common
|
@@ -998,11 +997,14 @@ class Tunings(_api_module.BaseModule):
|
|
998
997
|
config=config,
|
999
998
|
)
|
1000
999
|
operation_dict = operation.to_json_dict()
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1000
|
+
try:
|
1001
|
+
tuned_model_name = operation_dict['metadata']['tunedModel']
|
1002
|
+
except KeyError:
|
1003
|
+
tuned_model_name = operation_dict['name'].partition('/operations/')[0]
|
1004
|
+
tuning_job = types.TuningJob(
|
1005
|
+
name=tuned_model_name,
|
1006
|
+
state=types.JobState.JOB_STATE_QUEUED,
|
1004
1007
|
)
|
1005
|
-
tuning_job = types.TuningJob._from_response(tuning_job_dict, None)
|
1006
1008
|
if tuning_job.name and self._api_client.vertexai:
|
1007
1009
|
_IpythonUtils.display_model_tuning_button(
|
1008
1010
|
tuning_job_resource=tuning_job.name
|
@@ -1010,72 +1012,6 @@ class Tunings(_api_module.BaseModule):
|
|
1010
1012
|
return tuning_job
|
1011
1013
|
|
1012
1014
|
|
1013
|
-
_LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
|
1014
|
-
_LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
|
1015
|
-
_LRO_POLLING_TIMEOUT_SECONDS = 900.0
|
1016
|
-
_LRO_POLLING_MULTIPLIER = 1.5
|
1017
|
-
|
1018
|
-
|
1019
|
-
def _resolve_operation(api_client: ApiClient, struct: dict[str, Any]):
|
1020
|
-
if (name := struct.get('name')) and '/operations/' in name:
|
1021
|
-
operation: dict[str, Any] = struct
|
1022
|
-
total_seconds = 0.0
|
1023
|
-
delay_seconds = _LRO_POLLING_INITIAL_DELAY_SECONDS
|
1024
|
-
while not operation.get('done'):
|
1025
|
-
if total_seconds > _LRO_POLLING_TIMEOUT_SECONDS:
|
1026
|
-
raise RuntimeError(f'Operation {name} timed out.\n{operation}')
|
1027
|
-
# TODO(b/374433890): Replace with LRO module once it's available.
|
1028
|
-
operation: dict[str, Any] = api_client.request(
|
1029
|
-
http_method='GET', path=name, request_dict={}
|
1030
|
-
)
|
1031
|
-
if 'ReplayApiClient' not in type(api_client).__name__:
|
1032
|
-
time.sleep(delay_seconds)
|
1033
|
-
total_seconds += total_seconds
|
1034
|
-
# Exponential backoff
|
1035
|
-
delay_seconds = min(
|
1036
|
-
delay_seconds * _LRO_POLLING_MULTIPLIER,
|
1037
|
-
_LRO_POLLING_MAXIMUM_DELAY_SECONDS,
|
1038
|
-
)
|
1039
|
-
if error := operation.get('error'):
|
1040
|
-
raise RuntimeError(
|
1041
|
-
f'Operation {name} failed with error: {error}.\n{operation}'
|
1042
|
-
)
|
1043
|
-
return operation.get('response')
|
1044
|
-
else:
|
1045
|
-
return struct
|
1046
|
-
|
1047
|
-
|
1048
|
-
async def _resolve_operation_async(
|
1049
|
-
api_client: ApiClient, struct: dict[str, Any]
|
1050
|
-
):
|
1051
|
-
if (name := struct.get('name')) and '/operations/' in name:
|
1052
|
-
operation: dict[str, Any] = struct
|
1053
|
-
total_seconds = 0.0
|
1054
|
-
delay_seconds = _LRO_POLLING_INITIAL_DELAY_SECONDS
|
1055
|
-
while not operation.get('done'):
|
1056
|
-
if total_seconds > _LRO_POLLING_TIMEOUT_SECONDS:
|
1057
|
-
raise RuntimeError(f'Operation {name} timed out.\n{operation}')
|
1058
|
-
# TODO(b/374433890): Replace with LRO module once it's available.
|
1059
|
-
operation: dict[str, Any] = await api_client.async_request(
|
1060
|
-
http_method='GET', path=name, request_dict={}
|
1061
|
-
)
|
1062
|
-
if 'ReplayApiClient' not in type(api_client).__name__:
|
1063
|
-
time.sleep(delay_seconds)
|
1064
|
-
total_seconds += total_seconds
|
1065
|
-
# Exponential backoff
|
1066
|
-
delay_seconds = min(
|
1067
|
-
delay_seconds * _LRO_POLLING_MULTIPLIER,
|
1068
|
-
_LRO_POLLING_MAXIMUM_DELAY_SECONDS,
|
1069
|
-
)
|
1070
|
-
if error := operation.get('error'):
|
1071
|
-
raise RuntimeError(
|
1072
|
-
f'Operation {name} failed with error: {error}.\n{operation}'
|
1073
|
-
)
|
1074
|
-
return operation.get('response')
|
1075
|
-
else:
|
1076
|
-
return struct
|
1077
|
-
|
1078
|
-
|
1079
1015
|
class AsyncTunings(_api_module.BaseModule):
|
1080
1016
|
|
1081
1017
|
async def _get(
|
@@ -1370,13 +1306,14 @@ class AsyncTunings(_api_module.BaseModule):
|
|
1370
1306
|
config=config,
|
1371
1307
|
)
|
1372
1308
|
operation_dict = operation.to_json_dict()
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1309
|
+
try:
|
1310
|
+
tuned_model_name = operation_dict['metadata']['tunedModel']
|
1311
|
+
except KeyError:
|
1312
|
+
tuned_model_name = operation_dict['name'].partition('/operations/')[0]
|
1313
|
+
tuning_job = types.TuningJob(
|
1314
|
+
name=tuned_model_name,
|
1315
|
+
state=types.JobState.JOB_STATE_QUEUED,
|
1378
1316
|
)
|
1379
|
-
tuning_job = types.TuningJob._from_response(tuning_job_dict, None)
|
1380
1317
|
if tuning_job.name and self._api_client.vertexai:
|
1381
1318
|
_IpythonUtils.display_model_tuning_button(
|
1382
1319
|
tuning_job_resource=tuning_job.name
|
google/genai/types.py
CHANGED
@@ -971,13 +971,24 @@ class FunctionDeclaration(_common.BaseModel):
|
|
971
971
|
)
|
972
972
|
|
973
973
|
@classmethod
|
974
|
-
def
|
974
|
+
def from_callable_with_api_option(
|
975
975
|
cls,
|
976
976
|
*,
|
977
|
-
client,
|
978
977
|
callable: Callable,
|
978
|
+
api_option: Literal['VERTEX_AI', 'GEMINI_API'] = 'GEMINI_API',
|
979
979
|
) -> 'FunctionDeclaration':
|
980
|
-
"""Converts a Callable to a FunctionDeclaration based on the
|
980
|
+
"""Converts a Callable to a FunctionDeclaration based on the API option.
|
981
|
+
|
982
|
+
Supported API option is 'VERTEX_AI' or 'GEMINI_API'. If api_option is unset,
|
983
|
+
it will default to 'GEMINI_API'. If unsupported api_option is provided, it
|
984
|
+
will raise ValueError.
|
985
|
+
"""
|
986
|
+
supported_api_options = ['VERTEX_AI', 'GEMINI_API']
|
987
|
+
if api_option not in supported_api_options:
|
988
|
+
raise ValueError(
|
989
|
+
f'Unsupported api_option value: {api_option}. Supported api_option'
|
990
|
+
f' value is one of: {supported_api_options}.'
|
991
|
+
)
|
981
992
|
from . import _automatic_function_calling_util
|
982
993
|
|
983
994
|
parameters_properties = {}
|
@@ -988,7 +999,7 @@ class FunctionDeclaration(_common.BaseModel):
|
|
988
999
|
inspect.Parameter.POSITIONAL_ONLY,
|
989
1000
|
):
|
990
1001
|
schema = _automatic_function_calling_util._parse_schema_from_parameter(
|
991
|
-
|
1002
|
+
api_option, param, callable.__name__
|
992
1003
|
)
|
993
1004
|
parameters_properties[name] = schema
|
994
1005
|
declaration = FunctionDeclaration(
|
@@ -1000,13 +1011,13 @@ class FunctionDeclaration(_common.BaseModel):
|
|
1000
1011
|
type='OBJECT',
|
1001
1012
|
properties=parameters_properties,
|
1002
1013
|
)
|
1003
|
-
if
|
1014
|
+
if api_option == 'VERTEX_AI':
|
1004
1015
|
declaration.parameters.required = (
|
1005
1016
|
_automatic_function_calling_util._get_required_fields(
|
1006
1017
|
declaration.parameters
|
1007
1018
|
)
|
1008
1019
|
)
|
1009
|
-
if
|
1020
|
+
if api_option == 'GEMINI_API':
|
1010
1021
|
return declaration
|
1011
1022
|
|
1012
1023
|
return_annotation = inspect.signature(callable).return_annotation
|
@@ -1015,7 +1026,7 @@ class FunctionDeclaration(_common.BaseModel):
|
|
1015
1026
|
|
1016
1027
|
declaration.response = (
|
1017
1028
|
_automatic_function_calling_util._parse_schema_from_parameter(
|
1018
|
-
|
1029
|
+
api_option,
|
1019
1030
|
inspect.Parameter(
|
1020
1031
|
'return_value',
|
1021
1032
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -1026,6 +1037,23 @@ class FunctionDeclaration(_common.BaseModel):
|
|
1026
1037
|
)
|
1027
1038
|
return declaration
|
1028
1039
|
|
1040
|
+
@classmethod
|
1041
|
+
def from_callable(
|
1042
|
+
cls,
|
1043
|
+
*,
|
1044
|
+
client,
|
1045
|
+
callable: Callable,
|
1046
|
+
) -> 'FunctionDeclaration':
|
1047
|
+
"""Converts a Callable to a FunctionDeclaration based on the client."""
|
1048
|
+
if client.vertexai:
|
1049
|
+
return cls.from_callable_with_api_option(
|
1050
|
+
callable=callable, api_option='VERTEX_AI'
|
1051
|
+
)
|
1052
|
+
else:
|
1053
|
+
return cls.from_callable_with_api_option(
|
1054
|
+
callable=callable, api_option='GEMINI_API'
|
1055
|
+
)
|
1056
|
+
|
1029
1057
|
|
1030
1058
|
class FunctionDeclarationDict(TypedDict, total=False):
|
1031
1059
|
"""Defines a function that the model can generate JSON inputs for.
|
@@ -1909,6 +1937,20 @@ class GenerateContentConfig(_common.BaseModel):
|
|
1909
1937
|
""",
|
1910
1938
|
)
|
1911
1939
|
|
1940
|
+
@pydantic.field_validator('response_schema', mode='before')
|
1941
|
+
@classmethod
|
1942
|
+
def _convert_literal_to_enum(cls, value):
|
1943
|
+
if typing.get_origin(value) is typing.Literal:
|
1944
|
+
enum_vals = typing.get_args(value)
|
1945
|
+
if not all(isinstance(arg, str) for arg in enum_vals):
|
1946
|
+
# This doesn't stop execution, it tells pydantic to raise a ValidationError
|
1947
|
+
# when the class is instantiated with an unsupported Literal
|
1948
|
+
raise ValueError(f'Literal type {value} must be a list of strings.')
|
1949
|
+
# The title 'PlaceholderLiteralEnum' is removed from the generated Schema
|
1950
|
+
# before sending the request
|
1951
|
+
return Enum('PlaceholderLiteralEnum', {s: s for s in enum_vals})
|
1952
|
+
return value
|
1953
|
+
|
1912
1954
|
|
1913
1955
|
class GenerateContentConfigDict(TypedDict, total=False):
|
1914
1956
|
"""Optional model configuration parameters.
|
@@ -2815,7 +2857,7 @@ class GenerateContentResponse(_common.BaseModel):
|
|
2815
2857
|
text = ''
|
2816
2858
|
any_text_part_text = False
|
2817
2859
|
for part in self.candidates[0].content.parts:
|
2818
|
-
for field_name, field_value in part.
|
2860
|
+
for field_name, field_value in part.model_dump(
|
2819
2861
|
exclude={'text', 'thought'}
|
2820
2862
|
).items():
|
2821
2863
|
if field_value is not None:
|
@@ -2883,6 +2925,11 @@ class GenerateContentResponse(_common.BaseModel):
|
|
2883
2925
|
enum_value = result.text.replace('"', '')
|
2884
2926
|
try:
|
2885
2927
|
result.parsed = response_schema(enum_value)
|
2928
|
+
if (
|
2929
|
+
hasattr(response_schema, '__name__')
|
2930
|
+
and response_schema.__name__ == 'PlaceholderLiteralEnum'
|
2931
|
+
):
|
2932
|
+
result.parsed = str(response_schema(enum_value).name)
|
2886
2933
|
except ValueError:
|
2887
2934
|
pass
|
2888
2935
|
elif isinstance(response_schema, GenericAlias) or isinstance(
|
google/genai/version.py
CHANGED