google-genai 1.7.0__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Google LLC
1
+ # Copyright 2025 Google LLC
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@ import contextlib
21
21
  import json
22
22
  import logging
23
23
  from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
24
+ import warnings
24
25
 
25
26
  import google.auth
26
27
  import pydantic
@@ -30,7 +31,6 @@ from . import _api_module
30
31
  from . import _common
31
32
  from . import _transformers as t
32
33
  from . import client
33
- from . import errors
34
34
  from . import types
35
35
  from ._api_client import BaseApiClient
36
36
  from ._common import experimental_warning
@@ -65,6 +65,59 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
65
65
  )
66
66
 
67
67
 
68
+ def _ClientContent_to_mldev(
69
+ api_client: BaseApiClient,
70
+ from_object: types.LiveClientContent,
71
+ ) -> dict:
72
+ client_content = from_object.model_dump(exclude_none=True, mode='json')
73
+ if 'turns' in client_content:
74
+ client_content['turns'] = [
75
+ _Content_to_mldev(api_client=api_client, from_object=item)
76
+ for item in client_content['turns']
77
+ ]
78
+ return client_content
79
+
80
+
81
+ def _ClientContent_to_vertex(
82
+ api_client: BaseApiClient,
83
+ from_object: types.LiveClientContent,
84
+ ) -> dict:
85
+ client_content = from_object.model_dump(exclude_none=True, mode='json')
86
+ if 'turns' in client_content:
87
+ client_content['turns'] = [
88
+ _Content_to_vertex(api_client=api_client, from_object=item)
89
+ for item in client_content['turns']
90
+ ]
91
+ return client_content
92
+
93
+
94
+ def _ToolResponse_to_mldev(
95
+ api_client: BaseApiClient,
96
+ from_object: types.LiveClientToolResponse,
97
+ ) -> dict:
98
+ tool_response = from_object.model_dump(exclude_none=True, mode='json')
99
+ for response in tool_response.get('function_responses', []):
100
+ if response.get('id') is None:
101
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
102
+ return tool_response
103
+
104
+
105
+ def _ToolResponse_to_vertex(
106
+ api_client: BaseApiClient,
107
+ from_object: types.LiveClientToolResponse,
108
+ ) -> dict:
109
+ tool_response = from_object.model_dump(exclude_none=True, mode='json')
110
+ return tool_response
111
+
112
+
113
+ def _AudioTranscriptionConfig_to_vertex(
114
+ api_client: BaseApiClient,
115
+ from_object: types.AudioTranscriptionConfig,
116
+ ) -> dict:
117
+ audio_transcription: dict[str, Any] = {}
118
+ return audio_transcription
119
+
120
+
68
121
  class AsyncSession:
69
122
  """AsyncSession. The live module is experimental."""
70
123
 
@@ -112,6 +165,208 @@ class AsyncSession:
112
165
  client_message = self._parse_client_message(input, end_of_turn)
113
166
  await self._ws.send(json.dumps(client_message))
114
167
 
