google-genai 1.7.0__py3-none-any.whl → 1.53.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. google/genai/__init__.py +4 -2
  2. google/genai/_adapters.py +55 -0
  3. google/genai/_api_client.py +1301 -299
  4. google/genai/_api_module.py +1 -1
  5. google/genai/_automatic_function_calling_util.py +54 -33
  6. google/genai/_base_transformers.py +26 -0
  7. google/genai/_base_url.py +50 -0
  8. google/genai/_common.py +560 -59
  9. google/genai/_extra_utils.py +371 -38
  10. google/genai/_live_converters.py +1467 -0
  11. google/genai/_local_tokenizer_loader.py +214 -0
  12. google/genai/_mcp_utils.py +117 -0
  13. google/genai/_operations_converters.py +394 -0
  14. google/genai/_replay_api_client.py +204 -92
  15. google/genai/_test_api_client.py +1 -1
  16. google/genai/_tokens_converters.py +520 -0
  17. google/genai/_transformers.py +633 -233
  18. google/genai/batches.py +1733 -538
  19. google/genai/caches.py +678 -1012
  20. google/genai/chats.py +48 -38
  21. google/genai/client.py +142 -15
  22. google/genai/documents.py +532 -0
  23. google/genai/errors.py +141 -35
  24. google/genai/file_search_stores.py +1296 -0
  25. google/genai/files.py +312 -744
  26. google/genai/live.py +617 -367
  27. google/genai/live_music.py +197 -0
  28. google/genai/local_tokenizer.py +395 -0
  29. google/genai/models.py +3598 -3116
  30. google/genai/operations.py +201 -362
  31. google/genai/pagers.py +23 -7
  32. google/genai/py.typed +1 -0
  33. google/genai/tokens.py +362 -0
  34. google/genai/tunings.py +1274 -496
  35. google/genai/types.py +14535 -5454
  36. google/genai/version.py +2 -2
  37. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +736 -234
  38. google_genai-1.53.0.dist-info/RECORD +41 -0
  39. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +1 -1
  40. google_genai-1.7.0.dist-info/RECORD +0 -27
  41. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info/licenses}/LICENSE +0 -0
  42. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Google LLC
1
+ # Copyright 2025 Google LLC
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,34 +19,97 @@
19
19
  The BaseApiClient is intended to be a private module and is subject to change.
