google-genai 1.27.0__py3-none-any.whl → 1.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/_common.py CHANGED
@@ -22,15 +22,17 @@ import enum
22
22
  import functools
23
23
  import logging
24
24
  import typing
25
- from typing import Any, Callable, Optional, FrozenSet, Union, get_args, get_origin
25
+ from typing import Any, Callable, FrozenSet, Optional, Union, get_args, get_origin
26
26
  import uuid
27
27
  import warnings
28
-
29
28
  import pydantic
30
29
  from pydantic import alias_generators
30
+ from typing_extensions import TypeAlias
31
31
 
32
32
  logger = logging.getLogger('google_genai._common')
33
33
 
34
+ StringDict: TypeAlias = dict[str, Any]
35
+
34
36
 
35
37
  class ExperimentalWarning(Warning):
36
38
  """Warning for experimental features."""
@@ -355,7 +357,6 @@ def _pretty_repr(
355
357
  return raw_repr.replace('\n', f'\n{next_indent_str}')
356
358
 
357
359
 
358
-
359
360
  def _format_collection(
360
361
  obj: Any,
361
362
  *,
@@ -555,7 +556,9 @@ def _normalize_key_for_matching(key_str: str) -> str:
555
556
  return key_str.replace("_", "").lower()
556
557
 
557
558
 
558
- def align_key_case(target_dict: dict[str, Any], update_dict: dict[str, Any]) -> dict[str, Any]:
559
+ def align_key_case(
560
+ target_dict: StringDict, update_dict: StringDict
561
+ ) -> StringDict:
559
562
  """Aligns the keys of update_dict to the case of target_dict keys.
560
563
 
561
564
  Args:
@@ -565,7 +568,7 @@ def align_key_case(target_dict: dict[str, Any], update_dict: dict[str, Any]) ->
565
568
  Returns:
566
569
  A new dictionary with keys aligned to target_dict's key casing.
567
570
  """
568
- aligned_update_dict: dict[str, Any] = {}
571
+ aligned_update_dict: StringDict = {}
569
572
  target_keys_map = {_normalize_key_for_matching(key): key for key in target_dict.keys()}
570
573
 
571
574
  for key, value in update_dict.items():
@@ -587,7 +590,7 @@ def align_key_case(target_dict: dict[str, Any], update_dict: dict[str, Any]) ->
587
590
 
588
591
 
589
592
  def recursive_dict_update(
590
- target_dict: dict[str, Any], update_dict: dict[str, Any]
593
+ target_dict: StringDict, update_dict: StringDict
591
594
  ) -> None:
592
595
  """Recursively updates a target dictionary with values from an update dictionary.
593
596
 
@@ -155,8 +155,8 @@ def get_function_map(
155
155
 
156
156
 
157
157
  def convert_number_values_for_dict_function_call_args(
158
- args: dict[str, Any],
159
- ) -> dict[str, Any]:
158
+ args: _common.StringDict,
159
+ ) -> _common.StringDict:
160
160
  """Converts float values in dict with no decimal to integers."""
161
161
  return {
162
162
  key: convert_number_values_for_function_call_args(value)
@@ -257,8 +257,8 @@ def convert_if_exist_pydantic_model(
257
257
 
258
258
 
259
259
  def convert_argument_from_function(
260
- args: dict[str, Any], function: Callable[..., Any]
261
- ) -> dict[str, Any]:
260
+ args: _common.StringDict, function: Callable[..., Any]
261
+ ) -> _common.StringDict:
262
262
  signature = inspect.signature(function)
263
263
  func_name = function.__name__
264
264
  converted_args = {}
@@ -274,7 +274,7 @@ def convert_argument_from_function(
274
274
 
275
275
 
276
276
  def invoke_function_from_dict_args(
277
- args: Dict[str, Any], function_to_invoke: Callable[..., Any]
277
+ args: _common.StringDict, function_to_invoke: Callable[..., Any]
278
278
  ) -> Any:
279
279
  converted_args = convert_argument_from_function(args, function_to_invoke)
280
280
  try:
@@ -288,7 +288,7 @@ def invoke_function_from_dict_args(
288
288
 
289
289
 
290
290
  async def invoke_function_from_dict_args_async(
291
- args: Dict[str, Any], function_to_invoke: Callable[..., Any]
291
+ args: _common.StringDict, function_to_invoke: Callable[..., Any]
292
292
  ) -> Any:
293
293
  converted_args = convert_argument_from_function(args, function_to_invoke)
294
294
  try:
@@ -321,7 +321,7 @@ def get_function_response_parts(
321
321
  args = convert_number_values_for_dict_function_call_args(
322
322
  part.function_call.args
323
323
  )
324
- func_response: dict[str, Any]
324
+ func_response: _common.StringDict
325
325
  try:
326
326
  if not isinstance(func, McpToGenAiToolAdapter):
327
327
  func_response = {
@@ -356,7 +356,7 @@ async def get_function_response_parts_async(
356
356
  args = convert_number_values_for_dict_function_call_args(
357
357
  part.function_call.args
358
358
  )
359
- func_response: dict[str, Any]
359
+ func_response: _common.StringDict
360
360
  try:
361
361
  if isinstance(func, McpToGenAiToolAdapter):
362
362
  mcp_tool_response = await func.call_tool(
@@ -1098,6 +1098,13 @@ def _LiveMusicGenerationConfig_to_mldev(
1098
1098
  getv(from_object, ['only_bass_and_drums']),
1099
1099
  )
1100
1100
 
1101
+ if getv(from_object, ['music_generation_mode']) is not None:
1102
+ setv(
1103
+ to_object,
1104
+ ['musicGenerationMode'],
1105
+ getv(from_object, ['music_generation_mode']),
1106
+ )
1107
+
1101
1108
  return to_object
1102
1109
 
1103
1110
 
@@ -2871,6 +2878,13 @@ def _LiveMusicGenerationConfig_from_mldev(
2871
2878
  getv(from_object, ['onlyBassAndDrums']),
2872
2879
  )
2873
2880
 
2881
+ if getv(from_object, ['musicGenerationMode']) is not None:
2882
+ setv(
2883
+ to_object,
2884
+ ['music_generation_mode'],
2885
+ getv(from_object, ['musicGenerationMode']),
2886
+ )
2887
+
2874
2888
  return to_object
2875
2889
 
2876
2890
 
@@ -19,6 +19,7 @@ from importlib.metadata import PackageNotFoundError, version
19
19
  import typing
20
20
  from typing import Any
21
21
 
22
+ from . import _common
22
23
  from . import types
23
24
 
24
25
  if typing.TYPE_CHECKING:
@@ -89,7 +90,9 @@ def set_mcp_usage_header(headers: dict[str, str]) -> None:
89
90
  ).lstrip()
90
91
 
91
92
 
92
- def _filter_to_supported_schema(schema: dict[str, Any]) -> dict[str, Any]:
93
+ def _filter_to_supported_schema(
94
+ schema: _common.StringDict,
95
+ ) -> _common.StringDict:
93
96
  """Filters the schema to only include fields that are supported by JSONSchema."""
94
97
  supported_fields: set[str] = set(types.JSONSchema.model_fields.keys())
95
98
  schema_field_names: tuple[str] = ("items",) # 'additional_properties' to come
@@ -479,13 +479,14 @@ class ReplayApiClient(BaseApiClient):
479
479
  def _request(
480
480
  self,
481
481
  http_request: HttpRequest,
482
+ http_options: Optional[HttpOptionsOrDict] = None,
482
483
  stream: bool = False,
483
484
  ) -> HttpResponse:
484
485
  self._initialize_replay_session_if_not_loaded()
485
486
  if self._should_call_api():
486
487
  _debug_print('api mode request: %s' % http_request)
487
488
  try:
488
- result = super()._request(http_request, stream)
489
+ result = super()._request(http_request, http_options, stream)
489
490
  except errors.APIError as e:
490
491
  self._record_interaction(http_request, e)
491
492
  raise e
@@ -507,13 +508,16 @@ class ReplayApiClient(BaseApiClient):
507
508
  async def _async_request(
508
509
  self,
509
510
  http_request: HttpRequest,
511
+ http_options: Optional[HttpOptionsOrDict] = None,
510
512
  stream: bool = False,
511
513
  ) -> HttpResponse:
512
514
  self._initialize_replay_session_if_not_loaded()
513
515
  if self._should_call_api():
514
516
  _debug_print('api mode request: %s' % http_request)
515
517
  try:
516
- result = await super()._async_request(http_request, stream)
518
+ result = await super()._async_request(
519
+ http_request, http_options, stream
520
+ )
517
521
  except errors.APIError as e:
518
522
  self._record_interaction(http_request, e)
519
523
  raise e
@@ -35,6 +35,7 @@ if typing.TYPE_CHECKING:
35
35
  import pydantic
36
36
 
37
37
  from . import _api_client
38
+ from . import _common
38
39
  from . import types
39
40
 
40
41
  logger = logging.getLogger('google_genai._transformers')
@@ -195,20 +196,20 @@ def t_models_url(
195
196
 
196
197
 
197
198
  def t_extract_models(
198
- response: dict[str, Any],
199
- ) -> list[dict[str, Any]]:
199
+ response: _common.StringDict,
200
+ ) -> list[_common.StringDict]:
200
201
  if not response:
201
202
  return []
202
203
 
203
- models: Optional[list[dict[str, Any]]] = response.get('models')
204
+ models: Optional[list[_common.StringDict]] = response.get('models')
204
205
  if models is not None:
205
206
  return models
206
207
 
207
- tuned_models: Optional[list[dict[str, Any]]] = response.get('tunedModels')
208
+ tuned_models: Optional[list[_common.StringDict]] = response.get('tunedModels')
208
209
  if tuned_models is not None:
209
210
  return tuned_models
210
211
 
211
- publisher_models: Optional[list[dict[str, Any]]] = response.get(
212
+ publisher_models: Optional[list[_common.StringDict]] = response.get(
212
213
  'publisherModels'
213
214
  )
214
215
  if publisher_models is not None:
@@ -560,7 +561,7 @@ def t_contents(
560
561
  return result
561
562
 
562
563
 
563
- def handle_null_fields(schema: dict[str, Any]) -> None:
564
+ def handle_null_fields(schema: _common.StringDict) -> None:
564
565
  """Process null fields in the schema so it is compatible with OpenAPI.
565
566
 
566
567
  The OpenAPI spec does not support 'type: 'null' in the schema. This function
@@ -639,9 +640,9 @@ def _raise_for_unsupported_mldev_properties(
639
640
 
640
641
 
641
642
  def process_schema(
642
- schema: dict[str, Any],
643
+ schema: _common.StringDict,
643
644
  client: Optional[_api_client.BaseApiClient],
644
- defs: Optional[dict[str, Any]] = None,
645
+ defs: Optional[_common.StringDict] = None,
645
646
  *,
646
647
  order_properties: bool = True,
647
648
  ) -> None:
@@ -740,7 +741,7 @@ def process_schema(
740
741
  if (ref := schema.pop('$ref', None)) is not None:
741
742
  schema.update(defs[ref.split('defs/')[-1]])
742
743
 
743
- def _recurse(sub_schema: dict[str, Any]) -> dict[str, Any]:
744
+ def _recurse(sub_schema: _common.StringDict) -> _common.StringDict:
744
745
  """Returns the processed `sub_schema`, resolving its '$ref' if any."""
745
746
  if (ref := sub_schema.pop('$ref', None)) is not None:
746
747
  sub_schema = defs[ref.split('defs/')[-1]]
@@ -820,7 +821,7 @@ def _process_enum(
820
821
 
821
822
  def _is_type_dict_str_any(
822
823
  origin: Union[types.SchemaUnionDict, Any],
823
- ) -> TypeGuard[dict[str, Any]]:
824
+ ) -> TypeGuard[_common.StringDict]:
824
825
  """Verifies the schema is of type dict[str, Any] for mypy type checking."""
825
826
  return isinstance(origin, dict) and all(
826
827
  isinstance(key, str) for key in origin
@@ -1075,10 +1076,10 @@ LRO_POLLING_MULTIPLIER = 1.5
1075
1076
 
1076
1077
 
1077
1078
  def t_resolve_operation(
1078
- api_client: _api_client.BaseApiClient, struct: dict[str, Any]
1079
+ api_client: _api_client.BaseApiClient, struct: _common.StringDict
1079
1080
  ) -> Any:
1080
1081
  if (name := struct.get('name')) and '/operations/' in name:
1081
- operation: dict[str, Any] = struct
1082
+ operation: _common.StringDict = struct
1082
1083
  total_seconds = 0.0
1083
1084
  delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS
1084
1085
  while operation.get('done') != True:
google/genai/batches.py CHANGED
@@ -1492,6 +1492,9 @@ def _GenerateContentResponse_from_mldev(
1492
1492
  if getv(from_object, ['promptFeedback']) is not None:
1493
1493
  setv(to_object, ['prompt_feedback'], getv(from_object, ['promptFeedback']))
1494
1494
 
1495
+ if getv(from_object, ['responseId']) is not None:
1496
+ setv(to_object, ['response_id'], getv(from_object, ['responseId']))
1497
+
1495
1498
  if getv(from_object, ['usageMetadata']) is not None:
1496
1499
  setv(to_object, ['usage_metadata'], getv(from_object, ['usageMetadata']))
1497
1500
 
@@ -1648,6 +1651,11 @@ def _DeleteResourceJob_from_mldev(
1648
1651
  parent_object: Optional[dict[str, Any]] = None,
1649
1652
  ) -> dict[str, Any]:
1650
1653
  to_object: dict[str, Any] = {}
1654
+ if getv(from_object, ['sdkHttpResponse']) is not None:
1655
+ setv(
1656
+ to_object, ['sdk_http_response'], getv(from_object, ['sdkHttpResponse'])
1657
+ )
1658
+
1651
1659
  if getv(from_object, ['name']) is not None:
1652
1660
  setv(to_object, ['name'], getv(from_object, ['name']))
1653
1661
 
@@ -1815,6 +1823,11 @@ def _DeleteResourceJob_from_vertex(
1815
1823
  parent_object: Optional[dict[str, Any]] = None,
1816
1824
  ) -> dict[str, Any]:
1817
1825
  to_object: dict[str, Any] = {}
1826
+ if getv(from_object, ['sdkHttpResponse']) is not None:
1827
+ setv(
1828
+ to_object, ['sdk_http_response'], getv(from_object, ['sdkHttpResponse'])
1829
+ )
1830
+
1818
1831
  if getv(from_object, ['name']) is not None:
1819
1832
  setv(to_object, ['name'], getv(from_object, ['name']))
1820
1833
 
@@ -2186,7 +2199,9 @@ class Batches(_api_module.BaseModule):
2186
2199
  return_value = types.DeleteResourceJob._from_response(
2187
2200
  response=response_dict, kwargs=parameter_model.model_dump()
2188
2201
  )
2189
-
2202
+ return_value.sdk_http_response = types.HttpResponse(
2203
+ headers=response.headers
2204
+ )
2190
2205
  self._api_client._verify_response(return_value)
2191
2206
  return return_value
2192
2207
 
@@ -2619,7 +2634,9 @@ class AsyncBatches(_api_module.BaseModule):
2619
2634
  return_value = types.DeleteResourceJob._from_response(
2620
2635
  response=response_dict, kwargs=parameter_model.model_dump()
2621
2636
  )
2622
-
2637
+ return_value.sdk_http_response = types.HttpResponse(
2638
+ headers=response.headers
2639
+ )
2623
2640
  self._api_client._verify_response(return_value)
2624
2641
  return return_value
2625
2642
 
google/genai/errors.py CHANGED
@@ -65,7 +65,7 @@ class APIError(Exception):
65
65
  'code', response_json.get('error', {}).get('code', None)
66
66
  )
67
67
 
68
- def _to_replay_record(self) -> dict[str, Any]:
68
+ def _to_replay_record(self) -> _common.StringDict:
69
69
  """Returns a dictionary representation of the error for replay recording.
70
70
 
71
71
  details is not included since it may expose internal information in the
@@ -172,18 +172,21 @@ class ServerError(APIError):
172
172
 
173
173
  class UnknownFunctionCallArgumentError(ValueError):
174
174
  """Raised when the function call argument cannot be converted to the parameter annotation."""
175
-
176
175
  pass
177
176
 
178
177
 
179
178
  class UnsupportedFunctionError(ValueError):
180
179
  """Raised when the function is not supported."""
180
+ pass
181
181
 
182
182
 
183
183
  class FunctionInvocationError(ValueError):
184
184
  """Raised when the function cannot be invoked with the given arguments."""
185
-
186
185
  pass
187
186
 
188
187
 
188
+ class UnknownApiResponseError(ValueError):
189
+ """Raised when the response from the API cannot be parsed as JSON."""
190
+ pass
191
+
189
192
  ExperimentalWarning = _common.ExperimentalWarning
google/genai/live.py CHANGED
@@ -33,7 +33,6 @@ from . import _common
33
33
  from . import _live_converters as live_converters
34
34
  from . import _mcp_utils
35
35
  from . import _transformers as t
36
- from . import client
37
36
  from . import errors
38
37
  from . import types
39
38
  from ._api_client import BaseApiClient
@@ -288,7 +287,7 @@ class AsyncSession:
288
287
  print(f'{msg.text}')
289
288
  ```
290
289
  """
291
- kwargs: dict[str, Any] = {}
290
+ kwargs: _common.StringDict = {}
292
291
  if media is not None:
293
292
  kwargs['media'] = media
294
293
  if audio is not None:
@@ -639,7 +638,7 @@ class AsyncSession:
639
638
  elif isinstance(formatted_input, Sequence) and any(
640
639
  isinstance(c, str) for c in formatted_input
641
640
  ):
642
- to_object: dict[str, Any] = {}
641
+ to_object: _common.StringDict = {}
643
642
  content_input_parts: list[types.PartUnion] = []
644
643
  for item in formatted_input:
645
644
  if isinstance(item, get_args(types.PartUnion)):