google-genai 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py CHANGED
@@ -20,7 +20,7 @@ import base64
20
20
  import contextlib
21
21
  import json
22
22
  import logging
23
- from typing import AsyncIterator, Optional, Sequence, Union
23
+ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union
24
24
 
25
25
  import google.auth
26
26
  from websockets import ConnectionClosed
@@ -31,7 +31,7 @@ from . import _transformers as t
31
31
  from . import client
32
32
  from . import errors
33
33
  from . import types
34
- from ._api_client import ApiClient
34
+ from ._api_client import BaseApiClient
35
35
  from ._common import experimental_warning
36
36
  from ._common import get_value_by_path as getv
37
37
  from ._common import set_value_by_path as setv
@@ -49,12 +49,14 @@ from .models import _Tool_to_mldev
49
49
  from .models import _Tool_to_vertex
50
50
 
51
51
  try:
52
- from websockets.asyncio.client import ClientConnection
53
- from websockets.asyncio.client import connect
52
+ from websockets.asyncio.client import ClientConnection # type: ignore
53
+ from websockets.asyncio.client import connect # type: ignore
54
54
  except ModuleNotFoundError:
55
- from websockets.client import ClientConnection
56
- from websockets.client import connect
55
+ # This try/except is for TAP, mypy complains about it which is why we have the type: ignore
56
+ from websockets.client import ClientConnection # type: ignore
57
+ from websockets.client import connect # type: ignore
57
58
 
59
+ logger = logging.getLogger('google_genai.live')
58
60
 