168
+ async def send_client_content(
169
+ self,
170
+ *,
171
+ turns: Optional[
172
+ Union[
173
+ types.Content,
174
+ types.ContentDict,
175
+ list[Union[types.Content, types.ContentDict]]
176
+ ]
177
+ ] = None,
178
+ turn_complete: bool = True,
179
+ ):
180
+ """Send non-realtime, turn based content to the model.
181
+
182
+ There are two ways to send messages to the live API:
183
+ `send_client_content` and `send_realtime_input`.
184
+
185
+ `send_client_content` messages are added to the model context **in order**.
186
+ Having a conversation using `send_client_content` messages is roughly
187
+ equivalent to using the `Chat.send_message_stream` method, except that the
188
+ state of the `chat` history is stored on the API server.
189
+
190
+ Because of `send_client_content`'s order guarantee, the model cannot
191
+ respond as quickly to `send_client_content` messages as to
192
+ `send_realtime_input` messages. This makes the biggest difference when
193
+ sending objects that have significant preprocessing time (typically images).
194
+
195
+ The `send_client_content` message sends a list of `Content` objects,
196
+ which has more options than the `media:Blob` sent by `send_realtime_input`.
197
+
198
+ The main use-cases for `send_client_content` over `send_realtime_input` are:
199
+
200
+ - Prefilling a conversation context (including sending anything that can't
201
+ be represented as a realtime message), before starting a realtime
202
+ conversation.
203
+ - Conducting a non-realtime conversation, similar to `client.chat`, using
204
+ the live api.
205
+
206
+ Caution: Interleaving `send_client_content` and `send_realtime_input`
207
+ in the same conversation is not recommended and can lead to unexpected
208
+ results.
209
+
210
+ Args:
211
+ turns: A `Content` object or list of `Content` objects (or equivalent
212
+ dicts).
213
+ turn_complete: if true (the default) the model will reply immediately. If
214
+ false, the model will wait for you to send additional client_content,
215
+ and will not return until you send `turn_complete=True`.
216
+
217
+ Example:
218
+ ```
219
+ import google.genai
220
+ from google.genai import types
221
+
222
+ client = genai.Client(http_options={'api_version': 'v1alpha'})
223
+ async with client.aio.live.connect(
224
+ model=MODEL_NAME,
225
+ config={"response_modalities": ["TEXT"]}
226
+ ) as session:
227
+ await session.send_client_content(
228
+ turns=types.Content(
229
+ role='user',
230
+ parts=[types.Part(text="Hello world!")]))
231
+ async for msg in session.receive():
232
+ if msg.text:
233
+ print(msg.text)
234
+ ```
235
+ """
236
+ client_content = _t_client_content(turns, turn_complete)
237
+
238
+ if self._api_client.vertexai:
239
+ client_content_dict = _ClientContent_to_vertex(
240
+ api_client=self._api_client, from_object=client_content
241
+ )
242
+ else:
243
+ client_content_dict = _ClientContent_to_mldev(
244
+ api_client=self._api_client, from_object=client_content
245
+ )
246
+
247
+ await self._ws.send(json.dumps({'client_content': client_content_dict}))
248
+
249
+ async def send_realtime_input(self, *, media: t.BlobUnion):
250
+ """Send realtime media chunks to the model.
251
+
252
+ Use `send_realtime_input` for realtime audio chunks and video
253
+ frames(images).
254
+
255
+ With `send_realtime_input` the api will respond to audio automatically
256
+ based on voice activity detection (VAD).
257
+
258
+ `send_realtime_input` is optimized for responsivness at the expense of
259
+ deterministic ordering. Audio and video tokens are added to the
260
+ context when they become available.
261
+
262
+ Args:
263
+ media: A `Blob`-like object, the realtime media to send.
264
+
265
+ Example:
266
+ ```
267
+ from pathlib import Path
268
+
269
+ from google import genai
270
+ from google.genai import types
271
+
272
+ import PIL.Image
273
+
274
+ client = genai.Client(http_options= {'api_version': 'v1alpha'})
275
+
276
+ async with client.aio.live.connect(
277
+ model=MODEL_NAME,
278
+ config={"response_modalities": ["TEXT"]},
279
+ ) as session:
280
+ await session.send_realtime_input(
281
+ media=PIL.Image.open('image.jpg'))
282
+
283
+ audio_bytes = Path('audio.pcm').read_bytes()
284
+ await session.send_realtime_input(
285
+ media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))
286
+
287
+ async for msg in session.receive():
288
+ if msg.text is not None:
289
+ print(f'{msg.text}')
290
+ ```
291
+ """
292
+ realtime_input = _t_realtime_input(media)
293
+ realtime_input_dict = realtime_input.model_dump(
294
+ exclude_none=True, mode='json'
295
+ )
296
+ await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))
297
+
298
+ async def send_tool_response(
299
+ self,
300
+ *,
301
+ function_responses: Union[
302
+ types.FunctionResponseOrDict,
303
+ Sequence[types.FunctionResponseOrDict],
304
+ ],
305
+ ):
306
+ """Send a tool response to the session.
307
+
308
+ Use `send_tool_response` to reply to `LiveServerToolCall` messages
309
+ from the server.
310
+
311
+ To set the available tools, use the `config.tools` argument
312
+ when you connect to the session (`client.live.connect`).
313
+
314
+ Args:
315
+ function_responses: A `FunctionResponse`-like object or list of
316
+ `FunctionResponse`-like objects.
317
+
318
+ Example:
319
+ ```
320
+ from google import genai
321
+ from google.genai import types
322
+
323
+ client = genai.Client(http_options={'api_version': 'v1alpha'})
324
+
325
+ tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
326
+ config = {
327
+ "tools": tools,
328
+ "response_modalities": ['TEXT']
329
+ }
330
+
331
+ async with client.aio.live.connect(
332
+ model='gemini-2.0-flash-exp',
333
+ config=config
334
+ ) as session:
335
+ prompt = "Turn on the lights please"
336
+ await session.send_client_content(
337
+ turns=prompt,
338
+ turn_complete=True)
339
+
340
+ async for chunk in session.receive():
341
+ if chunk.server_content:
342
+ if chunk.text is not None:
343
+ print(chunk.text)
344
+ elif chunk.tool_call:
345
+ print(chunk.tool_call)
346
+ print('_'*80)
347
+ function_response=types.FunctionResponse(
348
+ name='turn_on_the_lights',
349
+ response={'result': 'ok'},
350
+ id=chunk.tool_call.function_calls[0].id,
351
+ )
352
+ print(function_response)
353
+ await session.send_tool_response(
354
+ function_responses=function_response
355
+ )
356
+
357
+ print('_'*80)
358
+ """
359
+ tool_response = _t_tool_response(function_responses)
360
+ if self._api_client.vertexai:
361
+ tool_response_dict = _ToolResponse_to_vertex(
362
+ api_client=self._api_client, from_object=tool_response
363
+ )
364
+ else:
365
+ tool_response_dict = _ToolResponse_to_mldev(
366
+ api_client=self._api_client, from_object=tool_response
367
+ )
368
+ await self._ws.send(json.dumps({'tool_response': tool_response_dict}))
369
+
115
370
  async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
