google-genai 1.4.0__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py CHANGED
@@ -20,9 +20,10 @@ import base64
20
20
  import contextlib
21
21
  import json
22
22
  import logging
23
- from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union
23
+ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
24
24
 
25
25
  import google.auth
26
+ import pydantic
26
27
  from websockets import ConnectionClosed
27
28
 
28
29
  from . import _api_module
@@ -49,12 +50,12 @@ from .models import _Tool_to_mldev
49
50
  from .models import _Tool_to_vertex
50
51
 
51
52
  try:
52
- from websockets.asyncio.client import ClientConnection # type: ignore
53
- from websockets.asyncio.client import connect # type: ignore
53
+ from websockets.asyncio.client import ClientConnection # type: ignore
54
+ from websockets.asyncio.client import connect # type: ignore
54
55
  except ModuleNotFoundError:
55
56
  # This try/except is for TAP, mypy complains about it which is why we have the type: ignore
56
- from websockets.client import ClientConnection # type: ignore
57
- from websockets.client import connect # type: ignore
57
+ from websockets.client import ClientConnection # type: ignore
58
+ from websockets.client import connect # type: ignore
58
59
 
59
60
  logger = logging.getLogger('google_genai.live')
60
61
 
@@ -171,7 +172,7 @@ class AsyncSession:
171
172
  stream = read_audio()
172
173
  for data in stream:
173
174
  yield data
174
- async with client.aio.live.connect(model='...') as session:
175
+ async with client.aio.live.connect(model='...', config=config) as session:
175
176
  for audio in session.start_stream(stream = audio_stream(),
176
177
  mime_type = 'audio/pcm'):
177
178
  play_audio_chunk(audio.data)
@@ -211,7 +212,7 @@ class AsyncSession:
211
212
  try:
212
213
  response = json.loads(raw_response)
213
214
  except json.decoder.JSONDecodeError:
214
- raise ValueError(f'Failed to parse response: {raw_response}')
215
+ raise ValueError(f'Failed to parse response: {raw_response!r}')
215
216
  else:
216
217
  response = {}
217
218
  if self._api_client.vertexai:
@@ -220,7 +221,7 @@ class AsyncSession:
220
221
  response_dict = self._LiveServerMessage_from_mldev(response)
221
222
 
222
223
  return types.LiveServerMessage._from_response(
223
- response=response_dict, kwargs=parameter_model
224
+ response=response_dict, kwargs=parameter_model.model_dump()
224
225
  )
225
226
 
226
227
  async def _send_loop(
@@ -230,8 +231,10 @@ class AsyncSession:
230
231
  stop_event: asyncio.Event,
231
232
  ):
232
233
  async for data in data_stream:
233
- input = {'data': data, 'mimeType': mime_type}
234
- await self.send(input=input)
234
+ model_input = types.LiveClientRealtimeInput(
235
+ media_chunks=[types.Blob(data=data, mime_type=mime_type)]
236
+ )
237
+ await self.send(input=model_input)
235
238
  # Give a chance for the receive loop to process responses.
236
239
  await asyncio.sleep(10**-12)
237
240
  # Give a chance for the receiver to process the last response.
@@ -372,124 +375,288 @@ class AsyncSession:
372
375
  ]
373
376
  ] = None,
374
377
  end_of_turn: Optional[bool] = False,
375
- ) -> Dict[str, Any]:
378
+ ) -> types.LiveClientMessageDict:
379
+
380
+ formatted_input: Any = input
376
381
 
377
382
  if not input:
378
383
  logging.info('No input provided. Assume it is the end of turn.')
379
384
  return {'client_content': {'turn_complete': True}}
380
385
  if isinstance(input, str):
381
- input = [input]
386
+ formatted_input = [input]
382
387
  elif isinstance(input, dict) and 'data' in input:
383
- if isinstance(input['data'], bytes):
384
- decoded_data = base64.b64encode(input['data']).decode('utf-8')
385
- input['data'] = decoded_data
386
- input = [input]
388
+ try:
389
+ blob_input = types.Blob(**input)
390
+ except pydantic.ValidationError:
391
+ raise ValueError(
392
+ f'Unsupported input type "{type(input)}" or input content "{input}"'
393
+ )
394
+ if (
395
+ isinstance(blob_input, types.Blob)
396
+ and isinstance(blob_input.data, bytes)
397
+ ):
398
+ formatted_input = [
399
+ blob_input.model_dump(mode='json', exclude_none=True)
400
+ ]
387
401
  elif isinstance(input, types.Blob):
