google-genai 1.7.0__py3-none-any.whl → 1.53.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. google/genai/__init__.py +4 -2
  2. google/genai/_adapters.py +55 -0
  3. google/genai/_api_client.py +1301 -299
  4. google/genai/_api_module.py +1 -1
  5. google/genai/_automatic_function_calling_util.py +54 -33
  6. google/genai/_base_transformers.py +26 -0
  7. google/genai/_base_url.py +50 -0
  8. google/genai/_common.py +560 -59
  9. google/genai/_extra_utils.py +371 -38
  10. google/genai/_live_converters.py +1467 -0
  11. google/genai/_local_tokenizer_loader.py +214 -0
  12. google/genai/_mcp_utils.py +117 -0
  13. google/genai/_operations_converters.py +394 -0
  14. google/genai/_replay_api_client.py +204 -92
  15. google/genai/_test_api_client.py +1 -1
  16. google/genai/_tokens_converters.py +520 -0
  17. google/genai/_transformers.py +633 -233
  18. google/genai/batches.py +1733 -538
  19. google/genai/caches.py +678 -1012
  20. google/genai/chats.py +48 -38
  21. google/genai/client.py +142 -15
  22. google/genai/documents.py +532 -0
  23. google/genai/errors.py +141 -35
  24. google/genai/file_search_stores.py +1296 -0
  25. google/genai/files.py +312 -744
  26. google/genai/live.py +617 -367
  27. google/genai/live_music.py +197 -0
  28. google/genai/local_tokenizer.py +395 -0
  29. google/genai/models.py +3598 -3116
  30. google/genai/operations.py +201 -362
  31. google/genai/pagers.py +23 -7
  32. google/genai/py.typed +1 -0
  33. google/genai/tokens.py +362 -0
  34. google/genai/tunings.py +1274 -496
  35. google/genai/types.py +14535 -5454
  36. google/genai/version.py +2 -2
  37. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +736 -234
  38. google_genai-1.53.0.dist-info/RECORD +41 -0
  39. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +1 -1
  40. google_genai-1.7.0.dist-info/RECORD +0 -27
  41. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info/licenses}/LICENSE +0 -0
  42. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
google/genai/_common.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Google LLC
1
+ # Copyright 2025 Google LLC
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,22 +16,32 @@
16
16
  """Common utilities for the SDK."""
17
17
 
18
18
  import base64
19
+ import collections.abc
19
20
  import datetime
20
21
  import enum
21
22
  import functools
23
+ import logging
24
+ import re
22
25
  import typing
23
- from typing import Any, Union
26
+ from typing import Any, Callable, FrozenSet, Optional, Union, get_args, get_origin
24
27
  import uuid
25
28
  import warnings
26
-
27
29
  import pydantic
28
30
  from pydantic import alias_generators
31
+ from typing_extensions import TypeAlias
32
+
33
+ logger = logging.getLogger('google_genai._common')
34
+
35
+ StringDict: TypeAlias = dict[str, Any]
36
+
29
37
 
30
- from . import _api_client
31
- from . import errors
38
+ class ExperimentalWarning(Warning):
39
+ """Warning for experimental features."""
32
40
 
33
41
 
