google-genai 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
google/genai/live.py CHANGED
@@ -68,6 +68,7 @@ class AsyncSession:
68
68
 
69
69
  async def send(
70
70
  self,
71
+ *,
71
72
  input: Union[
72
73
  types.ContentListUnion,
73
74
  types.ContentListUnionDict,
@@ -80,6 +81,25 @@ class AsyncSession:
80
81
  ],
81
82
  end_of_turn: Optional[bool] = False,
82
83
  ):
84
+ """Send input to the model.
85
+
86
+ The method will send the input request to the server.
87
+
88
+ Args:
89
+ input: The input request to the model.
90
+ end_of_turn: Whether the input is the last message in a turn.
91
+
92
+ Example usage:
93
+
94
+ .. code-block:: python
95
+
96
+ client = genai.Client(api_key=API_KEY)
97
+
98
+ async with client.aio.live.connect(model='...') as session:
99
+ await session.send(input='Hello world!', end_of_turn=True)
100
+ async for message in session.receive():
101
+ print(message)
102
+ """
83
103
  client_message = self._parse_client_message(input, end_of_turn)
84
104
  await self._ws.send(json.dumps(client_message))
85
105
 
@@ -113,7 +133,7 @@ class AsyncSession:
113
133
  yield result
114
134
 
115
135
  async def start_stream(
116
- self, stream: AsyncIterator[bytes], mime_type: str
136
+ self, *, stream: AsyncIterator[bytes], mime_type: str
117
137
  ) -> AsyncIterator[types.LiveServerMessage]:
118
138
  """start a live session from a data stream.
119
139
 
@@ -199,7 +219,7 @@ class AsyncSession:
199
219
  ):
200
220
  async for data in data_stream:
201
221
  input = {'data': data, 'mimeType': mime_type}
202
- await self.send(input)
222
+ await self.send(input=input)
203
223
  # Give a chance for the receive loop to process responses.
204
224
  await asyncio.sleep(10**-12)
205
225
  # Give a chance for the receiver to process the last response.
@@ -221,6 +241,8 @@ class AsyncSession:
221
241
  )
222
242
  if getv(from_object, ['turnComplete']) is not None:
223
243
  setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
244
+ if getv(from_object, ['interrupted']) is not None:
245
+ setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
224
246
  return to_object
225
247
 
226
248
  def _LiveToolCall_from_mldev(
@@ -292,6 +314,8 @@ class AsyncSession:
292
314
  )
293
315
  if getv(from_object, ['turnComplete']) is not None:
294
316
  setv(to_object, ['turn_complete'], getv(from_object, ['turnComplete']))
317
+ if getv(from_object, ['interrupted']) is not None:
318
+ setv(to_object, ['interrupted'], getv(from_object, ['interrupted']))
295
319
  return to_object
296
320
 
297
321
  def _LiveServerMessage_from_vertex(
@@ -338,7 +362,7 @@ class AsyncSession:
338
362
  ) -> dict:
339
363
  if isinstance(input, str):
340
364
  input = [input]
341
- elif (isinstance(input, dict) and 'data' in input):
365
+ elif isinstance(input, dict) and 'data' in input:
342
366
  if isinstance(input['data'], bytes):
343
367
  decoded_data = base64.b64encode(input['data']).decode('utf-8')
344
368
  input['data'] = decoded_data
@@ -405,7 +429,9 @@ class AsyncSession:
405
429
  client_message = {'client_content': input.model_dump(exclude_none=True)}
406
430
  elif isinstance(input, types.LiveClientToolResponse):
407
431
  # ToolResponse.FunctionResponse
408
- if not (self._api_client.vertexai) and not (input.function_responses[0].id):
432
+ if not (self._api_client.vertexai) and not (
433
+ input.function_responses[0].id
434
+ ):
409
435
  raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
410
436
  client_message = {'tool_response': input.model_dump(exclude_none=True)}
411
437
  elif isinstance(input, types.FunctionResponse):
@@ -457,7 +483,7 @@ class AsyncLive(_common.BaseModule):
457
483
  to_object,
458
484
  ['generationConfig'],
459
485
  _GenerateContentConfig_to_mldev(
460
- self.api_client,
486
+ self._api_client,
461
487
  getv(from_object, ['generation_config']),
462
488
  to_object,
463
489
  ),
@@ -474,17 +500,18 @@ class AsyncLive(_common.BaseModule):
474
500
  if getv(from_object, ['speech_config']) is not None:
475
501
  if getv(to_object, ['generationConfig']) is not None:
476
502
  to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_mldev(
477
- self.api_client,
503
+ self._api_client,
478
504
  t.t_speech_config(
479
- self.api_client, getv(from_object, ['speech_config'])),
505
+ self._api_client, getv(from_object, ['speech_config'])
506
+ ),
480
507
  to_object,
481
508
  )
482
509
  else:
483
510
  to_object['generationConfig'] = {
484
511
  'speechConfig': _SpeechConfig_to_mldev(
485
- self.api_client,
512
+ self._api_client,
486
513
  t.t_speech_config(
487
- self.api_client, getv(from_object, ['speech_config'])
514
+ self._api_client, getv(from_object, ['speech_config'])
488
515
  ),
489
516
  to_object,
490
517
  )
@@ -495,9 +522,9 @@ class AsyncLive(_common.BaseModule):
495
522
  to_object,
496
523
  ['systemInstruction'],
497
524
  _Content_to_mldev(
498
- self.api_client,
525
+ self._api_client,
499
526
  t.t_content(
500
- self.api_client, getv(from_object, ['system_instruction'])
527
+ self._api_client, getv(from_object, ['system_instruction'])
501
528
  ),
502
529
  to_object,
503
530
  ),
@@ -507,7 +534,7 @@ class AsyncLive(_common.BaseModule):
507
534
  to_object,
508
535
  ['tools'],
509
536
  [
510
- _Tool_to_mldev(self.api_client, item, to_object)
537
+ _Tool_to_mldev(self._api_client, item, to_object)
511
538
  for item in getv(from_object, ['tools'])
512
539
  ],
513
540
  )
@@ -531,7 +558,7 @@ class AsyncLive(_common.BaseModule):
531
558
  to_object,
532
559
  ['generationConfig'],
533
560
  _GenerateContentConfig_to_vertex(
534
- self.api_client,
561
+ self._api_client,
535
562
  getv(from_object, ['generation_config']),
536
563
  to_object,
537
564
  ),
@@ -556,17 +583,18 @@ class AsyncLive(_common.BaseModule):
556
583
  if getv(from_object, ['speech_config']) is not None:
557
584
  if getv(to_object, ['generationConfig']) is not None:
558
585
  to_object['generationConfig']['speechConfig'] = _SpeechConfig_to_vertex(
559
- self.api_client,
586
+ self._api_client,
560
587
  t.t_speech_config(
561
- self.api_client, getv(from_object, ['speech_config'])),
588
+ self._api_client, getv(from_object, ['speech_config'])
589
+ ),
562
590
  to_object,
563
591
  )
564
592
  else:
565
593
  to_object['generationConfig'] = {
566
594
  'speechConfig': _SpeechConfig_to_vertex(
567
- self.api_client,
595
+ self._api_client,
568
596
  t.t_speech_config(
569
- self.api_client, getv(from_object, ['speech_config'])
597
+ self._api_client, getv(from_object, ['speech_config'])
570
598
  ),
571
599
  to_object,
572
600
  )
@@ -576,9 +604,9 @@ class AsyncLive(_common.BaseModule):
576
604
  to_object,
577
605
  ['systemInstruction'],
578
606
  _Content_to_vertex(
579
- self.api_client,
607
+ self._api_client,
580
608
  t.t_content(
581
- self.api_client, getv(from_object, ['system_instruction'])
609
+ self._api_client, getv(from_object, ['system_instruction'])
582
610
  ),
583
611
  to_object,
584
612
  ),
@@ -588,7 +616,7 @@ class AsyncLive(_common.BaseModule):
588
616
  to_object,
589
617
  ['tools'],
590
618
  [
591
- _Tool_to_vertex(self.api_client, item, to_object)
619
+ _Tool_to_vertex(self._api_client, item, to_object)
592
620
  for item in getv(from_object, ['tools'])
593
621
  ],
594
622
  )
@@ -599,7 +627,10 @@ class AsyncLive(_common.BaseModule):
599
627
 
600
628
  @contextlib.asynccontextmanager
601
629
  async def connect(
602
- self, model: str, config: Optional[types.LiveConnectConfigOrDict] = None
630
+ self,
631
+ *,
632
+ model: str,
633
+ config: Optional[types.LiveConnectConfigOrDict] = None,
603
634
  ) -> AsyncSession:
604
635
  """Connect to the live server.
