livekit-plugins-google 1.0.0rc8__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,21 +3,22 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import json
5
5
  import os
6
- from collections.abc import AsyncIterable
6
+ import weakref
7
7
  from dataclasses import dataclass
8
- from typing import Literal
9
8
 
10
9
  from google import genai
11
10
  from google.genai._api_client import HttpOptions
12
11
  from google.genai.types import (
13
12
  Blob,
14
13
  Content,
15
- FunctionResponse,
14
+ FunctionDeclaration,
16
15
  GenerationConfig,
17
16
  LiveClientContent,
18
17
  LiveClientRealtimeInput,
19
- LiveClientToolResponse,
20
18
  LiveConnectConfig,
19
+ LiveServerContent,
20
+ LiveServerToolCall,
21
+ LiveServerToolCallCancellation,
21
22
  Modality,
22
23
  Part,
23
24
  PrebuiltVoiceConfig,
@@ -27,44 +28,16 @@ from google.genai.types import (
27
28
  )
28
29
  from livekit import rtc
29
30
  from livekit.agents import llm, utils
30
- from livekit.agents.utils import images
31
+ from livekit.agents.types import NOT_GIVEN, NotGivenOr
32
+ from livekit.agents.utils import is_given
31
33
 
32
34
  from ...log import logger
33
- from .api_proto import (
34
- ClientEvents,
35
- LiveAPIModels,
36
- Voice,
37
- )
38
-
39
- # temporary placeholders
40
- from .temp import _build_gemini_ctx, _build_tools, _create_ai_function_info
41
- from .transcriber import ModelTranscriber, TranscriberSession, TranscriptionContent
35
+ from ...utils import _build_gemini_fnc, get_tool_results_for_realtime, to_chat_ctx
36
+ from .api_proto import ClientEvents, LiveAPIModels, Voice
42
37
 
43
- EventTypes = Literal[
44
- "start_session",
45
- "input_speech_started",
46
- "response_content_added",
47
- "response_content_done",
48
- "function_calls_collected",
49
- "function_calls_finished",
50
- "function_calls_cancelled",
51
- "input_speech_transcription_completed",
52
- "agent_speech_transcription_completed",
53
- "agent_speech_stopped",
54
- ]
55
-
56
-
57
- @dataclass
58
- class GeminiContent:
59
- response_id: str
60
- item_id: str
61
- output_index: int
62
- content_index: int
63
- text: str
64
- audio: list[rtc.AudioFrame]
65
- text_stream: AsyncIterable[str]
66
- audio_stream: AsyncIterable[rtc.AudioFrame]
67
- content_type: Literal["text", "audio"]
38
+ INPUT_AUDIO_SAMPLE_RATE = 16000
39
+ OUTPUT_AUDIO_SAMPLE_RATE = 24000
40
+ NUM_CHANNELS = 1
68
41
 
69
42
 
70
43
  @dataclass
@@ -74,55 +47,59 @@ class InputTranscription:
74
47
 
75
48
 
76
49
  @dataclass
77
- class Capabilities:
78
- supports_truncate: bool
79
- input_audio_sample_rate: int | None = None
80
-
81
-
82
- @dataclass
83
- class ModelOptions:
50
+ class _RealtimeOptions:
84
51
  model: LiveAPIModels | str
85
52
  api_key: str | None
86
53
  voice: Voice | str
87
- response_modalities: list[Modality] | None
54
+ response_modalities: NotGivenOr[list[Modality]]
88
55
  vertexai: bool
89
56
  project: str | None
90
57
  location: str | None
91
58
  candidate_count: int
92
- temperature: float | None
93
- max_output_tokens: int | None
94
- top_p: float | None
95
- top_k: int | None
96
- presence_penalty: float | None
97
- frequency_penalty: float | None
98
- instructions: Content | None
99
- enable_user_audio_transcription: bool
100
- enable_agent_audio_transcription: bool
101
-
102
-
103
- class RealtimeModel:
59
+ temperature: NotGivenOr[float]
60
+ max_output_tokens: NotGivenOr[int]
61
+ top_p: NotGivenOr[float]
62
+ top_k: NotGivenOr[int]
63
+ presence_penalty: NotGivenOr[float]
64
+ frequency_penalty: NotGivenOr[float]
65
+ instructions: NotGivenOr[str]
66
+
67
+
68
+ @dataclass
69
+ class _MessageGeneration:
70
+ message_id: str
71
+ text_ch: utils.aio.Chan[str]
72
+ audio_ch: utils.aio.Chan[rtc.AudioFrame]
73
+
74
+
75
+ @dataclass
76
+ class _ResponseGeneration:
77
+ message_ch: utils.aio.Chan[llm.MessageGeneration]
78
+ function_ch: utils.aio.Chan[llm.FunctionCall]
79
+
80
+ messages: dict[str, _MessageGeneration]
81
+
82
+
83
+ class RealtimeModel(llm.RealtimeModel):
104
84
  def __init__(
105
85
  self,
106
86
  *,
107
- instructions: str | None = None,
87
+ instructions: NotGivenOr[str] = NOT_GIVEN,
108
88
  model: LiveAPIModels | str = "gemini-2.0-flash-exp",
109
- api_key: str | None = None,
89
+ api_key: NotGivenOr[str] = NOT_GIVEN,
110
90
  voice: Voice | str = "Puck",
111
- modalities: list[Modality] = None,
112
- enable_user_audio_transcription: bool = True,
113
- enable_agent_audio_transcription: bool = True,
91
+ modalities: NotGivenOr[list[Modality]] = NOT_GIVEN,
114
92
  vertexai: bool = False,
115
- project: str | None = None,
116
- location: str | None = None,
93
+ project: NotGivenOr[str] = NOT_GIVEN,
94
+ location: NotGivenOr[str] = NOT_GIVEN,
117
95
  candidate_count: int = 1,
118
- temperature: float | None = None,
119
- max_output_tokens: int | None = None,
120
- top_p: float | None = None,
121
- top_k: int | None = None,
122
- presence_penalty: float | None = None,
123
- frequency_penalty: float | None = None,
124
- loop: asyncio.AbstractEventLoop | None = None,
125
- ):
96
+ temperature: NotGivenOr[float] = NOT_GIVEN,
97
+ max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
98
+ top_p: NotGivenOr[float] = NOT_GIVEN,
99
+ top_k: NotGivenOr[int] = NOT_GIVEN,
100
+ presence_penalty: NotGivenOr[float] = NOT_GIVEN,
101
+ frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
102
+ ) -> None:
126
103
  """
127
104
  Initializes a RealtimeModel instance for interacting with Google's Realtime API.
128
105
 
@@ -135,66 +112,57 @@ class RealtimeModel:
135
112
 
136
113
  Args:
137
114
  instructions (str, optional): Initial system instructions for the model. Defaults to "".
138
- api_key (str or None, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
115
+ api_key (str, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
139
116
  modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
140
- model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
117
+ model (str, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
141
118
  voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
142
- enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True
143
- enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True
144
119
  temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
145
120
  vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
146
- project (str or None, optional): The project id to use for the API. Defaults to None. (for vertexai)
147
- location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai)
121
+ project (str, optional): The project id to use for the API. Defaults to None. (for vertexai)
122
+ location (str, optional): The location to use for the API. Defaults to None. (for vertexai)
148
123
  candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
149
124
  top_p (float, optional): The top-p value for response generation
150
125
  top_k (int, optional): The top-k value for response generation
151
126
  presence_penalty (float, optional): The presence penalty for response generation
152
127
  frequency_penalty (float, optional): The frequency penalty for response generation
153
- loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used.
154
128
 
155
129
  Raises:
156
- ValueError: If the API key is not provided and cannot be found in environment variables.
130
+ ValueError: If the API key is required but not found.
157
131
  """ # noqa: E501
