google-genai 1.7.0__py3-none-any.whl → 1.53.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. google/genai/__init__.py +4 -2
  2. google/genai/_adapters.py +55 -0
  3. google/genai/_api_client.py +1301 -299
  4. google/genai/_api_module.py +1 -1
  5. google/genai/_automatic_function_calling_util.py +54 -33
  6. google/genai/_base_transformers.py +26 -0
  7. google/genai/_base_url.py +50 -0
  8. google/genai/_common.py +560 -59
  9. google/genai/_extra_utils.py +371 -38
  10. google/genai/_live_converters.py +1467 -0
  11. google/genai/_local_tokenizer_loader.py +214 -0
  12. google/genai/_mcp_utils.py +117 -0
  13. google/genai/_operations_converters.py +394 -0
  14. google/genai/_replay_api_client.py +204 -92
  15. google/genai/_test_api_client.py +1 -1
  16. google/genai/_tokens_converters.py +520 -0
  17. google/genai/_transformers.py +633 -233
  18. google/genai/batches.py +1733 -538
  19. google/genai/caches.py +678 -1012
  20. google/genai/chats.py +48 -38
  21. google/genai/client.py +142 -15
  22. google/genai/documents.py +532 -0
  23. google/genai/errors.py +141 -35
  24. google/genai/file_search_stores.py +1296 -0
  25. google/genai/files.py +312 -744
  26. google/genai/live.py +617 -367
  27. google/genai/live_music.py +197 -0
  28. google/genai/local_tokenizer.py +395 -0
  29. google/genai/models.py +3598 -3116
  30. google/genai/operations.py +201 -362
  31. google/genai/pagers.py +23 -7
  32. google/genai/py.typed +1 -0
  33. google/genai/tokens.py +362 -0
  34. google/genai/tunings.py +1274 -496
  35. google/genai/types.py +14535 -5454
  36. google/genai/version.py +2 -2
  37. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/METADATA +736 -234
  38. google_genai-1.53.0.dist-info/RECORD +41 -0
  39. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/WHEEL +1 -1
  40. google_genai-1.7.0.dist-info/RECORD +0 -27
  41. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info/licenses}/LICENSE +0 -0
  42. {google_genai-1.7.0.dist-info → google_genai-1.53.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Google LLC
1
+ # Copyright 2025 Google LLC
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -26,7 +26,9 @@ import sys
26
26
  import time
27
27
  import types as builtin_types
28
28
  import typing
29
- from typing import Any, GenericAlias, Optional, Union # type: ignore[attr-defined]
29
+ from typing import Any, GenericAlias, List, Optional, Sequence, Union # type: ignore[attr-defined]
30
+ from ._mcp_utils import mcp_to_gemini_tool
31
+ from ._common import get_value_by_path as getv
30
32
 
31
33
  if typing.TYPE_CHECKING:
32
34
  import PIL.Image
@@ -34,6 +36,7 @@ if typing.TYPE_CHECKING:
34
36
  import pydantic
35
37
 
36
38
  from . import _api_client
39
+ from . import _common
37
40
  from . import types
38
41
 
39
42
  logger = logging.getLogger('google_genai._transformers')
@@ -43,17 +46,70 @@ if sys.version_info >= (3, 10):
43
46
  _UNION_TYPES = (typing.Union, builtin_types.UnionType)
44
47
  from typing import TypeGuard
45
48
  else:
46
- VersionedUnionType = typing._UnionGenericAlias
49
+ VersionedUnionType = typing._UnionGenericAlias # type: ignore[attr-defined]
47
50
  _UNION_TYPES = (typing.Union,)
48
51
  from typing_extensions import TypeGuard
49
52
 
53
+ if typing.TYPE_CHECKING:
54
+ from mcp import ClientSession as McpClientSession
55
+ from mcp.types import Tool as McpTool
56
+ else:
57
+ McpClientSession: typing.Type = Any
58
+ McpTool: typing.Type = Any
59
+ try:
60
+ from mcp import ClientSession as McpClientSession
61
+ from mcp.types import Tool as McpTool
62
+ except ImportError:
63
+ McpClientSession = None
64
+ McpTool = None
65
+
66
+
67
+ metric_name_sdk_api_map = {
68
+ 'exact_match': 'exactMatchSpec',
69
+ 'bleu': 'bleuSpec',
70
+ 'rouge_spec': 'rougeSpec',
71
+ }
72
+ metric_name_api_sdk_map = {v: k for k, v in metric_name_sdk_api_map.items()}
73
+
74
+
75
+ def _is_duck_type_of(obj: Any, cls: type[pydantic.BaseModel]) -> bool:
76
+ """Checks if an object has all of the fields of a Pydantic model.
77
+
78
+ This is a duck-typing alternative to `isinstance` to solve dual-import
79
+ problems. It returns False for dictionaries, which should be handled by
80
+ `isinstance(obj, dict)`.
81
+
82
+ Args:
83
+ obj: The object to check.
84
+ cls: The Pydantic model class to duck-type against.
85
+
86
+ Returns:
87
+ True if the object has all the fields defined in the Pydantic model, False
88
+ otherwise.
89
+ """
90
+ if isinstance(obj, dict) or not hasattr(cls, 'model_fields'):
91
+ return False
92
+
93
+ # Check if the object has all of the Pydantic model's defined fields.
94
+ all_matched = all(hasattr(obj, field) for field in cls.model_fields)
95
+ if not all_matched and isinstance(obj, pydantic.BaseModel):
96
+ # Check the other way around if obj is a Pydantic model.
97
+ # Check if the Pydantic model has all of the object's defined fields.
98
+ try:
99
+ obj_private = cls()
100
+ all_matched = all(hasattr(obj_private, f) for f in type(obj).model_fields)
101
+ except ValueError:
102
+ return False
103
+ return all_matched
104
+
105
+
50
106
  def _resource_name(
51
107
  client: _api_client.BaseApiClient,
52
108
  resource_name: str,
53
109
  *,
54
110
  collection_identifier: str,
55
111
  collection_hierarchy_depth: int = 2,
56
- ):
112
+ ) -> str:
57
113
  # pylint: disable=line-too-long
