google-genai 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +115 -45
- google/genai/_automatic_function_calling_util.py +3 -3
- google/genai/_common.py +5 -2
- google/genai/_extra_utils.py +62 -47
- google/genai/_replay_api_client.py +70 -2
- google/genai/_transformers.py +43 -26
- google/genai/batches.py +10 -10
- google/genai/caches.py +10 -10
- google/genai/files.py +22 -9
- google/genai/models.py +70 -46
- google/genai/operations.py +10 -10
- google/genai/pagers.py +14 -5
- google/genai/tunings.py +9 -9
- google/genai/types.py +59 -26
- google/genai/version.py +1 -1
- {google_genai-1.4.0.dist-info → google_genai-1.5.0.dist-info}/METADATA +2 -1
- google_genai-1.5.0.dist-info/RECORD +27 -0
- google_genai-1.4.0.dist-info/RECORD +0 -27
- {google_genai-1.4.0.dist-info → google_genai-1.5.0.dist-info}/LICENSE +0 -0
- {google_genai-1.4.0.dist-info → google_genai-1.5.0.dist-info}/WHEEL +0 -0
- {google_genai-1.4.0.dist-info → google_genai-1.5.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
The BaseApiClient is intended to be a private module and is subject to change.
|
20
20
|
"""
|
21
21
|
|
22
|
+
import anyio
|
22
23
|
import asyncio
|
23
24
|
import copy
|
24
25
|
from dataclasses import dataclass
|
@@ -44,6 +45,7 @@ from . import version
|
|
44
45
|
from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
|
45
46
|
|
46
47
|
logger = logging.getLogger('google_genai._api_client')
|
48
|
+
CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
|
47
49
|
|
48
50
|
|
49
51
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
@@ -118,8 +120,9 @@ def _load_auth(*, project: Union[str, None]) -> tuple[Credentials, str]:
|
|
118
120
|
return credentials, project
|
119
121
|
|
120
122
|
|
121
|
-
def _refresh_auth(credentials: Credentials) ->
|
123
|
+
def _refresh_auth(credentials: Credentials) -> Credentials:
|
122
124
|
credentials.refresh(Request())
|
125
|
+
return credentials
|
123
126
|
|
124
127
|
|
125
128
|
@dataclass
|
@@ -147,7 +150,7 @@ class HttpResponse:
|
|
147
150
|
|
148
151
|
def __init__(
|
149
152
|
self,
|
150
|
-
headers: dict[str, str],
|
153
|
+
headers: Union[dict[str, str], httpx.Headers],
|
151
154
|
response_stream: Union[Any, str] = None,
|
152
155
|
byte_stream: Union[Any, bytes] = None,
|
153
156
|
):
|
@@ -200,14 +203,19 @@ class HttpResponse:
|
|
200
203
|
yield c
|
201
204
|
else:
|
202
205
|
# Iterator of objects retrieved from the API.
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
206
|
+
if hasattr(self.response_stream, 'aiter_lines'):
|
207
|
+
async for chunk in self.response_stream.aiter_lines():
|
208
|
+
# This is httpx.Response.
|
209
|
+
if chunk:
|
210
|
+
# In async streaming mode, the chunk of JSON is prefixed with "data:"
|
211
|
+
# which we must strip before parsing.
|
212
|
+
if chunk.startswith('data: '):
|
213
|
+
chunk = chunk[len('data: ') :]
|
214
|
+
yield json.loads(chunk)
|
215
|
+
else:
|
216
|
+
raise ValueError(
|
217
|
+
'Error parsing streaming response.'
|
218
|
+
)
|
211
219
|
|
212
220
|
def byte_segments(self):
|
213
221
|
if isinstance(self.byte_stream, list):
|
@@ -373,17 +381,22 @@ class BaseApiClient:
|
|
373
381
|
if not self.project:
|
374
382
|
self.project = project
|
375
383
|
|
376
|
-
if self._credentials
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
384
|
+
if self._credentials:
|
385
|
+
if (
|
386
|
+
self._credentials.expired or not self._credentials.token
|
387
|
+
):
|
388
|
+
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
389
|
+
async with self._auth_lock:
|
390
|
+
if self._credentials.expired or not self._credentials.token:
|
391
|
+
# Double check that the credentials expired before refreshing.
|
392
|
+
await asyncio.to_thread(_refresh_auth, self._credentials)
|
382
393
|
|
383
|
-
|
384
|
-
|
394
|
+
if not self._credentials.token:
|
395
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
385
396
|
|
386
|
-
|
397
|
+
return self._credentials.token
|
398
|
+
else:
|
399
|
+
raise RuntimeError('Could not resolve API token from the environment')
|
387
400
|
|
388
401
|
def _build_request(
|
389
402
|
self,
|
@@ -428,8 +441,12 @@ class BaseApiClient:
|
|
428
441
|
patched_http_options['api_version'] + '/' + path,
|
429
442
|
)
|
430
443
|
|
431
|
-
timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
|
444
|
+
timeout_in_seconds: Optional[Union[float, int]] = patched_http_options.get(
|
445
|
+
'timeout', None
|
446
|
+
)
|
432
447
|
if timeout_in_seconds:
|
448
|
+
# HttpOptions.timeout is in milliseconds. But httpx.Client.request()
|
449
|
+
# expects seconds.
|
433
450
|
timeout_in_seconds = timeout_in_seconds / 1000.0
|
434
451
|
else:
|
435
452
|
timeout_in_seconds = None
|
@@ -503,18 +520,19 @@ class BaseApiClient:
|
|
503
520
|
http_request.headers['Authorization'] = (
|
504
521
|
f'Bearer {await self._async_access_token()}'
|
505
522
|
)
|
506
|
-
if self._credentials.quota_project_id:
|
523
|
+
if self._credentials and self._credentials.quota_project_id:
|
507
524
|
http_request.headers['x-goog-user-project'] = (
|
508
525
|
self._credentials.quota_project_id
|
509
526
|
)
|
510
527
|
if stream:
|
511
|
-
|
528
|
+
aclient = httpx.AsyncClient()
|
529
|
+
httpx_request = aclient.build_request(
|
512
530
|
method=http_request.method,
|
513
531
|
url=http_request.url,
|
514
532
|
content=json.dumps(http_request.data),
|
515
533
|
headers=http_request.headers,
|
534
|
+
timeout=http_request.timeout,
|
516
535
|
)
|
517
|
-
aclient = httpx.AsyncClient()
|
518
536
|
response = await aclient.send(
|
519
537
|
httpx_request,
|
520
538
|
stream=stream,
|
@@ -652,7 +670,7 @@ class BaseApiClient:
|
|
652
670
|
offset = 0
|
653
671
|
# Upload the file in chunks
|
654
672
|
while True:
|
655
|
-
file_chunk = file.read(
|
673
|
+
file_chunk = file.read(CHUNK_SIZE)
|
656
674
|
chunk_size = 0
|
657
675
|
if file_chunk:
|
658
676
|
chunk_size = len(file_chunk)
|
@@ -679,13 +697,12 @@ class BaseApiClient:
|
|
679
697
|
if upload_size <= offset: # Status is not finalized.
|
680
698
|
raise ValueError(
|
681
699
|
'All content has been uploaded, but the upload status is not'
|
682
|
-
f' finalized.
|
700
|
+
f' finalized.'
|
683
701
|
)
|
684
702
|
|
685
703
|
if response.headers['X-Goog-Upload-Status'] != 'final':
|
686
704
|
raise ValueError(
|
687
|
-
'Failed to upload file: Upload status is not finalized.
|
688
|
-
f' {response.headers}, body: {response.json}'
|
705
|
+
'Failed to upload file: Upload status is not finalized.'
|
689
706
|
)
|
690
707
|
return response.json
|
691
708
|
|
@@ -708,7 +725,7 @@ class BaseApiClient:
|
|
708
725
|
self,
|
709
726
|
http_request: HttpRequest,
|
710
727
|
) -> HttpResponse:
|
711
|
-
data: str
|
728
|
+
data: Optional[Union[str, bytes]] = None
|
712
729
|
if http_request.data:
|
713
730
|
if not isinstance(http_request.data, bytes):
|
714
731
|
data = json.dumps(http_request.data)
|
@@ -746,16 +763,17 @@ class BaseApiClient:
|
|
746
763
|
returns:
|
747
764
|
The response json object from the finalize request.
|
748
765
|
"""
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
766
|
+
if isinstance(file_path, io.IOBase):
|
767
|
+
return await self._async_upload_fd(file_path, upload_url, upload_size)
|
768
|
+
else:
|
769
|
+
file = anyio.Path(file_path)
|
770
|
+
fd = await file.open('rb')
|
771
|
+
async with fd:
|
772
|
+
return await self._async_upload_fd(fd, upload_url, upload_size)
|
755
773
|
|
756
774
|
async def _async_upload_fd(
|
757
775
|
self,
|
758
|
-
file: io.IOBase,
|
776
|
+
file: Union[io.IOBase, anyio.AsyncFile],
|
759
777
|
upload_url: str,
|
760
778
|
upload_size: int,
|
761
779
|
) -> str:
|
@@ -770,12 +788,45 @@ class BaseApiClient:
|
|
770
788
|
returns:
|
771
789
|
The response json object from the finalize request.
|
772
790
|
"""
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
791
|
+
async with httpx.AsyncClient() as aclient:
|
792
|
+
offset = 0
|
793
|
+
# Upload the file in chunks
|
794
|
+
while True:
|
795
|
+
if isinstance(file, io.IOBase):
|
796
|
+
file_chunk = file.read(CHUNK_SIZE)
|
797
|
+
else:
|
798
|
+
file_chunk = await file.read(CHUNK_SIZE)
|
799
|
+
chunk_size = 0
|
800
|
+
if file_chunk:
|
801
|
+
chunk_size = len(file_chunk)
|
802
|
+
upload_command = 'upload'
|
803
|
+
# If last chunk, finalize the upload.
|
804
|
+
if chunk_size + offset >= upload_size:
|
805
|
+
upload_command += ', finalize'
|
806
|
+
response = await aclient.request(
|
807
|
+
method='POST',
|
808
|
+
url=upload_url,
|
809
|
+
content=file_chunk,
|
810
|
+
headers={
|
811
|
+
'X-Goog-Upload-Command': upload_command,
|
812
|
+
'X-Goog-Upload-Offset': str(offset),
|
813
|
+
'Content-Length': str(chunk_size),
|
814
|
+
},
|
815
|
+
)
|
816
|
+
offset += chunk_size
|
817
|
+
if response.headers.get('x-goog-upload-status') != 'active':
|
818
|
+
break # upload is complete or it has been interrupted.
|
819
|
+
|
820
|
+
if upload_size <= offset: # Status is not finalized.
|
821
|
+
raise ValueError(
|
822
|
+
'All content has been uploaded, but the upload status is not'
|
823
|
+
f' finalized.'
|
824
|
+
)
|
825
|
+
if response.headers.get('x-goog-upload-status') != 'final':
|
826
|
+
raise ValueError(
|
827
|
+
'Failed to upload file: Upload status is not finalized.'
|
828
|
+
)
|
829
|
+
return response.json()
|
779
830
|
|
780
831
|
async def async_download_file(self, path: str, http_options):
|
781
832
|
"""Downloads the file data.
|
@@ -787,12 +838,31 @@ class BaseApiClient:
|
|
787
838
|
returns:
|
788
839
|
The file bytes
|
789
840
|
"""
|
790
|
-
|
791
|
-
|
792
|
-
path,
|
793
|
-
http_options,
|
841
|
+
http_request = self._build_request(
|
842
|
+
'get', path=path, request_dict={}, http_options=http_options
|
794
843
|
)
|
795
844
|
|
845
|
+
data: Optional[Union[str, bytes]]
|
846
|
+
if http_request.data:
|
847
|
+
if not isinstance(http_request.data, bytes):
|
848
|
+
data = json.dumps(http_request.data)
|
849
|
+
else:
|
850
|
+
data = http_request.data
|
851
|
+
|
852
|
+
async with httpx.AsyncClient(follow_redirects=True) as aclient:
|
853
|
+
response = await aclient.request(
|
854
|
+
method=http_request.method,
|
855
|
+
url=http_request.url,
|
856
|
+
headers=http_request.headers,
|
857
|
+
content=data,
|
858
|
+
timeout=http_request.timeout,
|
859
|
+
)
|
860
|
+
errors.APIError.raise_for_response(response)
|
861
|
+
|
862
|
+
return HttpResponse(
|
863
|
+
response.headers, byte_stream=[response.read()]
|
864
|
+
).byte_stream[0]
|
865
|
+
|
796
866
|
# This method does nothing in the real api client. It is used in the
|
797
867
|
# replay_api_client to verify the response from the SDK method matches the
|
798
868
|
# recorded response.
|
@@ -17,7 +17,7 @@ import inspect
|
|
17
17
|
import sys
|
18
18
|
import types as builtin_types
|
19
19
|
import typing
|
20
|
-
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Union
|
20
|
+
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Optional, Union
|
21
21
|
|
22
22
|
import pydantic
|
23
23
|
|
@@ -304,9 +304,9 @@ def _parse_schema_from_parameter(
|
|
304
304
|
)
|
305
305
|
|
306
306
|
|
307
|
-
def _get_required_fields(schema: types.Schema) -> list[str]:
|
307
|
+
def _get_required_fields(schema: types.Schema) -> Optional[list[str]]:
|
308
308
|
if not schema.properties:
|
309
|
-
return
|
309
|
+
return None
|
310
310
|
return [
|
311
311
|
field_name
|
312
312
|
for field_name, field_schema in schema.properties.items()
|
google/genai/_common.py
CHANGED
@@ -188,6 +188,8 @@ def _remove_extra_fields(
|
|
188
188
|
if isinstance(item, dict):
|
189
189
|
_remove_extra_fields(typing.get_args(annotation)[0], item)
|
190
190
|
|
191
|
+
T = typing.TypeVar('T', bound='BaseModel')
|
192
|
+
|
191
193
|
|
192
194
|
class BaseModel(pydantic.BaseModel):
|
193
195
|
|
@@ -201,12 +203,13 @@ class BaseModel(pydantic.BaseModel):
|
|
201
203
|
arbitrary_types_allowed=True,
|
202
204
|
ser_json_bytes='base64',
|
203
205
|
val_json_bytes='base64',
|
206
|
+
ignored_types=(typing.TypeVar,)
|
204
207
|
)
|
205
208
|
|
206
209
|
@classmethod
|
207
210
|
def _from_response(
|
208
|
-
cls, *, response: dict[str, object], kwargs: dict[str, object]
|
209
|
-
) ->
|
211
|
+
cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
|
212
|
+
) -> T:
|
210
213
|
# To maintain forward compatibility, we need to remove extra fields from
|
211
214
|
# the response.
|
212
215
|
# We will provide another mechanism to allow users to access these fields.
|
google/genai/_extra_utils.py
CHANGED
@@ -17,9 +17,9 @@
|
|
17
17
|
|
18
18
|
import inspect
|
19
19
|
import logging
|
20
|
+
import sys
|
20
21
|
import typing
|
21
22
|
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
|
22
|
-
import sys
|
23
23
|
|
24
24
|
import pydantic
|
25
25
|
|
@@ -37,6 +37,15 @@ _DEFAULT_MAX_REMOTE_CALLS_AFC = 10
|
|
37
37
|
logger = logging.getLogger('google_genai.models')
|
38
38
|
|
39
39
|
|
40
|
+
def _create_generate_content_config_model(
|
41
|
+
config: types.GenerateContentConfigOrDict,
|
42
|
+
) -> types.GenerateContentConfig:
|
43
|
+
if isinstance(config, dict):
|
44
|
+
return types.GenerateContentConfig(**config)
|
45
|
+
else:
|
46
|
+
return config
|
47
|
+
|
48
|
+
|
40
49
|
def format_destination(
|
41
50
|
src: str,
|
42
51
|
config: Optional[types.CreateBatchJobConfigOrDict] = None,
|
@@ -69,16 +78,12 @@ def format_destination(
|
|
69
78
|
|
70
79
|
def get_function_map(
|
71
80
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
72
|
-
) -> dict[str,
|
81
|
+
) -> dict[str, Callable]:
|
73
82
|
"""Returns a function map from the config."""
|
74
|
-
|
75
|
-
|
76
|
-
if config and isinstance(config, dict)
|
77
|
-
else config
|
78
|
-
)
|
79
|
-
function_map: dict[str, object] = {}
|
80
|
-
if not config_model:
|
83
|
+
function_map: dict[str, Callable] = {}
|
84
|
+
if not config:
|
81
85
|
return function_map
|
86
|
+
config_model = _create_generate_content_config_model(config)
|
82
87
|
if config_model.tools:
|
83
88
|
for tool in config_model.tools:
|
84
89
|
if callable(tool):
|
@@ -92,6 +97,16 @@ def get_function_map(
|
|
92
97
|
return function_map
|
93
98
|
|
94
99
|
|
100
|
+
def convert_number_values_for_dict_function_call_args(
|
101
|
+
args: dict[str, Any],
|
102
|
+
) -> dict[str, Any]:
|
103
|
+
"""Converts float values in dict with no decimal to integers."""
|
104
|
+
return {
|
105
|
+
key: convert_number_values_for_function_call_args(value)
|
106
|
+
for key, value in args.items()
|
107
|
+
}
|
108
|
+
|
109
|
+
|
95
110
|
def convert_number_values_for_function_call_args(
|
96
111
|
args: Union[dict[str, object], list[object], object],
|
97
112
|
) -> Union[dict[str, object], list[object], object]:
|
@@ -210,26 +225,35 @@ def invoke_function_from_dict_args(
|
|
210
225
|
|
211
226
|
def get_function_response_parts(
|
212
227
|
response: types.GenerateContentResponse,
|
213
|
-
function_map: dict[str,
|
228
|
+
function_map: dict[str, Callable],
|
214
229
|
) -> list[types.Part]:
|
215
230
|
"""Returns the function response parts from the response."""
|
216
231
|
func_response_parts = []
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
232
|
+
if (
|
233
|
+
response.candidates is not None
|
234
|
+
and isinstance(response.candidates[0].content, types.Content)
|
235
|
+
and response.candidates[0].content.parts is not None
|
236
|
+
):
|
237
|
+
for part in response.candidates[0].content.parts:
|
238
|
+
if not part.function_call:
|
239
|
+
continue
|
240
|
+
func_name = part.function_call.name
|
241
|
+
if func_name is not None and part.function_call.args is not None:
|
242
|
+
func = function_map[func_name]
|
243
|
+
args = convert_number_values_for_dict_function_call_args(
|
244
|
+
part.function_call.args
|
245
|
+
)
|
246
|
+
func_response: dict[str, Any]
|
247
|
+
try:
|
248
|
+
func_response = {
|
249
|
+
'result': invoke_function_from_dict_args(args, func)
|
250
|
+
}
|
251
|
+
except Exception as e: # pylint: disable=broad-except
|
252
|
+
func_response = {'error': str(e)}
|
253
|
+
func_response_part = types.Part.from_function_response(
|
254
|
+
name=func_name, response=func_response
|
255
|
+
)
|
256
|
+
func_response_parts.append(func_response_part)
|
233
257
|
return func_response_parts
|
234
258
|
|
235
259
|
|
@@ -237,12 +261,9 @@ def should_disable_afc(
|
|
237
261
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
238
262
|
) -> bool:
|
239
263
|
"""Returns whether automatic function calling is enabled."""
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
else config
|
244
|
-
)
|
245
|
-
|
264
|
+
if not config:
|
265
|
+
return False
|
266
|
+
config_model = _create_generate_content_config_model(config)
|
246
267
|
# If max_remote_calls is less or equal to 0, warn and disable AFC.
|
247
268
|
if (
|
248
269
|
config_model
|
@@ -261,8 +282,7 @@ def should_disable_afc(
|
|
261
282
|
|
262
283
|
# Default to enable AFC if not specified.
|
263
284
|
if (
|
264
|
-
not config_model
|
265
|
-
or not config_model.automatic_function_calling
|
285
|
+
not config_model.automatic_function_calling
|
266
286
|
or config_model.automatic_function_calling.disable is None
|
267
287
|
):
|
268
288
|
return False
|
@@ -295,20 +315,17 @@ def should_disable_afc(
|
|
295
315
|
def get_max_remote_calls_afc(
|
296
316
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
297
317
|
) -> int:
|
318
|
+
if not config:
|
319
|
+
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
298
320
|
"""Returns the remaining remote calls for automatic function calling."""
|
299
321
|
if should_disable_afc(config):
|
300
322
|
raise ValueError(
|
301
323
|
'automatic function calling is not enabled, but SDK is trying to get'
|
302
324
|
' max remote calls.'
|
303
325
|
)
|
304
|
-
config_model = (
|
305
|
-
types.GenerateContentConfig(**config)
|
306
|
-
if config and isinstance(config, dict)
|
307
|
-
else config
|
308
|
-
)
|
326
|
+
config_model = _create_generate_content_config_model(config)
|
309
327
|
if (
|
310
|
-
not config_model
|
311
|
-
or not config_model.automatic_function_calling
|
328
|
+
not config_model.automatic_function_calling
|
312
329
|
or config_model.automatic_function_calling.maximum_remote_calls is None
|
313
330
|
):
|
314
331
|
return _DEFAULT_MAX_REMOTE_CALLS_AFC
|
@@ -318,11 +335,9 @@ def get_max_remote_calls_afc(
|
|
318
335
|
def should_append_afc_history(
|
319
336
|
config: Optional[types.GenerateContentConfigOrDict] = None,
|
320
337
|
) -> bool:
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
)
|
326
|
-
if not config_model or not config_model.automatic_function_calling:
|
338
|
+
if not config:
|
339
|
+
return True
|
340
|
+
config_model = _create_generate_content_config_model(config)
|
341
|
+
if not config_model.automatic_function_calling:
|
327
342
|
return True
|
328
343
|
return not config_model.automatic_function_calling.ignore_call_history
|
@@ -109,7 +109,8 @@ def _redact_project_location_path(path: str) -> str:
|
|
109
109
|
return path
|
110
110
|
|
111
111
|
|
112
|
-
def _redact_request_body(body: dict[str, object])
|
112
|
+
def _redact_request_body(body: dict[str, object]):
|
113
|
+
"""Redacts fields in the request body in place."""
|
113
114
|
for key, value in body.items():
|
114
115
|
if isinstance(value, str):
|
115
116
|
body[key] = _redact_project_location_path(value)
|
@@ -302,13 +303,24 @@ class ReplayApiClient(BaseApiClient):
|
|
302
303
|
status_code=http_response.status_code,
|
303
304
|
sdk_response_segments=[],
|
304
305
|
)
|
305
|
-
|
306
|
+
elif isinstance(http_response, errors.APIError):
|
306
307
|
response = ReplayResponse(
|
307
308
|
headers=dict(http_response.response.headers),
|
308
309
|
body_segments=[http_response._to_replay_record()],
|
309
310
|
status_code=http_response.code,
|
310
311
|
sdk_response_segments=[],
|
311
312
|
)
|
313
|
+
elif isinstance(http_response, bytes):
|
314
|
+
response = ReplayResponse(
|
315
|
+
headers={},
|
316
|
+
body_segments=[],
|
317
|
+
byte_segments=[http_response],
|
318
|
+
sdk_response_segments=[],
|
319
|
+
)
|
320
|
+
else:
|
321
|
+
raise ValueError(
|
322
|
+
'Unsupported http_response type: ' + str(type(http_response))
|
323
|
+
)
|
312
324
|
self.replay_session.interactions.append(
|
313
325
|
ReplayInteraction(request=request, response=response)
|
314
326
|
)
|
@@ -471,6 +483,43 @@ class ReplayApiClient(BaseApiClient):
|
|
471
483
|
else:
|
472
484
|
return self._build_response_from_replay(request).json
|
473
485
|
|
486
|
+
async def async_upload_file(
|
487
|
+
self,
|
488
|
+
file_path: Union[str, io.IOBase],
|
489
|
+
upload_url: str,
|
490
|
+
upload_size: int,
|
491
|
+
) -> str:
|
492
|
+
if isinstance(file_path, io.IOBase):
|
493
|
+
offset = file_path.tell()
|
494
|
+
content = file_path.read()
|
495
|
+
file_path.seek(offset, os.SEEK_SET)
|
496
|
+
request = HttpRequest(
|
497
|
+
method='POST',
|
498
|
+
url='',
|
499
|
+
data={'bytes': base64.b64encode(content).decode('utf-8')},
|
500
|
+
headers={},
|
501
|
+
)
|
502
|
+
else:
|
503
|
+
request = HttpRequest(
|
504
|
+
method='POST', url='', data={'file_path': file_path}, headers={}
|
505
|
+
)
|
506
|
+
if self._should_call_api():
|
507
|
+
result: Union[str, HttpResponse]
|
508
|
+
try:
|
509
|
+
result = await super().async_upload_file(
|
510
|
+
file_path, upload_url, upload_size
|
511
|
+
)
|
512
|
+
except HTTPError as e:
|
513
|
+
result = HttpResponse(
|
514
|
+
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
515
|
+
)
|
516
|
+
result.status_code = e.response.status_code
|
517
|
+
raise e
|
518
|
+
self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
|
519
|
+
return result
|
520
|
+
else:
|
521
|
+
return self._build_response_from_replay(request).json
|
522
|
+
|
474
523
|
def _download_file_request(self, request):
|
475
524
|
self._initialize_replay_session_if_not_loaded()
|
476
525
|
if self._should_call_api():
|
@@ -486,3 +535,22 @@ class ReplayApiClient(BaseApiClient):
|
|
486
535
|
return result
|
487
536
|
else:
|
488
537
|
return self._build_response_from_replay(request)
|
538
|
+
|
539
|
+
async def async_download_file(self, path: str, http_options):
|
540
|
+
self._initialize_replay_session_if_not_loaded()
|
541
|
+
request = self._build_request(
|
542
|
+
'get', path=path, request_dict={}, http_options=http_options
|
543
|
+
)
|
544
|
+
if self._should_call_api():
|
545
|
+
try:
|
546
|
+
result = await super().async_download_file(path, http_options)
|
547
|
+
except HTTPError as e:
|
548
|
+
result = HttpResponse(
|
549
|
+
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
550
|
+
)
|
551
|
+
result.status_code = e.response.status_code
|
552
|
+
raise e
|
553
|
+
self._record_interaction(request, result)
|
554
|
+
return result
|
555
|
+
else:
|
556
|
+
return self._build_response_from_replay(request).byte_stream[0]
|