google-genai 1.5.0__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,16 +29,14 @@ import json
29
29
  import logging
30
30
  import os
31
31
  import sys
32
- from typing import Any, AsyncIterator, Optional, Tuple, TypedDict, Union
32
+ from typing import Any, AsyncIterator, Optional, Tuple, Union
33
33
  from urllib.parse import urlparse, urlunparse
34
34
  import google.auth
35
35
  import google.auth.credentials
36
36
  from google.auth.credentials import Credentials
37
- from google.auth.transport.requests import AuthorizedSession
38
37
  from google.auth.transport.requests import Request
39
38
  import httpx
40
- from pydantic import BaseModel, ConfigDict, Field, ValidationError
41
- import requests
39
+ from pydantic import BaseModel, Field, ValidationError
42
40
  from . import _common
43
41
  from . import errors
44
42
  from . import version
@@ -88,7 +86,8 @@ def _patch_http_options(
88
86
  copy_option[patch_key].update(patch_value)
89
87
  elif patch_value is not None: # Accept empty values.
90
88
  copy_option[patch_key] = patch_value
91
- _append_library_version_headers(copy_option['headers'])
89
+ if copy_option['headers']:
90
+ _append_library_version_headers(copy_option['headers'])
92
91
  return copy_option
93
92
 
94
93
 
@@ -103,7 +102,7 @@ def _join_url_path(base_url: str, path: str) -> str:
103
102
  return urlunparse(parsed_base._replace(path=base_path + '/' + path))
104
103
 
105
104
 
106
- def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
105
+ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
107
106
  """Loads google auth credentials and project id."""
108
107
  credentials, loaded_project_id = google.auth.default(
109
108
  scopes=['https://www.googleapis.com/auth/cloud-platform'],
@@ -189,9 +188,11 @@ class HttpResponse:
189
188
  if chunk:
190
189
  # In streaming mode, the chunk of JSON is prefixed with "data:" which
191
190
  # we must strip before parsing.
192
- if chunk.startswith(b'data: '):
193
- chunk = chunk[len(b'data: ') :]
194
- yield json.loads(str(chunk, 'utf-8'))
191
+ if not isinstance(chunk, str):
192
+ chunk = chunk.decode('utf-8')
193
+ if chunk.startswith('data: '):
194
+ chunk = chunk[len('data: ') :]
195
+ yield json.loads(chunk)
195
196
 
196
197
  async def async_segments(self) -> AsyncIterator[Any]:
197
198
  if isinstance(self.response_stream, list):
@@ -207,8 +208,10 @@ class HttpResponse:
207
208
  async for chunk in self.response_stream.aiter_lines():
208
209
  # This is httpx.Response.
209
210
  if chunk:
210
- # In async streaming mode, the chunk of JSON is prefixed with "data:"
211
- # which we must strip before parsing.
211
+ # In async streaming mode, the chunk of JSON is prefixed with
212
+ # "data:" which we must strip before parsing.
213
+ if not isinstance(chunk, str):
214
+ chunk = chunk.decode('utf-8')
212
215
  if chunk.startswith('data: '):
213
216
  chunk = chunk[len('data: ') :]
214
217
  yield json.loads(chunk)
@@ -235,6 +238,41 @@ class HttpResponse:
235
238
  response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
236
239
 
237
240
 
241
+ class SyncHttpxClient(httpx.Client):
242
+ """Sync httpx client."""
243
+
244
+ def __init__(self, **kwargs: Any) -> None:
245
+ """Initializes the httpx client."""
246
+ kwargs.setdefault('follow_redirects', True)
247
+ super().__init__(**kwargs)
248
+
249
+ def __del__(self) -> None:
250
+ """Closes the httpx client."""
251
+ if self.is_closed:
252
+ return
253
+ try:
254
+ self.close()
255
+ except Exception:
256
+ pass
257
+
258
+
259
+ class AsyncHttpxClient(httpx.AsyncClient):
260
+ """Async httpx client."""
261
+
262
+ def __init__(self, **kwargs: Any) -> None:
263
+ """Initializes the httpx client."""
264
+ kwargs.setdefault('follow_redirects', True)
265
+ super().__init__(**kwargs)
266
+
267
+ def __del__(self) -> None:
268
+ if self.is_closed:
269
+ return
270
+ try:
271
+ asyncio.get_running_loop().create_task(self.aclose())
272
+ except Exception:
273
+ pass
274
+
275
+
238
276
  class BaseApiClient:
239
277
  """Client for calling HTTP APIs sending and receiving JSON."""
240
278
 
@@ -273,7 +311,9 @@ class BaseApiClient:
273
311
  validated_http_options: dict[str, Any]
274
312
  if isinstance(http_options, dict):
275
313
  try:
276
- validated_http_options = HttpOptions.model_validate(http_options).model_dump()
314
+ validated_http_options = HttpOptions.model_validate(
315
+ http_options
316
+ ).model_dump()
277
317
  except ValidationError as e:
278
318
  raise ValueError(f'Invalid http_options: {e}')
279
319
  elif isinstance(http_options, HttpOptions):
@@ -359,16 +399,40 @@ class BaseApiClient:
359
399
  self._http_options['headers']['x-goog-api-key'] = self.api_key
360
400
  # Update the http options with the user provided http options.
361
401
  if http_options:
362
- self._http_options = _patch_http_options(self._http_options, validated_http_options)
402
+ self._http_options = _patch_http_options(
403
+ self._http_options, validated_http_options
404
+ )
363
405
  else:
364
406
  _append_library_version_headers(self._http_options['headers'])
407
+ # Initialize the httpx client.
408
+ self._httpx_client = SyncHttpxClient()
409
+ self._async_httpx_client = AsyncHttpxClient()
365
410
 
366
411
  def _websocket_base_url(self):
367
412
  url_parts = urlparse(self._http_options['base_url'])
368
413
  return url_parts._replace(scheme='wss').geturl()
369
414
 
370
- async def _async_access_token(self) -> str:
415
+ def _access_token(self) -> str:
371
416
  """Retrieves the access token for the credentials."""
417
+ if not self._credentials:
418
+ self._credentials, project = _load_auth(project=self.project)
419
+ if not self.project:
420
+ self.project = project
421
+
422
+ if self._credentials:
423
+ if (
424
+ self._credentials.expired or not self._credentials.token
425
+ ):
426
+ # Only refresh when it needs to. Default expiration is 3600 seconds.
427
+ _refresh_auth(self._credentials)
428
+ if not self._credentials.token:
429
+ raise RuntimeError('Could not resolve API token from the environment')
430
+ return self._credentials.token
431
+ else:
432
+ raise RuntimeError('Could not resolve API token from the environment')
433
+
434
+ async def _async_access_token(self) -> str:
435
+ """Retrieves the access token for the credentials asynchronously."""
372
436
  if not self._credentials:
373
437
  async with self._auth_lock:
374
438
  # This ensures that only one coroutine can execute the auth logic at a
@@ -437,8 +501,8 @@ class BaseApiClient:
437
501
  ):
438
502
  path = f'projects/{self.project}/locations/{self.location}/' + path
439
503
  url = _join_url_path(
440
- patched_http_options['base_url'],
441
- patched_http_options['api_version'] + '/' + path,
504
+ patched_http_options.get('base_url', ''),
505
+ patched_http_options.get('api_version', '') + '/' + path,
442
506
  )
443
507
 
444
508
  timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
@@ -464,59 +528,54 @@ class BaseApiClient:
464
528
  http_request: HttpRequest,
465
529
  stream: bool = False,
466
530
  ) -> HttpResponse:
