google-genai 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,7 @@ The BaseApiClient is intended to be a private module and is subject to change.
20
20
  """
21
21
 
22
22
  import asyncio
23
+ from collections.abc import Awaitable, Generator
23
24
  import copy
24
25
  from dataclasses import dataclass
25
26
  import datetime
@@ -29,21 +30,31 @@ import json
29
30
  import logging
30
31
  import math
31
32
  import os
33
+ import ssl
32
34
  import sys
33
35
  import time
34
36
  from typing import Any, AsyncIterator, Optional, Tuple, Union
35
- from urllib.parse import urlparse, urlunparse
37
+ from urllib.parse import urlparse
38
+ from urllib.parse import urlunparse
39
+
36
40
  import anyio
41
+ import certifi
37
42
  import google.auth
38
43
  import google.auth.credentials
39
44
  from google.auth.credentials import Credentials
40
45
  from google.auth.transport.requests import Request
41
46
  import httpx
42
- from pydantic import BaseModel, Field, ValidationError
47
+ from pydantic import BaseModel
48
+ from pydantic import Field
49
+ from pydantic import ValidationError
50
+
43
51
  from . import _common
44
52
  from . import errors
45
53
  from . import version
46
- from .types import HttpOptions, HttpOptionsDict, HttpOptionsOrDict
54
+ from .types import HttpOptions
55
+ from .types import HttpOptionsDict
56
+ from .types import HttpOptionsOrDict
57
+
47
58
 
48
59
  logger = logging.getLogger('google_genai._api_client')
49
60
  CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB chunk size
@@ -119,7 +130,7 @@ def _join_url_path(base_url: str, path: str) -> str:
119
130
 
120
131
  def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
121
132
  """Loads google auth credentials and project id."""
122
- credentials, loaded_project_id = google.auth.default(
133
+ credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
123
134
  scopes=['https://www.googleapis.com/auth/cloud-platform'],
124
135
  )
125
136
 
@@ -135,7 +146,7 @@ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
135
146
 
136
147
 
137
148
  def _refresh_auth(credentials: Credentials) -> Credentials:
138
- credentials.refresh(Request())
149
+ credentials.refresh(Request()) # type: ignore[no-untyped-call]
139
150
  return credentials
140
151
 
141
152
 
@@ -181,17 +192,17 @@ class HttpResponse:
181
192
  response_stream: Union[Any, str] = None,
182
193
  byte_stream: Union[Any, bytes] = None,
183
194
  ):
184
- self.status_code = 200
195
+ self.status_code: int = 200
185
196
  self.headers = headers
186
197
  self.response_stream = response_stream
187
198
  self.byte_stream = byte_stream
188
199
 
189
200
  # Async iterator for async streaming.
190
- def __aiter__(self):
201
+ def __aiter__(self) -> 'HttpResponse':
191
202
  self.segment_iterator = self.async_segments()
192
203
  return self
193
204
 
194
- async def __anext__(self):
205
+ async def __anext__(self) -> Any:
195
206
  try:
196
207
  return await self.segment_iterator.__anext__()
197
208
  except StopIteration:
@@ -203,7 +214,7 @@ class HttpResponse:
203
214
  return ''
204
215
  return json.loads(self.response_stream[0])
205
216
 
206
- def segments(self):
217
+ def segments(self) -> Generator[Any, None, None]:
207
218
  if isinstance(self.response_stream, list):
208
219
  # list of objects retrieved from replay or from non-streaming API.
209
220
  for chunk in self.response_stream:
@@ -212,7 +223,7 @@ class HttpResponse:
212
223
  yield from []
213
224
  else:
214
225
  # Iterator of objects retrieved from the API.
215
- for chunk in self.response_stream.iter_lines():
226
+ for chunk in self.response_stream.iter_lines(): # type: ignore[union-attr]
216
227
  if chunk:
217
228
  # In streaming mode, the chunk of JSON is prefixed with "data:" which
218
229
  # we must strip before parsing.
@@ -246,7 +257,7 @@ class HttpResponse:
246
257
  else:
247
258
  raise ValueError('Error parsing streaming response.')
248
259
 
249
- def byte_segments(self):
260
+ def byte_segments(self) -> Generator[Union[bytes, Any], None, None]:
250
261
  if isinstance(self.byte_stream, list):
251
262
  # list of objects retrieved from replay or from non-streaming API.
252
263
  yield from self.byte_stream
@@ -257,7 +268,7 @@ class HttpResponse:
257
268
  'Byte segments are not supported for streaming responses.'
258
269
  )
259
270
 
260
- def _copy_to_dict(self, response_payload: dict[str, object]):
271
+ def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
261
272
  # Cannot pickle 'generator' object.
262
273
  delattr(self, 'segment_iterator')
263
274
  for attribute in dir(self):
@@ -414,7 +425,7 @@ class BaseApiClient:
414
425
  if not self.api_key:
415
426
  raise ValueError(
416
427
  'Missing key inputs argument! To use the Google AI API,'
417
- 'provide (`api_key`) arguments. To use the Google Cloud API,'
428
+ ' provide (`api_key`) arguments. To use the Google Cloud API,'
418
429
  ' provide (`vertexai`, `project` & `location`) arguments.'
419
430
  )
420
431
  self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
@@ -432,13 +443,71 @@ class BaseApiClient:
432
443
  else:
433
444
  if self._http_options.headers is not None:
434
445
  _append_library_version_headers(self._http_options.headers)
435
- # Initialize the httpx client.
436
- self._httpx_client = SyncHttpxClient()
437
- self._async_httpx_client = AsyncHttpxClient()
438
446
 
439
- def _websocket_base_url(self):
447
+ client_args, async_client_args = self._ensure_ssl_ctx(self._http_options)
448
+ self._httpx_client = SyncHttpxClient(**client_args)
449
+ self._async_httpx_client = AsyncHttpxClient(**async_client_args)
450
+
451
+ @staticmethod
452
+ def _ensure_ssl_ctx(options: HttpOptions) -> (
453
+ Tuple[dict[str, Any], dict[str, Any]]):
454
+ """Ensures the SSL context is present in the client args.
455
+
456
+ Creates a default SSL context if one is not provided.
457
+
458
+ Args:
459
+ options: The http options to check for SSL context.
460
+
461
+ Returns:
462
+ A tuple of sync/async httpx client args.
463
+ """
464
+
465
+ verify = 'verify'
466
+ args = options.client_args
467
+ async_args = options.async_client_args
468
+ ctx = (
469
+ args.get(verify) if args else None
470
+ or async_args.get(verify) if async_args else None
471
+ )
472
+
473
+ if not ctx:
474
+ # Initialize the SSL context for the httpx client.
475
+ # Unlike requests, the httpx package does not automatically pull in the
476
+ # environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be
477
+ # enabled explicitly.
478
+ ctx = ssl.create_default_context(
479
+ cafile=os.environ.get('SSL_CERT_FILE', certifi.where()),
480
+ capath=os.environ.get('SSL_CERT_DIR'),
481
+ )
482
+
483
+ def _maybe_set(
484
+ args: Optional[dict[str, Any]],
485
+ ctx: ssl.SSLContext,
486
+ ) -> dict[str, Any]:
487
+ """Sets the SSL context in the client args if not set.
488
+
489
+ Does not override the SSL context if it is already set.
490
+
491
+ Args:
492
+ args: The client args to to check for SSL context.
493
+ ctx: The SSL context to set.
494
+
495
+ Returns:
496
+ The client args with the SSL context included.
497
+ """
498
+ if not args or not args.get(verify):
499
+ args = (args or {}).copy()
500
+ args[verify] = ctx
501
+ return args
502
+
503
+ return (
504
+ _maybe_set(args, ctx),
505
+ _maybe_set(async_args, ctx),
506
+ )
507
+
508
+ def _websocket_base_url(self) -> str:
440
509
  url_parts = urlparse(self._http_options.base_url)
441
- return url_parts._replace(scheme='wss').geturl()
510
+ return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
442
511
 
443
512
  def _access_token(self) -> str:
444
513
  """Retrieves the access token for the credentials."""
@@ -453,11 +522,11 @@ class BaseApiClient:
453
522
  _refresh_auth(self._credentials)
454
523
  if not self._credentials.token:
455
524
  raise RuntimeError('Could not resolve API token from the environment')
456
- return self._credentials.token
525
+ return self._credentials.token # type: ignore[no-any-return]
457
526
  else:
458
527
  raise RuntimeError('Could not resolve API token from the environment')
459
528
 
460
- async def _async_access_token(self) -> str:
529
+ async def _async_access_token(self) -> Union[str, Any]:
461
530
  """Retrieves the access token for the credentials asynchronously."""
462
531
  if not self._credentials:
463
532
  async with self._auth_lock:
@@ -607,7 +676,7 @@ class BaseApiClient:
607
676
 
608
677
  async def _async_request(
609
678
  self, http_request: HttpRequest, stream: bool = False
610
- ):
679
+ ) -> HttpResponse:
611
680
  data: Optional[Union[str, bytes]] = None
612
681
  if self.vertexai and not self.api_key:
613
682
  http_request.headers['Authorization'] = (
@@ -667,7 +736,7 @@ class BaseApiClient:
667
736
  path: str,
668
737
  request_dict: dict[str, object],
669
738
  http_options: Optional[HttpOptionsOrDict] = None,
670
- ):
739
+ ) -> Union[BaseResponse, Any]:
671
740
  http_request = self._build_request(
672
741
  http_method, path, request_dict, http_options
673
742
  )
@@ -685,7 +754,7 @@ class BaseApiClient:
685
754
  path: str,
686
755
  request_dict: dict[str, object],
687
756
  http_options: Optional[HttpOptionsOrDict] = None,
688
- ):
757
+ ) -> Generator[Any, None, None]:
689
758
  http_request = self._build_request(
690
759
  http_method, path, request_dict, http_options
691
760
  )
@@ -700,7 +769,7 @@ class BaseApiClient:
700
769
  path: str,
701
770
  request_dict: dict[str, object],
702
771
  http_options: Optional[HttpOptionsOrDict] = None,
703
- ) -> dict[str, object]:
772
+ ) -> Union[BaseResponse, Any]:
704
773
  http_request = self._build_request(
705
774
  http_method, path, request_dict, http_options
706
775
  )
@@ -717,18 +786,18 @@ class BaseApiClient:
717
786
  path: str,
718
787
  request_dict: dict[str, object],
719
788
  http_options: Optional[HttpOptionsOrDict] = None,
720
- ):
789
+ ) -> Any:
721
790
  http_request = self._build_request(
722
791
  http_method, path, request_dict, http_options
723
792
  )
724
793
 
725
794
  response = await self._async_request(http_request=http_request, stream=True)
726
795
 
727
- async def async_generator():
796
+ async def async_generator(): # type: ignore[no-untyped-def]
728
797
  async for chunk in response:
729
798
  yield chunk
730
799
 
731
- return async_generator()
800
+ return async_generator() # type: ignore[no-untyped-call]
732
801
 
733
802
  def upload_file(
734
803
  self,
@@ -840,7 +909,7 @@ class BaseApiClient:
840
909
  path: str,
841
910
  *,
842
911
  http_options: Optional[HttpOptionsOrDict] = None,
843
- ):
912
+ ) -> Union[Any,bytes]:
844
913
  """Downloads the file data.