59
61
  _FUNCTION_RESPONSE_REQUIRES_ID = (
60
62
  'FunctionResponse request must have an `id` field from the'
@@ -65,22 +67,26 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
65
67
  class AsyncSession:
66
68
  """AsyncSession. The live module is experimental."""
67
69
 
68
- def __init__(self, api_client: client.ApiClient, websocket: ClientConnection):
70
+ def __init__(
71
+ self, api_client: client.BaseApiClient, websocket: ClientConnection
72
+ ):
69
73
  self._api_client = api_client
70
74
  self._ws = websocket
71
75
 
72
76
  async def send(
73
77
  self,
74
78
  *,
75
- input: Union[
76
- types.ContentListUnion,
77
- types.ContentListUnionDict,
78
- types.LiveClientContentOrDict,
79
- types.LiveClientRealtimeInputOrDict,
80
- types.LiveClientToolResponseOrDict,
81
- types.FunctionResponseOrDict,
82
- Sequence[types.FunctionResponseOrDict],
83
- ],
79
+ input: Optional[
80
+ Union[
81
+ types.ContentListUnion,
82
+ types.ContentListUnionDict,
83
+ types.LiveClientContentOrDict,
84
+ types.LiveClientRealtimeInputOrDict,
85
+ types.LiveClientToolResponseOrDict,
86
+ types.FunctionResponseOrDict,
87
+ Sequence[types.FunctionResponseOrDict],
88
+ ]
89
+ ] = None,
84
90
  end_of_turn: Optional[bool] = False,
85
91
  ):
86
92
  """Send input to the model.
@@ -214,7 +220,7 @@ class AsyncSession:
214
220
  response_dict = self._LiveServerMessage_from_mldev(response)
215
221
 
216
222
  return types.LiveServerMessage._from_response(
217
- response_dict, parameter_model
223
+ response=response_dict, kwargs=parameter_model
218
224
  )
219
225
 
220
226
  async def _send_loop(
@@ -234,8 +240,8 @@ class AsyncSession:
234
240
  def _LiveServerContent_from_mldev(
235
241
  self,
236
242
  from_object: Union[dict, object],
237
- ) -> dict:
238
- to_object = {}
243
+ ) -> Dict[str, Any]:
244
+ to_object: dict[str, Any] = {}
239
245
  if getv(from_object, ['modelTurn']) is not None:
240
246
  setv(
241
247
  to_object,
@@ -254,8 +260,8 @@ class AsyncSession:
254
260
  def _LiveToolCall_from_mldev(
255
261
  self,
256
262
  from_object: Union[dict, object],
257
- ) -> dict:
258
- to_object = {}
263
+ ) -> Dict[str, Any]:
264
+ to_object: dict[str, Any] = {}
259
265
  if getv(from_object, ['functionCalls']) is not None:
260
266
  setv(
261
267
  to_object,
@@ -267,8 +273,8 @@ class AsyncSession:
267
273
  def _LiveToolCall_from_vertex(
268
274
  self,
269
275
  from_object: Union[dict, object],
270
- ) -> dict:
271
- to_object = {}
276
+ ) -> Dict[str, Any]:
277
+ to_object: dict[str, Any] = {}
272
278
  if getv(from_object, ['functionCalls']) is not None:
273
279
  setv(
274
280
  to_object,
@@ -280,8 +286,8 @@ class AsyncSession:
280
286
  def _LiveServerMessage_from_mldev(
281
287
  self,
282
288
  from_object: Union[dict, object],
283
- ) -> dict:
284
- to_object = {}
289
+ ) -> Dict[str, Any]:
290
+ to_object: dict[str, Any] = {}
285
291
  if getv(from_object, ['serverContent']) is not None:
286
292
  setv(
287
293
  to_object,
@@ -307,8 +313,8 @@ class AsyncSession:
307
313
  def _LiveServerContent_from_vertex(
308
314
  self,
309
315
  from_object: Union[dict, object],
310
- ) -> dict:
311
- to_object = {}
316
+ ) -> Dict[str, Any]:
317
+ to_object: dict[str, Any] = {}
312
318
  if getv(from_object, ['modelTurn']) is not None:
313
319
  setv(
314
320
  to_object,
@@ -327,8 +333,8 @@ class AsyncSession:
327
333
  def _LiveServerMessage_from_vertex(
328
334
  self,
329
335
  from_object: Union[dict, object],
330
- ) -> dict:
331
- to_object = {}
336
+ ) -> Dict[str, Any]:
337
+ to_object: dict[str, Any] = {}
332
338
  if getv(from_object, ['serverContent']) is not None:
333
339
  setv(
334
340
  to_object,
@@ -354,18 +360,23 @@ class AsyncSession:
354
360
 
355
361
  def _parse_client_message(
356
362
  self,
357
- input: Union[
358
- types.ContentListUnion,
359
- types.ContentListUnionDict,
360
- types.LiveClientContentOrDict,
361
- types.LiveClientRealtimeInputOrDict,
362
- types.LiveClientRealtimeInputOrDict,
363
- types.LiveClientToolResponseOrDict,
364
- types.FunctionResponseOrDict,
365
- Sequence[types.FunctionResponseOrDict],
366
- ],
363
+ input: Optional[
364
+ Union[
365
+ types.ContentListUnion,
366
+ types.ContentListUnionDict,
367
+ types.LiveClientContentOrDict,
368
+ types.LiveClientRealtimeInputOrDict,
369
+ types.LiveClientToolResponseOrDict,
370
+ types.FunctionResponseOrDict,
371
+ Sequence[types.FunctionResponseOrDict],
372
+ ]
373
+ ] = None,
367
374
  end_of_turn: Optional[bool] = False,
368
- ) -> dict:
375
+ ) -> Dict[str, Any]:
376
+
377
+ if not input:
378
+ logging.info('No input provided. Assume it is the end of turn.')
379
+ return {'client_content': {'turn_complete': True}}
369
380
  if isinstance(input, str):
370
381
  input = [input]
371
382
  elif isinstance(input, dict) and 'data' in input:
@@ -374,7 +385,6 @@ class AsyncSession:
374
385
  input['data'] = decoded_data
375
386
  input = [input]
376
387
  elif isinstance(input, types.Blob):
377
- input.data = base64.b64encode(input.data).decode('utf-8')
378
388
  input = [input]
379
389
  elif isinstance(input, dict) and 'name' in input and 'response' in input:
380
390
  # ToolResponse.FunctionResponse
@@ -392,7 +402,7 @@ class AsyncSession:
392
402
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
393
403
  client_message = {'tool_response': {'function_responses': input}}
394
404
  elif isinstance(input, Sequence) and any(isinstance(c, str) for c in input):
395
- to_object = {}
405
+ to_object: dict[str, Any] = {}
396
406
  if self._api_client.vertexai:
397
407
  contents = [
398
408
  _Content_to_vertex(self._api_client, item, to_object)
@@ -411,7 +421,7 @@ class AsyncSession:
411
421
  if any((isinstance(b, dict) and 'data' in b) for b in input):
412
422
  pass
413
423
  elif any(isinstance(b, types.Blob) for b in input):
414
- input = [b.model_dump(exclude_none=True) for b in input]
424
+ input = [b.model_dump(exclude_none=True, mode='json') for b in input]
415
425
  else:
416
426
  raise ValueError(
417
427
  f'Unsupported input type "{type(input)}" or input content "{input}"'
@@ -419,11 +429,21 @@ class AsyncSession:
419
429
 
420
430
  client_message = {'realtime_input': {'media_chunks': input}}
421
431
 
422
- elif isinstance(input, dict) and 'content' in input:
423
- # TODO(b/365983264) Add validation checks for content_update input_dict.
424
- client_message = {'client_content': input}
432
+ elif isinstance(input, dict):
433
+ if 'content' in input or 'turns' in input:
434
+ # TODO(b/365983264) Add validation checks for content_update input_dict.
435
+ client_message = {'client_content': input}
436
+ elif 'media_chunks' in input:
437
+ client_message = {'realtime_input': input}
438
+ elif 'function_responses' in input:
439
+ client_message = {'tool_response': input}
440
+ else:
441
+ raise ValueError(
442
+ f'Unsupported input type "{type(input)}" or input content "{input}"')
425
443
  elif isinstance(input, types.LiveClientRealtimeInput):
426
- client_message = {'realtime_input': input.model_dump(exclude_none=True)}
444
+ client_message = {
445
+ 'realtime_input': input.model_dump(exclude_none=True, mode='json')
446
+ }
427
447
  if isinstance(
428
448
  client_message['realtime_input']['media_chunks'][0]['data'], bytes
429
449
  ):
@@ -436,20 +456,26 @@ class AsyncSession:
436
456
  ]
437
457
 
438
458
  elif isinstance(input, types.LiveClientContent):
439
- client_message = {'client_content': input.model_dump(exclude_none=True)}
459
+ client_message = {
460
+ 'client_content': input.model_dump(exclude_none=True, mode='json')
461
+ }
440
462
  elif isinstance(input, types.LiveClientToolResponse):
441
463
  # ToolResponse.FunctionResponse
442
464
  if not (self._api_client.vertexai) and not (
443
465
  input.function_responses[0].id
444
466
  ):
445
467
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
446
- client_message = {'tool_response': input.model_dump(exclude_none=True)}
468
+ client_message = {
469
+ 'tool_response': input.model_dump(exclude_none=True, mode='json')
470
+ }
447
471
  elif isinstance(input, types.FunctionResponse):
448
472
  if not (self._api_client.vertexai) and not (input.id):
449
473
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
450
474
  client_message = {
451
475
  'tool_response': {
452
- 'function_responses': [input.model_dump(exclude_none=True)]
476
+ 'function_responses': [
477
+ input.model_dump(exclude_none=True, mode='json')
478
+ ]
453
479
  }
454
480
  }
455
481
  elif isinstance(input, Sequence) and isinstance(
@@ -460,7 +486,7 @@ class AsyncSession:
460
486
  client_message = {
461
487
  'tool_response': {
462
488
  'function_responses': [
463
- c.model_dump(exclude_none=True) for c in input
489
+ c.model_dump(exclude_none=True, mode='json') for c in input
464
490
  ]
465
491
  }
466
492
  }
@@ -480,39 +506,35 @@ class AsyncLive(_api_module.BaseModule):
480
506
  """AsyncLive. The live module is experimental."""
481
507
 
482
508
  def _LiveSetup_to_mldev(
483
- self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
509
+ self, model: str, config: Optional[types.LiveConnectConfig] = None
484
510
  ):
485
- if isinstance(config, types.LiveConnectConfig):
486
- from_object = config.model_dump(exclude_none=True)
487
- else:
488
- from_object = config
489
511
 
490
- to_object = {}
491
- if getv(from_object, ['generation_config']) is not None:
512
+ to_object: dict[str, Any] = {}
513
+ if getv(config, ['generation_config']) is not None:
492
514
  setv(
493
515
  to_object,
494
516
  ['generationConfig'],
495
517
  _GenerateContentConfig_to_mldev(
496
518
  self._api_client,
497
- getv(from_object, ['generation_config']),
519
+ getv(config, ['generation_config']),
498
520
  to_object,
499
521
  ),
500
522
  )
501
- if getv(from_object, ['response_modalities']) is not None:
523
+ if getv(config, ['response_modalities']) is not None:
502
524
  if getv(to_object, ['generationConfig']) is not None:
503
- to_object['generationConfig']['responseModalities'] = from_object[
504
- 'response_modalities'
505
- ]
525
+ to_object['generationConfig']['responseModalities'] = getv(
526
+ config, ['response_modalities']
527
+ )
506
528
  else:
507
529
  to_object['generationConfig'] = {
508
- 'responseModalities': from_object['response_modalities']
530
+ 'responseModalities': getv(config, ['response_modalities'])
509
531
  }
510
- if getv(from_object, ['speech_config']) is not None:
532
+ if getv(config, ['speech_config']) is not None:
511
533
  if getv(to_object, ['generationConfig']) is not None:
512
534
  to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
513
535
  self._api_client,
514
536
  t.t_speech_config(
515
- self._api_client, getv(from_object, ['speech_config'])
537
+ self._api_client, getv(config, ['speech_config'])
516
538
  ),
517
539
  to_object,
518
540
  )
@@ -521,31 +543,33 @@ class AsyncLive(_api_module.BaseModule):
521
543
  'speechConfig': _SpeechConfig_to_mldev(
522
544
  self._api_client,
523
545
  t.t_speech_config(
524
- self._api_client, getv(from_object, ['speech_config'])
546
+ self._api_client, getv(config, ['speech_config'])
525
547
  ),
526
548
  to_object,
527
549
  )
528
550
  }
529
551
 
530
- if getv(from_object, ['system_instruction']) is not None:
552
+ if getv(config, ['system_instruction']) is not None:
531
553
  setv(
532
554
  to_object,
533
555
  ['systemInstruction'],
534
556
  _Content_to_mldev(
535
557
  self._api_client,
536
558
  t.t_content(
537
- self._api_client, getv(from_object, ['system_instruction'])
559
+ self._api_client, getv(config, ['system_instruction'])
538
560
  ),
539
561
  to_object,
540
562
  ),
541
563
  )
542
- if getv(from_object, ['tools']) is not None:
564
+ if getv(config, ['tools']) is not None:
543
565
  setv(
544
566
  to_object,
545
567
  ['tools'],
546
568
  [
547
- _Tool_to_mldev(self._api_client, item, to_object)
548
- for item in getv(from_object, ['tools'])
569
+ _Tool_to_mldev(
570
+ self._api_client, t.t_tool(self._api_client, item), to_object
571
+ )
572
+ for item in t.t_tools(self._api_client, getv(config, ['tools']))
549
573
  ],
550
574
  )
551
575
 
@@ -554,33 +578,29 @@ class AsyncLive(_api_module.BaseModule):
554
578
  return return_value
555
579
 
556
580
  def _LiveSetup_to_vertex(
557
- self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
581
+ self, model: str, config: Optional[types.LiveConnectConfig] = None
558
582
  ):
559
- if isinstance(config, types.LiveConnectConfig):
560
- from_object = config.model_dump(exclude_none=True)
561
- else:
562
- from_object = config
563
583
 
564
- to_object = {}
584
+ to_object: dict[str, Any] = {}
565
585
 
566
- if getv(from_object, ['generation_config']) is not None:
586
+ if getv(config, ['generation_config']) is not None:
567
587
  setv(
568
588
  to_object,
569
589
  ['generationConfig'],
570
590
  _GenerateContentConfig_to_vertex(
571
591
  self._api_client,
572
- getv(from_object, ['generation_config']),
592
+ getv(config, ['generation_config']),
573
593
  to_object,
574
594
  ),
575
595
  )
576
- if getv(from_object, ['response_modalities']) is not None:
596
+ if getv(config, ['response_modalities']) is not None:
577
597
  if getv(to_object, ['generationConfig']) is not None:
578
- to_object['generationConfig']['responseModalities'] = from_object[
579
- 'response_modalities'
580
- ]
598
+ to_object['generationConfig']['responseModalities'] = getv(
599
+ config, ['response_modalities']
600
+ )
581
601
  else:
582
602
  to_object['generationConfig'] = {
583
- 'responseModalities': from_object['response_modalities']
603
+ 'responseModalities': getv(config, ['response_modalities'])
584
604
  }
585
605
  else:
586
606
  # Set default to AUDIO to align with MLDev API.
@@ -590,12 +610,12 @@ class AsyncLive(_api_module.BaseModule):
590
610
  to_object.update(
591
611
  {'generationConfig': {'responseModalities': ['AUDIO']}}
592
612
  )
593
- if getv(from_object, ['speech_config']) is not None:
613
+ if getv(config, ['speech_config']) is not None:
594
614
  if getv(to_object, ['generationConfig']) is not None:
595
615
  to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
596
616
  self._api_client,
597
617
  t.t_speech_config(
598
- self._api_client, getv(from_object, ['speech_config'])
618
+ self._api_client, getv(config, ['speech_config'])
599
619
  ),
600
620
  to_object,
601
621
  )
@@ -604,30 +624,32 @@ class AsyncLive(_api_module.BaseModule):
604
624
  'speechConfig': _SpeechConfig_to_vertex(
605
625
  self._api_client,
606
626
  t.t_speech_config(
607
- self._api_client, getv(from_object, ['speech_config'])
627
+ self._api_client, getv(config, ['speech_config'])
608
628
  ),
609
629
  to_object,
610
630
  )
611
631
  }
612
- if getv(from_object, ['system_instruction']) is not None:
632
+ if getv(config, ['system_instruction']) is not None:
613
633
  setv(
614
634
  to_object,
615
635
  ['systemInstruction'],
616
636
  _Content_to_vertex(
617
637
  self._api_client,
618
638
  t.t_content(
619
- self._api_client, getv(from_object, ['system_instruction'])
639
+ self._api_client, getv(config, ['system_instruction'])
620
640
  ),
621
641
  to_object,
622
642
  ),
623
643
  )
624
- if getv(from_object, ['tools']) is not None:
644
+ if getv(config, ['tools']) is not None:
625
645
  setv(
626
646
  to_object,
627
647
  ['tools'],
628
648
  [
629
- _Tool_to_vertex(self._api_client, item, to_object)
630
- for item in getv(from_object, ['tools'])
649
+ _Tool_to_vertex(
650
+ self._api_client, t.t_tool(self._api_client, item), to_object
651
+ )
652
+ for item in t.t_tools(self._api_client, getv(config, ['tools']))
631
653
  ],
632
654
  )
633
655
 
@@ -661,16 +683,24 @@ class AsyncLive(_api_module.BaseModule):
661
683
  print(message)
662
684
  """
663
685
  base_url = self._api_client._websocket_base_url()
686
+ transformed_model = t.t_model(self._api_client, model)
687
+ # Ensure the config is a LiveConnectConfig.
688
+ parameter_model = types.LiveConnectConfig(**config) if isinstance(
689
+ config, dict
690
+ ) else config
691
+
664
692
  if self._api_client.api_key:
665
693
  api_key = self._api_client.api_key
666
694
  version = self._api_client._http_options['api_version']
667
695
  uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
668
696
  headers = self._api_client._http_options['headers']
669
-
670
- transformed_model = t.t_model(self._api_client, model)
671
- request = json.dumps(
672
- self._LiveSetup_to_mldev(model=transformed_model, config=config)
697
+ request_dict = _common.convert_to_dict(
698
+ self._LiveSetup_to_mldev(
699
+ model=transformed_model,
700
+ config=parameter_model,
701
+ )
673
702
  )
703
+ request = json.dumps(request_dict)
674
704
  else:
675
705
  # Get bearer token through Application Default Credentials.
676
706
  creds, _ = google.auth.default(
@@ -682,26 +712,28 @@ class AsyncLive(_api_module.BaseModule):
682
712
  auth_req = google.auth.transport.requests.Request()
683
713
  creds.refresh(auth_req)
684
714
  bearer_token = creds.token
685
- headers = {
686
- 'Content-Type': 'application/json',
715
+ headers = self._api_client._http_options['headers']
716
+ headers.update({
687
717
  'Authorization': 'Bearer {}'.format(bearer_token),
688
- }
718
+ })
689
719
  version = self._api_client._http_options['api_version']
690
720
  uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
691
721
  location = self._api_client.location
692
722
  project = self._api_client.project
693
- transformed_model = t.t_model(self._api_client, model)
694
723
  if transformed_model.startswith('publishers/'):
695
724
  transformed_model = (
696
725
  f'projects/{project}/locations/{location}/' + transformed_model
697
726
  )
698
-
699
- request = json.dumps(
700
- self._LiveSetup_to_vertex(model=transformed_model, config=config)
727
+ request_dict = _common.convert_to_dict(
728
+ self._LiveSetup_to_vertex(
729
+ model=transformed_model,
730
+ config=parameter_model,
731
+ )
701
732
  )
733
+ request = json.dumps(request_dict)
702
734
 
703
735
  async with connect(uri, additional_headers=headers) as ws:
704
736
  await ws.send(request)
705
- logging.info(await ws.recv(decode=False))
737
+ logger.info(await ws.recv(decode=False))
706
738
 
707
739
  yield AsyncSession(api_client=self._api_client, websocket=ws)