531
+ data: Optional[Union[str, bytes]] = None
467
532
  if self.vertexai and not self.api_key:
468
- if not self._credentials:
469
- self._credentials, _ = _load_auth(project=self.project)
470
- if self._credentials.quota_project_id:
533
+ http_request.headers['Authorization'] = (
534
+ f'Bearer {self._access_token()}'
535
+ )
536
+ if self._credentials and self._credentials.quota_project_id:
471
537
  http_request.headers['x-goog-user-project'] = (
472
538
  self._credentials.quota_project_id
473
539
  )
474
- authed_session = AuthorizedSession(self._credentials)
475
- authed_session.stream = stream
476
- response = authed_session.request(
477
- http_request.method.upper(),
478
- http_request.url,
540
+ data = json.dumps(http_request.data) if http_request.data else None
541
+ else:
542
+ if http_request.data:
543
+ if not isinstance(http_request.data, bytes):
544
+ data = json.dumps(http_request.data) if http_request.data else None
545
+ else:
546
+ data = http_request.data
547
+
548
+ if stream:
549
+ httpx_request = self._httpx_client.build_request(
550
+ method=http_request.method,
551
+ url=http_request.url,
552
+ content=data,
479
553
  headers=http_request.headers,
480
- data=json.dumps(http_request.data) if http_request.data else None,
481
554
  timeout=http_request.timeout,
482
555
  )
556
+ response = self._httpx_client.send(httpx_request, stream=stream)
483
557
  errors.APIError.raise_for_response(response)
484
558
  return HttpResponse(
485
559
  response.headers, response if stream else [response.text]
486
560
  )
487
561
  else:
488
- return self._request_unauthorized(http_request, stream)
489
-
490
- def _request_unauthorized(
491
- self,
492
- http_request: HttpRequest,
493
- stream: bool = False,
494
- ) -> HttpResponse:
495
- data: Optional[Union[str, bytes]] = None
496
- if http_request.data:
497
- if not isinstance(http_request.data, bytes):
498
- data = json.dumps(http_request.data)
499
- else:
500
- data = http_request.data
501
-
502
- http_session = requests.Session()
503
- response = http_session.request(
504
- method=http_request.method,
505
- url=http_request.url,
506
- headers=http_request.headers,
507
- data=data,
508
- timeout=http_request.timeout,
509
- stream=stream,
510
- )
511
- errors.APIError.raise_for_response(response)
512
- return HttpResponse(
513
- response.headers, response if stream else [response.text]
514
- )
562
+ response = self._httpx_client.request(
563
+ method=http_request.method,
564
+ url=http_request.url,
565
+ headers=http_request.headers,
566
+ content=data,
567
+ timeout=http_request.timeout,
568
+ )
569
+ errors.APIError.raise_for_response(response)
570
+ return HttpResponse(
571
+ response.headers, response if stream else [response.text]
572
+ )
515
573
 
516
574
  async def _async_request(
517
575
  self, http_request: HttpRequest, stream: bool = False
518
576
  ):
519
- if self.vertexai:
577
+ data: Optional[Union[str, bytes]] = None
578
+ if self.vertexai and not self.api_key:
520
579
  http_request.headers['Authorization'] = (
521
580
  f'Bearer {await self._async_access_token()}'
522
581
  )
@@ -524,16 +583,23 @@ class BaseApiClient:
524
583
  http_request.headers['x-goog-user-project'] = (
525
584
  self._credentials.quota_project_id
526
585
  )
586
+ data = json.dumps(http_request.data) if http_request.data else None
587
+ else:
588
+ if http_request.data:
589
+ if not isinstance(http_request.data, bytes):
590
+ data = json.dumps(http_request.data) if http_request.data else None
591
+ else:
592
+ data = http_request.data
593
+
527
594
  if stream:
528
- aclient = httpx.AsyncClient()
529
- httpx_request = aclient.build_request(
595
+ httpx_request = self._async_httpx_client.build_request(
530
596
  method=http_request.method,
531
597
  url=http_request.url,
532
- content=json.dumps(http_request.data),
598
+ content=data,
533
599
  headers=http_request.headers,
534
600
  timeout=http_request.timeout,
535
601
  )
536
- response = await aclient.send(
602
+ response = await self._async_httpx_client.send(
537
603
  httpx_request,
538
604
  stream=stream,
539
605
  )
@@ -542,18 +608,17 @@ class BaseApiClient:
542
608
  response.headers, response if stream else [response.text]
543
609
  )
544
610
  else:
545
- async with httpx.AsyncClient() as aclient:
546
- response = await aclient.request(
547
- method=http_request.method,
548
- url=http_request.url,
549
- headers=http_request.headers,
550
- content=json.dumps(http_request.data) if http_request.data else None,
551
- timeout=http_request.timeout,
552
- )
553
- errors.APIError.raise_for_response(response)
554
- return HttpResponse(
555
- response.headers, response if stream else [response.text]
556
- )
611
+ response = await self._async_httpx_client.request(
612
+ method=http_request.method,
613
+ url=http_request.url,
614
+ headers=http_request.headers,
615
+ content=data,
616
+ timeout=http_request.timeout,
617
+ )
618
+ errors.APIError.raise_for_response(response)
619
+ return HttpResponse(
620
+ response.headers, response if stream else [response.text]
621
+ )
557
622
 
558
623
  def get_read_only_http_options(self) -> HttpOptionsDict:
559
624
  copied = HttpOptionsDict()
@@ -633,7 +698,7 @@ class BaseApiClient:
633
698
 
634
699
  def upload_file(
635
700
  self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
636
- ) -> str:
701
+ ) -> dict[str, str]:
637
702
  """Transfers a file to the given URL.
638
703
 
639
704
  Args:
@@ -655,7 +720,7 @@ class BaseApiClient:
655
720
 
656
721
  def _upload_fd(
657
722
  self, file: io.IOBase, upload_url: str, upload_size: int
658
- ) -> str:
723
+ ) -> dict[str, str]:
659
724
  """Transfers a file to the given URL.
660
725
 
661
726
  Args:
@@ -678,7 +743,7 @@ class BaseApiClient:
678
743
  # If last chunk, finalize the upload.
679
744
  if chunk_size + offset >= upload_size:
680
745
  upload_command += ', finalize'
681
- request = HttpRequest(
746
+ response = self._httpx_client.request(
682
747
  method='POST',
683
748
  url=upload_url,
684
749
  headers={
@@ -686,25 +751,22 @@ class BaseApiClient:
686
751
  'X-Goog-Upload-Offset': str(offset),
687
752
  'Content-Length': str(chunk_size),
688
753
  },
689
- data=file_chunk,
754
+ content=file_chunk,
690
755
  )
691
-
692
- response = self._request_unauthorized(request, stream=False)
693
756
  offset += chunk_size
694
- if response.headers['X-Goog-Upload-Status'] != 'active':
757
+ if response.headers['x-goog-upload-status'] != 'active':
695
758
  break # upload is complete or it has been interrupted.
696
-
697
759
  if upload_size <= offset: # Status is not finalized.
698
760
  raise ValueError(
699
761
  'All content has been uploaded, but the upload status is not'
700
762
  f' finalized.'
701
763
  )
702
764
 
703
- if response.headers['X-Goog-Upload-Status'] != 'final':
765
+ if response.headers['x-goog-upload-status'] != 'final':
704
766
  raise ValueError(
705
767
  'Failed to upload file: Upload status is not finalized.'
706
768
  )
707
- return response.json
769
+ return response.json()
708
770
 
709
771
  def download_file(self, path: str, http_options):
710
772
  """Downloads the file data.
@@ -719,12 +781,7 @@ class BaseApiClient:
719
781
  http_request = self._build_request(
720
782
  'get', path=path, request_dict={}, http_options=http_options
721
783
  )
722
- return self._download_file_request(http_request).byte_stream[0]
723
784
 
724
- def _download_file_request(
725
- self,
726
- http_request: HttpRequest,
727
- ) -> HttpResponse:
728
785
  data: Optional[Union[str, bytes]] = None
729
786
  if http_request.data:
730
787
  if not isinstance(http_request.data, bytes):
@@ -732,25 +789,25 @@ class BaseApiClient:
732
789
  else:
733
790
  data = http_request.data
734
791
 
735
- http_session = requests.Session()
736
- response = http_session.request(
792
+ response = self._httpx_client.request(
737
793
  method=http_request.method,
738
794
  url=http_request.url,
739
795
  headers=http_request.headers,
740
- data=data,
796
+ content=data,
741
797
  timeout=http_request.timeout,
742
- stream=False,
743
798
  )
744
799
 
745
800
  errors.APIError.raise_for_response(response)
746
- return HttpResponse(response.headers, byte_stream=[response.content])
801
+ return HttpResponse(
802
+ response.headers, byte_stream=[response.read()]
803
+ ).byte_stream[0]
747
804
 
748
805
  async def async_upload_file(
749
806
  self,
750
807
  file_path: Union[str, io.IOBase],
751
808
  upload_url: str,
752
809
  upload_size: int,
753
- ) -> str:
810
+ ) -> dict[str, str]:
754
811
  """Transfers a file asynchronously to the given URL.
755
812
 
756
813
  Args:
@@ -776,7 +833,7 @@ class BaseApiClient:
776
833
  file: Union[io.IOBase, anyio.AsyncFile],
777
834
  upload_url: str,
778
835
  upload_size: int,
779
- ) -> str:
836
+ ) -> dict[str, str]:
780
837
  """Transfers a file asynchronously to the given URL.
781
838
 
782
839
  Args:
@@ -788,45 +845,44 @@ class BaseApiClient:
788
845
  returns:
789
846
  The response json object from the finalize request.
790
847
  """
