livekit-plugins-google 1.0.15__py3-none-any.whl → 1.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import contextlib
4
5
  import json
5
6
  import os
6
7
  import weakref
8
+ from collections.abc import Iterator
7
9
  from dataclasses import dataclass
8
10
 
9
11
  from google import genai
@@ -31,15 +33,16 @@ from google.genai.types import (
31
33
  from livekit import rtc
32
34
  from livekit.agents import llm, utils
33
35
  from livekit.agents.types import NOT_GIVEN, NotGivenOr
34
- from livekit.agents.utils import images, is_given
36
+ from livekit.agents.utils import audio as audio_utils, images, is_given
37
+ from livekit.plugins.google.beta.realtime.api_proto import ClientEvents, LiveAPIModels, Voice
35
38
 
36
39
  from ...log import logger
37
40
  from ...utils import _build_gemini_fnc, get_tool_results_for_realtime, to_chat_ctx
38
- from .api_proto import ClientEvents, LiveAPIModels, Voice
39
41
 
40
42
  INPUT_AUDIO_SAMPLE_RATE = 16000
43
+ INPUT_AUDIO_CHANNELS = 1
41
44
  OUTPUT_AUDIO_SAMPLE_RATE = 24000
42
- NUM_CHANNELS = 1
45
+ OUTPUT_AUDIO_CHANNELS = 1
43
46
 
44
47
  DEFAULT_ENCODE_OPTIONS = images.EncodeOptions(
45
48
  format="JPEG",
@@ -126,7 +129,7 @@ class RealtimeModel(llm.RealtimeModel):
126
129
  instructions (str, optional): Initial system instructions for the model. Defaults to "".
127
130
  api_key (str, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
128
131
  modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
129
- model (str, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
132
+ model (str, optional): The name of the model to use. Defaults to "gemini-2.0-flash-live-001".
130
133
  voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
131
134
  temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
132
135
  vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
@@ -147,20 +150,20 @@ class RealtimeModel(llm.RealtimeModel):
147
150
  capabilities=llm.RealtimeCapabilities(
148
151
  message_truncation=False,
149
152
  turn_detection=True,
150
- user_transcription=False,
153
+ user_transcription=is_given(input_audio_transcription),
151
154
  )
152
155
  )
153
156
 
154
157
  gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
155
158
  gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
156
159
  gcp_location = location if is_given(location) else os.environ.get("GOOGLE_CLOUD_LOCATION")
160
+
157
161
  if vertexai:
158
162
  if not gcp_project or not gcp_location:
159
163
  raise ValueError(
160
164
  "Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables" # noqa: E501
161
165
  )
162
166
  gemini_api_key = None # VertexAI does not require an API key
163
-
164
167
  else:
165
168
  gcp_project = None
166
169
  gcp_location = None
@@ -213,7 +216,8 @@ class RealtimeModel(llm.RealtimeModel):
213
216
  for sess in self._sessions:
214
217
  sess.update_options(voice=self._opts.voice, temperature=self._opts.temperature)
215
218
 
216
- async def aclose(self) -> None: ...
219
+ async def aclose(self) -> None:
220
+ pass
217
221
 
218
222
 
219
223
  class RealtimeSession(llm.RealtimeSession):
@@ -221,138 +225,162 @@ class RealtimeSession(llm.RealtimeSession):
221
225
  super().__init__(realtime_model)
222
226
  self._opts = realtime_model._opts
223
227
  self._tools = llm.ToolContext.empty()
228
+ self._gemini_declarations: list[FunctionDeclaration] = []
224
229
  self._chat_ctx = llm.ChatContext.empty()
225
230
  self._msg_ch = utils.aio.Chan[ClientEvents]()
226
- self._gemini_tools: list[Tool] = []
231
+ self._input_resampler: rtc.AudioResampler | None = None
232
+
233
+ # 50ms chunks
234
+ self._bstream = audio_utils.AudioByteStream(
235
+ INPUT_AUDIO_SAMPLE_RATE,
236
+ INPUT_AUDIO_CHANNELS,
237
+ samples_per_channel=INPUT_AUDIO_SAMPLE_RATE // 20,
238
+ )
239
+
227
240
  self._client = genai.Client(
228
241
  api_key=self._opts.api_key,
229
242
  vertexai=self._opts.vertexai,
230
243
  project=self._opts.project,
231
244
  location=self._opts.location,
232
245
  )
246
+
233
247
  self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
234
248
 
235
249
  self._current_generation: _ResponseGeneration | None = None
236
-
237
- self._is_interrupted = False
238
- self._active_response_id = None
239
- self._session = None
240
- self._update_chat_ctx_lock = asyncio.Lock()
241
- self._update_fnc_ctx_lock = asyncio.Lock()
250
+ self._active_session: genai.LiveSession | None = None
251
+ # indicates if the underlying session should end
252
+ self._session_should_close = asyncio.Event()
242
253
  self._response_created_futures: dict[str, asyncio.Future[llm.GenerationCreatedEvent]] = {}
243
- self._pending_generation_event_id = None
254
+ self._pending_generation_fut: asyncio.Future[llm.GenerationCreatedEvent] | None = None
244
255
 
245
- self._reconnect_event = asyncio.Event()
256
+ self._update_lock = asyncio.Lock()
246
257
  self._session_lock = asyncio.Lock()
247
- self._gemini_close_task: asyncio.Task | None = None
248
-
249
- def _schedule_gemini_session_close(self) -> None:
250
- if self._session is not None:
251
- self._gemini_close_task = asyncio.create_task(self._close_gemini_session())
252
258
 
253
- async def _close_gemini_session(self) -> None:
259
+ async def _close_active_session(self) -> None:
254
260
  async with self._session_lock:
255
- if self._session:
261
+ if self._active_session:
256
262
  try:
257
- await self._session.close()
263
+ await self._active_session.close()
264
+ except Exception as e:
265
+ logger.warning(f"error closing Gemini session: {e}")
258
266
  finally:
259
- self._session = None
267
+ self._active_session = None
260
268
 
261
- def update_options(
269
+ def _mark_restart_needed(self):
270
+ if not self._session_should_close.is_set():
271
+ self._session_should_close.set()
272
+ # reset the msg_ch, do not send messages from previous session
273
+ self._msg_ch = utils.aio.Chan[ClientEvents]()
274
+
275
+ async def update_options(
262
276
  self,
263
277
  *,
264
278
  voice: NotGivenOr[str] = NOT_GIVEN,
265
- tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
266
279
  temperature: NotGivenOr[float] = NOT_GIVEN,
280
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
267
281
  ) -> None:
268
- if is_given(voice):
269
- self._opts.voice = voice
282
+ async with self._update_lock:
283
+ should_restart = False
284
+ if is_given(voice) and self._opts.voice != voice:
285
+ self._opts.voice = voice
286
+ should_restart = True
270
287
 
271
- if is_given(temperature):
272
- self._opts.temperature = temperature
288
+ if is_given(temperature) and self._opts.temperature != temperature:
289
+ self._opts.temperature = temperature if is_given(temperature) else NOT_GIVEN
290
+ should_restart = True
273
291
 
274
- if self._session:
275
- logger.warning("Updating options; triggering Gemini session reconnect.")
276
- self._reconnect_event.set()
277
- self._schedule_gemini_session_close()
292
+ if should_restart:
293
+ self._mark_restart_needed()
278
294
 
279
295
  async def update_instructions(self, instructions: str) -> None:
280
- self._opts.instructions = instructions
281
- if self._session:
282
- logger.warning("Updating instructions; triggering Gemini session reconnect.")
283
- self._reconnect_event.set()
284
- self._schedule_gemini_session_close()
296
+ async with self._update_lock:
297
+ if not is_given(self._opts.instructions) or self._opts.instructions != instructions:
298
+ self._opts.instructions = instructions
299
+ self._mark_restart_needed()
285
300
 
286
301
  async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
287
- async with self._update_chat_ctx_lock:
288
- self._chat_ctx = chat_ctx
302
+ async with self._update_lock:
303
+ self._chat_ctx = chat_ctx.copy()
289
304
  turns, _ = to_chat_ctx(self._chat_ctx, id(self), ignore_functions=True)
290
305
  tool_results = get_tool_results_for_realtime(self._chat_ctx)
306
+ # TODO(dz): need to compute delta and then either append or recreate session
291
307
  if turns:
292
- self._msg_ch.send_nowait(LiveClientContent(turns=turns, turn_complete=False))
308
+ self._send_client_event(LiveClientContent(turns=turns, turn_complete=False))
293
309
  if tool_results:
294
- self._msg_ch.send_nowait(tool_results)
310
+ self._send_client_event(tool_results)
295
311
 
296
312
  async def update_tools(self, tools: list[llm.FunctionTool]) -> None:
297
- async with self._update_fnc_ctx_lock:
298
- retained_tools: list[llm.FunctionTool] = []
299
- gemini_function_declarations: list[FunctionDeclaration] = []
300
-
301
- for tool in tools:
302
- gemini_function = _build_gemini_fnc(tool)
303
- gemini_function_declarations.append(gemini_function)
304
- retained_tools.append(tool)
305
-
306
- self._tools = llm.ToolContext(retained_tools)
307
- self._gemini_tools = [Tool(function_declarations=gemini_function_declarations)]
308
- if self._session and gemini_function_declarations:
309
- logger.warning("Updating tools; triggering Gemini session reconnect.")
310
- self._reconnect_event.set()
311
- self._schedule_gemini_session_close()
313
+ async with self._update_lock:
314
+ new_declarations: list[FunctionDeclaration] = [
315
+ _build_gemini_fnc(tool) for tool in tools
316
+ ]
317
+ current_tool_names = {f.name for f in self._gemini_declarations}
318
+ new_tool_names = {f.name for f in new_declarations}
319
+
320
+ if current_tool_names != new_tool_names:
321
+ self._gemini_declarations = new_declarations
322
+ self._tools = llm.ToolContext(tools)
323
+ self._mark_restart_needed()
312
324
 
313
325
  @property
314
326
  def chat_ctx(self) -> llm.ChatContext:
315
- return self._chat_ctx
327
+ return self._chat_ctx.copy()
316
328
 
317
329
  @property
318
330
  def tools(self) -> llm.ToolContext:
319
- return self._tools
331
+ return self._tools.copy()
320
332
 
321
333
  def push_audio(self, frame: rtc.AudioFrame) -> None:
322
- self.push_media(frame.data.tobytes(), "audio/pcm")
334
+ for f in self._resample_audio(frame):
335
+ for nf in self._bstream.write(f.data.tobytes()):
336
+ realtime_input = LiveClientRealtimeInput(
337
+ media_chunks=[Blob(data=nf.data.tobytes(), mime_type="audio/pcm")]
338
+ )
339
+ self._send_client_event(realtime_input)
323
340
 
324
341
  def push_video(self, frame: rtc.VideoFrame) -> None:
325
342
  encoded_data = images.encode(frame, DEFAULT_ENCODE_OPTIONS)
326
- self.push_media(encoded_data, "image/jpeg")
327
-
328
- def push_media(self, bytes: bytes, mime_type: str) -> None:
329
343
  realtime_input = LiveClientRealtimeInput(
330
- media_chunks=[Blob(data=bytes, mime_type=mime_type)]
344
+ media_chunks=[Blob(data=encoded_data, mime_type="image/jpeg")]
331
345
  )
332
- self._msg_ch.send_nowait(realtime_input)
346
+ self._send_client_event(realtime_input)
347
+
348
+ def _send_client_event(self, event: ClientEvents) -> None:
349
+ with contextlib.suppress(utils.aio.channel.ChanClosed):
350
+ self._msg_ch.send_nowait(event)
333
351
 
334
352
  def generate_reply(
335
353
  self, *, instructions: NotGivenOr[str] = NOT_GIVEN
336
354
  ) -> asyncio.Future[llm.GenerationCreatedEvent]:
337
- fut = asyncio.Future()
355
+ if self._pending_generation_fut and not self._pending_generation_fut.done():
356
+ logger.warning(
357
+ "generate_reply called while another generation is pending, cancelling previous."
358
+ )
359
+ self._pending_generation_fut.cancel("Superseded by new generate_reply call")
338
360
 
339
- event_id = utils.shortuuid("gemini-response-")
340
- self._response_created_futures[event_id] = fut
341
- self._pending_generation_event_id = event_id
361
+ fut = asyncio.Future()
362
+ self._pending_generation_fut = fut
342
363
 
343
- instructions_content = instructions if is_given(instructions) else "."
344
- ctx = [Content(parts=[Part(text=instructions_content)], role="user")]
345
- self._msg_ch.send_nowait(LiveClientContent(turns=ctx, turn_complete=True))
364
+ # Gemini requires the last message to end with user's turn
365
+ # so we need to add a placeholder user turn in order to trigger a new generation
366
+ event = LiveClientContent(turns=[], turn_complete=True)
367
+ if is_given(instructions):
368
+ event.turns.append(Content(parts=[Part(text=instructions)], role="model"))
369
+ event.turns.append(Content(parts=[Part(text=".")], role="user"))
370
+ self._send_client_event(event)
346
371
 
347
372
  def _on_timeout() -> None:
348
- if event_id in self._response_created_futures and not fut.done():
349
- fut.set_exception(llm.RealtimeError("generate_reply timed out."))
350
- self._response_created_futures.pop(event_id, None)
351
- if self._pending_generation_event_id == event_id:
352
- self._pending_generation_event_id = None
373
+ if not fut.done():
374
+ fut.set_exception(
375
+ llm.RealtimeError(
376
+ "generate_reply timed out waiting for generation_created event."
377
+ )
378
+ )
379
+ if self._pending_generation_fut is fut:
380
+ self._pending_generation_fut = None
353
381
 
354
- handle = asyncio.get_event_loop().call_later(5.0, _on_timeout)
355
- fut.add_done_callback(lambda _: handle.cancel())
382
+ timeout_handle = asyncio.get_event_loop().call_later(5.0, _on_timeout)
383
+ fut.add_done_callback(lambda _: timeout_handle.cancel())
356
384
 
357
385
  return fut
358
386
 
@@ -360,133 +388,195 @@ class RealtimeSession(llm.RealtimeSession):
360
388
  pass
361
389
 
362
390
  def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
391
+ logger.warning("truncate is not supported by the Google Realtime API.")
363
392
  pass
364
393
 
365
394
  async def aclose(self) -> None:
366
395
  self._msg_ch.close()
367
-
368
- for fut in self._response_created_futures.values():
369
- if not fut.done():
370
- fut.set_exception(llm.RealtimeError("Session closed"))
396
+ self._session_should_close.set()
371
397
 
372
398
  if self._main_atask:
373
399
  await utils.aio.cancel_and_wait(self._main_atask)
374
400
 
375
- if self._gemini_close_task:
376
- await utils.aio.cancel_and_wait(self._gemini_close_task)
401
+ await self._close_active_session()
402
+
403
+ if self._pending_generation_fut and not self._pending_generation_fut.done():
404
+ self._pending_generation_fut.cancel("Session closed")
405
+
406
+ for fut in self._response_created_futures.values():
407
+ if not fut.done():
408
+ fut.set_exception(llm.RealtimeError("Session closed before response created"))
409
+ self._response_created_futures.clear()
410
+
411
+ if self._current_generation:
412
+ self._finalize_response(closed=True)
377
413
 
378
414
  @utils.log_exceptions(logger=logger)
379
415
  async def _main_task(self):
380
- while True:
381
- config = LiveConnectConfig(
382
- response_modalities=self._opts.response_modalities
383
- if is_given(self._opts.response_modalities)
384
- else [Modality.AUDIO],
385
- generation_config=GenerationConfig(
386
- candidate_count=self._opts.candidate_count,
387
- temperature=self._opts.temperature
388
- if is_given(self._opts.temperature)
389
- else None,
390
- max_output_tokens=self._opts.max_output_tokens
391
- if is_given(self._opts.max_output_tokens)
392
- else None,
393
- top_p=self._opts.top_p if is_given(self._opts.top_p) else None,
394
- top_k=self._opts.top_k if is_given(self._opts.top_k) else None,
395
- presence_penalty=self._opts.presence_penalty
396
- if is_given(self._opts.presence_penalty)
397
- else None,
398
- frequency_penalty=self._opts.frequency_penalty
399
- if is_given(self._opts.frequency_penalty)
400
- else None,
401
- ),
402
- system_instruction=Content(parts=[Part(text=self._opts.instructions)])
403
- if is_given(self._opts.instructions)
404
- else None,
405
- speech_config=SpeechConfig(
406
- voice_config=VoiceConfig(
407
- prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
416
+ while not self._msg_ch.closed:
417
+ # previous session might not be closed yet, we'll do it here.
418
+ await self._close_active_session()
419
+
420
+ self._session_should_close.clear()
421
+ config = self._build_connect_config()
422
+ session = None
423
+ try:
424
+ logger.debug("connecting to Gemini Realtime API...")
425
+ async with self._client.aio.live.connect(
426
+ model=self._opts.model, config=config
427
+ ) as session:
428
+ async with self._session_lock:
429
+ self._active_session = session
430
+
431
+ # queue up existing chat context
432
+ send_task = asyncio.create_task(
433
+ self._send_task(session), name="gemini-realtime-send"
434
+ )
435
+ recv_task = asyncio.create_task(
436
+ self._recv_task(session), name="gemini-realtime-recv"
437
+ )
438
+ restart_wait_task = asyncio.create_task(
439
+ self._session_should_close.wait(), name="gemini-restart-wait"
408
440
  )
409
- ),
410
- tools=self._gemini_tools,
411
- input_audio_transcription=self._opts.input_audio_transcription,
412
- output_audio_transcription=self._opts.output_audio_transcription,
413
- )
414
-
415
- async with self._client.aio.live.connect(
416
- model=self._opts.model, config=config
417
- ) as session:
418
- async with self._session_lock:
419
- self._session = session
420
-
421
- @utils.log_exceptions(logger=logger)
422
- async def _send_task():
423
- async for msg in self._msg_ch:
424
- if isinstance(msg, LiveClientContent):
425
- await session.send(input=msg, end_of_turn=True)
426
- else:
427
- await session.send(input=msg)
428
- await session.send(input=".", end_of_turn=True)
429
-
430
- @utils.log_exceptions(logger=logger)
431
- async def _recv_task():
432
- while True:
433
- async for response in session.receive():
434
- if self._active_response_id is None:
435
- self._start_new_generation()
436
- if response.setup_complete:
437
- logger.info("connection established with gemini live api server")
438
- if response.server_content:
439
- self._handle_server_content(response.server_content)
440
- if response.tool_call:
441
- self._handle_tool_calls(response.tool_call)
442
- if response.tool_call_cancellation:
443
- self._handle_tool_call_cancellation(response.tool_call_cancellation)
444
- if response.usage_metadata:
445
- self._handle_usage_metadata(response.usage_metadata)
446
- if response.go_away:
447
- self._handle_go_away(response.go_away)
448
-
449
- send_task = asyncio.create_task(_send_task(), name="gemini-realtime-send")
450
- recv_task = asyncio.create_task(_recv_task(), name="gemini-realtime-recv")
451
- reconnect_task = asyncio.create_task(
452
- self._reconnect_event.wait(), name="reconnect-wait"
453
- )
454
441
 
455
- try:
456
- done, _ = await asyncio.wait(
457
- [send_task, recv_task, reconnect_task],
442
+ done, pending = await asyncio.wait(
443
+ [send_task, recv_task, restart_wait_task],
458
444
  return_when=asyncio.FIRST_COMPLETED,
459
445
  )
446
+
460
447
  for task in done:
461
- if task != reconnect_task:
462
- task.result()
448
+ if task is not restart_wait_task and task.exception():
449
+ logger.error(f"error in task {task.get_name()}: {task.exception()}")
450
+ raise task.exception() or Exception(f"{task.get_name()} failed")
463
451
 
464
- if reconnect_task not in done:
452
+ if restart_wait_task not in done and self._msg_ch.closed:
465
453
  break
466
454
 
467
- self._reconnect_event.clear()
468
- finally:
469
- await utils.aio.cancel_and_wait(send_task, recv_task, reconnect_task)
455
+ for task in pending:
456
+ await utils.aio.cancel_and_wait(task)
457
+
458
+ except asyncio.CancelledError:
459
+ break
460
+ except Exception as e:
461
+ logger.error(f"Gemini Realtime API error: {e}", exc_info=e)
462
+ if not self._msg_ch.closed:
463
+ logger.info("attempting to reconnect after 1 seconds...")
464
+ await asyncio.sleep(1)
465
+ finally:
466
+ await self._close_active_session()
467
+
468
+ async def _send_task(self, session: genai.LiveSession):
469
+ try:
470
+ async for msg in self._msg_ch:
471
+ async with self._session_lock:
472
+ if self._session_should_close.is_set() or (
473
+ not self._active_session or self._active_session != session
474
+ ):
475
+ break
476
+
477
+ if isinstance(msg, LiveClientContent):
478
+ await session.send(input=msg)
479
+ else:
480
+ await session.send(input=msg)
481
+ except Exception as e:
482
+ if not self._session_should_close.is_set():
483
+ logger.error(f"error in send task: {e}", exc_info=e)
484
+ self._mark_restart_needed()
485
+ finally:
486
+ logger.debug("send task finished.")
487
+
488
+ async def _recv_task(self, session: genai.LiveSession):
489
+ try:
490
+ while True:
491
+ async with self._session_lock:
492
+ if self._session_should_close.is_set() or (
493
+ not self._active_session or self._active_session != session
494
+ ):
495
+ logger.debug("receive task: Session changed or closed, stopping receive.")
496
+ break
497
+
498
+ async for response in session.receive():
499
+ if not self._current_generation and (
500
+ response.server_content or response.tool_call
501
+ ):
502
+ self._start_new_generation()
503
+
504
+ if response.server_content:
505
+ self._handle_server_content(response.server_content)
506
+ if response.tool_call:
507
+ self._handle_tool_calls(response.tool_call)
508
+ if response.tool_call_cancellation:
509
+ self._handle_tool_call_cancellation(response.tool_call_cancellation)
510
+ if response.usage_metadata:
511
+ self._handle_usage_metadata(response.usage_metadata)
512
+ if response.go_away:
513
+ self._handle_go_away(response.go_away)
514
+
515
+ # TODO(dz): a server-side turn is complete
516
+ except Exception as e:
517
+ if not self._session_should_close.is_set():
518
+ logger.error(f"error in receive task: {e}", exc_info=e)
519
+ self._mark_restart_needed()
520
+ finally:
521
+ self._finalize_response(closed=True)
522
+
523
+ def _build_connect_config(self) -> LiveConnectConfig:
524
+ temp = self._opts.temperature if is_given(self._opts.temperature) else None
525
+
526
+ return LiveConnectConfig(
527
+ response_modalities=self._opts.response_modalities
528
+ if is_given(self._opts.response_modalities)
529
+ else [Modality.AUDIO],
530
+ generation_config=GenerationConfig(
531
+ candidate_count=self._opts.candidate_count,
532
+ temperature=temp,
533
+ max_output_tokens=self._opts.max_output_tokens
534
+ if is_given(self._opts.max_output_tokens)
535
+ else None,
536
+ top_p=self._opts.top_p if is_given(self._opts.top_p) else None,
537
+ top_k=self._opts.top_k if is_given(self._opts.top_k) else None,
538
+ presence_penalty=self._opts.presence_penalty
539
+ if is_given(self._opts.presence_penalty)
540
+ else None,
541
+ frequency_penalty=self._opts.frequency_penalty
542
+ if is_given(self._opts.frequency_penalty)
543
+ else None,
544
+ ),
545
+ system_instruction=Content(parts=[Part(text=self._opts.instructions)])
546
+ if is_given(self._opts.instructions)
547
+ else None,
548
+ speech_config=SpeechConfig(
549
+ voice_config=VoiceConfig(
550
+ prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
551
+ )
552
+ ),
553
+ tools=[Tool(function_declarations=self._gemini_declarations)],
554
+ input_audio_transcription=self._opts.input_audio_transcription,
555
+ output_audio_transcription=self._opts.output_audio_transcription,
556
+ )
470
557
 
471
558
  def _start_new_generation(self):
472
- self._is_interrupted = False
473
- self._active_response_id = utils.shortuuid("gemini-turn-")
559
+ if self._current_generation:
560
+ logger.warning("starting new generation while another is active. Finalizing previous.")
561
+ self._finalize_response(closed=True)
562
+
563
+ response_id = utils.shortuuid("gemini-turn-")
474
564
  self._current_generation = _ResponseGeneration(
475
565
  message_ch=utils.aio.Chan[llm.MessageGeneration](),
476
566
  function_ch=utils.aio.Chan[llm.FunctionCall](),
477
567
  messages={},
478
568
  )
479
569
 
480
- # We'll assume each chunk belongs to a single message ID self._active_response_id
481
570
  item_generation = _MessageGeneration(
482
- message_id=self._active_response_id,
571
+ message_id=response_id,
483
572
  text_ch=utils.aio.Chan[str](),
484
573
  audio_ch=utils.aio.Chan[rtc.AudioFrame](),
485
574
  )
575
+ self._current_generation.messages[response_id] = item_generation
486
576
 
487
577
  self._current_generation.message_ch.send_nowait(
488
578
  llm.MessageGeneration(
489
- message_id=self._active_response_id,
579
+ message_id=response_id,
490
580
  text_stream=item_generation.text_ch,
491
581
  audio_stream=item_generation.audio_ch,
492
582
  )
@@ -498,84 +588,92 @@ class RealtimeSession(llm.RealtimeSession):
498
588
  user_initiated=False,
499
589
  )
500
590
 
501
- # Resolve any pending future from generate_reply()
502
- if self._pending_generation_event_id and (
503
- fut := self._response_created_futures.pop(self._pending_generation_event_id, None)
504
- ):
505
- fut.set_result(generation_event)
591
+ if self._pending_generation_fut and not self._pending_generation_fut.done():
592
+ generation_event.user_initiated = True
593
+ self._pending_generation_fut.set_result(generation_event)
594
+ self._pending_generation_fut = None
506
595
 
507
- self._pending_generation_event_id = None
508
596
  self.emit("generation_created", generation_event)
509
597
 
510
- self._current_generation.messages[self._active_response_id] = item_generation
511
-
512
598
  def _handle_server_content(self, server_content: LiveServerContent):
513
- if not self._current_generation or not self._active_response_id:
514
- logger.warning(
515
- "gemini-realtime-session: No active response ID, skipping server content"
516
- )
599
+ if not self._current_generation:
600
+ logger.warning("received server content but no active generation.")
517
601
  return
518
602
 
519
- item_generation = self._current_generation.messages[self._active_response_id]
603
+ response_id = list(self._current_generation.messages.keys())[0]
604
+ item_generation = self._current_generation.messages[response_id]
520
605
 
521
- model_turn = server_content.model_turn
522
- if model_turn:
606
+ if model_turn := server_content.model_turn:
523
607
  for part in model_turn.parts:
524
608
  if part.text:
525
609
  item_generation.text_ch.send_nowait(part.text)
526
610
  if part.inline_data:
527
611
  frame_data = part.inline_data.data
528
- frame = rtc.AudioFrame(
529
- data=frame_data,
530
- sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
531
- num_channels=NUM_CHANNELS,
532
- samples_per_channel=len(frame_data) // 2,
533
- )
534
- item_generation.audio_ch.send_nowait(frame)
535
- input_transcription = server_content.input_transcription
536
- if input_transcription and input_transcription.text:
537
- self.emit(
538
- "input_audio_transcription_completed",
539
- llm.InputTranscriptionCompleted(
540
- item_id=self._active_response_id, transcript=input_transcription.text
541
- ),
542
- )
543
- output_transcription = server_content.output_transcription
544
- if output_transcription and output_transcription.text:
545
- item_generation.text_ch.send_nowait(output_transcription.text)
612
+ try:
613
+ frame = rtc.AudioFrame(
614
+ data=frame_data,
615
+ sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
616
+ num_channels=OUTPUT_AUDIO_CHANNELS,
617
+ samples_per_channel=len(frame_data) // (2 * OUTPUT_AUDIO_CHANNELS),
618
+ )
619
+ item_generation.audio_ch.send_nowait(frame)
620
+ except ValueError as e:
621
+ logger.error(f"Error creating audio frame from Gemini data: {e}")
622
+
623
+ if input_transcription := server_content.input_transcription:
624
+ if input_transcription.text:
625
+ self.emit(
626
+ "input_audio_transcription_completed",
627
+ llm.InputTranscriptionCompleted(
628
+ item_id=response_id, transcript=input_transcription.text
629
+ ),
630
+ )
631
+ self._handle_input_speech_started()
632
+
633
+ if output_transcription := server_content.output_transcription:
634
+ if output_transcription.text:
635
+ item_generation.text_ch.send_nowait(output_transcription.text)
636
+
546
637
  if server_content.interrupted:
547
- self._finalize_response()
638
+ self._finalize_response(interrupted=True)
548
639
  self._handle_input_speech_started()
549
640
 
550
641
  if server_content.turn_complete:
551
642
  self._finalize_response()
552
643
 
553
- def _finalize_response(self) -> None:
644
+ def _finalize_response(self, interrupted: bool = False, closed: bool = False) -> None:
554
645
  if not self._current_generation:
555
646
  return
556
647
 
557
- for item_generation in self._current_generation.messages.values():
558
- item_generation.text_ch.close()
559
- item_generation.audio_ch.close()
560
-
561
- self._current_generation.function_ch.close()
562
- self._current_generation.message_ch.close()
648
+ gen = self._current_generation
563
649
  self._current_generation = None
564
- self._is_interrupted = True
565
- self._active_response_id = None
650
+
651
+ for item_generation in gen.messages.values():
652
+ if not item_generation.text_ch.closed:
653
+ item_generation.text_ch.close()
654
+ if not item_generation.audio_ch.closed:
655
+ item_generation.audio_ch.close()
656
+
657
+ gen.function_ch.close()
658
+ gen.message_ch.close()
566
659
 
567
660
  def _handle_input_speech_started(self):
568
661
  self.emit("input_speech_started", llm.InputSpeechStartedEvent())
569
662
 
570
663
  def _handle_tool_calls(self, tool_call: LiveServerToolCall):
571
664
  if not self._current_generation:
665
+ logger.warning("received tool call but no active generation.")
572
666
  return
667
+
668
+ gen = self._current_generation
573
669
  for fnc_call in tool_call.function_calls:
574
- self._current_generation.function_ch.send_nowait(
670
+ arguments = json.dumps(fnc_call.args)
671
+
672
+ gen.function_ch.send_nowait(
575
673
  llm.FunctionCall(
576
- call_id=fnc_call.id or "",
674
+ call_id=fnc_call.id or utils.shortuuid("fnc-call-"),
577
675
  name=fnc_call.name,
578
- arguments=json.dumps(fnc_call.args),
676
+ arguments=arguments,
579
677
  )
580
678
  )
581
679
  self._finalize_response()
@@ -584,28 +682,45 @@ class RealtimeSession(llm.RealtimeSession):
584
682
  self, tool_call_cancellation: LiveServerToolCallCancellation
585
683
  ):
586
684
  logger.warning(
587
- "function call cancelled",
588
- extra={
589
- "function_call_ids": tool_call_cancellation.ids,
590
- },
685
+ "server cancelled tool calls",
686
+ extra={"function_call_ids": tool_call_cancellation.ids},
591
687
  )
592
- self.emit("function_calls_cancelled", tool_call_cancellation.ids)
593
688
 
594
689
  def _handle_usage_metadata(self, usage_metadata: UsageMetadata):
595
- # todo: handle metrics
596
- logger.info("Usage metadata", extra={"usage_metadata": usage_metadata})
690
+ # TODO: handle metrics
691
+ logger.debug("usage metadata", extra={"usage_metadata": usage_metadata})
597
692
 
598
693
  def _handle_go_away(self, go_away: LiveServerGoAway):
599
- # should we reconnect?
600
694
  logger.warning(
601
- f"gemini live api server will soon disconnect. time left: {go_away.time_left}"
695
+ f"Gemini server indicates disconnection soon. Time left: {go_away.time_left}"
602
696
  )
697
+ # TODO(dz): this isn't a seamless reconnection just yet
698
+ self._session_should_close.set()
603
699
 
604
700
  def commit_audio(self) -> None:
605
- raise NotImplementedError("commit_audio_buffer is not supported yet")
701
+ pass
606
702
 
607
703
  def clear_audio(self) -> None:
608
- raise NotImplementedError("clear_audio is not supported yet")
704
+ self._bstream.clear()
609
705
 
610
- def server_vad_enabled(self) -> bool:
611
- return True
706
+ def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
707
+ if self._input_resampler:
708
+ if frame.sample_rate != self._input_resampler._input_rate:
709
+ # input audio changed to a different sample rate
710
+ self._input_resampler = None
711
+
712
+ if self._input_resampler is None and (
713
+ frame.sample_rate != INPUT_AUDIO_SAMPLE_RATE
714
+ or frame.num_channels != INPUT_AUDIO_CHANNELS
715
+ ):
716
+ self._input_resampler = rtc.AudioResampler(
717
+ input_rate=frame.sample_rate,
718
+ output_rate=INPUT_AUDIO_SAMPLE_RATE,
719
+ num_channels=INPUT_AUDIO_CHANNELS,
720
+ )
721
+
722
+ if self._input_resampler:
723
+ # TODO(long): flush the resampler when the input source is changed
724
+ yield from self._input_resampler.push(frame)
725
+ else:
726
+ yield frame
@@ -28,7 +28,7 @@ def get_tool_results_for_realtime(chat_ctx: llm.ChatContext) -> types.LiveClient
28
28
  types.FunctionResponse(
29
29
  id=msg.call_id,
30
30
  name=msg.name,
31
- response={"text": msg.output},
31
+ response={"output": msg.output},
32
32
  )
33
33
  )
34
34
  return (
@@ -99,9 +99,11 @@ def to_chat_ctx(
99
99
  if current_role is not None and parts:
100
100
  turns.append(types.Content(role=current_role, parts=parts))
101
101
 
102
- if not turns:
103
- # if no turns, add a user message with a placeholder
104
- turns = [types.Content(role="user", parts=[types.Part(text=".")])]
102
+ # # Gemini requires the last message to end with user's turn before they can generate
103
+ # # currently not used because to_chat_ctx should not be used to force a new generation
104
+ # if current_role != "user":
105
+ # turns.append(types.Content(role="user", parts=[types.Part(text=".")]))
106
+
105
107
  return turns, system_instruction
106
108
 
107
109
 
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.0.15"
15
+ __version__ = "1.0.17"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-google
3
- Version: 1.0.15
3
+ Version: 1.0.17
4
4
  Summary: Agent Framework plugin for services from Google Cloud
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
@@ -21,8 +21,8 @@ Requires-Python: >=3.9.0
21
21
  Requires-Dist: google-auth<3,>=2
22
22
  Requires-Dist: google-cloud-speech<3,>=2
23
23
  Requires-Dist: google-cloud-texttospeech<3,>=2
24
- Requires-Dist: google-genai>=1.10.0
25
- Requires-Dist: livekit-agents>=1.0.15
24
+ Requires-Dist: google-genai>=1.11.0
25
+ Requires-Dist: livekit-agents>=1.0.17
26
26
  Description-Content-Type: text/markdown
27
27
 
28
28
  # LiveKit Plugins Google
@@ -5,12 +5,12 @@ livekit/plugins/google/models.py,sha256=SGjAumdDK97NNLwMFcqZdKR68f1NoGB2Rk1UP2-i
5
5
  livekit/plugins/google/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  livekit/plugins/google/stt.py,sha256=AG_lh2fuuduJi0jFbA_QKFXLJ6NUdF1W_FfkLUJML_Q,22413
7
7
  livekit/plugins/google/tts.py,sha256=xhINokqY8UutXn85N-cbzq68eptbM6TTtIXmLktE_RM,9004
8
- livekit/plugins/google/utils.py,sha256=pbLSOAdQxInWhgI2Yhsrr9KvgvpFXYDdU2yx2p03pFg,9437
9
- livekit/plugins/google/version.py,sha256=wHPUkZRYx-OB6iDuwTmMNVVQXU9eg5xFSjgmKBqqwd4,601
8
+ livekit/plugins/google/utils.py,sha256=TjjTwMbdJdxr3bZjUXxs-J_fipTTM00goW2-d9KWX6w,9582
9
+ livekit/plugins/google/version.py,sha256=GOfJB-DKZur-i3hrjFbzgpC2NHE96dnWhGLziW1e0_E,601
10
10
  livekit/plugins/google/beta/__init__.py,sha256=AxRYc7NGG62Tv1MmcZVCDHNvlhbC86hM-_yP01Qb28k,47
11
11
  livekit/plugins/google/beta/realtime/__init__.py,sha256=_fW2NMN22F-hnQ4xAJ_g5lPbR7CvM_xXzSWlUQY-E-U,188
12
12
  livekit/plugins/google/beta/realtime/api_proto.py,sha256=Fyrejs3SG0EjOPCCFLEnWXKEUxCff47PMWk2VsKJm5E,594
13
- livekit/plugins/google/beta/realtime/realtime_api.py,sha256=HvPYyQXC9OodWaDNxbRt1UAJ8IVdXZGK-PsIEr7UwbY,25078
14
- livekit_plugins_google-1.0.15.dist-info/METADATA,sha256=wMOLBkgHx_fJ0o5s8URB7Ev6yEg2jhKHhb0OlH1_7p4,3492
15
- livekit_plugins_google-1.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- livekit_plugins_google-1.0.15.dist-info/RECORD,,
13
+ livekit/plugins/google/beta/realtime/realtime_api.py,sha256=2_nPBvPttVudoQswhf19ieJ6wxvHquGJgALJ09afQms,29873
14
+ livekit_plugins_google-1.0.17.dist-info/METADATA,sha256=cKeNSFwiM2A-MJeNA6zNeX7ioqbvkEZO3aFfR8Run2c,3492
15
+ livekit_plugins_google-1.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ livekit_plugins_google-1.0.17.dist-info/RECORD,,