google-genai 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {google_genai-0.3.0/google_genai.egg-info → google_genai-0.4.0}/PKG-INFO +57 -17
  2. {google_genai-0.3.0 → google_genai-0.4.0}/README.md +55 -15
  3. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/__init__.py +2 -1
  4. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_api_client.py +85 -36
  5. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_automatic_function_calling_util.py +14 -14
  6. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_replay_api_client.py +22 -28
  7. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/batches.py +16 -16
  8. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/caches.py +18 -18
  9. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/chats.py +2 -2
  10. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/client.py +6 -3
  11. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/files.py +22 -22
  12. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/live.py +28 -5
  13. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/models.py +97 -77
  14. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/tunings.py +17 -17
  15. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/types.py +150 -80
  16. google_genai-0.4.0/google/genai/version.py +16 -0
  17. {google_genai-0.3.0 → google_genai-0.4.0/google_genai.egg-info}/PKG-INFO +57 -17
  18. {google_genai-0.3.0 → google_genai-0.4.0}/google_genai.egg-info/SOURCES.txt +1 -0
  19. {google_genai-0.3.0 → google_genai-0.4.0}/pyproject.toml +1 -1
  20. {google_genai-0.3.0 → google_genai-0.4.0}/LICENSE +0 -0
  21. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_common.py +0 -0
  22. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_extra_utils.py +0 -0
  23. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_test_api_client.py +0 -0
  24. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/_transformers.py +0 -0
  25. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/errors.py +0 -0
  26. {google_genai-0.3.0 → google_genai-0.4.0}/google/genai/pagers.py +0 -0
  27. {google_genai-0.3.0 → google_genai-0.4.0}/google_genai.egg-info/dependency_links.txt +0 -0
  28. {google_genai-0.3.0 → google_genai-0.4.0}/google_genai.egg-info/requires.txt +0 -0
  29. {google_genai-0.3.0 → google_genai-0.4.0}/google_genai.egg-info/top_level.txt +0 -0
  30. {google_genai-0.3.0 → google_genai-0.4.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: google-genai
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: GenAI Python SDK
5
5
  Author-email: Google LLC <googleapis-packages@google.com>
6
6
  License: Apache-2.0
@@ -436,10 +436,10 @@ response1 = client.models.generate_image(
436
436
  model='imagen-3.0-generate-001',
437
437
  prompt='An umbrella in the foreground, and a rainy night sky in the background',
438
438
  config=types.GenerateImageConfig(
439
- negative_prompt= "human",
439
+ negative_prompt= 'human',
440
440
  number_of_images= 1,
441
441
  include_rai_reason= True,
442
- output_mime_type= "image/jpeg"
442
+ output_mime_type= 'image/jpeg'
443
443
  )
444
444
  )
445
445
  response1.generated_images[0].image.show()
@@ -454,7 +454,11 @@ Upscale image is not supported in Google AI.
454
454
  response2 = client.models.upscale_image(
455
455
  model='imagen-3.0-generate-001',
456
456
  image=response1.generated_images[0].image,
457
- config=types.UpscaleImageConfig(upscale_factor="x2")
457
+ upscale_factor='x2',
458
+ config=types.UpscaleImageConfig(
459
+ include_rai_reason= True,
460
+ output_mime_type= 'image/jpeg',
461
+ ),
458
462
  )
459
463
  response2.generated_images[0].image.show()
460
464
  ```
@@ -497,6 +501,42 @@ response3 = client.models.edit_image(
497
501
  response3.generated_images[0].image.show()
498
502
  ```
499
503
 
504
+ ## Chats
505
+
506
+ Create a chat session to start a multi-turn conversations with the model.
507
+
508
+ ### Send Message
509
+
510
+ ```python
511
+ chat = client.chats.create(model='gemini-2.0-flash-exp')
512
+ response = chat.send_message('tell me a story')
513
+ print(response.text)
514
+ ```
515
+
516
+ ### Streaming
517
+
518
+ ```python
519
+ chat = client.chats.create(model='gemini-2.0-flash-exp')
520
+ for chunk in chat.send_message_stream('tell me a story'):
521
+ print(chunk.text)
522
+ ```
523
+
524
+ ### Async
525
+
526
+ ```python
527
+ chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
528
+ response = await chat.send_message('tell me a story')
529
+ print(response.text)
530
+ ```
531
+
532
+ ### Async Streaming
533
+
534
+ ```python
535
+ chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
536
+ async for chunk in chat.send_message_stream('tell me a story'):
537
+ print(chunk.text)
538
+ ```
539
+
500
540
  ## Files (Only Google AI)
501
541
 
502
542
  ``` python
@@ -539,19 +579,19 @@ else:
539
579
 
540
580
  cached_content = client.caches.create(
541
581
  model='gemini-1.5-pro-002',
542
- contents=[
543
- types.Content(
544
- role='user',
545
- parts=[
546
- types.Part.from_uri(
547
- file_uri=file_uris[0],
548
- mime_type='application/pdf'),
549
- types.Part.from_uri(
550
- file_uri=file_uris[1],
551
- mime_type='application/pdf',)])
552
- ],
553
- system_instruction='What is the sum of the two pdfs?',
554
582
  config=types.CreateCachedContentConfig(
583
+ contents=[
584
+ types.Content(
585
+ role='user',
586
+ parts=[
587
+ types.Part.from_uri(
588
+ file_uri=file_uris[0],
589
+ mime_type='application/pdf'),
590
+ types.Part.from_uri(
591
+ file_uri=file_uris[1],
592
+ mime_type='application/pdf',)])
593
+ ],
594
+ system_instruction='What is the sum of the two pdfs?',
555
595
  display_name='test cache',
556
596
  ttl='3600s',
557
597
  ),
@@ -408,10 +408,10 @@ response1 = client.models.generate_image(
408
408
  model='imagen-3.0-generate-001',
409
409
  prompt='An umbrella in the foreground, and a rainy night sky in the background',
410
410
  config=types.GenerateImageConfig(
411
- negative_prompt= "human",
411
+ negative_prompt= 'human',
412
412
  number_of_images= 1,
413
413
  include_rai_reason= True,
414
- output_mime_type= "image/jpeg"
414
+ output_mime_type= 'image/jpeg'
415
415
  )
416
416
  )
417
417
  response1.generated_images[0].image.show()
@@ -426,7 +426,11 @@ Upscale image is not supported in Google AI.
426
426
  response2 = client.models.upscale_image(
427
427
  model='imagen-3.0-generate-001',
428
428
  image=response1.generated_images[0].image,
429
- config=types.UpscaleImageConfig(upscale_factor="x2")
429
+ upscale_factor='x2',
430
+ config=types.UpscaleImageConfig(
431
+ include_rai_reason= True,
432
+ output_mime_type= 'image/jpeg',
433
+ ),
430
434
  )
431
435
  response2.generated_images[0].image.show()
432
436
  ```
@@ -469,6 +473,42 @@ response3 = client.models.edit_image(
469
473
  response3.generated_images[0].image.show()
470
474
  ```
471
475
 
476
+ ## Chats
477
+
478
+ Create a chat session to start a multi-turn conversations with the model.
479
+
480
+ ### Send Message
481
+
482
+ ```python
483
+ chat = client.chats.create(model='gemini-2.0-flash-exp')
484
+ response = chat.send_message('tell me a story')
485
+ print(response.text)
486
+ ```
487
+
488
+ ### Streaming
489
+
490
+ ```python
491
+ chat = client.chats.create(model='gemini-2.0-flash-exp')
492
+ for chunk in chat.send_message_stream('tell me a story'):
493
+ print(chunk.text)
494
+ ```
495
+
496
+ ### Async
497
+
498
+ ```python
499
+ chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
500
+ response = await chat.send_message('tell me a story')
501
+ print(response.text)
502
+ ```
503
+
504
+ ### Async Streaming
505
+
506
+ ```python
507
+ chat = client.aio.chats.create(model='gemini-2.0-flash-exp')
508
+ async for chunk in chat.send_message_stream('tell me a story'):
509
+ print(chunk.text)
510
+ ```
511
+
472
512
  ## Files (Only Google AI)
473
513
 
474
514
  ``` python
@@ -511,19 +551,19 @@ else:
511
551
 
512
552
  cached_content = client.caches.create(
513
553
  model='gemini-1.5-pro-002',
514
- contents=[
515
- types.Content(
516
- role='user',
517
- parts=[
518
- types.Part.from_uri(
519
- file_uri=file_uris[0],
520
- mime_type='application/pdf'),
521
- types.Part.from_uri(
522
- file_uri=file_uris[1],
523
- mime_type='application/pdf',)])
524
- ],
525
- system_instruction='What is the sum of the two pdfs?',
526
554
  config=types.CreateCachedContentConfig(
555
+ contents=[
556
+ types.Content(
557
+ role='user',
558
+ parts=[
559
+ types.Part.from_uri(
560
+ file_uri=file_uris[0],
561
+ mime_type='application/pdf'),
562
+ types.Part.from_uri(
563
+ file_uri=file_uris[1],
564
+ mime_type='application/pdf',)])
565
+ ],
566
+ system_instruction='What is the sum of the two pdfs?',
527
567
  display_name='test cache',
528
568
  ttl='3600s',
529
569
  ),
@@ -16,7 +16,8 @@
16
16
  """Google Gen AI SDK"""
17
17
 
18
18
  from .client import Client
19
+ from . import version
19
20
 
20
- __version__ = '0.3.0'
21
+ __version__ = version.__version__
21
22
 
22
23
  __all__ = ['Client']
@@ -23,35 +23,66 @@ import datetime
23
23
  import json
24
24
  import os
25
25
  import sys
26
- from typing import Any, Optional, TypedDict, Union
26
+ from typing import Any, Optional, Tuple, TypedDict, Union
27
27
  from urllib.parse import urlparse, urlunparse
28
28
 
29
29
  import google.auth
30
30
  import google.auth.credentials
31
31
  from google.auth.transport.requests import AuthorizedSession
32
- from pydantic import BaseModel
32
+ from pydantic import BaseModel, ConfigDict, Field, ValidationError
33
33
  import requests
34
34
 
35
35
  from . import errors
36
+ from . import version
36
37
 
37
38
 
38
- class HttpOptions(TypedDict):
39
+ class HttpOptions(BaseModel):
40
+ """HTTP options for the api client."""
41
+ model_config = ConfigDict(extra='forbid')
42
+
43
+ base_url: Optional[str] = Field(
44
+ default=None,
45
+ description="""The base URL for the AI platform service endpoint.""",
46
+ )
47
+ api_version: Optional[str] = Field(
48
+ default=None,
49
+ description="""Specifies the version of the API to use.""",
50
+ )
51
+ headers: Optional[dict[str, str]] = Field(
52
+ default=None,
53
+ description="""Additional HTTP headers to be sent with the request.""",
54
+ )
55
+ response_payload: Optional[dict] = Field(
56
+ default=None,
57
+ description="""If set, the response payload will be returned int the supplied dict.""",
58
+ )
59
+ timeout: Optional[Union[float, Tuple[float, float]]] = Field(
60
+ default=None,
61
+ description="""Timeout for the request in seconds.""",
62
+ )
63
+
64
+
65
+ class HttpOptionsDict(TypedDict):
39
66
  """HTTP options for the api client."""
40
67
 
41
- base_url: str = None
68
+ base_url: Optional[str] = None
42
69
  """The base URL for the AI platform service endpoint."""
43
- api_version: str = None
70
+ api_version: Optional[str] = None
44
71
  """Specifies the version of the API to use."""
45
- headers: dict[str, dict] = None
72
+ headers: Optional[dict[str, Union[str, list[str]]]] = None
46
73
  """Additional HTTP headers to be sent with the request."""
47
- response_payload: dict = None
74
+ response_payload: Optional[dict] = None
48
75
  """If set, the response payload will be returned int the supplied dict."""
76
+ timeout: Optional[Union[float, Tuple[float, float]]] = None
77
+ """Timeout for the request in seconds."""
78
+
79
+
80
+ HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
49
81
 
50
82
 
51
83
  def _append_library_version_headers(headers: dict[str, str]) -> None:
52
84
  """Appends the telemetry header to the headers dict."""
53
- # TODO: Automate revisions to the SDK library version.
54
- library_label = f'google-genai-sdk/0.3.0'
85
+ library_label = f'google-genai-sdk/{version.__version__}'
55
86
  language_label = 'gl-python/' + sys.version.split()[0]
56
87
  version_header_value = f'{library_label} {language_label}'
57
88
  if (
@@ -71,20 +102,24 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
71
102
 
72
103
 
73
104
  def _patch_http_options(
74
- options: HttpOptions, patch_options: HttpOptions
75
- ) -> HttpOptions:
105
+ options: HttpOptionsDict, patch_options: HttpOptionsDict
106
+ ) -> HttpOptionsDict:
76
107
  # use shallow copy so we don't override the original objects.
77
- copy_option = HttpOptions()
108
+ copy_option = HttpOptionsDict()
78
109
  copy_option.update(options)
79
- for k, v in patch_options.items():
110
+ for patch_key, patch_value in patch_options.items():
80
111
  # if both are dicts, update the copy.
81
112
  # This is to handle cases like merging headers.
82
- if isinstance(v, dict) and isinstance(copy_option.get(k, None), dict):
83
- copy_option[k] = {}
84
- copy_option[k].update(options[k]) # shallow copy from original options.
85
- copy_option[k].update(v)
86
- elif v is not None: # Accept empty values.
87
- copy_option[k] = v
113
+ if isinstance(patch_value, dict) and isinstance(
114
+ copy_option.get(patch_key, None), dict
115
+ ):
116
+ copy_option[patch_key] = {}
117
+ copy_option[patch_key].update(
118
+ options[patch_key]
119
+ ) # shallow copy from original options.
120
+ copy_option[patch_key].update(patch_value)
121
+ elif patch_value is not None: # Accept empty values.
122
+ copy_option[patch_key] = patch_value
88
123
  _append_library_version_headers(copy_option['headers'])
89
124
  return copy_option
90
125
 
@@ -98,10 +133,11 @@ def _join_url_path(base_url: str, path: str) -> str:
98
133
 
99
134
  @dataclass
100
135
  class HttpRequest:
101
- headers: dict[str, str]
136
+ headers: dict[str, Union[str, list[str]]]
102
137
  url: str
103
138
  method: str
104
139
  data: Union[dict[str, object], bytes]
140
+ timeout: Optional[Union[float, Tuple[float, float]]] = None
105
141
 
106
142
 
107
143
  class HttpResponse:
@@ -147,7 +183,7 @@ class ApiClient:
147
183
  credentials: google.auth.credentials.Credentials = None,
148
184
  project: Union[str, None] = None,
149
185
  location: Union[str, None] = None,
150
- http_options: HttpOptions = None,
186
+ http_options: HttpOptionsOrDict = None,
151
187
  ):
152
188
  self.vertexai = vertexai
153
189
  if self.vertexai is None:
@@ -163,11 +199,20 @@ class ApiClient:
163
199
  'Project/location and API key are mutually exclusive in the client initializer.'
164
200
  )
165
201
 
202
+ # Validate http_options if a dict is provided.
203
+ if isinstance(http_options, dict):
204
+ try:
205
+ HttpOptions.model_validate(http_options)
206
+ except ValidationError as e:
207
+ raise ValueError(f'Invalid http_options: {e}')
208
+ elif(isinstance(http_options, HttpOptions)):
209
+ http_options = http_options.model_dump()
210
+
166
211
  self.api_key: Optional[str] = None
167
212
  self.project = project or os.environ.get('GOOGLE_CLOUD_PROJECT', None)
168
213
  self.location = location or os.environ.get('GOOGLE_CLOUD_LOCATION', None)
169
214
  self._credentials = credentials
170
- self._http_options = HttpOptions()
215
+ self._http_options = HttpOptionsDict()
171
216
 
172
217
  if self.vertexai:
173
218
  if not self.project:
@@ -208,7 +253,7 @@ class ApiClient:
208
253
  http_method: str,
209
254
  path: str,
210
255
  request_dict: dict[str, object],
211
- http_options: HttpOptions = None,
256
+ http_options: HttpOptionsDict = None,
212
257
  ) -> HttpRequest:
213
258
  # Remove all special dict keys such as _url and _query.
214
259
  keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
@@ -232,6 +277,7 @@ class ApiClient:
232
277
  url=url,
233
278
  headers=patched_http_options['headers'],
234
279
  data=request_dict,
280
+ timeout=patched_http_options.get('timeout', None),
235
281
  )
236
282
 
237
283
  def _request(
@@ -250,10 +296,10 @@ class ApiClient:
250
296
  http_request.method.upper(),
251
297
  http_request.url,
252
298
  headers=http_request.headers,
253
- data=json.dumps(http_request.data, cls=RequestJsonEncoder) if http_request.data else None,
254
- # TODO: support timeout in RequestOptions so it can be configured
255
- # per methods.
256
- timeout=None,
299
+ data=json.dumps(http_request.data, cls=RequestJsonEncoder)
300
+ if http_request.data
301
+ else None,
302
+ timeout=http_request.timeout,
257
303
  )
258
304
  errors.APIError.raise_for_response(response)
259
305
  return HttpResponse(
@@ -275,13 +321,14 @@ class ApiClient:
275
321
  data = http_request.data
276
322
 
277
323
  http_session = requests.Session()
278
- request = requests.Request(
324
+ response = http_session.request(
279
325
  method=http_request.method,
280
326
  url=http_request.url,
281
327
  headers=http_request.headers,
282
328
  data=data,
283
- ).prepare()
284
- response = http_session.send(request, stream=stream)
329
+ timeout=http_request.timeout,
330
+ stream=stream,
331
+ )
285
332
  errors.APIError.raise_for_response(response)
286
333
  return HttpResponse(
287
334
  response.headers, response if stream else [response.text]
@@ -307,8 +354,10 @@ class ApiClient:
307
354
  stream=stream,
308
355
  )
309
356
 
310
- def get_read_only_http_options(self) -> HttpOptions:
311
- copied = HttpOptions()
357
+ def get_read_only_http_options(self) -> HttpOptionsDict:
358
+ copied = HttpOptionsDict()
359
+ if isinstance(self._http_options, BaseModel):
360
+ self._http_options = self._http_options.model_dump()
312
361
  copied.update(self._http_options)
313
362
  return copied
314
363
 
@@ -317,7 +366,7 @@ class ApiClient:
317
366
  http_method: str,
318
367
  path: str,
319
368
  request_dict: dict[str, object],
320
- http_options: HttpOptions = None,
369
+ http_options: HttpOptionsDict = None,
321
370
  ):