20
20
  """
21
21
 
22
- import anyio
23
22
  import asyncio
23
+ from collections.abc import Generator
24
24
  import copy
25
25
  from dataclasses import dataclass
26
- import datetime
26
+ import inspect
27
27
  import io
28
28
  import json
29
29
  import logging
30
+ import math
30
31
  import os
32
+ import random
33
+ import ssl
31
34
  import sys
32
- from typing import Any, AsyncIterator, Optional, Tuple, Union
33
- from urllib.parse import urlparse, urlunparse
35
+ import threading
36
+ import time
37
+ from typing import Any, AsyncIterator, Iterator, Optional, Tuple, TYPE_CHECKING, Union
38
+ from urllib.parse import urlparse
39
+ from urllib.parse import urlunparse
40
+ import warnings
41
+
42
+ import anyio
43
+ import certifi
34
44
  import google.auth
35
45
  import google.auth.credentials
36
46
  from google.auth.credentials import Credentials
37
- from google.auth.transport.requests import Request
38
47
  import httpx
39
- from pydantic import BaseModel, Field, ValidationError
48
+ from pydantic import BaseModel
49
+ from pydantic import ValidationError
50
+ import tenacity
51
+
40
52
  from . import _common
41
53
  from . import errors
42
54
  from . import version
43
- from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
55
+ from .types import HttpOptions
56
+ from .types import HttpOptionsOrDict
57
+ from .types import HttpResponse as SdkHttpResponse
58
+ from .types import HttpRetryOptions
59
+ from .types import ResourceScope
60
+
61
+
62
+ try:
63
+ from websockets.asyncio.client import connect as ws_connect
64
+ except ModuleNotFoundError:
65
+ # This try/except is for TAP, mypy complains about it which is why we have the type: ignore
66
+ from websockets.client import connect as ws_connect # type: ignore
67
+
68
+ has_aiohttp = False
69
+ try:
70
+ import aiohttp
71
+
72
+ has_aiohttp = True
73
+ except ImportError:
74
+ pass
75
+
76
+
77
+ if TYPE_CHECKING:
78
+ from multidict import CIMultiDictProxy
79
+
44
80
 
45
81
  logger = logging.getLogger('google_genai._api_client')
46
82
  CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
83
+ READ_BUFFER_SIZE = 2**22
84
+ MAX_RETRY_COUNT = 3
85
+ INITIAL_RETRY_DELAY = 1 # second
86
+ DELAY_MULTIPLIER = 2
87
+
88
+
89
+ class EphemeralTokenAPIKeyError(ValueError):
90
+ """Error raised when the API key is invalid."""
91
+
92
+
93
+ # This method checks for the API key in the environment variables. Google API
94
+ # key is precedenced over Gemini API key.
95
+ def get_env_api_key() -> Optional[str]:
96
+ """Gets the API key from environment variables, prioritizing GOOGLE_API_KEY.
97
+
98
+ Returns:
99
+ The API key string if found, otherwise None. Empty string is considered
100
+ invalid.
101
+ """
102
+ env_google_api_key = os.environ.get('GOOGLE_API_KEY', None)
103
+ env_gemini_api_key = os.environ.get('GEMINI_API_KEY', None)
104
+ if env_google_api_key and env_gemini_api_key:
105
+ logger.warning(
106
+ 'Both GOOGLE_API_KEY and GEMINI_API_KEY are set. Using GOOGLE_API_KEY.'
107
+ )
108
+
109
+ return env_google_api_key or env_gemini_api_key or None
47
110
 
48
111
 
49
- def _append_library_version_headers(headers: dict[str, str]) -> None:
112
+ def append_library_version_headers(headers: dict[str, str]) -> None:
50
113
  """Appends the telemetry header to the headers dict."""
51
114
  library_label = f'google-genai-sdk/{version.__version__}'
52
115
  language_label = 'gl-python/' + sys.version.split()[0]
@@ -55,43 +118,57 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
55
118
  'user-agent' in headers
56
119
  and version_header_value not in headers['user-agent']
57
120
  ):
58
- headers['user-agent'] += f' {version_header_value}'
121
+ headers['user-agent'] = f'{version_header_value} ' + headers['user-agent']
59
122
  elif 'user-agent' not in headers:
60
123
  headers['user-agent'] = version_header_value
61
124
  if (
62
125
  'x-goog-api-client' in headers
63
126
  and version_header_value not in headers['x-goog-api-client']
64
127
  ):
65
- headers['x-goog-api-client'] += f' {version_header_value}'
128
+ headers['x-goog-api-client'] = (
129
+ f'{version_header_value} ' + headers['x-goog-api-client']
130
+ )
66
131
  elif 'x-goog-api-client' not in headers:
67
132
  headers['x-goog-api-client'] = version_header_value
68
133
 
69
134
 
70
- def _patch_http_options(
71
- options: HttpOptionsDict, patch_options: dict[str, Any]
72
- ) -> HttpOptionsDict:
73
- # use shallow copy so we don't override the original objects.
74
- copy_option = HttpOptionsDict()
75
- copy_option.update(options)
76
- for patch_key, patch_value in patch_options.items():
77
- # if both are dicts, update the copy.
78
- # This is to handle cases like merging headers.
79
- if isinstance(patch_value, dict) and isinstance(
80
- copy_option.get(patch_key, None), dict
81
- ):
82
- copy_option[patch_key] = {}
83
- copy_option[patch_key].update(
84
- options[patch_key]
85
- ) # shallow copy from original options.
86
- copy_option[patch_key].update(patch_value)
87
- elif patch_value is not None: # Accept empty values.
88
- copy_option[patch_key] = patch_value
89
- if copy_option['headers']:
90
- _append_library_version_headers(copy_option['headers'])
135
+ def patch_http_options(
136
+ options: HttpOptions, patch_options: HttpOptions
137
+ ) -> HttpOptions:
138
+ copy_option = options.model_copy()
139
+
140
+ options_headers = copy_option.headers or {}
141
+ patch_options_headers = patch_options.headers or {}
142
+ copy_option.headers = {
143
+ **options_headers,
144
+ **patch_options_headers,
145
+ }
146
+
147
+ http_options_keys = HttpOptions.model_fields.keys()
148
+
149
+ for key in http_options_keys:
150
+ if key == 'headers':
151
+ continue
152
+ patch_value = getattr(patch_options, key, None)
153
+ if patch_value is not None:
154
+ setattr(copy_option, key, patch_value)
155
+ else:
156
+ setattr(copy_option, key, getattr(options, key))
157
+
158
+ if copy_option.headers is not None:
159
+ append_library_version_headers(copy_option.headers)
91
160
  return copy_option
92
161
 
93
162
 
94
- def _join_url_path(base_url: str, path: str) -> str:
163
+ def populate_server_timeout_header(
164
+ headers: dict[str, str], timeout_in_seconds: Optional[Union[float, int]]
165
+ ) -> None:
166
+ """Populates the server timeout header in the headers dict."""
167
+ if timeout_in_seconds and 'X-Server-Timeout' not in headers:
168
+ headers['X-Server-Timeout'] = str(math.ceil(timeout_in_seconds))
169
+
170
+
171
+ def join_url_path(base_url: str, path: str) -> str:
95
172
  parsed_base = urlparse(base_url)
96
173
  base_path = (
97
174
  parsed_base.path[:-1]
@@ -102,9 +179,9 @@ def _join_url_path(base_url: str, path: str) -> str:
102
179
  return urlunparse(parsed_base._replace(path=base_path + '/' + path))
103
180
 
104
181
 
105
- def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
182
+ def load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
106
183
  """Loads google auth credentials and project id."""
107
- credentials, loaded_project_id = google.auth.default(
184
+ credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
108
185
  scopes=['https://www.googleapis.com/auth/cloud-platform'],
109
186
  )
110
187
 
@@ -119,11 +196,25 @@ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
119
196
  return credentials, project
120
197
 
121
198
 
122
- def _refresh_auth(credentials: Credentials) -> Credentials:
123
- credentials.refresh(Request())
199
+ def refresh_auth(credentials: Credentials) -> Credentials:
200
+ from google.auth.transport.requests import Request
201
+ credentials.refresh(Request()) # type: ignore[no-untyped-call]
124
202
  return credentials
125
203
 
126
204
 
205
+ def get_timeout_in_seconds(
206
+ timeout: Optional[Union[float, int]],
207
+ ) -> Optional[float]:
208
+ """Converts the timeout to seconds."""
209
+ if timeout:
210
+ # HttpOptions.timeout is in milliseconds. But httpx.Client.request()
211
+ # expects seconds.
212
+ timeout_in_seconds = timeout / 1000.0
213
+ else:
214
+ timeout_in_seconds = None
215
+ return timeout_in_seconds
216
+
217
+
127
218
  @dataclass
128
219
  class HttpRequest:
129
220
  headers: dict[str, str]
@@ -133,37 +224,35 @@ class HttpRequest:
133
224
  timeout: Optional[float] = None
134
225
 
135
226
 
136
- # TODO(b/394358912): Update this class to use a SDKResponse class that can be
137
- # generated and used for all languages.
138
- class BaseResponse(_common.BaseModel):
139
- http_headers: Optional[dict[str, str]] = Field(
140
- default=None, description='The http headers of the response.'
141
- )
142
-
143
- json_payload: Optional[Any] = Field(
144
- default=None, description='The json payload of the response.'
145
- )
146
-
147
-
148
227
  class HttpResponse:
149
228
 
150
229
  def __init__(
151
230
  self,
152
- headers: Union[dict[str, str], httpx.Headers],
231
+ headers: Union[dict[str, str], httpx.Headers, 'CIMultiDictProxy[str]'],
153
232
  response_stream: Union[Any, str] = None,
154
233
  byte_stream: Union[Any, bytes] = None,
155
234
  ):
156
- self.status_code = 200
157
- self.headers = headers
235
+ if isinstance(headers, dict):
236
+ self.headers = headers
237
+ elif isinstance(headers, httpx.Headers):
238
+ self.headers = {
239
+ key: ', '.join(headers.get_list(key)) for key in headers.keys()
240
+ }
241
+ elif type(headers).__name__ == 'CIMultiDictProxy':
242
+ self.headers = {
243
+ key: ', '.join(headers.getall(key)) for key in headers.keys()
244
+ }
245
+
246
+ self.status_code: int = 200
158
247
  self.response_stream = response_stream
159
248
  self.byte_stream = byte_stream
160
249
 
161
250
  # Async iterator for async streaming.
162
- def __aiter__(self):
251
+ def __aiter__(self) -> 'HttpResponse':
163
252
  self.segment_iterator = self.async_segments()
164
253
  return self
165
254
 
166
- async def __anext__(self):
255
+ async def __anext__(self) -> Any:
167
256
  try:
168
257
  return await self.segment_iterator.__anext__()
169
258
  except StopIteration:
@@ -173,54 +262,34 @@ class HttpResponse:
173
262
  def json(self) -> Any:
174
263
  if not self.response_stream[0]: # Empty response
175
264
  return ''
176
- return json.loads(self.response_stream[0])
265
+ return self._load_json_from_response(self.response_stream[0])
177
266
 
178
- def segments(self):
267
+ def segments(self) -> Generator[Any, None, None]:
179
268
  if isinstance(self.response_stream, list):
180
269
  # list of objects retrieved from replay or from non-streaming API.
181
270
  for chunk in self.response_stream:
182
- yield json.loads(chunk) if chunk else {}
271
+ yield self._load_json_from_response(chunk) if chunk else {}
183
272
  elif self.response_stream is None:
184
273
  yield from []
185
274
  else:
186
275
  # Iterator of objects retrieved from the API.
187
- for chunk in self.response_stream.iter_lines():
188
- if chunk:
189
- # In streaming mode, the chunk of JSON is prefixed with "data:" which
190
- # we must strip before parsing.
191
- if not isinstance(chunk, str):
192
- chunk = chunk.decode('utf-8')
193
- if chunk.startswith('data: '):
194
- chunk = chunk[len('data: ') :]
195
- yield json.loads(chunk)
276
+ for chunk in self._iter_response_stream():
277
+ yield self._load_json_from_response(chunk)
196
278
 
197
279
  async def async_segments(self) -> AsyncIterator[Any]:
198
280
  if isinstance(self.response_stream, list):
199
281
  # list of objects retrieved from replay or from non-streaming API.
200
282
  for chunk in self.response_stream:
201
- yield json.loads(chunk) if chunk else {}
283
+ yield self._load_json_from_response(chunk) if chunk else {}
202
284
  elif self.response_stream is None:
203
- async for c in []:
285
+ async for c in []: # type: ignore[attr-defined]
204
286
  yield c
205
287
  else:
206
288
  # Iterator of objects retrieved from the API.
207
- if hasattr(self.response_stream, 'aiter_lines'):
208
- async for chunk in self.response_stream.aiter_lines():
209
- # This is httpx.Response.
210
- if chunk:
211
- # In async streaming mode, the chunk of JSON is prefixed with
212
- # "data:" which we must strip before parsing.
213
- if not isinstance(chunk, str):
214
- chunk = chunk.decode('utf-8')
215
- if chunk.startswith('data: '):
216
- chunk = chunk[len('data: ') :]
217
- yield json.loads(chunk)
218
- else:
219
- raise ValueError(
220
- 'Error parsing streaming response.'
221
- )
289
+ async for chunk in self._aiter_response_stream():
290
+ yield self._load_json_from_response(chunk)
222
291
 
223
- def byte_segments(self):
292
+ def byte_segments(self) -> Generator[Union[bytes, Any], None, None]:
224
293
  if isinstance(self.byte_stream, list):
225
294
  # list of objects retrieved from replay or from non-streaming API.
226
295
  yield from self.byte_stream
@@ -231,12 +300,199 @@ class HttpResponse:
231
300
  'Byte segments are not supported for streaming responses.'
232
301
  )
233
302
 
234
- def _copy_to_dict(self, response_payload: dict[str, object]):
303
+ def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
235
304
  # Cannot pickle 'generator' object.
236
305
  delattr(self, 'segment_iterator')
237
306
  for attribute in dir(self):
238
307
  response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
239
308
 
309
+ def _iter_response_stream(self) -> Iterator[str]:
310
+ """Iterates over chunks retrieved from the API."""
311
+ if not isinstance(self.response_stream, httpx.Response):
312
+ raise TypeError(
313
+ 'Expected self.response_stream to be an httpx.Response object, '
314
+ f'but got {type(self.response_stream).__name__}.'
315
+ )
316
+
317
+ chunk = ''
318
+ balance = 0
319
+ for line in self.response_stream.iter_lines():
320
+ if not line:
321
+ continue
322
+
323
+ # In streaming mode, the response of JSON is prefixed with "data: " which
324
+ # we must strip before parsing.
325
+ if line.startswith('data: '):
326
+ yield line[len('data: '):]
327
+ continue
328
+
329
+ # When API returns an error message, it comes line by line. So we buffer
330
+ # the lines until a complete JSON string is read. A complete JSON string
331
+ # is found when the balance is 0.
332
+ for c in line:
333
+ if c == '{':
334
+ balance += 1
335
+ elif c == '}':
336
+ balance -= 1
337
+
338
+ chunk += line
339
+ if balance == 0:
340
+ yield chunk
341
+ chunk = ''
342
+
343
+ # If there is any remaining chunk, yield it.
344
+ if chunk:
345
+ yield chunk
346
+
347
+ async def _aiter_response_stream(self) -> AsyncIterator[str]:
348
+ """Asynchronously iterates over chunks retrieved from the API."""
349
+ is_valid_response = isinstance(self.response_stream, httpx.Response) or (
350
+ has_aiohttp and isinstance(self.response_stream, aiohttp.ClientResponse)
351
+ )
352
+ if not is_valid_response:
353
+ raise TypeError(
354
+ 'Expected self.response_stream to be an httpx.Response or'
355
+ ' aiohttp.ClientResponse object, but got'
356
+ f' {type(self.response_stream).__name__}.'
357
+ )
358
+
359
+ chunk = ''
360
+ balance = 0
361
+ # httpx.Response has a dedicated async line iterator.
362
+ if isinstance(self.response_stream, httpx.Response):
363
+ try:
364
+ async for line in self.response_stream.aiter_lines():
365
+ if not line:
366
+ continue
367
+ # In streaming mode, the response of JSON is prefixed with "data: "
368
+ # which we must strip before parsing.
369
+ if line.startswith('data: '):
370
+ yield line[len('data: '):]
371
+ continue
372
+
373
+ # When API returns an error message, it comes line by line. So we buffer
374
+ # the lines until a complete JSON string is read. A complete JSON string
375
+ # is found when the balance is 0.
376
+ for c in line:
377
+ if c == '{':
378
+ balance += 1
379
+ elif c == '}':
380
+ balance -= 1
381
+
382
+ chunk += line
383
+ if balance == 0:
384
+ yield chunk
385
+ chunk = ''
386
+ # If there is any remaining chunk, yield it.
387
+ if chunk:
388
+ yield chunk
389
+ finally:
390
+ # Close the response and release the connection.
391
+ await self.response_stream.aclose()
392
+
393
+ # aiohttp.ClientResponse uses a content stream that we read line by line.
394
+ elif has_aiohttp and isinstance(
395
+ self.response_stream, aiohttp.ClientResponse
396
+ ):
397
+ try:
398
+ while True:
399
+ # Read a line from the stream. This returns bytes.
400
+ line_bytes = await self.response_stream.content.readline()
401
+ if not line_bytes:
402
+ break
403
+ # Decode the bytes and remove trailing whitespace and newlines.
404
+ line = line_bytes.decode('utf-8').rstrip()
405
+ if not line:
406
+ continue
407
+
408
+ # In streaming mode, the response of JSON is prefixed with "data: "
409
+ # which we must strip before parsing.
410
+ if line.startswith('data: '):
411
+ yield line[len('data: '):]
412
+ continue
413
+
414
+ # When API returns an error message, it comes line by line. So we
415
+ # buffer the lines until a complete JSON string is read. A complete
416
+ # JSON strings found when the balance is 0.
417
+ for c in line:
418
+ if c == '{':
419
+ balance += 1
420
+ elif c == '}':
421
+ balance -= 1
422
+
423
+ chunk += line
424
+ if balance == 0:
425
+ yield chunk
426
+ chunk = ''
427
+ # If there is any remaining chunk, yield it.
428
+ if chunk:
429
+ yield chunk
430
+ finally:
431
+ # Release the connection back to the pool for potential reuse.
432
+ self.response_stream.release()
433
+
434
+ @classmethod
435
+ def _load_json_from_response(cls, response: Any) -> Any:
436
+ """Loads JSON from the response, or raises an error if the parsing fails."""
437
+ try:
438
+ return json.loads(response)
439
+ except json.JSONDecodeError as e:
440
+ raise errors.UnknownApiResponseError(
441
+ f'Failed to parse response as JSON. Raw response: {response}'
442
+ ) from e
443
+
444
+
445
+ # Default retry options.
446
+ # The config is based on https://cloud.google.com/storage/docs/retry-strategy.
447
+ # By default, the client will retry 4 times with approximately 1.0, 2.0, 4.0,
448
+ # 8.0 seconds between each attempt.
449
+ _RETRY_ATTEMPTS = 5 # including the initial call.
450
+ _RETRY_INITIAL_DELAY = 1.0 # seconds
451
+ _RETRY_MAX_DELAY = 60.0 # seconds
452
+ _RETRY_EXP_BASE = 2
453
+ _RETRY_JITTER = 1
454
+ _RETRY_HTTP_STATUS_CODES = (
455
+ 408, # Request timeout.
456
+ 429, # Too many requests.
457
+ 500, # Internal server error.
458
+ 502, # Bad gateway.
459
+ 503, # Service unavailable.
460
+ 504, # Gateway timeout
461
+ )
462
+
463
+
464
+ def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
465
+ """Returns the retry args for the given http retry options.
466
+
467
+ Args:
468
+ options: The http retry options to use for the retry configuration. If None,
469
+ the 'never retry' stop strategy will be used.
470
+
471
+ Returns:
472
+ The arguments passed to the tenacity.(Async)Retrying constructor.
473
+ """
474
+ if options is None:
475
+ return {'stop': tenacity.stop_after_attempt(1), 'reraise': True}
476
+
477
+ stop = tenacity.stop_after_attempt(options.attempts or _RETRY_ATTEMPTS)
478
+ retriable_codes = options.http_status_codes or _RETRY_HTTP_STATUS_CODES
479
+ retry = tenacity.retry_if_exception(
480
+ lambda e: isinstance(e, errors.APIError) and e.code in retriable_codes,
481
+ )
482
+ wait = tenacity.wait_exponential_jitter(
483
+ initial=options.initial_delay or _RETRY_INITIAL_DELAY,
484
+ max=options.max_delay or _RETRY_MAX_DELAY,
485
+ exp_base=options.exp_base or _RETRY_EXP_BASE,
486
+ jitter=options.jitter or _RETRY_JITTER,
487
+ )
488
+ return {
489
+ 'stop': stop,
490
+ 'retry': retry,
491
+ 'reraise': True,
492
+ 'wait': wait,
493
+ 'before_sleep': tenacity.before_sleep_log(logger, logging.INFO),
494
+ }
495
+
240
496
 
241
497
  class SyncHttpxClient(httpx.Client):
242
498
  """Sync httpx client."""
@@ -248,8 +504,11 @@ class SyncHttpxClient(httpx.Client):
248
504
 
249
505
  def __del__(self) -> None:
250
506
  """Closes the httpx client."""
251
- if self.is_closed:
252
- return
507
+ try:
508
+ if self.is_closed:
509
+ return
510
+ except Exception:
511
+ pass
253
512
  try:
254
513
  self.close()
255
514
  except Exception:
@@ -265,8 +524,11 @@ class AsyncHttpxClient(httpx.AsyncClient):
265
524
  super().__init__(**kwargs)
266
525
 
267
526
  def __del__(self) -> None:
268
- if self.is_closed:
269
- return
527
+ try:
528
+ if self.is_closed:
529
+ return
530
+ except Exception:
531
+ pass
270
532
  try:
271
533
  asyncio.get_running_loop().create_task(self.aclose())
272
534
  except Exception:
@@ -286,6 +548,7 @@ class BaseApiClient:
286
548
  http_options: Optional[HttpOptionsOrDict] = None,
287
549
  ):
288
550
  self.vertexai = vertexai
551
+ self.custom_base_url = None
289
552
  if self.vertexai is None:
290
553
  if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
291
554
  'true',
@@ -308,36 +571,42 @@ class BaseApiClient:
308
571
  )
309
572
 
310
573
  # Validate http_options if it is provided.
311
- validated_http_options: dict[str, Any]
574
+ validated_http_options = HttpOptions()
312
575
  if isinstance(http_options, dict):
313
576
  try:
314
- validated_http_options = HttpOptions.model_validate(
315
- http_options
316
- ).model_dump()
577
+ validated_http_options = HttpOptions.model_validate(http_options)
317
578
  except ValidationError as e:
318
- raise ValueError(f'Invalid http_options: {e}')
579
+ raise ValueError('Invalid http_options') from e
319
580
  elif isinstance(http_options, HttpOptions):
320
- validated_http_options = http_options.model_dump()
581
+ validated_http_options = http_options
582
+
583
+ if validated_http_options.base_url_resource_scope and not validated_http_options.base_url:
584
+ # base_url_resource_scope is only valid when base_url is set.
585
+ raise ValueError(
586
+ 'base_url must be set when base_url_resource_scope is set.'
587
+ )
321
588
 
322
589
  # Retrieve implicitly set values from the environment.
323
590
  env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
324
591
  env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None)
325
- env_api_key = os.environ.get('GOOGLE_API_KEY', None)
592
+ env_api_key = get_env_api_key()
326
593
  self.project = project or env_project
327
594
  self.location = location or env_location
328
595
  self.api_key = api_key or env_api_key
329
596
 
330
597
  self._credentials = credentials
331
- self._http_options = HttpOptionsDict()
598
+ self._http_options = HttpOptions()
332
599
  # Initialize the lock. This lock will be used to protect access to the
333
600
  # credentials. This is crucial for thread safety when multiple coroutines
334
601
  # might be accessing the credentials at the same time.
335
- self._auth_lock = asyncio.Lock()
602
+ self._sync_auth_lock = threading.Lock()
603
+ self._async_auth_lock: Optional[asyncio.Lock] = None
604
+ self._async_auth_lock_creation_lock: Optional[asyncio.Lock] = None
336
605
 
337
606
  # Handle when to use Vertex AI in express mode (api key).
338
607
  # Explicit initializer arguments are already validated above.
339
608
  if self.vertexai:
340
- if credentials:
609
+ if credentials and env_api_key:
341
610
  # Explicit credentials take precedence over implicit api_key.
342
611
  logger.info(
343
612
  'The user provided Google Cloud credentials will take precedence'
@@ -366,94 +635,412 @@ class BaseApiClient:
366
635
  + ' precedence over the API key from the environment variables.'
367
636
  )
368
637
  self.api_key = None
369
- if not self.project and not self.api_key:
370
- credentials, self.project = _load_auth(project=None)
638
+
639
+ self.custom_base_url = (
640
+ validated_http_options.base_url
641
+ if validated_http_options.base_url
642
+ else None
643
+ )
644
+
645
+ if not self.location and not self.api_key and not self.custom_base_url:
646
+ self.location = 'global'
647
+
648
+ # Skip fetching project from ADC if base url is provided in http options.
649
+ if (
650
+ not self.project
651
+ and not self.api_key
652
+ and not self.custom_base_url
653
+ ):
654
+ credentials, self.project = load_auth(project=None)
371
655
  if not self._credentials:
372
656
  self._credentials = credentials
373
- if not ((self.project and self.location) or self.api_key):
657
+
658
+ has_sufficient_auth = (self.project and self.location) or self.api_key
659
+
660
+ if not has_sufficient_auth and not self.custom_base_url:
661
+ # Skip sufficient auth check if base url is provided in http options.
374
662
  raise ValueError(
375
- 'Project and location or API key must be set when using the Vertex '
663
+ 'Project or API key must be set when using the Vertex '
376
664
  'AI API.'
377
665
  )
378
666
  if self.api_key or self.location == 'global':
379
- self._http_options['base_url'] = f'https://aiplatform.googleapis.com/'
667
+ self._http_options.base_url = f'https://aiplatform.googleapis.com/'
668
+ elif self.custom_base_url and not ((project and location) or api_key):
669
+ # Avoid setting default base url and api version if base_url provided.
670
+ # API gateway proxy can use the auth in custom headers, not url.
671
+ # Enable custom url if auth is not sufficient.
672
+ self._http_options.base_url = self.custom_base_url
673
+ # Clear project and location if base_url is provided.
674
+ self.project = None
675
+ self.location = None
380
676
  else:
381
- self._http_options['base_url'] = (
677
+ self._http_options.base_url = (
382
678
  f'https://{self.location}-aiplatform.googleapis.com/'
383
679
  )
384
- self._http_options['api_version'] = 'v1beta1'
680
+ self._http_options.api_version = 'v1beta1'
385
681
  else: # Implicit initialization or missing arguments.
386
682
  if not self.api_key:
387
683
  raise ValueError(
388
684
  'Missing key inputs argument! To use the Google AI API,'
389
- 'provide (`api_key`) arguments. To use the Google Cloud API,'
685
+ ' provide (`api_key`) arguments. To use the Google Cloud API,'
390
686
  ' provide (`vertexai`, `project` & `location`) arguments.'
391
687
  )
392
- self._http_options['base_url'] = (
393
- 'https://generativelanguage.googleapis.com/'
394
- )
395
- self._http_options['api_version'] = 'v1beta'
688
+ self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
689
+ self._http_options.api_version = 'v1beta'
396
690
  # Default options for both clients.
397
- self._http_options['headers'] = {'Content-Type': 'application/json'}
691
+ self._http_options.headers = {'Content-Type': 'application/json'}
398
692
  if self.api_key:
399
- self._http_options['headers']['x-goog-api-key'] = self.api_key
693
+ self.api_key = self.api_key.strip()
694
+ if self._http_options.headers is not None:
695
+ self._http_options.headers['x-goog-api-key'] = self.api_key
400
696
  # Update the http options with the user provided http options.
401
697
  if http_options:
402
- self._http_options = _patch_http_options(
698
+ self._http_options = patch_http_options(
403
699
  self._http_options, validated_http_options
404
700
  )
405
701
  else:
406
- _append_library_version_headers(self._http_options['headers'])
407
- # Initialize the httpx client.
408
- self._httpx_client = SyncHttpxClient()
409
- self._async_httpx_client = AsyncHttpxClient()
702
+ if self._http_options.headers is not None:
703
+ append_library_version_headers(self._http_options.headers)
704
+
705
+ client_args, async_client_args = self._ensure_httpx_ssl_ctx(
706
+ self._http_options
707
+ )
708
+ self._async_httpx_client_args = async_client_args
709
+
710
+ if self._http_options.httpx_client:
711
+ self._httpx_client = self._http_options.httpx_client
712
+ else:
713
+ self._httpx_client = SyncHttpxClient(**client_args)
714
+ if self._http_options.httpx_async_client:
715
+ self._async_httpx_client = self._http_options.httpx_async_client
716
+ else:
717
+ self._async_httpx_client = AsyncHttpxClient(**async_client_args)
718
+ if self._use_aiohttp():
719
+ try:
720
+ import aiohttp # pylint: disable=g-import-not-at-top
721
+ # Do it once at the genai.Client level. Share among all requests.
722
+ self._async_client_session_request_args = self._ensure_aiohttp_ssl_ctx(
723
+ self._http_options
724
+ )
725
+ except ImportError:
726
+ pass
727
+
728
+ # Initialize the aiohttp client session.
729
+ self._aiohttp_session: Optional[aiohttp.ClientSession] = None
730
+
731
+ retry_kwargs = retry_args(self._http_options.retry_options)
732
+ self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(self._http_options)
733
+ self._retry = tenacity.Retrying(**retry_kwargs)
734
+ self._async_retry = tenacity.AsyncRetrying(**retry_kwargs)
735
+
736
+ async def _get_aiohttp_session(self) -> 'aiohttp.ClientSession':
737
+ """Returns the aiohttp client session."""
738
+ if (
739
+ self._aiohttp_session is None
740
+ or self._aiohttp_session.closed
741
+ or self._aiohttp_session._loop.is_closed() # pylint: disable=protected-access
742
+ ):
743
+ # Initialize the aiohttp client session if it's not set up or closed.
744
+ class AiohttpClientSession(aiohttp.ClientSession): # type: ignore[misc]
745
+
746
+ def __del__(self, _warnings: Any = warnings) -> None:
747
+ if not self.closed:
748
+ context = {
749
+ 'client_session': self,
750
+ 'message': 'Unclosed client session',
751
+ }
752
+ if self._source_traceback is not None:
753
+ context['source_traceback'] = self._source_traceback
754
+ # Remove this self._loop.call_exception_handler(context)
755
+
756
+ class AiohttpTCPConnector(aiohttp.TCPConnector): # type: ignore[misc]
757
+
758
+ def __del__(self, _warnings: Any = warnings) -> None:
759
+ if self._closed:
760
+ return
761
+ if not self._conns:
762
+ return
763
+ conns = [repr(c) for c in self._conns.values()]
764
+ # After v3.13.2, it may change to self._close_immediately()
765
+ self._close()
766
+ context = {
767
+ 'connector': self,
768
+ 'connections': conns,
769
+ 'message': 'Unclosed connector',
770
+ }
771
+ if self._source_traceback is not None:
772
+ context['source_traceback'] = self._source_traceback
773
+ # Remove this self._loop.call_exception_handler(context)
774
+ self._aiohttp_session = AiohttpClientSession(
775
+ connector=AiohttpTCPConnector(limit=0),
776
+ trust_env=True,
777
+ read_bufsize=READ_BUFFER_SIZE,
778
+ )
779
+ return self._aiohttp_session
780
+
781
+ @staticmethod
782
+ def _ensure_httpx_ssl_ctx(
783
+ options: HttpOptions,
784
+ ) -> Tuple[_common.StringDict, _common.StringDict]:
785
+ """Ensures the SSL context is present in the HTTPX client args.
410
786
 
