google-genai 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py CHANGED
@@ -108,7 +108,7 @@ class AsyncSession:
108
108
 
109
109
  The method will yield the model responses from the server. The returned
110
110
  responses will represent a complete model turn. When the returned message
111
- is fuction call, user must call `send` with the function response to
111
+ is function call, user must call `send` with the function response to
112
112
  continue the turn.
113
113
 
114
114
  Yields:
@@ -241,6 +241,8 @@ class AsyncSession:
241
241
  )
242
242
  if getv(from_object, ['turnComplete']) is not None:
243
243
  setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
244
+ if getv(from_object, ['interrupted']) is not None:
245
+ setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
244
246
  return to_object
245
247
 
246
248
  def _LiveToolCall_from_mldev(
@@ -312,6 +314,8 @@ class AsyncSession:
312
314
  )
313
315
  if getv(from_object, ['turnComplete']) is not None:
314
316
  setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
317
+ if getv(from_object, ['interrupted']) is not None:
318
+ setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
315
319
  return to_object
316
320
 
317
321
  def _LiveServerMessage_from_vertex(
@@ -358,7 +362,7 @@ class AsyncSession:
358
362
  ) -> dict:
359
363
  if isinstance(input, str):
360
364
  input = [input]
361
- elif (isinstance(input, dict) and 'data' in input):
365
+ elif isinstance(input, dict) and 'data' in input:
362
366
  if isinstance(input['data'], bytes):
363
367
  decoded_data = base64.b64encode(input['data']).decode('utf-8')
364
368
  input['data'] = decoded_data
@@ -376,6 +380,10 @@ class AsyncSession:
376
380
  isinstance(c, dict) and 'name' in c and 'response' in c for c in input
377
381
  ):
378
382
  # ToolResponse.FunctionResponse
383
+ if not (self._api_client.vertexai):
384
+ for item in input:
385
+ if 'id' not in item:
386
+ raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
379
387
  client_message = {'tool_response': {'function_responses': input}}
380
388
  elif isinstance(input, Sequence) and any(isinstance(c, str) for c in input):
381
389
  to_object = {}
@@ -425,7 +433,9 @@ class AsyncSession:
425
433
  client_message = {'client_content': input.model_dump(exclude_none=True)}
426
434
  elif isinstance(input, types.LiveClientToolResponse):
427
435
  # ToolResponse.FunctionResponse
428
- if not (self._api_client.vertexai) and not (input.function_responses[0].id):
436
+ if not (self._api_client.vertexai) and not (
437
+ input.function_responses[0].id
438
+ ):
429
439
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
430
440
  client_message = {'tool_response': input.model_dump(exclude_none=True)}
431
441
  elif isinstance(input, types.FunctionResponse):
@@ -477,7 +487,7 @@ class AsyncLive(_common.BaseModule):
477
487
  to_object,
478
488
  ['generationConfig'],
479
489
  _GenerateContentConfig_to_mldev(
480
- self.api_client,
490
+ self._api_client,
481
491
  getv(from_object, ['generation_config']),
482
492
  to_object,
483
493
  ),
@@ -494,17 +504,18 @@ class AsyncLive(_common.BaseModule):
494
504
  if getv(from_object, ['speech_config']) is not None:
495
505
  if getv(to_object, ['generationConfig']) is not None:
496
506
  to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
497
- self.api_client,
507
+ self._api_client,
498
508
  t.t_speech_config(
499
- self.api_client, getv(from_object, ['speech_config'])),
509
+ self._api_client, getv(from_object, ['speech_config'])
510
+ ),
500
511
  to_object,
501
512
  )
502
513
  else:
503
514
  to_object['generationConfig'] = {
504
515
  'speechConfig': _SpeechConfig_to_mldev(
505
- self.api_client,
516
+ self._api_client,
506
517
  t.t_speech_config(
507
- self.api_client, getv(from_object, ['speech_config'])
518
+ self._api_client, getv(from_object, ['speech_config'])
508
519
  ),
509
520
  to_object,
510
521
  )
@@ -515,9 +526,9 @@ class AsyncLive(_common.BaseModule):
515
526
  to_object,
516
527
  ['systemInstruction'],
517
528
  _Content_to_mldev(
518
- self.api_client,
529
+ self._api_client,
519
530
  t.t_content(
520
- self.api_client, getv(from_object, ['system_instruction'])
531
+ self._api_client, getv(from_object, ['system_instruction'])
521
532
  ),
522
533
  to_object,
523
534
  ),
@@ -527,7 +538,7 @@ class AsyncLive(_common.BaseModule):
527
538
  to_object,
528
539
  ['tools'],
529
540
  [
530
- _Tool_to_mldev(self.api_client, item, to_object)
541
+ _Tool_to_mldev(self._api_client, item, to_object)
531
542
  for item in getv(from_object, ['tools'])
532
543
  ],
533
544
  )
