google-genai 1.4.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +207 -111
- google/genai/_automatic_function_calling_util.py +6 -16
- google/genai/_common.py +5 -2
- google/genai/_extra_utils.py +62 -47
- google/genai/_replay_api_client.py +70 -2
- google/genai/_transformers.py +98 -57
- google/genai/batches.py +14 -10
- google/genai/caches.py +30 -36
- google/genai/client.py +3 -2
- google/genai/errors.py +11 -19
- google/genai/files.py +28 -15
- google/genai/live.py +276 -93
- google/genai/models.py +201 -112
- google/genai/operations.py +40 -12
- google/genai/pagers.py +17 -10
- google/genai/tunings.py +40 -30
- google/genai/types.py +146 -58
- google/genai/version.py +1 -1
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/METADATA +194 -24
- google_genai-1.6.0.dist-info/RECORD +27 -0
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/WHEEL +1 -1
- google_genai-1.4.0.dist-info/RECORD +0 -27
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/LICENSE +0 -0
- {google_genai-1.4.0.dist-info → google_genai-1.6.0.dist-info}/top_level.txt +0 -0
google/genai/live.py
CHANGED
@@ -20,9 +20,10 @@ import base64
|
|
20
20
|
import contextlib
|
21
21
|
import json
|
22
22
|
import logging
|
23
|
-
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union
|
23
|
+
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
|
24
24
|
|
25
25
|
import google.auth
|
26
|
+
import pydantic
|
26
27
|
from websockets import ConnectionClosed
|
27
28
|
|
28
29
|
from . import _api_module
|
@@ -49,12 +50,12 @@ from .models import _Tool_to_mldev
|
|
49
50
|
from .models import _Tool_to_vertex
|
50
51
|
|
51
52
|
try:
|
52
|
-
from websockets.asyncio.client import ClientConnection
|
53
|
-
from websockets.asyncio.client import connect
|
53
|
+
from websockets.asyncio.client import ClientConnection # type: ignore
|
54
|
+
from websockets.asyncio.client import connect # type: ignore
|
54
55
|
except ModuleNotFoundError:
|
55
56
|
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
|
56
|
-
from websockets.client import ClientConnection
|
57
|
-
from websockets.client import connect
|
57
|
+
from websockets.client import ClientConnection # type: ignore
|
58
|
+
from websockets.client import connect # type: ignore
|
58
59
|
|
59
60
|
logger = logging.getLogger('google_genai.live')
|
60
61
|
|
@@ -171,7 +172,7 @@ class AsyncSession:
|
|
171
172
|
stream = read_audio()
|
172
173
|
for data in stream:
|
173
174
|
yield data
|
174
|
-
async with client.aio.live.connect(model='...') as session:
|
175
|
+
async with client.aio.live.connect(model='...', config=config) as session:
|
175
176
|
for audio in session.start_stream(stream = audio_stream(),
|
176
177
|
mime_type = 'audio/pcm'):
|
177
178
|
play_audio_chunk(audio.data)
|
@@ -211,7 +212,7 @@ class AsyncSession:
|
|
211
212
|
try:
|
212
213
|
response = json.loads(raw_response)
|
213
214
|
except json.decoder.JSONDecodeError:
|
214
|
-
raise ValueError(f'Failed to parse response: {raw_response}')
|
215
|
+
raise ValueError(f'Failed to parse response: {raw_response!r}')
|
215
216
|
else:
|
216
217
|
response = {}
|
217
218
|
if self._api_client.vertexai:
|
@@ -220,7 +221,7 @@ class AsyncSession:
|
|
220
221
|
response_dict = self._LiveServerMessage_from_mldev(response)
|
221
222
|
|
222
223
|
return types.LiveServerMessage._from_response(
|
223
|
-
response=response_dict, kwargs=parameter_model
|
224
|
+
response=response_dict, kwargs=parameter_model.model_dump()
|
224
225
|
)
|
225
226
|
|
226
227
|
async def _send_loop(
|
@@ -230,8 +231,10 @@ class AsyncSession:
|
|
230
231
|
stop_event: asyncio.Event,
|
231
232
|
):
|
232
233
|
async for data in data_stream:
|
233
|
-
|
234
|
-
|
234
|
+
model_input = types.LiveClientRealtimeInput(
|
235
|
+
media_chunks=[types.Blob(data=data, mime_type=mime_type)]
|
236
|
+
)
|
237
|
+
await self.send(input=model_input)
|
235
238
|
# Give a chance for the receive loop to process responses.
|
236
239
|
await asyncio.sleep(10**-12)
|
237
240
|
# Give a chance for the receiver to process the last response.
|
@@ -372,124 +375,288 @@ class AsyncSession:
|
|
372
375
|
]
|
373
376
|
] = None,
|
374
377
|
end_of_turn: Optional[bool] = False,
|
375
|
-
) ->
|
378
|
+
) -> types.LiveClientMessageDict:
|
379
|
+
|
380
|
+
formatted_input: Any = input
|
376
381
|
|
377
382
|
if not input:
|
378
383
|
logging.info('No input provided. Assume it is the end of turn.')
|
379
384
|
return {'client_content': {'turn_complete': True}}
|
380
385
|
if isinstance(input, str):
|
381
|
-
|
386
|
+
formatted_input = [input]
|
382
387
|
elif isinstance(input, dict) and 'data' in input:
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
388
|
+
try:
|
389
|
+
blob_input = types.Blob(**input)
|
390
|
+
except pydantic.ValidationError:
|
391
|
+
raise ValueError(
|
392
|
+
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
393
|
+
)
|
394
|
+
if (
|
395
|
+
isinstance(blob_input, types.Blob)
|
396
|
+
and isinstance(blob_input.data, bytes)
|
397
|
+
):
|
398
|
+
formatted_input = [
|
399
|
+
blob_input.model_dump(mode='json', exclude_none=True)
|
400
|
+
]
|
387
401
|
elif isinstance(input, types.Blob):
|
388
|
-
|
402
|
+
formatted_input = [input]
|
389
403
|
elif isinstance(input, dict) and 'name' in input and 'response' in input:
|
390
404
|
# ToolResponse.FunctionResponse
|
391
405
|
if not (self._api_client.vertexai) and 'id' not in input:
|
392
406
|
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
393
|
-
|
407
|
+
formatted_input = [input]
|
394
408
|
|
395
|
-
if isinstance(
|
396
|
-
isinstance(c, dict) and 'name' in c and 'response' in c
|
409
|
+
if isinstance(formatted_input, Sequence) and any(
|
410
|
+
isinstance(c, dict) and 'name' in c and 'response' in c
|
411
|
+
for c in formatted_input
|
397
412
|
):
|
398
413
|
# ToolResponse.FunctionResponse
|
399
|
-
|
400
|
-
|
401
|
-
|
414
|
+
function_responses_input = []
|
415
|
+
for item in formatted_input:
|
416
|
+
if isinstance(item, dict):
|
417
|
+
try:
|
418
|
+
function_response_input = types.FunctionResponse(**item)
|
419
|
+
except pydantic.ValidationError:
|
420
|
+
raise ValueError(
|
421
|
+
f'Unsupported input type "{type(input)}" or input content'
|
422
|
+
f' "{input}"'
|
423
|
+
)
|
424
|
+
if (
|
425
|
+
function_response_input.id is None
|
426
|
+
and not self._api_client.vertexai
|
427
|
+
):
|
402
428
|
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
403
|
-
|
404
|
-
|
429
|
+
else:
|
430
|
+
function_response_dict = function_response_input.model_dump(
|
431
|
+
exclude_none=True, mode='json'
|
432
|
+
)
|
433
|
+
function_response_typeddict = types.FunctionResponseDict(
|
434
|
+
name=function_response_dict.get('name'),
|
435
|
+
response=function_response_dict.get('response'),
|
436
|
+
)
|
437
|
+
if function_response_dict.get('id'):
|
438
|
+
function_response_typeddict['id'] = function_response_dict.get(
|
439
|
+
'id'
|
440
|
+
)
|
441
|
+
function_responses_input.append(function_response_typeddict)
|
442
|
+
client_message = types.LiveClientMessageDict(
|
443
|
+
tool_response=types.LiveClientToolResponseDict(
|
444
|
+
function_responses=function_responses_input
|
445
|
+
)
|
446
|
+
)
|
447
|
+
elif isinstance(formatted_input, Sequence) and any(
|
448
|
+
isinstance(c, str) for c in formatted_input
|
449
|
+
):
|
405
450
|
to_object: dict[str, Any] = {}
|
451
|
+
content_input_parts: list[types.PartUnion] = []
|
452
|
+
for item in formatted_input:
|
453
|
+
if isinstance(item, get_args(types.PartUnion)):
|
454
|
+
content_input_parts.append(item)
|
406
455
|
if self._api_client.vertexai:
|
407
456
|
contents = [
|
408
457
|
_Content_to_vertex(self._api_client, item, to_object)
|
409
|
-
for item in t.t_contents(self._api_client,
|
458
|
+
for item in t.t_contents(self._api_client, content_input_parts)
|
410
459
|
]
|
411
460
|
else:
|
412
461
|
contents = [
|
413
462
|
_Content_to_mldev(self._api_client, item, to_object)
|
414
|
-
for item in t.t_contents(self._api_client,
|
463
|
+
for item in t.t_contents(self._api_client, content_input_parts)
|
415
464
|
]
|
416
465
|
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
466
|
+
content_dict_list: list[types.ContentDict] = []
|
467
|
+
for item in contents:
|
468
|
+
try:
|
469
|
+
content_input = types.Content(**item)
|
470
|
+
except pydantic.ValidationError:
|
471
|
+
raise ValueError(
|
472
|
+
f'Unsupported input type "{type(input)}" or input content'
|
473
|
+
f' "{input}"'
|
474
|
+
)
|
475
|
+
content_dict_list.append(
|
476
|
+
types.ContentDict(
|
477
|
+
parts=content_input.model_dump(exclude_none=True, mode='json')[
|
478
|
+
'parts'
|
479
|
+
],
|
480
|
+
role=content_input.role,
|
481
|
+
)
|
482
|
+
)
|
483
|
+
|
484
|
+
client_message = types.LiveClientMessageDict(
|
485
|
+
client_content=types.LiveClientContentDict(
|
486
|
+
turns=content_dict_list, turn_complete=end_of_turn
|
487
|
+
)
|
488
|
+
)
|
489
|
+
elif isinstance(formatted_input, Sequence):
|
490
|
+
if any((isinstance(b, dict) and 'data' in b) for b in formatted_input):
|
422
491
|
pass
|
423
|
-
elif any(isinstance(b, types.Blob) for b in
|
424
|
-
|
492
|
+
elif any(isinstance(b, types.Blob) for b in formatted_input):
|
493
|
+
formatted_input = [
|
494
|
+
b.model_dump(exclude_none=True, mode='json')
|
495
|
+
for b in formatted_input
|
496
|
+
]
|
425
497
|
else:
|
426
498
|
raise ValueError(
|
427
499
|
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
428
500
|
)
|
429
501
|
|
430
|
-
client_message =
|
502
|
+
client_message = types.LiveClientMessageDict(
|
503
|
+
realtime_input=types.LiveClientRealtimeInputDict(
|
504
|
+
media_chunks=formatted_input
|
505
|
+
)
|
506
|
+
)
|
431
507
|
|
432
|
-
elif isinstance(
|
433
|
-
if 'content' in
|
508
|
+
elif isinstance(formatted_input, dict):
|
509
|
+
if 'content' in formatted_input or 'turns' in formatted_input:
|
434
510
|
# TODO(b/365983264) Add validation checks for content_update input_dict.
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
client_message =
|
511
|
+
if 'turns' in formatted_input:
|
512
|
+
content_turns = formatted_input['turns']
|
513
|
+
else:
|
514
|
+
content_turns = formatted_input['content']
|
515
|
+
client_message = types.LiveClientMessageDict(
|
516
|
+
client_content=types.LiveClientContentDict(
|
517
|
+
turns=content_turns,
|
518
|
+
turn_complete=formatted_input.get('turn_complete'),
|
519
|
+
)
|
520
|
+
)
|
521
|
+
elif 'media_chunks' in formatted_input:
|
522
|
+
try:
|
523
|
+
realtime_input = types.LiveClientRealtimeInput(**formatted_input)
|
524
|
+
except pydantic.ValidationError:
|
525
|
+
raise ValueError(
|
526
|
+
f'Unsupported input type "{type(input)}" or input content'
|
527
|
+
f' "{input}"'
|
528
|
+
)
|
529
|
+
client_message = types.LiveClientMessageDict(
|
530
|
+
realtime_input=types.LiveClientRealtimeInputDict(
|
531
|
+
media_chunks=realtime_input.model_dump(
|
532
|
+
exclude_none=True, mode='json'
|
533
|
+
)['media_chunks']
|
534
|
+
)
|
535
|
+
)
|
536
|
+
elif 'function_responses' in formatted_input:
|
537
|
+
try:
|
538
|
+
tool_response_input = types.LiveClientToolResponse(**formatted_input)
|
539
|
+
except pydantic.ValidationError:
|
540
|
+
raise ValueError(
|
541
|
+
f'Unsupported input type "{type(input)}" or input content'
|
542
|
+
f' "{input}"'
|
543
|
+
)
|
544
|
+
client_message = types.LiveClientMessageDict(
|
545
|
+
tool_response=types.LiveClientToolResponseDict(
|
546
|
+
function_responses=tool_response_input.model_dump(
|
547
|
+
exclude_none=True, mode='json'
|
548
|
+
)['function_responses']
|
549
|
+
)
|
550
|
+
)
|
440
551
|
else:
|
441
552
|
raise ValueError(
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
553
|
+
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
554
|
+
)
|
555
|
+
elif isinstance(formatted_input, types.LiveClientRealtimeInput):
|
556
|
+
realtime_input_dict = formatted_input.model_dump(
|
557
|
+
exclude_none=True, mode='json'
|
558
|
+
)
|
559
|
+
client_message = types.LiveClientMessageDict(
|
560
|
+
realtime_input=types.LiveClientRealtimeInputDict(
|
561
|
+
media_chunks=realtime_input_dict.get('media_chunks')
|
562
|
+
)
|
563
|
+
)
|
564
|
+
if (
|
565
|
+
client_message['realtime_input'] is not None
|
566
|
+
and client_message['realtime_input']['media_chunks'] is not None
|
567
|
+
and isinstance(
|
568
|
+
client_message['realtime_input']['media_chunks'][0]['data'], bytes
|
569
|
+
)
|
449
570
|
):
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
571
|
+
formatted_media_chunks: list[types.BlobDict] = []
|
572
|
+
for item in client_message['realtime_input']['media_chunks']:
|
573
|
+
if isinstance(item, dict):
|
574
|
+
try:
|
575
|
+
blob_input = types.Blob(**item)
|
576
|
+
except pydantic.ValidationError:
|
577
|
+
raise ValueError(
|
578
|
+
f'Unsupported input type "{type(input)}" or input content'
|
579
|
+
f' "{input}"'
|
580
|
+
)
|
581
|
+
if (
|
582
|
+
isinstance(blob_input, types.Blob)
|
583
|
+
and isinstance(blob_input.data, bytes)
|
584
|
+
and blob_input.data is not None
|
585
|
+
):
|
586
|
+
formatted_media_chunks.append(
|
587
|
+
types.BlobDict(
|
588
|
+
data=base64.b64decode(blob_input.data),
|
589
|
+
mime_type=blob_input.mime_type,
|
590
|
+
)
|
591
|
+
)
|
592
|
+
|
593
|
+
client_message['realtime_input'][
|
594
|
+
'media_chunks'
|
595
|
+
] = formatted_media_chunks
|
457
596
|
|
458
|
-
elif isinstance(
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
597
|
+
elif isinstance(formatted_input, types.LiveClientContent):
|
598
|
+
client_content_dict = formatted_input.model_dump(
|
599
|
+
exclude_none=True, mode='json'
|
600
|
+
)
|
601
|
+
client_message = types.LiveClientMessageDict(
|
602
|
+
client_content=types.LiveClientContentDict(
|
603
|
+
turns=client_content_dict.get('turns'),
|
604
|
+
turn_complete=client_content_dict.get('turn_complete'),
|
605
|
+
)
|
606
|
+
)
|
607
|
+
elif isinstance(formatted_input, types.LiveClientToolResponse):
|
463
608
|
# ToolResponse.FunctionResponse
|
464
|
-
if
|
465
|
-
|
609
|
+
if (
|
610
|
+
not (self._api_client.vertexai)
|
611
|
+
and formatted_input.function_responses is not None
|
612
|
+
and not (formatted_input.function_responses[0].id)
|
466
613
|
):
|
467
614
|
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
468
|
-
client_message =
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
615
|
+
client_message = types.LiveClientMessageDict(
|
616
|
+
tool_response=types.LiveClientToolResponseDict(
|
617
|
+
function_responses=formatted_input.model_dump(
|
618
|
+
exclude_none=True, mode='json'
|
619
|
+
).get('function_responses')
|
620
|
+
)
|
621
|
+
)
|
622
|
+
elif isinstance(formatted_input, types.FunctionResponse):
|
623
|
+
if not (self._api_client.vertexai) and not (formatted_input.id):
|
473
624
|
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
474
|
-
|
475
|
-
'
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
625
|
+
function_response_dict = formatted_input.model_dump(
|
626
|
+
exclude_none=True, mode='json'
|
627
|
+
)
|
628
|
+
function_response_typeddict = types.FunctionResponseDict(
|
629
|
+
name=function_response_dict.get('name'),
|
630
|
+
response=function_response_dict.get('response'),
|
631
|
+
)
|
632
|
+
if function_response_dict.get('id'):
|
633
|
+
function_response_typeddict['id'] = function_response_dict.get('id')
|
634
|
+
client_message = types.LiveClientMessageDict(
|
635
|
+
tool_response=types.LiveClientToolResponseDict(
|
636
|
+
function_responses=[function_response_typeddict]
|
637
|
+
)
|
638
|
+
)
|
639
|
+
elif isinstance(formatted_input, Sequence) and isinstance(
|
640
|
+
formatted_input[0], types.FunctionResponse
|
483
641
|
):
|
484
|
-
if not (self._api_client.vertexai) and not (
|
642
|
+
if not (self._api_client.vertexai) and not (formatted_input[0].id):
|
485
643
|
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
644
|
+
function_response_list: list[types.FunctionResponseDict] = []
|
645
|
+
for item in formatted_input:
|
646
|
+
function_response_dict = item.model_dump(exclude_none=True, mode='json')
|
647
|
+
function_response_typeddict = types.FunctionResponseDict(
|
648
|
+
name=function_response_dict.get('name'),
|
649
|
+
response=function_response_dict.get('response'),
|
650
|
+
)
|
651
|
+
if function_response_dict.get('id'):
|
652
|
+
function_response_typeddict['id'] = function_response_dict.get('id')
|
653
|
+
function_response_list.append(function_response_typeddict)
|
654
|
+
client_message = types.LiveClientMessageDict(
|
655
|
+
tool_response=types.LiveClientToolResponseDict(
|
656
|
+
function_responses=function_response_list
|
657
|
+
)
|
658
|
+
)
|
659
|
+
|
493
660
|
else:
|
494
661
|
raise ValueError(
|
495
662
|
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
@@ -658,7 +825,7 @@ class AsyncLive(_api_module.BaseModule):
|
|
658
825
|
return return_value
|
659
826
|
|
660
827
|
@experimental_warning(
|
661
|
-
|
828
|
+
'The live API is experimental and may change in future versions.',
|
662
829
|
)
|
663
830
|
@contextlib.asynccontextmanager
|
664
831
|
async def connect(
|
@@ -666,7 +833,7 @@ class AsyncLive(_api_module.BaseModule):
|
|
666
833
|
*,
|
667
834
|
model: str,
|
668
835
|
config: Optional[types.LiveConnectConfigOrDict] = None,
|
669
|
-
) -> AsyncSession:
|
836
|
+
) -> AsyncIterator[AsyncSession]:
|
670
837
|
"""Connect to the live server.
|
671
838
|
|
672
839
|
The live module is experimental.
|
@@ -685,9 +852,24 @@ class AsyncLive(_api_module.BaseModule):
|
|
685
852
|
base_url = self._api_client._websocket_base_url()
|
686
853
|
transformed_model = t.t_model(self._api_client, model)
|
687
854
|
# Ensure the config is a LiveConnectConfig.
|
688
|
-
|
689
|
-
|
690
|
-
|
855
|
+
if config is None:
|
856
|
+
parameter_model = types.LiveConnectConfig()
|
857
|
+
elif isinstance(config, dict):
|
858
|
+
if config.get('system_instruction') is None:
|
859
|
+
system_instruction = None
|
860
|
+
else:
|
861
|
+
system_instruction = t.t_content(
|
862
|
+
self._api_client, config.get('system_instruction')
|
863
|
+
)
|
864
|
+
parameter_model = types.LiveConnectConfig(
|
865
|
+
generation_config=config.get('generation_config'),
|
866
|
+
response_modalities=config.get('response_modalities'),
|
867
|
+
speech_config=config.get('speech_config'),
|
868
|
+
system_instruction=system_instruction,
|
869
|
+
tools=config.get('tools'),
|
870
|
+
)
|
871
|
+
else:
|
872
|
+
parameter_model = config
|
691
873
|
|
692
874
|
if self._api_client.api_key:
|
693
875
|
api_key = self._api_client.api_key
|
@@ -713,9 +895,10 @@ class AsyncLive(_api_module.BaseModule):
|
|
713
895
|
creds.refresh(auth_req)
|
714
896
|
bearer_token = creds.token
|
715
897
|
headers = self._api_client._http_options['headers']
|
716
|
-
headers
|
717
|
-
|
718
|
-
|
898
|
+
if headers is not None:
|
899
|
+
headers.update({
|
900
|
+
'Authorization': 'Bearer {}'.format(bearer_token),
|
901
|
+
})
|
719
902
|
version = self._api_client._http_options['api_version']
|
720
903
|
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
|
721
904
|
location = self._api_client.location
|