411
- def _websocket_base_url(self):
412
- url_parts = urlparse(self._http_options['base_url'])
413
- return url_parts._replace(scheme='wss').geturl()
787
+ Creates a default SSL context if one is not provided.
788
+
789
+ Args:
790
+ options: The http options to check for SSL context.
791
+
792
+ Returns:
793
+ A tuple of sync/async httpx client args.
794
+ """
795
+
796
+ verify = 'verify'
797
+ args = options.client_args
798
+ async_args = options.async_client_args
799
+ ctx = (
800
+ args.get(verify)
801
+ if args
802
+ else None or async_args.get(verify)
803
+ if async_args
804
+ else None
805
+ )
806
+
807
+ if not ctx:
808
+ # Initialize the SSL context for the httpx client.
809
+ # Unlike requests, the httpx package does not automatically pull in the
810
+ # environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
811
+ # enabled explicitly.
812
+ ctx = ssl.create_default_context(
813
+ cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
814
+ capath=os.environ.get('SSL_CERT_DIR'),
815
+ )
816
+
817
+ def _maybe_set(
818
+ args: Optional[_common.StringDict],
819
+ ctx: ssl.SSLContext,
820
+ ) -> _common.StringDict:
821
+ """Sets the SSL context in the client args if not set.
822
+
823
+ Does not override the SSL context if it is already set.
824
+
825
+ Args:
826
+ args: The client args to to check for SSL context.
827
+ ctx: The SSL context to set.
828
+
829
+ Returns:
830
+ The client args with the SSL context included.
831
+ """
832
+ if not args or not args.get(verify):
833
+ args = (args or {}).copy()
834
+ args[verify] = ctx
835
+ # Drop the args that isn't used by the httpx client.
836
+ copied_args = args.copy()
837
+ for key in copied_args.copy():
838
+ if key not in inspect.signature(httpx.Client.__init__).parameters:
839
+ del copied_args[key]
840
+ return copied_args
841
+
842
+ return (
843
+ _maybe_set(args, ctx),
844
+ _maybe_set(async_args, ctx),
845
+ )
846
+
847
+ @staticmethod
848
+ def _ensure_aiohttp_ssl_ctx(options: HttpOptions) -> _common.StringDict:
849
+ """Ensures the SSL context is present in the async client args.
850
+
851
+ Creates a default SSL context if one is not provided.
852
+
853
+ Args:
854
+ options: The http options to check for SSL context.
855
+
856
+ Returns:
857
+ An async aiohttp ClientSession._request args.
858
+ """
859
+ verify = 'ssl' # keep it consistent with httpx.
860
+ async_args = options.async_client_args
861
+ ctx = async_args.get(verify) if async_args else None
862
+
863
+ if not ctx:
864
+ # Initialize the SSL context for the httpx client.
865
+ # Unlike requests, the aiohttp package does not automatically pull in the
866
+ # environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
867
+ # enabled explicitly. Instead of 'verify' at client level in httpx,
868
+ # aiohttp uses 'ssl' at request level.
869
+ ctx = ssl.create_default_context(
870
+ cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
871
+ capath=os.environ.get('SSL_CERT_DIR'),
872
+ )
873
+
874
+ def _maybe_set(
875
+ args: Optional[_common.StringDict],
876
+ ctx: ssl.SSLContext,
877
+ ) -> _common.StringDict:
878
+ """Sets the SSL context in the client args if not set.
879
+
880
+ Does not override the SSL context if it is already set.
881
+
882
+ Args:
883
+ args: The client args to to check for SSL context.
884
+ ctx: The SSL context to set.
885
+
886
+ Returns:
887
+ The client args with the SSL context included.
888
+ """
889
+ if not args or not args.get(verify):
890
+ args = (args or {}).copy()
891
+ args[verify] = ctx
892
+ # Drop the args that isn't in the aiohttp RequestOptions.
893
+ copied_args = args.copy()
894
+ for key in copied_args.copy():
895
+ if (
896
+ key
897
+ not in inspect.signature(aiohttp.ClientSession._request).parameters
898
+ ):
899
+ del copied_args[key]
900
+ return copied_args
901
+
902
+ return _maybe_set(async_args, ctx)
903
+
904
+ @staticmethod
905
+ def _ensure_websocket_ssl_ctx(options: HttpOptions) -> _common.StringDict:
906
+ """Ensures the SSL context is present in the async client args.
907
+
908
+ Creates a default SSL context if one is not provided.
909
+
910
+ Args:
911
+ options: The http options to check for SSL context.
912
+
913
+ Returns:
914
+ An async aiohttp ClientSession._request args.
915
+ """
916
+
917
+ verify = 'ssl' # keep it consistent with httpx.
918
+ async_args = options.async_client_args
919
+ ctx = async_args.get(verify) if async_args else None
920
+
921
+ if not ctx:
922
+ # Initialize the SSL context for the httpx client.
923
+ # Unlike requests, the aiohttp package does not automatically pull in the
924
+ # environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
925
+ # enabled explicitly. Instead of 'verify' at client level in httpx,
926
+ # aiohttp uses 'ssl' at request level.
927
+ ctx = ssl.create_default_context(
928
+ cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
929
+ capath=os.environ.get('SSL_CERT_DIR'),
930
+ )
931
+
932
+ def _maybe_set(
933
+ args: Optional[_common.StringDict],
934
+ ctx: ssl.SSLContext,
935
+ ) -> _common.StringDict:
936
+ """Sets the SSL context in the client args if not set.
937
+
938
+ Does not override the SSL context if it is already set.
939
+
940
+ Args:
941
+ args: The client args to to check for SSL context.
942
+ ctx: The SSL context to set.
943
+
944
+ Returns:
945
+ The client args with the SSL context included.
946
+ """
947
+ if not args or not args.get(verify):
948
+ args = (args or {}).copy()
949
+ args[verify] = ctx
950
+ # Drop the args that isn't in the aiohttp RequestOptions.
951
+ copied_args = args.copy()
952
+ for key in copied_args.copy():
953
+ if key not in inspect.signature(ws_connect).parameters and key != 'ssl':
954
+ del copied_args[key]
955
+ return copied_args
956
+
957
+ return _maybe_set(async_args, ctx)
958
+
959
+ def _use_aiohttp(self) -> bool:
960
+ # If the instantiator has passed a custom transport, they want httpx not
961
+ # aiohttp.
962
+ return (
963
+ has_aiohttp
964
+ and (self._http_options.async_client_args or {}).get('transport')
965
+ is None
966
+ and (self._http_options.httpx_async_client is None)
967
+ )
968
+
969
+ def _websocket_base_url(self) -> str:
970
+ has_sufficient_auth = (self.project and self.location) or self.api_key
971
+ if self.custom_base_url and not has_sufficient_auth:
972
+ # API gateway proxy can use the auth in custom headers, not url.
973
+ # Enable custom url if auth is not sufficient.
974
+ return self.custom_base_url
975
+ url_parts = urlparse(self._http_options.base_url)
976
+ return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
414
977
 
415
978
  def _access_token(self) -> str:
416
979
  """Retrieves the access token for the credentials."""
417
- if not self._credentials:
418
- self._credentials, project = _load_auth(project=self.project)
419
- if not self.project:
420
- self.project = project
421
-
422
- if self._credentials:
423
- if (
424
- self._credentials.expired or not self._credentials.token
425
- ):
426
- # Only refresh when it needs to. Default expiration is 3600 seconds.
427
- _refresh_auth(self._credentials)
428
- if not self._credentials.token:
980
+ with self._sync_auth_lock:
981
+ if not self._credentials:
982
+ self._credentials, project = load_auth(project=self.project)
983
+ if not self.project:
984
+ self.project = project
985
+
986
+ if self._credentials:
987
+ if self._credentials.expired or not self._credentials.token:
988
+ # Only refresh when it needs to. Default expiration is 3600 seconds.
989
+ refresh_auth(self._credentials)
990
+ if not self._credentials.token:
991
+ raise RuntimeError('Could not resolve API token from the environment')
992
+ return self._credentials.token # type: ignore[no-any-return]
993
+ else:
429
994
  raise RuntimeError('Could not resolve API token from the environment')
430
- return self._credentials.token
431
- else:
432
- raise RuntimeError('Could not resolve API token from the environment')
433
995
 
434
- async def _async_access_token(self) -> str:
996
+ async def _get_async_auth_lock(self) -> asyncio.Lock:
997
+ """Lazily initializes and returns an asyncio.Lock for async authentication.
998
+
999
+ This method ensures that a single `asyncio.Lock` instance is created and
1000
+ shared among all asynchronous operations that require authentication,
1001
+ preventing race conditions when accessing or refreshing credentials.
1002
+
1003
+ The lock is created on the first call to this method. An internal async lock
1004
+ is used to protect the creation of the main authentication lock to ensure
1005
+ it's a singleton within the client instance.
1006
+
1007
+ Returns:
1008
+ The asyncio.Lock instance for asynchronous authentication operations.
1009
+ """
1010
+ if self._async_auth_lock is None:
1011
+ # Create async creation lock if needed
1012
+ if self._async_auth_lock_creation_lock is None:
1013
+ self._async_auth_lock_creation_lock = asyncio.Lock()
1014
+
1015
+ async with self._async_auth_lock_creation_lock:
1016
+ if self._async_auth_lock is None:
1017
+ self._async_auth_lock = asyncio.Lock()
1018
+
1019
+ return self._async_auth_lock
1020
+
1021
+ async def _async_access_token(self) -> Union[str, Any]:
435
1022
  """Retrieves the access token for the credentials asynchronously."""
436
1023
  if not self._credentials:
437
- async with self._auth_lock:
1024
+ async_auth_lock = await self._get_async_auth_lock()
1025
+ async with async_auth_lock:
438
1026
  # This ensures that only one coroutine can execute the auth logic at a
439
1027
  # time for thread safety.
440
1028
  if not self._credentials:
441
1029
  # Double check that the credentials are not set before loading them.
442
1030
  self._credentials, project = await asyncio.to_thread(
443
- _load_auth, project=self.project
1031
+ load_auth, project=self.project
444
1032
  )
445
1033
  if not self.project:
446
1034
  self.project = project
447
1035
 
448
1036
  if self._credentials:
449
- if (
450
- self._credentials.expired or not self._credentials.token
451
- ):
1037
+ if self._credentials.expired or not self._credentials.token:
452
1038
  # Only refresh when it needs to. Default expiration is 3600 seconds.
453
- async with self._auth_lock:
1039
+ async_auth_lock = await self._get_async_auth_lock()
1040
+ async with async_auth_lock:
454
1041
  if self._credentials.expired or not self._credentials.token:
455
1042
  # Double check that the credentials expired before refreshing.
456
- await asyncio.to_thread(_refresh_auth, self._credentials)
1043
+ await asyncio.to_thread(refresh_auth, self._credentials)
457
1044
 
458
1045
  if not self._credentials.token:
459
1046
  raise RuntimeError('Could not resolve API token from the environment')
@@ -476,12 +1063,13 @@ class BaseApiClient:
476
1063
  # patch the http options with the user provided settings.
477
1064
  if http_options:
478
1065
  if isinstance(http_options, HttpOptions):
479
- patched_http_options = _patch_http_options(
480
- self._http_options, http_options.model_dump()
1066
+ patched_http_options = patch_http_options(
1067
+ self._http_options,
1068
+ http_options,
481
1069
  )
482
1070
  else:
483
- patched_http_options = _patch_http_options(
484
- self._http_options, http_options
1071
+ patched_http_options = patch_http_options(
1072
+ self._http_options, HttpOptions.model_validate(http_options)
485
1073
  )
486
1074
  else:
487
1075
  patched_http_options = self._http_options
@@ -497,42 +1085,86 @@ class BaseApiClient:
497
1085
  self.vertexai
498
1086
  and not path.startswith('projects/')
499
1087
  and not query_vertex_base_models
500
- and not self.api_key
1088
+ and (self.project or self.location)
1089
+ and not (
1090
+ self.custom_base_url
1091
+ and patched_http_options.base_url_resource_scope
1092
+ == ResourceScope.COLLECTION
1093
+ )
501
1094
  ):
502
1095
  path = f'projects/{self.project}/locations/{self.location}/' + path
503
- url = _join_url_path(
504
- patched_http_options.get('base_url', ''),
505
- patched_http_options.get('api_version', '') + '/' + path,
506
- )
507
1096
 
508
- timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
509
- 'timeout', None
510
- )
511
- if timeout_in_seconds:
512
- # HttpOptions.timeout is in milliseconds. But httpx.Client.request()
513
- # expects seconds.
514
- timeout_in_seconds = timeout_in_seconds / 1000.0
1097
+ if patched_http_options.api_version is None:
1098
+ versioned_path = f'/{path}'
1099
+ else:
1100
+ versioned_path = f'{patched_http_options.api_version}/{path}'
1101
+
1102
+ if (
1103
+ patched_http_options.base_url is None
1104
+ or not patched_http_options.base_url
1105
+ ):
1106
+ raise ValueError('Base URL must be set.')
515
1107
  else:
516
- timeout_in_seconds = None
1108
+ base_url = patched_http_options.base_url
1109
+
1110
+ if (
1111
+ hasattr(patched_http_options, 'extra_body')
1112
+ and patched_http_options.extra_body
1113
+ ):
1114
+ _common.recursive_dict_update(
1115
+ request_dict, patched_http_options.extra_body
1116
+ )
1117
+ url = base_url
1118
+ if (
1119
+ not self.custom_base_url
1120
+ or (self.project and self.location)
1121
+ or self.api_key
1122
+ ):
1123
+ if (
1124
+ patched_http_options.base_url_resource_scope
1125
+ == ResourceScope.COLLECTION
1126
+ ):
1127
+ url = join_url_path(base_url, path)
1128
+ else:
1129
+ url = join_url_path(
1130
+ base_url,
1131
+ versioned_path,
1132
+ )
1133
+ elif(
1134
+ self.custom_base_url
1135
+ and patched_http_options.base_url_resource_scope == ResourceScope.COLLECTION
1136
+ ):
1137
+ url = join_url_path(base_url, path)
1138
+
1139
+ if self.api_key and self.api_key.startswith('auth_tokens/'):
1140
+ raise EphemeralTokenAPIKeyError(
1141
+ 'Ephemeral tokens can only be used with the live API.'
1142
+ )
517
1143
 
1144
+ timeout_in_seconds = get_timeout_in_seconds(patched_http_options.timeout)
1145
+
1146
+ if patched_http_options.headers is None:
1147
+ raise ValueError('Request headers must be set.')
1148
+ populate_server_timeout_header(
1149
+ patched_http_options.headers, timeout_in_seconds
1150
+ )
518
1151
  return HttpRequest(
519
1152
  method=http_method,
520
1153
  url=url,
521
- headers=patched_http_options['headers'],
1154
+ headers=patched_http_options.headers,
522
1155
  data=request_dict,
523
1156
  timeout=timeout_in_seconds,
524
1157
  )
525
1158
 
526
- def _request(
1159
+ def _request_once(
527
1160
  self,
528
1161
  http_request: HttpRequest,
529
1162
  stream: bool = False,
530
1163
  ) -> HttpResponse:
531
1164
  data: Optional[Union[str, bytes]] = None
532
- if self.vertexai and not self.api_key:
533
- http_request.headers['Authorization'] = (
534
- f'Bearer {self._access_token()}'
535
- )
1165
+ # If using proj/location, fetch ADC
1166
+ if self.vertexai and (self.project or self.location):
1167
+ http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
536
1168
  if self._credentials and self._credentials.quota_project_id:
537
1169
  http_request.headers['x-goog-user-project'] = (
538
1170
  self._credentials.quota_project_id
@@ -571,11 +1203,33 @@ class BaseApiClient:
571
1203
  response.headers, response if stream else [response.text]
572
1204
  )
573
1205
 
574
- async def _async_request(
1206
+ def _request(
1207
+ self,
1208
+ http_request: HttpRequest,
1209
+ http_options: Optional[HttpOptionsOrDict] = None,
1210
+ stream: bool = False,
1211
+ ) -> HttpResponse:
1212
+ if http_options:
1213
+ parameter_model = (
1214
+ HttpOptions(**http_options)
1215
+ if isinstance(http_options, dict)
1216
+ else http_options
1217
+ )
1218
+ # Support per request retry options.
1219
+ if parameter_model.retry_options:
1220
+ retry_kwargs = retry_args(parameter_model.retry_options)
1221
+ retry = tenacity.Retrying(**retry_kwargs)
1222
+ return retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
1223
+
1224
+ return self._retry(self._request_once, http_request, stream) # type: ignore[no-any-return]
1225
+
1226
+ async def _async_request_once(
575
1227
  self, http_request: HttpRequest, stream: bool = False
576
- ):
1228
+ ) -> HttpResponse:
577
1229
  data: Optional[Union[str, bytes]] = None
578
- if self.vertexai and not self.api_key:
1230
+
1231
+ # If using proj/location, fetch ADC
1232
+ if self.vertexai and (self.project or self.location):
579
1233
  http_request.headers['Authorization'] = (
580
1234
  f'Bearer {await self._async_access_token()}'
581
1235
  )
@@ -592,39 +1246,133 @@ class BaseApiClient:
592
1246
  data = http_request.data
593
1247
 
594
1248
  if stream:
595
- httpx_request = self._async_httpx_client.build_request(
596
- method=http_request.method,
597
- url=http_request.url,
598
- content=data,
599
- headers=http_request.headers,
600
- timeout=http_request.timeout,
601
- )
602
- response = await self._async_httpx_client.send(
603
- httpx_request,
604
- stream=stream,
605
- )
606
- errors.APIError.raise_for_response(response)
607
- return HttpResponse(
608
- response.headers, response if stream else [response.text]
609
- )
1249
+ if self._use_aiohttp():
1250
+ self._aiohttp_session = await self._get_aiohttp_session()
1251
+ try:
1252
+ response = await self._aiohttp_session.request(
1253
+ method=http_request.method,
1254
+ url=http_request.url,
1255
+ headers=http_request.headers,
1256
+ data=data,
1257
+ timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1258
+ **self._async_client_session_request_args,
1259
+ )
1260
+ except (
1261
+ aiohttp.ClientConnectorError,
1262
+ aiohttp.ClientConnectorDNSError,
1263
+ aiohttp.ClientOSError,
1264
+ aiohttp.ServerDisconnectedError,
1265
+ ) as e:
1266
+ await asyncio.sleep(1 + random.randint(0, 9))
1267
+ logger.info('Retrying due to aiohttp error: %s' % e)
1268
+ # Retrieve the SSL context from the session.
1269
+ self._async_client_session_request_args = (
1270
+ self._ensure_aiohttp_ssl_ctx(self._http_options)
1271
+ )
1272
+ # Instantiate a new session with the updated SSL context.
1273
+ self._aiohttp_session = await self._get_aiohttp_session()
1274
+ response = await self._aiohttp_session.request(
1275
+ method=http_request.method,
1276
+ url=http_request.url,
1277
+ headers=http_request.headers,
1278
+ data=data,
1279
+ timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1280
+ **self._async_client_session_request_args,
1281
+ )
1282
+
1283
+ await errors.APIError.raise_for_async_response(response)
1284
+ return HttpResponse(response.headers, response)
1285
+ else:
1286
+ # aiohttp is not available. Fall back to httpx.
1287
+ httpx_request = self._async_httpx_client.build_request(
1288
+ method=http_request.method,
1289
+ url=http_request.url,
1290
+ content=data,
1291
+ headers=http_request.headers,
1292
+ timeout=http_request.timeout,
1293
+ )
1294
+ client_response = await self._async_httpx_client.send(
1295
+ httpx_request,
1296
+ stream=stream,
1297
+ )
1298
+ await errors.APIError.raise_for_async_response(client_response)
1299
+ return HttpResponse(client_response.headers, client_response)
610
1300
  else:
611
- response = await self._async_httpx_client.request(
612
- method=http_request.method,
613
- url=http_request.url,
614
- headers=http_request.headers,
615
- content=data,
616
- timeout=http_request.timeout,
617
- )
618
- errors.APIError.raise_for_response(response)
619
- return HttpResponse(
620
- response.headers, response if stream else [response.text]
1301
+ if self._use_aiohttp():
1302
+ self._aiohttp_session = await self._get_aiohttp_session()
1303
+ try:
1304
+ response = await self._aiohttp_session.request(
1305
+ method=http_request.method,
1306
+ url=http_request.url,
1307
+ headers=http_request.headers,
1308
+ data=data,
1309
+ timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1310
+ **self._async_client_session_request_args,
1311
+ )
1312
+ await errors.APIError.raise_for_async_response(response)
1313
+ return HttpResponse(response.headers, [await response.text()])
1314
+ except (
1315
+ aiohttp.ClientConnectorError,
1316
+ aiohttp.ClientConnectorDNSError,
1317
+ aiohttp.ClientOSError,
1318
+ aiohttp.ServerDisconnectedError,
1319
+ ) as e:
1320
+ await asyncio.sleep(1 + random.randint(0, 9))
1321
+ logger.info('Retrying due to aiohttp error: %s' % e)
1322
+ # Retrieve the SSL context from the session.
1323
+ self._async_client_session_request_args = (
1324
+ self._ensure_aiohttp_ssl_ctx(self._http_options)
1325
+ )
1326
+ # Instantiate a new session with the updated SSL context.
1327
+ self._aiohttp_session = await self._get_aiohttp_session()
1328
+ response = await self._aiohttp_session.request(
1329
+ method=http_request.method,
1330
+ url=http_request.url,
1331
+ headers=http_request.headers,
1332
+ data=data,
1333
+ timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1334
+ **self._async_client_session_request_args,
1335
+ )
1336
+ await errors.APIError.raise_for_async_response(response)
1337
+ return HttpResponse(response.headers, [await response.text()])
1338
+ else:
1339
+ # aiohttp is not available. Fall back to httpx.
1340
+ client_response = await self._async_httpx_client.request(
1341
+ method=http_request.method,
1342
+ url=http_request.url,
1343
+ headers=http_request.headers,
1344
+ content=data,
1345
+ timeout=http_request.timeout,
1346
+ )
1347
+ await errors.APIError.raise_for_async_response(client_response)
1348
+ return HttpResponse(client_response.headers, [client_response.text])
1349
+
1350
+ async def _async_request(
1351
+ self,
1352
+ http_request: HttpRequest,
1353
+ http_options: Optional[HttpOptionsOrDict] = None,
1354
+ stream: bool = False,
1355
+ ) -> HttpResponse:
1356
+ if http_options:
1357
+ parameter_model = (
1358
+ HttpOptions(**http_options)
1359
+ if isinstance(http_options, dict)
1360
+ else http_options
621
1361
  )
1362
+ # Support per request retry options.
1363
+ if parameter_model.retry_options:
1364
+ retry_kwargs = retry_args(parameter_model.retry_options)
1365
+ retry = tenacity.AsyncRetrying(**retry_kwargs)
1366
+ return await retry(self._async_request_once, http_request, stream) # type: ignore[no-any-return]
1367
+ return await self._async_retry( # type: ignore[no-any-return]
1368
+ self._async_request_once, http_request, stream
1369
+ )
622
1370
 
623
- def get_read_only_http_options(self) -> HttpOptionsDict:
624
- copied = HttpOptionsDict()
1371
+ def get_read_only_http_options(self) -> _common.StringDict:
625
1372
  if isinstance(self._http_options, BaseModel):
626
- self._http_options = self._http_options.model_dump()
627
- copied.update(self._http_options)
1373
+ copied = self._http_options.model_dump()
1374
+ else:
1375
+ copied = self._http_options
628
1376
  return copied
629
1377
 
630
1378
  def request(
@@ -633,32 +1381,44 @@ class BaseApiClient:
633
1381
  path: str,
634
1382
  request_dict: dict[str, object],
635
1383
  http_options: Optional[HttpOptionsOrDict] = None,
636
- ):
1384
+ ) -> SdkHttpResponse:
637
1385
  http_request = self._build_request(
638
1386
  http_method, path, request_dict, http_options
639
1387
  )
640
- response = self._request(http_request, stream=False)
641
- json_response = response.json
642
- if not json_response:
643
- return BaseResponse(http_headers=response.headers).model_dump(
644
- by_alias=True
645
- )
646
- return json_response
1388
+ response = self._request(http_request, http_options, stream=False)
1389
+ response_body = (
1390
+ response.response_stream[0] if response.response_stream else ''
1391
+ )
1392
+ return SdkHttpResponse(headers=response.headers, body=response_body)
647
1393
 
648
1394
  def request_streamed(
649
1395
  self,
650
1396
  http_method: str,
651
1397
  path: str,
652
1398
  request_dict: dict[str, object],
653
- http_options: Optional[HttpOptionsDict] = None,
654
- ):
1399
+ http_options: Optional[HttpOptionsOrDict] = None,
1400
+ ) -> Generator[SdkHttpResponse, None, None]:
655
1401
  http_request = self._build_request(
656
1402
  http_method, path, request_dict, http_options
657
1403
  )
658
1404
 
659
- session_response = self._request(http_request, stream=True)
1405
+ session_response = self._request(http_request, http_options, stream=True)
660
1406
  for chunk in session_response.segments():
661
- yield chunk
1407
+ chunk_dump = json.dumps(chunk)
1408
+ try:
1409
+ if chunk_dump.startswith('{"error":'):
1410
+ chunk_json = json.loads(chunk_dump)
1411
+ errors.APIError.raise_error(
1412
+ chunk_json.get('error', {}).get('code'),
1413
+ chunk_json,
1414
+ session_response,
1415
+ )
1416
+ except json.decoder.JSONDecodeError:
1417
+ logger.debug(
1418
+ 'Failed to decode chunk that contains an error: %s' % chunk_dump
1419
+ )
1420
+ pass
1421
+ yield SdkHttpResponse(headers=session_response.headers, body=chunk_dump)
662
1422
 
663
1423
  async def async_request(
664
1424
  self,
@@ -666,39 +1426,58 @@ class BaseApiClient:
666
1426
  path: str,
667
1427
  request_dict: dict[str, object],
668
1428
  http_options: Optional[HttpOptionsOrDict] = None,
669
- ) -> dict[str, object]:
1429
+ ) -> SdkHttpResponse:
670
1430
  http_request = self._build_request(
671
1431
  http_method, path, request_dict, http_options
672
1432
  )
673
1433
 
674
- result = await self._async_request(http_request=http_request, stream=False)
675
- json_response = result.json
676
- if not json_response:
677
- return BaseResponse(http_headers=result.headers).model_dump(by_alias=True)
678
- return json_response
1434
+ result = await self._async_request(
1435
+ http_request=http_request, http_options=http_options, stream=False
1436
+ )
1437
+ response_body = result.response_stream[0] if result.response_stream else ''
1438
+ return SdkHttpResponse(headers=result.headers, body=response_body)
679
1439
 
680
1440
  async def async_request_streamed(
681
1441
  self,
682
1442
  http_method: str,
683
1443
  path: str,
684
1444
  request_dict: dict[str, object],
685
- http_options: Optional[HttpOptionsDict] = None,
686
- ):
1445
+ http_options: Optional[HttpOptionsOrDict] = None,
1446
+ ) -> Any:
687
1447
  http_request = self._build_request(
688
1448
  http_method, path, request_dict, http_options
689
1449
  )
690
1450
 
691
1451
  response = await self._async_request(http_request=http_request, stream=True)
692
1452
 
693
- async def async_generator():
1453
+ async def async_generator(): # type: ignore[no-untyped-def]
694
1454
  async for chunk in response:
695
- yield chunk
1455
+ chunk_dump = json.dumps(chunk)
1456
+ try:
1457
+ if chunk_dump.startswith('{"error":'):
1458
+ chunk_json = json.loads(chunk_dump)
1459
+ await errors.APIError.raise_error_async(
1460
+ chunk_json.get('error', {}).get('code'),
1461
+ chunk_json,
1462
+ response,
1463
+ )
1464
+ except json.decoder.JSONDecodeError:
1465
+ logger.debug(
1466
+ 'Failed to decode chunk that contains an error: %s' % chunk_dump
1467
+ )
1468
+ pass
1469
+ yield SdkHttpResponse(headers=response.headers, body=chunk_dump)
696
1470
 
697
- return async_generator()
1471
+ return async_generator() # type: ignore[no-untyped-call]
698
1472
 
699
1473
  def upload_file(
700
- self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
701
- ) -> dict[str, str]:
1474
+ self,
1475
+ file_path: Union[str, io.IOBase],
1476
+ upload_url: str,
1477
+ upload_size: int,
1478
+ *,
1479
+ http_options: Optional[HttpOptionsOrDict] = None,
1480
+ ) -> HttpResponse:
702
1481
  """Transfers a file to the given URL.