791
- async with httpx.AsyncClient() as aclient:
792
- offset = 0
793
- # Upload the file in chunks
794
- while True:
795
- if isinstance(file, io.IOBase):
796
- file_chunk = file.read(CHUNK_SIZE)
797
- else:
798
- file_chunk = await file.read(CHUNK_SIZE)
799
- chunk_size = 0
800
- if file_chunk:
801
- chunk_size = len(file_chunk)
802
- upload_command = 'upload'
803
- # If last chunk, finalize the upload.
804
- if chunk_size + offset >= upload_size:
805
- upload_command += ', finalize'
806
- response = await aclient.request(
807
- method='POST',
808
- url=upload_url,
809
- content=file_chunk,
810
- headers={
811
- 'X-Goog-Upload-Command': upload_command,
812
- 'X-Goog-Upload-Offset': str(offset),
813
- 'Content-Length': str(chunk_size),
814
- },
815
- )
816
- offset += chunk_size
817
- if response.headers.get('x-goog-upload-status') != 'active':
818
- break # upload is complete or it has been interrupted.
819
-
820
- if upload_size <= offset: # Status is not finalized.
821
- raise ValueError(
822
- 'All content has been uploaded, but the upload status is not'
823
- f' finalized.'
824
- )
825
- if response.headers.get('x-goog-upload-status') != 'final':
848
+ offset = 0
849
+ # Upload the file in chunks
850
+ while True:
851
+ if isinstance(file, io.IOBase):
852
+ file_chunk = file.read(CHUNK_SIZE)
853
+ else:
854
+ file_chunk = await file.read(CHUNK_SIZE)
855
+ chunk_size = 0
856
+ if file_chunk:
857
+ chunk_size = len(file_chunk)
858
+ upload_command = 'upload'
859
+ # If last chunk, finalize the upload.
860
+ if chunk_size + offset >= upload_size:
861
+ upload_command += ', finalize'
862
+ response = await self._async_httpx_client.request(
863
+ method='POST',
864
+ url=upload_url,
865
+ content=file_chunk,
866
+ headers={
867
+ 'X-Goog-Upload-Command': upload_command,
868
+ 'X-Goog-Upload-Offset': str(offset),
869
+ 'Content-Length': str(chunk_size),
870
+ },
871
+ )
872
+ offset += chunk_size
873
+ if response.headers.get('x-goog-upload-status') != 'active':
874
+ break # upload is complete or it has been interrupted.
875
+
876
+ if upload_size <= offset: # Status is not finalized.
826
877
  raise ValueError(
827
- 'Failed to upload file: Upload status is not finalized.'
878
+ 'All content has been uploaded, but the upload status is not'
879
+ f' finalized.'
828
880
  )
