google-genai 1.0.0rc0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +45 -35
- google/genai/_automatic_function_calling_util.py +21 -18
- google/genai/_common.py +24 -1
- google/genai/_extra_utils.py +14 -8
- google/genai/_replay_api_client.py +2 -0
- google/genai/_transformers.py +59 -11
- google/genai/caches.py +4 -2
- google/genai/chats.py +24 -8
- google/genai/client.py +13 -12
- google/genai/errors.py +4 -0
- google/genai/files.py +18 -12
- google/genai/live.py +5 -0
- google/genai/models.py +321 -28
- google/genai/tunings.py +224 -60
- google/genai/types.py +105 -78
- google/genai/version.py +1 -1
- {google_genai-1.0.0rc0.dist-info → google_genai-1.2.0.dist-info}/METADATA +293 -149
- google_genai-1.2.0.dist-info/RECORD +27 -0
- google_genai-1.0.0rc0.dist-info/RECORD +0 -27
- {google_genai-1.0.0rc0.dist-info → google_genai-1.2.0.dist-info}/LICENSE +0 -0
- {google_genai-1.0.0rc0.dist-info → google_genai-1.2.0.dist-info}/WHEEL +0 -0
- {google_genai-1.0.0rc0.dist-info → google_genai-1.2.0.dist-info}/top_level.txt +0 -0
google/genai/_api_client.py
CHANGED
@@ -85,7 +85,11 @@ def _patch_http_options(
|
|
85
85
|
|
86
86
|
def _join_url_path(base_url: str, path: str) -> str:
|
87
87
|
parsed_base = urlparse(base_url)
|
88
|
-
base_path =
|
88
|
+
base_path = (
|
89
|
+
parsed_base.path[:-1]
|
90
|
+
if parsed_base.path.endswith('/')
|
91
|
+
else parsed_base.path
|
92
|
+
)
|
89
93
|
path = path[1:] if path.startswith('/') else path
|
90
94
|
return urlunparse(parsed_base._replace(path=base_path + '/' + path))
|
91
95
|
|
@@ -99,6 +103,19 @@ class HttpRequest:
|
|
99
103
|
timeout: Optional[float] = None
|
100
104
|
|
101
105
|
|
106
|
+
# TODO(b/394358912): Update this class to use a SDKResponse class that can be
|
107
|
+
# generated and used for all languages.
|
108
|
+
@dataclass
|
109
|
+
class BaseResponse:
|
110
|
+
http_headers: dict[str, str]
|
111
|
+
|
112
|
+
@property
|
113
|
+
def dict(self) -> dict[str, Any]:
|
114
|
+
if isinstance(self, dict):
|
115
|
+
return self
|
116
|
+
return {'httpHeaders': self.http_headers}
|
117
|
+
|
118
|
+
|
102
119
|
class HttpResponse:
|
103
120
|
|
104
121
|
def __init__(
|
@@ -188,12 +205,14 @@ class ApiClient:
|
|
188
205
|
if (project or location) and api_key:
|
189
206
|
# API cannot consume both project/location and api_key.
|
190
207
|
raise ValueError(
|
191
|
-
'Project/location and API key are mutually exclusive in the client
|
208
|
+
'Project/location and API key are mutually exclusive in the client'
|
209
|
+
' initializer.'
|
192
210
|
)
|
193
211
|
elif credentials and api_key:
|
194
212
|
# API cannot consume both credentials and api_key.
|
195
213
|
raise ValueError(
|
196
|
-
'Credentials and API key are mutually exclusive in the client
|
214
|
+
'Credentials and API key are mutually exclusive in the client'
|
215
|
+
' initializer.'
|
197
216
|
)
|
198
217
|
|
199
218
|
# Validate http_options if a dict is provided.
|
@@ -202,7 +221,7 @@ class ApiClient:
|
|
202
221
|
HttpOptions.model_validate(http_options)
|
203
222
|
except ValidationError as e:
|
204
223
|
raise ValueError(f'Invalid http_options: {e}')
|
205
|
-
elif
|
224
|
+
elif isinstance(http_options, HttpOptions):
|
206
225
|
http_options = http_options.model_dump()
|
207
226
|
|
208
227
|
# Retrieve implicitly set values from the environment.
|
@@ -256,17 +275,19 @@ class ApiClient:
|
|
256
275
|
'AI API.'
|
257
276
|
)
|
258
277
|
if self.api_key or self.location == 'global':
|
259
|
-
self._http_options['base_url'] =
|
260
|
-
f'https://aiplatform.googleapis.com/'
|
261
|
-
)
|
278
|
+
self._http_options['base_url'] = f'https://aiplatform.googleapis.com/'
|
262
279
|
else:
|
263
280
|
self._http_options['base_url'] = (
|
264
281
|
f'https://{self.location}-aiplatform.googleapis.com/'
|
265
282
|
)
|
266
283
|
self._http_options['api_version'] = 'v1beta1'
|
267
|
-
else: #
|
284
|
+
else: # Implicit initialization or missing arguments.
|
268
285
|
if not self.api_key:
|
269
|
-
raise ValueError(
|
286
|
+
raise ValueError(
|
287
|
+
'Missing key inputs argument! To use the Google AI API,'
|
288
|
+
'provide (`api_key`) arguments. To use the Google Cloud API,'
|
289
|
+
' provide (`vertexai`, `project` & `location`) arguments.'
|
290
|
+
)
|
270
291
|
self._http_options['base_url'] = (
|
271
292
|
'https://generativelanguage.googleapis.com/'
|
272
293
|
)
|
@@ -350,7 +371,7 @@ class ApiClient:
|
|
350
371
|
if self.vertexai and not self.api_key:
|
351
372
|
if not self._credentials:
|
352
373
|
self._credentials, _ = google.auth.default(
|
353
|
-
scopes=[
|
374
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
354
375
|
)
|
355
376
|
authed_session = AuthorizedSession(self._credentials)
|
356
377
|
authed_session.stream = stream
|
@@ -358,9 +379,7 @@ class ApiClient:
|
|
358
379
|
http_request.method.upper(),
|
359
380
|
http_request.url,
|
360
381
|
headers=http_request.headers,
|
361
|
-
data=json.dumps(http_request.data)
|
362
|
-
if http_request.data
|
363
|
-
else None,
|
382
|
+
data=json.dumps(http_request.data) if http_request.data else None,
|
364
383
|
timeout=http_request.timeout,
|
365
384
|
)
|
366
385
|
errors.APIError.raise_for_response(response)
|
@@ -402,7 +421,7 @@ class ApiClient:
|
|
402
421
|
if self.vertexai:
|
403
422
|
if not self._credentials:
|
404
423
|
self._credentials, _ = google.auth.default(
|
405
|
-
scopes=[
|
424
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
406
425
|
)
|
407
426
|
return await asyncio.to_thread(
|
408
427
|
self._request,
|
@@ -434,18 +453,12 @@ class ApiClient:
|
|
434
453
|
http_method, path, request_dict, http_options
|
435
454
|
)
|
436
455
|
response = self._request(http_request, stream=False)
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
elif (
|
444
|
-
isinstance(http_options, dict)
|
445
|
-
and 'deprecated_response_payload' in http_options
|
446
|
-
):
|
447
|
-
response._copy_to_dict(http_options['deprecated_response_payload'])
|
448
|
-
return response.json
|
456
|
+
json_response = response.json
|
457
|
+
if not json_response:
|
458
|
+
base_response = BaseResponse(response.headers).dict
|
459
|
+
return base_response
|
460
|
+
|
461
|
+
return json_response
|
449
462
|
|
450
463
|
def request_streamed(
|
451
464
|
self,
|
@@ -459,10 +472,6 @@ class ApiClient:
|
|
459
472
|
)
|
460
473
|
|
461
474
|
session_response = self._request(http_request, stream=True)
|
462
|
-
if http_options and 'deprecated_response_payload' in http_options:
|
463
|
-
session_response._copy_to_dict(
|
464
|
-
http_options['deprecated_response_payload']
|
465
|
-
)
|
466
475
|
for chunk in session_response.segments():
|
467
476
|
yield chunk
|
468
477
|
|
@@ -478,9 +487,11 @@ class ApiClient:
|
|
478
487
|
)
|
479
488
|
|
480
489
|
result = await self._async_request(http_request=http_request, stream=False)
|
481
|
-
|
482
|
-
|
483
|
-
|
490
|
+
json_response = result.json
|
491
|
+
if not json_response:
|
492
|
+
base_response = BaseResponse(result.headers).dict
|
493
|
+
return base_response
|
494
|
+
return json_response
|
484
495
|
|
485
496
|
async def async_request_streamed(
|
486
497
|
self,
|
@@ -495,11 +506,10 @@ class ApiClient:
|
|
495
506
|
|
496
507
|
response = await self._async_request(http_request=http_request, stream=True)
|
497
508
|
|
498
|
-
if http_options and 'deprecated_response_payload' in http_options:
|
499
|
-
response._copy_to_dict(http_options['deprecated_response_payload'])
|
500
509
|
async def async_generator():
|
501
510
|
async for chunk in response:
|
502
511
|
yield chunk
|
512
|
+
|
503
513
|
return async_generator()
|
504
514
|
|
505
515
|
def upload_file(
|
@@ -17,10 +17,14 @@ import inspect
|
|
17
17
|
import sys
|
18
18
|
import types as builtin_types
|
19
19
|
import typing
|
20
|
-
from typing import Any, Callable,
|
20
|
+
from typing import _GenericAlias, Any, Callable, get_args, get_origin, Literal, Union
|
21
|
+
|
21
22
|
import pydantic
|
23
|
+
|
24
|
+
from . import _extra_utils
|
22
25
|
from . import types
|
23
26
|
|
27
|
+
|
24
28
|
if sys.version_info >= (3, 10):
|
25
29
|
VersionedUnionType = builtin_types.UnionType
|
26
30
|
else:
|
@@ -58,8 +62,8 @@ def _raise_for_default_if_mldev(schema: types.Schema):
|
|
58
62
|
)
|
59
63
|
|
60
64
|
|
61
|
-
def _raise_if_schema_unsupported(
|
62
|
-
if
|
65
|
+
def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
|
66
|
+
if api_option == 'GEMINI_API':
|
63
67
|
_raise_for_any_of_if_mldev(schema)
|
64
68
|
_raise_for_default_if_mldev(schema)
|
65
69
|
|
@@ -110,7 +114,7 @@ def _is_default_value_compatible(
|
|
110
114
|
|
111
115
|
|
112
116
|
def _parse_schema_from_parameter(
|
113
|
-
|
117
|
+
api_option: Literal['VERTEX_AI', 'GEMINI_API'],
|
114
118
|
param: inspect.Parameter,
|
115
119
|
func_name: str,
|
116
120
|
) -> types.Schema:
|
@@ -130,7 +134,7 @@ def _parse_schema_from_parameter(
|
|
130
134
|
raise ValueError(default_value_error_msg)
|
131
135
|
schema.default = param.default
|
132
136
|
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
133
|
-
_raise_if_schema_unsupported(
|
137
|
+
_raise_if_schema_unsupported(api_option, schema)
|
134
138
|
return schema
|
135
139
|
if (
|
136
140
|
isinstance(param.annotation, VersionedUnionType)
|
@@ -149,7 +153,7 @@ def _parse_schema_from_parameter(
|
|
149
153
|
schema.nullable = True
|
150
154
|
continue
|
151
155
|
schema_in_any_of = _parse_schema_from_parameter(
|
152
|
-
|
156
|
+
api_option,
|
153
157
|
inspect.Parameter(
|
154
158
|
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
155
159
|
),
|
@@ -171,7 +175,7 @@ def _parse_schema_from_parameter(
|
|
171
175
|
if not _is_default_value_compatible(param.default, param.annotation):
|
172
176
|
raise ValueError(default_value_error_msg)
|
173
177
|
schema.default = param.default
|
174
|
-
_raise_if_schema_unsupported(
|
178
|
+
_raise_if_schema_unsupported(api_option, schema)
|
175
179
|
return schema
|
176
180
|
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
177
181
|
param.annotation, builtin_types.GenericAlias
|
@@ -184,7 +188,7 @@ def _parse_schema_from_parameter(
|
|
184
188
|
if not _is_default_value_compatible(param.default, param.annotation):
|
185
189
|
raise ValueError(default_value_error_msg)
|
186
190
|
schema.default = param.default
|
187
|
-
_raise_if_schema_unsupported(
|
191
|
+
_raise_if_schema_unsupported(api_option, schema)
|
188
192
|
return schema
|
189
193
|
if origin is Literal:
|
190
194
|
if not all(isinstance(arg, str) for arg in args):
|
@@ -197,12 +201,12 @@ def _parse_schema_from_parameter(
|
|
197
201
|
if not _is_default_value_compatible(param.default, param.annotation):
|
198
202
|
raise ValueError(default_value_error_msg)
|
199
203
|
schema.default = param.default
|
200
|
-
_raise_if_schema_unsupported(
|
204
|
+
_raise_if_schema_unsupported(api_option, schema)
|
201
205
|
return schema
|
202
206
|
if origin is list:
|
203
207
|
schema.type = 'ARRAY'
|
204
208
|
schema.items = _parse_schema_from_parameter(
|
205
|
-
|
209
|
+
api_option,
|
206
210
|
inspect.Parameter(
|
207
211
|
'item',
|
208
212
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -214,7 +218,7 @@ def _parse_schema_from_parameter(
|
|
214
218
|
if not _is_default_value_compatible(param.default, param.annotation):
|
215
219
|
raise ValueError(default_value_error_msg)
|
216
220
|
schema.default = param.default
|
217
|
-
_raise_if_schema_unsupported(
|
221
|
+
_raise_if_schema_unsupported(api_option, schema)
|
218
222
|
return schema
|
219
223
|
if origin is Union:
|
220
224
|
schema.any_of = []
|
@@ -229,7 +233,7 @@ def _parse_schema_from_parameter(
|
|
229
233
|
schema.nullable = True
|
230
234
|
continue
|
231
235
|
schema_in_any_of = _parse_schema_from_parameter(
|
232
|
-
|
236
|
+
api_option,
|
233
237
|
inspect.Parameter(
|
234
238
|
'item',
|
235
239
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -264,13 +268,12 @@ def _parse_schema_from_parameter(
|
|
264
268
|
if not _is_default_value_compatible(param.default, param.annotation):
|
265
269
|
raise ValueError(default_value_error_msg)
|
266
270
|
schema.default = param.default
|
267
|
-
_raise_if_schema_unsupported(
|
271
|
+
_raise_if_schema_unsupported(api_option, schema)
|
268
272
|
return schema
|
269
273
|
# all other generic alias will be invoked in raise branch
|
270
274
|
if (
|
271
|
-
inspect.isclass(param.annotation)
|
272
275
|
# for user defined class, we only support pydantic model
|
273
|
-
|
276
|
+
_extra_utils.is_annotation_pydantic_model(param.annotation)
|
274
277
|
):
|
275
278
|
if (
|
276
279
|
param.default is not inspect.Parameter.empty
|
@@ -281,7 +284,7 @@ def _parse_schema_from_parameter(
|
|
281
284
|
schema.properties = {}
|
282
285
|
for field_name, field_info in param.annotation.model_fields.items():
|
283
286
|
schema.properties[field_name] = _parse_schema_from_parameter(
|
284
|
-
|
287
|
+
api_option,
|
285
288
|
inspect.Parameter(
|
286
289
|
field_name,
|
287
290
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
@@ -289,9 +292,9 @@ def _parse_schema_from_parameter(
|
|
289
292
|
),
|
290
293
|
func_name,
|
291
294
|
)
|
292
|
-
if
|
295
|
+
if api_option == 'VERTEX_AI':
|
293
296
|
schema.required = _get_required_fields(schema)
|
294
|
-
_raise_if_schema_unsupported(
|
297
|
+
_raise_if_schema_unsupported(api_option, schema)
|
295
298
|
return schema
|
296
299
|
raise ValueError(
|
297
300
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
google/genai/_common.py
CHANGED
@@ -18,6 +18,7 @@
|
|
18
18
|
import base64
|
19
19
|
import datetime
|
20
20
|
import enum
|
21
|
+
import functools
|
21
22
|
import typing
|
22
23
|
from typing import Union
|
23
24
|
import uuid
|
@@ -27,6 +28,7 @@ import pydantic
|
|
27
28
|
from pydantic import alias_generators
|
28
29
|
|
29
30
|
from . import _api_client
|
31
|
+
from . import errors
|
30
32
|
|
31
33
|
|
32
34
|
def set_value_by_path(data, keys, value):
|
@@ -218,7 +220,8 @@ class CaseInSensitiveEnum(str, enum.Enum):
|
|
218
220
|
warnings.warn(f"{value} is not a valid {cls.__name__}")
|
219
221
|
try:
|
220
222
|
# Creating a enum instance based on the value
|
221
|
-
|
223
|
+
# We need to use super() to avoid infinite recursion.
|
224
|
+
unknown_enum_val = super().__new__(cls, value)
|
222
225
|
unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
|
223
226
|
unknown_enum_val._value_ = value # pylint: disable=protected-access
|
224
227
|
return unknown_enum_val
|
@@ -273,3 +276,23 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
|
273
276
|
else:
|
274
277
|
processed_data[key] = value
|
275
278
|
return processed_data
|
279
|
+
|
280
|
+
|
281
|
+
def experimental_warning(message: str):
|
282
|
+
"""Experimental warning, only warns once."""
|
283
|
+
def decorator(func):
|
284
|
+
warning_done = False
|
285
|
+
@functools.wraps(func)
|
286
|
+
def wrapper(*args, **kwargs):
|
287
|
+
nonlocal warning_done
|
288
|
+
if not warning_done:
|
289
|
+
warning_done = True
|
290
|
+
warnings.warn(
|
291
|
+
message=message,
|
292
|
+
category=errors.ExperimentalWarning,
|
293
|
+
stacklevel=2,
|
294
|
+
)
|
295
|
+
return func(*args, **kwargs)
|
296
|
+
return wrapper
|
297
|
+
return decorator
|
298
|
+
|
google/genai/_extra_utils.py
CHANGED
@@ -108,16 +108,22 @@ def convert_number_values_for_function_call_args(
|
|
108
108
|
return args
|
109
109
|
|
110
110
|
|
111
|
-
def
|
112
|
-
|
113
|
-
|
114
|
-
|
111
|
+
def is_annotation_pydantic_model(annotation: Any) -> bool:
|
112
|
+
try:
|
113
|
+
return inspect.isclass(annotation) and issubclass(
|
114
|
+
annotation, pydantic.BaseModel
|
115
|
+
)
|
116
|
+
# for python 3.10 and below, inspect.isclass(annotation) has inconsistent
|
117
|
+
# results with versions above. for example, inspect.isclass(dict[str, int]) is
|
118
|
+
# True in 3.10 and below but False in 3.11 and above.
|
119
|
+
except TypeError:
|
120
|
+
return False
|
115
121
|
|
116
122
|
|
117
123
|
def convert_if_exist_pydantic_model(
|
118
124
|
value: Any, annotation: Any, param_name: str, func_name: str
|
119
125
|
) -> Any:
|
120
|
-
if isinstance(value, dict) and
|
126
|
+
if isinstance(value, dict) and is_annotation_pydantic_model(annotation):
|
121
127
|
try:
|
122
128
|
return annotation(**value)
|
123
129
|
except pydantic.ValidationError as e:
|
@@ -146,7 +152,7 @@ def convert_if_exist_pydantic_model(
|
|
146
152
|
if (
|
147
153
|
(get_args(arg) and get_origin(arg) is list)
|
148
154
|
or isinstance(value, arg)
|
149
|
-
or (isinstance(value, dict) and
|
155
|
+
or (isinstance(value, dict) and is_annotation_pydantic_model(arg))
|
150
156
|
):
|
151
157
|
try:
|
152
158
|
return convert_if_exist_pydantic_model(
|
@@ -265,8 +271,8 @@ def should_disable_afc(
|
|
265
271
|
and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
|
266
272
|
):
|
267
273
|
logging.warning(
|
268
|
-
'`automatic_function_calling.disable` is set to `True`.
|
269
|
-
' `automatic_function_calling.maximum_remote_calls` is
|
274
|
+
'`automatic_function_calling.disable` is set to `True`. And'
|
275
|
+
' `automatic_function_calling.maximum_remote_calls` is a'
|
270
276
|
' positive number'
|
271
277
|
f' {config_model.automatic_function_calling.maximum_remote_calls}.'
|
272
278
|
' Disabling automatic function calling. If you want to enable'
|
@@ -362,6 +362,8 @@ class ReplayApiClient(ApiClient):
|
|
362
362
|
if self._should_update_replay():
|
363
363
|
if isinstance(response_model, list):
|
364
364
|
response_model = response_model[0]
|
365
|
+
if response_model and 'http_headers' in response_model.model_fields:
|
366
|
+
response_model.http_headers.pop('Date', None)
|
365
367
|
interaction.response.sdk_response_segments.append(
|
366
368
|
response_model.model_dump(exclude_none=True)
|
367
369
|
)
|
google/genai/_transformers.py
CHANGED
@@ -21,10 +21,10 @@ from enum import Enum, EnumMeta
|
|
21
21
|
import inspect
|
22
22
|
import io
|
23
23
|
import re
|
24
|
+
import sys
|
24
25
|
import time
|
25
26
|
import typing
|
26
27
|
from typing import Any, GenericAlias, Optional, Union
|
27
|
-
import sys
|
28
28
|
|
29
29
|
if typing.TYPE_CHECKING:
|
30
30
|
import PIL.Image
|
@@ -205,12 +205,17 @@ def t_caches_model(api_client: _api_client.ApiClient, model: str):
|
|
205
205
|
def pil_to_blob(img) -> types.Blob:
|
206
206
|
try:
|
207
207
|
import PIL.PngImagePlugin
|
208
|
+
|
208
209
|
PngImagePlugin = PIL.PngImagePlugin
|
209
210
|
except ImportError:
|
210
211
|
PngImagePlugin = None
|
211
212
|
|
212
213
|
bytesio = io.BytesIO()
|
213
|
-
if
|
214
|
+
if (
|
215
|
+
PngImagePlugin is not None
|
216
|
+
and isinstance(img, PngImagePlugin.PngImageFile)
|
217
|
+
or img.mode == 'RGBA'
|
218
|
+
):
|
214
219
|
img.save(bytesio, format='PNG')
|
215
220
|
mime_type = 'image/png'
|
216
221
|
else:
|
@@ -374,15 +379,18 @@ def handle_null_fields(schema: dict[str, Any]):
|
|
374
379
|
schema['anyOf'].remove({'type': 'null'})
|
375
380
|
if len(schema['anyOf']) == 1:
|
376
381
|
# If there is only one type left after removing null, remove the anyOf field.
|
377
|
-
|
378
|
-
|
382
|
+
for key,val in schema['anyOf'][0].items():
|
383
|
+
schema[key] = val
|
379
384
|
del schema['anyOf']
|
380
385
|
|
381
386
|
|
382
387
|
def process_schema(
|
383
388
|
schema: dict[str, Any],
|
384
389
|
client: Optional[_api_client.ApiClient] = None,
|
385
|
-
defs: Optional[dict[str, Any]]=None
|
390
|
+
defs: Optional[dict[str, Any]] = None,
|
391
|
+
*,
|
392
|
+
order_properties: bool = True,
|
393
|
+
):
|
386
394
|
"""Updates the schema and each sub-schema inplace to be API-compatible.
|
387
395
|
|
388
396
|
- Removes the `title` field from the schema if the client is not vertexai.
|
@@ -446,9 +454,17 @@ def process_schema(
|
|
446
454
|
|
447
455
|
if schema.get('default') is not None:
|
448
456
|
raise ValueError(
|
449
|
-
'Default value is not supported in the response schema for the
|
457
|
+
'Default value is not supported in the response schema for the Gemini API.'
|
450
458
|
)
|
451
459
|
|
460
|
+
if schema.get('title') == 'PlaceholderLiteralEnum':
|
461
|
+
schema.pop('title', None)
|
462
|
+
|
463
|
+
# If a dict is provided directly to response_schema, it may use `any_of`
|
464
|
+
# instead of `anyOf`. Otherwise model_json_schema() uses `anyOf`
|
465
|
+
if schema.get('any_of', None) is not None:
|
466
|
+
schema['anyOf'] = schema.pop('any_of')
|
467
|
+
|
452
468
|
if defs is None:
|
453
469
|
defs = schema.pop('$defs', {})
|
454
470
|
for _, sub_schema in defs.items():
|
@@ -456,6 +472,15 @@ def process_schema(
|
|
456
472
|
|
457
473
|
handle_null_fields(schema)
|
458
474
|
|
475
|
+
# After removing null fields, Optional fields with only one possible type
|
476
|
+
# will have a $ref key that needs to be flattened
|
477
|
+
# For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
|
478
|
+
if schema.get('$ref', None):
|
479
|
+
ref = defs[schema.get('$ref').split('defs/')[-1]]
|
480
|
+
for schema_key in list(ref.keys()):
|
481
|
+
schema[schema_key] = ref[schema_key]
|
482
|
+
del schema['$ref']
|
483
|
+
|
459
484
|
any_of = schema.get('anyOf', None)
|
460
485
|
if any_of is not None:
|
461
486
|
if not client.vertexai:
|
@@ -478,6 +503,16 @@ def process_schema(
|
|
478
503
|
schema_type = schema_type.value
|
479
504
|
schema_type = schema_type.upper()
|
480
505
|
|
506
|
+
# model_json_schema() returns a schema with a 'const' field when a Literal with one value is provided as a pydantic field
|
507
|
+
# For example `genre: Literal['action']` becomes: {'const': 'action', 'title': 'Genre', 'type': 'string'}
|
508
|
+
const = schema.get('const', None)
|
509
|
+
if const is not None:
|
510
|
+
if schema_type == 'STRING':
|
511
|
+
schema['enum'] = [const]
|
512
|
+
del schema['const']
|
513
|
+
else:
|
514
|
+
raise ValueError('Literal values must be strings.')
|
515
|
+
|
481
516
|
if schema_type == 'OBJECT':
|
482
517
|
properties = schema.get('properties', None)
|
483
518
|
if properties is None:
|
@@ -490,6 +525,16 @@ def process_schema(
|
|
490
525
|
ref = defs[ref_key.split('defs/')[-1]]
|
491
526
|
process_schema(ref, client, defs)
|
492
527
|
properties[name] = ref
|
528
|
+
if (
|
529
|
+
len(properties.items()) > 1
|
530
|
+
and order_properties
|
531
|
+
and all(
|
532
|
+
ordering_key not in schema
|
533
|
+
for ordering_key in ['property_ordering', 'propertyOrdering']
|
534
|
+
)
|
535
|
+
):
|
536
|
+
property_names = list(properties.keys())
|
537
|
+
schema['property_ordering'] = property_names
|
493
538
|
elif schema_type == 'ARRAY':
|
494
539
|
sub_schema = schema.get('items', None)
|
495
540
|
if sub_schema is None:
|
@@ -502,6 +547,7 @@ def process_schema(
|
|
502
547
|
process_schema(ref, client, defs)
|
503
548
|
schema['items'] = ref
|
504
549
|
|
550
|
+
|
505
551
|
def _process_enum(
|
506
552
|
enum: EnumMeta, client: Optional[_api_client.ApiClient] = None
|
507
553
|
) -> types.Schema:
|
@@ -511,6 +557,7 @@ def _process_enum(
|
|
511
557
|
f'Enum member {member.name} value must be a string, got'
|
512
558
|
f' {type(member.value)}'
|
513
559
|
)
|
560
|
+
|
514
561
|
class Placeholder(pydantic.BaseModel):
|
515
562
|
placeholder: enum
|
516
563
|
|
@@ -526,7 +573,7 @@ def t_schema(
|
|
526
573
|
if not origin:
|
527
574
|
return None
|
528
575
|
if isinstance(origin, dict):
|
529
|
-
process_schema(origin, client)
|
576
|
+
process_schema(origin, client, order_properties=False)
|
530
577
|
return types.Schema.model_validate(origin)
|
531
578
|
if isinstance(origin, EnumMeta):
|
532
579
|
return _process_enum(origin, client)
|
@@ -535,15 +582,15 @@ def t_schema(
|
|
535
582
|
# response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation
|
536
583
|
raise ValueError(f'Unsupported schema type.')
|
537
584
|
schema = origin.model_dump(exclude_unset=True)
|
538
|
-
process_schema(schema, client)
|
585
|
+
process_schema(schema, client, order_properties=False)
|
539
586
|
return types.Schema.model_validate(schema)
|
540
587
|
|
541
588
|
if (
|
542
589
|
# in Python 3.9 Generic alias list[int] counts as a type,
|
543
590
|
# and breaks issubclass because it's not a class.
|
544
|
-
not isinstance(origin, GenericAlias)
|
545
|
-
isinstance(origin, type)
|
546
|
-
issubclass(origin, pydantic.BaseModel)
|
591
|
+
not isinstance(origin, GenericAlias)
|
592
|
+
and isinstance(origin, type)
|
593
|
+
and issubclass(origin, pydantic.BaseModel)
|
547
594
|
):
|
548
595
|
schema = origin.model_json_schema()
|
549
596
|
process_schema(schema, client)
|
@@ -554,6 +601,7 @@ def t_schema(
|
|
554
601
|
or isinstance(origin, VersionedUnionType)
|
555
602
|
or typing.get_origin(origin) in _UNION_TYPES
|
556
603
|
):
|
604
|
+
|
557
605
|
class Placeholder(pydantic.BaseModel):
|
558
606
|
placeholder: origin
|
559
607
|
|
google/genai/caches.py
CHANGED
@@ -172,8 +172,10 @@ def _Schema_to_mldev(
|
|
172
172
|
raise ValueError('example parameter is not supported in Gemini API.')
|
173
173
|
|
174
174
|
if getv(from_object, ['property_ordering']) is not None:
|
175
|
-
|
176
|
-
|
175
|
+
setv(
|
176
|
+
to_object,
|
177
|
+
['propertyOrdering'],
|
178
|
+
getv(from_object, ['property_ordering']),
|
177
179
|
)
|
178
180
|
|
179
181
|
if getv(from_object, ['pattern']) is not None:
|