703
1482
 
704
1483
  Args:
@@ -708,19 +1487,29 @@ class BaseApiClient:
708
1487
  upload_url: The URL to upload the file to.
709
1488
  upload_size: The size of file content to be uploaded, this will have to
710
1489
  match the size requested in the resumable upload request.
1490
+ http_options: The http options to use for the request.
711
1491
 
712
1492
  returns:
713
- The response json object from the finalize request.
1493
+ The HttpResponse object from the finalize request.
714
1494
  """
715
1495
  if isinstance(file_path, io.IOBase):
716
- return self._upload_fd(file_path, upload_url, upload_size)
1496
+ return self._upload_fd(
1497
+ file_path, upload_url, upload_size, http_options=http_options
1498
+ )
717
1499
  else:
718
1500
  with open(file_path, 'rb') as file:
719
- return self._upload_fd(file, upload_url, upload_size)
1501
+ return self._upload_fd(
1502
+ file, upload_url, upload_size, http_options=http_options
1503
+ )
720
1504
 
721
1505
  def _upload_fd(
722
- self, file: io.IOBase, upload_url: str, upload_size: int
723
- ) -> dict[str, str]:
1506
+ self,
1507
+ file: io.IOBase,
1508
+ upload_url: str,
1509
+ upload_size: int,
1510
+ *,
1511
+ http_options: Optional[HttpOptionsOrDict] = None,
1512
+ ) -> HttpResponse:
724
1513
  """Transfers a file to the given URL.
