google-genai 1.33.0__py3-none-any.whl → 1.53.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/_common.py CHANGED
@@ -21,6 +21,7 @@ import datetime
21
21
  import enum
22
22
  import functools
23
23
  import logging
24
+ import re
24
25
  import typing
25
26
  from typing import Any, Callable, FrozenSet, Optional, Union, get_args, get_origin
26
27
  import uuid
@@ -38,7 +39,9 @@ class ExperimentalWarning(Warning):
38
39
  """Warning for experimental features."""
39
40
 
40
41
 
41
- def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None:
42
+ def set_value_by_path(
43
+ data: Optional[dict[Any, Any]], keys: list[str], value: Any
44
+ ) -> None:
42
45
  """Examples:
43
46
 
44
47
  set_value_by_path({}, ['a', 'b'], v)
@@ -100,10 +103,19 @@ def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: An
100
103
  f' Existing value: {existing_data}; New value: {value}.'
101
104
  )
102
105
  else:
103
- data[keys[-1]] = value
106
+ if (
107
+ keys[-1] == '_self'
108
+ and isinstance(data, dict)
109
+ and isinstance(value, dict)
110
+ ):
111
+ data.update(value)
112
+ else:
113
+ data[keys[-1]] = value
104
114
 
105
115
 
106
- def get_value_by_path(data: Any, keys: list[str]) -> Any:
116
+ def get_value_by_path(
117
+ data: Any, keys: list[str], *, default_value: Any = None
118
+ ) -> Any:
107
119
  """Examples:
108
120
 
109
121
  get_value_by_path({'a': {'b': v}}, ['a', 'b'])
@@ -115,36 +127,141 @@ def get_value_by_path(data: Any, keys: list[str]) -> Any:
115
127
  return data
116
128
  for i, key in enumerate(keys):
117
129
  if not data:
118
- return None
130
+ return default_value
119
131
  if key.endswith('[]'):
120
132
  key_name = key[:-2]
121
133
  if key_name in data:
122
- return [get_value_by_path(d, keys[i + 1 :]) for d in data[key_name]]
134
+ return [
135
+ get_value_by_path(d, keys[i + 1 :], default_value=default_value)
136
+ for d in data[key_name]
137
+ ]
123
138
  else:
124
- return None
139
+ return default_value
125
140
  elif key.endswith('[0]'):
126
141
  key_name = key[:-3]
127
142
  if key_name in data and data[key_name]:
128
- return get_value_by_path(data[key_name][0], keys[i + 1 :])
143
+ return get_value_by_path(
144
+ data[key_name][0], keys[i + 1 :], default_value=default_value
145
+ )
129
146
  else:
130
- return None
147
+ return default_value
131
148
  else:
132
149
  if key in data:
133
150
  data = data[key]
134
151
  elif isinstance(data, BaseModel) and hasattr(data, key):
135
152
  data = getattr(data, key)
136
153
  else:
137
- return None
154
+ return default_value
138
155
  return data
139
156
 
140
157
 
