google-genai 1.33.0__py3-none-any.whl → 1.53.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +361 -208
- google/genai/_common.py +260 -69
- google/genai/_extra_utils.py +142 -12
- google/genai/_live_converters.py +691 -2746
- google/genai/_local_tokenizer_loader.py +0 -9
- google/genai/_operations_converters.py +186 -99
- google/genai/_replay_api_client.py +48 -51
- google/genai/_tokens_converters.py +169 -489
- google/genai/_transformers.py +193 -90
- google/genai/batches.py +1014 -1307
- google/genai/caches.py +458 -1107
- google/genai/client.py +101 -0
- google/genai/documents.py +532 -0
- google/genai/errors.py +58 -4
- google/genai/file_search_stores.py +1296 -0
- google/genai/files.py +108 -358
- google/genai/live.py +90 -32
- google/genai/live_music.py +24 -27
- google/genai/local_tokenizer.py +36 -3
- google/genai/models.py +2308 -3375
- google/genai/operations.py +129 -21
- google/genai/pagers.py +7 -1
- google/genai/tokens.py +2 -12
- google/genai/tunings.py +770 -436
- google/genai/types.py +4341 -1218
- google/genai/version.py +1 -1
- {google_genai-1.33.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +359 -201
- google_genai-1.53.0.dist-info/RECORD +41 -0
- google_genai-1.33.0.dist-info/RECORD +0 -39
- {google_genai-1.33.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +0 -0
- {google_genai-1.33.0.dist-info → google_genai-1.53.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.33.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
google/genai/_common.py
CHANGED
|
@@ -21,6 +21,7 @@ import datetime
|
|
|
21
21
|
import enum
|
|
22
22
|
import functools
|
|
23
23
|
import logging
|
|
24
|
+
import re
|
|
24
25
|
import typing
|
|
25
26
|
from typing import Any, Callable, FrozenSet, Optional, Union, get_args, get_origin
|
|
26
27
|
import uuid
|
|
@@ -38,7 +39,9 @@ class ExperimentalWarning(Warning):
|
|
|
38
39
|
"""Warning for experimental features."""
|
|
39
40
|
|
|
40
41
|
|
|
41
|
-
def set_value_by_path(
|
|
42
|
+
def set_value_by_path(
|
|
43
|
+
data: Optional[dict[Any, Any]], keys: list[str], value: Any
|
|
44
|
+
) -> None:
|
|
42
45
|
"""Examples:
|
|
43
46
|
|
|
44
47
|
set_value_by_path({}, ['a', 'b'], v)
|
|
@@ -100,10 +103,19 @@ def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: An
|
|
|
100
103
|
f' Existing value: {existing_data}; New value: {value}.'
|
|
101
104
|
)
|
|
102
105
|
else:
|
|
103
|
-
|
|
106
|
+
if (
|
|
107
|
+
keys[-1] == '_self'
|
|
108
|
+
and isinstance(data, dict)
|
|
109
|
+
and isinstance(value, dict)
|
|
110
|
+
):
|
|
111
|
+
data.update(value)
|
|
112
|
+
else:
|
|
113
|
+
data[keys[-1]] = value
|
|
104
114
|
|
|
105
115
|
|
|
106
|
-
def get_value_by_path(
|
|
116
|
+
def get_value_by_path(
|
|
117
|
+
data: Any, keys: list[str], *, default_value: Any = None
|
|
118
|
+
) -> Any:
|
|
107
119
|
"""Examples:
|
|
108
120
|
|
|
109
121
|
get_value_by_path({'a': {'b': v}}, ['a', 'b'])
|
|
@@ -115,36 +127,141 @@ def get_value_by_path(data: Any, keys: list[str]) -> Any:
|
|
|
115
127
|
return data
|
|
116
128
|
for i, key in enumerate(keys):
|
|
117
129
|
if not data:
|
|
118
|
-
return
|
|
130
|
+
return default_value
|
|
119
131
|
if key.endswith('[]'):
|
|
120
132
|
key_name = key[:-2]
|
|
121
133
|
if key_name in data:
|
|
122
|
-
return [
|
|
134
|
+
return [
|
|
135
|
+
get_value_by_path(d, keys[i + 1 :], default_value=default_value)
|
|
136
|
+
for d in data[key_name]
|
|
137
|
+
]
|
|
123
138
|
else:
|
|
124
|
-
return
|
|
139
|
+
return default_value
|
|
125
140
|
elif key.endswith('[0]'):
|
|
126
141
|
key_name = key[:-3]
|
|
127
142
|
if key_name in data and data[key_name]:
|
|
128
|
-
return get_value_by_path(
|
|
143
|
+
return get_value_by_path(
|
|
144
|
+
data[key_name][0], keys[i + 1 :], default_value=default_value
|
|
145
|
+
)
|
|
129
146
|
else:
|
|
130
|
-
return
|
|
147
|
+
return default_value
|
|
131
148
|
else:
|
|
132
149
|
if key in data:
|
|
133
150
|
data = data[key]
|
|
134
151
|
elif isinstance(data, BaseModel) and hasattr(data, key):
|
|
135
152
|
data = getattr(data, key)
|
|
136
153
|
else:
|
|
137
|
-
return
|
|
154
|
+
return default_value
|
|
138
155
|
return data
|
|
139
156
|
|
|
140
157
|
|
|
141
|
-
def
|
|
158
|
+
def move_value_by_path(data: Any, paths: dict[str, str]) -> None:
|
|
159
|
+
"""Moves values from source paths to destination paths.
|
|
160
|
+
|
|
161
|
+
Examples:
|
|
162
|
+
move_value_by_path(
|
|
163
|
+
{'requests': [{'content': v1}, {'content': v2}]},
|
|
164
|
+
{'requests[].*': 'requests[].request.*'}
|
|
165
|
+
)
|
|
166
|
+
-> {'requests': [{'request': {'content': v1}}, {'request': {'content':
|
|
167
|
+
v2}}]}
|
|
168
|
+
"""
|
|
169
|
+
for source_path, dest_path in paths.items():
|
|
170
|
+
source_keys = source_path.split('.')
|
|
171
|
+
dest_keys = dest_path.split('.')
|
|
172
|
+
|
|
173
|
+
# Determine keys to exclude from wildcard to avoid cyclic references
|
|
174
|
+
exclude_keys = set()
|
|
175
|
+
wildcard_idx = -1
|
|
176
|
+
for i, key in enumerate(source_keys):
|
|
177
|
+
if key == '*':
|
|
178
|
+
wildcard_idx = i
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
if wildcard_idx != -1 and len(dest_keys) > wildcard_idx:
|
|
182
|
+
# Extract the intermediate key between source and dest paths
|
|
183
|
+
# Example: source=['requests[]', '*'], dest=['requests[]', 'request', '*']
|
|
184
|
+
# We want to exclude 'request'
|
|
185
|
+
for i in range(wildcard_idx, len(dest_keys)):
|
|
186
|
+
key = dest_keys[i]
|
|
187
|
+
if key != '*' and not key.endswith('[]') and not key.endswith('[0]'):
|
|
188
|
+
exclude_keys.add(key)
|
|
189
|
+
|
|
190
|
+
# Move values recursively
|
|
191
|
+
_move_value_recursive(data, source_keys, dest_keys, 0, exclude_keys)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _move_value_recursive(
|
|
195
|
+
data: Any,
|
|
196
|
+
source_keys: list[str],
|
|
197
|
+
dest_keys: list[str],
|
|
198
|
+
key_idx: int,
|
|
199
|
+
exclude_keys: set[str],
|
|
200
|
+
) -> None:
|
|
201
|
+
"""Recursively moves values from source path to destination path."""
|
|
202
|
+
if key_idx >= len(source_keys):
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
key = source_keys[key_idx]
|
|
206
|
+
|
|
207
|
+
if key.endswith('[]'):
|
|
208
|
+
# Handle array iteration
|
|
209
|
+
key_name = key[:-2]
|
|
210
|
+
if key_name in data and isinstance(data[key_name], list):
|
|
211
|
+
for item in data[key_name]:
|
|
212
|
+
_move_value_recursive(
|
|
213
|
+
item, source_keys, dest_keys, key_idx + 1, exclude_keys
|
|
214
|
+
)
|
|
215
|
+
elif key == '*':
|
|
216
|
+
# Handle wildcard - move all fields
|
|
217
|
+
if isinstance(data, dict):
|
|
218
|
+
# Get all keys to move (excluding specified keys)
|
|
219
|
+
keys_to_move = [
|
|
220
|
+
k
|
|
221
|
+
for k in list(data.keys())
|
|
222
|
+
if not k.startswith('_') and k not in exclude_keys
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
# Collect values to move
|
|
226
|
+
values_to_move = {k: data[k] for k in keys_to_move}
|
|
227
|
+
|
|
228
|
+
# Set values at destination
|
|
229
|
+
for k, v in values_to_move.items():
|
|
230
|
+
# Build destination keys with the field name
|
|
231
|
+
new_dest_keys = []
|
|
232
|
+
for dk in dest_keys[key_idx:]:
|
|
233
|
+
if dk == '*':
|
|
234
|
+
new_dest_keys.append(k)
|
|
235
|
+
else:
|
|
236
|
+
new_dest_keys.append(dk)
|
|
237
|
+
set_value_by_path(data, new_dest_keys, v)
|
|
238
|
+
|
|
239
|
+
# Delete from source
|
|
240
|
+
for k in keys_to_move:
|
|
241
|
+
del data[k]
|
|
242
|
+
else:
|
|
243
|
+
# Navigate to next level
|
|
244
|
+
if key in data:
|
|
245
|
+
_move_value_recursive(
|
|
246
|
+
data[key], source_keys, dest_keys, key_idx + 1, exclude_keys
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def maybe_snake_to_camel(snake_str: str, convert: bool = True) -> str:
|
|
251
|
+
"""Converts a snake_case string to CamelCase, if convert is True."""
|
|
252
|
+
if not convert:
|
|
253
|
+
return snake_str
|
|
254
|
+
return re.sub(r'_([a-zA-Z])', lambda match: match.group(1).upper(), snake_str)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def convert_to_dict(obj: object, convert_keys: bool = False) -> Any:
|
|
142
258
|
"""Recursively converts a given object to a dictionary.
|
|
143
259
|
|
|
144
260
|
If the object is a Pydantic model, it uses the model's `model_dump()` method.
|
|
145
261
|
|
|
146
262
|
Args:
|
|
147
263
|
obj: The object to convert.
|
|
264
|
+
convert_keys: Whether to convert the keys from snake case to camel case.
|
|
148
265
|
|
|
149
266
|
Returns:
|
|
150
267
|
A dictionary representation of the object, a list of objects if a list is
|
|
@@ -152,17 +269,21 @@ def convert_to_dict(obj: object) -> Any:
|
|
|
152
269
|
model.
|
|
153
270
|
"""
|
|
154
271
|
if isinstance(obj, pydantic.BaseModel):
|
|
155
|
-
return obj.model_dump(exclude_none=True)
|
|
272
|
+
return convert_to_dict(obj.model_dump(exclude_none=True), convert_keys)
|
|
156
273
|
elif isinstance(obj, dict):
|
|
157
|
-
return {
|
|
274
|
+
return {
|
|
275
|
+
maybe_snake_to_camel(key, convert_keys): convert_to_dict(value)
|
|
276
|
+
for key, value in obj.items()
|
|
277
|
+
}
|
|
158
278
|
elif isinstance(obj, list):
|
|
159
|
-
return [convert_to_dict(item) for item in obj]
|
|
279
|
+
return [convert_to_dict(item, convert_keys) for item in obj]
|
|
160
280
|
else:
|
|
161
281
|
return obj
|
|
162
282
|
|
|
163
283
|
|
|
164
284
|
def _is_struct_type(annotation: type) -> bool:
|
|
165
285
|
"""Checks if the given annotation is list[dict[str, typing.Any]]
|
|
286
|
+
|
|
166
287
|
or typing.List[typing.Dict[str, typing.Any]].
|
|
167
288
|
|
|
168
289
|
This maps to Struct type in the API.
|
|
@@ -170,7 +291,7 @@ def _is_struct_type(annotation: type) -> bool:
|
|
|
170
291
|
outer_origin = get_origin(annotation)
|
|
171
292
|
outer_args = get_args(annotation)
|
|
172
293
|
|
|
173
|
-
if outer_origin is not list:
|
|
294
|
+
if outer_origin is not list: # Python 3.9+ normalizes list
|
|
174
295
|
return False
|
|
175
296
|
|
|
176
297
|
if not outer_args or len(outer_args) != 1:
|
|
@@ -181,7 +302,7 @@ def _is_struct_type(annotation: type) -> bool:
|
|
|
181
302
|
inner_origin = get_origin(inner_annotation)
|
|
182
303
|
inner_args = get_args(inner_annotation)
|
|
183
304
|
|
|
184
|
-
if inner_origin is not dict:
|
|
305
|
+
if inner_origin is not dict: # Python 3.9+ normalizes to dict
|
|
185
306
|
return False
|
|
186
307
|
|
|
187
308
|
if not inner_args or len(inner_args) != 2:
|
|
@@ -193,9 +314,7 @@ def _is_struct_type(annotation: type) -> bool:
|
|
|
193
314
|
return key_type is str and value_type is typing.Any
|
|
194
315
|
|
|
195
316
|
|
|
196
|
-
def _remove_extra_fields(
|
|
197
|
-
model: Any, response: dict[str, object]
|
|
198
|
-
) -> None:
|
|
317
|
+
def _remove_extra_fields(model: Any, response: dict[str, object]) -> None:
|
|
199
318
|
"""Removes extra fields from the response that are not in the model.
|
|
200
319
|
|
|
201
320
|
Mutates the response in place.
|
|
@@ -235,6 +354,7 @@ def _remove_extra_fields(
|
|
|
235
354
|
if isinstance(item, dict):
|
|
236
355
|
_remove_extra_fields(typing.get_args(annotation)[0], item)
|
|
237
356
|
|
|
357
|
+
|
|
238
358
|
T = typing.TypeVar('T', bound='BaseModel')
|
|
239
359
|
|
|
240
360
|
|
|
@@ -310,8 +430,15 @@ def _pretty_repr(
|
|
|
310
430
|
elif isinstance(obj, collections.abc.Mapping):
|
|
311
431
|
if not obj:
|
|
312
432
|
return '{}'
|
|
433
|
+
|
|
434
|
+
# Check if the next level of recursion for keys/values will exceed the depth limit.
|
|
435
|
+
if depth <= 0:
|
|
436
|
+
item_count_str = f"{len(obj)} item{'s' if len(obj) != 1 else ''}"
|
|
437
|
+
return f'{{<... {item_count_str} at Max depth ...>}}'
|
|
438
|
+
|
|
313
439
|
if len(obj) > max_items:
|
|
314
440
|
return f'<dict len={len(obj)}>'
|
|
441
|
+
|
|
315
442
|
items = []
|
|
316
443
|
try:
|
|
317
444
|
sorted_keys = sorted(obj.keys(), key=str)
|
|
@@ -367,47 +494,56 @@ def _format_collection(
|
|
|
367
494
|
depth: int,
|
|
368
495
|
visited: FrozenSet[int],
|
|
369
496
|
) -> str:
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
497
|
+
"""Formats a collection (list, tuple, set)."""
|
|
498
|
+
if isinstance(obj, list):
|
|
499
|
+
brackets = ('[', ']')
|
|
500
|
+
internal_obj = obj
|
|
501
|
+
elif isinstance(obj, tuple):
|
|
502
|
+
brackets = ('(', ')')
|
|
503
|
+
internal_obj = list(obj)
|
|
504
|
+
elif isinstance(obj, set):
|
|
505
|
+
internal_obj = list(obj)
|
|
506
|
+
if obj:
|
|
507
|
+
brackets = ('{', '}')
|
|
381
508
|
else:
|
|
382
|
-
|
|
509
|
+
brackets = ('set(', ')')
|
|
510
|
+
else:
|
|
511
|
+
raise ValueError(f'Unsupported collection type: {type(obj)}')
|
|
383
512
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
513
|
+
if not internal_obj:
|
|
514
|
+
return brackets[0] + brackets[1]
|
|
515
|
+
|
|
516
|
+
# If the call to _pretty_repr for elements will have depth < 0
|
|
517
|
+
if depth <= 0:
|
|
518
|
+
item_count_str = f"{len(internal_obj)} item{'s'*(len(internal_obj)!=1)}"
|
|
519
|
+
return f'{brackets[0]}<... {item_count_str} at Max depth ...>{brackets[1]}'
|
|
520
|
+
|
|
521
|
+
indent = ' ' * indent_level
|
|
522
|
+
next_indent_str = ' ' * (indent_level + indent_delta)
|
|
523
|
+
elements = []
|
|
524
|
+
num_to_show = min(len(internal_obj), max_items)
|
|
525
|
+
|
|
526
|
+
for i in range(num_to_show):
|
|
527
|
+
elem = internal_obj[i]
|
|
528
|
+
elements.append(
|
|
529
|
+
next_indent_str
|
|
530
|
+
+ _pretty_repr(
|
|
531
|
+
elem,
|
|
532
|
+
indent_level=indent_level + indent_delta,
|
|
533
|
+
indent_delta=indent_delta,
|
|
534
|
+
max_len=max_len,
|
|
535
|
+
max_items=max_items,
|
|
536
|
+
depth=depth - 1,
|
|
537
|
+
visited=visited,
|
|
408
538
|
)
|
|
539
|
+
)
|
|
409
540
|
|
|
410
|
-
|
|
541
|
+
if len(internal_obj) > max_items:
|
|
542
|
+
elements.append(
|
|
543
|
+
f'{next_indent_str}<... {len(internal_obj) - max_items} more items ...>'
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
return f'{brackets[0]}\n' + ',\n'.join(elements) + f',\n{indent}{brackets[1]}'
|
|
411
547
|
|
|
412
548
|
|
|
413
549
|
class BaseModel(pydantic.BaseModel):
|
|
@@ -422,9 +558,47 @@ class BaseModel(pydantic.BaseModel):
|
|
|
422
558
|
arbitrary_types_allowed=True,
|
|
423
559
|
ser_json_bytes='base64',
|
|
424
560
|
val_json_bytes='base64',
|
|
425
|
-
ignored_types=(typing.TypeVar,)
|
|
561
|
+
ignored_types=(typing.TypeVar,),
|
|
426
562
|
)
|
|
427
563
|
|
|
564
|
+
@pydantic.model_validator(mode='before')
|
|
565
|
+
@classmethod
|
|
566
|
+
def _check_field_type_mismatches(cls, data: Any) -> Any:
|
|
567
|
+
"""Check for type mismatches and warn before Pydantic processes the data."""
|
|
568
|
+
# Handle both dict and Pydantic model inputs
|
|
569
|
+
if not isinstance(data, (dict, pydantic.BaseModel)):
|
|
570
|
+
return data
|
|
571
|
+
|
|
572
|
+
for field_name, field_info in cls.model_fields.items():
|
|
573
|
+
if isinstance(data, dict):
|
|
574
|
+
value = data.get(field_name)
|
|
575
|
+
else:
|
|
576
|
+
value = getattr(data, field_name, None)
|
|
577
|
+
|
|
578
|
+
if value is None:
|
|
579
|
+
continue
|
|
580
|
+
|
|
581
|
+
expected_type = field_info.annotation
|
|
582
|
+
origin = get_origin(expected_type)
|
|
583
|
+
|
|
584
|
+
if origin is Union:
|
|
585
|
+
args = get_args(expected_type)
|
|
586
|
+
non_none_types = [arg for arg in args if arg is not type(None)]
|
|
587
|
+
if len(non_none_types) == 1:
|
|
588
|
+
expected_type = non_none_types[0]
|
|
589
|
+
|
|
590
|
+
if (isinstance(expected_type, type) and
|
|
591
|
+
get_origin(expected_type) is None and
|
|
592
|
+
issubclass(expected_type, pydantic.BaseModel) and
|
|
593
|
+
isinstance(value, pydantic.BaseModel) and
|
|
594
|
+
not isinstance(value, expected_type)):
|
|
595
|
+
logger.warning(
|
|
596
|
+
f"Type mismatch in {cls.__name__}.{field_name}: "
|
|
597
|
+
f"expected {expected_type.__name__}, got {type(value).__name__}"
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
return data
|
|
601
|
+
|
|
428
602
|
def __repr__(self) -> str:
|
|
429
603
|
try:
|
|
430
604
|
return _pretty_repr(self)
|
|
@@ -433,7 +607,10 @@ class BaseModel(pydantic.BaseModel):
|
|
|
433
607
|
|
|
434
608
|
@classmethod
|
|
435
609
|
def _from_response(
|
|
436
|
-
cls: typing.Type[T],
|
|
610
|
+
cls: typing.Type[T],
|
|
611
|
+
*,
|
|
612
|
+
response: dict[str, object],
|
|
613
|
+
kwargs: dict[str, object],
|
|
437
614
|
) -> T:
|
|
438
615
|
# To maintain forward compatibility, we need to remove extra fields from
|
|
439
616
|
# the response.
|
|
@@ -443,11 +620,11 @@ class BaseModel(pydantic.BaseModel):
|
|
|
443
620
|
# user may pass a dict that is not a subclass of BaseModel.
|
|
444
621
|
# If more modules require we skip this, we may want a different approach
|
|
445
622
|
should_skip_removing_fields = (
|
|
446
|
-
kwargs is not None
|
|
447
|
-
'config' in kwargs
|
|
448
|
-
kwargs['config'] is not None
|
|
449
|
-
isinstance(kwargs['config'], dict)
|
|
450
|
-
'include_all_fields' in kwargs['config']
|
|
623
|
+
kwargs is not None
|
|
624
|
+
and 'config' in kwargs
|
|
625
|
+
and kwargs['config'] is not None
|
|
626
|
+
and isinstance(kwargs['config'], dict)
|
|
627
|
+
and 'include_all_fields' in kwargs['config']
|
|
451
628
|
and kwargs['config']['include_all_fields']
|
|
452
629
|
)
|
|
453
630
|
|
|
@@ -471,7 +648,7 @@ class CaseInSensitiveEnum(str, enum.Enum):
|
|
|
471
648
|
try:
|
|
472
649
|
return cls[value.lower()] # Try to access directly with lowercase
|
|
473
650
|
except KeyError:
|
|
474
|
-
warnings.warn(f
|
|
651
|
+
warnings.warn(f'{value} is not a valid {cls.__name__}')
|
|
475
652
|
try:
|
|
476
653
|
# Creating a enum instance based on the value
|
|
477
654
|
# We need to use super() to avoid infinite recursion.
|
|
@@ -532,10 +709,14 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
|
|
|
532
709
|
return processed_data
|
|
533
710
|
|
|
534
711
|
|
|
535
|
-
def experimental_warning(
|
|
712
|
+
def experimental_warning(
|
|
713
|
+
message: str,
|
|
714
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
536
715
|
"""Experimental warning, only warns once."""
|
|
716
|
+
|
|
537
717
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
538
718
|
warning_done = False
|
|
719
|
+
|
|
539
720
|
@functools.wraps(func)
|
|
540
721
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
541
722
|
nonlocal warning_done
|
|
@@ -547,13 +728,15 @@ def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callabl
|
|
|
547
728
|
stacklevel=2,
|
|
548
729
|
)
|
|
549
730
|
return func(*args, **kwargs)
|
|
731
|
+
|
|
550
732
|
return wrapper
|
|
733
|
+
|
|
551
734
|
return decorator
|
|
552
735
|
|
|
553
736
|
|
|
554
737
|
def _normalize_key_for_matching(key_str: str) -> str:
|
|
555
738
|
"""Normalizes a key for case-insensitive and snake/camel matching."""
|
|
556
|
-
return key_str.replace(
|
|
739
|
+
return key_str.replace('_', '').lower()
|
|
557
740
|
|
|
558
741
|
|
|
559
742
|
def align_key_case(
|
|
@@ -569,7 +752,9 @@ def align_key_case(
|
|
|
569
752
|
A new dictionary with keys aligned to target_dict's key casing.
|
|
570
753
|
"""
|
|
571
754
|
aligned_update_dict: StringDict = {}
|
|
572
|
-
target_keys_map = {
|
|
755
|
+
target_keys_map = {
|
|
756
|
+
_normalize_key_for_matching(key): key for key in target_dict.keys()
|
|
757
|
+
}
|
|
573
758
|
|
|
574
759
|
for key, value in update_dict.items():
|
|
575
760
|
normalized_update_key = _normalize_key_for_matching(key)
|
|
@@ -579,9 +764,15 @@ def align_key_case(
|
|
|
579
764
|
else:
|
|
580
765
|
aligned_key = key
|
|
581
766
|
|
|
582
|
-
if isinstance(value, dict) and isinstance(
|
|
583
|
-
|
|
584
|
-
|
|
767
|
+
if isinstance(value, dict) and isinstance(
|
|
768
|
+
target_dict.get(aligned_key), dict
|
|
769
|
+
):
|
|
770
|
+
aligned_update_dict[aligned_key] = align_key_case(
|
|
771
|
+
target_dict[aligned_key], value
|
|
772
|
+
)
|
|
773
|
+
elif isinstance(value, list) and isinstance(
|
|
774
|
+
target_dict.get(aligned_key), list
|
|
775
|
+
):
|
|
585
776
|
# Direct assign as we treat update_dict list values as golden source.
|
|
586
777
|
aligned_update_dict[aligned_key] = value
|
|
587
778
|
else:
|