388
- input = [input]
402
+ formatted_input = [input]
389
403
  elif isinstance(input, dict) and 'name' in input and 'response' in input:
390
404
  # ToolResponse.FunctionResponse
391
405
  if not (self._api_client.vertexai) and 'id' not in input:
392
406
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
393
- input = [input]
407
+ formatted_input = [input]
394
408
 
395
- if isinstance(input, Sequence) and any(
396
- isinstance(c, dict) and 'name' in c and 'response' in c for c in input
409
+ if isinstance(formatted_input, Sequence) and any(
410
+ isinstance(c, dict) and 'name' in c and 'response' in c
411
+ for c in formatted_input
397
412
  ):
398
413
  # ToolResponse.FunctionResponse
399
- if not (self._api_client.vertexai):
400
- for item in input:
401
- if 'id' not in item:
414
+ function_responses_input = []
415
+ for item in formatted_input:
416
+ if isinstance(item, dict):
417
+ try:
418
+ function_response_input = types.FunctionResponse(**item)
419
+ except pydantic.ValidationError:
420
+ raise ValueError(
421
+ f'Unsupported input type "{type(input)}" or input content'
422
+ f' "{input}"'
423
+ )
424
+ if (
425
+ function_response_input.id is None
426
+ and not self._api_client.vertexai
427
+ ):
402
428
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
403
- client_message = {'tool_response': {'function_responses': input}}
404
- elif isinstance(input, Sequence) and any(isinstance(c, str) for c in input):
429
+ else:
430
+ function_response_dict = function_response_input.model_dump(
431
+ exclude_none=True, mode='json'
432
+ )
433
+ function_response_typeddict = types.FunctionResponseDict(
434
+ name=function_response_dict.get('name'),
435
+ response=function_response_dict.get('response'),
436
+ )
437
+ if function_response_dict.get('id'):
438
+ function_response_typeddict['id'] = function_response_dict.get(
439
+ 'id'
440
+ )
441
+ function_responses_input.append(function_response_typeddict)
442
+ client_message = types.LiveClientMessageDict(
443
+ tool_response=types.LiveClientToolResponseDict(
444
+ function_responses=function_responses_input
445
+ )
446
+ )
447
+ elif isinstance(formatted_input, Sequence) and any(
448
+ isinstance(c, str) for c in formatted_input
449
+ ):
405
450
  to_object: dict[str, Any] = {}
451
+ content_input_parts: list[types.PartUnion] = []
452
+ for item in formatted_input:
453
+ if isinstance(item, get_args(types.PartUnion)):
454
+ content_input_parts.append(item)
406
455
  if self._api_client.vertexai:
407
456
  contents = [
408
457
  _Content_to_vertex(self._api_client, item, to_object)
409
- for item in t.t_contents(self._api_client, input)
458
+ for item in t.t_contents(self._api_client, content_input_parts)
410
459
  ]
411
460
  else:
412
461
  contents = [
413
462
  _Content_to_mldev(self._api_client, item, to_object)
414
- for item in t.t_contents(self._api_client, input)
463
+ for item in t.t_contents(self._api_client, content_input_parts)
415
464
  ]
416
465
 
417
- client_message = {
418
- 'client_content': {'turns': contents, 'turn_complete': end_of_turn}
419
- }
420
- elif isinstance(input, Sequence):
421
- if any((isinstance(b, dict) and 'data' in b) for b in input):
466
+ content_dict_list: list[types.ContentDict] = []
467
+ for item in contents:
468
+ try:
469
+ content_input = types.Content(**item)
470
+ except pydantic.ValidationError:
471
+ raise ValueError(
472
+ f'Unsupported input type "{type(input)}" or input content'
473
+ f' "{input}"'
474
+ )
475
+ content_dict_list.append(
476
+ types.ContentDict(
477
+ parts=content_input.model_dump(exclude_none=True, mode='json')[
478
+ 'parts'
479
+ ],
480
+ role=content_input.role,
481
+ )
482
+ )
483
+
484
+ client_message = types.LiveClientMessageDict(
485
+ client_content=types.LiveClientContentDict(
486
+ turns=content_dict_list, turn_complete=end_of_turn
487
+ )
488
+ )
489
+ elif isinstance(formatted_input, Sequence):
490
+ if any((isinstance(b, dict) and 'data' in b) for b in formatted_input):
422
491
  pass
423
- elif any(isinstance(b, types.Blob) for b in input):
424
- input = [b.model_dump(exclude_none=True, mode='json') for b in input]
492
+ elif any(isinstance(b, types.Blob) for b in formatted_input):
493
+ formatted_input = [
494
+ b.model_dump(exclude_none=True, mode='json')
495
+ for b in formatted_input
496
+ ]
425
497
  else:
426
498
  raise ValueError(
427
499
  f'Unsupported input type "{type(input)}" or input content "{input}"'
428
500
  )
429
501
 
430
- client_message = {'realtime_input': {'media_chunks': input}}
502
+ client_message = types.LiveClientMessageDict(
503
+ realtime_input=types.LiveClientRealtimeInputDict(
504
+ media_chunks=formatted_input
505
+ )
506
+ )
431
507
 
432
- elif isinstance(input, dict):
433
- if 'content' in input or 'turns' in input:
508
+ elif isinstance(formatted_input, dict):
509
+ if 'content' in formatted_input or 'turns' in formatted_input:
434
510
  # TODO(b/365983264) Add validation checks for content_update input_dict.
435
- client_message = {'client_content': input}
436
- elif 'media_chunks' in input:
437
- client_message = {'realtime_input': input}
438
- elif 'function_responses' in input:
439
- client_message = {'tool_response': input}
511
+ if 'turns' in formatted_input:
512
+ content_turns = formatted_input['turns']
513
+ else:
514
+ content_turns = formatted_input['content']
515
+ client_message = types.LiveClientMessageDict(
516
+ client_content=types.LiveClientContentDict(
517
+ turns=content_turns,
518
+ turn_complete=formatted_input.get('turn_complete'),
519
+ )
520
+ )
521
+ elif 'media_chunks' in formatted_input:
522
+ try:
523
+ realtime_input = types.LiveClientRealtimeInput(**formatted_input)
524
+ except pydantic.ValidationError:
525
+ raise ValueError(
526
+ f'Unsupported input type "{type(input)}" or input content'
527
+ f' "{input}"'
528
+ )
529
+ client_message = types.LiveClientMessageDict(
530
+ realtime_input=types.LiveClientRealtimeInputDict(
531
+ media_chunks=realtime_input.model_dump(
532
+ exclude_none=True, mode='json'
533
+ )['media_chunks']
534
+ )
535
+ )
536
+ elif 'function_responses' in formatted_input:
537
+ try:
538
+ tool_response_input = types.LiveClientToolResponse(**formatted_input)
539
+ except pydantic.ValidationError:
540
+ raise ValueError(
541
+ f'Unsupported input type "{type(input)}" or input content'
542
+ f' "{input}"'
543
+ )
544
+ client_message = types.LiveClientMessageDict(
545
+ tool_response=types.LiveClientToolResponseDict(
546
+ function_responses=tool_response_input.model_dump(
547
+ exclude_none=True, mode='json'
548
+ )['function_responses']
549
+ )
550
+ )
440
551
  else:
441
552
  raise ValueError(
442
- f'Unsupported input type "{type(input)}" or input content "{input}"')
443
- elif isinstance(input, types.LiveClientRealtimeInput):
444
- client_message = {
445
- 'realtime_input': input.model_dump(exclude_none=True, mode='json')
446
- }
447
- if isinstance(
448
- client_message['realtime_input']['media_chunks'][0]['data'], bytes
553
+ f'Unsupported input type "{type(input)}" or input content "{input}"'
554
+ )
555
+ elif isinstance(formatted_input, types.LiveClientRealtimeInput):
556
+ realtime_input_dict = formatted_input.model_dump(
557
+ exclude_none=True, mode='json'
558
+ )
559
+ client_message = types.LiveClientMessageDict(
560
+ realtime_input=types.LiveClientRealtimeInputDict(
561
+ media_chunks=realtime_input_dict.get('media_chunks')
562
+ )
563
+ )
564
+ if (
565
+ client_message['realtime_input'] is not None
566
+ and client_message['realtime_input']['media_chunks'] is not None
567
+ and isinstance(
568
+ client_message['realtime_input']['media_chunks'][0]['data'], bytes
569
+ )
449
570
  ):
450
- client_message['realtime_input']['media_chunks'] = [
451
- {
452
- 'data': base64.b64encode(item['data']).decode('utf-8'),
453
- 'mime_type': item['mime_type'],
454
- }
455
- for item in client_message['realtime_input']['media_chunks']
456
- ]
571
+ formatted_media_chunks: list[types.BlobDict] = []
572
+ for item in client_message['realtime_input']['media_chunks']:
573
+ if isinstance(item, dict):
574
+ try:
575
+ blob_input = types.Blob(**item)
576
+ except pydantic.ValidationError:
577
+ raise ValueError(
578
+ f'Unsupported input type "{type(input)}" or input content'
579
+ f' "{input}"'
580
+ )
581
+ if (
582
+ isinstance(blob_input, types.Blob)
583
+ and isinstance(blob_input.data, bytes)
584
+ and blob_input.data is not None
585
+ ):
586
+ formatted_media_chunks.append(
587
+ types.BlobDict(
588
+ data=base64.b64decode(blob_input.data),
589
+ mime_type=blob_input.mime_type,
590
+ )
591
+ )
592
+
593
+ client_message['realtime_input'][
594
+ 'media_chunks'
595
+ ] = formatted_media_chunks
457
596
 
458
- elif isinstance(input, types.LiveClientContent):
459
- client_message = {
460
- 'client_content': input.model_dump(exclude_none=True, mode='json')
461
- }
462
- elif isinstance(input, types.LiveClientToolResponse):
597
+ elif isinstance(formatted_input, types.LiveClientContent):
598
+ client_content_dict = formatted_input.model_dump(
599
+ exclude_none=True, mode='json'
600
+ )
601
+ client_message = types.LiveClientMessageDict(
602
+ client_content=types.LiveClientContentDict(
603
+ turns=client_content_dict.get('turns'),
604
+ turn_complete=client_content_dict.get('turn_complete'),
605
+ )
606
+ )
607
+ elif isinstance(formatted_input, types.LiveClientToolResponse):
463
608
  # ToolResponse.FunctionResponse
464
- if not (self._api_client.vertexai) and not (
465
- input.function_responses[0].id
609
+ if (
610
+ not (self._api_client.vertexai)
611
+ and formatted_input.function_responses is not None
612
+ and not (formatted_input.function_responses[0].id)
466
613
  ):
467
614
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
468
- client_message = {
469
- 'tool_response': input.model_dump(exclude_none=True, mode='json')
470
- }
471
- elif isinstance(input, types.FunctionResponse):
472
- if not (self._api_client.vertexai) and not (input.id):
615
+ client_message = types.LiveClientMessageDict(
616
+ tool_response=types.LiveClientToolResponseDict(
617
+ function_responses=formatted_input.model_dump(
618
+ exclude_none=True, mode='json'
619
+ ).get('function_responses')
620
+ )
621
+ )
622
+ elif isinstance(formatted_input, types.FunctionResponse):
623
+ if not (self._api_client.vertexai) and not (formatted_input.id):
473
624
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
474
- client_message = {
475
- 'tool_response': {
476
- 'function_responses': [
477
- input.model_dump(exclude_none=True, mode='json')
478
- ]
479
- }
480
- }
481
- elif isinstance(input, Sequence) and isinstance(
482
- input[0], types.FunctionResponse
625
+ function_response_dict = formatted_input.model_dump(
626
+ exclude_none=True, mode='json'
627
+ )
628
+ function_response_typeddict = types.FunctionResponseDict(
629
+ name=function_response_dict.get('name'),
630
+ response=function_response_dict.get('response'),
631
+ )
632
+ if function_response_dict.get('id'):
633
+ function_response_typeddict['id'] = function_response_dict.get('id')
634
+ client_message = types.LiveClientMessageDict(
635
+ tool_response=types.LiveClientToolResponseDict(
636
+ function_responses=[function_response_typeddict]
637
+ )
638
+ )
639
+ elif isinstance(formatted_input, Sequence) and isinstance(
640
+ formatted_input[0], types.FunctionResponse
483
641
  ):
484
- if not (self._api_client.vertexai) and not (input[0].id):
642
+ if not (self._api_client.vertexai) and not (formatted_input[0].id):
485
643
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
486
- client_message = {
487
- 'tool_response': {
488
- 'function_responses': [
489
- c.model_dump(exclude_none=True, mode='json') for c in input
490
- ]
491
- }
492
- }
644
+ function_response_list: list[types.FunctionResponseDict] = []
645
+ for item in formatted_input:
646
+ function_response_dict = item.model_dump(exclude_none=True, mode='json')
647
+ function_response_typeddict = types.FunctionResponseDict(
648
+ name=function_response_dict.get('name'),
649
+ response=function_response_dict.get('response'),
650
+ )
651
+ if function_response_dict.get('id'):
652
+ function_response_typeddict['id'] = function_response_dict.get('id')
653
+ function_response_list.append(function_response_typeddict)
654
+ client_message = types.LiveClientMessageDict(
655
+ tool_response=types.LiveClientToolResponseDict(
656
+ function_responses=function_response_list
657
+ )
658
+ )
659
+
493
660
  else:
494
661
  raise ValueError(
495
662
  f'Unsupported input type "{type(input)}" or input content "{input}"'
@@ -658,7 +825,7 @@ class AsyncLive(_api_module.BaseModule):
658
825
  return return_value
659
826
 
660
827
  @experimental_warning(
661
- "The live API is experimental and may change in future versions.",
828
+ 'The live API is experimental and may change in future versions.',
662
829
  )
663
830
  @contextlib.asynccontextmanager
664
831
  async def connect(
@@ -666,7 +833,7 @@ class AsyncLive(_api_module.BaseModule):
666
833
  *,
667
834
  model: str,
668
835
  config: Optional[types.LiveConnectConfigOrDict] = None,
669
- ) -> AsyncSession:
836
+ ) -> AsyncIterator[AsyncSession]:
670
837
  """Connect to the live server.
671
838
 
672
839
  The live module is experimental.
@@ -685,9 +852,24 @@ class AsyncLive(_api_module.BaseModule):
685
852
  base_url = self._api_client._websocket_base_url()
686
853
  transformed_model = t.t_model(self._api_client, model)
687
854
  # Ensure the config is a LiveConnectConfig.
688
- parameter_model = types.LiveConnectConfig(**config) if isinstance(
689
- config, dict
690
- ) else config
855
+ if config is None:
856
+ parameter_model = types.LiveConnectConfig()
857
+ elif isinstance(config, dict):
858
+ if config.get('system_instruction') is None:
859
+ system_instruction = None
860
+ else:
861
+ system_instruction = t.t_content(
862
+ self._api_client, config.get('system_instruction')
863
+ )
864
+ parameter_model = types.LiveConnectConfig(
865
+ generation_config=config.get('generation_config'),
866
+ response_modalities=config.get('response_modalities'),
867
+ speech_config=config.get('speech_config'),
868
+ system_instruction=system_instruction,
869
+ tools=config.get('tools'),
870
+ )
871
+ else:
872
+ parameter_model = config
691
873
 
692
874
  if self._api_client.api_key:
693
875
  api_key = self._api_client.api_key
@@ -713,9 +895,10 @@ class AsyncLive(_api_module.BaseModule):
713
895
  creds.refresh(auth_req)
714
896
  bearer_token = creds.token
715
897
  headers = self._api_client._http_options['headers']
716
- headers.update({
717
- 'Authorization': 'Bearer {}'.format(bearer_token),
718
- })
898
+ if headers is not None:
899
+ headers.update({
900
+ 'Authorization': 'Bearer {}'.format(bearer_token),
901
+ })
719
902
  version = self._api_client._http_options['api_version']
720
903
  uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
721
904
  location = self._api_client.location