google-genai 1.6.0__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -68,26 +68,30 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
68
68
 
69
69
 
70
70
  def _patch_http_options(
71
- options: HttpOptionsDict, patch_options: dict[str, Any]
72
- ) -> HttpOptionsDict:
73
- # use shallow copy so we don't override the original objects.
74
- copy_option = HttpOptionsDict()
75
- copy_option.update(options)
76
- for patch_key, patch_value in patch_options.items():
77
- # if both are dicts, update the copy.
78
- # This is to handle cases like merging headers.
79
- if isinstance(patch_value, dict) and isinstance(
80
- copy_option.get(patch_key, None), dict
81
- ):
82
- copy_option[patch_key] = {}
83
- copy_option[patch_key].update(
84
- options[patch_key]
85
- ) # shallow copy from original options.
86
- copy_option[patch_key].update(patch_value)
87
- elif patch_value is not None: # Accept empty values.
88
- copy_option[patch_key] = patch_value
89
- if copy_option['headers']:
90
- _append_library_version_headers(copy_option['headers'])
71
+ options: HttpOptions, patch_options: HttpOptions
72
+ ) -> HttpOptions:
73
+ copy_option = options.model_copy()
74
+
75
+ options_headers = copy_option.headers or {}
76
+ patch_options_headers = patch_options.headers or {}
77
+ copy_option.headers = {
78
+ **options_headers,
79
+ **patch_options_headers,
80
+ }
81
+
82
+ http_options_keys = HttpOptions.model_fields.keys()
83
+
84
+ for key in http_options_keys:
85
+ if key == 'headers':
86
+ continue
87
+ patch_value = getattr(patch_options, key, None)
88
+ if patch_value is not None:
89
+ setattr(copy_option, key, patch_value)
90
+ else:
91
+ setattr(copy_option, key, getattr(options, key))
92
+
93
+ if copy_option.headers is not None:
94
+ _append_library_version_headers(copy_option.headers)
91
95
  return copy_option
92
96
 
93
97
 
@@ -188,9 +192,11 @@ class HttpResponse:
188
192
  if chunk:
189
193
  # In streaming mode, the chunk of JSON is prefixed with "data:" which
190
194
  # we must strip before parsing.
191
- if chunk.startswith(b'data: '):
192
- chunk = chunk[len(b'data: ') :]
193
- yield json.loads(str(chunk, 'utf-8'))
195
+ if not isinstance(chunk, str):
196
+ chunk = chunk.decode('utf-8')
197
+ if chunk.startswith('data: '):
198
+ chunk = chunk[len('data: ') :]
199
+ yield json.loads(chunk)
194
200
 
195
201
  async def async_segments(self) -> AsyncIterator[Any]:
196
202
  if isinstance(self.response_stream, list):
@@ -198,7 +204,7 @@ class HttpResponse:
198
204
  for chunk in self.response_stream:
199
205
  yield json.loads(chunk) if chunk else {}
200
206
  elif self.response_stream is None:
201
- async for c in []:
207
+ async for c in []: # type: ignore[attr-defined]
202
208
  yield c
203
209
  else:
204
210
  # Iterator of objects retrieved from the API.
@@ -206,15 +212,15 @@ class HttpResponse:
206
212
  async for chunk in self.response_stream.aiter_lines():
207
213
  # This is httpx.Response.
208
214
  if chunk:
209
- # In async streaming mode, the chunk of JSON is prefixed with "data:"
210
- # which we must strip before parsing.
215
+ # In async streaming mode, the chunk of JSON is prefixed with
216
+ # "data:" which we must strip before parsing.
217
+ if not isinstance(chunk, str):
218
+ chunk = chunk.decode('utf-8')
211
219
  if chunk.startswith('data: '):
212
220
  chunk = chunk[len('data: ') :]
213
221
  yield json.loads(chunk)
214
222
  else:
215
- raise ValueError(
216
- 'Error parsing streaming response.'
217
- )
223
+ raise ValueError('Error parsing streaming response.')
218
224
 
219
225
  def byte_segments(self):
220
226
  if isinstance(self.byte_stream, list):
@@ -234,6 +240,41 @@ class HttpResponse:
234
240
  response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