34
- def set_value_by_path(data, keys, value):
42
+ def set_value_by_path(
43
+ data: Optional[dict[Any, Any]], keys: list[str], value: Any
44
+ ) -> None:
35
45
  """Examples:
36
46
 
37
47
  set_value_by_path({}, ['a', 'b'], v)
@@ -46,54 +56,66 @@ def set_value_by_path(data, keys, value):
46
56
  for i, key in enumerate(keys[:-1]):
47
57
  if key.endswith('[]'):
48
58
  key_name = key[:-2]
49
- if key_name not in data:
59
+ if data is not None and key_name not in data:
50
60
  if isinstance(value, list):
51
61
  data[key_name] = [{} for _ in range(len(value))]
52
62
  else:
53
63
  raise ValueError(
54
64
  f'value {value} must be a list given an array path {key}'
55
65
  )
56
- if isinstance(value, list):
66
+ if isinstance(value, list) and data is not None:
57
67
  for j, d in enumerate(data[key_name]):
58
68
  set_value_by_path(d, keys[i + 1 :], value[j])
59
69
  else:
60
- for d in data[key_name]:
61
- set_value_by_path(d, keys[i + 1 :], value)
70
+ if data is not None:
71
+ for d in data[key_name]:
72
+ set_value_by_path(d, keys[i + 1 :], value)
62
73
  return
63
74
  elif key.endswith('[0]'):
64
75
  key_name = key[:-3]
65
- if key_name not in data:
76
+ if data is not None and key_name not in data:
66
77
  data[key_name] = [{}]
67
- set_value_by_path(data[key_name][0], keys[i + 1 :], value)
78
+ if data is not None:
79
+ set_value_by_path(data[key_name][0], keys[i + 1 :], value)
68
80
  return
69
-
70
- data = data.setdefault(key, {})
71
-
72
- existing_data = data.get(keys[-1])
73
- # If there is an existing value, merge, not overwrite.
74
- if existing_data is not None:
75
- # Don't overwrite existing non-empty value with new empty value.
76
- # This is triggered when handling tuning datasets.
77
- if not value:
78
- pass
79
- # Don't fail when overwriting value with same value
80
- elif value == existing_data:
81
- pass
82
- # Instead of overwriting dictionary with another dictionary, merge them.
83
- # This is important for handling training and validation datasets in tuning.
84
- elif isinstance(existing_data, dict) and isinstance(value, dict):
85
- # Merging dictionaries. Consider deep merging in the future.
86
- existing_data.update(value)
81
+ if data is not None:
82
+ data = data.setdefault(key, {})
83
+
84
+ if data is not None:
85
+ existing_data = data.get(keys[-1])
86
+ # If there is an existing value, merge, not overwrite.
87
+ if existing_data is not None:
88
+ # Don't overwrite existing non-empty value with new empty value.
89
+ # This is triggered when handling tuning datasets.
90
+ if not value:
91
+ pass
92
+ # Don't fail when overwriting value with same value
93
+ elif value == existing_data:
94
+ pass
95
+ # Instead of overwriting dictionary with another dictionary, merge them.
96
+ # This is important for handling training and validation datasets in tuning.
97
+ elif isinstance(existing_data, dict) and isinstance(value, dict):
98
+ # Merging dictionaries. Consider deep merging in the future.
99
+ existing_data.update(value)
100
+ else:
101
+ raise ValueError(
102
+ f'Cannot set value for an existing key. Key: {keys[-1]};'
103
+ f' Existing value: {existing_data}; New value: {value}.'
104
+ )
87
105
  else:
88
- raise ValueError(
89
- f'Cannot set value for an existing key. Key: {keys[-1]};'
90
- f' Existing value: {existing_data}; New value: {value}.'
91
- )
92
- else:
93
- data[keys[-1]] = value
106
+ if (
107
+ keys[-1] == '_self'
108
+ and isinstance(data, dict)
109
+ and isinstance(value, dict)
110
+ ):
111
+ data.update(value)
112
+ else:
113
+ data[keys[-1]] = value
94
114
 
95
115
 
96
- def get_value_by_path(data: Any, keys: list[str]):
116
+ def get_value_by_path(
117
+ data: Any, keys: list[str], *, default_value: Any = None
118
+ ) -> Any:
97
119
  """Examples:
98
120
 
99
121
  get_value_by_path({'a': {'b': v}}, ['a', 'b'])
@@ -105,36 +127,141 @@ def get_value_by_path(data: Any, keys: list[str]):
105
127
  return data
106
128
  for i, key in enumerate(keys):
107
129
  if not data:
108
- return None
130
+ return default_value
109
131
  if key.endswith('[]'):
110
132
  key_name = key[:-2]
111
133
  if key_name in data:
112
- return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
134
+ return [
135
+ get_value_by_path(d, keys[i + 1 :], default_value=default_value)
136
+ for d in data[key_name]
137
+ ]
113
138
  else:
114
- return None
139
+ return default_value
115
140
  elif key.endswith('[0]'):
116
141
  key_name = key[:-3]
117
142
  if key_name in data and data[key_name]:
118
- return get_value_by_path(data[key_name][0], keys[i + 1 :])
143
+ return get_value_by_path(
144
+ data[key_name][0], keys[i + 1 :], default_value=default_value
145
+ )
119
146
  else:
120
- return None
147
+ return default_value
121
148
  else:
122
149
  if key in data:
123
150
  data = data[key]
124
151
  elif isinstance(data, BaseModel) and hasattr(data, key):
125
152
  data = getattr(data, key)
126
153
  else:
127
- return None
154
+ return default_value
128
155
  return data
129
156
 
130
157
 
131
- def convert_to_dict(obj: object) -> Any:
158
+ def move_value_by_path(data: Any, paths: dict[str, str]) -> None:
159
+ """Moves values from source paths to destination paths.
160
+
161
+ Examples:
162
+ move_value_by_path(
163
+ {'requests': [{'content': v1}, {'content': v2}]},
164
+ {'requests[].*': 'requests[].request.*'}
165
+ )
166
+ -> {'requests': [{'request': {'content': v1}}, {'request': {'content':
167
+ v2}}]}
168
+ """
169
+ for source_path, dest_path in paths.items():
170
+ source_keys = source_path.split('.')
171
+ dest_keys = dest_path.split('.')
172
+
173
+ # Determine keys to exclude from wildcard to avoid cyclic references
174
+ exclude_keys = set()
175
+ wildcard_idx = -1
176
+ for i, key in enumerate(source_keys):
177
+ if key == '*':
178
+ wildcard_idx = i
179
+ break
180
+
181
+ if wildcard_idx != -1 and len(dest_keys) > wildcard_idx:
182
+ # Extract the intermediate key between source and dest paths
183
+ # Example: source=['requests[]', '*'], dest=['requests[]', 'request', '*']
184
+ # We want to exclude 'request'
185
+ for i in range(wildcard_idx, len(dest_keys)):
186
+ key = dest_keys[i]
187
+ if key != '*' and not key.endswith('[]') and not key.endswith('[0]'):
188
+ exclude_keys.add(key)
189
+
190
+ # Move values recursively
191
+ _move_value_recursive(data, source_keys, dest_keys, 0, exclude_keys)
192
+
193
+
194
+ def _move_value_recursive(
195
+ data: Any,
196
+ source_keys: list[str],
197
+ dest_keys: list[str],
198
+ key_idx: int,
199
+ exclude_keys: set[str],
200
+ ) -> None:
201
+ """Recursively moves values from source path to destination path."""
202
+ if key_idx >= len(source_keys):
203
+ return
204
+
205
+ key = source_keys[key_idx]
206
+
207
+ if key.endswith('[]'):
208
+ # Handle array iteration
209
+ key_name = key[:-2]
210
+ if key_name in data and isinstance(data[key_name], list):
211
+ for item in data[key_name]:
212
+ _move_value_recursive(
213
+ item, source_keys, dest_keys, key_idx + 1, exclude_keys
214
+ )
215
+ elif key == '*':
216
+ # Handle wildcard - move all fields
217
+ if isinstance(data, dict):
218
+ # Get all keys to move (excluding specified keys)
219
+ keys_to_move = [
220
+ k
221
+ for k in list(data.keys())
222
+ if not k.startswith('_') and k not in exclude_keys
223
+ ]
224
+
225
+ # Collect values to move
226
+ values_to_move = {k: data[k] for k in keys_to_move}
227
+
228
+ # Set values at destination
229
+ for k, v in values_to_move.items():
230
+ # Build destination keys with the field name
231
+ new_dest_keys = []
232
+ for dk in dest_keys[key_idx:]:
233
+ if dk == '*':
234
+ new_dest_keys.append(k)
235
+ else:
236
+ new_dest_keys.append(dk)
237
+ set_value_by_path(data, new_dest_keys, v)
238
+
239
+ # Delete from source
240
+ for k in keys_to_move:
241
+ del data[k]
242
+ else:
243
+ # Navigate to next level
244
+ if key in data:
245
+ _move_value_recursive(
246
+ data[key], source_keys, dest_keys, key_idx + 1, exclude_keys
247
+ )
248
+
249
+
250
+ def maybe_snake_to_camel(snake_str: str, convert: bool = True) -> str:
251
+ """Converts a snake_case string to CamelCase, if convert is True."""
252
+ if not convert:
253
+ return snake_str
254
+ return re.sub(r'_([a-zA-Z])', lambda match: match.group(1).upper(), snake_str)
255
+
256
+
257
+ def convert_to_dict(obj: object, convert_keys: bool = False) -> Any:
132
258
  """Recursively converts a given object to a dictionary.
133
259
 
134
260
  If the object is a Pydantic model, it uses the model's `model_dump()` method.
135
261
 
136
262
  Args:
137
263
  obj: The object to convert.
264
+ convert_keys: Whether to convert the keys from snake case to camel case.
138
265
 
139
266
  Returns:
140
267
  A dictionary representation of the object, a list of objects if a list is
@@ -142,18 +269,52 @@ def convert_to_dict(obj: object) -> Any:
142
269
  model.
143
270
  """
