google-genai 1.11.0__tar.gz → 1.12.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_genai-1.11.0/google_genai.egg-info → google_genai-1.12.0}/PKG-INFO +1 -1
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/_api_client.py +25 -24
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/_automatic_function_calling_util.py +4 -24
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/_common.py +40 -37
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/_extra_utils.py +7 -7
- google_genai-1.12.0/google/genai/_live_converters.py +2487 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/_replay_api_client.py +32 -26
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/_transformers.py +46 -81
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/batches.py +45 -45
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/caches.py +126 -126
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/chats.py +13 -9
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/client.py +3 -2
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/errors.py +6 -6
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/files.py +38 -38
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/live.py +69 -17
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/models.py +388 -388
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/operations.py +33 -33
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/pagers.py +2 -2
- google_genai-1.12.0/google/genai/py.typed +1 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/tunings.py +70 -70
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/types.py +455 -24
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/version.py +1 -1
- {google_genai-1.11.0 → google_genai-1.12.0/google_genai.egg-info}/PKG-INFO +1 -1
- {google_genai-1.11.0 → google_genai-1.12.0}/google_genai.egg-info/SOURCES.txt +2 -1
- {google_genai-1.11.0 → google_genai-1.12.0}/pyproject.toml +5 -2
- google_genai-1.11.0/google/genai/live_converters.py +0 -1298
- {google_genai-1.11.0 → google_genai-1.12.0}/LICENSE +0 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/MANIFEST.in +0 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/README.md +0 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/__init__.py +0 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/_api_module.py +0 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/google/genai/_test_api_client.py +0 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/google_genai.egg-info/dependency_links.txt +0 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/google_genai.egg-info/requires.txt +0 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/google_genai.egg-info/top_level.txt +0 -0
- {google_genai-1.11.0 → google_genai-1.12.0}/setup.cfg +0 -0
@@ -20,6 +20,7 @@ The BaseApiClient is intended to be a private module and is subject to change.
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
import asyncio
|
23
|
+
from collections.abc import Awaitable, Generator
|
23
24
|
import copy
|
24
25
|
from dataclasses import dataclass
|
25
26
|
import datetime
|
@@ -129,7 +130,7 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
129
130
|
|
130
131
|
def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
131
132
|
"""Loads google auth credentials and project id."""
|
132
|
-
credentials, loaded_project_id = google.auth.default(
|
133
|
+
credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
|
133
134
|
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
134
135
|
)
|
135
136
|
|
@@ -145,7 +146,7 @@ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
|
145
146
|
|
146
147
|
|
147
148
|
def _refresh_auth(credentials: Credentials) -> Credentials:
|
148
|
-
credentials.refresh(Request())
|
149
|
+
credentials.refresh(Request()) # type: ignore[no-untyped-call]
|
149
150
|
return credentials
|
150
151
|
|
151
152
|
|
@@ -191,17 +192,17 @@ class HttpResponse:
|
|
191
192
|
response_stream: Union[Any, str] = None,
|
192
193
|
byte_stream: Union[Any, bytes] = None,
|
193
194
|
):
|
194
|
-
self.status_code = 200
|
195
|
+
self.status_code: int = 200
|
195
196
|
self.headers = headers
|
196
197
|
self.response_stream = response_stream
|
197
198
|
self.byte_stream = byte_stream
|
198
199
|
|
199
200
|
# Async iterator for async streaming.
|
200
|
-
def __aiter__(self):
|
201
|
+
def __aiter__(self) -> 'HttpResponse':
|
201
202
|
self.segment_iterator = self.async_segments()
|
202
203
|
return self
|
203
204
|
|
204
|
-
async def __anext__(self):
|
205
|
+
async def __anext__(self) -> Any:
|
205
206
|
try:
|
206
207
|
return await self.segment_iterator.__anext__()
|
207
208
|
except StopIteration:
|
@@ -213,7 +214,7 @@ class HttpResponse:
|
|
213
214
|
return ''
|
214
215
|
return json.loads(self.response_stream[0])
|
215
216
|
|
216
|
-
def segments(self):
|
217
|
+
def segments(self) -> Generator[Any, None, None]:
|
217
218
|
if isinstance(self.response_stream, list):
|
218
219
|
# list of objects retrieved from replay or from non-streaming API.
|
219
220
|
for chunk in self.response_stream:
|
@@ -222,7 +223,7 @@ class HttpResponse:
|
|
222
223
|
yield from []
|
223
224
|
else:
|
224
225
|
# Iterator of objects retrieved from the API.
|
225
|
-
for chunk in self.response_stream.iter_lines():
|
226
|
+
for chunk in self.response_stream.iter_lines(): # type: ignore[union-attr]
|
226
227
|
if chunk:
|
227
228
|
# In streaming mode, the chunk of JSON is prefixed with "data:" which
|
228
229
|
# we must strip before parsing.
|
@@ -256,7 +257,7 @@ class HttpResponse:
|
|
256
257
|
else:
|
257
258
|
raise ValueError('Error parsing streaming response.')
|
258
259
|
|
259
|
-
def byte_segments(self):
|
260
|
+
def byte_segments(self) -> Generator[Union[bytes, Any], None, None]:
|
260
261
|
if isinstance(self.byte_stream, list):
|
261
262
|
# list of objects retrieved from replay or from non-streaming API.
|
262
263
|
yield from self.byte_stream
|
@@ -267,7 +268,7 @@ class HttpResponse:
|
|
267
268
|
'Byte segments are not supported for streaming responses.'
|
268
269
|
)
|
269
270
|
|
270
|
-
def _copy_to_dict(self, response_payload: dict[str, object]):
|
271
|
+
def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
|
271
272
|
# Cannot pickle 'generator' object.
|
272
273
|
delattr(self, 'segment_iterator')
|
273
274
|
for attribute in dir(self):
|
@@ -504,9 +505,9 @@ class BaseApiClient:
|
|
504
505
|
_maybe_set(async_args, ctx),
|
505
506
|
)
|
506
507
|
|
507
|
-
def _websocket_base_url(self):
|
508
|
+
def _websocket_base_url(self) -> str:
|
508
509
|
url_parts = urlparse(self._http_options.base_url)
|
509
|
-
return url_parts._replace(scheme='wss').geturl()
|
510
|
+
return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
|
510
511
|
|
511
512
|
def _access_token(self) -> str:
|
512
513
|
"""Retrieves the access token for the credentials."""
|
@@ -521,11 +522,11 @@ class BaseApiClient:
|
|
521
522
|
_refresh_auth(self._credentials)
|
522
523
|
if not self._credentials.token:
|
523
524
|
raise RuntimeError('Could not resolve API token from the environment')
|
524
|
-
return self._credentials.token
|
525
|
+
return self._credentials.token # type: ignore[no-any-return]
|
525
526
|
else:
|
526
527
|
raise RuntimeError('Could not resolve API token from the environment')
|
527
528
|
|
528
|
-
async def _async_access_token(self) -> str:
|
529
|
+
async def _async_access_token(self) -> Union[str, Any]:
|
529
530
|
"""Retrieves the access token for the credentials asynchronously."""
|
530
531
|
if not self._credentials:
|
531
532
|
async with self._auth_lock:
|
@@ -675,7 +676,7 @@ class BaseApiClient:
|
|
675
676
|
|
676
677
|
async def _async_request(
|
677
678
|
self, http_request: HttpRequest, stream: bool = False
|
678
|
-
):
|
679
|
+
) -> HttpResponse:
|
679
680
|
data: Optional[Union[str, bytes]] = None
|
680
681
|
if self.vertexai and not self.api_key:
|
681
682
|
http_request.headers['Authorization'] = (
|
@@ -735,7 +736,7 @@ class BaseApiClient:
|
|
735
736
|
path: str,
|
736
737
|
request_dict: dict[str, object],
|
737
738
|
http_options: Optional[HttpOptionsOrDict] = None,
|
738
|
-
):
|
739
|
+
) -> Union[BaseResponse, Any]:
|
739
740
|
http_request = self._build_request(
|
740
741
|
http_method, path, request_dict, http_options
|
741
742
|
)
|
@@ -753,7 +754,7 @@ class BaseApiClient:
|
|
753
754
|
path: str,
|
754
755
|
request_dict: dict[str, object],
|
755
756
|
http_options: Optional[HttpOptionsOrDict] = None,
|
756
|
-
):
|
757
|
+
) -> Generator[Any, None, None]:
|
757
758
|
http_request = self._build_request(
|
758
759
|
http_method, path, request_dict, http_options
|
759
760
|
)
|
@@ -768,7 +769,7 @@ class BaseApiClient:
|
|
768
769
|
path: str,
|
769
770
|
request_dict: dict[str, object],
|
770
771
|
http_options: Optional[HttpOptionsOrDict] = None,
|
771
|
-
) ->
|
772
|
+
) -> Union[BaseResponse, Any]:
|
772
773
|
http_request = self._build_request(
|
773
774
|
http_method, path, request_dict, http_options
|
774
775
|
)
|
@@ -785,18 +786,18 @@ class BaseApiClient:
|
|
785
786
|
path: str,
|
786
787
|
request_dict: dict[str, object],
|
787
788
|
http_options: Optional[HttpOptionsOrDict] = None,
|
788
|
-
):
|
789
|
+
) -> Any:
|
789
790
|
http_request = self._build_request(
|
790
791
|
http_method, path, request_dict, http_options
|
791
792
|
)
|
792
793
|
|
793
794
|
response = await self._async_request(http_request=http_request, stream=True)
|
794
795
|
|
795
|
-
async def async_generator():
|
796
|
+
async def async_generator(): # type: ignore[no-untyped-def]
|
796
797
|
async for chunk in response:
|
797
798
|
yield chunk
|
798
799
|
|
799
|
-
return async_generator()
|
800
|
+
return async_generator() # type: ignore[no-untyped-call]
|
800
801
|
|
801
802
|
def upload_file(
|
802
803
|
self,
|
@@ -908,7 +909,7 @@ class BaseApiClient:
|
|
908
909
|
path: str,
|
909
910
|
*,
|
910
911
|
http_options: Optional[HttpOptionsOrDict] = None,
|
911
|
-
):
|
912
|
+
) -> Union[Any,bytes]:
|
912
913
|
"""Downloads the file data.
|
913
914
|
|
914
915
|
Args:
|
@@ -977,7 +978,7 @@ class BaseApiClient:
|
|
977
978
|
|
978
979
|
async def _async_upload_fd(
|
979
980
|
self,
|
980
|
-
file: Union[io.IOBase, anyio.AsyncFile],
|
981
|
+
file: Union[io.IOBase, anyio.AsyncFile[Any]],
|
981
982
|
upload_url: str,
|
982
983
|
upload_size: int,
|
983
984
|
*,
|
@@ -1056,7 +1057,7 @@ class BaseApiClient:
|
|
1056
1057
|
path: str,
|
1057
1058
|
*,
|
1058
1059
|
http_options: Optional[HttpOptionsOrDict] = None,
|
1059
|
-
):
|
1060
|
+
) -> Union[Any, bytes]:
|
1060
1061
|
"""Downloads the file data.
|
1061
1062
|
|
1062
1063
|
Args:
|
@@ -1093,5 +1094,5 @@ class BaseApiClient:
|
|
1093
1094
|
# This method does nothing in the real api client. It is used in the
|
1094
1095
|
# replay_api_client to verify the response from the SDK method matches the
|
1095
1096
|
# recorded response.
|
1096
|
-
def _verify_response(self, response_model: _common.BaseModel):
|
1097
|
+
def _verify_response(self, response_model: _common.BaseModel) -> None:
|
1097
1098
|
pass
|
{google_genai-1.11.0 → google_genai-1.12.0}/google/genai/_automatic_function_calling_util.py
RENAMED
@@ -46,19 +46,6 @@ def _is_builtin_primitive_or_compound(
|
|
46
46
|
return annotation in _py_builtin_type_to_schema_type.keys()
|
47
47
|
|
48
48
|
|
49
|
-
def _raise_for_default_if_mldev(schema: types.Schema):
|
50
|
-
if schema.default is not None:
|
51
|
-
raise ValueError(
|
52
|
-
'Default value is not supported in function declaration schema for'
|
53
|
-
' the Gemini API.'
|
54
|
-
)
|
55
|
-
|
56
|
-
|
57
|
-
def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
|
58
|
-
if api_option == 'GEMINI_API':
|
59
|
-
_raise_for_default_if_mldev(schema)
|
60
|
-
|
61
|
-
|
62
49
|
def _is_default_value_compatible(
|
63
50
|
default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
|
64
51
|
) -> bool:
|
@@ -72,16 +59,16 @@ def _is_default_value_compatible(
|
|
72
59
|
or isinstance(annotation, VersionedUnionType)
|
73
60
|
):
|
74
61
|
origin = get_origin(annotation)
|
75
|
-
if origin in (Union, VersionedUnionType):
|
62
|
+
if origin in (Union, VersionedUnionType): # type: ignore[comparison-overlap]
|
76
63
|
return any(
|
77
64
|
_is_default_value_compatible(default_value, arg)
|
78
65
|
for arg in get_args(annotation)
|
79
66
|
)
|
80
67
|
|
81
|
-
if origin is dict:
|
68
|
+
if origin is dict: # type: ignore[comparison-overlap]
|
82
69
|
return isinstance(default_value, dict)
|
83
70
|
|
84
|
-
if origin is list:
|
71
|
+
if origin is list: # type: ignore[comparison-overlap]
|
85
72
|
if not isinstance(default_value, list):
|
86
73
|
return False
|
87
74
|
# most tricky case, element in list is union type
|
@@ -97,7 +84,7 @@ def _is_default_value_compatible(
|
|
97
84
|
for item in default_value
|
98
85
|
)
|
99
86
|
|
100
|
-
if origin is Literal:
|
87
|
+
if origin is Literal: # type: ignore[comparison-overlap]
|
101
88
|
return default_value in get_args(annotation)
|
102
89
|
|
103
90
|
# return False for any other unrecognized annotation
|
@@ -125,7 +112,6 @@ def _parse_schema_from_parameter(
|
|
125
112
|
raise ValueError(default_value_error_msg)
|
126
113
|
schema.default = param.default
|
127
114
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
128
|
-
_raise_if_schema_unsupported(api_option, schema)
|
129
115
|
return schema
|
130
116
|
if (
|
131
117
|
isinstance(param.annotation, VersionedUnionType)
|
@@ -166,7 +152,6 @@ def _parse_schema_from_parameter(
|
|
166
152
|
if not _is_default_value_compatible(param.default, param.annotation):
|
167
153
|
raise ValueError(default_value_error_msg)
|
168
154
|
schema.default = param.default
|
169
|
-
_raise_if_schema_unsupported(api_option, schema)
|
170
155
|
return schema
|
171
156
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
172
157
|
param.annotation, builtin_types.GenericAlias
|
@@ -179,7 +164,6 @@ def _parse_schema_from_parameter(
|
|
179
164
|
if not _is_default_value_compatible(param.default, param.annotation):
|
180
165
|
raise ValueError(default_value_error_msg)
|
181
166
|
schema.default = param.default
|
182
|
-
_raise_if_schema_unsupported(api_option, schema)
|
183
167
|
return schema
|
184
168
|
if origin is Literal:
|
185
169
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -192,7 +176,6 @@ def _parse_schema_from_parameter(
|
|
192
176
|
if not _is_default_value_compatible(param.default, param.annotation):
|
193
177
|
raise ValueError(default_value_error_msg)
|
194
178
|
schema.default = param.default
|
195
|
-
_raise_if_schema_unsupported(api_option, schema)
|
196
179
|
return schema
|
197
180
|
if origin is list:
|
198
181
|
schema.type = _py_builtin_type_to_schema_type[list]
|
@@ -209,7 +192,6 @@ def _parse_schema_from_parameter(
|
|
209
192
|
if not _is_default_value_compatible(param.default, param.annotation):
|
210
193
|
raise ValueError(default_value_error_msg)
|
211
194
|
schema.default = param.default
|
212
|
-
_raise_if_schema_unsupported(api_option, schema)
|
213
195
|
return schema
|
214
196
|
if origin is Union:
|
215
197
|
schema.any_of = []
|
@@ -259,7 +241,6 @@ def _parse_schema_from_parameter(
|
|
259
241
|
if not _is_default_value_compatible(param.default, param.annotation):
|
260
242
|
raise ValueError(default_value_error_msg)
|
261
243
|
schema.default = param.default
|
262
|
-
_raise_if_schema_unsupported(api_option, schema)
|
263
244
|
return schema
|
264
245
|
# all other generic alias will be invoked in raise branch
|
265
246
|
if (
|
@@ -284,7 +265,6 @@ def _parse_schema_from_parameter(
|
|
284
265
|
func_name,
|
285
266
|
)
|
286
267
|
schema.required = _get_required_fields(schema)
|
287
|
-
_raise_if_schema_unsupported(api_option, schema)
|
288
268
|
return schema
|
289
269
|
raise ValueError(
|
290
270
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
@@ -20,7 +20,7 @@ import datetime
|
|
20
20
|
import enum
|
21
21
|
import functools
|
22
22
|
import typing
|
23
|
-
from typing import Any, Union
|
23
|
+
from typing import Any, Callable, Optional, Union
|
24
24
|
import uuid
|
25
25
|
import warnings
|
26
26
|
|
@@ -31,7 +31,7 @@ from . import _api_client
|
|
31
31
|
from . import errors
|
32
32
|
|
33
33
|
|
34
|
-
def set_value_by_path(data, keys, value):
|
34
|
+
def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None:
|
35
35
|
"""Examples:
|
36
36
|
|
37
37
|
set_value_by_path({}, ['a', 'b'], v)
|
@@ -46,54 +46,57 @@ def set_value_by_path(data, keys, value):
|
|
46
46
|
for i, key in enumerate(keys[:-1]):
|
47
47
|
if key.endswith('[]'):
|
48
48
|
key_name = key[:-2]
|
49
|
-
if key_name not in data:
|
49
|
+
if data is not None and key_name not in data:
|
50
50
|
if isinstance(value, list):
|
51
51
|
data[key_name] = [{} for _ in range(len(value))]
|
52
52
|
else:
|
53
53
|
raise ValueError(
|
54
54
|
f'value {value} must be a list given an array path {key}'
|
55
55
|
)
|
56
|
-
if isinstance(value, list):
|
56
|
+
if isinstance(value, list) and data is not None:
|
57
57
|
for j, d in enumerate(data[key_name]):
|
58
58
|
set_value_by_path(d, keys[i + 1 :], value[j])
|
59
59
|
else:
|
60
|
-
|
61
|
-
|
60
|
+
if data is not None:
|
61
|
+
for d in data[key_name]:
|
62
|
+
set_value_by_path(d, keys[i + 1 :], value)
|
62
63
|
return
|
63
64
|
elif key.endswith('[0]'):
|
64
65
|
key_name = key[:-3]
|
65
|
-
if key_name not in data:
|
66
|
+
if data is not None and key_name not in data:
|
66
67
|
data[key_name] = [{}]
|
67
|
-
|
68
|
+
if data is not None:
|
69
|
+
set_value_by_path(data[key_name][0], keys[i + 1 :], value)
|
68
70
|
return
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
71
|
+
if data is not None:
|
72
|
+
data = data.setdefault(key, {})
|
73
|
+
|
74
|
+
if data is not None:
|
75
|
+
existing_data = data.get(keys[-1])
|
76
|
+
# If there is an existing value, merge, not overwrite.
|
77
|
+
if existing_data is not None:
|
78
|
+
# Don't overwrite existing non-empty value with new empty value.
|
79
|
+
# This is triggered when handling tuning datasets.
|
80
|
+
if not value:
|
81
|
+
pass
|
82
|
+
# Don't fail when overwriting value with same value
|
83
|
+
elif value == existing_data:
|
84
|
+
pass
|
85
|
+
# Instead of overwriting dictionary with another dictionary, merge them.
|
86
|
+
# This is important for handling training and validation datasets in tuning.
|
87
|
+
elif isinstance(existing_data, dict) and isinstance(value, dict):
|
88
|
+
# Merging dictionaries. Consider deep merging in the future.
|
89
|
+
existing_data.update(value)
|
90
|
+
else:
|
91
|
+
raise ValueError(
|
92
|
+
f'Cannot set value for an existing key. Key: {keys[-1]};'
|
93
|
+
f' Existing value: {existing_data}; New value: {value}.'
|
94
|
+
)
|
87
95
|
else:
|
88
|
-
|
89
|
-
f'Cannot set value for an existing key. Key: {keys[-1]};'
|
90
|
-
f' Existing value: {existing_data}; New value: {value}.'
|
91
|
-
)
|
92
|
-
else:
|
93
|
-
data[keys[-1]] = value
|
96
|
+
data[keys[-1]] = value
|
94
97
|
|
95
98
|
|
96
|
-
def get_value_by_path(data: Any, keys: list[str]):
|
99
|
+
def get_value_by_path(data: Any, keys: list[str]) -> Any:
|
97
100
|
"""Examples:
|
98
101
|
|
99
102
|
get_value_by_path({'a': {'b': v}}, ['a', 'b'])
|
@@ -227,7 +230,7 @@ class CaseInSensitiveEnum(str, enum.Enum):
|
|
227
230
|
"""Case insensitive enum."""
|
228
231
|
|
229
232
|
@classmethod
|
230
|
-
def _missing_(cls, value):
|
233
|
+
def _missing_(cls, value: Any) -> Any:
|
231
234
|
try:
|
232
235
|
return cls[value.upper()] # Try to access directly with uppercase
|
233
236
|
except KeyError:
|
@@ -295,12 +298,12 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
|
295
298
|
return processed_data
|
296
299
|
|
297
300
|
|
298
|
-
def experimental_warning(message: str):
|
301
|
+
def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
299
302
|
"""Experimental warning, only warns once."""
|
300
|
-
def decorator(func):
|
303
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
301
304
|
warning_done = False
|
302
305
|
@functools.wraps(func)
|
303
|
-
def wrapper(*args, **kwargs):
|
306
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
304
307
|
nonlocal warning_done
|
305
308
|
if not warning_done:
|
306
309
|
warning_done = True
|
@@ -79,9 +79,9 @@ def format_destination(
|
|
79
79
|
def get_function_map(
|
80
80
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
81
81
|
is_caller_method_async: bool = False,
|
82
|
-
) -> dict[str, Callable]:
|
82
|
+
) -> dict[str, Callable[..., Any]]:
|
83
83
|
"""Returns a function map from the config."""
|
84
|
-
function_map: dict[str, Callable] = {}
|
84
|
+
function_map: dict[str, Callable[..., Any]] = {}
|
85
85
|
if not config:
|
86
86
|
return function_map
|
87
87
|
config_model = _create_generate_content_config_model(config)
|
@@ -201,7 +201,7 @@ def convert_if_exist_pydantic_model(
|
|
201
201
|
|
202
202
|
|
203
203
|
def convert_argument_from_function(
|
204
|
-
args: dict[str, Any], function: Callable
|
204
|
+
args: dict[str, Any], function: Callable[..., Any]
|
205
205
|
) -> dict[str, Any]:
|
206
206
|
signature = inspect.signature(function)
|
207
207
|
func_name = function.__name__
|
@@ -218,7 +218,7 @@ def convert_argument_from_function(
|
|
218
218
|
|
219
219
|
|
220
220
|
def invoke_function_from_dict_args(
|
221
|
-
args: Dict[str, Any], function_to_invoke: Callable
|
221
|
+
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
222
222
|
) -> Any:
|
223
223
|
converted_args = convert_argument_from_function(args, function_to_invoke)
|
224
224
|
try:
|
@@ -232,7 +232,7 @@ def invoke_function_from_dict_args(
|
|
232
232
|
|
233
233
|
|
234
234
|
async def invoke_function_from_dict_args_async(
|
235
|
-
args: Dict[str, Any], function_to_invoke: Callable
|
235
|
+
args: Dict[str, Any], function_to_invoke: Callable[..., Any]
|
236
236
|
) -> Any:
|
237
237
|
converted_args = convert_argument_from_function(args, function_to_invoke)
|
238
238
|
try:
|
@@ -247,7 +247,7 @@ async def invoke_function_from_dict_args_async(
|
|
247
247
|
|
248
248
|
def get_function_response_parts(
|
249
249
|
response: types.GenerateContentResponse,
|
250
|
-
function_map: dict[str, Callable],
|
250
|
+
function_map: dict[str, Callable[..., Any]],
|
251
251
|
) -> list[types.Part]:
|
252
252
|
"""Returns the function response parts from the response."""
|
253
253
|
func_response_parts = []
|
@@ -280,7 +280,7 @@ def get_function_response_parts(
|
|
280
280
|
|
281
281
|
async def get_function_response_parts_async(
|
282
282
|
response: types.GenerateContentResponse,
|
283
|
-
function_map: dict[str, Callable],
|
283
|
+
function_map: dict[str, Callable[..., Any]],
|
284
284
|
) -> list[types.Part]:
|
285
285
|
"""Returns the function response parts from the response."""
|
286
286
|
func_response_parts = []
|