google-genai 1.33.0__py3-none-any.whl → 1.53.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,13 +37,13 @@ import time
37
37
  from typing import Any, AsyncIterator, Iterator, Optional, Tuple, TYPE_CHECKING, Union
38
38
  from urllib.parse import urlparse
39
39
  from urllib.parse import urlunparse
40
+ import warnings
40
41
 
41
42
  import anyio
42
43
  import certifi
43
44
  import google.auth
44
45
  import google.auth.credentials
45
46
  from google.auth.credentials import Credentials
46
- from google.auth.transport.requests import Request
47
47
  import httpx
48
48
  from pydantic import BaseModel
49
49
  from pydantic import ValidationError
@@ -56,6 +56,7 @@ from .types import HttpOptions
56
56
  from .types import HttpOptionsOrDict
57
57
  from .types import HttpResponse as SdkHttpResponse
58
58
  from .types import HttpRetryOptions
59
+ from .types import ResourceScope
59
60
 
60
61
 
61
62
  try:
@@ -79,7 +80,7 @@ if TYPE_CHECKING:
79
80
 
80
81
  logger = logging.getLogger('google_genai._api_client')
81
82
  CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
82
- READ_BUFFER_SIZE = 2**20
83
+ READ_BUFFER_SIZE = 2**22
83
84
  MAX_RETRY_COUNT = 3
84
85
  INITIAL_RETRY_DELAY = 1 # second
85
86
  DELAY_MULTIPLIER = 2
@@ -196,6 +197,7 @@ def load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
196
197
 
197
198
 
198
199
  def refresh_auth(credentials: Credentials) -> Credentials:
200
+ from google.auth.transport.requests import Request
199
201
  credentials.refresh(Request()) # type: ignore[no-untyped-call]
200
202
  return credentials
201
203
 
@@ -229,7 +231,6 @@ class HttpResponse:
229
231
  headers: Union[dict[str, str], httpx.Headers, 'CIMultiDictProxy[str]'],
230
232
  response_stream: Union[Any, str] = None,
231
233
  byte_stream: Union[Any, bytes] = None,
232
- session: Optional['aiohttp.ClientSession'] = None,
233
234
  ):
234
235
  if isinstance(headers, dict):
235
236
  self.headers = headers
@@ -245,7 +246,6 @@ class HttpResponse:
245
246
  self.status_code: int = 200
246
247
  self.response_stream = response_stream
247
248
  self.byte_stream = byte_stream
248
- self._session = session
249
249
 
250
250
  # Async iterator for async streaming.
251
251
  def __aiter__(self) -> 'HttpResponse':
@@ -360,69 +360,76 @@ class HttpResponse:
360
360
  balance = 0
361
361
  # httpx.Response has a dedicated async line iterator.
362
362
  if isinstance(self.response_stream, httpx.Response):
363
- async for line in self.response_stream.aiter_lines():
364
- if not line:
365
- continue
366
- # In streaming mode, the response of JSON is prefixed with "data: "
367
- # which we must strip before parsing.
368
- if line.startswith('data: '):
369
- yield line[len('data: '):]
370
- continue
371
-
372
- # When API returns an error message, it comes line by line. So we buffer
373
- # the lines until a complete JSON string is read. A complete JSON string
374
- # is found when the balance is 0.
375
- for c in line:
376
- if c == '{':
377
- balance += 1
378
- elif c == '}':
379
- balance -= 1
380
-
381
- chunk += line
382
- if balance == 0:
363
+ try:
364
+ async for line in self.response_stream.aiter_lines():
365
+ if not line:
366
+ continue
367
+ # In streaming mode, the response of JSON is prefixed with "data: "
368
+ # which we must strip before parsing.
369
+ if line.startswith('data: '):
370
+ yield line[len('data: '):]
371
+ continue
372
+
373
+ # When API returns an error message, it comes line by line. So we buffer
374
+ # the lines until a complete JSON string is read. A complete JSON string
375
+ # is found when the balance is 0.
376
+ for c in line:
377
+ if c == '{':
378
+ balance += 1
379
+ elif c == '}':
380
+ balance -= 1
381
+
382
+ chunk += line
383
+ if balance == 0:
384
+ yield chunk
385
+ chunk = ''
386
+ # If there is any remaining chunk, yield it.
387
+ if chunk:
383
388
  yield chunk