141
- def convert_to_dict(obj: object) -> Any:
158
+ def move_value_by_path(data: Any, paths: dict[str, str]) -> None:
159
+ """Moves values from source paths to destination paths.
160
+
161
+ Examples:
162
+ move_value_by_path(
163
+ {'requests': [{'content': v1}, {'content': v2}]},
164
+ {'requests[].*': 'requests[].request.*'}
165
+ )
166
+ -> {'requests': [{'request': {'content': v1}}, {'request': {'content':
167
+ v2}}]}
168
+ """
169
+ for source_path, dest_path in paths.items():
170
+ source_keys = source_path.split('.')
171
+ dest_keys = dest_path.split('.')
172
+
173
+ # Determine keys to exclude from wildcard to avoid cyclic references
174
+ exclude_keys = set()
175
+ wildcard_idx = -1
176
+ for i, key in enumerate(source_keys):
177
+ if key == '*':
178
+ wildcard_idx = i
179
+ break
180
+
181
+ if wildcard_idx != -1 and len(dest_keys) > wildcard_idx:
182
+ # Extract the intermediate key between source and dest paths
183
+ # Example: source=['requests[]', '*'], dest=['requests[]', 'request', '*']
184
+ # We want to exclude 'request'
185
+ for i in range(wildcard_idx, len(dest_keys)):
186
+ key = dest_keys[i]
187
+ if key != '*' and not key.endswith('[]') and not key.endswith('[0]'):
188
+ exclude_keys.add(key)
189
+
190
+ # Move values recursively
191
+ _move_value_recursive(data, source_keys, dest_keys, 0, exclude_keys)
192
+
193
+
194
+ def _move_value_recursive(
195
+ data: Any,
196
+ source_keys: list[str],
197
+ dest_keys: list[str],
198
+ key_idx: int,
199
+ exclude_keys: set[str],
200
+ ) -> None:
201
+ """Recursively moves values from source path to destination path."""
202
+ if key_idx >= len(source_keys):
203
+ return
204
+
205
+ key = source_keys[key_idx]
206
+
207
+ if key.endswith('[]'):
208
+ # Handle array iteration
209
+ key_name = key[:-2]
210
+ if key_name in data and isinstance(data[key_name], list):
211
+ for item in data[key_name]:
212
+ _move_value_recursive(
213
+ item, source_keys, dest_keys, key_idx + 1, exclude_keys
214
+ )
215
+ elif key == '*':
216
+ # Handle wildcard - move all fields
217
+ if isinstance(data, dict):
218
+ # Get all keys to move (excluding specified keys)
219
+ keys_to_move = [
220
+ k
221
+ for k in list(data.keys())
222
+ if not k.startswith('_') and k not in exclude_keys
223
+ ]
224
+
225
+ # Collect values to move
226
+ values_to_move = {k: data[k] for k in keys_to_move}
227
+
228
+ # Set values at destination
229
+ for k, v in values_to_move.items():
230
+ # Build destination keys with the field name
231
+ new_dest_keys = []
232
+ for dk in dest_keys[key_idx:]:
233
+ if dk == '*':
234
+ new_dest_keys.append(k)
235
+ else:
236
+ new_dest_keys.append(dk)
237
+ set_value_by_path(data, new_dest_keys, v)
238
+
239
+ # Delete from source
240
+ for k in keys_to_move:
241
+ del data[k]
242
+ else:
243
+ # Navigate to next level
244
+ if key in data:
245
+ _move_value_recursive(
246
+ data[key], source_keys, dest_keys, key_idx + 1, exclude_keys
247
+ )
248
+
249
+
250
+ def maybe_snake_to_camel(snake_str: str, convert: bool = True) -> str:
251
+ """Converts a snake_case string to CamelCase, if convert is True."""
252
+ if not convert:
253
+ return snake_str
254
+ return re.sub(r'_([a-zA-Z])', lambda match: match.group(1).upper(), snake_str)
255
+
256
+
257
+ def convert_to_dict(obj: object, convert_keys: bool = False) -> Any:
142
258
  """Recursively converts a given object to a dictionary.
143
259
 
144
260
  If the object is a Pydantic model, it uses the model's `model_dump()` method.
145
261
 
146
262
  Args:
147
263
  obj: The object to convert.
264
+ convert_keys: Whether to convert the keys from snake case to camel case.
148
265
 
149
266
  Returns:
150
267
  A dictionary representation of the object, a list of objects if a list is
@@ -152,17 +269,21 @@ def convert_to_dict(obj: object) -> Any:
152
269
  model.
153
270
  """
154
271
  if isinstance(obj, pydantic.BaseModel):
155
- return obj.model_dump(exclude_none=True)
272
+ return convert_to_dict(obj.model_dump(exclude_none=True), convert_keys)
156
273
  elif isinstance(obj, dict):
157
- return {key: convert_to_dict(value) for key, value in obj.items()}
274
+ return {
275
+ maybe_snake_to_camel(key, convert_keys): convert_to_dict(value)
276
+ for key, value in obj.items()
277
+ }
158
278
  elif isinstance(obj, list):
159
- return [convert_to_dict(item) for item in obj]
279
+ return [convert_to_dict(item, convert_keys) for item in obj]
160
280
  else:
161
281
  return obj
162
282
 
163
283
 
164
284
  def _is_struct_type(annotation: type) -> bool:
165
285
  """Checks if the given annotation is list[dict[str, typing.Any]]
286
+
166
287
  or typing.List[typing.Dict[str, typing.Any]].
167
288
 
168
289
  This maps to Struct type in the API.
@@ -170,7 +291,7 @@ def _is_struct_type(annotation: type) -> bool:
170
291
  outer_origin = get_origin(annotation)
171
292
  outer_args = get_args(annotation)
172
293
 
173
- if outer_origin is not list: # Python 3.9+ normalizes list
294
+ if outer_origin is not list: # Python 3.9+ normalizes list
174
295
  return False
175
296
 
176
297
  if not outer_args or len(outer_args) != 1:
@@ -181,7 +302,7 @@ def _is_struct_type(annotation: type) -> bool:
181
302
  inner_origin = get_origin(inner_annotation)
182
303
  inner_args = get_args(inner_annotation)
183
304
 
184
- if inner_origin is not dict: # Python 3.9+ normalizes to dict
305
+ if inner_origin is not dict: # Python 3.9+ normalizes to dict
185
306
  return False
186
307
 
187
308
  if not inner_args or len(inner_args) != 2:
@@ -193,9 +314,7 @@ def _is_struct_type(annotation: type) -> bool:
193
314
  return key_type is str and value_type is typing.Any
194
315
 
195
316
 
196
- def _remove_extra_fields(
197
- model: Any, response: dict[str, object]
198
- ) -> None:
317
+ def _remove_extra_fields(model: Any, response: dict[str, object]) -> None:
199
318
  """Removes extra fields from the response that are not in the model.
200
319
 
201
320
  Mutates the response in place.
