livekit-plugins-google 0.11.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,25 +3,22 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import json
5
5
  import os
6
+ import weakref
6
7
  from dataclasses import dataclass
7
- from typing import AsyncIterable, Literal
8
-
9
- from livekit import rtc
10
- from livekit.agents import llm, utils
11
- from livekit.agents.llm.function_context import _create_ai_function_info
12
- from livekit.agents.utils import images
13
8
 
14
9
  from google import genai
10
+ from google.genai._api_client import HttpOptions
15
11
  from google.genai.types import (
16
12
  Blob,
17
13
  Content,
18
- FunctionResponse,
14
+ FunctionDeclaration,
19
15
  GenerationConfig,
20
- HttpOptions,
21
16
  LiveClientContent,
22
17
  LiveClientRealtimeInput,
23
- LiveClientToolResponse,
24
18
  LiveConnectConfig,
19
+ LiveServerContent,
20
+ LiveServerToolCall,
21
+ LiveServerToolCallCancellation,
25
22
  Modality,
26
23
  Part,
27
24
  PrebuiltVoiceConfig,
@@ -29,42 +26,18 @@ from google.genai.types import (
29
26
  Tool,
30
27
  VoiceConfig,
31
28
  )
29
+ from livekit import rtc
30
+ from livekit.agents import llm, utils
31
+ from livekit.agents.types import NOT_GIVEN, NotGivenOr
32
+ from livekit.agents.utils import is_given
32
33
 
33
34
  from ...log import logger
34
- from .api_proto import (
35
- ClientEvents,
36
- LiveAPIModels,
37
- Voice,
38
- _build_gemini_ctx,
39
- _build_tools,
40
- )
41
- from .transcriber import ModelTranscriber, TranscriberSession, TranscriptionContent
42
-
43
- EventTypes = Literal[
44
- "start_session",
45
- "input_speech_started",
46
- "response_content_added",
47
- "response_content_done",
48
- "function_calls_collected",
49
- "function_calls_finished",
50
- "function_calls_cancelled",
51
- "input_speech_transcription_completed",
52
- "agent_speech_transcription_completed",
53
- "agent_speech_stopped",
54
- ]
35
+ from ...utils import _build_gemini_fnc, get_tool_results_for_realtime, to_chat_ctx
36
+ from .api_proto import ClientEvents, LiveAPIModels, Voice
55
37
 
56
-
57
- @dataclass
58
- class GeminiContent:
59
- response_id: str
60
- item_id: str
61
- output_index: int
62
- content_index: int
63
- text: str
64
- audio: list[rtc.AudioFrame]
65
- text_stream: AsyncIterable[str]
66
- audio_stream: AsyncIterable[rtc.AudioFrame]
67
- content_type: Literal["text", "audio"]
38
+ INPUT_AUDIO_SAMPLE_RATE = 16000
39
+ OUTPUT_AUDIO_SAMPLE_RATE = 24000
40
+ NUM_CHANNELS = 1
68
41
 
69
42
 
70
43
  @dataclass
@@ -74,57 +47,59 @@ class InputTranscription:
74
47
 
75
48
 
76
49
  @dataclass
77
- class Capabilities:
78
- supports_truncate: bool
79
- input_audio_sample_rate: int | None = None
80
-
81
-
82
- @dataclass
83
- class ModelOptions:
50
+ class _RealtimeOptions:
84
51
  model: LiveAPIModels | str
85
52
  api_key: str | None
86
- api_version: str
87
53
  voice: Voice | str
88
- response_modalities: list[Modality] | None
54
+ response_modalities: NotGivenOr[list[Modality]]
89
55
  vertexai: bool
90
56
  project: str | None
91
57
  location: str | None
92
58
  candidate_count: int
93
- temperature: float | None
94
- max_output_tokens: int | None
95
- top_p: float | None
96
- top_k: int | None
97
- presence_penalty: float | None
98
- frequency_penalty: float | None
99
- instructions: Content | None
100
- enable_user_audio_transcription: bool
101
- enable_agent_audio_transcription: bool
102
-
103
-
104
- class RealtimeModel:
59
+ temperature: NotGivenOr[float]
60
+ max_output_tokens: NotGivenOr[int]
61
+ top_p: NotGivenOr[float]
62
+ top_k: NotGivenOr[int]
63
+ presence_penalty: NotGivenOr[float]
64
+ frequency_penalty: NotGivenOr[float]
65
+ instructions: NotGivenOr[str]
66
+
67
+
68
+ @dataclass
69
+ class _MessageGeneration:
70
+ message_id: str
71
+ text_ch: utils.aio.Chan[str]
72
+ audio_ch: utils.aio.Chan[rtc.AudioFrame]
73
+
74
+
75
+ @dataclass
76
+ class _ResponseGeneration:
77
+ message_ch: utils.aio.Chan[llm.MessageGeneration]
78
+ function_ch: utils.aio.Chan[llm.FunctionCall]
79
+
80
+ messages: dict[str, _MessageGeneration]
81
+
82
+
83
+ class RealtimeModel(llm.RealtimeModel):
105
84
  def __init__(
106
85
  self,
107
86
  *,
108
- instructions: str | None = None,
87
+ instructions: NotGivenOr[str] = NOT_GIVEN,
109
88
  model: LiveAPIModels | str = "gemini-2.0-flash-exp",
110
- api_key: str | None = None,
111
- api_version: str = "v1alpha",
89
+ api_key: NotGivenOr[str] = NOT_GIVEN,
112
90
  voice: Voice | str = "Puck",
113
- modalities: list[Modality] = [Modality.AUDIO],
114
- enable_user_audio_transcription: bool = True,
115
- enable_agent_audio_transcription: bool = True,
91
+ modalities: NotGivenOr[list[Modality]] = NOT_GIVEN,
116
92
  vertexai: bool = False,
117
- project: str | None = None,
118
- location: str | None = None,
93
+ project: NotGivenOr[str] = NOT_GIVEN,
94
+ location: NotGivenOr[str] = NOT_GIVEN,
119
95
  candidate_count: int = 1,
120
- temperature: float | None = None,
121
- max_output_tokens: int | None = None,
122
- top_p: float | None = None,
123
- top_k: int | None = None,
124
- presence_penalty: float | None = None,
125
- frequency_penalty: float | None = None,
126
- loop: asyncio.AbstractEventLoop | None = None,
127
- ):
96
+ temperature: NotGivenOr[float] = NOT_GIVEN,
97
+ max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
98
+ top_p: NotGivenOr[float] = NOT_GIVEN,
99
+ top_k: NotGivenOr[int] = NOT_GIVEN,
100
+ presence_penalty: NotGivenOr[float] = NOT_GIVEN,
101
+ frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
102
+ ) -> None:
128
103
  """
129
104
  Initializes a RealtimeModel instance for interacting with Google's Realtime API.
130
105
 
@@ -137,68 +112,57 @@ class RealtimeModel:
137
112
 
138
113
  Args:
139
114
  instructions (str, optional): Initial system instructions for the model. Defaults to "".
140
- api_key (str or None, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
141
- api_version (str, optional): The version of the API to use. Defaults to "v1alpha".
115
+ api_key (str, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
142
116
  modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
143
- model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
117
+ model (str, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
144
118
  voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
145
- enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True
146
- enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True
147
119
  temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
148
120
  vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
149
- project (str or None, optional): The project id to use for the API. Defaults to None. (for vertexai)
150
- location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai)
121
+ project (str, optional): The project id to use for the API. Defaults to None. (for vertexai)
122
+ location (str, optional): The location to use for the API. Defaults to None. (for vertexai)
151
123
  candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
152
124
  top_p (float, optional): The top-p value for response generation
153
125
  top_k (int, optional): The top-k value for response generation
154
126
  presence_penalty (float, optional): The presence penalty for response generation
155
127
  frequency_penalty (float, optional): The frequency penalty for response generation
156
- loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used.
157
128
 
158
129
  Raises:
159
- ValueError: If the API key is not provided and cannot be found in environment variables.
160
- """
161
- super().__init__()
162
- self._capabilities = Capabilities(
163
- supports_truncate=False,
164
- input_audio_sample_rate=16000,
130
+ ValueError: If the API key is required but not found.
131
+ """ # noqa: E501
132
+ super().__init__(
133
+ capabilities=llm.RealtimeCapabilities(
134
+ message_truncation=False,
135
+ turn_detection=True,
136
+ user_transcription=False,
137
+ )
165
138
  )
166
- self._model = model
167
- self._loop = loop or asyncio.get_event_loop()
168
- self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
169
- self._project = project or os.environ.get("GOOGLE_CLOUD_PROJECT")
170
- self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION")
139
+
140
+ gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
141
+ gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
142
+ gcp_location = location if is_given(location) else os.environ.get("GOOGLE_CLOUD_LOCATION")
171
143
  if vertexai:
172
- if not self._project or not self._location:
144
+ if not gcp_project or not gcp_location:
173
145
  raise ValueError(
174
- "Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables"
146
+ "Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables" # noqa: E501
175
147
  )
176
- self._api_key = None # VertexAI does not require an API key
148
+ gemini_api_key = None # VertexAI does not require an API key
177
149
 
178
150
  else:
179
- self._project = None
180
- self._location = None
181
- if not self._api_key:
151
+ gcp_project = None
152
+ gcp_location = None
153
+ if not gemini_api_key:
182
154
  raise ValueError(
183
- "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
155
+ "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" # noqa: E501
184
156
  )
185
157
 
186
- instructions_content = (
187
- Content(parts=[Part(text=instructions)]) if instructions else None
188
- )
189
-
190
- self._rt_sessions: list[GeminiRealtimeSession] = []
191
- self._opts = ModelOptions(
158
+ self._opts = _RealtimeOptions(
192
159
  model=model,
193
- api_version=api_version,
194
- api_key=self._api_key,
160
+ api_key=gemini_api_key,
195
161
  voice=voice,
196
- enable_user_audio_transcription=enable_user_audio_transcription,
197
- enable_agent_audio_transcription=enable_agent_audio_transcription,
198
162
  response_modalities=modalities,
199
163
  vertexai=vertexai,
200
- project=self._project,
201
- location=self._location,
164
+ project=gcp_project,
165
+ location=gcp_location,
202
166
  candidate_count=candidate_count,
203
167
  temperature=temperature,
204
168
  max_output_tokens=max_output_tokens,
@@ -206,387 +170,381 @@ class RealtimeModel:
206
170
  top_k=top_k,
207
171
  presence_penalty=presence_penalty,
208
172
  frequency_penalty=frequency_penalty,
209
- instructions=instructions_content,
173
+ instructions=instructions,
210
174
  )
211
175
 
212
- @property
213
- def sessions(self) -> list[GeminiRealtimeSession]:
214
- return self._rt_sessions
176
+ self._sessions = weakref.WeakSet[RealtimeSession]()
215
177
 
216
- @property
217
- def capabilities(self) -> Capabilities:
218
- return self._capabilities
178
+ def session(self) -> RealtimeSession:
179
+ sess = RealtimeSession(self)
180
+ self._sessions.add(sess)
181
+ return sess
219
182
 
220
- def session(
221
- self,
222
- *,
223
- chat_ctx: llm.ChatContext | None = None,
224
- fnc_ctx: llm.FunctionContext | None = None,
225
- ) -> GeminiRealtimeSession:
226
- session = GeminiRealtimeSession(
227
- opts=self._opts,
228
- chat_ctx=chat_ctx or llm.ChatContext(),
229
- fnc_ctx=fnc_ctx,
230
- loop=self._loop,
231
- )
232
- self._rt_sessions.append(session)
183
+ def update_options(
184
+ self, *, voice: NotGivenOr[str] = NOT_GIVEN, temperature: NotGivenOr[float] = NOT_GIVEN
185
+ ) -> None:
186
+ if is_given(voice):
187
+ self._opts.voice = voice
233
188
 
234
- return session
189
+ if is_given(temperature):
190
+ self._opts.temperature = temperature
235
191
 
236
- async def aclose(self) -> None:
237
- for session in self._rt_sessions:
238
- await session.aclose()
192
+ for sess in self._sessions:
193
+ sess.update_options(voice=self._opts.voice, temperature=self._opts.temperature)
239
194
 
195
+ async def aclose(self) -> None: ...
240
196
 
241
- class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
242
- def __init__(
243
- self,
244
- *,
245
- opts: ModelOptions,
246
- chat_ctx: llm.ChatContext,
247
- fnc_ctx: llm.FunctionContext | None,
248
- loop: asyncio.AbstractEventLoop,
249
- ):
250
- """
251
- Initializes a GeminiRealtimeSession instance for interacting with Google's Realtime API.
252
197
 
253
- Args:
254
- opts (ModelOptions): The model options for the session.
255
- chat_ctx (llm.ChatContext): The chat context for the session.
256
- fnc_ctx (llm.FunctionContext or None): The function context for the session.
257
- loop (asyncio.AbstractEventLoop): The event loop for the session.
258
- """
259
- super().__init__()
260
- self._loop = loop
261
- self._opts = opts
262
- self._chat_ctx = chat_ctx
263
- self._fnc_ctx = fnc_ctx
264
- self._fnc_tasks = utils.aio.TaskSet()
265
- self._is_interrupted = False
266
- self._playout_complete = asyncio.Event()
267
- self._playout_complete.set()
268
-
269
- tools = []
270
- if self._fnc_ctx is not None:
271
- functions = _build_tools(self._fnc_ctx)
272
- tools.append(Tool(function_declarations=functions))
273
-
274
- self._config = LiveConnectConfig(
275
- response_modalities=self._opts.response_modalities,
276
- generation_config=GenerationConfig(
277
- candidate_count=self._opts.candidate_count,
278
- temperature=self._opts.temperature,
279
- max_output_tokens=self._opts.max_output_tokens,
280
- top_p=self._opts.top_p,
281
- top_k=self._opts.top_k,
282
- presence_penalty=self._opts.presence_penalty,
283
- frequency_penalty=self._opts.frequency_penalty,
284
- ),
285
- system_instruction=self._opts.instructions,
286
- speech_config=SpeechConfig(
287
- voice_config=VoiceConfig(
288
- prebuilt_voice_config=PrebuiltVoiceConfig(
289
- voice_name=self._opts.voice
290
- )
291
- )
292
- ),
293
- tools=tools,
294
- )
198
+ class RealtimeSession(llm.RealtimeSession):
199
+ def __init__(self, realtime_model: RealtimeModel) -> None:
200
+ super().__init__(realtime_model)
201
+ self._opts = realtime_model._opts
202
+ self._tools = llm.ToolContext.empty()
203
+ self._chat_ctx = llm.ChatContext.empty()
204
+ self._msg_ch = utils.aio.Chan[ClientEvents]()
205
+ self._gemini_tools: list[Tool] = []
295
206
  self._client = genai.Client(
296
- http_options=HttpOptions(api_version=self._opts.api_version),
207
+ http_options=HttpOptions(api_version="v1alpha"),
297
208
  api_key=self._opts.api_key,
298
209
  vertexai=self._opts.vertexai,
299
210
  project=self._opts.project,
300
211
  location=self._opts.location,
301
212
  )
302
- self._main_atask = asyncio.create_task(
303
- self._main_task(), name="gemini-realtime-session"
304
- )
305
- if self._opts.enable_user_audio_transcription:
306
- self._transcriber = TranscriberSession(
307
- client=self._client, model=self._opts.model
308
- )
309
- self._transcriber.on("input_speech_done", self._on_input_speech_done)
310
- if self._opts.enable_agent_audio_transcription:
311
- self._agent_transcriber = ModelTranscriber(
312
- client=self._client, model=self._opts.model
313
- )
314
- self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done)
315
- # init dummy task
316
- self._init_sync_task = asyncio.create_task(asyncio.sleep(0))
317
- self._send_ch = utils.aio.Chan[ClientEvents]()
318
- self._active_response_id = None
213
+ self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
319
214
 
320
- async def aclose(self) -> None:
321
- if self._send_ch.closed:
322
- return
215
+ self._current_generation: _ResponseGeneration | None = None
323
216
 
324
- self._send_ch.close()
325
- await self._main_atask
217
+ self._is_interrupted = False
218
+ self._active_response_id = None
219
+ self._session = None
220
+ self._update_chat_ctx_lock = asyncio.Lock()
221
+ self._update_fnc_ctx_lock = asyncio.Lock()
222
+ self._response_created_futures: dict[str, asyncio.Future[llm.GenerationCreatedEvent]] = {}
223
+ self._pending_generation_event_id = None
224
+
225
+ self._reconnect_event = asyncio.Event()
226
+ self._session_lock = asyncio.Lock()
227
+ self._gemini_close_task: asyncio.Task | None = None
228
+
229
+ def _schedule_gemini_session_close(self) -> None:
230
+ if self._session is not None:
231
+ self._gemini_close_task = asyncio.create_task(self._close_gemini_session())
232
+
233
+ async def _close_gemini_session(self) -> None:
234
+ async with self._session_lock:
235
+ if self._session:
236
+ try:
237
+ await self._session.close()
238
+ finally:
239
+ self._session = None
240
+
241
+ def update_options(
242
+ self,
243
+ *,
244
+ voice: NotGivenOr[str] = NOT_GIVEN,
245
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
246
+ temperature: NotGivenOr[float] = NOT_GIVEN,
247
+ ) -> None:
248
+ if is_given(voice):
249
+ self._opts.voice = voice
250
+
251
+ if is_given(temperature):
252
+ self._opts.temperature = temperature
253
+
254
+ if self._session:
255
+ logger.warning("Updating options; triggering Gemini session reconnect.")
256
+ self._reconnect_event.set()
257
+ self._schedule_gemini_session_close()
258
+
259
+ async def update_instructions(self, instructions: str) -> None:
260
+ self._opts.instructions = instructions
261
+ if self._session:
262
+ logger.warning("Updating instructions; triggering Gemini session reconnect.")
263
+ self._reconnect_event.set()
264
+ self._schedule_gemini_session_close()
265
+
266
+ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
267
+ async with self._update_chat_ctx_lock:
268
+ self._chat_ctx = chat_ctx
269
+ turns, _ = to_chat_ctx(self._chat_ctx, id(self), ignore_functions=True)
270
+ tool_results = get_tool_results_for_realtime(self._chat_ctx)
271
+ if turns:
272
+ self._msg_ch.send_nowait(LiveClientContent(turns=turns, turn_complete=False))
273
+ if tool_results:
274
+ self._msg_ch.send_nowait(tool_results)
275
+
276
+ async def update_tools(self, tools: list[llm.FunctionTool]) -> None:
277
+ async with self._update_fnc_ctx_lock:
278
+ retained_tools: list[llm.FunctionTool] = []
279
+ gemini_function_declarations: list[FunctionDeclaration] = []
280
+
281
+ for tool in tools:
282
+ gemini_function = _build_gemini_fnc(tool)
283
+ gemini_function_declarations.append(gemini_function)
284
+ retained_tools.append(tool)
285
+
286
+ self._tools = llm.ToolContext(retained_tools)
287
+ self._gemini_tools = [Tool(function_declarations=gemini_function_declarations)]
288
+ if self._session and gemini_function_declarations:
289
+ logger.warning("Updating tools; triggering Gemini session reconnect.")
290
+ self._reconnect_event.set()
291
+ self._schedule_gemini_session_close()
326
292
 
327
293
  @property
328
- def playout_complete(self) -> asyncio.Event | None:
329
- return self._playout_complete
294
+ def chat_ctx(self) -> llm.ChatContext:
295
+ return self._chat_ctx
330
296
 
331
297
  @property
332
- def fnc_ctx(self) -> llm.FunctionContext | None:
333
- return self._fnc_ctx
334
-
335
- @fnc_ctx.setter
336
- def fnc_ctx(self, value: llm.FunctionContext | None) -> None:
337
- self._fnc_ctx = value
298
+ def tools(self) -> llm.ToolContext:
299
+ return self._tools
338
300
 
339
- def _push_media_chunk(self, data: bytes, mime_type: str) -> None:
301
+ def push_audio(self, frame: rtc.AudioFrame) -> None:
340
302
  realtime_input = LiveClientRealtimeInput(
341
- media_chunks=[Blob(data=data, mime_type=mime_type)],
303
+ media_chunks=[Blob(data=frame.data.tobytes(), mime_type="audio/pcm")],
342
304
  )
343
- self._queue_msg(realtime_input)
305
+ self._msg_ch.send_nowait(realtime_input)
344
306
 
345
- DEFAULT_ENCODE_OPTIONS = images.EncodeOptions(
346
- format="JPEG",
347
- quality=75,
348
- resize_options=images.ResizeOptions(
349
- width=1024, height=1024, strategy="scale_aspect_fit"
350
- ),
351
- )
307
+ def generate_reply(
308
+ self, *, instructions: NotGivenOr[str] = NOT_GIVEN
309
+ ) -> asyncio.Future[llm.GenerationCreatedEvent]:
310
+ fut = asyncio.Future()
352
311
 
353
- def push_video(
354
- self,
355
- frame: rtc.VideoFrame,
356
- encode_options: images.EncodeOptions = DEFAULT_ENCODE_OPTIONS,
357
- ) -> None:
358
- """Push a video frame to the Gemini Multimodal Live session.
312
+ event_id = utils.shortuuid("gemini-response-")
313
+ self._response_created_futures[event_id] = fut
314
+ self._pending_generation_event_id = event_id
359
315
 
360
- Args:
361
- frame (rtc.VideoFrame): The video frame to push.
362
- encode_options (images.EncodeOptions, optional): The encode options for the video frame. Defaults to 1024x1024 JPEG.
316
+ instructions_content = instructions if is_given(instructions) else "."
317
+ ctx = [Content(parts=[Part(text=instructions_content)], role="user")]
318
+ self._msg_ch.send_nowait(LiveClientContent(turns=ctx, turn_complete=True))
363
319
 
364
- Notes:
365
- - This will be sent immediately so you should use a sampling frame rate that makes sense for your application and Gemini's constraints. 1 FPS is a good starting point.
366
- """
367
- encoded_data = images.encode(
368
- frame,
369
- encode_options,
370
- )
371
- mime_type = (
372
- "image/jpeg"
373
- if encode_options.format == "JPEG"
374
- else "image/png"
375
- if encode_options.format == "PNG"
376
- else "image/jpeg"
377
- )
378
- self._push_media_chunk(encoded_data, mime_type)
320
+ def _on_timeout() -> None:
321
+ if event_id in self._response_created_futures and not fut.done():
322
+ fut.set_exception(llm.RealtimeError("generate_reply timed out."))
323
+ self._response_created_futures.pop(event_id, None)
324
+ if self._pending_generation_event_id == event_id:
325
+ self._pending_generation_event_id = None
379
326
 
380
- def _push_audio(self, frame: rtc.AudioFrame) -> None:
381
- if self._opts.enable_user_audio_transcription:
382
- self._transcriber._push_audio(frame)
327
+ handle = asyncio.get_event_loop().call_later(5.0, _on_timeout)
328
+ fut.add_done_callback(lambda _: handle.cancel())
383
329
 
384
- self._push_media_chunk(frame.data.tobytes(), "audio/pcm")
330
+ return fut
385
331
 
386
- def _queue_msg(self, msg: ClientEvents) -> None:
387
- self._send_ch.send_nowait(msg)
332
+ def interrupt(self) -> None:
333
+ logger.warning("interrupt() - no direct cancellation in Gemini")
388
334
 
389
- def chat_ctx_copy(self) -> llm.ChatContext:
390
- return self._chat_ctx.copy()
335
+ def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
336
+ logger.warning(f"truncate(...) called for {message_id}, ignoring for Gemini")
391
337
 
392
- async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
393
- self._chat_ctx = ctx.copy()
338
+ async def aclose(self) -> None:
339
+ self._msg_ch.close()
394
340
 
395
- def cancel_response(self) -> None:
396
- raise NotImplementedError("cancel_response is not supported yet")
341
+ for fut in self._response_created_futures.values():
342
+ if not fut.done():
343
+ fut.set_exception(llm.RealtimeError("Session closed"))
397
344
 
398
- def create_response(
399
- self,
400
- on_duplicate: Literal[
401
- "cancel_existing", "cancel_new", "keep_both"
402
- ] = "keep_both",
403
- ) -> None:
404
- turns, _ = _build_gemini_ctx(self._chat_ctx, id(self))
405
- ctx = [self._opts.instructions] + turns if self._opts.instructions else turns
345
+ if self._main_atask:
346
+ await utils.aio.cancel_and_wait(self._main_atask)
406
347
 
407
- if not ctx:
408
- logger.warning(
409
- "gemini-realtime-session: No chat context to send, sending dummy content."
348
+ if self._gemini_close_task:
349
+ await utils.aio.cancel_and_wait(self._gemini_close_task)
350
+
351
+ @utils.log_exceptions(logger=logger)
352
+ async def _main_task(self):
353
+ while True:
354
+ config = LiveConnectConfig(
355
+ response_modalities=self._opts.response_modalities
356
+ if is_given(self._opts.response_modalities)
357
+ else [Modality.AUDIO],
358
+ generation_config=GenerationConfig(
359
+ candidate_count=self._opts.candidate_count,
360
+ temperature=self._opts.temperature
361
+ if is_given(self._opts.temperature)
362
+ else None,
363
+ max_output_tokens=self._opts.max_output_tokens
364
+ if is_given(self._opts.max_output_tokens)
365
+ else None,
366
+ top_p=self._opts.top_p if is_given(self._opts.top_p) else None,
367
+ top_k=self._opts.top_k if is_given(self._opts.top_k) else None,
368
+ presence_penalty=self._opts.presence_penalty
369
+ if is_given(self._opts.presence_penalty)
370
+ else None,
371
+ frequency_penalty=self._opts.frequency_penalty
372
+ if is_given(self._opts.frequency_penalty)
373
+ else None,
374
+ ),
375
+ system_instruction=Content(parts=[Part(text=self._opts.instructions)])
376
+ if is_given(self._opts.instructions)
377
+ else None,
378
+ speech_config=SpeechConfig(
379
+ voice_config=VoiceConfig(
380
+ prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
381
+ )
382
+ ),
383
+ tools=self._gemini_tools,
410
384
  )
411
- ctx = [Content(parts=[Part(text=".")])]
412
385
 
413
- self._queue_msg(LiveClientContent(turns=ctx, turn_complete=True))
386
+ async with self._client.aio.live.connect(
387
+ model=self._opts.model, config=config
388
+ ) as session:
389
+ async with self._session_lock:
390
+ self._session = session
391
+
392
+ @utils.log_exceptions(logger=logger)
393
+ async def _send_task():
394
+ async for msg in self._msg_ch:
395
+ if isinstance(msg, LiveClientContent):
396
+ await session.send(input=msg, end_of_turn=True)
397
+
398
+ await session.send(input=msg)
399
+ await session.send(input=".", end_of_turn=True)
400
+
401
+ @utils.log_exceptions(logger=logger)
402
+ async def _recv_task():
403
+ while True:
404
+ async for response in session.receive():
405
+ if self._active_response_id is None:
406
+ self._start_new_generation()
407
+ if response.server_content:
408
+ self._handle_server_content(response.server_content)
409
+ if response.tool_call:
410
+ self._handle_tool_calls(response.tool_call)
411
+ if response.tool_call_cancellation:
412
+ self._handle_tool_call_cancellation(response.tool_call_cancellation)
413
+
414
+ send_task = asyncio.create_task(_send_task(), name="gemini-realtime-send")
415
+ recv_task = asyncio.create_task(_recv_task(), name="gemini-realtime-recv")
416
+ reconnect_task = asyncio.create_task(
417
+ self._reconnect_event.wait(), name="reconnect-wait"
418
+ )
414
419
 
415
- def commit_audio_buffer(self) -> None:
416
- raise NotImplementedError("commit_audio_buffer is not supported yet")
420
+ try:
421
+ done, _ = await asyncio.wait(
422
+ [send_task, recv_task, reconnect_task],
423
+ return_when=asyncio.FIRST_COMPLETED,
424
+ )
425
+ for task in done:
426
+ if task != reconnect_task:
427
+ task.result()
417
428
 
418
- def server_vad_enabled(self) -> bool:
419
- return True
429
+ if reconnect_task not in done:
430
+ break
420
431
 
421
- def _on_input_speech_done(self, content: TranscriptionContent) -> None:
422
- if content.response_id and content.text:
423
- self.emit(
424
- "input_speech_transcription_completed",
425
- InputTranscription(
426
- item_id=content.response_id,
427
- transcript=content.text,
428
- ),
432
+ self._reconnect_event.clear()
433
+ finally:
434
+ await utils.aio.cancel_and_wait(send_task, recv_task, reconnect_task)
435
+
436
+ def _start_new_generation(self):
437
+ self._is_interrupted = False
438
+ self._active_response_id = utils.shortuuid("gemini-turn-")
439
+ self._current_generation = _ResponseGeneration(
440
+ message_ch=utils.aio.Chan[llm.MessageGeneration](),
441
+ function_ch=utils.aio.Chan[llm.FunctionCall](),
442
+ messages={},
443
+ )
444
+
445
+ # We'll assume each chunk belongs to a single message ID self._active_response_id
446
+ item_generation = _MessageGeneration(
447
+ message_id=self._active_response_id,
448
+ text_ch=utils.aio.Chan[str](),
449
+ audio_ch=utils.aio.Chan[rtc.AudioFrame](),
450
+ )
451
+
452
+ self._current_generation.message_ch.send_nowait(
453
+ llm.MessageGeneration(
454
+ message_id=self._active_response_id,
455
+ text_stream=item_generation.text_ch,
456
+ audio_stream=item_generation.audio_ch,
429
457
  )
458
+ )
430
459
 
431
- # self._chat_ctx.append(text=content.text, role="user")
432
- # TODO: implement sync mechanism to make sure the transcribed user speech is inside the chat_ctx and always before the generated agent speech
460
+ generation_event = llm.GenerationCreatedEvent(
461
+ message_stream=self._current_generation.message_ch,
462
+ function_stream=self._current_generation.function_ch,
463
+ user_initiated=False,
464
+ )
433
465
 
434
- def _on_agent_speech_done(self, content: TranscriptionContent) -> None:
435
- if content.response_id and content.text:
436
- self.emit(
437
- "agent_speech_transcription_completed",
438
- InputTranscription(
439
- item_id=content.response_id,
440
- transcript=content.text,
441
- ),
466
+ # Resolve any pending future from generate_reply()
467
+ if self._pending_generation_event_id and (
468
+ fut := self._response_created_futures.pop(self._pending_generation_event_id, None)
469
+ ):
470
+ fut.set_result(generation_event)
471
+
472
+ self._pending_generation_event_id = None
473
+ self.emit("generation_created", generation_event)
474
+
475
+ self._current_generation.messages[self._active_response_id] = item_generation
476
+
477
+ def _handle_server_content(self, server_content: LiveServerContent):
478
+ if not self._current_generation or not self._active_response_id:
479
+ logger.warning(
480
+ "gemini-realtime-session: No active response ID, skipping server content"
442
481
  )
443
- # self._chat_ctx.append(text=content.text, role="assistant")
482
+ return
444
483
 
445
- @utils.log_exceptions(logger=logger)
446
- async def _main_task(self):
447
- @utils.log_exceptions(logger=logger)
448
- async def _send_task():
449
- async for msg in self._send_ch:
450
- await self._session.send(input=msg)
451
-
452
- await self._session.send(input=".", end_of_turn=True)
453
-
454
- @utils.log_exceptions(logger=logger)
455
- async def _recv_task():
456
- while True:
457
- async for response in self._session.receive():
458
- if self._active_response_id is None:
459
- self._is_interrupted = False
460
- self._active_response_id = utils.shortuuid()
461
- text_stream = utils.aio.Chan[str]()
462
- audio_stream = utils.aio.Chan[rtc.AudioFrame]()
463
- content = GeminiContent(
464
- response_id=self._active_response_id,
465
- item_id=self._active_response_id,
466
- output_index=0,
467
- content_index=0,
468
- text="",
469
- audio=[],
470
- text_stream=text_stream,
471
- audio_stream=audio_stream,
472
- content_type="audio",
473
- )
474
- self.emit("response_content_added", content)
475
-
476
- server_content = response.server_content
477
- if server_content:
478
- model_turn = server_content.model_turn
479
- if model_turn:
480
- for part in model_turn.parts:
481
- if part.text:
482
- content.text_stream.send_nowait(part.text)
483
- if part.inline_data:
484
- frame = rtc.AudioFrame(
485
- data=part.inline_data.data,
486
- sample_rate=24000,
487
- num_channels=1,
488
- samples_per_channel=len(part.inline_data.data)
489
- // 2,
490
- )
491
- if self._opts.enable_agent_audio_transcription:
492
- content.audio.append(frame)
493
- content.audio_stream.send_nowait(frame)
494
-
495
- if server_content.interrupted or server_content.turn_complete:
496
- if self._opts.enable_agent_audio_transcription:
497
- self._agent_transcriber._push_audio(content.audio)
498
- for stream in (content.text_stream, content.audio_stream):
499
- if isinstance(stream, utils.aio.Chan):
500
- stream.close()
501
-
502
- self.emit("agent_speech_stopped")
503
- self._is_interrupted = True
504
-
505
- self._active_response_id = None
506
-
507
- if response.tool_call:
508
- if self._fnc_ctx is None:
509
- raise ValueError("Function context is not set")
510
- fnc_calls = []
511
- for fnc_call in response.tool_call.function_calls:
512
- fnc_call_info = _create_ai_function_info(
513
- self._fnc_ctx,
514
- fnc_call.id,
515
- fnc_call.name,
516
- json.dumps(fnc_call.args),
517
- )
518
- fnc_calls.append(fnc_call_info)
519
-
520
- self.emit("function_calls_collected", fnc_calls)
521
-
522
- for fnc_call_info in fnc_calls:
523
- self._fnc_tasks.create_task(
524
- self._run_fnc_task(fnc_call_info, content.item_id)
525
- )
526
-
527
- # Handle function call cancellations
528
- if response.tool_call_cancellation:
529
- logger.warning(
530
- "function call cancelled",
531
- extra={
532
- "function_call_ids": response.tool_call_cancellation.ids,
533
- },
534
- )
535
- self.emit(
536
- "function_calls_cancelled",
537
- response.tool_call_cancellation.ids,
538
- )
539
-
540
- async with self._client.aio.live.connect(
541
- model=self._opts.model, config=self._config
542
- ) as session:
543
- self._session = session
544
- tasks = [
545
- asyncio.create_task(_send_task(), name="gemini-realtime-send"),
546
- asyncio.create_task(_recv_task(), name="gemini-realtime-recv"),
547
- ]
548
-
549
- try:
550
- await asyncio.gather(*tasks)
551
- finally:
552
- await utils.aio.gracefully_cancel(*tasks)
553
- await self._session.close()
554
- if self._opts.enable_user_audio_transcription:
555
- await self._transcriber.aclose()
556
- if self._opts.enable_agent_audio_transcription:
557
- await self._agent_transcriber.aclose()
484
+ item_generation = self._current_generation.messages[self._active_response_id]
485
+
486
+ model_turn = server_content.model_turn
487
+ if model_turn:
488
+ for part in model_turn.parts:
489
+ if part.text:
490
+ item_generation.text_ch.send_nowait(part.text)
491
+ if part.inline_data:
492
+ frame_data = part.inline_data.data
493
+ frame = rtc.AudioFrame(
494
+ data=frame_data,
495
+ sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
496
+ num_channels=NUM_CHANNELS,
497
+ samples_per_channel=len(frame_data) // 2,
498
+ )
499
+ item_generation.audio_ch.send_nowait(frame)
558
500
 
559
- @utils.log_exceptions(logger=logger)
560
- async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str):
561
- logger.debug(
562
- "executing ai function",
501
+ if server_content.interrupted or server_content.turn_complete:
502
+ self._finalize_response()
503
+
504
+ def _finalize_response(self) -> None:
505
+ if not self._current_generation:
506
+ return
507
+
508
+ for item_generation in self._current_generation.messages.values():
509
+ item_generation.text_ch.close()
510
+ item_generation.audio_ch.close()
511
+
512
+ self._current_generation.function_ch.close()
513
+ self._current_generation.message_ch.close()
514
+ self._current_generation = None
515
+ self._is_interrupted = True
516
+ self._active_response_id = None
517
+ self.emit("agent_speech_stopped")
518
+
519
+ def _handle_tool_calls(self, tool_call: LiveServerToolCall):
520
+ if not self._current_generation:
521
+ return
522
+ for fnc_call in tool_call.function_calls:
523
+ self._current_generation.function_ch.send_nowait(
524
+ llm.FunctionCall(
525
+ call_id=fnc_call.id,
526
+ name=fnc_call.name,
527
+ arguments=json.dumps(fnc_call.args),
528
+ )
529
+ )
530
+ self._finalize_response()
531
+
532
+ def _handle_tool_call_cancellation(
533
+ self, tool_call_cancellation: LiveServerToolCallCancellation
534
+ ):
535
+ logger.warning(
536
+ "function call cancelled",
563
537
  extra={
564
- "function": fnc_call_info.function_info.name,
538
+ "function_call_ids": tool_call_cancellation.ids,
565
539
  },
566
540
  )
541
+ self.emit("function_calls_cancelled", tool_call_cancellation.ids)
567
542
 
568
- called_fnc = fnc_call_info.execute()
569
- try:
570
- await called_fnc.task
571
- except Exception as e:
572
- logger.exception(
573
- "error executing ai function",
574
- extra={
575
- "function": fnc_call_info.function_info.name,
576
- },
577
- exc_info=e,
578
- )
579
- tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc)
580
- if tool_call.content is not None:
581
- tool_response = LiveClientToolResponse(
582
- function_responses=[
583
- FunctionResponse(
584
- name=tool_call.name,
585
- id=tool_call.tool_call_id,
586
- response={"result": tool_call.content},
587
- )
588
- ]
589
- )
590
- await self._session.send(input=tool_response)
543
+ def commit_audio(self) -> None:
544
+ raise NotImplementedError("commit_audio_buffer is not supported yet")
591
545
 
592
- self.emit("function_calls_finished", [called_fnc])
546
+ def clear_audio(self) -> None:
547
+ raise NotImplementedError("clear_audio is not supported yet")
548
+
549
+ def server_vad_enabled(self) -> bool:
550
+ return True