google-genai 1.4.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +207 -111
- google/genai/_automatic_function_calling_util.py +6 -16
- google/genai/_common.py +5 -2
- google/genai/_extra_utils.py +62 -47
- google/genai/_replay_api_client.py +70 -2
- google/genai/_transformers.py +98 -57
- google/genai/batches.py +14 -10
- google/genai/caches.py +30 -36
- google/genai/client.py +3 -2
- google/genai/errors.py +11 -19
- google/genai/files.py +28 -15
- google/genai/live.py +276 -93
- google/genai/models.py +201 -112
- google/genai/operations.py +40 -12
- google/genai/pagers.py +17 -10
- google/genai/tunings.py +40 -30
- google/genai/types.py +146 -58
- google/genai/version.py +1 -1
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/METADATA +194 -24
- google_genai-1.6.0.dist-info/RECORD +27 -0
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/WHEEL +1 -1
- google_genai-1.4.0.dist-info/RECORD +0 -27
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/LICENSE +0 -0
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
The BaseApiClient is intended to be a private module and is subject to change.
|
20
20
|
"""
|
21
21
|
|
22
|
+
import anyio
|
22
23
|
import asyncio
|
23
24
|
import copy
|
24
25
|
from dataclasses import dataclass
|
@@ -28,22 +29,21 @@ import json
|
|
28
29
|
import logging
|
29
30
|
import os
|
30
31
|
import sys
|
31
|
-
from typing import Any, AsyncIterator, Optional, Tuple,
|
32
|
+
from typing import Any, AsyncIterator, Optional, Tuple, Union
|
32
33
|
from urllib.parse import urlparse, urlunparse
|
33
34
|
import google.auth
|
34
35
|
import google.auth.credentials
|
35
36
|
from google.auth.credentials import Credentials
|
36
|
-
from google.auth.transport.requests import AuthorizedSession
|
37
37
|
from google.auth.transport.requests import Request
|
38
38
|
import httpx
|
39
|
-
from pydantic import BaseModel,
|
40
|
-
import requests
|
39
|
+
from pydantic import BaseModel, Field, ValidationError
|
41
40
|
from . import _common
|
42
41
|
from . import errors
|
43
42
|
from . import version
|
44
43
|
from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
|
45
44
|
|
46
45
|
logger = logging.getLogger('google_genai._api_client')
|
46
|
+
CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
|
47
47
|
|
48
48
|
|
49
49
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
@@ -86,7 +86,8 @@ def _patch_http_options(
|
|
86
86
|
copy_option[patch_key].update(patch_value)
|
87
87
|
elif patch_value is not None: # Accept empty values.
|
88
88
|
copy_option[patch_key] = patch_value
|
89
|
-
|
89
|
+
if copy_option['headers']:
|
90
|
+
_append_library_version_headers(copy_option['headers'])
|
90
91
|
return copy_option
|
91
92
|
|
92
93
|
|
@@ -101,7 +102,7 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
101
102
|
return urlunparse(parsed_base._replace(path=base_path + '/' + path))
|
102
103
|
|
103
104
|
|
104
|
-
def _load_auth(*, project: Union[str, None]) ->
|
105
|
+
def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
105
106
|
"""Loads google auth credentials and project id."""
|
106
107
|
credentials, loaded_project_id = google.auth.default(
|
107
108
|
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
@@ -118,8 +119,9 @@ def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
|
|
118
119
|
return credentials, project
|
119
120
|
|
120
121
|
|
121
|
-
def _refresh_auth(credentials: Credentials) ->
|
122
|
+
def _refresh_auth(credentials: Credentials) -> Credentials:
|
122
123
|
credentials.refresh(Request())
|
124
|
+
return credentials
|
123
125
|
|
124
126
|
|
125
127
|
@dataclass
|
@@ -147,7 +149,7 @@ class HttpResponse:
|
|
147
149
|
|
148
150
|
def __init__(
|
149
151
|
self,
|
150
|
-
headers: dict[str, str],
|
152
|
+
headers: Union[dict[str, str], httpx.Headers],
|
151
153
|
response_stream: Union[Any, str] = None,
|
152
154
|
byte_stream: Union[Any, bytes] = None,
|
153
155
|
):
|
@@ -200,14 +202,19 @@ class HttpResponse:
|
|
200
202
|
yield c
|
201
203
|
else:
|
202
204
|
# Iterator of objects retrieved from the API.
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
205
|
+
if hasattr(self.response_stream, 'aiter_lines'):
|
206
|
+
async for chunk in self.response_stream.aiter_lines():
|
207
|
+
# This is httpx.Response.
|
208
|
+
if chunk:
|
209
|
+
# In async streaming mode, the chunk of JSON is prefixed with "data:"
|
210
|
+
# which we must strip before parsing.
|
211
|
+
if chunk.startswith('data: '):
|
212
|
+
chunk = chunk[len('data: ') :]
|
213
|
+
yield json.loads(chunk)
|
214
|
+
else:
|
215
|
+
raise ValueError(
|
216
|
+
'Error parsing streaming response.'
|
217
|
+
)
|
211
218
|
|
212
219
|
def byte_segments(self):
|
213
220
|
if isinstance(self.byte_stream, list):
|
@@ -265,7 +272,9 @@ class BaseApiClient:
|
|
265
272
|
validated_http_options: dict[str, Any]
|
266
273
|
if isinstance(http_options, dict):
|
267
274
|
try:
|
268
|
-
validated_http_options = HttpOptions.model_validate(
|
275
|
+
validated_http_options = HttpOptions.model_validate(
|
276
|
+
http_options
|
277
|
+
).model_dump()
|
269
278
|
except ValidationError as e:
|
270
279
|
raise ValueError(f'Invalid http_options: {e}')
|
271
280
|
elif isinstance(http_options, HttpOptions):
|
@@ -351,7 +360,9 @@ class BaseApiClient:
|
|
351
360
|
self._http_options['headers']['x-goog-api-key'] = self.api_key
|
352
361
|
# Update the http options with the user provided http options.
|
353
362
|
if http_options:
|
354
|
-
self._http_options = _patch_http_options(
|
363
|
+
self._http_options = _patch_http_options(
|
364
|
+
self._http_options, validated_http_options
|
365
|
+
)
|
355
366
|
else:
|
356
367
|
_append_library_version_headers(self._http_options['headers'])
|
357
368
|
|
@@ -359,8 +370,27 @@ class BaseApiClient:
|
|
359
370
|
url_parts = urlparse(self._http_options['base_url'])
|
360
371
|
return url_parts._replace(scheme='wss').geturl()
|
361
372
|
|
362
|
-
|
373
|
+
def _access_token(self) -> str:
|
363
374
|
"""Retrieves the access token for the credentials."""
|
375
|
+
if not self._credentials:
|
376
|
+
self._credentials, project = _load_auth(project=self.project)
|
377
|
+
if not self.project:
|
378
|
+
self.project = project
|
379
|
+
|
380
|
+
if self._credentials:
|
381
|
+
if (
|
382
|
+
self._credentials.expired or not self._credentials.token
|
383
|
+
):
|
384
|
+
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
385
|
+
_refresh_auth(self._credentials)
|
386
|
+
if not self._credentials.token:
|
387
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
388
|
+
return self._credentials.token
|
389
|
+
else:
|
390
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
391
|
+
|
392
|
+
async def _async_access_token(self) -> str:
|
393
|
+
"""Retrieves the access token for the credentials asynchronously."""
|
364
394
|
if not self._credentials:
|
365
395
|
async with self._auth_lock:
|
366
396
|
# This ensures that only one coroutine can execute the auth logic at a
|
@@ -373,17 +403,22 @@ class BaseApiClient:
|
|
373
403
|
if not self.project:
|
374
404
|
self.project = project
|
375
405
|
|
376
|
-
if self._credentials
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
406
|
+
if self._credentials:
|
407
|
+
if (
|
408
|
+
self._credentials.expired or not self._credentials.token
|
409
|
+
):
|
410
|
+
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
411
|
+
async with self._auth_lock:
|
412
|
+
if self._credentials.expired or not self._credentials.token:
|
413
|
+
# Double check that the credentials expired before refreshing.
|
414
|
+
await asyncio.to_thread(_refresh_auth, self._credentials)
|
382
415
|
|
383
|
-
|
384
|
-
|
416
|
+
if not self._credentials.token:
|
417
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
385
418
|
|
386
|
-
|
419
|
+
return self._credentials.token
|
420
|
+
else:
|
421
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
387
422
|
|
388
423
|
def _build_request(
|
389
424
|
self,
|
@@ -424,12 +459,16 @@ class BaseApiClient:
|
|
424
459
|
):
|
425
460
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
426
461
|
url = _join_url_path(
|
427
|
-
patched_http_options
|
428
|
-
patched_http_options
|
462
|
+
patched_http_options.get('base_url', ''),
|
463
|
+
patched_http_options.get('api_version', '') + '/' + path,
|
429
464
|
)
|
430
465
|
|
431
|
-
timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
|
466
|
+
timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
|
467
|
+
'timeout', None
|
468
|
+
)
|
432
469
|
if timeout_in_seconds:
|
470
|
+
# HttpOptions.timeout is in milliseconds. But httpx.Client.request()
|
471
|
+
# expects seconds.
|
433
472
|
timeout_in_seconds = timeout_in_seconds / 1000.0
|
434
473
|
else:
|
435
474
|
timeout_in_seconds = None
|
@@ -447,74 +486,80 @@ class BaseApiClient:
|
|
447
486
|
http_request: HttpRequest,
|
448
487
|
stream: bool = False,
|
449
488
|
) -> HttpResponse:
|
489
|
+
data: Optional[Union[str, bytes]] = None
|
450
490
|
if self.vertexai and not self.api_key:
|
451
|
-
|
452
|
-
|
453
|
-
|
491
|
+
http_request.headers['Authorization'] = (
|
492
|
+
f'Bearer {self._access_token()}'
|
493
|
+
)
|
494
|
+
if self._credentials and self._credentials.quota_project_id:
|
454
495
|
http_request.headers['x-goog-user-project'] = (
|
455
496
|
self._credentials.quota_project_id
|
456
497
|
)
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
http_request.
|
498
|
+
data = json.dumps(http_request.data)
|
499
|
+
else:
|
500
|
+
if http_request.data:
|
501
|
+
if not isinstance(http_request.data, bytes):
|
502
|
+
data = json.dumps(http_request.data)
|
503
|
+
else:
|
504
|
+
data = http_request.data
|
505
|
+
|
506
|
+
if stream:
|
507
|
+
client = httpx.Client()
|
508
|
+
httpx_request = client.build_request(
|
509
|
+
method=http_request.method,
|
510
|
+
url=http_request.url,
|
511
|
+
content=data,
|
462
512
|
headers=http_request.headers,
|
463
|
-
data=json.dumps(http_request.data) if http_request.data else None,
|
464
513
|
timeout=http_request.timeout,
|
465
514
|
)
|
515
|
+
response = client.send(httpx_request, stream=stream)
|
466
516
|
errors.APIError.raise_for_response(response)
|
467
517
|
return HttpResponse(
|
468
518
|
response.headers, response if stream else [response.text]
|
469
519
|
)
|
470
520
|
else:
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
data = http_request.data
|
484
|
-
|
485
|
-
http_session = requests.Session()
|
486
|
-
response = http_session.request(
|
487
|
-
method=http_request.method,
|
488
|
-
url=http_request.url,
|
489
|
-
headers=http_request.headers,
|
490
|
-
data=data,
|
491
|
-
timeout=http_request.timeout,
|
492
|
-
stream=stream,
|
493
|
-
)
|
494
|
-
errors.APIError.raise_for_response(response)
|
495
|
-
return HttpResponse(
|
496
|
-
response.headers, response if stream else [response.text]
|
497
|
-
)
|
521
|
+
with httpx.Client() as client:
|
522
|
+
response = client.request(
|
523
|
+
method=http_request.method,
|
524
|
+
url=http_request.url,
|
525
|
+
headers=http_request.headers,
|
526
|
+
content=data,
|
527
|
+
timeout=http_request.timeout,
|
528
|
+
)
|
529
|
+
errors.APIError.raise_for_response(response)
|
530
|
+
return HttpResponse(
|
531
|
+
response.headers, response if stream else [response.text]
|
532
|
+
)
|
498
533
|
|
499
534
|
async def _async_request(
|
500
535
|
self, http_request: HttpRequest, stream: bool = False
|
501
536
|
):
|
502
|
-
|
537
|
+
data: Optional[Union[str, bytes]] = None
|
538
|
+
if self.vertexai and not self.api_key:
|
503
539
|
http_request.headers['Authorization'] = (
|
504
540
|
f'Bearer {await self._async_access_token()}'
|
505
541
|
)
|
506
|
-
if self._credentials.quota_project_id:
|
542
|
+
if self._credentials and self._credentials.quota_project_id:
|
507
543
|
http_request.headers['x-goog-user-project'] = (
|
508
544
|
self._credentials.quota_project_id
|
509
545
|
)
|
546
|
+
data = json.dumps(http_request.data)
|
547
|
+
else:
|
548
|
+
if http_request.data:
|
549
|
+
if not isinstance(http_request.data, bytes):
|
550
|
+
data = json.dumps(http_request.data)
|
551
|
+
else:
|
552
|
+
data = http_request.data
|
553
|
+
|
510
554
|
if stream:
|
511
|
-
|
555
|
+
aclient = httpx.AsyncClient()
|
556
|
+
httpx_request = aclient.build_request(
|
512
557
|
method=http_request.method,
|
513
558
|
url=http_request.url,
|
514
|
-
content=
|
559
|
+
content=data,
|
515
560
|
headers=http_request.headers,
|
561
|
+
timeout=http_request.timeout,
|
516
562
|
)
|
517
|
-
aclient = httpx.AsyncClient()
|
518
563
|
response = await aclient.send(
|
519
564
|
httpx_request,
|
520
565
|
stream=stream,
|
@@ -529,7 +574,7 @@ class BaseApiClient:
|
|
529
574
|
method=http_request.method,
|
530
575
|
url=http_request.url,
|
531
576
|
headers=http_request.headers,
|
532
|
-
content=
|
577
|
+
content=data,
|
533
578
|
timeout=http_request.timeout,
|
534
579
|
)
|
535
580
|
errors.APIError.raise_for_response(response)
|
@@ -615,7 +660,7 @@ class BaseApiClient:
|
|
615
660
|
|
616
661
|
def upload_file(
|
617
662
|
self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
|
618
|
-
) -> str:
|
663
|
+
) -> dict[str, str]:
|
619
664
|
"""Transfers a file to the given URL.
|
620
665
|
|
621
666
|
Args:
|
@@ -637,7 +682,7 @@ class BaseApiClient:
|
|
637
682
|
|
638
683
|
def _upload_fd(
|
639
684
|
self, file: io.IOBase, upload_url: str, upload_size: int
|
640
|
-
) -> str:
|
685
|
+
) -> dict[str, str]:
|
641
686
|
"""Transfers a file to the given URL.
|
642
687
|
|
643
688
|
Args:
|
@@ -652,7 +697,7 @@ class BaseApiClient:
|
|
652
697
|
offset = 0
|
653
698
|
# Upload the file in chunks
|
654
699
|
while True:
|
655
|
-
file_chunk = file.read(
|
700
|
+
file_chunk = file.read(CHUNK_SIZE)
|
656
701
|
chunk_size = 0
|
657
702
|
if file_chunk:
|
658
703
|
chunk_size = len(file_chunk)
|
@@ -671,7 +716,7 @@ class BaseApiClient:
|
|
671
716
|
data=file_chunk,
|
672
717
|
)
|
673
718
|
|
674
|
-
response = self.
|
719
|
+
response = self._request(request, stream=False)
|
675
720
|
offset += chunk_size
|
676
721
|
if response.headers['X-Goog-Upload-Status'] != 'active':
|
677
722
|
break # upload is complete or it has been interrupted.
|
@@ -679,13 +724,12 @@ class BaseApiClient:
|
|
679
724
|
if upload_size <= offset: # Status is not finalized.
|
680
725
|
raise ValueError(
|
681
726
|
'All content has been uploaded, but the upload status is not'
|
682
|
-
f' finalized.
|
727
|
+
f' finalized.'
|
683
728
|
)
|
684
729
|
|
685
730
|
if response.headers['X-Goog-Upload-Status'] != 'final':
|
686
731
|
raise ValueError(
|
687
|
-
'Failed to upload file: Upload status is not finalized.
|
688
|
-
f' {response.headers}, body: {response.json}'
|
732
|
+
'Failed to upload file: Upload status is not finalized.'
|
689
733
|
)
|
690
734
|
return response.json
|
691
735
|
|
@@ -708,32 +752,31 @@ class BaseApiClient:
|
|
708
752
|
self,
|
709
753
|
http_request: HttpRequest,
|
710
754
|
) -> HttpResponse:
|
711
|
-
data: str
|
755
|
+
data: Optional[Union[str, bytes]] = None
|
712
756
|
if http_request.data:
|
713
757
|
if not isinstance(http_request.data, bytes):
|
714
758
|
data = json.dumps(http_request.data)
|
715
759
|
else:
|
716
760
|
data = http_request.data
|
717
761
|
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
)
|
762
|
+
with httpx.Client(follow_redirects=True) as client:
|
763
|
+
response = client.request(
|
764
|
+
method=http_request.method,
|
765
|
+
url=http_request.url,
|
766
|
+
headers=http_request.headers,
|
767
|
+
content=data,
|
768
|
+
timeout=http_request.timeout,
|
769
|
+
)
|
727
770
|
|
728
|
-
|
729
|
-
|
771
|
+
errors.APIError.raise_for_response(response)
|
772
|
+
return HttpResponse(response.headers, byte_stream=[response.read()])
|
730
773
|
|
731
774
|
async def async_upload_file(
|
732
775
|
self,
|
733
776
|
file_path: Union[str, io.IOBase],
|
734
777
|
upload_url: str,
|
735
778
|
upload_size: int,
|
736
|
-
) -> str:
|
779
|
+
) -> dict[str, str]:
|
737
780
|
"""Transfers a file asynchronously to the given URL.
|
738
781
|
|
739
782
|
Args:
|
@@ -746,19 +789,20 @@ class BaseApiClient:
|
|
746
789
|
returns:
|
747
790
|
The response json object from the finalize request.
|
748
791
|
"""
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
792
|
+
if isinstance(file_path, io.IOBase):
|
793
|
+
return await self._async_upload_fd(file_path, upload_url, upload_size)
|
794
|
+
else:
|
795
|
+
file = anyio.Path(file_path)
|
796
|
+
fd = await file.open('rb')
|
797
|
+
async with fd:
|
798
|
+
return await self._async_upload_fd(fd, upload_url, upload_size)
|
755
799
|
|
756
800
|
async def _async_upload_fd(
|
757
801
|
self,
|
758
|
-
file: io.IOBase,
|
802
|
+
file: Union[io.IOBase, anyio.AsyncFile],
|
759
803
|
upload_url: str,
|
760
804
|
upload_size: int,
|
761
|
-
) -> str:
|
805
|
+
) -> dict[str, str]:
|
762
806
|
"""Transfers a file asynchronously to the given URL.
|
763
807
|
|
764
808
|
Args:
|
@@ -770,12 +814,45 @@ class BaseApiClient:
|
|
770
814
|
returns:
|
771
815
|
The response json object from the finalize request.
|
772
816
|
"""
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
817
|
+
async with httpx.AsyncClient() as aclient:
|
818
|
+
offset = 0
|
819
|
+
# Upload the file in chunks
|
820
|
+
while True:
|
821
|
+
if isinstance(file, io.IOBase):
|
822
|
+
file_chunk = file.read(CHUNK_SIZE)
|
823
|
+
else:
|
824
|
+
file_chunk = await file.read(CHUNK_SIZE)
|
825
|
+
chunk_size = 0
|
826
|
+
if file_chunk:
|
827
|
+
chunk_size = len(file_chunk)
|
828
|
+
upload_command = 'upload'
|
829
|
+
# If last chunk, finalize the upload.
|
830
|
+
if chunk_size + offset >= upload_size:
|
831
|
+
upload_command += ', finalize'
|
832
|
+
response = await aclient.request(
|
833
|
+
method='POST',
|
834
|
+
url=upload_url,
|
835
|
+
content=file_chunk,
|
836
|
+
headers={
|
837
|
+
'X-Goog-Upload-Command': upload_command,
|
838
|
+
'X-Goog-Upload-Offset': str(offset),
|
839
|
+
'Content-Length': str(chunk_size),
|
840
|
+
},
|
841
|
+
)
|
842
|
+
offset += chunk_size
|
843
|
+
if response.headers.get('x-goog-upload-status') != 'active':
|
844
|
+
break # upload is complete or it has been interrupted.
|
845
|
+
|
846
|
+
if upload_size <= offset: # Status is not finalized.
|
847
|
+
raise ValueError(
|
848
|
+
'All content has been uploaded, but the upload status is not'
|
849
|
+
f' finalized.'
|
850
|
+
)
|
851
|
+
if response.headers.get('x-goog-upload-status') != 'final':
|
852
|
+
raise ValueError(
|
853
|
+
'Failed to upload file: Upload status is not finalized.'
|
854
|
+
)
|
855
|
+
return response.json()
|
779
856
|
|
780
857
|
async def async_download_file(self, path: str, http_options):
|
781
858
|
"""Downloads the file data.
|
@@ -787,12 +864,31 @@ class BaseApiClient:
|
|
787
864
|
returns:
|
788
865
|
The file bytes
|
789
866
|
"""
|
790
|
-
|
791
|
-
|
792
|
-
path,
|
793
|
-
http_options,
|
867
|
+
http_request = self._build_request(
|
868
|
+
'get', path=path, request_dict={}, http_options=http_options
|
794
869
|
)
|
795
870
|
|
871
|
+
data: Optional[Union[str, bytes]] = None
|
872
|
+
if http_request.data:
|
873
|
+
if not isinstance(http_request.data, bytes):
|
874
|
+
data = json.dumps(http_request.data)
|
875
|
+
else:
|
876
|
+
data = http_request.data
|
877
|
+
|
878
|
+
async with httpx.AsyncClient(follow_redirects=True) as aclient:
|
879
|
+
response = await aclient.request(
|
880
|
+
method=http_request.method,
|
881
|
+
url=http_request.url,
|
882
|
+
headers=http_request.headers,
|
883
|
+
content=data,
|
884
|
+
timeout=http_request.timeout,
|
885
|
+
)
|
886
|
+
errors.APIError.raise_for_response(response)
|
887
|
+
|
888
|
+
return HttpResponse(
|
889
|
+
response.headers, byte_stream=[response.read()]
|
890
|
+
).byte_stream[0]
|
891
|
+
|
796
892
|
# This method does nothing in the real api client. It is used in the
|
797
893
|
# replay_api_client to verify the response from the SDK method matches the
|
798
894
|
# recorded response.
|
@@ -17,7 +17,7 @@ import inspect
|
|
17
17
|
import sys
|
18
18
|
import types as builtin_types
|
19
19
|
import typing
|
20
|
-
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Union
|
20
|
+
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union # type: ignore[attr-defined]
|
21
21
|
|
22
22
|
import pydantic
|
23
23
|
|
@@ -41,19 +41,11 @@ _py_builtin_type_to_schema_type = {
|
|
41
41
|
|
42
42
|
|
43
43
|
def _is_builtin_primitive_or_compound(
|
44
|
-
annotation: inspect.Parameter.annotation,
|
44
|
+
annotation: inspect.Parameter.annotation, # type: ignore[valid-type]
|
45
45
|
) -> bool:
|
46
46
|
return annotation in _py_builtin_type_to_schema_type.keys()
|
47
47
|
|
48
48
|
|
49
|
-
def _raise_for_any_of_if_mldev(schema: types.Schema):
|
50
|
-
if schema.any_of:
|
51
|
-
raise ValueError(
|
52
|
-
'AnyOf is not supported in function declaration schema for'
|
53
|
-
' the Gemini API.'
|
54
|
-
)
|
55
|
-
|
56
|
-
|
57
49
|
def _raise_for_default_if_mldev(schema: types.Schema):
|
58
50
|
if schema.default is not None:
|
59
51
|
raise ValueError(
|
@@ -64,12 +56,11 @@ def _raise_for_default_if_mldev(schema: types.Schema):
|
|
64
56
|
|
65
57
|
def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
|
66
58
|
if api_option == 'GEMINI_API':
|
67
|
-
_raise_for_any_of_if_mldev(schema)
|
68
59
|
_raise_for_default_if_mldev(schema)
|
69
60
|
|
70
61
|
|
71
62
|
def _is_default_value_compatible(
|
72
|
-
default_value: Any, annotation: inspect.Parameter.annotation
|
63
|
+
default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
|
73
64
|
) -> bool:
|
74
65
|
# None type is expected to be handled external to this function
|
75
66
|
if _is_builtin_primitive_or_compound(annotation):
|
@@ -292,8 +283,7 @@ def _parse_schema_from_parameter(
|
|
292
283
|
),
|
293
284
|
func_name,
|
294
285
|
)
|
295
|
-
|
296
|
-
schema.required = _get_required_fields(schema)
|
286
|
+
schema.required = _get_required_fields(schema)
|
297
287
|
_raise_if_schema_unsupported(api_option, schema)
|
298
288
|
return schema
|
299
289
|
raise ValueError(
|
@@ -304,9 +294,9 @@ def _parse_schema_from_parameter(
|
|
304
294
|
)
|
305
295
|
|
306
296
|
|
307
|
-
def _get_required_fields(schema: types.Schema) -> list[str]:
|
297
|
+
def _get_required_fields(schema: types.Schema) -> Optional[list[str]]:
|
308
298
|
if not schema.properties:
|
309
|
-
return
|
299
|
+
return None
|
310
300
|
return [
|
311
301
|
field_name
|
312
302
|
for field_name, field_schema in schema.properties.items()
|
google/genai/_common.py
CHANGED
@@ -188,6 +188,8 @@ def _remove_extra_fields(
|
|
188
188
|
if isinstance(item, dict):
|
189
189
|
_remove_extra_fields(typing.get_args(annotation)[0], item)
|
190
190
|
|
191
|
+
T = typing.TypeVar('T', bound='BaseModel')
|
192
|
+
|
191
193
|
|
192
194
|
class BaseModel(pydantic.BaseModel):
|
193
195
|
|
@@ -201,12 +203,13 @@ class BaseModel(pydantic.BaseModel):
|
|
201
203
|
arbitrary_types_allowed=True,
|
202
204
|
ser_json_bytes='base64',
|
203
205
|
val_json_bytes='base64',
|
206
|
+
ignored_types=(typing.TypeVar,)
|
204
207
|
)
|
205
208
|
|
206
209
|
@classmethod
|
207
210
|
def _from_response(
|
208
|
-
cls, *, response: dict[str, object], kwargs: dict[str, object]
|
209
|
-
) ->
|
211
|
+
cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
|
212
|
+
) -> T:
|
210
213
|
# To maintain forward compatibility, we need to remove extra fields from
|
211
214
|
# the response.
|
212
215
|
# We will provide another mechanism to allow users to access these fields.
|