725
1514
 
726
1515
  Args:
@@ -728,9 +1517,10 @@ class BaseApiClient:
728
1517
  upload_url: The URL to upload the file to.
729
1518
  upload_size: The size of file content to be uploaded, this will have to
730
1519
  match the size requested in the resumable upload request.
1520
+ http_options: The http options to use for the request.
731
1521
 
732
1522
  returns:
733
- The response json object from the finalize request.
1523
+ The HttpResponse object from the finalize request.
734
1524
  """
735
1525
  offset = 0
736
1526
  # Upload the file in chunks
@@ -743,32 +1533,60 @@ class BaseApiClient:
743
1533
  # If last chunk, finalize the upload.
744
1534
  if chunk_size + offset >= upload_size:
745
1535
  upload_command += ', finalize'
746
- response = self._httpx_client.request(
747
- method='POST',
748
- url=upload_url,
749
- headers={
750
- 'X-Goog-Upload-Command': upload_command,
751
- 'X-Goog-Upload-Offset': str(offset),
752
- 'Content-Length': str(chunk_size),
753
- },
754
- content=file_chunk,
1536
+ http_options = http_options if http_options else self._http_options
1537
+ timeout = (
1538
+ http_options.get('timeout')
1539
+ if isinstance(http_options, dict)
1540
+ else http_options.timeout
755
1541
  )
1542
+ if timeout is None:
1543
+ # Per request timeout is not configured. Check the global timeout.
1544
+ timeout = (
1545
+ self._http_options.timeout
1546
+ if isinstance(self._http_options, dict)
1547
+ else self._http_options.timeout
1548
+ )
1549
+ timeout_in_seconds = get_timeout_in_seconds(timeout)
1550
+ upload_headers = {
1551
+ 'X-Goog-Upload-Command': upload_command,
1552
+ 'X-Goog-Upload-Offset': str(offset),
1553
+ 'Content-Length': str(chunk_size),
1554
+ }
1555
+ populate_server_timeout_header(upload_headers, timeout_in_seconds)
1556
+ retry_count = 0
1557
+ while retry_count < MAX_RETRY_COUNT:
1558
+ response = self._httpx_client.request(
1559
+ method='POST',
1560
+ url=upload_url,
1561
+ headers=upload_headers,
1562
+ content=file_chunk,
1563
+ timeout=timeout_in_seconds,
1564
+ )
1565
+ if response.headers.get('x-goog-upload-status'):
1566
+ break
1567
+ delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
1568
+ retry_count += 1
1569
+ time.sleep(delay_seconds)
1570
+
756
1571
  offset += chunk_size
757
- if response.headers['x-goog-upload-status'] != 'active':
1572
+ if response.headers.get('x-goog-upload-status') != 'active':
758
1573
  break # upload is complete or it has been interrupted.
759
1574
  if upload_size <= offset: # Status is not finalized.
760
1575
  raise ValueError(
761
- 'All content has been uploaded, but the upload status is not'
1576
+ f'All content has been uploaded, but the upload status is not'
762
1577
  f' finalized.'
763
1578
  )
1579
+ errors.APIError.raise_for_response(response)
1580
+ if response.headers.get('x-goog-upload-status') != 'final':
1581
+ raise ValueError('Failed to upload file: Upload status is not finalized.')
1582
+ return HttpResponse(response.headers, response_stream=[response.text])
764
1583
 
765
- if response.headers['x-goog-upload-status'] != 'final':
766
- raise ValueError(
767
- 'Failed to upload file: Upload status is not finalized.'
768
- )
769
- return response.json()
770
-
771
- def download_file(self, path: str, http_options):
1584
+ def download_file(
1585
+ self,
1586
+ path: str,
1587
+ *,
1588
+ http_options: Optional[HttpOptionsOrDict] = None,
1589
+ ) -> Union[Any, bytes]:
772
1590
  """Downloads the file data.
