google-genai 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +164 -55
- google/genai/_common.py +37 -6
- google/genai/_extra_utils.py +3 -3
- google/genai/_replay_api_client.py +44 -32
- google/genai/_transformers.py +167 -38
- google/genai/batches.py +10 -10
- google/genai/caches.py +10 -10
- google/genai/client.py +2 -1
- google/genai/errors.py +1 -1
- google/genai/files.py +239 -40
- google/genai/live.py +5 -1
- google/genai/models.py +102 -30
- google/genai/tunings.py +8 -8
- google/genai/types.py +546 -348
- google/genai/version.py +1 -1
- google_genai-0.6.0.dist-info/METADATA +973 -0
- google_genai-0.6.0.dist-info/RECORD +25 -0
- google_genai-0.5.0.dist-info/METADATA +0 -888
- google_genai-0.5.0.dist-info/RECORD +0 -25
- {google_genai-0.5.0.dist-info → google_genai-0.6.0.dist-info}/LICENSE +0 -0
- {google_genai-0.5.0.dist-info → google_genai-0.6.0.dist-info}/WHEEL +0 -0
- {google_genai-0.5.0.dist-info → google_genai-0.6.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -20,6 +20,7 @@ import asyncio
|
|
20
20
|
import copy
|
21
21
|
from dataclasses import dataclass
|
22
22
|
import datetime
|
23
|
+
import io
|
23
24
|
import json
|
24
25
|
import logging
|
25
26
|
import os
|
@@ -148,10 +149,16 @@ class HttpRequest:
|
|
148
149
|
|
149
150
|
class HttpResponse:
|
150
151
|
|
151
|
-
def __init__(
|
152
|
+
def __init__(
|
153
|
+
self,
|
154
|
+
headers: dict[str, str],
|
155
|
+
response_stream: Union[Any, str] = None,
|
156
|
+
byte_stream: Union[Any, bytes] = None,
|
157
|
+
):
|
152
158
|
self.status_code = 200
|
153
159
|
self.headers = headers
|
154
160
|
self.response_stream = response_stream
|
161
|
+
self.byte_stream = byte_stream
|
155
162
|
|
156
163
|
@property
|
157
164
|
def text(self) -> str:
|
@@ -164,6 +171,8 @@ class HttpResponse:
|
|
164
171
|
# list of objects retrieved from replay or from non-streaming API.
|
165
172
|
for chunk in self.response_stream:
|
166
173
|
yield json.loads(chunk) if chunk else {}
|
174
|
+
elif self.response_stream is None:
|
175
|
+
yield from []
|
167
176
|
else:
|
168
177
|
# Iterator of objects retrieved from the API.
|
169
178
|
for chunk in self.response_stream.iter_lines():
|
@@ -174,6 +183,17 @@ class HttpResponse:
|
|
174
183
|
chunk = chunk[len(b'data: ') :]
|
175
184
|
yield json.loads(str(chunk, 'utf-8'))
|
176
185
|
|
186
|
+
def byte_segments(self):
|
187
|
+
if isinstance(self.byte_stream, list):
|
188
|
+
# list of objects retrieved from replay or from non-streaming API.
|
189
|
+
yield from self.byte_stream
|
190
|
+
elif self.byte_stream is None:
|
191
|
+
yield from []
|
192
|
+
else:
|
193
|
+
raise ValueError(
|
194
|
+
'Byte segments are not supported for streaming responses.'
|
195
|
+
)
|
196
|
+
|
177
197
|
def copy_to_dict(self, response_payload: dict[str, object]):
|
178
198
|
for attribute in dir(self):
|
179
199
|
response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
|
@@ -199,7 +219,7 @@ class ApiClient:
|
|
199
219
|
]:
|
200
220
|
self.vertexai = True
|
201
221
|
|
202
|
-
# Validate explicitly set
|
222
|
+
# Validate explicitly set initializer values.
|
203
223
|
if (project or location) and api_key:
|
204
224
|
# API cannot consume both project/location and api_key.
|
205
225
|
raise ValueError(
|
@@ -265,9 +285,10 @@ class ApiClient:
|
|
265
285
|
self.api_key = None
|
266
286
|
if not self.project and not self.api_key:
|
267
287
|
self.project = google.auth.default()[1]
|
268
|
-
if not (self.project
|
288
|
+
if not ((self.project and self.location) or self.api_key):
|
269
289
|
raise ValueError(
|
270
|
-
'Project
|
290
|
+
'Project and location or API key must be set when using the Vertex '
|
291
|
+
'AI API.'
|
271
292
|
)
|
272
293
|
if self.api_key:
|
273
294
|
self._http_options['base_url'] = (
|
@@ -357,7 +378,7 @@ class ApiClient:
|
|
357
378
|
http_request.method.upper(),
|
358
379
|
http_request.url,
|
359
380
|
headers=http_request.headers,
|
360
|
-
data=json.dumps(http_request.data
|
381
|
+
data=json.dumps(http_request.data)
|
361
382
|
if http_request.data
|
362
383
|
else None,
|
363
384
|
timeout=http_request.timeout,
|
@@ -377,7 +398,7 @@ class ApiClient:
|
|
377
398
|
data = None
|
378
399
|
if http_request.data:
|
379
400
|
if not isinstance(http_request.data, bytes):
|
380
|
-
data = json.dumps(http_request.data
|
401
|
+
data = json.dumps(http_request.data)
|
381
402
|
else:
|
382
403
|
data = http_request.data
|
383
404
|
|
@@ -488,12 +509,35 @@ class ApiClient:
|
|
488
509
|
if http_options and 'response_payload' in http_options:
|
489
510
|
response.copy_to_dict(http_options['response_payload'])
|
490
511
|
|
491
|
-
def upload_file(
|
512
|
+
def upload_file(
|
513
|
+
self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
|
514
|
+
) -> str:
|
492
515
|
"""Transfers a file to the given URL.
|
493
516
|
|
494
517
|
Args:
|
495
|
-
file_path: The full path to the file
|
496
|
-
an error will be
|
518
|
+
file_path: The full path to the file or a file like object inherited from
|
519
|
+
io.BytesIO. If the local file path is not found, an error will be
|
520
|
+
raised.
|
521
|
+
upload_url: The URL to upload the file to.
|
522
|
+
upload_size: The size of file content to be uploaded, this will have to
|
523
|
+
match the size requested in the resumable upload request.
|
524
|
+
|
525
|
+
returns:
|
526
|
+
The response json object from the finalize request.
|
527
|
+
"""
|
528
|
+
if isinstance(file_path, io.IOBase):
|
529
|
+
return self._upload_fd(file_path, upload_url, upload_size)
|
530
|
+
else:
|
531
|
+
with open(file_path, 'rb') as file:
|
532
|
+
return self._upload_fd(file, upload_url, upload_size)
|
533
|
+
|
534
|
+
def _upload_fd(
|
535
|
+
self, file: io.IOBase, upload_url: str, upload_size: int
|
536
|
+
) -> str:
|
537
|
+
"""Transfers a file to the given URL.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
file: A file like object inherited from io.BytesIO.
|
497
541
|
upload_url: The URL to upload the file to.
|
498
542
|
upload_size: The size of file content to be uploaded, this will have to
|
499
543
|
match the size requested in the resumable upload request.
|
@@ -503,37 +547,36 @@ class ApiClient:
|
|
503
547
|
"""
|
504
548
|
offset = 0
|
505
549
|
# Upload the file in chunks
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
break # upload is complete or it has been interrupted.
|
550
|
+
while True:
|
551
|
+
file_chunk = file.read(1024 * 1024 * 8) # 8 MB chunk size
|
552
|
+
chunk_size = 0
|
553
|
+
if file_chunk:
|
554
|
+
chunk_size = len(file_chunk)
|
555
|
+
upload_command = 'upload'
|
556
|
+
# If last chunk, finalize the upload.
|
557
|
+
if chunk_size + offset >= upload_size:
|
558
|
+
upload_command += ', finalize'
|
559
|
+
request = HttpRequest(
|
560
|
+
method='POST',
|
561
|
+
url=upload_url,
|
562
|
+
headers={
|
563
|
+
'X-Goog-Upload-Command': upload_command,
|
564
|
+
'X-Goog-Upload-Offset': str(offset),
|
565
|
+
'Content-Length': str(chunk_size),
|
566
|
+
},
|
567
|
+
data=file_chunk,
|
568
|
+
)
|
569
|
+
|
570
|
+
response = self._request_unauthorized(request, stream=False)
|
571
|
+
offset += chunk_size
|
572
|
+
if response.headers['X-Goog-Upload-Status'] != 'active':
|
573
|
+
break # upload is complete or it has been interrupted.
|
531
574
|
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
575
|
+
if upload_size <= offset: # Status is not finalized.
|
576
|
+
raise ValueError(
|
577
|
+
'All content has been uploaded, but the upload status is not'
|
578
|
+
f' finalized. {response.headers}, body: {response.text}'
|
579
|
+
)
|
537
580
|
|
538
581
|
if response.headers['X-Goog-Upload-Status'] != 'final':
|
539
582
|
raise ValueError(
|
@@ -542,12 +585,52 @@ class ApiClient:
|
|
542
585
|
)
|
543
586
|
return response.text
|
544
587
|
|
588
|
+
def download_file(self, path: str, http_options):
|
589
|
+
"""Downloads the file data.
|
590
|
+
|
591
|
+
Args:
|
592
|
+
path: The request path with query params.
|
593
|
+
http_options: The http options to use for the request.
|
594
|
+
|
595
|
+
returns:
|
596
|
+
The file bytes
|
597
|
+
"""
|
598
|
+
http_request = self._build_request(
|
599
|
+
'get', path=path, request_dict={}, http_options=http_options
|
600
|
+
)
|
601
|
+
return self._download_file_request(http_request).byte_stream[0]
|
602
|
+
|
603
|
+
def _download_file_request(
|
604
|
+
self,
|
605
|
+
http_request: HttpRequest,
|
606
|
+
) -> HttpResponse:
|
607
|
+
data = None
|
608
|
+
if http_request.data:
|
609
|
+
if not isinstance(http_request.data, bytes):
|
610
|
+
data = json.dumps(http_request.data, cls=RequestJsonEncoder)
|
611
|
+
else:
|
612
|
+
data = http_request.data
|
613
|
+
|
614
|
+
http_session = requests.Session()
|
615
|
+
response = http_session.request(
|
616
|
+
method=http_request.method,
|
617
|
+
url=http_request.url,
|
618
|
+
headers=http_request.headers,
|
619
|
+
data=data,
|
620
|
+
timeout=http_request.timeout,
|
621
|
+
stream=False,
|
622
|
+
)
|
623
|
+
|
624
|
+
errors.APIError.raise_for_response(response)
|
625
|
+
return HttpResponse(response.headers, byte_stream=[response.content])
|
626
|
+
|
627
|
+
|
545
628
|
async def async_upload_file(
|
546
629
|
self,
|
547
|
-
file_path: str,
|
630
|
+
file_path: Union[str, io.IOBase],
|
548
631
|
upload_url: str,
|
549
632
|
upload_size: int,
|
550
|
-
):
|
633
|
+
) -> str:
|
551
634
|
"""Transfers a file asynchronously to the given URL.
|
552
635
|
|
553
636
|
Args:
|
@@ -567,22 +650,48 @@ class ApiClient:
|
|
567
650
|
upload_size,
|
568
651
|
)
|
569
652
|
|
653
|
+
async def _async_upload_fd(
|
654
|
+
self,
|
655
|
+
file: io.IOBase,
|
656
|
+
upload_url: str,
|
657
|
+
upload_size: int,
|
658
|
+
) -> str:
|
659
|
+
"""Transfers a file asynchronously to the given URL.
|
660
|
+
|
661
|
+
Args:
|
662
|
+
file: A file like object inherited from io.BytesIO.
|
663
|
+
upload_url: The URL to upload the file to.
|
664
|
+
upload_size: The size of file content to be uploaded, this will have to
|
665
|
+
match the size requested in the resumable upload request.
|
666
|
+
|
667
|
+
returns:
|
668
|
+
The response json object from the finalize request.
|
669
|
+
"""
|
670
|
+
return await asyncio.to_thread(
|
671
|
+
self._upload_fd,
|
672
|
+
file,
|
673
|
+
upload_url,
|
674
|
+
upload_size,
|
675
|
+
)
|
676
|
+
|
677
|
+
async def async_download_file(self, path: str, http_options):
|
678
|
+
"""Downloads the file data.
|
679
|
+
|
680
|
+
Args:
|
681
|
+
path: The request path with query params.
|
682
|
+
http_options: The http options to use for the request.
|
683
|
+
|
684
|
+
returns:
|
685
|
+
The file bytes
|
686
|
+
"""
|
687
|
+
return await asyncio.to_thread(
|
688
|
+
self.download_file,
|
689
|
+
path,
|
690
|
+
http_options,
|
691
|
+
)
|
692
|
+
|
570
693
|
# This method does nothing in the real api client. It is used in the
|
571
694
|
# replay_api_client to verify the response from the SDK method matches the
|
572
695
|
# recorded response.
|
573
696
|
def _verify_response(self, response_model: BaseModel):
|
574
697
|
pass
|
575
|
-
|
576
|
-
|
577
|
-
# TODO(b/389693448): Cleanup datetime hacks.
|
578
|
-
class RequestJsonEncoder(json.JSONEncoder):
|
579
|
-
"""Encode bytes as strings without modify its content."""
|
580
|
-
|
581
|
-
def default(self, o):
|
582
|
-
if isinstance(o, datetime.datetime):
|
583
|
-
# This Zulu time format is used by the Vertex AI API and the test recorder
|
584
|
-
# Using strftime works well, but we want to align with the replay encoder.
|
585
|
-
# o.astimezone(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
586
|
-
return o.isoformat().replace('+00:00', 'Z')
|
587
|
-
else:
|
588
|
-
return super().default(o)
|
google/genai/_common.py
CHANGED
@@ -17,6 +17,7 @@
|
|
17
17
|
|
18
18
|
import base64
|
19
19
|
import datetime
|
20
|
+
import enum
|
20
21
|
import typing
|
21
22
|
from typing import Union
|
22
23
|
import uuid
|
@@ -144,7 +145,7 @@ def _remove_extra_fields(
|
|
144
145
|
) -> None:
|
145
146
|
"""Removes extra fields from the response that are not in the model.
|
146
147
|
|
147
|
-
|
148
|
+
Mutates the response in place.
|
148
149
|
"""
|
149
150
|
|
150
151
|
key_values = list(response.items())
|
@@ -185,7 +186,7 @@ class BaseModel(pydantic.BaseModel):
|
|
185
186
|
alias_generator=alias_generators.to_camel,
|
186
187
|
populate_by_name=True,
|
187
188
|
from_attributes=True,
|
188
|
-
protected_namespaces=
|
189
|
+
protected_namespaces=(),
|
189
190
|
extra='forbid',
|
190
191
|
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
191
192
|
arbitrary_types_allowed=True,
|
@@ -208,6 +209,20 @@ class BaseModel(pydantic.BaseModel):
|
|
208
209
|
return self.model_dump(exclude_none=True, mode='json')
|
209
210
|
|
210
211
|
|
212
|
+
class CaseInSensitiveEnum(str, enum.Enum):
|
213
|
+
"""Case insensitive enum."""
|
214
|
+
|
215
|
+
@classmethod
|
216
|
+
def _missing_(cls, value):
|
217
|
+
try:
|
218
|
+
return cls[value.upper()] # Try to access directly with uppercase
|
219
|
+
except KeyError:
|
220
|
+
try:
|
221
|
+
return cls[value.lower()] # Try to access directly with lowercase
|
222
|
+
except KeyError as e:
|
223
|
+
raise ValueError(f"{value} is not a valid {cls.__name__}") from e
|
224
|
+
|
225
|
+
|
211
226
|
def timestamped_unique_name() -> str:
|
212
227
|
"""Composes a timestamped unique name.
|
213
228
|
|
@@ -219,23 +234,39 @@ def timestamped_unique_name() -> str:
|
|
219
234
|
return f'{timestamp}_{unique_id}'
|
220
235
|
|
221
236
|
|
222
|
-
def
|
223
|
-
"""
|
237
|
+
def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
238
|
+
"""Converts unserializable types in dict to json.dumps() compatible types.
|
239
|
+
|
240
|
+
This function is called in models.py after calling convert_to_dict(). The
|
241
|
+
convert_to_dict() can convert pydantic object to dict. However, the input to
|
242
|
+
convert_to_dict() is dict mixed of pydantic object and nested dict(the output
|
243
|
+
of converters). So they may be bytes in the dict and they are out of
|
244
|
+
`ser_json_bytes` control in model_dump(mode='json') called in
|
245
|
+
`convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
|
249
|
+
to compatible type (e.g. base64 encoded string, isoformat date string).
|
250
|
+
"""
|
224
251
|
processed_data = {}
|
225
252
|
if not isinstance(data, dict):
|
226
253
|
return data
|
227
254
|
for key, value in data.items():
|
228
255
|
if isinstance(value, bytes):
|
229
256
|
processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
|
257
|
+
elif isinstance(value, datetime.datetime):
|
258
|
+
processed_data[key] = value.isoformat()
|
230
259
|
elif isinstance(value, dict):
|
231
|
-
processed_data[key] =
|
260
|
+
processed_data[key] = encode_unserializable_types(value)
|
232
261
|
elif isinstance(value, list):
|
233
262
|
if all(isinstance(v, bytes) for v in value):
|
234
263
|
processed_data[key] = [
|
235
264
|
base64.urlsafe_b64encode(v).decode('ascii') for v in value
|
236
265
|
]
|
266
|
+
if all(isinstance(v, datetime.datetime) for v in value):
|
267
|
+
processed_data[key] = [v.isoformat() for v in value]
|
237
268
|
else:
|
238
|
-
processed_data[key] = [
|
269
|
+
processed_data[key] = [encode_unserializable_types(v) for v in value]
|
239
270
|
else:
|
240
271
|
processed_data[key] = value
|
241
272
|
return processed_data
|
google/genai/_extra_utils.py
CHANGED
@@ -116,7 +116,7 @@ def convert_if_exist_pydantic_model(
|
|
116
116
|
try:
|
117
117
|
return annotation(**value)
|
118
118
|
except pydantic.ValidationError as e:
|
119
|
-
raise errors.
|
119
|
+
raise errors.UnknownFunctionCallArgumentError(
|
120
120
|
f'Failed to parse parameter {param_name} for function'
|
121
121
|
f' {func_name} from function call part because function call argument'
|
122
122
|
f' value {value} is not compatible with parameter annotation'
|
@@ -150,7 +150,7 @@ def convert_if_exist_pydantic_model(
|
|
150
150
|
except pydantic.ValidationError:
|
151
151
|
continue
|
152
152
|
# if none of the union type is matched, raise error
|
153
|
-
raise errors.
|
153
|
+
raise errors.UnknownFunctionCallArgumentError(
|
154
154
|
f'Failed to parse parameter {param_name} for function'
|
155
155
|
f' {func_name} from function call part because function call argument'
|
156
156
|
f' value {value} cannot be converted to parameter annotation'
|
@@ -161,7 +161,7 @@ def convert_if_exist_pydantic_model(
|
|
161
161
|
if isinstance(value, int) and annotation is float:
|
162
162
|
return value
|
163
163
|
if not isinstance(value, annotation):
|
164
|
-
raise errors.
|
164
|
+
raise errors.UnknownFunctionCallArgumentError(
|
165
165
|
f'Failed to parse parameter {param_name} for function {func_name} from'
|
166
166
|
f' function call part because function call argument value {value} is'
|
167
167
|
f' not compatible with parameter annotation {annotation}.'
|
@@ -17,11 +17,12 @@
|
|
17
17
|
|
18
18
|
import base64
|
19
19
|
import copy
|
20
|
+
import datetime
|
20
21
|
import inspect
|
22
|
+
import io
|
21
23
|
import json
|
22
24
|
import os
|
23
25
|
import re
|
24
|
-
import datetime
|
25
26
|
from typing import Any, Literal, Optional, Union
|
26
27
|
|
27
28
|
import google.auth
|
@@ -32,9 +33,9 @@ from ._api_client import ApiClient
|
|
32
33
|
from ._api_client import HttpOptions
|
33
34
|
from ._api_client import HttpRequest
|
34
35
|
from ._api_client import HttpResponse
|
35
|
-
from ._api_client import RequestJsonEncoder
|
36
36
|
from ._common import BaseModel
|
37
37
|
|
38
|
+
|
38
39
|
def _redact_version_numbers(version_string: str) -> str:
|
39
40
|
"""Redacts version numbers in the form x.y.z from a string."""
|
40
41
|
return re.sub(r'\d+\.\d+\.\d+', '{VERSION_NUMBER}', version_string)
|
@@ -145,6 +146,7 @@ class ReplayResponse(BaseModel):
|
|
145
146
|
status_code: int = 200
|
146
147
|
headers: dict[str, str]
|
147
148
|
body_segments: list[dict[str, object]]
|
149
|
+
byte_segments: Optional[list[bytes]] = None
|
148
150
|
sdk_response_segments: list[dict[str, object]]
|
149
151
|
|
150
152
|
def model_post_init(self, __context: Any) -> None:
|
@@ -264,17 +266,13 @@ class ReplayApiClient(ApiClient):
|
|
264
266
|
replay_file_path = self._get_replay_file_path()
|
265
267
|
os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
|
266
268
|
with open(replay_file_path, 'w') as f:
|
267
|
-
f.write(
|
268
|
-
json.dumps(
|
269
|
-
self.replay_session.model_dump(mode='json'), indent=2, cls=ResponseJsonEncoder
|
270
|
-
)
|
271
|
-
)
|
269
|
+
f.write(self.replay_session.model_dump_json(exclude_unset=True, indent=2))
|
272
270
|
self.replay_session = None
|
273
271
|
|
274
272
|
def _record_interaction(
|
275
273
|
self,
|
276
274
|
http_request: HttpRequest,
|
277
|
-
http_response: Union[HttpResponse, errors.APIError],
|
275
|
+
http_response: Union[HttpResponse, errors.APIError, bytes],
|
278
276
|
):
|
279
277
|
if not self._should_update_replay():
|
280
278
|
return
|
@@ -289,6 +287,9 @@ class ReplayApiClient(ApiClient):
|
|
289
287
|
response = ReplayResponse(
|
290
288
|
headers=dict(http_response.headers),
|
291
289
|
body_segments=list(http_response.segments()),
|
290
|
+
byte_segments=[
|
291
|
+
seg[:100] + b'...' for seg in http_response.byte_segments()
|
292
|
+
],
|
292
293
|
status_code=http_response.status_code,
|
293
294
|
sdk_response_segments=[],
|
294
295
|
)
|
@@ -322,11 +323,7 @@ class ReplayApiClient(ApiClient):
|
|
322
323
|
# so that the comparison is fair.
|
323
324
|
_redact_request_body(request_data_copy)
|
324
325
|
|
325
|
-
|
326
|
-
# Because the expected_request_body dict never contains bytes values.
|
327
|
-
actual_request_body = [
|
328
|
-
json.loads(json.dumps(request_data_copy, cls=RequestJsonEncoder))
|
329
|
-
]
|
326
|
+
actual_request_body = [request_data_copy]
|
330
327
|
expected_request_body = interaction.request.body_segments
|
331
328
|
assert actual_request_body == expected_request_body, (
|
332
329
|
'Request body mismatch:\n'
|
@@ -349,6 +346,7 @@ class ReplayApiClient(ApiClient):
|
|
349
346
|
json.dumps(segment)
|
350
347
|
for segment in interaction.response.body_segments
|
351
348
|
],
|
349
|
+
byte_stream=interaction.response.byte_segments,
|
352
350
|
)
|
353
351
|
|
354
352
|
def _verify_response(self, response_model: BaseModel):
|
@@ -368,7 +366,9 @@ class ReplayApiClient(ApiClient):
|
|
368
366
|
response_model = response_model[0]
|
369
367
|
print('response_model: ', response_model.model_dump(exclude_none=True))
|
370
368
|
actual = response_model.model_dump(exclude_none=True, mode='json')
|
371
|
-
expected = interaction.response.sdk_response_segments[
|
369
|
+
expected = interaction.response.sdk_response_segments[
|
370
|
+
self._sdk_response_index
|
371
|
+
]
|
372
372
|
assert (
|
373
373
|
actual == expected
|
374
374
|
), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
|
@@ -402,10 +402,21 @@ class ReplayApiClient(ApiClient):
|
|
402
402
|
else:
|
403
403
|
return self._build_response_from_replay(http_request)
|
404
404
|
|
405
|
-
def upload_file(self, file_path: str, upload_url: str, upload_size: int):
|
406
|
-
|
407
|
-
|
408
|
-
|
405
|
+
def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
|
406
|
+
if isinstance(file_path, io.IOBase):
|
407
|
+
offset = file_path.tell()
|
408
|
+
content = file_path.read()
|
409
|
+
file_path.seek(offset, os.SEEK_SET)
|
410
|
+
request = HttpRequest(
|
411
|
+
method='POST',
|
412
|
+
url='',
|
413
|
+
data={'bytes': base64.b64encode(content).decode('utf-8')},
|
414
|
+
headers={}
|
415
|
+
)
|
416
|
+
else:
|
417
|
+
request = HttpRequest(
|
418
|
+
method='POST', url='', data={'file_path': file_path}, headers={}
|
419
|
+
)
|
409
420
|
if self._should_call_api():
|
410
421
|
try:
|
411
422
|
result = super().upload_file(file_path, upload_url, upload_size)
|
@@ -420,18 +431,19 @@ class ReplayApiClient(ApiClient):
|
|
420
431
|
else:
|
421
432
|
return self._build_response_from_replay(request).text
|
422
433
|
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
434
|
+
def _download_file_request(self, request):
|
435
|
+
self._initialize_replay_session_if_not_loaded()
|
436
|
+
if self._should_call_api():
|
437
|
+
try:
|
438
|
+
result = super()._download_file_request(request)
|
439
|
+
except HTTPError as e:
|
440
|
+
result = HttpResponse(
|
441
|
+
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
442
|
+
)
|
443
|
+
result.status_code = e.response.status_code
|
444
|
+
raise e
|
445
|
+
self._record_interaction(request, result)
|
446
|
+
return result
|
436
447
|
else:
|
437
|
-
return
|
448
|
+
return self._build_response_from_replay(request)
|
449
|
+
|