google-genai 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,11 +14,18 @@
14
14
  #
15
15
 
16
16
  import inspect
17
- import types as typing_types
17
+ import sys
18
+ import types as builtin_types
19
+ import typing
18
20
  from typing import Any, Callable, Literal, Union, _GenericAlias, get_args, get_origin
19
21
  import pydantic
20
22
  from . import types
21
23
 
24
+ if sys.version_info >= (3, 10):
25
+ UnionType = builtin_types.UnionType
26
+ else:
27
+ UnionType = typing._UnionGenericAlias
28
+
22
29
  _py_builtin_type_to_schema_type = {
23
30
  str: 'STRING',
24
31
  int: 'INTEGER',
@@ -58,8 +65,8 @@ def _raise_for_nullable_if_mldev(schema: types.Schema):
58
65
  )
59
66
 
60
67
 
61
- def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
62
- if not variant == 'VERTEX_AI':
68
+ def _raise_if_schema_unsupported(client, schema: types.Schema):
69
+ if not client.vertexai:
63
70
  _raise_for_any_of_if_mldev(schema)
64
71
  _raise_for_default_if_mldev(schema)
65
72
  _raise_for_nullable_if_mldev(schema)
@@ -74,11 +81,11 @@ def _is_default_value_compatible(
74
81
 
75
82
  if (
76
83
  isinstance(annotation, _GenericAlias)
77
- or isinstance(annotation, typing_types.GenericAlias)
78
- or isinstance(annotation, typing_types.UnionType)
84
+ or isinstance(annotation, builtin_types.GenericAlias)
85
+ or isinstance(annotation, UnionType)
79
86
  ):
80
87
  origin = get_origin(annotation)
81
- if origin in (Union, typing_types.UnionType):
88
+ if origin in (Union, UnionType):
82
89
  return any(
83
90
  _is_default_value_compatible(default_value, arg)
84
91
  for arg in get_args(annotation)
@@ -107,12 +114,13 @@ def _is_default_value_compatible(
107
114
  return default_value in get_args(annotation)
108
115
 
109
116
  # return False for any other unrecognized annotation
110
- # let caller handle the raise
111
117
  return False
112
118
 
113
119
 
114
120
  def _parse_schema_from_parameter(
115
- variant: str, param: inspect.Parameter, func_name: str
121
+ client,
122
+ param: inspect.Parameter,
123
+ func_name: str,
116
124
  ) -> types.Schema:
117
125
  """parse schema from parameter.
118
126
 
@@ -130,12 +138,12 @@ def _parse_schema_from_parameter(
130
138
  raise ValueError(default_value_error_msg)
131
139
  schema.default = param.default
132
140
  schema.type = _py_builtin_type_to_schema_type[param.annotation]
133
- _raise_if_schema_unsupported(variant, schema)
141
+ _raise_if_schema_unsupported(client, schema)
134
142
  return schema
135
143
  if (
136
- isinstance(param.annotation, typing_types.UnionType)
144
+ isinstance(param.annotation, UnionType)
137
145
  # only parse simple UnionType, example int | str | float | bool
138
- # complex types.UnionType will be invoked in raise branch
146
+ # complex UnionType will be invoked in raise branch
139
147
  and all(
140
148
  (_is_builtin_primitive_or_compound(arg) or arg is type(None))
141
149
  for arg in get_args(param.annotation)
@@ -149,7 +157,7 @@ def _parse_schema_from_parameter(
149
157
  schema.nullable = True
150
158
  continue
151
159
  schema_in_any_of = _parse_schema_from_parameter(
152
- variant,
160
+ client,
153
161
  inspect.Parameter(
154
162
  'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
155
163
  ),
@@ -171,10 +179,10 @@ def _parse_schema_from_parameter(
171
179
  if not _is_default_value_compatible(param.default, param.annotation):
172
180
  raise ValueError(default_value_error_msg)
173
181
  schema.default = param.default
174
- _raise_if_schema_unsupported(variant, schema)
182
+ _raise_if_schema_unsupported(client, schema)
175
183
  return schema
176
184
  if isinstance(param.annotation, _GenericAlias) or isinstance(
177
- param.annotation, typing_types.GenericAlias
185
+ param.annotation, builtin_types.GenericAlias
178
186
  ):
179
187
  origin = get_origin(param.annotation)
180
188
  args = get_args(param.annotation)
@@ -184,7 +192,7 @@ def _parse_schema_from_parameter(
184
192
  if not _is_default_value_compatible(param.default, param.annotation):
185
193
  raise ValueError(default_value_error_msg)
186
194
  schema.default = param.default
187
- _raise_if_schema_unsupported(variant, schema)
195
+ _raise_if_schema_unsupported(client, schema)
188
196
  return schema
189
197
  if origin is Literal:
190
198
  if not all(isinstance(arg, str) for arg in args):
@@ -197,12 +205,12 @@ def _parse_schema_from_parameter(
197
205
  if not _is_default_value_compatible(param.default, param.annotation):
198
206
  raise ValueError(default_value_error_msg)
199
207
  schema.default = param.default
200
- _raise_if_schema_unsupported(variant, schema)
208
+ _raise_if_schema_unsupported(client, schema)
201
209
  return schema
202
210
  if origin is list:
203
211
  schema.type = 'ARRAY'
204
212
  schema.items = _parse_schema_from_parameter(
205
- variant,
213
+ client,
206
214
  inspect.Parameter(
207
215
  'item',
208
216
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -214,7 +222,7 @@ def _parse_schema_from_parameter(
214
222
  if not _is_default_value_compatible(param.default, param.annotation):
215
223
  raise ValueError(default_value_error_msg)
216
224
  schema.default = param.default
217
- _raise_if_schema_unsupported(variant, schema)
225
+ _raise_if_schema_unsupported(client, schema)
218
226
  return schema
219
227
  if origin is Union:
220
228
  schema.any_of = []
@@ -225,7 +233,7 @@ def _parse_schema_from_parameter(
225
233
  schema.nullable = True
226
234
  continue
227
235
  schema_in_any_of = _parse_schema_from_parameter(
228
- variant,
236
+ client,
229
237
  inspect.Parameter(
230
238
  'item',
231
239
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -233,6 +241,17 @@ def _parse_schema_from_parameter(
233
241
  ),
234
242
  func_name,
235
243
  )
244
+ if (
245
+ len(param.annotation.__args__) == 2
246
+ and type(None) in param.annotation.__args__
247
+ ): # Optional type
248
+ for optional_arg in param.annotation.__args__:
249
+ if (
250
+ hasattr(optional_arg, '__origin__')
251
+ and optional_arg.__origin__ is list
252
+ ):
253
+ # Optional type with list, for example Optional[list[str]]
254
+ schema.items = schema_in_any_of.items
236
255
  if (
237
256
  schema_in_any_of.model_dump_json(exclude_none=True)
238
257
  not in unique_types
@@ -249,7 +268,7 @@ def _parse_schema_from_parameter(
249
268
  if not _is_default_value_compatible(param.default, param.annotation):
250
269
  raise ValueError(default_value_error_msg)
251
270
  schema.default = param.default
252
- _raise_if_schema_unsupported(variant, schema)
271
+ _raise_if_schema_unsupported(client, schema)
253
272
  return schema
254
273
  # all other generic alias will be invoked in raise branch
255
274
  if (
@@ -266,7 +285,7 @@ def _parse_schema_from_parameter(
266
285
  schema.properties = {}
267
286
  for field_name, field_info in param.annotation.model_fields.items():
268
287
  schema.properties[field_name] = _parse_schema_from_parameter(
269
- variant,
288
+ client,
270
289
  inspect.Parameter(
271
290
  field_name,
272
291
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -274,7 +293,9 @@ def _parse_schema_from_parameter(
274
293
  ),
275
294
  func_name,
276
295
  )
277
- _raise_if_schema_unsupported(variant, schema)
296
+ if client.vertexai:
297
+ schema.required = _get_required_fields(schema)
298
+ _raise_if_schema_unsupported(client, schema)
278
299
  return schema
279
300
  raise ValueError(
280
301
  f'Failed to parse the parameter {param} of function {func_name} for'
google/genai/_common.py CHANGED
@@ -17,6 +17,7 @@
17
17
 
18
18
  import base64
19
19
  import datetime
20
+ import enum
20
21
  import typing
21
22
  from typing import Union
22
23
  import uuid
@@ -112,12 +113,6 @@ def get_value_by_path(data: object, keys: list[str]):
112
113
  return data
113
114
 
114
115
 
115
- class BaseModule:
116
-
117
- def __init__(self, api_client_: _api_client.ApiClient):
118
- self._api_client = api_client_
119
-
120
-
121
116
  def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
122
117
  """Recursively converts a given object to a dictionary.
123
118
 
@@ -144,7 +139,7 @@ def _remove_extra_fields(
144
139
  ) -> None:
145
140
  """Removes extra fields from the response that are not in the model.
146
141
 
147
- Muates the response in place.
142
+ Mutates the response in place.
148
143
  """
149
144
 
150
145
  key_values = list(response.items())
@@ -185,7 +180,7 @@ class BaseModel(pydantic.BaseModel):
185
180
  alias_generator=alias_generators.to_camel,
186
181
  populate_by_name=True,
187
182
  from_attributes=True,
188
- protected_namespaces={},
183
+ protected_namespaces=(),
189
184
  extra='forbid',
190
185
  # This allows us to use arbitrary types in the model. E.g. PIL.Image.
191
186
  arbitrary_types_allowed=True,
@@ -208,6 +203,20 @@ class BaseModel(pydantic.BaseModel):
208
203
  return self.model_dump(exclude_none=True, mode='json')
209
204
 
210
205
 
206
+ class CaseInSensitiveEnum(str, enum.Enum):
207
+ """Case insensitive enum."""
208
+
209
+ @classmethod
210
+ def _missing_(cls, value):
211
+ try:
212
+ return cls[value.upper()] # Try to access directly with uppercase
213
+ except KeyError:
214
+ try:
215
+ return cls[value.lower()] # Try to access directly with lowercase
216
+ except KeyError as e:
217
+ raise ValueError(f"{value} is not a valid {cls.__name__}") from e
218
+
219
+
211
220
  def timestamped_unique_name() -> str:
212
221
  """Composes a timestamped unique name.
213
222
 
@@ -219,23 +228,39 @@ def timestamped_unique_name() -> str:
219
228
  return f'{timestamp}_{unique_id}'
220
229
 
221
230
 
222
- def apply_base64_encoding(data: dict[str, object]) -> dict[str, object]:
223
- """Applies base64 encoding to bytes values in the given data."""
231
+ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
232
+ """Converts unserializable types in dict to json.dumps() compatible types.
233
+
234
+ This function is called in models.py after calling convert_to_dict(). The
235
+ convert_to_dict() can convert pydantic object to dict. However, the input to
236
+ convert_to_dict() is dict mixed of pydantic object and nested dict(the output
237
+ of converters). So they may be bytes in the dict and they are out of
238
+ `ser_json_bytes` control in model_dump(mode='json') called in
239
+ `convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
240
+
241
+ Returns:
242
+ A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
243
+ to compatible type (e.g. base64 encoded string, isoformat date string).
244
+ """
224
245
  processed_data = {}
225
246
  if not isinstance(data, dict):
226
247
  return data
227
248
  for key, value in data.items():
228
249
  if isinstance(value, bytes):
229
250
  processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
251
+ elif isinstance(value, datetime.datetime):
252
+ processed_data[key] = value.isoformat()
230
253
  elif isinstance(value, dict):
231
- processed_data[key] = apply_base64_encoding(value)
254
+ processed_data[key] = encode_unserializable_types(value)
232
255
  elif isinstance(value, list):
233
256
  if all(isinstance(v, bytes) for v in value):
234
257
  processed_data[key] = [
235
258
  base64.urlsafe_b64encode(v).decode('ascii') for v in value
236
259
  ]
260
+ if all(isinstance(v, datetime.datetime) for v in value):
261
+ processed_data[key] = [v.isoformat() for v in value]
237
262
  else:
238
- processed_data[key] = [apply_base64_encoding(v) for v in value]
263
+ processed_data[key] = [encode_unserializable_types(v) for v in value]
239
264
  else:
240
265
  processed_data[key] = value
241
266
  return processed_data
@@ -13,12 +13,13 @@
13
13
  # limitations under the License.
14
14
  #
15
15
 
16
- """Extra utils depending on types that are shared between sync and async modules.
17
- """
16
+ """Extra utils depending on types that are shared between sync and async modules."""
18
17
 
19
18
  import inspect
20
19
  import logging
21
- from typing import Any, Callable, Dict, get_args, get_origin, Optional, types as typing_types, Union
20
+ import typing
21
+ from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
22
+ import sys
22
23
 
23
24
  import pydantic
24
25
 
@@ -26,6 +27,10 @@ from . import _common
26
27
  from . import errors
27
28
  from . import types
28
29
 
30
+ if sys.version_info >= (3, 10):
31
+ from types import UnionType
32
+ else:
33
+ UnionType = typing._UnionGenericAlias
29
34
 
30
35
  _DEFAULT_MAX_REMOTE_CALLS_AFC = 10
31
36
 
@@ -78,8 +83,8 @@ def get_function_map(
78
83
  if inspect.iscoroutinefunction(tool):
79
84
  raise errors.UnsupportedFunctionError(
80
85
  f'Function {tool.__name__} is a coroutine function, which is not'
81
- ' supported for automatic function calling. Please manually invoke'
82
- f' {tool.__name__} to get the function response.'
86
+ ' supported for automatic function calling. Please manually'
87
+ f' invoke {tool.__name__} to get the function response.'
83
88
  )
84
89
  function_map[tool.__name__] = tool
85
90
  return function_map
@@ -116,7 +121,7 @@ def convert_if_exist_pydantic_model(
116
121
  try:
117
122
  return annotation(**value)
118
123
  except pydantic.ValidationError as e:
119
- raise errors.UnkownFunctionCallArgumentError(
124
+ raise errors.UnknownFunctionCallArgumentError(
120
125
  f'Failed to parse parameter {param_name} for function'
121
126
  f' {func_name} from function call part because function call argument'
122
127
  f' value {value} is not compatible with parameter annotation'
@@ -135,11 +140,13 @@ def convert_if_exist_pydantic_model(
135
140
  for k, v in value.items()
136
141
  }
137
142
  # example 1: typing.Union[int, float]
138
- # example 2: int | float equivalent to typing.types.UnionType[int, float]
139
- if get_origin(annotation) in (Union, typing_types.UnionType):
143
+ # example 2: int | float equivalent to UnionType[int, float]
144
+ if get_origin(annotation) in (Union, UnionType):
140
145
  for arg in get_args(annotation):
141
- if isinstance(value, arg) or (
142
- isinstance(value, dict) and _is_annotation_pydantic_model(arg)
146
+ if (
147
+ (get_args(arg) and get_origin(arg) is list)
148
+ or isinstance(value, arg)
149
+ or (isinstance(value, dict) and _is_annotation_pydantic_model(arg))
143
150
  ):
144
151
  try:
145
152
  return convert_if_exist_pydantic_model(
@@ -150,7 +157,7 @@ def convert_if_exist_pydantic_model(
150
157
  except pydantic.ValidationError:
151
158
  continue
152
159
  # if none of the union type is matched, raise error
153
- raise errors.UnkownFunctionCallArgumentError(
160
+ raise errors.UnknownFunctionCallArgumentError(
154
161
  f'Failed to parse parameter {param_name} for function'
155
162
  f' {func_name} from function call part because function call argument'
156
163
  f' value {value} cannot be converted to parameter annotation'
@@ -161,7 +168,7 @@ def convert_if_exist_pydantic_model(
161
168
  if isinstance(value, int) and annotation is float:
162
169
  return value
163
170
  if not isinstance(value, annotation):
164
- raise errors.UnkownFunctionCallArgumentError(
171
+ raise errors.UnknownFunctionCallArgumentError(
165
172
  f'Failed to parse parameter {param_name} for function {func_name} from'
166
173
  f' function call part because function call argument value {value} is'
167
174
  f' not compatible with parameter annotation {annotation}.'
@@ -209,7 +216,9 @@ def get_function_response_parts(
209
216
  response = {'result': invoke_function_from_dict_args(args, func)}
210
217
  except Exception as e: # pylint: disable=broad-except
211
218
  response = {'error': str(e)}
212
- func_response = types.Part.from_function_response(func_name, response)
219
+ func_response = types.Part.from_function_response(
220
+ name=func_name, response=response
221
+ )
213
222
 
214
223
  func_response_parts.append(func_response)
215
224
  return func_response_parts
@@ -231,8 +240,7 @@ def should_disable_afc(
231
240
  and config_model.automatic_function_calling
232
241
  and config_model.automatic_function_calling.maximum_remote_calls
233
242
  is not None
234
- and int(config_model.automatic_function_calling.maximum_remote_calls)
235
- <= 0
243
+ and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
236
244
  ):
237
245
  logging.warning(
238
246
  'max_remote_calls in automatic_function_calling_config'
@@ -294,6 +302,7 @@ def get_max_remote_calls_afc(
294
302
  return _DEFAULT_MAX_REMOTE_CALLS_AFC
295
303
  return int(config_model.automatic_function_calling.maximum_remote_calls)
296
304
 
305
+
297
306
  def should_append_afc_history(
298
307
  config: Optional[types.GenerateContentConfigOrDict] = None,
299
308
  ) -> bool:
@@ -302,9 +311,6 @@ def should_append_afc_history(
302
311
  if config and isinstance(config, dict)
303
312
  else config
304
313
  )
305
- if (
306
- not config_model
307
- or not config_model.automatic_function_calling
308
- ):
314
+ if not config_model or not config_model.automatic_function_calling:
309
315
  return True
310
316
  return not config_model.automatic_function_calling.ignore_call_history
@@ -17,11 +17,12 @@
17
17
 
18
18
  import base64
19
19
  import copy
20
+ import datetime
20
21
  import inspect
22
+ import io
21
23
  import json
22
24
  import os
23
25
  import re
24
- import datetime
25
26
  from typing import Any, Literal, Optional, Union
26
27
 
27
28
  import google.auth
@@ -32,9 +33,9 @@ from ._api_client import ApiClient
32
33
  from ._api_client import HttpOptions
33
34
  from ._api_client import HttpRequest
34
35
  from ._api_client import HttpResponse
35
- from ._api_client import RequestJsonEncoder
36
36
  from ._common import BaseModel
37
37
 
38
+
38
39
  def _redact_version_numbers(version_string: str) -> str:
39
40
  """Redacts version numbers in the form x.y.z from a string."""
40
41
  return re.sub(r'\d+\.\d+\.\d+', '{VERSION_NUMBER}', version_string)
@@ -145,6 +146,7 @@ class ReplayResponse(BaseModel):
145
146
  status_code: int = 200
146
147
  headers: dict[str, str]
147
148
  body_segments: list[dict[str, object]]
149
+ byte_segments: Optional[list[bytes]] = None
148
150
  sdk_response_segments: list[dict[str, object]]
149
151
 
150
152
  def model_post_init(self, __context: Any) -> None:
@@ -264,17 +266,13 @@ class ReplayApiClient(ApiClient):
264
266
  replay_file_path = self._get_replay_file_path()
265
267
  os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
266
268
  with open(replay_file_path, 'w') as f:
267
- f.write(
268
- json.dumps(
269
- self.replay_session.model_dump(mode='json'), indent=2, cls=ResponseJsonEncoder
270
- )
271
- )
269
+ f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2))
272
270
  self.replay_session = None
273
271
 
274
272
  def _record_interaction(
275
273
  self,
276
274
  http_request: HttpRequest,
277
- http_response: Union[HttpResponse, errors.APIError],
275
+ http_response: Union[HttpResponse, errors.APIError, bytes],
278
276
  ):
279
277
  if not self._should_update_replay():
280
278
  return
@@ -289,6 +287,9 @@ class ReplayApiClient(ApiClient):
289
287
  response = ReplayResponse(
290
288
  headers=dict(http_response.headers),
291
289
  body_segments=list(http_response.segments()),
290
+ byte_segments=[
291
+ seg[:100] + b'...' for seg in http_response.byte_segments()
292
+ ],
292
293
  status_code=http_response.status_code,
293
294
  sdk_response_segments=[],
294
295
  )
@@ -322,11 +323,7 @@ class ReplayApiClient(ApiClient):
322
323
  # so that the comparison is fair.
323
324
  _redact_request_body(request_data_copy)
324
325
 
325
- # Need to call dumps() and loads() to convert dict bytes values to strings.
326
- # Because the expected_request_body dict never contains bytes values.
327
- actual_request_body = [
328
- json.loads(json.dumps(request_data_copy, cls=RequestJsonEncoder))
329
- ]
326
+ actual_request_body = [request_data_copy]
330
327
  expected_request_body = interaction.request.body_segments
331
328
  assert actual_request_body == expected_request_body, (
332
329
  'Request body mismatch:\n'
@@ -349,6 +346,7 @@ class ReplayApiClient(ApiClient):
349
346
  json.dumps(segment)
350
347
  for segment in interaction.response.body_segments
351
348
  ],
349
+ byte_stream=interaction.response.byte_segments,
352
350
  )
353
351
 
354
352
  def _verify_response(self, response_model: BaseModel):
@@ -368,7 +366,9 @@ class ReplayApiClient(ApiClient):
368
366
  response_model = response_model[0]
369
367
  print('response_model: ', response_model.model_dump(exclude_none=True))
370
368
  actual = response_model.model_dump(exclude_none=True, mode='json')
371
- expected = interaction.response.sdk_response_segments[self._sdk_response_index]
369
+ expected = interaction.response.sdk_response_segments[
370
+ self._sdk_response_index
371
+ ]
372
372
  assert (
373
373
  actual == expected
374
374
  ), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
@@ -397,15 +397,26 @@ class ReplayApiClient(ApiClient):
397
397
  # segments since the stream has been consumed.
398
398
  else:
399
399
  self._record_interaction(http_request, result)
400
- _debug_print('api mode result: %s' % result.text)
400
+ _debug_print('api mode result: %s' % result.json)
401
401
  return result
402
402
  else:
403
403
  return self._build_response_from_replay(http_request)
404
404
 
405
- def upload_file(self, file_path: str, upload_url: str, upload_size: int):
406
- request = HttpRequest(
407
- method='POST', url='', data={'file_path': file_path}, headers={}
408
- )
405
+ def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
406
+ if isinstance(file_path, io.IOBase):
407
+ offset = file_path.tell()
408
+ content = file_path.read()
409
+ file_path.seek(offset, os.SEEK_SET)
410
+ request = HttpRequest(
411
+ method='POST',
412
+ url='',
413
+ data={'bytes': base64.b64encode(content).decode('utf-8')},
414
+ headers={}
415
+ )
416
+ else:
417
+ request = HttpRequest(
418
+ method='POST', url='', data={'file_path': file_path}, headers={}
419
+ )
409
420
  if self._should_call_api():
410
421
  try:
411
422
  result = super().upload_file(file_path, upload_url, upload_size)
@@ -418,20 +429,21 @@ class ReplayApiClient(ApiClient):
418
429
  self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
419
430
  return result
420
431
  else:
421
- return self._build_response_from_replay(request).text
422
-
423
-
424
- # TODO(b/389693448): Cleanup datetime hacks.
425
- class ResponseJsonEncoder(json.JSONEncoder):
426
- """The replay test json encoder for response.
427
- """
428
- def default(self, o):
429
- if isinstance(o, datetime.datetime):
430
- # dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
431
- # but replay files want "2024-11-15T23:27:45.624657Z"
432
- if o.isoformat().endswith('+00:00'):
433
- return o.isoformat().replace('+00:00', 'Z')
434
- else:
435
- return o.isoformat()
432
+ return self._build_response_from_replay(request).json
433
+
434
+ def _download_file_request(self, request):
435
+ self._initialize_replay_session_if_not_loaded()
436
+ if self._should_call_api():
437
+ try:
438
+ result = super()._download_file_request(request)
439
+ except HTTPError as e:
440
+ result = HttpResponse(
441
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
442
+ )
443
+ result.status_code = e.response.status_code
444
+ raise e
445
+ self._record_interaction(request, result)
446
+ return result
436
447
  else:
437
- return super().default(o)
448
+ return self._build_response_from_replay(request)
449
+
@@ -132,7 +132,7 @@ async def test_async_request_streamed_non_blocking(
132
132
 
133
133
  chunks = []
134
134
  start_time = time.time()
135
- async for chunk in api_client.async_request_streamed(
135
+ async for chunk in await api_client.async_request_streamed(
136
136
  http_method, path, request_dict
137
137
  ):
138
138
  chunks.append(chunk)