google-genai 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +2 -1
- google/genai/_api_client.py +161 -52
- google/genai/_automatic_function_calling_util.py +14 -14
- google/genai/_common.py +14 -29
- google/genai/_replay_api_client.py +13 -54
- google/genai/_transformers.py +38 -0
- google/genai/batches.py +80 -78
- google/genai/caches.py +112 -98
- google/genai/chats.py +7 -10
- google/genai/client.py +6 -3
- google/genai/files.py +91 -90
- google/genai/live.py +65 -34
- google/genai/models.py +374 -297
- google/genai/tunings.py +87 -85
- google/genai/types.py +167 -82
- google/genai/version.py +16 -0
- {google_genai-0.3.0.dist-info → google_genai-0.5.0.dist-info}/METADATA +57 -17
- google_genai-0.5.0.dist-info/RECORD +25 -0
- {google_genai-0.3.0.dist-info → google_genai-0.5.0.dist-info}/WHEEL +1 -1
- google_genai-0.3.0.dist-info/RECORD +0 -24
- {google_genai-0.3.0.dist-info → google_genai-0.5.0.dist-info}/LICENSE +0 -0
- {google_genai-0.3.0.dist-info → google_genai-0.5.0.dist-info}/top_level.txt +0 -0
google/genai/__init__.py
CHANGED
google/genai/_api_client.py
CHANGED
@@ -21,37 +21,74 @@ import copy
|
|
21
21
|
from dataclasses import dataclass
|
22
22
|
import datetime
|
23
23
|
import json
|
24
|
+
import logging
|
24
25
|
import os
|
25
26
|
import sys
|
26
|
-
from typing import Any, Optional, TypedDict, Union
|
27
|
+
from typing import Any, Optional, Tuple, TypedDict, Union
|
27
28
|
from urllib.parse import urlparse, urlunparse
|
28
29
|
|
29
30
|
import google.auth
|
30
31
|
import google.auth.credentials
|
31
32
|
from google.auth.transport.requests import AuthorizedSession
|
32
|
-
from pydantic import BaseModel
|
33
|
+
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
33
34
|
import requests
|
34
35
|
|
35
36
|
from . import errors
|
37
|
+
from . import version
|
36
38
|
|
37
39
|
|
38
|
-
class HttpOptions(
|
40
|
+
class HttpOptions(BaseModel):
|
41
|
+
"""HTTP options for the api client."""
|
42
|
+
model_config = ConfigDict(extra='forbid')
|
43
|
+
|
44
|
+
base_url: Optional[str] = Field(
|
45
|
+
default=None,
|
46
|
+
description="""The base URL for the AI platform service endpoint.""",
|
47
|
+
)
|
48
|
+
api_version: Optional[str] = Field(
|
49
|
+
default=None,
|
50
|
+
description="""Specifies the version of the API to use.""",
|
51
|
+
)
|
52
|
+
headers: Optional[dict[str, str]] = Field(
|
53
|
+
default=None,
|
54
|
+
description="""Additional HTTP headers to be sent with the request.""",
|
55
|
+
)
|
56
|
+
response_payload: Optional[dict] = Field(
|
57
|
+
default=None,
|
58
|
+
description="""If set, the response payload will be returned int the supplied dict.""",
|
59
|
+
)
|
60
|
+
timeout: Optional[Union[float, Tuple[float, float]]] = Field(
|
61
|
+
default=None,
|
62
|
+
description="""Timeout for the request in seconds.""",
|
63
|
+
)
|
64
|
+
skip_project_and_location_in_path: bool = Field(
|
65
|
+
default=False,
|
66
|
+
description="""If set to True, the project and location will not be appended to the path.""",
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
class HttpOptionsDict(TypedDict):
|
39
71
|
"""HTTP options for the api client."""
|
40
72
|
|
41
|
-
base_url: str = None
|
73
|
+
base_url: Optional[str] = None
|
42
74
|
"""The base URL for the AI platform service endpoint."""
|
43
|
-
api_version: str = None
|
75
|
+
api_version: Optional[str] = None
|
44
76
|
"""Specifies the version of the API to use."""
|
45
|
-
headers: dict[str,
|
77
|
+
headers: Optional[dict[str, str]] = None
|
46
78
|
"""Additional HTTP headers to be sent with the request."""
|
47
|
-
response_payload: dict = None
|
79
|
+
response_payload: Optional[dict] = None
|
48
80
|
"""If set, the response payload will be returned int the supplied dict."""
|
81
|
+
timeout: Optional[Union[float, Tuple[float, float]]] = None
|
82
|
+
"""Timeout for the request in seconds."""
|
83
|
+
skip_project_and_location_in_path: bool = False
|
84
|
+
"""If set to True, the project and location will not be appended to the path."""
|
85
|
+
|
86
|
+
HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
|
49
87
|
|
50
88
|
|
51
89
|
def _append_library_version_headers(headers: dict[str, str]) -> None:
|
52
90
|
"""Appends the telemetry header to the headers dict."""
|
53
|
-
|
54
|
-
library_label = f'google-genai-sdk/0.3.0'
|
91
|
+
library_label = f'google-genai-sdk/{version.__version__}'
|
55
92
|
language_label = 'gl-python/' + sys.version.split()[0]
|
56
93
|
version_header_value = f'{library_label} {language_label}'
|
57
94
|
if (
|
@@ -71,20 +108,24 @@ def _append_library_version_headers(headers: dict[str, str]) -> None:
|
|
71
108
|
|
72
109
|
|
73
110
|
def _patch_http_options(
|
74
|
-
options:
|
75
|
-
) ->
|
111
|
+
options: HttpOptionsDict, patch_options: HttpOptionsDict
|
112
|
+
) -> HttpOptionsDict:
|
76
113
|
# use shallow copy so we don't override the original objects.
|
77
|
-
copy_option =
|
114
|
+
copy_option = HttpOptionsDict()
|
78
115
|
copy_option.update(options)
|
79
|
-
for
|
116
|
+
for patch_key, patch_value in patch_options.items():
|
80
117
|
# if both are dicts, update the copy.
|
81
118
|
# This is to handle cases like merging headers.
|
82
|
-
if isinstance(
|
83
|
-
|
84
|
-
|
85
|
-
copy_option[
|
86
|
-
|
87
|
-
|
119
|
+
if isinstance(patch_value, dict) and isinstance(
|
120
|
+
copy_option.get(patch_key, None), dict
|
121
|
+
):
|
122
|
+
copy_option[patch_key] = {}
|
123
|
+
copy_option[patch_key].update(
|
124
|
+
options[patch_key]
|
125
|
+
) # shallow copy from original options.
|
126
|
+
copy_option[patch_key].update(patch_value)
|
127
|
+
elif patch_value is not None: # Accept empty values.
|
128
|
+
copy_option[patch_key] = patch_value
|
88
129
|
_append_library_version_headers(copy_option['headers'])
|
89
130
|
return copy_option
|
90
131
|
|
@@ -102,6 +143,7 @@ class HttpRequest:
|
|
102
143
|
url: str
|
103
144
|
method: str
|
104
145
|
data: Union[dict[str, object], bytes]
|
146
|
+
timeout: Optional[Union[float, Tuple[float, float]]] = None
|
105
147
|
|
106
148
|
|
107
149
|
class HttpResponse:
|
@@ -147,7 +189,7 @@ class ApiClient:
|
|
147
189
|
credentials: google.auth.credentials.Credentials = None,
|
148
190
|
project: Union[str, None] = None,
|
149
191
|
location: Union[str, None] = None,
|
150
|
-
http_options:
|
192
|
+
http_options: HttpOptionsOrDict = None,
|
151
193
|
):
|
152
194
|
self.vertexai = vertexai
|
153
195
|
if self.vertexai is None:
|
@@ -159,30 +201,84 @@ class ApiClient:
|
|
159
201
|
|
160
202
|
# Validate explicitly set intializer values.
|
161
203
|
if (project or location) and api_key:
|
204
|
+
# API cannot consume both project/location and api_key.
|
162
205
|
raise ValueError(
|
163
206
|
'Project/location and API key are mutually exclusive in the client initializer.'
|
164
207
|
)
|
208
|
+
elif credentials and api_key:
|
209
|
+
# API cannot consume both credentials and api_key.
|
210
|
+
raise ValueError(
|
211
|
+
'Credentials and API key are mutually exclusive in the client initializer.'
|
212
|
+
)
|
213
|
+
|
214
|
+
# Validate http_options if a dict is provided.
|
215
|
+
if isinstance(http_options, dict):
|
216
|
+
try:
|
217
|
+
HttpOptions.model_validate(http_options)
|
218
|
+
except ValidationError as e:
|
219
|
+
raise ValueError(f'Invalid http_options: {e}')
|
220
|
+
elif(isinstance(http_options, HttpOptions)):
|
221
|
+
http_options = http_options.model_dump()
|
222
|
+
|
223
|
+
# Retrieve implicitly set values from the environment.
|
224
|
+
env_project = os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
225
|
+
env_location = os.environ.get('GOOGLE_CLOUD_LOCATION', None)
|
226
|
+
env_api_key = os.environ.get('GOOGLE_API_KEY', None)
|
227
|
+
self.project = project or env_project
|
228
|
+
self.location = location or env_location
|
229
|
+
self.api_key = api_key or env_api_key
|
165
230
|
|
166
|
-
self.api_key: Optional[str] = None
|
167
|
-
self.project = project or os.environ.get('GOOGLE_CLOUD_PROJECT', None)
|
168
|
-
self.location = location or os.environ.get('GOOGLE_CLOUD_LOCATION', None)
|
169
231
|
self._credentials = credentials
|
170
|
-
self._http_options =
|
232
|
+
self._http_options = HttpOptionsDict()
|
171
233
|
|
234
|
+
# Handle when to use Vertex AI in express mode (api key).
|
235
|
+
# Explicit initializer arguments are already validated above.
|
172
236
|
if self.vertexai:
|
173
|
-
if
|
237
|
+
if credentials:
|
238
|
+
# Explicit credentials take precedence over implicit api_key.
|
239
|
+
logging.info(
|
240
|
+
'The user provided Google Cloud credentials will take precedence'
|
241
|
+
+ ' over the API key from the environment variable.'
|
242
|
+
)
|
243
|
+
self.api_key = None
|
244
|
+
elif (env_location or env_project) and api_key:
|
245
|
+
# Explicit api_key takes precedence over implicit project/location.
|
246
|
+
logging.info(
|
247
|
+
'The user provided Vertex AI API key will take precedence over the'
|
248
|
+
+ ' project/location from the environment variables.'
|
249
|
+
)
|
250
|
+
self.project = None
|
251
|
+
self.location = None
|
252
|
+
elif (project or location) and env_api_key:
|
253
|
+
# Explicit project/location takes precedence over implicit api_key.
|
254
|
+
logging.info(
|
255
|
+
'The user provided project/location will take precedence over the'
|
256
|
+
+ ' Vertex AI API key from the environment variable.'
|
257
|
+
)
|
258
|
+
self.api_key = None
|
259
|
+
elif (env_location or env_project) and env_api_key:
|
260
|
+
# Implicit project/location takes precedence over implicit api_key.
|
261
|
+
logging.info(
|
262
|
+
'The project/location from the environment variables will take'
|
263
|
+
+ ' precedence over the API key from the environment variables.'
|
264
|
+
)
|
265
|
+
self.api_key = None
|
266
|
+
if not self.project and not self.api_key:
|
174
267
|
self.project = google.auth.default()[1]
|
175
|
-
|
176
|
-
if not self.project or not self.location:
|
268
|
+
if not (self.project or self.location) and not self.api_key:
|
177
269
|
raise ValueError(
|
178
|
-
'Project
|
270
|
+
'Project/location or API key must be set when using the Vertex AI API.'
|
271
|
+
)
|
272
|
+
if self.api_key:
|
273
|
+
self._http_options['base_url'] = (
|
274
|
+
f'https://aiplatform.googleapis.com/'
|
275
|
+
)
|
276
|
+
else:
|
277
|
+
self._http_options['base_url'] = (
|
278
|
+
f'https://{self.location}-aiplatform.googleapis.com/'
|
179
279
|
)
|
180
|
-
self._http_options['base_url'] = (
|
181
|
-
f'https://{self.location}-aiplatform.googleapis.com/'
|
182
|
-
)
|
183
280
|
self._http_options['api_version'] = 'v1beta1'
|
184
281
|
else: # ML Dev API
|
185
|
-
self.api_key = api_key or os.environ.get('GOOGLE_API_KEY', None)
|
186
282
|
if not self.api_key:
|
187
283
|
raise ValueError('API key must be set when using the Google AI API.')
|
188
284
|
self._http_options['base_url'] = (
|
@@ -191,7 +287,7 @@ class ApiClient:
|
|
191
287
|
self._http_options['api_version'] = 'v1beta'
|
192
288
|
# Default options for both clients.
|
193
289
|
self._http_options['headers'] = {'Content-Type': 'application/json'}
|
194
|
-
if self.api_key:
|
290
|
+
if self.api_key and not self.vertexai:
|
195
291
|
self._http_options['headers']['x-goog-api-key'] = self.api_key
|
196
292
|
# Update the http options with the user provided http options.
|
197
293
|
if http_options:
|
@@ -208,7 +304,7 @@ class ApiClient:
|
|
208
304
|
http_method: str,
|
209
305
|
path: str,
|
210
306
|
request_dict: dict[str, object],
|
211
|
-
http_options:
|
307
|
+
http_options: HttpOptionsDict = None,
|
212
308
|
) -> HttpRequest:
|
213
309
|
# Remove all special dict keys such as _url and _query.
|
214
310
|
keys_to_delete = [key for key in request_dict.keys() if key.startswith('_')]
|
@@ -221,8 +317,18 @@ class ApiClient:
|
|
221
317
|
)
|
222
318
|
else:
|
223
319
|
patched_http_options = self._http_options
|
224
|
-
|
320
|
+
skip_project_and_location_in_path_val = patched_http_options.get(
|
321
|
+
'skip_project_and_location_in_path', False
|
322
|
+
)
|
323
|
+
if (
|
324
|
+
self.vertexai
|
325
|
+
and not path.startswith('projects/')
|
326
|
+
and not skip_project_and_location_in_path_val
|
327
|
+
and not self.api_key
|
328
|
+
):
|
225
329
|
path = f'projects/{self.project}/locations/{self.location}/' + path
|
330
|
+
elif self.vertexai and self.api_key:
|
331
|
+
path = f'{path}?key={self.api_key}'
|
226
332
|
url = _join_url_path(
|
227
333
|
patched_http_options['base_url'],
|
228
334
|
patched_http_options['api_version'] + '/' + path,
|
@@ -232,6 +338,7 @@ class ApiClient:
|
|
232
338
|
url=url,
|
233
339
|
headers=patched_http_options['headers'],
|
234
340
|
data=request_dict,
|
341
|
+
timeout=patched_http_options.get('timeout', None),
|
235
342
|
)
|
236
343
|
|
237
344
|
def _request(
|
@@ -239,7 +346,7 @@ class ApiClient:
|
|
239
346
|
http_request: HttpRequest,
|
240
347
|
stream: bool = False,
|
241
348
|
) -> HttpResponse:
|
242
|
-
if self.vertexai:
|
349
|
+
if self.vertexai and not self.api_key:
|
243
350
|
if not self._credentials:
|
244
351
|
self._credentials, _ = google.auth.default(
|
245
352
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
@@ -250,10 +357,10 @@ class ApiClient:
|
|
250
357
|
http_request.method.upper(),
|
251
358
|
http_request.url,
|
252
359
|
headers=http_request.headers,
|
253
|
-
data=json.dumps(http_request.data, cls=RequestJsonEncoder)
|
254
|
-
|
255
|
-
|
256
|
-
timeout=
|
360
|
+
data=json.dumps(http_request.data, cls=RequestJsonEncoder)
|
361
|
+
if http_request.data
|
362
|
+
else None,
|
363
|
+
timeout=http_request.timeout,
|
257
364
|
)
|
258
365
|
errors.APIError.raise_for_response(response)
|
259
366
|
return HttpResponse(
|
@@ -275,13 +382,14 @@ class ApiClient:
|
|
275
382
|
data = http_request.data
|
276
383
|
|
277
384
|
http_session = requests.Session()
|
278
|
-
|
385
|
+
response = http_session.request(
|
279
386
|
method=http_request.method,
|
280
387
|
url=http_request.url,
|
281
388
|
headers=http_request.headers,
|
282
389
|
data=data,
|
283
|
-
|
284
|
-
|
390
|
+
timeout=http_request.timeout,
|
391
|
+
stream=stream,
|
392
|
+
)
|
285
393
|
errors.APIError.raise_for_response(response)
|
286
394
|
return HttpResponse(
|
287
395
|
response.headers, response if stream else [response.text]
|
@@ -307,8 +415,10 @@ class ApiClient:
|
|
307
415
|
stream=stream,
|
308
416
|
)
|
309
417
|
|
310
|
-
def get_read_only_http_options(self) ->
|
311
|
-
copied =
|
418
|
+
def get_read_only_http_options(self) -> HttpOptionsDict:
|
419
|
+
copied = HttpOptionsDict()
|
420
|
+
if isinstance(self._http_options, BaseModel):
|
421
|
+
self._http_options = self._http_options.model_dump()
|
312
422
|
copied.update(self._http_options)
|
313
423
|
return copied
|
314
424
|
|
@@ -317,7 +427,7 @@ class ApiClient:
|
|
317
427
|
http_method: str,
|
318
428
|
path: str,
|
319
429
|
request_dict: dict[str, object],
|
320
|
-
http_options:
|
430
|
+
http_options: HttpOptionsDict = None,
|
321
431
|
):
|
322
432
|
http_request = self._build_request(
|
323
433
|
http_method, path, request_dict, http_options
|
@@ -332,7 +442,7 @@ class ApiClient:
|
|
332
442
|
http_method: str,
|
333
443
|
path: str,
|
334
444
|
request_dict: dict[str, object],
|
335
|
-
http_options:
|
445
|
+
http_options: HttpOptionsDict = None,
|
336
446
|
):
|
337
447
|
http_request = self._build_request(
|
338
448
|
http_method, path, request_dict, http_options
|
@@ -349,7 +459,7 @@ class ApiClient:
|
|
349
459
|
http_method: str,
|
350
460
|
path: str,
|
351
461
|
request_dict: dict[str, object],
|
352
|
-
http_options:
|
462
|
+
http_options: HttpOptionsDict = None,
|
353
463
|
) -> dict[str, object]:
|
354
464
|
http_request = self._build_request(
|
355
465
|
http_method, path, request_dict, http_options
|
@@ -365,7 +475,7 @@ class ApiClient:
|
|
365
475
|
http_method: str,
|
366
476
|
path: str,
|
367
477
|
request_dict: dict[str, object],
|
368
|
-
http_options:
|
478
|
+
http_options: HttpOptionsDict = None,
|
369
479
|
):
|
370
480
|
http_request = self._build_request(
|
371
481
|
http_method, path, request_dict, http_options
|
@@ -464,13 +574,12 @@ class ApiClient:
|
|
464
574
|
pass
|
465
575
|
|
466
576
|
|
577
|
+
# TODO(b/389693448): Cleanup datetime hacks.
|
467
578
|
class RequestJsonEncoder(json.JSONEncoder):
|
468
579
|
"""Encode bytes as strings without modify its content."""
|
469
580
|
|
470
581
|
def default(self, o):
|
471
|
-
if isinstance(o,
|
472
|
-
return o.decode()
|
473
|
-
elif isinstance(o, datetime.datetime):
|
582
|
+
if isinstance(o, datetime.datetime):
|
474
583
|
# This Zulu time format is used by the Vertex AI API and the test recorder
|
475
584
|
# Using strftime works well, but we want to align with the replay encoder.
|
476
585
|
# o.astimezone(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
@@ -58,8 +58,8 @@ def _raise_for_nullable_if_mldev(schema: types.Schema):
|
|
58
58
|
)
|
59
59
|
|
60
60
|
|
61
|
-
def _raise_if_schema_unsupported(
|
62
|
-
if not
|
61
|
+
def _raise_if_schema_unsupported(variant: str, schema: types.Schema):
|
62
|
+
if not variant == 'VERTEX_AI':
|
63
63
|
_raise_for_any_of_if_mldev(schema)
|
64
64
|
_raise_for_default_if_mldev(schema)
|
65
65
|
_raise_for_nullable_if_mldev(schema)
|
@@ -112,7 +112,7 @@ def _is_default_value_compatible(
|
|
112
112
|
|
113
113
|
|
114
114
|
def _parse_schema_from_parameter(
|
115
|
-
|
115
|
+
variant: str, param: inspect.Parameter, func_name: str
|
116
116
|
) -> types.Schema:
|
117
117
|
"""parse schema from parameter.
|
118
118
|
|
@@ -130,7 +130,7 @@ def _parse_schema_from_parameter(
|
|
130
130
|
raise ValueError(default_value_error_msg)
|
131
131
|
schema.default = param.default
|
132
132
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
133
|
-
_raise_if_schema_unsupported(
|
133
|
+
_raise_if_schema_unsupported(variant, schema)
|
134
134
|
return schema
|
135
135
|
if (
|
136
136
|
isinstance(param.annotation, typing_types.UnionType)
|
@@ -149,7 +149,7 @@ def _parse_schema_from_parameter(
|
|
149
149
|
schema.nullable = True
|
150
150
|
continue
|
151
151
|
schema_in_any_of = _parse_schema_from_parameter(
|
152
|
-
|
152
|
+
variant,
|
153
153
|
inspect.Parameter(
|
154
154
|
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
155
155
|
),
|
@@ -171,7 +171,7 @@ def _parse_schema_from_parameter(
|
|
171
171
|
if not _is_default_value_compatible(param.default, param.annotation):
|
172
172
|
raise ValueError(default_value_error_msg)
|
173
173
|
schema.default = param.default
|
174
|
-
_raise_if_schema_unsupported(
|
174
|
+
_raise_if_schema_unsupported(variant, schema)
|
175
175
|
return schema
|
176
176
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
177
177
|
param.annotation, typing_types.GenericAlias
|
@@ -184,7 +184,7 @@ def _parse_schema_from_parameter(
|
|
184
184
|
if not _is_default_value_compatible(param.default, param.annotation):
|
185
185
|
raise ValueError(default_value_error_msg)
|
186
186
|
schema.default = param.default
|
187
|
-
_raise_if_schema_unsupported(
|
187
|
+
_raise_if_schema_unsupported(variant, schema)
|
188
188
|
return schema
|
189
189
|
if origin is Literal:
|
190
190
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -197,12 +197,12 @@ def _parse_schema_from_parameter(
|
|
197
197
|
if not _is_default_value_compatible(param.default, param.annotation):
|
198
198
|
raise ValueError(default_value_error_msg)
|
199
199
|
schema.default = param.default
|
200
|
-
_raise_if_schema_unsupported(
|
200
|
+
_raise_if_schema_unsupported(variant, schema)
|
201
201
|
return schema
|
202
202
|
if origin is list:
|
203
203
|
schema.type = 'ARRAY'
|
204
204
|
schema.items = _parse_schema_from_parameter(
|
205
|
-
|
205
|
+
variant,
|
206
206
|
inspect.Parameter(
|
207
207
|
'item',
|
208
208
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -214,7 +214,7 @@ def _parse_schema_from_parameter(
|
|
214
214
|
if not _is_default_value_compatible(param.default, param.annotation):
|
215
215
|
raise ValueError(default_value_error_msg)
|
216
216
|
schema.default = param.default
|
217
|
-
_raise_if_schema_unsupported(
|
217
|
+
_raise_if_schema_unsupported(variant, schema)
|
218
218
|
return schema
|
219
219
|
if origin is Union:
|
220
220
|
schema.any_of = []
|
@@ -225,7 +225,7 @@ def _parse_schema_from_parameter(
|
|
225
225
|
schema.nullable = True
|
226
226
|
continue
|
227
227
|
schema_in_any_of = _parse_schema_from_parameter(
|
228
|
-
|
228
|
+
variant,
|
229
229
|
inspect.Parameter(
|
230
230
|
'item',
|
231
231
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -249,7 +249,7 @@ def _parse_schema_from_parameter(
|
|
249
249
|
if not _is_default_value_compatible(param.default, param.annotation):
|
250
250
|
raise ValueError(default_value_error_msg)
|
251
251
|
schema.default = param.default
|
252
|
-
_raise_if_schema_unsupported(
|
252
|
+
_raise_if_schema_unsupported(variant, schema)
|
253
253
|
return schema
|
254
254
|
# all other generic alias will be invoked in raise branch
|
255
255
|
if (
|
@@ -266,7 +266,7 @@ def _parse_schema_from_parameter(
|
|
266
266
|
schema.properties = {}
|
267
267
|
for field_name, field_info in param.annotation.model_fields.items():
|
268
268
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
269
|
-
|
269
|
+
variant,
|
270
270
|
inspect.Parameter(
|
271
271
|
field_name,
|
272
272
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -274,7 +274,7 @@ def _parse_schema_from_parameter(
|
|
274
274
|
),
|
275
275
|
func_name,
|
276
276
|
)
|
277
|
-
_raise_if_schema_unsupported(
|
277
|
+
_raise_if_schema_unsupported(variant, schema)
|
278
278
|
return schema
|
279
279
|
raise ValueError(
|
280
280
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
google/genai/_common.py
CHANGED
@@ -17,7 +17,6 @@
|
|
17
17
|
|
18
18
|
import base64
|
19
19
|
import datetime
|
20
|
-
import json
|
21
20
|
import typing
|
22
21
|
from typing import Union
|
23
22
|
import uuid
|
@@ -116,7 +115,7 @@ def get_value_by_path(data: object, keys: list[str]):
|
|
116
115
|
class BaseModule:
|
117
116
|
|
118
117
|
def __init__(self, api_client_: _api_client.ApiClient):
|
119
|
-
self.
|
118
|
+
self._api_client = api_client_
|
120
119
|
|
121
120
|
|
122
121
|
def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
|
@@ -190,6 +189,8 @@ class BaseModel(pydantic.BaseModel):
|
|
190
189
|
extra='forbid',
|
191
190
|
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
192
191
|
arbitrary_types_allowed=True,
|
192
|
+
ser_json_bytes='base64',
|
193
|
+
val_json_bytes='base64',
|
193
194
|
)
|
194
195
|
|
195
196
|
@classmethod
|
@@ -201,7 +202,10 @@ class BaseModel(pydantic.BaseModel):
|
|
201
202
|
# We will provide another mechanism to allow users to access these fields.
|
202
203
|
_remove_extra_fields(cls, response)
|
203
204
|
validated_response = cls.model_validate(response)
|
204
|
-
return
|
205
|
+
return validated_response
|
206
|
+
|
207
|
+
def to_json_dict(self) -> dict[str, object]:
|
208
|
+
return self.model_dump(exclude_none=True, mode='json')
|
205
209
|
|
206
210
|
|
207
211
|
def timestamped_unique_name() -> str:
|
@@ -217,40 +221,21 @@ def timestamped_unique_name() -> str:
|
|
217
221
|
|
218
222
|
def apply_base64_encoding(data: dict[str, object]) -> dict[str, object]:
|
219
223
|
"""Applies base64 encoding to bytes values in the given data."""
|
220
|
-
return process_bytes_fields(data, encode=True)
|
221
|
-
|
222
|
-
|
223
|
-
def apply_base64_decoding(data: dict[str, object]) -> dict[str, object]:
|
224
|
-
"""Applies base64 decoding to bytes values in the given data."""
|
225
|
-
return process_bytes_fields(data, encode=False)
|
226
|
-
|
227
|
-
|
228
|
-
def apply_base64_decoding_for_model(data: BaseModel) -> BaseModel:
|
229
|
-
d = data.model_dump(exclude_none=True)
|
230
|
-
d = apply_base64_decoding(d)
|
231
|
-
return data.model_validate(d)
|
232
|
-
|
233
|
-
|
234
|
-
def process_bytes_fields(data: dict[str, object], encode=True) -> dict[str, object]:
|
235
224
|
processed_data = {}
|
236
225
|
if not isinstance(data, dict):
|
237
226
|
return data
|
238
227
|
for key, value in data.items():
|
239
228
|
if isinstance(value, bytes):
|
240
|
-
|
241
|
-
processed_data[key] = base64.b64encode(value)
|
242
|
-
else:
|
243
|
-
processed_data[key] = base64.b64decode(value)
|
229
|
+
processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii')
|
244
230
|
elif isinstance(value, dict):
|
245
|
-
processed_data[key] =
|
231
|
+
processed_data[key] = apply_base64_encoding(value)
|
246
232
|
elif isinstance(value, list):
|
247
|
-
if
|
248
|
-
processed_data[key] = [
|
249
|
-
|
250
|
-
|
233
|
+
if all(isinstance(v, bytes) for v in value):
|
234
|
+
processed_data[key] = [
|
235
|
+
base64.urlsafe_b64encode(v).decode('ascii') for v in value
|
236
|
+
]
|
251
237
|
else:
|
252
|
-
processed_data[key] = [
|
238
|
+
processed_data[key] = [apply_base64_encoding(v) for v in value]
|
253
239
|
else:
|
254
240
|
processed_data[key] = value
|
255
241
|
return processed_data
|
256
|
-
|