158
- if modalities is None:
159
- modalities = ["AUDIO"]
160
- super().__init__()
161
- self._capabilities = Capabilities(
162
- supports_truncate=False,
163
- input_audio_sample_rate=16000,
132
+ super().__init__(
133
+ capabilities=llm.RealtimeCapabilities(
134
+ message_truncation=False,
135
+ turn_detection=True,
136
+ user_transcription=False,
137
+ )
164
138
  )
165
- self._model = model
166
- self._loop = loop or asyncio.get_event_loop()
167
- self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
168
- self._project = project or os.environ.get("GOOGLE_CLOUD_PROJECT")
169
- self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION")
139
+
140
+ gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
141
+ gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
142
+ gcp_location = location if is_given(location) else os.environ.get("GOOGLE_CLOUD_LOCATION")
170
143
  if vertexai:
171
- if not self._project or not self._location:
144
+ if not gcp_project or not gcp_location:
172
145
  raise ValueError(
173
146
  "Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables" # noqa: E501
174
147
  )
175
- self._api_key = None # VertexAI does not require an API key
148
+ gemini_api_key = None # VertexAI does not require an API key
176
149
 
177
150
  else:
178
- self._project = None
179
- self._location = None
180
- if not self._api_key:
151
+ gcp_project = None
152
+ gcp_location = None
153
+ if not gemini_api_key:
181
154
  raise ValueError(
182
155
  "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" # noqa: E501
183
156
  )
184
157
 
185
- instructions_content = Content(parts=[Part(text=instructions)]) if instructions else None
186
-
187
- self._rt_sessions: list[GeminiRealtimeSession] = []
188
- self._opts = ModelOptions(
158
+ self._opts = _RealtimeOptions(
189
159
  model=model,
190
- api_key=self._api_key,
160
+ api_key=gemini_api_key,
191
161
  voice=voice,
192
- enable_user_audio_transcription=enable_user_audio_transcription,
193
- enable_agent_audio_transcription=enable_agent_audio_transcription,
194
162
  response_modalities=modalities,
195
163
  vertexai=vertexai,
196
- project=self._project,
197
- location=self._location,
164
+ project=gcp_project,
165
+ location=gcp_location,
198
166
  candidate_count=candidate_count,
199
167
  temperature=temperature,
200
168
  max_output_tokens=max_output_tokens,
@@ -202,88 +170,39 @@ class RealtimeModel:
202
170
  top_k=top_k,
203
171
  presence_penalty=presence_penalty,
204
172
  frequency_penalty=frequency_penalty,
205
- instructions=instructions_content,
173
+ instructions=instructions,
206
174
  )
207
175
 
208
- @property
209
- def sessions(self) -> list[GeminiRealtimeSession]:
210
- return self._rt_sessions
211
-
212
- @property
213
- def capabilities(self) -> Capabilities:
214
- return self._capabilities
176
+ self._sessions = weakref.WeakSet[RealtimeSession]()
215
177
 
216
- def session(
217
- self,
218
- *,
219
- chat_ctx: llm.ChatContext | None = None,
220
- fnc_ctx: llm.FunctionContext | None = None,
221
- ) -> GeminiRealtimeSession:
222
- session = GeminiRealtimeSession(
223
- opts=self._opts,
224
- chat_ctx=chat_ctx or llm.ChatContext(),
225
- fnc_ctx=fnc_ctx,
226
- loop=self._loop,
227
- )
228
- self._rt_sessions.append(session)
178
+ def session(self) -> RealtimeSession:
179
+ sess = RealtimeSession(self)
180
+ self._sessions.add(sess)
181
+ return sess
229
182
 
230
- return session
183
+ def update_options(
184
+ self, *, voice: NotGivenOr[str] = NOT_GIVEN, temperature: NotGivenOr[float] = NOT_GIVEN
185
+ ) -> None:
186
+ if is_given(voice):
187
+ self._opts.voice = voice
231
188
 
232
- async def aclose(self) -> None:
233
- for session in self._rt_sessions:
234
- await session.aclose()
189
+ if is_given(temperature):
190
+ self._opts.temperature = temperature
235
191
 
192
+ for sess in self._sessions:
193
+ sess.update_options(voice=self._opts.voice, temperature=self._opts.temperature)
236
194
 
237
- class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
238
- def __init__(
239
- self,
240
- *,
241
- opts: ModelOptions,
242
- chat_ctx: llm.ChatContext,
243
- fnc_ctx: llm.FunctionContext | None,
244
- loop: asyncio.AbstractEventLoop,
245
- ):
246
- """
247
- Initializes a GeminiRealtimeSession instance for interacting with Google's Realtime API.
195
+ async def aclose(self) -> None: ...
248
196
 
249
- Args:
250
- opts (ModelOptions): The model options for the session.
251
- chat_ctx (llm.ChatContext): The chat context for the session.
252
- fnc_ctx (llm.FunctionContext or None): The function context for the session.
253
- loop (asyncio.AbstractEventLoop): The event loop for the session.
254
- """
255
- super().__init__()
256
- self._loop = loop
257
- self._opts = opts
258
- self._chat_ctx = chat_ctx
259
- self._fnc_ctx = fnc_ctx
260
- self._fnc_tasks = utils.aio.TaskSet()
261
- self._is_interrupted = False
262
197
 
263
- tools = []
264
- if self._fnc_ctx is not None:
265
- functions = _build_tools(self._fnc_ctx)
266
- tools.append(Tool(function_declarations=functions))
267
-
268
- self._config = LiveConnectConfig(
269
- response_modalities=self._opts.response_modalities,
270
- generation_config=GenerationConfig(
271
- candidate_count=self._opts.candidate_count,
272
- temperature=self._opts.temperature,
273
- max_output_tokens=self._opts.max_output_tokens,
274
- top_p=self._opts.top_p,
275
- top_k=self._opts.top_k,
276
- presence_penalty=self._opts.presence_penalty,
277
- frequency_penalty=self._opts.frequency_penalty,
278
- ),
279
- system_instruction=self._opts.instructions,
280
- speech_config=SpeechConfig(
281
- voice_config=VoiceConfig(
282
- prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
283
- )
284
- ),
285
- tools=tools,
286
- )
198
+ class RealtimeSession(llm.RealtimeSession):
199
+ def __init__(self, realtime_model: RealtimeModel) -> None:
200
+ super().__init__(realtime_model)
201
+ self._opts = realtime_model._opts
202
+ self._tools = llm.ToolContext.empty()
203
+ self._chat_ctx = llm.ChatContext.empty()
204
+ self._msg_ch = utils.aio.Chan[ClientEvents]()
205
+ self._gemini_tools: list[Tool] = []
287
206
  self._client = genai.Client(
288
207
  http_options=HttpOptions(api_version="v1alpha"),
289
208
  api_key=self._opts.api_key,
@@ -292,278 +211,340 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
292
211
  location=self._opts.location,
293
212
  )
294
213
  self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
295
- if self._opts.enable_user_audio_transcription:
296
- self._transcriber = TranscriberSession(client=self._client, model=self._opts.model)
297
- self._transcriber.on("input_speech_done", self._on_input_speech_done)
298
- if self._opts.enable_agent_audio_transcription:
299
- self._agent_transcriber = ModelTranscriber(client=self._client, model=self._opts.model)
300
- self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done)
301
- # init dummy task
302
- self._init_sync_task = asyncio.create_task(asyncio.sleep(0))
303
- self._send_ch = utils.aio.Chan[ClientEvents]()
304
- self._active_response_id = None
305
214
 
306
- async def aclose(self) -> None:
307
- if self._send_ch.closed:
308
- return
215
+ self._current_generation: _ResponseGeneration | None = None
309
216
 
310
- self._send_ch.close()
311
- await self._main_atask
217
+ self._is_interrupted = False
218
+ self._active_response_id = None
219
+ self._session = None
220
+ self._update_chat_ctx_lock = asyncio.Lock()
221
+ self._update_fnc_ctx_lock = asyncio.Lock()
222
+ self._response_created_futures: dict[str, asyncio.Future[llm.GenerationCreatedEvent]] = {}
223
+ self._pending_generation_event_id = None
224
+
225
+ self._reconnect_event = asyncio.Event()
226
+ self._session_lock = asyncio.Lock()
227
+ self._gemini_close_task: asyncio.Task | None = None
228
+
229
+ def _schedule_gemini_session_close(self) -> None:
230
+ if self._session is not None:
231
+ self._gemini_close_task = asyncio.create_task(self._close_gemini_session())
232
+
233
+ async def _close_gemini_session(self) -> None:
234
+ async with self._session_lock:
235
+ if self._session:
236
+ try:
237
+ await self._session.close()
238
+ finally:
239
+ self._session = None
240
+
241
+ def update_options(
242
+ self,
243
+ *,
244
+ voice: NotGivenOr[str] = NOT_GIVEN,
245
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
246
+ temperature: NotGivenOr[float] = NOT_GIVEN,
247
+ ) -> None:
248
+ if is_given(voice):
249
+ self._opts.voice = voice
250
+
251
+ if is_given(temperature):
252
+ self._opts.temperature = temperature
253
+
254
+ if self._session:
255
+ logger.warning("Updating options; triggering Gemini session reconnect.")
256
+ self._reconnect_event.set()
257
+ self._schedule_gemini_session_close()
258
+
259
+ async def update_instructions(self, instructions: str) -> None:
260
+ self._opts.instructions = instructions
261
+ if self._session:
262
+ logger.warning("Updating instructions; triggering Gemini session reconnect.")
263
+ self._reconnect_event.set()
264
+ self._schedule_gemini_session_close()
265
+
266
+ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
267
+ async with self._update_chat_ctx_lock:
268
+ self._chat_ctx = chat_ctx
269
+ turns, _ = to_chat_ctx(self._chat_ctx, id(self), ignore_functions=True)
270
+ tool_results = get_tool_results_for_realtime(self._chat_ctx)
271
+ if turns:
272
+ self._msg_ch.send_nowait(LiveClientContent(turns=turns, turn_complete=False))
273
+ if tool_results:
274
+ self._msg_ch.send_nowait(tool_results)
275
+
276
+ async def update_tools(self, tools: list[llm.FunctionTool]) -> None:
277
+ async with self._update_fnc_ctx_lock:
278
+ retained_tools: list[llm.FunctionTool] = []
279
+ gemini_function_declarations: list[FunctionDeclaration] = []
280
+
281
+ for tool in tools:
282
+ gemini_function = _build_gemini_fnc(tool)
283
+ gemini_function_declarations.append(gemini_function)
284
+ retained_tools.append(tool)
285
+
286
+ self._tools = llm.ToolContext(retained_tools)
287
+ self._gemini_tools = [Tool(function_declarations=gemini_function_declarations)]
288
+ if self._session and gemini_function_declarations:
289
+ logger.warning("Updating tools; triggering Gemini session reconnect.")
290
+ self._reconnect_event.set()
291
+ self._schedule_gemini_session_close()
312
292
 
313
293
  @property
314
- def fnc_ctx(self) -> llm.FunctionContext | None:
315
- return self._fnc_ctx
294
+ def chat_ctx(self) -> llm.ChatContext:
295
+ return self._chat_ctx
316
296
 
317
- @fnc_ctx.setter
318
- def fnc_ctx(self, value: llm.FunctionContext | None) -> None:
319
- self._fnc_ctx = value
297
+ @property
298
+ def tools(self) -> llm.ToolContext:
299
+ return self._tools
320
300
 
321
- def _push_media_chunk(self, data: bytes, mime_type: str) -> None:
301
+ def push_audio(self, frame: rtc.AudioFrame) -> None:
322
302
  realtime_input = LiveClientRealtimeInput(
323
- media_chunks=[Blob(data=data, mime_type=mime_type)],
303
+ media_chunks=[Blob(data=frame.data.tobytes(), mime_type="audio/pcm")],
324
304
  )
325
- self._queue_msg(realtime_input)
305
+ self._msg_ch.send_nowait(realtime_input)
326
306
 
327
- DEFAULT_ENCODE_OPTIONS = images.EncodeOptions(
328
- format="JPEG",
329
- quality=75,
330
- resize_options=images.ResizeOptions(width=1024, height=1024, strategy="scale_aspect_fit"),
331
- )
307
+ def generate_reply(
308
+ self, *, instructions: NotGivenOr[str] = NOT_GIVEN
309
+ ) -> asyncio.Future[llm.GenerationCreatedEvent]:
310
+ fut = asyncio.Future()
332
311
 
333
- def push_video(
334
- self,
335
- frame: rtc.VideoFrame,
336
- encode_options: images.EncodeOptions = DEFAULT_ENCODE_OPTIONS,
337
- ) -> None:
338
- """Push a video frame to the Gemini Multimodal Live session.
312
+ event_id = utils.shortuuid("gemini-response-")
313
+ self._response_created_futures[event_id] = fut
314
+ self._pending_generation_event_id = event_id
339
315
 
340
- Args:
341
- frame (rtc.VideoFrame): The video frame to push.
342
- encode_options (images.EncodeOptions, optional): The encode options for the video frame. Defaults to 1024x1024 JPEG.
316
+ instructions_content = instructions if is_given(instructions) else "."
317
+ ctx = [Content(parts=[Part(text=instructions_content)], role="user")]
318
+ self._msg_ch.send_nowait(LiveClientContent(turns=ctx, turn_complete=True))
343
319
 
344
- Notes:
345
- - This will be sent immediately so you should use a sampling frame rate that makes sense for your application and Gemini's constraints. 1 FPS is a good starting point.
346
- """ # noqa: E501
347
- encoded_data = images.encode(
348
- frame,
349
- encode_options,
350
- )
351
- mime_type = (
352
- "image/jpeg"
353
- if encode_options.format == "JPEG"
354
- else "image/png"
355
- if encode_options.format == "PNG"
356
- else "image/jpeg"
357
- )
358
- self._push_media_chunk(encoded_data, mime_type)
320
+ def _on_timeout() -> None:
321
+ if event_id in self._response_created_futures and not fut.done():
322
+ fut.set_exception(llm.RealtimeError("generate_reply timed out."))
323
+ self._response_created_futures.pop(event_id, None)
324
+ if self._pending_generation_event_id == event_id:
325
+ self._pending_generation_event_id = None
359
326
 
360
- def _push_audio(self, frame: rtc.AudioFrame) -> None:
361
- if self._opts.enable_user_audio_transcription:
362
- self._transcriber._push_audio(frame)
327
+ handle = asyncio.get_event_loop().call_later(5.0, _on_timeout)
328
+ fut.add_done_callback(lambda _: handle.cancel())
363
329
 
364
- self._push_media_chunk(frame.data.tobytes(), "audio/pcm")
330
+ return fut
365
331
 
366
- def _queue_msg(self, msg: ClientEvents) -> None:
367
- self._send_ch.send_nowait(msg)
332
+ def interrupt(self) -> None:
333
+ logger.warning("interrupt() - no direct cancellation in Gemini")
368
334
 
369
- def chat_ctx_copy(self) -> llm.ChatContext:
370
- return self._chat_ctx.copy()
335
+ def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
336
+ logger.warning(f"truncate(...) called for {message_id}, ignoring for Gemini")
371
337
 
372
- async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
373
- self._chat_ctx = ctx.copy()
338
+ async def aclose(self) -> None:
339
+ self._msg_ch.close()
374
340
 
375
- def cancel_response(self) -> None:
376
- raise NotImplementedError("cancel_response is not supported yet")
341
+ for fut in self._response_created_futures.values():
342
+ if not fut.done():
343
+ fut.set_exception(llm.RealtimeError("Session closed"))
377
344
 
378
- def create_response(
379
- self,
380
- on_duplicate: Literal["cancel_existing", "cancel_new", "keep_both"] = "keep_both",
381
- ) -> None:
382
- turns, _ = _build_gemini_ctx(self._chat_ctx, id(self))
383
- ctx = [self._opts.instructions] + turns if self._opts.instructions else turns
345
+ if self._main_atask:
346
+ await utils.aio.cancel_and_wait(self._main_atask)
384
347
 
385
- if not ctx:
386
- logger.warning(
387
- "gemini-realtime-session: No chat context to send, sending dummy content."
348
+ if self._gemini_close_task:
349
+ await utils.aio.cancel_and_wait(self._gemini_close_task)
350
+
351
+ @utils.log_exceptions(logger=logger)
352
+ async def _main_task(self):
353
+ while True:
354
+ config = LiveConnectConfig(
355
+ response_modalities=self._opts.response_modalities
356
+ if is_given(self._opts.response_modalities)
357
+ else [Modality.AUDIO],
358
+ generation_config=GenerationConfig(
359
+ candidate_count=self._opts.candidate_count,
360
+ temperature=self._opts.temperature
361
+ if is_given(self._opts.temperature)
362
+ else None,
363
+ max_output_tokens=self._opts.max_output_tokens
364
+ if is_given(self._opts.max_output_tokens)
365
+ else None,
366
+ top_p=self._opts.top_p if is_given(self._opts.top_p) else None,
367
+ top_k=self._opts.top_k if is_given(self._opts.top_k) else None,
368
+ presence_penalty=self._opts.presence_penalty
369
+ if is_given(self._opts.presence_penalty)
370
+ else None,
371
+ frequency_penalty=self._opts.frequency_penalty
372
+ if is_given(self._opts.frequency_penalty)
373
+ else None,
374
+ ),
375
+ system_instruction=Content(parts=[Part(text=self._opts.instructions)])
376
+ if is_given(self._opts.instructions)
377
+ else None,
378
+ speech_config=SpeechConfig(
379
+ voice_config=VoiceConfig(
380
+ prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
381
+ )
382
+ ),
383
+ tools=self._gemini_tools,
388
384
  )
389
- ctx = [Content(parts=[Part(text=".")])]
390
385
 
391
- self._queue_msg(LiveClientContent(turns=ctx, turn_complete=True))
386
+ async with self._client.aio.live.connect(
387
+ model=self._opts.model, config=config
388
+ ) as session:
389
+ async with self._session_lock:
390
+ self._session = session
391
+
392
+ @utils.log_exceptions(logger=logger)
393
+ async def _send_task():
394
+ async for msg in self._msg_ch:
395
+ if isinstance(msg, LiveClientContent):
396
+ await session.send(input=msg, end_of_turn=True)
397
+
398
+ await session.send(input=msg)
399
+ await session.send(input=".", end_of_turn=True)
400
+
401
+ @utils.log_exceptions(logger=logger)
402
+ async def _recv_task():
403
+ while True:
404
+ async for response in session.receive():
405
+ if self._active_response_id is None:
406
+ self._start_new_generation()
407
+ if response.server_content:
408
+ self._handle_server_content(response.server_content)
409
+ if response.tool_call:
410
+ self._handle_tool_calls(response.tool_call)
411
+ if response.tool_call_cancellation:
412
+ self._handle_tool_call_cancellation(response.tool_call_cancellation)
413
+
414
+ send_task = asyncio.create_task(_send_task(), name="gemini-realtime-send")
415
+ recv_task = asyncio.create_task(_recv_task(), name="gemini-realtime-recv")
416
+ reconnect_task = asyncio.create_task(
417
+ self._reconnect_event.wait(), name="reconnect-wait"
418
+ )
392
419
 
393
- def commit_audio_buffer(self) -> None:
394
- raise NotImplementedError("commit_audio_buffer is not supported yet")
420
+ try:
421
+ done, _ = await asyncio.wait(
422
+ [send_task, recv_task, reconnect_task],
423
+ return_when=asyncio.FIRST_COMPLETED,
424
+ )
425
+ for task in done:
426
+ if task != reconnect_task:
427
+ task.result()
395
428
 
396
- def server_vad_enabled(self) -> bool:
397
- return True
429
+ if reconnect_task not in done:
430
+ break
398
431
 
399
- def _on_input_speech_done(self, content: TranscriptionContent) -> None:
400
- if content.response_id and content.text:
401
- self.emit(
402
- "input_speech_transcription_completed",
403
- InputTranscription(
404
- item_id=content.response_id,
405
- transcript=content.text,
406
- ),
432
+ self._reconnect_event.clear()
433
+ finally:
434
+ await utils.aio.cancel_and_wait(send_task, recv_task, reconnect_task)
435
+
436
+ def _start_new_generation(self):
437
+ self._is_interrupted = False
438
+ self._active_response_id = utils.shortuuid("gemini-turn-")
439
+ self._current_generation = _ResponseGeneration(
440
+ message_ch=utils.aio.Chan[llm.MessageGeneration](),
441
+ function_ch=utils.aio.Chan[llm.FunctionCall](),
442
+ messages={},
443
+ )
444
+
445
+ # We'll assume each chunk belongs to a single message ID self._active_response_id
446
+ item_generation = _MessageGeneration(
447
+ message_id=self._active_response_id,
448
+ text_ch=utils.aio.Chan[str](),
449
+ audio_ch=utils.aio.Chan[rtc.AudioFrame](),
450
+ )
451
+
452
+ self._current_generation.message_ch.send_nowait(
453
+ llm.MessageGeneration(
454
+ message_id=self._active_response_id,
455
+ text_stream=item_generation.text_ch,
456
+ audio_stream=item_generation.audio_ch,
407
457
  )
458
+ )
408
459
 
409
- # self._chat_ctx.append(text=content.text, role="user")
410
- # TODO: implement sync mechanism to make sure the transcribed user speech is inside the chat_ctx and always before the generated agent speech # noqa: E501
460
+ generation_event = llm.GenerationCreatedEvent(
461
+ message_stream=self._current_generation.message_ch,
462
+ function_stream=self._current_generation.function_ch,
463
+ user_initiated=False,
464
+ )
411
465
 
412
- def _on_agent_speech_done(self, content: TranscriptionContent) -> None:
413
- if content.response_id and content.text:
414
- self.emit(
415
- "agent_speech_transcription_completed",
416
- InputTranscription(
417
- item_id=content.response_id,
418
- transcript=content.text,
419
- ),
466
+ # Resolve any pending future from generate_reply()
467
+ if self._pending_generation_event_id and (
468
+ fut := self._response_created_futures.pop(self._pending_generation_event_id, None)
469
+ ):
470
+ fut.set_result(generation_event)
471
+
472
+ self._pending_generation_event_id = None
473
+ self.emit("generation_created", generation_event)
474
+
475
+ self._current_generation.messages[self._active_response_id] = item_generation
476
+
477
+ def _handle_server_content(self, server_content: LiveServerContent):
478
+ if not self._current_generation or not self._active_response_id:
479
+ logger.warning(
480
+ "gemini-realtime-session: No active response ID, skipping server content"
420
481
  )
421
- # self._chat_ctx.append(text=content.text, role="assistant")
482
+ return
422
483
 
423
- @utils.log_exceptions(logger=logger)
424
- async def _main_task(self):
425
- @utils.log_exceptions(logger=logger)
426
- async def _send_task():
427
- async for msg in self._send_ch:
428
- await self._session.send(input=msg)
429
-
430
- await self._session.send(input=".", end_of_turn=True)
431
-
432
- @utils.log_exceptions(logger=logger)
433
- async def _recv_task():
434
- while True:
435
- async for response in self._session.receive():
436
- if self._active_response_id is None:
437
- self._is_interrupted = False
438
- self._active_response_id = utils.shortuuid()
439
- text_stream = utils.aio.Chan[str]()
440
- audio_stream = utils.aio.Chan[rtc.AudioFrame]()
441
- content = GeminiContent(
442
- response_id=self._active_response_id,
443
- item_id=self._active_response_id,
444
- output_index=0,
445
- content_index=0,
446
- text="",
447
- audio=[],
448
- text_stream=text_stream,
449
- audio_stream=audio_stream,
450
- content_type="audio",
451
- )
452
- self.emit("response_content_added", content)
453
-
454
- server_content = response.server_content
455
- if server_content:
456
- model_turn = server_content.model_turn
457
- if model_turn:
458
- for part in model_turn.parts:
459
- if part.text:
460
- content.text_stream.send_nowait(part.text)
461
- if part.inline_data:
462
- frame = rtc.AudioFrame(
463
- data=part.inline_data.data,
464
- sample_rate=24000,
465
- num_channels=1,
466
- samples_per_channel=len(part.inline_data.data) // 2,
467
- )
468
- if self._opts.enable_agent_audio_transcription:
469
- content.audio.append(frame)
470
- content.audio_stream.send_nowait(frame)
471
-
472
- if server_content.interrupted or server_content.turn_complete:
473
- if self._opts.enable_agent_audio_transcription:
474
- self._agent_transcriber._push_audio(content.audio)
475
- for stream in (content.text_stream, content.audio_stream):
476
- if isinstance(stream, utils.aio.Chan):
477
- stream.close()
478
-
479
- self.emit("agent_speech_stopped")
480
- self._is_interrupted = True
481
-
482
- self._active_response_id = None
483
-
484
- if response.tool_call:
485
- if self._fnc_ctx is None:
486
- raise ValueError("Function context is not set")
487
- fnc_calls = []
488
- for fnc_call in response.tool_call.function_calls:
489
- fnc_call_info = _create_ai_function_info(
490
- self._fnc_ctx,
491
- fnc_call.id,
492
- fnc_call.name,
493
- json.dumps(fnc_call.args),
494
- )
495
- fnc_calls.append(fnc_call_info)
496
-
497
- self.emit("function_calls_collected", fnc_calls)
498
-
499
- for fnc_call_info in fnc_calls:
500
- self._fnc_tasks.create_task(
501
- self._run_fnc_task(fnc_call_info, content.item_id)
502
- )
503
-
504
- # Handle function call cancellations
505
- if response.tool_call_cancellation:
506
- logger.warning(
507
- "function call cancelled",
508
- extra={
509
- "function_call_ids": response.tool_call_cancellation.function_call_ids, # noqa: E501
510
- },
511
- )
512
- self.emit(
513
- "function_calls_cancelled",
514
- response.tool_call_cancellation.function_call_ids,
515
- )
516
-
517
- async with self._client.aio.live.connect(
518
- model=self._opts.model, config=self._config
519
- ) as session:
520
- self._session = session
521
- tasks = [
522
- asyncio.create_task(_send_task(), name="gemini-realtime-send"),
523
- asyncio.create_task(_recv_task(), name="gemini-realtime-recv"),
524
- ]
525
-
526
- try:
527
- await asyncio.gather(*tasks)
528
- finally:
529
- await utils.aio.gracefully_cancel(*tasks)
530
- await self._session.close()
531
- if self._opts.enable_user_audio_transcription:
532
- await self._transcriber.aclose()
533
- if self._opts.enable_agent_audio_transcription:
534
- await self._agent_transcriber.aclose()
484
+ item_generation = self._current_generation.messages[self._active_response_id]
485
+
486
+ model_turn = server_content.model_turn
487
+ if model_turn:
488
+ for part in model_turn.parts:
489
+ if part.text:
490
+ item_generation.text_ch.send_nowait(part.text)
491
+ if part.inline_data:
492
+ frame_data = part.inline_data.data
493
+ frame = rtc.AudioFrame(
494
+ data=frame_data,
495
+ sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
496
+ num_channels=NUM_CHANNELS,
497
+ samples_per_channel=len(frame_data) // 2,
498
+ )
499
+ item_generation.audio_ch.send_nowait(frame)
535
500
 
536
- @utils.log_exceptions(logger=logger)
537
- async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str):
538
- logger.debug(
539
- "executing ai function",
501
+ if server_content.interrupted or server_content.turn_complete:
502
+ self._finalize_response()
503
+
504
+ def _finalize_response(self) -> None:
505
+ if not self._current_generation:
506
+ return
507
+
508
+ for item_generation in self._current_generation.messages.values():
509
+ item_generation.text_ch.close()
510
+ item_generation.audio_ch.close()
511
+
512
+ self._current_generation.function_ch.close()
513
+ self._current_generation.message_ch.close()
514
+ self._current_generation = None
515
+ self._is_interrupted = True
516
+ self._active_response_id = None
517
+ self.emit("agent_speech_stopped")
518
+
519
+ def _handle_tool_calls(self, tool_call: LiveServerToolCall):
520
+ if not self._current_generation:
521
+ return
522
+ for fnc_call in tool_call.function_calls:
523
+ self._current_generation.function_ch.send_nowait(
524
+ llm.FunctionCall(
525
+ call_id=fnc_call.id,
526
+ name=fnc_call.name,
527
+ arguments=json.dumps(fnc_call.args),
528
+ )
529
+ )
530
+ self._finalize_response()
531
+
532
+ def _handle_tool_call_cancellation(
533
+ self, tool_call_cancellation: LiveServerToolCallCancellation
534
+ ):
535
+ logger.warning(
536
+ "function call cancelled",
540
537
  extra={
541
- "function": fnc_call_info.function_info.name,
538
+ "function_call_ids": tool_call_cancellation.ids,
542
539
  },
543
540
  )
541
+ self.emit("function_calls_cancelled", tool_call_cancellation.ids)
544
542
 
545
- called_fnc = fnc_call_info.execute()
546
- try:
547
- await called_fnc.task
548
- except Exception as e:
549
- logger.exception(
550
- "error executing ai function",
551
- extra={
552
- "function": fnc_call_info.function_info.name,
553
- },
554
- exc_info=e,
555
- )
556
- tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc)
557
- if tool_call.content is not None:
558
- tool_response = LiveClientToolResponse(
559
- function_responses=[
560
- FunctionResponse(
561
- name=tool_call.name,
562
- id=tool_call.tool_call_id,
563
- response={"result": tool_call.content},
564
- )
565
- ]
566
- )
567
- await self._session.send(input=tool_response)
543
+ def commit_audio(self) -> None:
544
+ raise NotImplementedError("commit_audio_buffer is not supported yet")
568
545
 
569
- self.emit("function_calls_finished", [called_fnc])
546
+ def clear_audio(self) -> None:
547
+ raise NotImplementedError("clear_audio is not supported yet")
548
+
549
+ def server_vad_enabled(self) -> bool:
550
+ return True