google-genai 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py CHANGED
@@ -31,7 +31,7 @@ from . import _transformers as t
31
31
  from . import client
32
32
  from . import errors
33
33
  from . import types
34
- from ._api_client import ApiClient
34
+ from ._api_client import BaseApiClient
35
35
  from ._common import experimental_warning
36
36
  from ._common import get_value_by_path as getv
37
37
  from ._common import set_value_by_path as setv
@@ -49,11 +49,12 @@ from .models import _Tool_to_mldev
49
49
  from .models import _Tool_to_vertex
50
50
 
51
51
  try:
52
- from websockets.asyncio.client import ClientConnection
53
- from websockets.asyncio.client import connect
52
+ from websockets.asyncio.client import ClientConnection # type: ignore
53
+ from websockets.asyncio.client import connect # type: ignore
54
54
  except ModuleNotFoundError:
55
- from websockets.client import ClientConnection
56
- from websockets.client import connect
55
+ # This try/except is for TAP, mypy complains about it which is why we have the type: ignore
56
+ from websockets.client import ClientConnection # type: ignore
57
+ from websockets.client import connect # type: ignore
57
58
 
58
59
  logger = logging.getLogger('google_genai.live')
59
60
 
@@ -66,7 +67,9 @@ _FUNCTION_RESPONSE_REQUIRES_ID = (
66
67
  class AsyncSession:
67
68
  """AsyncSession. The live module is experimental."""
68
69
 
69
- def __init__(self, api_client: client.ApiClient, websocket: ClientConnection):
70
+ def __init__(
71
+ self, api_client: client.BaseApiClient, websocket: ClientConnection
72
+ ):
70
73
  self._api_client = api_client
71
74
  self._ws = websocket
72
75
 
@@ -217,7 +220,7 @@ class AsyncSession:
217
220
  response_dict = self._LiveServerMessage_from_mldev(response)
218
221
 
219
222
  return types.LiveServerMessage._from_response(
220
- response_dict, parameter_model
223
+ response=response_dict, kwargs=parameter_model
221
224
  )
222
225
 
223
226
  async def _send_loop(
@@ -238,7 +241,7 @@ class AsyncSession:
238
241
  self,
239
242
  from_object: Union[dict, object],
240
243
  ) -> Dict[str, Any]:
241
- to_object = {}
244
+ to_object: dict[str, Any] = {}
242
245
  if getv(from_object, ['modelTurn']) is not None:
243
246
  setv(
244
247
  to_object,
@@ -258,7 +261,7 @@ class AsyncSession:
258
261
  self,
259
262
  from_object: Union[dict, object],
260
263
  ) -> Dict[str, Any]:
261
- to_object = {}
264
+ to_object: dict[str, Any] = {}
262
265
  if getv(from_object, ['functionCalls']) is not None:
263
266
  setv(
264
267
  to_object,
@@ -271,7 +274,7 @@ class AsyncSession:
271
274
  self,
272
275
  from_object: Union[dict, object],
273
276
  ) -> Dict[str, Any]:
274
- to_object = {}
277
+ to_object: dict[str, Any] = {}
275
278
  if getv(from_object, ['functionCalls']) is not None:
276
279
  setv(
277
280
  to_object,
@@ -284,7 +287,7 @@ class AsyncSession:
284
287
  self,
285
288
  from_object: Union[dict, object],
286
289
  ) -> Dict[str, Any]:
287
- to_object = {}
290
+ to_object: dict[str, Any] = {}
288
291
  if getv(from_object, ['serverContent']) is not None:
289
292
  setv(
290
293
  to_object,
@@ -311,7 +314,7 @@ class AsyncSession:
311
314
  self,
312
315
  from_object: Union[dict, object],
313
316
  ) -> Dict[str, Any]:
314
- to_object = {}
317
+ to_object: dict[str, Any] = {}
315
318
  if getv(from_object, ['modelTurn']) is not None:
316
319
  setv(
317
320
  to_object,
@@ -331,7 +334,7 @@ class AsyncSession:
331
334
  self,
332
335
  from_object: Union[dict, object],
333
336
  ) -> Dict[str, Any]:
334
- to_object = {}
337
+ to_object: dict[str, Any] = {}
335
338
  if getv(from_object, ['serverContent']) is not None:
336
339
  setv(
337
340
  to_object,
@@ -399,7 +402,7 @@ class AsyncSession:
399
402
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
400
403
  client_message = {'tool_response': {'function_responses': input}}
401
404
  elif isinstance(input, Sequence) and any(isinstance(c, str) for c in input):
402
- to_object = {}
405
+ to_object: dict[str, Any] = {}
403
406
  if self._api_client.vertexai:
404
407
  contents = [
405
408
  _Content_to_vertex(self._api_client, item, to_object)
@@ -503,39 +506,35 @@ class AsyncLive(_api_module.BaseModule):
503
506
  """AsyncLive. The live module is experimental."""
504
507
 
505
508
  def _LiveSetup_to_mldev(
506
- self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
509
+ self, model: str, config: Optional[types.LiveConnectConfig] = None
507
510
  ):
508
- if isinstance(config, types.LiveConnectConfig):
509
- from_object = config.model_dump(exclude_none=True)
510
- else:
511
- from_object = config
512
511
 
513
- to_object = {}
514
- if getv(from_object, ['generation_config']) is not None:
512
+ to_object: dict[str, Any] = {}
513
+ if getv(config, ['generation_config']) is not None:
515
514
  setv(
516
515
  to_object,
517
516
  ['generationConfig'],
518
517
  _GenerateContentConfig_to_mldev(
519
518
  self._api_client,
520
- getv(from_object, ['generation_config']),
519
+ getv(config, ['generation_config']),
521
520
  to_object,
522
521
  ),
523
522
  )
524
- if getv(from_object, ['response_modalities']) is not None:
523
+ if getv(config, ['response_modalities']) is not None:
525
524
  if getv(to_object, ['generationConfig']) is not None:
526
- to_object['generationConfig']['responseModalities'] = from_object[
527
- 'response_modalities'
528
- ]
525
+ to_object['generationConfig']['responseModalities'] = getv(
526
+ config, ['response_modalities']
527
+ )
529
528
  else:
530
529
  to_object['generationConfig'] = {
531
- 'responseModalities': from_object['response_modalities']
530
+ 'responseModalities': getv(config, ['response_modalities'])
532
531
  }
533
- if getv(from_object, ['speech_config']) is not None:
532
+ if getv(config, ['speech_config']) is not None:
534
533
  if getv(to_object, ['generationConfig']) is not None:
535
534
  to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
536
535
  self._api_client,
537
536
  t.t_speech_config(
538
- self._api_client, getv(from_object, ['speech_config'])
537
+ self._api_client, getv(config, ['speech_config'])
539
538
  ),
540
539
  to_object,
541
540
  )
@@ -544,31 +543,33 @@ class AsyncLive(_api_module.BaseModule):
544
543
  'speechConfig': _SpeechConfig_to_mldev(
545
544
  self._api_client,
546
545
  t.t_speech_config(
547
- self._api_client, getv(from_object, ['speech_config'])
546
+ self._api_client, getv(config, ['speech_config'])
548
547
  ),
549
548
  to_object,
550
549
  )
551
550
  }
552
551
 
553
- if getv(from_object, ['system_instruction']) is not None:
552
+ if getv(config, ['system_instruction']) is not None:
554
553
  setv(
555
554
  to_object,
556
555
  ['systemInstruction'],
557
556
  _Content_to_mldev(
558
557
  self._api_client,
559
558
  t.t_content(
560
- self._api_client, getv(from_object, ['system_instruction'])
559
+ self._api_client, getv(config, ['system_instruction'])
561
560
  ),
562
561
  to_object,
563
562
  ),
564
563
  )
565
- if getv(from_object, ['tools']) is not None:
564
+ if getv(config, ['tools']) is not None:
566
565
  setv(
567
566
  to_object,
568
567
  ['tools'],
569
568
  [
570
- _Tool_to_mldev(self._api_client, item, to_object)
571
- for item in getv(from_object, ['tools'])
569
+ _Tool_to_mldev(
570
+ self._api_client, t.t_tool(self._api_client, item), to_object
571
+ )
572
+ for item in t.t_tools(self._api_client, getv(config, ['tools']))
572
573
  ],
573
574
  )
574
575
 
@@ -577,33 +578,29 @@ class AsyncLive(_api_module.BaseModule):
577
578
  return return_value
578
579
 
579
580
  def _LiveSetup_to_vertex(
580
- self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
581
+ self, model: str, config: Optional[types.LiveConnectConfig] = None
581
582
  ):
582
- if isinstance(config, types.LiveConnectConfig):
583
- from_object = config.model_dump(exclude_none=True)
584
- else:
585
- from_object = config
586
583
 
587
- to_object = {}
584
+ to_object: dict[str, Any] = {}
588
585
 
589
- if getv(from_object, ['generation_config']) is not None:
586
+ if getv(config, ['generation_config']) is not None:
590
587
  setv(
591
588
  to_object,
592
589
  ['generationConfig'],
593
590
  _GenerateContentConfig_to_vertex(
594
591
  self._api_client,
595
- getv(from_object, ['generation_config']),
592
+ getv(config, ['generation_config']),
596
593
  to_object,
597
594
  ),
598
595
  )
599
- if getv(from_object, ['response_modalities']) is not None:
596
+ if getv(config, ['response_modalities']) is not None:
600
597
  if getv(to_object, ['generationConfig']) is not None:
601
- to_object['generationConfig']['responseModalities'] = from_object[
602
- 'response_modalities'
603
- ]
598
+ to_object['generationConfig']['responseModalities'] = getv(
599
+ config, ['response_modalities']
600
+ )
604
601
  else:
605
602
  to_object['generationConfig'] = {
606
- 'responseModalities': from_object['response_modalities']
603
+ 'responseModalities': getv(config, ['response_modalities'])
607
604
  }
608
605
  else:
609
606
  # Set default to AUDIO to align with MLDev API.
@@ -613,12 +610,12 @@ class AsyncLive(_api_module.BaseModule):
613
610
  to_object.update(
614
611
  {'generationConfig': {'responseModalities': ['AUDIO']}}
615
612
  )
616
- if getv(from_object, ['speech_config']) is not None:
613
+ if getv(config, ['speech_config']) is not None:
617
614
  if getv(to_object, ['generationConfig']) is not None:
618
615
  to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
619
616
  self._api_client,
620
617
  t.t_speech_config(
621
- self._api_client, getv(from_object, ['speech_config'])
618
+ self._api_client, getv(config, ['speech_config'])
622
619
  ),
623
620
  to_object,
624
621
  )
