google-genai 1.11.0__py3-none-any.whl → 1.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,6 +20,7 @@ The BaseApiClient is intended to be a private module and is subject to change.
20
20
  """
21
21
 
22
22
  import asyncio
23
+ from collections.abc import Awaitable, Generator
23
24
  import copy
24
25
  from dataclasses import dataclass
25
26
  import datetime
@@ -129,7 +130,7 @@ def _join_url_path(base_url: str, path: str) -> str:
129
130
 
130
131
  def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
131
132
  """Loads google auth credentials and project id."""
132
- credentials, loaded_project_id = google.auth.default(
133
+ credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call]
133
134
  scopes=['https://www.googleapis.com/auth/cloud-platform'],
134
135
  )
135
136
 
@@ -145,7 +146,7 @@ def _load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]:
145
146
 
146
147
 
147
148
  def _refresh_auth(credentials: Credentials) -> Credentials:
148
- credentials.refresh(Request())
149
+ credentials.refresh(Request()) # type: ignore[no-untyped-call]
149
150
  return credentials
150
151
 
151
152
 
@@ -191,17 +192,17 @@ class HttpResponse:
191
192
  response_stream: Union[Any, str] = None,
192
193
  byte_stream: Union[Any, bytes] = None,
193
194
  ):
194
- self.status_code = 200
195
+ self.status_code: int = 200
195
196
  self.headers = headers
196
197
  self.response_stream = response_stream
197
198
  self.byte_stream = byte_stream
198
199
 
199
200
  # Async iterator for async streaming.
200
- def __aiter__(self):
201
+ def __aiter__(self) -> 'HttpResponse':
201
202
  self.segment_iterator = self.async_segments()
202
203
  return self
203
204
 
204
- async def __anext__(self):
205
+ async def __anext__(self) -> Any:
205
206
  try:
206
207
  return await self.segment_iterator.__anext__()
207
208
  except StopIteration:
@@ -213,7 +214,7 @@ class HttpResponse:
213
214
  return ''
214
215
  return json.loads(self.response_stream[0])
215
216
 
216
- def segments(self):
217
+ def segments(self) -> Generator[Any, None, None]:
217
218
  if isinstance(self.response_stream, list):
218
219
  # list of objects retrieved from replay or from non-streaming API.
219
220
  for chunk in self.response_stream:
@@ -222,7 +223,7 @@ class HttpResponse:
222
223
  yield from []
223
224
  else:
224
225
  # Iterator of objects retrieved from the API.
225
- for chunk in self.response_stream.iter_lines():
226
+ for chunk in self.response_stream.iter_lines(): # type: ignore[union-attr]
226
227
  if chunk:
227
228
  # In streaming mode, the chunk of JSON is prefixed with "data:" which
228
229
  # we must strip before parsing.
@@ -256,7 +257,7 @@ class HttpResponse:
256
257
  else:
257
258
  raise ValueError('Error parsing streaming response.')
258
259
 
259
- def byte_segments(self):
260
+ def byte_segments(self) -> Generator[Union[bytes, Any], None, None]:
260
261
  if isinstance(self.byte_stream, list):
261
262
  # list of objects retrieved from replay or from non-streaming API.
262
263
  yield from self.byte_stream
@@ -267,7 +268,7 @@ class HttpResponse:
267
268
  'Byte segments are not supported for streaming responses.'
268
269
  )
269
270
 
270
- def _copy_to_dict(self, response_payload: dict[str, object]):
271
+ def _copy_to_dict(self, response_payload: dict[str, object]) -> None:
271
272
  # Cannot pickle 'generator' object.
272
273
  delattr(self, 'segment_iterator')
273
274
  for attribute in dir(self):
@@ -282,15 +283,6 @@ class SyncHttpxClient(httpx.Client):
282
283
  kwargs.setdefault('follow_redirects', True)
283
284
  super().__init__(**kwargs)
284
285
 
285
- def __del__(self) -> None:
286
- """Closes the httpx client."""
287
- if self.is_closed:
288
- return
289
- try:
290
- self.close()
291
- except Exception:
292
- pass
293
-
294
286
 
295
287
  class AsyncHttpxClient(httpx.AsyncClient):
296
288
  """Async httpx client."""
@@ -300,14 +292,6 @@ class AsyncHttpxClient(httpx.AsyncClient):
300
292
  kwargs.setdefault('follow_redirects', True)
301
293
  super().__init__(**kwargs)
302
294
 
303
- def __del__(self) -> None:
304
- if self.is_closed:
305
- return
306
- try:
307
- asyncio.get_running_loop().create_task(self.aclose())
308
- except Exception:
309
- pass
310
-
311
295
 
312
296
  class BaseApiClient:
313
297
  """Client for calling HTTP APIs sending and receiving JSON."""
@@ -504,9 +488,9 @@ class BaseApiClient:
504
488
  _maybe_set(async_args, ctx),
505
489
  )
506
490
 
507
- def _websocket_base_url(self):
491
+ def _websocket_base_url(self) -> str:
508
492
  url_parts = urlparse(self._http_options.base_url)
509
- return url_parts._replace(scheme='wss').geturl()
493
+ return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
510
494
 
511
495
  def _access_token(self) -> str:
512
496
  """Retrieves the access token for the credentials."""
@@ -521,11 +505,11 @@ class BaseApiClient:
521
505
  _refresh_auth(self._credentials)
522
506
  if not self._credentials.token:
523
507
  raise RuntimeError('Could not resolve API token from the environment')
524
- return self._credentials.token
508
+ return self._credentials.token # type: ignore[no-any-return]
525
509
  else:
526
510
  raise RuntimeError('Could not resolve API token from the environment')
527
511
 
528
- async def _async_access_token(self) -> str:
512
+ async def _async_access_token(self) -> Union[str, Any]:
529
513
  """Retrieves the access token for the credentials asynchronously."""
530
514
  if not self._credentials:
531
515
  async with self._auth_lock:
@@ -675,7 +659,7 @@ class BaseApiClient:
675
659
 
676
660
  async def _async_request(
677
661
  self, http_request: HttpRequest, stream: bool = False
678
- ):
662
+ ) -> HttpResponse:
679
663
  data: Optional[Union[str, bytes]] = None
680
664
  if self.vertexai and not self.api_key:
681
665
  http_request.headers['Authorization'] = (
@@ -735,7 +719,7 @@ class BaseApiClient:
735
719
  path: str,
736
720
  request_dict: dict[str, object],
737
721
  http_options: Optional[HttpOptionsOrDict] = None,
738
- ):
722
+ ) -> Union[BaseResponse, Any]:
739
723
  http_request = self._build_request(
740
724
  http_method, path, request_dict, http_options
741
725
  )
@@ -753,7 +737,7 @@ class BaseApiClient:
753
737
  path: str,
754
738
  request_dict: dict[str, object],
755
739
  http_options: Optional[HttpOptionsOrDict] = None,
756
- ):
740
+ ) -> Generator[Any, None, None]:
757
741
  http_request = self._build_request(
758
742
  http_method, path, request_dict, http_options
759
743
  )
@@ -768,7 +752,7 @@ class BaseApiClient:
768
752
  path: str,
769
753
  request_dict: dict[str, object],
770
754
  http_options: Optional[HttpOptionsOrDict] = None,
771
- ) -> dict[str, object]:
755
+ ) -> Union[BaseResponse, Any]:
772
756
  http_request = self._build_request(
773
757
  http_method, path, request_dict, http_options
774
758
  )
@@ -785,18 +769,18 @@ class BaseApiClient:
785
769
  path: str,
786
770
  request_dict: dict[str, object],
787
771
  http_options: Optional[HttpOptionsOrDict] = None,
788
- ):
772
+ ) -> Any:
789
773
  http_request = self._build_request(
790
774
  http_method, path, request_dict, http_options
791
775
  )
792
776
 
793
777
  response = await self._async_request(http_request=http_request, stream=True)
794
778
 
795
- async def async_generator():
779
+ async def async_generator(): # type: ignore[no-untyped-def]
796
780
  async for chunk in response:
797
781
  yield chunk
798
782
 
799
- return async_generator()
783
+ return async_generator() # type: ignore[no-untyped-call]
800
784
 
801
785
  def upload_file(
802
786
  self,
@@ -908,7 +892,7 @@ class BaseApiClient:
908
892
  path: str,
909
893
  *,
910
894
  http_options: Optional[HttpOptionsOrDict] = None,
911
- ):
895
+ ) -> Union[Any,bytes]:
912
896
  """Downloads the file data.