144
271
  if isinstance(obj, pydantic.BaseModel):
145
- return obj.model_dump(exclude_none=True)
272
+ return convert_to_dict(obj.model_dump(exclude_none=True), convert_keys)
146
273
  elif isinstance(obj, dict):
147
- return {key: convert_to_dict(value) for key, value in obj.items()}
274
+ return {
275
+ maybe_snake_to_camel(key, convert_keys): convert_to_dict(value)
276
+ for key, value in obj.items()
277
+ }
148
278
  elif isinstance(obj, list):
149
- return [convert_to_dict(item) for item in obj]
279
+ return [convert_to_dict(item, convert_keys) for item in obj]
150
280
  else:
151
281
  return obj
152
282
 
153
283
 
154
- def _remove_extra_fields(
155
- model: Any, response: dict[str, object]
156
- ) -> None:
284
+ def _is_struct_type(annotation: type) -> bool:
285
+ """Checks if the given annotation is list[dict[str, typing.Any]]
286
+
287
+ or typing.List[typing.Dict[str, typing.Any]].
288
+
289
+ This maps to Struct type in the API.
290
+ """
291
+ outer_origin = get_origin(annotation)
292
+ outer_args = get_args(annotation)
293
+
294
+ if outer_origin is not list: # Python 3.9+ normalizes list
295
+ return False
296
+
297
+ if not outer_args or len(outer_args) != 1:
298
+ return False
299
+
300
+ inner_annotation = outer_args[0]
301
+
302
+ inner_origin = get_origin(inner_annotation)
303
+ inner_args = get_args(inner_annotation)
304
+
305
+ if inner_origin is not dict: # Python 3.9+ normalizes to dict
306
+ return False
307
+
308
+ if not inner_args or len(inner_args) != 2:
309
+ # dict should have exactly two type arguments
310
+ return False
311
+
312
+ # Check if the dict arguments are str and typing.Any
313
+ key_type, value_type = inner_args
314
+ return key_type is str and value_type is typing.Any
315
+
316
+
317
+ def _remove_extra_fields(model: Any, response: dict[str, object]) -> None:
157
318
  """Removes extra fields from the response that are not in the model.
158
319
 
159
320
  Mutates the response in place.
@@ -185,14 +346,206 @@ def _remove_extra_fields(
185
346
  if isinstance(value, dict) and typing.get_origin(annotation) is not dict:
186
347
  _remove_extra_fields(annotation, value)
187
348
  elif isinstance(value, list):
349
+ if _is_struct_type(annotation):
350
+ continue
351
+
188
352
  for item in value:
189
353
  # assume a list of dict is list of BaseModel
190
354
  if isinstance(item, dict):
191
355
  _remove_extra_fields(typing.get_args(annotation)[0], item)
192
356
 
357
+
193
358
  T = typing.TypeVar('T', bound='BaseModel')
194
359
 
195
360
 
361
+ def _pretty_repr(
362
+ obj: Any,
363
+ *,
364
+ indent_level: int = 0,
365
+ indent_delta: int = 2,
366
+ max_len: int = 100,
367
+ max_items: int = 5,
368
+ depth: int = 6,
369
+ visited: Optional[FrozenSet[int]] = None,
370
+ ) -> str:
371
+ """Returns a representation of the given object."""
372
+ if visited is None:
373
+ visited = frozenset()
374
+
375
+ obj_id = id(obj)
376
+ if obj_id in visited:
377
+ return '<... Circular reference ...>'
378
+
379
+ if depth < 0:
380
+ return '<... Max depth ...>'
381
+
382
+ visited = frozenset(list(visited) + [obj_id])
383
+
384
+ indent = ' ' * indent_level
385
+ next_indent_str = ' ' * (indent_level + indent_delta)
386
+
387
+ if isinstance(obj, pydantic.BaseModel):
388
+ cls_name = obj.__class__.__name__
389
+ items = []
390
+ # Sort fields for consistent output
391
+ fields = sorted(type(obj).model_fields)
392
+
393
+ for field_name in fields:
394
+ field_info = type(obj).model_fields[field_name]
395
+ if not field_info.repr: # Respect Field(repr=False)
396
+ continue
397
+
398
+ try:
399
+ value = getattr(obj, field_name)
400
+ except AttributeError:
401
+ continue
402
+
403
+ if value is None:
404
+ continue
405
+
406
+ value_repr = _pretty_repr(
407
+ value,
408
+ indent_level=indent_level + indent_delta,
409
+ indent_delta=indent_delta,
410
+ max_len=max_len,
411
+ max_items=max_items,
412
+ depth=depth - 1,
413
+ visited=visited,
414
+ )
415
+ items.append(f'{next_indent_str}{field_name}={value_repr}')
416
+
417
+ if not items:
418
+ return f'{cls_name}()'
419
+ return f'{cls_name}(\n' + ',\n'.join(items) + f'\n{indent})'
420
+ elif isinstance(obj, str):
421
+ if '\n' in obj:
422
+ escaped = obj.replace('"""', '\\"\\"\\"')
423
+ # Indent the multi-line string block contents
424
+ return f'"""{escaped}"""'
425
+ return repr(obj)
426
+ elif isinstance(obj, bytes):
427
+ if len(obj) > max_len:
428
+ return f"{repr(obj[:max_len-3])[:-1]}...'"
429
+ return repr(obj)
430
+ elif isinstance(obj, collections.abc.Mapping):
431
+ if not obj:
432
+ return '{}'
433
+
434
+ # Check if the next level of recursion for keys/values will exceed the depth limit.
435
+ if depth <= 0:
436
+ item_count_str = f"{len(obj)} item{'s' if len(obj) != 1 else ''}"
437
+ return f'{{<... {item_count_str} at Max depth ...>}}'
438
+
439
+ if len(obj) > max_items:
440
+ return f'<dict len={len(obj)}>'
441
+
442
+ items = []
443
+ try:
444
+ sorted_keys = sorted(obj.keys(), key=str)
445
+ except TypeError:
446
+ sorted_keys = list(obj.keys())
447
+
448
+ for k in sorted_keys:
449
+ v = obj[k]
450
+ k_repr = _pretty_repr(
451
+ k,
452
+ indent_level=indent_level + indent_delta,
453
+ indent_delta=indent_delta,
454
+ max_len=max_len,
455
+ max_items=max_items,
456
+ depth=depth - 1,
457
+ visited=visited,
458
+ )
459
+ v_repr = _pretty_repr(
460
+ v,
461
+ indent_level=indent_level + indent_delta,
462
+ indent_delta=indent_delta,
463
+ max_len=max_len,
464
+ max_items=max_items,
465
+ depth=depth - 1,
466
+ visited=visited,
467
+ )
468
+ items.append(f'{next_indent_str}{k_repr}: {v_repr}')
469
+ return f'{{\n' + ',\n'.join(items) + f'\n{indent}}}'
470
+ elif isinstance(obj, (list, tuple, set)):
471
+ return _format_collection(
472
+ obj,
473
+ indent_level=indent_level,
474
+ indent_delta=indent_delta,
475
+ max_len=max_len,
476
+ max_items=max_items,
477
+ depth=depth,
478
+ visited=visited,
479
+ )
480
+ else:
481
+ # Fallback to standard repr, indenting subsequent lines only
482
+ raw_repr = repr(obj)
483
+ # Replace newlines with newline + indent
484
+ return raw_repr.replace('\n', f'\n{next_indent_str}')
485
+
486
+
487
+ def _format_collection(
488
+ obj: Any,
489
+ *,
490
+ indent_level: int,
491
+ indent_delta: int,
492
+ max_len: int,
493
+ max_items: int,
494
+ depth: int,
495
+ visited: FrozenSet[int],
496
+ ) -> str:
497
+ """Formats a collection (list, tuple, set)."""
498
+ if isinstance(obj, list):
499
+ brackets = ('[', ']')
500
+ internal_obj = obj
501
+ elif isinstance(obj, tuple):
502
+ brackets = ('(', ')')
503
+ internal_obj = list(obj)
504
+ elif isinstance(obj, set):
505
+ internal_obj = list(obj)
506
+ if obj:
507
+ brackets = ('{', '}')
508
+ else:
509
+ brackets = ('set(', ')')
510
+ else:
511
+ raise ValueError(f'Unsupported collection type: {type(obj)}')
512
+
513
+ if not internal_obj:
514
+ return brackets[0] + brackets[1]
515
+
516
+ # If the call to _pretty_repr for elements will have depth < 0
517
+ if depth <= 0:
518
+ item_count_str = f"{len(internal_obj)} item{'s'*(len(internal_obj)!=1)}"
519
+ return f'{brackets[0]}<... {item_count_str} at Max depth ...>{brackets[1]}'
520
+
521
+ indent = ' ' * indent_level
522
+ next_indent_str = ' ' * (indent_level + indent_delta)
523
+ elements = []
524
+ num_to_show = min(len(internal_obj), max_items)
525
+
526
+ for i in range(num_to_show):
527
+ elem = internal_obj[i]
528
+ elements.append(
529
+ next_indent_str
530
+ + _pretty_repr(
531
+ elem,
532
+ indent_level=indent_level + indent_delta,
533
+ indent_delta=indent_delta,
534
+ max_len=max_len,
535
+ max_items=max_items,
536
+ depth=depth - 1,
537
+ visited=visited,
538
+ )
539
+ )
540
+
541
+ if len(internal_obj) > max_items:
542
+ elements.append(
543
+ f'{next_indent_str}<... {len(internal_obj) - max_items} more items ...>'
544
+ )
545
+
546
+ return f'{brackets[0]}\n' + ',\n'.join(elements) + f',\n{indent}{brackets[1]}'
547
+
548
+
196
549
  class BaseModel(pydantic.BaseModel):
