google-genai 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +240 -71
- google/genai/_common.py +47 -31
- google/genai/_extra_utils.py +3 -3
- google/genai/_replay_api_client.py +51 -74
- google/genai/_transformers.py +197 -30
- google/genai/batches.py +74 -72
- google/genai/caches.py +104 -90
- google/genai/chats.py +5 -8
- google/genai/client.py +2 -1
- google/genai/errors.py +1 -1
- google/genai/files.py +302 -102
- google/genai/live.py +42 -30
- google/genai/models.py +379 -250
- google/genai/tunings.py +78 -76
- google/genai/types.py +563 -350
- google/genai/version.py +1 -1
- google_genai-0.6.0.dist-info/METADATA +973 -0
- google_genai-0.6.0.dist-info/RECORD +25 -0
- google_genai-0.4.0.dist-info/METADATA +0 -888
- google_genai-0.4.0.dist-info/RECORD +0 -25
- {google_genai-0.4.0.dist-info → google_genai-0.6.0.dist-info}/LICENSE +0 -0
- {google_genai-0.4.0.dist-info → google_genai-0.6.0.dist-info}/WHEEL +0 -0
- {google_genai-0.4.0.dist-info → google_genai-0.6.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -20,7 +20,9 @@ import asyncio
|
|
20
20
|
import copy
|
21
21
|
from dataclasses import dataclass
|
22
22
|
import datetime
|
23
|
+
import io
|
23
24
|
import json
|
25
|
+
import logging
|
24
26
|
import os
|
25
27
|
import sys
|
26
28
|
from typing import Any, Optional, Tuple, TypedDict, Union
|
@@ -60,6 +62,10 @@ class HttpOptions(BaseModel):
|
|
60
62
|
default=None,
|
61
63
|
description="""Timeout for the request in seconds.""",
|
62
64
|
)
|
65
|
+
skip_project_and_location_in_path: bool = Field(
|
66
|
+
default=False,
|
67
|
+
description="""If set to True, the project and location will not be appended to the path.""",
|
68
|
+
)
|
63
69
|
|
64
70
|
|
65
71
|
class HttpOptionsDict(TypedDict):
|
@@ -69,13 +75,14 @@ class HttpOptionsDict(TypedDict):
|
|
69
75
|
"""The base URL for the AI platform service endpoint."""
|
70
76
|
api_version: Optional[str] = None
|
71
77
|
"""Specifies the version of the API to use."""
|
72
|
-
headers: Optional[dict[str,
|
78
|
+
headers: Optional[dict[str, str]] = None
|
73
79
|
"""Additional HTTP headers to be sent with the request."""
|
74
80
|
response_payload: Optional[dict] = None
|
75
81
|
"""If set, the response payload will be returned int the supplied dict."""
|
76
82
|
timeout: Optional[Union[float, Tuple[float, float]]] = None
|
77
83
|
"""Timeout for the request in seconds."""
|
78
|
-
|
84
|
+
skip_project_and_location_in_path: bool = False
|
85
|
+
"""If set to True, the project and location will not be appended to the path."""
|
79
86
|
|
80
87
|
HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
|
81
88
|
|
@@ -133,7 +140,7 @@ def _join_url_path(base_url: str, path: str) -> str:
|
|
133
140
|
|
134
141
|
@dataclass
|
135
142
|
class HttpRequest:
|
136
|
-
headers: dict[str,
|
143
|
+
headers: dict[str, str]
|
137
144
|
url: str
|
138
145
|
method: str
|
139
146
|
data: Union[dict[str, object], bytes]
|
@@ -142,10 +149,16 @@ class HttpRequest:
|
|
142
149
|
|
143
150
|
class HttpResponse:
|
144
151
|
|
145
|
-
def __init__(
|
152
|
+
def __init__(
|
153
|
+
self,
|
154
|
+
headers: dict[str, str],
|
155
|
+
response_stream: Union[Any, str] = None,
|
156
|
+
byte_stream: Union[Any, bytes] = None,
|
157
|
+
):
|
146
158
|
self.status_code = 200
|
147
159
|
self.headers = headers
|
148
160
|
self.response_stream = response_stream
|
161
|
+
self.byte_stream = byte_stream
|
149
162
|
|
150
163
|
@property
|
151
164
|
def text(self) -> str:
|
@@ -158,6 +171,8 @@ class HttpResponse:
|
|
158
171
|
# list of objects retrieved from replay or from non-streaming API.
|
159
172
|
for chunk in self.response_stream:
|
160
173
|
yield json.loads(chunk) if chunk else {}
|
174
|
+
elif self.response_stream is None:
|
175
|
+
yield from []
|
161
176
|
else:
|
162
177
|
# Iterator of objects retrieved from the API.
|
163
178
|
for chunk in self.response_stream.iter_lines():
|
@@ -168,6 +183,17 @@ class HttpResponse:
|
|
168
183
|
chunk = chunk[len(b'data: ') :]
|
169
184
|
yield json.loads(str(chunk, 'utf-8'))
|
170
185
|
|
186
|
+
def byte_segments(self):
|
187
|
+
if isinstance(self.byte_stream, list):
|
188
|
+
# list of objects retrieved from replay or from non-streaming API.
|
189
|
+
yield from self.byte_stream
|
190
|
+
elif self.byte_stream is None:
|
191
|
+
yield from []
|
192
|
+
else:
|
193
|
+
raise ValueError(
|
194
|
+
'Byte segments are not supported for streaming responses.'
|
195
|
+
)
|
196
|
+
|
171
197
|
def copy_to_dict(self, response_payload: dict[str, object]):
|
172
198
|
for attribute in dir(self):
|
173
199
|
response_payload[attribute] = copy.deepcopy(getattr(self, attribute))
|
@@ -193,11 +219,17 @@ class ApiClient:
|
|
193
219
|
]:
|
194
220
|
self.vertexai = True
|
195
221
|
|
196
|
-
# Validate explicitly set
|
222
|
+
# Validate explicitly set initializer values.
|
197
223
|
if (project or location) and api_key:
|
224
|
+
# API cannot consume both project/location and api_key.
|
198
225
|
raise ValueError(
|
199
226
|
'Project/location and API key are mutually exclusive in the client initializer.'
|
200
227
|
)
|
228
|
+
elif credentials and api_key:
|
229
|
+
# API cannot consume both credentials and api_key.
|
230
|
+
raise ValueError(
|
231
|
+
'Credentials and API key are mutually exclusive in the client initializer.'
|
232
|
+
)
|
201
233
|
|
202
234
|
# Validate http_options if a dict is provided.
|
203
235
|
if isinstance(http_options, dict):
|
@@ -208,26 +240,66 @@ class ApiClient:
|
|
208
240
|
elif(isinstance(http_options, HttpOptions)):
|
209
241
|
http_options = http_options.model_dump()
|
210
242
|
|
211
|
-
|
212
|
-
|
213
|
-
|
243
|
+
# Retrieve implicitly set values from the environment.
|
244
|
+
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
245
|
+
env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None)
|
246
|
+
env_api_key = os.environ.get('GOOGLE_API_KEY', None)
|
247
|
+
self.project = project or env_project
|
248
|
+
self.location = location or env_location
|
249
|
+
self.api_key = api_key or env_api_key
|
250
|
+
|
214
251
|
self._credentials = credentials
|
215
252
|
self._http_options = HttpOptionsDict()
|
216
253
|
|
254
|
+
# Handle when to use Vertex AI in express mode (api key).
|
255
|
+
# Explicit initializer arguments are already validated above.
|
217
256
|
if self.vertexai:
|
218
|
-
if
|
257
|
+
if credentials:
|
258
|
+
# Explicit credentials take precedence over implicit api_key.
|
259
|
+
logging.info(
|
260
|
+
'The user provided Google Cloud credentials will take precedence'
|
261
|
+
+ ' over the API key from the environment variable.'
|
262
|
+
)
|
263
|
+
self.api_key = None
|
264
|
+
elif (env_location or env_project) and api_key:
|
265
|
+
# Explicit api_key takes precedence over implicit project/location.
|
266
|
+
logging.info(
|
267
|
+
'The user provided Vertex AI API key will take precedence over the'
|
268
|
+
+ ' project/location from the environment variables.'
|
269
|
+
)
|
270
|
+
self.project = None
|
271
|
+
self.location = None
|
272
|
+
elif (project or location) and env_api_key:
|
273
|
+
# Explicit project/location takes precedence over implicit api_key.
|
274
|
+
logging.info(
|
275
|
+
'The user provided project/location will take precedence over the'
|
276
|
+
+ ' Vertex AI API key from the environment variable.'
|
277
|
+
)
|
278
|
+
self.api_key = None
|
279
|
+
elif (env_location or env_project) and env_api_key:
|
280
|
+
# Implicit project/location takes precedence over implicit api_key.
|
281
|
+
logging.info(
|
282
|
+
'The project/location from the environment variables will take'
|
283
|
+
+ ' precedence over the API key from the environment variables.'
|
284
|
+
)
|
285
|
+
self.api_key = None
|
286
|
+
if not self.project and not self.api_key:
|
219
287
|
self.project = google.auth.default()[1]
|
220
|
-
|
221
|
-
if not self.project or not self.location:
|
288
|
+
if not ((self.project and self.location) or self.api_key):
|
222
289
|
raise ValueError(
|
223
|
-
'Project and location must be set when using the Vertex
|
290
|
+
'Project and location or API key must be set when using the Vertex '
|
291
|
+
'AI API.'
|
292
|
+
)
|
293
|
+
if self.api_key:
|
294
|
+
self._http_options['base_url'] = (
|
295
|
+
f'https://aiplatform.googleapis.com/'
|
296
|
+
)
|
297
|
+
else:
|
298
|
+
self._http_options['base_url'] = (
|
299
|
+
f'https://{self.location}-aiplatform.googleapis.com/'
|
224
300
|
)
|
225
|
-
self._http_options['base_url'] = (
|
226
|
-
f'https://{self.location}-aiplatform.googleapis.com/'
|
227
|
-
)
|
228
301
|
self._http_options['api_version'] = 'v1beta1'
|
229
302
|
else: # ML Dev API
|
230
|
-
self.api_key = api_key or os.environ.get('GOOGLE_API_KEY', None)
|
231
303
|
if not self.api_key:
|
232
304
|
raise ValueError('API key must be set when using the Google AI API.')
|
233
305
|
self._http_options['base_url'] = (
|
@@ -236,7 +308,7 @@ class ApiClient:
|
|
236
308
|
self._http_options['api_version'] = 'v1beta'
|
237
309
|
# Default options for both clients.
|
238
310
|
self._http_options['headers'] = {'Content-Type': 'application/json'}
|
239
|
-
if self.api_key:
|
311
|
+
if self.api_key and not self.vertexai:
|
240
312
|
self._http_options['headers']['x-goog-api-key'] = self.api_key
|
241
313
|
# Update the http options with the user provided http options.
|
242
314
|
if http_options:
|
@@ -266,8 +338,18 @@ class ApiClient:
|
|
266
338
|
)
|
267
339
|
else:
|
268
340
|
patched_http_options = self._http_options
|
269
|
-
|
341
|
+
skip_project_and_location_in_path_val = patched_http_options.get(
|
342
|
+
'skip_project_and_location_in_path', False
|
343
|
+
)
|
344
|
+
if (
|
345
|
+
self.vertexai
|
346
|
+
and not path.startswith('projects/')
|
347
|
+
and not skip_project_and_location_in_path_val
|
348
|
+
and not self.api_key
|
349
|
+
):
|
270
350
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
351
|
+
elif self.vertexai and self.api_key:
|
352
|
+
path = f'{path}?key={self.api_key}'
|
271
353
|
url = _join_url_path(
|
272
354
|
patched_http_options['base_url'],
|
273
355
|
patched_http_options['api_version'] + '/' + path,
|
@@ -285,7 +367,7 @@ class ApiClient:
|
|
285
367
|
http_request: HttpRequest,
|
286
368
|
stream: bool = False,
|
287
369
|
) -> HttpResponse:
|
288
|
-
if self.vertexai:
|
370
|
+
if self.vertexai and not self.api_key:
|
289
371
|
if not self._credentials:
|
290
372
|
self._credentials, _ = google.auth.default(
|
291
373
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
@@ -296,7 +378,7 @@ class ApiClient:
|
|
296
378
|
http_request.method.upper(),
|
297
379
|
http_request.url,
|
298
380
|
headers=http_request.headers,
|
299
|
-
data=json.dumps(http_request.data
|
381
|
+
data=json.dumps(http_request.data)
|
300
382
|
if http_request.data
|
301
383
|
else None,
|
302
384
|
timeout=http_request.timeout,
|
@@ -316,7 +398,7 @@ class ApiClient:
|
|
316
398
|
data = None
|
317
399
|
if http_request.data:
|
318
400
|
if not isinstance(http_request.data, bytes):
|
319
|
-
data = json.dumps(http_request.data
|
401
|
+
data = json.dumps(http_request.data)
|
320
402
|
else:
|
321
403
|
data = http_request.data
|
322
404
|
|
@@ -427,12 +509,35 @@ class ApiClient:
|
|
427
509
|
if http_options and 'response_payload' in http_options:
|
428
510
|
response.copy_to_dict(http_options['response_payload'])
|
429
511
|
|
430
|
-
def upload_file(
|
512
|
+
def upload_file(
|
513
|
+
self, file_path: Union[str, io.IOBase], upload_url: str, upload_size: int
|
514
|
+
) -> str:
|
431
515
|
"""Transfers a file to the given URL.
|
432
516
|
|
433
517
|
Args:
|
434
|
-
file_path: The full path to the file
|
435
|
-
an error will be
|
518
|
+
file_path: The full path to the file or a file like object inherited from
|
519
|
+
io.BytesIO. If the local file path is not found, an error will be
|
520
|
+
raised.
|
521
|
+
upload_url: The URL to upload the file to.
|
522
|
+
upload_size: The size of file content to be uploaded, this will have to
|
523
|
+
match the size requested in the resumable upload request.
|
524
|
+
|
525
|
+
returns:
|
526
|
+
The response json object from the finalize request.
|
527
|
+
"""
|
528
|
+
if isinstance(file_path, io.IOBase):
|
529
|
+
return self._upload_fd(file_path, upload_url, upload_size)
|
530
|
+
else:
|
531
|
+
with open(file_path, 'rb') as file:
|
532
|
+
return self._upload_fd(file, upload_url, upload_size)
|
533
|
+
|
534
|
+
def _upload_fd(
|
535
|
+
self, file: io.IOBase, upload_url: str, upload_size: int
|
536
|
+
) -> str:
|
537
|
+
"""Transfers a file to the given URL.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
file: A file like object inherited from io.BytesIO.
|
436
541
|
upload_url: The URL to upload the file to.
|
437
542
|
upload_size: The size of file content to be uploaded, this will have to
|
438
543
|
match the size requested in the resumable upload request.
|
@@ -442,37 +547,36 @@ class ApiClient:
|
|
442
547
|
"""
|
443
548
|
offset = 0
|
444
549
|
# Upload the file in chunks
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
data=file_chunk,
|
465
|
-
)
|
466
|
-
response = self._request_unauthorized(request, stream=False)
|
467
|
-
offset += chunk_size
|
468
|
-
if response.headers['X-Goog-Upload-Status'] != 'active':
|
469
|
-
break # upload is complete or it has been interrupted.
|
550
|
+
while True:
|
551
|
+
file_chunk = file.read(1024 * 1024 * 8) # 8 MB chunk size
|
552
|
+
chunk_size = 0
|
553
|
+
if file_chunk:
|
554
|
+
chunk_size = len(file_chunk)
|
555
|
+
upload_command = 'upload'
|
556
|
+
# If last chunk, finalize the upload.
|
557
|
+
if chunk_size + offset >= upload_size:
|
558
|
+
upload_command += ', finalize'
|
559
|
+
request = HttpRequest(
|
560
|
+
method='POST',
|
561
|
+
url=upload_url,
|
562
|
+
headers={
|
563
|
+
'X-Goog-Upload-Command': upload_command,
|
564
|
+
'X-Goog-Upload-Offset': str(offset),
|
565
|
+
'Content-Length': str(chunk_size),
|
566
|
+
},
|
567
|
+
data=file_chunk,
|
568
|
+
)
|
470
569
|
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
570
|
+
response = self._request_unauthorized(request, stream=False)
|
571
|
+
offset += chunk_size
|
572
|
+
if response.headers['X-Goog-Upload-Status'] != 'active':
|
573
|
+
break # upload is complete or it has been interrupted.
|
574
|
+
|
575
|
+
if upload_size <= offset: # Status is not finalized.
|
576
|
+
raise ValueError(
|
577
|
+
'All content has been uploaded, but the upload status is not'
|
578
|
+
f' finalized. {response.headers}, body: {response.text}'
|
579
|
+
)
|
476
580
|
|
477
581
|
if response.headers['X-Goog-Upload-Status'] != 'final':
|
478
582
|
raise ValueError(
|
@@ -481,12 +585,52 @@ class ApiClient:
|
|
481
585
|
)
|
482
586
|
return response.text
|
483
587
|
|
588
|
+
def download_file(self, path: str, http_options):
|
589
|
+
"""Downloads the file data.
|
590
|
+
|
591
|
+
Args:
|
592
|
+
path: The request path with query params.
|
593
|
+
http_options: The http options to use for the request.
|
594
|
+
|
595
|
+
returns:
|
596
|
+
The file bytes
|
597
|
+
"""
|
598
|
+
http_request = self._build_request(
|
599
|
+
'get', path=path, request_dict={}, http_options=http_options
|
600
|
+
)
|
601
|
+
return self._download_file_request(http_request).byte_stream[0]
|
602
|
+
|
603
|
+
def _download_file_request(
|
604
|
+
self,
|
605
|
+
http_request: HttpRequest,
|
606
|
+
) -> HttpResponse:
|
607
|
+
data = None
|
608
|
+
if http_request.data:
|
609
|
+
if not isinstance(http_request.data, bytes):
|
610
|
+
data = json.dumps(http_request.data, cls=RequestJsonEncoder)
|
611
|
+
else:
|
612
|
+
data = http_request.data
|
613
|
+
|
614
|
+
http_session = requests.Session()
|
615
|
+
response = http_session.request(
|
616
|
+
method=http_request.method,
|
617
|
+
url=http_request.url,
|
618
|
+
headers=http_request.headers,
|
619
|
+
data=data,
|
620
|
+
timeout=http_request.timeout,
|
621
|
+
stream=False,
|
622
|
+
)
|
623
|
+
|
624
|
+
errors.APIError.raise_for_response(response)
|
625
|
+
return HttpResponse(response.headers, byte_stream=[response.content])
|
626
|
+
|
627
|
+
|
484
628
|
async def async_upload_file(
|
485
629
|
self,
|
486
|
-
file_path: str,
|
630
|
+
file_path: Union[str, io.IOBase],
|
487
631
|
upload_url: str,
|
488
632
|
upload_size: int,
|
489
|
-
):
|
633
|
+
) -> str:
|
490
634
|
"""Transfers a file asynchronously to the given URL.
|
491
635
|
|
492
636
|
Args:
|
@@ -506,23 +650,48 @@ class ApiClient:
|
|
506
650
|
upload_size,
|
507
651
|
)
|
508
652
|
|
653
|
+
async def _async_upload_fd(
|
654
|
+
self,
|
655
|
+
file: io.IOBase,
|
656
|
+
upload_url: str,
|
657
|
+
upload_size: int,
|
658
|
+
) -> str:
|
659
|
+
"""Transfers a file asynchronously to the given URL.
|
660
|
+
|
661
|
+
Args:
|
662
|
+
file: A file like object inherited from io.BytesIO.
|
663
|
+
upload_url: The URL to upload the file to.
|
664
|
+
upload_size: The size of file content to be uploaded, this will have to
|
665
|
+
match the size requested in the resumable upload request.
|
666
|
+
|
667
|
+
returns:
|
668
|
+
The response json object from the finalize request.
|
669
|
+
"""
|
670
|
+
return await asyncio.to_thread(
|
671
|
+
self._upload_fd,
|
672
|
+
file,
|
673
|
+
upload_url,
|
674
|
+
upload_size,
|
675
|
+
)
|
676
|
+
|
677
|
+
async def async_download_file(self, path: str, http_options):
|
678
|
+
"""Downloads the file data.
|
679
|
+
|
680
|
+
Args:
|
681
|
+
path: The request path with query params.
|
682
|
+
http_options: The http options to use for the request.
|
683
|
+
|
684
|
+
returns:
|
685
|
+
The file bytes
|
686
|
+
"""
|
687
|
+
return await asyncio.to_thread(
|
688
|
+
self.download_file,
|
689
|
+
path,
|
690
|
+
http_options,
|
691
|
+
)
|
692
|
+
|
509
693
|
# This method does nothing in the real api client. It is used in the
|
510
694
|
# replay_api_client to verify the response from the SDK method matches the
|
511
695
|
# recorded response.
|
512
696
|
def _verify_response(self, response_model: BaseModel):
|
513
697
|
pass
|
514
|
-
|
515
|
-
|
516
|
-
class RequestJsonEncoder(json.JSONEncoder):
|
517
|
-
"""Encode bytes as strings without modify its content."""
|
518
|
-
|
519
|
-
def default(self, o):
|
520
|
-
if isinstance(o, bytes):
|
521
|
-
return o.decode()
|
522
|
-
elif isinstance(o, datetime.datetime):
|
523
|
-
# This Zulu time format is used by the Vertex AI API and the test recorder
|
524
|
-
# Using strftime works well, but we want to align with the replay encoder.
|
525
|
-
# o.astimezone(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
526
|
-
return o.isoformat().replace('+00:00', 'Z')
|
527
|
-
else:
|
528
|
-
return super().default(o)
|
google/genai/_common.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import base64
|
19
19
|
import datetime
|
20
|
-
import
|
20
|
+
import enum
|
21
21
|
import typing
|
22
22
|
from typing import Union
|
23
23
|
import uuid
|
@@ -116,7 +116,7 @@ def get_value_by_path(data: object, keys: list[str]):
|
|
116
116
|
class BaseModule:
|
117
117
|
|
118
118
|
def __init__(self, api_client_: _api_client.ApiClient):
|
119
|
-
self.
|
119
|
+
self._api_client = api_client_
|
120
120
|
|
121
121
|
|
122
122
|
def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
|
@@ -145,7 +145,7 @@ def _remove_extra_fields(
|
|
145
145
|
) -> None:
|
146
146
|
"""Removes extra fields from the response that are not in the model.
|
147
147
|
|
148
|
-
|
148
|
+
Mutates the response in place.
|
149
149
|
"""
|
150
150
|
|
151
151
|
key_values = list(response.items())
|
@@ -186,10 +186,12 @@ class BaseModel(pydantic.BaseModel):
|
|
186
186
|
alias_generator=alias_generators.to_camel,
|
187
187
|
populate_by_name=True,
|
188
188
|
from_attributes=True,
|
189
|
-
protected_namespaces=
|
189
|
+
protected_namespaces=(),
|
190
190
|
extra='forbid',
|
191
191
|
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
192
192
|
arbitrary_types_allowed=True,
|
193
|
+
ser_json_bytes='base64',
|
194
|
+
val_json_bytes='base64',
|
193
195
|
)
|
194
196
|
|
195
197
|
@classmethod
|
@@ -201,7 +203,24 @@ class BaseModel(pydantic.BaseModel):
|
|
201
203
|
# We will provide another mechanism to allow users to access these fields.
|
202
204
|
_remove_extra_fields(cls, response)
|
203
205
|
validated_response = cls.model_validate(response)
|
204
|
-
return
|
206
|
+
return validated_response
|
207
|
+
|
208
|
+
def to_json_dict(self) -> dict[str, object]:
|
209
|
+
return self.model_dump(exclude_none=True, mode='json')
|
210
|
+
|
211
|
+
|
212
|
+
class CaseInSensitiveEnum(str, enum.Enum):
|
213
|
+
"""Case insensitive enum."""
|
214
|
+
|
215
|
+
@classmethod
|
216
|
+
def _missing_(cls, value):
|
217
|
+
try:
|
218
|
+
return cls[value.upper()] # Try to access directly with uppercase
|
219
|
+
except KeyError:
|
220
|
+
try:
|
221
|
+
return cls[value.lower()] # Try to access directly with lowercase
|
222
|
+
except KeyError as e:
|
223
|
+
raise ValueError(f"{value} is not a valid {cls.__name__}") from e
|
205
224
|
|
206
225
|
|
207
226
|
def timestamped_unique_name() -> str:
|
@@ -215,42 +234,39 @@ def timestamped_unique_name() -> str:
|
|
215
234
|
return f'{timestamp}_{unique_id}'
|
216
235
|
|
217
236
|
|
218
|
-
def
|
219
|
-
"""
|
220
|
-
return process_bytes_fields(data, encode=True)
|
221
|
-
|
222
|
-
|
223
|
-
def apply_base64_decoding(data: dict[str, object]) -> dict[str, object]:
|
224
|
-
"""Applies base64 decoding to bytes values in the given data."""
|
225
|
-
return process_bytes_fields(data, encode=False)
|
237
|
+
def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
238
|
+
"""Converts unserializable types in dict to json.dumps() compatible types.
|
226
239
|
|
240
|
+
This function is called in models.py after calling convert_to_dict(). The
|
241
|
+
convert_to_dict() can convert pydantic object to dict. However, the input to
|
242
|
+
convert_to_dict() is dict mixed of pydantic object and nested dict(the output
|
243
|
+
of converters). So they may be bytes in the dict and they are out of
|
244
|
+
`ser_json_bytes` control in model_dump(mode='json') called in
|
245
|
+
`convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
|
227
246
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
def process_bytes_fields(data: dict[str, object], encode=True) -> dict[str, object]:
|
247
|
+
Returns:
|
248
|
+
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
|
249
|
+
to compatible type (e.g. base64 encoded string, isoformat date string).
|
250
|
+
"""
|
235
251
|
processed_data = {}
|
236
252
|
if not isinstance(data, dict):
|
237
253
|
return data
|
238
254
|
for key, value in data.items():
|
239
255
|
if isinstance(value, bytes):
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
processed_data[key] = base64.b64decode(value)
|
256
|
+
processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
|
257
|
+
elif isinstance(value, datetime.datetime):
|
258
|
+
processed_data[key] = value.isoformat()
|
244
259
|
elif isinstance(value, dict):
|
245
|
-
processed_data[key] =
|
260
|
+
processed_data[key] = encode_unserializable_types(value)
|
246
261
|
elif isinstance(value, list):
|
247
|
-
if
|
248
|
-
processed_data[key] = [
|
249
|
-
|
250
|
-
|
262
|
+
if all(isinstance(v, bytes) for v in value):
|
263
|
+
processed_data[key] = [
|
264
|
+
base64.urlsafe_b64encode(v).decode('ascii') for v in value
|
265
|
+
]
|
266
|
+
if all(isinstance(v, datetime.datetime) for v in value):
|
267
|
+
processed_data[key] = [v.isoformat() for v in value]
|
251
268
|
else:
|
252
|
-
processed_data[key] = [
|
269
|
+
processed_data[key] = [encode_unserializable_types(v) for v in value]
|
253
270
|
else:
|
254
271
|
processed_data[key] = value
|
255
272
|
return processed_data
|
256
|
-
|
google/genai/_extra_utils.py
CHANGED
@@ -116,7 +116,7 @@ def convert_if_exist_pydantic_model(
|
|
116
116
|
try:
|
117
117
|
return annotation(**value)
|
118
118
|
except pydantic.ValidationError as e:
|
119
|
-
raise errors.
|
119
|
+
raise errors.UnknownFunctionCallArgumentError(
|
120
120
|
f'Failed to parse parameter {param_name} for function'
|
121
121
|
f' {func_name} from function call part because function call argument'
|
122
122
|
f' value {value} is not compatible with parameter annotation'
|
@@ -150,7 +150,7 @@ def convert_if_exist_pydantic_model(
|
|
150
150
|
except pydantic.ValidationError:
|
151
151
|
continue
|
152
152
|
# if none of the union type is matched, raise error
|
153
|
-
raise errors.
|
153
|
+
raise errors.UnknownFunctionCallArgumentError(
|
154
154
|
f'Failed to parse parameter {param_name} for function'
|
155
155
|
f' {func_name} from function call part because function call argument'
|
156
156
|
f' value {value} cannot be converted to parameter annotation'
|
@@ -161,7 +161,7 @@ def convert_if_exist_pydantic_model(
|
|
161
161
|
if isinstance(value, int) and annotation is float:
|
162
162
|
return value
|
163
163
|
if not isinstance(value, annotation):
|
164
|
-
raise errors.
|
164
|
+
raise errors.UnknownFunctionCallArgumentError(
|
165
165
|
f'Failed to parse parameter {param_name} for function {func_name} from'
|
166
166
|
f' function call part because function call argument value {value} is'
|
167
167
|
f' not compatible with parameter annotation {annotation}.'
|