google-genai 1.7.0__tar.gz → 1.8.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_genai-1.7.0/google_genai.egg-info → google_genai-1.8.0}/PKG-INFO +3 -2
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/_api_client.py +93 -78
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/_replay_api_client.py +22 -14
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/_transformers.py +23 -14
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/batches.py +60 -294
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/caches.py +545 -525
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/chats.py +15 -8
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/client.py +5 -3
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/errors.py +46 -23
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/files.py +88 -304
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/live.py +4 -4
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/models.py +1991 -2290
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/operations.py +103 -123
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/tunings.py +255 -271
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/types.py +207 -74
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/version.py +1 -1
- {google_genai-1.7.0 → google_genai-1.8.0/google_genai.egg-info}/PKG-INFO +3 -2
- {google_genai-1.7.0 → google_genai-1.8.0}/pyproject.toml +1 -1
- {google_genai-1.7.0 → google_genai-1.8.0}/LICENSE +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/MANIFEST.in +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/README.md +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/__init__.py +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/_api_module.py +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/_automatic_function_calling_util.py +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/_common.py +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/_extra_utils.py +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/_test_api_client.py +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google/genai/pagers.py +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google_genai.egg-info/SOURCES.txt +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google_genai.egg-info/dependency_links.txt +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google_genai.egg-info/requires.txt +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/google_genai.egg-info/top_level.txt +0 -0
- {google_genai-1.7.0 → google_genai-1.8.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: google-genai
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.8.0
|
4
4
|
Summary: GenAI Python SDK
|
5
5
|
Author-email: Google LLC <googleapis-packages@google.com>
|
6
6
|
License: Apache-2.0
|
@@ -27,6 +27,7 @@ Requires-Dist: pydantic<3.0.0,>=2.0.0
|
|
27
27
|
Requires-Dist: requests<3.0.0,>=2.28.1
|
28
28
|
Requires-Dist: websockets<15.1.0,>=13.0.0
|
29
29
|
Requires-Dist: typing-extensions<5.0.0,>=4.11.0
|
30
|
+
Dynamic: license-file
|
30
31
|
|
31
32
|
# Google Gen AI SDK
|
32
33
|
|
@@ -68,26 +68,30 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
|
|
68
68
|
|
69
69
|
|
70
70
|
def _patch_http_options(
|
71
|
-
options:
|
72
|
-
) ->
|
73
|
-
|
74
|
-
|
75
|
-
copy_option.
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
71
|
+
options: HttpOptions, patch_options: HttpOptions
|
72
|
+
) -> HttpOptions:
|
73
|
+
copy_option = options.model_copy()
|
74
|
+
|
75
|
+
options_headers = copy_option.headers or {}
|
76
|
+
patch_options_headers = patch_options.headers or {}
|
77
|
+
copy_option.headers = {
|
78
|
+
**options_headers,
|
79
|
+
**patch_options_headers,
|
80
|
+
}
|
81
|
+
|
82
|
+
http_options_keys = HttpOptions.model_fields.keys()
|
83
|
+
|
84
|
+
for key in http_options_keys:
|
85
|
+
if key == 'headers':
|
86
|
+
continue
|
87
|
+
patch_value = getattr(patch_options, key, None)
|
88
|
+
if patch_value is not None:
|
89
|
+
setattr(copy_option, key, patch_value)
|
90
|
+
else:
|
91
|
+
setattr(copy_option, key, getattr(options, key))
|
92
|
+
|
93
|
+
if copy_option.headers is not None:
|
94
|
+
_append_library_version_headers(copy_option.headers)
|
91
95
|
return copy_option
|
92
96
|
|
93
97
|
|
@@ -200,7 +204,7 @@ class HttpResponse:
|
|
200
204
|
for chunk in self.response_stream:
|
201
205
|
yield json.loads(chunk) if chunk else {}
|
202
206
|
elif self.response_stream is None:
|
203
|
-
async for c in []:
|
207
|
+
async for c in []: # type: ignore[attr-defined]
|
204
208
|
yield c
|
205
209
|
else:
|
206
210
|
# Iterator of objects retrieved from the API.
|
@@ -216,9 +220,7 @@ class HttpResponse:
|
|
216
220
|
chunk = chunk[len('data: ') :]
|
217
221
|
yield json.loads(chunk)
|
218
222
|
else:
|
219
|
-
raise ValueError(
|
220
|
-
'Error parsing streaming response.'
|
221
|
-
)
|
223
|
+
raise ValueError('Error parsing streaming response.')
|
222
224
|
|
223
225
|
def byte_segments(self):
|
224
226
|
if isinstance(self.byte_stream, list):
|
@@ -308,16 +310,14 @@ class BaseApiClient:
|
|
308
310
|
)
|
309
311
|
|
310
312
|
# Validate http_options if it is provided.
|
311
|
-
validated_http_options
|
313
|
+
validated_http_options = HttpOptions()
|
312
314
|
if isinstance(http_options, dict):
|
313
315
|
try:
|
314
|
-
validated_http_options = HttpOptions.model_validate(
|
315
|
-
http_options
|
316
|
-
).model_dump()
|
316
|
+
validated_http_options = HttpOptions.model_validate(http_options)
|
317
317
|
except ValidationError as e:
|
318
318
|
raise ValueError(f'Invalid http_options: {e}')
|
319
319
|
elif isinstance(http_options, HttpOptions):
|
320
|
-
validated_http_options = http_options
|
320
|
+
validated_http_options = http_options
|
321
321
|
|
322
322
|
# Retrieve implicitly set values from the environment.
|
323
323
|
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
@@ -328,11 +328,15 @@ class BaseApiClient:
|
|
328
328
|
self.api_key = api_key or env_api_key
|
329
329
|
|
330
330
|
self._credentials = credentials
|
331
|
-
self._http_options =
|
331
|
+
self._http_options = HttpOptions()
|
332
332
|
# Initialize the lock. This lock will be used to protect access to the
|
333
333
|
# credentials. This is crucial for thread safety when multiple coroutines
|
334
334
|
# might be accessing the credentials at the same time.
|
335
|
-
|
335
|
+
try:
|
336
|
+
self._auth_lock = asyncio.Lock()
|
337
|
+
except RuntimeError:
|
338
|
+
asyncio.set_event_loop(asyncio.new_event_loop())
|
339
|
+
self._auth_lock = asyncio.Lock()
|
336
340
|
|
337
341
|
# Handle when to use Vertex AI in express mode (api key).
|
338
342
|
# Explicit initializer arguments are already validated above.
|
@@ -376,12 +380,12 @@ class BaseApiClient:
|
|
376
380
|
'AI API.'
|
377
381
|
)
|
378
382
|
if self.api_key or self.location == 'global':
|
379
|
-
self._http_options
|
383
|
+
self._http_options.base_url = f'https://aiplatform.googleapis.com/'
|
380
384
|
else:
|
381
|
-
self._http_options
|
385
|
+
self._http_options.base_url = (
|
382
386
|
f'https://{self.location}-aiplatform.googleapis.com/'
|
383
387
|
)
|
384
|
-
self._http_options
|
388
|
+
self._http_options.api_version = 'v1beta1'
|
385
389
|
else: # Implicit initialization or missing arguments.
|
386
390
|
if not self.api_key:
|
387
391
|
raise ValueError(
|
@@ -389,27 +393,27 @@ class BaseApiClient:
|
|
389
393
|
'provide (`api_key`) arguments. To use the Google Cloud API,'
|
390
394
|
' provide (`vertexai`, `project` & `location`) arguments.'
|
391
395
|
)
|
392
|
-
self._http_options
|
393
|
-
|
394
|
-
)
|
395
|
-
self._http_options['api_version'] = 'v1beta'
|
396
|
+
self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
|
397
|
+
self._http_options.api_version = 'v1beta'
|
396
398
|
# Default options for both clients.
|
397
|
-
self._http_options
|
399
|
+
self._http_options.headers = {'Content-Type': 'application/json'}
|
398
400
|
if self.api_key:
|
399
|
-
self._http_options
|
401
|
+
if self._http_options.headers is not None:
|
402
|
+
self._http_options.headers['x-goog-api-key'] = self.api_key
|
400
403
|
# Update the http options with the user provided http options.
|
401
404
|
if http_options:
|
402
405
|
self._http_options = _patch_http_options(
|
403
406
|
self._http_options, validated_http_options
|
404
407
|
)
|
405
408
|
else:
|
406
|
-
|
409
|
+
if self._http_options.headers is not None:
|
410
|
+
_append_library_version_headers(self._http_options.headers)
|
407
411
|
# Initialize the httpx client.
|
408
412
|
self._httpx_client = SyncHttpxClient()
|
409
413
|
self._async_httpx_client = AsyncHttpxClient()
|
410
414
|
|
411
415
|
def _websocket_base_url(self):
|
412
|
-
url_parts = urlparse(self._http_options
|
416
|
+
url_parts = urlparse(self._http_options.base_url)
|
413
417
|
return url_parts._replace(scheme='wss').geturl()
|
414
418
|
|
415
419
|
def _access_token(self) -> str:
|
@@ -420,9 +424,7 @@ class BaseApiClient:
|
|
420
424
|
self.project = project
|
421
425
|
|
422
426
|
if self._credentials:
|
423
|
-
if
|
424
|
-
self._credentials.expired or not self._credentials.token
|
425
|
-
):
|
427
|
+
if self._credentials.expired or not self._credentials.token:
|
426
428
|
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
427
429
|
_refresh_auth(self._credentials)
|
428
430
|
if not self._credentials.token:
|
@@ -446,9 +448,7 @@ class BaseApiClient:
|
|
446
448
|
self.project = project
|
447
449
|
|
448
450
|
if self._credentials:
|
449
|
-
if
|
450
|
-
self._credentials.expired or not self._credentials.token
|
451
|
-
):
|
451
|
+
if self._credentials.expired or not self._credentials.token:
|
452
452
|
# Only refresh when it needs to. Default expiration is 3600 seconds.
|
453
453
|
async with self._auth_lock:
|
454
454
|
if self._credentials.expired or not self._credentials.token:
|
@@ -477,11 +477,12 @@ class BaseApiClient:
|
|
477
477
|
if http_options:
|
478
478
|
if isinstance(http_options, HttpOptions):
|
479
479
|
patched_http_options = _patch_http_options(
|
480
|
-
self._http_options,
|
480
|
+
self._http_options,
|
481
|
+
http_options,
|
481
482
|
)
|
482
483
|
else:
|
483
484
|
patched_http_options = _patch_http_options(
|
484
|
-
self._http_options, http_options
|
485
|
+
self._http_options, HttpOptions.model_validate(http_options)
|
485
486
|
)
|
486
487
|
else:
|
487
488
|
patched_http_options = self._http_options
|
@@ -500,13 +501,27 @@ class BaseApiClient:
|
|
500
501
|
and not self.api_key
|
501
502
|
):
|
502
503
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
504
|
+
|
505
|
+
if patched_http_options.api_version is None:
|
506
|
+
versioned_path = f'/{path}'
|
507
|
+
else:
|
508
|
+
versioned_path = f'{patched_http_options.api_version}/{path}'
|
509
|
+
|
510
|
+
if (
|
511
|
+
patched_http_options.base_url is None
|
512
|
+
or not patched_http_options.base_url
|
513
|
+
):
|
514
|
+
raise ValueError('Base URL must be set.')
|
515
|
+
else:
|
516
|
+
base_url = patched_http_options.base_url
|
517
|
+
|
503
518
|
url = _join_url_path(
|
504
|
-
|
505
|
-
|
519
|
+
base_url,
|
520
|
+
versioned_path,
|
506
521
|
)
|
507
522
|
|
508
|
-
timeout_in_seconds: Optional[Union[float, int]] =
|
509
|
-
|
523
|
+
timeout_in_seconds: Optional[Union[float, int]] = (
|
524
|
+
patched_http_options.timeout
|
510
525
|
)
|
511
526
|
if timeout_in_seconds:
|
512
527
|
# HttpOptions.timeout is in milliseconds. But httpx.Client.request()
|
@@ -515,10 +530,12 @@ class BaseApiClient:
|
|
515
530
|
else:
|
516
531
|
timeout_in_seconds = None
|
517
532
|
|
533
|
+
if patched_http_options.headers is None:
|
534
|
+
raise ValueError('Request headers must be set.')
|
518
535
|
return HttpRequest(
|
519
536
|
method=http_method,
|
520
537
|
url=url,
|
521
|
-
headers=patched_http_options
|
538
|
+
headers=patched_http_options.headers,
|
522
539
|
data=request_dict,
|
523
540
|
timeout=timeout_in_seconds,
|
524
541
|
)
|
@@ -530,9 +547,7 @@ class BaseApiClient:
|
|
530
547
|
) -> HttpResponse:
|
531
548
|
data: Optional[Union[str, bytes]] = None
|
532
549
|
if self.vertexai and not self.api_key:
|
533
|
-
http_request.headers['Authorization'] = (
|
534
|
-
f'Bearer {self._access_token()}'
|
535
|
-
)
|
550
|
+
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
|
536
551
|
if self._credentials and self._credentials.quota_project_id:
|
537
552
|
http_request.headers['x-goog-user-project'] = (
|
538
553
|
self._credentials.quota_project_id
|
@@ -603,7 +618,7 @@ class BaseApiClient:
|
|
603
618
|
httpx_request,
|
604
619
|
stream=stream,
|
605
620
|
)
|
606
|
-
errors.APIError.
|
621
|
+
await errors.APIError.raise_for_async_response(response)
|
607
622
|
return HttpResponse(
|
608
623
|
response.headers, response if stream else [response.text]
|
609
624
|
)
|
@@ -615,16 +630,16 @@ class BaseApiClient:
|
|
615
630
|
content=data,
|
616
631
|
timeout=http_request.timeout,
|
617
632
|
)
|
618
|
-
errors.APIError.
|
633
|
+
await errors.APIError.raise_for_async_response(response)
|
619
634
|
return HttpResponse(
|
620
635
|
response.headers, response if stream else [response.text]
|
621
636
|
)
|
622
637
|
|
623
|
-
def get_read_only_http_options(self) ->
|
624
|
-
copied = HttpOptionsDict()
|
638
|
+
def get_read_only_http_options(self) -> dict[str, Any]:
|
625
639
|
if isinstance(self._http_options, BaseModel):
|
626
|
-
|
627
|
-
|
640
|
+
copied = self._http_options.model_dump()
|
641
|
+
else:
|
642
|
+
copied = self._http_options
|
628
643
|
return copied
|
629
644
|
|
630
645
|
def request(
|
@@ -650,7 +665,7 @@ class BaseApiClient:
|
|
650
665
|
http_method: str,
|
651
666
|
path: str,
|
652
667
|
request_dict: dict[str, object],
|
653
|
-
http_options: Optional[
|
668
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
654
669
|
):
|
655
670
|
http_request = self._build_request(
|
656
671
|
http_method, path, request_dict, http_options
|
@@ -682,7 +697,7 @@ class BaseApiClient:
|
|
682
697
|
http_method: str,
|
683
698
|
path: str,
|
684
699
|
request_dict: dict[str, object],
|
685
|
-
http_options: Optional[
|
700
|
+
http_options: Optional[HttpOptionsOrDict] = None,
|
686
701
|
):
|
687
702
|
http_request = self._build_request(
|
688
703
|
http_method, path, request_dict, http_options
|
@@ -698,7 +713,7 @@ class BaseApiClient:
|
|
698
713
|
|
699
714
|
def upload_file(
|
700
715
|
self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
|
701
|
-
) ->
|
716
|
+
) -> HttpResponse:
|
702
717
|
"""Transfers a file to the given URL.
|
703
718
|
|
704
719
|
Args:
|
@@ -710,7 +725,7 @@ class BaseApiClient:
|
|
710
725
|
match the size requested in the resumable upload request.
|
711
726
|
|
712
727
|
returns:
|
713
|
-
The
|
728
|
+
The HttpResponse object from the finalize request.
|
714
729
|
"""
|
715
730
|
if isinstance(file_path, io.IOBase):
|
716
731
|
return self._upload_fd(file_path, upload_url, upload_size)
|
@@ -720,7 +735,7 @@ class BaseApiClient:
|
|
720
735
|
|
721
736
|
def _upload_fd(
|
722
737
|
self, file: io.IOBase, upload_url: str, upload_size: int
|
723
|
-
) ->
|
738
|
+
) -> HttpResponse:
|
724
739
|
"""Transfers a file to the given URL.
|
725
740
|
|
726
741
|
Args:
|
@@ -730,7 +745,7 @@ class BaseApiClient:
|
|
730
745
|
match the size requested in the resumable upload request.
|
731
746
|
|
732
747
|
returns:
|
733
|
-
The
|
748
|
+
The HttpResponse object from the finalize request.
|
734
749
|
"""
|
735
750
|
offset = 0
|
736
751
|
# Upload the file in chunks
|
@@ -758,7 +773,7 @@ class BaseApiClient:
|
|
758
773
|
break # upload is complete or it has been interrupted.
|
759
774
|
if upload_size <= offset: # Status is not finalized.
|
760
775
|
raise ValueError(
|
761
|
-
'All content has been uploaded, but the upload status is not'
|
776
|
+
f'All content has been uploaded, but the upload status is not'
|
762
777
|
f' finalized.'
|
763
778
|
)
|
764
779
|
|
@@ -766,7 +781,7 @@ class BaseApiClient:
|
|
766
781
|
raise ValueError(
|
767
782
|
'Failed to upload file: Upload status is not finalized.'
|
768
783
|
)
|
769
|
-
return response.
|
784
|
+
return HttpResponse(response.headers, response_stream=[response.text])
|
770
785
|
|
771
786
|
def download_file(self, path: str, http_options):
|
772
787
|
"""Downloads the file data.
|
@@ -807,7 +822,7 @@ class BaseApiClient:
|
|
807
822
|
file_path: Union[str, io.IOBase],
|
808
823
|
upload_url: str,
|
809
824
|
upload_size: int,
|
810
|
-
) ->
|
825
|
+
) -> HttpResponse:
|
811
826
|
"""Transfers a file asynchronously to the given URL.
|
812
827
|
|
813
828
|
Args:
|
@@ -818,7 +833,7 @@ class BaseApiClient:
|
|
818
833
|
match the size requested in the resumable upload request.
|
819
834
|
|
820
835
|
returns:
|
821
|
-
The
|
836
|
+
The HttpResponse object from the finalize request.
|
822
837
|
"""
|
823
838
|
if isinstance(file_path, io.IOBase):
|
824
839
|
return await self._async_upload_fd(file_path, upload_url, upload_size)
|
@@ -833,7 +848,7 @@ class BaseApiClient:
|
|
833
848
|
file: Union[io.IOBase, anyio.AsyncFile],
|
834
849
|
upload_url: str,
|
835
850
|
upload_size: int,
|
836
|
-
) ->
|
851
|
+
) -> HttpResponse:
|
837
852
|
"""Transfers a file asynchronously to the given URL.
|
838
853
|
|
839
854
|
Args:
|
@@ -843,7 +858,7 @@ class BaseApiClient:
|
|
843
858
|
match the size requested in the resumable upload request.
|
844
859
|
|
845
860
|
returns:
|
846
|
-
The
|
861
|
+
The HttpResponse object from the finalized request.
|
847
862
|
"""
|
848
863
|
offset = 0
|
849
864
|
# Upload the file in chunks
|
@@ -882,7 +897,7 @@ class BaseApiClient:
|
|
882
897
|
raise ValueError(
|
883
898
|
'Failed to upload file: Upload status is not finalized.'
|
884
899
|
)
|
885
|
-
return response.
|
900
|
+
return HttpResponse(response.headers, response_stream=[response.text])
|
886
901
|
|
887
902
|
async def async_download_file(self, path: str, http_options):
|
888
903
|
"""Downloads the file data.
|
@@ -912,7 +927,7 @@ class BaseApiClient:
|
|
912
927
|
content=data,
|
913
928
|
timeout=http_request.timeout,
|
914
929
|
)
|
915
|
-
errors.APIError.
|
930
|
+
await errors.APIError.raise_for_async_response(response)
|
916
931
|
|
917
932
|
return HttpResponse(
|
918
933
|
response.headers, byte_stream=[response.read()]
|
@@ -119,7 +119,8 @@ def _redact_request_body(body: dict[str, object]):
|
|
119
119
|
def redact_http_request(http_request: HttpRequest):
|
120
120
|
http_request.headers = _redact_request_headers(http_request.headers)
|
121
121
|
http_request.url = _redact_request_url(http_request.url)
|
122
|
-
|
122
|
+
if not isinstance(http_request.data, bytes):
|
123
|
+
_redact_request_body(http_request.data)
|
123
124
|
|
124
125
|
|
125
126
|
def _current_file_path_and_line():
|
@@ -321,6 +322,8 @@ class ReplayApiClient(BaseApiClient):
|
|
321
322
|
raise ValueError(
|
322
323
|
'Unsupported http_response type: ' + str(type(http_response))
|
323
324
|
)
|
325
|
+
if self.replay_session is None:
|
326
|
+
raise ValueError('No replay session found.')
|
324
327
|
self.replay_session.interactions.append(
|
325
328
|
ReplayInteraction(request=request, response=response)
|
326
329
|
)
|
@@ -342,7 +345,8 @@ class ReplayApiClient(BaseApiClient):
|
|
342
345
|
request_data_copy = copy.deepcopy(http_request.data)
|
343
346
|
# Both the request and recorded request must be redacted before comparing
|
344
347
|
# so that the comparison is fair.
|
345
|
-
|
348
|
+
if not isinstance(request_data_copy, bytes):
|
349
|
+
_redact_request_body(request_data_copy)
|
346
350
|
|
347
351
|
actual_request_body = [request_data_copy]
|
348
352
|
expected_request_body = interaction.request.body_segments
|
@@ -352,9 +356,11 @@ class ReplayApiClient(BaseApiClient):
|
|
352
356
|
f'Expected: {expected_request_body}'
|
353
357
|
)
|
354
358
|
|
355
|
-
def _build_response_from_replay(self, http_request: HttpRequest):
|
359
|
+
def _build_response_from_replay(self, http_request: HttpRequest) -> HttpResponse:
|
356
360
|
redact_http_request(http_request)
|
357
361
|
|
362
|
+
if self.replay_session is None:
|
363
|
+
raise ValueError('No replay session found.')
|
358
364
|
interaction = self.replay_session.interactions[self._replay_index]
|
359
365
|
# Replay is on the right side of the assert so the diff makes more sense.
|
360
366
|
self._match_request(http_request, interaction)
|
@@ -373,6 +379,8 @@ class ReplayApiClient(BaseApiClient):
|
|
373
379
|
def _verify_response(self, response_model: BaseModel):
|
374
380
|
if self._mode == 'api':
|
375
381
|
return
|
382
|
+
if not self.replay_session:
|
383
|
+
raise ValueError('No replay session found.')
|
376
384
|
# replay_index is advanced in _build_response_from_replay, so we need to -1.
|
377
385
|
interaction = self.replay_session.interactions[self._replay_index - 1]
|
378
386
|
if self._should_update_replay():
|
@@ -453,7 +461,7 @@ class ReplayApiClient(BaseApiClient):
|
|
453
461
|
else:
|
454
462
|
return self._build_response_from_replay(http_request)
|
455
463
|
|
456
|
-
def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int):
|
464
|
+
def upload_file(self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int) -> HttpResponse:
|
457
465
|
if isinstance(file_path, io.IOBase):
|
458
466
|
offset = file_path.tell()
|
459
467
|
content = file_path.read()
|
@@ -474,21 +482,21 @@ class ReplayApiClient(BaseApiClient):
|
|
474
482
|
result = super().upload_file(file_path, upload_url, upload_size)
|
475
483
|
except HTTPError as e:
|
476
484
|
result = HttpResponse(
|
477
|
-
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
485
|
+
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
|
478
486
|
)
|
479
487
|
result.status_code = e.response.status_code
|
480
488
|
raise e
|
481
|
-
self._record_interaction(request,
|
489
|
+
self._record_interaction(request, result)
|
482
490
|
return result
|
483
491
|
else:
|
484
|
-
return self._build_response_from_replay(request)
|
492
|
+
return self._build_response_from_replay(request)
|
485
493
|
|
486
494
|
async def async_upload_file(
|
487
495
|
self,
|
488
496
|
file_path: Union[str, io.IOBase],
|
489
497
|
upload_url: str,
|
490
498
|
upload_size: int,
|
491
|
-
) ->
|
499
|
+
) -> HttpResponse:
|
492
500
|
if isinstance(file_path, io.IOBase):
|
493
501
|
offset = file_path.tell()
|
494
502
|
content = file_path.read()
|
@@ -504,21 +512,21 @@ class ReplayApiClient(BaseApiClient):
|
|
504
512
|
method='POST', url='', data={'file_path': file_path}, headers={}
|
505
513
|
)
|
506
514
|
if self._should_call_api():
|
507
|
-
result:
|
515
|
+
result: HttpResponse
|
508
516
|
try:
|
509
517
|
result = await super().async_upload_file(
|
510
518
|
file_path, upload_url, upload_size
|
511
519
|
)
|
512
520
|
except HTTPError as e:
|
513
521
|
result = HttpResponse(
|
514
|
-
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
522
|
+
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
|
515
523
|
)
|
516
524
|
result.status_code = e.response.status_code
|
517
525
|
raise e
|
518
|
-
self._record_interaction(request,
|
526
|
+
self._record_interaction(request, result)
|
519
527
|
return result
|
520
528
|
else:
|
521
|
-
return self._build_response_from_replay(request)
|
529
|
+
return self._build_response_from_replay(request)
|
522
530
|
|
523
531
|
def download_file(self, path: str, http_options: HttpOptions):
|
524
532
|
self._initialize_replay_session_if_not_loaded()
|
@@ -530,7 +538,7 @@ class ReplayApiClient(BaseApiClient):
|
|
530
538
|
result = super().download_file(path, http_options)
|
531
539
|
except HTTPError as e:
|
532
540
|
result = HttpResponse(
|
533
|
-
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
541
|
+
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
|
534
542
|
)
|
535
543
|
result.status_code = e.response.status_code
|
536
544
|
raise e
|
@@ -549,7 +557,7 @@ class ReplayApiClient(BaseApiClient):
|
|
549
557
|
result = await super().async_download_file(path, http_options)
|
550
558
|
except HTTPError as e:
|
551
559
|
result = HttpResponse(
|
552
|
-
e.response.headers, [json.dumps({'reason': e.response.reason})]
|
560
|
+
dict(e.response.headers), [json.dumps({'reason': e.response.reason})]
|
553
561
|
)
|
554
562
|
result.status_code = e.response.status_code
|
555
563
|
raise e
|
@@ -181,17 +181,26 @@ def t_models_url(
|
|
181
181
|
|
182
182
|
def t_extract_models(
|
183
183
|
api_client: _api_client.BaseApiClient,
|
184
|
-
response: dict[str,
|
185
|
-
) ->
|
184
|
+
response: dict[str, Any],
|
185
|
+
) -> list[dict[str, Any]]:
|
186
186
|
if not response:
|
187
187
|
return []
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
return
|
192
|
-
|
193
|
-
|
194
|
-
|
188
|
+
|
189
|
+
models: Optional[list[dict[str, Any]]] = response.get('models')
|
190
|
+
if models is not None:
|
191
|
+
return models
|
192
|
+
|
193
|
+
tuned_models: Optional[list[dict[str, Any]]] = response.get('tunedModels')
|
194
|
+
if tuned_models is not None:
|
195
|
+
return tuned_models
|
196
|
+
|
197
|
+
publisher_models: Optional[list[dict[str, Any]]] = response.get(
|
198
|
+
'publisherModels'
|
199
|
+
)
|
200
|
+
if publisher_models is not None:
|
201
|
+
return publisher_models
|
202
|
+
|
203
|
+
if (
|
195
204
|
response.get('httpHeaders') is not None
|
196
205
|
and response.get('jsonPayload') is None
|
197
206
|
):
|
@@ -526,7 +535,6 @@ def process_schema(
|
|
526
535
|
):
|
527
536
|
"""Updates the schema and each sub-schema inplace to be API-compatible.
|
528
537
|
|
529
|
-
- Removes the `title` field from the schema if the client is not vertexai.
|
530
538
|
- Inlines the $defs.
|
531
539
|
|
532
540
|
Example of a schema before and after (with mldev):
|
@@ -570,21 +578,22 @@ def process_schema(
|
|
570
578
|
'items': {
|
571
579
|
'properties': {
|
572
580
|
'continent': {
|
573
|
-
|
581
|
+
'title': 'Continent',
|
582
|
+
'type': 'string'
|
574
583
|
},
|
575
584
|
'gdp': {
|
576
|
-
|
585
|
+
'title': 'Gdp',
|
586
|
+
'type': 'integer'
|
577
587
|
},
|
578
588
|
}
|
579
589
|
'required':['continent', 'gdp'],
|
590
|
+
'title': 'CountryInfo',
|
580
591
|
'type': 'object'
|
581
592
|
},
|
582
593
|
'type': 'array'
|
583
594
|
}
|
584
595
|
"""
|
585
596
|
if not client.vertexai:
|
586
|
-
schema.pop('title', None)
|
587
|
-
|
588
597
|
if schema.get('default') is not None:
|
589
598
|
raise ValueError(
|
590
599
|
'Default value is not supported in the response schema for the Gemini'
|