google-genai 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +143 -69
- google/genai/_api_module.py +1 -1
- google/genai/_automatic_function_calling_util.py +15 -15
- google/genai/_common.py +6 -3
- google/genai/_extra_utils.py +62 -46
- google/genai/_replay_api_client.py +73 -4
- google/genai/_test_api_client.py +8 -8
- google/genai/_transformers.py +194 -66
- google/genai/batches.py +180 -134
- google/genai/caches.py +316 -216
- google/genai/chats.py +179 -35
- google/genai/client.py +3 -3
- google/genai/errors.py +1 -2
- google/genai/files.py +175 -119
- google/genai/live.py +73 -64
- google/genai/models.py +898 -637
- google/genai/operations.py +96 -66
- google/genai/pagers.py +16 -7
- google/genai/tunings.py +172 -112
- google/genai/types.py +228 -178
- google/genai/version.py +1 -1
- {google_genai-1.3.0.dist-info → google_genai-1.5.0.dist-info}/METADATA +8 -1
- google_genai-1.5.0.dist-info/RECORD +27 -0
- {google_genai-1.3.0.dist-info → google_genai-1.5.0.dist-info}/WHEEL +1 -1
- google_genai-1.3.0.dist-info/RECORD +0 -27
- {google_genai-1.3.0.dist-info → google_genai-1.5.0.dist-info}/LICENSE +0 -0
- {google_genai-1.3.0.dist-info → google_genai-1.5.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -14,8 +14,12 @@
|
|
14
14
|
#
|
15
15
|
|
16
16
|
|
17
|
-
"""Base client for calling HTTP APIs sending and receiving JSON.
|
17
|
+
"""Base client for calling HTTP APIs sending and receiving JSON.
|
18
18
|
|
19
|
+
The BaseApiClient is intended to be a private module and is subject to change.
|
20
|
+
"""
|
21
|
+
|
22
|
+
import anyio
|
19
23
|
import asyncio
|
20
24
|
import copy
|
21
25
|
from dataclasses import dataclass
|
@@ -41,6 +45,7 @@ from . import version
|
|
41
45
|
from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
|
42
46
|
|
43
47
|
logger = logging.getLogger('google_genai._api_client')
|
48
|
+
CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
|
44
49
|
|
45
50
|
|
46
51
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
@@ -65,7 +70,7 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
|
|
65
70
|
|
66
71
|
|
67
72
|
def _patch_http_options(
|
68
|
-
options: HttpOptionsDict, patch_options:
|
73
|
+
options: HttpOptionsDict, patch_options: dict[str, Any]
|
69
74
|
) -> HttpOptionsDict:
|
70
75
|
# use shallow copy so we don't override the original objects.
|
71
76
|
copy_option = HttpOptionsDict()
|
@@ -115,8 +120,9 @@ def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
|
|
115
120
|
return credentials, project
|
116
121
|
|
117
122
|
|
118
|
-
def _refresh_auth(credentials: Credentials) ->
|
123
|
+
def _refresh_auth(credentials: Credentials) -> Credentials:
|
119
124
|
credentials.refresh(Request())
|
125
|
+
return credentials
|
120
126
|
|
121
127
|
|
122
128
|
@dataclass
|
@@ -131,7 +137,7 @@ class HttpRequest:
|
|
131
137
|
# TODO(b/394358912): Update this class to use a SDKResponse class that can be
|
132
138
|
# generated and used for all languages.
|
133
139
|
class BaseResponse(_common.BaseModel):
|
134
|
-
http_headers: dict[str, str] = Field(
|
140
|
+
http_headers: Optional[dict[str, str]] = Field(
|
135
141
|
default=None, description='The http headers of the response.'
|
136
142
|
)
|
137
143
|
|
@@ -144,7 +150,7 @@ class HttpResponse:
|
|
144
150
|
|
145
151
|
def __init__(
|
146
152
|
self,
|
147
|
-
headers: dict[str, str],
|
153
|
+
headers: Union[dict[str, str], httpx.Headers],
|
148
154
|
response_stream: Union[Any, str] = None,
|
149
155
|
byte_stream: Union[Any, bytes] = None,
|
150
156
|
):
|
@@ -197,14 +203,19 @@ class HttpResponse:
|
|
197
203
|
yield c
|
198
204
|
else:
|
199
205
|
# Iterator of objects retrieved from the API.
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
206
|
+
if hasattr(self.response_stream, 'aiter_lines'):
|
207
|
+
async for chunk in self.response_stream.aiter_lines():
|
208
|
+
# This is httpx.Response.
|
209
|
+
if chunk:
|
210
|
+
# In async streaming mode, the chunk of JSON is prefixed with "data:"
|
211
|
+
# which we must strip before parsing.
|
212
|
+
if chunk.startswith('data: '):
|
213
|
+
chunk = chunk[len('data: ') :]
|
214
|
+
yield json.loads(chunk)
|
215
|
+
else:
|
216
|
+
raise ValueError(
|
217
|
+
'Error parsing streaming response.'
|
218
|
+
)
|
208
219
|
|
209
220
|
def byte_segments(self):
|
210
221
|
if isinstance(self.byte_stream, list):
|
@@ -224,17 +235,17 @@ class HttpResponse:
|
|
224
235
|
response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
|
225
236
|
|
226
237
|
|
227
|
-
class
|
238
|
+
class BaseApiClient:
|
228
239
|
"""Client for calling HTTP APIs sending and receiving JSON."""
|
229
240
|
|
230
241
|
def __init__(
|
231
242
|
self,
|
232
|
-
vertexai:
|
233
|
-
api_key:
|
234
|
-
credentials: google.auth.credentials.Credentials = None,
|
235
|
-
project:
|
236
|
-
location:
|
237
|
-
http_options: HttpOptionsOrDict = None,
|
243
|
+
vertexai: Optional[bool] = None,
|
244
|
+
api_key: Optional[str] = None,
|
245
|
+
credentials: Optional[google.auth.credentials.Credentials] = None,
|
246
|
+
project: Optional[str] = None,
|
247
|
+
location: Optional[str] = None,
|
248
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
238
249
|
):
|
239
250
|
self.vertexai = vertexai
|
240
251
|
if self.vertexai is None:
|
@@ -258,14 +269,15 @@ class ApiClient:
|
|
258
269
|
' initializer.'
|
259
270
|
)
|
260
271
|
|
261
|
-
# Validate http_options if
|
272
|
+
# Validate http_options if it is provided.
|
273
|
+
validated_http_options: dict[str, Any]
|
262
274
|
if isinstance(http_options, dict):
|
263
275
|
try:
|
264
|
-
HttpOptions.model_validate(http_options)
|
276
|
+
validated_http_options = HttpOptions.model_validate(http_options).model_dump()
|
265
277
|
except ValidationError as e:
|
266
278
|
raise ValueError(f'Invalid http_options: {e}')
|
267
279
|
elif isinstance(http_options, HttpOptions):
|
268
|
-
|
280
|
+
validated_http_options = http_options.model_dump()
|
269
281
|
|
270
282
|
# Retrieve implicitly set values from the environment.
|
271
283
|
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
@@ -347,7 +359,7 @@ class ApiClient:
|
|
347
359
|
self._http_options['headers']['x-goog-api-key'] = self.api_key
|
348
360
|
# Update the http options with the user provided http options.
|
349
361
|
if http_options:
|
350
|
-
self._http_options = _patch_http_options(self._http_options,
|
362
|
+
self._http_options = _patch_http_options(self._http_options, validated_http_options)
|
351
363
|
else:
|
352
364
|
_append_library_version_headers(self._http_options['headers'])
|
353
365
|
|
@@ -369,24 +381,29 @@ class ApiClient:
|
|
369
381
|
if not self.project:
|
370
382
|
self.project = project
|
371
383
|
|
372
|
-
if self._credentials
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
384
|
+
if self._credentials:
|
385
|
+
if (
|
386
|
+
self._credentials.expired or not self._credentials.token
|
387
|
+
):
|
388
|
+
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
389
|
+
async with self._auth_lock:
|
390
|
+
if self._credentials.expired or not self._credentials.token:
|
391
|
+
# Double check that the credentials expired before refreshing.
|
392
|
+
await asyncio.to_thread(_refresh_auth, self._credentials)
|
378
393
|
|
379
|
-
|
380
|
-
|
394
|
+
if not self._credentials.token:
|
395
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
381
396
|
|
382
|
-
|
397
|
+
return self._credentials.token
|
398
|
+
else:
|
399
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
383
400
|
|
384
401
|
def _build_request(
|
385
402
|
self,
|
386
403
|
http_method: str,
|
387
404
|
path: str,
|
388
405
|
request_dict: dict[str, object],
|
389
|
-
http_options: HttpOptionsOrDict = None,
|
406
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
390
407
|
) -> HttpRequest:
|
391
408
|
# Remove all special dict keys such as _url and _query.
|
392
409
|
keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
|
@@ -424,8 +441,12 @@ class ApiClient:
|
|
424
441
|
patched_http_options['api_version'] + '/' + path,
|
425
442
|
)
|
426
443
|
|
427
|
-
timeout_in_seconds = patched_http_options.get(
|
444
|
+
timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
|
445
|
+
'timeout', None
|
446
|
+
)
|
428
447
|
if timeout_in_seconds:
|
448
|
+
# HttpOptions.timeout is in milliseconds. But httpx.Client.request()
|
449
|
+
# expects seconds.
|
429
450
|
timeout_in_seconds = timeout_in_seconds / 1000.0
|
430
451
|
else:
|
431
452
|
timeout_in_seconds = None
|
@@ -471,7 +492,7 @@ class ApiClient:
|
|
471
492
|
http_request: HttpRequest,
|
472
493
|
stream: bool = False,
|
473
494
|
) -> HttpResponse:
|
474
|
-
data = None
|
495
|
+
data: Optional[Union[str, bytes]] = None
|
475
496
|
if http_request.data:
|
476
497
|
if not isinstance(http_request.data, bytes):
|
477
498
|
data = json.dumps(http_request.data)
|
@@ -499,18 +520,19 @@ class ApiClient:
|
|
499
520
|
http_request.headers['Authorization'] = (
|
500
521
|
f'Bearer {await self._async_access_token()}'
|
501
522
|
)
|
502
|
-
if self._credentials.quota_project_id:
|
523
|
+
if self._credentials and self._credentials.quota_project_id:
|
503
524
|
http_request.headers['x-goog-user-project'] = (
|
504
525
|
self._credentials.quota_project_id
|
505
526
|
)
|
506
527
|
if stream:
|
507
|
-
|
528
|
+
aclient = httpx.AsyncClient()
|
529
|
+
httpx_request = aclient.build_request(
|
508
530
|
method=http_request.method,
|
509
531
|
url=http_request.url,
|
510
|
-
|
532
|
+
content=json.dumps(http_request.data),
|
511
533
|
headers=http_request.headers,
|
534
|
+
timeout=http_request.timeout,
|
512
535
|
)
|
513
|
-
aclient = httpx.AsyncClient()
|
514
536
|
response = await aclient.send(
|
515
537
|
httpx_request,
|
516
538
|
stream=stream,
|
@@ -525,7 +547,7 @@ class ApiClient:
|
|
525
547
|
method=http_request.method,
|
526
548
|
url=http_request.url,
|
527
549
|
headers=http_request.headers,
|
528
|
-
|
550
|
+
content=json.dumps(http_request.data) if http_request.data else None,
|
529
551
|
timeout=http_request.timeout,
|
530
552
|
)
|
531
553
|
errors.APIError.raise_for_response(response)
|
@@ -545,7 +567,7 @@ class ApiClient:
|
|
545
567
|
http_method: str,
|
546
568
|
path: str,
|
547
569
|
request_dict: dict[str, object],
|
548
|
-
http_options: HttpOptionsOrDict = None,
|
570
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
549
571
|
):
|
550
572
|
http_request = self._build_request(
|
551
573
|
http_method, path, request_dict, http_options
|
@@ -563,7 +585,7 @@ class ApiClient:
|
|
563
585
|
http_method: str,
|
564
586
|
path: str,
|
565
587
|
request_dict: dict[str, object],
|
566
|
-
http_options: HttpOptionsDict = None,
|
588
|
+
http_options: Optional[HttpOptionsDict] = None,
|
567
589
|
):
|
568
590
|
http_request = self._build_request(
|
569
591
|
http_method, path, request_dict, http_options
|
@@ -578,7 +600,7 @@ class ApiClient:
|
|
578
600
|
http_method: str,
|
579
601
|
path: str,
|
580
602
|
request_dict: dict[str, object],
|
581
|
-
http_options:
|
603
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
582
604
|
) -> dict[str, object]:
|
583
605
|
http_request = self._build_request(
|
584
606
|
http_method, path, request_dict, http_options
|
@@ -595,7 +617,7 @@ class ApiClient:
|
|
595
617
|
http_method: str,
|
596
618
|
path: str,
|
597
619
|
request_dict: dict[str, object],
|
598
|
-
http_options: HttpOptionsDict = None,
|
620
|
+
http_options: Optional[HttpOptionsDict] = None,
|
599
621
|
):
|
600
622
|
http_request = self._build_request(
|
601
623
|
http_method, path, request_dict, http_options
|
@@ -648,7 +670,7 @@ class ApiClient:
|
|
648
670
|
offset = 0
|
649
671
|
# Upload the file in chunks
|
650
672
|
while True:
|
651
|
-
file_chunk = file.read(
|
673
|
+
file_chunk = file.read(CHUNK_SIZE)
|
652
674
|
chunk_size = 0
|
653
675
|
if file_chunk:
|
654
676
|
chunk_size = len(file_chunk)
|
@@ -675,13 +697,12 @@ class ApiClient:
|
|
675
697
|
if upload_size <= offset: # Status is not finalized.
|
676
698
|
raise ValueError(
|
677
699
|
'All content has been uploaded, but the upload status is not'
|
678
|
-
f' finalized.
|
700
|
+
f' finalized.'
|
679
701
|
)
|
680
702
|
|
681
703
|
if response.headers['X-Goog-Upload-Status'] != 'final':
|
682
704
|
raise ValueError(
|
683
|
-
'Failed to upload file: Upload status is not finalized.
|
684
|
-
f' {response.headers}, body: {response.json}'
|
705
|
+
'Failed to upload file: Upload status is not finalized.'
|
685
706
|
)
|
686
707
|
return response.json
|
687
708
|
|
@@ -704,10 +725,10 @@ class ApiClient:
|
|
704
725
|
self,
|
705
726
|
http_request: HttpRequest,
|
706
727
|
) -> HttpResponse:
|
707
|
-
data = None
|
728
|
+
data: Optional[Union[str, bytes]] = None
|
708
729
|
if http_request.data:
|
709
730
|
if not isinstance(http_request.data, bytes):
|
710
|
-
data = json.dumps(http_request.data
|
731
|
+
data = json.dumps(http_request.data)
|
711
732
|
else:
|
712
733
|
data = http_request.data
|
713
734
|
|
@@ -742,16 +763,17 @@ class ApiClient:
|
|
742
763
|
returns:
|
743
764
|
The response json object from the finalize request.
|
744
765
|
"""
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
766
|
+
if isinstance(file_path, io.IOBase):
|
767
|
+
return await self._async_upload_fd(file_path, upload_url, upload_size)
|
768
|
+
else:
|
769
|
+
file = anyio.Path(file_path)
|
770
|
+
fd = await file.open('rb')
|
771
|
+
async with fd:
|
772
|
+
return await self._async_upload_fd(fd, upload_url, upload_size)
|
751
773
|
|
752
774
|
async def _async_upload_fd(
|
753
775
|
self,
|
754
|
-
file: io.IOBase,
|
776
|
+
file: Union[io.IOBase, anyio.AsyncFile],
|
755
777
|
upload_url: str,
|
756
778
|
upload_size: int,
|
757
779
|
) -> str:
|
@@ -766,12 +788,45 @@ class ApiClient:
|
|
766
788
|
returns:
|
767
789
|
The response json object from the finalize request.
|
768
790
|
"""
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
791
|
+
async with httpx.AsyncClient() as aclient:
|
792
|
+
offset = 0
|
793
|
+
# Upload the file in chunks
|
794
|
+
while True:
|
795
|
+
if isinstance(file, io.IOBase):
|
796
|
+
file_chunk = file.read(CHUNK_SIZE)
|
797
|
+
else:
|
798
|
+
file_chunk = await file.read(CHUNK_SIZE)
|
799
|
+
chunk_size = 0
|
800
|
+
if file_chunk:
|
801
|
+
chunk_size = len(file_chunk)
|
802
|
+
upload_command = 'upload'
|
803
|
+
# If last chunk, finalize the upload.
|
804
|
+
if chunk_size + offset >= upload_size:
|
805
|
+
upload_command += ', finalize'
|
806
|
+
response = await aclient.request(
|
807
|
+
method='POST',
|
808
|
+
url=upload_url,
|
809
|
+
content=file_chunk,
|
810
|
+
headers={
|
811
|
+
'X-Goog-Upload-Command': upload_command,
|
812
|
+
'X-Goog-Upload-Offset': str(offset),
|
813
|
+
'Content-Length': str(chunk_size),
|
814
|
+
},
|
815
|
+
)
|
816
|
+
offset += chunk_size
|
817
|
+
if response.headers.get('x-goog-upload-status') != 'active':
|
818
|
+
break # upload is complete or it has been interrupted.
|
819
|
+
|
820
|
+
if upload_size <= offset: # Status is not finalized.
|
821
|
+
raise ValueError(
|
822
|
+
'All content has been uploaded, but the upload status is not'
|
823
|
+
f' finalized.'
|
824
|
+
)
|
825
|
+
if response.headers.get('x-goog-upload-status') != 'final':
|
826
|
+
raise ValueError(
|
827
|
+
'Failed to upload file: Upload status is not finalized.'
|
828
|
+
)
|
829
|
+
return response.json()
|
775
830
|
|
776
831
|
async def async_download_file(self, path: str, http_options):
|
777
832
|
"""Downloads the file data.
|
@@ -783,14 +838,33 @@ class ApiClient:
|
|
783
838
|
returns:
|
784
839
|
The file bytes
|
785
840
|
"""
|
786
|
-
|
787
|
-
|
788
|
-
path,
|
789
|
-
http_options,
|
841
|
+
http_request = self._build_request(
|
842
|
+
'get', path=path, request_dict={}, http_options=http_options
|
790
843
|
)
|
791
844
|
|
845
|
+
data: Optional[Union[str, bytes]]
|
846
|
+
if http_request.data:
|
847
|
+
if not isinstance(http_request.data, bytes):
|
848
|
+
data = json.dumps(http_request.data)
|
849
|
+
else:
|
850
|
+
data = http_request.data
|
851
|
+
|
852
|
+
async with httpx.AsyncClient(follow_redirects=True) as aclient:
|
853
|
+
response = await aclient.request(
|
854
|
+
method=http_request.method,
|
855
|
+
url=http_request.url,
|
856
|
+
headers=http_request.headers,
|
857
|
+
content=data,
|
858
|
+
timeout=http_request.timeout,
|
859
|
+
)
|
860
|
+
errors.APIError.raise_for_response(response)
|
861
|
+
|
862
|
+
return HttpResponse(
|
863
|
+
response.headers, byte_stream=[response.read()]
|
864
|
+
).byte_stream[0]
|
865
|
+
|
792
866
|
# This method does nothing in the real api client. It is used in the
|
793
867
|
# replay_api_client to verify the response from the SDK method matches the
|
794
868
|
# recorded response.
|
795
|
-
def _verify_response(self, response_model: BaseModel):
|
869
|
+
def _verify_response(self, response_model: _common.BaseModel):
|
796
870
|
pass
|
google/genai/_api_module.py
CHANGED
@@ -17,7 +17,7 @@ import inspect
|
|
17
17
|
import sys
|
18
18
|
import types as builtin_types
|
19
19
|
import typing
|
20
|
-
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Union
|
20
|
+
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union
|
21
21
|
|
22
22
|
import pydantic
|
23
23
|
|
@@ -31,12 +31,12 @@ else:
|
|
31
31
|
VersionedUnionType = typing._UnionGenericAlias
|
32
32
|
|
33
33
|
_py_builtin_type_to_schema_type = {
|
34
|
-
str:
|
35
|
-
int:
|
36
|
-
float:
|
37
|
-
bool:
|
38
|
-
list:
|
39
|
-
dict:
|
34
|
+
str: types.Type.STRING,
|
35
|
+
int: types.Type.INTEGER,
|
36
|
+
float: types.Type.NUMBER,
|
37
|
+
bool: types.Type.BOOLEAN,
|
38
|
+
list: types.Type.ARRAY,
|
39
|
+
dict: types.Type.OBJECT,
|
40
40
|
}
|
41
41
|
|
42
42
|
|
@@ -145,7 +145,7 @@ def _parse_schema_from_parameter(
|
|
145
145
|
for arg in get_args(param.annotation)
|
146
146
|
)
|
147
147
|
):
|
148
|
-
schema.type =
|
148
|
+
schema.type = _py_builtin_type_to_schema_type[dict]
|
149
149
|
schema.any_of = []
|
150
150
|
unique_types = set()
|
151
151
|
for arg in get_args(param.annotation):
|
@@ -183,7 +183,7 @@ def _parse_schema_from_parameter(
|
|
183
183
|
origin = get_origin(param.annotation)
|
184
184
|
args = get_args(param.annotation)
|
185
185
|
if origin is dict:
|
186
|
-
schema.type =
|
186
|
+
schema.type = _py_builtin_type_to_schema_type[dict]
|
187
187
|
if param.default is not inspect.Parameter.empty:
|
188
188
|
if not _is_default_value_compatible(param.default, param.annotation):
|
189
189
|
raise ValueError(default_value_error_msg)
|
@@ -195,7 +195,7 @@ def _parse_schema_from_parameter(
|
|
195
195
|
raise ValueError(
|
196
196
|
f'Literal type {param.annotation} must be a list of strings.'
|
197
197
|
)
|
198
|
-
schema.type =
|
198
|
+
schema.type = _py_builtin_type_to_schema_type[str]
|
199
199
|
schema.enum = list(args)
|
200
200
|
if param.default is not inspect.Parameter.empty:
|
201
201
|
if not _is_default_value_compatible(param.default, param.annotation):
|
@@ -204,7 +204,7 @@ def _parse_schema_from_parameter(
|
|
204
204
|
_raise_if_schema_unsupported(api_option, schema)
|
205
205
|
return schema
|
206
206
|
if origin is list:
|
207
|
-
schema.type =
|
207
|
+
schema.type = _py_builtin_type_to_schema_type[list]
|
208
208
|
schema.items = _parse_schema_from_parameter(
|
209
209
|
api_option,
|
210
210
|
inspect.Parameter(
|
@@ -222,7 +222,7 @@ def _parse_schema_from_parameter(
|
|
222
222
|
return schema
|
223
223
|
if origin is Union:
|
224
224
|
schema.any_of = []
|
225
|
-
schema.type =
|
225
|
+
schema.type = _py_builtin_type_to_schema_type[dict]
|
226
226
|
unique_types = set()
|
227
227
|
for arg in args:
|
228
228
|
# The first check is for NoneType in Python 3.9, since the __name__
|
@@ -280,7 +280,7 @@ def _parse_schema_from_parameter(
|
|
280
280
|
and param.default is not None
|
281
281
|
):
|
282
282
|
schema.default = param.default
|
283
|
-
schema.type =
|
283
|
+
schema.type = _py_builtin_type_to_schema_type[dict]
|
284
284
|
schema.properties = {}
|
285
285
|
for field_name, field_info in param.annotation.model_fields.items():
|
286
286
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
@@ -304,9 +304,9 @@ def _parse_schema_from_parameter(
|
|
304
304
|
)
|
305
305
|
|
306
306
|
|
307
|
-
def _get_required_fields(schema: types.Schema) -> list[str]:
|
307
|
+
def _get_required_fields(schema: types.Schema) -> Optional[list[str]]:
|
308
308
|
if not schema.properties:
|
309
|
-
return
|
309
|
+
return None
|
310
310
|
return [
|
311
311
|
field_name
|
312
312
|
for field_name, field_schema in schema.properties.items()
|
google/genai/_common.py
CHANGED
@@ -188,6 +188,8 @@ def _remove_extra_fields(
|
|
188
188
|
if isinstance(item, dict):
|
189
189
|
_remove_extra_fields(typing.get_args(annotation)[0], item)
|
190
190
|
|
191
|
+
T = typing.TypeVar('T', bound='BaseModel')
|
192
|
+
|
191
193
|
|
192
194
|
class BaseModel(pydantic.BaseModel):
|
193
195
|
|
@@ -201,12 +203,13 @@ class BaseModel(pydantic.BaseModel):
|
|
201
203
|
arbitrary_types_allowed=True,
|
202
204
|
ser_json_bytes='base64',
|
203
205
|
val_json_bytes='base64',
|
206
|
+
ignored_types=(typing.TypeVar,)
|
204
207
|
)
|
205
208
|
|
206
209
|
@classmethod
|
207
210
|
def _from_response(
|
208
|
-
cls, response: dict[str, object], kwargs: dict[str, object]
|
209
|
-
) ->
|
211
|
+
cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
|
212
|
+
) -> T:
|
210
213
|
# To maintain forward compatibility, we need to remove extra fields from
|
211
214
|
# the response.
|
212
215
|
# We will provide another mechanism to allow users to access these fields.
|
@@ -266,7 +269,7 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
|
266
269
|
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
|
267
270
|
to compatible type (e.g. base64 encoded string, isoformat date string).
|
268
271
|
"""
|
269
|
-
processed_data = {}
|
272
|
+
processed_data: dict[str, object] = {}
|
270
273
|
if not isinstance(data, dict):
|
271
274
|
return data
|
272
275
|
for key, value in data.items():
|