845
914
 
846
915
  Args:
@@ -909,7 +978,7 @@ class BaseApiClient:
909
978
 
910
979
  async def _async_upload_fd(
911
980
  self,
912
- file: Union[io.IOBase, anyio.AsyncFile],
981
+ file: Union[io.IOBase, anyio.AsyncFile[Any]],
913
982
  upload_url: str,
914
983
  upload_size: int,
915
984
  *,
@@ -988,7 +1057,7 @@ class BaseApiClient:
988
1057
  path: str,
989
1058
  *,
990
1059
  http_options: Optional[HttpOptionsOrDict] = None,
991
- ):
1060
+ ) -> Union[Any, bytes]:
992
1061
  """Downloads the file data.
993
1062
 
994
1063
  Args:
@@ -1025,5 +1094,5 @@ class BaseApiClient:
1025
1094
  # This method does nothing in the real api client. It is used in the
1026
1095
  # replay_api_client to verify the response from the SDK method matches the
1027
1096
  # recorded response.
1028
- def _verify_response(self, response_model: _common.BaseModel):
1097
+ def _verify_response(self, response_model: _common.BaseModel) -> None:
1029
1098
  pass
@@ -46,19 +46,6 @@ def _is_builtin_primitive_or_compound(
46
46
  return annotation in _py_builtin_type_to_schema_type.keys()
47
47
 
48
48
 
49
- def _raise_for_default_if_mldev(schema: types.Schema):
50
- if schema.default is not None:
51
- raise ValueError(
52
- 'Default value is not supported in function declaration schema for'
53
- ' the Gemini API.'
54
- )
55
-
56
-
57
- def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
58
- if api_option == 'GEMINI_API':
59
- _raise_for_default_if_mldev(schema)
60
-
61
-
62
49
  def _is_default_value_compatible(
63
50
  default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
64
51
  ) -> bool:
@@ -72,16 +59,16 @@ def _is_default_value_compatible(
72
59
  or isinstance(annotation, VersionedUnionType)
73
60
  ):
74
61
  origin = get_origin(annotation)
75
- if origin in (Union, VersionedUnionType):
62
+ if origin in (Union, VersionedUnionType): # type: ignore[comparison-overlap]
76
63
  return any(
77
64
  _is_default_value_compatible(default_value, arg)
78
65
  for arg in get_args(annotation)
79
66
  )
80
67
 
81
- if origin is dict:
68
+ if origin is dict: # type: ignore[comparison-overlap]
82
69
  return isinstance(default_value, dict)
83
70
 
84
- if origin is list:
71
+ if origin is list: # type: ignore[comparison-overlap]
85
72
  if not isinstance(default_value, list):
86
73
  return False
87
74
  # most tricky case, element in list is union type
@@ -97,7 +84,7 @@ def _is_default_value_compatible(
97
84
  for item in default_value
98
85
  )
99
86
 
100
- if origin is Literal:
87
+ if origin is Literal: # type: ignore[comparison-overlap]
101
88
  return default_value in get_args(annotation)
102
89
 
103
90
  # return False for any other unrecognized annotation
@@ -125,7 +112,6 @@ def _parse_schema_from_parameter(
125
112
  raise ValueError(default_value_error_msg)
126
113
  schema.default = param.default
127
114
  schema.type = _py_builtin_type_to_schema_type[param.annotation]
128
- _raise_if_schema_unsupported(api_option, schema)
129
115
  return schema
130
116
  if (
131
117
  isinstance(param.annotation, VersionedUnionType)
@@ -166,7 +152,6 @@ def _parse_schema_from_parameter(
166
152
  if not _is_default_value_compatible(param.default, param.annotation):
167
153
  raise ValueError(default_value_error_msg)
168
154
  schema.default = param.default
169
- _raise_if_schema_unsupported(api_option, schema)
170
155
  return schema
171
156
  if isinstance(param.annotation, _GenericAlias) or isinstance(
172
157
  param.annotation, builtin_types.GenericAlias
@@ -179,7 +164,6 @@ def _parse_schema_from_parameter(
179
164
  if not _is_default_value_compatible(param.default, param.annotation):
180
165
  raise ValueError(default_value_error_msg)
181
166
  schema.default = param.default
182
- _raise_if_schema_unsupported(api_option, schema)
183
167
  return schema
184
168
  if origin is Literal:
185
169
  if not all(isinstance(arg, str) for arg in args):
@@ -192,7 +176,6 @@ def _parse_schema_from_parameter(
192
176
  if not _is_default_value_compatible(param.default, param.annotation):
193
177
  raise ValueError(default_value_error_msg)
194
178
  schema.default = param.default
195
- _raise_if_schema_unsupported(api_option, schema)
196
179
  return schema
197
180
  if origin is list:
198
181
  schema.type = _py_builtin_type_to_schema_type[list]
@@ -209,7 +192,6 @@ def _parse_schema_from_parameter(
209
192
  if not _is_default_value_compatible(param.default, param.annotation):
210
193
  raise ValueError(default_value_error_msg)
211
194
  schema.default = param.default
212
- _raise_if_schema_unsupported(api_option, schema)
213
195
  return schema
214
196
  if origin is Union:
215
197
  schema.any_of = []
@@ -259,7 +241,6 @@ def _parse_schema_from_parameter(
259
241
  if not _is_default_value_compatible(param.default, param.annotation):
260
242
  raise ValueError(default_value_error_msg)
261
243
  schema.default = param.default
262
- _raise_if_schema_unsupported(api_option, schema)
263
244
  return schema
264
245
  # all other generic alias will be invoked in raise branch
265
246
  if (
@@ -284,7 +265,6 @@ def _parse_schema_from_parameter(
284
265
  func_name,
285
266
  )
286
267
  schema.required = _get_required_fields(schema)
287
- _raise_if_schema_unsupported(api_option, schema)
288
268
  return schema
289
269
  raise ValueError(
290
270
  f'Failed to parse the parameter {param} of function {func_name} for'
google/genai/_common.py CHANGED
@@ -20,7 +20,7 @@ import datetime
20
20
  import enum
21
21
  import functools
22
22
  import typing
23
- from typing import Any, Union
23
+ from typing import Any, Callable, Optional, Union
24
24
  import uuid
25
25
  import warnings
26
26
 
@@ -31,7 +31,7 @@ from . import _api_client
31
31
  from . import errors
32
32
 
33
33
 
34
- def set_value_by_path(data, keys, value):
34
+ def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None:
35
35
  """Examples:
36
36
 
37
37
  set_value_by_path({}, ['a', 'b'], v)
@@ -46,54 +46,57 @@ def set_value_by_path(data, keys, value):
46
46
  for i, key in enumerate(keys[:-1]):
47
47
  if key.endswith('[]'):
48
48
  key_name = key[:-2]
49
- if key_name not in data:
49
+ if data is not None and key_name not in data:
50
50
  if isinstance(value, list):
51
51
  data[key_name] = [{} for _ in range(len(value))]
52
52
  else:
53
53
  raise ValueError(
54
54
  f'value {value} must be a list given an array path {key}'
55
55
  )
56
- if isinstance(value, list):
56
+ if isinstance(value, list) and data is not None:
57
57
  for j, d in enumerate(data[key_name]):
58
58
  set_value_by_path(d, keys[i + 1 :], value[j])
59
59
  else:
60
- for d in data[key_name]:
61
- set_value_by_path(d, keys[i + 1 :], value)
60
+ if data is not None:
61
+ for d in data[key_name]:
62
+ set_value_by_path(d, keys[i + 1 :], value)
62
63
  return
63
64
  elif key.endswith('[0]'):
64
65
  key_name = key[:-3]
65
- if key_name not in data:
66
+ if data is not None and key_name not in data:
66
67
  data[key_name] = [{}]
