google-genai 1.4.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,9 +17,9 @@
17
17
 
18
18
  import inspect
19
19
  import logging
20
+ import sys
20
21
  import typing
21
22
  from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
22
- import sys
23
23
 
24
24
  import pydantic
25
25
 
@@ -37,6 +37,15 @@ _DEFAULT_MAX_REMOTE_CALLS_AFC = 10
37
37
  logger = logging.getLogger('google_genai.models')
38
38
 
39
39
 
40
+ def _create_generate_content_config_model(
41
+ config: types.GenerateContentConfigOrDict,
42
+ ) -> types.GenerateContentConfig:
43
+ if isinstance(config, dict):
44
+ return types.GenerateContentConfig(**config)
45
+ else:
46
+ return config
47
+
48
+
40
49
  def format_destination(
41
50
  src: str,
42
51
  config: Optional[types.CreateBatchJobConfigOrDict] = None,
@@ -69,16 +78,12 @@ def format_destination(
69
78
 
70
79
  def get_function_map(
71
80
  config: Optional[types.GenerateContentConfigOrDict] = None,
72
- ) -> dict[str, object]:
81
+ ) -> dict[str, Callable]:
73
82
  """Returns a function map from the config."""
74
- config_model = (
75
- types.GenerateContentConfig(**config)
76
- if config and isinstance(config, dict)
77
- else config
78
- )
79
- function_map: dict[str, object] = {}
80
- if not config_model:
83
+ function_map: dict[str, Callable] = {}
84
+ if not config:
81
85
  return function_map
86
+ config_model = _create_generate_content_config_model(config)
82
87
  if config_model.tools:
83
88
  for tool in config_model.tools:
84
89
  if callable(tool):
@@ -92,6 +97,16 @@ def get_function_map(
92
97
  return function_map
93
98
 
94
99
 
100
+ def convert_number_values_for_dict_function_call_args(
101
+ args: dict[str, Any],
102
+ ) -> dict[str, Any]:
103
+ """Converts float values in dict with no decimal to integers."""
104
+ return {
105
+ key: convert_number_values_for_function_call_args(value)
106
+ for key, value in args.items()
107
+ }
108
+
109
+
95
110
  def convert_number_values_for_function_call_args(
96
111
  args: Union[dict[str, object], list[object], object],
97
112
  ) -> Union[dict[str, object], list[object], object]:
@@ -210,26 +225,35 @@ def invoke_function_from_dict_args(
210
225
 
211
226
  def get_function_response_parts(
212
227
  response: types.GenerateContentResponse,
213
- function_map: dict[str, object],
228
+ function_map: dict[str, Callable],
214
229
  ) -> list[types.Part]:
215
230
  """Returns the function response parts from the response."""
216
231
  func_response_parts = []
217
- for part in response.candidates[0].content.parts:
218
- if not part.function_call:
219
- continue
220
- func_name = part.function_call.name
221
- func = function_map[func_name]
222
- args = convert_number_values_for_function_call_args(part.function_call.args)
223
- func_response: dict[str, Any]
224
- try:
225
- func_response = {'result': invoke_function_from_dict_args(args, func)}
226
- except Exception as e: # pylint: disable=broad-except
227
- func_response = {'error': str(e)}
228
- func_response_part = types.Part.from_function_response(
229
- name=func_name, response=func_response
230
- )
231
-
232
- func_response_parts.append(func_response_part)
232
+ if (
233
+ response.candidates is not None
234
+ and isinstance(response.candidates[0].content, types.Content)
235
+ and response.candidates[0].content.parts is not None
236
+ ):
237
+ for part in response.candidates[0].content.parts:
238
+ if not part.function_call:
239
+ continue
240
+ func_name = part.function_call.name
241
+ if func_name is not None and part.function_call.args is not None:
242
+ func = function_map[func_name]
243
+ args = convert_number_values_for_dict_function_call_args(
244
+ part.function_call.args
245
+ )
246
+ func_response: dict[str, Any]
247
+ try:
248
+ func_response = {
249
+ 'result': invoke_function_from_dict_args(args, func)
250
+ }
251
+ except Exception as e: # pylint: disable=broad-except
252
+ func_response = {'error': str(e)}
253
+ func_response_part = types.Part.from_function_response(
254
+ name=func_name, response=func_response
255
+ )
256
+ func_response_parts.append(func_response_part)
233
257
  return func_response_parts
234
258
 
235
259
 
@@ -237,12 +261,9 @@ def should_disable_afc(
237
261
  config: Optional[types.GenerateContentConfigOrDict] = None,
238
262
  ) -> bool:
239
263
  """Returns whether automatic function calling is enabled."""
240
- config_model = (
241
- types.GenerateContentConfig(**config)
242
- if config and isinstance(config, dict)
243
- else config
244
- )
245
-
264
+ if not config:
265
+ return False
266
+ config_model = _create_generate_content_config_model(config)
246
267
  # If max_remote_calls is less or equal to 0, warn and disable AFC.
247
268
  if (
248
269
  config_model
@@ -261,8 +282,7 @@ def should_disable_afc(
261
282
 
262
283
  # Default to enable AFC if not specified.
263
284
  if (
264
- not config_model
265
- or not config_model.automatic_function_calling
285
+ not config_model.automatic_function_calling
266
286
  or config_model.automatic_function_calling.disable is None
267
287
  ):
268
288
  return False
@@ -295,20 +315,17 @@ def should_disable_afc(
295
315
  def get_max_remote_calls_afc(
296
316
  config: Optional[types.GenerateContentConfigOrDict] = None,
297
317
  ) -> int:
318
+ if not config:
319
+ return _DEFAULT_MAX_REMOTE_CALLS_AFC
298
320
  """Returns the remaining remote calls for automatic function calling."""
299
321
  if should_disable_afc(config):
300
322
  raise ValueError(
301
323
  'automatic function calling is not enabled, but SDK is trying to get'
302
324
  ' max remote calls.'
303
325
  )
304
- config_model = (
305
- types.GenerateContentConfig(**config)
306
- if config and isinstance(config, dict)
307
- else config
308
- )
326
+ config_model = _create_generate_content_config_model(config)
309
327
  if (
310
- not config_model
311
- or not config_model.automatic_function_calling
328
+ not config_model.automatic_function_calling
312
329
  or config_model.automatic_function_calling.maximum_remote_calls is None
313
330
  ):
314
331
  return _DEFAULT_MAX_REMOTE_CALLS_AFC
@@ -318,11 +335,9 @@ def get_max_remote_calls_afc(
318
335
  def should_append_afc_history(
319
336
  config: Optional[types.GenerateContentConfigOrDict] = None,
320
337
  ) -> bool:
321
- config_model = (
322
- types.GenerateContentConfig(**config)
323
- if config and isinstance(config, dict)
324
- else config
325
- )
326
- if not config_model or not config_model.automatic_function_calling:
338
+ if not config:
339
+ return True
340
+ config_model = _create_generate_content_config_model(config)
341
+ if not config_model.automatic_function_calling:
327
342
  return True
328
343
  return not config_model.automatic_function_calling.ignore_call_history
@@ -109,7 +109,8 @@ def _redact_project_location_path(path: str) -> str:
109
109
  return path
110
110
 
111
111
 
112
- def _redact_request_body(body: dict[str, object]) -> dict[str, object]:
112
+ def _redact_request_body(body: dict[str, object]):
113
+ """Redacts fields in the request body in place."""
113
114
  for key, value in body.items():
114
115
  if isinstance(value, str):
115
116
  body[key] = _redact_project_location_path(value)
@@ -302,13 +303,24 @@ class ReplayApiClient(BaseApiClient):
302
303
  status_code=http_response.status_code,
303
304
  sdk_response_segments=[],
304
305
  )
305
- else:
306
+ elif isinstance(http_response, errors.APIError):
306
307
  response = ReplayResponse(
307
308
  headers=dict(http_response.response.headers),
308
309
  body_segments=[http_response._to_replay_record()],
309
310
  status_code=http_response.code,
310
311
  sdk_response_segments=[],
311
312
  )
313
+ elif isinstance(http_response, bytes):
314
+ response = ReplayResponse(
315
+ headers={},
316
+ body_segments=[],
317
+ byte_segments=[http_response],
318
+ sdk_response_segments=[],
319
+ )
320
+ else:
321
+ raise ValueError(
322
+ 'Unsupported http_response type: ' + str(type(http_response))
323
+ )
312
324
  self.replay_session.interactions.append(
313
325
  ReplayInteraction(request=request, response=response)
314
326
  )
@@ -471,6 +483,43 @@ class ReplayApiClient(BaseApiClient):
471
483
  else:
472
484
  return self._build_response_from_replay(request).json
473
485
 
486
+ async def async_upload_file(
487
+ self,
488
+ file_path: Union[str, io.IOBase],
489
+ upload_url: str,
490
+ upload_size: int,
491
+ ) -> str:
492
+ if isinstance(file_path, io.IOBase):
493
+ offset = file_path.tell()
494
+ content = file_path.read()
495
+ file_path.seek(offset, os.SEEK_SET)
496
+ request = HttpRequest(
497
+ method='POST',
498
+ url='',
499
+ data={'bytes': base64.b64encode(content).decode('utf-8')},
500
+ headers={},
501
+ )
502
+ else:
503
+ request = HttpRequest(
504
+ method='POST', url='', data={'file_path': file_path}, headers={}
505
+ )
506
+ if self._should_call_api():
507
+ result: Union[str, HttpResponse]
508
+ try:
509
+ result = await super().async_upload_file(
510
+ file_path, upload_url, upload_size
511
+ )
512
+ except HTTPError as e:
513
+ result = HttpResponse(
514
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
515
+ )
516
+ result.status_code = e.response.status_code
517
+ raise e
518
+ self._record_interaction(request, HttpResponse({}, [json.dumps(result)]))
519
+ return result
520
+ else:
521
+ return self._build_response_from_replay(request).json
522
+
474
523
  def _download_file_request(self, request):
475
524
  self._initialize_replay_session_if_not_loaded()
476
525
  if self._should_call_api():
@@ -486,3 +535,22 @@ class ReplayApiClient(BaseApiClient):
486
535
  return result
487
536
  else:
488
537
  return self._build_response_from_replay(request)
538
+
539
+ async def async_download_file(self, path: str, http_options):
540
+ self._initialize_replay_session_if_not_loaded()
541
+ request = self._build_request(
542
+ 'get', path=path, request_dict={}, http_options=http_options
543
+ )
544
+ if self._should_call_api():
545
+ try:
546
+ result = await super().async_download_file(path, http_options)
547
+ except HTTPError as e:
548
+ result = HttpResponse(
549
+ e.response.headers, [json.dumps({'reason': e.response.reason})]
550
+ )
551
+ result.status_code = e.response.status_code
552
+ raise e
553
+ self._record_interaction(request, result)
554
+ return result
555
+ else:
556
+ return self._build_response_from_replay(request).byte_stream[0]
@@ -26,9 +26,7 @@ import sys
26
26
  import time
27
27
  import types as builtin_types
28
28
  import typing
29
- from typing import Any, GenericAlias, Optional, Union
30
-
31
- import types as builtin_types
29
+ from typing import Any, GenericAlias, Optional, Union # type: ignore[attr-defined]
32
30
 
33
31
  if typing.TYPE_CHECKING:
34
32
  import PIL.Image
@@ -43,10 +41,11 @@ logger = logging.getLogger('google_genai._transformers')
43
41
  if sys.version_info >= (3, 10):
44
42
  VersionedUnionType = builtin_types.UnionType
45
43
  _UNION_TYPES = (typing.Union, builtin_types.UnionType)
44
+ from typing import TypeGuard
46
45
  else:
47
46
  VersionedUnionType = typing._UnionGenericAlias
48
47
  _UNION_TYPES = (typing.Union,)
49
-
48
+ from typing_extensions import TypeGuard
50
49
 
51
50
  def _resource_name(
52
51
  client: _api_client.BaseApiClient,
@@ -165,7 +164,9 @@ def t_model(client: _api_client.BaseApiClient, model: str):
165
164
  return f'models/{model}'
166
165
 
167
166
 
168
- def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> str:
167
+ def t_models_url(
168
+ api_client: _api_client.BaseApiClient, base_models: bool
169
+ ) -> str:
169
170
  if api_client.vertexai:
170
171
  if base_models:
171
172
  return 'publishers/google/models'
@@ -179,8 +180,9 @@ def t_models_url(api_client: _api_client.BaseApiClient, base_models: bool) -> st
179
180
 
180
181
 
181
182
  def t_extract_models(
182
- api_client: _api_client.BaseApiClient, response: dict
183
- ) -> list[types.Model]:
183
+ api_client: _api_client.BaseApiClient,
184
+ response: dict[str, list[types.ModelDict]],
185
+ ) -> Optional[list[types.ModelDict]]:
184
186
  if not response:
185
187
  return []
186
188
  elif response.get('models') is not None:
@@ -240,9 +242,7 @@ def pil_to_blob(img) -> types.Blob:
240
242
  return types.Blob(mime_type=mime_type, data=data)
241
243
 
242
244
 
243
- def t_part(
244
- client: _api_client.BaseApiClient, part: Optional[types.PartUnionDict]
245
- ) -> types.Part:
245
+ def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
246
246
  try:
247
247
  import PIL.Image
248
248
 
@@ -268,22 +268,21 @@ def t_part(
268
268
 
269
269
 
270
270
  def t_parts(
271
- client: _api_client.BaseApiClient,
272
- parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
271
+ parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]],
273
272
  ) -> list[types.Part]:
274
273
  #
275
274
  if parts is None or (isinstance(parts, list) and not parts):
276
275
  raise ValueError('content parts are required.')
277
276
  if isinstance(parts, list):
278
- return [t_part(client, part) for part in parts]
277
+ return [t_part(part) for part in parts]
279
278
  else:
280
- return [t_part(client, parts)]
279
+ return [t_part(parts)]
281
280
 
282
281
 
283
282
  def t_image_predictions(
284
283
  client: _api_client.BaseApiClient,
285
284
  predictions: Optional[Iterable[Mapping[str, Any]]],
286
- ) -> list[types.GeneratedImage]:
285
+ ) -> Optional[list[types.GeneratedImage]]:
287
286
  if not predictions:
288
287
  return None
289
288
  images = []
@@ -333,22 +332,35 @@ def t_content(
333
332
  def t_contents_for_embed(
334
333
  client: _api_client.BaseApiClient,
335
334
  contents: Union[list[types.Content], list[types.ContentDict], ContentType],
336
- ):
337
- if client.vertexai and isinstance(contents, list):
338
- # TODO: Assert that only text is supported.
339
- return [t_content(client, content).parts[0].text for content in contents]
340
- elif client.vertexai:
341
- return [t_content(client, contents).parts[0].text]
342
- elif isinstance(contents, list):
343
- return [t_content(client, content) for content in contents]
335
+ ) -> Union[list[str], list[types.Content]]:
336
+ if isinstance(contents, list):
337
+ transformed_contents = [t_content(client, content) for content in contents]
344
338
  else:
345
- return [t_content(client, contents)]
339
+ transformed_contents = [t_content(client, contents)]
340
+
341
+ if client.vertexai:
342
+ text_parts = []
343
+ for content in transformed_contents:
344
+ if content is not None:
345
+ if isinstance(content, dict):
346
+ content = types.Content.model_validate(content)
347
+ if content.parts is not None:
348
+ for part in content.parts:
349
+ if part.text:
350
+ text_parts.append(part.text)
351
+ else:
352
+ logger.warning(
353
+ f'Non-text part found, only returning text parts.'
354
+ )
355
+ return text_parts
356
+ else:
357
+ return transformed_contents
346
358
 
347
359
 
348
360
  def t_contents(
349
361
  client: _api_client.BaseApiClient,
350
362
  contents: Optional[
351
- Union[types.ContentListUnion, types.ContentListUnionDict]
363
+ Union[types.ContentListUnion, types.ContentListUnionDict, types.Content]
352
364
  ],
353
365
  ) -> list[types.Content]:
354
366
  if contents is None or (isinstance(contents, list) and not contents):
@@ -366,7 +378,7 @@ def t_contents(
366
378
  result: list[types.Content] = []
367
379
  accumulated_parts: list[types.Part] = []
368
380
 
369
- def _is_part(part: types.PartUnionDict) -> bool:
381
+ def _is_part(part: Union[types.PartUnionDict, Any]) -> TypeGuard[types.PartUnionDict]:
370
382
  if (
371
383
  isinstance(part, str)
372
384
  or isinstance(part, types.File)
@@ -408,7 +420,7 @@ def t_contents(
408
420
  accumulated_parts: list[types.Part],
409
421
  current_part: types.PartUnionDict,
410
422
  ):
411
- current_part = t_part(client, current_part)
423
+ current_part = t_part(current_part)
412
424
  if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
413
425
  accumulated_parts.append(current_part)
414
426
  else:
@@ -430,11 +442,11 @@ def t_contents(
430
442
  ):
431
443
  _append_accumulated_parts_as_content(result, accumulated_parts)
432
444
  if isinstance(content, list):
433
- result.append(types.UserContent(parts=content))
445
+ result.append(types.UserContent(parts=content)) # type: ignore[arg-type]
434
446
  else:
435
447
  result.append(content)
436
- elif (_is_part(content)): # type: ignore
437
- _handle_current_part(result, accumulated_parts, content) # type: ignore
448
+ elif (_is_part(content)):
449
+ _handle_current_part(result, accumulated_parts, content)
438
450
  elif isinstance(content, dict):
439
451
  # PactDict is already handled in _is_part
440
452
  result.append(types.Content.model_validate(content))
@@ -500,14 +512,14 @@ def handle_null_fields(schema: dict[str, Any]):
500
512
  schema['anyOf'].remove({'type': 'null'})
501
513
  if len(schema['anyOf']) == 1:
502
514
  # If there is only one type left after removing null, remove the anyOf field.
503
- for key,val in schema['anyOf'][0].items():
515
+ for key, val in schema['anyOf'][0].items():
504
516
  schema[key] = val
505
517
  del schema['anyOf']
506
518
 
507
519
 
508
520
  def process_schema(
509
521
  schema: dict[str, Any],
510
- client: Optional[_api_client.BaseApiClient] = None,
522
+ client: _api_client.BaseApiClient,
511
523
  defs: Optional[dict[str, Any]] = None,
512
524
  *,
513
525
  order_properties: bool = True,
@@ -570,12 +582,13 @@ def process_schema(
570
582
  'type': 'array'
571
583
  }
572
584
  """
573
- if client and not client.vertexai:
585
+ if not client.vertexai:
574
586
  schema.pop('title', None)
575
587
 
576
588
  if schema.get('default') is not None:
577
589
  raise ValueError(
578
- 'Default value is not supported in the response schema for the Gemini API.'
590
+ 'Default value is not supported in the response schema for the Gemini'
591
+ ' API.'
579
592
  )
580
593
 
581
594
  if schema.get('title') == 'PlaceholderLiteralEnum':
@@ -596,18 +609,15 @@ def process_schema(
596
609
  # After removing null fields, Optional fields with only one possible type
597
610
  # will have a $ref key that needs to be flattened
598
611
  # For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
599
- if schema.get('$ref', None):
600
- ref = defs[schema.get('$ref').split('defs/')[-1]]
612
+ schema_ref = schema.get('$ref', None)
613
+ if schema_ref is not None:
614
+ ref = defs[schema_ref.split('defs/')[-1]]
601
615
  for schema_key in list(ref.keys()):
602
616
  schema[schema_key] = ref[schema_key]
603
617
  del schema['$ref']
604
618
 
605
619
  any_of = schema.get('anyOf', None)
606
620
  if any_of is not None:
607
- if not client.vertexai:
608
- raise ValueError(
609
- 'AnyOf is not supported in the response schema for the Gemini API.'
610
- )
611
621
  for sub_schema in any_of:
612
622
  # $ref is present in any_of if the schema is a union of Pydantic classes
613
623
  ref_key = sub_schema.get('$ref', None)
@@ -670,7 +680,7 @@ def process_schema(
670
680
 
671
681
 
672
682
  def _process_enum(
673
- enum: EnumMeta, client: Optional[_api_client.BaseApiClient] = None # type: ignore
683
+ enum: EnumMeta, client: _api_client.BaseApiClient
674
684
  ) -> types.Schema:
675
685
  for member in enum: # type: ignore
676
686
  if not isinstance(member.value, str):
@@ -680,7 +690,7 @@ def _process_enum(
680
690
  )
681
691
 
682
692
  class Placeholder(pydantic.BaseModel):
683
- placeholder: enum
693
+ placeholder: enum # type: ignore[valid-type]
684
694
 
685
695
  enum_schema = Placeholder.model_json_schema()
686
696
  process_schema(enum_schema, client)
@@ -688,12 +698,19 @@ def _process_enum(
688
698
  return types.Schema.model_validate(enum_schema)
689
699
 
690
700
 
701
+ def _is_type_dict_str_any(origin: Union[types.SchemaUnionDict, Any]) -> TypeGuard[dict[str, Any]]:
702
+ """Verifies the schema is of type dict[str, Any] for mypy type checking."""
703
+ return isinstance(origin, dict) and all(
704
+ isinstance(key, str) for key in origin
705
+ )
706
+
707
+
691
708
  def t_schema(
692
709
  client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
693
710
  ) -> Optional[types.Schema]:
694
711
  if not origin:
695
712
  return None
696
- if isinstance(origin, dict):
713
+ if isinstance(origin, dict) and _is_type_dict_str_any(origin):
697
714
  process_schema(origin, client, order_properties=False)
698
715
  return types.Schema.model_validate(origin)
699
716
  if isinstance(origin, EnumMeta):
@@ -724,7 +741,7 @@ def t_schema(
724
741
  ):
725
742
 
726
743
  class Placeholder(pydantic.BaseModel):
727
- placeholder: origin
744
+ placeholder: origin # type: ignore[valid-type]
728
745
 
729
746
  schema = Placeholder.model_json_schema()
730
747
  process_schema(schema, client)
@@ -735,7 +752,8 @@ def t_schema(
735
752
 
736
753
 
737
754
  def t_speech_config(
738
- _: _api_client.BaseApiClient, origin: Union[types.SpeechConfigUnionDict, Any]
755
+ _: _api_client.BaseApiClient,
756
+ origin: Union[types.SpeechConfigUnionDict, Any],
739
757
  ) -> Optional[types.SpeechConfig]:
740
758
  if not origin:
741
759
  return None
@@ -750,7 +768,10 @@ def t_speech_config(
750
768
  if (
751
769
  isinstance(origin, dict)
752
770
  and 'voice_config' in origin
771
+ and origin['voice_config'] is not None
753
772
  and 'prebuilt_voice_config' in origin['voice_config']
773
+ and origin['voice_config']['prebuilt_voice_config'] is not None
774
+ and 'voice_name' in origin['voice_config']['prebuilt_voice_config']
754
775
  ):
755
776
  return types.SpeechConfig(
756
777
  voice_config=types.VoiceConfig(
@@ -764,7 +785,7 @@ def t_speech_config(
764
785
  raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
765
786
 
766
787
 
767
- def t_tool(client: _api_client.BaseApiClient, origin) -> types.Tool:
788
+ def t_tool(client: _api_client.BaseApiClient, origin) -> Optional[types.Tool]:
768
789
  if not origin:
769
790
  return None
770
791
  if inspect.isfunction(origin) or inspect.ismethod(origin):
@@ -790,12 +811,16 @@ def t_tools(
790
811
  for tool in origin:
791
812
  transformed_tool = t_tool(client, tool)
792
813
  # All functions should be merged into one tool.
793
- if transformed_tool.function_declarations:
794
- function_tool.function_declarations += (
814
+ if transformed_tool is not None:
815
+ if (
795
816
  transformed_tool.function_declarations
796
- )
797
- else:
798
- tools.append(transformed_tool)
817
+ and function_tool.function_declarations is not None
818
+ ):
819
+ function_tool.function_declarations += (
820
+ transformed_tool.function_declarations
821
+ )
822
+ else:
823
+ tools.append(transformed_tool)
799
824
  if function_tool.function_declarations:
800
825
  tools.append(function_tool)
801
826
  return tools
@@ -883,15 +908,28 @@ def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
883
908
 
884
909
 
885
910
  def t_file_name(
886
- api_client: _api_client.BaseApiClient, name: Optional[Union[str, types.File]]
911
+ api_client: _api_client.BaseApiClient,
912
+ name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
887
913
  ):
888
914
  # Remove the files/ prefix since it's added to the url path.
889
915
  if isinstance(name, types.File):
890
916
  name = name.name
917
+ elif isinstance(name, types.Video):
918
+ name = name.uri
919
+ elif isinstance(name, types.GeneratedVideo):
920
+ if name.video is not None:
921
+ name = name.video.uri
922
+ else:
923
+ name = None
891
924
 
892
925
  if name is None:
893
926
  raise ValueError('File name is required.')
894
927
 
928
+ if not isinstance(name, str):
929
+ raise ValueError(
930
+ f'Could not convert object of type `{type(name)}` to a file name.'
931
+ )
932
+
895
933
  if name.startswith('https://'):
896
934
  suffix = name.split('files/')[1]
897
935
  match = re.match('[a-z0-9]+', suffix)
@@ -906,16 +944,19 @@ def t_file_name(
906
944
 
907
945
  def t_tuning_job_status(
908
946
  api_client: _api_client.BaseApiClient, status: str
909
- ) -> types.JobState:
947
+ ) -> Union[types.JobState, str]:
910
948
  if status == 'STATE_UNSPECIFIED':
911
- return 'JOB_STATE_UNSPECIFIED'
949
+ return types.JobState.JOB_STATE_UNSPECIFIED
912
950
  elif status == 'CREATING':
913
- return 'JOB_STATE_RUNNING'
951
+ return types.JobState.JOB_STATE_RUNNING
914
952
  elif status == 'ACTIVE':
915
- return 'JOB_STATE_SUCCEEDED'
953
+ return types.JobState.JOB_STATE_SUCCEEDED
916
954
  elif status == 'FAILED':
917
- return 'JOB_STATE_FAILED'
955
+ return types.JobState.JOB_STATE_FAILED
918
956
  else:
957
+ for state in types.JobState:
958
+ if str(state.value) == status:
959
+ return state
919
960
  return status
920
961
 
921
962