google-genai 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +72 -78
- google/genai/_api_module.py +24 -0
- google/genai/_automatic_function_calling_util.py +43 -22
- google/genai/_common.py +0 -6
- google/genai/_extra_utils.py +22 -16
- google/genai/_replay_api_client.py +2 -2
- google/genai/_test_api_client.py +1 -1
- google/genai/_transformers.py +218 -97
- google/genai/batches.py +194 -155
- google/genai/caches.py +117 -134
- google/genai/chats.py +22 -18
- google/genai/client.py +31 -37
- google/genai/files.py +94 -125
- google/genai/live.py +11 -5
- google/genai/models.py +500 -254
- google/genai/tunings.py +85 -422
- google/genai/types.py +495 -458
- google/genai/version.py +1 -1
- {google_genai-0.6.0.dist-info → google_genai-0.7.0.dist-info}/METADATA +116 -68
- google_genai-0.7.0.dist-info/RECORD +26 -0
- google_genai-0.6.0.dist-info/RECORD +0 -25
- {google_genai-0.6.0.dist-info → google_genai-0.7.0.dist-info}/LICENSE +0 -0
- {google_genai-0.6.0.dist-info → google_genai-0.7.0.dist-info}/WHEEL +0 -0
- {google_genai-0.6.0.dist-info → google_genai-0.7.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -36,55 +36,7 @@ import requests
|
|
36
36
|
|
37
37
|
from . import errors
|
38
38
|
from . import version
|
39
|
-
|
40
|
-
|
41
|
-
class HttpOptions(BaseModel):
|
42
|
-
"""HTTP options for the api client."""
|
43
|
-
model_config = ConfigDict(extra='forbid')
|
44
|
-
|
45
|
-
base_url: Optional[str] = Field(
|
46
|
-
default=None,
|
47
|
-
description="""The base URL for the AI platform service endpoint.""",
|
48
|
-
)
|
49
|
-
api_version: Optional[str] = Field(
|
50
|
-
default=None,
|
51
|
-
description="""Specifies the version of the API to use.""",
|
52
|
-
)
|
53
|
-
headers: Optional[dict[str, str]] = Field(
|
54
|
-
default=None,
|
55
|
-
description="""Additional HTTP headers to be sent with the request.""",
|
56
|
-
)
|
57
|
-
response_payload: Optional[dict] = Field(
|
58
|
-
default=None,
|
59
|
-
description="""If set, the response payload will be returned int the supplied dict.""",
|
60
|
-
)
|
61
|
-
timeout: Optional[Union[float, Tuple[float, float]]] = Field(
|
62
|
-
default=None,
|
63
|
-
description="""Timeout for the request in seconds.""",
|
64
|
-
)
|
65
|
-
skip_project_and_location_in_path: bool = Field(
|
66
|
-
default=False,
|
67
|
-
description="""If set to True, the project and location will not be appended to the path.""",
|
68
|
-
)
|
69
|
-
|
70
|
-
|
71
|
-
class HttpOptionsDict(TypedDict):
|
72
|
-
"""HTTP options for the api client."""
|
73
|
-
|
74
|
-
base_url: Optional[str] = None
|
75
|
-
"""The base URL for the AI platform service endpoint."""
|
76
|
-
api_version: Optional[str] = None
|
77
|
-
"""Specifies the version of the API to use."""
|
78
|
-
headers: Optional[dict[str, str]] = None
|
79
|
-
"""Additional HTTP headers to be sent with the request."""
|
80
|
-
response_payload: Optional[dict] = None
|
81
|
-
"""If set, the response payload will be returned int the supplied dict."""
|
82
|
-
timeout: Optional[Union[float, Tuple[float, float]]] = None
|
83
|
-
"""Timeout for the request in seconds."""
|
84
|
-
skip_project_and_location_in_path: bool = False
|
85
|
-
"""If set to True, the project and location will not be appended to the path."""
|
86
|
-
|
87
|
-
HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
|
39
|
+
from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
|
88
40
|
|
89
41
|
|
90
42
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
@@ -144,7 +96,7 @@ class HttpRequest:
|
|
144
96
|
url: str
|
145
97
|
method: str
|
146
98
|
data: Union[dict[str, object], bytes]
|
147
|
-
timeout: Optional[
|
99
|
+
timeout: Optional[float] = None
|
148
100
|
|
149
101
|
|
150
102
|
class HttpResponse:
|
@@ -159,9 +111,20 @@ class HttpResponse:
|
|
159
111
|
self.headers = headers
|
160
112
|
self.response_stream = response_stream
|
161
113
|
self.byte_stream = byte_stream
|
114
|
+
self.segment_iterator = self.segments()
|
115
|
+
|
116
|
+
# Async iterator for async streaming.
|
117
|
+
def __aiter__(self):
|
118
|
+
return self
|
119
|
+
|
120
|
+
async def __anext__(self):
|
121
|
+
try:
|
122
|
+
return next(self.segment_iterator)
|
123
|
+
except StopIteration:
|
124
|
+
raise StopAsyncIteration
|
162
125
|
|
163
126
|
@property
|
164
|
-
def
|
127
|
+
def json(self) -> Any:
|
165
128
|
if not self.response_stream[0]: # Empty response
|
166
129
|
return ''
|
167
130
|
return json.loads(self.response_stream[0])
|
@@ -194,7 +157,9 @@ class HttpResponse:
|
|
194
157
|
'Byte segments are not supported for streaming responses.'
|
195
158
|
)
|
196
159
|
|
197
|
-
def
|
160
|
+
def _copy_to_dict(self, response_payload: dict[str, object]):
|
161
|
+
# Cannot pickle 'generator' object.
|
162
|
+
delattr(self, 'segment_iterator')
|
198
163
|
for attribute in dir(self):
|
199
164
|
response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
|
200
165
|
|
@@ -325,7 +290,7 @@ class ApiClient:
|
|
325
290
|
http_method: str,
|
326
291
|
path: str,
|
327
292
|
request_dict: dict[str, object],
|
328
|
-
http_options:
|
293
|
+
http_options: HttpOptionsOrDict = None,
|
329
294
|
) -> HttpRequest:
|
330
295
|
# Remove all special dict keys such as _url and _query.
|
331
296
|
keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
|
@@ -333,18 +298,28 @@ class ApiClient:
|
|
333
298
|
del request_dict[key]
|
334
299
|
# patch the http options with the user provided settings.
|
335
300
|
if http_options:
|
336
|
-
|
337
|
-
|
338
|
-
|
301
|
+
if isinstance(http_options, HttpOptions):
|
302
|
+
patched_http_options = _patch_http_options(
|
303
|
+
self._http_options, http_options.model_dump()
|
304
|
+
)
|
305
|
+
else:
|
306
|
+
patched_http_options = _patch_http_options(
|
307
|
+
self._http_options, http_options
|
308
|
+
)
|
339
309
|
else:
|
340
310
|
patched_http_options = self._http_options
|
341
|
-
|
342
|
-
|
343
|
-
|
311
|
+
# Skip adding project and locations when getting Vertex AI base models.
|
312
|
+
query_vertex_base_models = False
|
313
|
+
if (
|
314
|
+
self.vertexai
|
315
|
+
and http_method == 'get'
|
316
|
+
and path.startswith('publishers/google/models')
|
317
|
+
):
|
318
|
+
query_vertex_base_models = True
|
344
319
|
if (
|
345
320
|
self.vertexai
|
346
321
|
and not path.startswith('projects/')
|
347
|
-
and not
|
322
|
+
and not query_vertex_base_models
|
348
323
|
and not self.api_key
|
349
324
|
):
|
350
325
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
@@ -354,12 +329,19 @@ class ApiClient:
|
|
354
329
|
patched_http_options['base_url'],
|
355
330
|
patched_http_options['api_version'] + '/' + path,
|
356
331
|
)
|
332
|
+
|
333
|
+
timeout_in_seconds = patched_http_options.get('timeout', None)
|
334
|
+
if timeout_in_seconds:
|
335
|
+
timeout_in_seconds = timeout_in_seconds / 1000.0
|
336
|
+
else:
|
337
|
+
timeout_in_seconds = None
|
338
|
+
|
357
339
|
return HttpRequest(
|
358
340
|
method=http_method,
|
359
341
|
url=url,
|
360
342
|
headers=patched_http_options['headers'],
|
361
343
|
data=request_dict,
|
362
|
-
timeout=
|
344
|
+
timeout=timeout_in_seconds,
|
363
345
|
)
|
364
346
|
|
365
347
|
def _request(
|
@@ -448,15 +430,24 @@ class ApiClient:
|
|
448
430
|
http_method: str,
|
449
431
|
path: str,
|
450
432
|
request_dict: dict[str, object],
|
451
|
-
http_options:
|
433
|
+
http_options: HttpOptionsOrDict = None,
|
452
434
|
):
|
453
435
|
http_request = self._build_request(
|
454
436
|
http_method, path, request_dict, http_options
|
455
437
|
)
|
456
438
|
response = self._request(http_request, stream=False)
|
457
|
-
if http_options
|
458
|
-
|
459
|
-
|
439
|
+
if http_options:
|
440
|
+
if (
|
441
|
+
isinstance(http_options, HttpOptions)
|
442
|
+
and http_options.deprecated_response_payload is not None
|
443
|
+
):
|
444
|
+
response._copy_to_dict(http_options.deprecated_response_payload)
|
445
|
+
elif (
|
446
|
+
isinstance(http_options, dict)
|
447
|
+
and 'deprecated_response_payload' in http_options
|
448
|
+
):
|
449
|
+
response._copy_to_dict(http_options['deprecated_response_payload'])
|
450
|
+
return response.json
|
460
451
|
|
461
452
|
def request_streamed(
|
462
453
|
self,
|
@@ -470,8 +461,10 @@ class ApiClient:
|
|
470
461
|
)
|
471
462
|
|
472
463
|
session_response = self._request(http_request, stream=True)
|
473
|
-
if http_options and '
|
474
|
-
session_response.
|
464
|
+
if http_options and 'deprecated_response_payload' in http_options:
|
465
|
+
session_response._copy_to_dict(
|
466
|
+
http_options['deprecated_response_payload']
|
467
|
+
)
|
475
468
|
for chunk in session_response.segments():
|
476
469
|
yield chunk
|
477
470
|
|
@@ -487,9 +480,9 @@ class ApiClient:
|
|
487
480
|
)
|
488
481
|
|
489
482
|
result = await self._async_request(http_request=http_request, stream=False)
|
490
|
-
if http_options and '
|
491
|
-
result.
|
492
|
-
return result.
|
483
|
+
if http_options and 'deprecated_response_payload' in http_options:
|
484
|
+
result._copy_to_dict(http_options['deprecated_response_payload'])
|
485
|
+
return result.json
|
493
486
|
|
494
487
|
async def async_request_streamed(
|
495
488
|
self,
|
@@ -504,10 +497,12 @@ class ApiClient:
|
|
504
497
|
|
505
498
|
response = await self._async_request(http_request=http_request, stream=True)
|
506
499
|
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
response
|
500
|
+
if http_options and 'deprecated_response_payload' in http_options:
|
501
|
+
response._copy_to_dict(http_options['deprecated_response_payload'])
|
502
|
+
async def async_generator():
|
503
|
+
async for chunk in response:
|
504
|
+
yield chunk
|
505
|
+
return async_generator()
|
511
506
|
|
512
507
|
def upload_file(
|
513
508
|
self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
|
@@ -575,15 +570,15 @@ class ApiClient:
|
|
575
570
|
if upload_size <= offset: # Status is not finalized.
|
576
571
|
raise ValueError(
|
577
572
|
'All content has been uploaded, but the upload status is not'
|
578
|
-
f' finalized. {response.headers}, body: {response.
|
573
|
+
f' finalized. {response.headers}, body: {response.json}'
|
579
574
|
)
|
580
575
|
|
581
576
|
if response.headers['X-Goog-Upload-Status'] != 'final':
|
582
577
|
raise ValueError(
|
583
578
|
'Failed to upload file: Upload status is not finalized. headers:'
|
584
|
-
f' {response.headers}, body: {response.
|
579
|
+
f' {response.headers}, body: {response.json}'
|
585
580
|
)
|
586
|
-
return response.
|
581
|
+
return response.json
|
587
582
|
|
588
583
|
def download_file(self, path: str, http_options):
|
589
584
|
"""Downloads the file data.
|
@@ -624,7 +619,6 @@ class ApiClient:
|
|
624
619
|
errors.APIError.raise_for_response(response)
|
625
620
|
return HttpResponse(response.headers, byte_stream=[response.content])
|
626
621
|
|
627
|
-
|
628
622
|
async def async_upload_file(
|
629
623
|
self,
|
630
624
|
file_path: Union[str, io.IOBase],
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
"""Utilities for the API Modules of the Google Gen AI SDK."""
|
17
|
+
|
18
|
+
from . import _api_client
|
19
|
+
|
20
|
+
|
21
|
+
class BaseModule:
|
22
|
+
|
23
|
+
def __init__(self, api_client_: _api_client.ApiClient):
|
24
|
+
self._api_client = api_client_
|
@@ -14,11 +14,18 @@
|
|
14
14
|
#
|
15
15
|
|
16
16
|
import inspect
|
17
|
-
import
|
17
|
+
import sys
|
18
|
+
import types as builtin_types
|
19
|
+
import typing
|
18
20
|
from typing import Any, Callable, Literal, Union, _GenericAlias, get_args, get_origin
|
19
21
|
import pydantic
|
20
22
|
from . import types
|
21
23
|
|
24
|
+
if sys.version_info >= (3, 10):
|
25
|
+
UnionType = builtin_types.UnionType
|
26
|
+
else:
|
27
|
+
UnionType = typing._UnionGenericAlias
|
28
|
+
|
22
29
|
_py_builtin_type_to_schema_type = {
|
23
30
|
str: 'STRING',
|
24
31
|
int: 'INTEGER',
|
@@ -58,8 +65,8 @@ def _raise_for_nullable_if_mldev(schema: types.Schema):
|
|
58
65
|
)
|
59
66
|
|
60
67
|
|
61
|
-
def _raise_if_schema_unsupported(
|
62
|
-
if not
|
68
|
+
def _raise_if_schema_unsupported(client, schema: types.Schema):
|
69
|
+
if not client.vertexai:
|
63
70
|
_raise_for_any_of_if_mldev(schema)
|
64
71
|
_raise_for_default_if_mldev(schema)
|
65
72
|
_raise_for_nullable_if_mldev(schema)
|
@@ -74,11 +81,11 @@ def _is_default_value_compatible(
|
|
74
81
|
|
75
82
|
if (
|
76
83
|
isinstance(annotation, _GenericAlias)
|
77
|
-
or isinstance(annotation,
|
78
|
-
or isinstance(annotation,
|
84
|
+
or isinstance(annotation, builtin_types.GenericAlias)
|
85
|
+
or isinstance(annotation, UnionType)
|
79
86
|
):
|
80
87
|
origin = get_origin(annotation)
|
81
|
-
if origin in (Union,
|
88
|
+
if origin in (Union, UnionType):
|
82
89
|
return any(
|
83
90
|
_is_default_value_compatible(default_value, arg)
|
84
91
|
for arg in get_args(annotation)
|
@@ -107,12 +114,13 @@ def _is_default_value_compatible(
|
|
107
114
|
return default_value in get_args(annotation)
|
108
115
|
|
109
116
|
# return False for any other unrecognized annotation
|
110
|
-
# let caller handle the raise
|
111
117
|
return False
|
112
118
|
|
113
119
|
|
114
120
|
def _parse_schema_from_parameter(
|
115
|
-
|
121
|
+
client,
|
122
|
+
param: inspect.Parameter,
|
123
|
+
func_name: str,
|
116
124
|
) -> types.Schema:
|
117
125
|
"""parse schema from parameter.
|
118
126
|
|
@@ -130,12 +138,12 @@ def _parse_schema_from_parameter(
|
|
130
138
|
raise ValueError(default_value_error_msg)
|
131
139
|
schema.default = param.default
|
132
140
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
133
|
-
_raise_if_schema_unsupported(
|
141
|
+
_raise_if_schema_unsupported(client, schema)
|
134
142
|
return schema
|
135
143
|
if (
|
136
|
-
isinstance(param.annotation,
|
144
|
+
isinstance(param.annotation, UnionType)
|
137
145
|
# only parse simple UnionType, example int | str | float | bool
|
138
|
-
# complex
|
146
|
+
# complex UnionType will be invoked in raise branch
|
139
147
|
and all(
|
140
148
|
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
|
141
149
|
for arg in get_args(param.annotation)
|
@@ -149,7 +157,7 @@ def _parse_schema_from_parameter(
|
|
149
157
|
schema.nullable = True
|
150
158
|
continue
|
151
159
|
schema_in_any_of = _parse_schema_from_parameter(
|
152
|
-
|
160
|
+
client,
|
153
161
|
inspect.Parameter(
|
154
162
|
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
155
163
|
),
|
@@ -171,10 +179,10 @@ def _parse_schema_from_parameter(
|
|
171
179
|
if not _is_default_value_compatible(param.default, param.annotation):
|
172
180
|
raise ValueError(default_value_error_msg)
|
173
181
|
schema.default = param.default
|
174
|
-
_raise_if_schema_unsupported(
|
182
|
+
_raise_if_schema_unsupported(client, schema)
|
175
183
|
return schema
|
176
184
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
177
|
-
param.annotation,
|
185
|
+
param.annotation, builtin_types.GenericAlias
|
178
186
|
):
|
179
187
|
origin = get_origin(param.annotation)
|
180
188
|
args = get_args(param.annotation)
|
@@ -184,7 +192,7 @@ def _parse_schema_from_parameter(
|
|
184
192
|
if not _is_default_value_compatible(param.default, param.annotation):
|
185
193
|
raise ValueError(default_value_error_msg)
|
186
194
|
schema.default = param.default
|
187
|
-
_raise_if_schema_unsupported(
|
195
|
+
_raise_if_schema_unsupported(client, schema)
|
188
196
|
return schema
|
189
197
|
if origin is Literal:
|
190
198
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -197,12 +205,12 @@ def _parse_schema_from_parameter(
|
|
197
205
|
if not _is_default_value_compatible(param.default, param.annotation):
|
198
206
|
raise ValueError(default_value_error_msg)
|
199
207
|
schema.default = param.default
|
200
|
-
_raise_if_schema_unsupported(
|
208
|
+
_raise_if_schema_unsupported(client, schema)
|
201
209
|
return schema
|
202
210
|
if origin is list:
|
203
211
|
schema.type = 'ARRAY'
|
204
212
|
schema.items = _parse_schema_from_parameter(
|
205
|
-
|
213
|
+
client,
|
206
214
|
inspect.Parameter(
|
207
215
|
'item',
|
208
216
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -214,7 +222,7 @@ def _parse_schema_from_parameter(
|
|
214
222
|
if not _is_default_value_compatible(param.default, param.annotation):
|
215
223
|
raise ValueError(default_value_error_msg)
|
216
224
|
schema.default = param.default
|
217
|
-
_raise_if_schema_unsupported(
|
225
|
+
_raise_if_schema_unsupported(client, schema)
|
218
226
|
return schema
|
219
227
|
if origin is Union:
|
220
228
|
schema.any_of = []
|
@@ -225,7 +233,7 @@ def _parse_schema_from_parameter(
|
|
225
233
|
schema.nullable = True
|
226
234
|
continue
|
227
235
|
schema_in_any_of = _parse_schema_from_parameter(
|
228
|
-
|
236
|
+
client,
|
229
237
|
inspect.Parameter(
|
230
238
|
'item',
|
231
239
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -233,6 +241,17 @@ def _parse_schema_from_parameter(
|
|
233
241
|
),
|
234
242
|
func_name,
|
235
243
|
)
|
244
|
+
if (
|
245
|
+
len(param.annotation.__args__) == 2
|
246
|
+
and type(None) in param.annotation.__args__
|
247
|
+
): # Optional type
|
248
|
+
for optional_arg in param.annotation.__args__:
|
249
|
+
if (
|
250
|
+
hasattr(optional_arg, '__origin__')
|
251
|
+
and optional_arg.__origin__ is list
|
252
|
+
):
|
253
|
+
# Optional type with list, for example Optional[list[str]]
|
254
|
+
schema.items = schema_in_any_of.items
|
236
255
|
if (
|
237
256
|
schema_in_any_of.model_dump_json(exclude_none=True)
|
238
257
|
not in unique_types
|
@@ -249,7 +268,7 @@ def _parse_schema_from_parameter(
|
|
249
268
|
if not _is_default_value_compatible(param.default, param.annotation):
|
250
269
|
raise ValueError(default_value_error_msg)
|
251
270
|
schema.default = param.default
|
252
|
-
_raise_if_schema_unsupported(
|
271
|
+
_raise_if_schema_unsupported(client, schema)
|
253
272
|
return schema
|
254
273
|
# all other generic alias will be invoked in raise branch
|
255
274
|
if (
|
@@ -266,7 +285,7 @@ def _parse_schema_from_parameter(
|
|
266
285
|
schema.properties = {}
|
267
286
|
for field_name, field_info in param.annotation.model_fields.items():
|
268
287
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
269
|
-
|
288
|
+
client,
|
270
289
|
inspect.Parameter(
|
271
290
|
field_name,
|
272
291
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -274,7 +293,9 @@ def _parse_schema_from_parameter(
|
|
274
293
|
),
|
275
294
|
func_name,
|
276
295
|
)
|
277
|
-
|
296
|
+
if client.vertexai:
|
297
|
+
schema.required = _get_required_fields(schema)
|
298
|
+
_raise_if_schema_unsupported(client, schema)
|
278
299
|
return schema
|
279
300
|
raise ValueError(
|
280
301
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
google/genai/_common.py
CHANGED
@@ -113,12 +113,6 @@ def get_value_by_path(data: object, keys: list[str]):
|
|
113
113
|
return data
|
114
114
|
|
115
115
|
|
116
|
-
class BaseModule:
|
117
|
-
|
118
|
-
def __init__(self, api_client_: _api_client.ApiClient):
|
119
|
-
self._api_client = api_client_
|
120
|
-
|
121
|
-
|
122
116
|
def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
|
123
117
|
"""Recursively converts a given object to a dictionary.
|
124
118
|
|
google/genai/_extra_utils.py
CHANGED
@@ -13,12 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
#
|
15
15
|
|
16
|
-
"""Extra utils depending on types that are shared between sync and async modules.
|
17
|
-
"""
|
16
|
+
"""Extra utils depending on types that are shared between sync and async modules."""
|
18
17
|
|
19
18
|
import inspect
|
20
19
|
import logging
|
21
|
-
|
20
|
+
import typing
|
21
|
+
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
22
|
+
import sys
|
22
23
|
|
23
24
|
import pydantic
|
24
25
|
|
@@ -26,6 +27,10 @@ from . import _common
|
|
26
27
|
from . import errors
|
27
28
|
from . import types
|
28
29
|
|
30
|
+
if sys.version_info >= (3, 10):
|
31
|
+
from types import UnionType
|
32
|
+
else:
|
33
|
+
UnionType = typing._UnionGenericAlias
|
29
34
|
|
30
35
|
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
31
36
|
|
@@ -78,8 +83,8 @@ def get_function_map(
|
|
78
83
|
if inspect.iscoroutinefunction(tool):
|
79
84
|
raise errors.UnsupportedFunctionError(
|
80
85
|
f'Function {tool.__name__} is a coroutine function, which is not'
|
81
|
-
' supported for automatic function calling. Please manually
|
82
|
-
f' {tool.__name__} to get the function response.'
|
86
|
+
' supported for automatic function calling. Please manually'
|
87
|
+
f' invoke {tool.__name__} to get the function response.'
|
83
88
|
)
|
84
89
|
function_map[tool.__name__] = tool
|
85
90
|
return function_map
|
@@ -135,11 +140,13 @@ def convert_if_exist_pydantic_model(
|
|
135
140
|
for k, v in value.items()
|
136
141
|
}
|
137
142
|
# example 1: typing.Union[int, float]
|
138
|
-
# example 2: int | float equivalent to
|
139
|
-
if get_origin(annotation) in (Union,
|
143
|
+
# example 2: int | float equivalent to UnionType[int, float]
|
144
|
+
if get_origin(annotation) in (Union, UnionType):
|
140
145
|
for arg in get_args(annotation):
|
141
|
-
if
|
142
|
-
|
146
|
+
if (
|
147
|
+
(get_args(arg) and get_origin(arg) is list)
|
148
|
+
or isinstance(value, arg)
|
149
|
+
or (isinstance(value, dict) and _is_annotation_pydantic_model(arg))
|
143
150
|
):
|
144
151
|
try:
|
145
152
|
return convert_if_exist_pydantic_model(
|
@@ -209,7 +216,9 @@ def get_function_response_parts(
|
|
209
216
|
response = {'result': invoke_function_from_dict_args(args, func)}
|
210
217
|
except Exception as e: # pylint: disable=broad-except
|
211
218
|
response = {'error': str(e)}
|
212
|
-
func_response = types.Part.from_function_response(
|
219
|
+
func_response = types.Part.from_function_response(
|
220
|
+
name=func_name, response=response
|
221
|
+
)
|
213
222
|
|
214
223
|
func_response_parts.append(func_response)
|
215
224
|
return func_response_parts
|
@@ -231,8 +240,7 @@ def should_disable_afc(
|
|
231
240
|
and config_model.automatic_function_calling
|
232
241
|
and config_model.automatic_function_calling.maximum_remote_calls
|
233
242
|
is not None
|
234
|
-
and int(config_model.automatic_function_calling.maximum_remote_calls)
|
235
|
-
<= 0
|
243
|
+
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
|
236
244
|
):
|
237
245
|
logging.warning(
|
238
246
|
'max_remote_calls in automatic_function_calling_config'
|
@@ -294,6 +302,7 @@ def get_max_remote_calls_afc(
|
|
294
302
|
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
295
303
|
return int(config_model.automatic_function_calling.maximum_remote_calls)
|
296
304
|
|
305
|
+
|
297
306
|
def should_append_afc_history(
|
298
307
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
299
308
|
) -> bool:
|
@@ -302,9 +311,6 @@ def should_append_afc_history(
|
|
302
311
|
if config and isinstance(config, dict)
|
303
312
|
else config
|
304
313
|
)
|
305
|
-
if
|
306
|
-
not config_model
|
307
|
-
or not config_model.automatic_function_calling
|
308
|
-
):
|
314
|
+
if not config_model or not config_model.automatic_function_calling:
|
309
315
|
return True
|
310
316
|
return not config_model.automatic_function_calling.ignore_call_history
|
@@ -397,7 +397,7 @@ class ReplayApiClient(ApiClient):
|
|
397
397
|
# segments since the stream has been consumed.
|
398
398
|
else:
|
399
399
|
self._record_interaction(http_request, result)
|
400
|
-
_debug_print('api mode result: %s' % result.
|
400
|
+
_debug_print('api mode result: %s' % result.json)
|
401
401
|
return result
|
402
402
|
else:
|
403
403
|
return self._build_response_from_replay(http_request)
|
@@ -429,7 +429,7 @@ class ReplayApiClient(ApiClient):
|
|
429
429
|
self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
|
430
430
|
return result
|
431
431
|
else:
|
432
|
-
return self._build_response_from_replay(request).
|
432
|
+
return self._build_response_from_replay(request).json
|
433
433
|
|
434
434
|
def _download_file_request(self, request):
|
435
435
|
self._initialize_replay_session_if_not_loaded()
|
google/genai/_test_api_client.py
CHANGED
@@ -132,7 +132,7 @@ async def test_async_request_streamed_non_blocking(
|
|
132
132
|
|
133
133
|
chunks = []
|
134
134
|
start_time = time.time()
|
135
|
-
async for chunk in api_client.async_request_streamed(
|
135
|
+
async for chunk in await api_client.async_request_streamed(
|
136
136
|
http_method, path, request_dict
|
137
137
|
):
|
138
138
|
chunks.append(chunk)
|