google-genai 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/genai/__init__.py +20 -0
- google/genai/_api_client.py +467 -0
- google/genai/_automatic_function_calling_util.py +341 -0
- google/genai/_common.py +256 -0
- google/genai/_extra_utils.py +295 -0
- google/genai/_replay_api_client.py +478 -0
- google/genai/_test_api_client.py +149 -0
- google/genai/_transformers.py +438 -0
- google/genai/batches.py +1041 -0
- google/genai/caches.py +1830 -0
- google/genai/chats.py +184 -0
- google/genai/client.py +277 -0
- google/genai/errors.py +110 -0
- google/genai/files.py +1211 -0
- google/genai/live.py +629 -0
- google/genai/models.py +5307 -0
- google/genai/pagers.py +245 -0
- google/genai/tunings.py +1366 -0
- google/genai/types.py +7639 -0
- google_genai-0.0.1.dist-info/LICENSE +202 -0
- google_genai-0.0.1.dist-info/METADATA +763 -0
- google_genai-0.0.1.dist-info/RECORD +24 -0
- google_genai-0.0.1.dist-info/WHEEL +5 -0
- google_genai-0.0.1.dist-info/top_level.txt +1 -0
google/genai/live.py
ADDED
@@ -0,0 +1,629 @@
|
|
1
|
+
# Copyright 2024 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
#
|
15
|
+
|
16
|
+
"""Live client."""
|
17
|
+
|
18
|
+
import asyncio
|
19
|
+
import base64
|
20
|
+
import contextlib
|
21
|
+
import json
|
22
|
+
import logging
|
23
|
+
from typing import AsyncIterator, Optional, Sequence, Union
|
24
|
+
|
25
|
+
import google.auth
|
26
|
+
from websockets import ConnectionClosed
|
27
|
+
|
28
|
+
from . import _common
|
29
|
+
from . import _transformers as t
|
30
|
+
from . import client
|
31
|
+
from . import types
|
32
|
+
from ._api_client import ApiClient
|
33
|
+
from ._common import get_value_by_path as getv
|
34
|
+
from ._common import set_value_by_path as setv
|
35
|
+
from .models import _Content_from_mldev
|
36
|
+
from .models import _Content_from_vertex
|
37
|
+
from .models import _Content_to_mldev
|
38
|
+
from .models import _Content_to_vertex
|
39
|
+
from .models import _GenerateContentConfig_to_mldev
|
40
|
+
from .models import _GenerateContentConfig_to_vertex
|
41
|
+
from .models import _SafetySetting_to_mldev
|
42
|
+
from .models import _SafetySetting_to_vertex
|
43
|
+
from .models import _SpeechConfig_to_mldev
|
44
|
+
from .models import _SpeechConfig_to_vertex
|
45
|
+
from .models import _Tool_to_mldev
|
46
|
+
from .models import _Tool_to_vertex
|
47
|
+
|
48
|
+
try:
|
49
|
+
from websockets.asyncio.client import ClientConnection
|
50
|
+
from websockets.asyncio.client import connect
|
51
|
+
except ModuleNotFoundError:
|
52
|
+
from websockets.client import ClientConnection
|
53
|
+
from websockets.client import connect
|
54
|
+
|
55
|
+
|
56
|
+
class AsyncSession:
|
57
|
+
"""AsyncSession."""
|
58
|
+
|
59
|
+
def __init__(self, api_client: client.ApiClient, websocket: ClientConnection):
|
60
|
+
self._api_client = api_client
|
61
|
+
self._ws = websocket
|
62
|
+
|
63
|
+
async def send(
|
64
|
+
self,
|
65
|
+
input: Union[
|
66
|
+
types.ContentListUnion,
|
67
|
+
types.ContentListUnionDict,
|
68
|
+
types.LiveClientContentOrDict,
|
69
|
+
types.LiveClientRealtimeInputOrDict,
|
70
|
+
types.LiveClientRealtimeInputOrDict,
|
71
|
+
types.LiveClientToolResponseOrDict,
|
72
|
+
types.FunctionResponseOrDict,
|
73
|
+
Sequence[types.FunctionResponseOrDict],
|
74
|
+
],
|
75
|
+
end_of_turn: Optional[bool] = False,
|
76
|
+
):
|
77
|
+
client_message = self._parse_client_message(input, end_of_turn)
|
78
|
+
await self._ws.send(json.dumps(client_message))
|
79
|
+
|
80
|
+
async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
|
81
|
+
"""Receive model responses from the server.
|
82
|
+
|
83
|
+
The method will yield the model responses from the server. The returned
|
84
|
+
responses will represent a complete model turn.
|
85
|
+
when the returned message is fuction call, user must call `send` with the
|
86
|
+
function response to continue the turn.
|
87
|
+
Example usage:
|
88
|
+
```
|
89
|
+
client = genai.Client(api_key=API_KEY)
|
90
|
+
|
91
|
+
async with client.aio.live.connect(model='...') as session:
|
92
|
+
await session.send(input='Hello world!', end_of_turn=True)
|
93
|
+
async for message in session.receive():
|
94
|
+
print(message)
|
95
|
+
```
|
96
|
+
Yields:
|
97
|
+
The model responses from the server.
|
98
|
+
"""
|
99
|
+
# TODO(b/365983264) Handle intermittent issues for the user.
|
100
|
+
while result := await self._receive():
|
101
|
+
if result.server_content and result.server_content.turn_complete:
|
102
|
+
yield result
|
103
|
+
break
|
104
|
+
yield result
|
105
|
+
|
106
|
+
async def start_stream(
|
107
|
+
self, stream: AsyncIterator[bytes], mime_type: str
|
108
|
+
) -> AsyncIterator[types.LiveServerMessage]:
|
109
|
+
"""start a live session from a data stream.
|
110
|
+
|
111
|
+
The interaction terminates when the input stream is complete.
|
112
|
+
This method will start two async tasks. One task will be used to send the
|
113
|
+
input stream to the model and the other task will be used to receive the
|
114
|
+
responses from the model.
|
115
|
+
|
116
|
+
Example usage:
|
117
|
+
```
|
118
|
+
client = genai.Client(api_key=API_KEY)
|
119
|
+
config = {'response_modalities': ['AUDIO']}
|
120
|
+
|
121
|
+
async def audio_stream():
|
122
|
+
stream = read_audio()
|
123
|
+
for data in stream:
|
124
|
+
yield data
|
125
|
+
|
126
|
+
async with client.aio.live.connect(model='...') as session:
|
127
|
+
for audio in session.start_stream(stream = audio_stream(),
|
128
|
+
mime_type = 'audio/pcm'):
|
129
|
+
play_audio_chunk(audio.data)
|
130
|
+
```
|
131
|
+
|
132
|
+
Args:
|
133
|
+
stream: An iterator that yields the model response.
|
134
|
+
mime_type: The MIME type of the data in the stream.
|
135
|
+
|
136
|
+
Yields:
|
137
|
+
The audio bytes received from the model and server response messages.
|
138
|
+
"""
|
139
|
+
stop_event = asyncio.Event()
|
140
|
+
# Start the send loop. When stream is complete stop_event is set.
|
141
|
+
asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
|
142
|
+
recv_task = None
|
143
|
+
while not stop_event.is_set():
|
144
|
+
try:
|
145
|
+
recv_task = asyncio.create_task(self._receive())
|
146
|
+
await asyncio.wait(
|
147
|
+
[
|
148
|
+
recv_task,
|
149
|
+
asyncio.create_task(stop_event.wait()),
|
150
|
+
],
|
151
|
+
return_when=asyncio.FIRST_COMPLETED,
|
152
|
+
)
|
153
|
+
if recv_task.done():
|
154
|
+
yield recv_task.result()
|
155
|
+
# Give a chance for the send loop to process requests.
|
156
|
+
await asyncio.sleep(10**-12)
|
157
|
+
except ConnectionClosed:
|
158
|
+
break
|
159
|
+
if recv_task is not None and not recv_task.done():
|
160
|
+
recv_task.cancel()
|
161
|
+
# Wait for the task to finish (cancelled or not)
|
162
|
+
try:
|
163
|
+
await recv_task
|
164
|
+
except asyncio.CancelledError:
|
165
|
+
pass
|
166
|
+
|
167
|
+
async def _receive(self) -> types.LiveServerMessage:
|
168
|
+
parameter_model = types.LiveServerMessage()
|
169
|
+
raw_response = await self._ws.recv(decode=False)
|
170
|
+
if raw_response:
|
171
|
+
try:
|
172
|
+
response = json.loads(raw_response)
|
173
|
+
except json.decoder.JSONDecodeError:
|
174
|
+
raise ValueError(f'Failed to parse response: {raw_response}')
|
175
|
+
else:
|
176
|
+
response = {}
|
177
|
+
if self._api_client.vertexai:
|
178
|
+
response_dict = self._LiveServerMessage_from_vertex(response)
|
179
|
+
else:
|
180
|
+
response_dict = self._LiveServerMessage_from_mldev(response)
|
181
|
+
|
182
|
+
return types.LiveServerMessage._from_response(
|
183
|
+
response_dict, parameter_model
|
184
|
+
)
|
185
|
+
|
186
|
+
async def _send_loop(
|
187
|
+
self,
|
188
|
+
data_stream: AsyncIterator[bytes],
|
189
|
+
mime_type: str,
|
190
|
+
stop_event: asyncio.Event,
|
191
|
+
):
|
192
|
+
async for data in data_stream:
|
193
|
+
input = {'data': data, 'mimeType': mime_type}
|
194
|
+
await self.send(input)
|
195
|
+
# Give a chance for the receive loop to process responses.
|
196
|
+
await asyncio.sleep(10**-12)
|
197
|
+
# Give a chance for the receiver to process the last response.
|
198
|
+
stop_event.set()
|
199
|
+
|
200
|
+
def _LiveServerContent_from_mldev(
|
201
|
+
self,
|
202
|
+
from_object: Union[dict, object],
|
203
|
+
) -> dict:
|
204
|
+
to_object = {}
|
205
|
+
if getv(from_object, ['modelTurn']) is not None:
|
206
|
+
setv(
|
207
|
+
to_object,
|
208
|
+
['model_turn'],
|
209
|
+
_Content_from_mldev(
|
210
|
+
self._api_client,
|
211
|
+
getv(from_object, ['modelTurn']),
|
212
|
+
),
|
213
|
+
)
|
214
|
+
if getv(from_object, ['turnComplete']) is not None:
|
215
|
+
setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
|
216
|
+
return to_object
|
217
|
+
|
218
|
+
def _LiveToolCall_from_mldev(
|
219
|
+
self,
|
220
|
+
from_object: Union[dict, object],
|
221
|
+
) -> dict:
|
222
|
+
to_object = {}
|
223
|
+
if getv(from_object, ['functionCalls']) is not None:
|
224
|
+
setv(
|
225
|
+
to_object,
|
226
|
+
['function_calls'],
|
227
|
+
getv(from_object, ['functionCalls']),
|
228
|
+
)
|
229
|
+
return to_object
|
230
|
+
|
231
|
+
def _LiveToolCall_from_vertex(
|
232
|
+
self,
|
233
|
+
from_object: Union[dict, object],
|
234
|
+
) -> dict:
|
235
|
+
to_object = {}
|
236
|
+
if getv(from_object, ['functionCalls']) is not None:
|
237
|
+
setv(
|
238
|
+
to_object,
|
239
|
+
['function_calls'],
|
240
|
+
getv(from_object, ['functionCalls']),
|
241
|
+
)
|
242
|
+
return to_object
|
243
|
+
|
244
|
+
def _LiveServerMessage_from_mldev(
|
245
|
+
self,
|
246
|
+
from_object: Union[dict, object],
|
247
|
+
) -> dict:
|
248
|
+
to_object = {}
|
249
|
+
if getv(from_object, ['serverContent']) is not None:
|
250
|
+
setv(
|
251
|
+
to_object,
|
252
|
+
['server_content'],
|
253
|
+
self._LiveServerContent_from_mldev(
|
254
|
+
getv(from_object, ['serverContent'])
|
255
|
+
),
|
256
|
+
)
|
257
|
+
if getv(from_object, ['toolCall']) is not None:
|
258
|
+
setv(
|
259
|
+
to_object,
|
260
|
+
['tool_call'],
|
261
|
+
self._LiveToolCall_from_mldev(getv(from_object, ['toolCall'])),
|
262
|
+
)
|
263
|
+
if getv(from_object, ['toolCallCancellation']) is not None:
|
264
|
+
setv(
|
265
|
+
to_object,
|
266
|
+
['tool_call_cancellation'],
|
267
|
+
getv(from_object, ['toolCallCancellation']),
|
268
|
+
)
|
269
|
+
return to_object
|
270
|
+
|
271
|
+
def _LiveServerContent_from_vertex(
|
272
|
+
self,
|
273
|
+
from_object: Union[dict, object],
|
274
|
+
) -> dict:
|
275
|
+
to_object = {}
|
276
|
+
if getv(from_object, ['modelTurn']) is not None:
|
277
|
+
setv(
|
278
|
+
to_object,
|
279
|
+
['model_turn'],
|
280
|
+
_Content_from_vertex(
|
281
|
+
self._api_client,
|
282
|
+
getv(from_object, ['modelTurn']),
|
283
|
+
),
|
284
|
+
)
|
285
|
+
if getv(from_object, ['turnComplete']) is not None:
|
286
|
+
setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
|
287
|
+
return to_object
|
288
|
+
|
289
|
+
def _LiveServerMessage_from_vertex(
|
290
|
+
self,
|
291
|
+
from_object: Union[dict, object],
|
292
|
+
) -> dict:
|
293
|
+
to_object = {}
|
294
|
+
if getv(from_object, ['serverContent']) is not None:
|
295
|
+
setv(
|
296
|
+
to_object,
|
297
|
+
['server_content'],
|
298
|
+
self._LiveServerContent_from_vertex(
|
299
|
+
getv(from_object, ['serverContent'])
|
300
|
+
),
|
301
|
+
)
|
302
|
+
|
303
|
+
if getv(from_object, ['toolCall']) is not None:
|
304
|
+
setv(
|
305
|
+
to_object,
|
306
|
+
['tool_call'],
|
307
|
+
self._LiveToolCall_from_vertex(getv(from_object, ['toolCall'])),
|
308
|
+
)
|
309
|
+
if getv(from_object, ['toolCallCancellation']) is not None:
|
310
|
+
setv(
|
311
|
+
to_object,
|
312
|
+
['tool_call_cancellation'],
|
313
|
+
getv(from_object, ['toolCallCancellation']),
|
314
|
+
)
|
315
|
+
return to_object
|
316
|
+
|
317
|
+
def _parse_client_message(
|
318
|
+
self,
|
319
|
+
input: Union[
|
320
|
+
types.ContentListUnion,
|
321
|
+
types.ContentListUnionDict,
|
322
|
+
types.LiveClientContentOrDict,
|
323
|
+
types.LiveClientRealtimeInputOrDict,
|
324
|
+
types.LiveClientRealtimeInputOrDict,
|
325
|
+
types.LiveClientToolResponseOrDict,
|
326
|
+
types.FunctionResponseOrDict,
|
327
|
+
Sequence[types.FunctionResponseOrDict],
|
328
|
+
],
|
329
|
+
end_of_turn: Optional[bool] = False,
|
330
|
+
) -> dict:
|
331
|
+
if isinstance(input, str):
|
332
|
+
input = [input]
|
333
|
+
elif (isinstance(input, dict) and 'data' in input):
|
334
|
+
if isinstance(input['data'], bytes):
|
335
|
+
decoded_data = base64.b64encode(input['data']).decode('utf-8')
|
336
|
+
input['data'] = decoded_data
|
337
|
+
input = [input]
|
338
|
+
elif isinstance(input, types.Blob):
|
339
|
+
input.data = base64.b64encode(input.data).decode('utf-8')
|
340
|
+
input = [input]
|
341
|
+
elif isinstance(input, dict) and 'name' in input and 'response' in input:
|
342
|
+
# ToolResponse.FunctionResponse
|
343
|
+
input = [input]
|
344
|
+
|
345
|
+
if isinstance(input, Sequence) and any(
|
346
|
+
isinstance(c, dict) and 'name' in c and 'response' in c for c in input
|
347
|
+
):
|
348
|
+
# ToolResponse.FunctionResponse
|
349
|
+
client_message = {'tool_response': {'function_responses': input}}
|
350
|
+
elif isinstance(input, Sequence) and any(isinstance(c, str) for c in input):
|
351
|
+
to_object = {}
|
352
|
+
if self._api_client.vertexai:
|
353
|
+
contents = [
|
354
|
+
_Content_to_vertex(self._api_client, item, to_object)
|
355
|
+
for item in t.t_contents(self._api_client, input)
|
356
|
+
]
|
357
|
+
else:
|
358
|
+
contents = [
|
359
|
+
_Content_to_mldev(self._api_client, item, to_object)
|
360
|
+
for item in t.t_contents(self._api_client, input)
|
361
|
+
]
|
362
|
+
|
363
|
+
client_message = {
|
364
|
+
'client_content': {'turns': contents, 'turn_complete': end_of_turn}
|
365
|
+
}
|
366
|
+
elif isinstance(input, Sequence):
|
367
|
+
if any((isinstance(b, dict) and 'data' in b) for b in input):
|
368
|
+
pass
|
369
|
+
elif any(isinstance(b, types.Blob) for b in input):
|
370
|
+
input = [b.model_dump(exclude_none=True) for b in input]
|
371
|
+
else:
|
372
|
+
raise ValueError(
|
373
|
+
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
374
|
+
)
|
375
|
+
|
376
|
+
client_message = {'realtime_input': {'media_chunks': input}}
|
377
|
+
|
378
|
+
elif isinstance(input, dict) and 'content' in input:
|
379
|
+
# TODO(b/365983264) Add validation checks for content_update input_dict.
|
380
|
+
client_message = {'client_content': input}
|
381
|
+
elif isinstance(input, types.LiveClientRealtimeInput):
|
382
|
+
client_message = {'realtime_input': input.model_dump(exclude_none=True)}
|
383
|
+
if isinstance(
|
384
|
+
client_message['realtime_input']['media_chunks'][0]['data'], bytes
|
385
|
+
):
|
386
|
+
client_message['realtime_input']['media_chunks'] = [
|
387
|
+
{
|
388
|
+
'data': base64.b64encode(item['data']).decode('utf-8'),
|
389
|
+
'mime_type': item['mime_type'],
|
390
|
+
}
|
391
|
+
for item in client_message['realtime_input']['media_chunks']
|
392
|
+
]
|
393
|
+
|
394
|
+
elif isinstance(input, types.LiveClientContent):
|
395
|
+
client_message = {'client_content': input.model_dump(exclude_none=True)}
|
396
|
+
elif isinstance(input, types.LiveClientToolResponse):
|
397
|
+
# ToolResponse.FunctionResponse
|
398
|
+
client_message = {'tool_response': input.model_dump(exclude_none=True)}
|
399
|
+
else:
|
400
|
+
raise ValueError(
|
401
|
+
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
402
|
+
)
|
403
|
+
|
404
|
+
return client_message
|
405
|
+
|
406
|
+
async def close(self):
|
407
|
+
# Close the websocket connection.
|
408
|
+
await self._ws.close()
|
409
|
+
|
410
|
+
|
411
|
+
class AsyncLive(_common.BaseModule):
|
412
|
+
"""AsyncLive."""
|
413
|
+
|
414
|
+
def _LiveSetup_to_mldev(
|
415
|
+
self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
|
416
|
+
):
|
417
|
+
if isinstance(config, types.LiveConnectConfig):
|
418
|
+
from_object = config.model_dump(exclude_none=True)
|
419
|
+
else:
|
420
|
+
from_object = config
|
421
|
+
|
422
|
+
to_object = {}
|
423
|
+
if getv(from_object, ['generation_config']) is not None:
|
424
|
+
setv(
|
425
|
+
to_object,
|
426
|
+
['generationConfig'],
|
427
|
+
_GenerateContentConfig_to_mldev(
|
428
|
+
self.api_client,
|
429
|
+
getv(from_object, ['generation_config']),
|
430
|
+
to_object,
|
431
|
+
),
|
432
|
+
)
|
433
|
+
if getv(from_object, ['response_modalities']) is not None:
|
434
|
+
if getv(to_object, ['generationConfig']) is not None:
|
435
|
+
to_object['generationConfig']['responseModalities'] = from_object[
|
436
|
+
'response_modalities'
|
437
|
+
]
|
438
|
+
else:
|
439
|
+
to_object['generationConfig'] = {
|
440
|
+
'responseModalities': from_object['response_modalities']
|
441
|
+
}
|
442
|
+
if getv(from_object, ['speech_config']) is not None:
|
443
|
+
if getv(to_object, ['generationConfig']) is not None:
|
444
|
+
to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
|
445
|
+
self.api_client,
|
446
|
+
t.t_speech_config(
|
447
|
+
self.api_client, getv(from_object, ['speech_config'])),
|
448
|
+
to_object,
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
to_object['generationConfig'] = {
|
452
|
+
'speechConfig': _SpeechConfig_to_mldev(
|
453
|
+
self.api_client,
|
454
|
+
t.t_speech_config(
|
455
|
+
self.api_client, getv(from_object, ['speech_config'])
|
456
|
+
),
|
457
|
+
to_object,
|
458
|
+
)
|
459
|
+
}
|
460
|
+
|
461
|
+
if getv(from_object, ['system_instruction']) is not None:
|
462
|
+
setv(
|
463
|
+
to_object,
|
464
|
+
['systemInstruction'],
|
465
|
+
_Content_to_mldev(
|
466
|
+
self.api_client,
|
467
|
+
t.t_content(
|
468
|
+
self.api_client, getv(from_object, ['system_instruction'])
|
469
|
+
),
|
470
|
+
to_object,
|
471
|
+
),
|
472
|
+
)
|
473
|
+
if getv(from_object, ['tools']) is not None:
|
474
|
+
setv(
|
475
|
+
to_object,
|
476
|
+
['tools'],
|
477
|
+
[
|
478
|
+
_Tool_to_mldev(self.api_client, item, to_object)
|
479
|
+
for item in getv(from_object, ['tools'])
|
480
|
+
],
|
481
|
+
)
|
482
|
+
|
483
|
+
return_value = {'setup': {'model': model}}
|
484
|
+
return_value['setup'].update(to_object)
|
485
|
+
return return_value
|
486
|
+
|
487
|
+
def _LiveSetup_to_vertex(
|
488
|
+
self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
|
489
|
+
):
|
490
|
+
if isinstance(config, types.LiveConnectConfig):
|
491
|
+
from_object = config.model_dump(exclude_none=True)
|
492
|
+
else:
|
493
|
+
from_object = config
|
494
|
+
|
495
|
+
to_object = {}
|
496
|
+
|
497
|
+
if getv(from_object, ['generation_config']) is not None:
|
498
|
+
setv(
|
499
|
+
to_object,
|
500
|
+
['generationConfig'],
|
501
|
+
_GenerateContentConfig_to_vertex(
|
502
|
+
self.api_client,
|
503
|
+
getv(from_object, ['generation_config']),
|
504
|
+
to_object,
|
505
|
+
),
|
506
|
+
)
|
507
|
+
if getv(from_object, ['response_modalities']) is not None:
|
508
|
+
if getv(to_object, ['generationConfig']) is not None:
|
509
|
+
to_object['generationConfig']['responseModalities'] = from_object[
|
510
|
+
'response_modalities'
|
511
|
+
]
|
512
|
+
else:
|
513
|
+
to_object['generationConfig'] = {
|
514
|
+
'responseModalities': from_object['response_modalities']
|
515
|
+
}
|
516
|
+
else:
|
517
|
+
# Set default to AUDIO to align with MLDev API.
|
518
|
+
if getv(to_object, ['generationConfig']) is not None:
|
519
|
+
to_object['generationConfig'].update({'responseModalities': ['AUDIO']})
|
520
|
+
else:
|
521
|
+
to_object.update(
|
522
|
+
{'generationConfig': {'responseModalities': ['AUDIO']}}
|
523
|
+
)
|
524
|
+
if getv(from_object, ['speech_config']) is not None:
|
525
|
+
if getv(to_object, ['generationConfig']) is not None:
|
526
|
+
to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
|
527
|
+
self.api_client,
|
528
|
+
t.t_speech_config(
|
529
|
+
self.api_client, getv(from_object, ['speech_config'])),
|
530
|
+
to_object,
|
531
|
+
)
|
532
|
+
else:
|
533
|
+
to_object['generationConfig'] = {
|
534
|
+
'speechConfig': _SpeechConfig_to_vertex(
|
535
|
+
self.api_client,
|
536
|
+
t.t_speech_config(
|
537
|
+
self.api_client, getv(from_object, ['speech_config'])
|
538
|
+
),
|
539
|
+
to_object,
|
540
|
+
)
|
541
|
+
}
|
542
|
+
if getv(from_object, ['system_instruction']) is not None:
|
543
|
+
setv(
|
544
|
+
to_object,
|
545
|
+
['systemInstruction'],
|
546
|
+
_Content_to_vertex(
|
547
|
+
self.api_client,
|
548
|
+
t.t_content(
|
549
|
+
self.api_client, getv(from_object, ['system_instruction'])
|
550
|
+
),
|
551
|
+
to_object,
|
552
|
+
),
|
553
|
+
)
|
554
|
+
if getv(from_object, ['tools']) is not None:
|
555
|
+
setv(
|
556
|
+
to_object,
|
557
|
+
['tools'],
|
558
|
+
[
|
559
|
+
_Tool_to_vertex(self.api_client, item, to_object)
|
560
|
+
for item in getv(from_object, ['tools'])
|
561
|
+
],
|
562
|
+
)
|
563
|
+
|
564
|
+
return_value = {'setup': {'model': model}}
|
565
|
+
return_value['setup'].update(to_object)
|
566
|
+
return return_value
|
567
|
+
|
568
|
+
@contextlib.asynccontextmanager
|
569
|
+
async def connect(
|
570
|
+
self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
|
571
|
+
) -> AsyncSession:
|
572
|
+
"""Connect to the live server.
|
573
|
+
|
574
|
+
Example usage:
|
575
|
+
```
|
576
|
+
client = genai.Client(api_key=API_KEY)
|
577
|
+
config = {}
|
578
|
+
|
579
|
+
async with client.aio.live.connect(model='gemini-1.0-pro-002', config=config) as session:
|
580
|
+
await session.send(input='Hello world!', end_of_turn=True)
|
581
|
+
async for message in session:
|
582
|
+
print(message)
|
583
|
+
```
|
584
|
+
"""
|
585
|
+
base_url = self.api_client._websocket_base_url()
|
586
|
+
if self.api_client.api_key:
|
587
|
+
api_key = self.api_client.api_key
|
588
|
+
version = self.api_client._http_options['api_version']
|
589
|
+
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
|
590
|
+
headers = self.api_client._http_options['headers']
|
591
|
+
|
592
|
+
transformed_model = t.t_model(self.api_client, model)
|
593
|
+
request = json.dumps(
|
594
|
+
self._LiveSetup_to_mldev(model=transformed_model, config=config)
|
595
|
+
)
|
596
|
+
else:
|
597
|
+
# Get bearer token through Application Default Credentials.
|
598
|
+
creds, _ = google.auth.default(
|
599
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
600
|
+
)
|
601
|
+
|
602
|
+
# creds.valid is False, and creds.token is None
|
603
|
+
# Need to refresh credentials to populate those
|
604
|
+
auth_req = google.auth.transport.requests.Request()
|
605
|
+
creds.refresh(auth_req)
|
606
|
+
bearer_token = creds.token
|
607
|
+
headers = {
|
608
|
+
'Content-Type': 'application/json',
|
609
|
+
'Authorization': 'Bearer {}'.format(bearer_token),
|
610
|
+
}
|
611
|
+
version = self.api_client._http_options['api_version']
|
612
|
+
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
|
613
|
+
location = self.api_client.location
|
614
|
+
project = self.api_client.project
|
615
|
+
transformed_model = t.t_model(self.api_client, model)
|
616
|
+
if transformed_model.startswith('publishers/'):
|
617
|
+
transformed_model = (
|
618
|
+
f'projects/{project}/locations/{location}/' + transformed_model
|
619
|
+
)
|
620
|
+
|
621
|
+
request = json.dumps(
|
622
|
+
self._LiveSetup_to_vertex(model=transformed_model, config=config)
|
623
|
+
)
|
624
|
+
|
625
|
+
async with connect(uri, additional_headers=headers) as ws:
|
626
|
+
await ws.send(request)
|
627
|
+
logging.info(await ws.recv(decode=False))
|
628
|
+
|
629
|
+
yield AsyncSession(api_client=self.api_client, websocket=ws)
|