605
636
 
@@ -609,19 +640,19 @@ class AsyncLive(_common.BaseModule):
609
640
 
610
641
  client = genai.Client(api_key=API_KEY)
611
642
  config = {}
612
- async with client.aio.live.connect(model='gemini-1.0-pro-002', config=config) as session:
643
+ async with client.aio.live.connect(model='...', config=config) as session:
613
644
  await session.send(input='Hello world!', end_of_turn=True)
614
- async for message in session:
645
+ async for message in session.receive():
615
646
  print(message)
616
647
  """
617
- base_url = self.api_client._websocket_base_url()
618
- if self.api_client.api_key:
619
- api_key = self.api_client.api_key
620
- version = self.api_client._http_options['api_version']
648
+ base_url = self._api_client._websocket_base_url()
649
+ if self._api_client.api_key:
650
+ api_key = self._api_client.api_key
651
+ version = self._api_client._http_options['api_version']
621
652
  uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
622
- headers = self.api_client._http_options['headers']
653
+ headers = self._api_client._http_options['headers']
623
654
 
624
- transformed_model = t.t_model(self.api_client, model)
655
+ transformed_model = t.t_model(self._api_client, model)
625
656
  request = json.dumps(
626
657
  self._LiveSetup_to_mldev(model=transformed_model, config=config)
627
658
  )
@@ -640,11 +671,11 @@ class AsyncLive(_common.BaseModule):
640
671
  'Content-Type': 'application/json',
641
672
  'Authorization': 'Bearer {}'.format(bearer_token),
642
673
  }
643
- version = self.api_client._http_options['api_version']
674
+ version = self._api_client._http_options['api_version']
644
675
  uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
645
- location = self.api_client.location
646
- project = self.api_client.project
647
- transformed_model = t.t_model(self.api_client, model)
676
+ location = self._api_client.location
677
+ project = self._api_client.project
678
+ transformed_model = t.t_model(self._api_client, model)
648
679
  if transformed_model.startswith('publishers/'):
649
680
  transformed_model = (
650
681
  f'projects/{project}/locations/{location}/' + transformed_model
@@ -658,4 +689,4 @@ class AsyncLive(_common.BaseModule):
658
689
  await ws.send(request)
659
690
  logging.info(await ws.recv(decode=False))
660
691
 
661
- yield AsyncSession(api_client=self.api_client, websocket=ws)
692
+ yield AsyncSession(api_client=self._api_client, websocket=ws)