google-genai 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +100 -31
- google/genai/_automatic_function_calling_util.py +4 -24
- google/genai/_common.py +40 -37
- google/genai/_extra_utils.py +72 -12
- google/genai/_live_converters.py +2487 -0
- google/genai/_replay_api_client.py +32 -26
- google/genai/_transformers.py +119 -25
- google/genai/batches.py +45 -45
- google/genai/caches.py +126 -126
- google/genai/chats.py +13 -9
- google/genai/client.py +3 -2
- google/genai/errors.py +6 -6
- google/genai/files.py +38 -38
- google/genai/live.py +138 -1029
- google/genai/models.py +455 -387
- google/genai/operations.py +33 -33
- google/genai/pagers.py +2 -2
- google/genai/py.typed +1 -0
- google/genai/tunings.py +70 -70
- google/genai/types.py +964 -45
- google/genai/version.py +1 -1
- {google_genai-1.10.0.dist-info → google_genai-1.12.0.dist-info}/METADATA +1 -1
- google_genai-1.12.0.dist-info/RECORD +29 -0
- {google_genai-1.10.0.dist-info → google_genai-1.12.0.dist-info}/WHEEL +1 -1
- google_genai-1.10.0.dist-info/RECORD +0 -27
- {google_genai-1.10.0.dist-info → google_genai-1.12.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.10.0.dist-info → google_genai-1.12.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -20,6 +20,7 @@ The BaseApiClient is intended to be a private module and is subject to change.
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
import asyncio
|
23
|
+
from collections.abc import Awaitable, Generator
|
23
24
|
import copy
|
24
25
|
from dataclasses import dataclass
|
25
26
|
import datetime
|
@@ -29,21 +30,31 @@ import json
|
|
29
30
|
import logging
|
30
31
|
import math
|
31
32
|
import os
|
33
|
+
import ssl
|
32
34
|
import sys
|
33
35
|
import time
|
34
36
|
from typing import Any, AsyncIterator, Optional, Tuple, Union
|
35
|
-
from urllib.parse import urlparse
|
37
|
+
from urllib.parse import urlparse
|
38
|
+
from urllib.parse import urlunparse
|
39
|
+
|
36
40
|
import anyio
|
41
|
+
import certifi
|
37
42
|
import google.auth
|
38
43
|
import google.auth.credentials
|
39
44
|
from google.auth.credentials import Credentials
|
40
45
|
from google.auth.transport.requests import Request
|
41
46
|
import httpx
|
42
|
-
from pydantic import BaseModel
|
47
|
+
from pydantic import BaseModel
|
48
|
+
from pydantic import Field
|
49
|
+
from pydantic import ValidationError
|
50
|
+
|
43
51
|
from . import _common
|
44
52
|
from . import errors
|
45
53
|
from . import version
|
46
|
-
from .types import HttpOptions
|
54
|
+
from .types import HttpOptions
|
55
|
+
from .types import HttpOptionsDict
|
56
|
+
from .types import HttpOptionsOrDict
|
57
|
+
|
47
58
|
|
48
59
|
logger = logging.getLogger('google_genai._api_client')
|
49
60
|
CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
|
@@ -119,7 +130,7 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
119
130
|
|
120
131
|
def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
121
132
|
"""Loads google auth credentials and project id."""
|
122
|
-
credentials, loaded_project_id = google.auth.default(
|
133
|
+
credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
|
123
134
|
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
124
135
|
)
|
125
136
|
|
@@ -135,7 +146,7 @@ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
|
135
146
|
|
136
147
|
|
137
148
|
def _refresh_auth(credentials: Credentials) -> Credentials:
|
138
|
-
credentials.refresh(Request())
|
149
|
+
credentials.refresh(Request()) # type: ignore[no-untyped-call]
|
139
150
|
return credentials
|
140
151
|
|
141
152
|
|
@@ -181,17 +192,17 @@ class HttpResponse:
|
|
181
192
|
response_stream: Union[Any, str] = None,
|
182
193
|
byte_stream: Union[Any, bytes] = None,
|
183
194
|
):
|
184
|
-
self.status_code = 200
|
195
|
+
self.status_code: int = 200
|
185
196
|
self.headers = headers
|
186
197
|
self.response_stream = response_stream
|
187
198
|
self.byte_stream = byte_stream
|
188
199
|
|
189
200
|
# Async iterator for async streaming.
|
190
|
-
def __aiter__(self):
|
201
|
+
def __aiter__(self) -> 'HttpResponse':
|
191
202
|
self.segment_iterator = self.async_segments()
|
192
203
|
return self
|
193
204
|
|
194
|
-
async def __anext__(self):
|
205
|
+
async def __anext__(self) -> Any:
|
195
206
|
try:
|
196
207
|
return await self.segment_iterator.__anext__()
|
197
208
|
except StopIteration:
|
@@ -203,7 +214,7 @@ class HttpResponse:
|
|
203
214
|
return ''
|
204
215
|
return json.loads(self.response_stream[0])
|
205
216
|
|
206
|
-
def segments(self):
|
217
|
+
def segments(self) -> Generator[Any, None, None]:
|
207
218
|
if isinstance(self.response_stream, list):
|
208
219
|
# list of objects retrieved from replay or from non-streaming API.
|
209
220
|
for chunk in self.response_stream:
|
@@ -212,7 +223,7 @@ class HttpResponse:
|
|
212
223
|
yield from []
|
213
224
|
else:
|
214
225
|
# Iterator of objects retrieved from the API.
|
215
|
-
for chunk in self.response_stream.iter_lines():
|
226
|
+
for chunk in self.response_stream.iter_lines(): # type: ignore[union-attr]
|
216
227
|
if chunk:
|
217
228
|
# In streaming mode, the chunk of JSON is prefixed with "data:" which
|
218
229
|
# we must strip before parsing.
|
@@ -246,7 +257,7 @@ class HttpResponse:
|
|
246
257
|
else:
|
247
258
|
raise ValueError('Error parsing streaming response.')
|
248
259
|
|
249
|
-
def byte_segments(self):
|
260
|
+
def byte_segments(self) -> Generator[Union[bytes, Any], None, None]:
|
250
261
|
if isinstance(self.byte_stream, list):
|
251
262
|
# list of objects retrieved from replay or from non-streaming API.
|
252
263
|
yield from self.byte_stream
|
@@ -257,7 +268,7 @@ class HttpResponse:
|
|
257
268
|
'Byte segments are not supported for streaming responses.'
|
258
269
|
)
|
259
270
|
|
260
|
-
def _copy_to_dict(self, response_payload: dict[str, object]):
|
271
|
+
def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
|
261
272
|
# Cannot pickle 'generator' object.
|
262
273
|
delattr(self, 'segment_iterator')
|
263
274
|
for attribute in dir(self):
|
@@ -414,7 +425,7 @@ class BaseApiClient:
|
|
414
425
|
if not self.api_key:
|
415
426
|
raise ValueError(
|
416
427
|
'Missing key inputs argument! To use the Google AI API,'
|
417
|
-
'provide (`api_key`) arguments. To use the Google Cloud API,'
|
428
|
+
' provide (`api_key`) arguments. To use the Google Cloud API,'
|
418
429
|
' provide (`vertexai`, `project` & `location`) arguments.'
|
419
430
|
)
|
420
431
|
self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
|
@@ -432,13 +443,71 @@ class BaseApiClient:
|
|
432
443
|
else:
|
433
444
|
if self._http_options.headers is not None:
|
434
445
|
_append_library_version_headers(self._http_options.headers)
|
435
|
-
# Initialize the httpx client.
|
436
|
-
self._httpx_client = SyncHttpxClient()
|
437
|
-
self._async_httpx_client = AsyncHttpxClient()
|
438
446
|
|
439
|
-
|
447
|
+
client_args, async_client_args = self._ensure_ssl_ctx(self._http_options)
|
448
|
+
self._httpx_client = SyncHttpxClient(**client_args)
|
449
|
+
self._async_httpx_client = AsyncHttpxClient(**async_client_args)
|
450
|
+
|
451
|
+
@staticmethod
|
452
|
+
def _ensure_ssl_ctx(options: HttpOptions) -> (
|
453
|
+
Tuple[dict[str, Any], dict[str, Any]]):
|
454
|
+
"""Ensures the SSL context is present in the client args.
|
455
|
+
|
456
|
+
Creates a default SSL context if one is not provided.
|
457
|
+
|
458
|
+
Args:
|
459
|
+
options: The http options to check for SSL context.
|
460
|
+
|
461
|
+
Returns:
|
462
|
+
A tuple of sync/async httpx client args.
|
463
|
+
"""
|
464
|
+
|
465
|
+
verify = 'verify'
|
466
|
+
args = options.client_args
|
467
|
+
async_args = options.async_client_args
|
468
|
+
ctx = (
|
469
|
+
args.get(verify) if args else None
|
470
|
+
or async_args.get(verify) if async_args else None
|
471
|
+
)
|
472
|
+
|
473
|
+
if not ctx:
|
474
|
+
# Initialize the SSL context for the httpx client.
|
475
|
+
# Unlike requests, the httpx package does not automatically pull in the
|
476
|
+
# environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
|
477
|
+
# enabled explicitly.
|
478
|
+
ctx = ssl.create_default_context(
|
479
|
+
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
|
480
|
+
capath=os.environ.get('SSL_CERT_DIR'),
|
481
|
+
)
|
482
|
+
|
483
|
+
def _maybe_set(
|
484
|
+
args: Optional[dict[str, Any]],
|
485
|
+
ctx: ssl.SSLContext,
|
486
|
+
) -> dict[str, Any]:
|
487
|
+
"""Sets the SSL context in the client args if not set.
|
488
|
+
|
489
|
+
Does not override the SSL context if it is already set.
|
490
|
+
|
491
|
+
Args:
|
492
|
+
args: The client args to to check for SSL context.
|
493
|
+
ctx: The SSL context to set.
|
494
|
+
|
495
|
+
Returns:
|
496
|
+
The client args with the SSL context included.
|
497
|
+
"""
|
498
|
+
if not args or not args.get(verify):
|
499
|
+
args = (args or {}).copy()
|
500
|
+
args[verify] = ctx
|
501
|
+
return args
|
502
|
+
|
503
|
+
return (
|
504
|
+
_maybe_set(args, ctx),
|
505
|
+
_maybe_set(async_args, ctx),
|
506
|
+
)
|
507
|
+
|
508
|
+
def _websocket_base_url(self) -> str:
|
440
509
|
url_parts = urlparse(self._http_options.base_url)
|
441
|
-
return url_parts._replace(scheme='wss').geturl()
|
510
|
+
return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
|
442
511
|
|
443
512
|
def _access_token(self) -> str:
|
444
513
|
"""Retrieves the access token for the credentials."""
|
@@ -453,11 +522,11 @@ class BaseApiClient:
|
|
453
522
|
_refresh_auth(self._credentials)
|
454
523
|
if not self._credentials.token:
|
455
524
|
raise RuntimeError('Could not resolve API token from the environment')
|
456
|
-
return self._credentials.token
|
525
|
+
return self._credentials.token # type: ignore[no-any-return]
|
457
526
|
else:
|
458
527
|
raise RuntimeError('Could not resolve API token from the environment')
|
459
528
|
|
460
|
-
async def _async_access_token(self) -> str:
|
529
|
+
async def _async_access_token(self) -> Union[str, Any]:
|
461
530
|
"""Retrieves the access token for the credentials asynchronously."""
|
462
531
|
if not self._credentials:
|
463
532
|
async with self._auth_lock:
|
@@ -607,7 +676,7 @@ class BaseApiClient:
|
|
607
676
|
|
608
677
|
async def _async_request(
|
609
678
|
self, http_request: HttpRequest, stream: bool = False
|
610
|
-
):
|
679
|
+
) -> HttpResponse:
|
611
680
|
data: Optional[Union[str, bytes]] = None
|
612
681
|
if self.vertexai and not self.api_key:
|
613
682
|
http_request.headers['Authorization'] = (
|
@@ -667,7 +736,7 @@ class BaseApiClient:
|
|
667
736
|
path: str,
|
668
737
|
request_dict: dict[str, object],
|
669
738
|
http_options: Optional[HttpOptionsOrDict] = None,
|
670
|
-
):
|
739
|
+
) -> Union[BaseResponse, Any]:
|
671
740
|
http_request = self._build_request(
|
672
741
|
http_method, path, request_dict, http_options
|
673
742
|
)
|
@@ -685,7 +754,7 @@ class BaseApiClient:
|
|
685
754
|
path: str,
|
686
755
|
request_dict: dict[str, object],
|
687
756
|
http_options: Optional[HttpOptionsOrDict] = None,
|
688
|
-
):
|
757
|
+
) -> Generator[Any, None, None]:
|
689
758
|
http_request = self._build_request(
|
690
759
|
http_method, path, request_dict, http_options
|
691
760
|
)
|
@@ -700,7 +769,7 @@ class BaseApiClient:
|
|
700
769
|
path: str,
|
701
770
|
request_dict: dict[str, object],
|
702
771
|
http_options: Optional[HttpOptionsOrDict] = None,
|
703
|
-
) ->
|
772
|
+
) -> Union[BaseResponse, Any]:
|
704
773
|
http_request = self._build_request(
|
705
774
|
http_method, path, request_dict, http_options
|
706
775
|
)
|
@@ -717,18 +786,18 @@ class BaseApiClient:
|
|
717
786
|
path: str,
|
718
787
|
request_dict: dict[str, object],
|
719
788
|
http_options: Optional[HttpOptionsOrDict] = None,
|
720
|
-
):
|
789
|
+
) -> Any:
|
721
790
|
http_request = self._build_request(
|
722
791
|
http_method, path, request_dict, http_options
|
723
792
|
)
|
724
793
|
|
725
794
|
response = await self._async_request(http_request=http_request, stream=True)
|
726
795
|
|
727
|
-
async def async_generator():
|
796
|
+
async def async_generator(): # type: ignore[no-untyped-def]
|
728
797
|
async for chunk in response:
|
729
798
|
yield chunk
|
730
799
|
|
731
|
-
return async_generator()
|
800
|
+
return async_generator() # type: ignore[no-untyped-call]
|
732
801
|
|
733
802
|
def upload_file(
|
734
803
|
self,
|
@@ -840,7 +909,7 @@ class BaseApiClient:
|
|
840
909
|
path: str,
|
841
910
|
*,
|
842
911
|
http_options: Optional[HttpOptionsOrDict] = None,
|
843
|
-
):
|
912
|
+
) -> Union[Any,bytes]:
|
844
913
|
"""Downloads the file data.
|
845
914
|
|
846
915
|
Args:
|
@@ -909,7 +978,7 @@ class BaseApiClient:
|
|
909
978
|
|
910
979
|
async def _async_upload_fd(
|
911
980
|
self,
|
912
|
-
file: Union[io.IOBase, anyio.AsyncFile],
|
981
|
+
file: Union[io.IOBase, anyio.AsyncFile[Any]],
|
913
982
|
upload_url: str,
|
914
983
|
upload_size: int,
|
915
984
|
*,
|
@@ -988,7 +1057,7 @@ class BaseApiClient:
|
|
988
1057
|
path: str,
|
989
1058
|
*,
|
990
1059
|
http_options: Optional[HttpOptionsOrDict] = None,
|
991
|
-
):
|
1060
|
+
) -> Union[Any, bytes]:
|
992
1061
|
"""Downloads the file data.
|
993
1062
|
|
994
1063
|
Args:
|
@@ -1025,5 +1094,5 @@ class BaseApiClient:
|
|
1025
1094
|
# This method does nothing in the real api client. It is used in the
|
1026
1095
|
# replay_api_client to verify the response from the SDK method matches the
|
1027
1096
|
# recorded response.
|
1028
|
-
def _verify_response(self, response_model: _common.BaseModel):
|
1097
|
+
def _verify_response(self, response_model: _common.BaseModel) -> None:
|
1029
1098
|
pass
|
@@ -46,19 +46,6 @@ def _is_builtin_primitive_or_compound(
|
|
46
46
|
return annotation in _py_builtin_type_to_schema_type.keys()
|
47
47
|
|
48
48
|
|
49
|
-
def _raise_for_default_if_mldev(schema: types.Schema):
|
50
|
-
if schema.default is not None:
|
51
|
-
raise ValueError(
|
52
|
-
'Default value is not supported in function declaration schema for'
|
53
|
-
' the Gemini API.'
|
54
|
-
)
|
55
|
-
|
56
|
-
|
57
|
-
def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
|
58
|
-
if api_option == 'GEMINI_API':
|
59
|
-
_raise_for_default_if_mldev(schema)
|
60
|
-
|
61
|
-
|
62
49
|
def _is_default_value_compatible(
|
63
50
|
default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
|
64
51
|
) -> bool:
|
@@ -72,16 +59,16 @@ def _is_default_value_compatible(
|
|
72
59
|
or isinstance(annotation, VersionedUnionType)
|
73
60
|
):
|
74
61
|
origin = get_origin(annotation)
|
75
|
-
if origin in (Union, VersionedUnionType):
|
62
|
+
if origin in (Union, VersionedUnionType): # type: ignore[comparison-overlap]
|
76
63
|
return any(
|
77
64
|
_is_default_value_compatible(default_value, arg)
|
78
65
|
for arg in get_args(annotation)
|
79
66
|
)
|
80
67
|
|
81
|
-
if origin is dict:
|
68
|
+
if origin is dict: # type: ignore[comparison-overlap]
|
82
69
|
return isinstance(default_value, dict)
|
83
70
|
|
84
|
-
if origin is list:
|
71
|
+
if origin is list: # type: ignore[comparison-overlap]
|
85
72
|
if not isinstance(default_value, list):
|
86
73
|
return False
|
87
74
|
# most tricky case, element in list is union type
|
@@ -97,7 +84,7 @@ def _is_default_value_compatible(
|
|
97
84
|
for item in default_value
|
98
85
|
)
|
99
86
|
|
100
|
-
if origin is Literal:
|
87
|
+
if origin is Literal: # type: ignore[comparison-overlap]
|
101
88
|
return default_value in get_args(annotation)
|
102
89
|
|
103
90
|
# return False for any other unrecognized annotation
|
@@ -125,7 +112,6 @@ def _parse_schema_from_parameter(
|
|
125
112
|
raise ValueError(default_value_error_msg)
|
126
113
|
schema.default = param.default
|
127
114
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
128
|
-
_raise_if_schema_unsupported(api_option, schema)
|
129
115
|
return schema
|
130
116
|
if (
|
131
117
|
isinstance(param.annotation, VersionedUnionType)
|
@@ -166,7 +152,6 @@ def _parse_schema_from_parameter(
|
|
166
152
|
if not _is_default_value_compatible(param.default, param.annotation):
|
167
153
|
raise ValueError(default_value_error_msg)
|
168
154
|
schema.default = param.default
|
169
|
-
_raise_if_schema_unsupported(api_option, schema)
|
170
155
|
return schema
|
171
156
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
172
157
|
param.annotation, builtin_types.GenericAlias
|
@@ -179,7 +164,6 @@ def _parse_schema_from_parameter(
|
|
179
164
|
if not _is_default_value_compatible(param.default, param.annotation):
|
180
165
|
raise ValueError(default_value_error_msg)
|
181
166
|
schema.default = param.default
|
182
|
-
_raise_if_schema_unsupported(api_option, schema)
|
183
167
|
return schema
|
184
168
|
if origin is Literal:
|
185
169
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -192,7 +176,6 @@ def _parse_schema_from_parameter(
|
|
192
176
|
if not _is_default_value_compatible(param.default, param.annotation):
|
193
177
|
raise ValueError(default_value_error_msg)
|
194
178
|
schema.default = param.default
|
195
|
-
_raise_if_schema_unsupported(api_option, schema)
|
196
179
|
return schema
|
197
180
|
if origin is list:
|
198
181
|
schema.type = _py_builtin_type_to_schema_type[list]
|
@@ -209,7 +192,6 @@ def _parse_schema_from_parameter(
|
|
209
192
|
if not _is_default_value_compatible(param.default, param.annotation):
|
210
193
|
raise ValueError(default_value_error_msg)
|
211
194
|
schema.default = param.default
|
212
|
-
_raise_if_schema_unsupported(api_option, schema)
|
213
195
|
return schema
|
214
196
|
if origin is Union:
|
215
197
|
schema.any_of = []
|
@@ -259,7 +241,6 @@ def _parse_schema_from_parameter(
|
|
259
241
|
if not _is_default_value_compatible(param.default, param.annotation):
|
260
242
|
raise ValueError(default_value_error_msg)
|
261
243
|
schema.default = param.default
|
262
|
-
_raise_if_schema_unsupported(api_option, schema)
|
263
244
|
return schema
|
264
245
|
# all other generic alias will be invoked in raise branch
|
265
246
|
if (
|
@@ -284,7 +265,6 @@ def _parse_schema_from_parameter(
|
|
284
265
|
func_name,
|
285
266
|
)
|
286
267
|
schema.required = _get_required_fields(schema)
|
287
|
-
_raise_if_schema_unsupported(api_option, schema)
|
288
268
|
return schema
|
289
269
|
raise ValueError(
|
290
270
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
google/genai/_common.py
CHANGED
@@ -20,7 +20,7 @@ import datetime
|
|
20
20
|
import enum
|
21
21
|
import functools
|
22
22
|
import typing
|
23
|
-
from typing import Any, Union
|
23
|
+
from typing import Any, Callable, Optional, Union
|
24
24
|
import uuid
|
25
25
|
import warnings
|
26
26
|
|
@@ -31,7 +31,7 @@ from . import _api_client
|
|
31
31
|
from . import errors
|
32
32
|
|
33
33
|
|
34
|
-
def set_value_by_path(data, keys, value):
|
34
|
+
def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None:
|
35
35
|
"""Examples:
|
36
36
|
|
37
37
|
set_value_by_path({}, ['a', 'b'], v)
|
@@ -46,54 +46,57 @@ def set_value_by_path(data, keys, value):
|
|
46
46
|
for i, key in enumerate(keys[:-1]):
|
47
47
|
if key.endswith('[]'):
|
48
48
|
key_name = key[:-2]
|
49
|
-
if key_name not in data:
|
49
|
+
if data is not None and key_name not in data:
|
50
50
|
if isinstance(value, list):
|
51
51
|
data[key_name] = [{} for _ in range(len(value))]
|
52
52
|
else:
|
53
53
|
raise ValueError(
|
54
54
|
f'value {value} must be a list given an array path {key}'
|
55
55
|
)
|
56
|
-
if isinstance(value, list):
|
56
|
+
if isinstance(value, list) and data is not None:
|
57
57
|
for j, d in enumerate(data[key_name]):
|
58
58
|
set_value_by_path(d, keys[i + 1 :], value[j])
|
59
59
|
else:
|
60
|
-
|
61
|
-
|
60
|
+
if data is not None:
|
61
|
+
for d in data[key_name]:
|
62
|
+
set_value_by_path(d, keys[i + 1 :], value)
|
62
63
|
return
|
63
64
|
elif key.endswith('[0]'):
|
64
65
|
key_name = key[:-3]
|
65
|
-
if key_name not in data:
|
66
|
+
if data is not None and key_name not in data:
|
66
67
|
data[key_name] = [{}]
|
67
|
-
|
68
|
+
if data is not None:
|
69
|
+
set_value_by_path(data[key_name][0], keys[i + 1 :], value)
|
68
70
|
return
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
71
|
+
if data is not None:
|
72
|
+
data = data.setdefault(key, {})
|
73
|
+
|
74
|
+
if data is not None:
|
75
|
+
existing_data = data.get(keys[-1])
|
76
|
+
# If there is an existing value, merge, not overwrite.
|
77
|
+
if existing_data is not None:
|
78
|
+
# Don't overwrite existing non-empty value with new empty value.
|
79
|
+
# This is triggered when handling tuning datasets.
|
80
|
+
if not value:
|
81
|
+
pass
|
82
|
+
# Don't fail when overwriting value with same value
|
83
|
+
elif value == existing_data:
|
84
|
+
pass
|
85
|
+
# Instead of overwriting dictionary with another dictionary, merge them.
|
86
|
+
# This is important for handling training and validation datasets in tuning.
|
87
|
+
elif isinstance(existing_data, dict) and isinstance(value, dict):
|
88
|
+
# Merging dictionaries. Consider deep merging in the future.
|
89
|
+
existing_data.update(value)
|
90
|
+
else:
|
91
|
+
raise ValueError(
|
92
|
+
f'Cannot set value for an existing key. Key: {keys[-1]};'
|
93
|
+
f' Existing value: {existing_data}; New value: {value}.'
|
94
|
+
)
|
87
95
|
else:
|
88
|
-
|
89
|
-
f'Cannot set value for an existing key. Key: {keys[-1]};'
|
90
|
-
f' Existing value: {existing_data}; New value: {value}.'
|
91
|
-
)
|
92
|
-
else:
|
93
|
-
data[keys[-1]] = value
|
96
|
+
data[keys[-1]] = value
|
94
97
|
|
95
98
|
|
96
|
-
def get_value_by_path(data: Any, keys: list[str]):
|
99
|
+
def get_value_by_path(data: Any, keys: list[str]) -> Any:
|
97
100
|
"""Examples:
|
98
101
|
|
99
102
|
get_value_by_path({'a': {'b': v}}, ['a', 'b'])
|
@@ -227,7 +230,7 @@ class CaseInSensitiveEnum(str, enum.Enum):
|
|
227
230
|
"""Case insensitive enum."""
|
228
231
|
|
229
232
|
@classmethod
|
230
|
-
def _missing_(cls, value):
|
233
|
+
def _missing_(cls, value: Any) -> Any:
|
231
234
|
try:
|
232
235
|
return cls[value.upper()] # Try to access directly with uppercase
|
233
236
|
except KeyError:
|
@@ -295,12 +298,12 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
|
295
298
|
return processed_data
|
296
299
|
|
297
300
|
|
298
|
-
def experimental_warning(message: str):
|
301
|
+
def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
299
302
|
"""Experimental warning, only warns once."""
|
300
|
-
def decorator(func):
|
303
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
301
304
|
warning_done = False
|
302
305
|
@functools.wraps(func)
|
303
|
-
def wrapper(*args, **kwargs):
|
306
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
304
307
|
nonlocal warning_done
|
305
308
|
if not warning_done:
|
306
309
|
warning_done = True
|
google/genai/_extra_utils.py
CHANGED
@@ -78,16 +78,17 @@ def format_destination(
|
|
78
78
|
|
79
79
|
def get_function_map(
|
80
80
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
81
|
-
|
81
|
+
is_caller_method_async: bool = False,
|
82
|
+
) -> dict[str, Callable[..., Any]]:
|
82
83
|
"""Returns a function map from the config."""
|
83
|
-
function_map: dict[str, Callable] = {}
|
84
|
+
function_map: dict[str, Callable[..., Any]] = {}
|
84
85
|
if not config:
|
85
86
|
return function_map
|
86
87
|
config_model = _create_generate_content_config_model(config)
|
87
88
|
if config_model.tools:
|
88
89
|
for tool in config_model.tools:
|
89
90
|
if callable(tool):
|
90
|
-
if inspect.iscoroutinefunction(tool):
|
91
|
+
if inspect.iscoroutinefunction(tool) and not is_caller_method_async:
|
91
92
|
raise errors.UnsupportedFunctionError(
|
92
93
|
f'Function {tool.__name__} is a coroutine function, which is not'
|
93
94
|
' supported for automatic function calling. Please manually'
|
@@ -199,11 +200,11 @@ def convert_if_exist_pydantic_model(
|
|
199
200
|
return value
|
200
201
|
|
201
202
|
|
202
|
-
def
|
203
|
-
args:
|
204
|
-
) -> Any:
|
205
|
-
signature = inspect.signature(
|
206
|
-
func_name =
|
203
|
+
def convert_argument_from_function(
|
204
|
+
args: dict[str, Any], function: Callable[..., Any]
|
205
|
+
) -> dict[str, Any]:
|
206
|
+
signature = inspect.signature(function)
|
207
|
+
func_name = function.__name__
|
207
208
|
converted_args = {}
|
208
209
|
for param_name, param in signature.parameters.items():
|
209
210
|
if param_name in args:
|
@@ -213,19 +214,40 @@ def invoke_function_from_dict_args(
|
|
213
214
|
param_name,
|
214
215
|
func_name,
|
215
216
|
)
|
217
|
+
return converted_args
|
218
|
+
|
219
|
+
|
220
|
+
def invoke_function_from_dict_args(
|
221
|
+
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
222
|
+
) -> Any:
|
223
|
+
converted_args = convert_argument_from_function(args, function_to_invoke)
|
216
224
|
try:
|
217
225
|
return function_to_invoke(**converted_args)
|
218
226
|
except Exception as e:
|
219
227
|
raise errors.FunctionInvocationError(
|
220
|
-
f'Failed to invoke function {
|
221
|
-
f' {converted_args} from model returned function
|
222
|
-
f' {args} because of error {e}'
|
228
|
+
f'Failed to invoke function {function_to_invoke.__name__} with'
|
229
|
+
f' converted arguments {converted_args} from model returned function'
|
230
|
+
f' call argument {args} because of error {e}'
|
231
|
+
)
|
232
|
+
|
233
|
+
|
234
|
+
async def invoke_function_from_dict_args_async(
|
235
|
+
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
236
|
+
) -> Any:
|
237
|
+
converted_args = convert_argument_from_function(args, function_to_invoke)
|
238
|
+
try:
|
239
|
+
return await function_to_invoke(**converted_args)
|
240
|
+
except Exception as e:
|
241
|
+
raise errors.FunctionInvocationError(
|
242
|
+
f'Failed to invoke function {function_to_invoke.__name__} with'
|
243
|
+
f' converted arguments {converted_args} from model returned function'
|
244
|
+
f' call argument {args} because of error {e}'
|
223
245
|
)
|
224
246
|
|
225
247
|
|
226
248
|
def get_function_response_parts(
|
227
249
|
response: types.GenerateContentResponse,
|
228
|
-
function_map: dict[str, Callable],
|
250
|
+
function_map: dict[str, Callable[..., Any]],
|
229
251
|
) -> list[types.Part]:
|
230
252
|
"""Returns the function response parts from the response."""
|
231
253
|
func_response_parts = []
|
@@ -256,6 +278,44 @@ def get_function_response_parts(
|
|
256
278
|
func_response_parts.append(func_response_part)
|
257
279
|
return func_response_parts
|
258
280
|
|
281
|
+
async def get_function_response_parts_async(
|
282
|
+
response: types.GenerateContentResponse,
|
283
|
+
function_map: dict[str, Callable[..., Any]],
|
284
|
+
) -> list[types.Part]:
|
285
|
+
"""Returns the function response parts from the response."""
|
286
|
+
func_response_parts = []
|
287
|
+
if (
|
288
|
+
response.candidates is not None
|
289
|
+
and isinstance(response.candidates[0].content, types.Content)
|
290
|
+
and response.candidates[0].content.parts is not None
|
291
|
+
):
|
292
|
+
for part in response.candidates[0].content.parts:
|
293
|
+
if not part.function_call:
|
294
|
+
continue
|
295
|
+
func_name = part.function_call.name
|
296
|
+
if func_name is not None and part.function_call.args is not None:
|
297
|
+
func = function_map[func_name]
|
298
|
+
args = convert_number_values_for_dict_function_call_args(
|
299
|
+
part.function_call.args
|
300
|
+
)
|
301
|
+
func_response: dict[str, Any]
|
302
|
+
try:
|
303
|
+
if inspect.iscoroutinefunction(func):
|
304
|
+
func_response = {
|
305
|
+
'result': await invoke_function_from_dict_args_async(args, func)
|
306
|
+
}
|
307
|
+
else:
|
308
|
+
func_response = {
|
309
|
+
'result': invoke_function_from_dict_args(args, func)
|
310
|
+
}
|
311
|
+
except Exception as e: # pylint: disable=broad-except
|
312
|
+
func_response = {'error': str(e)}
|
313
|
+
func_response_part = types.Part.from_function_response(
|
314
|
+
name=func_name, response=func_response
|
315
|
+
)
|
316
|
+
func_response_parts.append(func_response_part)
|
317
|
+
return func_response_parts
|
318
|
+
|
259
319
|
|
260
320
|
def should_disable_afc(
|
261
321
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|