google-genai 1.7.0__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +93 -78
- google/genai/_replay_api_client.py +22 -14
- google/genai/_transformers.py +81 -16
- google/genai/batches.py +61 -295
- google/genai/caches.py +546 -526
- google/genai/chats.py +15 -8
- google/genai/client.py +5 -3
- google/genai/errors.py +47 -24
- google/genai/files.py +89 -305
- google/genai/live.py +466 -12
- google/genai/models.py +1992 -2291
- google/genai/operations.py +104 -124
- google/genai/tunings.py +256 -272
- google/genai/types.py +394 -98
- google/genai/version.py +1 -1
- {google_genai-1.7.0.dist-info → google_genai-1.9.0.dist-info}/METADATA +3 -2
- google_genai-1.9.0.dist-info/RECORD +27 -0
- {google_genai-1.7.0.dist-info → google_genai-1.9.0.dist-info}/WHEEL +1 -1
- google_genai-1.7.0.dist-info/RECORD +0 -27
- {google_genai-1.7.0.dist-info → google_genai-1.9.0.dist-info/licenses}/LICENSE +0 -0
- {google_genai-1.7.0.dist-info → google_genai-1.9.0.dist-info}/top_level.txt +0 -0
google/genai/live.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Google LLC
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -21,6 +21,7 @@ import contextlib
|
|
21
21
|
import json
|
22
22
|
import logging
|
23
23
|
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
|
24
|
+
import warnings
|
24
25
|
|
25
26
|
import google.auth
|
26
27
|
import pydantic
|
@@ -30,7 +31,6 @@ from . import _api_module
|
|
30
31
|
from . import _common
|
31
32
|
from . import _transformers as t
|
32
33
|
from . import client
|
33
|
-
from . import errors
|
34
34
|
from . import types
|
35
35
|
from ._api_client import BaseApiClient
|
36
36
|
from ._common import experimental_warning
|
@@ -65,6 +65,59 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
|
|
65
65
|
)
|
66
66
|
|
67
67
|
|
68
|
+
def _ClientContent_to_mldev(
|
69
|
+
api_client: BaseApiClient,
|
70
|
+
from_object: types.LiveClientContent,
|
71
|
+
) -> dict:
|
72
|
+
client_content = from_object.model_dump(exclude_none=True, mode='json')
|
73
|
+
if 'turns' in client_content:
|
74
|
+
client_content['turns'] = [
|
75
|
+
_Content_to_mldev(api_client=api_client, from_object=item)
|
76
|
+
for item in client_content['turns']
|
77
|
+
]
|
78
|
+
return client_content
|
79
|
+
|
80
|
+
|
81
|
+
def _ClientContent_to_vertex(
|
82
|
+
api_client: BaseApiClient,
|
83
|
+
from_object: types.LiveClientContent,
|
84
|
+
) -> dict:
|
85
|
+
client_content = from_object.model_dump(exclude_none=True, mode='json')
|
86
|
+
if 'turns' in client_content:
|
87
|
+
client_content['turns'] = [
|
88
|
+
_Content_to_vertex(api_client=api_client, from_object=item)
|
89
|
+
for item in client_content['turns']
|
90
|
+
]
|
91
|
+
return client_content
|
92
|
+
|
93
|
+
|
94
|
+
def _ToolResponse_to_mldev(
|
95
|
+
api_client: BaseApiClient,
|
96
|
+
from_object: types.LiveClientToolResponse,
|
97
|
+
) -> dict:
|
98
|
+
tool_response = from_object.model_dump(exclude_none=True, mode='json')
|
99
|
+
for response in tool_response.get('function_responses', []):
|
100
|
+
if response.get('id') is None:
|
101
|
+
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
102
|
+
return tool_response
|
103
|
+
|
104
|
+
|
105
|
+
def _ToolResponse_to_vertex(
|
106
|
+
api_client: BaseApiClient,
|
107
|
+
from_object: types.LiveClientToolResponse,
|
108
|
+
) -> dict:
|
109
|
+
tool_response = from_object.model_dump(exclude_none=True, mode='json')
|
110
|
+
return tool_response
|
111
|
+
|
112
|
+
|
113
|
+
def _AudioTranscriptionConfig_to_vertex(
|
114
|
+
api_client: BaseApiClient,
|
115
|
+
from_object: types.AudioTranscriptionConfig,
|
116
|
+
) -> dict:
|
117
|
+
audio_transcription: dict[str, Any] = {}
|
118
|
+
return audio_transcription
|
119
|
+
|
120
|
+
|
68
121
|
class AsyncSession:
|
69
122
|
"""AsyncSession. The live module is experimental."""
|
70
123
|
|
@@ -112,6 +165,208 @@ class AsyncSession:
|
|
112
165
|
client_message = self._parse_client_message(input, end_of_turn)
|
113
166
|
await self._ws.send(json.dumps(client_message))
|
114
167
|
|
168
|
+
async def send_client_content(
|
169
|
+
self,
|
170
|
+
*,
|
171
|
+
turns: Optional[
|
172
|
+
Union[
|
173
|
+
types.Content,
|
174
|
+
types.ContentDict,
|
175
|
+
list[Union[types.Content, types.ContentDict]]
|
176
|
+
]
|
177
|
+
] = None,
|
178
|
+
turn_complete: bool = True,
|
179
|
+
):
|
180
|
+
"""Send non-realtime, turn based content to the model.
|
181
|
+
|
182
|
+
There are two ways to send messages to the live API:
|
183
|
+
`send_client_content` and `send_realtime_input`.
|
184
|
+
|
185
|
+
`send_client_content` messages are added to the model context **in order**.
|
186
|
+
Having a conversation using `send_client_content` messages is roughly
|
187
|
+
equivalent to using the `Chat.send_message_stream` method, except that the
|
188
|
+
state of the `chat` history is stored on the API server.
|
189
|
+
|
190
|
+
Because of `send_client_content`'s order guarantee, the model cannot
|
191
|
+
respond as quickly to `send_client_content` messages as to
|
192
|
+
`send_realtime_input` messages. This makes the biggest difference when
|
193
|
+
sending objects that have significant preprocessing time (typically images).
|
194
|
+
|
195
|
+
The `send_client_content` message sends a list of `Content` objects,
|
196
|
+
which has more options than the `media:Blob` sent by `send_realtime_input`.
|
197
|
+
|
198
|
+
The main use-cases for `send_client_content` over `send_realtime_input` are:
|
199
|
+
|
200
|
+
- Prefilling a conversation context (including sending anything that can't
|
201
|
+
be represented as a realtime message), before starting a realtime
|
202
|
+
conversation.
|
203
|
+
- Conducting a non-realtime conversation, similar to `client.chat`, using
|
204
|
+
the live api.
|
205
|
+
|
206
|
+
Caution: Interleaving `send_client_content` and `send_realtime_input`
|
207
|
+
in the same conversation is not recommended and can lead to unexpected
|
208
|
+
results.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
turns: A `Content` object or list of `Content` objects (or equivalent
|
212
|
+
dicts).
|
213
|
+
turn_complete: if true (the default) the model will reply immediately. If
|
214
|
+
false, the model will wait for you to send additional client_content,
|
215
|
+
and will not return until you send `turn_complete=True`.
|
216
|
+
|
217
|
+
Example:
|
218
|
+
```
|
219
|
+
import google.genai
|
220
|
+
from google.genai import types
|
221
|
+
|
222
|
+
client = genai.Client(http_options={'api_version': 'v1alpha'})
|
223
|
+
async with client.aio.live.connect(
|
224
|
+
model=MODEL_NAME,
|
225
|
+
config={"response_modalities": ["TEXT"]}
|
226
|
+
) as session:
|
227
|
+
await session.send_client_content(
|
228
|
+
turns=types.Content(
|
229
|
+
role='user',
|
230
|
+
parts=[types.Part(text="Hello world!")]))
|
231
|
+
async for msg in session.receive():
|
232
|
+
if msg.text:
|
233
|
+
print(msg.text)
|
234
|
+
```
|
235
|
+
"""
|
236
|
+
client_content = _t_client_content(turns, turn_complete)
|
237
|
+
|
238
|
+
if self._api_client.vertexai:
|
239
|
+
client_content_dict = _ClientContent_to_vertex(
|
240
|
+
api_client=self._api_client, from_object=client_content
|
241
|
+
)
|
242
|
+
else:
|
243
|
+
client_content_dict = _ClientContent_to_mldev(
|
244
|
+
api_client=self._api_client, from_object=client_content
|
245
|
+
)
|
246
|
+
|
247
|
+
await self._ws.send(json.dumps({'client_content': client_content_dict}))
|
248
|
+
|
249
|
+
async def send_realtime_input(self, *, media: t.BlobUnion):
|
250
|
+
"""Send realtime media chunks to the model.
|
251
|
+
|
252
|
+
Use `send_realtime_input` for realtime audio chunks and video
|
253
|
+
frames(images).
|
254
|
+
|
255
|
+
With `send_realtime_input` the api will respond to audio automatically
|
256
|
+
based on voice activity detection (VAD).
|
257
|
+
|
258
|
+
`send_realtime_input` is optimized for responsivness at the expense of
|
259
|
+
deterministic ordering. Audio and video tokens are added to the
|
260
|
+
context when they become available.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
media: A `Blob`-like object, the realtime media to send.
|
264
|
+
|
265
|
+
Example:
|
266
|
+
```
|
267
|
+
from pathlib import Path
|
268
|
+
|
269
|
+
from google import genai
|
270
|
+
from google.genai import types
|
271
|
+
|
272
|
+
import PIL.Image
|
273
|
+
|
274
|
+
client = genai.Client(http_options= {'api_version': 'v1alpha'})
|
275
|
+
|
276
|
+
async with client.aio.live.connect(
|
277
|
+
model=MODEL_NAME,
|
278
|
+
config={"response_modalities": ["TEXT"]},
|
279
|
+
) as session:
|
280
|
+
await session.send_realtime_input(
|
281
|
+
media=PIL.Image.open('image.jpg'))
|
282
|
+
|
283
|
+
audio_bytes = Path('audio.pcm').read_bytes()
|
284
|
+
await session.send_realtime_input(
|
285
|
+
media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))
|
286
|
+
|
287
|
+
async for msg in session.receive():
|
288
|
+
if msg.text is not None:
|
289
|
+
print(f'{msg.text}')
|
290
|
+
```
|
291
|
+
"""
|
292
|
+
realtime_input = _t_realtime_input(media)
|
293
|
+
realtime_input_dict = realtime_input.model_dump(
|
294
|
+
exclude_none=True, mode='json'
|
295
|
+
)
|
296
|
+
await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))
|
297
|
+
|
298
|
+
async def send_tool_response(
|
299
|
+
self,
|
300
|
+
*,
|
301
|
+
function_responses: Union[
|
302
|
+
types.FunctionResponseOrDict,
|
303
|
+
Sequence[types.FunctionResponseOrDict],
|
304
|
+
],
|
305
|
+
):
|
306
|
+
"""Send a tool response to the session.
|
307
|
+
|
308
|
+
Use `send_tool_response` to reply to `LiveServerToolCall` messages
|
309
|
+
from the server.
|
310
|
+
|
311
|
+
To set the available tools, use the `config.tools` argument
|
312
|
+
when you connect to the session (`client.live.connect`).
|
313
|
+
|
314
|
+
Args:
|
315
|
+
function_responses: A `FunctionResponse`-like object or list of
|
316
|
+
`FunctionResponse`-like objects.
|
317
|
+
|
318
|
+
Example:
|
319
|
+
```
|
320
|
+
from google import genai
|
321
|
+
from google.genai import types
|
322
|
+
|
323
|
+
client = genai.Client(http_options={'api_version': 'v1alpha'})
|
324
|
+
|
325
|
+
tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
|
326
|
+
config = {
|
327
|
+
"tools": tools,
|
328
|
+
"response_modalities": ['TEXT']
|
329
|
+
}
|
330
|
+
|
331
|
+
async with client.aio.live.connect(
|
332
|
+
model='gemini-2.0-flash-exp',
|
333
|
+
config=config
|
334
|
+
) as session:
|
335
|
+
prompt = "Turn on the lights please"
|
336
|
+
await session.send_client_content(
|
337
|
+
turns=prompt,
|
338
|
+
turn_complete=True)
|
339
|
+
|
340
|
+
async for chunk in session.receive():
|
341
|
+
if chunk.server_content:
|
342
|
+
if chunk.text is not None:
|
343
|
+
print(chunk.text)
|
344
|
+
elif chunk.tool_call:
|
345
|
+
print(chunk.tool_call)
|
346
|
+
print('_'*80)
|
347
|
+
function_response=types.FunctionResponse(
|
348
|
+
name='turn_on_the_lights',
|
349
|
+
response={'result': 'ok'},
|
350
|
+
id=chunk.tool_call.function_calls[0].id,
|
351
|
+
)
|
352
|
+
print(function_response)
|
353
|
+
await session.send_tool_response(
|
354
|
+
function_responses=function_response
|
355
|
+
)
|
356
|
+
|
357
|
+
print('_'*80)
|
358
|
+
"""
|
359
|
+
tool_response = _t_tool_response(function_responses)
|
360
|
+
if self._api_client.vertexai:
|
361
|
+
tool_response_dict = _ToolResponse_to_vertex(
|
362
|
+
api_client=self._api_client, from_object=tool_response
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
tool_response_dict = _ToolResponse_to_mldev(
|
366
|
+
api_client=self._api_client, from_object=tool_response
|
367
|
+
)
|
368
|
+
await self._ws.send(json.dumps({'tool_response': tool_response_dict}))
|
369
|
+
|
115
370
|
async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
|
116
371
|
"""Receive model responses from the server.
|
117
372
|
|
@@ -207,7 +462,10 @@ class AsyncSession:
|
|
207
462
|
|
208
463
|
async def _receive(self) -> types.LiveServerMessage:
|
209
464
|
parameter_model = types.LiveServerMessage()
|
210
|
-
|
465
|
+
try:
|
466
|
+
raw_response = await self._ws.recv(decode=False)
|
467
|
+
except TypeError:
|
468
|
+
raw_response = await self._ws.recv()
|
211
469
|
if raw_response:
|
212
470
|
try:
|
213
471
|
response = json.loads(raw_response)
|
@@ -258,6 +516,12 @@ class AsyncSession:
|
|
258
516
|
setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
|
259
517
|
if getv(from_object, ['interrupted']) is not None:
|
260
518
|
setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
|
519
|
+
if getv(from_object, ['generationComplete']) is not None:
|
520
|
+
setv(
|
521
|
+
to_object,
|
522
|
+
['generation_complete'],
|
523
|
+
getv(from_object, ['generationComplete']),
|
524
|
+
)
|
261
525
|
return to_object
|
262
526
|
|
263
527
|
def _LiveToolCall_from_mldev(
|
@@ -329,6 +593,25 @@ class AsyncSession:
|
|
329
593
|
)
|
330
594
|
if getv(from_object, ['turnComplete']) is not None:
|
331
595
|
setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
|
596
|
+
if getv(from_object, ['generationComplete']) is not None:
|
597
|
+
setv(
|
598
|
+
to_object,
|
599
|
+
['generation_complete'],
|
600
|
+
getv(from_object, ['generationComplete']),
|
601
|
+
)
|
602
|
+
# Vertex supports transcription.
|
603
|
+
if getv(from_object, ['inputTranscription']) is not None:
|
604
|
+
setv(
|
605
|
+
to_object,
|
606
|
+
['input_transcription'],
|
607
|
+
getv(from_object, ['inputTranscription']),
|
608
|
+
)
|
609
|
+
if getv(from_object, ['outputTranscription']) is not None:
|
610
|
+
setv(
|
611
|
+
to_object,
|
612
|
+
['output_transcription'],
|
613
|
+
getv(from_object, ['outputTranscription']),
|
614
|
+
)
|
332
615
|
if getv(from_object, ['interrupted']) is not None:
|
333
616
|
setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
|
334
617
|
return to_object
|
@@ -669,6 +952,79 @@ class AsyncSession:
|
|
669
952
|
await self._ws.close()
|
670
953
|
|
671
954
|
|
955
|
+
def _t_content_strict(content: types.ContentOrDict):
|
956
|
+
if isinstance(content, dict):
|
957
|
+
return types.Content.model_validate(content)
|
958
|
+
elif isinstance(content, types.Content):
|
959
|
+
return content
|
960
|
+
else:
|
961
|
+
raise ValueError(
|
962
|
+
f'Could not convert input (type "{type(content)}") to '
|
963
|
+
'`types.Content`'
|
964
|
+
)
|
965
|
+
|
966
|
+
|
967
|
+
def _t_contents_strict(
|
968
|
+
contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict]):
|
969
|
+
if isinstance(contents, Sequence):
|
970
|
+
return [_t_content_strict(content) for content in contents]
|
971
|
+
else:
|
972
|
+
return [_t_content_strict(contents)]
|
973
|
+
|
974
|
+
|
975
|
+
def _t_client_content(
|
976
|
+
turns: Optional[
|
977
|
+
Union[Sequence[types.ContentOrDict], types.ContentOrDict]
|
978
|
+
] = None,
|
979
|
+
turn_complete: bool = True,
|
980
|
+
) -> types.LiveClientContent:
|
981
|
+
if turns is None:
|
982
|
+
return types.LiveClientContent(turn_complete=turn_complete)
|
983
|
+
|
984
|
+
try:
|
985
|
+
return types.LiveClientContent(
|
986
|
+
turns=_t_contents_strict(contents=turns),
|
987
|
+
turn_complete=turn_complete,
|
988
|
+
)
|
989
|
+
except Exception as e:
|
990
|
+
raise ValueError(
|
991
|
+
f'Could not convert input (type "{type(turns)}") to '
|
992
|
+
'`types.LiveClientContent`'
|
993
|
+
) from e
|
994
|
+
|
995
|
+
|
996
|
+
def _t_realtime_input(
|
997
|
+
media: t.BlobUnion,
|
998
|
+
) -> types.LiveClientRealtimeInput:
|
999
|
+
try:
|
1000
|
+
return types.LiveClientRealtimeInput(media_chunks=[t.t_blob(blob=media)])
|
1001
|
+
except Exception as e:
|
1002
|
+
raise ValueError(
|
1003
|
+
f'Could not convert input (type "{type(input)}") to '
|
1004
|
+
'`types.LiveClientRealtimeInput`'
|
1005
|
+
) from e
|
1006
|
+
|
1007
|
+
|
1008
|
+
def _t_tool_response(
|
1009
|
+
input: Union[
|
1010
|
+
types.FunctionResponseOrDict,
|
1011
|
+
Sequence[types.FunctionResponseOrDict],
|
1012
|
+
],
|
1013
|
+
) -> types.LiveClientToolResponse:
|
1014
|
+
if not input:
|
1015
|
+
raise ValueError(f'A tool response is required, got: \n{input}')
|
1016
|
+
|
1017
|
+
try:
|
1018
|
+
return types.LiveClientToolResponse(
|
1019
|
+
function_responses=t.t_function_responses(function_responses=input)
|
1020
|
+
)
|
1021
|
+
except Exception as e:
|
1022
|
+
raise ValueError(
|
1023
|
+
f'Could not convert input (type "{type(input)}") to '
|
1024
|
+
'`types.LiveClientToolResponse`'
|
1025
|
+
) from e
|
1026
|
+
|
1027
|
+
|
672
1028
|
class AsyncLive(_api_module.BaseModule):
|
673
1029
|
"""AsyncLive. The live module is experimental."""
|
674
1030
|
|
@@ -715,7 +1071,39 @@ class AsyncLive(_api_module.BaseModule):
|
|
715
1071
|
to_object,
|
716
1072
|
)
|
717
1073
|
}
|
718
|
-
|
1074
|
+
if getv(config, ['temperature']) is not None:
|
1075
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1076
|
+
to_object['generationConfig']['temperature'] = getv(
|
1077
|
+
config, ['temperature']
|
1078
|
+
)
|
1079
|
+
else:
|
1080
|
+
to_object['generationConfig'] = {
|
1081
|
+
'temperature': getv(config, ['temperature'])
|
1082
|
+
}
|
1083
|
+
if getv(config, ['top_p']) is not None:
|
1084
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1085
|
+
to_object['generationConfig']['topP'] = getv(config, ['top_p'])
|
1086
|
+
else:
|
1087
|
+
to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
|
1088
|
+
if getv(config, ['top_k']) is not None:
|
1089
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1090
|
+
to_object['generationConfig']['topK'] = getv(config, ['top_k'])
|
1091
|
+
else:
|
1092
|
+
to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
|
1093
|
+
if getv(config, ['max_output_tokens']) is not None:
|
1094
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1095
|
+
to_object['generationConfig']['maxOutputTokens'] = getv(
|
1096
|
+
config, ['max_output_tokens']
|
1097
|
+
)
|
1098
|
+
else:
|
1099
|
+
to_object['generationConfig'] = {
|
1100
|
+
'maxOutputTokens': getv(config, ['max_output_tokens'])
|
1101
|
+
}
|
1102
|
+
if getv(config, ['seed']) is not None:
|
1103
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1104
|
+
to_object['generationConfig']['seed'] = getv(config, ['seed'])
|
1105
|
+
else:
|
1106
|
+
to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
|
719
1107
|
if getv(config, ['system_instruction']) is not None:
|
720
1108
|
setv(
|
721
1109
|
to_object,
|
@@ -796,6 +1184,39 @@ class AsyncLive(_api_module.BaseModule):
|
|
796
1184
|
to_object,
|
797
1185
|
)
|
798
1186
|
}
|
1187
|
+
if getv(config, ['temperature']) is not None:
|
1188
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1189
|
+
to_object['generationConfig']['temperature'] = getv(
|
1190
|
+
config, ['temperature']
|
1191
|
+
)
|
1192
|
+
else:
|
1193
|
+
to_object['generationConfig'] = {
|
1194
|
+
'temperature': getv(config, ['temperature'])
|
1195
|
+
}
|
1196
|
+
if getv(config, ['top_p']) is not None:
|
1197
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1198
|
+
to_object['generationConfig']['topP'] = getv(config, ['top_p'])
|
1199
|
+
else:
|
1200
|
+
to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
|
1201
|
+
if getv(config, ['top_k']) is not None:
|
1202
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1203
|
+
to_object['generationConfig']['topK'] = getv(config, ['top_k'])
|
1204
|
+
else:
|
1205
|
+
to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
|
1206
|
+
if getv(config, ['max_output_tokens']) is not None:
|
1207
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1208
|
+
to_object['generationConfig']['maxOutputTokens'] = getv(
|
1209
|
+
config, ['max_output_tokens']
|
1210
|
+
)
|
1211
|
+
else:
|
1212
|
+
to_object['generationConfig'] = {
|
1213
|
+
'maxOutputTokens': getv(config, ['max_output_tokens'])
|
1214
|
+
}
|
1215
|
+
if getv(config, ['seed']) is not None:
|
1216
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1217
|
+
to_object['generationConfig']['seed'] = getv(config, ['seed'])
|
1218
|
+
else:
|
1219
|
+
to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
|
799
1220
|
if getv(config, ['system_instruction']) is not None:
|
800
1221
|
setv(
|
801
1222
|
to_object,
|
@@ -819,6 +1240,24 @@ class AsyncLive(_api_module.BaseModule):
|
|
819
1240
|
for item in t.t_tools(self._api_client, getv(config, ['tools']))
|
820
1241
|
],
|
821
1242
|
)
|
1243
|
+
if getv(config, ['input_audio_transcription']) is not None:
|
1244
|
+
setv(
|
1245
|
+
to_object,
|
1246
|
+
['inputAudioTranscription'],
|
1247
|
+
_AudioTranscriptionConfig_to_vertex(
|
1248
|
+
self._api_client,
|
1249
|
+
getv(config, ['input_audio_transcription']),
|
1250
|
+
),
|
1251
|
+
)
|
1252
|
+
if getv(config, ['output_audio_transcription']) is not None:
|
1253
|
+
setv(
|
1254
|
+
to_object,
|
1255
|
+
['outputAudioTranscription'],
|
1256
|
+
_AudioTranscriptionConfig_to_vertex(
|
1257
|
+
self._api_client,
|
1258
|
+
getv(config, ['output_audio_transcription']),
|
1259
|
+
),
|
1260
|
+
)
|
822
1261
|
|
823
1262
|
return_value = {'setup': {'model': model}}
|
824
1263
|
return_value['setup'].update(to_object)
|
@@ -865,17 +1304,24 @@ class AsyncLive(_api_module.BaseModule):
|
|
865
1304
|
generation_config=config.get('generation_config'),
|
866
1305
|
response_modalities=config.get('response_modalities'),
|
867
1306
|
speech_config=config.get('speech_config'),
|
1307
|
+
temperature=config.get('temperature'),
|
1308
|
+
top_p=config.get('top_p'),
|
1309
|
+
top_k=config.get('top_k'),
|
1310
|
+
max_output_tokens=config.get('max_output_tokens'),
|
1311
|
+
seed=config.get('seed'),
|
868
1312
|
system_instruction=system_instruction,
|
869
1313
|
tools=config.get('tools'),
|
1314
|
+
input_audio_transcription=config.get('input_audio_transcription'),
|
1315
|
+
output_audio_transcription=config.get('output_audio_transcription'),
|
870
1316
|
)
|
871
1317
|
else:
|
872
1318
|
parameter_model = config
|
873
1319
|
|
874
1320
|
if self._api_client.api_key:
|
875
1321
|
api_key = self._api_client.api_key
|
876
|
-
version = self._api_client._http_options
|
1322
|
+
version = self._api_client._http_options.api_version
|
877
1323
|
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
|
878
|
-
headers = self._api_client._http_options
|
1324
|
+
headers = self._api_client._http_options.headers
|
879
1325
|
request_dict = _common.convert_to_dict(
|
880
1326
|
self._LiveSetup_to_mldev(
|
881
1327
|
model=transformed_model,
|
@@ -894,12 +1340,12 @@ class AsyncLive(_api_module.BaseModule):
|
|
894
1340
|
auth_req = google.auth.transport.requests.Request()
|
895
1341
|
creds.refresh(auth_req)
|
896
1342
|
bearer_token = creds.token
|
897
|
-
headers = self._api_client._http_options
|
1343
|
+
headers = self._api_client._http_options.headers
|
898
1344
|
if headers is not None:
|
899
1345
|
headers.update({
|
900
1346
|
'Authorization': 'Bearer {}'.format(bearer_token),
|
901
1347
|
})
|
902
|
-
version = self._api_client._http_options
|
1348
|
+
version = self._api_client._http_options.api_version
|
903
1349
|
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
|
904
1350
|
location = self._api_client.location
|
905
1351
|
project = self._api_client.project
|
@@ -915,8 +1361,16 @@ class AsyncLive(_api_module.BaseModule):
|
|
915
1361
|
)
|
916
1362
|
request = json.dumps(request_dict)
|
917
1363
|
|
918
|
-
|
919
|
-
|
920
|
-
|
1364
|
+
try:
|
1365
|
+
async with connect(uri, additional_headers=headers) as ws:
|
1366
|
+
await ws.send(request)
|
1367
|
+
logger.info(await ws.recv(decode=False))
|
1368
|
+
|
1369
|
+
yield AsyncSession(api_client=self._api_client, websocket=ws)
|
1370
|
+
except TypeError:
|
1371
|
+
# Try with the older websockets API
|
1372
|
+
async with connect(uri, extra_headers=headers) as ws:
|
1373
|
+
await ws.send(request)
|
1374
|
+
logger.info(await ws.recv())
|
921
1375
|
|
922
|
-
|
1376
|
+
yield AsyncSession(api_client=self._api_client, websocket=ws)
|