197
550
 
198
551
  model_config = pydantic.ConfigDict(
@@ -205,17 +558,78 @@ class BaseModel(pydantic.BaseModel):
205
558
  arbitrary_types_allowed=True,
206
559
  ser_json_bytes='base64',
207
560
  val_json_bytes='base64',
208
- ignored_types=(typing.TypeVar,)
561
+ ignored_types=(typing.TypeVar,),
209
562
  )
210
563
 
564
+ @pydantic.model_validator(mode='before')
565
+ @classmethod
566
+ def _check_field_type_mismatches(cls, data: Any) -> Any:
567
+ """Check for type mismatches and warn before Pydantic processes the data."""
568
+ # Handle both dict and Pydantic model inputs
569
+ if not isinstance(data, (dict, pydantic.BaseModel)):
570
+ return data
571
+
572
+ for field_name, field_info in cls.model_fields.items():
573
+ if isinstance(data, dict):
574
+ value = data.get(field_name)
575
+ else:
576
+ value = getattr(data, field_name, None)
577
+
578
+ if value is None:
579
+ continue
580
+
581
+ expected_type = field_info.annotation
582
+ origin = get_origin(expected_type)
583
+
584
+ if origin is Union:
585
+ args = get_args(expected_type)
586
+ non_none_types = [arg for arg in args if arg is not type(None)]
587
+ if len(non_none_types) == 1:
588
+ expected_type = non_none_types[0]
589
+
590
+ if (isinstance(expected_type, type) and
591
+ get_origin(expected_type) is None and
592
+ issubclass(expected_type, pydantic.BaseModel) and
593
+ isinstance(value, pydantic.BaseModel) and
594
+ not isinstance(value, expected_type)):
595
+ logger.warning(
596
+ f"Type mismatch in {cls.__name__}.{field_name}: "
597
+ f"expected {expected_type.__name__}, got {type(value).__name__}"
598
+ )
599
+
600
+ return data
601
+
602
+ def __repr__(self) -> str:
603
+ try:
604
+ return _pretty_repr(self)
605
+ except Exception:
606
+ return super().__repr__()
607
+
211
608
  @classmethod
