google-genai 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +160 -59
- google/genai/_api_module.py +6 -1
- google/genai/_automatic_function_calling_util.py +12 -12
- google/genai/_common.py +14 -2
- google/genai/_extra_utils.py +14 -8
- google/genai/_replay_api_client.py +35 -3
- google/genai/_test_api_client.py +8 -8
- google/genai/_transformers.py +169 -48
- google/genai/batches.py +176 -127
- google/genai/caches.py +315 -214
- google/genai/chats.py +179 -35
- google/genai/client.py +16 -6
- google/genai/errors.py +19 -5
- google/genai/files.py +161 -115
- google/genai/live.py +137 -105
- google/genai/models.py +1553 -734
- google/genai/operations.py +635 -0
- google/genai/pagers.py +5 -5
- google/genai/tunings.py +166 -103
- google/genai/types.py +590 -142
- google/genai/version.py +1 -1
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/METADATA +94 -12
- google_genai-1.4.0.dist-info/RECORD +27 -0
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/WHEEL +1 -1
- google/genai/_operations.py +0 -365
- google_genai-1.2.0.dist-info/RECORD +0 -27
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/LICENSE +0 -0
- {google_genai-1.2.0.dist-info → google_genai-1.4.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -14,7 +14,10 @@
|
|
14
14
|
#
|
15
15
|
|
16
16
|
|
17
|
-
"""Base client for calling HTTP APIs sending and receiving JSON.
|
17
|
+
"""Base client for calling HTTP APIs sending and receiving JSON.
|
18
|
+
|
19
|
+
The BaseApiClient is intended to be a private module and is subject to change.
|
20
|
+
"""
|
18
21
|
|
19
22
|
import asyncio
|
20
23
|
import copy
|
@@ -25,19 +28,23 @@ import json
|
|
25
28
|
import logging
|
26
29
|
import os
|
27
30
|
import sys
|
28
|
-
from typing import Any, Optional, Tuple, TypedDict, Union
|
31
|
+
from typing import Any, AsyncIterator, Optional, Tuple, TypedDict, Union
|
29
32
|
from urllib.parse import urlparse, urlunparse
|
30
|
-
|
31
33
|
import google.auth
|
32
34
|
import google.auth.credentials
|
35
|
+
from google.auth.credentials import Credentials
|
33
36
|
from google.auth.transport.requests import AuthorizedSession
|
37
|
+
from google.auth.transport.requests import Request
|
38
|
+
import httpx
|
34
39
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
35
40
|
import requests
|
36
|
-
|
41
|
+
from . import _common
|
37
42
|
from . import errors
|
38
43
|
from . import version
|
39
44
|
from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
|
40
45
|
|
46
|
+
logger = logging.getLogger('google_genai._api_client')
|
47
|
+
|
41
48
|
|
42
49
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
43
50
|
"""Appends the telemetry header to the headers dict."""
|
@@ -61,7 +68,7 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
|
|
61
68
|
|
62
69
|
|
63
70
|
def _patch_http_options(
|
64
|
-
options: HttpOptionsDict, patch_options:
|
71
|
+
options: HttpOptionsDict, patch_options: dict[str, Any]
|
65
72
|
) -> HttpOptionsDict:
|
66
73
|
# use shallow copy so we don't override the original objects.
|
67
74
|
copy_option = HttpOptionsDict()
|
@@ -94,6 +101,27 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
94
101
|
return urlunparse(parsed_base._replace(path=base_path + '/' + path))
|
95
102
|
|
96
103
|
|
104
|
+
def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
|
105
|
+
"""Loads google auth credentials and project id."""
|
106
|
+
credentials, loaded_project_id = google.auth.default(
|
107
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
108
|
+
)
|
109
|
+
|
110
|
+
if not project:
|
111
|
+
project = loaded_project_id
|
112
|
+
|
113
|
+
if not project:
|
114
|
+
raise ValueError(
|
115
|
+
'Could not resolve project using application default credentials.'
|
116
|
+
)
|
117
|
+
|
118
|
+
return credentials, project
|
119
|
+
|
120
|
+
|
121
|
+
def _refresh_auth(credentials: Credentials) -> None:
|
122
|
+
credentials.refresh(Request())
|
123
|
+
|
124
|
+
|
97
125
|
@dataclass
|
98
126
|
class HttpRequest:
|
99
127
|
headers: dict[str, str]
|
@@ -105,15 +133,14 @@ class HttpRequest:
|
|
105
133
|
|
106
134
|
# TODO(b/394358912): Update this class to use a SDKResponse class that can be
|
107
135
|
# generated and used for all languages.
|
108
|
-
|
109
|
-
|
110
|
-
|
136
|
+
class BaseResponse(_common.BaseModel):
|
137
|
+
http_headers: Optional[dict[str, str]] = Field(
|
138
|
+
default=None, description='The http headers of the response.'
|
139
|
+
)
|
111
140
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
return self
|
116
|
-
return {'httpHeaders': self.http_headers}
|
141
|
+
json_payload: Optional[Any] = Field(
|
142
|
+
default=None, description='The json payload of the response.'
|
143
|
+
)
|
117
144
|
|
118
145
|
|
119
146
|
class HttpResponse:
|
@@ -128,15 +155,15 @@ class HttpResponse:
|
|
128
155
|
self.headers = headers
|
129
156
|
self.response_stream = response_stream
|
130
157
|
self.byte_stream = byte_stream
|
131
|
-
self.segment_iterator = self.segments()
|
132
158
|
|
133
159
|
# Async iterator for async streaming.
|
134
160
|
def __aiter__(self):
|
161
|
+
self.segment_iterator = self.async_segments()
|
135
162
|
return self
|
136
163
|
|
137
164
|
async def __anext__(self):
|
138
165
|
try:
|
139
|
-
return
|
166
|
+
return await self.segment_iterator.__anext__()
|
140
167
|
except StopIteration:
|
141
168
|
raise StopAsyncIteration
|
142
169
|
|
@@ -163,6 +190,25 @@ class HttpResponse:
|
|
163
190
|
chunk = chunk[len(b'data: ') :]
|
164
191
|
yield json.loads(str(chunk, 'utf-8'))
|
165
192
|
|
193
|
+
async def async_segments(self) -> AsyncIterator[Any]:
|
194
|
+
if isinstance(self.response_stream, list):
|
195
|
+
# list of objects retrieved from replay or from non-streaming API.
|
196
|
+
for chunk in self.response_stream:
|
197
|
+
yield json.loads(chunk) if chunk else {}
|
198
|
+
elif self.response_stream is None:
|
199
|
+
async for c in []:
|
200
|
+
yield c
|
201
|
+
else:
|
202
|
+
# Iterator of objects retrieved from the API.
|
203
|
+
async for chunk in self.response_stream.aiter_lines():
|
204
|
+
# This is httpx.Response.
|
205
|
+
if chunk:
|
206
|
+
# In async streaming mode, the chunk of JSON is prefixed with "data:"
|
207
|
+
# which we must strip before parsing.
|
208
|
+
if chunk.startswith('data: '):
|
209
|
+
chunk = chunk[len('data: ') :]
|
210
|
+
yield json.loads(chunk)
|
211
|
+
|
166
212
|
def byte_segments(self):
|
167
213
|
if isinstance(self.byte_stream, list):
|
168
214
|
# list of objects retrieved from replay or from non-streaming API.
|
@@ -181,17 +227,17 @@ class HttpResponse:
|
|
181
227
|
response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
|
182
228
|
|
183
229
|
|
184
|
-
class
|
230
|
+
class BaseApiClient:
|
185
231
|
"""Client for calling HTTP APIs sending and receiving JSON."""
|
186
232
|
|
187
233
|
def __init__(
|
188
234
|
self,
|
189
|
-
vertexai:
|
190
|
-
api_key:
|
191
|
-
credentials: google.auth.credentials.Credentials = None,
|
192
|
-
project:
|
193
|
-
location:
|
194
|
-
http_options: HttpOptionsOrDict = None,
|
235
|
+
vertexai: Optional[bool] = None,
|
236
|
+
api_key: Optional[str] = None,
|
237
|
+
credentials: Optional[google.auth.credentials.Credentials] = None,
|
238
|
+
project: Optional[str] = None,
|
239
|
+
location: Optional[str] = None,
|
240
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
195
241
|
):
|
196
242
|
self.vertexai = vertexai
|
197
243
|
if self.vertexai is None:
|
@@ -215,14 +261,15 @@ class ApiClient:
|
|
215
261
|
' initializer.'
|
216
262
|
)
|
217
263
|
|
218
|
-
# Validate http_options if
|
264
|
+
# Validate http_options if it is provided.
|
265
|
+
validated_http_options: dict[str, Any]
|
219
266
|
if isinstance(http_options, dict):
|
220
267
|
try:
|
221
|
-
HttpOptions.model_validate(http_options)
|
268
|
+
validated_http_options = HttpOptions.model_validate(http_options).model_dump()
|
222
269
|
except ValidationError as e:
|
223
270
|
raise ValueError(f'Invalid http_options: {e}')
|
224
271
|
elif isinstance(http_options, HttpOptions):
|
225
|
-
|
272
|
+
validated_http_options = http_options.model_dump()
|
226
273
|
|
227
274
|
# Retrieve implicitly set values from the environment.
|
228
275
|
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
@@ -234,20 +281,24 @@ class ApiClient:
|
|
234
281
|
|
235
282
|
self._credentials = credentials
|
236
283
|
self._http_options = HttpOptionsDict()
|
284
|
+
# Initialize the lock. This lock will be used to protect access to the
|
285
|
+
# credentials. This is crucial for thread safety when multiple coroutines
|
286
|
+
# might be accessing the credentials at the same time.
|
287
|
+
self._auth_lock = asyncio.Lock()
|
237
288
|
|
238
289
|
# Handle when to use Vertex AI in express mode (api key).
|
239
290
|
# Explicit initializer arguments are already validated above.
|
240
291
|
if self.vertexai:
|
241
292
|
if credentials:
|
242
293
|
# Explicit credentials take precedence over implicit api_key.
|
243
|
-
|
294
|
+
logger.info(
|
244
295
|
'The user provided Google Cloud credentials will take precedence'
|
245
296
|
+ ' over the API key from the environment variable.'
|
246
297
|
)
|
247
298
|
self.api_key = None
|
248
299
|
elif (env_location or env_project) and api_key:
|
249
300
|
# Explicit api_key takes precedence over implicit project/location.
|
250
|
-
|
301
|
+
logger.info(
|
251
302
|
'The user provided Vertex AI API key will take precedence over the'
|
252
303
|
+ ' project/location from the environment variables.'
|
253
304
|
)
|
@@ -255,20 +306,22 @@ class ApiClient:
|
|
255
306
|
self.location = None
|
256
307
|
elif (project or location) and env_api_key:
|
257
308
|
# Explicit project/location takes precedence over implicit api_key.
|
258
|
-
|
309
|
+
logger.info(
|
259
310
|
'The user provided project/location will take precedence over the'
|
260
311
|
+ ' Vertex AI API key from the environment variable.'
|
261
312
|
)
|
262
313
|
self.api_key = None
|
263
314
|
elif (env_location or env_project) and env_api_key:
|
264
315
|
# Implicit project/location takes precedence over implicit api_key.
|
265
|
-
|
316
|
+
logger.info(
|
266
317
|
'The project/location from the environment variables will take'
|
267
318
|
+ ' precedence over the API key from the environment variables.'
|
268
319
|
)
|
269
320
|
self.api_key = None
|
270
321
|
if not self.project and not self.api_key:
|
271
|
-
self.project =
|
322
|
+
credentials, self.project = _load_auth(project=None)
|
323
|
+
if not self._credentials:
|
324
|
+
self._credentials = credentials
|
272
325
|
if not ((self.project and self.location) or self.api_key):
|
273
326
|
raise ValueError(
|
274
327
|
'Project and location or API key must be set when using the Vertex '
|
@@ -298,7 +351,7 @@ class ApiClient:
|
|
298
351
|
self._http_options['headers']['x-goog-api-key'] = self.api_key
|
299
352
|
# Update the http options with the user provided http options.
|
300
353
|
if http_options:
|
301
|
-
self._http_options = _patch_http_options(self._http_options,
|
354
|
+
self._http_options = _patch_http_options(self._http_options, validated_http_options)
|
302
355
|
else:
|
303
356
|
_append_library_version_headers(self._http_options['headers'])
|
304
357
|
|
@@ -306,12 +359,38 @@ class ApiClient:
|
|
306
359
|
url_parts = urlparse(self._http_options['base_url'])
|
307
360
|
return url_parts._replace(scheme='wss').geturl()
|
308
361
|
|
362
|
+
async def _async_access_token(self) -> str:
|
363
|
+
"""Retrieves the access token for the credentials."""
|
364
|
+
if not self._credentials:
|
365
|
+
async with self._auth_lock:
|
366
|
+
# This ensures that only one coroutine can execute the auth logic at a
|
367
|
+
# time for thread safety.
|
368
|
+
if not self._credentials:
|
369
|
+
# Double check that the credentials are not set before loading them.
|
370
|
+
self._credentials, project = await asyncio.to_thread(
|
371
|
+
_load_auth, project=self.project
|
372
|
+
)
|
373
|
+
if not self.project:
|
374
|
+
self.project = project
|
375
|
+
|
376
|
+
if self._credentials.expired or not self._credentials.token:
|
377
|
+
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
378
|
+
async with self._auth_lock:
|
379
|
+
if self._credentials.expired or not self._credentials.token:
|
380
|
+
# Double check that the credentials expired before refreshing.
|
381
|
+
await asyncio.to_thread(_refresh_auth, self._credentials)
|
382
|
+
|
383
|
+
if not self._credentials.token:
|
384
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
385
|
+
|
386
|
+
return self._credentials.token
|
387
|
+
|
309
388
|
def _build_request(
|
310
389
|
self,
|
311
390
|
http_method: str,
|
312
391
|
path: str,
|
313
392
|
request_dict: dict[str, object],
|
314
|
-
http_options: HttpOptionsOrDict = None,
|
393
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
315
394
|
) -> HttpRequest:
|
316
395
|
# Remove all special dict keys such as _url and _query.
|
317
396
|
keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
|
@@ -349,7 +428,7 @@ class ApiClient:
|
|
349
428
|
patched_http_options['api_version'] + '/' + path,
|
350
429
|
)
|
351
430
|
|
352
|
-
timeout_in_seconds = patched_http_options.get('timeout', None)
|
431
|
+
timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get('timeout', None)
|
353
432
|
if timeout_in_seconds:
|
354
433
|
timeout_in_seconds = timeout_in_seconds / 1000.0
|
355
434
|
else:
|
@@ -370,8 +449,10 @@ class ApiClient:
|
|
370
449
|
) -> HttpResponse:
|
371
450
|
if self.vertexai and not self.api_key:
|
372
451
|
if not self._credentials:
|
373
|
-
self._credentials, _ =
|
374
|
-
|
452
|
+
self._credentials, _ = _load_auth(project=self.project)
|
453
|
+
if self._credentials.quota_project_id:
|
454
|
+
http_request.headers['x-goog-user-project'] = (
|
455
|
+
self._credentials.quota_project_id
|
375
456
|
)
|
376
457
|
authed_session = AuthorizedSession(self._credentials)
|
377
458
|
authed_session.stream = stream
|
@@ -394,7 +475,7 @@ class ApiClient:
|
|
394
475
|
http_request: HttpRequest,
|
395
476
|
stream: bool = False,
|
396
477
|
) -> HttpResponse:
|
397
|
-
data = None
|
478
|
+
data: Optional[Union[str, bytes]] = None
|
398
479
|
if http_request.data:
|
399
480
|
if not isinstance(http_request.data, bytes):
|
400
481
|
data = json.dumps(http_request.data)
|
@@ -419,21 +500,42 @@ class ApiClient:
|
|
419
500
|
self, http_request: HttpRequest, stream: bool = False
|
420
501
|
):
|
421
502
|
if self.vertexai:
|
422
|
-
|
423
|
-
|
424
|
-
|
503
|
+
http_request.headers['Authorization'] = (
|
504
|
+
f'Bearer {await self._async_access_token()}'
|
505
|
+
)
|
506
|
+
if self._credentials.quota_project_id:
|
507
|
+
http_request.headers['x-goog-user-project'] = (
|
508
|
+
self._credentials.quota_project_id
|
425
509
|
)
|
426
|
-
|
427
|
-
|
428
|
-
http_request,
|
429
|
-
|
510
|
+
if stream:
|
511
|
+
httpx_request = httpx.Request(
|
512
|
+
method=http_request.method,
|
513
|
+
url=http_request.url,
|
514
|
+
content=json.dumps(http_request.data),
|
515
|
+
headers=http_request.headers,
|
430
516
|
)
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
http_request,
|
517
|
+
aclient = httpx.AsyncClient()
|
518
|
+
response = await aclient.send(
|
519
|
+
httpx_request,
|
435
520
|
stream=stream,
|
436
521
|
)
|
522
|
+
errors.APIError.raise_for_response(response)
|
523
|
+
return HttpResponse(
|
524
|
+
response.headers, response if stream else [response.text]
|
525
|
+
)
|
526
|
+
else:
|
527
|
+
async with httpx.AsyncClient() as aclient:
|
528
|
+
response = await aclient.request(
|
529
|
+
method=http_request.method,
|
530
|
+
url=http_request.url,
|
531
|
+
headers=http_request.headers,
|
532
|
+
content=json.dumps(http_request.data) if http_request.data else None,
|
533
|
+
timeout=http_request.timeout,
|
534
|
+
)
|
535
|
+
errors.APIError.raise_for_response(response)
|
536
|
+
return HttpResponse(
|
537
|
+
response.headers, response if stream else [response.text]
|
538
|
+
)
|
437
539
|
|
438
540
|
def get_read_only_http_options(self) -> HttpOptionsDict:
|
439
541
|
copied = HttpOptionsDict()
|
@@ -447,7 +549,7 @@ class ApiClient:
|
|
447
549
|
http_method: str,
|
448
550
|
path: str,
|
449
551
|
request_dict: dict[str, object],
|
450
|
-
http_options: HttpOptionsOrDict = None,
|
552
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
451
553
|
):
|
452
554
|
http_request = self._build_request(
|
453
555
|
http_method, path, request_dict, http_options
|
@@ -455,9 +557,9 @@ class ApiClient:
|
|
455
557
|
response = self._request(http_request, stream=False)
|
456
558
|
json_response = response.json
|
457
559
|
if not json_response:
|
458
|
-
|
459
|
-
|
460
|
-
|
560
|
+
return BaseResponse(http_headers=response.headers).model_dump(
|
561
|
+
by_alias=True
|
562
|
+
)
|
461
563
|
return json_response
|
462
564
|
|
463
565
|
def request_streamed(
|
@@ -465,7 +567,7 @@ class ApiClient:
|
|
465
567
|
http_method: str,
|
466
568
|
path: str,
|
467
569
|
request_dict: dict[str, object],
|
468
|
-
http_options: HttpOptionsDict = None,
|
570
|
+
http_options: Optional[HttpOptionsDict] = None,
|
469
571
|
):
|
470
572
|
http_request = self._build_request(
|
471
573
|
http_method, path, request_dict, http_options
|
@@ -480,7 +582,7 @@ class ApiClient:
|
|
480
582
|
http_method: str,
|
481
583
|
path: str,
|
482
584
|
request_dict: dict[str, object],
|
483
|
-
http_options:
|
585
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
484
586
|
) -> dict[str, object]:
|
485
587
|
http_request = self._build_request(
|
486
588
|
http_method, path, request_dict, http_options
|
@@ -489,8 +591,7 @@ class ApiClient:
|
|
489
591
|
result = await self._async_request(http_request=http_request, stream=False)
|
490
592
|
json_response = result.json
|
491
593
|
if not json_response:
|
492
|
-
|
493
|
-
return base_response
|
594
|
+
return BaseResponse(http_headers=result.headers).model_dump(by_alias=True)
|
494
595
|
return json_response
|
495
596
|
|
496
597
|
async def async_request_streamed(
|
@@ -498,7 +599,7 @@ class ApiClient:
|
|
498
599
|
http_method: str,
|
499
600
|
path: str,
|
500
601
|
request_dict: dict[str, object],
|
501
|
-
http_options: HttpOptionsDict = None,
|
602
|
+
http_options: Optional[HttpOptionsDict] = None,
|
502
603
|
):
|
503
604
|
http_request = self._build_request(
|
504
605
|
http_method, path, request_dict, http_options
|
@@ -607,10 +708,10 @@ class ApiClient:
|
|
607
708
|
self,
|
608
709
|
http_request: HttpRequest,
|
609
710
|
) -> HttpResponse:
|
610
|
-
data = None
|
711
|
+
data: str | bytes | None = None
|
611
712
|
if http_request.data:
|
612
713
|
if not isinstance(http_request.data, bytes):
|
613
|
-
data = json.dumps(http_request.data
|
714
|
+
data = json.dumps(http_request.data)
|
614
715
|
else:
|
615
716
|
data = http_request.data
|
616
717
|
|
@@ -695,5 +796,5 @@ class ApiClient:
|
|
695
796
|
# This method does nothing in the real api client. It is used in the
|
696
797
|
# replay_api_client to verify the response from the SDK method matches the
|
697
798
|
# recorded response.
|
698
|
-
def _verify_response(self, response_model: BaseModel):
|
799
|
+
def _verify_response(self, response_model: _common.BaseModel):
|
699
800
|
pass
|
google/genai/_api_module.py
CHANGED
@@ -15,10 +15,15 @@
|
|
15
15
|
|
16
16
|
"""Utilities for the API Modules of the Google Gen AI SDK."""
|
17
17
|
|
18
|
+
from typing import Optional
|
18
19
|
from . import _api_client
|
19
20
|
|
20
21
|
|
21
22
|
class BaseModule:
|
22
23
|
|
23
|
-
def __init__(self, api_client_: _api_client.
|
24
|
+
def __init__(self, api_client_: _api_client.BaseApiClient):
|
24
25
|
self._api_client = api_client_
|
26
|
+
|
27
|
+
@property
|
28
|
+
def vertexai(self) -> Optional[bool]:
|
29
|
+
return self._api_client.vertexai
|
@@ -31,12 +31,12 @@ else:
|
|
31
31
|
VersionedUnionType = typing._UnionGenericAlias
|
32
32
|
|
33
33
|
_py_builtin_type_to_schema_type = {
|
34
|
-
str:
|
35
|
-
int:
|
36
|
-
float:
|
37
|
-
bool:
|
38
|
-
list:
|
39
|
-
dict:
|
34
|
+
str: types.Type.STRING,
|
35
|
+
int: types.Type.INTEGER,
|
36
|
+
float: types.Type.NUMBER,
|
37
|
+
bool: types.Type.BOOLEAN,
|
38
|
+
list: types.Type.ARRAY,
|
39
|
+
dict: types.Type.OBJECT,
|
40
40
|
}
|
41
41
|
|
42
42
|
|
@@ -145,7 +145,7 @@ def _parse_schema_from_parameter(
|
|
145
145
|
for arg in get_args(param.annotation)
|
146
146
|
)
|
147
147
|
):
|
148
|
-
schema.type =
|
148
|
+
schema.type = _py_builtin_type_to_schema_type[dict]
|
149
149
|
schema.any_of = []
|
150
150
|
unique_types = set()
|
151
151
|
for arg in get_args(param.annotation):
|
@@ -183,7 +183,7 @@ def _parse_schema_from_parameter(
|
|
183
183
|
origin = get_origin(param.annotation)
|
184
184
|
args = get_args(param.annotation)
|
185
185
|
if origin is dict:
|
186
|
-
schema.type =
|
186
|
+
schema.type = _py_builtin_type_to_schema_type[dict]
|
187
187
|
if param.default is not inspect.Parameter.empty:
|
188
188
|
if not _is_default_value_compatible(param.default, param.annotation):
|
189
189
|
raise ValueError(default_value_error_msg)
|
@@ -195,7 +195,7 @@ def _parse_schema_from_parameter(
|
|
195
195
|
raise ValueError(
|
196
196
|
f'Literal type {param.annotation} must be a list of strings.'
|
197
197
|
)
|
198
|
-
schema.type =
|
198
|
+
schema.type = _py_builtin_type_to_schema_type[str]
|
199
199
|
schema.enum = list(args)
|
200
200
|
if param.default is not inspect.Parameter.empty:
|
201
201
|
if not _is_default_value_compatible(param.default, param.annotation):
|
@@ -204,7 +204,7 @@ def _parse_schema_from_parameter(
|
|
204
204
|
_raise_if_schema_unsupported(api_option, schema)
|
205
205
|
return schema
|
206
206
|
if origin is list:
|
207
|
-
schema.type =
|
207
|
+
schema.type = _py_builtin_type_to_schema_type[list]
|
208
208
|
schema.items = _parse_schema_from_parameter(
|
209
209
|
api_option,
|
210
210
|
inspect.Parameter(
|
@@ -222,7 +222,7 @@ def _parse_schema_from_parameter(
|
|
222
222
|
return schema
|
223
223
|
if origin is Union:
|
224
224
|
schema.any_of = []
|
225
|
-
schema.type =
|
225
|
+
schema.type = _py_builtin_type_to_schema_type[dict]
|
226
226
|
unique_types = set()
|
227
227
|
for arg in args:
|
228
228
|
# The first check is for NoneType in Python 3.9, since the __name__
|
@@ -280,7 +280,7 @@ def _parse_schema_from_parameter(
|
|
280
280
|
and param.default is not None
|
281
281
|
):
|
282
282
|
schema.default = param.default
|
283
|
-
schema.type =
|
283
|
+
schema.type = _py_builtin_type_to_schema_type[dict]
|
284
284
|
schema.properties = {}
|
285
285
|
for field_name, field_info in param.annotation.model_fields.items():
|
286
286
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
google/genai/_common.py
CHANGED
@@ -60,6 +60,12 @@ def set_value_by_path(data, keys, value):
|
|
60
60
|
for d in data[key_name]:
|
61
61
|
set_value_by_path(d, keys[i + 1 :], value)
|
62
62
|
return
|
63
|
+
elif key.endswith('[0]'):
|
64
|
+
key_name = key[:-3]
|
65
|
+
if key_name not in data:
|
66
|
+
data[key_name] = [{}]
|
67
|
+
set_value_by_path(data[key_name][0], keys[i + 1 :], value)
|
68
|
+
return
|
63
69
|
|
64
70
|
data = data.setdefault(key, {})
|
65
71
|
|
@@ -106,6 +112,12 @@ def get_value_by_path(data: object, keys: list[str]):
|
|
106
112
|
return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
|
107
113
|
else:
|
108
114
|
return None
|
115
|
+
elif key.endswith('[0]'):
|
116
|
+
key_name = key[:-3]
|
117
|
+
if key_name in data and data[key_name]:
|
118
|
+
return get_value_by_path(data[key_name][0], keys[i + 1 :])
|
119
|
+
else:
|
120
|
+
return None
|
109
121
|
else:
|
110
122
|
if key in data:
|
111
123
|
data = data[key]
|
@@ -193,7 +205,7 @@ class BaseModel(pydantic.BaseModel):
|
|
193
205
|
|
194
206
|
@classmethod
|
195
207
|
def _from_response(
|
196
|
-
cls, response: dict[str, object], kwargs: dict[str, object]
|
208
|
+
cls, *, response: dict[str, object], kwargs: dict[str, object]
|
197
209
|
) -> 'BaseModel':
|
198
210
|
# To maintain forward compatibility, we need to remove extra fields from
|
199
211
|
# the response.
|
@@ -254,7 +266,7 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
|
254
266
|
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
|
255
267
|
to compatible type (e.g. base64 encoded string, isoformat date string).
|
256
268
|
"""
|
257
|
-
processed_data = {}
|
269
|
+
processed_data: dict[str, object] = {}
|
258
270
|
if not isinstance(data, dict):
|
259
271
|
return data
|
260
272
|
for key, value in data.items():
|
google/genai/_extra_utils.py
CHANGED
@@ -34,6 +34,8 @@ else:
|
|
34
34
|
|
35
35
|
_DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
36
36
|
|
37
|
+
logger = logging.getLogger('google_genai.models')
|
38
|
+
|
37
39
|
|
38
40
|
def format_destination(
|
39
41
|
src: str,
|
@@ -74,7 +76,7 @@ def get_function_map(
|
|
74
76
|
if config and isinstance(config, dict)
|
75
77
|
else config
|
76
78
|
)
|
77
|
-
function_map = {}
|
79
|
+
function_map: dict[str, object] = {}
|
78
80
|
if not config_model:
|
79
81
|
return function_map
|
80
82
|
if config_model.tools:
|
@@ -218,15 +220,16 @@ def get_function_response_parts(
|
|
218
220
|
func_name = part.function_call.name
|
219
221
|
func = function_map[func_name]
|
220
222
|
args = convert_number_values_for_function_call_args(part.function_call.args)
|
223
|
+
func_response: dict[str, Any]
|
221
224
|
try:
|
222
|
-
|
225
|
+
func_response = {'result': invoke_function_from_dict_args(args, func)}
|
223
226
|
except Exception as e: # pylint: disable=broad-except
|
224
|
-
|
225
|
-
|
226
|
-
name=func_name, response=
|
227
|
+
func_response = {'error': str(e)}
|
228
|
+
func_response_part = types.Part.from_function_response(
|
229
|
+
name=func_name, response=func_response
|
227
230
|
)
|
228
231
|
|
229
|
-
func_response_parts.append(
|
232
|
+
func_response_parts.append(func_response_part)
|
230
233
|
return func_response_parts
|
231
234
|
|
232
235
|
|
@@ -248,7 +251,7 @@ def should_disable_afc(
|
|
248
251
|
is not None
|
249
252
|
and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
|
250
253
|
):
|
251
|
-
|
254
|
+
logger.warning(
|
252
255
|
'max_remote_calls in automatic_function_calling_config'
|
253
256
|
f' {config_model.automatic_function_calling.maximum_remote_calls} is'
|
254
257
|
' less than or equal to 0. Disabling automatic function calling.'
|
@@ -268,9 +271,12 @@ def should_disable_afc(
|
|
268
271
|
config_model.automatic_function_calling.disable
|
269
272
|
and config_model.automatic_function_calling.maximum_remote_calls
|
270
273
|
is not None
|
274
|
+
# exclude the case where max_remote_calls is set to 10 by default.
|
275
|
+
and 'maximum_remote_calls'
|
276
|
+
in config_model.automatic_function_calling.model_fields_set
|
271
277
|
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
|
272
278
|
):
|
273
|
-
|
279
|
+
logger.warning(
|
274
280
|
'`automatic_function_calling.disable` is set to `True`. And'
|
275
281
|
' `automatic_function_calling.maximum_remote_calls` is a'
|
276
282
|
' positive number'
|