58
114
  """Prepends resource name with project, location, collection_identifier if needed.
59
115
 
@@ -140,9 +196,11 @@ def _resource_name(
140
196
  return resource_name
141
197
 
142
198
 
143
- def t_model(client: _api_client.BaseApiClient, model: str):
199
+ def t_model(client: _api_client.BaseApiClient, model: str) -> str:
144
200
  if not model:
145
201
  raise ValueError('model is required.')
202
+ if '..' in model or '?' in model or '&' in model:
203
+ raise ValueError('invalid model parameter.')
146
204
  if client.vertexai:
147
205
  if (
148
206
  model.startswith('projects/')
@@ -180,18 +238,26 @@ def t_models_url(
180
238
 
181
239
 
182
240
  def t_extract_models(
183
- api_client: _api_client.BaseApiClient,
184
- response: dict[str, list[types.ModelDict]],
185
- ) -> Optional[list[types.ModelDict]]:
241
+ response: _common.StringDict,
242
+ ) -> list[_common.StringDict]:
186
243
  if not response:
187
244
  return []
188
- elif response.get('models') is not None:
189
- return response.get('models')
190
- elif response.get('tunedModels') is not None:
191
- return response.get('tunedModels')
192
- elif response.get('publisherModels') is not None:
193
- return response.get('publisherModels')
194
- elif (
245
+
246
+ models: Optional[list[_common.StringDict]] = response.get('models')
247
+ if models is not None:
248
+ return models
249
+
250
+ tuned_models: Optional[list[_common.StringDict]] = response.get('tunedModels')
251
+ if tuned_models is not None:
252
+ return tuned_models
253
+
254
+ publisher_models: Optional[list[_common.StringDict]] = response.get(
255
+ 'publisherModels'
256
+ )
257
+ if publisher_models is not None:
258
+ return publisher_models
259
+
260
+ if (
195
261
  response.get('httpHeaders') is not None
196
262
  and response.get('jsonPayload') is None
197
263
  ):
@@ -202,7 +268,9 @@ def t_extract_models(
202
268
  return []
203
269
 
204
270
 
205
- def t_caches_model(api_client: _api_client.BaseApiClient, model: str):
271
+ def t_caches_model(
272
+ api_client: _api_client.BaseApiClient, model: str
273
+ ) -> Optional[str]:
206
274
  model = t_model(api_client, model)
207
275
  if not model:
208
276
  return None
@@ -217,7 +285,7 @@ def t_caches_model(api_client: _api_client.BaseApiClient, model: str):
217
285
  return model
218
286
 
219
287
 
220
- def pil_to_blob(img) -> types.Blob:
288
+ def pil_to_blob(img: Any) -> types.Blob:
221
289
  PngImagePlugin: Optional[builtin_types.ModuleType]
222
290
  try:
223
291
  import PIL.PngImagePlugin
@@ -242,33 +310,119 @@ def pil_to_blob(img) -> types.Blob:
242
310
  return types.Blob(mime_type=mime_type, data=data)
243
311
 
244
312
 
245
- def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
246
- try:
247
- import PIL.Image
313
+ def t_function_response(
314
+ function_response: types.FunctionResponseOrDict,
315
+ ) -> types.FunctionResponse:
316
+ if not function_response:
317
+ raise ValueError('function_response is required.')
318
+ if isinstance(function_response, dict):
319
+ return types.FunctionResponse.model_validate(function_response)
320
+ elif _is_duck_type_of(function_response, types.FunctionResponse):
321
+ return function_response
322
+ else:
323
+ raise TypeError(
324
+ 'Could not parse input as FunctionResponse. Unsupported'
325
+ f' function_response type: {type(function_response)}'
326
+ )
248
327
 
249
- PIL_Image = PIL.Image.Image
250
- except ImportError:
251
- PIL_Image = None
252
328
 
329
+ def t_function_responses(
330
+ function_responses: Union[
331
+ types.FunctionResponseOrDict,
332
+ Sequence[types.FunctionResponseOrDict],
333
+ ],
334
+ ) -> list[types.FunctionResponse]:
335
+ if not function_responses:
336
+ raise ValueError('function_responses are required.')
337
+ if isinstance(function_responses, Sequence):
338
+ return [t_function_response(response) for response in function_responses]
339
+ else:
340
+ return [t_function_response(function_responses)]
341
+
342
+
343
+ def t_blobs(
344
+ blobs: Union[types.BlobImageUnionDict, list[types.BlobImageUnionDict]],
345
+ ) -> list[types.Blob]:
346
+ if isinstance(blobs, list):
347
+ return [t_blob(blob) for blob in blobs]
348
+ else:
349
+ return [t_blob(blobs)]
350
+
351
+
352
+ def t_blob(blob: types.BlobImageUnionDict) -> types.Blob:
353
+ if not blob:
354
+ raise ValueError('blob is required.')
355
+
356
+ if _is_duck_type_of(blob, types.Blob):
357
+ return blob # type: ignore[return-value]
358
+
359
+ if isinstance(blob, dict):
360
+ return types.Blob.model_validate(blob)
361
+
362
+ if 'image' in blob.__class__.__name__.lower():
363
+ try:
364
+ import PIL.Image
365
+
366
+ PIL_Image = PIL.Image.Image
367
+ except ImportError:
368
+ PIL_Image = None
369
+
370
+ if PIL_Image is not None and isinstance(blob, PIL_Image):
371
+ return pil_to_blob(blob)
372
+
373
+ raise TypeError(
374
+ f'Could not parse input as Blob. Unsupported blob type: {type(blob)}'
375
+ )
376
+
377
+
378
+ def t_image_blob(blob: types.BlobImageUnionDict) -> types.Blob:
379
+ blob = t_blob(blob)
380
+ if blob.mime_type and blob.mime_type.startswith('image/'):
381
+ return blob
382
+ raise ValueError(f'Unsupported mime type: {blob.mime_type!r}')
383
+
384
+
385
+ def t_audio_blob(blob: types.BlobOrDict) -> types.Blob:
386
+ blob = t_blob(blob)
387
+ if blob.mime_type and blob.mime_type.startswith('audio/'):
388
+ return blob
389
+ raise ValueError(f'Unsupported mime type: {blob.mime_type!r}')
390
+
391
+
392
+ def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
253
393
  if part is None:
254
394
  raise ValueError('content part is required.')
255
395
  if isinstance(part, str):
256
396
  return types.Part(text=part)
257
- if PIL_Image is not None and isinstance(part, PIL_Image):
258
- return types.Part(inline_data=pil_to_blob(part))
259
- if isinstance(part, types.File):
260
- if not part.uri or not part.mime_type:
397
+ if _is_duck_type_of(part, types.File):
398
+ if not part.uri or not part.mime_type: # type: ignore[union-attr]
261
399
  raise ValueError('file uri and mime_type are required.')
262
- return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
400
+ return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type) # type: ignore[union-attr]
263
401
  if isinstance(part, dict):
264
- return types.Part.model_validate(part)
265
- if isinstance(part, types.Part):
266
- return part
402
+ try:
403
+ return types.Part.model_validate(part)
404
+ except pydantic.ValidationError:
405
+ return types.Part(file_data=types.FileData.model_validate(part))
406
+ if _is_duck_type_of(part, types.Part):
407
+ return part # type: ignore[return-value]
408
+
409
+ if 'image' in part.__class__.__name__.lower():
410
+ try:
411
+ import PIL.Image
412
+
413
+ PIL_Image = PIL.Image.Image
414
+ except ImportError:
415
+ PIL_Image = None
416
+
417
+ if PIL_Image is not None and isinstance(part, PIL_Image):
418
+ return types.Part(inline_data=pil_to_blob(part))
267
419
  raise ValueError(f'Unsupported content part type: {type(part)}')
268
420
 
269
421
 
270
422
  def t_parts(
271
- parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]],
423
+ parts: Optional[
424
+ Union[list[types.PartUnionDict], types.PartUnionDict, list[types.Part]]
425
+ ],
272
426
  ) -> list[types.Part]:
273
427
  #
274
428
  if parts is None or (isinstance(parts, list) and not parts):
@@ -280,7 +434,6 @@ def t_parts(
280
434
 
281
435
 
282
436
  def t_image_predictions(
283
- client: _api_client.BaseApiClient,
284
437
  predictions: Optional[Iterable[Mapping[str, Any]]],
285
438
  ) -> Optional[list[types.GeneratedImage]]:
286
439
  if not predictions:
@@ -303,30 +456,31 @@ ContentType = Union[types.Content, types.ContentDict, types.PartUnionDict]
303
456
 
304
457
 
305
458
  def t_content(
306
- client: _api_client.BaseApiClient,
307
- content: Optional[ContentType],
459
+ content: Union[ContentType, types.ContentDict, None],
308
460
  ) -> types.Content:
309
461
  if content is None:
310
462
  raise ValueError('content is required.')
311
- if isinstance(content, types.Content):
312
- return content
463
+ if _is_duck_type_of(content, types.Content):
464
+ return content # type: ignore[return-value]
313
465
  if isinstance(content, dict):
314
466
  try:
315
467
  return types.Content.model_validate(content)
316
468
  except pydantic.ValidationError:
317
- possible_part = types.Part.model_validate(content)
469
+ possible_part = t_part(content) # type: ignore[arg-type]
318
470
  return (
319
471
  types.ModelContent(parts=[possible_part])
320
472
  if possible_part.function_call
321
473
  else types.UserContent(parts=[possible_part])
322
474
  )
323
- if isinstance(content, types.Part):
475
+ if _is_duck_type_of(content, types.File):
476
+ return types.UserContent(parts=[t_part(content)]) # type: ignore[arg-type]
477
+ if _is_duck_type_of(content, types.Part):
324
478
  return (
325
- types.ModelContent(parts=[content])
326
- if content.function_call
327
- else types.UserContent(parts=[content])
479
+ types.ModelContent(parts=[content]) # type: ignore[arg-type]
480
+ if content.function_call # type: ignore[union-attr]
481
+ else types.UserContent(parts=[content]) # type: ignore[arg-type]
328
482
  )
329
- return types.UserContent(parts=content)
483
+ return types.UserContent(parts=content) # type: ignore[arg-type]
330
484
 
331
485
 
332
486
  def t_contents_for_embed(
@@ -334,9 +488,9 @@ def t_contents_for_embed(
334
488
  contents: Union[list[types.Content], list[types.ContentDict], ContentType],
335
489
  ) -> Union[list[str], list[types.Content]]:
336
490
  if isinstance(contents, list):
337
- transformed_contents = [t_content(client, content) for content in contents]
491
+ transformed_contents = [t_content(content) for content in contents]
338
492
  else:
339
- transformed_contents = [t_content(client, contents)]
493
+ transformed_contents = [t_content(contents)]
340
494
 
341
495
  if client.vertexai:
342
496
  text_parts = []
@@ -349,16 +503,13 @@ def t_contents_for_embed(
349
503
  if part.text:
350
504
  text_parts.append(part.text)
351
505
  else:
352
- logger.warning(
353
- f'Non-text part found, only returning text parts.'
354
- )
506
+ logger.warning(f'Non-text part found, only returning text parts.')
355
507
  return text_parts
356
508
  else:
357
509
  return transformed_contents
358
510
 
359
511
 
360
512
  def t_contents(
361
- client: _api_client.BaseApiClient,
362
513
  contents: Optional[
363
514
  Union[types.ContentListUnion, types.ContentListUnionDict, types.Content]
364
515
  ],
@@ -366,33 +517,45 @@ def t_contents(
366
517
  if contents is None or (isinstance(contents, list) and not contents):
367
518
  raise ValueError('contents are required.')
368
519
  if not isinstance(contents, list):
369
- return [t_content(client, contents)]
370
-
371
- try:
372
- import PIL.Image
373
-
374
- PIL_Image = PIL.Image.Image
375
- except ImportError:
376
- PIL_Image = None
520
+ return [t_content(contents)]
377
521
 
378
522
  result: list[types.Content] = []
379
523
  accumulated_parts: list[types.Part] = []
380
524
 
381
- def _is_part(part: Union[types.PartUnionDict, Any]) -> TypeGuard[types.PartUnionDict]:
525
+ def _is_part(
526
+ part: Union[types.PartUnionDict, Any],
527
+ ) -> TypeGuard[types.PartUnionDict]:
382
528
  if (
383
529
  isinstance(part, str)
384
- or isinstance(part, types.File)
385
- or (PIL_Image is not None and isinstance(part, PIL_Image))
386
- or isinstance(part, types.Part)
530
+ or _is_duck_type_of(part, types.File)
531
+ or _is_duck_type_of(part, types.Part)
387
532
  ):
388
533
  return True
389
534
 
390
535
  if isinstance(part, dict):
536
+ if not part:
537
+ # Empty dict should be considered as Content, not Part.
538
+ return False
391
539
  try:
392
540
  types.Part.model_validate(part)
393
541
  return True
394
542
  except pydantic.ValidationError:
395
- return False
543
+ try:
544
+ types.FileData.model_validate(part)
545
+ return True
546
+ except pydantic.ValidationError:
547
+ return False
548
+
549
+ if 'image' in part.__class__.__name__.lower():
550
+ try:
551
+ import PIL.Image
552
+
553
+ PIL_Image = PIL.Image.Image
554
+ except ImportError:
555
+ PIL_Image = None
556
+
557
+ if PIL_Image is not None and isinstance(part, PIL_Image):
558
+ return True
396
559
 
397
560
  return False
398
561
 
@@ -405,7 +568,7 @@ def t_contents(
405
568
  def _append_accumulated_parts_as_content(
406
569
  result: list[types.Content],
407
570
  accumulated_parts: list[types.Part],
408
- ):
571
+ ) -> None:
409
572
  if not accumulated_parts:
410
573
  return
411
574
  result.append(
@@ -419,7 +582,7 @@ def t_contents(
419
582
  result: list[types.Content],
420
583
  accumulated_parts: list[types.Part],
421
584
  current_part: types.PartUnionDict,
422
- ):
585
+ ) -> None:
423
586
  current_part = t_part(current_part)
424
587
  if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
425
588
  accumulated_parts.append(current_part)
@@ -435,17 +598,13 @@ def t_contents(
435
598
  # append to result
436
599
  # if list, we only accept a list of types.PartUnion
437
600
  for content in contents:
438
- if (
439
- isinstance(content, types.Content)
440
- # only allowed inner list is a list of types.PartUnion
441
- or isinstance(content, list)
442
- ):
601
+ if _is_duck_type_of(content, types.Content) or isinstance(content, list):
443
602
  _append_accumulated_parts_as_content(result, accumulated_parts)
444
603
  if isinstance(content, list):
445
604
  result.append(types.UserContent(parts=content)) # type: ignore[arg-type]
446
605
  else:
447
- result.append(content)
448
- elif (_is_part(content)):
606
+ result.append(content) # type: ignore[arg-type]
607
+ elif _is_part(content):
449
608
  _handle_current_part(result, accumulated_parts, content)
450
609
  elif isinstance(content, dict):
451
610
  # PactDict is already handled in _is_part
@@ -458,7 +617,7 @@ def t_contents(
458
617
  return result
459
618
 
460
619
 
461
- def handle_null_fields(schema: dict[str, Any]):
620
+ def handle_null_fields(schema: _common.StringDict) -> None:
462
621
  """Process null fields in the schema so it is compatible with OpenAPI.
