google-genai 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py ADDED
@@ -0,0 +1,629 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ """Live client."""
17
+
18
+ import asyncio
19
+ import base64
20
+ import contextlib
21
+ import json
22
+ import logging
23
+ from typing import AsyncIterator, Optional, Sequence, Union
24
+
25
+ import google.auth
26
+ from websockets import ConnectionClosed
27
+
28
+ from . import _common
29
+ from . import _transformers as t
30
+ from . import client
31
+ from . import types
32
+ from ._api_client import ApiClient
33
+ from ._common import get_value_by_path as getv
34
+ from ._common import set_value_by_path as setv
35
+ from .models import _Content_from_mldev
36
+ from .models import _Content_from_vertex
37
+ from .models import _Content_to_mldev
38
+ from .models import _Content_to_vertex
39
+ from .models import _GenerateContentConfig_to_mldev
40
+ from .models import _GenerateContentConfig_to_vertex
41
+ from .models import _SafetySetting_to_mldev
42
+ from .models import _SafetySetting_to_vertex
43
+ from .models import _SpeechConfig_to_mldev
44
+ from .models import _SpeechConfig_to_vertex
45
+ from .models import _Tool_to_mldev
46
+ from .models import _Tool_to_vertex
47
+
48
+ try:
49
+ from websockets.asyncio.client import ClientConnection
50
+ from websockets.asyncio.client import connect
51
+ except ModuleNotFoundError:
52
+ from websockets.client import ClientConnection
53
+ from websockets.client import connect
54
+
55
+
56
+ class AsyncSession:
57
+ """AsyncSession."""
58
+
59
+ def __init__(self, api_client: client.ApiClient, websocket: ClientConnection):
60
+ self._api_client = api_client
61
+ self._ws = websocket
62
+
63
+ async def send(
64
+ self,
65
+ input: Union[
66
+ types.ContentListUnion,
67
+ types.ContentListUnionDict,
68
+ types.LiveClientContentOrDict,
69
+ types.LiveClientRealtimeInputOrDict,
70
+ types.LiveClientRealtimeInputOrDict,
71
+ types.LiveClientToolResponseOrDict,
72
+ types.FunctionResponseOrDict,
73
+ Sequence[types.FunctionResponseOrDict],
74
+ ],
75
+ end_of_turn: Optional[bool] = False,
76
+ ):
77
+ client_message = self._parse_client_message(input, end_of_turn)
78
+ await self._ws.send(json.dumps(client_message))
79
+
80
+ async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
81
+ """Receive model responses from the server.
82
+
83
+ The method will yield the model responses from the server. The returned
84
+ responses will represent a complete model turn.
85
+ when the returned message is fuction call, user must call `send` with the
86
+ function response to continue the turn.
87
+ Example usage:
88
+ ```
89
+ client = genai.Client(api_key=API_KEY)
90
+
91
+ async with client.aio.live.connect(model='...') as session:
92
+ await session.send(input='Hello world!', end_of_turn=True)
93
+ async for message in session.receive():
94
+ print(message)
95
+ ```
96
+ Yields:
97
+ The model responses from the server.
98
+ """
99
+ # TODO(b/365983264) Handle intermittent issues for the user.
100
+ while result := await self._receive():
101
+ if result.server_content and result.server_content.turn_complete:
102
+ yield result
103
+ break
104
+ yield result
105
+
106
+ async def start_stream(
107
+ self, stream: AsyncIterator[bytes], mime_type: str
108
+ ) -> AsyncIterator[types.LiveServerMessage]:
109
+ """start a live session from a data stream.
110
+
111
+ The interaction terminates when the input stream is complete.
112
+ This method will start two async tasks. One task will be used to send the
113
+ input stream to the model and the other task will be used to receive the
114
+ responses from the model.
115
+
116
+ Example usage:
117
+ ```
118
+ client = genai.Client(api_key=API_KEY)
119
+ config = {'response_modalities': ['AUDIO']}
120
+
121
+ async def audio_stream():
122
+ stream = read_audio()
123
+ for data in stream:
124
+ yield data
125
+
126
+ async with client.aio.live.connect(model='...') as session:
127
+ for audio in session.start_stream(stream = audio_stream(),
128
+ mime_type = 'audio/pcm'):
129
+ play_audio_chunk(audio.data)
130
+ ```
131
+
132
+ Args:
133
+ stream: An iterator that yields the model response.
134
+ mime_type: The MIME type of the data in the stream.
135
+
136
+ Yields:
137
+ The audio bytes received from the model and server response messages.
138
+ """
139
+ stop_event = asyncio.Event()
140
+ # Start the send loop. When stream is complete stop_event is set.
141
+ asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
142
+ recv_task = None
143
+ while not stop_event.is_set():
144
+ try:
145
+ recv_task = asyncio.create_task(self._receive())
146
+ await asyncio.wait(
147
+ [
148
+ recv_task,
149
+ asyncio.create_task(stop_event.wait()),
150
+ ],
151
+ return_when=asyncio.FIRST_COMPLETED,
152
+ )
153
+ if recv_task.done():
154
+ yield recv_task.result()
155
+ # Give a chance for the send loop to process requests.
156
+ await asyncio.sleep(10**-12)
157
+ except ConnectionClosed:
158
+ break
159
+ if recv_task is not None and not recv_task.done():
160
+ recv_task.cancel()
161
+ # Wait for the task to finish (cancelled or not)
162
+ try:
163
+ await recv_task
164
+ except asyncio.CancelledError:
165
+ pass
166
+
167
+ async def _receive(self) -> types.LiveServerMessage:
168
+ parameter_model = types.LiveServerMessage()
169
+ raw_response = await self._ws.recv(decode=False)
170
+ if raw_response:
171
+ try:
172
+ response = json.loads(raw_response)
173
+ except json.decoder.JSONDecodeError:
174
+ raise ValueError(f'Failed to parse response: {raw_response}')
175
+ else:
176
+ response = {}
177
+ if self._api_client.vertexai:
178
+ response_dict = self._LiveServerMessage_from_vertex(response)
179
+ else:
180
+ response_dict = self._LiveServerMessage_from_mldev(response)
181
+
182
+ return types.LiveServerMessage._from_response(
183
+ response_dict, parameter_model
184
+ )
185
+
186
+ async def _send_loop(
187
+ self,
188
+ data_stream: AsyncIterator[bytes],
189
+ mime_type: str,
190
+ stop_event: asyncio.Event,
191
+ ):
192
+ async for data in data_stream:
193
+ input = {'data': data, 'mimeType': mime_type}
194
+ await self.send(input)
195
+ # Give a chance for the receive loop to process responses.
196
+ await asyncio.sleep(10**-12)
197
+ # Give a chance for the receiver to process the last response.
198
+ stop_event.set()
199
+
200
+ def _LiveServerContent_from_mldev(
201
+ self,
202
+ from_object: Union[dict, object],
203
+ ) -> dict:
204
+ to_object = {}
205
+ if getv(from_object, ['modelTurn']) is not None:
206
+ setv(
207
+ to_object,
208
+ ['model_turn'],
209
+ _Content_from_mldev(
210
+ self._api_client,
211
+ getv(from_object, ['modelTurn']),
212
+ ),
213
+ )
214
+ if getv(from_object, ['turnComplete']) is not None:
215
+ setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
216
+ return to_object
217
+
218
+ def _LiveToolCall_from_mldev(
219
+ self,
220
+ from_object: Union[dict, object],
221
+ ) -> dict:
222
+ to_object = {}
223
+ if getv(from_object, ['functionCalls']) is not None:
224
+ setv(
225
+ to_object,
226
+ ['function_calls'],
227
+ getv(from_object, ['functionCalls']),
228
+ )
229
+ return to_object
230
+
231
+ def _LiveToolCall_from_vertex(
232
+ self,
233
+ from_object: Union[dict, object],
234
+ ) -> dict:
235
+ to_object = {}
236
+ if getv(from_object, ['functionCalls']) is not None:
237
+ setv(
238
+ to_object,
239
+ ['function_calls'],
240
+ getv(from_object, ['functionCalls']),
241
+ )
242
+ return to_object
243
+
244
+ def _LiveServerMessage_from_mldev(
245
+ self,
246
+ from_object: Union[dict, object],
247
+ ) -> dict:
248
+ to_object = {}
249
+ if getv(from_object, ['serverContent']) is not None:
250
+ setv(
251
+ to_object,
252
+ ['server_content'],
253
+ self._LiveServerContent_from_mldev(
254
+ getv(from_object, ['serverContent'])
255
+ ),
256
+ )
257
+ if getv(from_object, ['toolCall']) is not None:
258
+ setv(
259
+ to_object,
260
+ ['tool_call'],
261
+ self._LiveToolCall_from_mldev(getv(from_object, ['toolCall'])),
262
+ )
263
+ if getv(from_object, ['toolCallCancellation']) is not None:
264
+ setv(
265
+ to_object,
266
+ ['tool_call_cancellation'],
267
+ getv(from_object, ['toolCallCancellation']),
268
+ )
269
+ return to_object
270
+
271
+ def _LiveServerContent_from_vertex(
272
+ self,
273
+ from_object: Union[dict, object],
274
+ ) -> dict:
275
+ to_object = {}
276
+ if getv(from_object, ['modelTurn']) is not None:
277
+ setv(
278
+ to_object,
279
+ ['model_turn'],
280
+ _Content_from_vertex(
281
+ self._api_client,
282
+ getv(from_object, ['modelTurn']),
283
+ ),
284
+ )
285
+ if getv(from_object, ['turnComplete']) is not None:
286
+ setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
287
+ return to_object
288
+
289
+ def _LiveServerMessage_from_vertex(
290
+ self,
291
+ from_object: Union[dict, object],
292
+ ) -> dict:
293
+ to_object = {}
294
+ if getv(from_object, ['serverContent']) is not None:
295
+ setv(
296
+ to_object,
297
+ ['server_content'],
298
+ self._LiveServerContent_from_vertex(
299
+ getv(from_object, ['serverContent'])
300
+ ),
301
+ )
302
+
303
+ if getv(from_object, ['toolCall']) is not None:
304
+ setv(
305
+ to_object,
306
+ ['tool_call'],
307
+ self._LiveToolCall_from_vertex(getv(from_object, ['toolCall'])),
308
+ )
309
+ if getv(from_object, ['toolCallCancellation']) is not None:
310
+ setv(
311
+ to_object,
312
+ ['tool_call_cancellation'],
313
+ getv(from_object, ['toolCallCancellation']),
314
+ )
315
+ return to_object
316
+
317
+ def _parse_client_message(
318
+ self,
319
+ input: Union[
320
+ types.ContentListUnion,
321
+ types.ContentListUnionDict,
322
+ types.LiveClientContentOrDict,
323
+ types.LiveClientRealtimeInputOrDict,
324
+ types.LiveClientRealtimeInputOrDict,
325
+ types.LiveClientToolResponseOrDict,
326
+ types.FunctionResponseOrDict,
327
+ Sequence[types.FunctionResponseOrDict],
328
+ ],
329
+ end_of_turn: Optional[bool] = False,
330
+ ) -> dict:
331
+ if isinstance(input, str):
332
+ input = [input]
333
+ elif (isinstance(input, dict) and 'data' in input):
334
+ if isinstance(input['data'], bytes):
335
+ decoded_data = base64.b64encode(input['data']).decode('utf-8')
336
+ input['data'] = decoded_data
337
+ input = [input]
338
+ elif isinstance(input, types.Blob):
339
+ input.data = base64.b64encode(input.data).decode('utf-8')
340
+ input = [input]
341
+ elif isinstance(input, dict) and 'name' in input and 'response' in input:
342
+ # ToolResponse.FunctionResponse
343
+ input = [input]
344
+
345
+ if isinstance(input, Sequence) and any(
346
+ isinstance(c, dict) and 'name' in c and 'response' in c for c in input
347
+ ):
348
+ # ToolResponse.FunctionResponse
349
+ client_message = {'tool_response': {'function_responses': input}}
350
+ elif isinstance(input, Sequence) and any(isinstance(c, str) for c in input):
351
+ to_object = {}
352
+ if self._api_client.vertexai:
353
+ contents = [
354
+ _Content_to_vertex(self._api_client, item, to_object)
355
+ for item in t.t_contents(self._api_client, input)
356
+ ]
357
+ else:
358
+ contents = [
359
+ _Content_to_mldev(self._api_client, item, to_object)
360
+ for item in t.t_contents(self._api_client, input)
361
+ ]
362
+
363
+ client_message = {
364
+ 'client_content': {'turns': contents, 'turn_complete': end_of_turn}
365
+ }
366
+ elif isinstance(input, Sequence):
367
+ if any((isinstance(b, dict) and 'data' in b) for b in input):
368
+ pass
369
+ elif any(isinstance(b, types.Blob) for b in input):
370
+ input = [b.model_dump(exclude_none=True) for b in input]
371
+ else:
372
+ raise ValueError(
373
+ f'Unsupported input type "{type(input)}" or input content "{input}"'
374
+ )
375
+
376
+ client_message = {'realtime_input': {'media_chunks': input}}
377
+
378
+ elif isinstance(input, dict) and 'content' in input:
379
+ # TODO(b/365983264) Add validation checks for content_update input_dict.
380
+ client_message = {'client_content': input}
381
+ elif isinstance(input, types.LiveClientRealtimeInput):
382
+ client_message = {'realtime_input': input.model_dump(exclude_none=True)}
383
+ if isinstance(
384
+ client_message['realtime_input']['media_chunks'][0]['data'], bytes
385
+ ):
386
+ client_message['realtime_input']['media_chunks'] = [
387
+ {
388
+ 'data': base64.b64encode(item['data']).decode('utf-8'),
389
+ 'mime_type': item['mime_type'],
390
+ }
391
+ for item in client_message['realtime_input']['media_chunks']
392
+ ]
393
+
394
+ elif isinstance(input, types.LiveClientContent):
395
+ client_message = {'client_content': input.model_dump(exclude_none=True)}
396
+ elif isinstance(input, types.LiveClientToolResponse):
397
+ # ToolResponse.FunctionResponse
398
+ client_message = {'tool_response': input.model_dump(exclude_none=True)}
399
+ else:
400
+ raise ValueError(
401
+ f'Unsupported input type "{type(input)}" or input content "{input}"'
402
+ )
403
+
404
+ return client_message
405
+
406
+ async def close(self):
407
+ # Close the websocket connection.
408
+ await self._ws.close()
409
+
410
+
411
+ class AsyncLive(_common.BaseModule):
412
+ """AsyncLive."""
413
+
414
+ def _LiveSetup_to_mldev(
415
+ self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
416
+ ):
417
+ if isinstance(config, types.LiveConnectConfig):
418
+ from_object = config.model_dump(exclude_none=True)
419
+ else:
420
+ from_object = config
421
+
422
+ to_object = {}
423
+ if getv(from_object, ['generation_config']) is not None:
424
+ setv(
425
+ to_object,
426
+ ['generationConfig'],
427
+ _GenerateContentConfig_to_mldev(
428
+ self.api_client,
429
+ getv(from_object, ['generation_config']),
430
+ to_object,
431
+ ),
432
+ )
433
+ if getv(from_object, ['response_modalities']) is not None:
434
+ if getv(to_object, ['generationConfig']) is not None:
435
+ to_object['generationConfig']['responseModalities'] = from_object[
436
+ 'response_modalities'
437
+ ]
438
+ else:
439
+ to_object['generationConfig'] = {
440
+ 'responseModalities': from_object['response_modalities']
441
+ }
442
+ if getv(from_object, ['speech_config']) is not None:
443
+ if getv(to_object, ['generationConfig']) is not None:
444
+ to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
445
+ self.api_client,
446
+ t.t_speech_config(
447
+ self.api_client, getv(from_object, ['speech_config'])),
448
+ to_object,
449
+ )
450
+ else:
451
+ to_object['generationConfig'] = {
452
+ 'speechConfig': _SpeechConfig_to_mldev(
453
+ self.api_client,
454
+ t.t_speech_config(
455
+ self.api_client, getv(from_object, ['speech_config'])
456
+ ),
457
+ to_object,
458
+ )
459
+ }
460
+
461
+ if getv(from_object, ['system_instruction']) is not None:
462
+ setv(
463
+ to_object,
464
+ ['systemInstruction'],
465
+ _Content_to_mldev(
466
+ self.api_client,
467
+ t.t_content(
468
+ self.api_client, getv(from_object, ['system_instruction'])
469
+ ),
470
+ to_object,
471
+ ),
472
+ )
473
+ if getv(from_object, ['tools']) is not None:
474
+ setv(
475
+ to_object,
476
+ ['tools'],
477
+ [
478
+ _Tool_to_mldev(self.api_client, item, to_object)
479
+ for item in getv(from_object, ['tools'])
480
+ ],
481
+ )
482
+
483
+ return_value = {'setup': {'model': model}}
484
+ return_value['setup'].update(to_object)
485
+ return return_value
486
+
487
+ def _LiveSetup_to_vertex(
488
+ self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
489
+ ):
490
+ if isinstance(config, types.LiveConnectConfig):
491
+ from_object = config.model_dump(exclude_none=True)
492
+ else:
493
+ from_object = config
494
+
495
+ to_object = {}
496
+
497
+ if getv(from_object, ['generation_config']) is not None:
498
+ setv(
499
+ to_object,
500
+ ['generationConfig'],
501
+ _GenerateContentConfig_to_vertex(
502
+ self.api_client,
503
+ getv(from_object, ['generation_config']),
504
+ to_object,
505
+ ),
506
+ )
507
+ if getv(from_object, ['response_modalities']) is not None:
508
+ if getv(to_object, ['generationConfig']) is not None:
509
+ to_object['generationConfig']['responseModalities'] = from_object[
510
+ 'response_modalities'
511
+ ]
512
+ else:
513
+ to_object['generationConfig'] = {
514
+ 'responseModalities': from_object['response_modalities']
515
+ }
516
+ else:
517
+ # Set default to AUDIO to align with MLDev API.
518
+ if getv(to_object, ['generationConfig']) is not None:
519
+ to_object['generationConfig'].update({'responseModalities': ['AUDIO']})
520
+ else:
521
+ to_object.update(
522
+ {'generationConfig': {'responseModalities': ['AUDIO']}}
523
+ )
524
+ if getv(from_object, ['speech_config']) is not None:
525
+ if getv(to_object, ['generationConfig']) is not None:
526
+ to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
527
+ self.api_client,
528
+ t.t_speech_config(
529
+ self.api_client, getv(from_object, ['speech_config'])),
530
+ to_object,
531
+ )
532
+ else:
533
+ to_object['generationConfig'] = {
534
+ 'speechConfig': _SpeechConfig_to_vertex(
535
+ self.api_client,
536
+ t.t_speech_config(
537
+ self.api_client, getv(from_object, ['speech_config'])
538
+ ),
539
+ to_object,
540
+ )
541
+ }
542
+ if getv(from_object, ['system_instruction']) is not None:
543
+ setv(
544
+ to_object,
545
+ ['systemInstruction'],
546
+ _Content_to_vertex(
547
+ self.api_client,
548
+ t.t_content(
549
+ self.api_client, getv(from_object, ['system_instruction'])
550
+ ),
551
+ to_object,
552
+ ),
553
+ )
554
+ if getv(from_object, ['tools']) is not None:
555
+ setv(
556
+ to_object,
557
+ ['tools'],
558
+ [
559
+ _Tool_to_vertex(self.api_client, item, to_object)
560
+ for item in getv(from_object, ['tools'])
561
+ ],
562
+ )
563
+
564
+ return_value = {'setup': {'model': model}}
565
+ return_value['setup'].update(to_object)
566
+ return return_value
567
+
568
+ @contextlib.asynccontextmanager
569
+ async def connect(
570
+ self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
571
+ ) -> AsyncSession:
572
+ """Connect to the live server.
573
+
574
+ Example usage:
575
+ ```
576
+ client = genai.Client(api_key=API_KEY)
577
+ config = {}
578
+
579
+ async with client.aio.live.connect(model='gemini-1.0-pro-002', config=config) as session:
580
+ await session.send(input='Hello world!', end_of_turn=True)
581
+ async for message in session:
582
+ print(message)
583
+ ```
584
+ """
585
+ base_url = self.api_client._websocket_base_url()
586
+ if self.api_client.api_key:
587
+ api_key = self.api_client.api_key
588
+ version = self.api_client._http_options['api_version']
589
+ uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
590
+ headers = self.api_client._http_options['headers']
591
+
592
+ transformed_model = t.t_model(self.api_client, model)
593
+ request = json.dumps(
594
+ self._LiveSetup_to_mldev(model=transformed_model, config=config)
595
+ )
596
+ else:
597
+ # Get bearer token through Application Default Credentials.
598
+ creds, _ = google.auth.default(
599
+ scopes=['https://www.googleapis.com/auth/cloud-platform']
600
+ )
601
+
602
+ # creds.valid is False, and creds.token is None
603
+ # Need to refresh credentials to populate those
604
+ auth_req = google.auth.transport.requests.Request()
605
+ creds.refresh(auth_req)
606
+ bearer_token = creds.token
607
+ headers = {
608
+ 'Content-Type': 'application/json',
609
+ 'Authorization': 'Bearer {}'.format(bearer_token),
610
+ }
611
+ version = self.api_client._http_options['api_version']
612
+ uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
613
+ location = self.api_client.location
614
+ project = self.api_client.project
615
+ transformed_model = t.t_model(self.api_client, model)
616
+ if transformed_model.startswith('publishers/'):
617
+ transformed_model = (
618
+ f'projects/{project}/locations/{location}/' + transformed_model
619
+ )
620
+
621
+ request = json.dumps(
622
+ self._LiveSetup_to_vertex(model=transformed_model, config=config)
623
+ )
624
+
625
+ async with connect(uri, additional_headers=headers) as ws:
626
+ await ws.send(request)
627
+ logging.info(await ws.recv(decode=False))
628
+
629
+ yield AsyncSession(api_client=self.api_client, websocket=ws)