235
241
 
236
242
 
243
+ class SyncHttpxClient(httpx.Client):
244
+ """Sync httpx client."""
245
+
246
+ def __init__(self, **kwargs: Any) -> None:
247
+ """Initializes the httpx client."""
248
+ kwargs.setdefault('follow_redirects', True)
249
+ super().__init__(**kwargs)
250
+
251
+ def __del__(self) -> None:
252
+ """Closes the httpx client."""
253
+ if self.is_closed:
254
+ return
255
+ try:
256
+ self.close()
257
+ except Exception:
258
+ pass
259
+
260
+
261
+ class AsyncHttpxClient(httpx.AsyncClient):
262
+ """Async httpx client."""
263
+
264
+ def __init__(self, **kwargs: Any) -> None:
265
+ """Initializes the httpx client."""
266
+ kwargs.setdefault('follow_redirects', True)
267
+ super().__init__(**kwargs)
268
+
269
+ def __del__(self) -> None:
270
+ if self.is_closed:
271
+ return
272
+ try:
273
+ asyncio.get_running_loop().create_task(self.aclose())
274
+ except Exception:
275
+ pass
276
+
277
+
237
278
  class BaseApiClient:
238
279
  """Client for calling HTTP APIs sending and receiving JSON."""
239
280
 
@@ -269,16 +310,14 @@ class BaseApiClient:
269
310
  )
270
311
 
271
312
  # Validate http_options if it is provided.
272
- validated_http_options: dict[str, Any]
313
+ validated_http_options = HttpOptions()
273
314
  if isinstance(http_options, dict):
274
315
  try:
275
- validated_http_options = HttpOptions.model_validate(
276
- http_options
277
- ).model_dump()
316
+ validated_http_options = HttpOptions.model_validate(http_options)
278
317
  except ValidationError as e:
279
318
  raise ValueError(f'Invalid http_options: {e}')
280
319
  elif isinstance(http_options, HttpOptions):
281
- validated_http_options = http_options.model_dump()
320
+ validated_http_options = http_options
282
321
 
283
322
  # Retrieve implicitly set values from the environment.
284
323
  env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
@@ -289,11 +328,15 @@ class BaseApiClient:
289
328
  self.api_key = api_key or env_api_key
290
329
 
291
330
  self._credentials = credentials
292
- self._http_options = HttpOptionsDict()
331
+ self._http_options = HttpOptions()
293
332
  # Initialize the lock. This lock will be used to protect access to the
294
333
  # credentials. This is crucial for thread safety when multiple coroutines
295
334
  # might be accessing the credentials at the same time.
296
- self._auth_lock = asyncio.Lock()
335
+ try:
336
+ self._auth_lock = asyncio.Lock()
337
+ except RuntimeError:
338
+ asyncio.set_event_loop(asyncio.new_event_loop())
339
+ self._auth_lock = asyncio.Lock()
297
340
 
298
341
  # Handle when to use Vertex AI in express mode (api key).
299
342
  # Explicit initializer arguments are already validated above.
@@ -337,12 +380,12 @@ class BaseApiClient:
337
380
  'AI API.'
338
381
  )
339
382
  if self.api_key or self.location == 'global':
340
- self._http_options['base_url'] = f'https://aiplatform.googleapis.com/'
383
+ self._http_options.base_url = f'https://aiplatform.googleapis.com/'
341
384
  else:
