google-genai 1.27.0__py3-none-any.whl → 1.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +1 -0
- google/genai/_api_client.py +281 -92
- google/genai/_automatic_function_calling_util.py +35 -7
- google/genai/_common.py +9 -6
- google/genai/_extra_utils.py +8 -8
- google/genai/_live_converters.py +14 -0
- google/genai/_mcp_utils.py +4 -1
- google/genai/_replay_api_client.py +6 -2
- google/genai/_transformers.py +13 -12
- google/genai/batches.py +19 -2
- google/genai/errors.py +6 -3
- google/genai/live.py +2 -3
- google/genai/models.py +472 -22
- google/genai/pagers.py +5 -5
- google/genai/tokens.py +3 -3
- google/genai/tunings.py +33 -6
- google/genai/types.py +395 -39
- google/genai/version.py +1 -1
- {google_genai-1.27.0.dist-info → google_genai-1.29.0.dist-info}/METADATA +2 -2
- google_genai-1.29.0.dist-info/RECORD +35 -0
- google_genai-1.27.0.dist-info/RECORD +0 -35
- {google_genai-1.27.0.dist-info → google_genai-1.29.0.dist-info}/WHEEL +0 -0
- {google_genai-1.27.0.dist-info → google_genai-1.29.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.27.0.dist-info → google_genai-1.29.0.dist-info}/top_level.txt +0 -0
google/genai/__init__.py
CHANGED
google/genai/_api_client.py
CHANGED
@@ -20,22 +20,21 @@ The BaseApiClient is intended to be a private module and is subject to change.
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
import asyncio
|
23
|
-
from collections.abc import
|
23
|
+
from collections.abc import Generator
|
24
24
|
import copy
|
25
25
|
from dataclasses import dataclass
|
26
|
-
import datetime
|
27
|
-
import http
|
28
26
|
import inspect
|
29
27
|
import io
|
30
28
|
import json
|
31
29
|
import logging
|
32
30
|
import math
|
33
31
|
import os
|
32
|
+
import random
|
34
33
|
import ssl
|
35
34
|
import sys
|
36
35
|
import threading
|
37
36
|
import time
|
38
|
-
from typing import Any, AsyncIterator,
|
37
|
+
from typing import Any, AsyncIterator, Iterator, Optional, Tuple, TYPE_CHECKING, Union
|
39
38
|
from urllib.parse import urlparse
|
40
39
|
from urllib.parse import urlunparse
|
41
40
|
|
@@ -47,7 +46,6 @@ from google.auth.credentials import Credentials
|
|
47
46
|
from google.auth.transport.requests import Request
|
48
47
|
import httpx
|
49
48
|
from pydantic import BaseModel
|
50
|
-
from pydantic import Field
|
51
49
|
from pydantic import ValidationError
|
52
50
|
import tenacity
|
53
51
|
|
@@ -55,11 +53,11 @@ from . import _common
|
|
55
53
|
from . import errors
|
56
54
|
from . import version
|
57
55
|
from .types import HttpOptions
|
58
|
-
from .types import HttpOptionsDict
|
59
56
|
from .types import HttpOptionsOrDict
|
60
57
|
from .types import HttpResponse as SdkHttpResponse
|
61
58
|
from .types import HttpRetryOptions
|
62
59
|
|
60
|
+
|
63
61
|
try:
|
64
62
|
from websockets.asyncio.client import connect as ws_connect
|
65
63
|
except ModuleNotFoundError:
|
@@ -81,6 +79,7 @@ if TYPE_CHECKING:
|
|
81
79
|
|
82
80
|
logger = logging.getLogger('google_genai._api_client')
|
83
81
|
CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
|
82
|
+
READ_BUFFER_SIZE = 2**20
|
84
83
|
MAX_RETRY_COUNT = 3
|
85
84
|
INITIAL_RETRY_DELAY = 1 # second
|
86
85
|
DELAY_MULTIPLIER = 2
|
@@ -236,12 +235,12 @@ class HttpResponse:
|
|
236
235
|
self.headers = headers
|
237
236
|
elif isinstance(headers, httpx.Headers):
|
238
237
|
self.headers = {
|
239
|
-
|
240
|
-
|
238
|
+
key: ', '.join(headers.get_list(key)) for key in headers.keys()
|
239
|
+
}
|
241
240
|
elif type(headers).__name__ == 'CIMultiDictProxy':
|
242
241
|
self.headers = {
|
243
|
-
key: ', '.join(headers.getall(key))
|
244
|
-
|
242
|
+
key: ', '.join(headers.getall(key)) for key in headers.keys()
|
243
|
+
}
|
245
244
|
|
246
245
|
self.status_code: int = 200
|
247
246
|
self.response_stream = response_stream
|
@@ -263,68 +262,32 @@ class HttpResponse:
|
|
263
262
|
def json(self) -> Any:
|
264
263
|
if not self.response_stream[0]: # Empty response
|
265
264
|
return ''
|
266
|
-
return
|
265
|
+
return self._load_json_from_response(self.response_stream[0])
|
267
266
|
|
268
267
|
def segments(self) -> Generator[Any, None, None]:
|
269
268
|
if isinstance(self.response_stream, list):
|
270
269
|
# list of objects retrieved from replay or from non-streaming API.
|
271
270
|
for chunk in self.response_stream:
|
272
|
-
yield
|
271
|
+
yield self._load_json_from_response(chunk) if chunk else {}
|
273
272
|
elif self.response_stream is None:
|
274
273
|
yield from []
|
275
274
|
else:
|
276
275
|
# Iterator of objects retrieved from the API.
|
277
|
-
for chunk in self.
|
278
|
-
|
279
|
-
# In streaming mode, the chunk of JSON is prefixed with "data:" which
|
280
|
-
# we must strip before parsing.
|
281
|
-
if not isinstance(chunk, str):
|
282
|
-
chunk = chunk.decode('utf-8')
|
283
|
-
if chunk.startswith('data: '):
|
284
|
-
chunk = chunk[len('data: ') :]
|
285
|
-
yield json.loads(chunk)
|
276
|
+
for chunk in self._iter_response_stream():
|
277
|
+
yield self._load_json_from_response(chunk)
|
286
278
|
|
287
279
|
async def async_segments(self) -> AsyncIterator[Any]:
|
288
280
|
if isinstance(self.response_stream, list):
|
289
281
|
# list of objects retrieved from replay or from non-streaming API.
|
290
282
|
for chunk in self.response_stream:
|
291
|
-
yield
|
283
|
+
yield self._load_json_from_response(chunk) if chunk else {}
|
292
284
|
elif self.response_stream is None:
|
293
285
|
async for c in []: # type: ignore[attr-defined]
|
294
286
|
yield c
|
295
287
|
else:
|
296
288
|
# Iterator of objects retrieved from the API.
|
297
|
-
|
298
|
-
|
299
|
-
# This is httpx.Response.
|
300
|
-
if chunk:
|
301
|
-
# In async streaming mode, the chunk of JSON is prefixed with
|
302
|
-
# "data:" which we must strip before parsing.
|
303
|
-
if not isinstance(chunk, str):
|
304
|
-
chunk = chunk.decode('utf-8')
|
305
|
-
if chunk.startswith('data: '):
|
306
|
-
chunk = chunk[len('data: ') :]
|
307
|
-
yield json.loads(chunk)
|
308
|
-
elif hasattr(self.response_stream, 'content'):
|
309
|
-
# This is aiohttp.ClientResponse.
|
310
|
-
try:
|
311
|
-
while True:
|
312
|
-
chunk = await self.response_stream.content.readline()
|
313
|
-
if not chunk:
|
314
|
-
break
|
315
|
-
# In async streaming mode, the chunk of JSON is prefixed with
|
316
|
-
# "data:" which we must strip before parsing.
|
317
|
-
chunk = chunk.decode('utf-8')
|
318
|
-
if chunk.startswith('data: '):
|
319
|
-
chunk = chunk[len('data: ') :]
|
320
|
-
chunk = chunk.strip()
|
321
|
-
if chunk:
|
322
|
-
yield json.loads(chunk)
|
323
|
-
finally:
|
324
|
-
if hasattr(self, '_session') and self._session:
|
325
|
-
await self._session.close()
|
326
|
-
else:
|
327
|
-
raise ValueError('Error parsing streaming response.')
|
289
|
+
async for chunk in self._aiter_response_stream():
|
290
|
+
yield self._load_json_from_response(chunk)
|
328
291
|
|
329
292
|
def byte_segments(self) -> Generator[Union[bytes, Any], None, None]:
|
330
293
|
if isinstance(self.byte_stream, list):
|
@@ -343,6 +306,130 @@ class HttpResponse:
|
|
343
306
|
for attribute in dir(self):
|
344
307
|
response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
|
345
308
|
|
309
|
+
def _iter_response_stream(self) -> Iterator[str]:
|
310
|
+
"""Iterates over chunks retrieved from the API."""
|
311
|
+
if not isinstance(self.response_stream, httpx.Response):
|
312
|
+
raise TypeError(
|
313
|
+
'Expected self.response_stream to be an httpx.Response object, '
|
314
|
+
f'but got {type(self.response_stream).__name__}.'
|
315
|
+
)
|
316
|
+
|
317
|
+
chunk = ''
|
318
|
+
balance = 0
|
319
|
+
for line in self.response_stream.iter_lines():
|
320
|
+
if not line:
|
321
|
+
continue
|
322
|
+
|
323
|
+
# In streaming mode, the response of JSON is prefixed with "data: " which
|
324
|
+
# we must strip before parsing.
|
325
|
+
if line.startswith('data: '):
|
326
|
+
yield line[len('data: '):]
|
327
|
+
continue
|
328
|
+
|
329
|
+
# When API returns an error message, it comes line by line. So we buffer
|
330
|
+
# the lines until a complete JSON string is read. A complete JSON string
|
331
|
+
# is found when the balance is 0.
|
332
|
+
for c in line:
|
333
|
+
if c == '{':
|
334
|
+
balance += 1
|
335
|
+
elif c == '}':
|
336
|
+
balance -= 1
|
337
|
+
|
338
|
+
chunk += line
|
339
|
+
if balance == 0:
|
340
|
+
yield chunk
|
341
|
+
chunk = ''
|
342
|
+
|
343
|
+
# If there is any remaining chunk, yield it.
|
344
|
+
if chunk:
|
345
|
+
yield chunk
|
346
|
+
|
347
|
+
async def _aiter_response_stream(self) -> AsyncIterator[str]:
|
348
|
+
"""Asynchronously iterates over chunks retrieved from the API."""
|
349
|
+
if not isinstance(
|
350
|
+
self.response_stream, (httpx.Response, aiohttp.ClientResponse)
|
351
|
+
):
|
352
|
+
raise TypeError(
|
353
|
+
'Expected self.response_stream to be an httpx.Response or'
|
354
|
+
' aiohttp.ClientResponse object, but got'
|
355
|
+
f' {type(self.response_stream).__name__}.'
|
356
|
+
)
|
357
|
+
|
358
|
+
chunk = ''
|
359
|
+
balance = 0
|
360
|
+
# httpx.Response has a dedicated async line iterator.
|
361
|
+
if isinstance(self.response_stream, httpx.Response):
|
362
|
+
async for line in self.response_stream.aiter_lines():
|
363
|
+
if not line:
|
364
|
+
continue
|
365
|
+
# In streaming mode, the response of JSON is prefixed with "data: "
|
366
|
+
# which we must strip before parsing.
|
367
|
+
if line.startswith('data: '):
|
368
|
+
yield line[len('data: '):]
|
369
|
+
continue
|
370
|
+
|
371
|
+
# When API returns an error message, it comes line by line. So we buffer
|
372
|
+
# the lines until a complete JSON string is read. A complete JSON string
|
373
|
+
# is found when the balance is 0.
|
374
|
+
for c in line:
|
375
|
+
if c == '{':
|
376
|
+
balance += 1
|
377
|
+
elif c == '}':
|
378
|
+
balance -= 1
|
379
|
+
|
380
|
+
chunk += line
|
381
|
+
if balance == 0:
|
382
|
+
yield chunk
|
383
|
+
chunk = ''
|
384
|
+
|
385
|
+
# aiohttp.ClientResponse uses a content stream that we read line by line.
|
386
|
+
elif isinstance(self.response_stream, aiohttp.ClientResponse):
|
387
|
+
while True:
|
388
|
+
# Read a line from the stream. This returns bytes.
|
389
|
+
line_bytes = await self.response_stream.content.readline()
|
390
|
+
if not line_bytes:
|
391
|
+
break
|
392
|
+
# Decode the bytes and remove trailing whitespace and newlines.
|
393
|
+
line = line_bytes.decode('utf-8').rstrip()
|
394
|
+
if not line:
|
395
|
+
continue
|
396
|
+
|
397
|
+
# In streaming mode, the response of JSON is prefixed with "data: "
|
398
|
+
# which we must strip before parsing.
|
399
|
+
if line.startswith('data: '):
|
400
|
+
yield line[len('data: '):]
|
401
|
+
continue
|
402
|
+
|
403
|
+
# When API returns an error message, it comes line by line. So we buffer
|
404
|
+
# the lines until a complete JSON string is read. A complete JSON string
|
405
|
+
# is found when the balance is 0.
|
406
|
+
for c in line:
|
407
|
+
if c == '{':
|
408
|
+
balance += 1
|
409
|
+
elif c == '}':
|
410
|
+
balance -= 1
|
411
|
+
|
412
|
+
chunk += line
|
413
|
+
if balance == 0:
|
414
|
+
yield chunk
|
415
|
+
chunk = ''
|
416
|
+
|
417
|
+
# If there is any remaining chunk, yield it.
|
418
|
+
if chunk:
|
419
|
+
yield chunk
|
420
|
+
|
421
|
+
if hasattr(self, '_session') and self._session:
|
422
|
+
await self._session.close()
|
423
|
+
|
424
|
+
@classmethod
|
425
|
+
def _load_json_from_response(cls, response: Any) -> Any:
|
426
|
+
"""Loads JSON from the response, or raises an error if the parsing fails."""
|
427
|
+
try:
|
428
|
+
return json.loads(response)
|
429
|
+
except json.JSONDecodeError as e:
|
430
|
+
raise errors.UnknownApiResponseError(
|
431
|
+
f'Failed to parse response as JSON. Raw response: {response}'
|
432
|
+
) from e
|
346
433
|
|
347
434
|
# Default retry options.
|
348
435
|
# The config is based on https://cloud.google.com/storage/docs/retry-strategy.
|
@@ -363,7 +450,7 @@ _RETRY_HTTP_STATUS_CODES = (
|
|
363
450
|
)
|
364
451
|
|
365
452
|
|
366
|
-
def _retry_args(options: Optional[HttpRetryOptions]) ->
|
453
|
+
def _retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
|
367
454
|
"""Returns the retry args for the given http retry options.
|
368
455
|
|
369
456
|
Args:
|
@@ -533,17 +620,30 @@ class BaseApiClient:
|
|
533
620
|
+ ' precedence over the API key from the environment variables.'
|
534
621
|
)
|
535
622
|
self.api_key = None
|
536
|
-
|
623
|
+
|
624
|
+
# Skip fetching project from ADC if base url is provided in http options.
|
625
|
+
if (
|
626
|
+
not self.project
|
627
|
+
and not self.api_key
|
628
|
+
and not validated_http_options.base_url
|
629
|
+
):
|
537
630
|
credentials, self.project = _load_auth(project=None)
|
538
631
|
if not self._credentials:
|
539
632
|
self._credentials = credentials
|
540
|
-
|
633
|
+
|
634
|
+
has_sufficient_auth = (self.project and self.location) or self.api_key
|
635
|
+
|
636
|
+
if (not has_sufficient_auth and not validated_http_options.base_url):
|
637
|
+
# Skip sufficient auth check if base url is provided in http options.
|
541
638
|
raise ValueError(
|
542
639
|
'Project and location or API key must be set when using the Vertex '
|
543
640
|
'AI API.'
|
544
641
|
)
|
545
642
|
if self.api_key or self.location == 'global':
|
546
643
|
self._http_options.base_url = f'https://aiplatform.googleapis.com/'
|
644
|
+
elif validated_http_options.base_url and not has_sufficient_auth:
|
645
|
+
# Avoid setting default base url and api version if base_url provided.
|
646
|
+
self._http_options.base_url = validated_http_options.base_url
|
547
647
|
else:
|
548
648
|
self._http_options.base_url = (
|
549
649
|
f'https://{self.location}-aiplatform.googleapis.com/'
|
@@ -592,7 +692,7 @@ class BaseApiClient:
|
|
592
692
|
@staticmethod
|
593
693
|
def _ensure_httpx_ssl_ctx(
|
594
694
|
options: HttpOptions,
|
595
|
-
) -> Tuple[
|
695
|
+
) -> Tuple[_common.StringDict, _common.StringDict]:
|
596
696
|
"""Ensures the SSL context is present in the HTTPX client args.
|
597
697
|
|
598
698
|
Creates a default SSL context if one is not provided.
|
@@ -626,9 +726,9 @@ class BaseApiClient:
|
|
626
726
|
)
|
627
727
|
|
628
728
|
def _maybe_set(
|
629
|
-
args: Optional[
|
729
|
+
args: Optional[_common.StringDict],
|
630
730
|
ctx: ssl.SSLContext,
|
631
|
-
) ->
|
731
|
+
) -> _common.StringDict:
|
632
732
|
"""Sets the SSL context in the client args if not set.
|
633
733
|
|
634
734
|
Does not override the SSL context if it is already set.
|
@@ -656,7 +756,7 @@ class BaseApiClient:
|
|
656
756
|
)
|
657
757
|
|
658
758
|
@staticmethod
|
659
|
-
def _ensure_aiohttp_ssl_ctx(options: HttpOptions) ->
|
759
|
+
def _ensure_aiohttp_ssl_ctx(options: HttpOptions) -> _common.StringDict:
|
660
760
|
"""Ensures the SSL context is present in the async client args.
|
661
761
|
|
662
762
|
Creates a default SSL context if one is not provided.
|
@@ -684,9 +784,9 @@ class BaseApiClient:
|
|
684
784
|
)
|
685
785
|
|
686
786
|
def _maybe_set(
|
687
|
-
args: Optional[
|
787
|
+
args: Optional[_common.StringDict],
|
688
788
|
ctx: ssl.SSLContext,
|
689
|
-
) ->
|
789
|
+
) -> _common.StringDict:
|
690
790
|
"""Sets the SSL context in the client args if not set.
|
691
791
|
|
692
792
|
Does not override the SSL context if it is already set.
|
@@ -714,7 +814,7 @@ class BaseApiClient:
|
|
714
814
|
return _maybe_set(async_args, ctx)
|
715
815
|
|
716
816
|
@staticmethod
|
717
|
-
def _ensure_websocket_ssl_ctx(options: HttpOptions) ->
|
817
|
+
def _ensure_websocket_ssl_ctx(options: HttpOptions) -> _common.StringDict:
|
718
818
|
"""Ensures the SSL context is present in the async client args.
|
719
819
|
|
720
820
|
Creates a default SSL context if one is not provided.
|
@@ -742,9 +842,9 @@ class BaseApiClient:
|
|
742
842
|
)
|
743
843
|
|
744
844
|
def _maybe_set(
|
745
|
-
args: Optional[
|
845
|
+
args: Optional[_common.StringDict],
|
746
846
|
ctx: ssl.SSLContext,
|
747
|
-
) ->
|
847
|
+
) -> _common.StringDict:
|
748
848
|
"""Sets the SSL context in the client args if not set.
|
749
849
|
|
750
850
|
Does not override the SSL context if it is already set.
|
@@ -864,7 +964,7 @@ class BaseApiClient:
|
|
864
964
|
self.vertexai
|
865
965
|
and not path.startswith('projects/')
|
866
966
|
and not query_vertex_base_models
|
867
|
-
and
|
967
|
+
and (self.project or self.location)
|
868
968
|
):
|
869
969
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
870
970
|
|
@@ -920,7 +1020,8 @@ class BaseApiClient:
|
|
920
1020
|
stream: bool = False,
|
921
1021
|
) -> HttpResponse:
|
922
1022
|
data: Optional[Union[str, bytes]] = None
|
923
|
-
|
1023
|
+
# If using proj/location, fetch ADC
|
1024
|
+
if self.vertexai and (self.project or self.location):
|
924
1025
|
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
|
925
1026
|
if self._credentials and self._credentials.quota_project_id:
|
926
1027
|
http_request.headers['x-goog-user-project'] = (
|
@@ -963,8 +1064,21 @@ class BaseApiClient:
|
|
963
1064
|
def _request(
|
964
1065
|
self,
|
965
1066
|
http_request: HttpRequest,
|
1067
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
966
1068
|
stream: bool = False,
|
967
1069
|
) -> HttpResponse:
|
1070
|
+
if http_options:
|
1071
|
+
parameter_model = (
|
1072
|
+
HttpOptions(**http_options)
|
1073
|
+
if isinstance(http_options, dict)
|
1074
|
+
else http_options
|
1075
|
+
)
|
1076
|
+
# Support per request retry options.
|
1077
|
+
if parameter_model.retry_options:
|
1078
|
+
retry_kwargs = _retry_args(parameter_model.retry_options)
|
1079
|
+
retry = tenacity.Retrying(**retry_kwargs)
|
1080
|
+
return retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
|
1081
|
+
|
968
1082
|
return self._retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
|
969
1083
|
|
970
1084
|
async def _async_request_once(
|
@@ -972,7 +1086,8 @@ class BaseApiClient:
|
|
972
1086
|
) -> HttpResponse:
|
973
1087
|
data: Optional[Union[str, bytes]] = None
|
974
1088
|
|
975
|
-
|
1089
|
+
# If using proj/location, fetch ADC
|
1090
|
+
if self.vertexai and (self.project or self.location):
|
976
1091
|
http_request.headers['Authorization'] = (
|
977
1092
|
f'Bearer {await self._async_access_token()}'
|
978
1093
|
)
|
@@ -993,15 +1108,43 @@ class BaseApiClient:
|
|
993
1108
|
session = aiohttp.ClientSession(
|
994
1109
|
headers=http_request.headers,
|
995
1110
|
trust_env=True,
|
1111
|
+
read_bufsize=READ_BUFFER_SIZE,
|
996
1112
|
)
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1113
|
+
try:
|
1114
|
+
response = await session.request(
|
1115
|
+
method=http_request.method,
|
1116
|
+
url=http_request.url,
|
1117
|
+
headers=http_request.headers,
|
1118
|
+
data=data,
|
1119
|
+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
1120
|
+
**self._async_client_session_request_args,
|
1121
|
+
)
|
1122
|
+
except (
|
1123
|
+
aiohttp.ClientConnectorError,
|
1124
|
+
aiohttp.ClientConnectorDNSError,
|
1125
|
+
aiohttp.ClientOSError,
|
1126
|
+
aiohttp.ServerDisconnectedError,
|
1127
|
+
) as e:
|
1128
|
+
await asyncio.sleep(1 + random.randint(0, 9))
|
1129
|
+
logger.info('Retrying due to aiohttp error: %s' % e)
|
1130
|
+
# Retrieve the SSL context from the session.
|
1131
|
+
self._async_client_session_request_args = (
|
1132
|
+
self._ensure_aiohttp_ssl_ctx(self._http_options)
|
1133
|
+
)
|
1134
|
+
# Instantiate a new session with the updated SSL context.
|
1135
|
+
session = aiohttp.ClientSession(
|
1136
|
+
headers=http_request.headers,
|
1137
|
+
trust_env=True,
|
1138
|
+
read_bufsize=READ_BUFFER_SIZE,
|
1139
|
+
)
|
1140
|
+
response = await session.request(
|
1141
|
+
method=http_request.method,
|
1142
|
+
url=http_request.url,
|
1143
|
+
headers=http_request.headers,
|
1144
|
+
data=data,
|
1145
|
+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
1146
|
+
**self._async_client_session_request_args,
|
1147
|
+
)
|
1005
1148
|
|
1006
1149
|
await errors.APIError.raise_for_async_response(response)
|
1007
1150
|
return HttpResponse(response.headers, response, session=session)
|
@@ -1022,20 +1165,50 @@ class BaseApiClient:
|
|
1022
1165
|
return HttpResponse(client_response.headers, client_response)
|
1023
1166
|
else:
|
1024
1167
|
if self._use_aiohttp():
|
1025
|
-
|
1026
|
-
|
1027
|
-
trust_env=True,
|
1028
|
-
) as session:
|
1029
|
-
response = await session.request(
|
1030
|
-
method=http_request.method,
|
1031
|
-
url=http_request.url,
|
1168
|
+
try:
|
1169
|
+
async with aiohttp.ClientSession(
|
1032
1170
|
headers=http_request.headers,
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1171
|
+
trust_env=True,
|
1172
|
+
read_bufsize=READ_BUFFER_SIZE,
|
1173
|
+
) as session:
|
1174
|
+
response = await session.request(
|
1175
|
+
method=http_request.method,
|
1176
|
+
url=http_request.url,
|
1177
|
+
headers=http_request.headers,
|
1178
|
+
data=data,
|
1179
|
+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
1180
|
+
**self._async_client_session_request_args,
|
1181
|
+
)
|
1182
|
+
await errors.APIError.raise_for_async_response(response)
|
1183
|
+
return HttpResponse(response.headers, [await response.text()])
|
1184
|
+
except (
|
1185
|
+
aiohttp.ClientConnectorError,
|
1186
|
+
aiohttp.ClientConnectorDNSError,
|
1187
|
+
aiohttp.ClientOSError,
|
1188
|
+
aiohttp.ServerDisconnectedError,
|
1189
|
+
) as e:
|
1190
|
+
await asyncio.sleep(1 + random.randint(0, 9))
|
1191
|
+
logger.info('Retrying due to aiohttp error: %s' % e)
|
1192
|
+
# Retrieve the SSL context from the session.
|
1193
|
+
self._async_client_session_request_args = (
|
1194
|
+
self._ensure_aiohttp_ssl_ctx(self._http_options)
|
1036
1195
|
)
|
1037
|
-
|
1038
|
-
|
1196
|
+
# Instantiate a new session with the updated SSL context.
|
1197
|
+
async with aiohttp.ClientSession(
|
1198
|
+
headers=http_request.headers,
|
1199
|
+
trust_env=True,
|
1200
|
+
read_bufsize=READ_BUFFER_SIZE,
|
1201
|
+
) as session:
|
1202
|
+
response = await session.request(
|
1203
|
+
method=http_request.method,
|
1204
|
+
url=http_request.url,
|
1205
|
+
headers=http_request.headers,
|
1206
|
+
data=data,
|
1207
|
+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
1208
|
+
**self._async_client_session_request_args,
|
1209
|
+
)
|
1210
|
+
await errors.APIError.raise_for_async_response(response)
|
1211
|
+
return HttpResponse(response.headers, [await response.text()])
|
1039
1212
|
else:
|
1040
1213
|
# aiohttp is not available. Fall back to httpx.
|
1041
1214
|
client_response = await self._async_httpx_client.request(
|
@@ -1051,13 +1224,25 @@ class BaseApiClient:
|
|
1051
1224
|
async def _async_request(
|
1052
1225
|
self,
|
1053
1226
|
http_request: HttpRequest,
|
1227
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
1054
1228
|
stream: bool = False,
|
1055
1229
|
) -> HttpResponse:
|
1230
|
+
if http_options:
|
1231
|
+
parameter_model = (
|
1232
|
+
HttpOptions(**http_options)
|
1233
|
+
if isinstance(http_options, dict)
|
1234
|
+
else http_options
|
1235
|
+
)
|
1236
|
+
# Support per request retry options.
|
1237
|
+
if parameter_model.retry_options:
|
1238
|
+
retry_kwargs = _retry_args(parameter_model.retry_options)
|
1239
|
+
retry = tenacity.AsyncRetrying(**retry_kwargs)
|
1240
|
+
return await retry(self._async_request_once, http_request, stream) # type: ignore[no-any-return]
|
1056
1241
|
return await self._async_retry( # type: ignore[no-any-return]
|
1057
1242
|
self._async_request_once, http_request, stream
|
1058
1243
|
)
|
1059
1244
|
|
1060
|
-
def get_read_only_http_options(self) ->
|
1245
|
+
def get_read_only_http_options(self) -> _common.StringDict:
|
1061
1246
|
if isinstance(self._http_options, BaseModel):
|
1062
1247
|
copied = self._http_options.model_dump()
|
1063
1248
|
else:
|
@@ -1074,7 +1259,7 @@ class BaseApiClient:
|
|
1074
1259
|
http_request = self._build_request(
|
1075
1260
|
http_method, path, request_dict, http_options
|
1076
1261
|
)
|
1077
|
-
response = self._request(http_request, stream=False)
|
1262
|
+
response = self._request(http_request, http_options, stream=False)
|
1078
1263
|
response_body = (
|
1079
1264
|
response.response_stream[0] if response.response_stream else ''
|
1080
1265
|
)
|
@@ -1091,7 +1276,7 @@ class BaseApiClient:
|
|
1091
1276
|
http_method, path, request_dict, http_options
|
1092
1277
|
)
|
1093
1278
|
|
1094
|
-
session_response = self._request(http_request, stream=True)
|
1279
|
+
session_response = self._request(http_request, http_options, stream=True)
|
1095
1280
|
for chunk in session_response.segments():
|
1096
1281
|
yield SdkHttpResponse(
|
1097
1282
|
headers=session_response.headers, body=json.dumps(chunk)
|
@@ -1108,7 +1293,9 @@ class BaseApiClient:
|
|
1108
1293
|
http_method, path, request_dict, http_options
|
1109
1294
|
)
|
1110
1295
|
|
1111
|
-
result = await self._async_request(
|
1296
|
+
result = await self._async_request(
|
1297
|
+
http_request=http_request, http_options=http_options, stream=False
|
1298
|
+
)
|
1112
1299
|
response_body = result.response_stream[0] if result.response_stream else ''
|
1113
1300
|
return SdkHttpResponse(headers=result.headers, body=response_body)
|
1114
1301
|
|
@@ -1340,6 +1527,7 @@ class BaseApiClient:
|
|
1340
1527
|
async with aiohttp.ClientSession(
|
1341
1528
|
headers=self._http_options.headers,
|
1342
1529
|
trust_env=True,
|
1530
|
+
read_bufsize=READ_BUFFER_SIZE,
|
1343
1531
|
) as session:
|
1344
1532
|
while True:
|
1345
1533
|
if isinstance(file, io.IOBase):
|
@@ -1523,6 +1711,7 @@ class BaseApiClient:
|
|
1523
1711
|
async with aiohttp.ClientSession(
|
1524
1712
|
headers=http_request.headers,
|
1525
1713
|
trust_env=True,
|
1714
|
+
read_bufsize=READ_BUFFER_SIZE,
|
1526
1715
|
) as session:
|
1527
1716
|
response = await session.request(
|
1528
1717
|
method=http_request.method,
|
@@ -41,6 +41,39 @@ _py_builtin_type_to_schema_type = {
|
|
41
41
|
}
|
42
42
|
|
43
43
|
|
44
|
+
def _raise_for_unsupported_param(
|
45
|
+
param: inspect.Parameter, func_name: str, exception: Union[Exception, type[Exception]]
|
46
|
+
) -> None:
|
47
|
+
raise ValueError(
|
48
|
+
f'Failed to parse the parameter {param} of function {func_name} for'
|
49
|
+
' automatic function calling.Automatic function calling works best with'
|
50
|
+
' simpler function signature schema, consider manually parsing your'
|
51
|
+
f' function declaration for function {func_name}.'
|
52
|
+
) from exception
|
53
|
+
|
54
|
+
|
55
|
+
def _handle_params_as_deferred_annotations(param: inspect.Parameter, annotation_under_future: dict[str, Any], name: str) -> inspect.Parameter:
|
56
|
+
"""Catches the case when type hints are stored as strings."""
|
57
|
+
if isinstance(param.annotation, str):
|
58
|
+
param = param.replace(annotation=annotation_under_future[name])
|
59
|
+
return param
|
60
|
+
|
61
|
+
|
62
|
+
def _add_unevaluated_items_to_fixed_len_tuple_schema(
|
63
|
+
json_schema: dict[str, Any]
|
64
|
+
) -> dict[str, Any]:
|
65
|
+
if (
|
66
|
+
json_schema.get('maxItems')
|
67
|
+
and (
|
68
|
+
json_schema.get('prefixItems')
|
69
|
+
and len(json_schema['prefixItems']) == json_schema['maxItems']
|
70
|
+
)
|
71
|
+
and json_schema.get('type') == 'array'
|
72
|
+
):
|
73
|
+
json_schema['unevaluatedItems'] = False
|
74
|
+
return json_schema
|
75
|
+
|
76
|
+
|
44
77
|
def _is_builtin_primitive_or_compound(
|
45
78
|
annotation: inspect.Parameter.annotation, # type: ignore[valid-type]
|
46
79
|
) -> bool:
|
@@ -92,7 +125,7 @@ def _is_default_value_compatible(
|
|
92
125
|
return False
|
93
126
|
|
94
127
|
|
95
|
-
def _parse_schema_from_parameter(
|
128
|
+
def _parse_schema_from_parameter( # type: ignore[return]
|
96
129
|
api_option: Literal['VERTEX_AI', 'GEMINI_API'],
|
97
130
|
param: inspect.Parameter,
|
98
131
|
func_name: str,
|
@@ -267,12 +300,7 @@ def _parse_schema_from_parameter(
|
|
267
300
|
)
|
268
301
|
schema.required = _get_required_fields(schema)
|
269
302
|
return schema
|
270
|
-
|
271
|
-
f'Failed to parse the parameter {param} of function {func_name} for'
|
272
|
-
' automatic function calling.Automatic function calling works best with'
|
273
|
-
' simpler function signature schema, consider manually parsing your'
|
274
|
-
f' function declaration for function {func_name}.'
|
275
|
-
)
|
303
|
+
_raise_for_unsupported_param(param, func_name, ValueError)
|
276
304
|
|
277
305
|
|
278
306
|
def _get_required_fields(schema: types.Schema) -> Optional[list[str]]:
|