913
897
 
914
898
  Args:
@@ -977,7 +961,7 @@ class BaseApiClient:
977
961
 
978
962
  async def _async_upload_fd(
979
963
  self,
980
- file: Union[io.IOBase, anyio.AsyncFile],
964
+ file: Union[io.IOBase, anyio.AsyncFile[Any]],
981
965
  upload_url: str,
982
966
  upload_size: int,
983
967
  *,
@@ -1056,7 +1040,7 @@ class BaseApiClient:
1056
1040
  path: str,
1057
1041
  *,
1058
1042
  http_options: Optional[HttpOptionsOrDict] = None,
1059
- ):
1043
+ ) -> Union[Any, bytes]:
1060
1044
  """Downloads the file data.
1061
1045
 
1062
1046
  Args:
@@ -1093,5 +1077,5 @@ class BaseApiClient:
1093
1077
  # This method does nothing in the real api client. It is used in the
1094
1078
  # replay_api_client to verify the response from the SDK method matches the
1095
1079
  # recorded response.
1096
- def _verify_response(self, response_model: _common.BaseModel):
1080
+ def _verify_response(self, response_model: _common.BaseModel) -> None:
1097
1081
  pass
@@ -46,19 +46,6 @@ def _is_builtin_primitive_or_compound(
46
46
  return annotation in _py_builtin_type_to_schema_type.keys()
47
47
 
48
48
 
49
- def _raise_for_default_if_mldev(schema: types.Schema):
50
- if schema.default is not None:
51
- raise ValueError(
52
- 'Default value is not supported in function declaration schema for'
53
- ' the Gemini API.'
54
- )
55
-
56
-
57
- def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
58
- if api_option == 'GEMINI_API':
59
- _raise_for_default_if_mldev(schema)
60
-
61
-
62
49
  def _is_default_value_compatible(
63
50
  default_value: Any, annotation: inspect.Parameter.annotation # type: ignore[valid-type]
64
51
  ) -> bool:
@@ -72,16 +59,16 @@ def _is_default_value_compatible(
72
59
  or isinstance(annotation, VersionedUnionType)
73
60
  ):
74
61
  origin = get_origin(annotation)
75
- if origin in (Union, VersionedUnionType):
62
+ if origin in (Union, VersionedUnionType): # type: ignore[comparison-overlap]
76
63
  return any(
77
64
  _is_default_value_compatible(default_value, arg)
78
65
  for arg in get_args(annotation)
79
66
  )
80
67
 
81
- if origin is dict:
68
+ if origin is dict: # type: ignore[comparison-overlap]
82
69
  return isinstance(default_value, dict)
83
70
 
84
- if origin is list:
71
+ if origin is list: # type: ignore[comparison-overlap]
85
72
  if not isinstance(default_value, list):
86
73
  return False
87
74
  # most tricky case, element in list is union type
@@ -97,7 +84,7 @@ def _is_default_value_compatible(
97
84
  for item in default_value
98
85
  )
99
86
 
100
- if origin is Literal:
87
+ if origin is Literal: # type: ignore[comparison-overlap]
101
88
  return default_value in get_args(annotation)
102
89
 
103
90
  # return False for any other unrecognized annotation
@@ -125,7 +112,6 @@ def _parse_schema_from_parameter(
125
112
  raise ValueError(default_value_error_msg)
126
113
  schema.default = param.default
127
114
  schema.type = _py_builtin_type_to_schema_type[param.annotation]
128
- _raise_if_schema_unsupported(api_option, schema)
129
115
  return schema
130
116
  if (
131
117
  isinstance(param.annotation, VersionedUnionType)
@@ -166,7 +152,6 @@ def _parse_schema_from_parameter(
166
152
  if not _is_default_value_compatible(param.default, param.annotation):
167
153
  raise ValueError(default_value_error_msg)
168
154
  schema.default = param.default
169
- _raise_if_schema_unsupported(api_option, schema)
170
155
  return schema
171
156
  if isinstance(param.annotation, _GenericAlias) or isinstance(
172
157
  param.annotation, builtin_types.GenericAlias
@@ -179,7 +164,6 @@ def _parse_schema_from_parameter(
179
164
  if not _is_default_value_compatible(param.default, param.annotation):
180
165
  raise ValueError(default_value_error_msg)
181
166
  schema.default = param.default
182
- _raise_if_schema_unsupported(api_option, schema)
183
167
  return schema
184
168
  if origin is Literal:
185
169
  if not all(isinstance(arg, str) for arg in args):
@@ -192,7 +176,6 @@ def _parse_schema_from_parameter(
192
176
  if not _is_default_value_compatible(param.default, param.annotation):
193
177
  raise ValueError(default_value_error_msg)
194
178
  schema.default = param.default
195
- _raise_if_schema_unsupported(api_option, schema)
196
179
  return schema
197
180
  if origin is list:
198
181
  schema.type = _py_builtin_type_to_schema_type[list]
@@ -209,7 +192,6 @@ def _parse_schema_from_parameter(
209
192
  if not _is_default_value_compatible(param.default, param.annotation):
210
193
  raise ValueError(default_value_error_msg)
211
194
  schema.default = param.default
212
- _raise_if_schema_unsupported(api_option, schema)
213
195
  return schema
214
196
  if origin is Union:
215
197
  schema.any_of = []
@@ -259,7 +241,6 @@ def _parse_schema_from_parameter(
259
241
  if not _is_default_value_compatible(param.default, param.annotation):
260
242
  raise ValueError(default_value_error_msg)
261
243
  schema.default = param.default
262
- _raise_if_schema_unsupported(api_option, schema)
263
244
  return schema
264
245
  # all other generic alias will be invoked in raise branch
265
246
  if (
@@ -284,7 +265,6 @@ def _parse_schema_from_parameter(
284
265
  func_name,
285
266
  )
286
267
  schema.required = _get_required_fields(schema)
287
- _raise_if_schema_unsupported(api_option, schema)
288
268
  return schema
289
269
  raise ValueError(
290
270
  f'Failed to parse the parameter {param} of function {func_name} for'
google/genai/_common.py CHANGED
@@ -20,7 +20,7 @@ import datetime
20
20
  import enum
21
21
  import functools
22
22
  import typing
23
- from typing import Any, Union
23
+ from typing import Any, Callable, Optional, Union
24
24
  import uuid
25
25
  import warnings
26
26
 
@@ -31,7 +31,7 @@ from . import _api_client
31
31
  from . import errors
32
32
 
33
33
 
34
- def set_value_by_path(data, keys, value):
34
+ def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None:
35
35
  """Examples:
36
36
 
37
37
  set_value_by_path({}, ['a', 'b'], v)
@@ -46,54 +46,57 @@ def set_value_by_path(data, keys, value):
46
46
  for i, key in enumerate(keys[:-1]):
47
47
  if key.endswith('[]'):
48
48
  key_name = key[:-2]
49
- if key_name not in data:
49
+ if data is not None and key_name not in data:
50
50
  if isinstance(value, list):
51
51
  data[key_name] = [{} for _ in range(len(value))]
52
52
  else:
53
53
  raise ValueError(
54
54
  f'value {value} must be a list given an array path {key}'
55
55
  )
56
- if isinstance(value, list):
56
+ if isinstance(value, list) and data is not None:
57
57
  for j, d in enumerate(data[key_name]):
58
58
  set_value_by_path(d, keys[i + 1 :], value[j])
59
59
  else:
60
- for d in data[key_name]:
61
- set_value_by_path(d, keys[i + 1 :], value)
60
+ if data is not None:
61
+ for d in data[key_name]:
62
+ set_value_by_path(d, keys[i + 1 :], value)
62
63
  return
63
64
  elif key.endswith('[0]'):
64
65
  key_name = key[:-3]
65
- if key_name not in data:
66
+ if data is not None and key_name not in data:
66
67
  data[key_name] = [{}]