463
622
 
464
623
  The OpenAPI spec does not support 'type: 'null' in the schema. This function
@@ -517,16 +676,34 @@ def handle_null_fields(schema: dict[str, Any]):
517
676
  del schema['anyOf']
518
677
 
519
678
 
679
+ def _raise_for_unsupported_schema_type(origin: Any) -> None:
680
+ """Raises an error if the schema type is unsupported."""
681
+ raise ValueError(f'Unsupported schema type: {origin}')
682
+
683
+
684
+ def _raise_for_unsupported_mldev_properties(
685
+ schema: Any, client: Optional[_api_client.BaseApiClient]
686
+ ) -> None:
687
+ if (
688
+ client
689
+ and not client.vertexai
690
+ and (
691
+ schema.get('additionalProperties')
692
+ or schema.get('additional_properties')
693
+ )
694
+ ):
695
+ raise ValueError('additionalProperties is not supported in the Gemini API.')
696
+
697
+
520
698
  def process_schema(
521
- schema: dict[str, Any],
522
- client: _api_client.BaseApiClient,
523
- defs: Optional[dict[str, Any]] = None,
699
+ schema: _common.StringDict,
700
+ client: Optional[_api_client.BaseApiClient],
701
+ defs: Optional[_common.StringDict] = None,
524
702
  *,
525
703
  order_properties: bool = True,
526
- ):
704
+ ) -> None:
527
705
  """Updates the schema and each sub-schema inplace to be API-compatible.
528
706
 
529
- - Removes the `title` field from the schema if the client is not vertexai.
530
707
  - Inlines the $defs.
531
708
 
532
709
  Example of a schema before and after (with mldev):
@@ -570,73 +747,76 @@ def process_schema(
570
747
  'items': {
571
748
  'properties': {
572
749
  'continent': {
573
- 'type': 'string'
750
+ 'title': 'Continent',
751
+ 'type': 'string'
574
752
  },
575
753
  'gdp': {
576
- 'type': 'integer'}
754
+ 'title': 'Gdp',
755
+ 'type': 'integer'
577
756
  },
578
757
  }
579
758
  'required':['continent', 'gdp'],
759
+ 'title': 'CountryInfo',
580
760
  'type': 'object'
581
761
  },
582
762
  'type': 'array'
583
763
  }
584
764
  """
