google-genai 1.7.0__py3-none-any.whl → 1.53.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +4 -2
- google/genai/_adapters.py +55 -0
- google/genai/_api_client.py +1301 -299
- google/genai/_api_module.py +1 -1
- google/genai/_automatic_function_calling_util.py +54 -33
- google/genai/_base_transformers.py +26 -0
- google/genai/_base_url.py +50 -0
- google/genai/_common.py +560 -59
- google/genai/_extra_utils.py +371 -38
- google/genai/_live_converters.py +1467 -0
- google/genai/_local_tokenizer_loader.py +214 -0
- google/genai/_mcp_utils.py +117 -0
- google/genai/_operations_converters.py +394 -0
- google/genai/_replay_api_client.py +204 -92
- google/genai/_test_api_client.py +1 -1
- google/genai/_tokens_converters.py +520 -0
- google/genai/_transformers.py +633 -233
- google/genai/batches.py +1733 -538
- google/genai/caches.py +678 -1012
- google/genai/chats.py +48 -38
- google/genai/client.py +142 -15
- google/genai/documents.py +532 -0
- google/genai/errors.py +141 -35
- google/genai/file_search_stores.py +1296 -0
- google/genai/files.py +312 -744
- google/genai/live.py +617 -367
- google/genai/live_music.py +197 -0
- google/genai/local_tokenizer.py +395 -0
- google/genai/models.py +3598 -3116
- google/genai/operations.py +201 -362
- google/genai/pagers.py +23 -7
- google/genai/py.typed +1 -0
- google/genai/tokens.py +362 -0
- google/genai/tunings.py +1274 -496
- google/genai/types.py +14535 -5454
- google/genai/version.py +2 -2
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +736 -234
- google_genai-1.53.0.dist-info/RECORD +41 -0
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +1 -1
- google_genai-1.7.0.dist-info/RECORD +0 -27
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info/licenses}/LICENSE +0 -0
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -19,34 +19,97 @@
|
|
|
19
19
|
The BaseApiClient is intended to be a private module and is subject to change.
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
import anyio
|
|
23
22
|
import asyncio
|
|
23
|
+
from collections.abc import Generator
|
|
24
24
|
import copy
|
|
25
25
|
from dataclasses import dataclass
|
|
26
|
-
import
|
|
26
|
+
import inspect
|
|
27
27
|
import io
|
|
28
28
|
import json
|
|
29
29
|
import logging
|
|
30
|
+
import math
|
|
30
31
|
import os
|
|
32
|
+
import random
|
|
33
|
+
import ssl
|
|
31
34
|
import sys
|
|
32
|
-
|
|
33
|
-
|
|
35
|
+
import threading
|
|
36
|
+
import time
|
|
37
|
+
from typing import Any, AsyncIterator, Iterator, Optional, Tuple, TYPE_CHECKING, Union
|
|
38
|
+
from urllib.parse import urlparse
|
|
39
|
+
from urllib.parse import urlunparse
|
|
40
|
+
import warnings
|
|
41
|
+
|
|
42
|
+
import anyio
|
|
43
|
+
import certifi
|
|
34
44
|
import google.auth
|
|
35
45
|
import google.auth.credentials
|
|
36
46
|
from google.auth.credentials import Credentials
|
|
37
|
-
from google.auth.transport.requests import Request
|
|
38
47
|
import httpx
|
|
39
|
-
from pydantic import BaseModel
|
|
48
|
+
from pydantic import BaseModel
|
|
49
|
+
from pydantic import ValidationError
|
|
50
|
+
import tenacity
|
|
51
|
+
|
|
40
52
|
from . import _common
|
|
41
53
|
from . import errors
|
|
42
54
|
from . import version
|
|
43
|
-
from .types import HttpOptions
|
|
55
|
+
from .types import HttpOptions
|
|
56
|
+
from .types import HttpOptionsOrDict
|
|
57
|
+
from .types import HttpResponse as SdkHttpResponse
|
|
58
|
+
from .types import HttpRetryOptions
|
|
59
|
+
from .types import ResourceScope
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
from websockets.asyncio.client import connect as ws_connect
|
|
64
|
+
except ModuleNotFoundError:
|
|
65
|
+
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
|
|
66
|
+
from websockets.client import connect as ws_connect # type: ignore
|
|
67
|
+
|
|
68
|
+
has_aiohttp = False
|
|
69
|
+
try:
|
|
70
|
+
import aiohttp
|
|
71
|
+
|
|
72
|
+
has_aiohttp = True
|
|
73
|
+
except ImportError:
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
if TYPE_CHECKING:
|
|
78
|
+
from multidict import CIMultiDictProxy
|
|
79
|
+
|
|
44
80
|
|
|
45
81
|
logger = logging.getLogger('google_genai._api_client')
|
|
46
82
|
CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
|
|
83
|
+
READ_BUFFER_SIZE = 2**22
|
|
84
|
+
MAX_RETRY_COUNT = 3
|
|
85
|
+
INITIAL_RETRY_DELAY = 1 # second
|
|
86
|
+
DELAY_MULTIPLIER = 2
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class EphemeralTokenAPIKeyError(ValueError):
|
|
90
|
+
"""Error raised when the API key is invalid."""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# This method checks for the API key in the environment variables. Google API
|
|
94
|
+
# key is precedenced over Gemini API key.
|
|
95
|
+
def get_env_api_key() -> Optional[str]:
|
|
96
|
+
"""Gets the API key from environment variables, prioritizing GOOGLE_API_KEY.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
The API key string if found, otherwise None. Empty string is considered
|
|
100
|
+
invalid.
|
|
101
|
+
"""
|
|
102
|
+
env_google_api_key = os.environ.get('GOOGLE_API_KEY', None)
|
|
103
|
+
env_gemini_api_key = os.environ.get('GEMINI_API_KEY', None)
|
|
104
|
+
if env_google_api_key and env_gemini_api_key:
|
|
105
|
+
logger.warning(
|
|
106
|
+
'Both GOOGLE_API_KEY and GEMINI_API_KEY are set. Using GOOGLE_API_KEY.'
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return env_google_api_key or env_gemini_api_key or None
|
|
47
110
|
|
|
48
111
|
|
|
49
|
-
def
|
|
112
|
+
def append_library_version_headers(headers: dict[str, str]) -> None:
|
|
50
113
|
"""Appends the telemetry header to the headers dict."""
|
|
51
114
|
library_label = f'google-genai-sdk/{version.__version__}'
|
|
52
115
|
language_label = 'gl-python/' + sys.version.split()[0]
|
|
@@ -55,43 +118,57 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
|
|
|
55
118
|
'user-agent' in headers
|
|
56
119
|
and version_header_value not in headers['user-agent']
|
|
57
120
|
):
|
|
58
|
-
headers['user-agent']
|
|
121
|
+
headers['user-agent'] = f'{version_header_value} ' + headers['user-agent']
|
|
59
122
|
elif 'user-agent' not in headers:
|
|
60
123
|
headers['user-agent'] = version_header_value
|
|
61
124
|
if (
|
|
62
125
|
'x-goog-api-client' in headers
|
|
63
126
|
and version_header_value not in headers['x-goog-api-client']
|
|
64
127
|
):
|
|
65
|
-
headers['x-goog-api-client']
|
|
128
|
+
headers['x-goog-api-client'] = (
|
|
129
|
+
f'{version_header_value} ' + headers['x-goog-api-client']
|
|
130
|
+
)
|
|
66
131
|
elif 'x-goog-api-client' not in headers:
|
|
67
132
|
headers['x-goog-api-client'] = version_header_value
|
|
68
133
|
|
|
69
134
|
|
|
70
|
-
def
|
|
71
|
-
options:
|
|
72
|
-
) ->
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
copy_option.
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
135
|
+
def patch_http_options(
|
|
136
|
+
options: HttpOptions, patch_options: HttpOptions
|
|
137
|
+
) -> HttpOptions:
|
|
138
|
+
copy_option = options.model_copy()
|
|
139
|
+
|
|
140
|
+
options_headers = copy_option.headers or {}
|
|
141
|
+
patch_options_headers = patch_options.headers or {}
|
|
142
|
+
copy_option.headers = {
|
|
143
|
+
**options_headers,
|
|
144
|
+
**patch_options_headers,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
http_options_keys = HttpOptions.model_fields.keys()
|
|
148
|
+
|
|
149
|
+
for key in http_options_keys:
|
|
150
|
+
if key == 'headers':
|
|
151
|
+
continue
|
|
152
|
+
patch_value = getattr(patch_options, key, None)
|
|
153
|
+
if patch_value is not None:
|
|
154
|
+
setattr(copy_option, key, patch_value)
|
|
155
|
+
else:
|
|
156
|
+
setattr(copy_option, key, getattr(options, key))
|
|
157
|
+
|
|
158
|
+
if copy_option.headers is not None:
|
|
159
|
+
append_library_version_headers(copy_option.headers)
|
|
91
160
|
return copy_option
|
|
92
161
|
|
|
93
162
|
|
|
94
|
-
def
|
|
163
|
+
def populate_server_timeout_header(
|
|
164
|
+
headers: dict[str, str], timeout_in_seconds: Optional[Union[float, int]]
|
|
165
|
+
) -> None:
|
|
166
|
+
"""Populates the server timeout header in the headers dict."""
|
|
167
|
+
if timeout_in_seconds and 'X-Server-Timeout' not in headers:
|
|
168
|
+
headers['X-Server-Timeout'] = str(math.ceil(timeout_in_seconds))
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def join_url_path(base_url: str, path: str) -> str:
|
|
95
172
|
parsed_base = urlparse(base_url)
|
|
96
173
|
base_path = (
|
|
97
174
|
parsed_base.path[:-1]
|
|
@@ -102,9 +179,9 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
|
102
179
|
return urlunparse(parsed_base._replace(path=base_path + '/' + path))
|
|
103
180
|
|
|
104
181
|
|
|
105
|
-
def
|
|
182
|
+
def load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
|
106
183
|
"""Loads google auth credentials and project id."""
|
|
107
|
-
credentials, loaded_project_id = google.auth.default(
|
|
184
|
+
credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
|
|
108
185
|
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
|
109
186
|
)
|
|
110
187
|
|
|
@@ -119,11 +196,25 @@ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
|
|
|
119
196
|
return credentials, project
|
|
120
197
|
|
|
121
198
|
|
|
122
|
-
def
|
|
123
|
-
|
|
199
|
+
def refresh_auth(credentials: Credentials) -> Credentials:
|
|
200
|
+
from google.auth.transport.requests import Request
|
|
201
|
+
credentials.refresh(Request()) # type: ignore[no-untyped-call]
|
|
124
202
|
return credentials
|
|
125
203
|
|
|
126
204
|
|
|
205
|
+
def get_timeout_in_seconds(
|
|
206
|
+
timeout: Optional[Union[float, int]],
|
|
207
|
+
) -> Optional[float]:
|
|
208
|
+
"""Converts the timeout to seconds."""
|
|
209
|
+
if timeout:
|
|
210
|
+
# HttpOptions.timeout is in milliseconds. But httpx.Client.request()
|
|
211
|
+
# expects seconds.
|
|
212
|
+
timeout_in_seconds = timeout / 1000.0
|
|
213
|
+
else:
|
|
214
|
+
timeout_in_seconds = None
|
|
215
|
+
return timeout_in_seconds
|
|
216
|
+
|
|
217
|
+
|
|
127
218
|
@dataclass
|
|
128
219
|
class HttpRequest:
|
|
129
220
|
headers: dict[str, str]
|
|
@@ -133,37 +224,35 @@ class HttpRequest:
|
|
|
133
224
|
timeout: Optional[float] = None
|
|
134
225
|
|
|
135
226
|
|
|
136
|
-
# TODO(b/394358912): Update this class to use a SDKResponse class that can be
|
|
137
|
-
# generated and used for all languages.
|
|
138
|
-
class BaseResponse(_common.BaseModel):
|
|
139
|
-
http_headers: Optional[dict[str, str]] = Field(
|
|
140
|
-
default=None, description='The http headers of the response.'
|
|
141
|
-
)
|
|
142
|
-
|
|
143
|
-
json_payload: Optional[Any] = Field(
|
|
144
|
-
default=None, description='The json payload of the response.'
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
|
|
148
227
|
class HttpResponse:
|
|
149
228
|
|
|
150
229
|
def __init__(
|
|
151
230
|
self,
|
|
152
|
-
headers: Union[dict[str, str], httpx.Headers],
|
|
231
|
+
headers: Union[dict[str, str], httpx.Headers, 'CIMultiDictProxy[str]'],
|
|
153
232
|
response_stream: Union[Any, str] = None,
|
|
154
233
|
byte_stream: Union[Any, bytes] = None,
|
|
155
234
|
):
|
|
156
|
-
|
|
157
|
-
|
|
235
|
+
if isinstance(headers, dict):
|
|
236
|
+
self.headers = headers
|
|
237
|
+
elif isinstance(headers, httpx.Headers):
|
|
238
|
+
self.headers = {
|
|
239
|
+
key: ', '.join(headers.get_list(key)) for key in headers.keys()
|
|
240
|
+
}
|
|
241
|
+
elif type(headers).__name__ == 'CIMultiDictProxy':
|
|
242
|
+
self.headers = {
|
|
243
|
+
key: ', '.join(headers.getall(key)) for key in headers.keys()
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
self.status_code: int = 200
|
|
158
247
|
self.response_stream = response_stream
|
|
159
248
|
self.byte_stream = byte_stream
|
|
160
249
|
|
|
161
250
|
# Async iterator for async streaming.
|
|
162
|
-
def __aiter__(self):
|
|
251
|
+
def __aiter__(self) -> 'HttpResponse':
|
|
163
252
|
self.segment_iterator = self.async_segments()
|
|
164
253
|
return self
|
|
165
254
|
|
|
166
|
-
async def __anext__(self):
|
|
255
|
+
async def __anext__(self) -> Any:
|
|
167
256
|
try:
|
|
168
257
|
return await self.segment_iterator.__anext__()
|
|
169
258
|
except StopIteration:
|
|
@@ -173,54 +262,34 @@ class HttpResponse:
|
|
|
173
262
|
def json(self) -> Any:
|
|
174
263
|
if not self.response_stream[0]: # Empty response
|
|
175
264
|
return ''
|
|
176
|
-
return
|
|
265
|
+
return self._load_json_from_response(self.response_stream[0])
|
|
177
266
|
|
|
178
|
-
def segments(self):
|
|
267
|
+
def segments(self) -> Generator[Any, None, None]:
|
|
179
268
|
if isinstance(self.response_stream, list):
|
|
180
269
|
# list of objects retrieved from replay or from non-streaming API.
|
|
181
270
|
for chunk in self.response_stream:
|
|
182
|
-
yield
|
|
271
|
+
yield self._load_json_from_response(chunk) if chunk else {}
|
|
183
272
|
elif self.response_stream is None:
|
|
184
273
|
yield from []
|
|
185
274
|
else:
|
|
186
275
|
# Iterator of objects retrieved from the API.
|
|
187
|
-
for chunk in self.
|
|
188
|
-
|
|
189
|
-
# In streaming mode, the chunk of JSON is prefixed with "data:" which
|
|
190
|
-
# we must strip before parsing.
|
|
191
|
-
if not isinstance(chunk, str):
|
|
192
|
-
chunk = chunk.decode('utf-8')
|
|
193
|
-
if chunk.startswith('data: '):
|
|
194
|
-
chunk = chunk[len('data: ') :]
|
|
195
|
-
yield json.loads(chunk)
|
|
276
|
+
for chunk in self._iter_response_stream():
|
|
277
|
+
yield self._load_json_from_response(chunk)
|
|
196
278
|
|
|
197
279
|
async def async_segments(self) -> AsyncIterator[Any]:
|
|
198
280
|
if isinstance(self.response_stream, list):
|
|
199
281
|
# list of objects retrieved from replay or from non-streaming API.
|
|
200
282
|
for chunk in self.response_stream:
|
|
201
|
-
yield
|
|
283
|
+
yield self._load_json_from_response(chunk) if chunk else {}
|
|
202
284
|
elif self.response_stream is None:
|
|
203
|
-
async for c in []:
|
|
285
|
+
async for c in []: # type: ignore[attr-defined]
|
|
204
286
|
yield c
|
|
205
287
|
else:
|
|
206
288
|
# Iterator of objects retrieved from the API.
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
# This is httpx.Response.
|
|
210
|
-
if chunk:
|
|
211
|
-
# In async streaming mode, the chunk of JSON is prefixed with
|
|
212
|
-
# "data:" which we must strip before parsing.
|
|
213
|
-
if not isinstance(chunk, str):
|
|
214
|
-
chunk = chunk.decode('utf-8')
|
|
215
|
-
if chunk.startswith('data: '):
|
|
216
|
-
chunk = chunk[len('data: ') :]
|
|
217
|
-
yield json.loads(chunk)
|
|
218
|
-
else:
|
|
219
|
-
raise ValueError(
|
|
220
|
-
'Error parsing streaming response.'
|
|
221
|
-
)
|
|
289
|
+
async for chunk in self._aiter_response_stream():
|
|
290
|
+
yield self._load_json_from_response(chunk)
|
|
222
291
|
|
|
223
|
-
def byte_segments(self):
|
|
292
|
+
def byte_segments(self) -> Generator[Union[bytes, Any], None, None]:
|
|
224
293
|
if isinstance(self.byte_stream, list):
|
|
225
294
|
# list of objects retrieved from replay or from non-streaming API.
|
|
226
295
|
yield from self.byte_stream
|
|
@@ -231,12 +300,199 @@ class HttpResponse:
|
|
|
231
300
|
'Byte segments are not supported for streaming responses.'
|
|
232
301
|
)
|
|
233
302
|
|
|
234
|
-
def _copy_to_dict(self, response_payload: dict[str, object]):
|
|
303
|
+
def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
|
|
235
304
|
# Cannot pickle 'generator' object.
|
|
236
305
|
delattr(self, 'segment_iterator')
|
|
237
306
|
for attribute in dir(self):
|
|
238
307
|
response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
|
|
239
308
|
|
|
309
|
+
def _iter_response_stream(self) -> Iterator[str]:
|
|
310
|
+
"""Iterates over chunks retrieved from the API."""
|
|
311
|
+
if not isinstance(self.response_stream, httpx.Response):
|
|
312
|
+
raise TypeError(
|
|
313
|
+
'Expected self.response_stream to be an httpx.Response object, '
|
|
314
|
+
f'but got {type(self.response_stream).__name__}.'
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
chunk = ''
|
|
318
|
+
balance = 0
|
|
319
|
+
for line in self.response_stream.iter_lines():
|
|
320
|
+
if not line:
|
|
321
|
+
continue
|
|
322
|
+
|
|
323
|
+
# In streaming mode, the response of JSON is prefixed with "data: " which
|
|
324
|
+
# we must strip before parsing.
|
|
325
|
+
if line.startswith('data: '):
|
|
326
|
+
yield line[len('data: '):]
|
|
327
|
+
continue
|
|
328
|
+
|
|
329
|
+
# When API returns an error message, it comes line by line. So we buffer
|
|
330
|
+
# the lines until a complete JSON string is read. A complete JSON string
|
|
331
|
+
# is found when the balance is 0.
|
|
332
|
+
for c in line:
|
|
333
|
+
if c == '{':
|
|
334
|
+
balance += 1
|
|
335
|
+
elif c == '}':
|
|
336
|
+
balance -= 1
|
|
337
|
+
|
|
338
|
+
chunk += line
|
|
339
|
+
if balance == 0:
|
|
340
|
+
yield chunk
|
|
341
|
+
chunk = ''
|
|
342
|
+
|
|
343
|
+
# If there is any remaining chunk, yield it.
|
|
344
|
+
if chunk:
|
|
345
|
+
yield chunk
|
|
346
|
+
|
|
347
|
+
async def _aiter_response_stream(self) -> AsyncIterator[str]:
|
|
348
|
+
"""Asynchronously iterates over chunks retrieved from the API."""
|
|
349
|
+
is_valid_response = isinstance(self.response_stream, httpx.Response) or (
|
|
350
|
+
has_aiohttp and isinstance(self.response_stream, aiohttp.ClientResponse)
|
|
351
|
+
)
|
|
352
|
+
if not is_valid_response:
|
|
353
|
+
raise TypeError(
|
|
354
|
+
'Expected self.response_stream to be an httpx.Response or'
|
|
355
|
+
' aiohttp.ClientResponse object, but got'
|
|
356
|
+
f' {type(self.response_stream).__name__}.'
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
chunk = ''
|
|
360
|
+
balance = 0
|
|
361
|
+
# httpx.Response has a dedicated async line iterator.
|
|
362
|
+
if isinstance(self.response_stream, httpx.Response):
|
|
363
|
+
try:
|
|
364
|
+
async for line in self.response_stream.aiter_lines():
|
|
365
|
+
if not line:
|
|
366
|
+
continue
|
|
367
|
+
# In streaming mode, the response of JSON is prefixed with "data: "
|
|
368
|
+
# which we must strip before parsing.
|
|
369
|
+
if line.startswith('data: '):
|
|
370
|
+
yield line[len('data: '):]
|
|
371
|
+
continue
|
|
372
|
+
|
|
373
|
+
# When API returns an error message, it comes line by line. So we buffer
|
|
374
|
+
# the lines until a complete JSON string is read. A complete JSON string
|
|
375
|
+
# is found when the balance is 0.
|
|
376
|
+
for c in line:
|
|
377
|
+
if c == '{':
|
|
378
|
+
balance += 1
|
|
379
|
+
elif c == '}':
|
|
380
|
+
balance -= 1
|
|
381
|
+
|
|
382
|
+
chunk += line
|
|
383
|
+
if balance == 0:
|
|
384
|
+
yield chunk
|
|
385
|
+
chunk = ''
|
|
386
|
+
# If there is any remaining chunk, yield it.
|
|
387
|
+
if chunk:
|
|
388
|
+
yield chunk
|
|
389
|
+
finally:
|
|
390
|
+
# Close the response and release the connection.
|
|
391
|
+
await self.response_stream.aclose()
|
|
392
|
+
|
|
393
|
+
# aiohttp.ClientResponse uses a content stream that we read line by line.
|
|
394
|
+
elif has_aiohttp and isinstance(
|
|
395
|
+
self.response_stream, aiohttp.ClientResponse
|
|
396
|
+
):
|
|
397
|
+
try:
|
|
398
|
+
while True:
|
|
399
|
+
# Read a line from the stream. This returns bytes.
|
|
400
|
+
line_bytes = await self.response_stream.content.readline()
|
|
401
|
+
if not line_bytes:
|
|
402
|
+
break
|
|
403
|
+
# Decode the bytes and remove trailing whitespace and newlines.
|
|
404
|
+
line = line_bytes.decode('utf-8').rstrip()
|
|
405
|
+
if not line:
|
|
406
|
+
continue
|
|
407
|
+
|
|
408
|
+
# In streaming mode, the response of JSON is prefixed with "data: "
|
|
409
|
+
# which we must strip before parsing.
|
|
410
|
+
if line.startswith('data: '):
|
|
411
|
+
yield line[len('data: '):]
|
|
412
|
+
continue
|
|
413
|
+
|
|
414
|
+
# When API returns an error message, it comes line by line. So we
|
|
415
|
+
# buffer the lines until a complete JSON string is read. A complete
|
|
416
|
+
# JSON strings found when the balance is 0.
|
|
417
|
+
for c in line:
|
|
418
|
+
if c == '{':
|
|
419
|
+
balance += 1
|
|
420
|
+
elif c == '}':
|
|
421
|
+
balance -= 1
|
|
422
|
+
|
|
423
|
+
chunk += line
|
|
424
|
+
if balance == 0:
|
|
425
|
+
yield chunk
|
|
426
|
+
chunk = ''
|
|
427
|
+
# If there is any remaining chunk, yield it.
|
|
428
|
+
if chunk:
|
|
429
|
+
yield chunk
|
|
430
|
+
finally:
|
|
431
|
+
# Release the connection back to the pool for potential reuse.
|
|
432
|
+
self.response_stream.release()
|
|
433
|
+
|
|
434
|
+
@classmethod
|
|
435
|
+
def _load_json_from_response(cls, response: Any) -> Any:
|
|
436
|
+
"""Loads JSON from the response, or raises an error if the parsing fails."""
|
|
437
|
+
try:
|
|
438
|
+
return json.loads(response)
|
|
439
|
+
except json.JSONDecodeError as e:
|
|
440
|
+
raise errors.UnknownApiResponseError(
|
|
441
|
+
f'Failed to parse response as JSON. Raw response: {response}'
|
|
442
|
+
) from e
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
# Default retry options.
|
|
446
|
+
# The config is based on https://cloud.google.com/storage/docs/retry-strategy.
|
|
447
|
+
# By default, the client will retry 4 times with approximately 1.0, 2.0, 4.0,
|
|
448
|
+
# 8.0 seconds between each attempt.
|
|
449
|
+
_RETRY_ATTEMPTS = 5 # including the initial call.
|
|
450
|
+
_RETRY_INITIAL_DELAY = 1.0 # seconds
|
|
451
|
+
_RETRY_MAX_DELAY = 60.0 # seconds
|
|
452
|
+
_RETRY_EXP_BASE = 2
|
|
453
|
+
_RETRY_JITTER = 1
|
|
454
|
+
_RETRY_HTTP_STATUS_CODES = (
|
|
455
|
+
408, # Request timeout.
|
|
456
|
+
429, # Too many requests.
|
|
457
|
+
500, # Internal server error.
|
|
458
|
+
502, # Bad gateway.
|
|
459
|
+
503, # Service unavailable.
|
|
460
|
+
504, # Gateway timeout
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
|
|
465
|
+
"""Returns the retry args for the given http retry options.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
options: The http retry options to use for the retry configuration. If None,
|
|
469
|
+
the 'never retry' stop strategy will be used.
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
The arguments passed to the tenacity.(Async)Retrying constructor.
|
|
473
|
+
"""
|
|
474
|
+
if options is None:
|
|
475
|
+
return {'stop': tenacity.stop_after_attempt(1), 'reraise': True}
|
|
476
|
+
|
|
477
|
+
stop = tenacity.stop_after_attempt(options.attempts or _RETRY_ATTEMPTS)
|
|
478
|
+
retriable_codes = options.http_status_codes or _RETRY_HTTP_STATUS_CODES
|
|
479
|
+
retry = tenacity.retry_if_exception(
|
|
480
|
+
lambda e: isinstance(e, errors.APIError) and e.code in retriable_codes,
|
|
481
|
+
)
|
|
482
|
+
wait = tenacity.wait_exponential_jitter(
|
|
483
|
+
initial=options.initial_delay or _RETRY_INITIAL_DELAY,
|
|
484
|
+
max=options.max_delay or _RETRY_MAX_DELAY,
|
|
485
|
+
exp_base=options.exp_base or _RETRY_EXP_BASE,
|
|
486
|
+
jitter=options.jitter or _RETRY_JITTER,
|
|
487
|
+
)
|
|
488
|
+
return {
|
|
489
|
+
'stop': stop,
|
|
490
|
+
'retry': retry,
|
|
491
|
+
'reraise': True,
|
|
492
|
+
'wait': wait,
|
|
493
|
+
'before_sleep': tenacity.before_sleep_log(logger, logging.INFO),
|
|
494
|
+
}
|
|
495
|
+
|
|
240
496
|
|
|
241
497
|
class SyncHttpxClient(httpx.Client):
|
|
242
498
|
"""Sync httpx client."""
|
|
@@ -248,8 +504,11 @@ class SyncHttpxClient(httpx.Client):
|
|
|
248
504
|
|
|
249
505
|
def __del__(self) -> None:
|
|
250
506
|
"""Closes the httpx client."""
|
|
251
|
-
|
|
252
|
-
|
|
507
|
+
try:
|
|
508
|
+
if self.is_closed:
|
|
509
|
+
return
|
|
510
|
+
except Exception:
|
|
511
|
+
pass
|
|
253
512
|
try:
|
|
254
513
|
self.close()
|
|
255
514
|
except Exception:
|
|
@@ -265,8 +524,11 @@ class AsyncHttpxClient(httpx.AsyncClient):
|
|
|
265
524
|
super().__init__(**kwargs)
|
|
266
525
|
|
|
267
526
|
def __del__(self) -> None:
|
|
268
|
-
|
|
269
|
-
|
|
527
|
+
try:
|
|
528
|
+
if self.is_closed:
|
|
529
|
+
return
|
|
530
|
+
except Exception:
|
|
531
|
+
pass
|
|
270
532
|
try:
|
|
271
533
|
asyncio.get_running_loop().create_task(self.aclose())
|
|
272
534
|
except Exception:
|
|
@@ -286,6 +548,7 @@ class BaseApiClient:
|
|
|
286
548
|
http_options: Optional[HttpOptionsOrDict] = None,
|
|
287
549
|
):
|
|
288
550
|
self.vertexai = vertexai
|
|
551
|
+
self.custom_base_url = None
|
|
289
552
|
if self.vertexai is None:
|
|
290
553
|
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
|
|
291
554
|
'true',
|
|
@@ -308,36 +571,42 @@ class BaseApiClient:
|
|
|
308
571
|
)
|
|
309
572
|
|
|
310
573
|
# Validate http_options if it is provided.
|
|
311
|
-
validated_http_options
|
|
574
|
+
validated_http_options = HttpOptions()
|
|
312
575
|
if isinstance(http_options, dict):
|
|
313
576
|
try:
|
|
314
|
-
validated_http_options = HttpOptions.model_validate(
|
|
315
|
-
http_options
|
|
316
|
-
).model_dump()
|
|
577
|
+
validated_http_options = HttpOptions.model_validate(http_options)
|
|
317
578
|
except ValidationError as e:
|
|
318
|
-
raise ValueError(
|
|
579
|
+
raise ValueError('Invalid http_options') from e
|
|
319
580
|
elif isinstance(http_options, HttpOptions):
|
|
320
|
-
validated_http_options = http_options
|
|
581
|
+
validated_http_options = http_options
|
|
582
|
+
|
|
583
|
+
if validated_http_options.base_url_resource_scope and not validated_http_options.base_url:
|
|
584
|
+
# base_url_resource_scope is only valid when base_url is set.
|
|
585
|
+
raise ValueError(
|
|
586
|
+
'base_url must be set when base_url_resource_scope is set.'
|
|
587
|
+
)
|
|
321
588
|
|
|
322
589
|
# Retrieve implicitly set values from the environment.
|
|
323
590
|
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
|
324
591
|
env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None)
|
|
325
|
-
env_api_key =
|
|
592
|
+
env_api_key = get_env_api_key()
|
|
326
593
|
self.project = project or env_project
|
|
327
594
|
self.location = location or env_location
|
|
328
595
|
self.api_key = api_key or env_api_key
|
|
329
596
|
|
|
330
597
|
self._credentials = credentials
|
|
331
|
-
self._http_options =
|
|
598
|
+
self._http_options = HttpOptions()
|
|
332
599
|
# Initialize the lock. This lock will be used to protect access to the
|
|
333
600
|
# credentials. This is crucial for thread safety when multiple coroutines
|
|
334
601
|
# might be accessing the credentials at the same time.
|
|
335
|
-
self.
|
|
602
|
+
self._sync_auth_lock = threading.Lock()
|
|
603
|
+
self._async_auth_lock: Optional[asyncio.Lock] = None
|
|
604
|
+
self._async_auth_lock_creation_lock: Optional[asyncio.Lock] = None
|
|
336
605
|
|
|
337
606
|
# Handle when to use Vertex AI in express mode (api key).
|
|
338
607
|
# Explicit initializer arguments are already validated above.
|
|
339
608
|
if self.vertexai:
|
|
340
|
-
if credentials:
|
|
609
|
+
if credentials and env_api_key:
|
|
341
610
|
# Explicit credentials take precedence over implicit api_key.
|
|
342
611
|
logger.info(
|
|
343
612
|
'The user provided Google Cloud credentials will take precedence'
|
|
@@ -366,94 +635,412 @@ class BaseApiClient:
|
|
|
366
635
|
+ ' precedence over the API key from the environment variables.'
|
|
367
636
|
)
|
|
368
637
|
self.api_key = None
|
|
369
|
-
|
|
370
|
-
|
|
638
|
+
|
|
639
|
+
self.custom_base_url = (
|
|
640
|
+
validated_http_options.base_url
|
|
641
|
+
if validated_http_options.base_url
|
|
642
|
+
else None
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
if not self.location and not self.api_key and not self.custom_base_url:
|
|
646
|
+
self.location = 'global'
|
|
647
|
+
|
|
648
|
+
# Skip fetching project from ADC if base url is provided in http options.
|
|
649
|
+
if (
|
|
650
|
+
not self.project
|
|
651
|
+
and not self.api_key
|
|
652
|
+
and not self.custom_base_url
|
|
653
|
+
):
|
|
654
|
+
credentials, self.project = load_auth(project=None)
|
|
371
655
|
if not self._credentials:
|
|
372
656
|
self._credentials = credentials
|
|
373
|
-
|
|
657
|
+
|
|
658
|
+
has_sufficient_auth = (self.project and self.location) or self.api_key
|
|
659
|
+
|
|
660
|
+
if not has_sufficient_auth and not self.custom_base_url:
|
|
661
|
+
# Skip sufficient auth check if base url is provided in http options.
|
|
374
662
|
raise ValueError(
|
|
375
|
-
'Project
|
|
663
|
+
'Project or API key must be set when using the Vertex '
|
|
376
664
|
'AI API.'
|
|
377
665
|
)
|
|
378
666
|
if self.api_key or self.location == 'global':
|
|
379
|
-
self._http_options
|
|
667
|
+
self._http_options.base_url = f'https://aiplatform.googleapis.com/'
|
|
668
|
+
elif self.custom_base_url and not ((project and location) or api_key):
|
|
669
|
+
# Avoid setting default base url and api version if base_url provided.
|
|
670
|
+
# API gateway proxy can use the auth in custom headers, not url.
|
|
671
|
+
# Enable custom url if auth is not sufficient.
|
|
672
|
+
self._http_options.base_url = self.custom_base_url
|
|
673
|
+
# Clear project and location if base_url is provided.
|
|
674
|
+
self.project = None
|
|
675
|
+
self.location = None
|
|
380
676
|
else:
|
|
381
|
-
self._http_options
|
|
677
|
+
self._http_options.base_url = (
|
|
382
678
|
f'https://{self.location}-aiplatform.googleapis.com/'
|
|
383
679
|
)
|
|
384
|
-
self._http_options
|
|
680
|
+
self._http_options.api_version = 'v1beta1'
|
|
385
681
|
else: # Implicit initialization or missing arguments.
|
|
386
682
|
if not self.api_key:
|
|
387
683
|
raise ValueError(
|
|
388
684
|
'Missing key inputs argument! To use the Google AI API,'
|
|
389
|
-
'provide (`api_key`) arguments. To use the Google Cloud API,'
|
|
685
|
+
' provide (`api_key`) arguments. To use the Google Cloud API,'
|
|
390
686
|
' provide (`vertexai`, `project` & `location`) arguments.'
|
|
391
687
|
)
|
|
392
|
-
self._http_options
|
|
393
|
-
|
|
394
|
-
)
|
|
395
|
-
self._http_options['api_version'] = 'v1beta'
|
|
688
|
+
self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
|
|
689
|
+
self._http_options.api_version = 'v1beta'
|
|
396
690
|
# Default options for both clients.
|
|
397
|
-
self._http_options
|
|
691
|
+
self._http_options.headers = {'Content-Type': 'application/json'}
|
|
398
692
|
if self.api_key:
|
|
399
|
-
self.
|
|
693
|
+
self.api_key = self.api_key.strip()
|
|
694
|
+
if self._http_options.headers is not None:
|
|
695
|
+
self._http_options.headers['x-goog-api-key'] = self.api_key
|
|
400
696
|
# Update the http options with the user provided http options.
|
|
401
697
|
if http_options:
|
|
402
|
-
self._http_options =
|
|
698
|
+
self._http_options = patch_http_options(
|
|
403
699
|
self._http_options, validated_http_options
|
|
404
700
|
)
|
|
405
701
|
else:
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
702
|
+
if self._http_options.headers is not None:
|
|
703
|
+
append_library_version_headers(self._http_options.headers)
|
|
704
|
+
|
|
705
|
+
client_args, async_client_args = self._ensure_httpx_ssl_ctx(
|
|
706
|
+
self._http_options
|
|
707
|
+
)
|
|
708
|
+
self._async_httpx_client_args = async_client_args
|
|
709
|
+
|
|
710
|
+
if self._http_options.httpx_client:
|
|
711
|
+
self._httpx_client = self._http_options.httpx_client
|
|
712
|
+
else:
|
|
713
|
+
self._httpx_client = SyncHttpxClient(**client_args)
|
|
714
|
+
if self._http_options.httpx_async_client:
|
|
715
|
+
self._async_httpx_client = self._http_options.httpx_async_client
|
|
716
|
+
else:
|
|
717
|
+
self._async_httpx_client = AsyncHttpxClient(**async_client_args)
|
|
718
|
+
if self._use_aiohttp():
|
|
719
|
+
try:
|
|
720
|
+
import aiohttp # pylint: disable=g-import-not-at-top
|
|
721
|
+
# Do it once at the genai.Client level. Share among all requests.
|
|
722
|
+
self._async_client_session_request_args = self._ensure_aiohttp_ssl_ctx(
|
|
723
|
+
self._http_options
|
|
724
|
+
)
|
|
725
|
+
except ImportError:
|
|
726
|
+
pass
|
|
727
|
+
|
|
728
|
+
# Initialize the aiohttp client session.
|
|
729
|
+
self._aiohttp_session: Optional[aiohttp.ClientSession] = None
|
|
730
|
+
|
|
731
|
+
retry_kwargs = retry_args(self._http_options.retry_options)
|
|
732
|
+
self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(self._http_options)
|
|
733
|
+
self._retry = tenacity.Retrying(**retry_kwargs)
|
|
734
|
+
self._async_retry = tenacity.AsyncRetrying(**retry_kwargs)
|
|
735
|
+
|
|
736
|
+
async def _get_aiohttp_session(self) -> 'aiohttp.ClientSession':
|
|
737
|
+
"""Returns the aiohttp client session."""
|
|
738
|
+
if (
|
|
739
|
+
self._aiohttp_session is None
|
|
740
|
+
or self._aiohttp_session.closed
|
|
741
|
+
or self._aiohttp_session._loop.is_closed() # pylint: disable=protected-access
|
|
742
|
+
):
|
|
743
|
+
# Initialize the aiohttp client session if it's not set up or closed.
|
|
744
|
+
class AiohttpClientSession(aiohttp.ClientSession): # type: ignore[misc]
|
|
745
|
+
|
|
746
|
+
def __del__(self, _warnings: Any = warnings) -> None:
|
|
747
|
+
if not self.closed:
|
|
748
|
+
context = {
|
|
749
|
+
'client_session': self,
|
|
750
|
+
'message': 'Unclosed client session',
|
|
751
|
+
}
|
|
752
|
+
if self._source_traceback is not None:
|
|
753
|
+
context['source_traceback'] = self._source_traceback
|
|
754
|
+
# Remove this self._loop.call_exception_handler(context)
|
|
755
|
+
|
|
756
|
+
class AiohttpTCPConnector(aiohttp.TCPConnector): # type: ignore[misc]
|
|
757
|
+
|
|
758
|
+
def __del__(self, _warnings: Any = warnings) -> None:
|
|
759
|
+
if self._closed:
|
|
760
|
+
return
|
|
761
|
+
if not self._conns:
|
|
762
|
+
return
|
|
763
|
+
conns = [repr(c) for c in self._conns.values()]
|
|
764
|
+
# After v3.13.2, it may change to self._close_immediately()
|
|
765
|
+
self._close()
|
|
766
|
+
context = {
|
|
767
|
+
'connector': self,
|
|
768
|
+
'connections': conns,
|
|
769
|
+
'message': 'Unclosed connector',
|
|
770
|
+
}
|
|
771
|
+
if self._source_traceback is not None:
|
|
772
|
+
context['source_traceback'] = self._source_traceback
|
|
773
|
+
# Remove this self._loop.call_exception_handler(context)
|
|
774
|
+
self._aiohttp_session = AiohttpClientSession(
|
|
775
|
+
connector=AiohttpTCPConnector(limit=0),
|
|
776
|
+
trust_env=True,
|
|
777
|
+
read_bufsize=READ_BUFFER_SIZE,
|
|
778
|
+
)
|
|
779
|
+
return self._aiohttp_session
|
|
780
|
+
|
|
781
|
+
@staticmethod
|
|
782
|
+
def _ensure_httpx_ssl_ctx(
|
|
783
|
+
options: HttpOptions,
|
|
784
|
+
) -> Tuple[_common.StringDict, _common.StringDict]:
|
|
785
|
+
"""Ensures the SSL context is present in the HTTPX client args.
|
|
410
786
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
787
|
+
Creates a default SSL context if one is not provided.
|
|
788
|
+
|
|
789
|
+
Args:
|
|
790
|
+
options: The http options to check for SSL context.
|
|
791
|
+
|
|
792
|
+
Returns:
|
|
793
|
+
A tuple of sync/async httpx client args.
|
|
794
|
+
"""
|
|
795
|
+
|
|
796
|
+
verify = 'verify'
|
|
797
|
+
args = options.client_args
|
|
798
|
+
async_args = options.async_client_args
|
|
799
|
+
ctx = (
|
|
800
|
+
args.get(verify)
|
|
801
|
+
if args
|
|
802
|
+
else None or async_args.get(verify)
|
|
803
|
+
if async_args
|
|
804
|
+
else None
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
if not ctx:
|
|
808
|
+
# Initialize the SSL context for the httpx client.
|
|
809
|
+
# Unlike requests, the httpx package does not automatically pull in the
|
|
810
|
+
# environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
|
|
811
|
+
# enabled explicitly.
|
|
812
|
+
ctx = ssl.create_default_context(
|
|
813
|
+
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
|
|
814
|
+
capath=os.environ.get('SSL_CERT_DIR'),
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
def _maybe_set(
|
|
818
|
+
args: Optional[_common.StringDict],
|
|
819
|
+
ctx: ssl.SSLContext,
|
|
820
|
+
) -> _common.StringDict:
|
|
821
|
+
"""Sets the SSL context in the client args if not set.
|
|
822
|
+
|
|
823
|
+
Does not override the SSL context if it is already set.
|
|
824
|
+
|
|
825
|
+
Args:
|
|
826
|
+
args: The client args to to check for SSL context.
|
|
827
|
+
ctx: The SSL context to set.
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
The client args with the SSL context included.
|
|
831
|
+
"""
|
|
832
|
+
if not args or not args.get(verify):
|
|
833
|
+
args = (args or {}).copy()
|
|
834
|
+
args[verify] = ctx
|
|
835
|
+
# Drop the args that isn't used by the httpx client.
|
|
836
|
+
copied_args = args.copy()
|
|
837
|
+
for key in copied_args.copy():
|
|
838
|
+
if key not in inspect.signature(httpx.Client.__init__).parameters:
|
|
839
|
+
del copied_args[key]
|
|
840
|
+
return copied_args
|
|
841
|
+
|
|
842
|
+
return (
|
|
843
|
+
_maybe_set(args, ctx),
|
|
844
|
+
_maybe_set(async_args, ctx),
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
@staticmethod
|
|
848
|
+
def _ensure_aiohttp_ssl_ctx(options: HttpOptions) -> _common.StringDict:
|
|
849
|
+
"""Ensures the SSL context is present in the async client args.
|
|
850
|
+
|
|
851
|
+
Creates a default SSL context if one is not provided.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
options: The http options to check for SSL context.
|
|
855
|
+
|
|
856
|
+
Returns:
|
|
857
|
+
An async aiohttp ClientSession._request args.
|
|
858
|
+
"""
|
|
859
|
+
verify = 'ssl' # keep it consistent with httpx.
|
|
860
|
+
async_args = options.async_client_args
|
|
861
|
+
ctx = async_args.get(verify) if async_args else None
|
|
862
|
+
|
|
863
|
+
if not ctx:
|
|
864
|
+
# Initialize the SSL context for the httpx client.
|
|
865
|
+
# Unlike requests, the aiohttp package does not automatically pull in the
|
|
866
|
+
# environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
|
|
867
|
+
# enabled explicitly. Instead of 'verify' at client level in httpx,
|
|
868
|
+
# aiohttp uses 'ssl' at request level.
|
|
869
|
+
ctx = ssl.create_default_context(
|
|
870
|
+
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
|
|
871
|
+
capath=os.environ.get('SSL_CERT_DIR'),
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
def _maybe_set(
|
|
875
|
+
args: Optional[_common.StringDict],
|
|
876
|
+
ctx: ssl.SSLContext,
|
|
877
|
+
) -> _common.StringDict:
|
|
878
|
+
"""Sets the SSL context in the client args if not set.
|
|
879
|
+
|
|
880
|
+
Does not override the SSL context if it is already set.
|
|
881
|
+
|
|
882
|
+
Args:
|
|
883
|
+
args: The client args to to check for SSL context.
|
|
884
|
+
ctx: The SSL context to set.
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
The client args with the SSL context included.
|
|
888
|
+
"""
|
|
889
|
+
if not args or not args.get(verify):
|
|
890
|
+
args = (args or {}).copy()
|
|
891
|
+
args[verify] = ctx
|
|
892
|
+
# Drop the args that isn't in the aiohttp RequestOptions.
|
|
893
|
+
copied_args = args.copy()
|
|
894
|
+
for key in copied_args.copy():
|
|
895
|
+
if (
|
|
896
|
+
key
|
|
897
|
+
not in inspect.signature(aiohttp.ClientSession._request).parameters
|
|
898
|
+
):
|
|
899
|
+
del copied_args[key]
|
|
900
|
+
return copied_args
|
|
901
|
+
|
|
902
|
+
return _maybe_set(async_args, ctx)
|
|
903
|
+
|
|
904
|
+
@staticmethod
|
|
905
|
+
def _ensure_websocket_ssl_ctx(options: HttpOptions) -> _common.StringDict:
|
|
906
|
+
"""Ensures the SSL context is present in the async client args.
|
|
907
|
+
|
|
908
|
+
Creates a default SSL context if one is not provided.
|
|
909
|
+
|
|
910
|
+
Args:
|
|
911
|
+
options: The http options to check for SSL context.
|
|
912
|
+
|
|
913
|
+
Returns:
|
|
914
|
+
An async aiohttp ClientSession._request args.
|
|
915
|
+
"""
|
|
916
|
+
|
|
917
|
+
verify = 'ssl' # keep it consistent with httpx.
|
|
918
|
+
async_args = options.async_client_args
|
|
919
|
+
ctx = async_args.get(verify) if async_args else None
|
|
920
|
+
|
|
921
|
+
if not ctx:
|
|
922
|
+
# Initialize the SSL context for the httpx client.
|
|
923
|
+
# Unlike requests, the aiohttp package does not automatically pull in the
|
|
924
|
+
# environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
|
|
925
|
+
# enabled explicitly. Instead of 'verify' at client level in httpx,
|
|
926
|
+
# aiohttp uses 'ssl' at request level.
|
|
927
|
+
ctx = ssl.create_default_context(
|
|
928
|
+
cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
|
|
929
|
+
capath=os.environ.get('SSL_CERT_DIR'),
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
def _maybe_set(
|
|
933
|
+
args: Optional[_common.StringDict],
|
|
934
|
+
ctx: ssl.SSLContext,
|
|
935
|
+
) -> _common.StringDict:
|
|
936
|
+
"""Sets the SSL context in the client args if not set.
|
|
937
|
+
|
|
938
|
+
Does not override the SSL context if it is already set.
|
|
939
|
+
|
|
940
|
+
Args:
|
|
941
|
+
args: The client args to to check for SSL context.
|
|
942
|
+
ctx: The SSL context to set.
|
|
943
|
+
|
|
944
|
+
Returns:
|
|
945
|
+
The client args with the SSL context included.
|
|
946
|
+
"""
|
|
947
|
+
if not args or not args.get(verify):
|
|
948
|
+
args = (args or {}).copy()
|
|
949
|
+
args[verify] = ctx
|
|
950
|
+
# Drop the args that isn't in the aiohttp RequestOptions.
|
|
951
|
+
copied_args = args.copy()
|
|
952
|
+
for key in copied_args.copy():
|
|
953
|
+
if key not in inspect.signature(ws_connect).parameters and key != 'ssl':
|
|
954
|
+
del copied_args[key]
|
|
955
|
+
return copied_args
|
|
956
|
+
|
|
957
|
+
return _maybe_set(async_args, ctx)
|
|
958
|
+
|
|
959
|
+
def _use_aiohttp(self) -> bool:
|
|
960
|
+
# If the instantiator has passed a custom transport, they want httpx not
|
|
961
|
+
# aiohttp.
|
|
962
|
+
return (
|
|
963
|
+
has_aiohttp
|
|
964
|
+
and (self._http_options.async_client_args or {}).get('transport')
|
|
965
|
+
is None
|
|
966
|
+
and (self._http_options.httpx_async_client is None)
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
def _websocket_base_url(self) -> str:
|
|
970
|
+
has_sufficient_auth = (self.project and self.location) or self.api_key
|
|
971
|
+
if self.custom_base_url and not has_sufficient_auth:
|
|
972
|
+
# API gateway proxy can use the auth in custom headers, not url.
|
|
973
|
+
# Enable custom url if auth is not sufficient.
|
|
974
|
+
return self.custom_base_url
|
|
975
|
+
url_parts = urlparse(self._http_options.base_url)
|
|
976
|
+
return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
|
|
414
977
|
|
|
415
978
|
def _access_token(self) -> str:
|
|
416
979
|
"""Retrieves the access token for the credentials."""
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
self.project
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
if
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
980
|
+
with self._sync_auth_lock:
|
|
981
|
+
if not self._credentials:
|
|
982
|
+
self._credentials, project = load_auth(project=self.project)
|
|
983
|
+
if not self.project:
|
|
984
|
+
self.project = project
|
|
985
|
+
|
|
986
|
+
if self._credentials:
|
|
987
|
+
if self._credentials.expired or not self._credentials.token:
|
|
988
|
+
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
|
989
|
+
refresh_auth(self._credentials)
|
|
990
|
+
if not self._credentials.token:
|
|
991
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
|
992
|
+
return self._credentials.token # type: ignore[no-any-return]
|
|
993
|
+
else:
|
|
429
994
|
raise RuntimeError('Could not resolve API token from the environment')
|
|
430
|
-
return self._credentials.token
|
|
431
|
-
else:
|
|
432
|
-
raise RuntimeError('Could not resolve API token from the environment')
|
|
433
995
|
|
|
434
|
-
async def
|
|
996
|
+
async def _get_async_auth_lock(self) -> asyncio.Lock:
|
|
997
|
+
"""Lazily initializes and returns an asyncio.Lock for async authentication.
|
|
998
|
+
|
|
999
|
+
This method ensures that a single `asyncio.Lock` instance is created and
|
|
1000
|
+
shared among all asynchronous operations that require authentication,
|
|
1001
|
+
preventing race conditions when accessing or refreshing credentials.
|
|
1002
|
+
|
|
1003
|
+
The lock is created on the first call to this method. An internal async lock
|
|
1004
|
+
is used to protect the creation of the main authentication lock to ensure
|
|
1005
|
+
it's a singleton within the client instance.
|
|
1006
|
+
|
|
1007
|
+
Returns:
|
|
1008
|
+
The asyncio.Lock instance for asynchronous authentication operations.
|
|
1009
|
+
"""
|
|
1010
|
+
if self._async_auth_lock is None:
|
|
1011
|
+
# Create async creation lock if needed
|
|
1012
|
+
if self._async_auth_lock_creation_lock is None:
|
|
1013
|
+
self._async_auth_lock_creation_lock = asyncio.Lock()
|
|
1014
|
+
|
|
1015
|
+
async with self._async_auth_lock_creation_lock:
|
|
1016
|
+
if self._async_auth_lock is None:
|
|
1017
|
+
self._async_auth_lock = asyncio.Lock()
|
|
1018
|
+
|
|
1019
|
+
return self._async_auth_lock
|
|
1020
|
+
|
|
1021
|
+
async def _async_access_token(self) -> Union[str, Any]:
|
|
435
1022
|
"""Retrieves the access token for the credentials asynchronously."""
|
|
436
1023
|
if not self._credentials:
|
|
437
|
-
|
|
1024
|
+
async_auth_lock = await self._get_async_auth_lock()
|
|
1025
|
+
async with async_auth_lock:
|
|
438
1026
|
# This ensures that only one coroutine can execute the auth logic at a
|
|
439
1027
|
# time for thread safety.
|
|
440
1028
|
if not self._credentials:
|
|
441
1029
|
# Double check that the credentials are not set before loading them.
|
|
442
1030
|
self._credentials, project = await asyncio.to_thread(
|
|
443
|
-
|
|
1031
|
+
load_auth, project=self.project
|
|
444
1032
|
)
|
|
445
1033
|
if not self.project:
|
|
446
1034
|
self.project = project
|
|
447
1035
|
|
|
448
1036
|
if self._credentials:
|
|
449
|
-
if
|
|
450
|
-
self._credentials.expired or not self._credentials.token
|
|
451
|
-
):
|
|
1037
|
+
if self._credentials.expired or not self._credentials.token:
|
|
452
1038
|
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
|
453
|
-
|
|
1039
|
+
async_auth_lock = await self._get_async_auth_lock()
|
|
1040
|
+
async with async_auth_lock:
|
|
454
1041
|
if self._credentials.expired or not self._credentials.token:
|
|
455
1042
|
# Double check that the credentials expired before refreshing.
|
|
456
|
-
await asyncio.to_thread(
|
|
1043
|
+
await asyncio.to_thread(refresh_auth, self._credentials)
|
|
457
1044
|
|
|
458
1045
|
if not self._credentials.token:
|
|
459
1046
|
raise RuntimeError('Could not resolve API token from the environment')
|
|
@@ -476,12 +1063,13 @@ class BaseApiClient:
|
|
|
476
1063
|
# patch the http options with the user provided settings.
|
|
477
1064
|
if http_options:
|
|
478
1065
|
if isinstance(http_options, HttpOptions):
|
|
479
|
-
patched_http_options =
|
|
480
|
-
self._http_options,
|
|
1066
|
+
patched_http_options = patch_http_options(
|
|
1067
|
+
self._http_options,
|
|
1068
|
+
http_options,
|
|
481
1069
|
)
|
|
482
1070
|
else:
|
|
483
|
-
patched_http_options =
|
|
484
|
-
self._http_options, http_options
|
|
1071
|
+
patched_http_options = patch_http_options(
|
|
1072
|
+
self._http_options, HttpOptions.model_validate(http_options)
|
|
485
1073
|
)
|
|
486
1074
|
else:
|
|
487
1075
|
patched_http_options = self._http_options
|
|
@@ -497,42 +1085,86 @@ class BaseApiClient:
|
|
|
497
1085
|
self.vertexai
|
|
498
1086
|
and not path.startswith('projects/')
|
|
499
1087
|
and not query_vertex_base_models
|
|
500
|
-
and
|
|
1088
|
+
and (self.project or self.location)
|
|
1089
|
+
and not (
|
|
1090
|
+
self.custom_base_url
|
|
1091
|
+
and patched_http_options.base_url_resource_scope
|
|
1092
|
+
== ResourceScope.COLLECTION
|
|
1093
|
+
)
|
|
501
1094
|
):
|
|
502
1095
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
|
503
|
-
url = _join_url_path(
|
|
504
|
-
patched_http_options.get('base_url', ''),
|
|
505
|
-
patched_http_options.get('api_version', '') + '/' + path,
|
|
506
|
-
)
|
|
507
1096
|
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
1097
|
+
if patched_http_options.api_version is None:
|
|
1098
|
+
versioned_path = f'/{path}'
|
|
1099
|
+
else:
|
|
1100
|
+
versioned_path = f'{patched_http_options.api_version}/{path}'
|
|
1101
|
+
|
|
1102
|
+
if (
|
|
1103
|
+
patched_http_options.base_url is None
|
|
1104
|
+
or not patched_http_options.base_url
|
|
1105
|
+
):
|
|
1106
|
+
raise ValueError('Base URL must be set.')
|
|
515
1107
|
else:
|
|
516
|
-
|
|
1108
|
+
base_url = patched_http_options.base_url
|
|
1109
|
+
|
|
1110
|
+
if (
|
|
1111
|
+
hasattr(patched_http_options, 'extra_body')
|
|
1112
|
+
and patched_http_options.extra_body
|
|
1113
|
+
):
|
|
1114
|
+
_common.recursive_dict_update(
|
|
1115
|
+
request_dict, patched_http_options.extra_body
|
|
1116
|
+
)
|
|
1117
|
+
url = base_url
|
|
1118
|
+
if (
|
|
1119
|
+
not self.custom_base_url
|
|
1120
|
+
or (self.project and self.location)
|
|
1121
|
+
or self.api_key
|
|
1122
|
+
):
|
|
1123
|
+
if (
|
|
1124
|
+
patched_http_options.base_url_resource_scope
|
|
1125
|
+
== ResourceScope.COLLECTION
|
|
1126
|
+
):
|
|
1127
|
+
url = join_url_path(base_url, path)
|
|
1128
|
+
else:
|
|
1129
|
+
url = join_url_path(
|
|
1130
|
+
base_url,
|
|
1131
|
+
versioned_path,
|
|
1132
|
+
)
|
|
1133
|
+
elif(
|
|
1134
|
+
self.custom_base_url
|
|
1135
|
+
and patched_http_options.base_url_resource_scope == ResourceScope.COLLECTION
|
|
1136
|
+
):
|
|
1137
|
+
url = join_url_path(base_url, path)
|
|
1138
|
+
|
|
1139
|
+
if self.api_key and self.api_key.startswith('auth_tokens/'):
|
|
1140
|
+
raise EphemeralTokenAPIKeyError(
|
|
1141
|
+
'Ephemeral tokens can only be used with the live API.'
|
|
1142
|
+
)
|
|
517
1143
|
|
|
1144
|
+
timeout_in_seconds = get_timeout_in_seconds(patched_http_options.timeout)
|
|
1145
|
+
|
|
1146
|
+
if patched_http_options.headers is None:
|
|
1147
|
+
raise ValueError('Request headers must be set.')
|
|
1148
|
+
populate_server_timeout_header(
|
|
1149
|
+
patched_http_options.headers, timeout_in_seconds
|
|
1150
|
+
)
|
|
518
1151
|
return HttpRequest(
|
|
519
1152
|
method=http_method,
|
|
520
1153
|
url=url,
|
|
521
|
-
headers=patched_http_options
|
|
1154
|
+
headers=patched_http_options.headers,
|
|
522
1155
|
data=request_dict,
|
|
523
1156
|
timeout=timeout_in_seconds,
|
|
524
1157
|
)
|
|
525
1158
|
|
|
526
|
-
def
|
|
1159
|
+
def _request_once(
|
|
527
1160
|
self,
|
|
528
1161
|
http_request: HttpRequest,
|
|
529
1162
|
stream: bool = False,
|
|
530
1163
|
) -> HttpResponse:
|
|
531
1164
|
data: Optional[Union[str, bytes]] = None
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
)
|
|
1165
|
+
# If using proj/location, fetch ADC
|
|
1166
|
+
if self.vertexai and (self.project or self.location):
|
|
1167
|
+
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
|
|
536
1168
|
if self._credentials and self._credentials.quota_project_id:
|
|
537
1169
|
http_request.headers['x-goog-user-project'] = (
|
|
538
1170
|
self._credentials.quota_project_id
|
|
@@ -571,11 +1203,33 @@ class BaseApiClient:
|
|
|
571
1203
|
response.headers, response if stream else [response.text]
|
|
572
1204
|
)
|
|
573
1205
|
|
|
574
|
-
|
|
1206
|
+
def _request(
|
|
1207
|
+
self,
|
|
1208
|
+
http_request: HttpRequest,
|
|
1209
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1210
|
+
stream: bool = False,
|
|
1211
|
+
) -> HttpResponse:
|
|
1212
|
+
if http_options:
|
|
1213
|
+
parameter_model = (
|
|
1214
|
+
HttpOptions(**http_options)
|
|
1215
|
+
if isinstance(http_options, dict)
|
|
1216
|
+
else http_options
|
|
1217
|
+
)
|
|
1218
|
+
# Support per request retry options.
|
|
1219
|
+
if parameter_model.retry_options:
|
|
1220
|
+
retry_kwargs = retry_args(parameter_model.retry_options)
|
|
1221
|
+
retry = tenacity.Retrying(**retry_kwargs)
|
|
1222
|
+
return retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
|
|
1223
|
+
|
|
1224
|
+
return self._retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
|
|
1225
|
+
|
|
1226
|
+
async def _async_request_once(
|
|
575
1227
|
self, http_request: HttpRequest, stream: bool = False
|
|
576
|
-
):
|
|
1228
|
+
) -> HttpResponse:
|
|
577
1229
|
data: Optional[Union[str, bytes]] = None
|
|
578
|
-
|
|
1230
|
+
|
|
1231
|
+
# If using proj/location, fetch ADC
|
|
1232
|
+
if self.vertexai and (self.project or self.location):
|
|
579
1233
|
http_request.headers['Authorization'] = (
|
|
580
1234
|
f'Bearer {await self._async_access_token()}'
|
|
581
1235
|
)
|
|
@@ -592,39 +1246,133 @@ class BaseApiClient:
|
|
|
592
1246
|
data = http_request.data
|
|
593
1247
|
|
|
594
1248
|
if stream:
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
1249
|
+
if self._use_aiohttp():
|
|
1250
|
+
self._aiohttp_session = await self._get_aiohttp_session()
|
|
1251
|
+
try:
|
|
1252
|
+
response = await self._aiohttp_session.request(
|
|
1253
|
+
method=http_request.method,
|
|
1254
|
+
url=http_request.url,
|
|
1255
|
+
headers=http_request.headers,
|
|
1256
|
+
data=data,
|
|
1257
|
+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
|
1258
|
+
**self._async_client_session_request_args,
|
|
1259
|
+
)
|
|
1260
|
+
except (
|
|
1261
|
+
aiohttp.ClientConnectorError,
|
|
1262
|
+
aiohttp.ClientConnectorDNSError,
|
|
1263
|
+
aiohttp.ClientOSError,
|
|
1264
|
+
aiohttp.ServerDisconnectedError,
|
|
1265
|
+
) as e:
|
|
1266
|
+
await asyncio.sleep(1 + random.randint(0, 9))
|
|
1267
|
+
logger.info('Retrying due to aiohttp error: %s' % e)
|
|
1268
|
+
# Retrieve the SSL context from the session.
|
|
1269
|
+
self._async_client_session_request_args = (
|
|
1270
|
+
self._ensure_aiohttp_ssl_ctx(self._http_options)
|
|
1271
|
+
)
|
|
1272
|
+
# Instantiate a new session with the updated SSL context.
|
|
1273
|
+
self._aiohttp_session = await self._get_aiohttp_session()
|
|
1274
|
+
response = await self._aiohttp_session.request(
|
|
1275
|
+
method=http_request.method,
|
|
1276
|
+
url=http_request.url,
|
|
1277
|
+
headers=http_request.headers,
|
|
1278
|
+
data=data,
|
|
1279
|
+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
|
1280
|
+
**self._async_client_session_request_args,
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
await errors.APIError.raise_for_async_response(response)
|
|
1284
|
+
return HttpResponse(response.headers, response)
|
|
1285
|
+
else:
|
|
1286
|
+
# aiohttp is not available. Fall back to httpx.
|
|
1287
|
+
httpx_request = self._async_httpx_client.build_request(
|
|
1288
|
+
method=http_request.method,
|
|
1289
|
+
url=http_request.url,
|
|
1290
|
+
content=data,
|
|
1291
|
+
headers=http_request.headers,
|
|
1292
|
+
timeout=http_request.timeout,
|
|
1293
|
+
)
|
|
1294
|
+
client_response = await self._async_httpx_client.send(
|
|
1295
|
+
httpx_request,
|
|
1296
|
+
stream=stream,
|
|
1297
|
+
)
|
|
1298
|
+
await errors.APIError.raise_for_async_response(client_response)
|
|
1299
|
+
return HttpResponse(client_response.headers, client_response)
|
|
610
1300
|
else:
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
1301
|
+
if self._use_aiohttp():
|
|
1302
|
+
self._aiohttp_session = await self._get_aiohttp_session()
|
|
1303
|
+
try:
|
|
1304
|
+
response = await self._aiohttp_session.request(
|
|
1305
|
+
method=http_request.method,
|
|
1306
|
+
url=http_request.url,
|
|
1307
|
+
headers=http_request.headers,
|
|
1308
|
+
data=data,
|
|
1309
|
+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
|
1310
|
+
**self._async_client_session_request_args,
|
|
1311
|
+
)
|
|
1312
|
+
await errors.APIError.raise_for_async_response(response)
|
|
1313
|
+
return HttpResponse(response.headers, [await response.text()])
|
|
1314
|
+
except (
|
|
1315
|
+
aiohttp.ClientConnectorError,
|
|
1316
|
+
aiohttp.ClientConnectorDNSError,
|
|
1317
|
+
aiohttp.ClientOSError,
|
|
1318
|
+
aiohttp.ServerDisconnectedError,
|
|
1319
|
+
) as e:
|
|
1320
|
+
await asyncio.sleep(1 + random.randint(0, 9))
|
|
1321
|
+
logger.info('Retrying due to aiohttp error: %s' % e)
|
|
1322
|
+
# Retrieve the SSL context from the session.
|
|
1323
|
+
self._async_client_session_request_args = (
|
|
1324
|
+
self._ensure_aiohttp_ssl_ctx(self._http_options)
|
|
1325
|
+
)
|
|
1326
|
+
# Instantiate a new session with the updated SSL context.
|
|
1327
|
+
self._aiohttp_session = await self._get_aiohttp_session()
|
|
1328
|
+
response = await self._aiohttp_session.request(
|
|
1329
|
+
method=http_request.method,
|
|
1330
|
+
url=http_request.url,
|
|
1331
|
+
headers=http_request.headers,
|
|
1332
|
+
data=data,
|
|
1333
|
+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
|
1334
|
+
**self._async_client_session_request_args,
|
|
1335
|
+
)
|
|
1336
|
+
await errors.APIError.raise_for_async_response(response)
|
|
1337
|
+
return HttpResponse(response.headers, [await response.text()])
|
|
1338
|
+
else:
|
|
1339
|
+
# aiohttp is not available. Fall back to httpx.
|
|
1340
|
+
client_response = await self._async_httpx_client.request(
|
|
1341
|
+
method=http_request.method,
|
|
1342
|
+
url=http_request.url,
|
|
1343
|
+
headers=http_request.headers,
|
|
1344
|
+
content=data,
|
|
1345
|
+
timeout=http_request.timeout,
|
|
1346
|
+
)
|
|
1347
|
+
await errors.APIError.raise_for_async_response(client_response)
|
|
1348
|
+
return HttpResponse(client_response.headers, [client_response.text])
|
|
1349
|
+
|
|
1350
|
+
async def _async_request(
|
|
1351
|
+
self,
|
|
1352
|
+
http_request: HttpRequest,
|
|
1353
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1354
|
+
stream: bool = False,
|
|
1355
|
+
) -> HttpResponse:
|
|
1356
|
+
if http_options:
|
|
1357
|
+
parameter_model = (
|
|
1358
|
+
HttpOptions(**http_options)
|
|
1359
|
+
if isinstance(http_options, dict)
|
|
1360
|
+
else http_options
|
|
621
1361
|
)
|
|
1362
|
+
# Support per request retry options.
|
|
1363
|
+
if parameter_model.retry_options:
|
|
1364
|
+
retry_kwargs = retry_args(parameter_model.retry_options)
|
|
1365
|
+
retry = tenacity.AsyncRetrying(**retry_kwargs)
|
|
1366
|
+
return await retry(self._async_request_once, http_request, stream) # type: ignore[no-any-return]
|
|
1367
|
+
return await self._async_retry( # type: ignore[no-any-return]
|
|
1368
|
+
self._async_request_once, http_request, stream
|
|
1369
|
+
)
|
|
622
1370
|
|
|
623
|
-
def get_read_only_http_options(self) ->
|
|
624
|
-
copied = HttpOptionsDict()
|
|
1371
|
+
def get_read_only_http_options(self) -> _common.StringDict:
|
|
625
1372
|
if isinstance(self._http_options, BaseModel):
|
|
626
|
-
|
|
627
|
-
|
|
1373
|
+
copied = self._http_options.model_dump()
|
|
1374
|
+
else:
|
|
1375
|
+
copied = self._http_options
|
|
628
1376
|
return copied
|
|
629
1377
|
|
|
630
1378
|
def request(
|
|
@@ -633,32 +1381,44 @@ class BaseApiClient:
|
|
|
633
1381
|
path: str,
|
|
634
1382
|
request_dict: dict[str, object],
|
|
635
1383
|
http_options: Optional[HttpOptionsOrDict] = None,
|
|
636
|
-
):
|
|
1384
|
+
) -> SdkHttpResponse:
|
|
637
1385
|
http_request = self._build_request(
|
|
638
1386
|
http_method, path, request_dict, http_options
|
|
639
1387
|
)
|
|
640
|
-
response = self._request(http_request, stream=False)
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
)
|
|
646
|
-
return json_response
|
|
1388
|
+
response = self._request(http_request, http_options, stream=False)
|
|
1389
|
+
response_body = (
|
|
1390
|
+
response.response_stream[0] if response.response_stream else ''
|
|
1391
|
+
)
|
|
1392
|
+
return SdkHttpResponse(headers=response.headers, body=response_body)
|
|
647
1393
|
|
|
648
1394
|
def request_streamed(
|
|
649
1395
|
self,
|
|
650
1396
|
http_method: str,
|
|
651
1397
|
path: str,
|
|
652
1398
|
request_dict: dict[str, object],
|
|
653
|
-
http_options: Optional[
|
|
654
|
-
):
|
|
1399
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1400
|
+
) -> Generator[SdkHttpResponse, None, None]:
|
|
655
1401
|
http_request = self._build_request(
|
|
656
1402
|
http_method, path, request_dict, http_options
|
|
657
1403
|
)
|
|
658
1404
|
|
|
659
|
-
session_response = self._request(http_request, stream=True)
|
|
1405
|
+
session_response = self._request(http_request, http_options, stream=True)
|
|
660
1406
|
for chunk in session_response.segments():
|
|
661
|
-
|
|
1407
|
+
chunk_dump = json.dumps(chunk)
|
|
1408
|
+
try:
|
|
1409
|
+
if chunk_dump.startswith('{"error":'):
|
|
1410
|
+
chunk_json = json.loads(chunk_dump)
|
|
1411
|
+
errors.APIError.raise_error(
|
|
1412
|
+
chunk_json.get('error', {}).get('code'),
|
|
1413
|
+
chunk_json,
|
|
1414
|
+
session_response,
|
|
1415
|
+
)
|
|
1416
|
+
except json.decoder.JSONDecodeError:
|
|
1417
|
+
logger.debug(
|
|
1418
|
+
'Failed to decode chunk that contains an error: %s' % chunk_dump
|
|
1419
|
+
)
|
|
1420
|
+
pass
|
|
1421
|
+
yield SdkHttpResponse(headers=session_response.headers, body=chunk_dump)
|
|
662
1422
|
|
|
663
1423
|
async def async_request(
|
|
664
1424
|
self,
|
|
@@ -666,39 +1426,58 @@ class BaseApiClient:
|
|
|
666
1426
|
path: str,
|
|
667
1427
|
request_dict: dict[str, object],
|
|
668
1428
|
http_options: Optional[HttpOptionsOrDict] = None,
|
|
669
|
-
) ->
|
|
1429
|
+
) -> SdkHttpResponse:
|
|
670
1430
|
http_request = self._build_request(
|
|
671
1431
|
http_method, path, request_dict, http_options
|
|
672
1432
|
)
|
|
673
1433
|
|
|
674
|
-
result = await self._async_request(
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
return
|
|
1434
|
+
result = await self._async_request(
|
|
1435
|
+
http_request=http_request, http_options=http_options, stream=False
|
|
1436
|
+
)
|
|
1437
|
+
response_body = result.response_stream[0] if result.response_stream else ''
|
|
1438
|
+
return SdkHttpResponse(headers=result.headers, body=response_body)
|
|
679
1439
|
|
|
680
1440
|
async def async_request_streamed(
|
|
681
1441
|
self,
|
|
682
1442
|
http_method: str,
|
|
683
1443
|
path: str,
|
|
684
1444
|
request_dict: dict[str, object],
|
|
685
|
-
http_options: Optional[
|
|
686
|
-
):
|
|
1445
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1446
|
+
) -> Any:
|
|
687
1447
|
http_request = self._build_request(
|
|
688
1448
|
http_method, path, request_dict, http_options
|
|
689
1449
|
)
|
|
690
1450
|
|
|
691
1451
|
response = await self._async_request(http_request=http_request, stream=True)
|
|
692
1452
|
|
|
693
|
-
async def async_generator():
|
|
1453
|
+
async def async_generator(): # type: ignore[no-untyped-def]
|
|
694
1454
|
async for chunk in response:
|
|
695
|
-
|
|
1455
|
+
chunk_dump = json.dumps(chunk)
|
|
1456
|
+
try:
|
|
1457
|
+
if chunk_dump.startswith('{"error":'):
|
|
1458
|
+
chunk_json = json.loads(chunk_dump)
|
|
1459
|
+
await errors.APIError.raise_error_async(
|
|
1460
|
+
chunk_json.get('error', {}).get('code'),
|
|
1461
|
+
chunk_json,
|
|
1462
|
+
response,
|
|
1463
|
+
)
|
|
1464
|
+
except json.decoder.JSONDecodeError:
|
|
1465
|
+
logger.debug(
|
|
1466
|
+
'Failed to decode chunk that contains an error: %s' % chunk_dump
|
|
1467
|
+
)
|
|
1468
|
+
pass
|
|
1469
|
+
yield SdkHttpResponse(headers=response.headers, body=chunk_dump)
|
|
696
1470
|
|
|
697
|
-
return async_generator()
|
|
1471
|
+
return async_generator() # type: ignore[no-untyped-call]
|
|
698
1472
|
|
|
699
1473
|
def upload_file(
|
|
700
|
-
self,
|
|
701
|
-
|
|
1474
|
+
self,
|
|
1475
|
+
file_path: Union[str, io.IOBase],
|
|
1476
|
+
upload_url: str,
|
|
1477
|
+
upload_size: int,
|
|
1478
|
+
*,
|
|
1479
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1480
|
+
) -> HttpResponse:
|
|
702
1481
|
"""Transfers a file to the given URL.
|
|
703
1482
|
|
|
704
1483
|
Args:
|
|
@@ -708,19 +1487,29 @@ class BaseApiClient:
|
|
|
708
1487
|
upload_url: The URL to upload the file to.
|
|
709
1488
|
upload_size: The size of file content to be uploaded, this will have to
|
|
710
1489
|
match the size requested in the resumable upload request.
|
|
1490
|
+
http_options: The http options to use for the request.
|
|
711
1491
|
|
|
712
1492
|
returns:
|
|
713
|
-
The
|
|
1493
|
+
The HttpResponse object from the finalize request.
|
|
714
1494
|
"""
|
|
715
1495
|
if isinstance(file_path, io.IOBase):
|
|
716
|
-
return self._upload_fd(
|
|
1496
|
+
return self._upload_fd(
|
|
1497
|
+
file_path, upload_url, upload_size, http_options=http_options
|
|
1498
|
+
)
|
|
717
1499
|
else:
|
|
718
1500
|
with open(file_path, 'rb') as file:
|
|
719
|
-
return self._upload_fd(
|
|
1501
|
+
return self._upload_fd(
|
|
1502
|
+
file, upload_url, upload_size, http_options=http_options
|
|
1503
|
+
)
|
|
720
1504
|
|
|
721
1505
|
def _upload_fd(
|
|
722
|
-
self,
|
|
723
|
-
|
|
1506
|
+
self,
|
|
1507
|
+
file: io.IOBase,
|
|
1508
|
+
upload_url: str,
|
|
1509
|
+
upload_size: int,
|
|
1510
|
+
*,
|
|
1511
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1512
|
+
) -> HttpResponse:
|
|
724
1513
|
"""Transfers a file to the given URL.
|
|
725
1514
|
|
|
726
1515
|
Args:
|
|
@@ -728,9 +1517,10 @@ class BaseApiClient:
|
|
|
728
1517
|
upload_url: The URL to upload the file to.
|
|
729
1518
|
upload_size: The size of file content to be uploaded, this will have to
|
|
730
1519
|
match the size requested in the resumable upload request.
|
|
1520
|
+
http_options: The http options to use for the request.
|
|
731
1521
|
|
|
732
1522
|
returns:
|
|
733
|
-
The
|
|
1523
|
+
The HttpResponse object from the finalize request.
|
|
734
1524
|
"""
|
|
735
1525
|
offset = 0
|
|
736
1526
|
# Upload the file in chunks
|
|
@@ -743,32 +1533,60 @@ class BaseApiClient:
|
|
|
743
1533
|
# If last chunk, finalize the upload.
|
|
744
1534
|
if chunk_size + offset >= upload_size:
|
|
745
1535
|
upload_command += ', finalize'
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
'X-Goog-Upload-Offset': str(offset),
|
|
752
|
-
'Content-Length': str(chunk_size),
|
|
753
|
-
},
|
|
754
|
-
content=file_chunk,
|
|
1536
|
+
http_options = http_options if http_options else self._http_options
|
|
1537
|
+
timeout = (
|
|
1538
|
+
http_options.get('timeout')
|
|
1539
|
+
if isinstance(http_options, dict)
|
|
1540
|
+
else http_options.timeout
|
|
755
1541
|
)
|
|
1542
|
+
if timeout is None:
|
|
1543
|
+
# Per request timeout is not configured. Check the global timeout.
|
|
1544
|
+
timeout = (
|
|
1545
|
+
self._http_options.timeout
|
|
1546
|
+
if isinstance(self._http_options, dict)
|
|
1547
|
+
else self._http_options.timeout
|
|
1548
|
+
)
|
|
1549
|
+
timeout_in_seconds = get_timeout_in_seconds(timeout)
|
|
1550
|
+
upload_headers = {
|
|
1551
|
+
'X-Goog-Upload-Command': upload_command,
|
|
1552
|
+
'X-Goog-Upload-Offset': str(offset),
|
|
1553
|
+
'Content-Length': str(chunk_size),
|
|
1554
|
+
}
|
|
1555
|
+
populate_server_timeout_header(upload_headers, timeout_in_seconds)
|
|
1556
|
+
retry_count = 0
|
|
1557
|
+
while retry_count < MAX_RETRY_COUNT:
|
|
1558
|
+
response = self._httpx_client.request(
|
|
1559
|
+
method='POST',
|
|
1560
|
+
url=upload_url,
|
|
1561
|
+
headers=upload_headers,
|
|
1562
|
+
content=file_chunk,
|
|
1563
|
+
timeout=timeout_in_seconds,
|
|
1564
|
+
)
|
|
1565
|
+
if response.headers.get('x-goog-upload-status'):
|
|
1566
|
+
break
|
|
1567
|
+
delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
|
|
1568
|
+
retry_count += 1
|
|
1569
|
+
time.sleep(delay_seconds)
|
|
1570
|
+
|
|
756
1571
|
offset += chunk_size
|
|
757
|
-
if response.headers
|
|
1572
|
+
if response.headers.get('x-goog-upload-status') != 'active':
|
|
758
1573
|
break # upload is complete or it has been interrupted.
|
|
759
1574
|
if upload_size <= offset: # Status is not finalized.
|
|
760
1575
|
raise ValueError(
|
|
761
|
-
'All content has been uploaded, but the upload status is not'
|
|
1576
|
+
f'All content has been uploaded, but the upload status is not'
|
|
762
1577
|
f' finalized.'
|
|
763
1578
|
)
|
|
1579
|
+
errors.APIError.raise_for_response(response)
|
|
1580
|
+
if response.headers.get('x-goog-upload-status') != 'final':
|
|
1581
|
+
raise ValueError('Failed to upload file: Upload status is not finalized.')
|
|
1582
|
+
return HttpResponse(response.headers, response_stream=[response.text])
|
|
764
1583
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
def download_file(self, path: str, http_options):
|
|
1584
|
+
def download_file(
|
|
1585
|
+
self,
|
|
1586
|
+
path: str,
|
|
1587
|
+
*,
|
|
1588
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1589
|
+
) -> Union[Any, bytes]:
|
|
772
1590
|
"""Downloads the file data.
|
|
773
1591
|
|
|
774
1592
|
Args:
|
|
@@ -807,7 +1625,9 @@ class BaseApiClient:
|
|
|
807
1625
|
file_path: Union[str, io.IOBase],
|
|
808
1626
|
upload_url: str,
|
|
809
1627
|
upload_size: int,
|
|
810
|
-
|
|
1628
|
+
*,
|
|
1629
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1630
|
+
) -> HttpResponse:
|
|
811
1631
|
"""Transfers a file asynchronously to the given URL.
|
|
812
1632
|
|
|
813
1633
|
Args:
|
|
@@ -816,24 +1636,31 @@ class BaseApiClient:
|
|
|
816
1636
|
upload_url: The URL to upload the file to.
|
|
817
1637
|
upload_size: The size of file content to be uploaded, this will have to
|
|
818
1638
|
match the size requested in the resumable upload request.
|
|
1639
|
+
http_options: The http options to use for the request.
|
|
819
1640
|
|
|
820
1641
|
returns:
|
|
821
|
-
The
|
|
1642
|
+
The HttpResponse object from the finalize request.
|
|
822
1643
|
"""
|
|
823
1644
|
if isinstance(file_path, io.IOBase):
|
|
824
|
-
return await self._async_upload_fd(
|
|
1645
|
+
return await self._async_upload_fd(
|
|
1646
|
+
file_path, upload_url, upload_size, http_options=http_options
|
|
1647
|
+
)
|
|
825
1648
|
else:
|
|
826
1649
|
file = anyio.Path(file_path)
|
|
827
1650
|
fd = await file.open('rb')
|
|
828
1651
|
async with fd:
|
|
829
|
-
return await self._async_upload_fd(
|
|
1652
|
+
return await self._async_upload_fd(
|
|
1653
|
+
fd, upload_url, upload_size, http_options=http_options
|
|
1654
|
+
)
|
|
830
1655
|
|
|
831
1656
|
async def _async_upload_fd(
|
|
832
1657
|
self,
|
|
833
|
-
file: Union[io.IOBase, anyio.AsyncFile],
|
|
1658
|
+
file: Union[io.IOBase, anyio.AsyncFile[Any]],
|
|
834
1659
|
upload_url: str,
|
|
835
1660
|
upload_size: int,
|
|
836
|
-
|
|
1661
|
+
*,
|
|
1662
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1663
|
+
) -> HttpResponse:
|
|
837
1664
|
"""Transfers a file asynchronously to the given URL.
|
|
838
1665
|
|
|
839
1666
|
Args:
|
|
@@ -841,50 +1668,175 @@ class BaseApiClient:
|
|
|
841
1668
|
upload_url: The URL to upload the file to.
|
|
842
1669
|
upload_size: The size of file content to be uploaded, this will have to
|
|
843
1670
|
match the size requested in the resumable upload request.
|
|
1671
|
+
http_options: The http options to use for the request.
|
|
844
1672
|
|
|
845
1673
|
returns:
|
|
846
|
-
The
|
|
1674
|
+
The HttpResponse object from the finalized request.
|
|
847
1675
|
"""
|
|
848
1676
|
offset = 0
|
|
849
1677
|
# Upload the file in chunks
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
chunk_size =
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
1678
|
+
if self._use_aiohttp(): # pylint: disable=g-import-not-at-top
|
|
1679
|
+
self._aiohttp_session = await self._get_aiohttp_session()
|
|
1680
|
+
while True:
|
|
1681
|
+
if isinstance(file, io.IOBase):
|
|
1682
|
+
file_chunk = file.read(CHUNK_SIZE)
|
|
1683
|
+
else:
|
|
1684
|
+
file_chunk = await file.read(CHUNK_SIZE)
|
|
1685
|
+
chunk_size = 0
|
|
1686
|
+
if file_chunk:
|
|
1687
|
+
chunk_size = len(file_chunk)
|
|
1688
|
+
upload_command = 'upload'
|
|
1689
|
+
# If last chunk, finalize the upload.
|
|
1690
|
+
if chunk_size + offset >= upload_size:
|
|
1691
|
+
upload_command += ', finalize'
|
|
1692
|
+
http_options = http_options if http_options else self._http_options
|
|
1693
|
+
timeout = (
|
|
1694
|
+
http_options.get('timeout')
|
|
1695
|
+
if isinstance(http_options, dict)
|
|
1696
|
+
else http_options.timeout
|
|
1697
|
+
)
|
|
1698
|
+
if timeout is None:
|
|
1699
|
+
# Per request timeout is not configured. Check the global timeout.
|
|
1700
|
+
timeout = (
|
|
1701
|
+
self._http_options.timeout
|
|
1702
|
+
if isinstance(self._http_options, dict)
|
|
1703
|
+
else self._http_options.timeout
|
|
1704
|
+
)
|
|
1705
|
+
timeout_in_seconds = get_timeout_in_seconds(timeout)
|
|
1706
|
+
upload_headers = {
|
|
1707
|
+
'X-Goog-Upload-Command': upload_command,
|
|
1708
|
+
'X-Goog-Upload-Offset': str(offset),
|
|
1709
|
+
'Content-Length': str(chunk_size),
|
|
1710
|
+
}
|
|
1711
|
+
populate_server_timeout_header(upload_headers, timeout_in_seconds)
|
|
1712
|
+
|
|
1713
|
+
retry_count = 0
|
|
1714
|
+
response = None
|
|
1715
|
+
while retry_count < MAX_RETRY_COUNT:
|
|
1716
|
+
response = await self._aiohttp_session.request(
|
|
1717
|
+
method='POST',
|
|
1718
|
+
url=upload_url,
|
|
1719
|
+
data=file_chunk,
|
|
1720
|
+
headers=upload_headers,
|
|
1721
|
+
timeout=aiohttp.ClientTimeout(connect=timeout_in_seconds),
|
|
1722
|
+
)
|
|
1723
|
+
|
|
1724
|
+
if response.headers.get('X-Goog-Upload-Status'):
|
|
1725
|
+
break
|
|
1726
|
+
delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
|
|
1727
|
+
retry_count += 1
|
|
1728
|
+
await asyncio.sleep(delay_seconds)
|
|
1729
|
+
|
|
1730
|
+
offset += chunk_size
|
|
1731
|
+
if (
|
|
1732
|
+
response is not None
|
|
1733
|
+
and response.headers.get('X-Goog-Upload-Status') != 'active'
|
|
1734
|
+
):
|
|
1735
|
+
break # upload is complete or it has been interrupted.
|
|
1736
|
+
|
|
1737
|
+
if upload_size <= offset: # Status is not finalized.
|
|
1738
|
+
raise ValueError(
|
|
1739
|
+
f'All content has been uploaded, but the upload status is not'
|
|
1740
|
+
f' finalized.'
|
|
1741
|
+
)
|
|
1742
|
+
|
|
1743
|
+
await errors.APIError.raise_for_async_response(response)
|
|
1744
|
+
if (
|
|
1745
|
+
response is not None
|
|
1746
|
+
and response.headers.get('X-Goog-Upload-Status') != 'final'
|
|
1747
|
+
):
|
|
1748
|
+
raise ValueError(
|
|
1749
|
+
'Failed to upload file: Upload status is not finalized.'
|
|
1750
|
+
)
|
|
1751
|
+
return HttpResponse(
|
|
1752
|
+
response.headers, response_stream=[await response.text()]
|
|
871
1753
|
)
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
1754
|
+
else:
|
|
1755
|
+
# aiohttp is not available. Fall back to httpx.
|
|
1756
|
+
while True:
|
|
1757
|
+
if isinstance(file, io.IOBase):
|
|
1758
|
+
file_chunk = file.read(CHUNK_SIZE)
|
|
1759
|
+
else:
|
|
1760
|
+
file_chunk = await file.read(CHUNK_SIZE)
|
|
1761
|
+
chunk_size = 0
|
|
1762
|
+
if file_chunk:
|
|
1763
|
+
chunk_size = len(file_chunk)
|
|
1764
|
+
upload_command = 'upload'
|
|
1765
|
+
# If last chunk, finalize the upload.
|
|
1766
|
+
if chunk_size + offset >= upload_size:
|
|
1767
|
+
upload_command += ', finalize'
|
|
1768
|
+
http_options = http_options if http_options else self._http_options
|
|
1769
|
+
timeout = (
|
|
1770
|
+
http_options.get('timeout')
|
|
1771
|
+
if isinstance(http_options, dict)
|
|
1772
|
+
else http_options.timeout
|
|
1773
|
+
)
|
|
1774
|
+
if timeout is None:
|
|
1775
|
+
# Per request timeout is not configured. Check the global timeout.
|
|
1776
|
+
timeout = (
|
|
1777
|
+
self._http_options.timeout
|
|
1778
|
+
if isinstance(self._http_options, dict)
|
|
1779
|
+
else self._http_options.timeout
|
|
1780
|
+
)
|
|
1781
|
+
timeout_in_seconds = get_timeout_in_seconds(timeout)
|
|
1782
|
+
upload_headers = {
|
|
1783
|
+
'X-Goog-Upload-Command': upload_command,
|
|
1784
|
+
'X-Goog-Upload-Offset': str(offset),
|
|
1785
|
+
'Content-Length': str(chunk_size),
|
|
1786
|
+
}
|
|
1787
|
+
populate_server_timeout_header(upload_headers, timeout_in_seconds)
|
|
1788
|
+
|
|
1789
|
+
retry_count = 0
|
|
1790
|
+
client_response = None
|
|
1791
|
+
while retry_count < MAX_RETRY_COUNT:
|
|
1792
|
+
client_response = await self._async_httpx_client.request(
|
|
1793
|
+
method='POST',
|
|
1794
|
+
url=upload_url,
|
|
1795
|
+
content=file_chunk,
|
|
1796
|
+
headers=upload_headers,
|
|
1797
|
+
timeout=timeout_in_seconds,
|
|
1798
|
+
)
|
|
1799
|
+
if (
|
|
1800
|
+
client_response is not None
|
|
1801
|
+
and client_response.headers
|
|
1802
|
+
and client_response.headers.get('x-goog-upload-status')
|
|
1803
|
+
):
|
|
1804
|
+
break
|
|
1805
|
+
delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
|
|
1806
|
+
retry_count += 1
|
|
1807
|
+
time.sleep(delay_seconds)
|
|
1808
|
+
|
|
1809
|
+
offset += chunk_size
|
|
1810
|
+
if (
|
|
1811
|
+
client_response is not None
|
|
1812
|
+
and client_response.headers.get('x-goog-upload-status') != 'active'
|
|
1813
|
+
):
|
|
1814
|
+
break # upload is complete or it has been interrupted.
|
|
1815
|
+
|
|
1816
|
+
if upload_size <= offset: # Status is not finalized.
|
|
1817
|
+
raise ValueError(
|
|
1818
|
+
'All content has been uploaded, but the upload status is not'
|
|
1819
|
+
' finalized.'
|
|
1820
|
+
)
|
|
875
1821
|
|
|
876
|
-
|
|
1822
|
+
await errors.APIError.raise_for_async_response(client_response)
|
|
1823
|
+
if (
|
|
1824
|
+
client_response is not None
|
|
1825
|
+
and client_response.headers.get('x-goog-upload-status') != 'final'
|
|
1826
|
+
):
|
|
877
1827
|
raise ValueError(
|
|
878
|
-
'
|
|
879
|
-
f' finalized.'
|
|
1828
|
+
'Failed to upload file: Upload status is not finalized.'
|
|
880
1829
|
)
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
'Failed to upload file: Upload status is not finalized.'
|
|
1830
|
+
return HttpResponse(
|
|
1831
|
+
client_response.headers, response_stream=[client_response.text]
|
|
884
1832
|
)
|
|
885
|
-
return response.json()
|
|
886
1833
|
|
|
887
|
-
async def async_download_file(
|
|
1834
|
+
async def async_download_file(
|
|
1835
|
+
self,
|
|
1836
|
+
path: str,
|
|
1837
|
+
*,
|
|
1838
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
|
1839
|
+
) -> Union[Any, bytes]:
|
|
888
1840
|
"""Downloads the file data.
|
|
889
1841
|
|
|
890
1842
|
Args:
|
|
@@ -905,21 +1857,71 @@ class BaseApiClient:
|
|
|
905
1857
|
else:
|
|
906
1858
|
data = http_request.data
|
|
907
1859
|
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
1860
|
+
if self._use_aiohttp():
|
|
1861
|
+
self._aiohttp_session = await self._get_aiohttp_session()
|
|
1862
|
+
response = await self._aiohttp_session.request(
|
|
1863
|
+
method=http_request.method,
|
|
1864
|
+
url=http_request.url,
|
|
1865
|
+
headers=http_request.headers,
|
|
1866
|
+
data=data,
|
|
1867
|
+
timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
|
|
1868
|
+
)
|
|
1869
|
+
await errors.APIError.raise_for_async_response(response)
|
|
916
1870
|
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
1871
|
+
return HttpResponse(
|
|
1872
|
+
response.headers, byte_stream=[await response.read()]
|
|
1873
|
+
).byte_stream[0]
|
|
1874
|
+
else:
|
|
1875
|
+
# aiohttp is not available. Fall back to httpx.
|
|
1876
|
+
client_response = await self._async_httpx_client.request(
|
|
1877
|
+
method=http_request.method,
|
|
1878
|
+
url=http_request.url,
|
|
1879
|
+
headers=http_request.headers,
|
|
1880
|
+
content=data,
|
|
1881
|
+
timeout=http_request.timeout,
|
|
1882
|
+
)
|
|
1883
|
+
await errors.APIError.raise_for_async_response(client_response)
|
|
1884
|
+
|
|
1885
|
+
return HttpResponse(
|
|
1886
|
+
client_response.headers, byte_stream=[client_response.read()]
|
|
1887
|
+
).byte_stream[0]
|
|
920
1888
|
|
|
921
1889
|
# This method does nothing in the real api client. It is used in the
|
|
922
1890
|
# replay_api_client to verify the response from the SDK method matches the
|
|
923
1891
|
# recorded response.
|
|
924
|
-
def _verify_response(self, response_model: _common.BaseModel):
|
|
1892
|
+
def _verify_response(self, response_model: _common.BaseModel) -> None:
|
|
925
1893
|
pass
|
|
1894
|
+
|
|
1895
|
+
def close(self) -> None:
|
|
1896
|
+
"""Closes the API client."""
|
|
1897
|
+
# Let users close the custom client explicitly by themselves. Otherwise,
|
|
1898
|
+
# close the client when the object is garbage collected.
|
|
1899
|
+
if not self._http_options.httpx_client:
|
|
1900
|
+
self._httpx_client.close()
|
|
1901
|
+
|
|
1902
|
+
async def aclose(self) -> None:
|
|
1903
|
+
"""Closes the API async client."""
|
|
1904
|
+
# Let users close the custom client explicitly by themselves. Otherwise,
|
|
1905
|
+
# close the client when the object is garbage collected.
|
|
1906
|
+
if not self._http_options.httpx_async_client:
|
|
1907
|
+
await self._async_httpx_client.aclose()
|
|
1908
|
+
if self._aiohttp_session:
|
|
1909
|
+
await self._aiohttp_session.close()
|
|
1910
|
+
|
|
1911
|
+
def __del__(self) -> None:
|
|
1912
|
+
"""Closes the API client when the object is garbage collected.
|
|
1913
|
+
|
|
1914
|
+
ADK uses this client so cannot rely on the genai.[Async]Client.__del__
|
|
1915
|
+
for cleanup.
|
|
1916
|
+
"""
|
|
1917
|
+
|
|
1918
|
+
try:
|
|
1919
|
+
if not self._http_options.httpx_client:
|
|
1920
|
+
self.close()
|
|
1921
|
+
except Exception: # pylint: disable=broad-except
|
|
1922
|
+
pass
|
|
1923
|
+
|
|
1924
|
+
try:
|
|
1925
|
+
asyncio.get_running_loop().create_task(self.aclose())
|
|
1926
|
+
except Exception: # pylint: disable=broad-except
|
|
1927
|
+
pass
|