google-genai 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -62,8 +62,8 @@ def _raise_for_default_if_mldev(schema: types.Schema):
62
62
  )
63
63
 
64
64
 
65
- def _raise_if_schema_unsupported(client, schema: types.Schema):
66
- if not client.vertexai:
65
+ def _raise_if_schema_unsupported(api_option: Literal['VERTEX_AI', 'GEMINI_API'], schema: types.Schema):
66
+ if api_option == 'GEMINI_API':
67
67
  _raise_for_any_of_if_mldev(schema)
68
68
  _raise_for_default_if_mldev(schema)
69
69
 
@@ -114,7 +114,7 @@ def _is_default_value_compatible(
114
114
 
115
115
 
116
116
  def _parse_schema_from_parameter(
117
- client,
117
+ api_option: Literal['VERTEX_AI', 'GEMINI_API'],
118
118
  param: inspect.Parameter,
119
119
  func_name: str,
120
120
  ) -> types.Schema:
@@ -134,7 +134,7 @@ def _parse_schema_from_parameter(
134
134
  raise ValueError(default_value_error_msg)
135
135
  schema.default = param.default
136
136
  schema.type = _py_builtin_type_to_schema_type[param.annotation]
137
- _raise_if_schema_unsupported(client, schema)
137
+ _raise_if_schema_unsupported(api_option, schema)
138
138
  return schema
139
139
  if (
140
140
  isinstance(param.annotation, VersionedUnionType)
@@ -153,7 +153,7 @@ def _parse_schema_from_parameter(
153
153
  schema.nullable = True
154
154
  continue
155
155
  schema_in_any_of = _parse_schema_from_parameter(
156
- client,
156
+ api_option,
157
157
  inspect.Parameter(
158
158
  'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg
159
159
  ),
@@ -175,7 +175,7 @@ def _parse_schema_from_parameter(
175
175
  if not _is_default_value_compatible(param.default, param.annotation):
176
176
  raise ValueError(default_value_error_msg)
177
177
  schema.default = param.default
178
- _raise_if_schema_unsupported(client, schema)
178
+ _raise_if_schema_unsupported(api_option, schema)
179
179
  return schema
180
180
  if isinstance(param.annotation, _GenericAlias) or isinstance(
181
181
  param.annotation, builtin_types.GenericAlias
@@ -188,7 +188,7 @@ def _parse_schema_from_parameter(
188
188
  if not _is_default_value_compatible(param.default, param.annotation):
189
189
  raise ValueError(default_value_error_msg)
190
190
  schema.default = param.default
191
- _raise_if_schema_unsupported(client, schema)
191
+ _raise_if_schema_unsupported(api_option, schema)
192
192
  return schema
193
193
  if origin is Literal:
194
194
  if not all(isinstance(arg, str) for arg in args):
@@ -201,12 +201,12 @@ def _parse_schema_from_parameter(
201
201
  if not _is_default_value_compatible(param.default, param.annotation):
202
202
  raise ValueError(default_value_error_msg)
203
203
  schema.default = param.default
204
- _raise_if_schema_unsupported(client, schema)
204
+ _raise_if_schema_unsupported(api_option, schema)
205
205
  return schema
206
206
  if origin is list:
207
207
  schema.type = 'ARRAY'
208
208
  schema.items = _parse_schema_from_parameter(
209
- client,
209
+ api_option,
210
210
  inspect.Parameter(
211
211
  'item',
212
212
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -218,7 +218,7 @@ def _parse_schema_from_parameter(
218
218
  if not _is_default_value_compatible(param.default, param.annotation):
219
219
  raise ValueError(default_value_error_msg)
220
220
  schema.default = param.default
221
- _raise_if_schema_unsupported(client, schema)
221
+ _raise_if_schema_unsupported(api_option, schema)
222
222
  return schema
223
223
  if origin is Union:
224
224
  schema.any_of = []
@@ -233,7 +233,7 @@ def _parse_schema_from_parameter(
233
233
  schema.nullable = True
234
234
  continue
235
235
  schema_in_any_of = _parse_schema_from_parameter(
236
- client,
236
+ api_option,
237
237
  inspect.Parameter(
238
238
  'item',
239
239
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -268,7 +268,7 @@ def _parse_schema_from_parameter(
268
268
  if not _is_default_value_compatible(param.default, param.annotation):
269
269
  raise ValueError(default_value_error_msg)
270
270
  schema.default = param.default
271
- _raise_if_schema_unsupported(client, schema)
271
+ _raise_if_schema_unsupported(api_option, schema)
272
272
  return schema
273
273
  # all other generic alias will be invoked in raise branch
274
274
  if (
@@ -284,7 +284,7 @@ def _parse_schema_from_parameter(
284
284
  schema.properties = {}
285
285
  for field_name, field_info in param.annotation.model_fields.items():
286
286
  schema.properties[field_name] = _parse_schema_from_parameter(
287
- client,
287
+ api_option,
288
288
  inspect.Parameter(
289
289
  field_name,
290
290
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -292,9 +292,9 @@ def _parse_schema_from_parameter(
292
292
  ),
293
293
  func_name,
294
294
  )
295
- if client.vertexai:
295
+ if api_option == 'VERTEX_AI':
296
296
  schema.required = _get_required_fields(schema)
297
- _raise_if_schema_unsupported(client, schema)
297
+ _raise_if_schema_unsupported(api_option, schema)
298
298
  return schema
299
299
  raise ValueError(
300
300
  f'Failed to parse the parameter {param} of function {func_name} for'
google/genai/_common.py CHANGED
@@ -220,7 +220,8 @@ class CaseInSensitiveEnum(str, enum.Enum):
220
220
  warnings.warn(f"{value} is not a valid {cls.__name__}")
221
221
  try:
222
222
  # Creating a enum instance based on the value
223
- unknown_enum_val = cls._new_member_(cls) # pylint: disable=protected-access,attribute-error
223
+ # We need to use super() to avoid infinite recursion.
224
+ unknown_enum_val = super().__new__(cls, value)
224
225
  unknown_enum_val._name_ = str(value) # pylint: disable=protected-access
225
226
  unknown_enum_val._value_ = value # pylint: disable=protected-access
226
227
  return unknown_enum_val
@@ -271,8 +271,8 @@ def should_disable_afc(
271
271
  and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
272
272
  ):
273
273
  logging.warning(
274
- '`automatic_function_calling.disable` is set to `True`. But'
275
- ' `automatic_function_calling.maximum_remote_calls` is set to be a'
274
+ '`automatic_function_calling.disable` is set to `True`. And'
275
+ ' `automatic_function_calling.maximum_remote_calls` is a'
276
276
  ' positive number'
277
277
  f' {config_model.automatic_function_calling.maximum_remote_calls}.'
278
278
  ' Disabling automatic function calling. If you want to enable'
@@ -374,8 +374,8 @@ def handle_null_fields(schema: dict[str, Any]):
374
374
  schema['anyOf'].remove({'type': 'null'})
375
375
  if len(schema['anyOf']) == 1:
376
376
  # If there is only one type left after removing null, remove the anyOf field.
377
- field_type = schema['anyOf'][0]['type']
378
- schema['type'] = field_type
377
+ for key,val in schema['anyOf'][0].items():
378
+ schema[key] = val
379
379
  del schema['anyOf']
380
380
 
381
381
 
@@ -446,9 +446,17 @@ def process_schema(
446
446
 
447
447
  if schema.get('default') is not None:
448
448
  raise ValueError(
449
- 'Default value is not supported in the response schema for the Gemmini API.'
449
+ 'Default value is not supported in the response schema for the Gemini API.'
450
450
  )
451
451
 
452
+ if schema.get('title') == 'PlaceholderLiteralEnum':
453
+ schema.pop('title', None)
454
+
455
+ # If a dict is provided directly to response_schema, it may use `any_of`
456
+ # instead of `anyOf`. Otherwise model_json_schema() uses `anyOf`
457
+ if schema.get('any_of', None) is not None:
458
+ schema['anyOf'] = schema.pop('any_of')
459
+
452
460
  if defs is None:
453
461
  defs = schema.pop('$defs', {})
454
462
  for _, sub_schema in defs.items():
@@ -456,6 +464,15 @@ def process_schema(
456
464
 
457
465
  handle_null_fields(schema)
458
466
 
467
+ # After removing null fields, Optional fields with only one possible type
468
+ # will have a $ref key that needs to be flattened
469
+ # For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
470
+ if schema.get('$ref', None):
471
+ ref = defs[schema.get('$ref').split('defs/')[-1]]
472
+ for schema_key in list(ref.keys()):
473
+ schema[schema_key] = ref[schema_key]
474
+ del schema['$ref']
475
+
459
476
  any_of = schema.get('anyOf', None)
460
477
  if any_of is not None:
461
478
  if not client.vertexai:
@@ -478,6 +495,16 @@ def process_schema(
478
495
  schema_type = schema_type.value
479
496
  schema_type = schema_type.upper()
480
497
 
498
+ # model_json_schema() returns a schema with a 'const' field when a Literal with one value is provided as a pydantic field
499
+ # For example `genre: Literal['action']` becomes: {'const': 'action', 'title': 'Genre', 'type': 'string'}
500
+ const = schema.get('const', None)
501
+ if const is not None:
502
+ if schema_type == 'STRING':
503
+ schema['enum'] = [const]
504
+ del schema['const']
505
+ else:
506
+ raise ValueError('Literal values must be strings.')
507
+
481
508
  if schema_type == 'OBJECT':
482
509
  properties = schema.get('properties', None)
483
510
  if properties is None:
@@ -502,6 +529,7 @@ def process_schema(
502
529
  process_schema(ref, client, defs)
503
530
  schema['items'] = ref
504
531
 
532
+
505
533
  def _process_enum(
506
534
  enum: EnumMeta, client: Optional[_api_client.ApiClient] = None
507
535
  ) -> types.Schema:
google/genai/chats.py CHANGED
@@ -57,12 +57,16 @@ class Chat(_BaseChat):
57
57
  """Chat session."""
58
58
 
59
59
  def send_message(
60
- self, message: Union[list[PartUnionDict], PartUnionDict]
60
+ self,
61
+ message: Union[list[PartUnionDict], PartUnionDict],
62
+ config: Optional[GenerateContentConfigOrDict] = None,
61
63
  ) -> GenerateContentResponse:
62
64
  """Sends the conversation history with the additional message and returns the model's response.
63
65
 
64
66
  Args:
65
67
  message: The message to send to the model.
68
+ config: Optional config to override the default Chat config for this
69
+ request.
66
70
 
67
71
  Returns:
68
72
  The model's response.
@@ -79,7 +83,7 @@ class Chat(_BaseChat):
79
83
  response = self._modules.generate_content(
80
84
  model=self._model,
81
85
  contents=self._curated_history + [input_content],
82
- config=self._config,
86
+ config=config if config else self._config,
83
87
  )
84
88
  if _validate_response(response):
85
89
  if response.automatic_function_calling_history:
@@ -92,12 +96,16 @@ class Chat(_BaseChat):
92
96
  return response
93
97
 
94
98
  def send_message_stream(
95
- self, message: Union[list[PartUnionDict], PartUnionDict]
99
+ self,
100
+ message: Union[list[PartUnionDict], PartUnionDict],
101
+ config: Optional[GenerateContentConfigOrDict] = None,
96
102
  ):
97
103
  """Sends the conversation history with the additional message and yields the model's response in chunks.
98
104
 
99
105
  Args:
100
106
  message: The message to send to the model.
107
+ config: Optional config to override the default Chat config for this
108
+ request.
101
109
 
102
110
  Yields:
103
111
  The model's response in chunks.
@@ -117,7 +125,7 @@ class Chat(_BaseChat):
117
125
  for chunk in self._modules.generate_content_stream(
118
126
  model=self._model,
119
127
  contents=self._curated_history + [input_content],
120
- config=self._config,
128
+ config=config if config else self._config,
121
129
  ):
122
130
  if _validate_response(chunk):
123
131
  output_contents.append(chunk.candidates[0].content)
@@ -164,12 +172,16 @@ class AsyncChat(_BaseChat):
164
172
  """Async chat session."""
165
173
 
166
174
  async def send_message(
167
- self, message: Union[list[PartUnionDict], PartUnionDict]
175
+ self,
176
+ message: Union[list[PartUnionDict], PartUnionDict],
177
+ config: Optional[GenerateContentConfigOrDict] = None,
168
178
  ) -> GenerateContentResponse:
169
179
  """Sends the conversation history with the additional message and returns model's response.
170
180
 
171
181
  Args:
172
182
  message: The message to send to the model.
183
+ config: Optional config to override the default Chat config for this
184
+ request.
173
185
 
174
186
  Returns:
175
187
  The model's response.
@@ -186,7 +198,7 @@ class AsyncChat(_BaseChat):
186
198
  response = await self._modules.generate_content(
187
199
  model=self._model,
188
200
  contents=self._curated_history + [input_content],
189
- config=self._config,
201
+ config=config if config else self._config,
190
202
  )
191
203
  if _validate_response(response):
192
204
  if response.automatic_function_calling_history:
@@ -199,12 +211,16 @@ class AsyncChat(_BaseChat):
199
211
  return response
200
212
 
201
213
  async def send_message_stream(
202
- self, message: Union[list[PartUnionDict], PartUnionDict]
214
+ self,
215
+ message: Union[list[PartUnionDict], PartUnionDict],
216
+ config: Optional[GenerateContentConfigOrDict] = None,
203
217
  ) -> Awaitable[AsyncIterator[GenerateContentResponse]]:
204
218
  """Sends the conversation history with the additional message and yields the model's response in chunks.
205
219
 
206
220
  Args:
207
221
  message: The message to send to the model.
222
+ config: Optional config to override the default Chat config for this
223
+ request.
208
224
 
209
225
  Yields:
210
226
  The model's response in chunks.
@@ -225,7 +241,7 @@ class AsyncChat(_BaseChat):
225
241
  async for chunk in await self._modules.generate_content_stream(
226
242
  model=self._model,
227
243
  contents=self._curated_history + [input_content],
228
- config=self._config,
244
+ config=config if config else self._config,
229
245
  ):
230
246
  if _validate_response(chunk):
231
247
  output_contents.append(chunk.candidates[0].content)
google/genai/models.py CHANGED
@@ -4784,15 +4784,18 @@ class Models(_api_module.BaseModule):
4784
4784
  """
4785
4785
 
4786
4786
  if _extra_utils.should_disable_afc(config):
4787
- return self._generate_content_stream(
4787
+ yield from self._generate_content_stream(
4788
4788
  model=model, contents=contents, config=config
4789
4789
  )
4790
+ return
4791
+
4790
4792
  remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config)
4791
4793
  logging.info(
4792
4794
  f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.'
4793
4795
  )
4794
4796
  automatic_function_calling_history = []
4795
4797
  chunk = None
4798
+ func_response_parts = None
4796
4799
  i = 0
4797
4800
  while remaining_remote_calls_afc > 0:
4798
4801
  i += 1
@@ -4828,6 +4831,10 @@ class Models(_api_module.BaseModule):
4828
4831
  automatic_function_calling_history
4829
4832
  )
4830
4833
  yield chunk
4834
+ func_response_parts = _extra_utils.get_function_response_parts(
4835
+ chunk, function_map
4836
+ )
4837
+
4831
4838
  if not chunk:
4832
4839
  break
4833
4840
  if (
@@ -4840,9 +4847,6 @@ class Models(_api_module.BaseModule):
4840
4847
 
4841
4848
  if not function_map:
4842
4849
  break
4843
- func_response_parts = _extra_utils.get_function_response_parts(
4844
- chunk, function_map
4845
- )
4846
4850
  if not func_response_parts:
4847
4851
  break
4848
4852
 
@@ -5930,10 +5934,16 @@ class AsyncModels(_api_module.BaseModule):
5930
5934
  """
5931
5935
 
5932
5936
  if _extra_utils.should_disable_afc(config):
5933
- return self._generate_content_stream(
5937
+ response = await self._generate_content_stream(
5934
5938
  model=model, contents=contents, config=config
5935
5939
  )
5936
5940
 
5941
+ async def base_async_generator(model, contents, config):
5942
+ async for chunk in response:
5943
+ yield chunk
5944
+
5945
+ return base_async_generator(model, contents, config)
5946
+
5937
5947
  async def async_generator(model, contents, config):
5938
5948
  remaining_remote_calls_afc = _extra_utils.get_max_remote_calls_afc(config)
5939
5949
  logging.info(
@@ -5979,7 +5989,9 @@ class AsyncModels(_api_module.BaseModule):
5979
5989
  automatic_function_calling_history
5980
5990
  )
5981
5991
  yield chunk
5982
-
5992
+ func_response_parts = _extra_utils.get_function_response_parts(
5993
+ chunk, function_map
5994
+ )
5983
5995
  if not chunk:
5984
5996
  break
5985
5997
  if (
@@ -5991,9 +6003,7 @@ class AsyncModels(_api_module.BaseModule):
5991
6003
  break
5992
6004
  if not function_map:
5993
6005
  break
5994
- func_response_parts = _extra_utils.get_function_response_parts(
5995
- chunk, function_map
5996
- )
6006
+
5997
6007
  if not func_response_parts:
5998
6008
  break
5999
6009
 
google/genai/tunings.py CHANGED
@@ -15,8 +15,7 @@
15
15
 
16
16
  # Code generated by the Google Gen AI SDK generator DO NOT EDIT.
17
17
 
18
- import time
19
- from typing import Any, Optional, Union
18
+ from typing import Optional, Union
20
19
  from urllib.parse import urlencode
21
20
  from . import _api_module
22
21
  from . import _common
@@ -998,11 +997,14 @@ class Tunings(_api_module.BaseModule):
998
997
  config=config,
999
998
  )
1000
999
  operation_dict = operation.to_json_dict()
1001
- tuned_model_dict = _resolve_operation(self._api_client, operation_dict)
1002
- tuning_job_dict = _TuningJob_from_mldev(
1003
- self._api_client, tuned_model_dict
1000
+ try:
1001
+ tuned_model_name = operation_dict['metadata']['tunedModel']
1002
+ except KeyError:
1003
+ tuned_model_name = operation_dict['name'].partition('/operations/')[0]
1004
+ tuning_job = types.TuningJob(
1005
+ name=tuned_model_name,
1006
+ state=types.JobState.JOB_STATE_QUEUED,
1004
1007
  )
1005
- tuning_job = types.TuningJob._from_response(tuning_job_dict, None)
1006
1008
  if tuning_job.name and self._api_client.vertexai:
1007
1009
  _IpythonUtils.display_model_tuning_button(
1008
1010
  tuning_job_resource=tuning_job.name
@@ -1010,72 +1012,6 @@ class Tunings(_api_module.BaseModule):
1010
1012
  return tuning_job
1011
1013
 
1012
1014
 
1013
- _LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
1014
- _LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
1015
- _LRO_POLLING_TIMEOUT_SECONDS = 900.0
1016
- _LRO_POLLING_MULTIPLIER = 1.5
1017
-
1018
-
1019
- def _resolve_operation(api_client: ApiClient, struct: dict[str, Any]):
1020
- if (name := struct.get('name')) and '/operations/' in name:
1021
- operation: dict[str, Any] = struct
1022
- total_seconds = 0.0
1023
- delay_seconds = _LRO_POLLING_INITIAL_DELAY_SECONDS
1024
- while not operation.get('done'):
1025
- if total_seconds > _LRO_POLLING_TIMEOUT_SECONDS:
1026
- raise RuntimeError(f'Operation {name} timed out.\n{operation}')
1027
- # TODO(b/374433890): Replace with LRO module once it's available.
1028
- operation: dict[str, Any] = api_client.request(
1029
- http_method='GET', path=name, request_dict={}
1030
- )
1031
- if 'ReplayApiClient' not in type(api_client).__name__:
1032
- time.sleep(delay_seconds)
1033
- total_seconds += total_seconds
1034
- # Exponential backoff
1035
- delay_seconds = min(
1036
- delay_seconds * _LRO_POLLING_MULTIPLIER,
1037
- _LRO_POLLING_MAXIMUM_DELAY_SECONDS,
1038
- )
1039
- if error := operation.get('error'):
1040
- raise RuntimeError(
1041
- f'Operation {name} failed with error: {error}.\n{operation}'
1042
- )
1043
- return operation.get('response')
1044
- else:
1045
- return struct
1046
-
1047
-
1048
- async def _resolve_operation_async(
1049
- api_client: ApiClient, struct: dict[str, Any]
1050
- ):
1051
- if (name := struct.get('name')) and '/operations/' in name:
1052
- operation: dict[str, Any] = struct
1053
- total_seconds = 0.0
1054
- delay_seconds = _LRO_POLLING_INITIAL_DELAY_SECONDS
1055
- while not operation.get('done'):
1056
- if total_seconds > _LRO_POLLING_TIMEOUT_SECONDS:
1057
- raise RuntimeError(f'Operation {name} timed out.\n{operation}')
1058
- # TODO(b/374433890): Replace with LRO module once it's available.
1059
- operation: dict[str, Any] = await api_client.async_request(
1060
- http_method='GET', path=name, request_dict={}
1061
- )
1062
- if 'ReplayApiClient' not in type(api_client).__name__:
1063
- time.sleep(delay_seconds)
1064
- total_seconds += total_seconds
1065
- # Exponential backoff
1066
- delay_seconds = min(
1067
- delay_seconds * _LRO_POLLING_MULTIPLIER,
1068
- _LRO_POLLING_MAXIMUM_DELAY_SECONDS,
1069
- )
1070
- if error := operation.get('error'):
1071
- raise RuntimeError(
1072
- f'Operation {name} failed with error: {error}.\n{operation}'
1073
- )
1074
- return operation.get('response')
1075
- else:
1076
- return struct
1077
-
1078
-
1079
1015
  class AsyncTunings(_api_module.BaseModule):
1080
1016
 
1081
1017
  async def _get(
@@ -1370,13 +1306,14 @@ class AsyncTunings(_api_module.BaseModule):
1370
1306
  config=config,
1371
1307
  )
1372
1308
  operation_dict = operation.to_json_dict()
1373
- tuned_model_dict = await _resolve_operation_async(
1374
- self._api_client, operation_dict
1375
- )
1376
- tuning_job_dict = _TuningJob_from_mldev(
1377
- self._api_client, tuned_model_dict
1309
+ try:
1310
+ tuned_model_name = operation_dict['metadata']['tunedModel']
1311
+ except KeyError:
1312
+ tuned_model_name = operation_dict['name'].partition('/operations/')[0]
1313
+ tuning_job = types.TuningJob(
1314
+ name=tuned_model_name,
1315
+ state=types.JobState.JOB_STATE_QUEUED,
1378
1316
  )
1379
- tuning_job = types.TuningJob._from_response(tuning_job_dict, None)
1380
1317
  if tuning_job.name and self._api_client.vertexai:
1381
1318
  _IpythonUtils.display_model_tuning_button(
1382
1319
  tuning_job_resource=tuning_job.name
google/genai/types.py CHANGED
@@ -971,13 +971,24 @@ class FunctionDeclaration(_common.BaseModel):
971
971
  )
972
972
 
973
973
  @classmethod
974
- def from_callable(
974
+ def from_callable_with_api_option(
975
975
  cls,
976
976
  *,
977
- client,
978
977
  callable: Callable,
978
+ api_option: Literal['VERTEX_AI', 'GEMINI_API'] = 'GEMINI_API',
979
979
  ) -> 'FunctionDeclaration':
980
- """Converts a Callable to a FunctionDeclaration based on the client."""
980
+ """Converts a Callable to a FunctionDeclaration based on the API option.
981
+
982
+ Supported API option is 'VERTEX_AI' or 'GEMINI_API'. If api_option is unset,
983
+ it will default to 'GEMINI_API'. If unsupported api_option is provided, it
984
+ will raise ValueError.
985
+ """
986
+ supported_api_options = ['VERTEX_AI', 'GEMINI_API']
987
+ if api_option not in supported_api_options:
988
+ raise ValueError(
989
+ f'Unsupported api_option value: {api_option}. Supported api_option'
990
+ f' value is one of: {supported_api_options}.'
991
+ )
981
992
  from . import _automatic_function_calling_util
982
993
 
983
994
  parameters_properties = {}
@@ -988,7 +999,7 @@ class FunctionDeclaration(_common.BaseModel):
988
999
  inspect.Parameter.POSITIONAL_ONLY,
989
1000
  ):
990
1001
  schema = _automatic_function_calling_util._parse_schema_from_parameter(
991
- client, param, callable.__name__
1002
+ api_option, param, callable.__name__
992
1003
  )
993
1004
  parameters_properties[name] = schema
994
1005
  declaration = FunctionDeclaration(
@@ -1000,13 +1011,13 @@ class FunctionDeclaration(_common.BaseModel):
1000
1011
  type='OBJECT',
1001
1012
  properties=parameters_properties,
1002
1013
  )
1003
- if client.vertexai:
1014
+ if api_option == 'VERTEX_AI':
1004
1015
  declaration.parameters.required = (
1005
1016
  _automatic_function_calling_util._get_required_fields(
1006
1017
  declaration.parameters
1007
1018
  )
1008
1019
  )
1009
- if not client.vertexai:
1020
+ if api_option == 'GEMINI_API':
1010
1021
  return declaration
1011
1022
 
1012
1023
  return_annotation = inspect.signature(callable).return_annotation
@@ -1015,7 +1026,7 @@ class FunctionDeclaration(_common.BaseModel):
1015
1026
 
1016
1027
  declaration.response = (
1017
1028
  _automatic_function_calling_util._parse_schema_from_parameter(
1018
- client,
1029
+ api_option,
1019
1030
  inspect.Parameter(
1020
1031
  'return_value',
1021
1032
  inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -1026,6 +1037,23 @@ class FunctionDeclaration(_common.BaseModel):
1026
1037
  )
1027
1038
  return declaration
1028
1039
 
1040
+ @classmethod
1041
+ def from_callable(
1042
+ cls,
1043
+ *,
1044
+ client,
1045
+ callable: Callable,
1046
+ ) -> 'FunctionDeclaration':
1047
+ """Converts a Callable to a FunctionDeclaration based on the client."""
1048
+ if client.vertexai:
1049
+ return cls.from_callable_with_api_option(
1050
+ callable=callable, api_option='VERTEX_AI'
1051
+ )
1052
+ else:
1053
+ return cls.from_callable_with_api_option(
1054
+ callable=callable, api_option='GEMINI_API'
1055
+ )
1056
+
1029
1057
 
1030
1058
  class FunctionDeclarationDict(TypedDict, total=False):
1031
1059
  """Defines a function that the model can generate JSON inputs for.
@@ -1909,6 +1937,20 @@ class GenerateContentConfig(_common.BaseModel):
1909
1937
  """,
1910
1938
  )
1911
1939
 
1940
+ @pydantic.field_validator('response_schema', mode='before')
1941
+ @classmethod
1942
+ def _convert_literal_to_enum(cls, value):
1943
+ if typing.get_origin(value) is typing.Literal:
1944
+ enum_vals = typing.get_args(value)
1945
+ if not all(isinstance(arg, str) for arg in enum_vals):
1946
+ # This doesn't stop execution, it tells pydantic to raise a ValidationError
1947
+ # when the class is instantiated with an unsupported Literal
1948
+ raise ValueError(f'Literal type {value} must be a list of strings.')
1949
+ # The title 'PlaceholderLiteralEnum' is removed from the generated Schema
1950
+ # before sending the request
1951
+ return Enum('PlaceholderLiteralEnum', {s: s for s in enum_vals})
1952
+ return value
1953
+
1912
1954
 
1913
1955
  class GenerateContentConfigDict(TypedDict, total=False):
1914
1956
  """Optional model configuration parameters.
@@ -2815,7 +2857,7 @@ class GenerateContentResponse(_common.BaseModel):
2815
2857
  text = ''
2816
2858
  any_text_part_text = False
2817
2859
  for part in self.candidates[0].content.parts:
2818
- for field_name, field_value in part.dict(
2860
+ for field_name, field_value in part.model_dump(
2819
2861
  exclude={'text', 'thought'}
2820
2862
  ).items():
2821
2863
  if field_value is not None:
@@ -2883,6 +2925,11 @@ class GenerateContentResponse(_common.BaseModel):
2883
2925
  enum_value = result.text.replace('"', '')
2884
2926
  try:
2885
2927
  result.parsed = response_schema(enum_value)
2928
+ if (
2929
+ hasattr(response_schema, '__name__')
2930
+ and response_schema.__name__ == 'PlaceholderLiteralEnum'
2931
+ ):
2932
+ result.parsed = str(response_schema(enum_value).name)
2886
2933
  except ValueError:
2887
2934
  pass
2888
2935
  elif isinstance(response_schema, GenericAlias) or isinstance(
google/genai/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  #
15
15
 
16
- __version__ = '1.0.0' # x-release-please-version
16
+ __version__ = '1.1.0' # x-release-please-version