67
- set_value_by_path(data[key_name][0], keys[i + 1 :], value)
68
+ if data is not None:
69
+ set_value_by_path(data[key_name][0], keys[i + 1 :], value)
68
70
  return
69
-
70
- data = data.setdefault(key, {})
71
-
72
- existing_data = data.get(keys[-1])
73
- # If there is an existing value, merge, not overwrite.
74
- if existing_data is not None:
75
- # Don't overwrite existing non-empty value with new empty value.
76
- # This is triggered when handling tuning datasets.
77
- if not value:
78
- pass
79
- # Don't fail when overwriting value with same value
80
- elif value == existing_data:
81
- pass
82
- # Instead of overwriting dictionary with another dictionary, merge them.
83
- # This is important for handling training and validation datasets in tuning.
84
- elif isinstance(existing_data, dict) and isinstance(value, dict):
85
- # Merging dictionaries. Consider deep merging in the future.
86
- existing_data.update(value)
71
+ if data is not None:
72
+ data = data.setdefault(key, {})
73
+
74
+ if data is not None:
75
+ existing_data = data.get(keys[-1])
76
+ # If there is an existing value, merge, not overwrite.
77
+ if existing_data is not None:
78
+ # Don't overwrite existing non-empty value with new empty value.
79
+ # This is triggered when handling tuning datasets.
80
+ if not value:
81
+ pass
82
+ # Don't fail when overwriting value with same value
83
+ elif value == existing_data:
84
+ pass
85
+ # Instead of overwriting dictionary with another dictionary, merge them.
86
+ # This is important for handling training and validation datasets in tuning.
87
+ elif isinstance(existing_data, dict) and isinstance(value, dict):
88
+ # Merging dictionaries. Consider deep merging in the future.
89
+ existing_data.update(value)
90
+ else:
91
+ raise ValueError(
92
+ f'Cannot set value for an existing key. Key: {keys[-1]};'
93
+ f' Existing value: {existing_data}; New value: {value}.'
94
+ )
87
95
  else:
88
- raise ValueError(
89
- f'Cannot set value for an existing key. Key: {keys[-1]};'
90
- f' Existing value: {existing_data}; New value: {value}.'
91
- )
92
- else:
93
- data[keys[-1]] = value
96
+ data[keys[-1]] = value
94
97
 
95
98
 
96
- def get_value_by_path(data: Any, keys: list[str]):
99
+ def get_value_by_path(data: Any, keys: list[str]) -> Any:
97
100
  """Examples:
98
101
 
99
102
  get_value_by_path({'a': {'b': v}}, ['a', 'b'])
@@ -227,7 +230,7 @@ class CaseInSensitiveEnum(str, enum.Enum):
227
230
  """Case insensitive enum."""
228
231
 
229
232
  @classmethod
230
- def _missing_(cls, value):
233
+ def _missing_(cls, value: Any) -> Any:
231
234
  try:
232
235
  return cls[value.upper()] # Try to access directly with uppercase
233
236
  except KeyError:
@@ -295,12 +298,12 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
295
298
  return processed_data
296
299
 
297
300
 
298
- def experimental_warning(message: str):
301
+ def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
299
302
  """Experimental warning, only warns once."""
300
- def decorator(func):
303
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
301
304
  warning_done = False
302
305
  @functools.wraps(func)
303
- def wrapper(*args, **kwargs):
306
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
304
307
  nonlocal warning_done
305
308
  if not warning_done:
306
309
  warning_done = True
@@ -78,16 +78,17 @@ def format_destination(
78
78
 
79
79
  def get_function_map(
80
80
  config: Optional[types.GenerateContentConfigOrDict] = None,
81
- ) -> dict[str, Callable]:
81
+ is_caller_method_async: bool = False,
82
+ ) -> dict[str, Callable[..., Any]]:
82
83
  """Returns a function map from the config."""
83
- function_map: dict[str, Callable] = {}
84
+ function_map: dict[str, Callable[..., Any]] = {}
84
85
  if not config:
85
86
  return function_map
86
87
  config_model = _create_generate_content_config_model(config)
87
88
  if config_model.tools:
88
89
  for tool in config_model.tools:
89
90
  if callable(tool):
90
- if inspect.iscoroutinefunction(tool):
91
+ if inspect.iscoroutinefunction(tool) and not is_caller_method_async:
91
92
  raise errors.UnsupportedFunctionError(
92
93
  f'Function {tool.__name__} is a coroutine function, which is not'
93
94
  ' supported for automatic function calling. Please manually'
@@ -199,11 +200,11 @@ def convert_if_exist_pydantic_model(
199
200
  return value
200
201
 
201
202
 
202
- def invoke_function_from_dict_args(
203
- args: Dict[str, Any], function_to_invoke: Callable
204
- ) -> Any:
205
- signature = inspect.signature(function_to_invoke)
206
- func_name = function_to_invoke.__name__
203
+ def convert_argument_from_function(
204
+ args: dict[str, Any], function: Callable[..., Any]
205
+ ) -> dict[str, Any]:
206
+ signature = inspect.signature(function)
207
+ func_name = function.__name__
207
208
  converted_args = {}
208
209
  for param_name, param in signature.parameters.items():
209
210
  if param_name in args:
@@ -213,19 +214,40 @@ def invoke_function_from_dict_args(
213
214
  param_name,
214
215
  func_name,
215
216
  )
217
+ return converted_args
218
+
219
+
220
+ def invoke_function_from_dict_args(
221
+ args: Dict[str, Any], function_to_invoke: Callable[..., Any]
222
+ ) -> Any:
223
+ converted_args = convert_argument_from_function(args, function_to_invoke)
216
224
  try:
217
225
  return function_to_invoke(**converted_args)
218
226
  except Exception as e:
219
227
  raise errors.FunctionInvocationError(
220
- f'Failed to invoke function {func_name} with converted arguments'
221
- f' {converted_args} from model returned function call argument'
222
- f' {args} because of error {e}'
228
+ f'Failed to invoke function {function_to_invoke.__name__} with'
229
+ f' converted arguments {converted_args} from model returned function'
230
+ f' call argument {args} because of error {e}'
231
+ )
232
+
233
+
234
+ async def invoke_function_from_dict_args_async(
235
+ args: Dict[str, Any], function_to_invoke: Callable[..., Any]
236
+ ) -> Any:
237
+ converted_args = convert_argument_from_function(args, function_to_invoke)
238
+ try:
239
+ return await function_to_invoke(**converted_args)
240
+ except Exception as e:
241
+ raise errors.FunctionInvocationError(
242
+ f'Failed to invoke function {function_to_invoke.__name__} with'
243
+ f' converted arguments {converted_args} from model returned function'
244
+ f' call argument {args} because of error {e}'
223
245
  )
224
246
 
225
247
 
226
248
  def get_function_response_parts(
227
249
  response: types.GenerateContentResponse,
228
- function_map: dict[str, Callable],
250
+ function_map: dict[str, Callable[..., Any]],
229
251
  ) -> list[types.Part]:
230
252
  """Returns the function response parts from the response."""
231
253
  func_response_parts = []
@@ -256,6 +278,44 @@ def get_function_response_parts(
256
278
  func_response_parts.append(func_response_part)
257
279
  return func_response_parts
258
280
 
281
+ async def get_function_response_parts_async(
282
+ response: types.GenerateContentResponse,
283
+ function_map: dict[str, Callable[..., Any]],
284
+ ) -> list[types.Part]:
285
+ """Returns the function response parts from the response."""
286
+ func_response_parts = []
287
+ if (
288
+ response.candidates is not None
289
+ and isinstance(response.candidates[0].content, types.Content)
290
+ and response.candidates[0].content.parts is not None
291
+ ):
292
+ for part in response.candidates[0].content.parts:
293
+ if not part.function_call:
294
+ continue
295
+ func_name = part.function_call.name
296
+ if func_name is not None and part.function_call.args is not None:
297
+ func = function_map[func_name]
298
+ args = convert_number_values_for_dict_function_call_args(
299
+ part.function_call.args
300
+ )
301
+ func_response: dict[str, Any]
302
+ try:
303
+ if inspect.iscoroutinefunction(func):
304
+ func_response = {
305
+ 'result': await invoke_function_from_dict_args_async(args, func)
306
+ }
307
+ else:
308
+ func_response = {
309
+ 'result': invoke_function_from_dict_args(args, func)
310
+ }
311
+ except Exception as e: # pylint: disable=broad-except
312
+ func_response = {'error': str(e)}
313
+ func_response_part = types.Part.from_function_response(
314
+ name=func_name, response=func_response
315
+ )
316
+ func_response_parts.append(func_response_part)
317
+ return func_response_parts
318
+
259
319
 
260
320
  def should_disable_afc(
261
321
  config: Optional[types.GenerateContentConfigOrDict] = None,