212
609
  def _from_response(
213
- cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
610
+ cls: typing.Type[T],
611
+ *,
612
+ response: dict[str, object],
613
+ kwargs: dict[str, object],
214
614
  ) -> T:
215
615
  # To maintain forward compatibility, we need to remove extra fields from
216
616
  # the response.
217
617
  # We will provide another mechanism to allow users to access these fields.
218
- _remove_extra_fields(cls, response)
618
+
619
+ # For Agent Engine we don't want to call _remove_all_fields because the
620
+ # user may pass a dict that is not a subclass of BaseModel.
621
+ # If more modules require we skip this, we may want a different approach
622
+ should_skip_removing_fields = (
623
+ kwargs is not None
624
+ and 'config' in kwargs
625
+ and kwargs['config'] is not None
626
+ and isinstance(kwargs['config'], dict)
627
+ and 'include_all_fields' in kwargs['config']
628
+ and kwargs['config']['include_all_fields']
629
+ )
630
+
631
+ if not should_skip_removing_fields:
632
+ _remove_extra_fields(cls, response)
219
633
  validated_response = cls.model_validate(response)
220
634
  return validated_response
221
635
 
@@ -227,14 +641,14 @@ class CaseInSensitiveEnum(str, enum.Enum):
227
641
  """Case insensitive enum."""