@@ -627,30 +624,32 @@ class AsyncLive(_api_module.BaseModule):
627
624
  'speechConfig': _SpeechConfig_to_vertex(
628
625
  self._api_client,
629
626
  t.t_speech_config(
630
- self._api_client, getv(from_object, ['speech_config'])
627
+ self._api_client, getv(config, ['speech_config'])
631
628
  ),
632
629
  to_object,
633
630
  )
634
631
  }
635
- if getv(from_object, ['system_instruction']) is not None:
632
+ if getv(config, ['system_instruction']) is not None:
636
633
  setv(
637
634
  to_object,
638
635
  ['systemInstruction'],
639
636
  _Content_to_vertex(
640
637
  self._api_client,
641
638
  t.t_content(
642
- self._api_client, getv(from_object, ['system_instruction'])
639
+ self._api_client, getv(config, ['system_instruction'])
643
640
  ),
644
641
  to_object,
645
642
  ),
646
643
  )
647
- if getv(from_object, ['tools']) is not None:
644
+ if getv(config, ['tools']) is not None:
648
645
  setv(
649
646
  to_object,
650
647
  ['tools'],
651
648
  [
652
- _Tool_to_vertex(self._api_client, item, to_object)
653
- for item in getv(from_object, ['tools'])
649
+ _Tool_to_vertex(
650
+ self._api_client, t.t_tool(self._api_client, item), to_object
651
+ )
652
+ for item in t.t_tools(self._api_client, getv(config, ['tools']))
654
653
  ],
655
654
  )