773
1591
 
774
1592
  Args:
@@ -807,7 +1625,9 @@ class BaseApiClient:
807
1625
  file_path: Union[str, io.IOBase],
808
1626
  upload_url: str,
809
1627
  upload_size: int,
810
- ) -> dict[str, str]:
1628
+ *,
1629
+ http_options: Optional[HttpOptionsOrDict] = None,
1630
+ ) -> HttpResponse:
811
1631
  """Transfers a file asynchronously to the given URL.
812
1632
 
813
1633
  Args:
@@ -816,24 +1636,31 @@ class BaseApiClient:
816
1636
  upload_url: The URL to upload the file to.
817
1637
  upload_size: The size of file content to be uploaded, this will have to
818
1638
  match the size requested in the resumable upload request.
1639
+ http_options: The http options to use for the request.
819
1640
 
820
1641
  returns:
821
- The response json object from the finalize request.
1642
+ The HttpResponse object from the finalize request.
822
1643
  """
823
1644
  if isinstance(file_path, io.IOBase):
824
- return await self._async_upload_fd(file_path, upload_url, upload_size)
1645
+ return await self._async_upload_fd(
1646
+ file_path, upload_url, upload_size, http_options=http_options
1647
+ )
825
1648
  else:
826
1649
  file = anyio.Path(file_path)