@@ -551,7 +562,7 @@ class AsyncLive(_common.BaseModule):
551
562
  to_object,
552
563
  ['generationConfig'],
553
564
  _GenerateContentConfig_to_vertex(
554
- self.api_client,
565
+ self._api_client,
555
566
  getv(from_object, ['generation_config']),
556
567
  to_object,
557
568
  ),
@@ -576,17 +587,18 @@ class AsyncLive(_common.BaseModule):
576
587
  if getv(from_object, ['speech_config']) is not None:
577
588
  if getv(to_object, ['generationConfig']) is not None:
578
589
  to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
579
- self.api_client,
590
+ self._api_client,
580
591
  t.t_speech_config(
581
- self.api_client, getv(from_object, ['speech_config'])),
592
+ self._api_client, getv(from_object, ['speech_config'])
593
+ ),
582
594
  to_object,
583
595
  )
584
596
  else:
585
597
  to_object['generationConfig'] = {
586
598
  'speechConfig': _SpeechConfig_to_vertex(
587
- self.api_client,
599
+ self._api_client,
588
600
  t.t_speech_config(
589
- self.api_client, getv(from_object, ['speech_config'])
601
+ self._api_client, getv(from_object, ['speech_config'])
590
602
  ),
591
603
  to_object,
592
604
  )
@@ -596,9 +608,9 @@ class AsyncLive(_common.BaseModule):
596
608
  to_object,
597
609
  ['systemInstruction'],
598
610
  _Content_to_vertex(
599
- self.api_client,
611
+ self._api_client,
600
612
  t.t_content(
601
- self.api_client, getv(from_object, ['system_instruction'])
613
+ self._api_client, getv(from_object, ['system_instruction'])
602
614
  ),
603
615
  to_object,
604
616
  ),
@@ -608,7 +620,7 @@ class AsyncLive(_common.BaseModule):
608
620
  to_object,
609
621
  ['tools'],
610
622
  [
611
- _Tool_to_vertex(self.api_client, item, to_object)
623
+ _Tool_to_vertex(self._api_client, item, to_object)
612
624
  for item in getv(from_object, ['tools'])
613
625
  ],
614
626
  )
@@ -637,14 +649,14 @@ class AsyncLive(_common.BaseModule):
637
649
  async for message in session.receive():
638
650
  print(message)
639
651
  """
640
- base_url = self.api_client._websocket_base_url()
641
- if self.api_client.api_key:
642
- api_key = self.api_client.api_key
643
- version = self.api_client._http_options['api_version']
652
+ base_url = self._api_client._websocket_base_url()
653
+ if self._api_client.api_key:
654
+ api_key = self._api_client.api_key
655
+ version = self._api_client._http_options['api_version']
644
656
  uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
645
- headers = self.api_client._http_options['headers']
657
+ headers = self._api_client._http_options['headers']
646
658
 
647
- transformed_model = t.t_model(self.api_client, model)
659
+ transformed_model = t.t_model(self._api_client, model)
648
660
  request = json.dumps(
649
661
  self._LiveSetup_to_mldev(model=transformed_model, config=config)
650
662
  )
@@ -663,11 +675,11 @@ class AsyncLive(_common.BaseModule):
663
675
  'Content-Type': 'application/json',
664
676
  'Authorization': 'Bearer {}'.format(bearer_token),
665
677
  }
666
- version = self.api_client._http_options['api_version']
678
+ version = self._api_client._http_options['api_version']
667
679
  uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
668
- location = self.api_client.location
669
- project = self.api_client.project
670
- transformed_model = t.t_model(self.api_client, model)
680
+ location = self._api_client.location
681
+ project = self._api_client.project
682
+ transformed_model = t.t_model(self._api_client, model)
671
683
  if transformed_model.startswith('publishers/'):
672
684
  transformed_model = (
673
685
  f'projects/{project}/locations/{location}/' + transformed_model
@@ -681,4 +693,4 @@ class AsyncLive(_common.BaseModule):
681
693
  await ws.send(request)
682
694
  logging.info(await ws.recv(decode=False))
683
695
 
684
- yield AsyncSession(api_client=self.api_client, websocket=ws)
696
+ yield AsyncSession(api_client=self._api_client, websocket=ws)