livekit-plugins-google 0.3.0__py3-none-any.whl → 1.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1249 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import json
6
+ import os
7
+ import time
8
+ import weakref
9
+ from collections.abc import Iterator
10
+ from dataclasses import dataclass, field
11
+ from typing import Literal
12
+
13
+ from google.auth._default_async import default_async
14
+ from google.genai import Client as GenAIClient, types
15
+ from google.genai.live import AsyncSession
16
+ from livekit import rtc
17
+ from livekit.agents import APIConnectionError, llm, utils
18
+ from livekit.agents.metrics import RealtimeModelMetrics
19
+ from livekit.agents.metrics.base import Metadata
20
+ from livekit.agents.types import (
21
+ DEFAULT_API_CONNECT_OPTIONS,
22
+ NOT_GIVEN,
23
+ APIConnectOptions,
24
+ NotGivenOr,
25
+ )
26
+ from livekit.agents.utils import audio as audio_utils, images, is_given
27
+ from livekit.plugins.google.realtime.api_proto import ClientEvents, LiveAPIModels, Voice
28
+
29
+ from ..log import logger
30
+ from ..utils import create_tools_config, get_tool_results_for_realtime
31
+ from ..version import __version__
32
+
33
+ INPUT_AUDIO_SAMPLE_RATE = 16000
34
+ INPUT_AUDIO_CHANNELS = 1
35
+ OUTPUT_AUDIO_SAMPLE_RATE = 24000
36
+ OUTPUT_AUDIO_CHANNELS = 1
37
+
38
+ DEFAULT_IMAGE_ENCODE_OPTIONS = images.EncodeOptions(
39
+ format="JPEG",
40
+ quality=75,
41
+ resize_options=images.ResizeOptions(width=1024, height=1024, strategy="scale_aspect_fit"),
42
+ )
43
+
44
+ lk_google_debug = int(os.getenv("LK_GOOGLE_DEBUG", 0))
45
+
46
+
47
+ @dataclass
48
+ class InputTranscription:
49
+ item_id: str
50
+ transcript: str
51
+
52
+
53
+ @dataclass
54
+ class _RealtimeOptions:
55
+ model: LiveAPIModels | str
56
+ api_key: str | None
57
+ voice: Voice | str
58
+ language: NotGivenOr[str]
59
+ response_modalities: list[types.Modality]
60
+ vertexai: bool
61
+ project: str | None
62
+ location: str | None
63
+ candidate_count: int
64
+ temperature: NotGivenOr[float]
65
+ max_output_tokens: NotGivenOr[int]
66
+ top_p: NotGivenOr[float]
67
+ top_k: NotGivenOr[int]
68
+ presence_penalty: NotGivenOr[float]
69
+ frequency_penalty: NotGivenOr[float]
70
+ instructions: NotGivenOr[str]
71
+ input_audio_transcription: types.AudioTranscriptionConfig | None
72
+ output_audio_transcription: types.AudioTranscriptionConfig | None
73
+ image_encode_options: NotGivenOr[images.EncodeOptions]
74
+ conn_options: APIConnectOptions
75
+ http_options: NotGivenOr[types.HttpOptions]
76
+ enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN
77
+ proactivity: NotGivenOr[bool] = NOT_GIVEN
78
+ realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN
79
+ context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN
80
+ api_version: NotGivenOr[str] = NOT_GIVEN
81
+ tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN
82
+ tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN
83
+ thinking_config: NotGivenOr[types.ThinkingConfig] = NOT_GIVEN
84
+ session_resumption: NotGivenOr[types.SessionResumptionConfig] = NOT_GIVEN
85
+
86
+
87
+ @dataclass
88
+ class _ResponseGeneration:
89
+ message_ch: utils.aio.Chan[llm.MessageGeneration]
90
+ function_ch: utils.aio.Chan[llm.FunctionCall]
91
+
92
+ input_id: str
93
+ response_id: str
94
+ text_ch: utils.aio.Chan[str]
95
+ audio_ch: utils.aio.Chan[rtc.AudioFrame]
96
+
97
+ input_transcription: str = ""
98
+ output_text: str = ""
99
+
100
+ _created_timestamp: float = field(default_factory=time.time)
101
+ """The timestamp when the generation is created"""
102
+ _first_token_timestamp: float | None = None
103
+ """The timestamp when the first audio token is received"""
104
+ _completed_timestamp: float | None = None
105
+ """The timestamp when the generation is completed"""
106
+ _done: bool = False
107
+ """Whether the generation is done (set when the turn is complete)"""
108
+
109
+ def push_text(self, text: str) -> None:
110
+ if self.output_text:
111
+ self.output_text += text
112
+ else:
113
+ self.output_text = text
114
+
115
+ self.text_ch.send_nowait(text)
116
+
117
+
118
+ class RealtimeModel(llm.RealtimeModel):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ instructions: NotGivenOr[str] = NOT_GIVEN,
123
+ model: NotGivenOr[LiveAPIModels | str] = NOT_GIVEN,
124
+ api_key: NotGivenOr[str] = NOT_GIVEN,
125
+ voice: Voice | str = "Puck",
126
+ language: NotGivenOr[str] = NOT_GIVEN,
127
+ modalities: NotGivenOr[list[types.Modality]] = NOT_GIVEN,
128
+ vertexai: NotGivenOr[bool] = NOT_GIVEN,
129
+ project: NotGivenOr[str] = NOT_GIVEN,
130
+ location: NotGivenOr[str] = NOT_GIVEN,
131
+ candidate_count: int = 1,
132
+ temperature: NotGivenOr[float] = NOT_GIVEN,
133
+ max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
134
+ top_p: NotGivenOr[float] = NOT_GIVEN,
135
+ top_k: NotGivenOr[int] = NOT_GIVEN,
136
+ presence_penalty: NotGivenOr[float] = NOT_GIVEN,
137
+ frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
138
+ input_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
139
+ output_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
140
+ image_encode_options: NotGivenOr[images.EncodeOptions] = NOT_GIVEN,
141
+ enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN,
142
+ proactivity: NotGivenOr[bool] = NOT_GIVEN,
143
+ realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN,
144
+ context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN,
145
+ tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN,
146
+ tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN,
147
+ session_resumption: NotGivenOr[types.SessionResumptionConfig] = NOT_GIVEN,
148
+ api_version: NotGivenOr[str] = NOT_GIVEN,
149
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
150
+ http_options: NotGivenOr[types.HttpOptions] = NOT_GIVEN,
151
+ thinking_config: NotGivenOr[types.ThinkingConfig] = NOT_GIVEN,
152
+ ) -> None:
153
+ """
154
+ Initializes a RealtimeModel instance for interacting with Google's Realtime API.
155
+
156
+ Environment Requirements:
157
+ - For VertexAI: Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the service account key file or use any of the other Google Cloud auth methods.
158
+ The Google Cloud project and location can be set via `project` and `location` arguments or the environment variables
159
+ `GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`. By default, the project is inferred from the service account key file,
160
+ and the location defaults to "us-central1".
161
+ - For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.
162
+
163
+ Args:
164
+ instructions (str, optional): Initial system instructions for the model. Defaults to "".
165
+ api_key (str, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
166
+ modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
167
+ model (str, optional): The name of the model to use. Defaults to "gemini-2.5-flash-native-audio-preview-12-2025" or "gemini-live-2.5-flash-native-audio" (vertexai).
168
+ voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
169
+ language (str, optional): The language(BCP-47 Code) to use for the API. supported languages - https://ai.google.dev/gemini-api/docs/live#supported-languages
170
+ temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
171
+ vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
172
+ project (str, optional): The project id to use for the API. Defaults to None. (for vertexai)
173
+ location (str, optional): The location to use for the API. Defaults to None. (for vertexai)
174
+ candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
175
+ top_p (float, optional): The top-p value for response generation
176
+ top_k (int, optional): The top-k value for response generation
177
+ presence_penalty (float, optional): The presence penalty for response generation
178
+ frequency_penalty (float, optional): The frequency penalty for response generation
179
+ input_audio_transcription (AudioTranscriptionConfig | None, optional): The configuration for input audio transcription. Defaults to None.)
180
+ output_audio_transcription (AudioTranscriptionConfig | None, optional): The configuration for output audio transcription. Defaults to AudioTranscriptionConfig().
181
+ image_encode_options (images.EncodeOptions, optional): The configuration for image encoding. Defaults to DEFAULT_ENCODE_OPTIONS.
182
+ enable_affective_dialog (bool, optional): Whether to enable affective dialog. Defaults to False.
183
+ proactivity (bool, optional): Whether to enable proactive audio. Defaults to False.
184
+ realtime_input_config (RealtimeInputConfig, optional): The configuration for realtime input. Defaults to None.
185
+ context_window_compression (ContextWindowCompressionConfig, optional): The configuration for context window compression. Defaults to None.
186
+ tool_behavior (Behavior, optional): The behavior for tool call. Default behavior is BLOCK in Gemini Realtime API.
187
+ tool_response_scheduling (FunctionResponseScheduling, optional): The scheduling for tool response. Default scheduling is WHEN_IDLE.
188
+ session_resumption (SessionResumptionConfig, optional): The configuration for session resumption. Defaults to None.
189
+ thinking_config (ThinkingConfig, optional): Native audio thinking configuration.
190
+ conn_options (APIConnectOptions, optional): The configuration for the API connection. Defaults to DEFAULT_API_CONNECT_OPTIONS.
191
+
192
+ Raises:
193
+ ValueError: If the API key is required but not found.
194
+ """ # noqa: E501
195
+ if not is_given(input_audio_transcription):
196
+ input_audio_transcription = types.AudioTranscriptionConfig()
197
+ if not is_given(output_audio_transcription):
198
+ output_audio_transcription = types.AudioTranscriptionConfig()
199
+
200
+ server_turn_detection = True
201
+ if (
202
+ is_given(realtime_input_config)
203
+ and realtime_input_config.automatic_activity_detection
204
+ and realtime_input_config.automatic_activity_detection.disabled
205
+ ):
206
+ server_turn_detection = False
207
+ modalities = modalities if is_given(modalities) else [types.Modality.AUDIO]
208
+
209
+ super().__init__(
210
+ capabilities=llm.RealtimeCapabilities(
211
+ message_truncation=False,
212
+ turn_detection=server_turn_detection,
213
+ user_transcription=input_audio_transcription is not None,
214
+ auto_tool_reply_generation=True,
215
+ audio_output=types.Modality.AUDIO in modalities,
216
+ manual_function_calls=False,
217
+ )
218
+ )
219
+
220
+ if not is_given(model):
221
+ if vertexai:
222
+ model = "gemini-live-2.5-flash-native-audio"
223
+ else:
224
+ model = "gemini-2.5-flash-native-audio-preview-12-2025"
225
+
226
+ gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
227
+ gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
228
+ gcp_location: str | None = (
229
+ location
230
+ if is_given(location)
231
+ else os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
232
+ )
233
+ use_vertexai = (
234
+ vertexai
235
+ if is_given(vertexai)
236
+ else os.environ.get("GOOGLE_GENAI_USE_VERTEXAI", "0").lower() in ["true", "1"]
237
+ )
238
+
239
+ if use_vertexai:
240
+ if not gcp_project:
241
+ _, gcp_project = default_async( # type: ignore
242
+ scopes=["https://www.googleapis.com/auth/cloud-platform"]
243
+ )
244
+ if not gcp_project or not gcp_location:
245
+ raise ValueError(
246
+ "Project is required for VertexAI via project kwarg or GOOGLE_CLOUD_PROJECT environment variable" # noqa: E501
247
+ )
248
+ gemini_api_key = None # VertexAI does not require an API key
249
+ else:
250
+ gcp_project = None
251
+ gcp_location = None
252
+ if not gemini_api_key:
253
+ raise ValueError(
254
+ "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" # noqa: E501
255
+ )
256
+
257
+ self._opts = _RealtimeOptions(
258
+ model=model,
259
+ api_key=gemini_api_key,
260
+ voice=voice,
261
+ response_modalities=modalities,
262
+ vertexai=use_vertexai,
263
+ project=gcp_project,
264
+ location=gcp_location,
265
+ candidate_count=candidate_count,
266
+ temperature=temperature,
267
+ max_output_tokens=max_output_tokens,
268
+ top_p=top_p,
269
+ top_k=top_k,
270
+ presence_penalty=presence_penalty,
271
+ frequency_penalty=frequency_penalty,
272
+ instructions=instructions,
273
+ input_audio_transcription=input_audio_transcription,
274
+ output_audio_transcription=output_audio_transcription,
275
+ language=language,
276
+ image_encode_options=image_encode_options,
277
+ enable_affective_dialog=enable_affective_dialog,
278
+ proactivity=proactivity,
279
+ realtime_input_config=realtime_input_config,
280
+ context_window_compression=context_window_compression,
281
+ api_version=api_version,
282
+ tool_behavior=tool_behavior,
283
+ tool_response_scheduling=tool_response_scheduling,
284
+ conn_options=conn_options,
285
+ http_options=http_options,
286
+ thinking_config=thinking_config,
287
+ session_resumption=session_resumption,
288
+ )
289
+
290
+ self._sessions = weakref.WeakSet[RealtimeSession]()
291
+
292
+ @property
293
+ def model(self) -> str:
294
+ return self._opts.model
295
+
296
+ @property
297
+ def provider(self) -> str:
298
+ if self._opts.vertexai:
299
+ return "Vertex AI"
300
+ else:
301
+ return "Gemini"
302
+
303
+ def session(self) -> RealtimeSession:
304
+ sess = RealtimeSession(self)
305
+ self._sessions.add(sess)
306
+ return sess
307
+
308
+ def update_options(
309
+ self,
310
+ *,
311
+ voice: NotGivenOr[str] = NOT_GIVEN,
312
+ temperature: NotGivenOr[float] = NOT_GIVEN,
313
+ tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN,
314
+ tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN,
315
+ ) -> None:
316
+ """
317
+ Update the options for the RealtimeModel.
318
+
319
+ Args:
320
+ voice (str, optional): The voice to use for the session.
321
+ temperature (float, optional): The temperature to use for the session.
322
+ tools (list[LLMTool], optional): The tools to use for the session.
323
+ """
324
+ if is_given(voice):
325
+ self._opts.voice = voice
326
+
327
+ if is_given(temperature):
328
+ self._opts.temperature = temperature
329
+
330
+ if is_given(tool_behavior):
331
+ self._opts.tool_behavior = tool_behavior
332
+
333
+ if is_given(tool_response_scheduling):
334
+ self._opts.tool_response_scheduling = tool_response_scheduling
335
+
336
+ for sess in self._sessions:
337
+ sess.update_options(
338
+ voice=self._opts.voice,
339
+ temperature=self._opts.temperature,
340
+ tool_behavior=self._opts.tool_behavior,
341
+ tool_response_scheduling=self._opts.tool_response_scheduling,
342
+ )
343
+
344
+ async def aclose(self) -> None:
345
+ pass
346
+
347
+
348
+ class RealtimeSession(llm.RealtimeSession):
349
+ def __init__(self, realtime_model: RealtimeModel) -> None:
350
+ super().__init__(realtime_model)
351
+ self._opts = realtime_model._opts
352
+ self._tools = llm.ToolContext.empty()
353
+ self._chat_ctx = llm.ChatContext.empty()
354
+ self._msg_ch = utils.aio.Chan[ClientEvents]()
355
+ self._input_resampler: rtc.AudioResampler | None = None
356
+
357
+ # 50ms chunks
358
+ self._bstream = audio_utils.AudioByteStream(
359
+ INPUT_AUDIO_SAMPLE_RATE,
360
+ INPUT_AUDIO_CHANNELS,
361
+ samples_per_channel=INPUT_AUDIO_SAMPLE_RATE // 20,
362
+ )
363
+
364
+ api_version = self._opts.api_version
365
+ if (
366
+ not api_version
367
+ and (self._opts.enable_affective_dialog or self._opts.proactivity)
368
+ and not self._opts.vertexai
369
+ ):
370
+ api_version = "v1alpha"
371
+
372
+ http_options = self._opts.http_options or types.HttpOptions(
373
+ timeout=int(self._opts.conn_options.timeout * 1000)
374
+ )
375
+ if api_version:
376
+ http_options.api_version = api_version
377
+ if not http_options.headers:
378
+ http_options.headers = {}
379
+ http_options.headers["x-goog-api-client"] = f"livekit-agents/{__version__}"
380
+
381
+ self._client = GenAIClient(
382
+ api_key=self._opts.api_key,
383
+ vertexai=self._opts.vertexai,
384
+ project=self._opts.project,
385
+ location=self._opts.location,
386
+ http_options=http_options,
387
+ )
388
+
389
+ self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
390
+
391
+ self._current_generation: _ResponseGeneration | None = None
392
+ self._active_session: AsyncSession | None = None
393
+ # indicates if the underlying session should end
394
+ self._session_should_close = asyncio.Event()
395
+ self._response_created_futures: dict[str, asyncio.Future[llm.GenerationCreatedEvent]] = {}
396
+ self._pending_generation_fut: asyncio.Future[llm.GenerationCreatedEvent] | None = None
397
+
398
+ self._session_resumption_handle: str | None = (
399
+ self._opts.session_resumption.handle
400
+ if is_given(self._opts.session_resumption)
401
+ else None
402
+ )
403
+
404
+ self._in_user_activity = False
405
+ self._session_lock = asyncio.Lock()
406
+ self._num_retries = 0
407
+
408
+ async def _close_active_session(self) -> None:
409
+ async with self._session_lock:
410
+ if self._active_session:
411
+ try:
412
+ await self._active_session.close()
413
+ except Exception as e:
414
+ logger.warning(f"error closing Gemini session: {e}")
415
+ finally:
416
+ self._active_session = None
417
+
418
+ def _mark_restart_needed(self, on_error: bool = False) -> None:
419
+ if not self._session_should_close.is_set():
420
+ self._session_should_close.set()
421
+ # reset the msg_ch, do not send messages from previous session
422
+ if not on_error:
423
+ while not self._msg_ch.empty():
424
+ msg = self._msg_ch.recv_nowait()
425
+ if isinstance(msg, types.LiveClientContent) and msg.turn_complete is True:
426
+ logger.warning(
427
+ "discarding client content for turn completion, may cause generate_reply timeout",
428
+ extra={"content": str(msg)},
429
+ )
430
+
431
+ self._msg_ch = utils.aio.Chan[ClientEvents]()
432
+
433
+ def update_options(
434
+ self,
435
+ *,
436
+ voice: NotGivenOr[str] = NOT_GIVEN,
437
+ temperature: NotGivenOr[float] = NOT_GIVEN,
438
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
439
+ tool_behavior: NotGivenOr[types.Behavior] = NOT_GIVEN,
440
+ tool_response_scheduling: NotGivenOr[types.FunctionResponseScheduling] = NOT_GIVEN,
441
+ ) -> None:
442
+ should_restart = False
443
+ if is_given(voice) and self._opts.voice != voice:
444
+ self._opts.voice = voice
445
+ should_restart = True
446
+
447
+ if is_given(temperature) and self._opts.temperature != temperature:
448
+ self._opts.temperature = temperature if is_given(temperature) else NOT_GIVEN
449
+ should_restart = True
450
+
451
+ if is_given(tool_behavior) and self._opts.tool_behavior != tool_behavior:
452
+ self._opts.tool_behavior = tool_behavior
453
+ should_restart = True
454
+
455
+ if (
456
+ is_given(tool_response_scheduling)
457
+ and self._opts.tool_response_scheduling != tool_response_scheduling
458
+ ):
459
+ self._opts.tool_response_scheduling = tool_response_scheduling
460
+ # no need to restart
461
+
462
+ if is_given(tool_choice):
463
+ logger.warning("tool_choice is not supported by the Google Realtime API.")
464
+
465
+ if should_restart:
466
+ self._mark_restart_needed()
467
+
468
+ async def update_instructions(self, instructions: str) -> None:
469
+ if not is_given(self._opts.instructions) or self._opts.instructions != instructions:
470
+ self._opts.instructions = instructions
471
+ self._mark_restart_needed()
472
+
473
+ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
474
+ chat_ctx = chat_ctx.copy(
475
+ exclude_handoff=True, exclude_instructions=True, exclude_empty_message=True
476
+ )
477
+ async with self._session_lock:
478
+ if not self._active_session:
479
+ self._chat_ctx = chat_ctx
480
+ return
481
+
482
+ diff_ops = llm.utils.compute_chat_ctx_diff(self._chat_ctx, chat_ctx)
483
+
484
+ if diff_ops.to_remove:
485
+ logger.warning("Gemini Live does not support removing messages")
486
+
487
+ append_ctx = llm.ChatContext.empty()
488
+ for _, item_id in diff_ops.to_create:
489
+ item = chat_ctx.get_by_id(item_id)
490
+ if item:
491
+ append_ctx.items.append(item)
492
+
493
+ if append_ctx.items:
494
+ turns_dict, _ = append_ctx.copy(exclude_function_call=True).to_provider_format(
495
+ format="google", inject_dummy_user_message=False
496
+ )
497
+ # we are not generating, and do not need to inject
498
+ turns = [types.Content.model_validate(turn) for turn in turns_dict]
499
+ tool_results = get_tool_results_for_realtime(
500
+ append_ctx,
501
+ vertexai=self._opts.vertexai,
502
+ tool_response_scheduling=self._opts.tool_response_scheduling,
503
+ )
504
+ if turns:
505
+ self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=False))
506
+ if tool_results:
507
+ self._send_client_event(tool_results)
508
+
509
+ # since we don't have a view of the history on the server side, we'll assume
510
+ # the current state is accurate. this isn't perfect because removals aren't done.
511
+ self._chat_ctx = chat_ctx
512
+
513
+ async def update_tools(self, tools: list[llm.Tool]) -> None:
514
+ tool_ctx = llm.ToolContext(tools)
515
+ if self._tools == tool_ctx:
516
+ return
517
+
518
+ self._tools = tool_ctx
519
+ self._mark_restart_needed()
520
+
521
+ @property
522
+ def chat_ctx(self) -> llm.ChatContext:
523
+ return self._chat_ctx.copy()
524
+
525
+ @property
526
+ def tools(self) -> llm.ToolContext:
527
+ return self._tools.copy()
528
+
529
+ @property
530
+ def _manual_activity_detection(self) -> bool:
531
+ if (
532
+ is_given(self._opts.realtime_input_config)
533
+ and self._opts.realtime_input_config.automatic_activity_detection is not None
534
+ and self._opts.realtime_input_config.automatic_activity_detection.disabled
535
+ ):
536
+ return True
537
+ return False
538
+
539
+ @property
540
+ def session_resumption_handle(self) -> str | None:
541
+ return self._session_resumption_handle
542
+
543
+ def push_audio(self, frame: rtc.AudioFrame) -> None:
544
+ for f in self._resample_audio(frame):
545
+ for nf in self._bstream.write(f.data.tobytes()):
546
+ realtime_input = types.LiveClientRealtimeInput(
547
+ media_chunks=[
548
+ types.Blob(
549
+ data=nf.data.tobytes(),
550
+ mime_type=f"audio/pcm;rate={INPUT_AUDIO_SAMPLE_RATE}",
551
+ )
552
+ ]
553
+ )
554
+ self._send_client_event(realtime_input)
555
+
556
+ def push_video(self, frame: rtc.VideoFrame) -> None:
557
+ encoded_data = images.encode(
558
+ frame, self._opts.image_encode_options or DEFAULT_IMAGE_ENCODE_OPTIONS
559
+ )
560
+ realtime_input = types.LiveClientRealtimeInput(
561
+ media_chunks=[types.Blob(data=encoded_data, mime_type="image/jpeg")]
562
+ )
563
+ self._send_client_event(realtime_input)
564
+
565
+ def _send_client_event(self, event: ClientEvents) -> None:
566
+ with contextlib.suppress(utils.aio.channel.ChanClosed):
567
+ self._msg_ch.send_nowait(event)
568
+
569
+ def generate_reply(
570
+ self, *, instructions: NotGivenOr[str] = NOT_GIVEN
571
+ ) -> asyncio.Future[llm.GenerationCreatedEvent]:
572
+ if self._pending_generation_fut and not self._pending_generation_fut.done():
573
+ logger.warning(
574
+ "generate_reply called while another generation is pending, cancelling previous."
575
+ )
576
+ self._pending_generation_fut.cancel("Superseded by new generate_reply call")
577
+
578
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
579
+ self._pending_generation_fut = fut
580
+
581
+ if self._in_user_activity:
582
+ self._send_client_event(
583
+ types.LiveClientRealtimeInput(
584
+ activity_end=types.ActivityEnd(),
585
+ )
586
+ )
587
+ self._in_user_activity = False
588
+
589
+ # Gemini requires the last message to end with user's turn
590
+ # so we need to add a placeholder user turn in order to trigger a new generation
591
+ turns = []
592
+ if is_given(instructions):
593
+ turns.append(types.Content(parts=[types.Part(text=instructions)], role="model"))
594
+ turns.append(types.Content(parts=[types.Part(text=".")], role="user"))
595
+ self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=True))
596
+
597
+ def _on_timeout() -> None:
598
+ if not fut.done():
599
+ fut.set_exception(
600
+ llm.RealtimeError(
601
+ "generate_reply timed out waiting for generation_created event."
602
+ )
603
+ )
604
+ if self._pending_generation_fut is fut:
605
+ self._pending_generation_fut = None
606
+
607
+ timeout_handle = asyncio.get_event_loop().call_later(5.0, _on_timeout)
608
+ fut.add_done_callback(lambda _: timeout_handle.cancel())
609
+
610
+ return fut
611
+
612
+ def start_user_activity(self) -> None:
613
+ if not self._manual_activity_detection:
614
+ return
615
+
616
+ if not self._in_user_activity:
617
+ self._in_user_activity = True
618
+ self._send_client_event(
619
+ types.LiveClientRealtimeInput(
620
+ activity_start=types.ActivityStart(),
621
+ )
622
+ )
623
+
624
+ def interrupt(self) -> None:
625
+ # Gemini Live treats activity start as interruption, so we rely on start_user_activity
626
+ # notifications to handle it
627
+ if (
628
+ self._opts.realtime_input_config
629
+ and self._opts.realtime_input_config.activity_handling
630
+ == types.ActivityHandling.NO_INTERRUPTION
631
+ ):
632
+ return
633
+ self.start_user_activity()
634
+
635
+ def truncate(
636
+ self,
637
+ *,
638
+ message_id: str,
639
+ modalities: list[Literal["text", "audio"]],
640
+ audio_end_ms: int,
641
+ audio_transcript: NotGivenOr[str] = NOT_GIVEN,
642
+ ) -> None:
643
+ logger.warning("truncate is not supported by the Google Realtime API.")
644
+ pass
645
+
646
+ async def aclose(self) -> None:
647
+ self._msg_ch.close()
648
+ self._session_should_close.set()
649
+
650
+ if self._main_atask:
651
+ await utils.aio.cancel_and_wait(self._main_atask)
652
+
653
+ await self._close_active_session()
654
+
655
+ if self._pending_generation_fut and not self._pending_generation_fut.done():
656
+ self._pending_generation_fut.cancel("Session closed")
657
+
658
+ for fut in self._response_created_futures.values():
659
+ if not fut.done():
660
+ fut.set_exception(llm.RealtimeError("Session closed before response created"))
661
+ self._response_created_futures.clear()
662
+
663
+ if self._current_generation:
664
+ self._mark_current_generation_done()
665
+
666
+ @utils.log_exceptions(logger=logger)
667
+ async def _main_task(self) -> None:
668
+ max_retries = self._opts.conn_options.max_retry
669
+
670
+ while not self._msg_ch.closed:
671
+ # previous session might not be closed yet, we'll do it here.
672
+ await self._close_active_session()
673
+
674
+ self._session_should_close.clear()
675
+ config = self._build_connect_config()
676
+ session = None
677
+ try:
678
+ logger.debug("connecting to Gemini Realtime API...")
679
+ async with self._client.aio.live.connect(
680
+ model=self._opts.model, config=config
681
+ ) as session:
682
+ async with self._session_lock:
683
+ self._active_session = session
684
+ turns_dict, _ = self._chat_ctx.copy(
685
+ exclude_function_call=True,
686
+ exclude_handoff=True,
687
+ exclude_instructions=True,
688
+ exclude_empty_message=True,
689
+ ).to_provider_format(format="google", inject_dummy_user_message=False)
690
+ if turns_dict:
691
+ turns = [types.Content.model_validate(turn) for turn in turns_dict]
692
+ await session.send_client_content(
693
+ turns=turns, # type: ignore
694
+ turn_complete=False,
695
+ )
696
+ # queue up existing chat context
697
+ send_task = asyncio.create_task(
698
+ self._send_task(session), name="gemini-realtime-send"
699
+ )
700
+ recv_task = asyncio.create_task(
701
+ self._recv_task(session), name="gemini-realtime-recv"
702
+ )
703
+ restart_wait_task = asyncio.create_task(
704
+ self._session_should_close.wait(), name="gemini-restart-wait"
705
+ )
706
+
707
+ done, pending = await asyncio.wait(
708
+ [send_task, recv_task, restart_wait_task],
709
+ return_when=asyncio.FIRST_COMPLETED,
710
+ )
711
+
712
+ for task in done:
713
+ if task is not restart_wait_task and task.exception():
714
+ logger.error(f"error in task {task.get_name()}: {task.exception()}")
715
+ raise task.exception() or Exception(f"{task.get_name()} failed")
716
+
717
+ if restart_wait_task not in done and self._msg_ch.closed:
718
+ break
719
+
720
+ for task in pending:
721
+ await utils.aio.cancel_and_wait(task)
722
+
723
+ except asyncio.CancelledError:
724
+ break
725
+ except Exception as e:
726
+ logger.error(f"Gemini Realtime API error: {e}", exc_info=e)
727
+ if not self._msg_ch.closed:
728
+ # we shouldn't retry when it's not connected, usually this means incorrect
729
+ # parameters or setup
730
+ if not session or max_retries == 0:
731
+ self._emit_error(e, recoverable=False)
732
+ raise APIConnectionError(message="Failed to connect to Gemini Live") from e
733
+
734
+ if self._num_retries == max_retries:
735
+ self._emit_error(e, recoverable=False)
736
+ raise APIConnectionError(
737
+ message=f"Failed to connect to Gemini Live after {max_retries} attempts"
738
+ ) from e
739
+
740
+ retry_interval = self._opts.conn_options._interval_for_retry(self._num_retries)
741
+ logger.warning(
742
+ f"Gemini Realtime API connection failed, retrying in {retry_interval}s",
743
+ exc_info=e,
744
+ extra={"attempt": self._num_retries, "max_retries": max_retries},
745
+ )
746
+ await asyncio.sleep(retry_interval)
747
+ self._num_retries += 1
748
+ finally:
749
+ await self._close_active_session()
750
+
751
+ async def _send_task(self, session: AsyncSession) -> None:
752
+ try:
753
+ async for msg in self._msg_ch:
754
+ async with self._session_lock:
755
+ if self._session_should_close.is_set() or (
756
+ not self._active_session or self._active_session != session
757
+ ):
758
+ break
759
+ if isinstance(msg, types.LiveClientContent):
760
+ await session.send_client_content(
761
+ turns=msg.turns, # type: ignore
762
+ turn_complete=msg.turn_complete if msg.turn_complete is not None else True,
763
+ )
764
+ elif isinstance(msg, types.LiveClientToolResponse) and msg.function_responses:
765
+ await session.send_tool_response(function_responses=msg.function_responses)
766
+ elif isinstance(msg, types.LiveClientRealtimeInput):
767
+ if msg.media_chunks:
768
+ for media_chunk in msg.media_chunks:
769
+ await session.send_realtime_input(media=media_chunk)
770
+ elif msg.activity_start:
771
+ await session.send_realtime_input(activity_start=msg.activity_start)
772
+ elif msg.activity_end:
773
+ await session.send_realtime_input(activity_end=msg.activity_end)
774
+ else:
775
+ logger.warning(f"Warning: Received unhandled message type: {type(msg)}")
776
+
777
+ if lk_google_debug and isinstance(
778
+ msg,
779
+ (
780
+ types.LiveClientContent,
781
+ types.LiveClientToolResponse,
782
+ types.LiveClientRealtimeInput,
783
+ ),
784
+ ):
785
+ if not isinstance(msg, types.LiveClientRealtimeInput) or not msg.media_chunks:
786
+ logger.debug(
787
+ f">>> sent {type(msg).__name__}",
788
+ extra={"content": msg.model_dump(exclude_defaults=True)},
789
+ )
790
+
791
+ except Exception as e:
792
+ if not self._session_should_close.is_set():
793
+ logger.error(f"error in send task: {e}", exc_info=e)
794
+ self._mark_restart_needed(on_error=True)
795
+ finally:
796
+ logger.debug("send task finished.")
797
+
798
+ async def _recv_task(self, session: AsyncSession) -> None:
799
+ try:
800
+ while True:
801
+ async with self._session_lock:
802
+ if self._session_should_close.is_set() or (
803
+ not self._active_session or self._active_session != session
804
+ ):
805
+ logger.debug("receive task: Session changed or closed, stopping receive.")
806
+ break
807
+
808
+ async for response in session.receive():
809
+ if lk_google_debug:
810
+ resp_copy = response.model_dump(exclude_defaults=True)
811
+ # remove audio from debugging logs
812
+ if (
813
+ (sc := resp_copy.get("server_content"))
814
+ and (mt := sc.get("model_turn"))
815
+ and (parts := mt.get("parts"))
816
+ ):
817
+ for part in parts:
818
+ if part and part.get("inline_data"):
819
+ part["inline_data"] = "<audio>"
820
+ logger.debug("<<< received response", extra={"response": resp_copy})
821
+
822
+ if not self._current_generation or self._current_generation._done:
823
+ if (sc := response.server_content) and sc.interrupted:
824
+ # two cases an interrupted event is sent without an active generation
825
+ # 1) the generation is done but playout is not finished (turn_complete -> interrupted)
826
+ # 2) the generation is not started (interrupted -> turn_complete)
827
+ # for both cases, we interrupt the agent if there is no pending generation from `generate_reply`
828
+ # for the second case, the pending generation will be stopped by `turn_complete` event coming later
829
+ if not self._pending_generation_fut:
830
+ self._handle_input_speech_started()
831
+
832
+ sc.interrupted = None
833
+ sc_copy = sc.model_dump(exclude_none=True)
834
+ if not sc_copy:
835
+ # ignore empty server content
836
+ response.server_content = None
837
+ if lk_google_debug:
838
+ logger.debug("ignoring empty server content")
839
+
840
+ if self._is_new_generation(response):
841
+ self._start_new_generation()
842
+ if lk_google_debug:
843
+ logger.debug(f"new generation started: {self._current_generation}")
844
+
845
+ if response.session_resumption_update:
846
+ if (
847
+ response.session_resumption_update.resumable
848
+ and response.session_resumption_update.new_handle
849
+ ):
850
+ self._session_resumption_handle = (
851
+ response.session_resumption_update.new_handle
852
+ )
853
+
854
+ if response.server_content:
855
+ self._handle_server_content(response.server_content)
856
+ if response.tool_call:
857
+ self._handle_tool_calls(response.tool_call)
858
+ if response.tool_call_cancellation:
859
+ self._handle_tool_call_cancellation(response.tool_call_cancellation)
860
+ if response.usage_metadata:
861
+ self._handle_usage_metadata(response.usage_metadata)
862
+ if response.go_away:
863
+ self._handle_go_away(response.go_away)
864
+
865
+ if self._num_retries > 0:
866
+ self._num_retries = 0 # reset the retry counter
867
+
868
+ # TODO(dz): a server-side turn is complete
869
+ except Exception as e:
870
+ if not self._session_should_close.is_set():
871
+ logger.error(f"error in receive task: {e}", exc_info=e)
872
+ self._mark_restart_needed(on_error=True)
873
+ finally:
874
+ self._mark_current_generation_done()
875
+
876
+ def _build_connect_config(self) -> types.LiveConnectConfig:
877
+ temp = self._opts.temperature if is_given(self._opts.temperature) else None
878
+
879
+ tools_config = create_tools_config(self._tools, tool_behavior=self._opts.tool_behavior)
880
+ conf = types.LiveConnectConfig(
881
+ response_modalities=self._opts.response_modalities,
882
+ generation_config=types.GenerationConfig(
883
+ candidate_count=self._opts.candidate_count,
884
+ temperature=temp,
885
+ max_output_tokens=self._opts.max_output_tokens
886
+ if is_given(self._opts.max_output_tokens)
887
+ else None,
888
+ top_p=self._opts.top_p if is_given(self._opts.top_p) else None,
889
+ top_k=self._opts.top_k if is_given(self._opts.top_k) else None,
890
+ presence_penalty=self._opts.presence_penalty
891
+ if is_given(self._opts.presence_penalty)
892
+ else None,
893
+ frequency_penalty=self._opts.frequency_penalty
894
+ if is_given(self._opts.frequency_penalty)
895
+ else None,
896
+ thinking_config=self._opts.thinking_config
897
+ if is_given(self._opts.thinking_config)
898
+ else None,
899
+ ),
900
+ system_instruction=types.Content(parts=[types.Part(text=self._opts.instructions)])
901
+ if is_given(self._opts.instructions)
902
+ else None,
903
+ speech_config=types.SpeechConfig(
904
+ voice_config=types.VoiceConfig(
905
+ prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=self._opts.voice)
906
+ ),
907
+ language_code=self._opts.language if is_given(self._opts.language) else None,
908
+ ),
909
+ tools=tools_config,
910
+ input_audio_transcription=self._opts.input_audio_transcription,
911
+ output_audio_transcription=self._opts.output_audio_transcription,
912
+ session_resumption=types.SessionResumptionConfig(
913
+ handle=self._session_resumption_handle
914
+ ),
915
+ )
916
+
917
+ if is_given(self._opts.proactivity):
918
+ conf.proactivity = types.ProactivityConfig(proactive_audio=self._opts.proactivity)
919
+ if is_given(self._opts.enable_affective_dialog):
920
+ conf.enable_affective_dialog = self._opts.enable_affective_dialog
921
+ if is_given(self._opts.realtime_input_config):
922
+ conf.realtime_input_config = self._opts.realtime_input_config
923
+ if is_given(self._opts.context_window_compression):
924
+ conf.context_window_compression = self._opts.context_window_compression
925
+
926
+ return conf
927
+
928
+ def _start_new_generation(self) -> None:
929
+ if self._current_generation and not self._current_generation._done:
930
+ logger.warning("starting new generation while another is active. Finalizing previous.")
931
+ self._mark_current_generation_done()
932
+
933
+ response_id = utils.shortuuid("GR_")
934
+ self._current_generation = _ResponseGeneration(
935
+ message_ch=utils.aio.Chan[llm.MessageGeneration](),
936
+ function_ch=utils.aio.Chan[llm.FunctionCall](),
937
+ response_id=response_id,
938
+ input_id=utils.shortuuid("GI_"),
939
+ text_ch=utils.aio.Chan[str](),
940
+ audio_ch=utils.aio.Chan[rtc.AudioFrame](),
941
+ _created_timestamp=time.time(),
942
+ )
943
+ if not self._realtime_model.capabilities.audio_output:
944
+ self._current_generation.audio_ch.close()
945
+
946
+ msg_modalities = asyncio.Future[list[Literal["text", "audio"]]]()
947
+ msg_modalities.set_result(
948
+ ["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
949
+ )
950
+ self._current_generation.message_ch.send_nowait(
951
+ llm.MessageGeneration(
952
+ message_id=response_id,
953
+ text_stream=self._current_generation.text_ch,
954
+ audio_stream=self._current_generation.audio_ch,
955
+ modalities=msg_modalities,
956
+ )
957
+ )
958
+
959
+ generation_event = llm.GenerationCreatedEvent(
960
+ message_stream=self._current_generation.message_ch,
961
+ function_stream=self._current_generation.function_ch,
962
+ user_initiated=False,
963
+ response_id=self._current_generation.response_id,
964
+ )
965
+
966
+ if self._pending_generation_fut and not self._pending_generation_fut.done():
967
+ generation_event.user_initiated = True
968
+ self._pending_generation_fut.set_result(generation_event)
969
+ self._pending_generation_fut = None
970
+ else:
971
+ # emit input_speech_started event before starting an agent initiated generation
972
+ # to interrupt the previous audio playout if any
973
+ self._handle_input_speech_started()
974
+
975
+ self.emit("generation_created", generation_event)
976
+
977
+ def _handle_server_content(self, server_content: types.LiveServerContent) -> None:
978
+ current_gen = self._current_generation
979
+ if not current_gen:
980
+ logger.warning("received server content but no active generation.")
981
+ return
982
+
983
+ if model_turn := server_content.model_turn:
984
+ for part in model_turn.parts or []:
985
+ if part.thought:
986
+ # bypass reasoning output
987
+ continue
988
+ if part.text:
989
+ current_gen.push_text(part.text)
990
+ if part.inline_data:
991
+ if not current_gen._first_token_timestamp:
992
+ current_gen._first_token_timestamp = time.time()
993
+ frame_data = part.inline_data.data
994
+ try:
995
+ if not isinstance(frame_data, bytes):
996
+ raise ValueError("frame_data is not bytes")
997
+ frame = rtc.AudioFrame(
998
+ data=frame_data,
999
+ sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
1000
+ num_channels=OUTPUT_AUDIO_CHANNELS,
1001
+ samples_per_channel=len(frame_data) // (2 * OUTPUT_AUDIO_CHANNELS),
1002
+ )
1003
+ current_gen.audio_ch.send_nowait(frame)
1004
+ except ValueError as e:
1005
+ logger.error(f"Error creating audio frame from Gemini data: {e}")
1006
+
1007
+ if input_transcription := server_content.input_transcription:
1008
+ text = input_transcription.text
1009
+ if text:
1010
+ if current_gen.input_transcription == "":
1011
+ # gemini would start with a space, which doesn't make sense
1012
+ # at beginning of the transcript
1013
+ text = text.lstrip()
1014
+ current_gen.input_transcription += text
1015
+ self.emit(
1016
+ "input_audio_transcription_completed",
1017
+ llm.InputTranscriptionCompleted(
1018
+ item_id=current_gen.input_id,
1019
+ transcript=current_gen.input_transcription,
1020
+ is_final=False,
1021
+ ),
1022
+ )
1023
+
1024
+ if output_transcription := server_content.output_transcription:
1025
+ text = output_transcription.text
1026
+ if text:
1027
+ current_gen.push_text(text)
1028
+
1029
+ if server_content.generation_complete or server_content.turn_complete:
1030
+ current_gen._completed_timestamp = time.time()
1031
+
1032
+ if server_content.interrupted and not self._pending_generation_fut:
1033
+ # interrupt agent if there is no pending user initiated generation
1034
+ self._handle_input_speech_started()
1035
+
1036
+ if server_content.turn_complete:
1037
+ self._mark_current_generation_done()
1038
+
1039
+ def _mark_current_generation_done(self) -> None:
1040
+ if not self._current_generation or self._current_generation._done:
1041
+ return
1042
+
1043
+ # emit input_speech_stopped event after the generation is done
1044
+ self._handle_input_speech_stopped()
1045
+
1046
+ gen = self._current_generation
1047
+
1048
+ # The only way we'd know that the transcription is complete is by when they are
1049
+ # done with generation
1050
+ if gen.input_transcription:
1051
+ self.emit(
1052
+ "input_audio_transcription_completed",
1053
+ llm.InputTranscriptionCompleted(
1054
+ item_id=gen.input_id,
1055
+ transcript=gen.input_transcription,
1056
+ is_final=True,
1057
+ ),
1058
+ )
1059
+
1060
+ # since gemini doesn't give us a view of the chat history on the server side,
1061
+ # we would handle it manually here
1062
+ self._chat_ctx.add_message(
1063
+ role="user",
1064
+ content=gen.input_transcription,
1065
+ id=gen.input_id,
1066
+ )
1067
+
1068
+ if gen.output_text:
1069
+ self._chat_ctx.add_message(
1070
+ role="assistant",
1071
+ content=gen.output_text,
1072
+ id=gen.response_id,
1073
+ )
1074
+
1075
+ if not gen.text_ch.closed:
1076
+ if self._opts.output_audio_transcription is None:
1077
+ # close the text data of transcription synchronizer
1078
+ gen.text_ch.send_nowait("")
1079
+ gen.text_ch.close()
1080
+ if not gen.audio_ch.closed:
1081
+ gen.audio_ch.close()
1082
+
1083
+ gen.function_ch.close()
1084
+ gen.message_ch.close()
1085
+ gen._done = True
1086
+ if lk_google_debug:
1087
+ logger.debug(f"generation done {gen}")
1088
+
1089
+ def _handle_input_speech_started(self) -> None:
1090
+ self.emit("input_speech_started", llm.InputSpeechStartedEvent())
1091
+
1092
+ def _handle_input_speech_stopped(self) -> None:
1093
+ self.emit(
1094
+ "input_speech_stopped",
1095
+ llm.InputSpeechStoppedEvent(user_transcription_enabled=False),
1096
+ )
1097
+
1098
+ def _handle_tool_calls(self, tool_call: types.LiveServerToolCall) -> None:
1099
+ if not self._current_generation:
1100
+ logger.warning("received tool call but no active generation.")
1101
+ return
1102
+
1103
+ gen = self._current_generation
1104
+ for fnc_call in tool_call.function_calls or []:
1105
+ arguments = json.dumps(fnc_call.args)
1106
+
1107
+ gen.function_ch.send_nowait(
1108
+ llm.FunctionCall(
1109
+ call_id=fnc_call.id or utils.shortuuid("fnc-call-"),
1110
+ name=fnc_call.name,
1111
+ arguments=arguments,
1112
+ )
1113
+ )
1114
+ self._mark_current_generation_done()
1115
+
1116
+ def _handle_tool_call_cancellation(
1117
+ self, tool_call_cancellation: types.LiveServerToolCallCancellation
1118
+ ) -> None:
1119
+ logger.warning(
1120
+ "server cancelled tool calls",
1121
+ extra={"function_call_ids": tool_call_cancellation.ids},
1122
+ )
1123
+
1124
+ def _handle_usage_metadata(self, usage_metadata: types.UsageMetadata) -> None:
1125
+ current_gen = self._current_generation
1126
+ if not current_gen:
1127
+ logger.warning("no active generation to report metrics for")
1128
+ return
1129
+
1130
+ ttft = (
1131
+ current_gen._first_token_timestamp - current_gen._created_timestamp
1132
+ if current_gen._first_token_timestamp
1133
+ else -1
1134
+ )
1135
+ duration = (
1136
+ current_gen._completed_timestamp or time.time()
1137
+ ) - current_gen._created_timestamp
1138
+
1139
+ def _token_details_map(
1140
+ token_details: list[types.ModalityTokenCount] | None,
1141
+ ) -> dict[str, int]:
1142
+ token_details_map = {"audio_tokens": 0, "text_tokens": 0, "image_tokens": 0}
1143
+ if not token_details:
1144
+ return token_details_map
1145
+
1146
+ for token_detail in token_details:
1147
+ if not token_detail.token_count:
1148
+ continue
1149
+
1150
+ if token_detail.modality == types.MediaModality.AUDIO:
1151
+ token_details_map["audio_tokens"] += token_detail.token_count
1152
+ elif token_detail.modality == types.MediaModality.TEXT:
1153
+ token_details_map["text_tokens"] += token_detail.token_count
1154
+ elif token_detail.modality == types.MediaModality.IMAGE:
1155
+ token_details_map["image_tokens"] += token_detail.token_count
1156
+ return token_details_map
1157
+
1158
+ metrics = RealtimeModelMetrics(
1159
+ label=self._realtime_model.label,
1160
+ request_id=current_gen.response_id,
1161
+ timestamp=current_gen._created_timestamp,
1162
+ duration=duration,
1163
+ ttft=ttft,
1164
+ cancelled=False,
1165
+ input_tokens=usage_metadata.prompt_token_count or 0,
1166
+ output_tokens=usage_metadata.response_token_count or 0,
1167
+ total_tokens=usage_metadata.total_token_count or 0,
1168
+ tokens_per_second=(usage_metadata.response_token_count or 0) / duration
1169
+ if duration > 0
1170
+ else 0,
1171
+ input_token_details=RealtimeModelMetrics.InputTokenDetails(
1172
+ **_token_details_map(usage_metadata.prompt_tokens_details),
1173
+ cached_tokens=sum(
1174
+ token_detail.token_count or 0
1175
+ for token_detail in usage_metadata.cache_tokens_details or []
1176
+ ),
1177
+ cached_tokens_details=RealtimeModelMetrics.CachedTokenDetails(
1178
+ **_token_details_map(usage_metadata.cache_tokens_details),
1179
+ ),
1180
+ ),
1181
+ output_token_details=RealtimeModelMetrics.OutputTokenDetails(
1182
+ **_token_details_map(usage_metadata.response_tokens_details),
1183
+ ),
1184
+ metadata=Metadata(
1185
+ model_name=self._realtime_model.model, model_provider=self._realtime_model.provider
1186
+ ),
1187
+ )
1188
+ self.emit("metrics_collected", metrics)
1189
+
1190
+ def _handle_go_away(self, go_away: types.LiveServerGoAway) -> None:
1191
+ logger.warning(
1192
+ f"Gemini server indicates disconnection soon. Time left: {go_away.time_left}"
1193
+ )
1194
+ # TODO(dz): this isn't a seamless reconnection just yet
1195
+ self._session_should_close.set()
1196
+
1197
+ def commit_audio(self) -> None:
1198
+ pass
1199
+
1200
+ def clear_audio(self) -> None:
1201
+ pass
1202
+
1203
+ def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
1204
+ if self._input_resampler:
1205
+ if frame.sample_rate != self._input_resampler._input_rate:
1206
+ # input audio changed to a different sample rate
1207
+ self._input_resampler = None
1208
+
1209
+ if self._input_resampler is None and (
1210
+ frame.sample_rate != INPUT_AUDIO_SAMPLE_RATE
1211
+ or frame.num_channels != INPUT_AUDIO_CHANNELS
1212
+ ):
1213
+ self._input_resampler = rtc.AudioResampler(
1214
+ input_rate=frame.sample_rate,
1215
+ output_rate=INPUT_AUDIO_SAMPLE_RATE,
1216
+ num_channels=INPUT_AUDIO_CHANNELS,
1217
+ )
1218
+
1219
+ if self._input_resampler:
1220
+ # TODO(long): flush the resampler when the input source is changed
1221
+ yield from self._input_resampler.push(frame)
1222
+ else:
1223
+ yield frame
1224
+
1225
+ def _emit_error(self, error: Exception, recoverable: bool) -> None:
1226
+ self.emit(
1227
+ "error",
1228
+ llm.RealtimeModelError(
1229
+ timestamp=time.time(),
1230
+ label=self._realtime_model._label,
1231
+ error=error,
1232
+ recoverable=recoverable,
1233
+ ),
1234
+ )
1235
+
1236
+ def _is_new_generation(self, resp: types.LiveServerMessage) -> bool:
1237
+ if resp.tool_call:
1238
+ return True
1239
+
1240
+ if (sc := resp.server_content) and (
1241
+ sc.model_turn
1242
+ or (sc.output_transcription and sc.output_transcription is not None)
1243
+ or (sc.input_transcription and sc.input_transcription is not None)
1244
+ or (sc.generation_complete is not None)
1245
+ or (sc.turn_complete is not None)
1246
+ ):
1247
+ return True
1248
+
1249
+ return False