@@ -235,6 +354,7 @@ def _remove_extra_fields(
235
354
  if isinstance(item, dict):
236
355
  _remove_extra_fields(typing.get_args(annotation)[0], item)
237
356
 
357
+
238
358
  T = typing.TypeVar('T', bound='BaseModel')
239
359
 
240
360
 
@@ -310,8 +430,15 @@ def _pretty_repr(
310
430
  elif isinstance(obj, collections.abc.Mapping):
311
431
  if not obj:
312
432
  return '{}'
433
+
434
+ # Check if the next level of recursion for keys/values will exceed the depth limit.
435
+ if depth <= 0:
436
+ item_count_str = f"{len(obj)} item{'s' if len(obj) != 1 else ''}"
437
+ return f'{{<... {item_count_str} at Max depth ...>}}'
438
+
313
439
  if len(obj) > max_items:
314
440
  return f'<dict len={len(obj)}>'
441
+
315
442
  items = []
316
443
  try:
317
444
  sorted_keys = sorted(obj.keys(), key=str)
@@ -367,47 +494,56 @@ def _format_collection(
367
494
  depth: int,
368
495
  visited: FrozenSet[int],
369
496
  ) -> str:
370
- """Formats a collection (list, tuple, set)."""
371
- if isinstance(obj, list):
372
- brackets = ('[', ']')
373
- elif isinstance(obj, tuple):
374
- brackets = ('(', ')')
375
- elif isinstance(obj, set):
376
- obj = list(obj)
377
- if obj:
378
- brackets = ('{', '}')
379
- else:
380
- brackets = ('set(', ')')
497
+ """Formats a collection (list, tuple, set)."""
498
+ if isinstance(obj, list):
499
+ brackets = ('[', ']')
500
+ internal_obj = obj
501
+ elif isinstance(obj, tuple):
502
+ brackets = ('(', ')')
503
+ internal_obj = list(obj)
504
+ elif isinstance(obj, set):
505
+ internal_obj = list(obj)
506
+ if obj:
507
+ brackets = ('{', '}')
381
508
  else:
382
- raise ValueError(f"Unsupported collection type: {type(obj)}")
509
+ brackets = ('set(', ')')
510
+ else:
511
+ raise ValueError(f'Unsupported collection type: {type(obj)}')
383
512
 
384
- if not obj:
385
- return brackets[0] + brackets[1]
386
-
387
- indent = ' ' * indent_level
388
- next_indent_str = ' ' * (indent_level + indent_delta)
389
- elements = []
390
- for i, elem in enumerate(obj):
391
- if i >= max_items:
392
- elements.append(
393
- f'{next_indent_str}<... {len(obj) - max_items} more items ...>'
394
- )
395
- break
396
- # Each element starts on a new line, fully indented
397
- elements.append(
398
- next_indent_str
399
- + _pretty_repr(
400
- elem,
401
- indent_level=indent_level + indent_delta,
402
- indent_delta=indent_delta,
403
- max_len=max_len,
404
- max_items=max_items,
405
- depth=depth - 1,
406
- visited=visited,
407
- )
513
+ if not internal_obj:
514
+ return brackets[0] + brackets[1]
515
+
516
+ # If the call to _pretty_repr for elements will have depth < 0
517
+ if depth <= 0:
518
+ item_count_str = f"{len(internal_obj)} item{'s'*(len(internal_obj)!=1)}"
519
+ return f'{brackets[0]}<... {item_count_str} at Max depth ...>{brackets[1]}'
520
+
521
+ indent = ' ' * indent_level
522
+ next_indent_str = ' ' * (indent_level + indent_delta)
523
+ elements = []
524
+ num_to_show = min(len(internal_obj), max_items)
525
+
526
+ for i in range(num_to_show):
527
+ elem = internal_obj[i]
528
+ elements.append(
529
+ next_indent_str
530
+ + _pretty_repr(
531
+ elem,
532
+ indent_level=indent_level + indent_delta,
533
+ indent_delta=indent_delta,
534
+ max_len=max_len,
535
+ max_items=max_items,
536
+ depth=depth - 1,
537
+ visited=visited,
408
538
  )
539
+ )
409
540
 
410
- return f'{brackets[0]}\n' + ',\n'.join(elements) + "," + f'\n{indent}{brackets[1]}'
541
+ if len(internal_obj) > max_items:
542
+ elements.append(
543
+ f'{next_indent_str}<... {len(internal_obj) - max_items} more items ...>'
544
+ )
545
+
546
+ return f'{brackets[0]}\n' + ',\n'.join(elements) + f',\n{indent}{brackets[1]}'
411
547
 
412
548
 
413
549
  class BaseModel(pydantic.BaseModel):
@@ -422,9 +558,47 @@ class BaseModel(pydantic.BaseModel):
422
558
  arbitrary_types_allowed=True,
423
559
  ser_json_bytes='base64',
424
560
  val_json_bytes='base64',
425
- ignored_types=(typing.TypeVar,)
561
+ ignored_types=(typing.TypeVar,),
426
562
  )
427
563
 
564
+ @pydantic.model_validator(mode='before')
565
+ @classmethod
566
+ def _check_field_type_mismatches(cls, data: Any) -> Any:
567
+ """Check for type mismatches and warn before Pydantic processes the data."""
568
+ # Handle both dict and Pydantic model inputs
569
+ if not isinstance(data, (dict, pydantic.BaseModel)):
570
+ return data
571
+
572
+ for field_name, field_info in cls.model_fields.items():
573
+ if isinstance(data, dict):
574
+ value = data.get(field_name)
575
+ else:
576
+ value = getattr(data, field_name, None)
577
+
578
+ if value is None:
579
+ continue
580
+
581
+ expected_type = field_info.annotation
582
+ origin = get_origin(expected_type)
583
+
584
+ if origin is Union:
585
+ args = get_args(expected_type)
586
+ non_none_types = [arg for arg in args if arg is not type(None)]
587
+ if len(non_none_types) == 1:
588
+ expected_type = non_none_types[0]
589
+
590
+ if (isinstance(expected_type, type) and
591
+ get_origin(expected_type) is None and
592
+ issubclass(expected_type, pydantic.BaseModel) and
593
+ isinstance(value, pydantic.BaseModel) and
594
+ not isinstance(value, expected_type)):
595
+ logger.warning(
596
+ f"Type mismatch in {cls.__name__}.{field_name}: "
597
+ f"expected {expected_type.__name__}, got {type(value).__name__}"
598
+ )
599
+
600
+ return data
601
+
428
602
  def __repr__(self) -> str:
429
603
  try:
430
604
  return _pretty_repr(self)
@@ -433,7 +607,10 @@ class BaseModel(pydantic.BaseModel):
433
607
 
434
608
  @classmethod
435
609
  def _from_response(
436
- cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object]
610
+ cls: typing.Type[T],
611
+ *,
612
+ response: dict[str, object],
613
+ kwargs: dict[str, object],
437
614
  ) -> T:
438
615
  # To maintain forward compatibility, we need to remove extra fields from
439
616
  # the response.
@@ -443,11 +620,11 @@ class BaseModel(pydantic.BaseModel):
443
620
  # user may pass a dict that is not a subclass of BaseModel.
444
621
  # If more modules require we skip this, we may want a different approach
445
622
  should_skip_removing_fields = (
446
- kwargs is not None and
447
- 'config' in kwargs and
448
- kwargs['config'] is not None and
449
- isinstance(kwargs['config'], dict) and
450
- 'include_all_fields' in kwargs['config']
623
+ kwargs is not None
624
+ and 'config' in kwargs
625
+ and kwargs['config'] is not None
626
+ and isinstance(kwargs['config'], dict)
627
+ and 'include_all_fields' in kwargs['config']
451
628
  and kwargs['config']['include_all_fields']
452
629
  )
453
630
 
@@ -471,7 +648,7 @@ class CaseInSensitiveEnum(str, enum.Enum):
471
648
  try:
472
649
  return cls[value.lower()] # Try to access directly with lowercase
473
650
  except KeyError:
474
- warnings.warn(f"{value} is not a valid {cls.__name__}")
651
+ warnings.warn(f'{value} is not a valid {cls.__name__}')
475
652
  try:
476
653
  # Creating a enum instance based on the value
477
654
  # We need to use super() to avoid infinite recursion.
@@ -532,10 +709,14 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
532
709
  return processed_data
533
710
 
534
711
 
535
- def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
712
+ def experimental_warning(
713
+ message: str,
714
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
536
715
  """Experimental warning, only warns once."""
716
+
537
717
  def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
538
718
  warning_done = False
719
+
539
720
  @functools.wraps(func)
540
721
  def wrapper(*args: Any, **kwargs: Any) -> Any:
541
722
  nonlocal warning_done
@@ -547,13 +728,15 @@ def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callabl
547
728
  stacklevel=2,
548
729
  )
549
730
  return func(*args, **kwargs)
731
+
550
732
  return wrapper
733
+
551
734
  return decorator
552
735
 
553
736
 
554
737
  def _normalize_key_for_matching(key_str: str) -> str:
555
738
  """Normalizes a key for case-insensitive and snake/camel matching."""
556
- return key_str.replace("_", "").lower()
739
+ return key_str.replace('_', '').lower()
557
740
 
558
741
 
559
742
  def align_key_case(
@@ -569,7 +752,9 @@ def align_key_case(
569
752
  A new dictionary with keys aligned to target_dict's key casing.
570
753
  """
571
754
  aligned_update_dict: StringDict = {}
572
- target_keys_map = {_normalize_key_for_matching(key): key for key in target_dict.keys()}
755
+ target_keys_map = {
756
+ _normalize_key_for_matching(key): key for key in target_dict.keys()
757
+ }
573
758
 
574
759
  for key, value in update_dict.items():
575
760
  normalized_update_key = _normalize_key_for_matching(key)
@@ -579,9 +764,15 @@ def align_key_case(
579
764
  else:
580
765
  aligned_key = key
581
766
 
582
- if isinstance(value, dict) and isinstance(target_dict.get(aligned_key), dict):
583
- aligned_update_dict[aligned_key] = align_key_case(target_dict[aligned_key], value)
584
- elif isinstance(value, list) and isinstance(target_dict.get(aligned_key), list):
767
+ if isinstance(value, dict) and isinstance(
768
+ target_dict.get(aligned_key), dict
769
+ ):
770
+ aligned_update_dict[aligned_key] = align_key_case(
771
+ target_dict[aligned_key], value
772
+ )
773
+ elif isinstance(value, list) and isinstance(
774
+ target_dict.get(aligned_key), list
775
+ ):
585
776
  # Direct assign as we treat update_dict list values as golden source.
586
777
  aligned_update_dict[aligned_key] = value
587
778
  else: