livekit-plugins-google 1.0.22__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,42 +10,23 @@ from collections.abc import Iterator
10
10
  from dataclasses import dataclass, field
11
11
 
12
12
  from google import genai
13
+ from google.genai import types
13
14
  from google.genai.live import AsyncSession
14
- from google.genai.types import (
15
- AudioTranscriptionConfig,
16
- AutomaticActivityDetection,
17
- Blob,
18
- Content,
19
- FunctionDeclaration,
20
- GenerationConfig,
21
- LiveClientContent,
22
- LiveClientRealtimeInput,
23
- LiveClientToolResponse,
24
- LiveConnectConfig,
25
- LiveServerContent,
26
- LiveServerGoAway,
27
- LiveServerToolCall,
28
- LiveServerToolCallCancellation,
29
- Modality,
30
- ModalityTokenCount,
31
- Part,
32
- PrebuiltVoiceConfig,
33
- RealtimeInputConfig,
34
- SessionResumptionConfig,
35
- SpeechConfig,
36
- Tool,
37
- UsageMetadata,
38
- VoiceConfig,
39
- )
40
15
  from livekit import rtc
41
- from livekit.agents import llm, utils
16
+ from livekit.agents import APIConnectionError, llm, utils
42
17
  from livekit.agents.metrics import RealtimeModelMetrics
43
- from livekit.agents.types import NOT_GIVEN, NotGivenOr
18
+ from livekit.agents.types import (
19
+ DEFAULT_API_CONNECT_OPTIONS,
20
+ NOT_GIVEN,
21
+ APIConnectOptions,
22
+ NotGivenOr,
23
+ )
44
24
  from livekit.agents.utils import audio as audio_utils, images, is_given
45
25
  from livekit.plugins.google.beta.realtime.api_proto import ClientEvents, LiveAPIModels, Voice
46
26
 
47
27
  from ...log import logger
48
- from ...utils import get_tool_results_for_realtime, to_chat_ctx, to_fnc_ctx
28
+ from ...tools import _LLMTool
29
+ from ...utils import create_tools_config, get_tool_results_for_realtime, to_fnc_ctx
49
30
 
50
31
  INPUT_AUDIO_SAMPLE_RATE = 16000
51
32
  INPUT_AUDIO_CHANNELS = 1
@@ -71,7 +52,7 @@ class _RealtimeOptions:
71
52
  api_key: str | None
72
53
  voice: Voice | str
73
54
  language: NotGivenOr[str]
74
- response_modalities: NotGivenOr[list[Modality]]
55
+ response_modalities: NotGivenOr[list[types.Modality]]
75
56
  vertexai: bool
76
57
  project: str | None
77
58
  location: str | None
@@ -83,9 +64,16 @@ class _RealtimeOptions:
83
64
  presence_penalty: NotGivenOr[float]
84
65
  frequency_penalty: NotGivenOr[float]
85
66
  instructions: NotGivenOr[str]
86
- input_audio_transcription: AudioTranscriptionConfig | None
87
- output_audio_transcription: AudioTranscriptionConfig | None
67
+ input_audio_transcription: types.AudioTranscriptionConfig | None
68
+ output_audio_transcription: types.AudioTranscriptionConfig | None
88
69
  image_encode_options: NotGivenOr[images.EncodeOptions]
70
+ conn_options: APIConnectOptions
71
+ enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN
72
+ proactivity: NotGivenOr[bool] = NOT_GIVEN
73
+ realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN
74
+ context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN
75
+ api_version: NotGivenOr[str] = NOT_GIVEN
76
+ gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN
89
77
 
90
78
 
91
79
  @dataclass
@@ -93,10 +81,13 @@ class _ResponseGeneration:
93
81
  message_ch: utils.aio.Chan[llm.MessageGeneration]
94
82
  function_ch: utils.aio.Chan[llm.FunctionCall]
95
83
 
84
+ input_id: str
96
85
  response_id: str
97
86
  text_ch: utils.aio.Chan[str]
98
87
  audio_ch: utils.aio.Chan[rtc.AudioFrame]
88
+
99
89
  input_transcription: str = ""
90
+ output_text: str = ""
100
91
 
101
92
  _created_timestamp: float = field(default_factory=time.time)
102
93
  """The timestamp when the generation is created"""
@@ -107,6 +98,14 @@ class _ResponseGeneration:
107
98
  _done: bool = False
108
99
  """Whether the generation is done (set when the turn is complete)"""
109
100
 
101
+ def push_text(self, text: str) -> None:
102
+ if self.output_text:
103
+ self.output_text += text
104
+ else:
105
+ self.output_text = text
106
+
107
+ self.text_ch.send_nowait(text)
108
+
110
109
 
111
110
  class RealtimeModel(llm.RealtimeModel):
112
111
  def __init__(
@@ -117,7 +116,7 @@ class RealtimeModel(llm.RealtimeModel):
117
116
  api_key: NotGivenOr[str] = NOT_GIVEN,
118
117
  voice: Voice | str = "Puck",
119
118
  language: NotGivenOr[str] = NOT_GIVEN,
120
- modalities: NotGivenOr[list[Modality]] = NOT_GIVEN,
119
+ modalities: NotGivenOr[list[types.Modality]] = NOT_GIVEN,
121
120
  vertexai: NotGivenOr[bool] = NOT_GIVEN,
122
121
  project: NotGivenOr[str] = NOT_GIVEN,
123
122
  location: NotGivenOr[str] = NOT_GIVEN,
@@ -128,9 +127,16 @@ class RealtimeModel(llm.RealtimeModel):
128
127
  top_k: NotGivenOr[int] = NOT_GIVEN,
129
128
  presence_penalty: NotGivenOr[float] = NOT_GIVEN,
130
129
  frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
131
- input_audio_transcription: NotGivenOr[AudioTranscriptionConfig | None] = NOT_GIVEN,
132
- output_audio_transcription: NotGivenOr[AudioTranscriptionConfig | None] = NOT_GIVEN,
130
+ input_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
131
+ output_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
133
132
  image_encode_options: NotGivenOr[images.EncodeOptions] = NOT_GIVEN,
133
+ enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN,
134
+ proactivity: NotGivenOr[bool] = NOT_GIVEN,
135
+ realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN,
136
+ context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN,
137
+ api_version: NotGivenOr[str] = NOT_GIVEN,
138
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
139
+ _gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
134
140
  ) -> None:
135
141
  """
136
142
  Initializes a RealtimeModel instance for interacting with Google's Realtime API.
@@ -161,19 +167,33 @@ class RealtimeModel(llm.RealtimeModel):
161
167
  input_audio_transcription (AudioTranscriptionConfig | None, optional): The configuration for input audio transcription. Defaults to None.)
162
168
  output_audio_transcription (AudioTranscriptionConfig | None, optional): The configuration for output audio transcription. Defaults to AudioTranscriptionConfig().
163
169
  image_encode_options (images.EncodeOptions, optional): The configuration for image encoding. Defaults to DEFAULT_ENCODE_OPTIONS.
170
+ enable_affective_dialog (bool, optional): Whether to enable affective dialog. Defaults to False.
171
+ proactivity (bool, optional): Whether to enable proactive audio. Defaults to False.
172
+ realtime_input_config (RealtimeInputConfig, optional): The configuration for realtime input. Defaults to None.
173
+ context_window_compression (ContextWindowCompressionConfig, optional): The configuration for context window compression. Defaults to None.
174
+ conn_options (APIConnectOptions, optional): The configuration for the API connection. Defaults to DEFAULT_API_CONNECT_OPTIONS.
175
+ _gemini_tools (list[LLMTool], optional): Gemini-specific tools to use for the session. This parameter is experimental and may change.
164
176
 
165
177
  Raises:
166
178
  ValueError: If the API key is required but not found.
167
179
  """ # noqa: E501
168
180
  if not is_given(input_audio_transcription):
169
- input_audio_transcription = AudioTranscriptionConfig()
181
+ input_audio_transcription = types.AudioTranscriptionConfig()
170
182
  if not is_given(output_audio_transcription):
171
- output_audio_transcription = AudioTranscriptionConfig()
183
+ output_audio_transcription = types.AudioTranscriptionConfig()
184
+
185
+ server_turn_detection = True
186
+ if (
187
+ is_given(realtime_input_config)
188
+ and realtime_input_config.automatic_activity_detection
189
+ and realtime_input_config.automatic_activity_detection.disabled
190
+ ):
191
+ server_turn_detection = False
172
192
 
173
193
  super().__init__(
174
194
  capabilities=llm.RealtimeCapabilities(
175
195
  message_truncation=False,
176
- turn_detection=True,
196
+ turn_detection=server_turn_detection,
177
197
  user_transcription=input_audio_transcription is not None,
178
198
  auto_tool_reply_generation=True,
179
199
  )
@@ -187,7 +207,7 @@ class RealtimeModel(llm.RealtimeModel):
187
207
 
188
208
  gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
189
209
  gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
190
- gcp_location = (
210
+ gcp_location: str | None = (
191
211
  location
192
212
  if is_given(location)
193
213
  else os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
@@ -232,6 +252,13 @@ class RealtimeModel(llm.RealtimeModel):
232
252
  output_audio_transcription=output_audio_transcription,
233
253
  language=language,
234
254
  image_encode_options=image_encode_options,
255
+ enable_affective_dialog=enable_affective_dialog,
256
+ proactivity=proactivity,
257
+ realtime_input_config=realtime_input_config,
258
+ context_window_compression=context_window_compression,
259
+ api_version=api_version,
260
+ gemini_tools=_gemini_tools,
261
+ conn_options=conn_options,
235
262
  )
236
263
 
237
264
  self._sessions = weakref.WeakSet[RealtimeSession]()
@@ -242,8 +269,19 @@ class RealtimeModel(llm.RealtimeModel):
242
269
  return sess
243
270
 
244
271
  def update_options(
245
- self, *, voice: NotGivenOr[str] = NOT_GIVEN, temperature: NotGivenOr[float] = NOT_GIVEN
272
+ self,
273
+ *,
274
+ voice: NotGivenOr[str] = NOT_GIVEN,
275
+ temperature: NotGivenOr[float] = NOT_GIVEN,
246
276
  ) -> None:
277
+ """
278
+ Update the options for the RealtimeModel.
279
+
280
+ Args:
281
+ voice (str, optional): The voice to use for the session.
282
+ temperature (float, optional): The temperature to use for the session.
283
+ tools (list[LLMTool], optional): The tools to use for the session.
284
+ """
247
285
  if is_given(voice):
248
286
  self._opts.voice = voice
249
287
 
@@ -251,7 +289,10 @@ class RealtimeModel(llm.RealtimeModel):
251
289
  self._opts.temperature = temperature
252
290
 
253
291
  for sess in self._sessions:
254
- sess.update_options(voice=self._opts.voice, temperature=self._opts.temperature)
292
+ sess.update_options(
293
+ voice=self._opts.voice,
294
+ temperature=self._opts.temperature,
295
+ )
255
296
 
256
297
  async def aclose(self) -> None:
257
298
  pass
@@ -262,7 +303,7 @@ class RealtimeSession(llm.RealtimeSession):
262
303
  super().__init__(realtime_model)
263
304
  self._opts = realtime_model._opts
264
305
  self._tools = llm.ToolContext.empty()
265
- self._gemini_declarations: list[FunctionDeclaration] = []
306
+ self._gemini_declarations: list[types.FunctionDeclaration] = []
266
307
  self._chat_ctx = llm.ChatContext.empty()
267
308
  self._msg_ch = utils.aio.Chan[ClientEvents]()
268
309
  self._input_resampler: rtc.AudioResampler | None = None
@@ -274,11 +315,20 @@ class RealtimeSession(llm.RealtimeSession):
274
315
  samples_per_channel=INPUT_AUDIO_SAMPLE_RATE // 20,
275
316
  )
276
317
 
318
+ api_version = self._opts.api_version
319
+ if not api_version and (self._opts.enable_affective_dialog or self._opts.proactivity):
320
+ api_version = "v1alpha"
321
+
322
+ http_options = types.HttpOptions(timeout=int(self._opts.conn_options.timeout * 1000))
323
+ if api_version:
324
+ http_options.api_version = api_version
325
+
277
326
  self._client = genai.Client(
278
327
  api_key=self._opts.api_key,
279
328
  vertexai=self._opts.vertexai,
280
329
  project=self._opts.project,
281
330
  location=self._opts.location,
331
+ http_options=http_options,
282
332
  )
283
333
 
284
334
  self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
@@ -291,8 +341,9 @@ class RealtimeSession(llm.RealtimeSession):
291
341
  self._pending_generation_fut: asyncio.Future[llm.GenerationCreatedEvent] | None = None
292
342
 
293
343
  self._session_resumption_handle: str | None = None
294
-
344
+ self._in_user_activity = False
295
345
  self._session_lock = asyncio.Lock()
346
+ self._num_retries = 0
296
347
 
297
348
  async def _close_active_session(self) -> None:
298
349
  async with self._session_lock:
@@ -304,7 +355,7 @@ class RealtimeSession(llm.RealtimeSession):
304
355
  finally:
305
356
  self._active_session = None
306
357
 
307
- def _mark_restart_needed(self):
358
+ def _mark_restart_needed(self) -> None:
308
359
  if not self._session_should_close.is_set():
309
360
  self._session_should_close.set()
310
361
  # reset the msg_ch, do not send messages from previous session
@@ -335,6 +386,11 @@ class RealtimeSession(llm.RealtimeSession):
335
386
  self._mark_restart_needed()
336
387
 
337
388
  async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
389
+ async with self._session_lock:
390
+ if not self._active_session:
391
+ self._chat_ctx = chat_ctx.copy()
392
+ return
393
+
338
394
  diff_ops = llm.utils.compute_chat_ctx_diff(self._chat_ctx, chat_ctx)
339
395
 
340
396
  if diff_ops.to_remove:
@@ -347,15 +403,23 @@ class RealtimeSession(llm.RealtimeSession):
347
403
  append_ctx.items.append(item)
348
404
 
349
405
  if append_ctx.items:
350
- turns, _ = to_chat_ctx(append_ctx, id(self), ignore_functions=True)
406
+ turns_dict, _ = append_ctx.copy(
407
+ exclude_function_call=True,
408
+ ).to_provider_format(format="google", inject_dummy_user_message=False)
409
+ # we are not generating, and do not need to inject
410
+ turns = [types.Content.model_validate(turn) for turn in turns_dict]
351
411
  tool_results = get_tool_results_for_realtime(append_ctx, vertexai=self._opts.vertexai)
352
412
  if turns:
353
- self._send_client_event(LiveClientContent(turns=turns, turn_complete=False))
413
+ self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=False))
354
414
  if tool_results:
355
415
  self._send_client_event(tool_results)
356
416
 
357
- async def update_tools(self, tools: list[llm.FunctionTool]) -> None:
358
- new_declarations: list[FunctionDeclaration] = to_fnc_ctx(tools)
417
+ # since we don't have a view of the history on the server side, we'll assume
418
+ # the current state is accurate. this isn't perfect because removals aren't done.
419
+ self._chat_ctx = chat_ctx.copy()
420
+
421
+ async def update_tools(self, tools: list[llm.FunctionTool | llm.RawFunctionTool]) -> None:
422
+ new_declarations: list[types.FunctionDeclaration] = to_fnc_ctx(tools)
359
423
  current_tool_names = {f.name for f in self._gemini_declarations}
360
424
  new_tool_names = {f.name for f in new_declarations}
361
425
 
@@ -372,11 +436,21 @@ class RealtimeSession(llm.RealtimeSession):
372
436
  def tools(self) -> llm.ToolContext:
373
437
  return self._tools.copy()
374
438
 
439
+ @property
440
+ def _manual_activity_detection(self) -> bool:
441
+ if (
442
+ is_given(self._opts.realtime_input_config)
443
+ and self._opts.realtime_input_config.automatic_activity_detection is not None
444
+ and self._opts.realtime_input_config.automatic_activity_detection.disabled
445
+ ):
446
+ return True
447
+ return False
448
+
375
449
  def push_audio(self, frame: rtc.AudioFrame) -> None:
376
450
  for f in self._resample_audio(frame):
377
451
  for nf in self._bstream.write(f.data.tobytes()):
378
- realtime_input = LiveClientRealtimeInput(
379
- media_chunks=[Blob(data=nf.data.tobytes(), mime_type="audio/pcm")]
452
+ realtime_input = types.LiveClientRealtimeInput(
453
+ media_chunks=[types.Blob(data=nf.data.tobytes(), mime_type="audio/pcm")]
380
454
  )
381
455
  self._send_client_event(realtime_input)
382
456
 
@@ -384,8 +458,8 @@ class RealtimeSession(llm.RealtimeSession):
384
458
  encoded_data = images.encode(
385
459
  frame, self._opts.image_encode_options or DEFAULT_IMAGE_ENCODE_OPTIONS
386
460
  )
387
- realtime_input = LiveClientRealtimeInput(
388
- media_chunks=[Blob(data=encoded_data, mime_type="image/jpeg")]
461
+ realtime_input = types.LiveClientRealtimeInput(
462
+ media_chunks=[types.Blob(data=encoded_data, mime_type="image/jpeg")]
389
463
  )
390
464
  self._send_client_event(realtime_input)
391
465
 
@@ -402,16 +476,24 @@ class RealtimeSession(llm.RealtimeSession):
402
476
  )
403
477
  self._pending_generation_fut.cancel("Superseded by new generate_reply call")
404
478
 
405
- fut = asyncio.Future()
479
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
406
480
  self._pending_generation_fut = fut
407
481
 
482
+ if self._in_user_activity:
483
+ self._send_client_event(
484
+ types.LiveClientRealtimeInput(
485
+ activity_end=types.ActivityEnd(),
486
+ )
487
+ )
488
+ self._in_user_activity = False
489
+
408
490
  # Gemini requires the last message to end with user's turn
409
491
  # so we need to add a placeholder user turn in order to trigger a new generation
410
- event = LiveClientContent(turns=[], turn_complete=True)
492
+ turns = []
411
493
  if is_given(instructions):
412
- event.turns.append(Content(parts=[Part(text=instructions)], role="model"))
413
- event.turns.append(Content(parts=[Part(text=".")], role="user"))
414
- self._send_client_event(event)
494
+ turns.append(types.Content(parts=[types.Part(text=instructions)], role="model"))
495
+ turns.append(types.Content(parts=[types.Part(text=".")], role="user"))
496
+ self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=True))
415
497
 
416
498
  def _on_timeout() -> None:
417
499
  if not fut.done():
@@ -428,8 +510,28 @@ class RealtimeSession(llm.RealtimeSession):
428
510
 
429
511
  return fut
430
512
 
513
+ def start_user_activity(self) -> None:
514
+ if not self._manual_activity_detection:
515
+ return
516
+
517
+ if not self._in_user_activity:
518
+ self._in_user_activity = True
519
+ self._send_client_event(
520
+ types.LiveClientRealtimeInput(
521
+ activity_start=types.ActivityStart(),
522
+ )
523
+ )
524
+
431
525
  def interrupt(self) -> None:
432
- pass
526
+ # Gemini Live treats activity start as interruption, so we rely on start_user_activity
527
+ # notifications to handle it
528
+ if (
529
+ self._opts.realtime_input_config
530
+ and self._opts.realtime_input_config.activity_handling
531
+ == types.ActivityHandling.NO_INTERRUPTION
532
+ ):
533
+ return
534
+ self.start_user_activity()
433
535
 
434
536
  def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
435
537
  logger.warning("truncate is not supported by the Google Realtime API.")
@@ -456,7 +558,9 @@ class RealtimeSession(llm.RealtimeSession):
456
558
  self._mark_current_generation_done()
457
559
 
458
560
  @utils.log_exceptions(logger=logger)
459
- async def _main_task(self):
561
+ async def _main_task(self) -> None:
562
+ max_retries = self._opts.conn_options.max_retry
563
+
460
564
  while not self._msg_ch.closed:
461
565
  # previous session might not be closed yet, we'll do it here.
462
566
  await self._close_active_session()
@@ -471,7 +575,15 @@ class RealtimeSession(llm.RealtimeSession):
471
575
  ) as session:
472
576
  async with self._session_lock:
473
577
  self._active_session = session
474
-
578
+ turns_dict, _ = self._chat_ctx.copy(
579
+ exclude_function_call=True,
580
+ ).to_provider_format(format="google", inject_dummy_user_message=False)
581
+ if turns_dict:
582
+ turns = [types.Content.model_validate(turn) for turn in turns_dict]
583
+ await session.send_client_content(
584
+ turns=turns, # type: ignore
585
+ turn_complete=False,
586
+ )
475
587
  # queue up existing chat context
476
588
  send_task = asyncio.create_task(
477
589
  self._send_task(session), name="gemini-realtime-send"
@@ -504,12 +616,30 @@ class RealtimeSession(llm.RealtimeSession):
504
616
  except Exception as e:
505
617
  logger.error(f"Gemini Realtime API error: {e}", exc_info=e)
506
618
  if not self._msg_ch.closed:
507
- logger.info("attempting to reconnect after 1 seconds...")
508
- await asyncio.sleep(1)
619
+ # we shouldn't retry when it's not connected, usually this means incorrect
620
+ # parameters or setup
621
+ if not session or max_retries == 0:
622
+ self._emit_error(e, recoverable=False)
623
+ raise APIConnectionError(message="Failed to connect to Gemini Live") from e
624
+
625
+ if self._num_retries == max_retries:
626
+ self._emit_error(e, recoverable=False)
627
+ raise APIConnectionError(
628
+ message=f"Failed to connect to Gemini Live after {max_retries} attempts"
629
+ ) from e
630
+
631
+ retry_interval = self._opts.conn_options._interval_for_retry(self._num_retries)
632
+ logger.warning(
633
+ f"Gemini Realtime API connection failed, retrying in {retry_interval}s",
634
+ exc_info=e,
635
+ extra={"attempt": self._num_retries, "max_retries": max_retries},
636
+ )
637
+ await asyncio.sleep(retry_interval)
638
+ self._num_retries += 1
509
639
  finally:
510
640
  await self._close_active_session()
511
641
 
512
- async def _send_task(self, session: AsyncSession):
642
+ async def _send_task(self, session: AsyncSession) -> None:
513
643
  try:
514
644
  async for msg in self._msg_ch:
515
645
  async with self._session_lock:
@@ -517,15 +647,21 @@ class RealtimeSession(llm.RealtimeSession):
517
647
  not self._active_session or self._active_session != session
518
648
  ):
519
649
  break
520
- if isinstance(msg, LiveClientContent):
650
+ if isinstance(msg, types.LiveClientContent):
521
651
  await session.send_client_content(
522
- turns=msg.turns, turn_complete=msg.turn_complete
652
+ turns=msg.turns, # type: ignore
653
+ turn_complete=msg.turn_complete or True,
523
654
  )
524
- elif isinstance(msg, LiveClientToolResponse):
655
+ elif isinstance(msg, types.LiveClientToolResponse) and msg.function_responses:
525
656
  await session.send_tool_response(function_responses=msg.function_responses)
526
- elif isinstance(msg, LiveClientRealtimeInput):
527
- for media_chunk in msg.media_chunks:
528
- await session.send_realtime_input(media=media_chunk)
657
+ elif isinstance(msg, types.LiveClientRealtimeInput):
658
+ if msg.media_chunks:
659
+ for media_chunk in msg.media_chunks:
660
+ await session.send_realtime_input(media=media_chunk)
661
+ elif msg.activity_start:
662
+ await session.send_realtime_input(activity_start=msg.activity_start)
663
+ elif msg.activity_end:
664
+ await session.send_realtime_input(activity_end=msg.activity_end)
529
665
  else:
530
666
  logger.warning(f"Warning: Received unhandled message type: {type(msg)}")
531
667
 
@@ -536,7 +672,7 @@ class RealtimeSession(llm.RealtimeSession):
536
672
  finally:
537
673
  logger.debug("send task finished.")
538
674
 
539
- async def _recv_task(self, session: AsyncSession):
675
+ async def _recv_task(self, session: AsyncSession) -> None:
540
676
  try:
541
677
  while True:
542
678
  async with self._session_lock:
@@ -572,6 +708,9 @@ class RealtimeSession(llm.RealtimeSession):
572
708
  if response.go_away:
573
709
  self._handle_go_away(response.go_away)
574
710
 
711
+ if self._num_retries > 0:
712
+ self._num_retries = 0 # reset the retry counter
713
+
575
714
  # TODO(dz): a server-side turn is complete
576
715
  except Exception as e:
577
716
  if not self._session_should_close.is_set():
@@ -580,14 +719,18 @@ class RealtimeSession(llm.RealtimeSession):
580
719
  finally:
581
720
  self._mark_current_generation_done()
582
721
 
583
- def _build_connect_config(self) -> LiveConnectConfig:
722
+ def _build_connect_config(self) -> types.LiveConnectConfig:
584
723
  temp = self._opts.temperature if is_given(self._opts.temperature) else None
585
724
 
586
- return LiveConnectConfig(
725
+ tools_config = create_tools_config(
726
+ function_tools=self._gemini_declarations,
727
+ gemini_tools=self._opts.gemini_tools if is_given(self._opts.gemini_tools) else None,
728
+ )
729
+ conf = types.LiveConnectConfig(
587
730
  response_modalities=self._opts.response_modalities
588
731
  if is_given(self._opts.response_modalities)
589
- else [Modality.AUDIO],
590
- generation_config=GenerationConfig(
732
+ else [types.Modality.AUDIO],
733
+ generation_config=types.GenerationConfig(
591
734
  candidate_count=self._opts.candidate_count,
592
735
  temperature=temp,
593
736
  max_output_tokens=self._opts.max_output_tokens
@@ -602,34 +745,45 @@ class RealtimeSession(llm.RealtimeSession):
602
745
  if is_given(self._opts.frequency_penalty)
603
746
  else None,
604
747
  ),
605
- system_instruction=Content(parts=[Part(text=self._opts.instructions)])
748
+ system_instruction=types.Content(parts=[types.Part(text=self._opts.instructions)])
606
749
  if is_given(self._opts.instructions)
607
750
  else None,
608
- speech_config=SpeechConfig(
609
- voice_config=VoiceConfig(
610
- prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
751
+ speech_config=types.SpeechConfig(
752
+ voice_config=types.VoiceConfig(
753
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=self._opts.voice)
611
754
  ),
612
755
  language_code=self._opts.language if is_given(self._opts.language) else None,
613
756
  ),
614
- tools=[Tool(function_declarations=self._gemini_declarations)],
757
+ tools=tools_config, # type: ignore
615
758
  input_audio_transcription=self._opts.input_audio_transcription,
616
759
  output_audio_transcription=self._opts.output_audio_transcription,
617
- session_resumption=SessionResumptionConfig(handle=self._session_resumption_handle),
618
- realtime_input_config=RealtimeInputConfig(
619
- automatic_activity_detection=AutomaticActivityDetection(),
760
+ session_resumption=types.SessionResumptionConfig(
761
+ handle=self._session_resumption_handle
620
762
  ),
621
763
  )
622
764
 
623
- def _start_new_generation(self):
765
+ if is_given(self._opts.proactivity):
766
+ conf.proactivity = types.ProactivityConfig(proactive_audio=self._opts.proactivity)
767
+ if is_given(self._opts.enable_affective_dialog):
768
+ conf.enable_affective_dialog = self._opts.enable_affective_dialog
769
+ if is_given(self._opts.realtime_input_config):
770
+ conf.realtime_input_config = self._opts.realtime_input_config
771
+ if is_given(self._opts.context_window_compression):
772
+ conf.context_window_compression = self._opts.context_window_compression
773
+
774
+ return conf
775
+
776
+ def _start_new_generation(self) -> None:
624
777
  if self._current_generation and not self._current_generation._done:
625
778
  logger.warning("starting new generation while another is active. Finalizing previous.")
626
779
  self._mark_current_generation_done()
627
780
 
628
- response_id = utils.shortuuid("gemini-turn-")
781
+ response_id = utils.shortuuid("GR_")
629
782
  self._current_generation = _ResponseGeneration(
630
783
  message_ch=utils.aio.Chan[llm.MessageGeneration](),
631
784
  function_ch=utils.aio.Chan[llm.FunctionCall](),
632
785
  response_id=response_id,
786
+ input_id=utils.shortuuid("GI_"),
633
787
  text_ch=utils.aio.Chan[str](),
634
788
  audio_ch=utils.aio.Chan[rtc.AudioFrame](),
635
789
  _created_timestamp=time.time(),
@@ -656,21 +810,23 @@ class RealtimeSession(llm.RealtimeSession):
656
810
 
657
811
  self.emit("generation_created", generation_event)
658
812
 
659
- def _handle_server_content(self, server_content: LiveServerContent):
813
+ def _handle_server_content(self, server_content: types.LiveServerContent) -> None:
660
814
  current_gen = self._current_generation
661
815
  if not current_gen:
662
816
  logger.warning("received server content but no active generation.")
663
817
  return
664
818
 
665
819
  if model_turn := server_content.model_turn:
666
- for part in model_turn.parts:
820
+ for part in model_turn.parts or []:
667
821
  if part.text:
668
- current_gen.text_ch.send_nowait(part.text)
822
+ current_gen.push_text(part.text)
669
823
  if part.inline_data:
670
824
  if not current_gen._first_token_timestamp:
671
825
  current_gen._first_token_timestamp = time.time()
672
826
  frame_data = part.inline_data.data
673
827
  try:
828
+ if not isinstance(frame_data, bytes):
829
+ raise ValueError("frame_data is not bytes")
674
830
  frame = rtc.AudioFrame(
675
831
  data=frame_data,
676
832
  sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
@@ -692,7 +848,7 @@ class RealtimeSession(llm.RealtimeSession):
692
848
  self.emit(
693
849
  "input_audio_transcription_completed",
694
850
  llm.InputTranscriptionCompleted(
695
- item_id=current_gen.response_id,
851
+ item_id=current_gen.input_id,
696
852
  transcript=current_gen.input_transcription,
697
853
  is_final=False,
698
854
  ),
@@ -701,20 +857,9 @@ class RealtimeSession(llm.RealtimeSession):
701
857
  if output_transcription := server_content.output_transcription:
702
858
  text = output_transcription.text
703
859
  if text:
704
- current_gen.text_ch.send_nowait(text)
860
+ current_gen.push_text(text)
705
861
 
706
- if server_content.generation_complete:
707
- # The only way we'd know that the transcription is complete is by when they are
708
- # done with generation
709
- if current_gen.input_transcription:
710
- self.emit(
711
- "input_audio_transcription_completed",
712
- llm.InputTranscriptionCompleted(
713
- item_id=current_gen.response_id,
714
- transcript=current_gen.input_transcription,
715
- is_final=True,
716
- ),
717
- )
862
+ if server_content.generation_complete or server_content.turn_complete:
718
863
  current_gen._completed_timestamp = time.time()
719
864
 
720
865
  if server_content.interrupted:
@@ -724,10 +869,38 @@ class RealtimeSession(llm.RealtimeSession):
724
869
  self._mark_current_generation_done()
725
870
 
726
871
  def _mark_current_generation_done(self) -> None:
727
- if not self._current_generation:
872
+ if not self._current_generation or self._current_generation._done:
728
873
  return
729
874
 
730
875
  gen = self._current_generation
876
+
877
+ # The only way we'd know that the transcription is complete is by when they are
878
+ # done with generation
879
+ if gen.input_transcription:
880
+ self.emit(
881
+ "input_audio_transcription_completed",
882
+ llm.InputTranscriptionCompleted(
883
+ item_id=gen.input_id,
884
+ transcript=gen.input_transcription,
885
+ is_final=True,
886
+ ),
887
+ )
888
+
889
+ # since gemini doesn't give us a view of the chat history on the server side,
890
+ # we would handle it manually here
891
+ self._chat_ctx.add_message(
892
+ role="user",
893
+ content=gen.input_transcription,
894
+ id=gen.input_id,
895
+ )
896
+
897
+ if gen.output_text:
898
+ self._chat_ctx.add_message(
899
+ role="assistant",
900
+ content=gen.output_text,
901
+ id=gen.response_id,
902
+ )
903
+
731
904
  if not gen.text_ch.closed:
732
905
  gen.text_ch.close()
733
906
  if not gen.audio_ch.closed:
@@ -737,36 +910,37 @@ class RealtimeSession(llm.RealtimeSession):
737
910
  gen.message_ch.close()
738
911
  gen._done = True
739
912
 
740
- def _handle_input_speech_started(self):
913
+ def _handle_input_speech_started(self) -> None:
741
914
  self.emit("input_speech_started", llm.InputSpeechStartedEvent())
742
915
 
743
- def _handle_tool_calls(self, tool_call: LiveServerToolCall):
916
+ def _handle_tool_calls(self, tool_call: types.LiveServerToolCall) -> None:
744
917
  if not self._current_generation:
745
918
  logger.warning("received tool call but no active generation.")
746
919
  return
747
920
 
748
921
  gen = self._current_generation
749
- for fnc_call in tool_call.function_calls:
922
+ for fnc_call in tool_call.function_calls or []:
750
923
  arguments = json.dumps(fnc_call.args)
751
924
 
752
925
  gen.function_ch.send_nowait(
753
926
  llm.FunctionCall(
754
927
  call_id=fnc_call.id or utils.shortuuid("fnc-call-"),
755
- name=fnc_call.name,
928
+ name=fnc_call.name, # type: ignore
756
929
  arguments=arguments,
757
930
  )
758
931
  )
932
+ self._on_final_input_audio_transcription()
759
933
  self._mark_current_generation_done()
760
934
 
761
935
  def _handle_tool_call_cancellation(
762
- self, tool_call_cancellation: LiveServerToolCallCancellation
763
- ):
936
+ self, tool_call_cancellation: types.LiveServerToolCallCancellation
937
+ ) -> None:
764
938
  logger.warning(
765
939
  "server cancelled tool calls",
766
940
  extra={"function_call_ids": tool_call_cancellation.ids},
767
941
  )
768
942
 
769
- def _handle_usage_metadata(self, usage_metadata: UsageMetadata):
943
+ def _handle_usage_metadata(self, usage_metadata: types.UsageMetadata) -> None:
770
944
  current_gen = self._current_generation
771
945
  if not current_gen:
772
946
  logger.warning("no active generation to report metrics for")
@@ -782,18 +956,21 @@ class RealtimeSession(llm.RealtimeSession):
782
956
  ) - current_gen._created_timestamp
783
957
 
784
958
  def _token_details_map(
785
- token_details: list[ModalityTokenCount] | None,
786
- ) -> dict[Modality, int]:
959
+ token_details: list[types.ModalityTokenCount] | None,
960
+ ) -> dict[str, int]:
787
961
  token_details_map = {"audio_tokens": 0, "text_tokens": 0, "image_tokens": 0}
788
962
  if not token_details:
789
963
  return token_details_map
790
964
 
791
965
  for token_detail in token_details:
792
- if token_detail.modality == Modality.AUDIO:
966
+ if not token_detail.token_count:
967
+ continue
968
+
969
+ if token_detail.modality == types.MediaModality.AUDIO:
793
970
  token_details_map["audio_tokens"] += token_detail.token_count
794
- elif token_detail.modality == Modality.TEXT:
971
+ elif token_detail.modality == types.MediaModality.TEXT:
795
972
  token_details_map["text_tokens"] += token_detail.token_count
796
- elif token_detail.modality == Modality.IMAGE:
973
+ elif token_detail.modality == types.MediaModality.IMAGE:
797
974
  token_details_map["image_tokens"] += token_detail.token_count
798
975
  return token_details_map
799
976
 
@@ -807,7 +984,9 @@ class RealtimeSession(llm.RealtimeSession):
807
984
  input_tokens=usage_metadata.prompt_token_count or 0,
808
985
  output_tokens=usage_metadata.response_token_count or 0,
809
986
  total_tokens=usage_metadata.total_token_count or 0,
810
- tokens_per_second=(usage_metadata.response_token_count or 0) / duration,
987
+ tokens_per_second=(usage_metadata.response_token_count or 0) / duration
988
+ if duration > 0
989
+ else 0,
811
990
  input_token_details=RealtimeModelMetrics.InputTokenDetails(
812
991
  **_token_details_map(usage_metadata.prompt_tokens_details),
813
992
  cached_tokens=sum(
@@ -824,18 +1003,27 @@ class RealtimeSession(llm.RealtimeSession):
824
1003
  )
825
1004
  self.emit("metrics_collected", metrics)
826
1005
 
827
- def _handle_go_away(self, go_away: LiveServerGoAway):
1006
+ def _handle_go_away(self, go_away: types.LiveServerGoAway) -> None:
828
1007
  logger.warning(
829
1008
  f"Gemini server indicates disconnection soon. Time left: {go_away.time_left}"
830
1009
  )
831
1010
  # TODO(dz): this isn't a seamless reconnection just yet
832
1011
  self._session_should_close.set()
833
1012
 
1013
+ def _on_final_input_audio_transcription(self) -> None:
1014
+ if (gen := self._current_generation) and gen.input_transcription:
1015
+ self.emit(
1016
+ "input_audio_transcription_completed",
1017
+ llm.InputTranscriptionCompleted(
1018
+ item_id=gen.response_id, transcript=gen.input_transcription, is_final=True
1019
+ ),
1020
+ )
1021
+
834
1022
  def commit_audio(self) -> None:
835
1023
  pass
836
1024
 
837
1025
  def clear_audio(self) -> None:
838
- self._bstream.clear()
1026
+ pass
839
1027
 
840
1028
  def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
841
1029
  if self._input_resampler:
@@ -858,3 +1046,14 @@ class RealtimeSession(llm.RealtimeSession):
858
1046
  yield from self._input_resampler.push(frame)
859
1047
  else:
860
1048
  yield frame
1049
+
1050
+ def _emit_error(self, error: Exception, recoverable: bool) -> None:
1051
+ self.emit(
1052
+ "error",
1053
+ llm.RealtimeModelError(
1054
+ timestamp=time.time(),
1055
+ label=self._realtime_model._label,
1056
+ error=error,
1057
+ recoverable=recoverable,
1058
+ ),
1059
+ )