228
642
 
229
643
  @classmethod
230
- def _missing_(cls, value):
644
+ def _missing_(cls, value: Any) -> Any:
231
645
  try:
232
646
  return cls[value.upper()] # Try to access directly with uppercase
233
647
  except KeyError:
234
648
  try:
235
649
  return cls[value.lower()] # Try to access directly with lowercase
236
650
  except KeyError:
237
- warnings.warn(f"{value} is not a valid {cls.__name__}")
651
+ warnings.warn(f'{value} is not a valid {cls.__name__}')
238
652
  try:
239
653
  # Creating a enum instance based on the value
240
654
  # We need to use super() to avoid infinite recursion.
@@ -295,21 +709,108 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
295
709
  return processed_data
296
710
 
297
711
 
298
- def experimental_warning(message: str):
712
+ def experimental_warning(
713
+ message: str,
714
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
299
715
  """Experimental warning, only warns once."""
300
- def decorator(func):
716
+
717
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
301
718
  warning_done = False
719
+
302
720
  @functools.wraps(func)
303
- def wrapper(*args, **kwargs):
721
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
304
722
  nonlocal warning_done
305
723
  if not warning_done:
306
724
  warning_done = True
307
725
  warnings.warn(
308
726
  message=message,
309
- category=errors.ExperimentalWarning,
727
+ category=ExperimentalWarning,
310
728
  stacklevel=2,
311
729
  )
312
730
  return func(*args, **kwargs)
731
+
313
732
  return wrapper
733
+
314
734
  return decorator
315
735
 
736
+
737
+ def _normalize_key_for_matching(key_str: str) -> str:
738
+ """Normalizes a key for case-insensitive and snake/camel matching."""
739
+ return key_str.replace('_', '').lower()
740
+
741
+
742
+ def align_key_case(
743
+ target_dict: StringDict, update_dict: StringDict
744
+ ) -> StringDict:
745
+ """Aligns the keys of update_dict to the case of target_dict keys.
746
+
747
+ Args:
748
+ target_dict: The dictionary with the target key casing.
749
+ update_dict: The dictionary whose keys need to be aligned.
750
+
751
+ Returns:
752
+ A new dictionary with keys aligned to target_dict's key casing.
753
+ """
754
+ aligned_update_dict: StringDict = {}
755
+ target_keys_map = {
756
+ _normalize_key_for_matching(key): key for key in target_dict.keys()
757
+ }
758
+
759
+ for key, value in update_dict.items():
760
+ normalized_update_key = _normalize_key_for_matching(key)
761
+
762
+ if normalized_update_key in target_keys_map:
763
+ aligned_key = target_keys_map[normalized_update_key]
764
+ else:
765
+ aligned_key = key
766
+
767
+ if isinstance(value, dict) and isinstance(
768
+ target_dict.get(aligned_key), dict
769
+ ):
770
+ aligned_update_dict[aligned_key] = align_key_case(
771
+ target_dict[aligned_key], value
772
+ )
773
+ elif isinstance(value, list) and isinstance(
774
+ target_dict.get(aligned_key), list
775
+ ):
776
+ # Direct assign as we treat update_dict list values as golden source.
777
+ aligned_update_dict[aligned_key] = value
778
+ else:
779
+ aligned_update_dict[aligned_key] = value
780
+ return aligned_update_dict
781
+
782
+
783
+ def recursive_dict_update(
784
+ target_dict: StringDict, update_dict: StringDict
785
+ ) -> None:
786
+ """Recursively updates a target dictionary with values from an update dictionary.
787
+
788
+ We don't enforce the updated dict values to have the same type with the
789
+ target_dict values except log warnings.
790
+ Users providing the update_dict should be responsible for constructing correct
791
+ data.
792
+
793
+ Args:
794
+ target_dict (dict): The dictionary to be updated.
795
+ update_dict (dict): The dictionary containing updates.
796
+ """
797
+ # Python SDK http request may change in camel case or snake case:
798
+ # If the field is directly set via setv() function, then it is camel case;
799
+ # otherwise it is snake case.
800
+ # Align the update_dict key case to target_dict to ensure correct dict update.
801
+ aligned_update_dict = align_key_case(target_dict, update_dict)
802
+ for key, value in aligned_update_dict.items():
803
+ if (
804
+ key in target_dict
805
+ and isinstance(target_dict[key], dict)
806
+ and isinstance(value, dict)
807
+ ):
808
+ recursive_dict_update(target_dict[key], value)
809
+ elif key in target_dict and not isinstance(target_dict[key], type(value)):
810
+ logger.warning(
811
+ f"Type mismatch for key '{key}'. Existing type:"
812
+ f' {type(target_dict[key])}, new type: {type(value)}. Overwriting.'
813
+ )
814
+ target_dict[key] = value
815
+ else:
816
+ target_dict[key] = value