google-genai 1.8.0__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/_api_client.py +117 -28
- google/genai/_automatic_function_calling_util.py +1 -1
- google/genai/_extra_utils.py +1 -1
- google/genai/_replay_api_client.py +32 -8
- google/genai/_transformers.py +101 -61
- google/genai/batches.py +1 -1
- google/genai/caches.py +1 -1
- google/genai/errors.py +1 -1
- google/genai/files.py +23 -7
- google/genai/live.py +996 -43
- google/genai/models.py +24 -10
- google/genai/operations.py +18 -10
- google/genai/tunings.py +1 -4
- google/genai/types.py +742 -81
- google/genai/version.py +1 -1
- {google_genai-1.8.0.dist-info → google_genai-1.10.0.dist-info}/METADATA +1 -1
- google_genai-1.10.0.dist-info/RECORD +27 -0
- google_genai-1.8.0.dist-info/RECORD +0 -27
- {google_genai-1.8.0.dist-info → google_genai-1.10.0.dist-info}/WHEEL +0 -0
- {google_genai-1.8.0.dist-info → google_genai-1.10.0.dist-info}/licenses/LICENSE +0 -0
- {google_genai-1.8.0.dist-info → google_genai-1.10.0.dist-info}/top_level.txt +0 -0
google/genai/live.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Google LLC
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,14 +13,15 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
#
|
15
15
|
|
16
|
-
"""Live client.
|
16
|
+
"""[Preview] Live API client."""
|
17
17
|
|
18
18
|
import asyncio
|
19
19
|
import base64
|
20
20
|
import contextlib
|
21
21
|
import json
|
22
22
|
import logging
|
23
|
-
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, get_args
|
23
|
+
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, cast, get_args
|
24
|
+
import warnings
|
24
25
|
|
25
26
|
import google.auth
|
26
27
|
import pydantic
|
@@ -30,7 +31,6 @@ from . import _api_module
|
|
30
31
|
from . import _common
|
31
32
|
from . import _transformers as t
|
32
33
|
from . import client
|
33
|
-
from . import errors
|
34
34
|
from . import types
|
35
35
|
from ._api_client import BaseApiClient
|
36
36
|
from ._common import experimental_warning
|
@@ -65,8 +65,68 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
|
|
65
65
|
)
|
66
66
|
|
67
67
|
|
68
|
+
def _ClientContent_to_mldev(
|
69
|
+
api_client: BaseApiClient,
|
70
|
+
from_object: types.LiveClientContent,
|
71
|
+
) -> dict:
|
72
|
+
client_content = from_object.model_dump(exclude_none=True, mode='json')
|
73
|
+
if 'turns' in client_content:
|
74
|
+
client_content['turns'] = [
|
75
|
+
_Content_to_mldev(api_client=api_client, from_object=item)
|
76
|
+
for item in client_content['turns']
|
77
|
+
]
|
78
|
+
return client_content
|
79
|
+
|
80
|
+
|
81
|
+
def _ClientContent_to_vertex(
|
82
|
+
api_client: BaseApiClient,
|
83
|
+
from_object: types.LiveClientContent,
|
84
|
+
) -> dict:
|
85
|
+
client_content = from_object.model_dump(exclude_none=True, mode='json')
|
86
|
+
if 'turns' in client_content:
|
87
|
+
client_content['turns'] = [
|
88
|
+
_Content_to_vertex(api_client=api_client, from_object=item)
|
89
|
+
for item in client_content['turns']
|
90
|
+
]
|
91
|
+
return client_content
|
92
|
+
|
93
|
+
|
94
|
+
def _ToolResponse_to_mldev(
|
95
|
+
api_client: BaseApiClient,
|
96
|
+
from_object: types.LiveClientToolResponse,
|
97
|
+
) -> dict:
|
98
|
+
tool_response = from_object.model_dump(exclude_none=True, mode='json')
|
99
|
+
for response in tool_response.get('function_responses', []):
|
100
|
+
if response.get('id') is None:
|
101
|
+
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
102
|
+
return tool_response
|
103
|
+
|
104
|
+
|
105
|
+
def _ToolResponse_to_vertex(
|
106
|
+
api_client: BaseApiClient,
|
107
|
+
from_object: types.LiveClientToolResponse,
|
108
|
+
) -> dict:
|
109
|
+
tool_response = from_object.model_dump(exclude_none=True, mode='json')
|
110
|
+
return tool_response
|
111
|
+
|
112
|
+
def _AudioTranscriptionConfig_to_mldev(
|
113
|
+
api_client: BaseApiClient,
|
114
|
+
from_object: types.AudioTranscriptionConfig,
|
115
|
+
) -> dict:
|
116
|
+
audio_transcription: dict[str, Any] = {}
|
117
|
+
return audio_transcription
|
118
|
+
|
119
|
+
|
120
|
+
def _AudioTranscriptionConfig_to_vertex(
|
121
|
+
api_client: BaseApiClient,
|
122
|
+
from_object: types.AudioTranscriptionConfig,
|
123
|
+
) -> dict:
|
124
|
+
audio_transcription: dict[str, Any] = {}
|
125
|
+
return audio_transcription
|
126
|
+
|
127
|
+
|
68
128
|
class AsyncSession:
|
69
|
-
"""AsyncSession.
|
129
|
+
"""[Preview] AsyncSession."""
|
70
130
|
|
71
131
|
def __init__(
|
72
132
|
self, api_client: client.BaseApiClient, websocket: ClientConnection
|
@@ -90,7 +150,12 @@ class AsyncSession:
|
|
90
150
|
] = None,
|
91
151
|
end_of_turn: Optional[bool] = False,
|
92
152
|
):
|
93
|
-
"""Send input to the model.
|
153
|
+
"""[Deprecated] Send input to the model.
|
154
|
+
|
155
|
+
> **Warning**: This method is deprecated and will be removed in a future
|
156
|
+
version (not before Q3 2025). Please use one of the more specific methods:
|
157
|
+
`send_client_content`, `send_realtime_input`, or `send_tool_response`
|
158
|
+
instead.
|
94
159
|
|
95
160
|
The method will send the input request to the server.
|
96
161
|
|
@@ -109,9 +174,219 @@ class AsyncSession:
|
|
109
174
|
async for message in session.receive():
|
110
175
|
print(message)
|
111
176
|
"""
|
177
|
+
warnings.warn(
|
178
|
+
'The `session.send` method is deprecated and will be removed in a '
|
179
|
+
'future version (not before Q3 2025).\n'
|
180
|
+
'Please use one of the more specific methods: `send_client_content`, '
|
181
|
+
'`send_realtime_input`, or `send_tool_response` instead.',
|
182
|
+
DeprecationWarning,
|
183
|
+
stacklevel=2,
|
184
|
+
)
|
112
185
|
client_message = self._parse_client_message(input, end_of_turn)
|
113
186
|
await self._ws.send(json.dumps(client_message))
|
114
187
|
|
188
|
+
async def send_client_content(
|
189
|
+
self,
|
190
|
+
*,
|
191
|
+
turns: Optional[
|
192
|
+
Union[
|
193
|
+
types.Content,
|
194
|
+
types.ContentDict,
|
195
|
+
list[Union[types.Content, types.ContentDict]]
|
196
|
+
]
|
197
|
+
] = None,
|
198
|
+
turn_complete: bool = True,
|
199
|
+
):
|
200
|
+
"""Send non-realtime, turn based content to the model.
|
201
|
+
|
202
|
+
There are two ways to send messages to the live API:
|
203
|
+
`send_client_content` and `send_realtime_input`.
|
204
|
+
|
205
|
+
`send_client_content` messages are added to the model context **in order**.
|
206
|
+
Having a conversation using `send_client_content` messages is roughly
|
207
|
+
equivalent to using the `Chat.send_message_stream` method, except that the
|
208
|
+
state of the `chat` history is stored on the API server.
|
209
|
+
|
210
|
+
Because of `send_client_content`'s order guarantee, the model cannot
|
211
|
+
respond as quickly to `send_client_content` messages as to
|
212
|
+
`send_realtime_input` messages. This makes the biggest difference when
|
213
|
+
sending objects that have significant preprocessing time (typically images).
|
214
|
+
|
215
|
+
The `send_client_content` message sends a list of `Content` objects,
|
216
|
+
which has more options than the `media:Blob` sent by `send_realtime_input`.
|
217
|
+
|
218
|
+
The main use-cases for `send_client_content` over `send_realtime_input` are:
|
219
|
+
|
220
|
+
- Prefilling a conversation context (including sending anything that can't
|
221
|
+
be represented as a realtime message), before starting a realtime
|
222
|
+
conversation.
|
223
|
+
- Conducting a non-realtime conversation, similar to `client.chat`, using
|
224
|
+
the live api.
|
225
|
+
|
226
|
+
Caution: Interleaving `send_client_content` and `send_realtime_input`
|
227
|
+
in the same conversation is not recommended and can lead to unexpected
|
228
|
+
results.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
turns: A `Content` object or list of `Content` objects (or equivalent
|
232
|
+
dicts).
|
233
|
+
turn_complete: if true (the default) the model will reply immediately. If
|
234
|
+
false, the model will wait for you to send additional client_content,
|
235
|
+
and will not return until you send `turn_complete=True`.
|
236
|
+
|
237
|
+
Example:
|
238
|
+
```
|
239
|
+
import google.genai
|
240
|
+
from google.genai import types
|
241
|
+
|
242
|
+
client = genai.Client(http_options={'api_version': 'v1alpha'})
|
243
|
+
async with client.aio.live.connect(
|
244
|
+
model=MODEL_NAME,
|
245
|
+
config={"response_modalities": ["TEXT"]}
|
246
|
+
) as session:
|
247
|
+
await session.send_client_content(
|
248
|
+
turns=types.Content(
|
249
|
+
role='user',
|
250
|
+
parts=[types.Part(text="Hello world!")]))
|
251
|
+
async for msg in session.receive():
|
252
|
+
if msg.text:
|
253
|
+
print(msg.text)
|
254
|
+
```
|
255
|
+
"""
|
256
|
+
client_content = _t_client_content(turns, turn_complete)
|
257
|
+
|
258
|
+
if self._api_client.vertexai:
|
259
|
+
client_content_dict = _ClientContent_to_vertex(
|
260
|
+
api_client=self._api_client, from_object=client_content
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
client_content_dict = _ClientContent_to_mldev(
|
264
|
+
api_client=self._api_client, from_object=client_content
|
265
|
+
)
|
266
|
+
|
267
|
+
await self._ws.send(json.dumps({'client_content': client_content_dict}))
|
268
|
+
|
269
|
+
async def send_realtime_input(self, *, media: t.BlobUnion):
|
270
|
+
"""Send realtime media chunks to the model.
|
271
|
+
|
272
|
+
Use `send_realtime_input` for realtime audio chunks and video
|
273
|
+
frames(images).
|
274
|
+
|
275
|
+
With `send_realtime_input` the api will respond to audio automatically
|
276
|
+
based on voice activity detection (VAD).
|
277
|
+
|
278
|
+
`send_realtime_input` is optimized for responsivness at the expense of
|
279
|
+
deterministic ordering. Audio and video tokens are added to the
|
280
|
+
context when they become available.
|
281
|
+
|
282
|
+
Args:
|
283
|
+
media: A `Blob`-like object, the realtime media to send.
|
284
|
+
|
285
|
+
Example:
|
286
|
+
```
|
287
|
+
from pathlib import Path
|
288
|
+
|
289
|
+
from google import genai
|
290
|
+
from google.genai import types
|
291
|
+
|
292
|
+
import PIL.Image
|
293
|
+
|
294
|
+
client = genai.Client(http_options= {'api_version': 'v1alpha'})
|
295
|
+
|
296
|
+
async with client.aio.live.connect(
|
297
|
+
model=MODEL_NAME,
|
298
|
+
config={"response_modalities": ["TEXT"]},
|
299
|
+
) as session:
|
300
|
+
await session.send_realtime_input(
|
301
|
+
media=PIL.Image.open('image.jpg'))
|
302
|
+
|
303
|
+
audio_bytes = Path('audio.pcm').read_bytes()
|
304
|
+
await session.send_realtime_input(
|
305
|
+
media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))
|
306
|
+
|
307
|
+
async for msg in session.receive():
|
308
|
+
if msg.text is not None:
|
309
|
+
print(f'{msg.text}')
|
310
|
+
```
|
311
|
+
"""
|
312
|
+
realtime_input = _t_realtime_input(media)
|
313
|
+
realtime_input_dict = realtime_input.model_dump(
|
314
|
+
exclude_none=True, mode='json'
|
315
|
+
)
|
316
|
+
await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))
|
317
|
+
|
318
|
+
async def send_tool_response(
|
319
|
+
self,
|
320
|
+
*,
|
321
|
+
function_responses: Union[
|
322
|
+
types.FunctionResponseOrDict,
|
323
|
+
Sequence[types.FunctionResponseOrDict],
|
324
|
+
],
|
325
|
+
):
|
326
|
+
"""Send a tool response to the session.
|
327
|
+
|
328
|
+
Use `send_tool_response` to reply to `LiveServerToolCall` messages
|
329
|
+
from the server.
|
330
|
+
|
331
|
+
To set the available tools, use the `config.tools` argument
|
332
|
+
when you connect to the session (`client.live.connect`).
|
333
|
+
|
334
|
+
Args:
|
335
|
+
function_responses: A `FunctionResponse`-like object or list of
|
336
|
+
`FunctionResponse`-like objects.
|
337
|
+
|
338
|
+
Example:
|
339
|
+
```
|
340
|
+
from google import genai
|
341
|
+
from google.genai import types
|
342
|
+
|
343
|
+
client = genai.Client(http_options={'api_version': 'v1alpha'})
|
344
|
+
|
345
|
+
tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
|
346
|
+
config = {
|
347
|
+
"tools": tools,
|
348
|
+
"response_modalities": ['TEXT']
|
349
|
+
}
|
350
|
+
|
351
|
+
async with client.aio.live.connect(
|
352
|
+
model='gemini-2.0-flash-exp',
|
353
|
+
config=config
|
354
|
+
) as session:
|
355
|
+
prompt = "Turn on the lights please"
|
356
|
+
await session.send_client_content(
|
357
|
+
turns=prompt,
|
358
|
+
turn_complete=True)
|
359
|
+
|
360
|
+
async for chunk in session.receive():
|
361
|
+
if chunk.server_content:
|
362
|
+
if chunk.text is not None:
|
363
|
+
print(chunk.text)
|
364
|
+
elif chunk.tool_call:
|
365
|
+
print(chunk.tool_call)
|
366
|
+
print('_'*80)
|
367
|
+
function_response=types.FunctionResponse(
|
368
|
+
name='turn_on_the_lights',
|
369
|
+
response={'result': 'ok'},
|
370
|
+
id=chunk.tool_call.function_calls[0].id,
|
371
|
+
)
|
372
|
+
print(function_response)
|
373
|
+
await session.send_tool_response(
|
374
|
+
function_responses=function_response
|
375
|
+
)
|
376
|
+
|
377
|
+
print('_'*80)
|
378
|
+
"""
|
379
|
+
tool_response = _t_tool_response(function_responses)
|
380
|
+
if self._api_client.vertexai:
|
381
|
+
tool_response_dict = _ToolResponse_to_vertex(
|
382
|
+
api_client=self._api_client, from_object=tool_response
|
383
|
+
)
|
384
|
+
else:
|
385
|
+
tool_response_dict = _ToolResponse_to_mldev(
|
386
|
+
api_client=self._api_client, from_object=tool_response
|
387
|
+
)
|
388
|
+
await self._ws.send(json.dumps({'tool_response': tool_response_dict}))
|
389
|
+
|
115
390
|
async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
|
116
391
|
"""Receive model responses from the server.
|
117
392
|
|
@@ -120,8 +395,6 @@ class AsyncSession:
|
|
120
395
|
is function call, user must call `send` with the function response to
|
121
396
|
continue the turn.
|
122
397
|
|
123
|
-
The live module is experimental.
|
124
|
-
|
125
398
|
Yields:
|
126
399
|
The model responses from the server.
|
127
400
|
|
@@ -146,15 +419,18 @@ class AsyncSession:
|
|
146
419
|
async def start_stream(
|
147
420
|
self, *, stream: AsyncIterator[bytes], mime_type: str
|
148
421
|
) -> AsyncIterator[types.LiveServerMessage]:
|
149
|
-
"""
|
422
|
+
"""[Deprecated] Start a live session from a data stream.
|
423
|
+
|
424
|
+
> **Warning**: This method is deprecated and will be removed in a future
|
425
|
+
version (not before Q2 2025). Please use one of the more specific methods:
|
426
|
+
`send_client_content`, `send_realtime_input`, or `send_tool_response`
|
427
|
+
instead.
|
150
428
|
|
151
429
|
The interaction terminates when the input stream is complete.
|
152
430
|
This method will start two async tasks. One task will be used to send the
|
153
431
|
input stream to the model and the other task will be used to receive the
|
154
432
|
responses from the model.
|
155
433
|
|
156
|
-
The live module is experimental.
|
157
|
-
|
158
434
|
Args:
|
159
435
|
stream: An iterator that yields the model response.
|
160
436
|
mime_type: The MIME type of the data in the stream.
|
@@ -177,6 +453,13 @@ class AsyncSession:
|
|
177
453
|
mime_type = 'audio/pcm'):
|
178
454
|
play_audio_chunk(audio.data)
|
179
455
|
"""
|
456
|
+
warnings.warn(
|
457
|
+
'Setting `AsyncSession.start_stream` is deprecated, '
|
458
|
+
'and will be removed in a future release (not before Q3 2025). '
|
459
|
+
'Please use the `receive`, and `send_realtime_input`, methods instead.',
|
460
|
+
DeprecationWarning,
|
461
|
+
stacklevel=4,
|
462
|
+
)
|
180
463
|
stop_event = asyncio.Event()
|
181
464
|
# Start the send loop. When stream is complete stop_event is set.
|
182
465
|
asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
|
@@ -207,7 +490,10 @@ class AsyncSession:
|
|
207
490
|
|
208
491
|
async def _receive(self) -> types.LiveServerMessage:
|
209
492
|
parameter_model = types.LiveServerMessage()
|
210
|
-
|
493
|
+
try:
|
494
|
+
raw_response = await self._ws.recv(decode=False)
|
495
|
+
except TypeError:
|
496
|
+
raw_response = await self._ws.recv() # type: ignore[assignment]
|
211
497
|
if raw_response:
|
212
498
|
try:
|
213
499
|
response = json.loads(raw_response)
|
@@ -215,6 +501,7 @@ class AsyncSession:
|
|
215
501
|
raise ValueError(f'Failed to parse response: {raw_response!r}')
|
216
502
|
else:
|
217
503
|
response = {}
|
504
|
+
|
218
505
|
if self._api_client.vertexai:
|
219
506
|
response_dict = self._LiveServerMessage_from_vertex(response)
|
220
507
|
else:
|
@@ -256,6 +543,24 @@ class AsyncSession:
|
|
256
543
|
)
|
257
544
|
if getv(from_object, ['turnComplete']) is not None:
|
258
545
|
setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
|
546
|
+
if getv(from_object, ['generationComplete']) is not None:
|
547
|
+
setv(
|
548
|
+
to_object,
|
549
|
+
['generation_complete'],
|
550
|
+
getv(from_object, ['generationComplete']),
|
551
|
+
)
|
552
|
+
if getv(from_object, ['inputTranscription']) is not None:
|
553
|
+
setv(
|
554
|
+
to_object,
|
555
|
+
['input_transcription'],
|
556
|
+
getv(from_object, ['inputTranscription']),
|
557
|
+
)
|
558
|
+
if getv(from_object, ['outputTranscription']) is not None:
|
559
|
+
setv(
|
560
|
+
to_object,
|
561
|
+
['output_transcription'],
|
562
|
+
getv(from_object, ['outputTranscription']),
|
563
|
+
)
|
259
564
|
if getv(from_object, ['interrupted']) is not None:
|
260
565
|
setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
|
261
566
|
return to_object
|
@@ -286,6 +591,128 @@ class AsyncSession:
|
|
286
591
|
)
|
287
592
|
return to_object
|
288
593
|
|
594
|
+
def _LiveServerGoAway_from_mldev(
|
595
|
+
self,
|
596
|
+
from_object: Union[dict, object],
|
597
|
+
parent_object: Optional[dict] = None,
|
598
|
+
) -> dict:
|
599
|
+
to_object: dict[str, Any] = {}
|
600
|
+
if getv(from_object, ['timeLeft']) is not None:
|
601
|
+
setv(to_object, ['time_left'], getv(from_object, ['timeLeft']))
|
602
|
+
|
603
|
+
return to_object
|
604
|
+
|
605
|
+
def _LiveServerSessionResumptionUpdate_from_mldev(
|
606
|
+
self,
|
607
|
+
from_object: Union[dict, object],
|
608
|
+
parent_object: Optional[dict] = None,
|
609
|
+
) -> dict:
|
610
|
+
to_object: dict[str, Any] = {}
|
611
|
+
if getv(from_object, ['newHandle']) is not None:
|
612
|
+
setv(to_object, ['new_handle'], getv(from_object, ['newHandle']))
|
613
|
+
|
614
|
+
if getv(from_object, ['resumable']) is not None:
|
615
|
+
setv(to_object, ['resumable'], getv(from_object, ['resumable']))
|
616
|
+
|
617
|
+
if getv(from_object, ['lastConsumedClientMessageIndex']) is not None:
|
618
|
+
setv(
|
619
|
+
to_object,
|
620
|
+
['last_consumed_client_message_index'],
|
621
|
+
getv(from_object, ['lastConsumedClientMessageIndex']),
|
622
|
+
)
|
623
|
+
|
624
|
+
return to_object
|
625
|
+
|
626
|
+
def _ModalityTokenCount_from_mldev(
|
627
|
+
self,
|
628
|
+
from_object: Union[dict, object],
|
629
|
+
) -> Dict[str, Any]:
|
630
|
+
to_object: Dict[str, Any] = {}
|
631
|
+
if getv(from_object, ['modality']) is not None:
|
632
|
+
setv(to_object, ['modality'], getv(from_object, ['modality']))
|
633
|
+
if getv(from_object, ['tokenCount']) is not None:
|
634
|
+
setv(to_object, ['token_count'], getv(from_object, ['tokenCount']))
|
635
|
+
return to_object
|
636
|
+
|
637
|
+
def _UsageMetadata_from_mldev(
|
638
|
+
self,
|
639
|
+
from_object: Union[dict, object],
|
640
|
+
) -> Dict[str, Any]:
|
641
|
+
to_object: dict[str, Any] = {}
|
642
|
+
if getv(from_object, ['promptTokenCount']) is not None:
|
643
|
+
setv(
|
644
|
+
to_object,
|
645
|
+
['prompt_token_count'],
|
646
|
+
getv(from_object, ['promptTokenCount']),
|
647
|
+
)
|
648
|
+
if getv(from_object, ['cachedContentTokenCount']) is not None:
|
649
|
+
setv(
|
650
|
+
to_object,
|
651
|
+
['cached_content_token_count'],
|
652
|
+
getv(from_object, ['cachedContentTokenCount']),
|
653
|
+
)
|
654
|
+
if getv(from_object, ['responseTokenCount']) is not None:
|
655
|
+
setv(
|
656
|
+
to_object,
|
657
|
+
['response_token_count'],
|
658
|
+
getv(from_object, ['responseTokenCount']),
|
659
|
+
)
|
660
|
+
if getv(from_object, ['toolUsePromptTokenCount']) is not None:
|
661
|
+
setv(
|
662
|
+
to_object,
|
663
|
+
['tool_use_prompt_token_count'],
|
664
|
+
getv(from_object, ['toolUsePromptTokenCount']),
|
665
|
+
)
|
666
|
+
if getv(from_object, ['thoughtsTokenCount']) is not None:
|
667
|
+
setv(
|
668
|
+
to_object,
|
669
|
+
['thoughts_token_count'],
|
670
|
+
getv(from_object, ['thoughtsTokenCount']),
|
671
|
+
)
|
672
|
+
if getv(from_object, ['totalTokenCount']) is not None:
|
673
|
+
setv(
|
674
|
+
to_object,
|
675
|
+
['total_token_count'],
|
676
|
+
getv(from_object, ['totalTokenCount']),
|
677
|
+
)
|
678
|
+
if getv(from_object, ['promptTokensDetails']) is not None:
|
679
|
+
setv(
|
680
|
+
to_object,
|
681
|
+
['prompt_tokens_details'],
|
682
|
+
[
|
683
|
+
self._ModalityTokenCount_from_mldev(item)
|
684
|
+
for item in getv(from_object, ['promptTokensDetails'])
|
685
|
+
],
|
686
|
+
)
|
687
|
+
if getv(from_object, ['cacheTokensDetails']) is not None:
|
688
|
+
setv(
|
689
|
+
to_object,
|
690
|
+
['cache_tokens_details'],
|
691
|
+
[
|
692
|
+
self._ModalityTokenCount_from_mldev(item)
|
693
|
+
for item in getv(from_object, ['cacheTokensDetails'])
|
694
|
+
],
|
695
|
+
)
|
696
|
+
if getv(from_object, ['responseTokensDetails']) is not None:
|
697
|
+
setv(
|
698
|
+
to_object,
|
699
|
+
['response_tokens_details'],
|
700
|
+
[
|
701
|
+
self._ModalityTokenCount_from_mldev(item)
|
702
|
+
for item in getv(from_object, ['responseTokensDetails'])
|
703
|
+
],
|
704
|
+
)
|
705
|
+
if getv(from_object, ['toolUsePromptTokensDetails']) is not None:
|
706
|
+
setv(
|
707
|
+
to_object,
|
708
|
+
['tool_use_prompt_tokens_details'],
|
709
|
+
[
|
710
|
+
self._ModalityTokenCount_from_mldev(item)
|
711
|
+
for item in getv(from_object, ['toolUsePromptTokensDetails'])
|
712
|
+
],
|
713
|
+
)
|
714
|
+
return to_object
|
715
|
+
|
289
716
|
def _LiveServerMessage_from_mldev(
|
290
717
|
self,
|
291
718
|
from_object: Union[dict, object],
|
@@ -311,6 +738,34 @@ class AsyncSession:
|
|
311
738
|
['tool_call_cancellation'],
|
312
739
|
getv(from_object, ['toolCallCancellation']),
|
313
740
|
)
|
741
|
+
|
742
|
+
if getv(from_object, ['goAway']) is not None:
|
743
|
+
setv(
|
744
|
+
to_object,
|
745
|
+
['go_away'],
|
746
|
+
self._LiveServerGoAway_from_mldev(
|
747
|
+
getv(from_object, ['goAway']), to_object
|
748
|
+
),
|
749
|
+
)
|
750
|
+
|
751
|
+
if getv(from_object, ['sessionResumptionUpdate']) is not None:
|
752
|
+
setv(
|
753
|
+
to_object,
|
754
|
+
['session_resumption_update'],
|
755
|
+
self._LiveServerSessionResumptionUpdate_from_mldev(
|
756
|
+
getv(from_object, ['sessionResumptionUpdate']),
|
757
|
+
to_object,
|
758
|
+
),
|
759
|
+
)
|
760
|
+
|
761
|
+
return to_object
|
762
|
+
|
763
|
+
if getv(from_object, ['usageMetadata']) is not None:
|
764
|
+
setv(
|
765
|
+
to_object,
|
766
|
+
['usage_metadata'],
|
767
|
+
self._UsageMetadata_from_mldev(getv(from_object, ['usageMetadata'])),
|
768
|
+
)
|
314
769
|
return to_object
|
315
770
|
|
316
771
|
def _LiveServerContent_from_vertex(
|
@@ -329,10 +784,155 @@ class AsyncSession:
|
|
329
784
|
)
|
330
785
|
if getv(from_object, ['turnComplete']) is not None:
|
331
786
|
setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
|
787
|
+
if getv(from_object, ['generationComplete']) is not None:
|
788
|
+
setv(
|
789
|
+
to_object,
|
790
|
+
['generation_complete'],
|
791
|
+
getv(from_object, ['generationComplete']),
|
792
|
+
)
|
793
|
+
if getv(from_object, ['inputTranscription']) is not None:
|
794
|
+
setv(
|
795
|
+
to_object,
|
796
|
+
['input_transcription'],
|
797
|
+
getv(from_object, ['inputTranscription']),
|
798
|
+
)
|
799
|
+
if getv(from_object, ['outputTranscription']) is not None:
|
800
|
+
setv(
|
801
|
+
to_object,
|
802
|
+
['output_transcription'],
|
803
|
+
getv(from_object, ['outputTranscription']),
|
804
|
+
)
|
332
805
|
if getv(from_object, ['interrupted']) is not None:
|
333
806
|
setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
|
334
807
|
return to_object
|
335
808
|
|
809
|
+
def _LiveServerGoAway_from_vertex(
|
810
|
+
self,
|
811
|
+
from_object: Union[dict, object],
|
812
|
+
) -> dict:
|
813
|
+
to_object: dict[str, Any] = {}
|
814
|
+
if getv(from_object, ['timeLeft']) is not None:
|
815
|
+
setv(to_object, ['time_left'], getv(from_object, ['timeLeft']))
|
816
|
+
|
817
|
+
return to_object
|
818
|
+
|
819
|
+
def _LiveServerSessionResumptionUpdate_from_vertex(
|
820
|
+
self,
|
821
|
+
from_object: Union[dict, object],
|
822
|
+
) -> dict:
|
823
|
+
to_object: dict[str, Any] = {}
|
824
|
+
if getv(from_object, ['newHandle']) is not None:
|
825
|
+
setv(to_object, ['new_handle'], getv(from_object, ['newHandle']))
|
826
|
+
|
827
|
+
if getv(from_object, ['resumable']) is not None:
|
828
|
+
setv(to_object, ['resumable'], getv(from_object, ['resumable']))
|
829
|
+
|
830
|
+
if getv(from_object, ['lastConsumedClientMessageIndex']) is not None:
|
831
|
+
setv(
|
832
|
+
to_object,
|
833
|
+
['last_consumed_client_message_index'],
|
834
|
+
getv(from_object, ['lastConsumedClientMessageIndex']),
|
835
|
+
)
|
836
|
+
|
837
|
+
return to_object
|
838
|
+
|
839
|
+
|
840
|
+
def _ModalityTokenCount_from_vertex(
|
841
|
+
self,
|
842
|
+
from_object: Union[dict, object],
|
843
|
+
) -> Dict[str, Any]:
|
844
|
+
to_object: Dict[str, Any] = {}
|
845
|
+
if getv(from_object, ['modality']) is not None:
|
846
|
+
setv(to_object, ['modality'], getv(from_object, ['modality']))
|
847
|
+
if getv(from_object, ['tokenCount']) is not None:
|
848
|
+
setv(to_object, ['token_count'], getv(from_object, ['tokenCount']))
|
849
|
+
return to_object
|
850
|
+
|
851
|
+
def _UsageMetadata_from_vertex(
|
852
|
+
self,
|
853
|
+
from_object: Union[dict, object],
|
854
|
+
) -> Dict[str, Any]:
|
855
|
+
to_object: dict[str, Any] = {}
|
856
|
+
if getv(from_object, ['promptTokenCount']) is not None:
|
857
|
+
setv(
|
858
|
+
to_object,
|
859
|
+
['prompt_token_count'],
|
860
|
+
getv(from_object, ['promptTokenCount']),
|
861
|
+
)
|
862
|
+
if getv(from_object, ['cachedContentTokenCount']) is not None:
|
863
|
+
setv(
|
864
|
+
to_object,
|
865
|
+
['cached_content_token_count'],
|
866
|
+
getv(from_object, ['cachedContentTokenCount']),
|
867
|
+
)
|
868
|
+
if getv(from_object, ['candidatesTokenCount']) is not None:
|
869
|
+
setv(
|
870
|
+
to_object,
|
871
|
+
['response_token_count'],
|
872
|
+
getv(from_object, ['candidatesTokenCount']),
|
873
|
+
)
|
874
|
+
if getv(from_object, ['toolUsePromptTokenCount']) is not None:
|
875
|
+
setv(
|
876
|
+
to_object,
|
877
|
+
['tool_use_prompt_token_count'],
|
878
|
+
getv(from_object, ['toolUsePromptTokenCount']),
|
879
|
+
)
|
880
|
+
if getv(from_object, ['thoughtsTokenCount']) is not None:
|
881
|
+
setv(
|
882
|
+
to_object,
|
883
|
+
['thoughts_token_count'],
|
884
|
+
getv(from_object, ['thoughtsTokenCount']),
|
885
|
+
)
|
886
|
+
if getv(from_object, ['totalTokenCount']) is not None:
|
887
|
+
setv(
|
888
|
+
to_object,
|
889
|
+
['total_token_count'],
|
890
|
+
getv(from_object, ['totalTokenCount']),
|
891
|
+
)
|
892
|
+
if getv(from_object, ['promptTokensDetails']) is not None:
|
893
|
+
setv(
|
894
|
+
to_object,
|
895
|
+
['prompt_tokens_details'],
|
896
|
+
[
|
897
|
+
self._ModalityTokenCount_from_vertex(item)
|
898
|
+
for item in getv(from_object, ['promptTokensDetails'])
|
899
|
+
],
|
900
|
+
)
|
901
|
+
if getv(from_object, ['cacheTokensDetails']) is not None:
|
902
|
+
setv(
|
903
|
+
to_object,
|
904
|
+
['cache_tokens_details'],
|
905
|
+
[
|
906
|
+
self._ModalityTokenCount_from_vertex(item)
|
907
|
+
for item in getv(from_object, ['cacheTokensDetails'])
|
908
|
+
],
|
909
|
+
)
|
910
|
+
if getv(from_object, ['toolUsePromptTokensDetails']) is not None:
|
911
|
+
setv(
|
912
|
+
to_object,
|
913
|
+
['tool_use_prompt_tokens_details'],
|
914
|
+
[
|
915
|
+
self._ModalityTokenCount_from_vertex(item)
|
916
|
+
for item in getv(from_object, ['toolUsePromptTokensDetails'])
|
917
|
+
],
|
918
|
+
)
|
919
|
+
if getv(from_object, ['candidatesTokensDetails']) is not None:
|
920
|
+
setv(
|
921
|
+
to_object,
|
922
|
+
['response_tokens_details'],
|
923
|
+
[
|
924
|
+
self._ModalityTokenCount_from_vertex(item)
|
925
|
+
for item in getv(from_object, ['candidatesTokensDetails'])
|
926
|
+
],
|
927
|
+
)
|
928
|
+
if getv(from_object, ['trafficType']) is not None:
|
929
|
+
setv(
|
930
|
+
to_object,
|
931
|
+
['traffic_type'],
|
932
|
+
getv(from_object, ['trafficType']),
|
933
|
+
)
|
934
|
+
return to_object
|
935
|
+
|
336
936
|
def _LiveServerMessage_from_vertex(
|
337
937
|
self,
|
338
938
|
from_object: Union[dict, object],
|
@@ -346,7 +946,6 @@ class AsyncSession:
|
|
346
946
|
getv(from_object, ['serverContent'])
|
347
947
|
),
|
348
948
|
)
|
349
|
-
|
350
949
|
if getv(from_object, ['toolCall']) is not None:
|
351
950
|
setv(
|
352
951
|
to_object,
|
@@ -359,6 +958,31 @@ class AsyncSession:
|
|
359
958
|
['tool_call_cancellation'],
|
360
959
|
getv(from_object, ['toolCallCancellation']),
|
361
960
|
)
|
961
|
+
|
962
|
+
if getv(from_object, ['goAway']) is not None:
|
963
|
+
setv(
|
964
|
+
to_object,
|
965
|
+
['go_away'],
|
966
|
+
self._LiveServerGoAway_from_vertex(
|
967
|
+
getv(from_object, ['goAway'])
|
968
|
+
),
|
969
|
+
)
|
970
|
+
|
971
|
+
if getv(from_object, ['sessionResumptionUpdate']) is not None:
|
972
|
+
setv(
|
973
|
+
to_object,
|
974
|
+
['session_resumption_update'],
|
975
|
+
self._LiveServerSessionResumptionUpdate_from_vertex(
|
976
|
+
getv(from_object, ['sessionResumptionUpdate']),
|
977
|
+
),
|
978
|
+
)
|
979
|
+
|
980
|
+
if getv(from_object, ['usageMetadata']) is not None:
|
981
|
+
setv(
|
982
|
+
to_object,
|
983
|
+
['usage_metadata'],
|
984
|
+
self._UsageMetadata_from_vertex(getv(from_object, ['usageMetadata'])),
|
985
|
+
)
|
362
986
|
return to_object
|
363
987
|
|
364
988
|
def _parse_client_message(
|
@@ -669,8 +1293,81 @@ class AsyncSession:
|
|
669
1293
|
await self._ws.close()
|
670
1294
|
|
671
1295
|
|
1296
|
+
def _t_content_strict(content: types.ContentOrDict):
|
1297
|
+
if isinstance(content, dict):
|
1298
|
+
return types.Content.model_validate(content)
|
1299
|
+
elif isinstance(content, types.Content):
|
1300
|
+
return content
|
1301
|
+
else:
|
1302
|
+
raise ValueError(
|
1303
|
+
f'Could not convert input (type "{type(content)}") to '
|
1304
|
+
'`types.Content`'
|
1305
|
+
)
|
1306
|
+
|
1307
|
+
|
1308
|
+
def _t_contents_strict(
|
1309
|
+
contents: Union[Sequence[types.ContentOrDict], types.ContentOrDict]):
|
1310
|
+
if isinstance(contents, Sequence):
|
1311
|
+
return [_t_content_strict(content) for content in contents]
|
1312
|
+
else:
|
1313
|
+
return [_t_content_strict(contents)]
|
1314
|
+
|
1315
|
+
|
1316
|
+
def _t_client_content(
|
1317
|
+
turns: Optional[
|
1318
|
+
Union[Sequence[types.ContentOrDict], types.ContentOrDict]
|
1319
|
+
] = None,
|
1320
|
+
turn_complete: bool = True,
|
1321
|
+
) -> types.LiveClientContent:
|
1322
|
+
if turns is None:
|
1323
|
+
return types.LiveClientContent(turn_complete=turn_complete)
|
1324
|
+
|
1325
|
+
try:
|
1326
|
+
return types.LiveClientContent(
|
1327
|
+
turns=_t_contents_strict(contents=turns),
|
1328
|
+
turn_complete=turn_complete,
|
1329
|
+
)
|
1330
|
+
except Exception as e:
|
1331
|
+
raise ValueError(
|
1332
|
+
f'Could not convert input (type "{type(turns)}") to '
|
1333
|
+
'`types.LiveClientContent`'
|
1334
|
+
) from e
|
1335
|
+
|
1336
|
+
|
1337
|
+
def _t_realtime_input(
|
1338
|
+
media: t.BlobUnion,
|
1339
|
+
) -> types.LiveClientRealtimeInput:
|
1340
|
+
try:
|
1341
|
+
return types.LiveClientRealtimeInput(media_chunks=[t.t_blob(blob=media)])
|
1342
|
+
except Exception as e:
|
1343
|
+
raise ValueError(
|
1344
|
+
f'Could not convert input (type "{type(input)}") to '
|
1345
|
+
'`types.LiveClientRealtimeInput`'
|
1346
|
+
) from e
|
1347
|
+
|
1348
|
+
|
1349
|
+
def _t_tool_response(
|
1350
|
+
input: Union[
|
1351
|
+
types.FunctionResponseOrDict,
|
1352
|
+
Sequence[types.FunctionResponseOrDict],
|
1353
|
+
],
|
1354
|
+
) -> types.LiveClientToolResponse:
|
1355
|
+
if not input:
|
1356
|
+
raise ValueError(f'A tool response is required, got: \n{input}')
|
1357
|
+
|
1358
|
+
try:
|
1359
|
+
return types.LiveClientToolResponse(
|
1360
|
+
function_responses=t.t_function_responses(function_responses=input)
|
1361
|
+
)
|
1362
|
+
except Exception as e:
|
1363
|
+
raise ValueError(
|
1364
|
+
f'Could not convert input (type "{type(input)}") to '
|
1365
|
+
'`types.LiveClientToolResponse`'
|
1366
|
+
) from e
|
1367
|
+
|
1368
|
+
|
672
1369
|
class AsyncLive(_api_module.BaseModule):
|
673
|
-
"""AsyncLive.
|
1370
|
+
"""[Preview] AsyncLive."""
|
674
1371
|
|
675
1372
|
def _LiveSetup_to_mldev(
|
676
1373
|
self, model: str, config: Optional[types.LiveConnectConfig] = None
|
@@ -715,7 +1412,48 @@ class AsyncLive(_api_module.BaseModule):
|
|
715
1412
|
to_object,
|
716
1413
|
)
|
717
1414
|
}
|
718
|
-
|
1415
|
+
if getv(config, ['temperature']) is not None:
|
1416
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1417
|
+
to_object['generationConfig']['temperature'] = getv(
|
1418
|
+
config, ['temperature']
|
1419
|
+
)
|
1420
|
+
else:
|
1421
|
+
to_object['generationConfig'] = {
|
1422
|
+
'temperature': getv(config, ['temperature'])
|
1423
|
+
}
|
1424
|
+
if getv(config, ['top_p']) is not None:
|
1425
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1426
|
+
to_object['generationConfig']['topP'] = getv(config, ['top_p'])
|
1427
|
+
else:
|
1428
|
+
to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
|
1429
|
+
if getv(config, ['top_k']) is not None:
|
1430
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1431
|
+
to_object['generationConfig']['topK'] = getv(config, ['top_k'])
|
1432
|
+
else:
|
1433
|
+
to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
|
1434
|
+
if getv(config, ['max_output_tokens']) is not None:
|
1435
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1436
|
+
to_object['generationConfig']['maxOutputTokens'] = getv(
|
1437
|
+
config, ['max_output_tokens']
|
1438
|
+
)
|
1439
|
+
else:
|
1440
|
+
to_object['generationConfig'] = {
|
1441
|
+
'maxOutputTokens': getv(config, ['max_output_tokens'])
|
1442
|
+
}
|
1443
|
+
if getv(config, ['media_resolution']) is not None:
|
1444
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1445
|
+
to_object['generationConfig']['mediaResolution'] = getv(
|
1446
|
+
config, ['media_resolution']
|
1447
|
+
)
|
1448
|
+
else:
|
1449
|
+
to_object['generationConfig'] = {
|
1450
|
+
'mediaResolution': getv(config, ['media_resolution'])
|
1451
|
+
}
|
1452
|
+
if getv(config, ['seed']) is not None:
|
1453
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1454
|
+
to_object['generationConfig']['seed'] = getv(config, ['seed'])
|
1455
|
+
else:
|
1456
|
+
to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
|
719
1457
|
if getv(config, ['system_instruction']) is not None:
|
720
1458
|
setv(
|
721
1459
|
to_object,
|
@@ -739,11 +1477,84 @@ class AsyncLive(_api_module.BaseModule):
|
|
739
1477
|
for item in t.t_tools(self._api_client, getv(config, ['tools']))
|
740
1478
|
],
|
741
1479
|
)
|
1480
|
+
if getv(config, ['input_audio_transcription']) is not None:
|
1481
|
+
raise ValueError('input_audio_transcription is not supported in MLDev '
|
1482
|
+
'API.')
|
1483
|
+
if getv(config, ['output_audio_transcription']) is not None:
|
1484
|
+
setv(
|
1485
|
+
to_object,
|
1486
|
+
['outputAudioTranscription'],
|
1487
|
+
_AudioTranscriptionConfig_to_mldev(
|
1488
|
+
self._api_client,
|
1489
|
+
getv(config, ['output_audio_transcription']),
|
1490
|
+
),
|
1491
|
+
)
|
1492
|
+
|
1493
|
+
if getv(config, ['session_resumption']) is not None:
|
1494
|
+
setv(
|
1495
|
+
to_object,
|
1496
|
+
['sessionResumption'],
|
1497
|
+
self._LiveClientSessionResumptionConfig_to_mldev(
|
1498
|
+
getv(config, ['session_resumption'])
|
1499
|
+
),
|
1500
|
+
)
|
1501
|
+
|
1502
|
+
if getv(config, ['context_window_compression']) is not None:
|
1503
|
+
setv(
|
1504
|
+
to_object,
|
1505
|
+
['contextWindowCompression'],
|
1506
|
+
self._ContextWindowCompressionConfig_to_mldev(
|
1507
|
+
getv(config, ['context_window_compression']),
|
1508
|
+
),
|
1509
|
+
)
|
742
1510
|
|
743
1511
|
return_value = {'setup': {'model': model}}
|
744
1512
|
return_value['setup'].update(to_object)
|
745
1513
|
return return_value
|
746
1514
|
|
1515
|
+
def _SlidingWindow_to_mldev(
|
1516
|
+
self,
|
1517
|
+
from_object: Union[dict, object],
|
1518
|
+
) -> dict:
|
1519
|
+
to_object: dict[str, Any] = {}
|
1520
|
+
if getv(from_object, ['target_tokens']) is not None:
|
1521
|
+
setv(to_object, ['targetTokens'], getv(from_object, ['target_tokens']))
|
1522
|
+
|
1523
|
+
return to_object
|
1524
|
+
|
1525
|
+
|
1526
|
+
def _ContextWindowCompressionConfig_to_mldev(
|
1527
|
+
self,
|
1528
|
+
from_object: Union[dict, object],
|
1529
|
+
) -> dict:
|
1530
|
+
to_object: dict[str, Any] = {}
|
1531
|
+
if getv(from_object, ['trigger_tokens']) is not None:
|
1532
|
+
setv(to_object, ['triggerTokens'], getv(from_object, ['trigger_tokens']))
|
1533
|
+
|
1534
|
+
if getv(from_object, ['sliding_window']) is not None:
|
1535
|
+
setv(
|
1536
|
+
to_object,
|
1537
|
+
['slidingWindow'],
|
1538
|
+
self._SlidingWindow_to_mldev(
|
1539
|
+
getv(from_object, ['sliding_window'])
|
1540
|
+
),
|
1541
|
+
)
|
1542
|
+
|
1543
|
+
return to_object
|
1544
|
+
|
1545
|
+
def _LiveClientSessionResumptionConfig_to_mldev(
|
1546
|
+
self,
|
1547
|
+
from_object: Union[dict, object]
|
1548
|
+
) -> dict:
|
1549
|
+
to_object: dict[str, Any] = {}
|
1550
|
+
if getv(from_object, ['handle']) is not None:
|
1551
|
+
setv(to_object, ['handle'], getv(from_object, ['handle']))
|
1552
|
+
|
1553
|
+
if getv(from_object, ['transparent']) is not None:
|
1554
|
+
raise ValueError('The `transparent` field is not supported in MLDev API')
|
1555
|
+
|
1556
|
+
return to_object
|
1557
|
+
|
747
1558
|
def _LiveSetup_to_vertex(
|
748
1559
|
self, model: str, config: Optional[types.LiveConnectConfig] = None
|
749
1560
|
):
|
@@ -796,6 +1607,48 @@ class AsyncLive(_api_module.BaseModule):
|
|
796
1607
|
to_object,
|
797
1608
|
)
|
798
1609
|
}
|
1610
|
+
if getv(config, ['temperature']) is not None:
|
1611
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1612
|
+
to_object['generationConfig']['temperature'] = getv(
|
1613
|
+
config, ['temperature']
|
1614
|
+
)
|
1615
|
+
else:
|
1616
|
+
to_object['generationConfig'] = {
|
1617
|
+
'temperature': getv(config, ['temperature'])
|
1618
|
+
}
|
1619
|
+
if getv(config, ['top_p']) is not None:
|
1620
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1621
|
+
to_object['generationConfig']['topP'] = getv(config, ['top_p'])
|
1622
|
+
else:
|
1623
|
+
to_object['generationConfig'] = {'topP': getv(config, ['top_p'])}
|
1624
|
+
if getv(config, ['top_k']) is not None:
|
1625
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1626
|
+
to_object['generationConfig']['topK'] = getv(config, ['top_k'])
|
1627
|
+
else:
|
1628
|
+
to_object['generationConfig'] = {'topK': getv(config, ['top_k'])}
|
1629
|
+
if getv(config, ['max_output_tokens']) is not None:
|
1630
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1631
|
+
to_object['generationConfig']['maxOutputTokens'] = getv(
|
1632
|
+
config, ['max_output_tokens']
|
1633
|
+
)
|
1634
|
+
else:
|
1635
|
+
to_object['generationConfig'] = {
|
1636
|
+
'maxOutputTokens': getv(config, ['max_output_tokens'])
|
1637
|
+
}
|
1638
|
+
if getv(config, ['media_resolution']) is not None:
|
1639
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1640
|
+
to_object['generationConfig']['mediaResolution'] = getv(
|
1641
|
+
config, ['media_resolution']
|
1642
|
+
)
|
1643
|
+
else:
|
1644
|
+
to_object['generationConfig'] = {
|
1645
|
+
'mediaResolution': getv(config, ['media_resolution'])
|
1646
|
+
}
|
1647
|
+
if getv(config, ['seed']) is not None:
|
1648
|
+
if getv(to_object, ['generationConfig']) is not None:
|
1649
|
+
to_object['generationConfig']['seed'] = getv(config, ['seed'])
|
1650
|
+
else:
|
1651
|
+
to_object['generationConfig'] = {'seed': getv(config, ['seed'])}
|
799
1652
|
if getv(config, ['system_instruction']) is not None:
|
800
1653
|
setv(
|
801
1654
|
to_object,
|
@@ -819,14 +1672,89 @@ class AsyncLive(_api_module.BaseModule):
|
|
819
1672
|
for item in t.t_tools(self._api_client, getv(config, ['tools']))
|
820
1673
|
],
|
821
1674
|
)
|
1675
|
+
if getv(config, ['input_audio_transcription']) is not None:
|
1676
|
+
setv(
|
1677
|
+
to_object,
|
1678
|
+
['inputAudioTranscription'],
|
1679
|
+
_AudioTranscriptionConfig_to_vertex(
|
1680
|
+
self._api_client,
|
1681
|
+
getv(config, ['input_audio_transcription']),
|
1682
|
+
),
|
1683
|
+
)
|
1684
|
+
if getv(config, ['output_audio_transcription']) is not None:
|
1685
|
+
setv(
|
1686
|
+
to_object,
|
1687
|
+
['outputAudioTranscription'],
|
1688
|
+
_AudioTranscriptionConfig_to_vertex(
|
1689
|
+
self._api_client,
|
1690
|
+
getv(config, ['output_audio_transcription']),
|
1691
|
+
),
|
1692
|
+
)
|
1693
|
+
|
1694
|
+
if getv(config, ['session_resumption']) is not None:
|
1695
|
+
setv(
|
1696
|
+
to_object,
|
1697
|
+
['sessionResumption'],
|
1698
|
+
self._LiveClientSessionResumptionConfig_to_vertex(
|
1699
|
+
getv(config, ['session_resumption'])
|
1700
|
+
),
|
1701
|
+
)
|
1702
|
+
|
1703
|
+
if getv(config, ['context_window_compression']) is not None:
|
1704
|
+
setv(
|
1705
|
+
to_object,
|
1706
|
+
['contextWindowCompression'],
|
1707
|
+
self._ContextWindowCompressionConfig_to_vertex(
|
1708
|
+
getv(config, ['context_window_compression']),
|
1709
|
+
),
|
1710
|
+
)
|
822
1711
|
|
823
1712
|
return_value = {'setup': {'model': model}}
|
824
1713
|
return_value['setup'].update(to_object)
|
825
1714
|
return return_value
|
826
1715
|
|
827
|
-
|
828
|
-
|
829
|
-
|
1716
|
+
def _SlidingWindow_to_vertex(
|
1717
|
+
self,
|
1718
|
+
from_object: Union[dict, object],
|
1719
|
+
) -> dict:
|
1720
|
+
to_object: dict[str, Any] = {}
|
1721
|
+
if getv(from_object, ['target_tokens']) is not None:
|
1722
|
+
setv(to_object, ['targetTokens'], getv(from_object, ['target_tokens']))
|
1723
|
+
|
1724
|
+
return to_object
|
1725
|
+
|
1726
|
+
def _ContextWindowCompressionConfig_to_vertex(
|
1727
|
+
self,
|
1728
|
+
from_object: Union[dict, object],
|
1729
|
+
) -> dict:
|
1730
|
+
to_object: dict[str, Any] = {}
|
1731
|
+
if getv(from_object, ['trigger_tokens']) is not None:
|
1732
|
+
setv(to_object, ['triggerTokens'], getv(from_object, ['trigger_tokens']))
|
1733
|
+
|
1734
|
+
if getv(from_object, ['sliding_window']) is not None:
|
1735
|
+
setv(
|
1736
|
+
to_object,
|
1737
|
+
['slidingWindow'],
|
1738
|
+
self._SlidingWindow_to_mldev(
|
1739
|
+
getv(from_object, ['sliding_window'])
|
1740
|
+
),
|
1741
|
+
)
|
1742
|
+
|
1743
|
+
return to_object
|
1744
|
+
|
1745
|
+
def _LiveClientSessionResumptionConfig_to_vertex(
|
1746
|
+
self,
|
1747
|
+
from_object: Union[dict, object]
|
1748
|
+
) -> dict:
|
1749
|
+
to_object: dict[str, Any] = {}
|
1750
|
+
if getv(from_object, ['handle']) is not None:
|
1751
|
+
setv(to_object, ['handle'], getv(from_object, ['handle']))
|
1752
|
+
|
1753
|
+
if getv(from_object, ['transparent']) is not None:
|
1754
|
+
setv(to_object, ['transparent'], getv(from_object, ['transparent']))
|
1755
|
+
|
1756
|
+
return to_object
|
1757
|
+
|
830
1758
|
@contextlib.asynccontextmanager
|
831
1759
|
async def connect(
|
832
1760
|
self,
|
@@ -834,9 +1762,9 @@ class AsyncLive(_api_module.BaseModule):
|
|
834
1762
|
model: str,
|
835
1763
|
config: Optional[types.LiveConnectConfigOrDict] = None,
|
836
1764
|
) -> AsyncIterator[AsyncSession]:
|
837
|
-
"""Connect to the live server.
|
1765
|
+
"""[Preview] Connect to the live server.
|
838
1766
|
|
839
|
-
|
1767
|
+
Note: the live API is currently in preview.
|
840
1768
|
|
841
1769
|
Usage:
|
842
1770
|
|
@@ -851,25 +1779,8 @@ class AsyncLive(_api_module.BaseModule):
|
|
851
1779
|
"""
|
852
1780
|
base_url = self._api_client._websocket_base_url()
|
853
1781
|
transformed_model = t.t_model(self._api_client, model)
|
854
|
-
|
855
|
-
|
856
|
-
parameter_model = types.LiveConnectConfig()
|
857
|
-
elif isinstance(config, dict):
|
858
|
-
if config.get('system_instruction') is None:
|
859
|
-
system_instruction = None
|
860
|
-
else:
|
861
|
-
system_instruction = t.t_content(
|
862
|
-
self._api_client, config.get('system_instruction')
|
863
|
-
)
|
864
|
-
parameter_model = types.LiveConnectConfig(
|
865
|
-
generation_config=config.get('generation_config'),
|
866
|
-
response_modalities=config.get('response_modalities'),
|
867
|
-
speech_config=config.get('speech_config'),
|
868
|
-
system_instruction=system_instruction,
|
869
|
-
tools=config.get('tools'),
|
870
|
-
)
|
871
|
-
else:
|
872
|
-
parameter_model = config
|
1782
|
+
|
1783
|
+
parameter_model = _t_live_connect_config(self._api_client, config)
|
873
1784
|
|
874
1785
|
if self._api_client.api_key:
|
875
1786
|
api_key = self._api_client.api_key
|
@@ -915,8 +1826,50 @@ class AsyncLive(_api_module.BaseModule):
|
|
915
1826
|
)
|
916
1827
|
request = json.dumps(request_dict)
|
917
1828
|
|
918
|
-
|
919
|
-
|
920
|
-
|
1829
|
+
try:
|
1830
|
+
async with connect(uri, additional_headers=headers) as ws:
|
1831
|
+
await ws.send(request)
|
1832
|
+
logger.info(await ws.recv(decode=False))
|
1833
|
+
|
1834
|
+
yield AsyncSession(api_client=self._api_client, websocket=ws)
|
1835
|
+
except TypeError:
|
1836
|
+
# Try with the older websockets API
|
1837
|
+
async with connect(uri, extra_headers=headers) as ws:
|
1838
|
+
await ws.send(request)
|
1839
|
+
logger.info(await ws.recv())
|
1840
|
+
|
1841
|
+
yield AsyncSession(api_client=self._api_client, websocket=ws)
|
1842
|
+
|
1843
|
+
|
1844
|
+
def _t_live_connect_config(
|
1845
|
+
api_client: BaseApiClient,
|
1846
|
+
config: Optional[types.LiveConnectConfigOrDict],
|
1847
|
+
) -> types.LiveConnectConfig:
|
1848
|
+
# Ensure the config is a LiveConnectConfig.
|
1849
|
+
if config is None:
|
1850
|
+
parameter_model = types.LiveConnectConfig()
|
1851
|
+
elif isinstance(config, dict):
|
1852
|
+
system_instruction = config.pop('system_instruction', None)
|
1853
|
+
if system_instruction is not None:
|
1854
|
+
converted_system_instruction = t.t_content(
|
1855
|
+
api_client, content=system_instruction
|
1856
|
+
)
|
1857
|
+
else:
|
1858
|
+
converted_system_instruction = None
|
1859
|
+
parameter_model = types.LiveConnectConfig(
|
1860
|
+
system_instruction=converted_system_instruction,
|
1861
|
+
**config
|
1862
|
+
) # type: ignore
|
1863
|
+
else:
|
1864
|
+
parameter_model = config
|
1865
|
+
|
1866
|
+
if parameter_model.generation_config is not None:
|
1867
|
+
warnings.warn(
|
1868
|
+
'Setting `LiveConnectConfig.generation_config` is deprecated, '
|
1869
|
+
'please set the fields on `LiveConnectConfig` directly. This will '
|
1870
|
+
'become an error in a future version (not before Q3 2025)',
|
1871
|
+
DeprecationWarning,
|
1872
|
+
stacklevel=4,
|
1873
|
+
)
|
921
1874
|
|
922
|
-
|
1875
|
+
return parameter_model
|