829
- return response.json()
881
+ if response.headers.get('x-goog-upload-status') != 'final':
882
+ raise ValueError(
883
+ 'Failed to upload file: Upload status is not finalized.'
884
+ )
885
+ return response.json()
830
886
 
831
887
  async def async_download_file(self, path: str, http_options):
832
888
  """Downloads the file data.
@@ -842,26 +898,25 @@ class BaseApiClient:
842
898
  'get', path=path, request_dict={}, http_options=http_options
843
899
  )
844
900
 
845
- data: Optional[Union[str, bytes]]
901
+ data: Optional[Union[str, bytes]] = None
846
902
  if http_request.data:
847
903
  if not isinstance(http_request.data, bytes):
848
904
  data = json.dumps(http_request.data)
849
905
  else:
850
906
  data = http_request.data
851
907
 
852
- async with httpx.AsyncClient(follow_redirects=True) as aclient:
853
- response = await aclient.request(
854
- method=http_request.method,
855
- url=http_request.url,
856
- headers=http_request.headers,
857
- content=data,
858
- timeout=http_request.timeout,
859
- )
860
- errors.APIError.raise_for_response(response)
908
+ response = await self._async_httpx_client.request(
909
+ method=http_request.method,
910
+ url=http_request.url,
911
+ headers=http_request.headers,
912
+ content=data,
913
+ timeout=http_request.timeout,
914
+ )
915
+ errors.APIError.raise_for_response(response)
861
916
 
862
- return HttpResponse(
863
- response.headers, byte_stream=[response.read()]
864
- ).byte_stream[0]
917
+ return HttpResponse(
918
+ response.headers, byte_stream=[response.read()]
919
+ ).byte_stream[0]
865
920
 
866
921
  # This method does nothing in the real api client. It is used in the
867
922
  # replay_api_client to verify the response from the SDK method matches the
@@ -17,7 +17,7 @@ import inspect
17
17
  import sys
18
18
  import types as builtin_types
19
19
  import typing
20
- from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union
20
+ from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union # type: ignore[attr-defined]
21
21
 
22
22
  import pydantic
23
23
 
@@ -41,19 +41,11 @@ _py_builtin_type_to_schema_type = {
41
41
 
42
42
 
43
43
  def _is_builtin_primitive_or_compound(
44
- annotation: inspect.Parameter.annotation,
44
+ annotation: inspect.Parameter.annotation, # type: ignore[valid-type]
45
45
  ) -> bool:
46
46
  return annotation in _py_builtin_type_to_schema_type.keys()
47
47
 
48
48
 
49
- def _raise_for_any_of_if_mldev(schema: types.Schema):
50
- if schema.any_of:
51
- raise ValueError(
52
- 'AnyOf is not supported in function declaration schema for'
53
- ' the Gemini API.'
54
- )
55
-
56
-
57
49
  def _raise_for_default_if_mldev(schema: types.Schema):
58
50
  if schema.default is not None:
59
51
  raise ValueError(
@@ -64,12 +56,11 @@ def _raise_for_default_if_mldev(schema: types.Schema):
64
56
 
65
57
  def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
66
58
  if api_option == 'GEMINI_API':
67
- _raise_for_any_of_if_mldev(schema)
68
59
  _raise_for_default_if_mldev(schema)
69
60
 
70
61
 
71
62
  def _is_default_value_compatible(
72
- default_value: Any, annotation: inspect.Parameter.annotation
63
+ default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
73
64
  ) -> bool:
74
65
  # None type is expected to be handled external to this function
75
66
  if _is_builtin_primitive_or_compound(annotation):
@@ -292,8 +283,7 @@ def _parse_schema_from_parameter(
292
283
  ),
293
284
  func_name,
294
285
  )
295
- if api_option == 'VERTEX_AI':
296
- schema.required = _get_required_fields(schema)
286
+ schema.required = _get_required_fields(schema)
297
287
  _raise_if_schema_unsupported(api_option, schema)
298
288
  return schema
299
289
  raise ValueError(
google/genai/_common.py CHANGED
@@ -20,7 +20,7 @@ import datetime
20
20
  import enum
21
21
  import functools
22
22
  import typing
23
- from typing import Union
23
+ from typing import Any, Union
24
24
  import uuid
25
25
  import warnings
26
26
 
@@ -93,7 +93,7 @@ def set_value_by_path(data, keys, value):
93
93
  data[keys[-1]] = value
94
94
 
95
95
 
96
- def get_value_by_path(data: object, keys: list[str]):
96
+ def get_value_by_path(data: Any, keys: list[str]):
97
97
  """Examples:
98
98
 
99
99
  get_value_by_path({'a': {'b': v}}, ['a', 'b'])
@@ -128,7 +128,7 @@ def get_value_by_path(data: object, keys: list[str]):
128
128
  return data
129
129
 
130
130
 
131
- def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
131
+ def convert_to_dict(obj: object) -> Any:
132
132
  """Recursively converts a given object to a dictionary.
133
133
 
134
134
  If the object is a Pydantic model, it uses the model's `model_dump()` method.
@@ -137,7 +137,9 @@ def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
137
137
  obj: The object to convert.
138
138
 
139
139
  Returns:
140
- A dictionary representation of the object.
140
+ A dictionary representation of the object, a list of objects if a list is
141
+ passed, or the object itself if it is not a dictionary, list, or Pydantic
142
+ model.
141
143
  """
142
144
  if isinstance(obj, pydantic.BaseModel):
143
145
  return obj.model_dump(exclude_none=True)
@@ -150,7 +152,7 @@ def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
150
152
 
151
153
 
152
154
  def _remove_extra_fields(
153
- model: pydantic.BaseModel, response: dict[str, object]
155
+ model: Any, response: dict[str, object]
154
156
  ) -> None:
155
157
  """Removes extra fields from the response that are not in the model.
156
158
 
@@ -520,11 +520,14 @@ class ReplayApiClient(BaseApiClient):
520
520
  else:
521
521
  return self._build_response_from_replay(request).json
522
522
 
523
- def _download_file_request(self, request):
523
+ def download_file(self, path: str, http_options: HttpOptions):
524
524
  self._initialize_replay_session_if_not_loaded()
525
+ request = self._build_request(
526
+ 'get', path=path, request_dict={}, http_options=http_options
527
+ )
525
528
  if self._should_call_api():
526
529
  try:
527
- result = super()._download_file_request(request)
530
+ result = super().download_file(path, http_options)
528
531
  except HTTPError as e:
529
532
  result = HttpResponse(
530
533
  e.response.headers, [json.dumps({'reason': e.response.reason})]
@@ -534,7 +537,7 @@ class ReplayApiClient(BaseApiClient):
534
537
  self._record_interaction(request, result)
535
538
  return result
536
539
  else:
537
- return self._build_response_from_replay(request)
540
+ return self._build_response_from_replay(request).byte_stream[0]
538
541
 
539
542
  async def async_download_file(self, path: str, http_options):
540
543
  self._initialize_replay_session_if_not_loaded()