827
1650
  fd = await file.open('rb')
828
1651
  async with fd:
829
- return await self._async_upload_fd(fd, upload_url, upload_size)
1652
+ return await self._async_upload_fd(
1653
+ fd, upload_url, upload_size, http_options=http_options
1654
+ )
830
1655
 
831
1656
  async def _async_upload_fd(
832
1657
  self,
833
- file: Union[io.IOBase, anyio.AsyncFile],
1658
+ file: Union[io.IOBase, anyio.AsyncFile[Any]],
834
1659
  upload_url: str,
835
1660
  upload_size: int,
836
- ) -> dict[str, str]:
1661
+ *,
1662
+ http_options: Optional[HttpOptionsOrDict] = None,
1663
+ ) -> HttpResponse:
837
1664
  """Transfers a file asynchronously to the given URL.
838
1665
 
839
1666
  Args:
@@ -841,50 +1668,175 @@ class BaseApiClient:
841
1668
  upload_url: The URL to upload the file to.
842
1669
  upload_size: The size of file content to be uploaded, this will have to
843
1670
  match the size requested in the resumable upload request.
1671
+ http_options: The http options to use for the request.
844
1672
 
845
1673
  returns:
846
- The response json object from the finalize request.
1674
+ The HttpResponse object from the finalized request.
847
1675
  """
848
1676
  offset = 0
849
1677
  # Upload the file in chunks