322
371
  http_request = self._build_request(
323
372
  http_method, path, request_dict, http_options
@@ -332,7 +381,7 @@ class ApiClient:
332
381
  http_method: str,
333
382
  path: str,
334
383
  request_dict: dict[str, object],
335
- http_options: HttpOptions = None,
384
+ http_options: HttpOptionsDict = None,
336
385
  ):
337
386
  http_request = self._build_request(
338
387
  http_method, path, request_dict, http_options
@@ -349,7 +398,7 @@ class ApiClient:
349
398
  http_method: str,
350
399
  path: str,
351
400
  request_dict: dict[str, object],
352
- http_options: HttpOptions = None,
401
+ http_options: HttpOptionsDict = None,
353
402
  ) -> dict[str, object]:
354
403
  http_request = self._build_request(
355
404
  http_method, path, request_dict, http_options
@@ -365,7 +414,7 @@ class ApiClient:
365
414
  http_method: str,
366
415
  path: str,
367
416
  request_dict: dict[str, object],
368
- http_options: HttpOptions = None,
417
+ http_options: HttpOptionsDict = None,
369
418
  ):
370
419
  http_request = self._build_request(
371
420
  http_method, path, request_dict, http_options
@@ -58,8 +58,8 @@ def _raise_for_nullable_if_mldev(schema: types.Schema):
58
58
  )