342
- self._http_options['base_url'] = (
385
+ self._http_options.base_url = (
343
386
  f'https://{self.location}-aiplatform.googleapis.com/'
344
387
  )
345
- self._http_options['api_version'] = 'v1beta1'
388
+ self._http_options.api_version = 'v1beta1'
346
389
  else: # Implicit initialization or missing arguments.
347
390
  if not self.api_key:
348
391
  raise ValueError(
@@ -350,24 +393,27 @@ class BaseApiClient:
350
393
  'provide (`api_key`) arguments. To use the Google Cloud API,'
351
394
  ' provide (`vertexai`, `project` & `location`) arguments.'
352
395
  )
353
- self._http_options['base_url'] = (
354
- 'https://generativelanguage.googleapis.com/'
355
- )
356
- self._http_options['api_version'] = 'v1beta'
396
+ self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
397
+ self._http_options.api_version = 'v1beta'
357
398
  # Default options for both clients.
358
- self._http_options['headers'] = {'Content-Type': 'application/json'}
399
+ self._http_options.headers = {'Content-Type': 'application/json'}
359
400
  if self.api_key:
360
- self._http_options['headers']['x-goog-api-key'] = self.api_key
401
+ if self._http_options.headers is not None:
402
+ self._http_options.headers['x-goog-api-key'] = self.api_key
361
403
  # Update the http options with the user provided http options.
362
404
  if http_options:
363
405
  self._http_options = _patch_http_options(
364
406
  self._http_options, validated_http_options
365
407
  )
366
408
  else:
367
- _append_library_version_headers(self._http_options['headers'])
409
+ if self._http_options.headers is not None:
410
+ _append_library_version_headers(self._http_options.headers)
411
+ # Initialize the httpx client.
412
+ self._httpx_client = SyncHttpxClient()
413
+ self._async_httpx_client = AsyncHttpxClient()
368
414
 
369
415
  def _websocket_base_url(self):
370
- url_parts = urlparse(self._http_options['base_url'])
416
+ url_parts = urlparse(self._http_options.base_url)
371
417
  return url_parts._replace(scheme='wss').geturl()
372
418
 
373
419
  def _access_token(self) -> str:
@@ -378,9 +424,7 @@ class BaseApiClient:
378
424
  self.project = project
379
425
 
380
426
  if self._credentials:
381
- if (
382
- self._credentials.expired or not self._credentials.token
383
- ):
427
+ if self._credentials.expired or not self._credentials.token:
384
428
  # Only refresh when it needs to. Default expiration is 3600 seconds.
385
429
  _refresh_auth(self._credentials)
386
430
  if not self._credentials.token:
@@ -404,9 +448,7 @@ class BaseApiClient:
404
448
  self.project = project
405
449
 
406
450
  if self._credentials:
407
- if (
408
- self._credentials.expired or not self._credentials.token
409
- ):
451
+ if self._credentials.expired or not self._credentials.token:
410
452
  # Only refresh when it needs to. Default expiration is 3600 seconds.
411
453
  async with self._auth_lock:
412
454
  if self._credentials.expired or not self._credentials.token:
@@ -435,11 +477,12 @@ class BaseApiClient:
435
477
  if http_options:
436
478
  if isinstance(http_options, HttpOptions):
437
479
  patched_http_options = _patch_http_options(
438
- self._http_options, http_options.model_dump()
480
+ self._http_options,
481
+ http_options,
439
482
  )
440
483
  else:
441
484
  patched_http_options = _patch_http_options(
442
- self._http_options, http_options
485
+ self._http_options, HttpOptions.model_validate(http_options)
443
486
  )
444
487
  else:
445
488
  patched_http_options = self._http_options
@@ -458,13 +501,27 @@ class BaseApiClient:
458
501
  and not self.api_key
459
502
  ):
460
503
  path = f'projects/{self.project}/locations/{self.location}/' + path
504
+
505
+ if patched_http_options.api_version is None:
506
+ versioned_path = f'/{path}'
507
+ else:
508
+ versioned_path = f'{patched_http_options.api_version}/{path}'
509
+
510
+ if (
511
+ patched_http_options.base_url is None
512
+ or not patched_http_options.base_url
513
+ ):
514
+ raise ValueError('Base URL must be set.')
515
+ else:
516
+ base_url = patched_http_options.base_url
517
+
461
518
  url = _join_url_path(
462
- patched_http_options.get('base_url', ''),
463
- patched_http_options.get('api_version', '') + '/' + path,
519
+ base_url,
520
+ versioned_path,
464
521
  )
465
522
 
466
- timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
467
- 'timeout', None
523
+ timeout_in_seconds: Optional[Union[float, int]] = (
524
+ patched_http_options.timeout
468
525
  )
469
526
  if timeout_in_seconds:
470
527
  # HttpOptions.timeout is in milliseconds. But httpx.Client.request()
@@ -473,10 +530,12 @@ class BaseApiClient:
473
530
  else:
