livekit-plugins-google 1.0.23__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,41 +10,23 @@ from collections.abc import Iterator
10
10
  from dataclasses import dataclass, field
11
11
 
12
12
  from google import genai
13
+ from google.genai import types
13
14
  from google.genai.live import AsyncSession
14
- from google.genai.types import (
15
- AudioTranscriptionConfig,
16
- Blob,
17
- Content,
18
- FunctionDeclaration,
19
- GenerationConfig,
20
- LiveClientContent,
21
- LiveClientRealtimeInput,
22
- LiveClientToolResponse,
23
- LiveConnectConfig,
24
- LiveServerContent,
25
- LiveServerGoAway,
26
- LiveServerToolCall,
27
- LiveServerToolCallCancellation,
28
- Modality,
29
- ModalityTokenCount,
30
- Part,
31
- PrebuiltVoiceConfig,
32
- RealtimeInputConfig,
33
- SessionResumptionConfig,
34
- SpeechConfig,
35
- Tool,
36
- UsageMetadata,
37
- VoiceConfig,
38
- )
39
15
  from livekit import rtc
40
- from livekit.agents import llm, utils
16
+ from livekit.agents import APIConnectionError, llm, utils
41
17
  from livekit.agents.metrics import RealtimeModelMetrics
42
- from livekit.agents.types import NOT_GIVEN, NotGivenOr
18
+ from livekit.agents.types import (
19
+ DEFAULT_API_CONNECT_OPTIONS,
20
+ NOT_GIVEN,
21
+ APIConnectOptions,
22
+ NotGivenOr,
23
+ )
43
24
  from livekit.agents.utils import audio as audio_utils, images, is_given
44
25
  from livekit.plugins.google.beta.realtime.api_proto import ClientEvents, LiveAPIModels, Voice
45
26
 
46
27
  from ...log import logger
47
- from ...utils import get_tool_results_for_realtime, to_chat_ctx, to_fnc_ctx
28
+ from ...tools import _LLMTool
29
+ from ...utils import create_tools_config, get_tool_results_for_realtime, to_fnc_ctx
48
30
 
49
31
  INPUT_AUDIO_SAMPLE_RATE = 16000
50
32
  INPUT_AUDIO_CHANNELS = 1
@@ -70,7 +52,7 @@ class _RealtimeOptions:
70
52
  api_key: str | None
71
53
  voice: Voice | str
72
54
  language: NotGivenOr[str]
73
- response_modalities: NotGivenOr[list[Modality]]
55
+ response_modalities: NotGivenOr[list[types.Modality]]
74
56
  vertexai: bool
75
57
  project: str | None
76
58
  location: str | None
@@ -82,12 +64,16 @@ class _RealtimeOptions:
82
64
  presence_penalty: NotGivenOr[float]
83
65
  frequency_penalty: NotGivenOr[float]
84
66
  instructions: NotGivenOr[str]
85
- input_audio_transcription: AudioTranscriptionConfig | None
86
- output_audio_transcription: AudioTranscriptionConfig | None
67
+ input_audio_transcription: types.AudioTranscriptionConfig | None
68
+ output_audio_transcription: types.AudioTranscriptionConfig | None
87
69
  image_encode_options: NotGivenOr[images.EncodeOptions]
70
+ conn_options: APIConnectOptions
88
71
  enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN
89
72
  proactivity: NotGivenOr[bool] = NOT_GIVEN
90
- realtime_input_config: NotGivenOr[RealtimeInputConfig] = NOT_GIVEN
73
+ realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN
74
+ context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN
75
+ api_version: NotGivenOr[str] = NOT_GIVEN
76
+ gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN
91
77
 
92
78
 
93
79
  @dataclass
@@ -95,10 +81,13 @@ class _ResponseGeneration:
95
81
  message_ch: utils.aio.Chan[llm.MessageGeneration]
96
82
  function_ch: utils.aio.Chan[llm.FunctionCall]
97
83
 
84
+ input_id: str
98
85
  response_id: str
99
86
  text_ch: utils.aio.Chan[str]
100
87
  audio_ch: utils.aio.Chan[rtc.AudioFrame]
88
+
101
89
  input_transcription: str = ""
90
+ output_text: str = ""
102
91
 
103
92
  _created_timestamp: float = field(default_factory=time.time)
104
93
  """The timestamp when the generation is created"""
@@ -109,6 +98,14 @@ class _ResponseGeneration:
109
98
  _done: bool = False
110
99
  """Whether the generation is done (set when the turn is complete)"""
111
100
 
101
+ def push_text(self, text: str) -> None:
102
+ if self.output_text:
103
+ self.output_text += text
104
+ else:
105
+ self.output_text = text
106
+
107
+ self.text_ch.send_nowait(text)
108
+
112
109
 
113
110
  class RealtimeModel(llm.RealtimeModel):
114
111
  def __init__(
@@ -119,7 +116,7 @@ class RealtimeModel(llm.RealtimeModel):
119
116
  api_key: NotGivenOr[str] = NOT_GIVEN,
120
117
  voice: Voice | str = "Puck",
121
118
  language: NotGivenOr[str] = NOT_GIVEN,
122
- modalities: NotGivenOr[list[Modality]] = NOT_GIVEN,
119
+ modalities: NotGivenOr[list[types.Modality]] = NOT_GIVEN,
123
120
  vertexai: NotGivenOr[bool] = NOT_GIVEN,
124
121
  project: NotGivenOr[str] = NOT_GIVEN,
125
122
  location: NotGivenOr[str] = NOT_GIVEN,
@@ -130,12 +127,16 @@ class RealtimeModel(llm.RealtimeModel):
130
127
  top_k: NotGivenOr[int] = NOT_GIVEN,
131
128
  presence_penalty: NotGivenOr[float] = NOT_GIVEN,
132
129
  frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
133
- input_audio_transcription: NotGivenOr[AudioTranscriptionConfig | None] = NOT_GIVEN,
134
- output_audio_transcription: NotGivenOr[AudioTranscriptionConfig | None] = NOT_GIVEN,
130
+ input_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
131
+ output_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
135
132
  image_encode_options: NotGivenOr[images.EncodeOptions] = NOT_GIVEN,
136
133
  enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN,
137
134
  proactivity: NotGivenOr[bool] = NOT_GIVEN,
138
- realtime_input_config: NotGivenOr[RealtimeInputConfig] = NOT_GIVEN,
135
+ realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN,
136
+ context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN,
137
+ api_version: NotGivenOr[str] = NOT_GIVEN,
138
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
139
+ _gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
139
140
  ) -> None:
140
141
  """
141
142
  Initializes a RealtimeModel instance for interacting with Google's Realtime API.
@@ -169,19 +170,30 @@ class RealtimeModel(llm.RealtimeModel):
169
170
  enable_affective_dialog (bool, optional): Whether to enable affective dialog. Defaults to False.
170
171
  proactivity (bool, optional): Whether to enable proactive audio. Defaults to False.
171
172
  realtime_input_config (RealtimeInputConfig, optional): The configuration for realtime input. Defaults to None.
173
+ context_window_compression (ContextWindowCompressionConfig, optional): The configuration for context window compression. Defaults to None.
174
+ conn_options (APIConnectOptions, optional): The configuration for the API connection. Defaults to DEFAULT_API_CONNECT_OPTIONS.
175
+ _gemini_tools (list[LLMTool], optional): Gemini-specific tools to use for the session. This parameter is experimental and may change.
172
176
 
173
177
  Raises:
174
178
  ValueError: If the API key is required but not found.
175
179
  """ # noqa: E501
176
180
  if not is_given(input_audio_transcription):
177
- input_audio_transcription = AudioTranscriptionConfig()
181
+ input_audio_transcription = types.AudioTranscriptionConfig()
178
182
  if not is_given(output_audio_transcription):
179
- output_audio_transcription = AudioTranscriptionConfig()
183
+ output_audio_transcription = types.AudioTranscriptionConfig()
184
+
185
+ server_turn_detection = True
186
+ if (
187
+ is_given(realtime_input_config)
188
+ and realtime_input_config.automatic_activity_detection
189
+ and realtime_input_config.automatic_activity_detection.disabled
190
+ ):
191
+ server_turn_detection = False
180
192
 
181
193
  super().__init__(
182
194
  capabilities=llm.RealtimeCapabilities(
183
195
  message_truncation=False,
184
- turn_detection=True,
196
+ turn_detection=server_turn_detection,
185
197
  user_transcription=input_audio_transcription is not None,
186
198
  auto_tool_reply_generation=True,
187
199
  )
@@ -195,7 +207,7 @@ class RealtimeModel(llm.RealtimeModel):
195
207
 
196
208
  gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
197
209
  gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
198
- gcp_location = (
210
+ gcp_location: str | None = (
199
211
  location
200
212
  if is_given(location)
201
213
  else os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
@@ -243,6 +255,10 @@ class RealtimeModel(llm.RealtimeModel):
243
255
  enable_affective_dialog=enable_affective_dialog,
244
256
  proactivity=proactivity,
245
257
  realtime_input_config=realtime_input_config,
258
+ context_window_compression=context_window_compression,
259
+ api_version=api_version,
260
+ gemini_tools=_gemini_tools,
261
+ conn_options=conn_options,
246
262
  )
247
263
 
248
264
  self._sessions = weakref.WeakSet[RealtimeSession]()
@@ -253,8 +269,19 @@ class RealtimeModel(llm.RealtimeModel):
253
269
  return sess
254
270
 
255
271
  def update_options(
256
- self, *, voice: NotGivenOr[str] = NOT_GIVEN, temperature: NotGivenOr[float] = NOT_GIVEN
272
+ self,
273
+ *,
274
+ voice: NotGivenOr[str] = NOT_GIVEN,
275
+ temperature: NotGivenOr[float] = NOT_GIVEN,
257
276
  ) -> None:
277
+ """
278
+ Update the options for the RealtimeModel.
279
+
280
+ Args:
281
+ voice (str, optional): The voice to use for the session.
282
+ temperature (float, optional): The temperature to use for the session.
283
+ tools (list[LLMTool], optional): The tools to use for the session.
284
+ """
258
285
  if is_given(voice):
259
286
  self._opts.voice = voice
260
287
 
@@ -262,7 +289,10 @@ class RealtimeModel(llm.RealtimeModel):
262
289
  self._opts.temperature = temperature
263
290
 
264
291
  for sess in self._sessions:
265
- sess.update_options(voice=self._opts.voice, temperature=self._opts.temperature)
292
+ sess.update_options(
293
+ voice=self._opts.voice,
294
+ temperature=self._opts.temperature,
295
+ )
266
296
 
267
297
  async def aclose(self) -> None:
268
298
  pass
@@ -273,7 +303,7 @@ class RealtimeSession(llm.RealtimeSession):
273
303
  super().__init__(realtime_model)
274
304
  self._opts = realtime_model._opts
275
305
  self._tools = llm.ToolContext.empty()
276
- self._gemini_declarations: list[FunctionDeclaration] = []
306
+ self._gemini_declarations: list[types.FunctionDeclaration] = []
277
307
  self._chat_ctx = llm.ChatContext.empty()
278
308
  self._msg_ch = utils.aio.Chan[ClientEvents]()
279
309
  self._input_resampler: rtc.AudioResampler | None = None
@@ -285,11 +315,20 @@ class RealtimeSession(llm.RealtimeSession):
285
315
  samples_per_channel=INPUT_AUDIO_SAMPLE_RATE // 20,
286
316
  )
287
317
 
318
+ api_version = self._opts.api_version
319
+ if not api_version and (self._opts.enable_affective_dialog or self._opts.proactivity):
320
+ api_version = "v1alpha"
321
+
322
+ http_options = types.HttpOptions(timeout=int(self._opts.conn_options.timeout * 1000))
323
+ if api_version:
324
+ http_options.api_version = api_version
325
+
288
326
  self._client = genai.Client(
289
327
  api_key=self._opts.api_key,
290
328
  vertexai=self._opts.vertexai,
291
329
  project=self._opts.project,
292
330
  location=self._opts.location,
331
+ http_options=http_options,
293
332
  )
294
333
 
295
334
  self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
@@ -302,8 +341,9 @@ class RealtimeSession(llm.RealtimeSession):
302
341
  self._pending_generation_fut: asyncio.Future[llm.GenerationCreatedEvent] | None = None
303
342
 
304
343
  self._session_resumption_handle: str | None = None
305
-
344
+ self._in_user_activity = False
306
345
  self._session_lock = asyncio.Lock()
346
+ self._num_retries = 0
307
347
 
308
348
  async def _close_active_session(self) -> None:
309
349
  async with self._session_lock:
@@ -315,7 +355,7 @@ class RealtimeSession(llm.RealtimeSession):
315
355
  finally:
316
356
  self._active_session = None
317
357
 
318
- def _mark_restart_needed(self):
358
+ def _mark_restart_needed(self) -> None:
319
359
  if not self._session_should_close.is_set():
320
360
  self._session_should_close.set()
321
361
  # reset the msg_ch, do not send messages from previous session
@@ -346,6 +386,11 @@ class RealtimeSession(llm.RealtimeSession):
346
386
  self._mark_restart_needed()
347
387
 
348
388
  async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
389
+ async with self._session_lock:
390
+ if not self._active_session:
391
+ self._chat_ctx = chat_ctx.copy()
392
+ return
393
+
349
394
  diff_ops = llm.utils.compute_chat_ctx_diff(self._chat_ctx, chat_ctx)
350
395
 
351
396
  if diff_ops.to_remove:
@@ -358,15 +403,23 @@ class RealtimeSession(llm.RealtimeSession):
358
403
  append_ctx.items.append(item)
359
404
 
360
405
  if append_ctx.items:
361
- turns, _ = to_chat_ctx(append_ctx, id(self), ignore_functions=True)
406
+ turns_dict, _ = append_ctx.copy(
407
+ exclude_function_call=True,
408
+ ).to_provider_format(format="google", inject_dummy_user_message=False)
409
+ # we are not generating, and do not need to inject
410
+ turns = [types.Content.model_validate(turn) for turn in turns_dict]
362
411
  tool_results = get_tool_results_for_realtime(append_ctx, vertexai=self._opts.vertexai)
363
412
  if turns:
364
- self._send_client_event(LiveClientContent(turns=turns, turn_complete=False))
413
+ self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=False))
365
414
  if tool_results:
366
415
  self._send_client_event(tool_results)
367
416
 
368
- async def update_tools(self, tools: list[llm.FunctionTool]) -> None:
369
- new_declarations: list[FunctionDeclaration] = to_fnc_ctx(tools)
417
+ # since we don't have a view of the history on the server side, we'll assume
418
+ # the current state is accurate. this isn't perfect because removals aren't done.
419
+ self._chat_ctx = chat_ctx.copy()
420
+
421
+ async def update_tools(self, tools: list[llm.FunctionTool | llm.RawFunctionTool]) -> None:
422
+ new_declarations: list[types.FunctionDeclaration] = to_fnc_ctx(tools)
370
423
  current_tool_names = {f.name for f in self._gemini_declarations}
371
424
  new_tool_names = {f.name for f in new_declarations}
372
425
 
@@ -383,11 +436,21 @@ class RealtimeSession(llm.RealtimeSession):
383
436
  def tools(self) -> llm.ToolContext:
384
437
  return self._tools.copy()
385
438
 
439
+ @property
440
+ def _manual_activity_detection(self) -> bool:
441
+ if (
442
+ is_given(self._opts.realtime_input_config)
443
+ and self._opts.realtime_input_config.automatic_activity_detection is not None
444
+ and self._opts.realtime_input_config.automatic_activity_detection.disabled
445
+ ):
446
+ return True
447
+ return False
448
+
386
449
  def push_audio(self, frame: rtc.AudioFrame) -> None:
387
450
  for f in self._resample_audio(frame):
388
451
  for nf in self._bstream.write(f.data.tobytes()):
389
- realtime_input = LiveClientRealtimeInput(
390
- media_chunks=[Blob(data=nf.data.tobytes(), mime_type="audio/pcm")]
452
+ realtime_input = types.LiveClientRealtimeInput(
453
+ media_chunks=[types.Blob(data=nf.data.tobytes(), mime_type="audio/pcm")]
391
454
  )
392
455
  self._send_client_event(realtime_input)
393
456
 
@@ -395,8 +458,8 @@ class RealtimeSession(llm.RealtimeSession):
395
458
  encoded_data = images.encode(
396
459
  frame, self._opts.image_encode_options or DEFAULT_IMAGE_ENCODE_OPTIONS
397
460
  )
398
- realtime_input = LiveClientRealtimeInput(
399
- media_chunks=[Blob(data=encoded_data, mime_type="image/jpeg")]
461
+ realtime_input = types.LiveClientRealtimeInput(
462
+ media_chunks=[types.Blob(data=encoded_data, mime_type="image/jpeg")]
400
463
  )
401
464
  self._send_client_event(realtime_input)
402
465
 
@@ -413,16 +476,24 @@ class RealtimeSession(llm.RealtimeSession):
413
476
  )
414
477
  self._pending_generation_fut.cancel("Superseded by new generate_reply call")
415
478
 
416
- fut = asyncio.Future()
479
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
417
480
  self._pending_generation_fut = fut
418
481
 
482
+ if self._in_user_activity:
483
+ self._send_client_event(
484
+ types.LiveClientRealtimeInput(
485
+ activity_end=types.ActivityEnd(),
486
+ )
487
+ )
488
+ self._in_user_activity = False
489
+
419
490
  # Gemini requires the last message to end with user's turn
420
491
  # so we need to add a placeholder user turn in order to trigger a new generation
421
- event = LiveClientContent(turns=[], turn_complete=True)
492
+ turns = []
422
493
  if is_given(instructions):
423
- event.turns.append(Content(parts=[Part(text=instructions)], role="model"))
424
- event.turns.append(Content(parts=[Part(text=".")], role="user"))
425
- self._send_client_event(event)
494
+ turns.append(types.Content(parts=[types.Part(text=instructions)], role="model"))
495
+ turns.append(types.Content(parts=[types.Part(text=".")], role="user"))
496
+ self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=True))
426
497
 
427
498
  def _on_timeout() -> None:
428
499
  if not fut.done():
@@ -439,8 +510,28 @@ class RealtimeSession(llm.RealtimeSession):
439
510
 
440
511
  return fut
441
512
 
513
+ def start_user_activity(self) -> None:
514
+ if not self._manual_activity_detection:
515
+ return
516
+
517
+ if not self._in_user_activity:
518
+ self._in_user_activity = True
519
+ self._send_client_event(
520
+ types.LiveClientRealtimeInput(
521
+ activity_start=types.ActivityStart(),
522
+ )
523
+ )
524
+
442
525
  def interrupt(self) -> None:
443
- pass
526
+ # Gemini Live treats activity start as interruption, so we rely on start_user_activity
527
+ # notifications to handle it
528
+ if (
529
+ self._opts.realtime_input_config
530
+ and self._opts.realtime_input_config.activity_handling
531
+ == types.ActivityHandling.NO_INTERRUPTION
532
+ ):
533
+ return
534
+ self.start_user_activity()
444
535
 
445
536
  def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
446
537
  logger.warning("truncate is not supported by the Google Realtime API.")
@@ -467,7 +558,9 @@ class RealtimeSession(llm.RealtimeSession):
467
558
  self._mark_current_generation_done()
468
559
 
469
560
  @utils.log_exceptions(logger=logger)
470
- async def _main_task(self):
561
+ async def _main_task(self) -> None:
562
+ max_retries = self._opts.conn_options.max_retry
563
+
471
564
  while not self._msg_ch.closed:
472
565
  # previous session might not be closed yet, we'll do it here.
473
566
  await self._close_active_session()
@@ -482,7 +575,15 @@ class RealtimeSession(llm.RealtimeSession):
482
575
  ) as session:
483
576
  async with self._session_lock:
484
577
  self._active_session = session
485
-
578
+ turns_dict, _ = self._chat_ctx.copy(
579
+ exclude_function_call=True,
580
+ ).to_provider_format(format="google", inject_dummy_user_message=False)
581
+ if turns_dict:
582
+ turns = [types.Content.model_validate(turn) for turn in turns_dict]
583
+ await session.send_client_content(
584
+ turns=turns, # type: ignore
585
+ turn_complete=False,
586
+ )
486
587
  # queue up existing chat context
487
588
  send_task = asyncio.create_task(
488
589
  self._send_task(session), name="gemini-realtime-send"
@@ -515,12 +616,30 @@ class RealtimeSession(llm.RealtimeSession):
515
616
  except Exception as e:
516
617
  logger.error(f"Gemini Realtime API error: {e}", exc_info=e)
517
618
  if not self._msg_ch.closed:
518
- logger.info("attempting to reconnect after 1 seconds...")
519
- await asyncio.sleep(1)
619
+ # we shouldn't retry when it's not connected, usually this means incorrect
620
+ # parameters or setup
621
+ if not session or max_retries == 0:
622
+ self._emit_error(e, recoverable=False)
623
+ raise APIConnectionError(message="Failed to connect to Gemini Live") from e
624
+
625
+ if self._num_retries == max_retries:
626
+ self._emit_error(e, recoverable=False)
627
+ raise APIConnectionError(
628
+ message=f"Failed to connect to Gemini Live after {max_retries} attempts"
629
+ ) from e
630
+
631
+ retry_interval = self._opts.conn_options._interval_for_retry(self._num_retries)
632
+ logger.warning(
633
+ f"Gemini Realtime API connection failed, retrying in {retry_interval}s",
634
+ exc_info=e,
635
+ extra={"attempt": self._num_retries, "max_retries": max_retries},
636
+ )
637
+ await asyncio.sleep(retry_interval)
638
+ self._num_retries += 1
520
639
  finally:
521
640
  await self._close_active_session()
522
641
 
523
- async def _send_task(self, session: AsyncSession):
642
+ async def _send_task(self, session: AsyncSession) -> None:
524
643
  try:
525
644
  async for msg in self._msg_ch:
526
645
  async with self._session_lock:
@@ -528,15 +647,21 @@ class RealtimeSession(llm.RealtimeSession):
528
647
  not self._active_session or self._active_session != session
529
648
  ):
530
649
  break
531
- if isinstance(msg, LiveClientContent):
650
+ if isinstance(msg, types.LiveClientContent):
532
651
  await session.send_client_content(
533
- turns=msg.turns, turn_complete=msg.turn_complete
652
+ turns=msg.turns, # type: ignore
653
+ turn_complete=msg.turn_complete or True,
534
654
  )
535
- elif isinstance(msg, LiveClientToolResponse):
655
+ elif isinstance(msg, types.LiveClientToolResponse) and msg.function_responses:
536
656
  await session.send_tool_response(function_responses=msg.function_responses)
537
- elif isinstance(msg, LiveClientRealtimeInput):
538
- for media_chunk in msg.media_chunks:
539
- await session.send_realtime_input(media=media_chunk)
657
+ elif isinstance(msg, types.LiveClientRealtimeInput):
658
+ if msg.media_chunks:
659
+ for media_chunk in msg.media_chunks:
660
+ await session.send_realtime_input(media=media_chunk)
661
+ elif msg.activity_start:
662
+ await session.send_realtime_input(activity_start=msg.activity_start)
663
+ elif msg.activity_end:
664
+ await session.send_realtime_input(activity_end=msg.activity_end)
540
665
  else:
541
666
  logger.warning(f"Warning: Received unhandled message type: {type(msg)}")
542
667
 
@@ -547,7 +672,7 @@ class RealtimeSession(llm.RealtimeSession):
547
672
  finally:
548
673
  logger.debug("send task finished.")
549
674
 
550
- async def _recv_task(self, session: AsyncSession):
675
+ async def _recv_task(self, session: AsyncSession) -> None:
551
676
  try:
552
677
  while True:
553
678
  async with self._session_lock:
@@ -583,6 +708,9 @@ class RealtimeSession(llm.RealtimeSession):
583
708
  if response.go_away:
584
709
  self._handle_go_away(response.go_away)
585
710
 
711
+ if self._num_retries > 0:
712
+ self._num_retries = 0 # reset the retry counter
713
+
586
714
  # TODO(dz): a server-side turn is complete
587
715
  except Exception as e:
588
716
  if not self._session_should_close.is_set():
@@ -591,14 +719,18 @@ class RealtimeSession(llm.RealtimeSession):
591
719
  finally:
592
720
  self._mark_current_generation_done()
593
721
 
594
- def _build_connect_config(self) -> LiveConnectConfig:
722
+ def _build_connect_config(self) -> types.LiveConnectConfig:
595
723
  temp = self._opts.temperature if is_given(self._opts.temperature) else None
596
724
 
597
- conf = LiveConnectConfig(
725
+ tools_config = create_tools_config(
726
+ function_tools=self._gemini_declarations,
727
+ gemini_tools=self._opts.gemini_tools if is_given(self._opts.gemini_tools) else None,
728
+ )
729
+ conf = types.LiveConnectConfig(
598
730
  response_modalities=self._opts.response_modalities
599
731
  if is_given(self._opts.response_modalities)
600
- else [Modality.AUDIO],
601
- generation_config=GenerationConfig(
732
+ else [types.Modality.AUDIO],
733
+ generation_config=types.GenerationConfig(
602
734
  candidate_count=self._opts.candidate_count,
603
735
  temperature=temp,
604
736
  max_output_tokens=self._opts.max_output_tokens
@@ -613,41 +745,45 @@ class RealtimeSession(llm.RealtimeSession):
613
745
  if is_given(self._opts.frequency_penalty)
614
746
  else None,
615
747
  ),
616
- system_instruction=Content(parts=[Part(text=self._opts.instructions)])
748
+ system_instruction=types.Content(parts=[types.Part(text=self._opts.instructions)])
617
749
  if is_given(self._opts.instructions)
618
750
  else None,
619
- speech_config=SpeechConfig(
620
- voice_config=VoiceConfig(
621
- prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
751
+ speech_config=types.SpeechConfig(
752
+ voice_config=types.VoiceConfig(
753
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=self._opts.voice)
622
754
  ),
623
755
  language_code=self._opts.language if is_given(self._opts.language) else None,
624
756
  ),
625
- tools=[Tool(function_declarations=self._gemini_declarations)],
757
+ tools=tools_config, # type: ignore
626
758
  input_audio_transcription=self._opts.input_audio_transcription,
627
759
  output_audio_transcription=self._opts.output_audio_transcription,
628
- session_resumption=SessionResumptionConfig(handle=self._session_resumption_handle),
629
- realtime_input_config=self._opts.realtime_input_config,
760
+ session_resumption=types.SessionResumptionConfig(
761
+ handle=self._session_resumption_handle
762
+ ),
630
763
  )
631
764
 
632
765
  if is_given(self._opts.proactivity):
633
- conf.proactivity = {"proactive_audio": self._opts.proactivity}
766
+ conf.proactivity = types.ProactivityConfig(proactive_audio=self._opts.proactivity)
634
767
  if is_given(self._opts.enable_affective_dialog):
635
768
  conf.enable_affective_dialog = self._opts.enable_affective_dialog
636
769
  if is_given(self._opts.realtime_input_config):
637
770
  conf.realtime_input_config = self._opts.realtime_input_config
771
+ if is_given(self._opts.context_window_compression):
772
+ conf.context_window_compression = self._opts.context_window_compression
638
773
 
639
774
  return conf
640
775
 
641
- def _start_new_generation(self):
776
+ def _start_new_generation(self) -> None:
642
777
  if self._current_generation and not self._current_generation._done:
643
778
  logger.warning("starting new generation while another is active. Finalizing previous.")
644
779
  self._mark_current_generation_done()
645
780
 
646
- response_id = utils.shortuuid("gemini-turn-")
781
+ response_id = utils.shortuuid("GR_")
647
782
  self._current_generation = _ResponseGeneration(
648
783
  message_ch=utils.aio.Chan[llm.MessageGeneration](),
649
784
  function_ch=utils.aio.Chan[llm.FunctionCall](),
650
785
  response_id=response_id,
786
+ input_id=utils.shortuuid("GI_"),
651
787
  text_ch=utils.aio.Chan[str](),
652
788
  audio_ch=utils.aio.Chan[rtc.AudioFrame](),
653
789
  _created_timestamp=time.time(),
@@ -674,21 +810,23 @@ class RealtimeSession(llm.RealtimeSession):
674
810
 
675
811
  self.emit("generation_created", generation_event)
676
812
 
677
- def _handle_server_content(self, server_content: LiveServerContent):
813
+ def _handle_server_content(self, server_content: types.LiveServerContent) -> None:
678
814
  current_gen = self._current_generation
679
815
  if not current_gen:
680
816
  logger.warning("received server content but no active generation.")
681
817
  return
682
818
 
683
819
  if model_turn := server_content.model_turn:
684
- for part in model_turn.parts:
820
+ for part in model_turn.parts or []:
685
821
  if part.text:
686
- current_gen.text_ch.send_nowait(part.text)
822
+ current_gen.push_text(part.text)
687
823
  if part.inline_data:
688
824
  if not current_gen._first_token_timestamp:
689
825
  current_gen._first_token_timestamp = time.time()
690
826
  frame_data = part.inline_data.data
691
827
  try:
828
+ if not isinstance(frame_data, bytes):
829
+ raise ValueError("frame_data is not bytes")
692
830
  frame = rtc.AudioFrame(
693
831
  data=frame_data,
694
832
  sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
@@ -710,7 +848,7 @@ class RealtimeSession(llm.RealtimeSession):
710
848
  self.emit(
711
849
  "input_audio_transcription_completed",
712
850
  llm.InputTranscriptionCompleted(
713
- item_id=current_gen.response_id,
851
+ item_id=current_gen.input_id,
714
852
  transcript=current_gen.input_transcription,
715
853
  is_final=False,
716
854
  ),
@@ -719,20 +857,9 @@ class RealtimeSession(llm.RealtimeSession):
719
857
  if output_transcription := server_content.output_transcription:
720
858
  text = output_transcription.text
721
859
  if text:
722
- current_gen.text_ch.send_nowait(text)
860
+ current_gen.push_text(text)
723
861
 
724
- if server_content.generation_complete:
725
- # The only way we'd know that the transcription is complete is by when they are
726
- # done with generation
727
- if current_gen.input_transcription:
728
- self.emit(
729
- "input_audio_transcription_completed",
730
- llm.InputTranscriptionCompleted(
731
- item_id=current_gen.response_id,
732
- transcript=current_gen.input_transcription,
733
- is_final=True,
734
- ),
735
- )
862
+ if server_content.generation_complete or server_content.turn_complete:
736
863
  current_gen._completed_timestamp = time.time()
737
864
 
738
865
  if server_content.interrupted:
@@ -742,10 +869,38 @@ class RealtimeSession(llm.RealtimeSession):
742
869
  self._mark_current_generation_done()
743
870
 
744
871
  def _mark_current_generation_done(self) -> None:
745
- if not self._current_generation:
872
+ if not self._current_generation or self._current_generation._done:
746
873
  return
747
874
 
748
875
  gen = self._current_generation
876
+
877
+ # The only way we'd know that the transcription is complete is by when they are
878
+ # done with generation
879
+ if gen.input_transcription:
880
+ self.emit(
881
+ "input_audio_transcription_completed",
882
+ llm.InputTranscriptionCompleted(
883
+ item_id=gen.input_id,
884
+ transcript=gen.input_transcription,
885
+ is_final=True,
886
+ ),
887
+ )
888
+
889
+ # since gemini doesn't give us a view of the chat history on the server side,
890
+ # we would handle it manually here
891
+ self._chat_ctx.add_message(
892
+ role="user",
893
+ content=gen.input_transcription,
894
+ id=gen.input_id,
895
+ )
896
+
897
+ if gen.output_text:
898
+ self._chat_ctx.add_message(
899
+ role="assistant",
900
+ content=gen.output_text,
901
+ id=gen.response_id,
902
+ )
903
+
749
904
  if not gen.text_ch.closed:
750
905
  gen.text_ch.close()
751
906
  if not gen.audio_ch.closed:
@@ -755,36 +910,37 @@ class RealtimeSession(llm.RealtimeSession):
755
910
  gen.message_ch.close()
756
911
  gen._done = True
757
912
 
758
- def _handle_input_speech_started(self):
913
+ def _handle_input_speech_started(self) -> None:
759
914
  self.emit("input_speech_started", llm.InputSpeechStartedEvent())
760
915
 
761
- def _handle_tool_calls(self, tool_call: LiveServerToolCall):
916
+ def _handle_tool_calls(self, tool_call: types.LiveServerToolCall) -> None:
762
917
  if not self._current_generation:
763
918
  logger.warning("received tool call but no active generation.")
764
919
  return
765
920
 
766
921
  gen = self._current_generation
767
- for fnc_call in tool_call.function_calls:
922
+ for fnc_call in tool_call.function_calls or []:
768
923
  arguments = json.dumps(fnc_call.args)
769
924
 
770
925
  gen.function_ch.send_nowait(
771
926
  llm.FunctionCall(
772
927
  call_id=fnc_call.id or utils.shortuuid("fnc-call-"),
773
- name=fnc_call.name,
928
+ name=fnc_call.name, # type: ignore
774
929
  arguments=arguments,
775
930
  )
776
931
  )
932
+ self._on_final_input_audio_transcription()
777
933
  self._mark_current_generation_done()
778
934
 
779
935
  def _handle_tool_call_cancellation(
780
- self, tool_call_cancellation: LiveServerToolCallCancellation
781
- ):
936
+ self, tool_call_cancellation: types.LiveServerToolCallCancellation
937
+ ) -> None:
782
938
  logger.warning(
783
939
  "server cancelled tool calls",
784
940
  extra={"function_call_ids": tool_call_cancellation.ids},
785
941
  )
786
942
 
787
- def _handle_usage_metadata(self, usage_metadata: UsageMetadata):
943
+ def _handle_usage_metadata(self, usage_metadata: types.UsageMetadata) -> None:
788
944
  current_gen = self._current_generation
789
945
  if not current_gen:
790
946
  logger.warning("no active generation to report metrics for")
@@ -800,8 +956,8 @@ class RealtimeSession(llm.RealtimeSession):
800
956
  ) - current_gen._created_timestamp
801
957
 
802
958
  def _token_details_map(
803
- token_details: list[ModalityTokenCount] | None,
804
- ) -> dict[Modality, int]:
959
+ token_details: list[types.ModalityTokenCount] | None,
960
+ ) -> dict[str, int]:
805
961
  token_details_map = {"audio_tokens": 0, "text_tokens": 0, "image_tokens": 0}
806
962
  if not token_details:
807
963
  return token_details_map
@@ -810,11 +966,11 @@ class RealtimeSession(llm.RealtimeSession):
810
966
  if not token_detail.token_count:
811
967
  continue
812
968
 
813
- if token_detail.modality == Modality.AUDIO:
969
+ if token_detail.modality == types.MediaModality.AUDIO:
814
970
  token_details_map["audio_tokens"] += token_detail.token_count
815
- elif token_detail.modality == Modality.TEXT:
971
+ elif token_detail.modality == types.MediaModality.TEXT:
816
972
  token_details_map["text_tokens"] += token_detail.token_count
817
- elif token_detail.modality == Modality.IMAGE:
973
+ elif token_detail.modality == types.MediaModality.IMAGE:
818
974
  token_details_map["image_tokens"] += token_detail.token_count
819
975
  return token_details_map
820
976
 
@@ -828,7 +984,9 @@ class RealtimeSession(llm.RealtimeSession):
828
984
  input_tokens=usage_metadata.prompt_token_count or 0,
829
985
  output_tokens=usage_metadata.response_token_count or 0,
830
986
  total_tokens=usage_metadata.total_token_count or 0,
831
- tokens_per_second=(usage_metadata.response_token_count or 0) / duration,
987
+ tokens_per_second=(usage_metadata.response_token_count or 0) / duration
988
+ if duration > 0
989
+ else 0,
832
990
  input_token_details=RealtimeModelMetrics.InputTokenDetails(
833
991
  **_token_details_map(usage_metadata.prompt_tokens_details),
834
992
  cached_tokens=sum(
@@ -845,18 +1003,27 @@ class RealtimeSession(llm.RealtimeSession):
845
1003
  )
846
1004
  self.emit("metrics_collected", metrics)
847
1005
 
848
- def _handle_go_away(self, go_away: LiveServerGoAway):
1006
+ def _handle_go_away(self, go_away: types.LiveServerGoAway) -> None:
849
1007
  logger.warning(
850
1008
  f"Gemini server indicates disconnection soon. Time left: {go_away.time_left}"
851
1009
  )
852
1010
  # TODO(dz): this isn't a seamless reconnection just yet
853
1011
  self._session_should_close.set()
854
1012
 
1013
+ def _on_final_input_audio_transcription(self) -> None:
1014
+ if (gen := self._current_generation) and gen.input_transcription:
1015
+ self.emit(
1016
+ "input_audio_transcription_completed",
1017
+ llm.InputTranscriptionCompleted(
1018
+ item_id=gen.response_id, transcript=gen.input_transcription, is_final=True
1019
+ ),
1020
+ )
1021
+
855
1022
  def commit_audio(self) -> None:
856
1023
  pass
857
1024
 
858
1025
  def clear_audio(self) -> None:
859
- self._bstream.clear()
1026
+ pass
860
1027
 
861
1028
  def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
862
1029
  if self._input_resampler:
@@ -879,3 +1046,14 @@ class RealtimeSession(llm.RealtimeSession):
879
1046
  yield from self._input_resampler.push(frame)
880
1047
  else:
881
1048
  yield frame
1049
+
1050
+ def _emit_error(self, error: Exception, recoverable: bool) -> None:
1051
+ self.emit(
1052
+ "error",
1053
+ llm.RealtimeModelError(
1054
+ timestamp=time.time(),
1055
+ label=self._realtime_model._label,
1056
+ error=error,
1057
+ recoverable=recoverable,
1058
+ ),
1059
+ )