google-genai 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +20 -0
- google/genai/_api_client.py +467 -0
- google/genai/_automatic_function_calling_util.py +341 -0
- google/genai/_common.py +256 -0
- google/genai/_extra_utils.py +295 -0
- google/genai/_replay_api_client.py +478 -0
- google/genai/_test_api_client.py +149 -0
- google/genai/_transformers.py +438 -0
- google/genai/batches.py +1041 -0
- google/genai/caches.py +1830 -0
- google/genai/chats.py +184 -0
- google/genai/client.py +277 -0
- google/genai/errors.py +110 -0
- google/genai/files.py +1211 -0
- google/genai/live.py +629 -0
- google/genai/models.py +5307 -0
- google/genai/pagers.py +245 -0
- google/genai/tunings.py +1366 -0
- google/genai/types.py +7639 -0
- google_genai-0.0.1.dist-info/LICENSE +202 -0
- google_genai-0.0.1.dist-info/METADATA +763 -0
- google_genai-0.0.1.dist-info/RECORD +24 -0
- google_genai-0.0.1.dist-info/WHEEL +5 -0
- google_genai-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,341 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
import inspect
|
17
|
+
import types as typing_types
|
18
|
+
from typing import Any, Callable, Literal, Union, _GenericAlias, get_args, get_origin
|
19
|
+
import pydantic
|
20
|
+
from . import types
|
21
|
+
|
22
|
+
_py_builtin_type_to_schema_type = {
|
23
|
+
str: 'STRING',
|
24
|
+
int: 'INTEGER',
|
25
|
+
float: 'NUMBER',
|
26
|
+
bool: 'BOOLEAN',
|
27
|
+
list: 'ARRAY',
|
28
|
+
dict: 'OBJECT',
|
29
|
+
}
|
30
|
+
|
31
|
+
|
32
|
+
def _is_builtin_primitive_or_compound(
|
33
|
+
annotation: inspect.Parameter.annotation,
|
34
|
+
) -> bool:
|
35
|
+
return annotation in _py_builtin_type_to_schema_type.keys()
|
36
|
+
|
37
|
+
|
38
|
+
def _raise_for_any_of_if_mldev(schema: types.Schema):
|
39
|
+
if schema.any_of:
|
40
|
+
raise ValueError(
|
41
|
+
'AnyOf is not supported in function declaration schema for Google AI.'
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
def _raise_for_default_if_mldev(schema: types.Schema):
|
46
|
+
if schema.default is not None:
|
47
|
+
raise ValueError(
|
48
|
+
'Default value is not supported in function declaration schema for'
|
49
|
+
' Google AI.'
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def _raise_for_nullable_if_mldev(schema: types.Schema):
|
54
|
+
if schema.nullable:
|
55
|
+
raise ValueError(
|
56
|
+
'Nullable is not supported in function declaration schema for'
|
57
|
+
' Google AI.'
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def _raise_if_schema_unsupported(client, schema: types.Schema):
|
62
|
+
if not client.vertexai:
|
63
|
+
_raise_for_any_of_if_mldev(schema)
|
64
|
+
_raise_for_default_if_mldev(schema)
|
65
|
+
_raise_for_nullable_if_mldev(schema)
|
66
|
+
|
67
|
+
|
68
|
+
def _is_default_value_compatible(
|
69
|
+
default_value: Any, annotation: inspect.Parameter.annotation
|
70
|
+
) -> bool:
|
71
|
+
# None type is expected to be handled external to this function
|
72
|
+
if _is_builtin_primitive_or_compound(annotation):
|
73
|
+
return isinstance(default_value, annotation)
|
74
|
+
|
75
|
+
if (
|
76
|
+
isinstance(annotation, _GenericAlias)
|
77
|
+
or isinstance(annotation, typing_types.GenericAlias)
|
78
|
+
or isinstance(annotation, typing_types.UnionType)
|
79
|
+
):
|
80
|
+
origin = get_origin(annotation)
|
81
|
+
if origin in (Union, typing_types.UnionType):
|
82
|
+
return any(
|
83
|
+
_is_default_value_compatible(default_value, arg)
|
84
|
+
for arg in get_args(annotation)
|
85
|
+
)
|
86
|
+
|
87
|
+
if origin is dict:
|
88
|
+
return isinstance(default_value, dict)
|
89
|
+
|
90
|
+
if origin is list:
|
91
|
+
if not isinstance(default_value, list):
|
92
|
+
return False
|
93
|
+
# most tricky case, element in list is union type
|
94
|
+
# need to apply any logic within all
|
95
|
+
# see test case test_generic_alias_complex_array_with_default_value
|
96
|
+
# a: typing.List[int | str | float | bool]
|
97
|
+
# default_value: [1, 'a', 1.1, True]
|
98
|
+
return all(
|
99
|
+
any(
|
100
|
+
_is_default_value_compatible(item, arg)
|
101
|
+
for arg in get_args(annotation)
|
102
|
+
)
|
103
|
+
for item in default_value
|
104
|
+
)
|
105
|
+
|
106
|
+
if origin is Literal:
|
107
|
+
return default_value in get_args(annotation)
|
108
|
+
|
109
|
+
# return False for any other unrecognized annotation
|
110
|
+
# let caller handle the raise
|
111
|
+
return False
|
112
|
+
|
113
|
+
|
114
|
+
def _parse_schema_from_parameter(
|
115
|
+
client, param: inspect.Parameter, func_name: str
|
116
|
+
) -> types.Schema:
|
117
|
+
"""parse schema from parameter.
|
118
|
+
|
119
|
+
from the simplest case to the most complex case.
|
120
|
+
"""
|
121
|
+
schema = types.Schema()
|
122
|
+
default_value_error_msg = (
|
123
|
+
f'Default value {param.default} of parameter {param} of function'
|
124
|
+
f' {func_name} is not compatible with the parameter annotation'
|
125
|
+
f' {param.annotation}.'
|
126
|
+
)
|
127
|
+
if _is_builtin_primitive_or_compound(param.annotation):
|
128
|
+
if param.default is not inspect.Parameter.empty:
|
129
|
+
if not _is_default_value_compatible(param.default, param.annotation):
|
130
|
+
raise ValueError(default_value_error_msg)
|
131
|
+
schema.default = param.default
|
132
|
+
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
133
|
+
_raise_if_schema_unsupported(client, schema)
|
134
|
+
return schema
|
135
|
+
if (
|
136
|
+
isinstance(param.annotation, typing_types.UnionType)
|
137
|
+
# only parse simple UnionType, example int | str | float | bool
|
138
|
+
# complex types.UnionType will be invoked in raise branch
|
139
|
+
and all(
|
140
|
+
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
|
141
|
+
for arg in get_args(param.annotation)
|
142
|
+
)
|
143
|
+
):
|
144
|
+
schema.type = 'OBJECT'
|
145
|
+
schema.any_of = []
|
146
|
+
unique_types = set()
|
147
|
+
for arg in get_args(param.annotation):
|
148
|
+
if arg.__name__ == 'NoneType': # Optional type
|
149
|
+
schema.nullable = True
|
150
|
+
continue
|
151
|
+
schema_in_any_of = _parse_schema_from_parameter(
|
152
|
+
client,
|
153
|
+
inspect.Parameter(
|
154
|
+
'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
|
155
|
+
),
|
156
|
+
func_name,
|
157
|
+
)
|
158
|
+
if (
|
159
|
+
schema_in_any_of.model_dump_json(exclude_none=True)
|
160
|
+
not in unique_types
|
161
|
+
):
|
162
|
+
schema.any_of.append(schema_in_any_of)
|
163
|
+
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
|
164
|
+
if len(schema.any_of) == 1: # param: list | None -> Array
|
165
|
+
schema.type = schema.any_of[0].type
|
166
|
+
schema.any_of = None
|
167
|
+
if (
|
168
|
+
param.default is not inspect.Parameter.empty
|
169
|
+
and param.default is not None
|
170
|
+
):
|
171
|
+
if not _is_default_value_compatible(param.default, param.annotation):
|
172
|
+
raise ValueError(default_value_error_msg)
|
173
|
+
# TODO: b/379715133 - handle pydantic model default value
|
174
|
+
schema.default = param.default
|
175
|
+
_raise_if_schema_unsupported(client, schema)
|
176
|
+
return schema
|
177
|
+
if isinstance(param.annotation, _GenericAlias) or isinstance(
|
178
|
+
param.annotation, typing_types.GenericAlias
|
179
|
+
):
|
180
|
+
origin = get_origin(param.annotation)
|
181
|
+
args = get_args(param.annotation)
|
182
|
+
if origin is dict:
|
183
|
+
schema.type = 'OBJECT'
|
184
|
+
if param.default is not inspect.Parameter.empty:
|
185
|
+
if not _is_default_value_compatible(param.default, param.annotation):
|
186
|
+
raise ValueError(default_value_error_msg)
|
187
|
+
schema.default = param.default
|
188
|
+
_raise_if_schema_unsupported(client, schema)
|
189
|
+
return schema
|
190
|
+
if origin is Literal:
|
191
|
+
if not all(isinstance(arg, str) for arg in args):
|
192
|
+
raise ValueError(
|
193
|
+
f'Literal type {param.annotation} must be a list of strings.'
|
194
|
+
)
|
195
|
+
schema.type = 'STRING'
|
196
|
+
schema.enum = list(args)
|
197
|
+
if param.default is not inspect.Parameter.empty:
|
198
|
+
if not _is_default_value_compatible(param.default, param.annotation):
|
199
|
+
raise ValueError(default_value_error_msg)
|
200
|
+
schema.default = param.default
|
201
|
+
_raise_if_schema_unsupported(client, schema)
|
202
|
+
return schema
|
203
|
+
if origin is list:
|
204
|
+
schema.type = 'ARRAY'
|
205
|
+
schema.items = _parse_schema_from_parameter(
|
206
|
+
client,
|
207
|
+
inspect.Parameter(
|
208
|
+
'item',
|
209
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
210
|
+
annotation=args[0],
|
211
|
+
),
|
212
|
+
func_name,
|
213
|
+
)
|
214
|
+
if param.default is not inspect.Parameter.empty:
|
215
|
+
if not _is_default_value_compatible(param.default, param.annotation):
|
216
|
+
raise ValueError(default_value_error_msg)
|
217
|
+
schema.default = param.default
|
218
|
+
_raise_if_schema_unsupported(client, schema)
|
219
|
+
return schema
|
220
|
+
if origin is Union:
|
221
|
+
schema.any_of = []
|
222
|
+
schema.type = 'OBJECT'
|
223
|
+
unique_types = set()
|
224
|
+
for arg in args:
|
225
|
+
if arg.__name__ == 'NoneType': # Optional type
|
226
|
+
schema.nullable = True
|
227
|
+
continue
|
228
|
+
schema_in_any_of = _parse_schema_from_parameter(
|
229
|
+
client,
|
230
|
+
inspect.Parameter(
|
231
|
+
'item',
|
232
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
233
|
+
annotation=arg,
|
234
|
+
),
|
235
|
+
func_name,
|
236
|
+
)
|
237
|
+
if (
|
238
|
+
schema_in_any_of.model_dump_json(exclude_none=True)
|
239
|
+
not in unique_types
|
240
|
+
):
|
241
|
+
schema.any_of.append(schema_in_any_of)
|
242
|
+
unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True))
|
243
|
+
if len(schema.any_of) == 1: # param: Union[List, None] -> Array
|
244
|
+
schema.type = schema.any_of[0].type
|
245
|
+
schema.any_of = None
|
246
|
+
if (
|
247
|
+
param.default is not None
|
248
|
+
and param.default is not inspect.Parameter.empty
|
249
|
+
):
|
250
|
+
if not _is_default_value_compatible(param.default, param.annotation):
|
251
|
+
raise ValueError(default_value_error_msg)
|
252
|
+
schema.default = param.default
|
253
|
+
_raise_if_schema_unsupported(client, schema)
|
254
|
+
return schema
|
255
|
+
# all other generic alias will be invoked in raise branch
|
256
|
+
if (
|
257
|
+
inspect.isclass(param.annotation)
|
258
|
+
# for user defined class, we only support pydantic model
|
259
|
+
and issubclass(param.annotation, pydantic.BaseModel)
|
260
|
+
):
|
261
|
+
if param.default is not inspect.Parameter.empty:
|
262
|
+
# TODO: b/379715133 - handle pydantic model default value
|
263
|
+
raise ValueError(
|
264
|
+
f'Default value {param.default} of Pydantic model{param} of function'
|
265
|
+
f' {func_name} is not supported.'
|
266
|
+
)
|
267
|
+
schema.type = 'OBJECT'
|
268
|
+
schema.properties = {}
|
269
|
+
for field_name, field_info in param.annotation.model_fields.items():
|
270
|
+
schema.properties[field_name] = _parse_schema_from_parameter(
|
271
|
+
client,
|
272
|
+
inspect.Parameter(
|
273
|
+
field_name,
|
274
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
275
|
+
annotation=field_info.annotation,
|
276
|
+
),
|
277
|
+
func_name,
|
278
|
+
)
|
279
|
+
_raise_if_schema_unsupported(client, schema)
|
280
|
+
return schema
|
281
|
+
raise ValueError(
|
282
|
+
f'Failed to parse the parameter {param} of function {func_name} for'
|
283
|
+
' automatic function calling.Automatic function calling works best with'
|
284
|
+
' simpler function signature schema,consider manually parse your'
|
285
|
+
f' function declaration for function {func_name}.'
|
286
|
+
)
|
287
|
+
|
288
|
+
|
289
|
+
def _get_required_fields(schema: types.Schema) -> list[str]:
|
290
|
+
if not schema.properties:
|
291
|
+
return
|
292
|
+
return [
|
293
|
+
field_name
|
294
|
+
for field_name, field_schema in schema.properties.items()
|
295
|
+
if not field_schema.nullable and field_schema.default is None
|
296
|
+
]
|
297
|
+
|
298
|
+
|
299
|
+
def function_to_declaration(
|
300
|
+
client, func: Callable
|
301
|
+
) -> types.FunctionDeclaration:
|
302
|
+
"""Converts a function to a FunctionDeclaration."""
|
303
|
+
parameters_properties = {}
|
304
|
+
for name, param in inspect.signature(func).parameters.items():
|
305
|
+
if param.kind in (
|
306
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
307
|
+
inspect.Parameter.KEYWORD_ONLY,
|
308
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
309
|
+
):
|
310
|
+
schema = _parse_schema_from_parameter(client, param, func.__name__)
|
311
|
+
parameters_properties[name] = schema
|
312
|
+
declaration = types.FunctionDeclaration(
|
313
|
+
name=func.__name__,
|
314
|
+
description=func.__doc__,
|
315
|
+
)
|
316
|
+
if parameters_properties:
|
317
|
+
declaration.parameters = types.Schema(
|
318
|
+
type='OBJECT',
|
319
|
+
properties=parameters_properties,
|
320
|
+
)
|
321
|
+
if client.vertexai:
|
322
|
+
declaration.parameters.required = _get_required_fields(
|
323
|
+
declaration.parameters
|
324
|
+
)
|
325
|
+
if not client.vertexai:
|
326
|
+
return declaration
|
327
|
+
|
328
|
+
return_annotation = inspect.signature(func).return_annotation
|
329
|
+
if return_annotation is inspect._empty:
|
330
|
+
return declaration
|
331
|
+
|
332
|
+
declaration.response = _parse_schema_from_parameter(
|
333
|
+
client,
|
334
|
+
inspect.Parameter(
|
335
|
+
'return_value',
|
336
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
337
|
+
annotation=return_annotation,
|
338
|
+
),
|
339
|
+
func.__name__,
|
340
|
+
)
|
341
|
+
return declaration
|
google/genai/_common.py
ADDED
@@ -0,0 +1,256 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
"""Common utilities for the SDK."""
|
17
|
+
|
18
|
+
import base64
|
19
|
+
import datetime
|
20
|
+
import json
|
21
|
+
import typing
|
22
|
+
from typing import Union
|
23
|
+
import uuid
|
24
|
+
|
25
|
+
import pydantic
|
26
|
+
from pydantic import alias_generators
|
27
|
+
|
28
|
+
from . import _api_client
|
29
|
+
|
30
|
+
|
31
|
+
def set_value_by_path(data, keys, value):
|
32
|
+
"""Examples:
|
33
|
+
|
34
|
+
set_value_by_path({}, ['a', 'b'], v)
|
35
|
+
-> {'a': {'b': v}}
|
36
|
+
set_value_by_path({}, ['a', 'b[]', c], [v1, v2])
|
37
|
+
-> {'a': {'b': [{'c': v1}, {'c': v2}]}}
|
38
|
+
set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3)
|
39
|
+
-> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}}
|
40
|
+
"""
|
41
|
+
if value is None:
|
42
|
+
return
|
43
|
+
for i, key in enumerate(keys[:-1]):
|
44
|
+
if key.endswith('[]'):
|
45
|
+
key_name = key[:-2]
|
46
|
+
if key_name not in data:
|
47
|
+
if isinstance(value, list):
|
48
|
+
data[key_name] = [{} for _ in range(len(value))]
|
49
|
+
else:
|
50
|
+
raise ValueError(
|
51
|
+
f'value {value} must be a list given an array path {key}'
|
52
|
+
)
|
53
|
+
if isinstance(value, list):
|
54
|
+
for j, d in enumerate(data[key_name]):
|
55
|
+
set_value_by_path(d, keys[i + 1 :], value[j])
|
56
|
+
else:
|
57
|
+
for d in data[key_name]:
|
58
|
+
set_value_by_path(d, keys[i + 1 :], value)
|
59
|
+
return
|
60
|
+
|
61
|
+
data = data.setdefault(key, {})
|
62
|
+
|
63
|
+
existing_data = data.get(keys[-1])
|
64
|
+
# If there is an existing value, merge, not overwrite.
|
65
|
+
if existing_data is not None:
|
66
|
+
# Don't overwrite existing non-empty value with new empty value.
|
67
|
+
# This is triggered when handling tuning datasets.
|
68
|
+
if not value:
|
69
|
+
pass
|
70
|
+
# Don't fail when overwriting value with same value
|
71
|
+
elif value == existing_data:
|
72
|
+
pass
|
73
|
+
# Instead of overwriting dictionary with another dictionary, merge them.
|
74
|
+
# This is important for handling training and validation datasets in tuning.
|
75
|
+
elif isinstance(existing_data, dict) and isinstance(value, dict):
|
76
|
+
# Merging dictionaries. Consider deep merging in the future.
|
77
|
+
existing_data.update(value)
|
78
|
+
else:
|
79
|
+
raise ValueError(
|
80
|
+
f'Cannot set value for an existing key. Key: {keys[-1]};'
|
81
|
+
f' Existing value: {existing_data}; New value: {value}.'
|
82
|
+
)
|
83
|
+
else:
|
84
|
+
data[keys[-1]] = value
|
85
|
+
|
86
|
+
|
87
|
+
def get_value_by_path(data: object, keys: list[str]):
|
88
|
+
"""Examples:
|
89
|
+
|
90
|
+
get_value_by_path({'a': {'b': v}}, ['a', 'b'])
|
91
|
+
-> v
|
92
|
+
get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c'])
|
93
|
+
-> [v1, v2]
|
94
|
+
"""
|
95
|
+
if keys == ['_self']:
|
96
|
+
return data
|
97
|
+
for i, key in enumerate(keys):
|
98
|
+
if not data:
|
99
|
+
return None
|
100
|
+
if key.endswith('[]'):
|
101
|
+
key_name = key[:-2]
|
102
|
+
if key_name in data:
|
103
|
+
return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
|
104
|
+
else:
|
105
|
+
return None
|
106
|
+
else:
|
107
|
+
if key in data:
|
108
|
+
data = data[key]
|
109
|
+
elif isinstance(data, BaseModel) and hasattr(data, key):
|
110
|
+
data = getattr(data, key)
|
111
|
+
else:
|
112
|
+
return None
|
113
|
+
return data
|
114
|
+
|
115
|
+
|
116
|
+
class BaseModule:
|
117
|
+
|
118
|
+
def __init__(self, api_client_: _api_client.ApiClient):
|
119
|
+
self.api_client = api_client_
|
120
|
+
|
121
|
+
|
122
|
+
def convert_to_dict(obj: dict[str, object]) -> dict[str, object]:
|
123
|
+
"""Recursively converts a given object to a dictionary.
|
124
|
+
|
125
|
+
If the object is a Pydantic model, it uses the model's `model_dump()` method.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
obj: The object to convert.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
A dictionary representation of the object.
|
132
|
+
"""
|
133
|
+
if isinstance(obj, pydantic.BaseModel):
|
134
|
+
return obj.model_dump(exclude_none=True)
|
135
|
+
elif isinstance(obj, dict):
|
136
|
+
return {key: convert_to_dict(value) for key, value in obj.items()}
|
137
|
+
elif isinstance(obj, list):
|
138
|
+
return [convert_to_dict(item) for item in obj]
|
139
|
+
else:
|
140
|
+
return obj
|
141
|
+
|
142
|
+
|
143
|
+
def _remove_extra_fields(
|
144
|
+
model: pydantic.BaseModel, response: dict[str, object]
|
145
|
+
) -> None:
|
146
|
+
"""Removes extra fields from the response that are not in the model.
|
147
|
+
|
148
|
+
Muates the response in place.
|
149
|
+
"""
|
150
|
+
|
151
|
+
key_values = list(response.items())
|
152
|
+
|
153
|
+
for key, value in key_values:
|
154
|
+
# Need to convert to snake case to match model fields names
|
155
|
+
# ex: UsageMetadata
|
156
|
+
alias_map = {
|
157
|
+
field_info.alias: key for key, field_info in model.model_fields.items()
|
158
|
+
}
|
159
|
+
|
160
|
+
if key not in model.model_fields and key not in alias_map:
|
161
|
+
response.pop(key)
|
162
|
+
continue
|
163
|
+
|
164
|
+
key = alias_map.get(key, key)
|
165
|
+
|
166
|
+
annotation = model.model_fields[key].annotation
|
167
|
+
|
168
|
+
# Get the BaseModel if Optional
|
169
|
+
if typing.get_origin(annotation) is Union:
|
170
|
+
annotation = typing.get_args(annotation)[0]
|
171
|
+
|
172
|
+
# if dict, assume BaseModel but also check that field type is not dict
|
173
|
+
# example: FunctionCall.args
|
174
|
+
if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
|
175
|
+
_remove_extra_fields(annotation, value)
|
176
|
+
elif isinstance(value, list):
|
177
|
+
for item in value:
|
178
|
+
# assume a list of dict is list of BaseModel
|
179
|
+
if isinstance(item, dict):
|
180
|
+
_remove_extra_fields(typing.get_args(annotation)[0], item)
|
181
|
+
|
182
|
+
|
183
|
+
class BaseModel(pydantic.BaseModel):
|
184
|
+
|
185
|
+
model_config = pydantic.ConfigDict(
|
186
|
+
alias_generator=alias_generators.to_camel,
|
187
|
+
populate_by_name=True,
|
188
|
+
from_attributes=True,
|
189
|
+
protected_namespaces={},
|
190
|
+
extra='forbid',
|
191
|
+
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
192
|
+
arbitrary_types_allowed=True,
|
193
|
+
)
|
194
|
+
|
195
|
+
@classmethod
|
196
|
+
def _from_response(
|
197
|
+
cls, response: dict[str, object], kwargs: dict[str, object]
|
198
|
+
) -> 'BaseModel':
|
199
|
+
# To maintain forward compatibility, we need to remove extra fields from
|
200
|
+
# the response.
|
201
|
+
# We will provide another mechanism to allow users to access these fields.
|
202
|
+
_remove_extra_fields(cls, response)
|
203
|
+
validated_response = cls.model_validate(response)
|
204
|
+
return apply_base64_decoding_for_model(validated_response)
|
205
|
+
|
206
|
+
|
207
|
+
def timestamped_unique_name() -> str:
|
208
|
+
"""Composes a timestamped unique name.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
A string representing a unique name.
|
212
|
+
"""
|
213
|
+
timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
214
|
+
unique_id = uuid.uuid4().hex[0:5]
|
215
|
+
return f'{timestamp}_{unique_id}'
|
216
|
+
|
217
|
+
|
218
|
+
def apply_base64_encoding(data: dict[str, object]) -> dict[str, object]:
|
219
|
+
"""Applies base64 encoding to bytes values in the given data."""
|
220
|
+
return process_bytes_fields(data, encode=True)
|
221
|
+
|
222
|
+
|
223
|
+
def apply_base64_decoding(data: dict[str, object]) -> dict[str, object]:
|
224
|
+
"""Applies base64 decoding to bytes values in the given data."""
|
225
|
+
return process_bytes_fields(data, encode=False)
|
226
|
+
|
227
|
+
|
228
|
+
def apply_base64_decoding_for_model(data: BaseModel) -> BaseModel:
|
229
|
+
d = data.model_dump(exclude_none=True)
|
230
|
+
d = apply_base64_decoding(d)
|
231
|
+
return data.model_validate(d)
|
232
|
+
|
233
|
+
|
234
|
+
def process_bytes_fields(data: dict[str, object], encode=True) -> dict[str, object]:
|
235
|
+
processed_data = {}
|
236
|
+
if not isinstance(data, dict):
|
237
|
+
return data
|
238
|
+
for key, value in data.items():
|
239
|
+
if isinstance(value, bytes):
|
240
|
+
if encode:
|
241
|
+
processed_data[key] = base64.b64encode(value)
|
242
|
+
else:
|
243
|
+
processed_data[key] = base64.b64decode(value)
|
244
|
+
elif isinstance(value, dict):
|
245
|
+
processed_data[key] = process_bytes_fields(value, encode)
|
246
|
+
elif isinstance(value, list):
|
247
|
+
if encode and all(isinstance(v, bytes) for v in value):
|
248
|
+
processed_data[key] = [base64.b64encode(v) for v in value]
|
249
|
+
elif all(isinstance(v, bytes) for v in value):
|
250
|
+
processed_data[key] = [base64.b64decode(v) for v in value]
|
251
|
+
else:
|
252
|
+
processed_data[key] = [process_bytes_fields(v, encode) for v in value]
|
253
|
+
else:
|
254
|
+
processed_data[key] = value
|
255
|
+
return processed_data
|
256
|
+
|