google-genai 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +74 -82
- google/genai/_api_module.py +24 -0
- google/genai/_automatic_function_calling_util.py +43 -22
- google/genai/_common.py +11 -8
- google/genai/_extra_utils.py +22 -16
- google/genai/_operations.py +365 -0
- google/genai/_replay_api_client.py +7 -2
- google/genai/_test_api_client.py +1 -1
- google/genai/_transformers.py +218 -97
- google/genai/batches.py +194 -155
- google/genai/caches.py +117 -134
- google/genai/chats.py +22 -18
- google/genai/client.py +31 -37
- google/genai/files.py +154 -183
- google/genai/live.py +11 -5
- google/genai/models.py +506 -254
- google/genai/tunings.py +85 -422
- google/genai/types.py +647 -458
- google/genai/version.py +1 -1
- {google_genai-0.6.0.dist-info → google_genai-0.8.0.dist-info}/METADATA +119 -70
- google_genai-0.8.0.dist-info/RECORD +27 -0
- google_genai-0.6.0.dist-info/RECORD +0 -25
- {google_genai-0.6.0.dist-info → google_genai-0.8.0.dist-info}/LICENSE +0 -0
- {google_genai-0.6.0.dist-info → google_genai-0.8.0.dist-info}/WHEEL +0 -0
- {google_genai-0.6.0.dist-info → google_genai-0.8.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -36,55 +36,7 @@ import requests
|
|
36
36
|
|
37
37
|
from . import errors
|
38
38
|
from . import version
|
39
|
-
|
40
|
-
|
41
|
-
class HttpOptions(BaseModel):
|
42
|
-
"""HTTP options for the api client."""
|
43
|
-
model_config = ConfigDict(extra='forbid')
|
44
|
-
|
45
|
-
base_url: Optional[str] = Field(
|
46
|
-
default=None,
|
47
|
-
description="""The base URL for the AI platform service endpoint.""",
|
48
|
-
)
|
49
|
-
api_version: Optional[str] = Field(
|
50
|
-
default=None,
|
51
|
-
description="""Specifies the version of the API to use.""",
|
52
|
-
)
|
53
|
-
headers: Optional[dict[str, str]] = Field(
|
54
|
-
default=None,
|
55
|
-
description="""Additional HTTP headers to be sent with the request.""",
|
56
|
-
)
|
57
|
-
response_payload: Optional[dict] = Field(
|
58
|
-
default=None,
|
59
|
-
description="""If set, the response payload will be returned int the supplied dict.""",
|
60
|
-
)
|
61
|
-
timeout: Optional[Union[float, Tuple[float, float]]] = Field(
|
62
|
-
default=None,
|
63
|
-
description="""Timeout for the request in seconds.""",
|
64
|
-
)
|
65
|
-
skip_project_and_location_in_path: bool = Field(
|
66
|
-
default=False,
|
67
|
-
description="""If set to True, the project and location will not be appended to the path.""",
|
68
|
-
)
|
69
|
-
|
70
|
-
|
71
|
-
class HttpOptionsDict(TypedDict):
|
72
|
-
"""HTTP options for the api client."""
|
73
|
-
|
74
|
-
base_url: Optional[str] = None
|
75
|
-
"""The base URL for the AI platform service endpoint."""
|
76
|
-
api_version: Optional[str] = None
|
77
|
-
"""Specifies the version of the API to use."""
|
78
|
-
headers: Optional[dict[str, str]] = None
|
79
|
-
"""Additional HTTP headers to be sent with the request."""
|
80
|
-
response_payload: Optional[dict] = None
|
81
|
-
"""If set, the response payload will be returned int the supplied dict."""
|
82
|
-
timeout: Optional[Union[float, Tuple[float, float]]] = None
|
83
|
-
"""Timeout for the request in seconds."""
|
84
|
-
skip_project_and_location_in_path: bool = False
|
85
|
-
"""If set to True, the project and location will not be appended to the path."""
|
86
|
-
|
87
|
-
HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
|
39
|
+
from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
|
88
40
|
|
89
41
|
|
90
42
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
@@ -144,7 +96,7 @@ class HttpRequest:
|
|
144
96
|
url: str
|
145
97
|
method: str
|
146
98
|
data: Union[dict[str, object], bytes]
|
147
|
-
timeout: Optional[
|
99
|
+
timeout: Optional[float] = None
|
148
100
|
|
149
101
|
|
150
102
|
class HttpResponse:
|
@@ -159,9 +111,20 @@ class HttpResponse:
|
|
159
111
|
self.headers = headers
|
160
112
|
self.response_stream = response_stream
|
161
113
|
self.byte_stream = byte_stream
|
114
|
+
self.segment_iterator = self.segments()
|
115
|
+
|
116
|
+
# Async iterator for async streaming.
|
117
|
+
def __aiter__(self):
|
118
|
+
return self
|
119
|
+
|
120
|
+
async def __anext__(self):
|
121
|
+
try:
|
122
|
+
return next(self.segment_iterator)
|
123
|
+
except StopIteration:
|
124
|
+
raise StopAsyncIteration
|
162
125
|
|
163
126
|
@property
|
164
|
-
def
|
127
|
+
def json(self) -> Any:
|
165
128
|
if not self.response_stream[0]: # Empty response
|
166
129
|
return ''
|
167
130
|
return json.loads(self.response_stream[0])
|
@@ -194,7 +157,9 @@ class HttpResponse:
|
|
194
157
|
'Byte segments are not supported for streaming responses.'
|
195
158
|
)
|
196
159
|
|
197
|
-
def
|
160
|
+
def _copy_to_dict(self, response_payload: dict[str, object]):
|
161
|
+
# Cannot pickle 'generator' object.
|
162
|
+
delattr(self, 'segment_iterator')
|
198
163
|
for attribute in dir(self):
|
199
164
|
response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
|
200
165
|
|
@@ -290,7 +255,7 @@ class ApiClient:
|
|
290
255
|
'Project and location or API key must be set when using the Vertex '
|
291
256
|
'AI API.'
|
292
257
|
)
|
293
|
-
if self.api_key:
|
258
|
+
if self.api_key or self.location == 'global':
|
294
259
|
self._http_options['base_url'] = (
|
295
260
|
f'https://aiplatform.googleapis.com/'
|
296
261
|
)
|
@@ -308,7 +273,7 @@ class ApiClient:
|
|
308
273
|
self._http_options['api_version'] = 'v1beta'
|
309
274
|
# Default options for both clients.
|
310
275
|
self._http_options['headers'] = {'Content-Type': 'application/json'}
|
311
|
-
if self.api_key
|
276
|
+
if self.api_key:
|
312
277
|
self._http_options['headers']['x-goog-api-key'] = self.api_key
|
313
278
|
# Update the http options with the user provided http options.
|
314
279
|
if http_options:
|
@@ -325,7 +290,7 @@ class ApiClient:
|
|
325
290
|
http_method: str,
|
326
291
|
path: str,
|
327
292
|
request_dict: dict[str, object],
|
328
|
-
http_options:
|
293
|
+
http_options: HttpOptionsOrDict = None,
|
329
294
|
) -> HttpRequest:
|
330
295
|
# Remove all special dict keys such as _url and _query.
|
331
296
|
keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
|
@@ -333,33 +298,48 @@ class ApiClient:
|
|
333
298
|
del request_dict[key]
|
334
299
|
# patch the http options with the user provided settings.
|
335
300
|
if http_options:
|
336
|
-
|
337
|
-
|
338
|
-
|
301
|
+
if isinstance(http_options, HttpOptions):
|
302
|
+
patched_http_options = _patch_http_options(
|
303
|
+
self._http_options, http_options.model_dump()
|
304
|
+
)
|
305
|
+
else:
|
306
|
+
patched_http_options = _patch_http_options(
|
307
|
+
self._http_options, http_options
|
308
|
+
)
|
339
309
|
else:
|
340
310
|
patched_http_options = self._http_options
|
341
|
-
|
342
|
-
|
343
|
-
|
311
|
+
# Skip adding project and locations when getting Vertex AI base models.
|
312
|
+
query_vertex_base_models = False
|
313
|
+
if (
|
314
|
+
self.vertexai
|
315
|
+
and http_method == 'get'
|
316
|
+
and path.startswith('publishers/google/models')
|
317
|
+
):
|
318
|
+
query_vertex_base_models = True
|
344
319
|
if (
|
345
320
|
self.vertexai
|
346
321
|
and not path.startswith('projects/')
|
347
|
-
and not
|
322
|
+
and not query_vertex_base_models
|
348
323
|
and not self.api_key
|
349
324
|
):
|
350
325
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
351
|
-
elif self.vertexai and self.api_key:
|
352
|
-
path = f'{path}?key={self.api_key}'
|
353
326
|
url = _join_url_path(
|
354
327
|
patched_http_options['base_url'],
|
355
328
|
patched_http_options['api_version'] + '/' + path,
|
356
329
|
)
|
330
|
+
|
331
|
+
timeout_in_seconds = patched_http_options.get('timeout', None)
|
332
|
+
if timeout_in_seconds:
|
333
|
+
timeout_in_seconds = timeout_in_seconds / 1000.0
|
334
|
+
else:
|
335
|
+
timeout_in_seconds = None
|
336
|
+
|
357
337
|
return HttpRequest(
|
358
338
|
method=http_method,
|
359
339
|
url=url,
|
360
340
|
headers=patched_http_options['headers'],
|
361
341
|
data=request_dict,
|
362
|
-
timeout=
|
342
|
+
timeout=timeout_in_seconds,
|
363
343
|
)
|
364
344
|
|
365
345
|
def _request(
|
@@ -448,15 +428,24 @@ class ApiClient:
|
|
448
428
|
http_method: str,
|
449
429
|
path: str,
|
450
430
|
request_dict: dict[str, object],
|
451
|
-
http_options:
|
431
|
+
http_options: HttpOptionsOrDict = None,
|
452
432
|
):
|
453
433
|
http_request = self._build_request(
|
454
434
|
http_method, path, request_dict, http_options
|
455
435
|
)
|
456
436
|
response = self._request(http_request, stream=False)
|
457
|
-
if http_options
|
458
|
-
|
459
|
-
|
437
|
+
if http_options:
|
438
|
+
if (
|
439
|
+
isinstance(http_options, HttpOptions)
|
440
|
+
and http_options.deprecated_response_payload is not None
|
441
|
+
):
|
442
|
+
response._copy_to_dict(http_options.deprecated_response_payload)
|
443
|
+
elif (
|
444
|
+
isinstance(http_options, dict)
|
445
|
+
and 'deprecated_response_payload' in http_options
|
446
|
+
):
|
447
|
+
response._copy_to_dict(http_options['deprecated_response_payload'])
|
448
|
+
return response.json
|
460
449
|
|
461
450
|
def request_streamed(
|
462
451
|
self,
|
@@ -470,8 +459,10 @@ class ApiClient:
|
|
470
459
|
)
|
471
460
|
|
472
461
|
session_response = self._request(http_request, stream=True)
|
473
|
-
if http_options and '
|
474
|
-
session_response.
|
462
|
+
if http_options and 'deprecated_response_payload' in http_options:
|
463
|
+
session_response._copy_to_dict(
|
464
|
+
http_options['deprecated_response_payload']
|
465
|
+
)
|
475
466
|
for chunk in session_response.segments():
|
476
467
|
yield chunk
|
477
468
|
|
@@ -487,9 +478,9 @@ class ApiClient:
|
|
487
478
|
)
|
488
479
|
|
489
480
|
result = await self._async_request(http_request=http_request, stream=False)
|
490
|
-
if http_options and '
|
491
|
-
result.
|
492
|
-
return result.
|
481
|
+
if http_options and 'deprecated_response_payload' in http_options:
|
482
|
+
result._copy_to_dict(http_options['deprecated_response_payload'])
|
483
|
+
return result.json
|
493
484
|
|
494
485
|
async def async_request_streamed(
|
495
486
|
self,
|
@@ -504,10 +495,12 @@ class ApiClient:
|
|
504
495
|
|
505
496
|
response = await self._async_request(http_request=http_request, stream=True)
|
506
497
|
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
response
|
498
|
+
if http_options and 'deprecated_response_payload' in http_options:
|
499
|
+
response._copy_to_dict(http_options['deprecated_response_payload'])
|
500
|
+
async def async_generator():
|
501
|
+
async for chunk in response:
|
502
|
+
yield chunk
|
503
|
+
return async_generator()
|
511
504
|
|
512
505
|
def upload_file(
|
513
506
|
self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
|
@@ -575,15 +568,15 @@ class ApiClient:
|
|
575
568
|
if upload_size <= offset: # Status is not finalized.
|
576
569
|
raise ValueError(
|
577
570
|
'All content has been uploaded, but the upload status is not'
|
578
|
-
f' finalized. {response.headers}, body: {response.
|
571
|
+
f' finalized. {response.headers}, body: {response.json}'
|
579
572
|
)
|
580
573
|
|
581
574
|
if response.headers['X-Goog-Upload-Status'] != 'final':
|
582
575
|
raise ValueError(
|
583
576
|
'Failed to upload file: Upload status is not finalized. headers:'
|
584
|
-
f' {response.headers}, body: {response.
|
577
|
+
f' {response.headers}, body: {response.json}'
|
585
578
|
)
|
586
|
-
return response.
|
579
|
+
return response.json
|
587
580
|
|
588
581
|
def download_file(self, path: str, http_options):
|
589
582
|
"""Downloads the file data.
|
@@ -624,7 +617,6 @@ class ApiClient:
|
|
624
617
|
errors.APIError.raise_for_response(response)
|
625
618
|
return HttpResponse(response.headers, byte_stream=[response.content])
|
626
619
|
|
627
|
-
|
628
620
|
async def async_upload_file(
|
629
621
|
self,
|
630
622
|
file_path: Union[str, io.IOBase],
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
"""Utilities for the API Modules of the Google Gen AI SDK."""
|
17
|
+
|
18
|
+
from . import _api_client
|
19
|
+
|
20
|
+
|
21
|
+
class BaseModule:
|
22
|
+
|
23
|
+
def __init__(self, api_client_: _api_client.ApiClient):
|
24
|
+
self._api_client = api_client_
|
@@ -14,11 +14,18 @@
|
|
14
14
|
#
|
15
15
|
|
16
16
|
import inspect
|
17
|
-
import
|
17
|
+
import sys
|
18
|
+
import types as builtin_types
|
19
|
+
import typing
|
18
20
|
from typing import Any, Callable, Literal, Union, _GenericAlias, get_args, get_origin
|
19
21
|
import pydantic
|
20
22
|
from . import types
|
21
23
|
|
24
|
+
if sys.version_info >= (3, 10):
|
25
|
+
UnionType = builtin_types.UnionType
|
26
|
+
else:
|
27
|
+
UnionType = typing._UnionGenericAlias
|
28
|
+
|
22
29
|
_py_builtin_type_to_schema_type = {
|
23
30
|
str: 'STRING',
|
24
31
|
int: 'INTEGER',
|
@@ -58,8 +65,8 @@ def _raise_for_nullable_if_mldev(schema: types.Schema):
|
|
58
65
|
)
|
59
66
|
|
60
67
|
|
61
|
-
def _raise_if_schema_unsupported(
|
62
|
-
if not
|
68
|
+
def _raise_if_schema_unsupported(client, schema: types.Schema):
|
69
|
+
if not client.vertexai:
|
63
70
|
_raise_for_any_of_if_mldev(schema)
|
64
71
|
_raise_for_default_if_mldev(schema)
|
65
72
|
_raise_for_nullable_if_mldev(schema)
|
@@ -74,11 +81,11 @@ def _is_default_value_compatible(
|
|
74
81
|
|
75
82
|
if (
|
76
83
|
isinstance(annotation, _GenericAlias)
|
77
|
-
or isinstance(annotation,
|
78
|
-
or isinstance(annotation,
|
84
|
+
or isinstance(annotation, builtin_types.GenericAlias)
|
85
|
+
or isinstance(annotation, UnionType)
|
79
86
|
):
|
80
87
|
origin = get_origin(annotation)
|
81
|
-
if origin in (Union,
|
88
|
+
if origin in (Union, UnionType):
|
82
89
|
return any(
|
83
90
|
_is_default_value_compatible(default_value, arg)
|
84
91
|
for arg in get_args(annotation)
|
@@ -107,12 +114,13 @@ def _is_default_value_compatible(
|
|
107
114
|
return default_value in get_args(annotation)
|
108
115
|
|
109
116
|
# return False for any other unrecognized annotation
|
110
|
-
# let caller handle the raise
|
111
117
|
return False
|
112
118
|
|
113
119
|
|
114
120
|
def _parse_schema_from_parameter(
|
115
|
-
|
121
|
+
client,
|
122
|
+
param: inspect.Parameter,
|
123
|
+
func_name: str,
|
116
124
|
) -> types.Schema:
|
117
125
|
"""parse schema from parameter.
|
118
126
|
|
@@ -130,12 +138,12 @@ def _parse_schema_from_parameter(
|
|
130
138
|
raise ValueError(default_value_error_msg)
|
131
139
|
schema.default = param.default
|
132
140
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
133
|
-
_raise_if_schema_unsupported(
|
141
|
+
_raise_if_schema_unsupported(client, schema)
|
134
142
|
return schema
|
135
143
|
if (
|
136
|
-
isinstance(param.annotation,
|
144
|
+
isinstance(param.annotation, UnionType)
|
137
145
|
# only parse simple UnionType, example int | str | float | bool
|
138
|
-
# complex
|
146
|
+
# complex UnionType will be invoked in raise branch
|
139
147
|
and all(
|
140
148
|
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
|
141
149
|
for arg in get_args(param.annotation)
|
@@ -149,7 +157,7 @@ def _parse_schema_from_parameter(
|
|
149
157
|
schema.nullable = True
|
150
158
|
continue
|
151
159
|
schema_in_any_of = _parse_schema_from_parameter(
|
152
|
-
|
160
|
+
client,
|
153
161
|
inspect.Parameter(
|
154
162
|
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
155
163
|
),
|
@@ -171,10 +179,10 @@ def _parse_schema_from_parameter(
|
|
171
179
|
if not _is_default_value_compatible(param.default, param.annotation):
|
172
180
|
raise ValueError(default_value_error_msg)
|
173
181
|
schema.default = param.default
|
174
|
-
_raise_if_schema_unsupported(
|
182
|
+
_raise_if_schema_unsupported(client, schema)
|
175
183
|
return schema
|
176
184
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
177
|
-
param.annotation,
|
185
|
+
param.annotation, builtin_types.GenericAlias
|
178
186
|
):
|
179
187
|
origin = get_origin(param.annotation)
|
180
188
|
args = get_args(param.annotation)
|
@@ -184,7 +192,7 @@ def _parse_schema_from_parameter(
|
|
184
192
|
if not _is_default_value_compatible(param.default, param.annotation):
|
185
193
|
raise ValueError(default_value_error_msg)
|
186
194
|
schema.default = param.default
|
187
|
-
_raise_if_schema_unsupported(
|
195
|
+
_raise_if_schema_unsupported(client, schema)
|
188
196
|
return schema
|
189
197
|
if origin is Literal:
|
190
198
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -197,12 +205,12 @@ def _parse_schema_from_parameter(
|
|
197
205
|
if not _is_default_value_compatible(param.default, param.annotation):
|
198
206
|
raise ValueError(default_value_error_msg)
|
199
207
|
schema.default = param.default
|
200
|
-
_raise_if_schema_unsupported(
|
208
|
+
_raise_if_schema_unsupported(client, schema)
|
201
209
|
return schema
|
202
210
|
if origin is list:
|
203
211
|
schema.type = 'ARRAY'
|
204
212
|
schema.items = _parse_schema_from_parameter(
|
205
|
-
|
213
|
+
client,
|
206
214
|
inspect.Parameter(
|
207
215
|
'item',
|
208
216
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -214,7 +222,7 @@ def _parse_schema_from_parameter(
|
|
214
222
|
if not _is_default_value_compatible(param.default, param.annotation):
|
215
223
|
raise ValueError(default_value_error_msg)
|
216
224
|
schema.default = param.default
|
217
|
-
_raise_if_schema_unsupported(
|
225
|
+
_raise_if_schema_unsupported(client, schema)
|
218
226
|
return schema
|
219
227
|
if origin is Union:
|
220
228
|
schema.any_of = []
|
@@ -225,7 +233,7 @@ def _parse_schema_from_parameter(
|
|
225
233
|
schema.nullable = True
|
226
234
|
continue
|
227
235
|
schema_in_any_of = _parse_schema_from_parameter(
|
228
|
-
|
236
|
+
client,
|
229
237
|
inspect.Parameter(
|
230
238
|
'item',
|
231
239
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -233,6 +241,17 @@ def _parse_schema_from_parameter(
|
|
233
241
|
),
|
234
242
|
func_name,
|
235
243
|
)
|
244
|
+
if (
|
245
|
+
len(param.annotation.__args__) == 2
|
246
|
+
and type(None) in param.annotation.__args__
|
247
|
+
): # Optional type
|
248
|
+
for optional_arg in param.annotation.__args__:
|
249
|
+
if (
|
250
|
+
hasattr(optional_arg, '__origin__')
|
251
|
+
and optional_arg.__origin__ is list
|
252
|
+
):
|
253
|
+
# Optional type with list, for example Optional[list[str]]
|
254
|
+
schema.items = schema_in_any_of.items
|
236
255
|
if (
|
237
256
|
schema_in_any_of.model_dump_json(exclude_none=True)
|
238
257
|
not in unique_types
|
@@ -249,7 +268,7 @@ def _parse_schema_from_parameter(
|
|
249
268
|
if not _is_default_value_compatible(param.default, param.annotation):
|
250
269
|
raise ValueError(default_value_error_msg)
|
251
270
|
schema.default = param.default
|
252
|
-
_raise_if_schema_unsupported(
|
271
|
+
_raise_if_schema_unsupported(client, schema)
|
253
272
|
return schema
|
254
273
|
# all other generic alias will be invoked in raise branch
|
255
274
|
if (
|
@@ -266,7 +285,7 @@ def _parse_schema_from_parameter(
|
|
266
285
|
schema.properties = {}
|
267
286
|
for field_name, field_info in param.annotation.model_fields.items():
|
268
287
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
269
|
-
|
288
|
+
client,
|
270
289
|
inspect.Parameter(
|
271
290
|
field_name,
|
272
291
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -274,7 +293,9 @@ def _parse_schema_from_parameter(
|
|
274
293
|
),
|
275
294
|
func_name,
|
276
295
|
)
|
277
|
-
|
296
|
+
if client.vertexai:
|
297
|
+
schema.required = _get_required_fields(schema)
|
298
|
+
_raise_if_schema_unsupported(client, schema)
|
278
299
|
return schema
|
279
300
|
raise ValueError(
|
280
301
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
google/genai/_common.py
CHANGED
@@ -21,6 +21,7 @@ import enum
|
|
21
21
|
import typing
|
22
22
|
from typing import Union
|
23
23
|
import uuid
|
24
|
+
import warnings
|
24
25
|
|
25
26
|
import pydantic
|
26
27
|
from pydantic import alias_generators
|
@@ -113,12 +114,6 @@ def get_value_by_path(data: object, keys: list[str]):
|
|
113
114
|
return data
|
114
115
|
|
115
116
|
|
116
|
-
class BaseModule:
|
117
|
-
|
118
|
-
def __init__(self, api_client_: _api_client.ApiClient):
|
119
|
-
self._api_client = api_client_
|
120
|
-
|
121
|
-
|
122
117
|
def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
|
123
118
|
"""Recursively converts a given object to a dictionary.
|
124
119
|
|
@@ -219,8 +214,16 @@ class CaseInSensitiveEnum(str, enum.Enum):
|
|
219
214
|
except KeyError:
|
220
215
|
try:
|
221
216
|
return cls[value.lower()] # Try to access directly with lowercase
|
222
|
-
except KeyError
|
223
|
-
|
217
|
+
except KeyError:
|
218
|
+
warnings.warn(f"{value} is not a valid {cls.__name__}")
|
219
|
+
try:
|
220
|
+
# Creating a enum instance based on the value
|
221
|
+
unknown_enum_val = cls._new_member_(cls) # pylint: disable=protected-access,attribute-error
|
222
|
+
unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
|
223
|
+
unknown_enum_val._value_ = value # pylint: disable=protected-access
|
224
|
+
return unknown_enum_val
|
225
|
+
except:
|
226
|
+
return None
|
224
227
|
|
225
228
|
|
226
229
|
def timestamped_unique_name() -> str:
|
google/genai/_extra_utils.py
CHANGED
@@ -13,12 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
#
|
15
15
|
|
16
|
-
"""Extra utils depending on types that are shared between sync and async modules.
|
17
|
-
"""
|
16
|
+
"""Extra utils depending on types that are shared between sync and async modules."""
|
18
17
|
|
19
18
|
import inspect
|
20
19
|
import logging
|
21
|
-
|
20
|
+
import typing
|
21
|
+
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
22
|
+
import sys
|
22
23
|
|
23
24
|
import pydantic
|
24
25
|
|
@@ -26,6 +27,10 @@ from . import _common
|
|
26
27
|
from . import errors
|
27
28
|
from . import types
|
28
29
|
|
30
|
+
if sys.version_info >= (3, 10):
|
31
|
+
from types import UnionType
|
32
|
+
else:
|
33
|
+
UnionType = typing._UnionGenericAlias
|
29
34
|
|
30
35
|
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
31
36
|
|
@@ -78,8 +83,8 @@ def get_function_map(
|
|
78
83
|
if inspect.iscoroutinefunction(tool):
|
79
84
|
raise errors.UnsupportedFunctionError(
|
80
85
|
f'Function {tool.__name__} is a coroutine function, which is not'
|
81
|
-
' supported for automatic function calling. Please manually
|
82
|
-
f' {tool.__name__} to get the function response.'
|
86
|
+
' supported for automatic function calling. Please manually'
|
87
|
+
f' invoke {tool.__name__} to get the function response.'
|
83
88
|
)
|
84
89
|
function_map[tool.__name__] = tool
|
85
90
|
return function_map
|
@@ -135,11 +140,13 @@ def convert_if_exist_pydantic_model(
|
|
135
140
|
for k, v in value.items()
|
136
141
|
}
|
137
142
|
# example 1: typing.Union[int, float]
|
138
|
-
# example 2: int | float equivalent to
|
139
|
-
if get_origin(annotation) in (Union,
|
143
|
+
# example 2: int | float equivalent to UnionType[int, float]
|
144
|
+
if get_origin(annotation) in (Union, UnionType):
|
140
145
|
for arg in get_args(annotation):
|
141
|
-
if
|
142
|
-
|
146
|
+
if (
|
147
|
+
(get_args(arg) and get_origin(arg) is list)
|
148
|
+
or isinstance(value, arg)
|
149
|
+
or (isinstance(value, dict) and _is_annotation_pydantic_model(arg))
|
143
150
|
):
|
144
151
|
try:
|
145
152
|
return convert_if_exist_pydantic_model(
|
@@ -209,7 +216,9 @@ def get_function_response_parts(
|
|
209
216
|
response = {'result': invoke_function_from_dict_args(args, func)}
|
210
217
|
except Exception as e: # pylint: disable=broad-except
|
211
218
|
response = {'error': str(e)}
|
212
|
-
func_response = types.Part.from_function_response(
|
219
|
+
func_response = types.Part.from_function_response(
|
220
|
+
name=func_name, response=response
|
221
|
+
)
|
213
222
|
|
214
223
|
func_response_parts.append(func_response)
|
215
224
|
return func_response_parts
|
@@ -231,8 +240,7 @@ def should_disable_afc(
|
|
231
240
|
and config_model.automatic_function_calling
|
232
241
|
and config_model.automatic_function_calling.maximum_remote_calls
|
233
242
|
is not None
|
234
|
-
and int(config_model.automatic_function_calling.maximum_remote_calls)
|
235
|
-
<= 0
|
243
|
+
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
|
236
244
|
):
|
237
245
|
logging.warning(
|
238
246
|
'max_remote_calls in automatic_function_calling_config'
|
@@ -294,6 +302,7 @@ def get_max_remote_calls_afc(
|
|
294
302
|
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
295
303
|
return int(config_model.automatic_function_calling.maximum_remote_calls)
|
296
304
|
|
305
|
+
|
297
306
|
def should_append_afc_history(
|
298
307
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
299
308
|
) -> bool:
|
@@ -302,9 +311,6 @@ def should_append_afc_history(
|
|
302
311
|
if config and isinstance(config, dict)
|
303
312
|
else config
|
304
313
|
)
|
305
|
-
if
|
306
|
-
not config_model
|
307
|
-
or not config_model.automatic_function_calling
|
308
|
-
):
|
314
|
+
if not config_model or not config_model.automatic_function_calling:
|
309
315
|
return True
|
310
316
|
return not config_model.automatic_function_calling.ignore_call_history
|