google-genai 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py CHANGED
@@ -33,21 +33,12 @@ from . import _transformers as t
33
33
  from . import client
34
34
  from . import types
35
35
  from ._api_client import BaseApiClient
36
- from ._common import experimental_warning
37
36
  from ._common import get_value_by_path as getv
38
37
  from ._common import set_value_by_path as setv
39
- from .models import _Content_from_mldev
40
- from .models import _Content_from_vertex
38
+ from . import live_converters
41
39
  from .models import _Content_to_mldev
42
40
  from .models import _Content_to_vertex
43
- from .models import _GenerateContentConfig_to_mldev
44
- from .models import _GenerateContentConfig_to_vertex
45
- from .models import _SafetySetting_to_mldev
46
- from .models import _SafetySetting_to_vertex
47
- from .models import _SpeechConfig_to_mldev
48
- from .models import _SpeechConfig_to_vertex
49
- from .models import _Tool_to_mldev
50
- from .models import _Tool_to_vertex
41
+
51
42
 
52
43
  try:
53
44
  from websockets.asyncio.client import ClientConnection # type: ignore
@@ -65,66 +56,6 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
65
56
  )
66
57
 
67
58
 
68
- def _ClientContent_to_mldev(
69
- api_client: BaseApiClient,
70
- from_object: types.LiveClientContent,
71
- ) -> dict:
72
- client_content = from_object.model_dump(exclude_none=True, mode='json')
73
- if 'turns' in client_content:
74
- client_content['turns'] = [
75
- _Content_to_mldev(api_client=api_client, from_object=item)
76
- for item in client_content['turns']
77
- ]
78
- return client_content
79
-
80
-
81
- def _ClientContent_to_vertex(
82
- api_client: BaseApiClient,
83
- from_object: types.LiveClientContent,
84
- ) -> dict:
85
- client_content = from_object.model_dump(exclude_none=True, mode='json')
86
- if 'turns' in client_content:
87
- client_content['turns'] = [
88
- _Content_to_vertex(api_client=api_client, from_object=item)
89
- for item in client_content['turns']
90
- ]
91
- return client_content
92
-
93
-
94
- def _ToolResponse_to_mldev(
95
- api_client: BaseApiClient,
96
- from_object: types.LiveClientToolResponse,
97
- ) -> dict:
98
- tool_response = from_object.model_dump(exclude_none=True, mode='json')
99
- for response in tool_response.get('function_responses', []):
100
- if response.get('id') is None:
101
- raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
102
- return tool_response
103
-
104
-
105
- def _ToolResponse_to_vertex(
106
- api_client: BaseApiClient,
107
- from_object: types.LiveClientToolResponse,
108
- ) -> dict:
109
- tool_response = from_object.model_dump(exclude_none=True, mode='json')
110
- return tool_response
111
-
112
- def _AudioTranscriptionConfig_to_mldev(
113
- api_client: BaseApiClient,
114
- from_object: types.AudioTranscriptionConfig,
115
- ) -> dict:
116
- audio_transcription: dict[str, Any] = {}
117
- return audio_transcription
118
-
119
-
120
- def _AudioTranscriptionConfig_to_vertex(
121
- api_client: BaseApiClient,
122
- from_object: types.AudioTranscriptionConfig,
123
- ) -> dict:
124
- audio_transcription: dict[str, Any] = {}
125
- return audio_transcription
126
-
127
-
128
59
  class AsyncSession:
129
60
  """[Preview] AsyncSession."""
130
61
 
@@ -238,8 +169,14 @@ class AsyncSession:
238
169
  ```
239
170
  import google.genai
240
171
  from google.genai import types
172
+ import os
173
+
174
+ if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
175
+ MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
176
+ else:
177
+ MODEL_NAME = 'gemini-2.0-flash-live-001';
241
178
 
242
- client = genai.Client(http_options={'api_version': 'v1alpha'})
179
+ client = genai.Client()
243
180
  async with client.aio.live.connect(
244
181
  model=MODEL_NAME,
245
182
  config={"response_modalities": ["TEXT"]}
@@ -253,14 +190,14 @@ class AsyncSession:
253
190
  print(msg.text)
254
191
  ```
255
192
  """
256
- client_content = _t_client_content(turns, turn_complete)
193
+ client_content = t.t_client_content(turns, turn_complete)
257
194
 
258
195
  if self._api_client.vertexai:
259
- client_content_dict = _ClientContent_to_vertex(
196
+ client_content_dict = live_converters._LiveClientContent_to_vertex(
260
197
  api_client=self._api_client, from_object=client_content
261
198
  )
262
199
  else:
263
- client_content_dict = _ClientContent_to_mldev(
200
+ client_content_dict = live_converters._LiveClientContent_to_mldev(
264
201
  api_client=self._api_client, from_object=client_content
265
202
  )
266
203
 
@@ -290,8 +227,16 @@ class AsyncSession:
290
227
  from google.genai import types
291
228
 
292
229
  import PIL.Image
230
+
231
+ import os
293
232
 
294
- client = genai.Client(http_options= {'api_version': 'v1alpha'})
233
+ if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
234
+ MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
235
+ else:
236
+ MODEL_NAME = 'gemini-2.0-flash-live-001';
237
+
238
+
239
+ client = genai.Client()
295
240
 
296
241
  async with client.aio.live.connect(
297
242
  model=MODEL_NAME,
@@ -309,7 +254,7 @@ class AsyncSession:
309
254
  print(f'{msg.text}')
310
255
  ```
311
256
  """
312
- realtime_input = _t_realtime_input(media)
257
+ realtime_input = t.t_realtime_input(media)
313
258
  realtime_input_dict = realtime_input.model_dump(
314
259
  exclude_none=True, mode='json'
315
260
  )
