google-genai 1.7.0__py3-none-any.whl → 1.53.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +4 -2
- google/genai/_adapters.py +55 -0
- google/genai/_api_client.py +1301 -299
- google/genai/_api_module.py +1 -1
- google/genai/_automatic_function_calling_util.py +54 -33
- google/genai/_base_transformers.py +26 -0
- google/genai/_base_url.py +50 -0
- google/genai/_common.py +560 -59
- google/genai/_extra_utils.py +371 -38
- google/genai/_live_converters.py +1467 -0
- google/genai/_local_tokenizer_loader.py +214 -0
- google/genai/_mcp_utils.py +117 -0
- google/genai/_operations_converters.py +394 -0
- google/genai/_replay_api_client.py +204 -92
- google/genai/_test_api_client.py +1 -1
- google/genai/_tokens_converters.py +520 -0
- google/genai/_transformers.py +633 -233
- google/genai/batches.py +1733 -538
- google/genai/caches.py +678 -1012
- google/genai/chats.py +48 -38
- google/genai/client.py +142 -15
- google/genai/documents.py +532 -0
- google/genai/errors.py +141 -35
- google/genai/file_search_stores.py +1296 -0
- google/genai/files.py +312 -744
- google/genai/live.py +617 -367
- google/genai/live_music.py +197 -0
- google/genai/local_tokenizer.py +395 -0
- google/genai/models.py +3598 -3116
- google/genai/operations.py +201 -362
- google/genai/pagers.py +23 -7
- google/genai/py.typed +1 -0
- google/genai/tokens.py +362 -0
- google/genai/tunings.py +1274 -496
- google/genai/types.py +14535 -5454
- google/genai/version.py +2 -2
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +736 -234
- google_genai-1.53.0.dist-info/RECORD +41 -0
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +1 -1
- google_genai-1.7.0.dist-info/RECORD +0 -27
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info/licenses}/LICENSE +0 -0
- {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
google/genai/_common.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -16,22 +16,32 @@
|
|
|
16
16
|
"""Common utilities for the SDK."""
|
|
17
17
|
|
|
18
18
|
import base64
|
|
19
|
+
import collections.abc
|
|
19
20
|
import datetime
|
|
20
21
|
import enum
|
|
21
22
|
import functools
|
|
23
|
+
import logging
|
|
24
|
+
import re
|
|
22
25
|
import typing
|
|
23
|
-
from typing import Any, Union
|
|
26
|
+
from typing import Any, Callable, FrozenSet, Optional, Union, get_args, get_origin
|
|
24
27
|
import uuid
|
|
25
28
|
import warnings
|
|
26
|
-
|
|
27
29
|
import pydantic
|
|
28
30
|
from pydantic import alias_generators
|
|
31
|
+
from typing_extensions import TypeAlias
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger('google_genai._common')
|
|
34
|
+
|
|
35
|
+
StringDict: TypeAlias = dict[str, Any]
|
|
36
|
+
|
|
29
37
|
|
|
30
|
-
|
|
31
|
-
|
|
38
|
+
class ExperimentalWarning(Warning):
|
|
39
|
+
"""Warning for experimental features."""
|
|
32
40
|
|
|
33
41
|
|
|
34
|
-
def set_value_by_path(
|
|
42
|
+
def set_value_by_path(
|
|
43
|
+
data: Optional[dict[Any, Any]], keys: list[str], value: Any
|
|
44
|
+
) -> None:
|
|
35
45
|
"""Examples:
|
|
36
46
|
|
|
37
47
|
set_value_by_path({}, ['a', 'b'], v)
|
|
@@ -46,54 +56,66 @@ def set_value_by_path(data, keys, value):
|
|
|
46
56
|
for i, key in enumerate(keys[:-1]):
|
|
47
57
|
if key.endswith('[]'):
|
|
48
58
|
key_name = key[:-2]
|
|
49
|
-
if key_name not in data:
|
|
59
|
+
if data is not None and key_name not in data:
|
|
50
60
|
if isinstance(value, list):
|
|
51
61
|
data[key_name] = [{} for _ in range(len(value))]
|
|
52
62
|
else:
|
|
53
63
|
raise ValueError(
|
|
54
64
|
f'value {value} must be a list given an array path {key}'
|
|
55
65
|
)
|
|
56
|
-
if isinstance(value, list):
|
|
66
|
+
if isinstance(value, list) and data is not None:
|
|
57
67
|
for j, d in enumerate(data[key_name]):
|
|
58
68
|
set_value_by_path(d, keys[i + 1 :], value[j])
|
|
59
69
|
else:
|
|
60
|
-
|
|
61
|
-
|
|
70
|
+
if data is not None:
|
|
71
|
+
for d in data[key_name]:
|
|
72
|
+
set_value_by_path(d, keys[i + 1 :], value)
|
|
62
73
|
return
|
|
63
74
|
elif key.endswith('[0]'):
|
|
64
75
|
key_name = key[:-3]
|
|
65
|
-
if key_name not in data:
|
|
76
|
+
if data is not None and key_name not in data:
|
|
66
77
|
data[key_name] = [{}]
|
|
67
|
-
|
|
78
|
+
if data is not None:
|
|
79
|
+
set_value_by_path(data[key_name][0], keys[i + 1 :], value)
|
|
68
80
|
return
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
81
|
+
if data is not None:
|
|
82
|
+
data = data.setdefault(key, {})
|
|
83
|
+
|
|
84
|
+
if data is not None:
|
|
85
|
+
existing_data = data.get(keys[-1])
|
|
86
|
+
# If there is an existing value, merge, not overwrite.
|
|
87
|
+
if existing_data is not None:
|
|
88
|
+
# Don't overwrite existing non-empty value with new empty value.
|
|
89
|
+
# This is triggered when handling tuning datasets.
|
|
90
|
+
if not value:
|
|
91
|
+
pass
|
|
92
|
+
# Don't fail when overwriting value with same value
|
|
93
|
+
elif value == existing_data:
|
|
94
|
+
pass
|
|
95
|
+
# Instead of overwriting dictionary with another dictionary, merge them.
|
|
96
|
+
# This is important for handling training and validation datasets in tuning.
|
|
97
|
+
elif isinstance(existing_data, dict) and isinstance(value, dict):
|
|
98
|
+
# Merging dictionaries. Consider deep merging in the future.
|
|
99
|
+
existing_data.update(value)
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f'Cannot set value for an existing key. Key: {keys[-1]};'
|
|
103
|
+
f' Existing value: {existing_data}; New value: {value}.'
|
|
104
|
+
)
|
|
87
105
|
else:
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
106
|
+
if (
|
|
107
|
+
keys[-1] == '_self'
|
|
108
|
+
and isinstance(data, dict)
|
|
109
|
+
and isinstance(value, dict)
|
|
110
|
+
):
|
|
111
|
+
data.update(value)
|
|
112
|
+
else:
|
|
113
|
+
data[keys[-1]] = value
|
|
94
114
|
|
|
95
115
|
|
|
96
|
-
def get_value_by_path(
|
|
116
|
+
def get_value_by_path(
|
|
117
|
+
data: Any, keys: list[str], *, default_value: Any = None
|
|
118
|
+
) -> Any:
|
|
97
119
|
"""Examples:
|
|
98
120
|
|
|
99
121
|
get_value_by_path({'a': {'b': v}}, ['a', 'b'])
|
|
@@ -105,36 +127,141 @@ def get_value_by_path(data: Any, keys: list[str]):
|
|
|
105
127
|
return data
|
|
106
128
|
for i, key in enumerate(keys):
|
|
107
129
|
if not data:
|
|
108
|
-
return
|
|
130
|
+
return default_value
|
|
109
131
|
if key.endswith('[]'):
|
|
110
132
|
key_name = key[:-2]
|
|
111
133
|
if key_name in data:
|
|
112
|
-
return [
|
|
134
|
+
return [
|
|
135
|
+
get_value_by_path(d, keys[i + 1 :], default_value=default_value)
|
|
136
|
+
for d in data[key_name]
|
|
137
|
+
]
|
|
113
138
|
else:
|
|
114
|
-
return
|
|
139
|
+
return default_value
|
|
115
140
|
elif key.endswith('[0]'):
|
|
116
141
|
key_name = key[:-3]
|
|
117
142
|
if key_name in data and data[key_name]:
|
|
118
|
-
return get_value_by_path(
|
|
143
|
+
return get_value_by_path(
|
|
144
|
+
data[key_name][0], keys[i + 1 :], default_value=default_value
|
|
145
|
+
)
|
|
119
146
|
else:
|
|
120
|
-
return
|
|
147
|
+
return default_value
|
|
121
148
|
else:
|
|
122
149
|
if key in data:
|
|
123
150
|
data = data[key]
|
|
124
151
|
elif isinstance(data, BaseModel) and hasattr(data, key):
|
|
125
152
|
data = getattr(data, key)
|
|
126
153
|
else:
|
|
127
|
-
return
|
|
154
|
+
return default_value
|
|
128
155
|
return data
|
|
129
156
|
|
|
130
157
|
|
|
131
|
-
def
|
|
158
|
+
def move_value_by_path(data: Any, paths: dict[str, str]) -> None:
|
|
159
|
+
"""Moves values from source paths to destination paths.
|
|
160
|
+
|
|
161
|
+
Examples:
|
|
162
|
+
move_value_by_path(
|
|
163
|
+
{'requests': [{'content': v1}, {'content': v2}]},
|
|
164
|
+
{'requests[].*': 'requests[].request.*'}
|
|
165
|
+
)
|
|
166
|
+
-> {'requests': [{'request': {'content': v1}}, {'request': {'content':
|
|
167
|
+
v2}}]}
|
|
168
|
+
"""
|
|
169
|
+
for source_path, dest_path in paths.items():
|
|
170
|
+
source_keys = source_path.split('.')
|
|
171
|
+
dest_keys = dest_path.split('.')
|
|
172
|
+
|
|
173
|
+
# Determine keys to exclude from wildcard to avoid cyclic references
|
|
174
|
+
exclude_keys = set()
|
|
175
|
+
wildcard_idx = -1
|
|
176
|
+
for i, key in enumerate(source_keys):
|
|
177
|
+
if key == '*':
|
|
178
|
+
wildcard_idx = i
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
if wildcard_idx != -1 and len(dest_keys) > wildcard_idx:
|
|
182
|
+
# Extract the intermediate key between source and dest paths
|
|
183
|
+
# Example: source=['requests[]', '*'], dest=['requests[]', 'request', '*']
|
|
184
|
+
# We want to exclude 'request'
|
|
185
|
+
for i in range(wildcard_idx, len(dest_keys)):
|
|
186
|
+
key = dest_keys[i]
|
|
187
|
+
if key != '*' and not key.endswith('[]') and not key.endswith('[0]'):
|
|
188
|
+
exclude_keys.add(key)
|
|
189
|
+
|
|
190
|
+
# Move values recursively
|
|
191
|
+
_move_value_recursive(data, source_keys, dest_keys, 0, exclude_keys)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _move_value_recursive(
|
|
195
|
+
data: Any,
|
|
196
|
+
source_keys: list[str],
|
|
197
|
+
dest_keys: list[str],
|
|
198
|
+
key_idx: int,
|
|
199
|
+
exclude_keys: set[str],
|
|
200
|
+
) -> None:
|
|
201
|
+
"""Recursively moves values from source path to destination path."""
|
|
202
|
+
if key_idx >= len(source_keys):
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
key = source_keys[key_idx]
|
|
206
|
+
|
|
207
|
+
if key.endswith('[]'):
|
|
208
|
+
# Handle array iteration
|
|
209
|
+
key_name = key[:-2]
|
|
210
|
+
if key_name in data and isinstance(data[key_name], list):
|
|
211
|
+
for item in data[key_name]:
|
|
212
|
+
_move_value_recursive(
|
|
213
|
+
item, source_keys, dest_keys, key_idx + 1, exclude_keys
|
|
214
|
+
)
|
|
215
|
+
elif key == '*':
|
|
216
|
+
# Handle wildcard - move all fields
|
|
217
|
+
if isinstance(data, dict):
|
|
218
|
+
# Get all keys to move (excluding specified keys)
|
|
219
|
+
keys_to_move = [
|
|
220
|
+
k
|
|
221
|
+
for k in list(data.keys())
|
|
222
|
+
if not k.startswith('_') and k not in exclude_keys
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
# Collect values to move
|
|
226
|
+
values_to_move = {k: data[k] for k in keys_to_move}
|
|
227
|
+
|
|
228
|
+
# Set values at destination
|
|
229
|
+
for k, v in values_to_move.items():
|
|
230
|
+
# Build destination keys with the field name
|
|
231
|
+
new_dest_keys = []
|
|
232
|
+
for dk in dest_keys[key_idx:]:
|
|
233
|
+
if dk == '*':
|
|
234
|
+
new_dest_keys.append(k)
|
|
235
|
+
else:
|
|
236
|
+
new_dest_keys.append(dk)
|
|
237
|
+
set_value_by_path(data, new_dest_keys, v)
|
|
238
|
+
|
|
239
|
+
# Delete from source
|
|
240
|
+
for k in keys_to_move:
|
|
241
|
+
del data[k]
|
|
242
|
+
else:
|
|
243
|
+
# Navigate to next level
|
|
244
|
+
if key in data:
|
|
245
|
+
_move_value_recursive(
|
|
246
|
+
data[key], source_keys, dest_keys, key_idx + 1, exclude_keys
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def maybe_snake_to_camel(snake_str: str, convert: bool = True) -> str:
|
|
251
|
+
"""Converts a snake_case string to CamelCase, if convert is True."""
|
|
252
|
+
if not convert:
|
|
253
|
+
return snake_str
|
|
254
|
+
return re.sub(r'_([a-zA-Z])', lambda match: match.group(1).upper(), snake_str)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def convert_to_dict(obj: object, convert_keys: bool = False) -> Any:
|
|
132
258
|
"""Recursively converts a given object to a dictionary.
|
|
133
259
|
|
|
134
260
|
If the object is a Pydantic model, it uses the model's `model_dump()` method.
|
|
135
261
|
|
|
136
262
|
Args:
|
|
137
263
|
obj: The object to convert.
|
|
264
|
+
convert_keys: Whether to convert the keys from snake case to camel case.
|
|
138
265
|
|
|
139
266
|
Returns:
|
|
140
267
|
A dictionary representation of the object, a list of objects if a list is
|
|
@@ -142,18 +269,52 @@ def convert_to_dict(obj: object) -> Any:
|
|
|
142
269
|
model.
|
|
143
270
|
"""
|
|
144
271
|
if isinstance(obj, pydantic.BaseModel):
|
|
145
|
-
return obj.model_dump(exclude_none=True)
|
|
272
|
+
return convert_to_dict(obj.model_dump(exclude_none=True), convert_keys)
|
|
146
273
|
elif isinstance(obj, dict):
|
|
147
|
-
return {
|
|
274
|
+
return {
|
|
275
|
+
maybe_snake_to_camel(key, convert_keys): convert_to_dict(value)
|
|
276
|
+
for key, value in obj.items()
|
|
277
|
+
}
|
|
148
278
|
elif isinstance(obj, list):
|
|
149
|
-
return [convert_to_dict(item) for item in obj]
|
|
279
|
+
return [convert_to_dict(item, convert_keys) for item in obj]
|
|
150
280
|
else:
|
|
151
281
|
return obj
|
|
152
282
|
|
|
153
283
|
|
|
154
|
-
def
|
|
155
|
-
|
|
156
|
-
|
|
284
|
+
def _is_struct_type(annotation: type) -> bool:
|
|
285
|
+
"""Checks if the given annotation is list[dict[str, typing.Any]]
|
|
286
|
+
|
|
287
|
+
or typing.List[typing.Dict[str, typing.Any]].
|
|
288
|
+
|
|
289
|
+
This maps to Struct type in the API.
|
|
290
|
+
"""
|
|
291
|
+
outer_origin = get_origin(annotation)
|
|
292
|
+
outer_args = get_args(annotation)
|
|
293
|
+
|
|
294
|
+
if outer_origin is not list: # Python 3.9+ normalizes list
|
|
295
|
+
return False
|
|
296
|
+
|
|
297
|
+
if not outer_args or len(outer_args) != 1:
|
|
298
|
+
return False
|
|
299
|
+
|
|
300
|
+
inner_annotation = outer_args[0]
|
|
301
|
+
|
|
302
|
+
inner_origin = get_origin(inner_annotation)
|
|
303
|
+
inner_args = get_args(inner_annotation)
|
|
304
|
+
|
|
305
|
+
if inner_origin is not dict: # Python 3.9+ normalizes to dict
|
|
306
|
+
return False
|
|
307
|
+
|
|
308
|
+
if not inner_args or len(inner_args) != 2:
|
|
309
|
+
# dict should have exactly two type arguments
|
|
310
|
+
return False
|
|
311
|
+
|
|
312
|
+
# Check if the dict arguments are str and typing.Any
|
|
313
|
+
key_type, value_type = inner_args
|
|
314
|
+
return key_type is str and value_type is typing.Any
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _remove_extra_fields(model: Any, response: dict[str, object]) -> None:
|
|
157
318
|
"""Removes extra fields from the response that are not in the model.
|
|
158
319
|
|
|
159
320
|
Mutates the response in place.
|
|
@@ -185,14 +346,206 @@ def _remove_extra_fields(
|
|
|
185
346
|
if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
|
|
186
347
|
_remove_extra_fields(annotation, value)
|
|
187
348
|
elif isinstance(value, list):
|
|
349
|
+
if _is_struct_type(annotation):
|
|
350
|
+
continue
|
|
351
|
+
|
|
188
352
|
for item in value:
|
|
189
353
|
# assume a list of dict is list of BaseModel
|
|
190
354
|
if isinstance(item, dict):
|
|
191
355
|
_remove_extra_fields(typing.get_args(annotation)[0], item)
|
|
192
356
|
|
|
357
|
+
|
|
193
358
|
T = typing.TypeVar('T', bound='BaseModel')
|
|
194
359
|
|
|
195
360
|
|
|
361
|
+
def _pretty_repr(
|
|
362
|
+
obj: Any,
|
|
363
|
+
*,
|
|
364
|
+
indent_level: int = 0,
|
|
365
|
+
indent_delta: int = 2,
|
|
366
|
+
max_len: int = 100,
|
|
367
|
+
max_items: int = 5,
|
|
368
|
+
depth: int = 6,
|
|
369
|
+
visited: Optional[FrozenSet[int]] = None,
|
|
370
|
+
) -> str:
|
|
371
|
+
"""Returns a representation of the given object."""
|
|
372
|
+
if visited is None:
|
|
373
|
+
visited = frozenset()
|
|
374
|
+
|
|
375
|
+
obj_id = id(obj)
|
|
376
|
+
if obj_id in visited:
|
|
377
|
+
return '<... Circular reference ...>'
|
|
378
|
+
|
|
379
|
+
if depth < 0:
|
|
380
|
+
return '<... Max depth ...>'
|
|
381
|
+
|
|
382
|
+
visited = frozenset(list(visited) + [obj_id])
|
|
383
|
+
|
|
384
|
+
indent = ' ' * indent_level
|
|
385
|
+
next_indent_str = ' ' * (indent_level + indent_delta)
|
|
386
|
+
|
|
387
|
+
if isinstance(obj, pydantic.BaseModel):
|
|
388
|
+
cls_name = obj.__class__.__name__
|
|
389
|
+
items = []
|
|
390
|
+
# Sort fields for consistent output
|
|
391
|
+
fields = sorted(type(obj).model_fields)
|
|
392
|
+
|
|
393
|
+
for field_name in fields:
|
|
394
|
+
field_info = type(obj).model_fields[field_name]
|
|
395
|
+
if not field_info.repr: # Respect Field(repr=False)
|
|
396
|
+
continue
|
|
397
|
+
|
|
398
|
+
try:
|
|
399
|
+
value = getattr(obj, field_name)
|
|
400
|
+
except AttributeError:
|
|
401
|
+
continue
|
|
402
|
+
|
|
403
|
+
if value is None:
|
|
404
|
+
continue
|
|
405
|
+
|
|
406
|
+
value_repr = _pretty_repr(
|
|
407
|
+
value,
|
|
408
|
+
indent_level=indent_level + indent_delta,
|
|
409
|
+
indent_delta=indent_delta,
|
|
410
|
+
max_len=max_len,
|
|
411
|
+
max_items=max_items,
|
|
412
|
+
depth=depth - 1,
|
|
413
|
+
visited=visited,
|
|
414
|
+
)
|
|
415
|
+
items.append(f'{next_indent_str}{field_name}={value_repr}')
|
|
416
|
+
|
|
417
|
+
if not items:
|
|
418
|
+
return f'{cls_name}()'
|
|
419
|
+
return f'{cls_name}(\n' + ',\n'.join(items) + f'\n{indent})'
|
|
420
|
+
elif isinstance(obj, str):
|
|
421
|
+
if '\n' in obj:
|
|
422
|
+
escaped = obj.replace('"""', '\\"\\"\\"')
|
|
423
|
+
# Indent the multi-line string block contents
|
|
424
|
+
return f'"""{escaped}"""'
|
|
425
|
+
return repr(obj)
|
|
426
|
+
elif isinstance(obj, bytes):
|
|
427
|
+
if len(obj) > max_len:
|
|
428
|
+
return f"{repr(obj[:max_len-3])[:-1]}...'"
|
|
429
|
+
return repr(obj)
|
|
430
|
+
elif isinstance(obj, collections.abc.Mapping):
|
|
431
|
+
if not obj:
|
|
432
|
+
return '{}'
|
|
433
|
+
|
|
434
|
+
# Check if the next level of recursion for keys/values will exceed the depth limit.
|
|
435
|
+
if depth <= 0:
|
|
436
|
+
item_count_str = f"{len(obj)} item{'s' if len(obj) != 1 else ''}"
|
|
437
|
+
return f'{{<... {item_count_str} at Max depth ...>}}'
|
|
438
|
+
|
|
439
|
+
if len(obj) > max_items:
|
|
440
|
+
return f'<dict len={len(obj)}>'
|
|
441
|
+
|
|
442
|
+
items = []
|
|
443
|
+
try:
|
|
444
|
+
sorted_keys = sorted(obj.keys(), key=str)
|
|
445
|
+
except TypeError:
|
|
446
|
+
sorted_keys = list(obj.keys())
|
|
447
|
+
|
|
448
|
+
for k in sorted_keys:
|
|
449
|
+
v = obj[k]
|
|
450
|
+
k_repr = _pretty_repr(
|
|
451
|
+
k,
|
|
452
|
+
indent_level=indent_level + indent_delta,
|
|
453
|
+
indent_delta=indent_delta,
|
|
454
|
+
max_len=max_len,
|
|
455
|
+
max_items=max_items,
|
|
456
|
+
depth=depth - 1,
|
|
457
|
+
visited=visited,
|
|
458
|
+
)
|
|
459
|
+
v_repr = _pretty_repr(
|
|
460
|
+
v,
|
|
461
|
+
indent_level=indent_level + indent_delta,
|
|
462
|
+
indent_delta=indent_delta,
|
|
463
|
+
max_len=max_len,
|
|
464
|
+
max_items=max_items,
|
|
465
|
+
depth=depth - 1,
|
|
466
|
+
visited=visited,
|
|
467
|
+
)
|
|
468
|
+
items.append(f'{next_indent_str}{k_repr}: {v_repr}')
|
|
469
|
+
return f'{{\n' + ',\n'.join(items) + f'\n{indent}}}'
|
|
470
|
+
elif isinstance(obj, (list, tuple, set)):
|
|
471
|
+
return _format_collection(
|
|
472
|
+
obj,
|
|
473
|
+
indent_level=indent_level,
|
|
474
|
+
indent_delta=indent_delta,
|
|
475
|
+
max_len=max_len,
|
|
476
|
+
max_items=max_items,
|
|
477
|
+
depth=depth,
|
|
478
|
+
visited=visited,
|
|
479
|
+
)
|
|
480
|
+
else:
|
|
481
|
+
# Fallback to standard repr, indenting subsequent lines only
|
|
482
|
+
raw_repr = repr(obj)
|
|
483
|
+
# Replace newlines with newline + indent
|
|
484
|
+
return raw_repr.replace('\n', f'\n{next_indent_str}')
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def _format_collection(
|
|
488
|
+
obj: Any,
|
|
489
|
+
*,
|
|
490
|
+
indent_level: int,
|
|
491
|
+
indent_delta: int,
|
|
492
|
+
max_len: int,
|
|
493
|
+
max_items: int,
|
|
494
|
+
depth: int,
|
|
495
|
+
visited: FrozenSet[int],
|
|
496
|
+
) -> str:
|
|
497
|
+
"""Formats a collection (list, tuple, set)."""
|
|
498
|
+
if isinstance(obj, list):
|
|
499
|
+
brackets = ('[', ']')
|
|
500
|
+
internal_obj = obj
|
|
501
|
+
elif isinstance(obj, tuple):
|
|
502
|
+
brackets = ('(', ')')
|
|
503
|
+
internal_obj = list(obj)
|
|
504
|
+
elif isinstance(obj, set):
|
|
505
|
+
internal_obj = list(obj)
|
|
506
|
+
if obj:
|
|
507
|
+
brackets = ('{', '}')
|
|
508
|
+
else:
|
|
509
|
+
brackets = ('set(', ')')
|
|
510
|
+
else:
|
|
511
|
+
raise ValueError(f'Unsupported collection type: {type(obj)}')
|
|
512
|
+
|
|
513
|
+
if not internal_obj:
|
|
514
|
+
return brackets[0] + brackets[1]
|
|
515
|
+
|
|
516
|
+
# If the call to _pretty_repr for elements will have depth < 0
|
|
517
|
+
if depth <= 0:
|
|
518
|
+
item_count_str = f"{len(internal_obj)} item{'s'*(len(internal_obj)!=1)}"
|
|
519
|
+
return f'{brackets[0]}<... {item_count_str} at Max depth ...>{brackets[1]}'
|
|
520
|
+
|
|
521
|
+
indent = ' ' * indent_level
|
|
522
|
+
next_indent_str = ' ' * (indent_level + indent_delta)
|
|
523
|
+
elements = []
|
|
524
|
+
num_to_show = min(len(internal_obj), max_items)
|
|
525
|
+
|
|
526
|
+
for i in range(num_to_show):
|
|
527
|
+
elem = internal_obj[i]
|
|
528
|
+
elements.append(
|
|
529
|
+
next_indent_str
|
|
530
|
+
+ _pretty_repr(
|
|
531
|
+
elem,
|
|
532
|
+
indent_level=indent_level + indent_delta,
|
|
533
|
+
indent_delta=indent_delta,
|
|
534
|
+
max_len=max_len,
|
|
535
|
+
max_items=max_items,
|
|
536
|
+
depth=depth - 1,
|
|
537
|
+
visited=visited,
|
|
538
|
+
)
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
if len(internal_obj) > max_items:
|
|
542
|
+
elements.append(
|
|
543
|
+
f'{next_indent_str}<... {len(internal_obj) - max_items} more items ...>'
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
return f'{brackets[0]}\n' + ',\n'.join(elements) + f',\n{indent}{brackets[1]}'
|
|
547
|
+
|
|
548
|
+
|
|
196
549
|
class BaseModel(pydantic.BaseModel):
|
|
197
550
|
|
|
198
551
|
model_config = pydantic.ConfigDict(
|
|
@@ -205,17 +558,78 @@ class BaseModel(pydantic.BaseModel):
|
|
|
205
558
|
arbitrary_types_allowed=True,
|
|
206
559
|
ser_json_bytes='base64',
|
|
207
560
|
val_json_bytes='base64',
|
|
208
|
-
ignored_types=(typing.TypeVar,)
|
|
561
|
+
ignored_types=(typing.TypeVar,),
|
|
209
562
|
)
|
|
210
563
|
|
|
564
|
+
@pydantic.model_validator(mode='before')
|
|
565
|
+
@classmethod
|
|
566
|
+
def _check_field_type_mismatches(cls, data: Any) -> Any:
|
|
567
|
+
"""Check for type mismatches and warn before Pydantic processes the data."""
|
|
568
|
+
# Handle both dict and Pydantic model inputs
|
|
569
|
+
if not isinstance(data, (dict, pydantic.BaseModel)):
|
|
570
|
+
return data
|
|
571
|
+
|
|
572
|
+
for field_name, field_info in cls.model_fields.items():
|
|
573
|
+
if isinstance(data, dict):
|
|
574
|
+
value = data.get(field_name)
|
|
575
|
+
else:
|
|
576
|
+
value = getattr(data, field_name, None)
|
|
577
|
+
|
|
578
|
+
if value is None:
|
|
579
|
+
continue
|
|
580
|
+
|
|
581
|
+
expected_type = field_info.annotation
|
|
582
|
+
origin = get_origin(expected_type)
|
|
583
|
+
|
|
584
|
+
if origin is Union:
|
|
585
|
+
args = get_args(expected_type)
|
|
586
|
+
non_none_types = [arg for arg in args if arg is not type(None)]
|
|
587
|
+
if len(non_none_types) == 1:
|
|
588
|
+
expected_type = non_none_types[0]
|
|
589
|
+
|
|
590
|
+
if (isinstance(expected_type, type) and
|
|
591
|
+
get_origin(expected_type) is None and
|
|
592
|
+
issubclass(expected_type, pydantic.BaseModel) and
|
|
593
|
+
isinstance(value, pydantic.BaseModel) and
|
|
594
|
+
not isinstance(value, expected_type)):
|
|
595
|
+
logger.warning(
|
|
596
|
+
f"Type mismatch in {cls.__name__}.{field_name}: "
|
|
597
|
+
f"expected {expected_type.__name__}, got {type(value).__name__}"
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
return data
|
|
601
|
+
|
|
602
|
+
def __repr__(self) -> str:
|
|
603
|
+
try:
|
|
604
|
+
return _pretty_repr(self)
|
|
605
|
+
except Exception:
|
|
606
|
+
return super().__repr__()
|
|
607
|
+
|
|
211
608
|
@classmethod
|
|
212
609
|
def _from_response(
|
|
213
|
-
cls: typing.Type[T],
|
|
610
|
+
cls: typing.Type[T],
|
|
611
|
+
*,
|
|
612
|
+
response: dict[str, object],
|
|
613
|
+
kwargs: dict[str, object],
|
|
214
614
|
) -> T:
|
|
215
615
|
# To maintain forward compatibility, we need to remove extra fields from
|
|
216
616
|
# the response.
|
|
217
617
|
# We will provide another mechanism to allow users to access these fields.
|
|
218
|
-
|
|
618
|
+
|
|
619
|
+
# For Agent Engine we don't want to call _remove_all_fields because the
|
|
620
|
+
# user may pass a dict that is not a subclass of BaseModel.
|
|
621
|
+
# If more modules require we skip this, we may want a different approach
|
|
622
|
+
should_skip_removing_fields = (
|
|
623
|
+
kwargs is not None
|
|
624
|
+
and 'config' in kwargs
|
|
625
|
+
and kwargs['config'] is not None
|
|
626
|
+
and isinstance(kwargs['config'], dict)
|
|
627
|
+
and 'include_all_fields' in kwargs['config']
|
|
628
|
+
and kwargs['config']['include_all_fields']
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
if not should_skip_removing_fields:
|
|
632
|
+
_remove_extra_fields(cls, response)
|
|
219
633
|
validated_response = cls.model_validate(response)
|
|
220
634
|
return validated_response
|
|
221
635
|
|
|
@@ -227,14 +641,14 @@ class CaseInSensitiveEnum(str, enum.Enum):
|
|
|
227
641
|
"""Case insensitive enum."""
|
|
228
642
|
|
|
229
643
|
@classmethod
|
|
230
|
-
def _missing_(cls, value):
|
|
644
|
+
def _missing_(cls, value: Any) -> Any:
|
|
231
645
|
try:
|
|
232
646
|
return cls[value.upper()] # Try to access directly with uppercase
|
|
233
647
|
except KeyError:
|
|
234
648
|
try:
|
|
235
649
|
return cls[value.lower()] # Try to access directly with lowercase
|
|
236
650
|
except KeyError:
|
|
237
|
-
warnings.warn(f
|
|
651
|
+
warnings.warn(f'{value} is not a valid {cls.__name__}')
|
|
238
652
|
try:
|
|
239
653
|
# Creating a enum instance based on the value
|
|
240
654
|
# We need to use super() to avoid infinite recursion.
|
|
@@ -295,21 +709,108 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
|
|
295
709
|
return processed_data
|
|
296
710
|
|
|
297
711
|
|
|
298
|
-
def experimental_warning(
|
|
712
|
+
def experimental_warning(
|
|
713
|
+
message: str,
|
|
714
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
299
715
|
"""Experimental warning, only warns once."""
|
|
300
|
-
|
|
716
|
+
|
|
717
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
301
718
|
warning_done = False
|
|
719
|
+
|
|
302
720
|
@functools.wraps(func)
|
|
303
|
-
def wrapper(*args, **kwargs):
|
|
721
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
304
722
|
nonlocal warning_done
|
|
305
723
|
if not warning_done:
|
|
306
724
|
warning_done = True
|
|
307
725
|
warnings.warn(
|
|
308
726
|
message=message,
|
|
309
|
-
category=
|
|
727
|
+
category=ExperimentalWarning,
|
|
310
728
|
stacklevel=2,
|
|
311
729
|
)
|
|
312
730
|
return func(*args, **kwargs)
|
|
731
|
+
|
|
313
732
|
return wrapper
|
|
733
|
+
|
|
314
734
|
return decorator
|
|
315
735
|
|
|
736
|
+
|
|
737
|
+
def _normalize_key_for_matching(key_str: str) -> str:
|
|
738
|
+
"""Normalizes a key for case-insensitive and snake/camel matching."""
|
|
739
|
+
return key_str.replace('_', '').lower()
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def align_key_case(
|
|
743
|
+
target_dict: StringDict, update_dict: StringDict
|
|
744
|
+
) -> StringDict:
|
|
745
|
+
"""Aligns the keys of update_dict to the case of target_dict keys.
|
|
746
|
+
|
|
747
|
+
Args:
|
|
748
|
+
target_dict: The dictionary with the target key casing.
|
|
749
|
+
update_dict: The dictionary whose keys need to be aligned.
|
|
750
|
+
|
|
751
|
+
Returns:
|
|
752
|
+
A new dictionary with keys aligned to target_dict's key casing.
|
|
753
|
+
"""
|
|
754
|
+
aligned_update_dict: StringDict = {}
|
|
755
|
+
target_keys_map = {
|
|
756
|
+
_normalize_key_for_matching(key): key for key in target_dict.keys()
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
for key, value in update_dict.items():
|
|
760
|
+
normalized_update_key = _normalize_key_for_matching(key)
|
|
761
|
+
|
|
762
|
+
if normalized_update_key in target_keys_map:
|
|
763
|
+
aligned_key = target_keys_map[normalized_update_key]
|
|
764
|
+
else:
|
|
765
|
+
aligned_key = key
|
|
766
|
+
|
|
767
|
+
if isinstance(value, dict) and isinstance(
|
|
768
|
+
target_dict.get(aligned_key), dict
|
|
769
|
+
):
|
|
770
|
+
aligned_update_dict[aligned_key] = align_key_case(
|
|
771
|
+
target_dict[aligned_key], value
|
|
772
|
+
)
|
|
773
|
+
elif isinstance(value, list) and isinstance(
|
|
774
|
+
target_dict.get(aligned_key), list
|
|
775
|
+
):
|
|
776
|
+
# Direct assign as we treat update_dict list values as golden source.
|
|
777
|
+
aligned_update_dict[aligned_key] = value
|
|
778
|
+
else:
|
|
779
|
+
aligned_update_dict[aligned_key] = value
|
|
780
|
+
return aligned_update_dict
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
def recursive_dict_update(
|
|
784
|
+
target_dict: StringDict, update_dict: StringDict
|
|
785
|
+
) -> None:
|
|
786
|
+
"""Recursively updates a target dictionary with values from an update dictionary.
|
|
787
|
+
|
|
788
|
+
We don't enforce the updated dict values to have the same type with the
|
|
789
|
+
target_dict values except log warnings.
|
|
790
|
+
Users providing the update_dict should be responsible for constructing correct
|
|
791
|
+
data.
|
|
792
|
+
|
|
793
|
+
Args:
|
|
794
|
+
target_dict (dict): The dictionary to be updated.
|
|
795
|
+
update_dict (dict): The dictionary containing updates.
|
|
796
|
+
"""
|
|
797
|
+
# Python SDK http request may change in camel case or snake case:
|
|
798
|
+
# If the field is directly set via setv() function, then it is camel case;
|
|
799
|
+
# otherwise it is snake case.
|
|
800
|
+
# Align the update_dict key case to target_dict to ensure correct dict update.
|
|
801
|
+
aligned_update_dict = align_key_case(target_dict, update_dict)
|
|
802
|
+
for key, value in aligned_update_dict.items():
|
|
803
|
+
if (
|
|
804
|
+
key in target_dict
|
|
805
|
+
and isinstance(target_dict[key], dict)
|
|
806
|
+
and isinstance(value, dict)
|
|
807
|
+
):
|
|
808
|
+
recursive_dict_update(target_dict[key], value)
|
|
809
|
+
elif key in target_dict and not isinstance(target_dict[key], type(value)):
|
|
810
|
+
logger.warning(
|
|
811
|
+
f"Type mismatch for key '{key}'. Existing type:"
|
|
812
|
+
f' {type(target_dict[key])}, new type: {type(value)}. Overwriting.'
|
|
813
|
+
)
|
|
814
|
+
target_dict[key] = value
|
|
815
|
+
else:
|
|
816
|
+
target_dict[key] = value
|