116
371
  """Receive model responses from the server.
117
372
 
@@ -207,7 +462,10 @@ class AsyncSession:
207
462
 
208
463
  async def _receive(self) -> types.LiveServerMessage:
209
464
  parameter_model = types.LiveServerMessage()
210
- raw_response = await self._ws.recv(decode=False)
465
+ try:
466
+ raw_response = await self._ws.recv(decode=False)
467
+ except TypeError:
468
+ raw_response = await self._ws.recv()
211
469
  if raw_response:
212
470
  try:
213
471
  response = json.loads(raw_response)
@@ -258,6 +516,12 @@ class AsyncSession:
258
516
  setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
259
517
  if getv(from_object, ['interrupted']) is not None:
260
518
  setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
519
+ if getv(from_object, ['generationComplete']) is not None:
520
+ setv(
521
+ to_object,
522
+ ['generation_complete'],
523
+ getv(from_object, ['generationComplete']),
524
+ )
261
525
  return to_object
262
526
 
263
527
  def _LiveToolCall_from_mldev(
@@ -329,6 +593,25 @@ class AsyncSession:
329
593
  )
330
594
  if getv(from_object, ['turnComplete']) is not None:
331
595
  setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
596
+ if getv(from_object, ['generationComplete']) is not None:
597
+ setv(
598
+ to_object,
599
+ ['generation_complete'],
600
+ getv(from_object, ['generationComplete']),
601
+ )
602
+ # Vertex supports transcription.
603
+ if getv(from_object, ['inputTranscription']) is not None:
604
+ setv(
605
+ to_object,
606
+ ['input_transcription'],
607
+ getv(from_object, ['inputTranscription']),
608
+ )
609
+ if getv(from_object, ['outputTranscription']) is not None:
610
+ setv(
611
+ to_object,
612
+ ['output_transcription'],
613
+ getv(from_object, ['outputTranscription']),
614
+ )
332
615
  if getv(from_object, ['interrupted']) is not None:
333
616
  setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
334
617
  return to_object
@@ -669,6 +952,79 @@ class AsyncSession:
669
952
  await self._ws.close()
670
953
 
671
954
 
955
+ def _t_content_strict(content: types.ContentOrDict):
956
+ if isinstance(content, dict):
957
+ return types.Content.model_validate(content)
958
+ elif isinstance(content, types.Content):
959
+ return content
960
+ else:
961
+ raise ValueError(
962
+ f'Could not convert input (type "{type(content)}") to '
963
+ '`types.Content`'
964
+ )
965
+
966
+
967
+ def _t_contents_strict(
968
+ contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict]):
969
+ if isinstance(contents, Sequence):
970
+ return [_t_content_strict(content) for content in contents]
971
+ else:
972
+ return [_t_content_strict(contents)]
973
+
974
+
975
+ def _t_client_content(
976
+ turns: Optional[
977
+ Union[Sequence[types.ContentOrDict], types.ContentOrDict]
978
+ ] = None,
979
+ turn_complete: bool = True,
980
+ ) -> types.LiveClientContent:
981
+ if turns is None:
982
+ return types.LiveClientContent(turn_complete=turn_complete)
983
+
984
+ try:
985
+ return types.LiveClientContent(
986
+ turns=_t_contents_strict(contents=turns),
987
+ turn_complete=turn_complete,
988
+ )
989
+ except Exception as e:
990
+ raise ValueError(
991
+ f'Could not convert input (type "{type(turns)}") to '
992
+ '`types.LiveClientContent`'
993
+ ) from e
994
+
995
+
996
+ def _t_realtime_input(
997
+ media: t.BlobUnion,
998
+ ) -> types.LiveClientRealtimeInput:
999
+ try:
1000
+ return types.LiveClientRealtimeInput(media_chunks=[t.t_blob(blob=media)])
1001
+ except Exception as e:
1002
+ raise ValueError(
1003
+ f'Could not convert input (type "{type(input)}") to '
1004
+ '`types.LiveClientRealtimeInput`'
1005
+ ) from e
1006
+
1007
+
1008
+ def _t_tool_response(
1009
+ input: Union[
1010
+ types.FunctionResponseOrDict,
1011
+ Sequence[types.FunctionResponseOrDict],
1012
+ ],
1013
+ ) -> types.LiveClientToolResponse:
1014
+ if not input:
1015
+ raise ValueError(f'A tool response is required, got: \n{input}')
1016
+
1017
+ try:
1018
+ return types.LiveClientToolResponse(
1019
+ function_responses=t.t_function_responses(function_responses=input)
1020
+ )
1021
+ except Exception as e:
1022
+ raise ValueError(
1023
+ f'Could not convert input (type "{type(input)}") to '
1024
+ '`types.LiveClientToolResponse`'
1025
+ ) from e
1026
+
1027
+
672
1028
  class AsyncLive(_api_module.BaseModule):
673
1029
  """AsyncLive. The live module is experimental."""
674
1030
 
@@ -715,7 +1071,39 @@ class AsyncLive(_api_module.BaseModule):
715
1071
  to_object,
716
1072
  )
717
1073
  }
718
-
1074
+ if getv(config, ['temperature']) is not None:
1075
+ if getv(to_object, ['generationConfig']) is not None:
1076
+ to_object['generationConfig']['temperature'] = getv(
1077
+ config, ['temperature']
1078
+ )
1079
+ else:
1080
+ to_object['generationConfig'] = {
1081
+ 'temperature': getv(config, ['temperature'])
1082
+ }
1083
+ if getv(config, ['top_p']) is not None:
1084
+ if getv(to_object, ['generationConfig']) is not None:
1085
+ to_object['generationConfig']['topP'] = getv(config, ['top_p'])
1086
+ else:
1087
+ to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
1088
+ if getv(config, ['top_k']) is not None:
1089
+ if getv(to_object, ['generationConfig']) is not None:
1090
+ to_object['generationConfig']['topK'] = getv(config, ['top_k'])
1091
+ else:
1092
+ to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
1093
+ if getv(config, ['max_output_tokens']) is not None:
1094
+ if getv(to_object, ['generationConfig']) is not None:
1095
+ to_object['generationConfig']['maxOutputTokens'] = getv(
1096
+ config, ['max_output_tokens']
1097
+ )
1098
+ else:
1099
+ to_object['generationConfig'] = {
1100
+ 'maxOutputTokens': getv(config, ['max_output_tokens'])
1101
+ }
1102
+ if getv(config, ['seed']) is not None:
1103
+ if getv(to_object, ['generationConfig']) is not None:
1104
+ to_object['generationConfig']['seed'] = getv(config, ['seed'])
1105
+ else:
1106
+ to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
719
1107
  if getv(config, ['system_instruction']) is not None:
720
1108
  setv(
721
1109
  to_object,
@@ -796,6 +1184,39 @@ class AsyncLive(_api_module.BaseModule):
796
1184
  to_object,
797
1185
  )
798
1186
  }
1187
+ if getv(config, ['temperature']) is not None:
1188
+ if getv(to_object, ['generationConfig']) is not None:
1189
+ to_object['generationConfig']['temperature'] = getv(
1190
+ config, ['temperature']
1191
+ )
1192
+ else:
1193
+ to_object['generationConfig'] = {
1194
+ 'temperature': getv(config, ['temperature'])
1195
+ }
1196
+ if getv(config, ['top_p']) is not None:
1197
+ if getv(to_object, ['generationConfig']) is not None:
1198
+ to_object['generationConfig']['topP'] = getv(config, ['top_p'])
1199
+ else:
1200
+ to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
1201
+ if getv(config, ['top_k']) is not None:
1202
+ if getv(to_object, ['generationConfig']) is not None:
1203
+ to_object['generationConfig']['topK'] = getv(config, ['top_k'])
1204
+ else:
1205
+ to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
1206
+ if getv(config, ['max_output_tokens']) is not None:
1207
+ if getv(to_object, ['generationConfig']) is not None:
1208
+ to_object['generationConfig']['maxOutputTokens'] = getv(
1209
+ config, ['max_output_tokens']
1210
+ )
1211
+ else:
1212
+ to_object['generationConfig'] = {
1213
+ 'maxOutputTokens': getv(config, ['max_output_tokens'])
1214
+ }
1215
+ if getv(config, ['seed']) is not None:
1216
+ if getv(to_object, ['generationConfig']) is not None:
1217
+ to_object['generationConfig']['seed'] = getv(config, ['seed'])
1218
+ else:
1219
+ to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
799
1220
  if getv(config, ['system_instruction']) is not None:
800
1221
  setv(
801
1222
  to_object,
@@ -819,6 +1240,24 @@ class AsyncLive(_api_module.BaseModule):
819
1240
  for item in t.t_tools(self._api_client, getv(config, ['tools']))
820
1241
  ],
821
1242
  )
1243
+ if getv(config, ['input_audio_transcription']) is not None:
1244
+ setv(
1245
+ to_object,
1246
+ ['inputAudioTranscription'],
1247
+ _AudioTranscriptionConfig_to_vertex(
1248
+ self._api_client,
1249
+ getv(config, ['input_audio_transcription']),
1250
+ ),
1251
+ )
1252
+ if getv(config, ['output_audio_transcription']) is not None:
1253
+ setv(
1254
+ to_object,
1255
+ ['outputAudioTranscription'],
1256
+ _AudioTranscriptionConfig_to_vertex(
1257
+ self._api_client,
1258
+ getv(config, ['output_audio_transcription']),
1259
+ ),
1260
+ )
822
1261
 
823
1262
  return_value = {'setup': {'model': model}}
824
1263
  return_value['setup'].update(to_object)
@@ -865,17 +1304,24 @@ class AsyncLive(_api_module.BaseModule):
865
1304
  generation_config=config.get('generation_config'),
866
1305
  response_modalities=config.get('response_modalities'),
867
1306
  speech_config=config.get('speech_config'),
1307
+ temperature=config.get('temperature'),
1308
+ top_p=config.get('top_p'),
1309
+ top_k=config.get('top_k'),
1310
+ max_output_tokens=config.get('max_output_tokens'),
1311
+ seed=config.get('seed'),
868
1312
  system_instruction=system_instruction,
869
1313
  tools=config.get('tools'),
1314
+ input_audio_transcription=config.get('input_audio_transcription'),
1315
+ output_audio_transcription=config.get('output_audio_transcription'),
870
1316
  )
871
1317
  else:
872
1318
  parameter_model = config
873
1319
 
874
1320
  if self._api_client.api_key:
875
1321
  api_key = self._api_client.api_key
876
- version = self._api_client._http_options['api_version']
1322
+ version = self._api_client._http_options.api_version
877
1323
  uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
878
- headers = self._api_client._http_options['headers']
1324
+ headers = self._api_client._http_options.headers
879
1325
  request_dict = _common.convert_to_dict(
880
1326
  self._LiveSetup_to_mldev(
881
1327
  model=transformed_model,
@@ -894,12 +1340,12 @@ class AsyncLive(_api_module.BaseModule):
894
1340
  auth_req = google.auth.transport.requests.Request()
895
1341
  creds.refresh(auth_req)
896
1342
  bearer_token = creds.token
897
- headers = self._api_client._http_options['headers']
1343
+ headers = self._api_client._http_options.headers
898
1344
  if headers is not None:
899
1345
  headers.update({
900
1346
  'Authorization': 'Bearer {}'.format(bearer_token),
901
1347
  })
902
- version = self._api_client._http_options['api_version']
1348
+ version = self._api_client._http_options.api_version
903
1349
  uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
904
1350
  location = self._api_client.location
905
1351
  project = self._api_client.project
@@ -915,8 +1361,16 @@ class AsyncLive(_api_module.BaseModule):
915
1361
  )
916
1362
  request = json.dumps(request_dict)
917
1363
 
918
- async with connect(uri, additional_headers=headers) as ws:
919
- await ws.send(request)
920
- logger.info(await ws.recv(decode=False))
1364
+ try:
1365
+ async with connect(uri, additional_headers=headers) as ws:
1366
+ await ws.send(request)
1367
+ logger.info(await ws.recv(decode=False))
1368
+
1369
+ yield AsyncSession(api_client=self._api_client, websocket=ws)
1370
+ except TypeError:
1371
+ # Try with the older websockets API
1372
+ async with connect(uri, extra_headers=headers) as ws:
1373
+ await ws.send(request)
1374
+ logger.info(await ws.recv())
921
1375
 
922
- yield AsyncSession(api_client=self._api_client, websocket=ws)
1376
+ yield AsyncSession(api_client=self._api_client, websocket=ws)