google-genai 1.8.0__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Google LLC
1
+ # Copyright 2025 Google LLC
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,14 +13,15 @@
13
13
  # limitations under the License.
14
14
  #
15
15
 
16
- """Live client. The live module is experimental."""
16
+ """[Preview] Live API client."""
17
17
 
18
18
  import asyncio
19
19
  import base64
20
20
  import contextlib
21
21
  import json
22
22
  import logging
23
- from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
23
+ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, cast, get_args
24
+ import warnings
24
25
 
25
26
  import google.auth
26
27
  import pydantic
@@ -30,7 +31,6 @@ from . import _api_module
30
31
  from . import _common
31
32
  from . import _transformers as t
32
33
  from . import client
33
- from . import errors
34
34
  from . import types
35
35
  from ._api_client import BaseApiClient
36
36
  from ._common import experimental_warning
@@ -65,8 +65,68 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
65
65
  )
66
66
 
67
67
 
68
+ def _ClientContent_to_mldev(
69
+ api_client: BaseApiClient,
70
+ from_object: types.LiveClientContent,
71
+ ) -> dict:
72
+ client_content = from_object.model_dump(exclude_none=True, mode='json')
73
+ if 'turns' in client_content:
74
+ client_content['turns'] = [
75
+ _Content_to_mldev(api_client=api_client, from_object=item)
76
+ for item in client_content['turns']
77
+ ]
78
+ return client_content
79
+
80
+
81
+ def _ClientContent_to_vertex(
82
+ api_client: BaseApiClient,
83
+ from_object: types.LiveClientContent,
84
+ ) -> dict:
85
+ client_content = from_object.model_dump(exclude_none=True, mode='json')
86
+ if 'turns' in client_content:
87
+ client_content['turns'] = [
88
+ _Content_to_vertex(api_client=api_client, from_object=item)
89
+ for item in client_content['turns']
90
+ ]
91
+ return client_content
92
+
93
+
94
+ def _ToolResponse_to_mldev(
95
+ api_client: BaseApiClient,
96
+ from_object: types.LiveClientToolResponse,
97
+ ) -> dict:
98
+ tool_response = from_object.model_dump(exclude_none=True, mode='json')
99
+ for response in tool_response.get('function_responses', []):
100
+ if response.get('id') is None:
101
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
102
+ return tool_response
103
+
104
+
105
+ def _ToolResponse_to_vertex(
106
+ api_client: BaseApiClient,
107
+ from_object: types.LiveClientToolResponse,
108
+ ) -> dict:
109
+ tool_response = from_object.model_dump(exclude_none=True, mode='json')
110
+ return tool_response
111
+
112
+ def _AudioTranscriptionConfig_to_mldev(
113
+ api_client: BaseApiClient,
114
+ from_object: types.AudioTranscriptionConfig,
115
+ ) -> dict:
116
+ audio_transcription: dict[str, Any] = {}
117
+ return audio_transcription
118
+
119
+
120
+ def _AudioTranscriptionConfig_to_vertex(
121
+ api_client: BaseApiClient,
122
+ from_object: types.AudioTranscriptionConfig,
123
+ ) -> dict:
124
+ audio_transcription: dict[str, Any] = {}
125
+ return audio_transcription
126
+
127
+
68
128
  class AsyncSession:
69
- """AsyncSession. The live module is experimental."""
129
+ """[Preview] AsyncSession."""
70
130
 
71
131
  def __init__(
72
132
  self, api_client: client.BaseApiClient, websocket: ClientConnection
@@ -90,7 +150,12 @@ class AsyncSession:
90
150
  ] = None,
91
151
  end_of_turn: Optional[bool] = False,
92
152
  ):
93
- """Send input to the model.
153
+ """[Deprecated] Send input to the model.
154
+
155
+ > **Warning**: This method is deprecated and will be removed in a future
156
+ version (not before Q3 2025). Please use one of the more specific methods:
157
+ `send_client_content`, `send_realtime_input`, or `send_tool_response`
158
+ instead.
94
159
 
95
160
  The method will send the input request to the server.
96
161
 
@@ -109,9 +174,219 @@ class AsyncSession:
109
174
  async for message in session.receive():
110
175
  print(message)
