google-genai 1.4.0__tar.gz → 1.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {google_genai-1.4.0/google_genai.egg-info → google_genai-1.5.0}/PKG-INFO +2 -1
  2. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/_api_client.py +115 -45
  3. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/_automatic_function_calling_util.py +3 -3
  4. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/_common.py +5 -2
  5. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/_extra_utils.py +62 -47
  6. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/_replay_api_client.py +70 -2
  7. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/_transformers.py +43 -26
  8. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/batches.py +10 -10
  9. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/caches.py +10 -10
  10. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/files.py +22 -9
  11. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/models.py +70 -46
  12. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/operations.py +10 -10
  13. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/pagers.py +14 -5
  14. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/tunings.py +9 -9
  15. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/types.py +59 -26
  16. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/version.py +1 -1
  17. {google_genai-1.4.0 → google_genai-1.5.0/google_genai.egg-info}/PKG-INFO +2 -1
  18. {google_genai-1.4.0 → google_genai-1.5.0}/google_genai.egg-info/requires.txt +1 -0
  19. {google_genai-1.4.0 → google_genai-1.5.0}/pyproject.toml +2 -1
  20. {google_genai-1.4.0 → google_genai-1.5.0}/LICENSE +0 -0
  21. {google_genai-1.4.0 → google_genai-1.5.0}/MANIFEST.in +0 -0
  22. {google_genai-1.4.0 → google_genai-1.5.0}/README.md +0 -0
  23. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/__init__.py +0 -0
  24. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/_api_module.py +0 -0
  25. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/_test_api_client.py +0 -0
  26. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/chats.py +0 -0
  27. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/client.py +0 -0
  28. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/errors.py +0 -0
  29. {google_genai-1.4.0 → google_genai-1.5.0}/google/genai/live.py +0 -0
  30. {google_genai-1.4.0 → google_genai-1.5.0}/google_genai.egg-info/SOURCES.txt +0 -0
  31. {google_genai-1.4.0 → google_genai-1.5.0}/google_genai.egg-info/dependency_links.txt +0 -0
  32. {google_genai-1.4.0 → google_genai-1.5.0}/google_genai.egg-info/top_level.txt +0 -0
  33. {google_genai-1.4.0 → google_genai-1.5.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: google-genai
3
- Version: 1.4.0
3
+ Version: 1.5.0
4
4
  Summary: GenAI Python SDK
5
5
  Author-email: Google LLC <googleapis-packages@google.com>
6
6
  License: Apache-2.0
@@ -20,6 +20,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
20
20
  Requires-Python: >=3.9
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
+ Requires-Dist: anyio<5.0.0dev,>=4.8.0
23
24
  Requires-Dist: google-auth<3.0.0dev,>=2.14.1
24
25
  Requires-Dist: httpx<1.0.0dev,>=0.28.1
25
26
  Requires-Dist: pydantic<3.0.0dev,>=2.0.0
@@ -19,6 +19,7 @@
19
19
  The BaseApiClient is intended to be a private module and is subject to change.
20
20
  """
21
21
 
22
+ import anyio
22
23
  import asyncio
23
24
  import copy
24
25
  from dataclasses import dataclass
@@ -44,6 +45,7 @@ from . import version
44
45
  from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
45
46
 
46
47
  logger = logging.getLogger('google_genai._api_client')
48
+ CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
47
49
 
48
50
 
49
51
  def _append_library_version_headers(headers: dict[str, str]) -> None:
@@ -118,8 +120,9 @@ def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
118
120
  return credentials, project
119
121
 
120
122
 
121
- def _refresh_auth(credentials: Credentials) -> None:
123
+ def _refresh_auth(credentials: Credentials) -> Credentials:
122
124
  credentials.refresh(Request())
125
+ return credentials
123
126
 
124
127
 
125
128
  @dataclass
@@ -147,7 +150,7 @@ class HttpResponse:
147
150
 
148
151
  def __init__(
149
152
  self,
150
- headers: dict[str, str],
153
+ headers: Union[dict[str, str], httpx.Headers],
151
154
  response_stream: Union[Any, str] = None,
152
155
  byte_stream: Union[Any, bytes] = None,
153
156
  ):
@@ -200,14 +203,19 @@ class HttpResponse:
200
203
  yield c
201
204
  else:
202
205
  # Iterator of objects retrieved from the API.
203
- async for chunk in self.response_stream.aiter_lines():
204
- # This is httpx.Response.
205
- if chunk:
206
- # In async streaming mode, the chunk of JSON is prefixed with "data:"
207
- # which we must strip before parsing.
208
- if chunk.startswith('data: '):
209
- chunk = chunk[len('data: ') :]
210
- yield json.loads(chunk)
206
+ if hasattr(self.response_stream, 'aiter_lines'):
207
+ async for chunk in self.response_stream.aiter_lines():
208
+ # This is httpx.Response.
209
+ if chunk:
210
+ # In async streaming mode, the chunk of JSON is prefixed with "data:"
211
+ # which we must strip before parsing.
212
+ if chunk.startswith('data: '):
213
+ chunk = chunk[len('data: ') :]
214
+ yield json.loads(chunk)
215
+ else:
216
+ raise ValueError(
217
+ 'Error parsing streaming response.'
218
+ )
211
219
 
212
220
  def byte_segments(self):
213
221
  if isinstance(self.byte_stream, list):
@@ -373,17 +381,22 @@ class BaseApiClient:
373
381
  if not self.project:
374
382
  self.project = project
375
383
 
376
- if self._credentials.expired or not self._credentials.token:
377
- # Only refresh when it needs to. Default expiration is 3600 seconds.
378
- async with self._auth_lock:
379
- if self._credentials.expired or not self._credentials.token:
380
- # Double check that the credentials expired before refreshing.
381
- await asyncio.to_thread(_refresh_auth, self._credentials)
384
+ if self._credentials:
385
+ if (
386
+ self._credentials.expired or not self._credentials.token
387
+ ):
388
+ # Only refresh when it needs to. Default expiration is 3600 seconds.
389
+ async with self._auth_lock:
390
+ if self._credentials.expired or not self._credentials.token:
391
+ # Double check that the credentials expired before refreshing.
392
+ await asyncio.to_thread(_refresh_auth, self._credentials)
382
393
 
383
- if not self._credentials.token:
384
- raise RuntimeError('Could not resolve API token from the environment')
394
+ if not self._credentials.token:
395
+ raise RuntimeError('Could not resolve API token from the environment')
385
396
 
386
- return self._credentials.token
397
+ return self._credentials.token
398
+ else:
399
+ raise RuntimeError('Could not resolve API token from the environment')
387
400
 
388
401
  def _build_request(
389
402
  self,
@@ -428,8 +441,12 @@ class BaseApiClient:
428
441
  patched_http_options['api_version'] + '/' + path,
429
442
  )
430
443
 
431
- timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get('timeout', None)
444
+ timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
445
+ 'timeout', None
446
+ )
432
447
  if timeout_in_seconds:
448
+ # HttpOptions.timeout is in milliseconds. But httpx.Client.request()
449
+ # expects seconds.
433
450
  timeout_in_seconds = timeout_in_seconds / 1000.0
434
451
  else:
435
452
  timeout_in_seconds = None
@@ -503,18 +520,19 @@ class BaseApiClient:
503
520
  http_request.headers['Authorization'] = (
504
521
  f'Bearer {await self._async_access_token()}'
505
522
  )
506
- if self._credentials.quota_project_id:
523
+ if self._credentials and self._credentials.quota_project_id:
507
524
  http_request.headers['x-goog-user-project'] = (
508
525
  self._credentials.quota_project_id
509
526
  )
510
527
  if stream:
511
- httpx_request = httpx.Request(
528
+ aclient = httpx.AsyncClient()
529
+ httpx_request = aclient.build_request(
512
530
  method=http_request.method,
513
531
  url=http_request.url,
514
532
  content=json.dumps(http_request.data),
515
533
  headers=http_request.headers,
534
+ timeout=http_request.timeout,
516
535
  )
517
- aclient = httpx.AsyncClient()
518
536
  response = await aclient.send(
519
537
  httpx_request,
520
538
  stream=stream,
@@ -652,7 +670,7 @@ class BaseApiClient:
652
670
  offset = 0
653
671
  # Upload the file in chunks
654
672
  while True:
655
- file_chunk = file.read(1024 * 1024 * 8) # 8 MB chunk size
673
+ file_chunk = file.read(CHUNK_SIZE)
656
674
  chunk_size = 0
657
675
  if file_chunk:
658
676
  chunk_size = len(file_chunk)
@@ -679,13 +697,12 @@ class BaseApiClient:
679
697
  if upload_size <= offset: # Status is not finalized.
680
698
  raise ValueError(
681
699
  'All content has been uploaded, but the upload status is not'
682
- f' finalized. {response.headers}, body: {response.json}'
700
+ f' finalized.'
683
701
  )
684
702
 
685
703
  if response.headers['X-Goog-Upload-Status'] != 'final':
686
704
  raise ValueError(
687
- 'Failed to upload file: Upload status is not finalized. headers:'
688
- f' {response.headers}, body: {response.json}'
705
+ 'Failed to upload file: Upload status is not finalized.'
689
706
  )
690
707
  return response.json
691
708
 
@@ -708,7 +725,7 @@ class BaseApiClient:
708
725
  self,
709
726
  http_request: HttpRequest,
710
727
  ) -> HttpResponse:
711
- data: str | bytes | None = None
728
+ data: Optional[Union[str, bytes]] = None
712
729
  if http_request.data:
713
730
  if not isinstance(http_request.data, bytes):
714
731
  data = json.dumps(http_request.data)
@@ -746,16 +763,17 @@ class BaseApiClient:
746
763
  returns:
747
764
  The response json object from the finalize request.
748
765
  """
749
- return await asyncio.to_thread(
750
- self.upload_file,
751
- file_path,
752
- upload_url,
753
- upload_size,
754
- )
766
+ if isinstance(file_path, io.IOBase):
767
+ return await self._async_upload_fd(file_path, upload_url, upload_size)
768
+ else:
769
+ file = anyio.Path(file_path)
770
+ fd = await file.open('rb')
771
+ async with fd:
772
+ return await self._async_upload_fd(fd, upload_url, upload_size)
755
773
 
756
774
  async def _async_upload_fd(
757
775
  self,
758
- file: io.IOBase,
776
+ file: Union[io.IOBase, anyio.AsyncFile],
759
777
  upload_url: str,
760
778
  upload_size: int,
761
779
  ) -> str:
@@ -770,12 +788,45 @@ class BaseApiClient:
770
788
  returns:
771
789
  The response json object from the finalize request.
772
790
  """
773
- return await asyncio.to_thread(
774
- self._upload_fd,
775
- file,
776
- upload_url,
777
- upload_size,
778
- )
791
+ async with httpx.AsyncClient() as aclient:
792
+ offset = 0
793
+ # Upload the file in chunks
794
+ while True:
795
+ if isinstance(file, io.IOBase):
796
+ file_chunk = file.read(CHUNK_SIZE)
797
+ else:
798
+ file_chunk = await file.read(CHUNK_SIZE)
799
+ chunk_size = 0
800
+ if file_chunk:
801
+ chunk_size = len(file_chunk)
802
+ upload_command = 'upload'
803
+ # If last chunk, finalize the upload.
804
+ if chunk_size + offset >= upload_size:
805
+ upload_command += ', finalize'
806
+ response = await aclient.request(
807
+ method='POST',
808
+ url=upload_url,
809
+ content=file_chunk,
810
+ headers={
811
+ 'X-Goog-Upload-Command': upload_command,
812
+ 'X-Goog-Upload-Offset': str(offset),
813
+ 'Content-Length': str(chunk_size),
814
+ },
815
+ )
816
+ offset += chunk_size
817
+ if response.headers.get('x-goog-upload-status') != 'active':
818
+ break # upload is complete or it has been interrupted.
819
+
820
+ if upload_size <= offset: # Status is not finalized.
821
+ raise ValueError(
822
+ 'All content has been uploaded, but the upload status is not'
823
+ f' finalized.'
824
+ )
825
+ if response.headers.get('x-goog-upload-status') != 'final':
826
+ raise ValueError(
827
+ 'Failed to upload file: Upload status is not finalized.'
828
+ )
829
+ return response.json()
779
830
 
780
831
  async def async_download_file(self, path: str, http_options):
781
832
  """Downloads the file data.
@@ -787,12 +838,31 @@ class BaseApiClient:
787
838
  returns:
788
839
  The file bytes
789
840
  """
790
- return await asyncio.to_thread(
791
- self.download_file,
792
- path,
793
- http_options,
841
+ http_request = self._build_request(
842
+ 'get', path=path, request_dict={}, http_options=http_options
794
843
  )
795
844
 
845
+ data: Optional[Union[str, bytes]]
846
+ if http_request.data:
847
+ if not isinstance(http_request.data, bytes):
848
+ data = json.dumps(http_request.data)
849
+ else:
850
+ data = http_request.data
851
+
852
+ async with httpx.AsyncClient(follow_redirects=True) as aclient:
853
+ response = await aclient.request(
854
+ method=http_request.method,
855
+ url=http_request.url,
856
+ headers=http_request.headers,
857
+ content=data,
858
+ timeout=http_request.timeout,
859
+ )
860
+ errors.APIError.raise_for_response(response)
861
+
862
+ return HttpResponse(
863
+ response.headers, byte_stream=[response.read()]
864
+ ).byte_stream[0]
865
+
796
866
  # This method does nothing in the real api client. It is used in the
797
867
  # replay_api_client to verify the response from the SDK method matches the
798
868
  # recorded response.
@@ -17,7 +17,7 @@ import inspect
17
17
  import sys
18
18
  import types as builtin_types
19
19
  import typing
20
- from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Union
20
+ from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union
21
21
 
22
22
  import pydantic
23
23
 
@@ -304,9 +304,9 @@ def _parse_schema_from_parameter(
304
304
  )
305
305
 
306
306
 
307
- def _get_required_fields(schema: types.Schema) -> list[str]:
307
+ def _get_required_fields(schema: types.Schema) -> Optional[list[str]]:
308
308
  if not schema.properties:
309
- return
309
+ return None
310
310
  return [
311
311
  field_name
312
312
  for field_name, field_schema in schema.properties.items()
@@ -188,6 +188,8 @@ def _remove_extra_fields(
188
188
  if isinstance(item, dict):
189
189
  _remove_extra_fields(typing.get_args(annotation)[0], item)
190
190
 
191
+ T = typing.TypeVar('T', bound='BaseModel')
192
+
191
193
 
192
194
  class BaseModel(pydantic.BaseModel):
193
195
 
@@ -201,12 +203,13 @@ class BaseModel(pydantic.BaseModel):
201
203
  arbitrary_types_allowed=True,
202
204
  ser_json_bytes='base64',
203
205
  val_json_bytes='base64',
206
+ ignored_types=(typing.TypeVar,)
204
207
  )
205
208
 
206
209
  @classmethod
207
210
  def _from_response(
208
- cls, *, response: dict[str, object], kwargs: dict[str, object]
209
- ) -> 'BaseModel':
211
+ cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
212
+ ) -> T:
210
213
  # To maintain forward compatibility, we need to remove extra fields from
211
214
  # the response.
212
215
  # We will provide another mechanism to allow users to access these fields.
@@ -17,9 +17,9 @@
17
17
 
18
18
  import inspect
19
19
  import logging
20
+ import sys
20
21
  import typing
21
22
  from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
22
- import sys
23
23
 
24
24
  import pydantic
25
25
 
@@ -37,6 +37,15 @@ _DEFAULT_MAX_REMOTE_CALLS_AFC = 10
37
37
  logger = logging.getLogger('google_genai.models')
38
38
 
39
39
 
40
+ def _create_generate_content_config_model(
41
+ config: types.GenerateContentConfigOrDict,
42
+ ) -> types.GenerateContentConfig:
43
+ if isinstance(config, dict):
44
+ return types.GenerateContentConfig(**config)
45
+ else:
46
+ return config
47
+
48
+
40
49
  def format_destination(
41
50
  src: str,
42
51
  config: Optional[types.CreateBatchJobConfigOrDict] = None,
@@ -69,16 +78,12 @@ def format_destination(
69
78
 
70
79
  def get_function_map(
71
80
  config: Optional[types.GenerateContentConfigOrDict] = None,
72
- ) -> dict[str, object]:
81
+ ) -> dict[str, Callable]:
73
82
  """Returns a function map from the config."""
74
- config_model = (
75
- types.GenerateContentConfig(**config)
76
- if config and isinstance(config, dict)
77
- else config
78
- )
79
- function_map: dict[str, object] = {}
80
- if not config_model:
83
+ function_map: dict[str, Callable] = {}
84
+ if not config:
81
85
  return function_map
86
+ config_model = _create_generate_content_config_model(config)
82
87
  if config_model.tools:
83
88
  for tool in config_model.tools:
84
89
  if callable(tool):
@@ -92,6 +97,16 @@ def get_function_map(
92
97
  return function_map
93
98
 
94
99
 
100
+ def convert_number_values_for_dict_function_call_args(
101
+ args: dict[str, Any],
102
+ ) -> dict[str, Any]:
103
+ """Converts float values in dict with no decimal to integers."""
104
+ return {
105
+ key: convert_number_values_for_function_call_args(value)
106
+ for key, value in args.items()
107
+ }
108
+
109
+
95
110
  def convert_number_values_for_function_call_args(
96
111
  args: Union[dict[str, object], list[object], object],
97
112
  ) -> Union[dict[str, object], list[object], object]:
@@ -210,26 +225,35 @@ def invoke_function_from_dict_args(
210
225
 
211
226
  def get_function_response_parts(
212
227
  response: types.GenerateContentResponse,
213
- function_map: dict[str, object],
228
+ function_map: dict[str, Callable],
214
229
  ) -> list[types.Part]:
215
230
  """Returns the function response parts from the response."""
216
231
  func_response_parts = []
217
- for part in response.candidates[0].content.parts:
218
- if not part.function_call:
219
- continue
220
- func_name = part.function_call.name
221
- func = function_map[func_name]
222
- args = convert_number_values_for_function_call_args(part.function_call.args)
223
- func_response: dict[str, Any]
224
- try:
225
- func_response = {'result': invoke_function_from_dict_args(args, func)}
226
- except Exception as e: # pylint: disable=broad-except
227
- func_response = {'error': str(e)}
228
- func_response_part = types.Part.from_function_response(
229
- name=func_name, response=func_response
230
- )
231
-
232
- func_response_parts.append(func_response_part)
232
+ if (
233
+ response.candidates is not None
234
+ and isinstance(response.candidates[0].content, types.Content)
235
+ and response.candidates[0].content.parts is not None
236
+ ):
237
+ for part in response.candidates[0].content.parts:
238
+ if not part.function_call:
239
+ continue
240
+ func_name = part.function_call.name
241
+ if func_name is not None and part.function_call.args is not None:
242
+ func = function_map[func_name]
243
+ args = convert_number_values_for_dict_function_call_args(
244
+ part.function_call.args
245
+ )
246
+ func_response: dict[str, Any]
247
+ try:
248
+ func_response = {
249
+ 'result': invoke_function_from_dict_args(args, func)
250
+ }
251
+ except Exception as e: # pylint: disable=broad-except
252
+ func_response = {'error': str(e)}
253
+ func_response_part = types.Part.from_function_response(
254
+ name=func_name, response=func_response
255
+ )
256
+ func_response_parts.append(func_response_part)
233
257
  return func_response_parts
234
258
 
235
259
 
@@ -237,12 +261,9 @@ def should_disable_afc(
237
261
  config: Optional[types.GenerateContentConfigOrDict] = None,
238
262
  ) -> bool:
239
263
  """Returns whether automatic function calling is enabled."""
240
- config_model = (
241
- types.GenerateContentConfig(**config)
242
- if config and isinstance(config, dict)
243
- else config
244
- )
245
-
264
+ if not config:
265
+ return False
266
+ config_model = _create_generate_content_config_model(config)
246
267
  # If max_remote_calls is less or equal to 0, warn and disable AFC.
247
268
  if (
248
269
  config_model
@@ -261,8 +282,7 @@ def should_disable_afc(
261
282
 
262
283
  # Default to enable AFC if not specified.
263
284
  if (
264
- not config_model
265
- or not config_model.automatic_function_calling
285
+ not config_model.automatic_function_calling
266
286
  or config_model.automatic_function_calling.disable is None
267
287
  ):
268
288
  return False
@@ -295,20 +315,17 @@ def should_disable_afc(
295
315
  def get_max_remote_calls_afc(
296
316
  config: Optional[types.GenerateContentConfigOrDict] = None,
297
317
  ) -> int:
318
+ if not config:
319
+ return _DEFAULT_MAX_REMOTE_CALLS_AFC
298
320
  """Returns the remaining remote calls for automatic function calling."""
299
321
  if should_disable_afc(config):
300
322
  raise ValueError(
301
323
  'automatic function calling is not enabled, but SDK is trying to get'
302
324
  ' max remote calls.'
303
325
  )
304
- config_model = (
305
- types.GenerateContentConfig(**config)
306
- if config and isinstance(config, dict)
307
- else config
308
- )
326
+ config_model = _create_generate_content_config_model(config)
309
327
  if (
310
- not config_model
311
- or not config_model.automatic_function_calling
328
+ not config_model.automatic_function_calling
312
329
  or config_model.automatic_function_calling.maximum_remote_calls is None
313
330
  ):
314
331
  return _DEFAULT_MAX_REMOTE_CALLS_AFC
@@ -318,11 +335,9 @@ def get_max_remote_calls_afc(
318
335
  def should_append_afc_history(
319
336
  config: Optional[types.GenerateContentConfigOrDict] = None,
320
337
  ) -> bool:
321
- config_model = (
322
- types.GenerateContentConfig(**config)
323
- if config and isinstance(config, dict)
324
- else config
325
- )
326
- if not config_model or not config_model.automatic_function_calling:
338
+ if not config:
339
+ return True
340
+ config_model = _create_generate_content_config_model(config)
341
+ if not config_model.automatic_function_calling:
327
342
  return True
328
343
  return not config_model.automatic_function_calling.ignore_call_history
@@ -109,7 +109,8 @@ def _redact_project_location_path(path: str) -> str:
109
109
  return path
110
110
 
111
111
 
112
- def _redact_request_body(body: dict[str, object]) -> dict[str, object]:
112
+ def _redact_request_body(body: dict[str, object]):
113
+ """Redacts fields in the request body in place."""
113
114
  for key, value in body.items():
114
115
  if isinstance(value, str):
115
116
  body[key] = _redact_project_location_path(value)
@@ -302,13 +303,24 @@ class ReplayApiClient(BaseApiClient):
302
303
  status_code=http_response.status_code,
303
304
  sdk_response_segments=[],
304
305
  )
305
- else:
306
+ elif isinstance(http_response, errors.APIError):
306
307
  response = ReplayResponse(
307
308
  headers=dict(http_response.response.headers),
308
309
  body_segments=[http_response._to_replay_record()],
309
310
  status_code=http_response.code,
310
311
  sdk_response_segments=[],
311
312
  )
313
+ elif isinstance(http_response, bytes):
314
+ response = ReplayResponse(
315
+ headers={},
316
+ body_segments=[],
317
+ byte_segments=[http_response],
318
+ sdk_response_segments=[],
319
+ )
320
+ else:
321
+ raise ValueError(
322
+ 'Unsupported http_response type: ' + str(type(http_response))
323
+ )
312
324
  self.replay_session.interactions.append(
313
325
  ReplayInteraction(request=request, response=response)
314
326
  )
@@ -471,6 +483,43 @@ class ReplayApiClient(BaseApiClient):
471
483
  else:
472
484
  return self._build_response_from_replay(request).json
473
485
 
486
+ async def async_upload_file(
487
+ self,
488
+ file_path: Union[str, io.IOBase],
489
+ upload_url: str,
490
+ upload_size: int,
491
+ ) -> str:
492
+ if isinstance(file_path, io.IOBase):
493
+ offset = file_path.tell()
494
+ content = file_path.read()
495
+ file_path.seek(offset, os.SEEK_SET)
496
+ request = HttpRequest(
497
+ method='POST',
498
+ url='',
499
+ data={'bytes': base64.b64encode(content).decode('utf-8')},
500
+ headers={},
501
+ )
502
+ else:
503
+ request = HttpRequest(
504
+ method='POST', url='', data={'file_path': file_path}, headers={}
505
+ )
506
+ if self._should_call_api():
507
+ result: Union[str, HttpResponse]
508
+ try:
509
+ result = await super().async_upload_file(
510
+ file_path, upload_url, upload_size
511
+ )
512
+ except HTTPError as e:
513
+ result = HttpResponse(
514
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
515
+ )
516
+ result.status_code = e.response.status_code
517
+ raise e
518
+ self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
519
+ return result
520
+ else:
521
+ return self._build_response_from_replay(request).json
522
+
474
523
  def _download_file_request(self, request):
475
524
  self._initialize_replay_session_if_not_loaded()
476
525
  if self._should_call_api():
@@ -486,3 +535,22 @@ class ReplayApiClient(BaseApiClient):
486
535
  return result
487
536
  else:
488
537
  return self._build_response_from_replay(request)
538
+
539
+ async def async_download_file(self, path: str, http_options):
540
+ self._initialize_replay_session_if_not_loaded()
541
+ request = self._build_request(
542
+ 'get', path=path, request_dict={}, http_options=http_options
543
+ )
544
+ if self._should_call_api():
545
+ try:
546
+ result = await super().async_download_file(path, http_options)
547
+ except HTTPError as e:
548
+ result = HttpResponse(
549
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
550
+ )
551
+ result.status_code = e.response.status_code
552
+ raise e
553
+ self._record_interaction(request, result)
554
+ return result
555
+ else:
556
+ return self._build_response_from_replay(request).byte_stream[0]