384
- chunk = ''
389
+ finally:
390
+ # Close the response and release the connection.
391
+ await self.response_stream.aclose()
385
392
 
386
393
  # aiohttp.ClientResponse uses a content stream that we read line by line.
387
394
  elif has_aiohttp and isinstance(
388
395
  self.response_stream, aiohttp.ClientResponse
389
396
  ):
390
- while True:
391
- # Read a line from the stream. This returns bytes.
392
- line_bytes = await self.response_stream.content.readline()
393
- if not line_bytes:
394
- break
395
- # Decode the bytes and remove trailing whitespace and newlines.
396
- line = line_bytes.decode('utf-8').rstrip()
397
- if not line:
398
- continue
399
-
400
- # In streaming mode, the response of JSON is prefixed with "data: "
401
- # which we must strip before parsing.
402
- if line.startswith('data: '):
403
- yield line[len('data: '):]
404
- continue
405
-
406
- # When API returns an error message, it comes line by line. So we buffer
407
- # the lines until a complete JSON string is read. A complete JSON string
408
- # is found when the balance is 0.
409
- for c in line:
410
- if c == '{':
411
- balance += 1
412
- elif c == '}':
413
- balance -= 1
414
-
415
- chunk += line
416
- if balance == 0:
397
+ try:
398
+ while True:
399
+ # Read a line from the stream. This returns bytes.
400
+ line_bytes = await self.response_stream.content.readline()
401
+ if not line_bytes:
402
+ break
403
+ # Decode the bytes and remove trailing whitespace and newlines.
404
+ line = line_bytes.decode('utf-8').rstrip()
405
+ if not line:
406
+ continue
407
+
408
+ # In streaming mode, the response of JSON is prefixed with "data: "
409
+ # which we must strip before parsing.
410
+ if line.startswith('data: '):
411
+ yield line[len('data: '):]
412
+ continue
413
+
414
+ # When API returns an error message, it comes line by line. So we
415
+ # buffer the lines until a complete JSON string is read. A complete
416
+ # JSON strings found when the balance is 0.
417
+ for c in line:
418
+ if c == '{':
419
+ balance += 1
420
+ elif c == '}':
421
+ balance -= 1
422
+
423
+ chunk += line
424
+ if balance == 0:
425
+ yield chunk
426
+ chunk = ''
427
+ # If there is any remaining chunk, yield it.
428
+ if chunk:
417
429
  yield chunk
418
- chunk = ''
419
-
420
- # If there is any remaining chunk, yield it.
421
- if chunk:
422
- yield chunk
423
-
424
- if hasattr(self, '_session') and self._session:
425
- await self._session.close()
430
+ finally:
431
+ # Release the connection back to the pool for potential reuse.
432
+ self.response_stream.release()
426
433
 
427
434
  @classmethod
428
435
  def _load_json_from_response(cls, response: Any) -> Any:
@@ -483,6 +490,7 @@ def retry_args(options: Optional[HttpRetryOptions]) -> _common.StringDict:
483
490
  'retry': retry,
484
491
  'reraise': True,
485
492
  'wait': wait,
493
+ 'before_sleep': tenacity.before_sleep_log(logger, logging.INFO),
486
494
  }
487
495
 
488
496
 
@@ -540,6 +548,7 @@ class BaseApiClient:
540
548
  http_options: Optional[HttpOptionsOrDict] = None,
541
549
  ):
542
550
  self.vertexai = vertexai
551
+ self.custom_base_url = None
543
552
  if self.vertexai is None:
544
553
  if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
545
554
  'true',
@@ -571,6 +580,12 @@ class BaseApiClient:
571
580
  elif isinstance(http_options, HttpOptions):
572
581
  validated_http_options = http_options
573
582
 
583
+ if validated_http_options.base_url_resource_scope and not validated_http_options.base_url:
584
+ # base_url_resource_scope is only valid when base_url is set.
585
+ raise ValueError(
586
+ 'base_url must be set when base_url_resource_scope is set.'
587
+ )
588
+
574
589
  # Retrieve implicitly set values from the environment.
575
590
  env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
576
591
  env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None)
@@ -591,7 +606,7 @@ class BaseApiClient:
591
606
  # Handle when to use Vertex AI in express mode (api key).
592
607
  # Explicit initializer arguments are already validated above.
593
608
  if self.vertexai:
594
- if credentials:
609
+ if credentials and env_api_key:
595
610
  # Explicit credentials take precedence over implicit api_key.
