google-genai 0.3.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_genai-0.3.0/google_genai.egg-info → google_genai-0.4.0}/PKG-INFO +57 -17
- {google_genai-0.3.0 → google_genai-0.4.0}/README.md +55 -15
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/__init__.py +2 -1
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_api_client.py +85 -36
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_automatic_function_calling_util.py +14 -14
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_replay_api_client.py +22 -28
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/batches.py +16 -16
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/caches.py +18 -18
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/chats.py +2 -2
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/client.py +6 -3
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/files.py +22 -22
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/live.py +28 -5
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/models.py +97 -77
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/tunings.py +17 -17
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/types.py +150 -80
- google_genai-0.4.0/google/genai/version.py +16 -0
- {google_genai-0.3.0 → google_genai-0.4.0/google_genai.egg-info}/PKG-INFO +57 -17
- {google_genai-0.3.0 → google_genai-0.4.0}/google_genai.egg-info/SOURCES.txt +1 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/pyproject.toml +1 -1
- {google_genai-0.3.0 → google_genai-0.4.0}/LICENSE +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_common.py +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_extra_utils.py +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_test_api_client.py +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_transformers.py +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/errors.py +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/pagers.py +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/google_genai.egg-info/dependency_links.txt +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/google_genai.egg-info/requires.txt +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/google_genai.egg-info/top_level.txt +0 -0
- {google_genai-0.3.0 → google_genai-0.4.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: google-genai
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: GenAI Python SDK
|
5
5
|
Author-email: Google LLC <googleapis-packages@google.com>
|
6
6
|
License: Apache-2.0
|
@@ -436,10 +436,10 @@ response1 = client.models.generate_image(
|
|
436
436
|
model='imagen-3.0-generate-001',
|
437
437
|
prompt='An umbrella in the foreground, and a rainy night sky in the background',
|
438
438
|
config=types.GenerateImageConfig(
|
439
|
-
negative_prompt=
|
439
|
+
negative_prompt= 'human',
|
440
440
|
number_of_images= 1,
|
441
441
|
include_rai_reason= True,
|
442
|
-
output_mime_type=
|
442
|
+
output_mime_type= 'image/jpeg'
|
443
443
|
)
|
444
444
|
)
|
445
445
|
response1.generated_images[0].image.show()
|
@@ -454,7 +454,11 @@ Upscale image is not supported in Google AI.
|
|
454
454
|
response2 = client.models.upscale_image(
|
455
455
|
model='imagen-3.0-generate-001',
|
456
456
|
image=response1.generated_images[0].image,
|
457
|
-
|
457
|
+
upscale_factor='x2',
|
458
|
+
config=types.UpscaleImageConfig(
|
459
|
+
include_rai_reason= True,
|
460
|
+
output_mime_type= 'image/jpeg',
|
461
|
+
),
|
458
462
|
)
|
459
463
|
response2.generated_images[0].image.show()
|
460
464
|
```
|
@@ -497,6 +501,42 @@ response3 = client.models.edit_image(
|
|
497
501
|
response3.generated_images[0].image.show()
|
498
502
|
```
|
499
503
|
|
504
|
+
## Chats
|
505
|
+
|
506
|
+
Create a chat session to start a multi-turn conversations with the model.
|
507
|
+
|
508
|
+
### Send Message
|
509
|
+
|
510
|
+
```python
|
511
|
+
chat = client.chats.create(model='gemini-2.0-flash-exp')
|
512
|
+
response = chat.send_message('tell me a story')
|
513
|
+
print(response.text)
|
514
|
+
```
|
515
|
+
|
516
|
+
### Streaming
|
517
|
+
|
518
|
+
```python
|
519
|
+
chat = client.chats.create(model='gemini-2.0-flash-exp')
|
520
|
+
for chunk in chat.send_message_stream('tell me a story'):
|
521
|
+
print(chunk.text)
|
522
|
+
```
|
523
|
+
|
524
|
+
### Async
|
525
|
+
|
526
|
+
```python
|
527
|
+
chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
|
528
|
+
response = await chat.send_message('tell me a story')
|
529
|
+
print(response.text)
|
530
|
+
```
|
531
|
+
|
532
|
+
### Async Streaming
|
533
|
+
|
534
|
+
```python
|
535
|
+
chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
|
536
|
+
async for chunk in chat.send_message_stream('tell me a story'):
|
537
|
+
print(chunk.text)
|
538
|
+
```
|
539
|
+
|
500
540
|
## Files (Only Google AI)
|
501
541
|
|
502
542
|
``` python
|
@@ -539,19 +579,19 @@ else:
|
|
539
579
|
|
540
580
|
cached_content = client.caches.create(
|
541
581
|
model='gemini-1.5-pro-002',
|
542
|
-
contents=[
|
543
|
-
types.Content(
|
544
|
-
role='user',
|
545
|
-
parts=[
|
546
|
-
types.Part.from_uri(
|
547
|
-
file_uri=file_uris[0],
|
548
|
-
mime_type='application/pdf'),
|
549
|
-
types.Part.from_uri(
|
550
|
-
file_uri=file_uris[1],
|
551
|
-
mime_type='application/pdf',)])
|
552
|
-
],
|
553
|
-
system_instruction='What is the sum of the two pdfs?',
|
554
582
|
config=types.CreateCachedContentConfig(
|
583
|
+
contents=[
|
584
|
+
types.Content(
|
585
|
+
role='user',
|
586
|
+
parts=[
|
587
|
+
types.Part.from_uri(
|
588
|
+
file_uri=file_uris[0],
|
589
|
+
mime_type='application/pdf'),
|
590
|
+
types.Part.from_uri(
|
591
|
+
file_uri=file_uris[1],
|
592
|
+
mime_type='application/pdf',)])
|
593
|
+
],
|
594
|
+
system_instruction='What is the sum of the two pdfs?',
|
555
595
|
display_name='test cache',
|
556
596
|
ttl='3600s',
|
557
597
|
),
|
@@ -408,10 +408,10 @@ response1 = client.models.generate_image(
|
|
408
408
|
model='imagen-3.0-generate-001',
|
409
409
|
prompt='An umbrella in the foreground, and a rainy night sky in the background',
|
410
410
|
config=types.GenerateImageConfig(
|
411
|
-
negative_prompt=
|
411
|
+
negative_prompt= 'human',
|
412
412
|
number_of_images= 1,
|
413
413
|
include_rai_reason= True,
|
414
|
-
output_mime_type=
|
414
|
+
output_mime_type= 'image/jpeg'
|
415
415
|
)
|
416
416
|
)
|
417
417
|
response1.generated_images[0].image.show()
|
@@ -426,7 +426,11 @@ Upscale image is not supported in Google AI.
|
|
426
426
|
response2 = client.models.upscale_image(
|
427
427
|
model='imagen-3.0-generate-001',
|
428
428
|
image=response1.generated_images[0].image,
|
429
|
-
|
429
|
+
upscale_factor='x2',
|
430
|
+
config=types.UpscaleImageConfig(
|
431
|
+
include_rai_reason= True,
|
432
|
+
output_mime_type= 'image/jpeg',
|
433
|
+
),
|
430
434
|
)
|
431
435
|
response2.generated_images[0].image.show()
|
432
436
|
```
|
@@ -469,6 +473,42 @@ response3 = client.models.edit_image(
|
|
469
473
|
response3.generated_images[0].image.show()
|
470
474
|
```
|
471
475
|
|
476
|
+
## Chats
|
477
|
+
|
478
|
+
Create a chat session to start a multi-turn conversations with the model.
|
479
|
+
|
480
|
+
### Send Message
|
481
|
+
|
482
|
+
```python
|
483
|
+
chat = client.chats.create(model='gemini-2.0-flash-exp')
|
484
|
+
response = chat.send_message('tell me a story')
|
485
|
+
print(response.text)
|
486
|
+
```
|
487
|
+
|
488
|
+
### Streaming
|
489
|
+
|
490
|
+
```python
|
491
|
+
chat = client.chats.create(model='gemini-2.0-flash-exp')
|
492
|
+
for chunk in chat.send_message_stream('tell me a story'):
|
493
|
+
print(chunk.text)
|
494
|
+
```
|
495
|
+
|
496
|
+
### Async
|
497
|
+
|
498
|
+
```python
|
499
|
+
chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
|
500
|
+
response = await chat.send_message('tell me a story')
|
501
|
+
print(response.text)
|
502
|
+
```
|
503
|
+
|
504
|
+
### Async Streaming
|
505
|
+
|
506
|
+
```python
|
507
|
+
chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
|
508
|
+
async for chunk in chat.send_message_stream('tell me a story'):
|
509
|
+
print(chunk.text)
|
510
|
+
```
|
511
|
+
|
472
512
|
## Files (Only Google AI)
|
473
513
|
|
474
514
|
``` python
|
@@ -511,19 +551,19 @@ else:
|
|
511
551
|
|
512
552
|
cached_content = client.caches.create(
|
513
553
|
model='gemini-1.5-pro-002',
|
514
|
-
contents=[
|
515
|
-
types.Content(
|
516
|
-
role='user',
|
517
|
-
parts=[
|
518
|
-
types.Part.from_uri(
|
519
|
-
file_uri=file_uris[0],
|
520
|
-
mime_type='application/pdf'),
|
521
|
-
types.Part.from_uri(
|
522
|
-
file_uri=file_uris[1],
|
523
|
-
mime_type='application/pdf',)])
|
524
|
-
],
|
525
|
-
system_instruction='What is the sum of the two pdfs?',
|
526
554
|
config=types.CreateCachedContentConfig(
|
555
|
+
contents=[
|
556
|
+
types.Content(
|
557
|
+
role='user',
|
558
|
+
parts=[
|
559
|
+
types.Part.from_uri(
|
560
|
+
file_uri=file_uris[0],
|
561
|
+
mime_type='application/pdf'),
|
562
|
+
types.Part.from_uri(
|
563
|
+
file_uri=file_uris[1],
|
564
|
+
mime_type='application/pdf',)])
|
565
|
+
],
|
566
|
+
system_instruction='What is the sum of the two pdfs?',
|
527
567
|
display_name='test cache',
|
528
568
|
ttl='3600s',
|
529
569
|
),
|
@@ -23,35 +23,66 @@ import datetime
|
|
23
23
|
import json
|
24
24
|
import os
|
25
25
|
import sys
|
26
|
-
from typing import Any, Optional, TypedDict, Union
|
26
|
+
from typing import Any, Optional, Tuple, TypedDict, Union
|
27
27
|
from urllib.parse import urlparse, urlunparse
|
28
28
|
|
29
29
|
import google.auth
|
30
30
|
import google.auth.credentials
|
31
31
|
from google.auth.transport.requests import AuthorizedSession
|
32
|
-
from pydantic import BaseModel
|
32
|
+
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
33
33
|
import requests
|
34
34
|
|
35
35
|
from . import errors
|
36
|
+
from . import version
|
36
37
|
|
37
38
|
|
38
|
-
class HttpOptions(
|
39
|
+
class HttpOptions(BaseModel):
|
40
|
+
"""HTTP options for the api client."""
|
41
|
+
model_config = ConfigDict(extra='forbid')
|
42
|
+
|
43
|
+
base_url: Optional[str] = Field(
|
44
|
+
default=None,
|
45
|
+
description="""The base URL for the AI platform service endpoint.""",
|
46
|
+
)
|
47
|
+
api_version: Optional[str] = Field(
|
48
|
+
default=None,
|
49
|
+
description="""Specifies the version of the API to use.""",
|
50
|
+
)
|
51
|
+
headers: Optional[dict[str, str]] = Field(
|
52
|
+
default=None,
|
53
|
+
description="""Additional HTTP headers to be sent with the request.""",
|
54
|
+
)
|
55
|
+
response_payload: Optional[dict] = Field(
|
56
|
+
default=None,
|
57
|
+
description="""If set, the response payload will be returned int the supplied dict.""",
|
58
|
+
)
|
59
|
+
timeout: Optional[Union[float, Tuple[float, float]]] = Field(
|
60
|
+
default=None,
|
61
|
+
description="""Timeout for the request in seconds.""",
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
class HttpOptionsDict(TypedDict):
|
39
66
|
"""HTTP options for the api client."""
|
40
67
|
|
41
|
-
base_url: str = None
|
68
|
+
base_url: Optional[str] = None
|
42
69
|
"""The base URL for the AI platform service endpoint."""
|
43
|
-
api_version: str = None
|
70
|
+
api_version: Optional[str] = None
|
44
71
|
"""Specifies the version of the API to use."""
|
45
|
-
headers: dict[str,
|
72
|
+
headers: Optional[dict[str, Union[str, list[str]]]] = None
|
46
73
|
"""Additional HTTP headers to be sent with the request."""
|
47
|
-
response_payload: dict = None
|
74
|
+
response_payload: Optional[dict] = None
|
48
75
|
"""If set, the response payload will be returned int the supplied dict."""
|
76
|
+
timeout: Optional[Union[float, Tuple[float, float]]] = None
|
77
|
+
"""Timeout for the request in seconds."""
|
78
|
+
|
79
|
+
|
80
|
+
HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
|
49
81
|
|
50
82
|
|
51
83
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
52
84
|
"""Appends the telemetry header to the headers dict."""
|
53
|
-
|
54
|
-
library_label = f'google-genai-sdk/0.3.0'
|
85
|
+
library_label = f'google-genai-sdk/{version.__version__}'
|
55
86
|
language_label = 'gl-python/' + sys.version.split()[0]
|
56
87
|
version_header_value = f'{library_label} {language_label}'
|
57
88
|
if (
|
@@ -71,20 +102,24 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
|
|
71
102
|
|
72
103
|
|
73
104
|
def _patch_http_options(
|
74
|
-
options:
|
75
|
-
) ->
|
105
|
+
options: HttpOptionsDict, patch_options: HttpOptionsDict
|
106
|
+
) -> HttpOptionsDict:
|
76
107
|
# use shallow copy so we don't override the original objects.
|
77
|
-
copy_option =
|
108
|
+
copy_option = HttpOptionsDict()
|
78
109
|
copy_option.update(options)
|
79
|
-
for
|
110
|
+
for patch_key, patch_value in patch_options.items():
|
80
111
|
# if both are dicts, update the copy.
|
81
112
|
# This is to handle cases like merging headers.
|
82
|
-
if isinstance(
|
83
|
-
|
84
|
-
|
85
|
-
copy_option[
|
86
|
-
|
87
|
-
|
113
|
+
if isinstance(patch_value, dict) and isinstance(
|
114
|
+
copy_option.get(patch_key, None), dict
|
115
|
+
):
|
116
|
+
copy_option[patch_key] = {}
|
117
|
+
copy_option[patch_key].update(
|
118
|
+
options[patch_key]
|
119
|
+
) # shallow copy from original options.
|
120
|
+
copy_option[patch_key].update(patch_value)
|
121
|
+
elif patch_value is not None: # Accept empty values.
|
122
|
+
copy_option[patch_key] = patch_value
|
88
123
|
_append_library_version_headers(copy_option['headers'])
|
89
124
|
return copy_option
|
90
125
|
|
@@ -98,10 +133,11 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
98
133
|
|
99
134
|
@dataclass
|
100
135
|
class HttpRequest:
|
101
|
-
headers: dict[str, str]
|
136
|
+
headers: dict[str, Union[str, list[str]]]
|
102
137
|
url: str
|
103
138
|
method: str
|
104
139
|
data: Union[dict[str, object], bytes]
|
140
|
+
timeout: Optional[Union[float, Tuple[float, float]]] = None
|
105
141
|
|
106
142
|
|
107
143
|
class HttpResponse:
|
@@ -147,7 +183,7 @@ class ApiClient:
|
|
147
183
|
credentials: google.auth.credentials.Credentials = None,
|
148
184
|
project: Union[str, None] = None,
|
149
185
|
location: Union[str, None] = None,
|
150
|
-
http_options:
|
186
|
+
http_options: HttpOptionsOrDict = None,
|
151
187
|
):
|
152
188
|
self.vertexai = vertexai
|
153
189
|
if self.vertexai is None:
|
@@ -163,11 +199,20 @@ class ApiClient:
|
|
163
199
|
'Project/location and API key are mutually exclusive in the client initializer.'
|
164
200
|
)
|
165
201
|
|
202
|
+
# Validate http_options if a dict is provided.
|
203
|
+
if isinstance(http_options, dict):
|
204
|
+
try:
|
205
|
+
HttpOptions.model_validate(http_options)
|
206
|
+
except ValidationError as e:
|
207
|
+
raise ValueError(f'Invalid http_options: {e}')
|
208
|
+
elif(isinstance(http_options, HttpOptions)):
|
209
|
+
http_options = http_options.model_dump()
|
210
|
+
|
166
211
|
self.api_key: Optional[str] = None
|
167
212
|
self.project = project or os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
168
213
|
self.location = location or os.environ.get('GOOGLE_CLOUD_LOCATION', None)
|
169
214
|
self._credentials = credentials
|
170
|
-
self._http_options =
|
215
|
+
self._http_options = HttpOptionsDict()
|
171
216
|
|
172
217
|
if self.vertexai:
|
173
218
|
if not self.project:
|
@@ -208,7 +253,7 @@ class ApiClient:
|
|
208
253
|
http_method: str,
|
209
254
|
path: str,
|
210
255
|
request_dict: dict[str, object],
|
211
|
-
http_options:
|
256
|
+
http_options: HttpOptionsDict = None,
|
212
257
|
) -> HttpRequest:
|
213
258
|
# Remove all special dict keys such as _url and _query.
|
214
259
|
keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
|
@@ -232,6 +277,7 @@ class ApiClient:
|
|
232
277
|
url=url,
|
233
278
|
headers=patched_http_options['headers'],
|
234
279
|
data=request_dict,
|
280
|
+
timeout=patched_http_options.get('timeout', None),
|
235
281
|
)
|
236
282
|
|
237
283
|
def _request(
|
@@ -250,10 +296,10 @@ class ApiClient:
|
|
250
296
|
http_request.method.upper(),
|
251
297
|
http_request.url,
|
252
298
|
headers=http_request.headers,
|
253
|
-
data=json.dumps(http_request.data, cls=RequestJsonEncoder)
|
254
|
-
|
255
|
-
|
256
|
-
timeout=
|
299
|
+
data=json.dumps(http_request.data, cls=RequestJsonEncoder)
|
300
|
+
if http_request.data
|
301
|
+
else None,
|
302
|
+
timeout=http_request.timeout,
|
257
303
|
)
|
258
304
|
errors.APIError.raise_for_response(response)
|
259
305
|
return HttpResponse(
|
@@ -275,13 +321,14 @@ class ApiClient:
|
|
275
321
|
data = http_request.data
|
276
322
|
|
277
323
|
http_session = requests.Session()
|
278
|
-
|
324
|
+
response = http_session.request(
|
279
325
|
method=http_request.method,
|
280
326
|
url=http_request.url,
|
281
327
|
headers=http_request.headers,
|
282
328
|
data=data,
|
283
|
-
|
284
|
-
|
329
|
+
timeout=http_request.timeout,
|
330
|
+
stream=stream,
|
331
|
+
)
|
285
332
|
errors.APIError.raise_for_response(response)
|
286
333
|
return HttpResponse(
|
287
334
|
response.headers, response if stream else [response.text]
|
@@ -307,8 +354,10 @@ class ApiClient:
|
|
307
354
|
stream=stream,
|
308
355
|
)
|
309
356
|
|
310
|
-
def get_read_only_http_options(self) ->
|
311
|
-
copied =
|
357
|
+
def get_read_only_http_options(self) -> HttpOptionsDict:
|
358
|
+
copied = HttpOptionsDict()
|
359
|
+
if isinstance(self._http_options, BaseModel):
|
360
|
+
self._http_options = self._http_options.model_dump()
|
312
361
|
copied.update(self._http_options)
|
313
362
|
return copied
|
314
363
|
|
@@ -317,7 +366,7 @@ class ApiClient:
|
|
317
366
|
http_method: str,
|
318
367
|
path: str,
|
319
368
|
request_dict: dict[str, object],
|
320
|
-
http_options:
|
369
|
+
http_options: HttpOptionsDict = None,
|
321
370
|
):
|
322
371
|
http_request = self._build_request(
|
323
372
|
http_method, path, request_dict, http_options
|
@@ -332,7 +381,7 @@ class ApiClient:
|
|
332
381
|
http_method: str,
|
333
382
|
path: str,
|
334
383
|
request_dict: dict[str, object],
|
335
|
-
http_options:
|
384
|
+
http_options: HttpOptionsDict = None,
|
336
385
|
):
|
337
386
|
http_request = self._build_request(
|
338
387
|
http_method, path, request_dict, http_options
|
@@ -349,7 +398,7 @@ class ApiClient:
|
|
349
398
|
http_method: str,
|
350
399
|
path: str,
|
351
400
|
request_dict: dict[str, object],
|
352
|
-
http_options:
|
401
|
+
http_options: HttpOptionsDict = None,
|
353
402
|
) -> dict[str, object]:
|
354
403
|
http_request = self._build_request(
|
355
404
|
http_method, path, request_dict, http_options
|
@@ -365,7 +414,7 @@ class ApiClient:
|
|
365
414
|
http_method: str,
|
366
415
|
path: str,
|
367
416
|
request_dict: dict[str, object],
|
368
|
-
http_options:
|
417
|
+
http_options: HttpOptionsDict = None,
|
369
418
|
):
|
370
419
|
http_request = self._build_request(
|
371
420
|
http_method, path, request_dict, http_options
|
@@ -58,8 +58,8 @@ def _raise_for_nullable_if_mldev(schema: types.Schema):
|
|
58
58
|
)
|
59
59
|
|
60
60
|
|
61
|
-
def _raise_if_schema_unsupported(
|
62
|
-
if not
|
61
|
+
def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
|
62
|
+
if not variant == 'VERTEX_AI':
|
63
63
|
_raise_for_any_of_if_mldev(schema)
|
64
64
|
_raise_for_default_if_mldev(schema)
|
65
65
|
_raise_for_nullable_if_mldev(schema)
|
@@ -112,7 +112,7 @@ def _is_default_value_compatible(
|
|
112
112
|
|
113
113
|
|
114
114
|
def _parse_schema_from_parameter(
|
115
|
-
|
115
|
+
variant: str, param: inspect.Parameter, func_name: str
|
116
116
|
) -> types.Schema:
|
117
117
|
"""parse schema from parameter.
|
118
118
|
|
@@ -130,7 +130,7 @@ def _parse_schema_from_parameter(
|
|
130
130
|
raise ValueError(default_value_error_msg)
|
131
131
|
schema.default = param.default
|
132
132
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
133
|
-
_raise_if_schema_unsupported(
|
133
|
+
_raise_if_schema_unsupported(variant, schema)
|
134
134
|
return schema
|
135
135
|
if (
|
136
136
|
isinstance(param.annotation, typing_types.UnionType)
|
@@ -149,7 +149,7 @@ def _parse_schema_from_parameter(
|
|
149
149
|
schema.nullable = True
|
150
150
|
continue
|
151
151
|
schema_in_any_of = _parse_schema_from_parameter(
|
152
|
-
|
152
|
+
variant,
|
153
153
|
inspect.Parameter(
|
154
154
|
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
155
155
|
),
|
@@ -171,7 +171,7 @@ def _parse_schema_from_parameter(
|
|
171
171
|
if not _is_default_value_compatible(param.default, param.annotation):
|
172
172
|
raise ValueError(default_value_error_msg)
|
173
173
|
schema.default = param.default
|
174
|
-
_raise_if_schema_unsupported(
|
174
|
+
_raise_if_schema_unsupported(variant, schema)
|
175
175
|
return schema
|
176
176
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
177
177
|
param.annotation, typing_types.GenericAlias
|
@@ -184,7 +184,7 @@ def _parse_schema_from_parameter(
|
|
184
184
|
if not _is_default_value_compatible(param.default, param.annotation):
|
185
185
|
raise ValueError(default_value_error_msg)
|
186
186
|
schema.default = param.default
|
187
|
-
_raise_if_schema_unsupported(
|
187
|
+
_raise_if_schema_unsupported(variant, schema)
|
188
188
|
return schema
|
189
189
|
if origin is Literal:
|
190
190
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -197,12 +197,12 @@ def _parse_schema_from_parameter(
|
|
197
197
|
if not _is_default_value_compatible(param.default, param.annotation):
|
198
198
|
raise ValueError(default_value_error_msg)
|
199
199
|
schema.default = param.default
|
200
|
-
_raise_if_schema_unsupported(
|
200
|
+
_raise_if_schema_unsupported(variant, schema)
|
201
201
|
return schema
|
202
202
|
if origin is list:
|
203
203
|
schema.type = 'ARRAY'
|
204
204
|
schema.items = _parse_schema_from_parameter(
|
205
|
-
|
205
|
+
variant,
|
206
206
|
inspect.Parameter(
|
207
207
|
'item',
|
208
208
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -214,7 +214,7 @@ def _parse_schema_from_parameter(
|
|
214
214
|
if not _is_default_value_compatible(param.default, param.annotation):
|
215
215
|
raise ValueError(default_value_error_msg)
|
216
216
|
schema.default = param.default
|
217
|
-
_raise_if_schema_unsupported(
|
217
|
+
_raise_if_schema_unsupported(variant, schema)
|
218
218
|
return schema
|
219
219
|
if origin is Union:
|
220
220
|
schema.any_of = []
|
@@ -225,7 +225,7 @@ def _parse_schema_from_parameter(
|
|
225
225
|
schema.nullable = True
|
226
226
|
continue
|
227
227
|
schema_in_any_of = _parse_schema_from_parameter(
|
228
|
-
|
228
|
+
variant,
|
229
229
|
inspect.Parameter(
|
230
230
|
'item',
|
231
231
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -249,7 +249,7 @@ def _parse_schema_from_parameter(
|
|
249
249
|
if not _is_default_value_compatible(param.default, param.annotation):
|
250
250
|
raise ValueError(default_value_error_msg)
|
251
251
|
schema.default = param.default
|
252
|
-
_raise_if_schema_unsupported(
|
252
|
+
_raise_if_schema_unsupported(variant, schema)
|
253
253
|
return schema
|
254
254
|
# all other generic alias will be invoked in raise branch
|
255
255
|
if (
|
@@ -266,7 +266,7 @@ def _parse_schema_from_parameter(
|
|
266
266
|
schema.properties = {}
|
267
267
|
for field_name, field_info in param.annotation.model_fields.items():
|
268
268
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
269
|
-
|
269
|
+
variant,
|
270
270
|
inspect.Parameter(
|
271
271
|
field_name,
|
272
272
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -274,7 +274,7 @@ def _parse_schema_from_parameter(
|
|
274
274
|
),
|
275
275
|
func_name,
|
276
276
|
)
|
277
|
-
_raise_if_schema_unsupported(
|
277
|
+
_raise_if_schema_unsupported(variant, schema)
|
278
278
|
return schema
|
279
279
|
raise ValueError(
|
280
280
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Replay API client."""
|
17
17
|
|
18
|
+
import base64
|
18
19
|
import copy
|
19
20
|
import inspect
|
20
21
|
import json
|
@@ -105,28 +106,6 @@ def redact_http_request(http_request: HttpRequest):
|
|
105
106
|
_redact_request_body(http_request.data)
|
106
107
|
|
107
108
|
|
108
|
-
def process_bytes_fields(data: dict[str, object]):
|
109
|
-
"""Converts bytes fields to strings.
|
110
|
-
|
111
|
-
This function doesn't modify the content of data dict.
|
112
|
-
"""
|
113
|
-
if not isinstance(data, dict):
|
114
|
-
return data
|
115
|
-
for key, value in data.items():
|
116
|
-
if isinstance(value, bytes):
|
117
|
-
data[key] = value.decode()
|
118
|
-
elif isinstance(value, dict):
|
119
|
-
process_bytes_fields(value)
|
120
|
-
elif isinstance(value, list):
|
121
|
-
if all(isinstance(v, bytes) for v in value):
|
122
|
-
data[key] = [v.decode() for v in value]
|
123
|
-
else:
|
124
|
-
data[key] = [process_bytes_fields(v) for v in value]
|
125
|
-
else:
|
126
|
-
data[key] = value
|
127
|
-
return data
|
128
|
-
|
129
|
-
|
130
109
|
def _current_file_path_and_line():
|
131
110
|
"""Prints the current file path and line number."""
|
132
111
|
frame = inspect.currentframe().f_back.f_back
|
@@ -185,7 +164,7 @@ class ReplayFile(BaseModel):
|
|
185
164
|
|
186
165
|
|
187
166
|
class ReplayApiClient(ApiClient):
|
188
|
-
"""For integration testing, send recorded
|
167
|
+
"""For integration testing, send recorded response or records a response."""
|
189
168
|
|
190
169
|
def __init__(
|
191
170
|
self,
|
@@ -280,9 +259,18 @@ class ReplayApiClient(ApiClient):
|
|
280
259
|
replay_file_path = self._get_replay_file_path()
|
281
260
|
os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
|
282
261
|
with open(replay_file_path, 'w') as f:
|
262
|
+
replay_session_dict = self.replay_session.model_dump()
|
263
|
+
# Use for non-utf-8 bytes in image/video... output.
|
264
|
+
for interaction in replay_session_dict['interactions']:
|
265
|
+
segments = []
|
266
|
+
for response in interaction['response']['sdk_response_segments']:
|
267
|
+
segments.append(json.loads(json.dumps(
|
268
|
+
response, cls=ResponseJsonEncoder
|
269
|
+
)))
|
270
|
+
interaction['response']['sdk_response_segments'] = segments
|
283
271
|
f.write(
|
284
272
|
json.dumps(
|
285
|
-
|
273
|
+
replay_session_dict, indent=2, cls=RequestJsonEncoder
|
286
274
|
)
|
287
275
|
)
|
288
276
|
self.replay_session = None
|
@@ -463,10 +451,16 @@ class ResponseJsonEncoder(json.JSONEncoder):
|
|
463
451
|
"""
|
464
452
|
def default(self, o):
|
465
453
|
if isinstance(o, bytes):
|
466
|
-
#
|
467
|
-
#
|
468
|
-
#
|
469
|
-
|
454
|
+
# Use base64.b64encode() to encode bytes to string so that the media bytes
|
455
|
+
# fields are serializable.
|
456
|
+
# o.decode(encoding='utf-8', errors='replace') doesn't work because it
|
457
|
+
# uses a fixed error string `\ufffd` for all non-utf-8 characters,
|
458
|
+
# which cannot be converted back to original bytes. And other languages
|
459
|
+
# only have the original bytes to compare with.
|
460
|
+
# Since we use base64.b64encoding() in replay test, a change that breaks
|
461
|
+
# native bytes can be captured by
|
462
|
+
# test_compute_tokens.py::test_token_bytes_deserialization.
|
463
|
+
return base64.b64encode(o).decode(encoding='utf-8')
|
470
464
|
elif isinstance(o, datetime.datetime):
|
471
465
|
# dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
|
472
466
|
# but replay files want "2024-11-15T23:27:45.624657Z"
|