111
176
  """
177
+ warnings.warn(
178
+ 'The `session.send` method is deprecated and will be removed in a '
179
+ 'future version (not before Q3 2025).\n'
180
+ 'Please use one of the more specific methods: `send_client_content`, '
181
+ '`send_realtime_input`, or `send_tool_response` instead.',
182
+ DeprecationWarning,
183
+ stacklevel=2,
184
+ )
112
185
  client_message = self._parse_client_message(input, end_of_turn)
113
186
  await self._ws.send(json.dumps(client_message))
114
187
 
188
+ async def send_client_content(
189
+ self,
190
+ *,
191
+ turns: Optional[
192
+ Union[
193
+ types.Content,
194
+ types.ContentDict,
195
+ list[Union[types.Content, types.ContentDict]]
196
+ ]
197
+ ] = None,
198
+ turn_complete: bool = True,
199
+ ):
200
+ """Send non-realtime, turn based content to the model.
201
+
202
+ There are two ways to send messages to the live API:
203
+ `send_client_content` and `send_realtime_input`.
204
+
205
+ `send_client_content` messages are added to the model context **in order**.
206
+ Having a conversation using `send_client_content` messages is roughly
207
+ equivalent to using the `Chat.send_message_stream` method, except that the
208
+ state of the `chat` history is stored on the API server.
209
+
210
+ Because of `send_client_content`'s order guarantee, the model cannot
211
+ respond as quickly to `send_client_content` messages as to
212
+ `send_realtime_input` messages. This makes the biggest difference when
213
+ sending objects that have significant preprocessing time (typically images).
214
+
215
+ The `send_client_content` message sends a list of `Content` objects,
216
+ which has more options than the `media:Blob` sent by `send_realtime_input`.
217
+
218
+ The main use-cases for `send_client_content` over `send_realtime_input` are:
219
+
220
+ - Prefilling a conversation context (including sending anything that can't
221
+ be represented as a realtime message), before starting a realtime
222
+ conversation.
223
+ - Conducting a non-realtime conversation, similar to `client.chat`, using
224
+ the live api.
225
+
226
+ Caution: Interleaving `send_client_content` and `send_realtime_input`
227
+ in the same conversation is not recommended and can lead to unexpected
228
+ results.
229
+
230
+ Args:
231
+ turns: A `Content` object or list of `Content` objects (or equivalent
232
+ dicts).
233
+ turn_complete: if true (the default) the model will reply immediately. If
234
+ false, the model will wait for you to send additional client_content,
235
+ and will not return until you send `turn_complete=True`.
236
+
237
+ Example:
238
+ ```
239
+ import google.genai
240
+ from google.genai import types
241
+
242
+ client = genai.Client(http_options={'api_version': 'v1alpha'})
243
+ async with client.aio.live.connect(
244
+ model=MODEL_NAME,
245
+ config={"response_modalities": ["TEXT"]}
246
+ ) as session:
247
+ await session.send_client_content(
248
+ turns=types.Content(
249
+ role='user',
250
+ parts=[types.Part(text="Hello world!")]))
251
+ async for msg in session.receive():
252
+ if msg.text:
253
+ print(msg.text)
254
+ ```
255
+ """
256
+ client_content = _t_client_content(turns, turn_complete)
257
+
258
+ if self._api_client.vertexai:
259
+ client_content_dict = _ClientContent_to_vertex(
260
+ api_client=self._api_client, from_object=client_content
261
+ )
262
+ else:
263
+ client_content_dict = _ClientContent_to_mldev(
264
+ api_client=self._api_client, from_object=client_content
265
+ )
266
+
267
+ await self._ws.send(json.dumps({'client_content': client_content_dict}))
268
+
269
+ async def send_realtime_input(self, *, media: t.BlobUnion):
270
+ """Send realtime media chunks to the model.
271
+
272
+ Use `send_realtime_input` for realtime audio chunks and video
273
+ frames(images).
274
+
275
+ With `send_realtime_input` the api will respond to audio automatically
276
+ based on voice activity detection (VAD).
277
+
278
+ `send_realtime_input` is optimized for responsivness at the expense of
279
+ deterministic ordering. Audio and video tokens are added to the
280
+ context when they become available.
281
+
282
+ Args:
283
+ media: A `Blob`-like object, the realtime media to send.
284
+
285
+ Example:
286
+ ```
287
+ from pathlib import Path
288
+
289
+ from google import genai
290
+ from google.genai import types
291
+
292
+ import PIL.Image
293
+
294
+ client = genai.Client(http_options= {'api_version': 'v1alpha'})
295
+
296
+ async with client.aio.live.connect(
297
+ model=MODEL_NAME,
298
+ config={"response_modalities": ["TEXT"]},
299
+ ) as session:
300
+ await session.send_realtime_input(
301
+ media=PIL.Image.open('image.jpg'))
302
+
303
+ audio_bytes = Path('audio.pcm').read_bytes()
304
+ await session.send_realtime_input(
305
+ media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))
306
+
307
+ async for msg in session.receive():
308
+ if msg.text is not None:
309
+ print(f'{msg.text}')
310
+ ```
311
+ """
312
+ realtime_input = _t_realtime_input(media)
313
+ realtime_input_dict = realtime_input.model_dump(
314
+ exclude_none=True, mode='json'
315
+ )
316
+ await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))
317
+
318
+ async def send_tool_response(
319
+ self,
320
+ *,
321
+ function_responses: Union[
322
+ types.FunctionResponseOrDict,
323
+ Sequence[types.FunctionResponseOrDict],
324
+ ],
325
+ ):
326
+ """Send a tool response to the session.
327
+
328
+ Use `send_tool_response` to reply to `LiveServerToolCall` messages
329
+ from the server.
330
+
331
+ To set the available tools, use the `config.tools` argument
332
+ when you connect to the session (`client.live.connect`).
333
+
334
+ Args:
335
+ function_responses: A `FunctionResponse`-like object or list of
336
+ `FunctionResponse`-like objects.
337
+
338
+ Example:
339
+ ```
340
+ from google import genai
341
+ from google.genai import types
342
+
343
+ client = genai.Client(http_options={'api_version': 'v1alpha'})
344
+
345
+ tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
346
+ config = {
347
+ "tools": tools,
348
+ "response_modalities": ['TEXT']
349
+ }
350
+
351
+ async with client.aio.live.connect(
352
+ model='gemini-2.0-flash-exp',
353
+ config=config
354
+ ) as session:
355
+ prompt = "Turn on the lights please"
356
+ await session.send_client_content(
357
+ turns=prompt,
358
+ turn_complete=True)
359
+
360
+ async for chunk in session.receive():
361
+ if chunk.server_content:
362
+ if chunk.text is not None:
363
+ print(chunk.text)
364
+ elif chunk.tool_call:
365
+ print(chunk.tool_call)
366
+ print('_'*80)
367
+ function_response=types.FunctionResponse(
368
+ name='turn_on_the_lights',
369
+ response={'result': 'ok'},
370
+ id=chunk.tool_call.function_calls[0].id,
371
+ )
372
+ print(function_response)
373
+ await session.send_tool_response(
374
+ function_responses=function_response
375
+ )
376
+
377
+ print('_'*80)
378
+ """
379
+ tool_response = _t_tool_response(function_responses)
380
+ if self._api_client.vertexai:
381
+ tool_response_dict = _ToolResponse_to_vertex(
382
+ api_client=self._api_client, from_object=tool_response
383
+ )
384
+ else:
385
+ tool_response_dict = _ToolResponse_to_mldev(
386
+ api_client=self._api_client, from_object=tool_response
387
+ )
388
+ await self._ws.send(json.dumps({'tool_response': tool_response_dict}))
389
+
115
390
  async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
116
391
  """Receive model responses from the server.
117
392
 
@@ -120,8 +395,6 @@ class AsyncSession:
120
395
  is function call, user must call `send` with the function response to
121
396
  continue the turn.
122
397
 
123
- The live module is experimental.
124
-
125
398
  Yields:
126
399
  The model responses from the server.
127
400
 
@@ -146,15 +419,18 @@ class AsyncSession:
146
419
  async def start_stream(
147
420
  self, *, stream: AsyncIterator[bytes], mime_type: str
148
421
  ) -> AsyncIterator[types.LiveServerMessage]:
149
- """start a live session from a data stream.
422
+ """[Deprecated] Start a live session from a data stream.
423
+
424
+ > **Warning**: This method is deprecated and will be removed in a future
425
+ version (not before Q2 2025). Please use one of the more specific methods:
426
+ `send_client_content`, `send_realtime_input`, or `send_tool_response`
427
+ instead.
150
428
 
151
429
  The interaction terminates when the input stream is complete.
152
430
  This method will start two async tasks. One task will be used to send the
153
431
  input stream to the model and the other task will be used to receive the
154
432
  responses from the model.
155
433
 
156
- The live module is experimental.
157
-
158
434
  Args:
159
435
  stream: An iterator that yields the model response.
160
436
  mime_type: The MIME type of the data in the stream.
@@ -177,6 +453,13 @@ class AsyncSession:
177
453
  mime_type = 'audio/pcm'):
178
454
  play_audio_chunk(audio.data)
179
455
  """
456
+ warnings.warn(
457
+ 'Setting `AsyncSession.start_stream` is deprecated, '
458
+ 'and will be removed in a future release (not before Q3 2025). '
459
+ 'Please use the `receive`, and `send_realtime_input`, methods instead.',
460
+ DeprecationWarning,
461
+ stacklevel=4,
462
+ )
180
463
  stop_event = asyncio.Event()
181
464
  # Start the send loop. When stream is complete stop_event is set.
182
465
  asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
@@ -207,7 +490,10 @@ class AsyncSession:
207
490
 
208
491
  async def _receive(self) -> types.LiveServerMessage:
209
492
  parameter_model = types.LiveServerMessage()
210
- raw_response = await self._ws.recv(decode=False)
493
+ try:
494
+ raw_response = await self._ws.recv(decode=False)
495
+ except TypeError:
496
+ raw_response = await self._ws.recv() # type: ignore[assignment]
211
497
  if raw_response:
212
498
  try:
213
499
  response = json.loads(raw_response)
@@ -215,6 +501,7 @@ class AsyncSession:
215
501
  raise ValueError(f'Failed to parse response: {raw_response!r}')
216
502
  else:
217
503
  response = {}
504
+
218
505
  if self._api_client.vertexai:
219
506
  response_dict = self._LiveServerMessage_from_vertex(response)
220
507
  else:
@@ -256,6 +543,24 @@ class AsyncSession:
256
543
  )
257
544
  if getv(from_object, ['turnComplete']) is not None:
258
545
  setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
546
+ if getv(from_object, ['generationComplete']) is not None:
547
+ setv(
548
+ to_object,
549
+ ['generation_complete'],
550
+ getv(from_object, ['generationComplete']),
551
+ )
552
+ if getv(from_object, ['inputTranscription']) is not None:
553
+ setv(
554
+ to_object,
555
+ ['input_transcription'],
556
+ getv(from_object, ['inputTranscription']),
557
+ )
558
+ if getv(from_object, ['outputTranscription']) is not None:
559
+ setv(
560
+ to_object,
561
+ ['output_transcription'],
562
+ getv(from_object, ['outputTranscription']),
563
+ )
259
564
  if getv(from_object, ['interrupted']) is not None:
260
565
  setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
261
566
  return to_object
@@ -286,6 +591,128 @@ class AsyncSession:
286
591
  )
287
592
  return to_object
288
593
 
594
+ def _LiveServerGoAway_from_mldev(
595
+ self,
596
+ from_object: Union[dict, object],
597
+ parent_object: Optional[dict] = None,
598
+ ) -> dict:
599
+ to_object: dict[str, Any] = {}
600
+ if getv(from_object, ['timeLeft']) is not None:
601
+ setv(to_object, ['time_left'], getv(from_object, ['timeLeft']))
602
+
603
+ return to_object
604
+
605
+ def _LiveServerSessionResumptionUpdate_from_mldev(
606
+ self,
607
+ from_object: Union[dict, object],
608
+ parent_object: Optional[dict] = None,
609
+ ) -> dict:
610
+ to_object: dict[str, Any] = {}
611
+ if getv(from_object, ['newHandle']) is not None:
612
+ setv(to_object, ['new_handle'], getv(from_object, ['newHandle']))
613
+
614
+ if getv(from_object, ['resumable']) is not None:
615
+ setv(to_object, ['resumable'], getv(from_object, ['resumable']))
616
+
617
+ if getv(from_object, ['lastConsumedClientMessageIndex']) is not None:
618
+ setv(
619
+ to_object,
620
+ ['last_consumed_client_message_index'],
621
+ getv(from_object, ['lastConsumedClientMessageIndex']),
622
+ )
623
+
624
+ return to_object
625
+
626
+ def _ModalityTokenCount_from_mldev(
627
+ self,
628
+ from_object: Union[dict, object],
629
+ ) -> Dict[str, Any]:
630
+ to_object: Dict[str, Any] = {}
631
+ if getv(from_object, ['modality']) is not None:
632
+ setv(to_object, ['modality'], getv(from_object, ['modality']))
633
+ if getv(from_object, ['tokenCount']) is not None:
634
+ setv(to_object, ['token_count'], getv(from_object, ['tokenCount']))
635
+ return to_object
636
+
637
+ def _UsageMetadata_from_mldev(
638
+ self,
639
+ from_object: Union[dict, object],
640
+ ) -> Dict[str, Any]:
641
+ to_object: dict[str, Any] = {}
642
+ if getv(from_object, ['promptTokenCount']) is not None:
643
+ setv(
644
+ to_object,
645
+ ['prompt_token_count'],
646
+ getv(from_object, ['promptTokenCount']),
647
+ )
648
+ if getv(from_object, ['cachedContentTokenCount']) is not None:
649
+ setv(
650
+ to_object,
651
+ ['cached_content_token_count'],
652
+ getv(from_object, ['cachedContentTokenCount']),
653
+ )
654
+ if getv(from_object, ['responseTokenCount']) is not None:
655
+ setv(
656
+ to_object,
657
+ ['response_token_count'],
658
+ getv(from_object, ['responseTokenCount']),
659
+ )
660
+ if getv(from_object, ['toolUsePromptTokenCount']) is not None:
661
+ setv(
662
+ to_object,
663
+ ['tool_use_prompt_token_count'],
664
+ getv(from_object, ['toolUsePromptTokenCount']),
665
+ )
666
+ if getv(from_object, ['thoughtsTokenCount']) is not None:
667
+ setv(
668
+ to_object,
669
+ ['thoughts_token_count'],
670
+ getv(from_object, ['thoughtsTokenCount']),
671
+ )
672
+ if getv(from_object, ['totalTokenCount']) is not None:
673
+ setv(
674
+ to_object,
675
+ ['total_token_count'],
676
+ getv(from_object, ['totalTokenCount']),
677
+ )
678
+ if getv(from_object, ['promptTokensDetails']) is not None:
679
+ setv(
680
+ to_object,
681
+ ['prompt_tokens_details'],
682
+ [
683
+ self._ModalityTokenCount_from_mldev(item)
684
+ for item in getv(from_object, ['promptTokensDetails'])
685
+ ],
686
+ )
687
+ if getv(from_object, ['cacheTokensDetails']) is not None:
688
+ setv(
689
+ to_object,
690
+ ['cache_tokens_details'],
691
+ [
692
+ self._ModalityTokenCount_from_mldev(item)
693
+ for item in getv(from_object, ['cacheTokensDetails'])
694
+ ],
695
+ )
696
+ if getv(from_object, ['responseTokensDetails']) is not None:
697
+ setv(
698
+ to_object,
699
+ ['response_tokens_details'],
700
+ [
701
+ self._ModalityTokenCount_from_mldev(item)
702
+ for item in getv(from_object, ['responseTokensDetails'])
703
+ ],
704
+ )
705
+ if getv(from_object, ['toolUsePromptTokensDetails']) is not None:
706
+ setv(
707
+ to_object,
708
+ ['tool_use_prompt_tokens_details'],
709
+ [
710
+ self._ModalityTokenCount_from_mldev(item)
711
+ for item in getv(from_object, ['toolUsePromptTokensDetails'])
712
+ ],
713
+ )
714
+ return to_object
715
+
289
716
  def _LiveServerMessage_from_mldev(
290
717
  self,
291
718
  from_object: Union[dict, object],
@@ -311,6 +738,34 @@ class AsyncSession:
311
738
  ['tool_call_cancellation'],
312
739
  getv(from_object, ['toolCallCancellation']),
313
740
  )
741
+
742
+ if getv(from_object, ['goAway']) is not None:
743
+ setv(
744
+ to_object,
745
+ ['go_away'],
746
+ self._LiveServerGoAway_from_mldev(
747
+ getv(from_object, ['goAway']), to_object
748
+ ),
749
+ )
750
+
751
+ if getv(from_object, ['sessionResumptionUpdate']) is not None:
752
+ setv(
753
+ to_object,
754
+ ['session_resumption_update'],
755
+ self._LiveServerSessionResumptionUpdate_from_mldev(
756
+ getv(from_object, ['sessionResumptionUpdate']),
757
+ to_object,
758
+ ),
759
+ )
760
+
761
+ return to_object
762
+
763
+ if getv(from_object, ['usageMetadata']) is not None:
764
+ setv(
765
+ to_object,
766
+ ['usage_metadata'],
767
+ self._UsageMetadata_from_mldev(getv(from_object, ['usageMetadata'])),
768
+ )
314
769
  return to_object
315
770
 
316
771
  def _LiveServerContent_from_vertex(
@@ -329,10 +784,155 @@ class AsyncSession:
329
784
  )
330
785
  if getv(from_object, ['turnComplete']) is not None:
331
786
  setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
787
+ if getv(from_object, ['generationComplete']) is not None:
788
+ setv(
789
+ to_object,
790
+ ['generation_complete'],
791
+ getv(from_object, ['generationComplete']),
792
+ )
793
+ if getv(from_object, ['inputTranscription']) is not None:
794
+ setv(
795
+ to_object,
796
+ ['input_transcription'],
797
+ getv(from_object, ['inputTranscription']),
798
+ )
799
+ if getv(from_object, ['outputTranscription']) is not None:
800
+ setv(
801
+ to_object,
802
+ ['output_transcription'],
803
+ getv(from_object, ['outputTranscription']),
804
+ )
332
805
  if getv(from_object, ['interrupted']) is not None:
333
806
  setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
334
807
  return to_object
335
808
 
809
+ def _LiveServerGoAway_from_vertex(
810
+ self,
811
+ from_object: Union[dict, object],
812
+ ) -> dict:
813
+ to_object: dict[str, Any] = {}
814
+ if getv(from_object, ['timeLeft']) is not None:
815
+ setv(to_object, ['time_left'], getv(from_object, ['timeLeft']))
816
+
817
+ return to_object
818
+
819
+ def _LiveServerSessionResumptionUpdate_from_vertex(
820
+ self,
821
+ from_object: Union[dict, object],
822
+ ) -> dict:
823
+ to_object: dict[str, Any] = {}
824
+ if getv(from_object, ['newHandle']) is not None:
825
+ setv(to_object, ['new_handle'], getv(from_object, ['newHandle']))
826
+
827
+ if getv(from_object, ['resumable']) is not None:
828
+ setv(to_object, ['resumable'], getv(from_object, ['resumable']))
829
+
830
+ if getv(from_object, ['lastConsumedClientMessageIndex']) is not None:
831
+ setv(
832
+ to_object,
833
+ ['last_consumed_client_message_index'],
834
+ getv(from_object, ['lastConsumedClientMessageIndex']),
835
+ )
836
+
837
+ return to_object
838
+
839
+
840
+ def _ModalityTokenCount_from_vertex(
841
+ self,
842
+ from_object: Union[dict, object],
843
+ ) -> Dict[str, Any]:
844
+ to_object: Dict[str, Any] = {}
845
+ if getv(from_object, ['modality']) is not None:
846
+ setv(to_object, ['modality'], getv(from_object, ['modality']))
847
+ if getv(from_object, ['tokenCount']) is not None:
848
+ setv(to_object, ['token_count'], getv(from_object, ['tokenCount']))
849
+ return to_object
850
+
851
+ def _UsageMetadata_from_vertex(
852
+ self,
853
+ from_object: Union[dict, object],
854
+ ) -> Dict[str, Any]:
855
+ to_object: dict[str, Any] = {}
856
+ if getv(from_object, ['promptTokenCount']) is not None:
857
+ setv(
858
+ to_object,
859
+ ['prompt_token_count'],
860
+ getv(from_object, ['promptTokenCount']),
861
+ )
862
+ if getv(from_object, ['cachedContentTokenCount']) is not None:
863
+ setv(
864
+ to_object,
865
+ ['cached_content_token_count'],
866
+ getv(from_object, ['cachedContentTokenCount']),
867
+ )
868
+ if getv(from_object, ['candidatesTokenCount']) is not None:
869
+ setv(
870
+ to_object,
871
+ ['response_token_count'],
872
+ getv(from_object, ['candidatesTokenCount']),
873
+ )
874
+ if getv(from_object, ['toolUsePromptTokenCount']) is not None:
875
+ setv(
876
+ to_object,
877
+ ['tool_use_prompt_token_count'],
878
+ getv(from_object, ['toolUsePromptTokenCount']),
879
+ )
880
+ if getv(from_object, ['thoughtsTokenCount']) is not None:
881
+ setv(
882
+ to_object,
883
+ ['thoughts_token_count'],
884
+ getv(from_object, ['thoughtsTokenCount']),
885
+ )
886
+ if getv(from_object, ['totalTokenCount']) is not None:
887
+ setv(
888
+ to_object,
889
+ ['total_token_count'],
890
+ getv(from_object, ['totalTokenCount']),
891
+ )
892
+ if getv(from_object, ['promptTokensDetails']) is not None:
893
+ setv(
894
+ to_object,
895
+ ['prompt_tokens_details'],
896
+ [
897
+ self._ModalityTokenCount_from_vertex(item)
898
+ for item in getv(from_object, ['promptTokensDetails'])
899
+ ],
900
+ )
901
+ if getv(from_object, ['cacheTokensDetails']) is not None:
902
+ setv(
903
+ to_object,
904
+ ['cache_tokens_details'],
905
+ [
906
+ self._ModalityTokenCount_from_vertex(item)
907
+ for item in getv(from_object, ['cacheTokensDetails'])
908
+ ],
909
+ )
910
+ if getv(from_object, ['toolUsePromptTokensDetails']) is not None:
911
+ setv(
912
+ to_object,
913
+ ['tool_use_prompt_tokens_details'],
914
+ [
915
+ self._ModalityTokenCount_from_vertex(item)
916
+ for item in getv(from_object, ['toolUsePromptTokensDetails'])
917
+ ],
918
+ )
919
+ if getv(from_object, ['candidatesTokensDetails']) is not None:
920
+ setv(
921
+ to_object,
922
+ ['response_tokens_details'],
923
+ [
924
+ self._ModalityTokenCount_from_vertex(item)
925
+ for item in getv(from_object, ['candidatesTokensDetails'])
926
+ ],
927
+ )
928
+ if getv(from_object, ['trafficType']) is not None:
929
+ setv(
930
+ to_object,
931
+ ['traffic_type'],
932
+ getv(from_object, ['trafficType']),
933
+ )
934
+ return to_object
935
+
336
936
  def _LiveServerMessage_from_vertex(
337
937
  self,
338
938
  from_object: Union[dict, object],
@@ -346,7 +946,6 @@ class AsyncSession:
346
946
  getv(from_object, ['serverContent'])
347
947
  ),
348
948
  )
349
-
350
949
  if getv(from_object, ['toolCall']) is not None:
351
950
  setv(
352
951
  to_object,
@@ -359,6 +958,31 @@ class AsyncSession:
359
958
  ['tool_call_cancellation'],
360
959
  getv(from_object, ['toolCallCancellation']),
361
960
  )
961
+
962
+ if getv(from_object, ['goAway']) is not None:
963
+ setv(
964
+ to_object,
965
+ ['go_away'],
966
+ self._LiveServerGoAway_from_vertex(
967
+ getv(from_object, ['goAway'])
968
+ ),
969
+ )
970
+
971
+ if getv(from_object, ['sessionResumptionUpdate']) is not None:
972
+ setv(
973
+ to_object,
974
+ ['session_resumption_update'],
975
+ self._LiveServerSessionResumptionUpdate_from_vertex(
976
+ getv(from_object, ['sessionResumptionUpdate']),
977
+ ),
978
+ )
979
+
980
+ if getv(from_object, ['usageMetadata']) is not None:
981
+ setv(
982
+ to_object,
983
+ ['usage_metadata'],
984
+ self._UsageMetadata_from_vertex(getv(from_object, ['usageMetadata'])),
985
+ )
362
986
  return to_object
363
987
 
364
988
  def _parse_client_message(
@@ -669,8 +1293,81 @@ class AsyncSession:
669
1293
  await self._ws.close()
670
1294
 
671
1295
 
1296
+ def _t_content_strict(content: types.ContentOrDict):
1297
+ if isinstance(content, dict):
1298
+ return types.Content.model_validate(content)
1299
+ elif isinstance(content, types.Content):
1300
+ return content
1301
+ else:
1302
+ raise ValueError(
1303
+ f'Could not convert input (type "{type(content)}") to '
1304
+ '`types.Content`'
1305
+ )
1306
+
1307
+
1308
+ def _t_contents_strict(
1309
+ contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict]):
1310
+ if isinstance(contents, Sequence):
1311
+ return [_t_content_strict(content) for content in contents]
1312
+ else:
1313
+ return [_t_content_strict(contents)]
1314
+
1315
+
1316
+ def _t_client_content(
1317
+ turns: Optional[
1318
+ Union[Sequence[types.ContentOrDict], types.ContentOrDict]
1319
+ ] = None,
1320
+ turn_complete: bool = True,
1321
+ ) -> types.LiveClientContent:
1322
+ if turns is None:
1323
+ return types.LiveClientContent(turn_complete=turn_complete)
1324
+
1325
+ try:
1326
+ return types.LiveClientContent(
1327
+ turns=_t_contents_strict(contents=turns),
1328
+ turn_complete=turn_complete,
1329
+ )
1330
+ except Exception as e:
1331
+ raise ValueError(
1332
+ f'Could not convert input (type "{type(turns)}") to '
1333
+ '`types.LiveClientContent`'
1334
+ ) from e
1335
+
1336
+
1337
+ def _t_realtime_input(
1338
+ media: t.BlobUnion,
1339
+ ) -> types.LiveClientRealtimeInput:
1340
+ try:
1341
+ return types.LiveClientRealtimeInput(media_chunks=[t.t_blob(blob=media)])
1342
+ except Exception as e:
1343
+ raise ValueError(
1344
+ f'Could not convert input (type "{type(input)}") to '
1345
+ '`types.LiveClientRealtimeInput`'
1346
+ ) from e
1347
+
1348
+
1349
+ def _t_tool_response(
1350
+ input: Union[
1351
+ types.FunctionResponseOrDict,
1352
+ Sequence[types.FunctionResponseOrDict],
1353
+ ],
1354
+ ) -> types.LiveClientToolResponse:
1355
+ if not input:
1356
+ raise ValueError(f'A tool response is required, got: \n{input}')
1357
+
1358
+ try:
1359
+ return types.LiveClientToolResponse(
1360
+ function_responses=t.t_function_responses(function_responses=input)
1361
+ )
1362
+ except Exception as e:
1363
+ raise ValueError(
1364
+ f'Could not convert input (type "{type(input)}") to '
1365
+ '`types.LiveClientToolResponse`'
1366
+ ) from e
1367
+
1368
+
672
1369
  class AsyncLive(_api_module.BaseModule):
673
- """AsyncLive. The live module is experimental."""
1370
+ """[Preview] AsyncLive."""
674
1371
 
675
1372
  def _LiveSetup_to_mldev(
676
1373
  self, model: str, config: Optional[types.LiveConnectConfig] = None
@@ -715,7 +1412,48 @@ class AsyncLive(_api_module.BaseModule):
715
1412
  to_object,
716
1413
  )
717
1414
  }
718
-
1415
+ if getv(config, ['temperature']) is not None:
1416
+ if getv(to_object, ['generationConfig']) is not None:
1417
+ to_object['generationConfig']['temperature'] = getv(
1418
+ config, ['temperature']
1419
+ )
1420
+ else:
1421
+ to_object['generationConfig'] = {
1422
+ 'temperature': getv(config, ['temperature'])
1423
+ }
1424
+ if getv(config, ['top_p']) is not None:
1425
+ if getv(to_object, ['generationConfig']) is not None:
1426
+ to_object['generationConfig']['topP'] = getv(config, ['top_p'])
1427
+ else:
1428
+ to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
1429
+ if getv(config, ['top_k']) is not None:
1430
+ if getv(to_object, ['generationConfig']) is not None:
1431
+ to_object['generationConfig']['topK'] = getv(config, ['top_k'])
1432
+ else:
1433
+ to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
1434
+ if getv(config, ['max_output_tokens']) is not None:
1435
+ if getv(to_object, ['generationConfig']) is not None:
1436
+ to_object['generationConfig']['maxOutputTokens'] = getv(
1437
+ config, ['max_output_tokens']
1438
+ )
1439
+ else:
1440
+ to_object['generationConfig'] = {
1441
+ 'maxOutputTokens': getv(config, ['max_output_tokens'])
1442
+ }
1443
+ if getv(config, ['media_resolution']) is not None:
1444
+ if getv(to_object, ['generationConfig']) is not None:
1445
+ to_object['generationConfig']['mediaResolution'] = getv(
1446
+ config, ['media_resolution']
1447
+ )
1448
+ else:
1449
+ to_object['generationConfig'] = {
1450
+ 'mediaResolution': getv(config, ['media_resolution'])
1451
+ }
1452
+ if getv(config, ['seed']) is not None:
1453
+ if getv(to_object, ['generationConfig']) is not None:
1454
+ to_object['generationConfig']['seed'] = getv(config, ['seed'])
1455
+ else:
1456
+ to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
719
1457
  if getv(config, ['system_instruction']) is not None:
720
1458
  setv(
721
1459
  to_object,
@@ -739,11 +1477,84 @@ class AsyncLive(_api_module.BaseModule):
739
1477
  for item in t.t_tools(self._api_client, getv(config, ['tools']))
740
1478
  ],
741
1479
  )
1480
+ if getv(config, ['input_audio_transcription']) is not None:
1481
+ raise ValueError('input_audio_transcription is not supported in MLDev '
1482
+ 'API.')
1483
+ if getv(config, ['output_audio_transcription']) is not None:
1484
+ setv(
1485
+ to_object,
1486
+ ['outputAudioTranscription'],
1487
+ _AudioTranscriptionConfig_to_mldev(
1488
+ self._api_client,
1489
+ getv(config, ['output_audio_transcription']),
1490
+ ),
1491
+ )
1492
+
1493
+ if getv(config, ['session_resumption']) is not None:
1494
+ setv(
1495
+ to_object,
1496
+ ['sessionResumption'],
1497
+ self._LiveClientSessionResumptionConfig_to_mldev(
1498
+ getv(config, ['session_resumption'])
1499
+ ),
1500
+ )
1501
+
1502
+ if getv(config, ['context_window_compression']) is not None:
1503
+ setv(
1504
+ to_object,
1505
+ ['contextWindowCompression'],
1506
+ self._ContextWindowCompressionConfig_to_mldev(
1507
+ getv(config, ['context_window_compression']),
1508
+ ),
1509
+ )
742
1510
 
743
1511
  return_value = {'setup': {'model': model}}
744
1512
  return_value['setup'].update(to_object)
745
1513
  return return_value
746
1514
 
1515
+ def _SlidingWindow_to_mldev(
1516
+ self,
1517
+ from_object: Union[dict, object],
1518
+ ) -> dict:
1519
+ to_object: dict[str, Any] = {}
1520
+ if getv(from_object, ['target_tokens']) is not None:
1521
+ setv(to_object, ['targetTokens'], getv(from_object, ['target_tokens']))
1522
+
1523
+ return to_object
1524
+
1525
+
1526
+ def _ContextWindowCompressionConfig_to_mldev(
1527
+ self,
1528
+ from_object: Union[dict, object],
1529
+ ) -> dict:
1530
+ to_object: dict[str, Any] = {}
1531
+ if getv(from_object, ['trigger_tokens']) is not None:
1532
+ setv(to_object, ['triggerTokens'], getv(from_object, ['trigger_tokens']))
1533
+
1534
+ if getv(from_object, ['sliding_window']) is not None:
1535
+ setv(
1536
+ to_object,
1537
+ ['slidingWindow'],
1538
+ self._SlidingWindow_to_mldev(
1539
+ getv(from_object, ['sliding_window'])
1540
+ ),
1541
+ )
1542
+
1543
+ return to_object
1544
+
1545
+ def _LiveClientSessionResumptionConfig_to_mldev(
1546
+ self,
1547
+ from_object: Union[dict, object]
1548
+ ) -> dict:
1549
+ to_object: dict[str, Any] = {}
1550
+ if getv(from_object, ['handle']) is not None:
1551
+ setv(to_object, ['handle'], getv(from_object, ['handle']))
1552
+
1553
+ if getv(from_object, ['transparent']) is not None:
1554
+ raise ValueError('The `transparent` field is not supported in MLDev API')
1555
+
1556
+ return to_object
1557
+
747
1558
  def _LiveSetup_to_vertex(
748
1559
  self, model: str, config: Optional[types.LiveConnectConfig] = None
749
1560
  ):
@@ -796,6 +1607,48 @@ class AsyncLive(_api_module.BaseModule):
796
1607
  to_object,
797
1608
  )
798
1609
  }
1610
+ if getv(config, ['temperature']) is not None:
1611
+ if getv(to_object, ['generationConfig']) is not None:
1612
+ to_object['generationConfig']['temperature'] = getv(
1613
+ config, ['temperature']
1614
+ )
1615
+ else:
1616
+ to_object['generationConfig'] = {
1617
+ 'temperature': getv(config, ['temperature'])
1618
+ }
1619
+ if getv(config, ['top_p']) is not None:
1620
+ if getv(to_object, ['generationConfig']) is not None:
1621
+ to_object['generationConfig']['topP'] = getv(config, ['top_p'])
1622
+ else:
1623
+ to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
1624
+ if getv(config, ['top_k']) is not None:
1625
+ if getv(to_object, ['generationConfig']) is not None:
1626
+ to_object['generationConfig']['topK'] = getv(config, ['top_k'])
1627
+ else:
1628
+ to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
1629
+ if getv(config, ['max_output_tokens']) is not None:
1630
+ if getv(to_object, ['generationConfig']) is not None:
1631
+ to_object['generationConfig']['maxOutputTokens'] = getv(
1632
+ config, ['max_output_tokens']
1633
+ )
1634
+ else:
1635
+ to_object['generationConfig'] = {
1636
+ 'maxOutputTokens': getv(config, ['max_output_tokens'])
1637
+ }
1638
+ if getv(config, ['media_resolution']) is not None:
1639
+ if getv(to_object, ['generationConfig']) is not None:
1640
+ to_object['generationConfig']['mediaResolution'] = getv(
1641
+ config, ['media_resolution']
1642
+ )
1643
+ else:
1644
+ to_object['generationConfig'] = {
1645
+ 'mediaResolution': getv(config, ['media_resolution'])
1646
+ }
1647
+ if getv(config, ['seed']) is not None:
1648
+ if getv(to_object, ['generationConfig']) is not None:
1649
+ to_object['generationConfig']['seed'] = getv(config, ['seed'])
1650
+ else:
1651
+ to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
799
1652
  if getv(config, ['system_instruction']) is not None:
800
1653
  setv(
801
1654
  to_object,
@@ -819,14 +1672,89 @@ class AsyncLive(_api_module.BaseModule):
819
1672
  for item in t.t_tools(self._api_client, getv(config, ['tools']))
820
1673
  ],
821
1674
  )
1675
+ if getv(config, ['input_audio_transcription']) is not None:
1676
+ setv(
1677
+ to_object,
1678
+ ['inputAudioTranscription'],
1679
+ _AudioTranscriptionConfig_to_vertex(
1680
+ self._api_client,
1681
+ getv(config, ['input_audio_transcription']),
1682
+ ),
1683
+ )
1684
+ if getv(config, ['output_audio_transcription']) is not None:
1685
+ setv(
1686
+ to_object,
1687
+ ['outputAudioTranscription'],
1688
+ _AudioTranscriptionConfig_to_vertex(
1689
+ self._api_client,
1690
+ getv(config, ['output_audio_transcription']),
1691
+ ),
1692
+ )
1693
+
1694
+ if getv(config, ['session_resumption']) is not None:
1695
+ setv(
1696
+ to_object,
1697
+ ['sessionResumption'],
1698
+ self._LiveClientSessionResumptionConfig_to_vertex(
1699
+ getv(config, ['session_resumption'])
1700
+ ),
1701
+ )
1702
+
1703
+ if getv(config, ['context_window_compression']) is not None:
1704
+ setv(
1705
+ to_object,
1706
+ ['contextWindowCompression'],
1707
+ self._ContextWindowCompressionConfig_to_vertex(
1708
+ getv(config, ['context_window_compression']),
1709
+ ),
1710
+ )
822
1711
 
823
1712
  return_value = {'setup': {'model': model}}
824
1713
  return_value['setup'].update(to_object)
825
1714
  return return_value
826
1715
 
827
- @experimental_warning(
828
- 'The live API is experimental and may change in future versions.',
829
- )
1716
+ def _SlidingWindow_to_vertex(
1717
+ self,
1718
+ from_object: Union[dict, object],
1719
+ ) -> dict:
1720
+ to_object: dict[str, Any] = {}
1721
+ if getv(from_object, ['target_tokens']) is not None:
1722
+ setv(to_object, ['targetTokens'], getv(from_object, ['target_tokens']))
1723
+
1724
+ return to_object
1725
+
1726
+ def _ContextWindowCompressionConfig_to_vertex(
1727
+ self,
1728
+ from_object: Union[dict, object],
1729
+ ) -> dict:
1730
+ to_object: dict[str, Any] = {}
1731
+ if getv(from_object, ['trigger_tokens']) is not None:
1732
+ setv(to_object, ['triggerTokens'], getv(from_object, ['trigger_tokens']))
1733
+
1734
+ if getv(from_object, ['sliding_window']) is not None:
1735
+ setv(
1736
+ to_object,
1737
+ ['slidingWindow'],
1738
+ self._SlidingWindow_to_mldev(
1739
+ getv(from_object, ['sliding_window'])
1740
+ ),
1741
+ )
1742
+
1743
+ return to_object
1744
+
1745
+ def _LiveClientSessionResumptionConfig_to_vertex(
1746
+ self,
1747
+ from_object: Union[dict, object]
1748
+ ) -> dict:
1749
+ to_object: dict[str, Any] = {}
1750
+ if getv(from_object, ['handle']) is not None:
1751
+ setv(to_object, ['handle'], getv(from_object, ['handle']))
1752
+
1753
+ if getv(from_object, ['transparent']) is not None:
1754
+ setv(to_object, ['transparent'], getv(from_object, ['transparent']))
1755
+
1756
+ return to_object
1757
+
830
1758
  @contextlib.asynccontextmanager
831
1759
  async def connect(
832
1760
  self,
@@ -834,9 +1762,9 @@ class AsyncLive(_api_module.BaseModule):
834
1762
  model: str,
835
1763
  config: Optional[types.LiveConnectConfigOrDict] = None,
836
1764
  ) -> AsyncIterator[AsyncSession]:
837
- """Connect to the live server.
1765
+ """[Preview] Connect to the live server.
838
1766
 
839
- The live module is experimental.
1767
+ Note: the live API is currently in preview.
840
1768
 
841
1769
  Usage:
842
1770
 
@@ -851,25 +1779,8 @@ class AsyncLive(_api_module.BaseModule):
851
1779
  """
852
1780
  base_url = self._api_client._websocket_base_url()
853
1781
  transformed_model = t.t_model(self._api_client, model)
854
- # Ensure the config is a LiveConnectConfig.
855
- if config is None:
856
- parameter_model = types.LiveConnectConfig()
857
- elif isinstance(config, dict):
858
- if config.get('system_instruction') is None:
859
- system_instruction = None
860
- else:
861
- system_instruction = t.t_content(
862
- self._api_client, config.get('system_instruction')
863
- )
864
- parameter_model = types.LiveConnectConfig(
865
- generation_config=config.get('generation_config'),
866
- response_modalities=config.get('response_modalities'),
867
- speech_config=config.get('speech_config'),
868
- system_instruction=system_instruction,
869
- tools=config.get('tools'),
870
- )
871
- else:
872
- parameter_model = config
1782
+
1783
+ parameter_model = _t_live_connect_config(self._api_client, config)
873
1784
 
874
1785
  if self._api_client.api_key:
875
1786
  api_key = self._api_client.api_key
@@ -915,8 +1826,50 @@ class AsyncLive(_api_module.BaseModule):
915
1826
  )
916
1827
  request = json.dumps(request_dict)
917
1828
 
918
- async with connect(uri, additional_headers=headers) as ws:
919
- await ws.send(request)
920
- logger.info(await ws.recv(decode=False))
1829
+ try:
1830
+ async with connect(uri, additional_headers=headers) as ws:
1831
+ await ws.send(request)
1832
+ logger.info(await ws.recv(decode=False))
1833
+
1834
+ yield AsyncSession(api_client=self._api_client, websocket=ws)
1835
+ except TypeError:
1836
+ # Try with the older websockets API
1837
+ async with connect(uri, extra_headers=headers) as ws:
1838
+ await ws.send(request)
1839
+ logger.info(await ws.recv())
1840
+
1841
+ yield AsyncSession(api_client=self._api_client, websocket=ws)
1842
+
1843
+
1844
+ def _t_live_connect_config(
1845
+ api_client: BaseApiClient,
1846
+ config: Optional[types.LiveConnectConfigOrDict],
1847
+ ) -> types.LiveConnectConfig:
1848
+ # Ensure the config is a LiveConnectConfig.
1849
+ if config is None:
1850
+ parameter_model = types.LiveConnectConfig()
1851
+ elif isinstance(config, dict):
1852
+ system_instruction = config.pop('system_instruction', None)
1853
+ if system_instruction is not None:
1854
+ converted_system_instruction = t.t_content(
1855
+ api_client, content=system_instruction
1856
+ )
1857
+ else:
1858
+ converted_system_instruction = None
1859
+ parameter_model = types.LiveConnectConfig(
1860
+ system_instruction=converted_system_instruction,
1861
+ **config
1862
+ ) # type: ignore
1863
+ else:
1864
+ parameter_model = config
1865
+
1866
+ if parameter_model.generation_config is not None:
1867
+ warnings.warn(
1868
+ 'Setting `LiveConnectConfig.generation_config` is deprecated, '
1869
+ 'please set the fields on `LiveConnectConfig` directly. This will '
1870
+ 'become an error in a future version (not before Q3 2025)',
1871
+ DeprecationWarning,
1872
+ stacklevel=4,
1873
+ )
921
1874
 
922
- yield AsyncSession(api_client=self._api_client, websocket=ws)
1875
+ return parameter_model