850
- while True:
851
- if isinstance(file, io.IOBase):
852
- file_chunk = file.read(CHUNK_SIZE)
853
- else:
854
- file_chunk = await file.read(CHUNK_SIZE)
855
- chunk_size = 0
856
- if file_chunk:
857
- chunk_size = len(file_chunk)
858
- upload_command = 'upload'
859
- # If last chunk, finalize the upload.
860
- if chunk_size + offset >= upload_size:
861
- upload_command += ', finalize'
862
- response = await self._async_httpx_client.request(
863
- method='POST',
864
- url=upload_url,
865
- content=file_chunk,
866
- headers={
867
- 'X-Goog-Upload-Command': upload_command,
868
- 'X-Goog-Upload-Offset': str(offset),
869
- 'Content-Length': str(chunk_size),
870
- },
1678
+ if self._use_aiohttp(): # pylint: disable=g-import-not-at-top
1679
+ self._aiohttp_session = await self._get_aiohttp_session()
1680
+ while True:
1681
+ if isinstance(file, io.IOBase):
1682
+ file_chunk = file.read(CHUNK_SIZE)
1683
+ else:
1684
+ file_chunk = await file.read(CHUNK_SIZE)
1685
+ chunk_size = 0
1686
+ if file_chunk:
1687
+ chunk_size = len(file_chunk)
1688
+ upload_command = 'upload'
1689
+ # If last chunk, finalize the upload.
1690
+ if chunk_size + offset >= upload_size:
1691
+ upload_command += ', finalize'
1692
+ http_options = http_options if http_options else self._http_options
1693
+ timeout = (
1694
+ http_options.get('timeout')
1695
+ if isinstance(http_options, dict)
1696
+ else http_options.timeout
1697
+ )
1698
+ if timeout is None:
1699
+ # Per request timeout is not configured. Check the global timeout.
1700
+ timeout = (
1701
+ self._http_options.timeout
1702
+ if isinstance(self._http_options, dict)
1703
+ else self._http_options.timeout
1704
+ )
1705
+ timeout_in_seconds = get_timeout_in_seconds(timeout)
1706
+ upload_headers = {
1707
+ 'X-Goog-Upload-Command': upload_command,
1708
+ 'X-Goog-Upload-Offset': str(offset),
1709
+ 'Content-Length': str(chunk_size),
1710
+ }
1711
+ populate_server_timeout_header(upload_headers, timeout_in_seconds)
1712
+
1713
+ retry_count = 0
1714
+ response = None
1715
+ while retry_count < MAX_RETRY_COUNT:
1716
+ response = await self._aiohttp_session.request(
1717
+ method='POST',
1718
+ url=upload_url,
1719
+ data=file_chunk,
1720
+ headers=upload_headers,
1721
+ timeout=aiohttp.ClientTimeout(connect=timeout_in_seconds),
1722
+ )
1723
+
1724
+ if response.headers.get('X-Goog-Upload-Status'):
1725
+ break
1726
+ delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
1727
+ retry_count += 1
1728
+ await asyncio.sleep(delay_seconds)
1729
+
1730
+ offset += chunk_size
1731
+ if (
1732
+ response is not None
1733
+ and response.headers.get('X-Goog-Upload-Status') != 'active'
1734
+ ):
1735
+ break # upload is complete or it has been interrupted.
1736
+
1737
+ if upload_size <= offset: # Status is not finalized.
1738
+ raise ValueError(
1739
+ f'All content has been uploaded, but the upload status is not'
1740
+ f' finalized.'
1741
+ )
1742
+
1743
+ await errors.APIError.raise_for_async_response(response)
1744
+ if (
1745
+ response is not None
1746
+ and response.headers.get('X-Goog-Upload-Status') != 'final'
1747
+ ):
1748
+ raise ValueError(
1749
+ 'Failed to upload file: Upload status is not finalized.'
1750
+ )
1751
+ return HttpResponse(
1752
+ response.headers, response_stream=[await response.text()]
871
1753
  )
872
- offset += chunk_size
873
- if response.headers.get('x-goog-upload-status') != 'active':
874
- break # upload is complete or it has been interrupted.
1754
+ else:
1755
+ # aiohttp is not available. Fall back to httpx.
1756
+ while True:
1757
+ if isinstance(file, io.IOBase):
1758
+ file_chunk = file.read(CHUNK_SIZE)
1759
+ else:
1760
+ file_chunk = await file.read(CHUNK_SIZE)
1761
+ chunk_size = 0
1762
+ if file_chunk:
1763
+ chunk_size = len(file_chunk)
1764
+ upload_command = 'upload'
1765
+ # If last chunk, finalize the upload.
1766
+ if chunk_size + offset >= upload_size:
1767
+ upload_command += ', finalize'
1768
+ http_options = http_options if http_options else self._http_options
1769
+ timeout = (
1770
+ http_options.get('timeout')
1771
+ if isinstance(http_options, dict)
1772
+ else http_options.timeout
1773
+ )
1774
+ if timeout is None:
1775
+ # Per request timeout is not configured. Check the global timeout.
1776
+ timeout = (
1777
+ self._http_options.timeout
1778
+ if isinstance(self._http_options, dict)
1779
+ else self._http_options.timeout
1780
+ )
1781
+ timeout_in_seconds = get_timeout_in_seconds(timeout)
1782
+ upload_headers = {
1783
+ 'X-Goog-Upload-Command': upload_command,
1784
+ 'X-Goog-Upload-Offset': str(offset),
1785
+ 'Content-Length': str(chunk_size),
1786
+ }
1787
+ populate_server_timeout_header(upload_headers, timeout_in_seconds)
1788
+
1789
+ retry_count = 0
1790
+ client_response = None
1791
+ while retry_count < MAX_RETRY_COUNT:
1792
+ client_response = await self._async_httpx_client.request(
1793
+ method='POST',
1794
+ url=upload_url,
1795
+ content=file_chunk,
1796
+ headers=upload_headers,
1797
+ timeout=timeout_in_seconds,
1798
+ )
1799
+ if (
1800
+ client_response is not None
1801
+ and client_response.headers
1802
+ and client_response.headers.get('x-goog-upload-status')
1803
+ ):
1804
+ break
1805
+ delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
1806
+ retry_count += 1
1807
+ time.sleep(delay_seconds)
1808
+
1809
+ offset += chunk_size
1810
+ if (
1811
+ client_response is not None
1812
+ and client_response.headers.get('x-goog-upload-status') != 'active'
1813
+ ):
1814
+ break # upload is complete or it has been interrupted.
1815
+
1816
+ if upload_size <= offset: # Status is not finalized.
1817
+ raise ValueError(
1818
+ 'All content has been uploaded, but the upload status is not'
1819
+ ' finalized.'
1820
+ )
875
1821
 
876
- if upload_size <= offset: # Status is not finalized.
1822
+ await errors.APIError.raise_for_async_response(client_response)
1823
+ if (
1824
+ client_response is not None
1825
+ and client_response.headers.get('x-goog-upload-status') != 'final'
1826
+ ):
877
1827
  raise ValueError(
878
- 'All content has been uploaded, but the upload status is not'
879
- f' finalized.'
1828
+ 'Failed to upload file: Upload status is not finalized.'
880
1829
  )
881
- if response.headers.get('x-goog-upload-status') != 'final':
882
- raise ValueError(
883
- 'Failed to upload file: Upload status is not finalized.'
1830
+ return HttpResponse(
1831
+ client_response.headers, response_stream=[client_response.text]
884
1832
  )
885
- return response.json()
886
1833
 
887
- async def async_download_file(self, path: str, http_options):
1834
+ async def async_download_file(
1835
+ self,
1836
+ path: str,
1837
+ *,
1838
+ http_options: Optional[HttpOptionsOrDict] = None,
1839
+ ) -> Union[Any, bytes]:
888
1840
  """Downloads the file data.
889
1841
 
890
1842
  Args:
@@ -905,21 +1857,71 @@ class BaseApiClient:
905
1857
  else:
906
1858
  data = http_request.data
907
1859
 
908
- response = await self._async_httpx_client.request(
909
- method=http_request.method,
910
- url=http_request.url,
911
- headers=http_request.headers,
912
- content=data,
913
- timeout=http_request.timeout,
914
- )
915
- errors.APIError.raise_for_response(response)
1860
+ if self._use_aiohttp():
1861
+ self._aiohttp_session = await self._get_aiohttp_session()
1862
+ response = await self._aiohttp_session.request(
1863
+ method=http_request.method,
1864
+ url=http_request.url,
1865
+ headers=http_request.headers,
1866
+ data=data,
1867
+ timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1868
+ )
1869
+ await errors.APIError.raise_for_async_response(response)
916
1870
 
917
- return HttpResponse(
918
- response.headers, byte_stream=[response.read()]
919
- ).byte_stream[0]
1871
+ return HttpResponse(
1872
+ response.headers, byte_stream=[await response.read()]
1873
+ ).byte_stream[0]
1874
+ else:
1875
+ # aiohttp is not available. Fall back to httpx.
1876
+ client_response = await self._async_httpx_client.request(
1877
+ method=http_request.method,
1878
+ url=http_request.url,
1879
+ headers=http_request.headers,
1880
+ content=data,
1881
+ timeout=http_request.timeout,
1882
+ )
1883
+ await errors.APIError.raise_for_async_response(client_response)
1884
+
1885
+ return HttpResponse(
1886
+ client_response.headers, byte_stream=[client_response.read()]
1887
+ ).byte_stream[0]
920
1888
 
921
1889
  # This method does nothing in the real api client. It is used in the
922
1890
  # replay_api_client to verify the response from the SDK method matches the
923
1891
  # recorded response.
924
- def _verify_response(self, response_model: _common.BaseModel):
1892
+ def _verify_response(self, response_model: _common.BaseModel) -> None:
925
1893
  pass
1894
+
1895
+ def close(self) -> None:
1896
+ """Closes the API client."""
1897
+ # Let users close the custom client explicitly by themselves. Otherwise,
1898
+ # close the client when the object is garbage collected.
1899
+ if not self._http_options.httpx_client:
1900
+ self._httpx_client.close()
1901
+
1902
+ async def aclose(self) -> None:
1903
+ """Closes the API async client."""
1904
+ # Let users close the custom client explicitly by themselves. Otherwise,
1905
+ # close the client when the object is garbage collected.
1906
+ if not self._http_options.httpx_async_client:
1907
+ await self._async_httpx_client.aclose()
1908
+ if self._aiohttp_session:
1909
+ await self._aiohttp_session.close()
1910
+
1911
+ def __del__(self) -> None:
1912
+ """Closes the API client when the object is garbage collected.
1913
+
1914
+ ADK uses this client so cannot rely on the genai.[Async]Client.__del__
1915
+ for cleanup.
1916
+ """
1917
+
1918
+ try:
1919
+ if not self._http_options.httpx_client:
1920
+ self.close()
1921
+ except Exception: # pylint: disable=broad-except
1922
+ pass
1923
+
1924
+ try:
1925
+ asyncio.get_running_loop().create_task(self.aclose())
1926
+ except Exception: # pylint: disable=broad-except
1927
+ pass