@@ -340,7 +285,14 @@ class AsyncSession:
340
285
  from google import genai
341
286
  from google.genai import types
342
287
 
343
- client = genai.Client(http_options={'api_version': 'v1alpha'})
288
+ import os
289
+
290
+ if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
291
+ MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
292
+ else:
293
+ MODEL_NAME = 'gemini-2.0-flash-live-001';
294
+
295
+ client = genai.Client()
344
296
 
345
297
  tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
346
298
  config = {
@@ -349,13 +301,13 @@ class AsyncSession:
349
301
  }
350
302
 
351
303
  async with client.aio.live.connect(
352
- model='gemini-2.0-flash-exp',
304
+ model='models/gemini-2.0-flash-live-001',
353
305
  config=config
354
306
  ) as session:
355
307
  prompt = "Turn on the lights please"
356
308
  await session.send_client_content(
357
- turns=prompt,
358
- turn_complete=True)
309
+ turns={"parts": [{'text': prompt}]}
310
+ )
359
311
 
360
312
  async for chunk in session.receive():
361
313
  if chunk.server_content:
@@ -376,13 +328,13 @@ class AsyncSession:
376
328
 
377
329
  print('_'*80)
378
330
  """
379
- tool_response = _t_tool_response(function_responses)
331
+ tool_response = t.t_tool_response(function_responses)
380
332
  if self._api_client.vertexai:
381
- tool_response_dict = _ToolResponse_to_vertex(
333
+ tool_response_dict = live_converters._LiveClientToolResponse_to_vertex(
382
334
  api_client=self._api_client, from_object=tool_response
383
335
  )
384
336
  else:
385
- tool_response_dict = _ToolResponse_to_mldev(
337
+ tool_response_dict = live_converters._LiveClientToolResponse_to_mldev(
386
338
  api_client=self._api_client, from_object=tool_response
387
339
  )
388
340
  await self._ws.send(json.dumps({'tool_response': tool_response_dict}))
@@ -503,9 +455,9 @@ class AsyncSession:
503
455
  response = {}
504
456
 
505
457
  if self._api_client.vertexai:
506
- response_dict = self._LiveServerMessage_from_vertex(response)
458
+ response_dict = live_converters._LiveServerMessage_from_vertex(self._api_client, response)
507
459
  else:
508
- response_dict = self._LiveServerMessage_from_mldev(response)
460
+ response_dict = live_converters._LiveServerMessage_from_mldev(self._api_client, response)
509
461
 
510
462
  return types.LiveServerMessage._from_response(
511
463
  response=response_dict, kwargs=parameter_model.model_dump()
@@ -527,464 +479,6 @@ class AsyncSession:
527
479
  # Give a chance for the receiver to process the last response.
528
480
  stop_event.set()
529
481
 
530
- def _LiveServerContent_from_mldev(
531
- self,
532
- from_object: Union[dict, object],
533
- ) -> Dict[str, Any]:
534
- to_object: dict[str, Any] = {}
535
- if getv(from_object, ['modelTurn']) is not None:
536
- setv(
537
- to_object,
538
- ['model_turn'],
539
- _Content_from_mldev(
540
- self._api_client,
541
- getv(from_object, ['modelTurn']),
542
- ),
543
- )
544
- if getv(from_object, ['turnComplete']) is not None:
545
- setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
546
- if getv(from_object, ['generationComplete']) is not None:
547
- setv(
548
- to_object,
549
- ['generation_complete'],
550
- getv(from_object, ['generationComplete']),
551
- )
552
- if getv(from_object, ['inputTranscription']) is not None:
553
- setv(
554
- to_object,
555
- ['input_transcription'],
556
- getv(from_object, ['inputTranscription']),
557
- )
558
- if getv(from_object, ['outputTranscription']) is not None:
559
- setv(
560
- to_object,
561
- ['output_transcription'],
562
- getv(from_object, ['outputTranscription']),
563
- )
564
- if getv(from_object, ['interrupted']) is not None:
565
- setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
566
- return to_object
567
-
568
- def _LiveToolCall_from_mldev(
569
- self,
570
- from_object: Union[dict, object],
571
- ) -> Dict[str, Any]:
572
- to_object: dict[str, Any] = {}
573
- if getv(from_object, ['functionCalls']) is not None:
574
- setv(
575
- to_object,
576
- ['function_calls'],
577
- getv(from_object, ['functionCalls']),
578
- )
579
- return to_object
580
-
581
- def _LiveToolCall_from_vertex(
582
- self,
583
- from_object: Union[dict, object],
584
- ) -> Dict[str, Any]:
585
- to_object: dict[str, Any] = {}
586
- if getv(from_object, ['functionCalls']) is not None:
587
- setv(
588
- to_object,
589
- ['function_calls'],
590
- getv(from_object, ['functionCalls']),
591
- )
592
- return to_object
593
-
594
- def _LiveServerGoAway_from_mldev(
595
- self,
596
- from_object: Union[dict, object],
597
- parent_object: Optional[dict] = None,
598
- ) -> dict:
599
- to_object: dict[str, Any] = {}
600
- if getv(from_object, ['timeLeft']) is not None:
601
- setv(to_object, ['time_left'], getv(from_object, ['timeLeft']))
602
-
603
- return to_object
604
-
605
- def _LiveServerSessionResumptionUpdate_from_mldev(
606
- self,
607
- from_object: Union[dict, object],
608
- parent_object: Optional[dict] = None,
609
- ) -> dict:
610
- to_object: dict[str, Any] = {}
611
- if getv(from_object, ['newHandle']) is not None:
612
- setv(to_object, ['new_handle'], getv(from_object, ['newHandle']))
613
-
614
- if getv(from_object, ['resumable']) is not None:
615
- setv(to_object, ['resumable'], getv(from_object, ['resumable']))
616
-
617
- if getv(from_object, ['lastConsumedClientMessageIndex']) is not None:
618
- setv(
619
- to_object,
620
- ['last_consumed_client_message_index'],
621
- getv(from_object, ['lastConsumedClientMessageIndex']),
622
- )
623
-
624
- return to_object
625
-
626
- def _ModalityTokenCount_from_mldev(
627
- self,
628
- from_object: Union[dict, object],
629
- ) -> Dict[str, Any]:
630
- to_object: Dict[str, Any] = {}
631
- if getv(from_object, ['modality']) is not None:
632
- setv(to_object, ['modality'], getv(from_object, ['modality']))
633
- if getv(from_object, ['tokenCount']) is not None:
634
- setv(to_object, ['token_count'], getv(from_object, ['tokenCount']))
635
- return to_object
636
-
637
- def _UsageMetadata_from_mldev(
638
- self,
639
- from_object: Union[dict, object],
640
- ) -> Dict[str, Any]:
641
- to_object: dict[str, Any] = {}
642
- if getv(from_object, ['promptTokenCount']) is not None:
643
- setv(
644
- to_object,
645
- ['prompt_token_count'],
646
- getv(from_object, ['promptTokenCount']),
647
- )
648
- if getv(from_object, ['cachedContentTokenCount']) is not None:
649
- setv(
650
- to_object,
651
- ['cached_content_token_count'],
652
- getv(from_object, ['cachedContentTokenCount']),
653
- )
654
- if getv(from_object, ['responseTokenCount']) is not None:
655
- setv(
656
- to_object,
657
- ['response_token_count'],
658
- getv(from_object, ['responseTokenCount']),
659
- )
660
- if getv(from_object, ['toolUsePromptTokenCount']) is not None:
661
- setv(
662
- to_object,
663
- ['tool_use_prompt_token_count'],
664
- getv(from_object, ['toolUsePromptTokenCount']),
665
- )
666
- if getv(from_object, ['thoughtsTokenCount']) is not None:
667
- setv(
668
- to_object,
669
- ['thoughts_token_count'],
670
- getv(from_object, ['thoughtsTokenCount']),
671
- )
672
- if getv(from_object, ['totalTokenCount']) is not None:
673
- setv(
674
- to_object,
675
- ['total_token_count'],
676
- getv(from_object, ['totalTokenCount']),
677
- )
678
- if getv(from_object, ['promptTokensDetails']) is not None:
679
- setv(
680
- to_object,
681
- ['prompt_tokens_details'],
682
- [
683
- self._ModalityTokenCount_from_mldev(item)
684
- for item in getv(from_object, ['promptTokensDetails'])
685
- ],
686
- )
687
- if getv(from_object, ['cacheTokensDetails']) is not None:
688
- setv(
689
- to_object,
690
- ['cache_tokens_details'],
691
- [
692
- self._ModalityTokenCount_from_mldev(item)
693
- for item in getv(from_object, ['cacheTokensDetails'])
694
- ],
695
- )
696
- if getv(from_object, ['responseTokensDetails']) is not None:
697
- setv(
698
- to_object,
699
- ['response_tokens_details'],
700
- [
701
- self._ModalityTokenCount_from_mldev(item)
702
- for item in getv(from_object, ['responseTokensDetails'])
703
- ],
704
- )
705
- if getv(from_object, ['toolUsePromptTokensDetails']) is not None:
706
- setv(
707
- to_object,
708
- ['tool_use_prompt_tokens_details'],
709
- [
710
- self._ModalityTokenCount_from_mldev(item)
711
- for item in getv(from_object, ['toolUsePromptTokensDetails'])
712
- ],
713
- )
714
- return to_object
715
-
716
- def _LiveServerMessage_from_mldev(
717
- self,
718
- from_object: Union[dict, object],
719
- ) -> Dict[str, Any]:
720
- to_object: dict[str, Any] = {}
721
- if getv(from_object, ['serverContent']) is not None:
722
- setv(
723
- to_object,
724
- ['server_content'],
725
- self._LiveServerContent_from_mldev(
726
- getv(from_object, ['serverContent'])
727
- ),
728
- )
729
- if getv(from_object, ['toolCall']) is not None:
730
- setv(
731
- to_object,
732
- ['tool_call'],
733
- self._LiveToolCall_from_mldev(getv(from_object, ['toolCall'])),
734
- )
735
- if getv(from_object, ['toolCallCancellation']) is not None:
736
- setv(
737
- to_object,
738
- ['tool_call_cancellation'],
739
- getv(from_object, ['toolCallCancellation']),
740
- )
741
-
742
- if getv(from_object, ['goAway']) is not None:
743
- setv(
744
- to_object,
745
- ['go_away'],
746
- self._LiveServerGoAway_from_mldev(
747
- getv(from_object, ['goAway']), to_object
748
- ),
749
- )
750
-
751
- if getv(from_object, ['sessionResumptionUpdate']) is not None:
752
- setv(
753
- to_object,
754
- ['session_resumption_update'],
755
- self._LiveServerSessionResumptionUpdate_from_mldev(
756
- getv(from_object, ['sessionResumptionUpdate']),
757
- to_object,
758
- ),
759
- )
760
-
761
- return to_object
762
-
763
- if getv(from_object, ['usageMetadata']) is not None:
764
- setv(
765
- to_object,
766
- ['usage_metadata'],
767
- self._UsageMetadata_from_mldev(getv(from_object, ['usageMetadata'])),
768
- )
769
- return to_object
770
-
771
- def _LiveServerContent_from_vertex(
772
- self,
773
- from_object: Union[dict, object],
774
- ) -> Dict[str, Any]:
775
- to_object: dict[str, Any] = {}
776
- if getv(from_object, ['modelTurn']) is not None:
777
- setv(
778
- to_object,
779
- ['model_turn'],
780
- _Content_from_vertex(
781
- self._api_client,
782
- getv(from_object, ['modelTurn']),
783
- ),
784
- )
785
- if getv(from_object, ['turnComplete']) is not None:
786
- setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
787
- if getv(from_object, ['generationComplete']) is not None:
788
- setv(
789
- to_object,
790
- ['generation_complete'],
791
- getv(from_object, ['generationComplete']),
792
- )
793
- if getv(from_object, ['inputTranscription']) is not None:
794
- setv(
795
- to_object,
796
- ['input_transcription'],
797
- getv(from_object, ['inputTranscription']),
798
- )
799
- if getv(from_object, ['outputTranscription']) is not None:
800
- setv(
801
- to_object,
802
- ['output_transcription'],
803
- getv(from_object, ['outputTranscription']),
804
- )
805
- if getv(from_object, ['interrupted']) is not None:
806
- setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
807
- return to_object
808
-
809
- def _LiveServerGoAway_from_vertex(
810
- self,
811
- from_object: Union[dict, object],
812
- ) -> dict:
813
- to_object: dict[str, Any] = {}
814
- if getv(from_object, ['timeLeft']) is not None:
815
- setv(to_object, ['time_left'], getv(from_object, ['timeLeft']))
816
-
817
- return to_object
818
-
819
- def _LiveServerSessionResumptionUpdate_from_vertex(
820
- self,
821
- from_object: Union[dict, object],
822
- ) -> dict:
823
- to_object: dict[str, Any] = {}
824
- if getv(from_object, ['newHandle']) is not None:
825
- setv(to_object, ['new_handle'], getv(from_object, ['newHandle']))
826
-
827
- if getv(from_object, ['resumable']) is not None:
828
- setv(to_object, ['resumable'], getv(from_object, ['resumable']))
829
-
830
- if getv(from_object, ['lastConsumedClientMessageIndex']) is not None:
831
- setv(
832
- to_object,
833
- ['last_consumed_client_message_index'],
834
- getv(from_object, ['lastConsumedClientMessageIndex']),
835
- )
836
-
837
- return to_object
838
-
839
-
840
- def _ModalityTokenCount_from_vertex(
841
- self,
842
- from_object: Union[dict, object],
843
- ) -> Dict[str, Any]:
844
- to_object: Dict[str, Any] = {}
845
- if getv(from_object, ['modality']) is not None:
846
- setv(to_object, ['modality'], getv(from_object, ['modality']))
847
- if getv(from_object, ['tokenCount']) is not None:
848
- setv(to_object, ['token_count'], getv(from_object, ['tokenCount']))
849
- return to_object
850
-
851
- def _UsageMetadata_from_vertex(
852
- self,
853
- from_object: Union[dict, object],
854
- ) -> Dict[str, Any]:
855
- to_object: dict[str, Any] = {}
856
- if getv(from_object, ['promptTokenCount']) is not None:
857
- setv(
858
- to_object,
859
- ['prompt_token_count'],
860
- getv(from_object, ['promptTokenCount']),
861
- )
862
- if getv(from_object, ['cachedContentTokenCount']) is not None:
863
- setv(
864
- to_object,
865
- ['cached_content_token_count'],
866
- getv(from_object, ['cachedContentTokenCount']),
867
- )
868
- if getv(from_object, ['candidatesTokenCount']) is not None:
869
- setv(
870
- to_object,
871
- ['response_token_count'],
872
- getv(from_object, ['candidatesTokenCount']),
873
- )
874
- if getv(from_object, ['toolUsePromptTokenCount']) is not None:
875
- setv(
876
- to_object,
877
- ['tool_use_prompt_token_count'],
878
- getv(from_object, ['toolUsePromptTokenCount']),
879
- )
880
- if getv(from_object, ['thoughtsTokenCount']) is not None:
881
- setv(
882
- to_object,
883
- ['thoughts_token_count'],
884
- getv(from_object, ['thoughtsTokenCount']),
885
- )
886
- if getv(from_object, ['totalTokenCount']) is not None:
887
- setv(
888
- to_object,
889
- ['total_token_count'],
890
- getv(from_object, ['totalTokenCount']),
891
- )
892
- if getv(from_object, ['promptTokensDetails']) is not None:
893
- setv(
894
- to_object,
895
- ['prompt_tokens_details'],
896
- [
897
- self._ModalityTokenCount_from_vertex(item)
898
- for item in getv(from_object, ['promptTokensDetails'])
899
- ],
900
- )
901
- if getv(from_object, ['cacheTokensDetails']) is not None:
902
- setv(
903
- to_object,
904
- ['cache_tokens_details'],
905
- [
906
- self._ModalityTokenCount_from_vertex(item)
907
- for item in getv(from_object, ['cacheTokensDetails'])
908
- ],
909
- )
910
- if getv(from_object, ['toolUsePromptTokensDetails']) is not None:
911
- setv(
912
- to_object,
913
- ['tool_use_prompt_tokens_details'],
914
- [
915
- self._ModalityTokenCount_from_vertex(item)
916
- for item in getv(from_object, ['toolUsePromptTokensDetails'])
917
- ],
918
- )
919
- if getv(from_object, ['candidatesTokensDetails']) is not None:
920
- setv(
921
- to_object,
922
- ['response_tokens_details'],
923
- [
924
- self._ModalityTokenCount_from_vertex(item)
925
- for item in getv(from_object, ['candidatesTokensDetails'])
926
- ],
927
- )
928
- if getv(from_object, ['trafficType']) is not None:
929
- setv(
930
- to_object,
931
- ['traffic_type'],
932
- getv(from_object, ['trafficType']),
933
- )
934
- return to_object
935
-
936
- def _LiveServerMessage_from_vertex(
937
- self,
938
- from_object: Union[dict, object],
939
- ) -> Dict[str, Any]:
940
- to_object: dict[str, Any] = {}
941
- if getv(from_object, ['serverContent']) is not None:
942
- setv(
943
- to_object,
944
- ['server_content'],
945
- self._LiveServerContent_from_vertex(
946
- getv(from_object, ['serverContent'])
947
- ),
948
- )
949
- if getv(from_object, ['toolCall']) is not None:
950
- setv(
951
- to_object,
952
- ['tool_call'],
953
- self._LiveToolCall_from_vertex(getv(from_object, ['toolCall'])),
954
- )
955
- if getv(from_object, ['toolCallCancellation']) is not None:
956
- setv(
957
- to_object,
958
- ['tool_call_cancellation'],
959
- getv(from_object, ['toolCallCancellation']),
960
- )
961
-
962
- if getv(from_object, ['goAway']) is not None:
963
- setv(
964
- to_object,
965
- ['go_away'],
966
- self._LiveServerGoAway_from_vertex(
967
- getv(from_object, ['goAway'])
968
- ),
969
- )
970
-
971
- if getv(from_object, ['sessionResumptionUpdate']) is not None:
972
- setv(
973
- to_object,
974
- ['session_resumption_update'],
975
- self._LiveServerSessionResumptionUpdate_from_vertex(
976
- getv(from_object, ['sessionResumptionUpdate']),
977
- ),
978
- )
979
-
980
- if getv(from_object, ['usageMetadata']) is not None:
981
- setv(
982
- to_object,
983
- ['usage_metadata'],
984
- self._UsageMetadata_from_vertex(getv(from_object, ['usageMetadata'])),
985
- )
986
- return to_object
987
-
988
482
  def _parse_client_message(
989
483
  self,
990
484
  input: Optional[
@@ -1288,472 +782,15 @@ class AsyncSession:
1288
782
 
1289
783
  return client_message
1290
784
 
1291
- async def close(self):
785
+ async def close(self) -> None:
1292
786
  # Close the websocket connection.
1293
787
  await self._ws.close()
1294
788
 
1295
789
 
1296
- def _t_content_strict(content: types.ContentOrDict):
1297
- if isinstance(content, dict):
1298
- return types.Content.model_validate(content)
1299
- elif isinstance(content, types.Content):
1300
- return content
1301
- else:
1302
- raise ValueError(
1303
- f'Could not convert input (type "{type(content)}") to '
1304
- '`types.Content`'
1305
- )
1306
-
1307
-
1308
- def _t_contents_strict(
1309
- contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict]):
1310
- if isinstance(contents, Sequence):
1311
- return [_t_content_strict(content) for content in contents]
1312
- else:
1313
- return [_t_content_strict(contents)]
1314
-
1315
-
1316
- def _t_client_content(
1317
- turns: Optional[
1318
- Union[Sequence[types.ContentOrDict], types.ContentOrDict]
1319
- ] = None,
1320
- turn_complete: bool = True,
1321
- ) -> types.LiveClientContent:
1322
- if turns is None:
1323
- return types.LiveClientContent(turn_complete=turn_complete)
1324
-
1325
- try:
1326
- return types.LiveClientContent(
1327
- turns=_t_contents_strict(contents=turns),
1328
- turn_complete=turn_complete,
1329
- )
1330
- except Exception as e:
1331
- raise ValueError(
1332
- f'Could not convert input (type "{type(turns)}") to '
1333
- '`types.LiveClientContent`'
1334
- ) from e
1335
-
1336
-
1337
- def _t_realtime_input(
1338
- media: t.BlobUnion,
1339
- ) -> types.LiveClientRealtimeInput:
1340
- try:
1341
- return types.LiveClientRealtimeInput(media_chunks=[t.t_blob(blob=media)])
1342
- except Exception as e:
1343
- raise ValueError(
1344
- f'Could not convert input (type "{type(input)}") to '
1345
- '`types.LiveClientRealtimeInput`'
1346
- ) from e
1347
-
1348
-
1349
- def _t_tool_response(
1350
- input: Union[
1351
- types.FunctionResponseOrDict,
1352
- Sequence[types.FunctionResponseOrDict],
1353
- ],
1354
- ) -> types.LiveClientToolResponse:
1355
- if not input:
1356
- raise ValueError(f'A tool response is required, got: \n{input}')
1357
-
1358
- try:
1359
- return types.LiveClientToolResponse(
1360
- function_responses=t.t_function_responses(function_responses=input)
1361
- )
1362
- except Exception as e:
1363
- raise ValueError(
1364
- f'Could not convert input (type "{type(input)}") to '
1365
- '`types.LiveClientToolResponse`'
1366
- ) from e
1367
-
1368
790
 
1369
791
  class AsyncLive(_api_module.BaseModule):
1370
792
  """[Preview] AsyncLive."""
1371
793
 
1372
- def _LiveSetup_to_mldev(
1373
- self, model: str, config: Optional[types.LiveConnectConfig] = None
1374
- ):
1375
-
1376
- to_object: dict[str, Any] = {}
1377
- if getv(config, ['generation_config']) is not None:
1378
- setv(
1379
- to_object,
1380
- ['generationConfig'],
1381
- _GenerateContentConfig_to_mldev(
1382
- self._api_client,
1383
- getv(config, ['generation_config']),
1384
- to_object,
1385
- ),
1386
- )
1387
- if getv(config, ['response_modalities']) is not None:
1388
- if getv(to_object, ['generationConfig']) is not None:
1389
- to_object['generationConfig']['responseModalities'] = getv(
1390
- config, ['response_modalities']
1391
- )
1392
- else:
1393
- to_object['generationConfig'] = {
1394
- 'responseModalities': getv(config, ['response_modalities'])
1395
- }
1396
- if getv(config, ['speech_config']) is not None:
1397
- if getv(to_object, ['generationConfig']) is not None:
1398
- to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
1399
- self._api_client,
1400
- t.t_speech_config(
1401
- self._api_client, getv(config, ['speech_config'])
1402
- ),
1403
- to_object,
1404
- )
1405
- else:
1406
- to_object['generationConfig'] = {
1407
- 'speechConfig': _SpeechConfig_to_mldev(
1408
- self._api_client,
1409
- t.t_speech_config(
1410
- self._api_client, getv(config, ['speech_config'])
1411
- ),
1412
- to_object,
1413
- )
1414
- }
1415
- if getv(config, ['temperature']) is not None:
1416
- if getv(to_object, ['generationConfig']) is not None:
1417
- to_object['generationConfig']['temperature'] = getv(
1418
- config, ['temperature']
1419
- )
1420
- else:
1421
- to_object['generationConfig'] = {
1422
- 'temperature': getv(config, ['temperature'])
1423
- }
1424
- if getv(config, ['top_p']) is not None:
1425
- if getv(to_object, ['generationConfig']) is not None:
1426
- to_object['generationConfig']['topP'] = getv(config, ['top_p'])
1427
- else:
1428
- to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
1429
- if getv(config, ['top_k']) is not None:
1430
- if getv(to_object, ['generationConfig']) is not None:
1431
- to_object['generationConfig']['topK'] = getv(config, ['top_k'])
1432
- else:
1433
- to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
1434
- if getv(config, ['max_output_tokens']) is not None:
1435
- if getv(to_object, ['generationConfig']) is not None:
1436
- to_object['generationConfig']['maxOutputTokens'] = getv(
1437
- config, ['max_output_tokens']
1438
- )
1439
- else:
1440
- to_object['generationConfig'] = {
1441
- 'maxOutputTokens': getv(config, ['max_output_tokens'])
1442
- }
1443
- if getv(config, ['media_resolution']) is not None:
1444
- if getv(to_object, ['generationConfig']) is not None:
1445
- to_object['generationConfig']['mediaResolution'] = getv(
1446
- config, ['media_resolution']
1447
- )
1448
- else:
1449
- to_object['generationConfig'] = {
1450
- 'mediaResolution': getv(config, ['media_resolution'])
1451
- }
1452
- if getv(config, ['seed']) is not None:
1453
- if getv(to_object, ['generationConfig']) is not None:
1454
- to_object['generationConfig']['seed'] = getv(config, ['seed'])
1455
- else:
1456
- to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
1457
- if getv(config, ['system_instruction']) is not None:
1458
- setv(
1459
- to_object,
1460
- ['systemInstruction'],
1461
- _Content_to_mldev(
1462
- self._api_client,
1463
- t.t_content(
1464
- self._api_client, getv(config, ['system_instruction'])
1465
- ),
1466
- to_object,
1467
- ),
1468
- )
1469
- if getv(config, ['tools']) is not None:
1470
- setv(
1471
- to_object,
1472
- ['tools'],
1473
- [
1474
- _Tool_to_mldev(
1475
- self._api_client, t.t_tool(self._api_client, item), to_object
1476
- )
1477
- for item in t.t_tools(self._api_client, getv(config, ['tools']))
1478
- ],
1479
- )
1480
- if getv(config, ['input_audio_transcription']) is not None:
1481
- raise ValueError('input_audio_transcription is not supported in MLDev '
1482
- 'API.')
1483
- if getv(config, ['output_audio_transcription']) is not None:
1484
- setv(
1485
- to_object,
1486
- ['outputAudioTranscription'],
1487
- _AudioTranscriptionConfig_to_mldev(
1488
- self._api_client,
1489
- getv(config, ['output_audio_transcription']),
1490
- ),
1491
- )
1492
-
1493
- if getv(config, ['session_resumption']) is not None:
1494
- setv(
1495
- to_object,
1496
- ['sessionResumption'],
1497
- self._LiveClientSessionResumptionConfig_to_mldev(
1498
- getv(config, ['session_resumption'])
1499
- ),
1500
- )
1501
-
1502
- if getv(config, ['context_window_compression']) is not None:
1503
- setv(
1504
- to_object,
1505
- ['contextWindowCompression'],
1506
- self._ContextWindowCompressionConfig_to_mldev(
1507
- getv(config, ['context_window_compression']),
1508
- ),
1509
- )
1510
-
1511
- return_value = {'setup': {'model': model}}
1512
- return_value['setup'].update(to_object)
1513
- return return_value
1514
-
1515
- def _SlidingWindow_to_mldev(
1516
- self,
1517
- from_object: Union[dict, object],
1518
- ) -> dict:
1519
- to_object: dict[str, Any] = {}
1520
- if getv(from_object, ['target_tokens']) is not None:
1521
- setv(to_object, ['targetTokens'], getv(from_object, ['target_tokens']))
1522
-
1523
- return to_object
1524
-
1525
-
1526
- def _ContextWindowCompressionConfig_to_mldev(
1527
- self,
1528
- from_object: Union[dict, object],
1529
- ) -> dict:
1530
- to_object: dict[str, Any] = {}
1531
- if getv(from_object, ['trigger_tokens']) is not None:
1532
- setv(to_object, ['triggerTokens'], getv(from_object, ['trigger_tokens']))
1533
-
1534
- if getv(from_object, ['sliding_window']) is not None:
1535
- setv(
1536
- to_object,
1537
- ['slidingWindow'],
1538
- self._SlidingWindow_to_mldev(
1539
- getv(from_object, ['sliding_window'])
1540
- ),
1541
- )
1542
-
1543
- return to_object
1544
-
1545
- def _LiveClientSessionResumptionConfig_to_mldev(
1546
- self,
1547
- from_object: Union[dict, object]
1548
- ) -> dict:
1549
- to_object: dict[str, Any] = {}
1550
- if getv(from_object, ['handle']) is not None:
1551
- setv(to_object, ['handle'], getv(from_object, ['handle']))
1552
-
1553
- if getv(from_object, ['transparent']) is not None:
1554
- raise ValueError('The `transparent` field is not supported in MLDev API')
1555
-
1556
- return to_object
1557
-
1558
- def _LiveSetup_to_vertex(
1559
- self, model: str, config: Optional[types.LiveConnectConfig] = None
1560
- ):
1561
-
1562
- to_object: dict[str, Any] = {}
1563
-
1564
- if getv(config, ['generation_config']) is not None:
1565
- setv(
1566
- to_object,
1567
- ['generationConfig'],
1568
- _GenerateContentConfig_to_vertex(
1569
- self._api_client,
1570
- getv(config, ['generation_config']),
1571
- to_object,
1572
- ),
1573
- )
1574
- if getv(config, ['response_modalities']) is not None:
1575
- if getv(to_object, ['generationConfig']) is not None:
1576
- to_object['generationConfig']['responseModalities'] = getv(
1577
- config, ['response_modalities']
1578
- )
1579
- else:
1580
- to_object['generationConfig'] = {
1581
- 'responseModalities': getv(config, ['response_modalities'])
1582
- }
1583
- else:
1584
- # Set default to AUDIO to align with MLDev API.
1585
- if getv(to_object, ['generationConfig']) is not None:
1586
- to_object['generationConfig'].update({'responseModalities': ['AUDIO']})
1587
- else:
1588
- to_object.update(
1589
- {'generationConfig': {'responseModalities': ['AUDIO']}}
1590
- )
1591
- if getv(config, ['speech_config']) is not None:
1592
- if getv(to_object, ['generationConfig']) is not None:
1593
- to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
1594
- self._api_client,
1595
- t.t_speech_config(
1596
- self._api_client, getv(config, ['speech_config'])
1597
- ),
1598
- to_object,
1599
- )
1600
- else:
1601
- to_object['generationConfig'] = {
1602
- 'speechConfig': _SpeechConfig_to_vertex(
1603
- self._api_client,
1604
- t.t_speech_config(
1605
- self._api_client, getv(config, ['speech_config'])
1606
- ),
1607
- to_object,
1608
- )
1609
- }
1610
- if getv(config, ['temperature']) is not None:
1611
- if getv(to_object, ['generationConfig']) is not None:
1612
- to_object['generationConfig']['temperature'] = getv(
1613
- config, ['temperature']
1614
- )
1615
- else:
1616
- to_object['generationConfig'] = {
1617
- 'temperature': getv(config, ['temperature'])
1618
- }
1619
- if getv(config, ['top_p']) is not None:
1620
- if getv(to_object, ['generationConfig']) is not None:
1621
- to_object['generationConfig']['topP'] = getv(config, ['top_p'])
1622
- else:
1623
- to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
1624
- if getv(config, ['top_k']) is not None:
1625
- if getv(to_object, ['generationConfig']) is not None:
1626
- to_object['generationConfig']['topK'] = getv(config, ['top_k'])
1627
- else:
1628
- to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
1629
- if getv(config, ['max_output_tokens']) is not None:
1630
- if getv(to_object, ['generationConfig']) is not None:
1631
- to_object['generationConfig']['maxOutputTokens'] = getv(
1632
- config, ['max_output_tokens']
1633
- )
1634
- else:
1635
- to_object['generationConfig'] = {
1636
- 'maxOutputTokens': getv(config, ['max_output_tokens'])
1637
- }
1638
- if getv(config, ['media_resolution']) is not None:
1639
- if getv(to_object, ['generationConfig']) is not None:
1640
- to_object['generationConfig']['mediaResolution'] = getv(
1641
- config, ['media_resolution']
1642
- )
1643
- else:
1644
- to_object['generationConfig'] = {
1645
- 'mediaResolution': getv(config, ['media_resolution'])
1646
- }
1647
- if getv(config, ['seed']) is not None:
1648
- if getv(to_object, ['generationConfig']) is not None:
1649
- to_object['generationConfig']['seed'] = getv(config, ['seed'])
1650
- else:
1651
- to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
1652
- if getv(config, ['system_instruction']) is not None:
1653
- setv(
1654
- to_object,
1655
- ['systemInstruction'],
1656
- _Content_to_vertex(
1657
- self._api_client,
1658
- t.t_content(
1659
- self._api_client, getv(config, ['system_instruction'])
1660
- ),
1661
- to_object,
1662
- ),
1663
- )
1664
- if getv(config, ['tools']) is not None:
1665
- setv(
1666
- to_object,
1667
- ['tools'],
1668
- [
1669
- _Tool_to_vertex(
1670
- self._api_client, t.t_tool(self._api_client, item), to_object
1671
- )
1672
- for item in t.t_tools(self._api_client, getv(config, ['tools']))
1673
- ],
1674
- )
1675
- if getv(config, ['input_audio_transcription']) is not None:
1676
- setv(
1677
- to_object,
1678
- ['inputAudioTranscription'],
1679
- _AudioTranscriptionConfig_to_vertex(
1680
- self._api_client,
1681
- getv(config, ['input_audio_transcription']),
1682
- ),
1683
- )
1684
- if getv(config, ['output_audio_transcription']) is not None:
1685
- setv(
1686
- to_object,
1687
- ['outputAudioTranscription'],
1688
- _AudioTranscriptionConfig_to_vertex(
1689
- self._api_client,
1690
- getv(config, ['output_audio_transcription']),
1691
- ),
1692
- )
1693
-
1694
- if getv(config, ['session_resumption']) is not None:
1695
- setv(
1696
- to_object,
1697
- ['sessionResumption'],
1698
- self._LiveClientSessionResumptionConfig_to_vertex(
1699
- getv(config, ['session_resumption'])
1700
- ),
1701
- )
1702
-
1703
- if getv(config, ['context_window_compression']) is not None:
1704
- setv(
1705
- to_object,
1706
- ['contextWindowCompression'],
1707
- self._ContextWindowCompressionConfig_to_vertex(
1708
- getv(config, ['context_window_compression']),
1709
- ),
1710
- )
1711
-
1712
- return_value = {'setup': {'model': model}}
1713
- return_value['setup'].update(to_object)
1714
- return return_value
1715
-
1716
- def _SlidingWindow_to_vertex(
1717
- self,
1718
- from_object: Union[dict, object],
1719
- ) -> dict:
1720
- to_object: dict[str, Any] = {}
1721
- if getv(from_object, ['target_tokens']) is not None:
1722
- setv(to_object, ['targetTokens'], getv(from_object, ['target_tokens']))
1723
-
1724
- return to_object
1725
-
1726
- def _ContextWindowCompressionConfig_to_vertex(
1727
- self,
1728
- from_object: Union[dict, object],
1729
- ) -> dict:
1730
- to_object: dict[str, Any] = {}
1731
- if getv(from_object, ['trigger_tokens']) is not None:
1732
- setv(to_object, ['triggerTokens'], getv(from_object, ['trigger_tokens']))
1733
-
1734
- if getv(from_object, ['sliding_window']) is not None:
1735
- setv(
1736
- to_object,
1737
- ['slidingWindow'],
1738
- self._SlidingWindow_to_mldev(
1739
- getv(from_object, ['sliding_window'])
1740
- ),
1741
- )
1742
-
1743
- return to_object
1744
-
1745
- def _LiveClientSessionResumptionConfig_to_vertex(
1746
- self,
1747
- from_object: Union[dict, object]
1748
- ) -> dict:
1749
- to_object: dict[str, Any] = {}
1750
- if getv(from_object, ['handle']) is not None:
1751
- setv(to_object, ['handle'], getv(from_object, ['handle']))
1752
-
1753
- if getv(from_object, ['transparent']) is not None:
1754
- setv(to_object, ['transparent'], getv(from_object, ['transparent']))
1755
-
1756
- return to_object
1757
794
 
1758
795
  @contextlib.asynccontextmanager
1759
796
  async def connect(
@@ -1787,12 +824,20 @@ class AsyncLive(_api_module.BaseModule):
1787
824
  version = self._api_client._http_options.api_version
1788
825
  uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
1789
826
  headers = self._api_client._http_options.headers
827
+
1790
828
  request_dict = _common.convert_to_dict(
1791
- self._LiveSetup_to_mldev(
1792
- model=transformed_model,
1793
- config=parameter_model,
829
+ live_converters._LiveConnectParameters_to_mldev(
830
+ api_client=self._api_client,
831
+ from_object=types.LiveConnectParameters(
832
+ model=transformed_model,
833
+ config=parameter_model,
834
+ ).model_dump(exclude_none=True)
1794
835
  )
1795
836
  )
837
+ del request_dict['config']
838
+
839
+ setv(request_dict, ['setup', 'model'], transformed_model)
840
+
1796
841
  request = json.dumps(request_dict)
1797
842
  else:
1798
843
  # Get bearer token through Application Default Credentials.
@@ -1819,11 +864,19 @@ class AsyncLive(_api_module.BaseModule):
1819
864
  f'projects/{project}/locations/{location}/' + transformed_model
1820
865
  )
1821
866
  request_dict = _common.convert_to_dict(
1822
- self._LiveSetup_to_vertex(
1823
- model=transformed_model,
1824
- config=parameter_model,
867
+ live_converters._LiveConnectParameters_to_vertex(
868
+ api_client=self._api_client,
869
+ from_object=types.LiveConnectParameters(
870
+ model=transformed_model,
871
+ config=parameter_model,
872
+ ).model_dump(exclude_none=True)
1825
873
  )
1826
874
  )
875
+ del request_dict['config']
876
+
877
+ if getv(request_dict, ['setup', 'generationConfig', 'responseModalities']) is None:
878
+ setv(request_dict, ['setup', 'generationConfig', 'responseModalities'], ['AUDIO'])
879
+
1827
880
  request = json.dumps(request_dict)
1828
881
 
1829
882
  try:
@@ -1849,19 +902,23 @@ def _t_live_connect_config(
1849
902
  if config is None:
1850
903
  parameter_model = types.LiveConnectConfig()
1851
904
  elif isinstance(config, dict):
1852
- system_instruction = config.pop('system_instruction', None)
1853
- if system_instruction is not None:
905
+ if getv(config, ['system_instruction']) is not None:
1854
906
  converted_system_instruction = t.t_content(
1855
- api_client, content=system_instruction
907
+ api_client, getv(config, ['system_instruction'])
1856
908
  )
1857
909
  else:
1858
910
  converted_system_instruction = None
1859
- parameter_model = types.LiveConnectConfig(
1860
- system_instruction=converted_system_instruction,
1861
- **config
1862
- ) # type: ignore
911
+ parameter_model = types.LiveConnectConfig(**config)
912
+ parameter_model.system_instruction = converted_system_instruction
1863
913
  else:
914
+ if config.system_instruction is None:
915
+ system_instruction = None
916
+ else:
917
+ system_instruction = t.t_content(
918
+ api_client, getv(config, ['system_instruction'])
919
+ )
1864
920
  parameter_model = config
921
+ parameter_model.system_instruction = system_instruction
1865
922
 
1866
923
  if parameter_model.generation_config is not None:
1867
924
  warnings.warn(