656
655
 
@@ -684,16 +683,24 @@ class AsyncLive(_api_module.BaseModule):
684
683
  print(message)
685
684
  """
686
685
  base_url = self._api_client._websocket_base_url()
686
+ transformed_model = t.t_model(self._api_client, model)
687
+ # Ensure the config is a LiveConnectConfig.
688
+ parameter_model = types.LiveConnectConfig(**config) if isinstance(
689
+ config, dict
690
+ ) else config
691
+
687
692
  if self._api_client.api_key:
688
693
  api_key = self._api_client.api_key
689
694
  version = self._api_client._http_options['api_version']
690
695
  uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
691
696
  headers = self._api_client._http_options['headers']
692
-
693
- transformed_model = t.t_model(self._api_client, model)
694
- request = json.dumps(
695
- self._LiveSetup_to_mldev(model=transformed_model, config=config)
697
+ request_dict = _common.convert_to_dict(
698
+ self._LiveSetup_to_mldev(
699
+ model=transformed_model,
700
+ config=parameter_model,
701
+ )
696
702
  )
703
+ request = json.dumps(request_dict)
697
704
  else:
698
705
  # Get bearer token through Application Default Credentials.
699
706
  creds, _ = google.auth.default(
@@ -713,15 +720,17 @@ class AsyncLive(_api_module.BaseModule):
713
720
  uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
714
721
  location = self._api_client.location
715
722
  project = self._api_client.project
716
- transformed_model = t.t_model(self._api_client, model)
717
723
  if transformed_model.startswith('publishers/'):
718
724
  transformed_model = (
719
725
  f'projects/{project}/locations/{location}/' + transformed_model
720
726
  )
721
-
722
- request = json.dumps(
723
- self._LiveSetup_to_vertex(model=transformed_model, config=config)
727
+ request_dict = _common.convert_to_dict(
728
+ self._LiveSetup_to_vertex(
729
+ model=transformed_model,
730
+ config=parameter_model,
731
+ )
724
732
  )
733
+ request = json.dumps(request_dict)
725
734
 
726
735
  async with connect(uri, additional_headers=headers) as ws:
727
736
  await ws.send(request)