585
- if not client.vertexai:
586
- schema.pop('title', None)
587
-
588
- if schema.get('default') is not None:
589
- raise ValueError(
590
- 'Default value is not supported in the response schema for the Gemini'
591
- ' API.'
592
- )
593
-
594
765
  if schema.get('title') == 'PlaceholderLiteralEnum':
595
- schema.pop('title', None)
596
-
597
- # If a dict is provided directly to response_schema, it may use `any_of`
598
- # instead of `anyOf`. Otherwise model_json_schema() uses `anyOf`
599
- if schema.get('any_of', None) is not None:
600
- schema['anyOf'] = schema.pop('any_of')
766
+ del schema['title']
767
+
768
+ _raise_for_unsupported_mldev_properties(schema, client)
769
+
770
+ # Standardize spelling for relevant schema fields. For example, if a dict is
771
+ # provided directly to response_schema, it may use `any_of` instead of `anyOf.
772
+ # Otherwise, model_json_schema() uses `anyOf`.
773
+ for from_name, to_name in [
774
+ ('additional_properties', 'additionalProperties'),
775
+ ('any_of', 'anyOf'),
776
+ ('prefix_items', 'prefixItems'),
777
+ ('property_ordering', 'propertyOrdering'),
778
+ ]:
779
+ if (value := schema.pop(from_name, None)) is not None:
780
+ schema[to_name] = value
601
781
 
602
782
  if defs is None:
603
783
  defs = schema.pop('$defs', {})
604
784
  for _, sub_schema in defs.items():
605
- process_schema(sub_schema, client, defs)
785
+ # We can skip the '$ref' check, because JSON schema forbids a '$ref' from
786
+ # directly referencing another '$ref':
787
+ # https://json-schema.org/understanding-json-schema/structuring#recursion
788
+ process_schema(
789
+ sub_schema, client, defs, order_properties=order_properties
790
+ )
606
791
 
607
792
  handle_null_fields(schema)
608
793
 
609
794
  # After removing null fields, Optional fields with only one possible type
610
795
  # will have a $ref key that needs to be flattened
611
796
  # For example: {'default': None, 'description': 'Name of the person', 'nullable': True, '$ref': '#/$defs/TestPerson'}
612
- schema_ref = schema.get('$ref', None)
613
- if schema_ref is not None:
614
- ref = defs[schema_ref.split('defs/')[-1]]
615
- for schema_key in list(ref.keys()):
616
- schema[schema_key] = ref[schema_key]
617
- del schema['$ref']
618
-
619
- any_of = schema.get('anyOf', None)
620
- if any_of is not None:
621
- for sub_schema in any_of:
622
- # $ref is present in any_of if the schema is a union of Pydantic classes
623
- ref_key = sub_schema.get('$ref', None)
624
- if ref_key is None:
625
- process_schema(sub_schema, client, defs)
626
- else:
627
- ref = defs[ref_key.split('defs/')[-1]]
628
- any_of.append(ref)
629
- schema['anyOf'] = [item for item in any_of if '$ref' not in item]
797
+ if (ref := schema.pop('$ref', None)) is not None:
798
+ schema.update(defs[ref.split('defs/')[-1]])
799
+
800
+ def _recurse(sub_schema: _common.StringDict) -> _common.StringDict:
801
+ """Returns the processed `sub_schema`, resolving its '$ref' if any."""
802
+ if (ref := sub_schema.pop('$ref', None)) is not None:
803
+ sub_schema = defs[ref.split('defs/')[-1]]
804
+ process_schema(sub_schema, client, defs, order_properties=order_properties)
805
+ return sub_schema
806
+
807
+ if (any_of := schema.get('anyOf')) is not None:
808
+ schema['anyOf'] = [_recurse(sub_schema) for sub_schema in any_of]
630
809
  return