67
- set_value_by_path(data[key_name][0], keys[i + 1 :], value)
68
+ if data is not None:
69
+ set_value_by_path(data[key_name][0], keys[i + 1 :], value)
68
70
  return
69
-
70
- data = data.setdefault(key, {})
71
-
72
- existing_data = data.get(keys[-1])
73
- # If there is an existing value, merge, not overwrite.
74
- if existing_data is not None:
75
- # Don't overwrite existing non-empty value with new empty value.
76
- # This is triggered when handling tuning datasets.
77
- if not value:
78
- pass
79
- # Don't fail when overwriting value with same value
80
- elif value == existing_data:
81
- pass
82
- # Instead of overwriting dictionary with another dictionary, merge them.
83
- # This is important for handling training and validation datasets in tuning.
84
- elif isinstance(existing_data, dict) and isinstance(value, dict):
85
- # Merging dictionaries. Consider deep merging in the future.
86
- existing_data.update(value)
71
+ if data is not None:
72
+ data = data.setdefault(key, {})
73
+
74
+ if data is not None:
75
+ existing_data = data.get(keys[-1])
76
+ # If there is an existing value, merge, not overwrite.
77
+ if existing_data is not None:
78
+ # Don't overwrite existing non-empty value with new empty value.
79
+ # This is triggered when handling tuning datasets.
80
+ if not value:
81
+ pass
82
+ # Don't fail when overwriting value with same value
83
+ elif value == existing_data:
84
+ pass
85
+ # Instead of overwriting dictionary with another dictionary, merge them.
86
+ # This is important for handling training and validation datasets in tuning.
87
+ elif isinstance(existing_data, dict) and isinstance(value, dict):
88
+ # Merging dictionaries. Consider deep merging in the future.
89
+ existing_data.update(value)
90
+ else:
91
+ raise ValueError(
92
+ f'Cannot set value for an existing key. Key: {keys[-1]};'
93
+ f' Existing value: {existing_data}; New value: {value}.'
94
+ )
87
95
  else:
88
- raise ValueError(
89
- f'Cannot set value for an existing key. Key: {keys[-1]};'
90
- f' Existing value: {existing_data}; New value: {value}.'
91
- )
92
- else:
93
- data[keys[-1]] = value
96
+ data[keys[-1]] = value
94
97
 
95
98
 
96
- def get_value_by_path(data: Any, keys: list[str]):
99
+ def get_value_by_path(data: Any, keys: list[str]) -> Any:
97
100
  """Examples:
98
101
 
99
102
  get_value_by_path({'a': {'b': v}}, ['a', 'b'])
@@ -227,7 +230,7 @@ class CaseInSensitiveEnum(str, enum.Enum):
227
230
  """Case insensitive enum."""
228
231
 
229
232
  @classmethod
230
- def _missing_(cls, value):
233
+ def _missing_(cls, value: Any) -> Any:
231
234
  try:
232
235
  return cls[value.upper()] # Try to access directly with uppercase
233
236
  except KeyError:
@@ -295,12 +298,12 @@ def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]:
295
298
  return processed_data
296
299
 
297
300
 
298
- def experimental_warning(message: str):
301
+ def experimental_warning(message: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
299
302
  """Experimental warning, only warns once."""
300
- def decorator(func):
303
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
301
304
  warning_done = False
302
305
  @functools.wraps(func)
303
- def wrapper(*args, **kwargs):
306
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
304
307
  nonlocal warning_done
305
308
  if not warning_done:
306
309
  warning_done = True
@@ -79,9 +79,9 @@ def format_destination(
79
79
  def get_function_map(
80
80
  config: Optional[types.GenerateContentConfigOrDict] = None,
81
81
  is_caller_method_async: bool = False,
82
- ) -> dict[str, Callable]:
82
+ ) -> dict[str, Callable[..., Any]]:
83
83
  """Returns a function map from the config."""
84
- function_map: dict[str, Callable] = {}
84
+ function_map: dict[str, Callable[..., Any]] = {}
85
85
  if not config:
86
86
  return function_map
87
87
  config_model = _create_generate_content_config_model(config)
@@ -201,7 +201,7 @@ def convert_if_exist_pydantic_model(
201
201
 
202
202
 
203
203
  def convert_argument_from_function(
204
- args: dict[str, Any], function: Callable
204
+ args: dict[str, Any], function: Callable[..., Any]
205
205
  ) -> dict[str, Any]:
206
206
  signature = inspect.signature(function)
207
207
  func_name = function.__name__
@@ -218,7 +218,7 @@ def convert_argument_from_function(
218
218
 
219
219
 
220
220
  def invoke_function_from_dict_args(
221
- args: Dict[str, Any], function_to_invoke: Callable
221
+ args: Dict[str, Any], function_to_invoke: Callable[..., Any]
222
222
  ) -> Any:
223
223
  converted_args = convert_argument_from_function(args, function_to_invoke)
224
224
  try:
@@ -232,7 +232,7 @@ def invoke_function_from_dict_args(
232
232
 
233
233
 
234
234
  async def invoke_function_from_dict_args_async(
235
- args: Dict[str, Any], function_to_invoke: Callable
235
+ args: Dict[str, Any], function_to_invoke: Callable[..., Any]
236
236
  ) -> Any:
237
237
  converted_args = convert_argument_from_function(args, function_to_invoke)
238
238
  try:
@@ -247,7 +247,7 @@ async def invoke_function_from_dict_args_async(
247
247
 
248
248
  def get_function_response_parts(
249
249
  response: types.GenerateContentResponse,
250
- function_map: dict[str, Callable],
250
+ function_map: dict[str, Callable[..., Any]],
251
251
  ) -> list[types.Part]:
252
252
  """Returns the function response parts from the response."""
253
253
  func_response_parts = []
@@ -280,7 +280,7 @@ def get_function_response_parts(
280
280
 
281
281
  async def get_function_response_parts_async(
282
282
  response: types.GenerateContentResponse,
283
- function_map: dict[str, Callable],
283
+ function_map: dict[str, Callable[..., Any]],
284
284
  ) -> list[types.Part]:
285
285
  """Returns the function response parts from the response."""
286
286
  func_response_parts = []