596
611
  logger.info(
597
612
  'The user provided Google Cloud credentials will take precedence'
@@ -621,11 +636,20 @@ class BaseApiClient:
621
636
  )
622
637
  self.api_key = None
623
638
 
639
+ self.custom_base_url = (
640
+ validated_http_options.base_url
641
+ if validated_http_options.base_url
642
+ else None
643
+ )
644
+
645
+ if not self.location and not self.api_key and not self.custom_base_url:
646
+ self.location = 'global'
647
+
624
648
  # Skip fetching project from ADC if base url is provided in http options.
625
649
  if (
626
650
  not self.project
627
651
  and not self.api_key
628
- and not validated_http_options.base_url
652
+ and not self.custom_base_url
629
653
  ):
630
654
  credentials, self.project = load_auth(project=None)
631
655
  if not self._credentials:
@@ -633,17 +657,22 @@ class BaseApiClient:
633
657
 
634
658
  has_sufficient_auth = (self.project and self.location) or self.api_key
635
659
 
636
- if not has_sufficient_auth and not validated_http_options.base_url:
660
+ if not has_sufficient_auth and not self.custom_base_url:
637
661
  # Skip sufficient auth check if base url is provided in http options.
638
662
  raise ValueError(
639
- 'Project and location or API key must be set when using the Vertex '
663
+ 'Project or API key must be set when using the Vertex '
640
664
  'AI API.'
641
665
  )
642
666
  if self.api_key or self.location == 'global':
643
667
  self._http_options.base_url = f'https://aiplatform.googleapis.com/'
644
- elif validated_http_options.base_url and not has_sufficient_auth:
668
+ elif self.custom_base_url and not ((project and location) or api_key):
645
669
  # Avoid setting default base url and api version if base_url provided.
646
- self._http_options.base_url = validated_http_options.base_url
670
+ # API gateway proxy can use the auth in custom headers, not url.
671
+ # Enable custom url if auth is not sufficient.
672
+ self._http_options.base_url = self.custom_base_url
673
+ # Clear project and location if base_url is provided.
674
+ self.project = None
675
+ self.location = None
647
676
  else:
648
677
  self._http_options.base_url = (
649
678
  f'https://{self.location}-aiplatform.googleapis.com/'
@@ -676,19 +705,79 @@ class BaseApiClient:
676
705
  client_args, async_client_args = self._ensure_httpx_ssl_ctx(
677
706
  self._http_options
678
707
  )
679
- self._httpx_client = SyncHttpxClient(**client_args)
680
- self._async_httpx_client = AsyncHttpxClient(**async_client_args)
708
+ self._async_httpx_client_args = async_client_args
709
+
710
+ if self._http_options.httpx_client:
711
+ self._httpx_client = self._http_options.httpx_client
712
+ else:
713
+ self._httpx_client = SyncHttpxClient(**client_args)
714
+ if self._http_options.httpx_async_client:
715
+ self._async_httpx_client = self._http_options.httpx_async_client
716
+ else:
717
+ self._async_httpx_client = AsyncHttpxClient(**async_client_args)
681
718
  if self._use_aiohttp():
682
- # Do it once at the genai.Client level. Share among all requests.
683
- self._async_client_session_request_args = self._ensure_aiohttp_ssl_ctx(
684
- self._http_options
685
- )
686
- self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(self._http_options)
719
+ try:
720
+ import aiohttp # pylint: disable=g-import-not-at-top
721
+ # Do it once at the genai.Client level. Share among all requests.
722
+ self._async_client_session_request_args = self._ensure_aiohttp_ssl_ctx(
723
+ self._http_options
724
+ )
725
+ except ImportError:
726
+ pass
727
+
728
+ # Initialize the aiohttp client session.
729
+ self._aiohttp_session: Optional[aiohttp.ClientSession] = None
687
730
 
688
731
  retry_kwargs = retry_args(self._http_options.retry_options)
732
+ self._websocket_ssl_ctx = self._ensure_websocket_ssl_ctx(self._http_options)
689
733
  self._retry = tenacity.Retrying(**retry_kwargs)
690
734
  self._async_retry = tenacity.AsyncRetrying(**retry_kwargs)
691
735
 
736
+ async def _get_aiohttp_session(self) -> 'aiohttp.ClientSession':
737
+ """Returns the aiohttp client session."""
738
+ if (
739
+ self._aiohttp_session is None
740
+ or self._aiohttp_session.closed
741
+ or self._aiohttp_session._loop.is_closed() # pylint: disable=protected-access
742
+ ):
743
+ # Initialize the aiohttp client session if it's not set up or closed.
744
+ class AiohttpClientSession(aiohttp.ClientSession): # type: ignore[misc]
745
+
746
+ def __del__(self, _warnings: Any = warnings) -> None:
747
+ if not self.closed:
748
+ context = {
749
+ 'client_session': self,
750
+ 'message': 'Unclosed client session',
751
+ }
752
+ if self._source_traceback is not None:
753
+ context['source_traceback'] = self._source_traceback
754
+ # Remove this self._loop.call_exception_handler(context)
755
+
756
+ class AiohttpTCPConnector(aiohttp.TCPConnector): # type: ignore[misc]
757
+
758
+ def __del__(self, _warnings: Any = warnings) -> None:
759
+ if self._closed:
760
+ return
761
+ if not self._conns:
762
+ return
763
+ conns = [repr(c) for c in self._conns.values()]
764
+ # After v3.13.2, it may change to self._close_immediately()
765
+ self._close()
766
+ context = {
767
+ 'connector': self,
768
+ 'connections': conns,
769
+ 'message': 'Unclosed connector',
770
+ }
771
+ if self._source_traceback is not None:
772
+ context['source_traceback'] = self._source_traceback
773
+ # Remove this self._loop.call_exception_handler(context)
774
+ self._aiohttp_session = AiohttpClientSession(
775
+ connector=AiohttpTCPConnector(limit=0),
776
+ trust_env=True,
777
+ read_bufsize=READ_BUFFER_SIZE,
778
+ )
779
+ return self._aiohttp_session
780
+
692
781
  @staticmethod
693
782
  def _ensure_httpx_ssl_ctx(
694
783
  options: HttpOptions,
@@ -767,7 +856,6 @@ class BaseApiClient:
767
856
  Returns:
768
857
  An async aiohttp ClientSession._request args.
769
858
  """
770
-
771
859
  verify = 'ssl' # keep it consistent with httpx.
772
860
  async_args = options.async_client_args
773
861
  ctx = async_args.get(verify) if async_args else None
@@ -875,9 +963,15 @@ class BaseApiClient:
875
963
  has_aiohttp
876
964
  and (self._http_options.async_client_args or {}).get('transport')
877
965
  is None
966
+ and (self._http_options.httpx_async_client is None)
878
967
  )
879
968
 
880
969
  def _websocket_base_url(self) -> str:
970
+ has_sufficient_auth = (self.project and self.location) or self.api_key
971
+ if self.custom_base_url and not has_sufficient_auth:
972
+ # API gateway proxy can use the auth in custom headers, not url.
973
+ # Enable custom url if auth is not sufficient.
974
+ return self.custom_base_url
881
975
  url_parts = urlparse(self._http_options.base_url)
882
976
  return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
883
977
 
@@ -992,6 +1086,11 @@ class BaseApiClient:
992
1086
  and not path.startswith('projects/')
993
1087
  and not query_vertex_base_models
994
1088
  and (self.project or self.location)
1089
+ and not (
1090
+ self.custom_base_url
1091
+ and patched_http_options.base_url_resource_scope
1092
+ == ResourceScope.COLLECTION
1093
+ )
995
1094
  ):
996
1095
  path = f'projects/{self.project}/locations/{self.location}/' + path
997
1096
 
@@ -1015,11 +1114,27 @@ class BaseApiClient:
1015
1114
  _common.recursive_dict_update(
1016
1115
  request_dict, patched_http_options.extra_body
1017
1116
  )
1018
-
1019
- url = join_url_path(
1020
- base_url,
1021
- versioned_path,
1022
- )
1117
+ url = base_url
1118
+ if (
1119
+ not self.custom_base_url
1120
+ or (self.project and self.location)
1121
+ or self.api_key
1122
+ ):
1123
+ if (
1124
+ patched_http_options.base_url_resource_scope
1125
+ == ResourceScope.COLLECTION
1126
+ ):
1127
+ url = join_url_path(base_url, path)
1128
+ else:
1129
+ url = join_url_path(
1130
+ base_url,
1131
+ versioned_path,
1132
+ )
1133
+ elif(
1134
+ self.custom_base_url
1135
+ and patched_http_options.base_url_resource_scope == ResourceScope.COLLECTION
1136
+ ):
1137
+ url = join_url_path(base_url, path)
1023
1138
 
1024
1139
  if self.api_key and self.api_key.startswith('auth_tokens/'):
1025
1140
  raise EphemeralTokenAPIKeyError(
@@ -1132,13 +1247,9 @@ class BaseApiClient:
1132
1247
 
1133
1248
  if stream:
1134
1249
  if self._use_aiohttp():
1135
- session = aiohttp.ClientSession(
1136
- headers=http_request.headers,
1137
- trust_env=True,
1138
- read_bufsize=READ_BUFFER_SIZE,
1139
- )
1250
+ self._aiohttp_session = await self._get_aiohttp_session()
1140
1251
  try:
1141
- response = await session.request(
1252
+ response = await self._aiohttp_session.request(
1142
1253
  method=http_request.method,
1143
1254
  url=http_request.url,
1144
1255
  headers=http_request.headers,
@@ -1159,12 +1270,8 @@ class BaseApiClient:
1159
1270
  self._ensure_aiohttp_ssl_ctx(self._http_options)
1160
1271
  )
1161
1272
  # Instantiate a new session with the updated SSL context.
1162
- session = aiohttp.ClientSession(
1163
- headers=http_request.headers,
1164
- trust_env=True,
1165
- read_bufsize=READ_BUFFER_SIZE,
1166
- )
1167
- response = await session.request(
1273
+ self._aiohttp_session = await self._get_aiohttp_session()
1274
+ response = await self._aiohttp_session.request(
1168
1275
  method=http_request.method,
1169
1276
  url=http_request.url,
1170
1277
  headers=http_request.headers,
@@ -1174,7 +1281,7 @@ class BaseApiClient:
1174
1281
  )
1175
1282
 
1176
1283
  await errors.APIError.raise_for_async_response(response)
1177
- return HttpResponse(response.headers, response, session=session)
1284
+ return HttpResponse(response.headers, response)
1178
1285
  else:
1179
1286
  # aiohttp is not available. Fall back to httpx.
1180
1287
  httpx_request = self._async_httpx_client.build_request(
@@ -1192,22 +1299,18 @@ class BaseApiClient:
1192
1299
  return HttpResponse(client_response.headers, client_response)
1193
1300
  else:
1194
1301
  if self._use_aiohttp():
1302
+ self._aiohttp_session = await self._get_aiohttp_session()
1195
1303
  try:
1196
- async with aiohttp.ClientSession(
1304
+ response = await self._aiohttp_session.request(
1305
+ method=http_request.method,
1306
+ url=http_request.url,
1197
1307
  headers=http_request.headers,
1198
- trust_env=True,
1199
- read_bufsize=READ_BUFFER_SIZE,
1200
- ) as session:
1201
- response = await session.request(
1202
- method=http_request.method,
1203
- url=http_request.url,
1204
- headers=http_request.headers,
1205
- data=data,
1206
- timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1207
- **self._async_client_session_request_args,
1208
- )
1209
- await errors.APIError.raise_for_async_response(response)
1210
- return HttpResponse(response.headers, [await response.text()])
1308
+ data=data,
1309
+ timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1310
+ **self._async_client_session_request_args,
1311
+ )
1312
+ await errors.APIError.raise_for_async_response(response)
1313
+ return HttpResponse(response.headers, [await response.text()])
1211
1314
  except (
1212
1315
  aiohttp.ClientConnectorError,
1213
1316
  aiohttp.ClientConnectorDNSError,
@@ -1221,21 +1324,17 @@ class BaseApiClient:
1221
1324
  self._ensure_aiohttp_ssl_ctx(self._http_options)
1222
1325
  )
1223
1326
  # Instantiate a new session with the updated SSL context.
1224
- async with aiohttp.ClientSession(
1327
+ self._aiohttp_session = await self._get_aiohttp_session()
1328
+ response = await self._aiohttp_session.request(
1329
+ method=http_request.method,
1330
+ url=http_request.url,
1225
1331
  headers=http_request.headers,
1226
- trust_env=True,
1227
- read_bufsize=READ_BUFFER_SIZE,
1228
- ) as session:
1229
- response = await session.request(
1230
- method=http_request.method,
1231
- url=http_request.url,
1232
- headers=http_request.headers,
1233
- data=data,
1234
- timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1235
- **self._async_client_session_request_args,
1236
- )
1237
- await errors.APIError.raise_for_async_response(response)
1238
- return HttpResponse(response.headers, [await response.text()])
1332
+ data=data,
1333
+ timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1334
+ **self._async_client_session_request_args,
1335
+ )
1336
+ await errors.APIError.raise_for_async_response(response)
1337
+ return HttpResponse(response.headers, [await response.text()])
1239
1338
  else:
1240
1339
  # aiohttp is not available. Fall back to httpx.
1241
1340
  client_response = await self._async_httpx_client.request(
@@ -1305,9 +1404,21 @@ class BaseApiClient:
1305
1404
 
1306
1405
  session_response = self._request(http_request, http_options, stream=True)
1307
1406
  for chunk in session_response.segments():
1308
- yield SdkHttpResponse(
1309
- headers=session_response.headers, body=json.dumps(chunk)
1310
- )
1407
+ chunk_dump = json.dumps(chunk)
1408
+ try:
1409
+ if chunk_dump.startswith('{"error":'):
1410
+ chunk_json = json.loads(chunk_dump)
1411
+ errors.APIError.raise_error(
1412
+ chunk_json.get('error', {}).get('code'),
1413
+ chunk_json,
1414
+ session_response,
1415
+ )
1416
+ except json.decoder.JSONDecodeError:
1417
+ logger.debug(
1418
+ 'Failed to decode chunk that contains an error: %s' % chunk_dump
1419
+ )
1420
+ pass
1421
+ yield SdkHttpResponse(headers=session_response.headers, body=chunk_dump)
1311
1422
 
1312
1423
  async def async_request(
1313
1424
  self,
@@ -1341,7 +1452,21 @@ class BaseApiClient:
1341
1452
 
1342
1453
  async def async_generator(): # type: ignore[no-untyped-def]
1343
1454
  async for chunk in response:
1344
- yield SdkHttpResponse(headers=response.headers, body=json.dumps(chunk))
1455
+ chunk_dump = json.dumps(chunk)
1456
+ try:
1457
+ if chunk_dump.startswith('{"error":'):
1458
+ chunk_json = json.loads(chunk_dump)
1459
+ await errors.APIError.raise_error_async(
1460
+ chunk_json.get('error', {}).get('code'),
1461
+ chunk_json,
1462
+ response,
1463
+ )
1464
+ except json.decoder.JSONDecodeError:
1465
+ logger.debug(
1466
+ 'Failed to decode chunk that contains an error: %s' % chunk_dump
1467
+ )
1468
+ pass
1469
+ yield SdkHttpResponse(headers=response.headers, body=chunk_dump)
1345
1470
 
1346
1471
  return async_generator() # type: ignore[no-untyped-call]
1347
1472
 
@@ -1451,7 +1576,7 @@ class BaseApiClient:
1451
1576
  f'All content has been uploaded, but the upload status is not'
1452
1577
  f' finalized.'
1453
1578
  )
1454
-
1579
+ errors.APIError.raise_for_response(response)
1455
1580
  if response.headers.get('x-goog-upload-status') != 'final':
1456
1581
  raise ValueError('Failed to upload file: Upload status is not finalized.')
1457
1582
  return HttpResponse(response.headers, response_stream=[response.text])
@@ -1551,85 +1676,81 @@ class BaseApiClient:
1551
1676
  offset = 0
1552
1677
  # Upload the file in chunks
1553
1678
  if self._use_aiohttp(): # pylint: disable=g-import-not-at-top
1554
- async with aiohttp.ClientSession(
1555
- headers=self._http_options.headers,
1556
- trust_env=True,
1557
- read_bufsize=READ_BUFFER_SIZE,
1558
- ) as session:
1559
- while True:
1560
- if isinstance(file, io.IOBase):
1561
- file_chunk = file.read(CHUNK_SIZE)
1562
- else:
1563
- file_chunk = await file.read(CHUNK_SIZE)
1564
- chunk_size = 0
1565
- if file_chunk:
1566
- chunk_size = len(file_chunk)
1567
- upload_command = 'upload'
1568
- # If last chunk, finalize the upload.
1569
- if chunk_size + offset >= upload_size:
1570
- upload_command += ', finalize'
1571
- http_options = http_options if http_options else self._http_options
1679
+ self._aiohttp_session = await self._get_aiohttp_session()
1680
+ while True:
1681
+ if isinstance(file, io.IOBase):
1682
+ file_chunk = file.read(CHUNK_SIZE)
1683
+ else:
1684
+ file_chunk = await file.read(CHUNK_SIZE)
1685
+ chunk_size = 0
1686
+ if file_chunk:
1687
+ chunk_size = len(file_chunk)
1688
+ upload_command = 'upload'
1689
+ # If last chunk, finalize the upload.
1690
+ if chunk_size + offset >= upload_size:
1691
+ upload_command += ', finalize'
1692
+ http_options = http_options if http_options else self._http_options
1693
+ timeout = (
1694
+ http_options.get('timeout')
1695
+ if isinstance(http_options, dict)
1696
+ else http_options.timeout
1697
+ )
1698
+ if timeout is None:
1699
+ # Per request timeout is not configured. Check the global timeout.
1572
1700
  timeout = (
1573
- http_options.get('timeout')
1574
- if isinstance(http_options, dict)
1575
- else http_options.timeout
1701
+ self._http_options.timeout
1702
+ if isinstance(self._http_options, dict)
1703
+ else self._http_options.timeout
1576
1704
  )
1577
- if timeout is None:
1578
- # Per request timeout is not configured. Check the global timeout.
1579
- timeout = (
1580
- self._http_options.timeout
1581
- if isinstance(self._http_options, dict)
1582
- else self._http_options.timeout
1583
- )
1584
- timeout_in_seconds = get_timeout_in_seconds(timeout)
1585
- upload_headers = {
1586
- 'X-Goog-Upload-Command': upload_command,
1587
- 'X-Goog-Upload-Offset': str(offset),
1588
- 'Content-Length': str(chunk_size),
1589
- }
1590
- populate_server_timeout_header(upload_headers, timeout_in_seconds)
1591
-
1592
- retry_count = 0
1593
- response = None
1594
- while retry_count < MAX_RETRY_COUNT:
1595
- response = await session.request(
1596
- method='POST',
1597
- url=upload_url,
1598
- data=file_chunk,
1599
- headers=upload_headers,
1600
- timeout=aiohttp.ClientTimeout(connect=timeout_in_seconds),
1601
- )
1705
+ timeout_in_seconds = get_timeout_in_seconds(timeout)
1706
+ upload_headers = {
1707
+ 'X-Goog-Upload-Command': upload_command,
1708
+ 'X-Goog-Upload-Offset': str(offset),
1709
+ 'Content-Length': str(chunk_size),
1710
+ }
1711
+ populate_server_timeout_header(upload_headers, timeout_in_seconds)
1602
1712
 
1603
- if response.headers.get('X-Goog-Upload-Status'):
1604
- break
1605
- delay_seconds = INITIAL_RETRY_DELAY * (
1606
- DELAY_MULTIPLIER**retry_count
1607
- )
1608
- retry_count += 1
1609
- time.sleep(delay_seconds)
1713
+ retry_count = 0
1714
+ response = None
1715
+ while retry_count < MAX_RETRY_COUNT:
1716
+ response = await self._aiohttp_session.request(
1717
+ method='POST',
1718
+ url=upload_url,
1719
+ data=file_chunk,
1720
+ headers=upload_headers,
1721
+ timeout=aiohttp.ClientTimeout(connect=timeout_in_seconds),
1722
+ )
1610
1723
 
1611
- offset += chunk_size
1612
- if (
1613
- response is not None
1614
- and response.headers.get('X-Goog-Upload-Status') != 'active'
1615
- ):
1616
- break # upload is complete or it has been interrupted.
1724
+ if response.headers.get('X-Goog-Upload-Status'):
1725
+ break
1726
+ delay_seconds = INITIAL_RETRY_DELAY * (DELAY_MULTIPLIER**retry_count)
1727
+ retry_count += 1
1728
+ await asyncio.sleep(delay_seconds)
1617
1729
 
1618
- if upload_size <= offset: # Status is not finalized.
1619
- raise ValueError(
1620
- f'All content has been uploaded, but the upload status is not'
1621
- f' finalized.'
1622
- )
1730
+ offset += chunk_size
1623
1731
  if (
1624
1732
  response is not None
1625
- and response.headers.get('X-Goog-Upload-Status') != 'final'
1733
+ and response.headers.get('X-Goog-Upload-Status') != 'active'
1626
1734
  ):
1735
+ break # upload is complete or it has been interrupted.
1736
+
1737
+ if upload_size <= offset: # Status is not finalized.
1627
1738
  raise ValueError(
1628
- 'Failed to upload file: Upload status is not finalized.'
1739
+ f'All content has been uploaded, but the upload status is not'
1740
+ f' finalized.'
1629
1741
  )
1630
- return HttpResponse(
1631
- response.headers, response_stream=[await response.text()]
1742
+
1743
+ await errors.APIError.raise_for_async_response(response)
1744
+ if (
1745
+ response is not None
1746
+ and response.headers.get('X-Goog-Upload-Status') != 'final'
1747
+ ):
1748
+ raise ValueError(
1749
+ 'Failed to upload file: Upload status is not finalized.'
1632
1750
  )
1751
+ return HttpResponse(
1752
+ response.headers, response_stream=[await response.text()]
1753
+ )
1633
1754
  else:
1634
1755
  # aiohttp is not available. Fall back to httpx.
1635
1756
  while True:
@@ -1697,6 +1818,8 @@ class BaseApiClient:
1697
1818
  'All content has been uploaded, but the upload status is not'
1698
1819
  ' finalized.'
1699
1820
  )
1821
+
1822
+ await errors.APIError.raise_for_async_response(client_response)
1700
1823
  if (
1701
1824
  client_response is not None
1702
1825
  and client_response.headers.get('x-goog-upload-status') != 'final'
@@ -1735,23 +1858,19 @@ class BaseApiClient:
1735
1858
  data = http_request.data
1736
1859
 
1737
1860
  if self._use_aiohttp():
1738
- async with aiohttp.ClientSession(
1861
+ self._aiohttp_session = await self._get_aiohttp_session()
1862
+ response = await self._aiohttp_session.request(
1863
+ method=http_request.method,
1864
+ url=http_request.url,
1739
1865
  headers=http_request.headers,
1740
- trust_env=True,
1741
- read_bufsize=READ_BUFFER_SIZE,
1742
- ) as session:
1743
- response = await session.request(
1744
- method=http_request.method,
1745
- url=http_request.url,
1746
- headers=http_request.headers,
1747
- data=data,
1748
- timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1749
- )
1750
- await errors.APIError.raise_for_async_response(response)
1866
+ data=data,
1867
+ timeout=aiohttp.ClientTimeout(connect=http_request.timeout),
1868
+ )
1869
+ await errors.APIError.raise_for_async_response(response)
1751
1870
 
1752
- return HttpResponse(
1753
- response.headers, byte_stream=[await response.read()]
1754
- ).byte_stream[0]
1871
+ return HttpResponse(
1872
+ response.headers, byte_stream=[await response.read()]
1873
+ ).byte_stream[0]
1755
1874
  else:
1756
1875
  # aiohttp is not available. Fall back to httpx.
1757
1876
  client_response = await self._async_httpx_client.request(
@@ -1772,3 +1891,37 @@ class BaseApiClient:
1772
1891
  # recorded response.
1773
1892
  def _verify_response(self, response_model: _common.BaseModel) -> None:
1774
1893
  pass
1894
+
1895
+ def close(self) -> None:
1896
+ """Closes the API client."""
1897
+ # Let users close the custom client explicitly by themselves. Otherwise,
1898
+ # close the client when the object is garbage collected.
1899
+ if not self._http_options.httpx_client:
1900
+ self._httpx_client.close()
1901
+
1902
+ async def aclose(self) -> None:
1903
+ """Closes the API async client."""
1904
+ # Let users close the custom client explicitly by themselves. Otherwise,
1905
+ # close the client when the object is garbage collected.
1906
+ if not self._http_options.httpx_async_client:
1907
+ await self._async_httpx_client.aclose()
1908
+ if self._aiohttp_session:
1909
+ await self._aiohttp_session.close()
1910
+
1911
+ def __del__(self) -> None:
1912
+ """Closes the API client when the object is garbage collected.
1913
+
1914
+ ADK uses this client so cannot rely on the genai.[Async]Client.__del__
1915
+ for cleanup.
1916
+ """
1917
+
1918
+ try:
1919
+ if not self._http_options.httpx_client:
1920
+ self.close()
1921
+ except Exception: # pylint: disable=broad-except
1922
+ pass
1923
+
1924
+ try:
1925
+ asyncio.get_running_loop().create_task(self.aclose())
1926
+ except Exception: # pylint: disable=broad-except
1927
+ pass