631
810
 
632
- schema_type = schema.get('type', None)
811
+ schema_type = schema.get('type')
633
812
  if isinstance(schema_type, Enum):
634
813
  schema_type = schema_type.value
635
- schema_type = schema_type.upper()
814
+ if isinstance(schema_type, str):
815
+ schema_type = schema_type.upper()
636
816
 
637
817
  # model_json_schema() returns a schema with a 'const' field when a Literal with one value is provided as a pydantic field
638
818
  # For example `genre: Literal['action']` becomes: {'const': 'action', 'title': 'Genre', 'type': 'string'}
639
- const = schema.get('const', None)
819
+ const = schema.get('const')
640
820
  if const is not None:
641
821
  if schema_type == 'STRING':
642
822
  schema['enum'] = [const]
@@ -645,52 +825,49 @@ def process_schema(
645
825
  raise ValueError('Literal values must be strings.')
646
826
 
647
827
  if schema_type == 'OBJECT':
648
- properties = schema.get('properties', None)
649
- if properties is None:
650
- return
651
- for name, sub_schema in properties.items():
652
- ref_key = sub_schema.get('$ref', None)
653
- if ref_key is None:
654
- process_schema(sub_schema, client, defs)
655
- else:
656
- ref = defs[ref_key.split('defs/')[-1]]
657
- process_schema(ref, client, defs)
658
- properties[name] = ref
659
- if (
660
- len(properties.items()) > 1
661
- and order_properties
662
- and all(
663
- ordering_key not in schema
664
- for ordering_key in ['property_ordering', 'propertyOrdering']
665
- )
666
- ):
667
- property_names = list(properties.keys())
668
- schema['property_ordering'] = property_names
828
+ if (properties := schema.get('properties')) is not None:
829
+ for name, sub_schema in list(properties.items()):
830
+ properties[name] = _recurse(sub_schema)
831
+ if (
832
+ len(properties.items()) > 1
833
+ and order_properties
834
+ and 'propertyOrdering' not in schema
835
+ ):
836
+ schema['property_ordering'] = list(properties.keys())
837
+ if (additional := schema.get('additionalProperties')) is not None:
838
+ # It is legal to set 'additionalProperties' to a bool:
839
+ # https://json-schema.org/understanding-json-schema/reference/object#additionalproperties
840
+ if isinstance(additional, dict):
841
+ schema['additionalProperties'] = _recurse(additional)
669
842
  elif schema_type == 'ARRAY':
670
- sub_schema = schema.get('items', None)
671
- if sub_schema is None:
672
- return
673
- ref_key = sub_schema.get('$ref', None)
674
- if ref_key is None:
675
- process_schema(sub_schema, client, defs)
676
- else:
677
- ref = defs[ref_key.split('defs/')[-1]]
678
- process_schema(ref, client, defs)
679
- schema['items'] = ref
843
+ if (items := schema.get('items')) is not None:
844
+ schema['items'] = _recurse(items)
845
+ if (prefixes := schema.get('prefixItems')) is not None:
846
+ schema['prefixItems'] = [_recurse(prefix) for prefix in prefixes]
680
847
 
681
848
 
682
849
  def _process_enum(
683
- enum: EnumMeta, client: _api_client.BaseApiClient
850
+ enum: EnumMeta, client: Optional[_api_client.BaseApiClient]
684
851
  ) -> types.Schema:
852
+ is_integer_enum = False
853
+
685
854
  for member in enum: # type: ignore
686
- if not isinstance(member.value, str):
855
+ if isinstance(member.value, int):
856
+ is_integer_enum = True
857
+ elif not isinstance(member.value, str):
687
858
  raise TypeError(
688
- f'Enum member {member.name} value must be a string, got'
859
+ f'Enum member {member.name} value must be a string or integer, got'
689
860
  f' {type(member.value)}'
690
861
  )
691
862
 
863
+ enum_to_process = enum
864
+ if is_integer_enum:
865
+ str_members = [str(member.value) for member in enum] # type: ignore
866
+ str_enum = Enum(enum.__name__, str_members, type=str) # type: ignore
867
+ enum_to_process = str_enum
868
+
692
869
  class Placeholder(pydantic.BaseModel):
693
- placeholder: enum # type: ignore[valid-type]
870
+ placeholder: enum_to_process # type: ignore[valid-type]
694
871
 
695
872
  enum_schema = Placeholder.model_json_schema()
696
873
  process_schema(enum_schema, client)
@@ -698,7 +875,9 @@ def _process_enum(
698
875
  return types.Schema.model_validate(enum_schema)
699
876
 
700
877
 
701
- def _is_type_dict_str_any(origin: Union[types.SchemaUnionDict, Any]) -> TypeGuard[dict[str, Any]]:
878
+ def _is_type_dict_str_any(
879
+ origin: Union[types.SchemaUnionDict, Any],
880
+ ) -> TypeGuard[_common.StringDict]:
702
881
  """Verifies the schema is of type dict[str, Any] for mypy type checking."""
703
882
  return isinstance(origin, dict) and all(
704
883
  isinstance(key, str) for key in origin
@@ -706,21 +885,23 @@ def _is_type_dict_str_any(origin: Union[types.SchemaUnionDict, Any]) -> TypeGuar
706
885
 
707
886
 
708
887
  def t_schema(
709
- client: _api_client.BaseApiClient, origin: Union[types.SchemaUnionDict, Any]
888
+ client: Optional[_api_client.BaseApiClient],
889
+ origin: Union[types.SchemaUnionDict, Any],
710
890
  ) -> Optional[types.Schema]:
711
891
  if not origin:
712
892
  return None
713
893
  if isinstance(origin, dict) and _is_type_dict_str_any(origin):
714
- process_schema(origin, client, order_properties=False)
894
+ process_schema(origin, client)
715
895
  return types.Schema.model_validate(origin)
716
896
  if isinstance(origin, EnumMeta):
717
897
  return _process_enum(origin, client)
718
- if isinstance(origin, types.Schema):
719
- if dict(origin) == dict(types.Schema()):
720
- # response_schema value was coerced to an empty Schema instance because it did not adhere to the Schema field annotation
721
- raise ValueError(f'Unsupported schema type.')
722
- schema = origin.model_dump(exclude_unset=True)
723
- process_schema(schema, client, order_properties=False)
898
+ if _is_duck_type_of(origin, types.Schema):
899
+ if dict(origin) == dict(types.Schema()): # type: ignore [arg-type]
900
+ # response_schema value was coerced to an empty Schema instance because
901
+ # it did not adhere to the Schema field annotation
902
+ _raise_for_unsupported_schema_type(origin)
903
+ schema = origin.model_dump(exclude_unset=True) # type: ignore[union-attr]
904
+ process_schema(schema, client)
724
905
  return types.Schema.model_validate(schema)
725
906
 
726
907
  if (
@@ -752,40 +933,43 @@ def t_schema(
752
933
 
753
934
 
754
935
  def t_speech_config(
755
- _: _api_client.BaseApiClient,
756
936
  origin: Union[types.SpeechConfigUnionDict, Any],
757
937
  ) -> Optional[types.SpeechConfig]:
758
938
  if not origin:
759
939
  return None
760
- if isinstance(origin, types.SpeechConfig):
761
- return origin
940
+ if _is_duck_type_of(origin, types.SpeechConfig):
941
+ return origin # type: ignore[return-value]
762
942
  if isinstance(origin, str):
763
943
  return types.SpeechConfig(
764
944
  voice_config=types.VoiceConfig(
765
945
  prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin)
766
946
  )
767
947
  )
768
- if (
769
- isinstance(origin, dict)
770
- and 'voice_config' in origin
771
- and origin['voice_config'] is not None
772
- and 'prebuilt_voice_config' in origin['voice_config']
773
- and origin['voice_config']['prebuilt_voice_config'] is not None
774
- and 'voice_name' in origin['voice_config']['prebuilt_voice_config']
775
- ):
776
- return types.SpeechConfig(
777
- voice_config=types.VoiceConfig(
778
- prebuilt_voice_config=types.PrebuiltVoiceConfig(
779
- voice_name=origin['voice_config']['prebuilt_voice_config'].get(
780
- 'voice_name'
781
- )
782
- )
783
- )
784
- )
948
+ if isinstance(origin, dict):
949
+ return types.SpeechConfig.model_validate(origin)
950
+
785
951
  raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
786
952
 
787
953
 
788
- def t_tool(client: _api_client.BaseApiClient, origin) -> Optional[types.Tool]:
954
+ def t_live_speech_config(
955
+ origin: types.SpeechConfigOrDict,
956
+ ) -> Optional[types.SpeechConfig]:
957
+ if _is_duck_type_of(origin, types.SpeechConfig):
958
+ speech_config = origin
959
+ if isinstance(origin, dict):
960
+ speech_config = types.SpeechConfig.model_validate(origin)
961
+
962
+ if speech_config.multi_speaker_voice_config is not None: # type: ignore[union-attr]
963
+ raise ValueError(
964
+ 'multi_speaker_voice_config is not supported in the live API.'
965
+ )
966
+
967
+ return speech_config # type: ignore[return-value]
968
+
969
+
970
+ def t_tool(
971
+ client: _api_client.BaseApiClient, origin: Any
972
+ ) -> Optional[Union[types.Tool, Any]]:
789
973
  if not origin:
790
974
  return None
791
975
  if inspect.isfunction(origin) or inspect.ismethod(origin):
@@ -796,11 +980,14 @@ def t_tool(client: _api_client.BaseApiClient, origin) -> Optional[types.Tool]:
796
980
  )
797
981
  ]
798
982
  )
983
+ elif McpTool is not None and _is_duck_type_of(origin, McpTool):
984
+ return mcp_to_gemini_tool(origin)
985
+ elif isinstance(origin, dict):
986
+ return types.Tool.model_validate(origin)
799
987
  else:
800
988
  return origin
801
989
 
802
990
 
803
- # Only support functions now.
804
991
  def t_tools(
805
992
  client: _api_client.BaseApiClient, origin: list[Any]
806
993
  ) -> list[types.Tool]:
@@ -826,46 +1013,136 @@ def t_tools(
826
1013
  return tools
827
1014
 
828
1015
 
829
- def t_cached_content_name(client: _api_client.BaseApiClient, name: str):
1016
+ def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
830
1017
  return _resource_name(client, name, collection_identifier='cachedContents')
831
1018
 
832
1019
 
833
- def t_batch_job_source(client: _api_client.BaseApiClient, src: str):
834
- if src.startswith('gs://'):
835
- return types.BatchJobSource(
836
- format='jsonl',
837
- gcs_uri=[src],
838
- )
839
- elif src.startswith('bq://'):
840
- return types.BatchJobSource(
841
- format='bigquery',
842
- bigquery_uri=src,
1020
+ def t_batch_job_source(
1021
+ client: _api_client.BaseApiClient,
1022
+ src: types.BatchJobSourceUnionDict,
1023
+ ) -> types.BatchJobSource:
1024
+ if isinstance(src, dict):
1025
+ src = types.BatchJobSource(**src)
1026
+ if _is_duck_type_of(src, types.BatchJobSource):
1027
+ vertex_sources = sum(
1028
+ [src.gcs_uri is not None, src.bigquery_uri is not None] # type: ignore[union-attr]
843
1029
  )
844
- else:
845
- raise ValueError(f'Unsupported source: {src}')
1030
+ mldev_sources = sum([
1031
+ src.inlined_requests is not None, # type: ignore[union-attr]
1032
+ src.file_name is not None, # type: ignore[union-attr]
1033
+ ])
1034
+ if client.vertexai:
1035
+ if mldev_sources or vertex_sources != 1:
1036
+ raise ValueError(
1037
+ 'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other '
1038
+ 'sources are not supported in Vertex AI.'
1039
+ )
1040
+ else:
1041
+ if vertex_sources or mldev_sources != 1:
1042
+ raise ValueError(
1043
+ 'Exactly one of `inlined_requests`, `file_name`, '
1044
+ '`inlined_embed_content_requests`, or `embed_content_file_name` '
1045
+ 'must be set, other sources are not supported in Gemini API.'
1046
+ )
1047
+ return src # type: ignore[return-value]
1048
+
1049
+ elif isinstance(src, list):
1050
+ return types.BatchJobSource(inlined_requests=src)
1051
+ elif isinstance(src, str):
1052
+ if src.startswith('gs://'):
1053
+ return types.BatchJobSource(
1054
+ format='jsonl',
1055
+ gcs_uri=[src],
1056
+ )
1057
+ elif src.startswith('bq://'):
1058
+ return types.BatchJobSource(
1059
+ format='bigquery',
1060
+ bigquery_uri=src,
1061
+ )
1062
+ elif src.startswith('files/'):
1063
+ return types.BatchJobSource(
1064
+ file_name=src,
1065
+ )
846
1066
 
1067
+ raise ValueError(f'Unsupported source: {src}')
847
1068
 
848
- def t_batch_job_destination(client: _api_client.BaseApiClient, dest: str):
849
- if dest.startswith('gs://'):
850
- return types.BatchJobDestination(
851
- format='jsonl',
852
- gcs_uri=dest,
853
- )
854
- elif dest.startswith('bq://'):
855
- return types.BatchJobDestination(
856
- format='bigquery',
857
- bigquery_uri=dest,
858
- )
1069
+
1070
+ def t_embedding_batch_job_source(
1071
+ client: _api_client.BaseApiClient,
1072
+ src: types.EmbeddingsBatchJobSourceOrDict,
1073
+ ) -> types.EmbeddingsBatchJobSource:
1074
+ if isinstance(src, dict):
1075
+ src = types.EmbeddingsBatchJobSource(**src)
1076
+
1077
+ if _is_duck_type_of(src, types.EmbeddingsBatchJobSource):
1078
+ mldev_sources = sum([
1079
+ src.inlined_requests is not None,
1080
+ src.file_name is not None,
1081
+ ])
1082
+ if mldev_sources != 1:
1083
+ raise ValueError(
1084
+ 'Exactly one of `inlined_requests`, `file_name`, '
1085
+ '`inlined_embed_content_requests`, or `embed_content_file_name` '
1086
+ 'must be set, other sources are not supported in Gemini API.'
1087
+ )
1088
+ return src
1089
+ else:
1090
+ raise ValueError(f'Unsupported source type: {type(src)}')
1091
+
1092
+
1093
+ def t_batch_job_destination(
1094
+ dest: Union[str, types.BatchJobDestinationOrDict],
1095
+ ) -> types.BatchJobDestination:
1096
+ if isinstance(dest, dict):
1097
+ dest = types.BatchJobDestination(**dest)
1098
+ return dest
1099
+ elif isinstance(dest, str):
1100
+ if dest.startswith('gs://'):
1101
+ return types.BatchJobDestination(
1102
+ format='jsonl',
1103
+ gcs_uri=dest,
1104
+ )
1105
+ elif dest.startswith('bq://'):
1106
+ return types.BatchJobDestination(
1107
+ format='bigquery',
1108
+ bigquery_uri=dest,
1109
+ )
1110
+ else:
1111
+ raise ValueError(f'Unsupported destination: {dest}')
1112
+ elif _is_duck_type_of(dest, types.BatchJobDestination):
1113
+ return dest
859
1114
  else:
860
1115
  raise ValueError(f'Unsupported destination: {dest}')
861
1116
 
862
1117
 
863
- def t_batch_job_name(client: _api_client.BaseApiClient, name: str):
1118
+ def t_recv_batch_job_destination(dest: dict[str, Any]) -> dict[str, Any]:
1119
+ # Rename inlinedResponses if it looks like an embedding response.
1120
+ inline_responses = dest.get('inlinedResponses', {}).get(
1121
+ 'inlinedResponses', []
1122
+ )
1123
+ if not inline_responses:
1124
+ return dest
1125
+ for response in inline_responses:
1126
+ inner_response = response.get('response', {})
1127
+ if not inner_response:
1128
+ continue
1129
+ if 'embedding' in inner_response:
1130
+ dest['inlinedEmbedContentResponses'] = dest.pop('inlinedResponses')
1131
+ break
1132
+ return dest
1133
+
1134
+
1135
+ def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
864
1136
  if not client.vertexai:
865
- return name
1137
+ mldev_pattern = r'batches/[^/]+$'
1138
+ if re.match(mldev_pattern, name):
1139
+ return name.split('/')[-1]
1140
+ else:
1141
+ raise ValueError(f'Invalid batch job name: {name}.')
1142
+
1143
+ vertex_pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
866
1144
 
867
- pattern = r'^projects/[^/]+/locations/[^/]+/batchPredictionJobs/[^/]+$'
868
- if re.match(pattern, name):
1145
+ if re.match(vertex_pattern, name):
869
1146
  return name.split('/')[-1]
870
1147
  elif name.isdigit():
871
1148
  return name
@@ -873,22 +1150,43 @@ def t_batch_job_name(client: _api_client.BaseApiClient, name: str):
873
1150
  raise ValueError(f'Invalid batch job name: {name}.')
874
1151
 
875
1152
 
1153
+ def t_job_state(state: str) -> str:
1154
+ if state == 'BATCH_STATE_UNSPECIFIED':
1155
+ return 'JOB_STATE_UNSPECIFIED'
1156
+ elif state == 'BATCH_STATE_PENDING':
1157
+ return 'JOB_STATE_PENDING'
1158
+ elif state == 'BATCH_STATE_RUNNING':
1159
+ return 'JOB_STATE_RUNNING'
1160
+ elif state == 'BATCH_STATE_SUCCEEDED':
1161
+ return 'JOB_STATE_SUCCEEDED'
1162
+ elif state == 'BATCH_STATE_FAILED':
1163
+ return 'JOB_STATE_FAILED'
1164
+ elif state == 'BATCH_STATE_CANCELLED':
1165
+ return 'JOB_STATE_CANCELLED'
1166
+ elif state == 'BATCH_STATE_EXPIRED':
1167
+ return 'JOB_STATE_EXPIRED'
1168
+ else:
1169
+ return state
1170
+
1171
+
876
1172
  LRO_POLLING_INITIAL_DELAY_SECONDS = 1.0
877
1173
  LRO_POLLING_MAXIMUM_DELAY_SECONDS = 20.0
878
1174
  LRO_POLLING_TIMEOUT_SECONDS = 900.0
879
1175
  LRO_POLLING_MULTIPLIER = 1.5
880
1176
 
881
1177
 
882
- def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
1178
+ def t_resolve_operation(
1179
+ api_client: _api_client.BaseApiClient, struct: _common.StringDict
1180
+ ) -> Any:
883
1181
  if (name := struct.get('name')) and '/operations/' in name:
884
- operation: dict[str, Any] = struct
1182
+ operation: _common.StringDict = struct
885
1183
  total_seconds = 0.0
886
1184
  delay_seconds = LRO_POLLING_INITIAL_DELAY_SECONDS
887
1185
  while operation.get('done') != True:
888
1186
  if total_seconds > LRO_POLLING_TIMEOUT_SECONDS:
889
1187
  raise RuntimeError(f'Operation {name} timed out.\n{operation}')
890
1188
  # TODO(b/374433890): Replace with LRO module once it's available.
891
- operation = api_client.request(
1189
+ operation = api_client.request( # type: ignore[assignment]
892
1190
  http_method='GET', path=name, request_dict={}
893
1191
  )
894
1192
  time.sleep(delay_seconds)
@@ -908,17 +1206,16 @@ def t_resolve_operation(api_client: _api_client.BaseApiClient, struct: dict):
908
1206
 
909
1207
 
910
1208
  def t_file_name(
911
- api_client: _api_client.BaseApiClient,
912
1209
  name: Optional[Union[str, types.File, types.Video, types.GeneratedVideo]],
913
- ):
1210
+ ) -> str:
914
1211
  # Remove the files/ prefix since it's added to the url path.
915
- if isinstance(name, types.File):
916
- name = name.name
917
- elif isinstance(name, types.Video):
918
- name = name.uri
919
- elif isinstance(name, types.GeneratedVideo):
920
- if name.video is not None:
921
- name = name.video.uri
1212
+ if _is_duck_type_of(name, types.File):
1213
+ name = name.name # type: ignore[union-attr]
1214
+ elif _is_duck_type_of(name, types.Video):
1215
+ name = name.uri # type: ignore[union-attr]
1216
+ elif _is_duck_type_of(name, types.GeneratedVideo):
1217
+ if name.video is not None: # type: ignore[union-attr]
1218
+ name = name.video.uri # type: ignore[union-attr]
922
1219
  else:
923
1220
  name = None
924
1221
 
@@ -942,9 +1239,7 @@ def t_file_name(
942
1239
  return name
943
1240
 
944
1241
 
945
- def t_tuning_job_status(
946
- api_client: _api_client.BaseApiClient, status: str
947
- ) -> Union[types.JobState, str]:
1242
+ def t_tuning_job_status(status: str) -> Union[types.JobState, str]:
948
1243
  if status == 'STATE_UNSPECIFIED':
949
1244
  return types.JobState.JOB_STATE_UNSPECIFIED
950
1245
  elif status == 'CREATING':
@@ -960,11 +1255,116 @@ def t_tuning_job_status(
960
1255
  return status
961
1256
 
962
1257
 
963
- # Some fields don't accept url safe base64 encoding.
964
- # We shouldn't use this transformer if the backend adhere to Cloud Type
965
- # format https://cloud.google.com/docs/discovery/type-format.
966
- # TODO(b/389133914,b/390320301): Remove the hack after backend fix the issue.
967
- def t_bytes(api_client: _api_client.BaseApiClient, data: bytes) -> str:
968
- if not isinstance(data, bytes):
969
- return data
970
- return base64.b64encode(data).decode('ascii')
1258
+ def t_content_strict(content: types.ContentOrDict) -> types.Content:
1259
+ if isinstance(content, dict):
1260
+ return types.Content.model_validate(content)
1261
+ elif _is_duck_type_of(content, types.Content):
1262
+ return content
1263
+ else:
1264
+ raise ValueError(
1265
+ f'Could not convert input (type "{type(content)}") to `types.Content`'
1266
+ )
1267
+
1268
+
1269
+ def t_contents_strict(
1270
+ contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict],
1271
+ ) -> list[types.Content]:
1272
+ if isinstance(contents, Sequence):
1273
+ return [t_content_strict(content) for content in contents]
1274
+ else:
1275
+ return [t_content_strict(contents)]
1276
+
1277
+
1278
+ def t_client_content(
1279
+ turns: Optional[
1280
+ Union[Sequence[types.ContentOrDict], types.ContentOrDict]
1281
+ ] = None,
1282
+ turn_complete: bool = True,
1283
+ ) -> types.LiveClientContent:
1284
+ if turns is None:
1285
+ return types.LiveClientContent(turn_complete=turn_complete)
1286
+
1287
+ try:
1288
+ return types.LiveClientContent(
1289
+ turns=t_contents_strict(contents=turns),
1290
+ turn_complete=turn_complete,
1291
+ )
1292
+ except Exception as e:
1293
+ raise ValueError(
1294
+ f'Could not convert input (type "{type(turns)}") to '
1295
+ '`types.LiveClientContent`'
1296
+ ) from e
1297
+
1298
+
1299
+ def t_tool_response(
1300
+ input: Union[
1301
+ types.FunctionResponseOrDict,
1302
+ Sequence[types.FunctionResponseOrDict],
1303
+ ],
1304
+ ) -> types.LiveClientToolResponse:
1305
+ if not input:
1306
+ raise ValueError(f'A tool response is required, got: \n{input}')
1307
+
1308
+ try:
1309
+ return types.LiveClientToolResponse(
1310
+ function_responses=t_function_responses(function_responses=input)
1311
+ )
1312
+ except Exception as e:
1313
+ raise ValueError(
1314
+ f'Could not convert input (type "{type(input)}") to '
1315
+ '`types.LiveClientToolResponse`'
1316
+ ) from e
1317
+
1318
+
1319
+ def t_metrics(
1320
+ metrics: list[types.MetricSubclass]
1321
+ ) -> list[dict[str, Any]]:
1322
+ """Prepares the metric payload for the evaluation request.
1323
+
1324
+ Args:
1325
+ request_dict: The dictionary containing the request details.
1326
+ resolved_metrics: A list of resolved metric objects.
1327
+
1328
+ Returns:
1329
+ The updated request dictionary with the prepared metric payload.
1330
+ """
1331
+ metrics_payload = []
1332
+
1333
+ for metric in metrics:
1334
+ metric_payload_item: dict[str, Any] = {}
1335
+ metric_payload_item['aggregation_metrics'] = [
1336
+ 'AVERAGE',
1337
+ 'STANDARD_DEVIATION',
1338
+ ]
1339
+
1340
+ metric_name = getv(metric, ['name']).lower()
1341
+
1342
+ if metric_name == 'exact_match':
1343
+ metric_payload_item['exact_match_spec'] = {}
1344
+ elif metric_name == 'bleu':
1345
+ metric_payload_item['bleu_spec'] = {}
1346
+ elif metric_name.startswith('rouge'):
1347
+ rouge_type = metric_name.replace("_", "")
1348
+ metric_payload_item['rouge_spec'] = {'rouge_type': rouge_type}
1349
+
1350
+ elif hasattr(metric, 'prompt_template') and metric.prompt_template:
1351
+ pointwise_spec = {'metric_prompt_template': metric.prompt_template}
1352
+ system_instruction = getv(
1353
+ metric, ['judge_model_system_instruction']
1354
+ )
1355
+ if system_instruction:
1356
+ pointwise_spec['system_instruction'] = system_instruction
1357
+ return_raw_output = getv(
1358
+ metric, ['return_raw_output']
1359
+ )
1360
+ if return_raw_output:
1361
+ pointwise_spec['custom_output_format_config'] = { # type: ignore[assignment]
1362
+ 'return_raw_output': return_raw_output
1363
+ }
1364
+ metric_payload_item['pointwise_metric_spec'] = pointwise_spec
1365
+ else:
1366
+ raise ValueError(
1367
+ 'Unsupported metric type or invalid metric name:' f' {metric_name}'
1368
+ )
1369
+ metrics_payload.append(metric_payload_item)
1370
+ return metrics_payload