59
59
 
60
60
 
61
- def _raise_if_schema_unsupported(client, schema: types.Schema):
62
- if not client.vertexai:
61
+ def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
62
+ if not variant == 'VERTEX_AI':
63
63
  _raise_for_any_of_if_mldev(schema)
64
64
  _raise_for_default_if_mldev(schema)
65
65
  _raise_for_nullable_if_mldev(schema)
@@ -112,7 +112,7 @@ def _is_default_value_compatible(
112
112
 
113
113
 
114
114
  def _parse_schema_from_parameter(
115
- client, param: inspect.Parameter, func_name: str
115
+ variant: str, param: inspect.Parameter, func_name: str
116
116
  ) -> types.Schema:
117
117
  """parse schema from parameter.
118
118
 
@@ -130,7 +130,7 @@ def _parse_schema_from_parameter(
130
130
  raise ValueError(default_value_error_msg)
131
131
  schema.default = param.default
132
132
  schema.type = _py_builtin_type_to_schema_type[param.annotation]
133
- _raise_if_schema_unsupported(client, schema)
133
+ _raise_if_schema_unsupported(variant, schema)
134
134
  return schema
135
135
  if (
136
136
  isinstance(param.annotation, typing_types.UnionType)
@@ -149,7 +149,7 @@ def _parse_schema_from_parameter(
149
149
  schema.nullable = True
150
150
  continue
151
151
  schema_in_any_of = _parse_schema_from_parameter(
152
- client,
152
+ variant,
153
153
  inspect.Parameter(
154
154
  'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
155
155
  ),
@@ -171,7 +171,7 @@ def _parse_schema_from_parameter(
171
171
  if not _is_default_value_compatible(param.default, param.annotation):
172
172
  raise ValueError(default_value_error_msg)
173
173
  schema.default = param.default
174
- _raise_if_schema_unsupported(client, schema)
174
+ _raise_if_schema_unsupported(variant, schema)
175
175
  return schema
176
176
  if isinstance(param.annotation, _GenericAlias) or isinstance(
177
177
  param.annotation, typing_types.GenericAlias
@@ -184,7 +184,7 @@ def _parse_schema_from_parameter(
184
184
  if not _is_default_value_compatible(param.default, param.annotation):
185
185
  raise ValueError(default_value_error_msg)
186
186
  schema.default = param.default
187
- _raise_if_schema_unsupported(client, schema)
187
+ _raise_if_schema_unsupported(variant, schema)
188
188
  return schema
189
189
  if origin is Literal:
190
190
  if not all(isinstance(arg, str) for arg in args):
@@ -197,12 +197,12 @@ def _parse_schema_from_parameter(
197
197
  if not _is_default_value_compatible(param.default, param.annotation):
198
198
  raise ValueError(default_value_error_msg)
199
199
  schema.default = param.default
200
- _raise_if_schema_unsupported(client, schema)
200
+ _raise_if_schema_unsupported(variant, schema)
201
201
  return schema
202
202
  if origin is list:
203
203
  schema.type = 'ARRAY'
204
204
  schema.items = _parse_schema_from_parameter(
205
- client,
205
+ variant,
206
206
  inspect.Parameter(
207
207
  'item',
208
208
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -214,7 +214,7 @@ def _parse_schema_from_parameter(
214
214
  if not _is_default_value_compatible(param.default, param.annotation):
215
215
  raise ValueError(default_value_error_msg)
216
216
  schema.default = param.default
217
- _raise_if_schema_unsupported(client, schema)
217
+ _raise_if_schema_unsupported(variant, schema)
218
218
  return schema
219
219
  if origin is Union:
220
220
  schema.any_of = []
@@ -225,7 +225,7 @@ def _parse_schema_from_parameter(
225
225
  schema.nullable = True
226
226
  continue
227
227
  schema_in_any_of = _parse_schema_from_parameter(
228
- client,
228
+ variant,
229
229
  inspect.Parameter(
230
230
  'item',
231
231
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -249,7 +249,7 @@ def _parse_schema_from_parameter(
249
249
  if not _is_default_value_compatible(param.default, param.annotation):
250
250
  raise ValueError(default_value_error_msg)
251
251
  schema.default = param.default
252
- _raise_if_schema_unsupported(client, schema)
252
+ _raise_if_schema_unsupported(variant, schema)
253
253
  return schema
254
254
  # all other generic alias will be invoked in raise branch
255
255
  if (
@@ -266,7 +266,7 @@ def _parse_schema_from_parameter(
266
266
  schema.properties = {}
267
267
  for field_name, field_info in param.annotation.model_fields.items():
268
268
  schema.properties[field_name] = _parse_schema_from_parameter(
269
- client,
269
+ variant,
270
270
  inspect.Parameter(
271
271
  field_name,
272
272
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -274,7 +274,7 @@ def _parse_schema_from_parameter(
274
274
  ),
275
275
  func_name,
276
276
  )
277
- _raise_if_schema_unsupported(client, schema)
277
+ _raise_if_schema_unsupported(variant, schema)
278
278
  return schema
279
279
  raise ValueError(
280
280
  f'Failed to parse the parameter {param} of function {func_name} for'
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Replay API client."""
17
17
 
18
+ import base64
18
19
  import copy
19
20
  import inspect
20
21
  import json
@@ -105,28 +106,6 @@ def redact_http_request(http_request: HttpRequest):
105
106
  _redact_request_body(http_request.data)
106
107
 
107
108
 
108
- def process_bytes_fields(data: dict[str, object]):
109
- """Converts bytes fields to strings.
110
-
111
- This function doesn't modify the content of data dict.
112
- """
113
- if not isinstance(data, dict):
114
- return data
115
- for key, value in data.items():
116
- if isinstance(value, bytes):
117
- data[key] = value.decode()
118
- elif isinstance(value, dict):
119
- process_bytes_fields(value)
120
- elif isinstance(value, list):
121
- if all(isinstance(v, bytes) for v in value):
122
- data[key] = [v.decode() for v in value]
123
- else:
124
- data[key] = [process_bytes_fields(v) for v in value]
125
- else:
126
- data[key] = value
127
- return data
128
-
129
-
130
109
  def _current_file_path_and_line():
131
110
  """Prints the current file path and line number."""
132
111
  frame = inspect.currentframe().f_back.f_back
@@ -185,7 +164,7 @@ class ReplayFile(BaseModel):
185
164
 
186
165
 
187
166
  class ReplayApiClient(ApiClient):
188
- """For integration testing, send recorded responese or records a response."""
167
+ """For integration testing, send recorded response or records a response."""
189
168
 
190
169
  def __init__(
191
170
  self,
@@ -280,9 +259,18 @@ class ReplayApiClient(ApiClient):
280
259
  replay_file_path = self._get_replay_file_path()
281
260
  os.makedirs(os.path.dirname(replay_file_path), exist_ok=True)
282
261
  with open(replay_file_path, 'w') as f:
262
+ replay_session_dict = self.replay_session.model_dump()
263
+ # Use for non-utf-8 bytes in image/video... output.
264
+ for interaction in replay_session_dict['interactions']:
265
+ segments = []
266
+ for response in interaction['response']['sdk_response_segments']:
267
+ segments.append(json.loads(json.dumps(
268
+ response, cls=ResponseJsonEncoder
269
+ )))
270
+ interaction['response']['sdk_response_segments'] = segments
283
271
  f.write(
284
272
  json.dumps(
285
- self.replay_session.model_dump(), indent=2, cls=ResponseJsonEncoder
273
+ replay_session_dict, indent=2, cls=RequestJsonEncoder
286
274
  )
287
275
  )
288
276
  self.replay_session = None
@@ -463,10 +451,16 @@ class ResponseJsonEncoder(json.JSONEncoder):
463
451
  """
464
452
  def default(self, o):
465
453
  if isinstance(o, bytes):
466
- # use error replace because response need to be serialized with bytes
467
- # string, not base64 string. Otherwise, we cannot tell the response is
468
- # already decoded from base64 or not from the replay file.
469
- return o.decode(encoding='utf-8', errors='replace')
454
+ # Use base64.b64encode() to encode bytes to string so that the media bytes
455
+ # fields are serializable.
456
+ # o.decode(encoding='utf-8', errors='replace') doesn't work because it
457
+ # uses a fixed error string `\ufffd` for all non-utf-8 characters,
458
+ # which cannot be converted back to original bytes. And other languages
459
+ # only have the original bytes to compare with.
460
+ # Since we use base64.b64encoding() in replay test, a change that breaks
461
+ # native bytes can be captured by
462
+ # test_compute_tokens.py::test_token_bytes_deserialization.
463
+ return base64.b64encode(o).decode(encoding='utf-8')
470
464
  elif isinstance(o, datetime.datetime):
471
465
  # dt.isoformat() prints "2024-11-15T23:27:45.624657+00:00"
472
466
  # but replay files want "2024-11-15T23:27:45.624657Z"