474
531
  timeout_in_seconds = None
475
532
 
533
+ if patched_http_options.headers is None:
534
+ raise ValueError('Request headers must be set.')
476
535
  return HttpRequest(
477
536
  method=http_method,
478
537
  url=url,
479
- headers=patched_http_options['headers'],
538
+ headers=patched_http_options.headers,
480
539
  data=request_dict,
481
540
  timeout=timeout_in_seconds,
482
541
  )
@@ -488,48 +547,44 @@ class BaseApiClient:
488
547
  ) -> HttpResponse:
489
548
  data: Optional[Union[str, bytes]] = None
490
549
  if self.vertexai and not self.api_key:
491
- http_request.headers['Authorization'] = (
492
- f'Bearer {self._access_token()}'
493
- )
550
+ http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
494
551
  if self._credentials and self._credentials.quota_project_id:
495
552
  http_request.headers['x-goog-user-project'] = (
496
553
  self._credentials.quota_project_id
497
554
  )
498
- data = json.dumps(http_request.data)
555
+ data = json.dumps(http_request.data) if http_request.data else None
499
556
  else:
500
557
  if http_request.data:
501
558
  if not isinstance(http_request.data, bytes):
502
- data = json.dumps(http_request.data)
559
+ data = json.dumps(http_request.data) if http_request.data else None
503
560
  else:
504
561
  data = http_request.data
505
562
 
506
563
  if stream:
507
- client = httpx.Client()
508
- httpx_request = client.build_request(
564
+ httpx_request = self._httpx_client.build_request(
509
565
  method=http_request.method,
510
566
  url=http_request.url,
511
567
  content=data,
512
568
  headers=http_request.headers,
513
569
  timeout=http_request.timeout,
514
570
  )
515
- response = client.send(httpx_request, stream=stream)
571
+ response = self._httpx_client.send(httpx_request, stream=stream)
516
572
  errors.APIError.raise_for_response(response)
517
573
  return HttpResponse(
518
574
  response.headers, response if stream else [response.text]
519
575
  )
520
576
  else:
521
- with httpx.Client() as client:
522
- response = client.request(
523
- method=http_request.method,
524
- url=http_request.url,
525
- headers=http_request.headers,
526
- content=data,
527
- timeout=http_request.timeout,
528
- )
529
- errors.APIError.raise_for_response(response)
530
- return HttpResponse(
531
- response.headers, response if stream else [response.text]
532
- )
577
+ response = self._httpx_client.request(
578
+ method=http_request.method,
579
+ url=http_request.url,
580
+ headers=http_request.headers,
581
+ content=data,
582
+ timeout=http_request.timeout,
583
+ )
584
+ errors.APIError.raise_for_response(response)
585
+ return HttpResponse(
586
+ response.headers, response if stream else [response.text]
587
+ )
533
588
 
534
589
  async def _async_request(
535
590
  self, http_request: HttpRequest, stream: bool = False
@@ -543,50 +598,48 @@ class BaseApiClient:
543
598
  http_request.headers['x-goog-user-project'] = (
544
599
  self._credentials.quota_project_id
545
600
  )
546
- data = json.dumps(http_request.data)
601
+ data = json.dumps(http_request.data) if http_request.data else None
547
602
  else:
548
603
  if http_request.data:
549
604
  if not isinstance(http_request.data, bytes):
550
- data = json.dumps(http_request.data)
605
+ data = json.dumps(http_request.data) if http_request.data else None
551
606
  else:
552
607
  data = http_request.data
553
608
 
554
609
  if stream:
555
- aclient = httpx.AsyncClient()
556
- httpx_request = aclient.build_request(
610
+ httpx_request = self._async_httpx_client.build_request(
557
611
  method=http_request.method,
558
612
  url=http_request.url,
559
613
  content=data,
560
614
  headers=http_request.headers,
561
615
  timeout=http_request.timeout,
562
616
  )
563
- response = await aclient.send(
617
+ response = await self._async_httpx_client.send(
564
618
  httpx_request,
565
619
  stream=stream,
566
620
  )
567
- errors.APIError.raise_for_response(response)
621
+ await errors.APIError.raise_for_async_response(response)
568
622
  return HttpResponse(
569
623
  response.headers, response if stream else [response.text]
570
624
  )
571
625
  else:
572
- async with httpx.AsyncClient() as aclient:
573
- response = await aclient.request(
574
- method=http_request.method,
575
- url=http_request.url,
576
- headers=http_request.headers,
577
- content=data,
578
- timeout=http_request.timeout,
579
- )
580
- errors.APIError.raise_for_response(response)
581
- return HttpResponse(
582
- response.headers, response if stream else [response.text]
583
- )
626
+ response = await self._async_httpx_client.request(
627
+ method=http_request.method,
628
+ url=http_request.url,
629
+ headers=http_request.headers,
630
+ content=data,
631
+ timeout=http_request.timeout,
632
+ )
633
+ await errors.APIError.raise_for_async_response(response)
634
+ return HttpResponse(
635
+ response.headers, response if stream else [response.text]
636
+ )
584
637
 
585
- def get_read_only_http_options(self) -> HttpOptionsDict:
586
- copied = HttpOptionsDict()
638
+ def get_read_only_http_options(self) -> dict[str, Any]:
587
639
  if isinstance(self._http_options, BaseModel):
588
- self._http_options = self._http_options.model_dump()
589
- copied.update(self._http_options)
640
+ copied = self._http_options.model_dump()
641
+ else:
642
+ copied = self._http_options
590
643
  return copied
591
644
 
592
645
  def request(
@@ -612,7 +665,7 @@ class BaseApiClient:
612
665
  http_method: str,
613
666
  path: str,
614
667
  request_dict: dict[str, object],
615
- http_options: Optional[HttpOptionsDict] = None,
668
+ http_options: Optional[HttpOptionsOrDict] = None,
616
669
  ):
617
670
  http_request = self._build_request(
618
671
  http_method, path, request_dict, http_options
@@ -644,7 +697,7 @@ class BaseApiClient:
644
697
  http_method: str,
645
698
  path: str,
646
699
  request_dict: dict[str, object],
647
- http_options: Optional[HttpOptionsDict] = None,
700
+ http_options: Optional[HttpOptionsOrDict] = None,
648
701
  ):
649
702
  http_request = self._build_request(
650
703
  http_method, path, request_dict, http_options
@@ -660,7 +713,7 @@ class BaseApiClient:
660
713
 
661
714
  def upload_file(
662
715
  self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
663
- ) -> dict[str, str]:
716
+ ) -> HttpResponse:
664
717
  """Transfers a file to the given URL.
665
718
 
666
719
  Args:
@@ -672,7 +725,7 @@ class BaseApiClient:
672
725
  match the size requested in the resumable upload request.
673
726
 
674
727
  returns:
675
- The response json object from the finalize request.
728
+ The HttpResponse object from the finalize request.
676
729
  """
677
730
  if isinstance(file_path, io.IOBase):
678
731
  return self._upload_fd(file_path, upload_url, upload_size)
@@ -682,7 +735,7 @@ class BaseApiClient:
682
735
 
683
736
  def _upload_fd(
684
737
  self, file: io.IOBase, upload_url: str, upload_size: int
685
- ) -> dict[str, str]:
738
+ ) -> HttpResponse:
686
739
  """Transfers a file to the given URL.
687
740
 
688
741
  Args:
@@ -692,7 +745,7 @@ class BaseApiClient:
692
745
  match the size requested in the resumable upload request.
693
746
 
694
747
  returns:
695
- The response json object from the finalize request.
748
+ The HttpResponse object from the finalize request.
696
749
  """
697
750
  offset = 0
698
751
  # Upload the file in chunks
@@ -705,7 +758,7 @@ class BaseApiClient:
705
758
  # If last chunk, finalize the upload.
706
759
  if chunk_size + offset >= upload_size:
707
760
  upload_command += ', finalize'
708
- request = HttpRequest(
761
+ response = self._httpx_client.request(
709
762
  method='POST',
710
763
  url=upload_url,
711
764
  headers={
@@ -713,25 +766,22 @@ class BaseApiClient:
713
766
  'X-Goog-Upload-Offset': str(offset),
714
767
  'Content-Length': str(chunk_size),
715
768
  },
716
- data=file_chunk,
769
+ content=file_chunk,
717
770
  )
718
-
719
- response = self._request(request, stream=False)
720
771
  offset += chunk_size
721
- if response.headers['X-Goog-Upload-Status'] != 'active':
772
+ if response.headers['x-goog-upload-status'] != 'active':
722
773
  break # upload is complete or it has been interrupted.
723
-
724
774
  if upload_size <= offset: # Status is not finalized.
725
775
  raise ValueError(
726
- 'All content has been uploaded, but the upload status is not'
776
+ f'All content has been uploaded, but the upload status is not'
727
777
  f' finalized.'
728
778
  )
729
779
 
730
- if response.headers['X-Goog-Upload-Status'] != 'final':
780
+ if response.headers['x-goog-upload-status'] != 'final':
731
781
  raise ValueError(
732
782
  'Failed to upload file: Upload status is not finalized.'
733
783
  )
734
- return response.json
784
+ return HttpResponse(response.headers, response_stream=[response.text])
735
785
 
736
786
  def download_file(self, path: str, http_options):
737
787
  """Downloads the file data.
@@ -746,12 +796,7 @@ class BaseApiClient:
746
796
  http_request = self._build_request(
747
797
  'get', path=path, request_dict={}, http_options=http_options
748
798
  )
749
- return self._download_file_request(http_request).byte_stream[0]
750
799
 
751
- def _download_file_request(
752
- self,
753
- http_request: HttpRequest,
754
- ) -> HttpResponse:
755
800
  data: Optional[Union[str, bytes]] = None
756
801
  if http_request.data:
757
802
  if not isinstance(http_request.data, bytes):
@@ -759,24 +804,25 @@ class BaseApiClient:
759
804
  else:
760
805
  data = http_request.data
761
806
 
762
- with httpx.Client(follow_redirects=True) as client:
763
- response = client.request(
764
- method=http_request.method,
765
- url=http_request.url,
766
- headers=http_request.headers,
767
- content=data,
768
- timeout=http_request.timeout,
769
- )
807
+ response = self._httpx_client.request(
808
+ method=http_request.method,
809
+ url=http_request.url,
810
+ headers=http_request.headers,
811
+ content=data,
812
+ timeout=http_request.timeout,
813
+ )
770
814
 
771
- errors.APIError.raise_for_response(response)
772
- return HttpResponse(response.headers, byte_stream=[response.read()])
815
+ errors.APIError.raise_for_response(response)
816
+ return HttpResponse(
817
+ response.headers, byte_stream=[response.read()]
818
+ ).byte_stream[0]
773
819
 
774
820
  async def async_upload_file(
775
821
  self,
776
822
  file_path: Union[str, io.IOBase],
777
823
  upload_url: str,
778
824
  upload_size: int,
779
- ) -> dict[str, str]:
825
+ ) -> HttpResponse:
780
826
  """Transfers a file asynchronously to the given URL.
781
827
 
782
828
  Args:
@@ -787,7 +833,7 @@ class BaseApiClient:
787
833
  match the size requested in the resumable upload request.
788
834
 
789
835
  returns:
790
- The response json object from the finalize request.
836
+ The HttpResponse object from the finalize request.
791
837
  """
792
838
  if isinstance(file_path, io.IOBase):
793
839
  return await self._async_upload_fd(file_path, upload_url, upload_size)
@@ -802,7 +848,7 @@ class BaseApiClient:
802
848
  file: Union[io.IOBase, anyio.AsyncFile],
803
849
  upload_url: str,
804
850
  upload_size: int,
805
- ) -> dict[str, str]:
851
+ ) -> HttpResponse:
806
852
  """Transfers a file asynchronously to the given URL.
807
853
 
808
854
  Args:
@@ -812,47 +858,46 @@ class BaseApiClient:
812
858
  match the size requested in the resumable upload request.
813
859
 
814
860
  returns:
815
- The response json object from the finalize request.
861
+ The HttpResponse object from the finalized request.
816
862
  """
817
- async with httpx.AsyncClient() as aclient:
818
- offset = 0
819
- # Upload the file in chunks
820
- while True:
821
- if isinstance(file, io.IOBase):
822
- file_chunk = file.read(CHUNK_SIZE)
823
- else:
824
- file_chunk = await file.read(CHUNK_SIZE)
825
- chunk_size = 0
826
- if file_chunk:
827
- chunk_size = len(file_chunk)
828
- upload_command = 'upload'
829
- # If last chunk, finalize the upload.
830
- if chunk_size + offset >= upload_size:
831
- upload_command += ', finalize'
832
- response = await aclient.request(
833
- method='POST',
834
- url=upload_url,
835
- content=file_chunk,
836
- headers={
837
- 'X-Goog-Upload-Command': upload_command,
838
- 'X-Goog-Upload-Offset': str(offset),
839
- 'Content-Length': str(chunk_size),
840
- },
841
- )
842
- offset += chunk_size
843
- if response.headers.get('x-goog-upload-status') != 'active':
844
- break # upload is complete or it has been interrupted.
845
-
846
- if upload_size <= offset: # Status is not finalized.
847
- raise ValueError(
848
- 'All content has been uploaded, but the upload status is not'
849
- f' finalized.'
850
- )
851
- if response.headers.get('x-goog-upload-status') != 'final':
863
+ offset = 0
864
+ # Upload the file in chunks
865
+ while True:
866
+ if isinstance(file, io.IOBase):
867
+ file_chunk = file.read(CHUNK_SIZE)
868
+ else:
869
+ file_chunk = await file.read(CHUNK_SIZE)
870
+ chunk_size = 0
871
+ if file_chunk:
872
+ chunk_size = len(file_chunk)
873
+ upload_command = 'upload'
874
+ # If last chunk, finalize the upload.
875
+ if chunk_size + offset >= upload_size:
876
+ upload_command += ', finalize'
877
+ response = await self._async_httpx_client.request(
878
+ method='POST',
879
+ url=upload_url,
880
+ content=file_chunk,
881
+ headers={
882
+ 'X-Goog-Upload-Command': upload_command,
883
+ 'X-Goog-Upload-Offset': str(offset),
884
+ 'Content-Length': str(chunk_size),
885
+ },
886
+ )
887
+ offset += chunk_size
888
+ if response.headers.get('x-goog-upload-status') != 'active':
889
+ break # upload is complete or it has been interrupted.
890
+
891
+ if upload_size <= offset: # Status is not finalized.
852
892
  raise ValueError(
853
- 'Failed to upload file: Upload status is not finalized.'
893
+ 'All content has been uploaded, but the upload status is not'
894
+ f' finalized.'
854
895
  )
855
- return response.json()
896
+ if response.headers.get('x-goog-upload-status') != 'final':
897
+ raise ValueError(
898
+ 'Failed to upload file: Upload status is not finalized.'
899
+ )
900
+ return HttpResponse(response.headers, response_stream=[response.text])
856
901
 
857
902
  async def async_download_file(self, path: str, http_options):
858
903
  """Downloads the file data.
@@ -875,19 +920,18 @@ class BaseApiClient:
875
920
  else:
876
921
  data = http_request.data
877
922
 
878
- async with httpx.AsyncClient(follow_redirects=True) as aclient:
879
- response = await aclient.request(
880
- method=http_request.method,
881
- url=http_request.url,
882
- headers=http_request.headers,
883
- content=data,
884
- timeout=http_request.timeout,
885
- )
886
- errors.APIError.raise_for_response(response)
923
+ response = await self._async_httpx_client.request(
924
+ method=http_request.method,
925
+ url=http_request.url,
926
+ headers=http_request.headers,
927
+ content=data,
928
+ timeout=http_request.timeout,
929
+ )
930
+ await errors.APIError.raise_for_async_response(response)
887
931
 
888
- return HttpResponse(
889
- response.headers, byte_stream=[response.read()]
890
- ).byte_stream[0]
932
+ return HttpResponse(
933
+ response.headers, byte_stream=[response.read()]
934
+ ).byte_stream[0]
891
935
 
892
936
  # This method does nothing in the real api client. It is used in the
893
937
  # replay_api_client to verify the response from the SDK method matches the