google-genai 1.4.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +207 -111
- google/genai/_automatic_function_calling_util.py +6 -16
- google/genai/_common.py +5 -2
- google/genai/_extra_utils.py +62 -47
- google/genai/_replay_api_client.py +70 -2
- google/genai/_transformers.py +98 -57
- google/genai/batches.py +14 -10
- google/genai/caches.py +30 -36
- google/genai/client.py +3 -2
- google/genai/errors.py +11 -19
- google/genai/files.py +28 -15
- google/genai/live.py +276 -93
- google/genai/models.py +201 -112
- google/genai/operations.py +40 -12
- google/genai/pagers.py +17 -10
- google/genai/tunings.py +40 -30
- google/genai/types.py +146 -58
- google/genai/version.py +1 -1
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/METADATA +194 -24
- google_genai-1.6.0.dist-info/RECORD +27 -0
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/WHEEL +1 -1
- google_genai-1.4.0.dist-info/RECORD +0 -27
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/LICENSE +0 -0
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/top_level.txt +0 -0
google/genai/_extra_utils.py
CHANGED
@@ -17,9 +17,9 @@
|
|
17
17
|
|
18
18
|
import inspect
|
19
19
|
import logging
|
20
|
+
import sys
|
20
21
|
import typing
|
21
22
|
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
22
|
-
import sys
|
23
23
|
|
24
24
|
import pydantic
|
25
25
|
|
@@ -37,6 +37,15 @@ _DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
|
37
37
|
logger = logging.getLogger('google_genai.models')
|
38
38
|
|
39
39
|
|
40
|
+
def _create_generate_content_config_model(
|
41
|
+
config: types.GenerateContentConfigOrDict,
|
42
|
+
) -> types.GenerateContentConfig:
|
43
|
+
if isinstance(config, dict):
|
44
|
+
return types.GenerateContentConfig(**config)
|
45
|
+
else:
|
46
|
+
return config
|
47
|
+
|
48
|
+
|
40
49
|
def format_destination(
|
41
50
|
src: str,
|
42
51
|
config: Optional[types.CreateBatchJobConfigOrDict] = None,
|
@@ -69,16 +78,12 @@ def format_destination(
|
|
69
78
|
|
70
79
|
def get_function_map(
|
71
80
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
72
|
-
) -> dict[str,
|
81
|
+
) -> dict[str, Callable]:
|
73
82
|
"""Returns a function map from the config."""
|
74
|
-
|
75
|
-
|
76
|
-
if config and isinstance(config, dict)
|
77
|
-
else config
|
78
|
-
)
|
79
|
-
function_map: dict[str, object] = {}
|
80
|
-
if not config_model:
|
83
|
+
function_map: dict[str, Callable] = {}
|
84
|
+
if not config:
|
81
85
|
return function_map
|
86
|
+
config_model = _create_generate_content_config_model(config)
|
82
87
|
if config_model.tools:
|
83
88
|
for tool in config_model.tools:
|
84
89
|
if callable(tool):
|
@@ -92,6 +97,16 @@ def get_function_map(
|
|
92
97
|
return function_map
|
93
98
|
|
94
99
|
|
100
|
+
def convert_number_values_for_dict_function_call_args(
|
101
|
+
args: dict[str, Any],
|
102
|
+
) -> dict[str, Any]:
|
103
|
+
"""Converts float values in dict with no decimal to integers."""
|
104
|
+
return {
|
105
|
+
key: convert_number_values_for_function_call_args(value)
|
106
|
+
for key, value in args.items()
|
107
|
+
}
|
108
|
+
|
109
|
+
|
95
110
|
def convert_number_values_for_function_call_args(
|
96
111
|
args: Union[dict[str, object], list[object], object],
|
97
112
|
) -> Union[dict[str, object], list[object], object]:
|
@@ -210,26 +225,35 @@ def invoke_function_from_dict_args(
|
|
210
225
|
|
211
226
|
def get_function_response_parts(
|
212
227
|
response: types.GenerateContentResponse,
|
213
|
-
function_map: dict[str,
|
228
|
+
function_map: dict[str, Callable],
|
214
229
|
) -> list[types.Part]:
|
215
230
|
"""Returns the function response parts from the response."""
|
216
231
|
func_response_parts = []
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
232
|
+
if (
|
233
|
+
response.candidates is not None
|
234
|
+
and isinstance(response.candidates[0].content, types.Content)
|
235
|
+
and response.candidates[0].content.parts is not None
|
236
|
+
):
|
237
|
+
for part in response.candidates[0].content.parts:
|
238
|
+
if not part.function_call:
|
239
|
+
continue
|
240
|
+
func_name = part.function_call.name
|
241
|
+
if func_name is not None and part.function_call.args is not None:
|
242
|
+
func = function_map[func_name]
|
243
|
+
args = convert_number_values_for_dict_function_call_args(
|
244
|
+
part.function_call.args
|
245
|
+
)
|
246
|
+
func_response: dict[str, Any]
|
247
|
+
try:
|
248
|
+
func_response = {
|
249
|
+
'result': invoke_function_from_dict_args(args, func)
|
250
|
+
}
|
251
|
+
except Exception as e: # pylint: disable=broad-except
|
252
|
+
func_response = {'error': str(e)}
|
253
|
+
func_response_part = types.Part.from_function_response(
|
254
|
+
name=func_name, response=func_response
|
255
|
+
)
|
256
|
+
func_response_parts.append(func_response_part)
|
233
257
|
return func_response_parts
|
234
258
|
|
235
259
|
|
@@ -237,12 +261,9 @@ def should_disable_afc(
|
|
237
261
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
238
262
|
) -> bool:
|
239
263
|
"""Returns whether automatic function calling is enabled."""
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
else config
|
244
|
-
)
|
245
|
-
|
264
|
+
if not config:
|
265
|
+
return False
|
266
|
+
config_model = _create_generate_content_config_model(config)
|
246
267
|
# If max_remote_calls is less or equal to 0, warn and disable AFC.
|
247
268
|
if (
|
248
269
|
config_model
|
@@ -261,8 +282,7 @@ def should_disable_afc(
|
|
261
282
|
|
262
283
|
# Default to enable AFC if not specified.
|
263
284
|
if (
|
264
|
-
not config_model
|
265
|
-
or not config_model.automatic_function_calling
|
285
|
+
not config_model.automatic_function_calling
|
266
286
|
or config_model.automatic_function_calling.disable is None
|
267
287
|
):
|
268
288
|
return False
|
@@ -295,20 +315,17 @@ def should_disable_afc(
|
|
295
315
|
def get_max_remote_calls_afc(
|
296
316
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
297
317
|
) -> int:
|
318
|
+
if not config:
|
319
|
+
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
298
320
|
"""Returns the remaining remote calls for automatic function calling."""
|
299
321
|
if should_disable_afc(config):
|
300
322
|
raise ValueError(
|
301
323
|
'automatic function calling is not enabled, but SDK is trying to get'
|
302
324
|
' max remote calls.'
|
303
325
|
)
|
304
|
-
config_model = (
|
305
|
-
types.GenerateContentConfig(**config)
|
306
|
-
if config and isinstance(config, dict)
|
307
|
-
else config
|
308
|
-
)
|
326
|
+
config_model = _create_generate_content_config_model(config)
|
309
327
|
if (
|
310
|
-
not config_model
|
311
|
-
or not config_model.automatic_function_calling
|
328
|
+
not config_model.automatic_function_calling
|
312
329
|
or config_model.automatic_function_calling.maximum_remote_calls is None
|
313
330
|
):
|
314
331
|
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
@@ -318,11 +335,9 @@ def get_max_remote_calls_afc(
|
|
318
335
|
def should_append_afc_history(
|
319
336
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
320
337
|
) -> bool:
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
)
|
326
|
-
if not config_model or not config_model.automatic_function_calling:
|
338
|
+
if not config:
|
339
|
+
return True
|
340
|
+
config_model = _create_generate_content_config_model(config)
|
341
|
+
if not config_model.automatic_function_calling:
|
327
342
|
return True
|
328
343
|
return not config_model.automatic_function_calling.ignore_call_history
|
@@ -109,7 +109,8 @@ def _redact_project_location_path(path: str) -> str:
|
|
109
109
|
return path
|
110
110
|
|
111
111
|
|
112
|
-
def _redact_request_body(body: dict[str, object])
|
112
|
+
def _redact_request_body(body: dict[str, object]):
|
113
|
+
"""Redacts fields in the request body in place."""
|
113
114
|
for key, value in body.items():
|
114
115
|
if isinstance(value, str):
|
115
116
|
body[key] = _redact_project_location_path(value)
|
@@ -302,13 +303,24 @@ class ReplayApiClient(BaseApiClient):
|
|
302
303
|
status_code=http_response.status_code,
|
303
304
|
sdk_response_segments=[],
|
304
305
|
)
|
305
|
-
|
306
|
+
elif isinstance(http_response, errors.APIError):
|
306
307
|
response = ReplayResponse(
|
307
308
|
headers=dict(http_response.response.headers),
|
308
309
|
body_segments=[http_response._to_replay_record()],
|
309
310
|
status_code=http_response.code,
|
310
311
|
sdk_response_segments=[],
|
311
312
|
)
|
313
|
+
elif isinstance(http_response, bytes):
|
314
|
+
response = ReplayResponse(
|
315
|
+
headers={},
|
316
|
+
body_segments=[],
|
317
|
+
byte_segments=[http_response],
|
318
|
+
sdk_response_segments=[],
|
319
|
+
)
|
320
|
+
else:
|
321
|
+
raise ValueError(
|
322
|
+
'Unsupported http_response type: ' + str(type(http_response))
|
323
|
+
)
|
312
324
|
self.replay_session.interactions.append(
|
313
325
|
ReplayInteraction(request=request, response=response)
|
314
326
|
)
|
@@ -471,6 +483,43 @@ class ReplayApiClient(BaseApiClient):
|
|
471
483
|
else:
|
472
484
|
return self._build_response_from_replay(request).json
|
473
485
|
|
486
|
+
async def async_upload_file(
|
487
|
+
self,
|
488
|
+
file_path: Union[str, io.IOBase],
|
489
|
+
upload_url: str,
|
490
|
+
upload_size: int,
|
491
|
+
) -> str:
|
492
|
+
if isinstance(file_path, io.IOBase):
|
493
|
+
offset = file_path.tell()
|
494
|
+
content = file_path.read()
|
495
|
+
file_path.seek(offset, os.SEEK_SET)
|
496
|
+
request = HttpRequest(
|
497
|
+
method='POST',
|
498
|
+
url='',
|
499
|
+
data={'bytes': base64.b64encode(content).decode('utf-8')},
|
500
|
+
headers={},
|
501
|
+
)
|
502
|
+
else:
|
503
|
+
request = HttpRequest(
|
504
|
+
method='POST', url='', data={'file_path': file_path}, headers={}
|
505
|
+
)
|
506
|
+
if self._should_call_api():
|
507
|
+
result: Union[str, HttpResponse]
|
508
|
+
try:
|
509
|
+
result = await super().async_upload_file(
|
510
|
+
file_path, upload_url, upload_size
|
511
|
+
)
|
512
|
+
except HTTPError as e:
|
513
|
+
result = HttpResponse(
|
514
|
+
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
515
|
+
)
|
516
|
+
result.status_code = e.response.status_code
|
517
|
+
raise e
|
518
|
+
self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
|
519
|
+
return result
|
520
|
+
else:
|
521
|
+
return self._build_response_from_replay(request).json
|
522
|
+
|
474
523
|
def _download_file_request(self, request):
|
475
524
|
self._initialize_replay_session_if_not_loaded()
|
476
525
|
if self._should_call_api():
|
@@ -486,3 +535,22 @@ class ReplayApiClient(BaseApiClient):
|
|
486
535
|
return result
|
487
536
|
else:
|
488
537
|
return self._build_response_from_replay(request)
|
538
|
+
|
539
|
+
async def async_download_file(self, path: str, http_options):
|
540
|
+
self._initialize_replay_session_if_not_loaded()
|
541
|
+
request = self._build_request(
|
542
|
+
'get', path=path, request_dict={}, http_options=http_options
|
543
|
+
)
|
544
|
+
if self._should_call_api():
|
545
|
+
try:
|
546
|
+
result = await super().async_download_file(path, http_options)
|
547
|
+
except HTTPError as e:
|
548
|
+
result = HttpResponse(
|
549
|
+
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
550
|
+
)
|
551
|
+
result.status_code = e.response.status_code
|
552
|
+
raise e
|
553
|
+
self._record_interaction(request, result)
|
554
|
+
return result
|
555
|
+
else:
|
556
|
+
return self._build_response_from_replay(request).byte_stream[0]
|
google/genai/_transformers.py
CHANGED
@@ -26,9 +26,7 @@ import sys
|
|
26
26
|
import time
|
27
27
|
import types as builtin_types
|
28
28
|
import typing
|
29
|
-
from typing import Any, GenericAlias, Optional, Union
|
30
|
-
|
31
|
-
import types as builtin_types
|
29
|
+
from typing import Any, GenericAlias, Optional, Union # type: ignore[attr-defined]
|
32
30
|
|
33
31
|
if typing.TYPE_CHECKING:
|
34
32
|
import PIL.Image
|
@@ -43,10 +41,11 @@ logger = logging.getLogger('google_genai._transformers')
|
|
43
41
|
if sys.version_info >= (3, 10):
|
44
42
|
VersionedUnionType = builtin_types.UnionType
|
45
43
|
_UNION_TYPES = (typing.Union, builtin_types.UnionType)
|
44
|
+
from typing import TypeGuard
|
46
45
|
else:
|
47
46
|
VersionedUnionType = typing._UnionGenericAlias
|
48
47
|
_UNION_TYPES = (typing.Union,)
|
49
|
-
|
48
|
+
from typing_extensions import TypeGuard
|
50
49
|
|
51
50
|
def _resource_name(
|
52
51
|
client: _api_client.BaseApiClient,
|
@@ -165,7 +164,9 @@ def t_model(client: _api_client.BaseApiClient, model: str):
|
|
165
164
|
return f'models/{model}'
|
166
165
|
|
167
166
|
|
168
|
-
def t_models_url(
|
167
|
+
def t_models_url(
|
168
|
+
api_client: _api_client.BaseApiClient, base_models: bool
|
169
|
+
) -> str:
|
169
170
|
if api_client.vertexai:
|
170
171
|
if base_models:
|
171
172
|
return 'publishers/google/models'
|
@@ -179,8 +180,9 @@ def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> st
|
|
179
180
|
|
180
181
|
|
181
182
|
def t_extract_models(
|
182
|
-
api_client: _api_client.BaseApiClient,
|
183
|
-
|
183
|
+
api_client: _api_client.BaseApiClient,
|
184
|
+
response: dict[str, list[types.ModelDict]],
|
185
|
+
) -> Optional[list[types.ModelDict]]:
|
184
186
|
if not response:
|
185
187
|
return []
|
186
188
|
elif response.get('models') is not None:
|
@@ -240,9 +242,7 @@ def pil_to_blob(img) -> types.Blob:
|
|
240
242
|
return types.Blob(mime_type=mime_type, data=data)
|
241
243
|
|
242
244
|
|
243
|
-
def t_part(
|
244
|
-
client: _api_client.BaseApiClient, part: Optional[types.PartUnionDict]
|
245
|
-
) -> types.Part:
|
245
|
+
def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
|
246
246
|
try:
|
247
247
|
import PIL.Image
|
248
248
|
|
@@ -268,22 +268,21 @@ def t_part(
|
|
268
268
|
|
269
269
|
|
270
270
|
def t_parts(
|
271
|
-
|
272
|
-
parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
|
271
|
+
parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]],
|
273
272
|
) -> list[types.Part]:
|
274
273
|
#
|
275
274
|
if parts is None or (isinstance(parts, list) and not parts):
|
276
275
|
raise ValueError('content parts are required.')
|
277
276
|
if isinstance(parts, list):
|
278
|
-
return [t_part(
|
277
|
+
return [t_part(part) for part in parts]
|
279
278
|
else:
|
280
|
-
return [t_part(
|
279
|
+
return [t_part(parts)]
|
281
280
|
|
282
281
|
|
283
282
|
def t_image_predictions(
|
284
283
|
client: _api_client.BaseApiClient,
|
285
284
|
predictions: Optional[Iterable[Mapping[str, Any]]],
|
286
|
-
) -> list[types.GeneratedImage]:
|
285
|
+
) -> Optional[list[types.GeneratedImage]]:
|
287
286
|
if not predictions:
|
288
287
|
return None
|
289
288
|
images = []
|
@@ -333,22 +332,35 @@ def t_content(
|
|
333
332
|
def t_contents_for_embed(
|
334
333
|
client: _api_client.BaseApiClient,
|
335
334
|
contents: Union[list[types.Content], list[types.ContentDict], ContentType],
|
336
|
-
):
|
337
|
-
if
|
338
|
-
|
339
|
-
return [t_content(client, content).parts[0].text for content in contents]
|
340
|
-
elif client.vertexai:
|
341
|
-
return [t_content(client, contents).parts[0].text]
|
342
|
-
elif isinstance(contents, list):
|
343
|
-
return [t_content(client, content) for content in contents]
|
335
|
+
) -> Union[list[str], list[types.Content]]:
|
336
|
+
if isinstance(contents, list):
|
337
|
+
transformed_contents = [t_content(client, content) for content in contents]
|
344
338
|
else:
|
345
|
-
|
339
|
+
transformed_contents = [t_content(client, contents)]
|
340
|
+
|
341
|
+
if client.vertexai:
|
342
|
+
text_parts = []
|
343
|
+
for content in transformed_contents:
|
344
|
+
if content is not None:
|
345
|
+
if isinstance(content, dict):
|
346
|
+
content = types.Content.model_validate(content)
|
347
|
+
if content.parts is not None:
|
348
|
+
for part in content.parts:
|
349
|
+
if part.text:
|
350
|
+
text_parts.append(part.text)
|
351
|
+
else:
|
352
|
+
logger.warning(
|
353
|
+
f'Non-text part found, only returning text parts.'
|
354
|
+
)
|
355
|
+
return text_parts
|
356
|
+
else:
|
357
|
+
return transformed_contents
|
346
358
|
|
347
359
|
|
348
360
|
def t_contents(
|
349
361
|
client: _api_client.BaseApiClient,
|
350
362
|
contents: Optional[
|
351
|
-
Union[types.ContentListUnion, types.ContentListUnionDict]
|
363
|
+
Union[types.ContentListUnion, types.ContentListUnionDict, types.Content]
|
352
364
|
],
|
353
365
|
) -> list[types.Content]:
|
354
366
|
if contents is None or (isinstance(contents, list) and not contents):
|
@@ -366,7 +378,7 @@ def t_contents(
|
|
366
378
|
result: list[types.Content] = []
|
367
379
|
accumulated_parts: list[types.Part] = []
|
368
380
|
|
369
|
-
def _is_part(part: types.PartUnionDict) ->
|
381
|
+
def _is_part(part: Union[types.PartUnionDict, Any]) -> TypeGuard[types.PartUnionDict]:
|
370
382
|
if (
|
371
383
|
isinstance(part, str)
|
372
384
|
or isinstance(part, types.File)
|
@@ -408,7 +420,7 @@ def t_contents(
|
|
408
420
|
accumulated_parts: list[types.Part],
|
409
421
|
current_part: types.PartUnionDict,
|
410
422
|
):
|
411
|
-
current_part = t_part(
|
423
|
+
current_part = t_part(current_part)
|
412
424
|
if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
|
413
425
|
accumulated_parts.append(current_part)
|
414
426
|
else:
|
@@ -430,11 +442,11 @@ def t_contents(
|
|
430
442
|
):
|
431
443
|
_append_accumulated_parts_as_content(result, accumulated_parts)
|
432
444
|
if isinstance(content, list):
|
433
|
-
result.append(types.UserContent(parts=content))
|
445
|
+
result.append(types.UserContent(parts=content)) # type: ignore[arg-type]
|
434
446
|
else:
|
435
447
|
result.append(content)
|
436
|
-
elif (_is_part(content)):
|
437
|
-
_handle_current_part(result, accumulated_parts, content)
|
448
|
+
elif (_is_part(content)):
|
449
|
+
_handle_current_part(result, accumulated_parts, content)
|
438
450
|
elif isinstance(content, dict):
|
439
451
|
# PactDict is already handled in _is_part
|
440
452
|
result.append(types.Content.model_validate(content))
|
@@ -500,14 +512,14 @@ def handle_null_fields(schema: dict[str, Any]):
|
|
500
512
|
schema['anyOf'].remove({'type': 'null'})
|
501
513
|
if len(schema['anyOf']) == 1:
|
502
514
|
# If there is only one type left after removing null, remove the anyOf field.
|
503
|
-
for key,val in schema['anyOf'][0].items():
|
515
|
+
for key, val in schema['anyOf'][0].items():
|
504
516
|
schema[key] = val
|
505
517
|
del schema['anyOf']
|
506
518
|
|
507
519
|
|
508
520
|
def process_schema(
|
509
521
|
schema: dict[str, Any],
|
510
|
-
client:
|
522
|
+
client: _api_client.BaseApiClient,
|
511
523
|
defs: Optional[dict[str, Any]] = None,
|
512
524
|
*,
|
513
525
|
order_properties: bool = True,
|
@@ -570,12 +582,13 @@ def process_schema(
|
|
570
582
|
'type': 'array'
|
571
583
|
}
|
572
584
|
"""
|
573
|
-
if
|
585
|
+
if not client.vertexai:
|
574
586
|
schema.pop('title', None)
|
575
587
|
|
576
588
|
if schema.get('default') is not None:
|
577
589
|
raise ValueError(
|
578
|
-
'Default value is not supported in the response schema for the Gemini
|
590
|
+
'Default value is not supported in the response schema for the Gemini'
|
591
|
+
' API.'
|
579
592
|
)
|
580
593
|
|
581
594
|
if schema.get('title') == 'PlaceholderLiteralEnum':
|
@@ -596,18 +609,15 @@ def process_schema(
|
|
596
609
|
# After removing null fields, Optional fields with only one possible type
|
597
610
|
# will have a $ref key that needs to be flattened
|
598
611
|
# For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
|
599
|
-
|
600
|
-
|
612
|
+
schema_ref = schema.get('$ref', None)
|
613
|
+
if schema_ref is not None:
|
614
|
+
ref = defs[schema_ref.split('defs/')[-1]]
|
601
615
|
for schema_key in list(ref.keys()):
|
602
616
|
schema[schema_key] = ref[schema_key]
|
603
617
|
del schema['$ref']
|
604
618
|
|
605
619
|
any_of = schema.get('anyOf', None)
|
606
620
|
if any_of is not None:
|
607
|
-
if not client.vertexai:
|
608
|
-
raise ValueError(
|
609
|
-
'AnyOf is not supported in the response schema for the Gemini API.'
|
610
|
-
)
|
611
621
|
for sub_schema in any_of:
|
612
622
|
# $ref is present in any_of if the schema is a union of Pydantic classes
|
613
623
|
ref_key = sub_schema.get('$ref', None)
|
@@ -670,7 +680,7 @@ def process_schema(
|
|
670
680
|
|
671
681
|
|
672
682
|
def _process_enum(
|
673
|
-
enum: EnumMeta, client:
|
683
|
+
enum: EnumMeta, client: _api_client.BaseApiClient
|
674
684
|
) -> types.Schema:
|
675
685
|
for member in enum: # type: ignore
|
676
686
|
if not isinstance(member.value, str):
|
@@ -680,7 +690,7 @@ def _process_enum(
|
|
680
690
|
)
|
681
691
|
|
682
692
|
class Placeholder(pydantic.BaseModel):
|
683
|
-
placeholder: enum
|
693
|
+
placeholder: enum # type: ignore[valid-type]
|
684
694
|
|
685
695
|
enum_schema = Placeholder.model_json_schema()
|
686
696
|
process_schema(enum_schema, client)
|
@@ -688,12 +698,19 @@ def _process_enum(
|
|
688
698
|
return types.Schema.model_validate(enum_schema)
|
689
699
|
|
690
700
|
|
701
|
+
def _is_type_dict_str_any(origin: Union[types.SchemaUnionDict, Any]) -> TypeGuard[dict[str, Any]]:
|
702
|
+
"""Verifies the schema is of type dict[str, Any] for mypy type checking."""
|
703
|
+
return isinstance(origin, dict) and all(
|
704
|
+
isinstance(key, str) for key in origin
|
705
|
+
)
|
706
|
+
|
707
|
+
|
691
708
|
def t_schema(
|
692
709
|
client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
|
693
710
|
) -> Optional[types.Schema]:
|
694
711
|
if not origin:
|
695
712
|
return None
|
696
|
-
if isinstance(origin, dict):
|
713
|
+
if isinstance(origin, dict) and _is_type_dict_str_any(origin):
|
697
714
|
process_schema(origin, client, order_properties=False)
|
698
715
|
return types.Schema.model_validate(origin)
|
699
716
|
if isinstance(origin, EnumMeta):
|
@@ -724,7 +741,7 @@ def t_schema(
|
|
724
741
|
):
|
725
742
|
|
726
743
|
class Placeholder(pydantic.BaseModel):
|
727
|
-
placeholder: origin
|
744
|
+
placeholder: origin # type: ignore[valid-type]
|
728
745
|
|
729
746
|
schema = Placeholder.model_json_schema()
|
730
747
|
process_schema(schema, client)
|
@@ -735,7 +752,8 @@ def t_schema(
|
|
735
752
|
|
736
753
|
|
737
754
|
def t_speech_config(
|
738
|
-
_: _api_client.BaseApiClient,
|
755
|
+
_: _api_client.BaseApiClient,
|
756
|
+
origin: Union[types.SpeechConfigUnionDict, Any],
|
739
757
|
) -> Optional[types.SpeechConfig]:
|
740
758
|
if not origin:
|
741
759
|
return None
|
@@ -750,7 +768,10 @@ def t_speech_config(
|
|
750
768
|
if (
|
751
769
|
isinstance(origin, dict)
|
752
770
|
and 'voice_config' in origin
|
771
|
+
and origin['voice_config'] is not None
|
753
772
|
and 'prebuilt_voice_config' in origin['voice_config']
|
773
|
+
and origin['voice_config']['prebuilt_voice_config'] is not None
|
774
|
+
and 'voice_name' in origin['voice_config']['prebuilt_voice_config']
|
754
775
|
):
|
755
776
|
return types.SpeechConfig(
|
756
777
|
voice_config=types.VoiceConfig(
|
@@ -764,7 +785,7 @@ def t_speech_config(
|
|
764
785
|
raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
|
765
786
|
|
766
787
|
|
767
|
-
def t_tool(client: _api_client.BaseApiClient, origin) -> types.Tool:
|
788
|
+
def t_tool(client: _api_client.BaseApiClient, origin) -> Optional[types.Tool]:
|
768
789
|
if not origin:
|
769
790
|
return None
|
770
791
|
if inspect.isfunction(origin) or inspect.ismethod(origin):
|
@@ -790,12 +811,16 @@ def t_tools(
|
|
790
811
|
for tool in origin:
|
791
812
|
transformed_tool = t_tool(client, tool)
|
792
813
|
# All functions should be merged into one tool.
|
793
|
-
if transformed_tool
|
794
|
-
|
814
|
+
if transformed_tool is not None:
|
815
|
+
if (
|
795
816
|
transformed_tool.function_declarations
|
796
|
-
|
797
|
-
|
798
|
-
|
817
|
+
and function_tool.function_declarations is not None
|
818
|
+
):
|
819
|
+
function_tool.function_declarations += (
|
820
|
+
transformed_tool.function_declarations
|
821
|
+
)
|
822
|
+
else:
|
823
|
+
tools.append(transformed_tool)
|
799
824
|
if function_tool.function_declarations:
|
800
825
|
tools.append(function_tool)
|
801
826
|
return tools
|
@@ -883,15 +908,28 @@ def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
|
|
883
908
|
|
884
909
|
|
885
910
|
def t_file_name(
|
886
|
-
api_client: _api_client.BaseApiClient,
|
911
|
+
api_client: _api_client.BaseApiClient,
|
912
|
+
name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
|
887
913
|
):
|
888
914
|
# Remove the files/ prefix since it's added to the url path.
|
889
915
|
if isinstance(name, types.File):
|
890
916
|
name = name.name
|
917
|
+
elif isinstance(name, types.Video):
|
918
|
+
name = name.uri
|
919
|
+
elif isinstance(name, types.GeneratedVideo):
|
920
|
+
if name.video is not None:
|
921
|
+
name = name.video.uri
|
922
|
+
else:
|
923
|
+
name = None
|
891
924
|
|
892
925
|
if name is None:
|
893
926
|
raise ValueError('File name is required.')
|
894
927
|
|
928
|
+
if not isinstance(name, str):
|
929
|
+
raise ValueError(
|
930
|
+
f'Could not convert object of type `{type(name)}` to a file name.'
|
931
|
+
)
|
932
|
+
|
895
933
|
if name.startswith('https://'):
|
896
934
|
suffix = name.split('files/')[1]
|
897
935
|
match = re.match('[a-z0-9]+', suffix)
|
@@ -906,16 +944,19 @@ def t_file_name(
|
|
906
944
|
|
907
945
|
def t_tuning_job_status(
|
908
946
|
api_client: _api_client.BaseApiClient, status: str
|
909
|
-
) -> types.JobState:
|
947
|
+
) -> Union[types.JobState, str]:
|
910
948
|
if status == 'STATE_UNSPECIFIED':
|
911
|
-
return
|
949
|
+
return types.JobState.JOB_STATE_UNSPECIFIED
|
912
950
|
elif status == 'CREATING':
|
913
|
-
return
|
951
|
+
return types.JobState.JOB_STATE_RUNNING
|
914
952
|
elif status == 'ACTIVE':
|
915
|
-
return
|
953
|
+
return types.JobState.JOB_STATE_SUCCEEDED
|
916
954
|
elif status == 'FAILED':
|
917
|
-
return
|
955
|
+
return types.JobState.JOB_STATE_FAILED
|
918
956
|
else:
|
957
|
+
for state in types.JobState:
|
958
|
+
if str(state.value) == status:
|
959
|
+
return state
|
919
960
|
return status
|
920
961
|
|
921
962
|
|