livekit-plugins-google 1.0.22__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/__init__.py +3 -2
- livekit/plugins/google/beta/realtime/api_proto.py +8 -2
- livekit/plugins/google/beta/realtime/realtime_api.py +316 -117
- livekit/plugins/google/llm.py +62 -32
- livekit/plugins/google/models.py +1 -0
- livekit/plugins/google/stt.py +19 -12
- livekit/plugins/google/tools.py +11 -0
- livekit/plugins/google/tts.py +109 -136
- livekit/plugins/google/utils.py +39 -88
- livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-1.0.22.dist-info → livekit_plugins_google-1.1.0.dist-info}/METADATA +3 -3
- livekit_plugins_google-1.1.0.dist-info/RECORD +17 -0
- livekit_plugins_google-1.0.22.dist-info/RECORD +0 -16
- {livekit_plugins_google-1.0.22.dist-info → livekit_plugins_google-1.1.0.dist-info}/WHEEL +0 -0
@@ -10,42 +10,23 @@ from collections.abc import Iterator
|
|
10
10
|
from dataclasses import dataclass, field
|
11
11
|
|
12
12
|
from google import genai
|
13
|
+
from google.genai import types
|
13
14
|
from google.genai.live import AsyncSession
|
14
|
-
from google.genai.types import (
|
15
|
-
AudioTranscriptionConfig,
|
16
|
-
AutomaticActivityDetection,
|
17
|
-
Blob,
|
18
|
-
Content,
|
19
|
-
FunctionDeclaration,
|
20
|
-
GenerationConfig,
|
21
|
-
LiveClientContent,
|
22
|
-
LiveClientRealtimeInput,
|
23
|
-
LiveClientToolResponse,
|
24
|
-
LiveConnectConfig,
|
25
|
-
LiveServerContent,
|
26
|
-
LiveServerGoAway,
|
27
|
-
LiveServerToolCall,
|
28
|
-
LiveServerToolCallCancellation,
|
29
|
-
Modality,
|
30
|
-
ModalityTokenCount,
|
31
|
-
Part,
|
32
|
-
PrebuiltVoiceConfig,
|
33
|
-
RealtimeInputConfig,
|
34
|
-
SessionResumptionConfig,
|
35
|
-
SpeechConfig,
|
36
|
-
Tool,
|
37
|
-
UsageMetadata,
|
38
|
-
VoiceConfig,
|
39
|
-
)
|
40
15
|
from livekit import rtc
|
41
|
-
from livekit.agents import llm, utils
|
16
|
+
from livekit.agents import APIConnectionError, llm, utils
|
42
17
|
from livekit.agents.metrics import RealtimeModelMetrics
|
43
|
-
from livekit.agents.types import
|
18
|
+
from livekit.agents.types import (
|
19
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
20
|
+
NOT_GIVEN,
|
21
|
+
APIConnectOptions,
|
22
|
+
NotGivenOr,
|
23
|
+
)
|
44
24
|
from livekit.agents.utils import audio as audio_utils, images, is_given
|
45
25
|
from livekit.plugins.google.beta.realtime.api_proto import ClientEvents, LiveAPIModels, Voice
|
46
26
|
|
47
27
|
from ...log import logger
|
48
|
-
from ...
|
28
|
+
from ...tools import _LLMTool
|
29
|
+
from ...utils import create_tools_config, get_tool_results_for_realtime, to_fnc_ctx
|
49
30
|
|
50
31
|
INPUT_AUDIO_SAMPLE_RATE = 16000
|
51
32
|
INPUT_AUDIO_CHANNELS = 1
|
@@ -71,7 +52,7 @@ class _RealtimeOptions:
|
|
71
52
|
api_key: str | None
|
72
53
|
voice: Voice | str
|
73
54
|
language: NotGivenOr[str]
|
74
|
-
response_modalities: NotGivenOr[list[Modality]]
|
55
|
+
response_modalities: NotGivenOr[list[types.Modality]]
|
75
56
|
vertexai: bool
|
76
57
|
project: str | None
|
77
58
|
location: str | None
|
@@ -83,9 +64,16 @@ class _RealtimeOptions:
|
|
83
64
|
presence_penalty: NotGivenOr[float]
|
84
65
|
frequency_penalty: NotGivenOr[float]
|
85
66
|
instructions: NotGivenOr[str]
|
86
|
-
input_audio_transcription: AudioTranscriptionConfig | None
|
87
|
-
output_audio_transcription: AudioTranscriptionConfig | None
|
67
|
+
input_audio_transcription: types.AudioTranscriptionConfig | None
|
68
|
+
output_audio_transcription: types.AudioTranscriptionConfig | None
|
88
69
|
image_encode_options: NotGivenOr[images.EncodeOptions]
|
70
|
+
conn_options: APIConnectOptions
|
71
|
+
enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN
|
72
|
+
proactivity: NotGivenOr[bool] = NOT_GIVEN
|
73
|
+
realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN
|
74
|
+
context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN
|
75
|
+
api_version: NotGivenOr[str] = NOT_GIVEN
|
76
|
+
gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN
|
89
77
|
|
90
78
|
|
91
79
|
@dataclass
|
@@ -93,10 +81,13 @@ class _ResponseGeneration:
|
|
93
81
|
message_ch: utils.aio.Chan[llm.MessageGeneration]
|
94
82
|
function_ch: utils.aio.Chan[llm.FunctionCall]
|
95
83
|
|
84
|
+
input_id: str
|
96
85
|
response_id: str
|
97
86
|
text_ch: utils.aio.Chan[str]
|
98
87
|
audio_ch: utils.aio.Chan[rtc.AudioFrame]
|
88
|
+
|
99
89
|
input_transcription: str = ""
|
90
|
+
output_text: str = ""
|
100
91
|
|
101
92
|
_created_timestamp: float = field(default_factory=time.time)
|
102
93
|
"""The timestamp when the generation is created"""
|
@@ -107,6 +98,14 @@ class _ResponseGeneration:
|
|
107
98
|
_done: bool = False
|
108
99
|
"""Whether the generation is done (set when the turn is complete)"""
|
109
100
|
|
101
|
+
def push_text(self, text: str) -> None:
|
102
|
+
if self.output_text:
|
103
|
+
self.output_text += text
|
104
|
+
else:
|
105
|
+
self.output_text = text
|
106
|
+
|
107
|
+
self.text_ch.send_nowait(text)
|
108
|
+
|
110
109
|
|
111
110
|
class RealtimeModel(llm.RealtimeModel):
|
112
111
|
def __init__(
|
@@ -117,7 +116,7 @@ class RealtimeModel(llm.RealtimeModel):
|
|
117
116
|
api_key: NotGivenOr[str] = NOT_GIVEN,
|
118
117
|
voice: Voice | str = "Puck",
|
119
118
|
language: NotGivenOr[str] = NOT_GIVEN,
|
120
|
-
modalities: NotGivenOr[list[Modality]] = NOT_GIVEN,
|
119
|
+
modalities: NotGivenOr[list[types.Modality]] = NOT_GIVEN,
|
121
120
|
vertexai: NotGivenOr[bool] = NOT_GIVEN,
|
122
121
|
project: NotGivenOr[str] = NOT_GIVEN,
|
123
122
|
location: NotGivenOr[str] = NOT_GIVEN,
|
@@ -128,9 +127,16 @@ class RealtimeModel(llm.RealtimeModel):
|
|
128
127
|
top_k: NotGivenOr[int] = NOT_GIVEN,
|
129
128
|
presence_penalty: NotGivenOr[float] = NOT_GIVEN,
|
130
129
|
frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
|
131
|
-
input_audio_transcription: NotGivenOr[AudioTranscriptionConfig | None] = NOT_GIVEN,
|
132
|
-
output_audio_transcription: NotGivenOr[AudioTranscriptionConfig | None] = NOT_GIVEN,
|
130
|
+
input_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
|
131
|
+
output_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
|
133
132
|
image_encode_options: NotGivenOr[images.EncodeOptions] = NOT_GIVEN,
|
133
|
+
enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN,
|
134
|
+
proactivity: NotGivenOr[bool] = NOT_GIVEN,
|
135
|
+
realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN,
|
136
|
+
context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN,
|
137
|
+
api_version: NotGivenOr[str] = NOT_GIVEN,
|
138
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
139
|
+
_gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
|
134
140
|
) -> None:
|
135
141
|
"""
|
136
142
|
Initializes a RealtimeModel instance for interacting with Google's Realtime API.
|
@@ -161,19 +167,33 @@ class RealtimeModel(llm.RealtimeModel):
|
|
161
167
|
input_audio_transcription (AudioTranscriptionConfig | None, optional): The configuration for input audio transcription. Defaults to None.)
|
162
168
|
output_audio_transcription (AudioTranscriptionConfig | None, optional): The configuration for output audio transcription. Defaults to AudioTranscriptionConfig().
|
163
169
|
image_encode_options (images.EncodeOptions, optional): The configuration for image encoding. Defaults to DEFAULT_ENCODE_OPTIONS.
|
170
|
+
enable_affective_dialog (bool, optional): Whether to enable affective dialog. Defaults to False.
|
171
|
+
proactivity (bool, optional): Whether to enable proactive audio. Defaults to False.
|
172
|
+
realtime_input_config (RealtimeInputConfig, optional): The configuration for realtime input. Defaults to None.
|
173
|
+
context_window_compression (ContextWindowCompressionConfig, optional): The configuration for context window compression. Defaults to None.
|
174
|
+
conn_options (APIConnectOptions, optional): The configuration for the API connection. Defaults to DEFAULT_API_CONNECT_OPTIONS.
|
175
|
+
_gemini_tools (list[LLMTool], optional): Gemini-specific tools to use for the session. This parameter is experimental and may change.
|
164
176
|
|
165
177
|
Raises:
|
166
178
|
ValueError: If the API key is required but not found.
|
167
179
|
""" # noqa: E501
|
168
180
|
if not is_given(input_audio_transcription):
|
169
|
-
input_audio_transcription = AudioTranscriptionConfig()
|
181
|
+
input_audio_transcription = types.AudioTranscriptionConfig()
|
170
182
|
if not is_given(output_audio_transcription):
|
171
|
-
output_audio_transcription = AudioTranscriptionConfig()
|
183
|
+
output_audio_transcription = types.AudioTranscriptionConfig()
|
184
|
+
|
185
|
+
server_turn_detection = True
|
186
|
+
if (
|
187
|
+
is_given(realtime_input_config)
|
188
|
+
and realtime_input_config.automatic_activity_detection
|
189
|
+
and realtime_input_config.automatic_activity_detection.disabled
|
190
|
+
):
|
191
|
+
server_turn_detection = False
|
172
192
|
|
173
193
|
super().__init__(
|
174
194
|
capabilities=llm.RealtimeCapabilities(
|
175
195
|
message_truncation=False,
|
176
|
-
turn_detection=
|
196
|
+
turn_detection=server_turn_detection,
|
177
197
|
user_transcription=input_audio_transcription is not None,
|
178
198
|
auto_tool_reply_generation=True,
|
179
199
|
)
|
@@ -187,7 +207,7 @@ class RealtimeModel(llm.RealtimeModel):
|
|
187
207
|
|
188
208
|
gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
|
189
209
|
gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
|
190
|
-
gcp_location = (
|
210
|
+
gcp_location: str | None = (
|
191
211
|
location
|
192
212
|
if is_given(location)
|
193
213
|
else os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
@@ -232,6 +252,13 @@ class RealtimeModel(llm.RealtimeModel):
|
|
232
252
|
output_audio_transcription=output_audio_transcription,
|
233
253
|
language=language,
|
234
254
|
image_encode_options=image_encode_options,
|
255
|
+
enable_affective_dialog=enable_affective_dialog,
|
256
|
+
proactivity=proactivity,
|
257
|
+
realtime_input_config=realtime_input_config,
|
258
|
+
context_window_compression=context_window_compression,
|
259
|
+
api_version=api_version,
|
260
|
+
gemini_tools=_gemini_tools,
|
261
|
+
conn_options=conn_options,
|
235
262
|
)
|
236
263
|
|
237
264
|
self._sessions = weakref.WeakSet[RealtimeSession]()
|
@@ -242,8 +269,19 @@ class RealtimeModel(llm.RealtimeModel):
|
|
242
269
|
return sess
|
243
270
|
|
244
271
|
def update_options(
|
245
|
-
self,
|
272
|
+
self,
|
273
|
+
*,
|
274
|
+
voice: NotGivenOr[str] = NOT_GIVEN,
|
275
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
246
276
|
) -> None:
|
277
|
+
"""
|
278
|
+
Update the options for the RealtimeModel.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
voice (str, optional): The voice to use for the session.
|
282
|
+
temperature (float, optional): The temperature to use for the session.
|
283
|
+
tools (list[LLMTool], optional): The tools to use for the session.
|
284
|
+
"""
|
247
285
|
if is_given(voice):
|
248
286
|
self._opts.voice = voice
|
249
287
|
|
@@ -251,7 +289,10 @@ class RealtimeModel(llm.RealtimeModel):
|
|
251
289
|
self._opts.temperature = temperature
|
252
290
|
|
253
291
|
for sess in self._sessions:
|
254
|
-
sess.update_options(
|
292
|
+
sess.update_options(
|
293
|
+
voice=self._opts.voice,
|
294
|
+
temperature=self._opts.temperature,
|
295
|
+
)
|
255
296
|
|
256
297
|
async def aclose(self) -> None:
|
257
298
|
pass
|
@@ -262,7 +303,7 @@ class RealtimeSession(llm.RealtimeSession):
|
|
262
303
|
super().__init__(realtime_model)
|
263
304
|
self._opts = realtime_model._opts
|
264
305
|
self._tools = llm.ToolContext.empty()
|
265
|
-
self._gemini_declarations: list[FunctionDeclaration] = []
|
306
|
+
self._gemini_declarations: list[types.FunctionDeclaration] = []
|
266
307
|
self._chat_ctx = llm.ChatContext.empty()
|
267
308
|
self._msg_ch = utils.aio.Chan[ClientEvents]()
|
268
309
|
self._input_resampler: rtc.AudioResampler | None = None
|
@@ -274,11 +315,20 @@ class RealtimeSession(llm.RealtimeSession):
|
|
274
315
|
samples_per_channel=INPUT_AUDIO_SAMPLE_RATE // 20,
|
275
316
|
)
|
276
317
|
|
318
|
+
api_version = self._opts.api_version
|
319
|
+
if not api_version and (self._opts.enable_affective_dialog or self._opts.proactivity):
|
320
|
+
api_version = "v1alpha"
|
321
|
+
|
322
|
+
http_options = types.HttpOptions(timeout=int(self._opts.conn_options.timeout * 1000))
|
323
|
+
if api_version:
|
324
|
+
http_options.api_version = api_version
|
325
|
+
|
277
326
|
self._client = genai.Client(
|
278
327
|
api_key=self._opts.api_key,
|
279
328
|
vertexai=self._opts.vertexai,
|
280
329
|
project=self._opts.project,
|
281
330
|
location=self._opts.location,
|
331
|
+
http_options=http_options,
|
282
332
|
)
|
283
333
|
|
284
334
|
self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
|
@@ -291,8 +341,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
291
341
|
self._pending_generation_fut: asyncio.Future[llm.GenerationCreatedEvent] | None = None
|
292
342
|
|
293
343
|
self._session_resumption_handle: str | None = None
|
294
|
-
|
344
|
+
self._in_user_activity = False
|
295
345
|
self._session_lock = asyncio.Lock()
|
346
|
+
self._num_retries = 0
|
296
347
|
|
297
348
|
async def _close_active_session(self) -> None:
|
298
349
|
async with self._session_lock:
|
@@ -304,7 +355,7 @@ class RealtimeSession(llm.RealtimeSession):
|
|
304
355
|
finally:
|
305
356
|
self._active_session = None
|
306
357
|
|
307
|
-
def _mark_restart_needed(self):
|
358
|
+
def _mark_restart_needed(self) -> None:
|
308
359
|
if not self._session_should_close.is_set():
|
309
360
|
self._session_should_close.set()
|
310
361
|
# reset the msg_ch, do not send messages from previous session
|
@@ -335,6 +386,11 @@ class RealtimeSession(llm.RealtimeSession):
|
|
335
386
|
self._mark_restart_needed()
|
336
387
|
|
337
388
|
async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
|
389
|
+
async with self._session_lock:
|
390
|
+
if not self._active_session:
|
391
|
+
self._chat_ctx = chat_ctx.copy()
|
392
|
+
return
|
393
|
+
|
338
394
|
diff_ops = llm.utils.compute_chat_ctx_diff(self._chat_ctx, chat_ctx)
|
339
395
|
|
340
396
|
if diff_ops.to_remove:
|
@@ -347,15 +403,23 @@ class RealtimeSession(llm.RealtimeSession):
|
|
347
403
|
append_ctx.items.append(item)
|
348
404
|
|
349
405
|
if append_ctx.items:
|
350
|
-
|
406
|
+
turns_dict, _ = append_ctx.copy(
|
407
|
+
exclude_function_call=True,
|
408
|
+
).to_provider_format(format="google", inject_dummy_user_message=False)
|
409
|
+
# we are not generating, and do not need to inject
|
410
|
+
turns = [types.Content.model_validate(turn) for turn in turns_dict]
|
351
411
|
tool_results = get_tool_results_for_realtime(append_ctx, vertexai=self._opts.vertexai)
|
352
412
|
if turns:
|
353
|
-
self._send_client_event(LiveClientContent(turns=turns, turn_complete=False))
|
413
|
+
self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=False))
|
354
414
|
if tool_results:
|
355
415
|
self._send_client_event(tool_results)
|
356
416
|
|
357
|
-
|
358
|
-
|
417
|
+
# since we don't have a view of the history on the server side, we'll assume
|
418
|
+
# the current state is accurate. this isn't perfect because removals aren't done.
|
419
|
+
self._chat_ctx = chat_ctx.copy()
|
420
|
+
|
421
|
+
async def update_tools(self, tools: list[llm.FunctionTool | llm.RawFunctionTool]) -> None:
|
422
|
+
new_declarations: list[types.FunctionDeclaration] = to_fnc_ctx(tools)
|
359
423
|
current_tool_names = {f.name for f in self._gemini_declarations}
|
360
424
|
new_tool_names = {f.name for f in new_declarations}
|
361
425
|
|
@@ -372,11 +436,21 @@ class RealtimeSession(llm.RealtimeSession):
|
|
372
436
|
def tools(self) -> llm.ToolContext:
|
373
437
|
return self._tools.copy()
|
374
438
|
|
439
|
+
@property
|
440
|
+
def _manual_activity_detection(self) -> bool:
|
441
|
+
if (
|
442
|
+
is_given(self._opts.realtime_input_config)
|
443
|
+
and self._opts.realtime_input_config.automatic_activity_detection is not None
|
444
|
+
and self._opts.realtime_input_config.automatic_activity_detection.disabled
|
445
|
+
):
|
446
|
+
return True
|
447
|
+
return False
|
448
|
+
|
375
449
|
def push_audio(self, frame: rtc.AudioFrame) -> None:
|
376
450
|
for f in self._resample_audio(frame):
|
377
451
|
for nf in self._bstream.write(f.data.tobytes()):
|
378
|
-
realtime_input = LiveClientRealtimeInput(
|
379
|
-
media_chunks=[Blob(data=nf.data.tobytes(), mime_type="audio/pcm")]
|
452
|
+
realtime_input = types.LiveClientRealtimeInput(
|
453
|
+
media_chunks=[types.Blob(data=nf.data.tobytes(), mime_type="audio/pcm")]
|
380
454
|
)
|
381
455
|
self._send_client_event(realtime_input)
|
382
456
|
|
@@ -384,8 +458,8 @@ class RealtimeSession(llm.RealtimeSession):
|
|
384
458
|
encoded_data = images.encode(
|
385
459
|
frame, self._opts.image_encode_options or DEFAULT_IMAGE_ENCODE_OPTIONS
|
386
460
|
)
|
387
|
-
realtime_input = LiveClientRealtimeInput(
|
388
|
-
media_chunks=[Blob(data=encoded_data, mime_type="image/jpeg")]
|
461
|
+
realtime_input = types.LiveClientRealtimeInput(
|
462
|
+
media_chunks=[types.Blob(data=encoded_data, mime_type="image/jpeg")]
|
389
463
|
)
|
390
464
|
self._send_client_event(realtime_input)
|
391
465
|
|
@@ -402,16 +476,24 @@ class RealtimeSession(llm.RealtimeSession):
|
|
402
476
|
)
|
403
477
|
self._pending_generation_fut.cancel("Superseded by new generate_reply call")
|
404
478
|
|
405
|
-
fut = asyncio.Future()
|
479
|
+
fut = asyncio.Future[llm.GenerationCreatedEvent]()
|
406
480
|
self._pending_generation_fut = fut
|
407
481
|
|
482
|
+
if self._in_user_activity:
|
483
|
+
self._send_client_event(
|
484
|
+
types.LiveClientRealtimeInput(
|
485
|
+
activity_end=types.ActivityEnd(),
|
486
|
+
)
|
487
|
+
)
|
488
|
+
self._in_user_activity = False
|
489
|
+
|
408
490
|
# Gemini requires the last message to end with user's turn
|
409
491
|
# so we need to add a placeholder user turn in order to trigger a new generation
|
410
|
-
|
492
|
+
turns = []
|
411
493
|
if is_given(instructions):
|
412
|
-
|
413
|
-
|
414
|
-
self._send_client_event(
|
494
|
+
turns.append(types.Content(parts=[types.Part(text=instructions)], role="model"))
|
495
|
+
turns.append(types.Content(parts=[types.Part(text=".")], role="user"))
|
496
|
+
self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=True))
|
415
497
|
|
416
498
|
def _on_timeout() -> None:
|
417
499
|
if not fut.done():
|
@@ -428,8 +510,28 @@ class RealtimeSession(llm.RealtimeSession):
|
|
428
510
|
|
429
511
|
return fut
|
430
512
|
|
513
|
+
def start_user_activity(self) -> None:
|
514
|
+
if not self._manual_activity_detection:
|
515
|
+
return
|
516
|
+
|
517
|
+
if not self._in_user_activity:
|
518
|
+
self._in_user_activity = True
|
519
|
+
self._send_client_event(
|
520
|
+
types.LiveClientRealtimeInput(
|
521
|
+
activity_start=types.ActivityStart(),
|
522
|
+
)
|
523
|
+
)
|
524
|
+
|
431
525
|
def interrupt(self) -> None:
|
432
|
-
|
526
|
+
# Gemini Live treats activity start as interruption, so we rely on start_user_activity
|
527
|
+
# notifications to handle it
|
528
|
+
if (
|
529
|
+
self._opts.realtime_input_config
|
530
|
+
and self._opts.realtime_input_config.activity_handling
|
531
|
+
== types.ActivityHandling.NO_INTERRUPTION
|
532
|
+
):
|
533
|
+
return
|
534
|
+
self.start_user_activity()
|
433
535
|
|
434
536
|
def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
|
435
537
|
logger.warning("truncate is not supported by the Google Realtime API.")
|
@@ -456,7 +558,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
456
558
|
self._mark_current_generation_done()
|
457
559
|
|
458
560
|
@utils.log_exceptions(logger=logger)
|
459
|
-
async def _main_task(self):
|
561
|
+
async def _main_task(self) -> None:
|
562
|
+
max_retries = self._opts.conn_options.max_retry
|
563
|
+
|
460
564
|
while not self._msg_ch.closed:
|
461
565
|
# previous session might not be closed yet, we'll do it here.
|
462
566
|
await self._close_active_session()
|
@@ -471,7 +575,15 @@ class RealtimeSession(llm.RealtimeSession):
|
|
471
575
|
) as session:
|
472
576
|
async with self._session_lock:
|
473
577
|
self._active_session = session
|
474
|
-
|
578
|
+
turns_dict, _ = self._chat_ctx.copy(
|
579
|
+
exclude_function_call=True,
|
580
|
+
).to_provider_format(format="google", inject_dummy_user_message=False)
|
581
|
+
if turns_dict:
|
582
|
+
turns = [types.Content.model_validate(turn) for turn in turns_dict]
|
583
|
+
await session.send_client_content(
|
584
|
+
turns=turns, # type: ignore
|
585
|
+
turn_complete=False,
|
586
|
+
)
|
475
587
|
# queue up existing chat context
|
476
588
|
send_task = asyncio.create_task(
|
477
589
|
self._send_task(session), name="gemini-realtime-send"
|
@@ -504,12 +616,30 @@ class RealtimeSession(llm.RealtimeSession):
|
|
504
616
|
except Exception as e:
|
505
617
|
logger.error(f"Gemini Realtime API error: {e}", exc_info=e)
|
506
618
|
if not self._msg_ch.closed:
|
507
|
-
|
508
|
-
|
619
|
+
# we shouldn't retry when it's not connected, usually this means incorrect
|
620
|
+
# parameters or setup
|
621
|
+
if not session or max_retries == 0:
|
622
|
+
self._emit_error(e, recoverable=False)
|
623
|
+
raise APIConnectionError(message="Failed to connect to Gemini Live") from e
|
624
|
+
|
625
|
+
if self._num_retries == max_retries:
|
626
|
+
self._emit_error(e, recoverable=False)
|
627
|
+
raise APIConnectionError(
|
628
|
+
message=f"Failed to connect to Gemini Live after {max_retries} attempts"
|
629
|
+
) from e
|
630
|
+
|
631
|
+
retry_interval = self._opts.conn_options._interval_for_retry(self._num_retries)
|
632
|
+
logger.warning(
|
633
|
+
f"Gemini Realtime API connection failed, retrying in {retry_interval}s",
|
634
|
+
exc_info=e,
|
635
|
+
extra={"attempt": self._num_retries, "max_retries": max_retries},
|
636
|
+
)
|
637
|
+
await asyncio.sleep(retry_interval)
|
638
|
+
self._num_retries += 1
|
509
639
|
finally:
|
510
640
|
await self._close_active_session()
|
511
641
|
|
512
|
-
async def _send_task(self, session: AsyncSession):
|
642
|
+
async def _send_task(self, session: AsyncSession) -> None:
|
513
643
|
try:
|
514
644
|
async for msg in self._msg_ch:
|
515
645
|
async with self._session_lock:
|
@@ -517,15 +647,21 @@ class RealtimeSession(llm.RealtimeSession):
|
|
517
647
|
not self._active_session or self._active_session != session
|
518
648
|
):
|
519
649
|
break
|
520
|
-
if isinstance(msg, LiveClientContent):
|
650
|
+
if isinstance(msg, types.LiveClientContent):
|
521
651
|
await session.send_client_content(
|
522
|
-
turns=msg.turns,
|
652
|
+
turns=msg.turns, # type: ignore
|
653
|
+
turn_complete=msg.turn_complete or True,
|
523
654
|
)
|
524
|
-
elif isinstance(msg, LiveClientToolResponse):
|
655
|
+
elif isinstance(msg, types.LiveClientToolResponse) and msg.function_responses:
|
525
656
|
await session.send_tool_response(function_responses=msg.function_responses)
|
526
|
-
elif isinstance(msg, LiveClientRealtimeInput):
|
527
|
-
|
528
|
-
|
657
|
+
elif isinstance(msg, types.LiveClientRealtimeInput):
|
658
|
+
if msg.media_chunks:
|
659
|
+
for media_chunk in msg.media_chunks:
|
660
|
+
await session.send_realtime_input(media=media_chunk)
|
661
|
+
elif msg.activity_start:
|
662
|
+
await session.send_realtime_input(activity_start=msg.activity_start)
|
663
|
+
elif msg.activity_end:
|
664
|
+
await session.send_realtime_input(activity_end=msg.activity_end)
|
529
665
|
else:
|
530
666
|
logger.warning(f"Warning: Received unhandled message type: {type(msg)}")
|
531
667
|
|
@@ -536,7 +672,7 @@ class RealtimeSession(llm.RealtimeSession):
|
|
536
672
|
finally:
|
537
673
|
logger.debug("send task finished.")
|
538
674
|
|
539
|
-
async def _recv_task(self, session: AsyncSession):
|
675
|
+
async def _recv_task(self, session: AsyncSession) -> None:
|
540
676
|
try:
|
541
677
|
while True:
|
542
678
|
async with self._session_lock:
|
@@ -572,6 +708,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
572
708
|
if response.go_away:
|
573
709
|
self._handle_go_away(response.go_away)
|
574
710
|
|
711
|
+
if self._num_retries > 0:
|
712
|
+
self._num_retries = 0 # reset the retry counter
|
713
|
+
|
575
714
|
# TODO(dz): a server-side turn is complete
|
576
715
|
except Exception as e:
|
577
716
|
if not self._session_should_close.is_set():
|
@@ -580,14 +719,18 @@ class RealtimeSession(llm.RealtimeSession):
|
|
580
719
|
finally:
|
581
720
|
self._mark_current_generation_done()
|
582
721
|
|
583
|
-
def _build_connect_config(self) -> LiveConnectConfig:
|
722
|
+
def _build_connect_config(self) -> types.LiveConnectConfig:
|
584
723
|
temp = self._opts.temperature if is_given(self._opts.temperature) else None
|
585
724
|
|
586
|
-
|
725
|
+
tools_config = create_tools_config(
|
726
|
+
function_tools=self._gemini_declarations,
|
727
|
+
gemini_tools=self._opts.gemini_tools if is_given(self._opts.gemini_tools) else None,
|
728
|
+
)
|
729
|
+
conf = types.LiveConnectConfig(
|
587
730
|
response_modalities=self._opts.response_modalities
|
588
731
|
if is_given(self._opts.response_modalities)
|
589
|
-
else [Modality.AUDIO],
|
590
|
-
generation_config=GenerationConfig(
|
732
|
+
else [types.Modality.AUDIO],
|
733
|
+
generation_config=types.GenerationConfig(
|
591
734
|
candidate_count=self._opts.candidate_count,
|
592
735
|
temperature=temp,
|
593
736
|
max_output_tokens=self._opts.max_output_tokens
|
@@ -602,34 +745,45 @@ class RealtimeSession(llm.RealtimeSession):
|
|
602
745
|
if is_given(self._opts.frequency_penalty)
|
603
746
|
else None,
|
604
747
|
),
|
605
|
-
system_instruction=Content(parts=[Part(text=self._opts.instructions)])
|
748
|
+
system_instruction=types.Content(parts=[types.Part(text=self._opts.instructions)])
|
606
749
|
if is_given(self._opts.instructions)
|
607
750
|
else None,
|
608
|
-
speech_config=SpeechConfig(
|
609
|
-
voice_config=VoiceConfig(
|
610
|
-
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
|
751
|
+
speech_config=types.SpeechConfig(
|
752
|
+
voice_config=types.VoiceConfig(
|
753
|
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=self._opts.voice)
|
611
754
|
),
|
612
755
|
language_code=self._opts.language if is_given(self._opts.language) else None,
|
613
756
|
),
|
614
|
-
tools=
|
757
|
+
tools=tools_config, # type: ignore
|
615
758
|
input_audio_transcription=self._opts.input_audio_transcription,
|
616
759
|
output_audio_transcription=self._opts.output_audio_transcription,
|
617
|
-
session_resumption=SessionResumptionConfig(
|
618
|
-
|
619
|
-
automatic_activity_detection=AutomaticActivityDetection(),
|
760
|
+
session_resumption=types.SessionResumptionConfig(
|
761
|
+
handle=self._session_resumption_handle
|
620
762
|
),
|
621
763
|
)
|
622
764
|
|
623
|
-
|
765
|
+
if is_given(self._opts.proactivity):
|
766
|
+
conf.proactivity = types.ProactivityConfig(proactive_audio=self._opts.proactivity)
|
767
|
+
if is_given(self._opts.enable_affective_dialog):
|
768
|
+
conf.enable_affective_dialog = self._opts.enable_affective_dialog
|
769
|
+
if is_given(self._opts.realtime_input_config):
|
770
|
+
conf.realtime_input_config = self._opts.realtime_input_config
|
771
|
+
if is_given(self._opts.context_window_compression):
|
772
|
+
conf.context_window_compression = self._opts.context_window_compression
|
773
|
+
|
774
|
+
return conf
|
775
|
+
|
776
|
+
def _start_new_generation(self) -> None:
|
624
777
|
if self._current_generation and not self._current_generation._done:
|
625
778
|
logger.warning("starting new generation while another is active. Finalizing previous.")
|
626
779
|
self._mark_current_generation_done()
|
627
780
|
|
628
|
-
response_id = utils.shortuuid("
|
781
|
+
response_id = utils.shortuuid("GR_")
|
629
782
|
self._current_generation = _ResponseGeneration(
|
630
783
|
message_ch=utils.aio.Chan[llm.MessageGeneration](),
|
631
784
|
function_ch=utils.aio.Chan[llm.FunctionCall](),
|
632
785
|
response_id=response_id,
|
786
|
+
input_id=utils.shortuuid("GI_"),
|
633
787
|
text_ch=utils.aio.Chan[str](),
|
634
788
|
audio_ch=utils.aio.Chan[rtc.AudioFrame](),
|
635
789
|
_created_timestamp=time.time(),
|
@@ -656,21 +810,23 @@ class RealtimeSession(llm.RealtimeSession):
|
|
656
810
|
|
657
811
|
self.emit("generation_created", generation_event)
|
658
812
|
|
659
|
-
def _handle_server_content(self, server_content: LiveServerContent):
|
813
|
+
def _handle_server_content(self, server_content: types.LiveServerContent) -> None:
|
660
814
|
current_gen = self._current_generation
|
661
815
|
if not current_gen:
|
662
816
|
logger.warning("received server content but no active generation.")
|
663
817
|
return
|
664
818
|
|
665
819
|
if model_turn := server_content.model_turn:
|
666
|
-
for part in model_turn.parts:
|
820
|
+
for part in model_turn.parts or []:
|
667
821
|
if part.text:
|
668
|
-
current_gen.
|
822
|
+
current_gen.push_text(part.text)
|
669
823
|
if part.inline_data:
|
670
824
|
if not current_gen._first_token_timestamp:
|
671
825
|
current_gen._first_token_timestamp = time.time()
|
672
826
|
frame_data = part.inline_data.data
|
673
827
|
try:
|
828
|
+
if not isinstance(frame_data, bytes):
|
829
|
+
raise ValueError("frame_data is not bytes")
|
674
830
|
frame = rtc.AudioFrame(
|
675
831
|
data=frame_data,
|
676
832
|
sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
|
@@ -692,7 +848,7 @@ class RealtimeSession(llm.RealtimeSession):
|
|
692
848
|
self.emit(
|
693
849
|
"input_audio_transcription_completed",
|
694
850
|
llm.InputTranscriptionCompleted(
|
695
|
-
item_id=current_gen.
|
851
|
+
item_id=current_gen.input_id,
|
696
852
|
transcript=current_gen.input_transcription,
|
697
853
|
is_final=False,
|
698
854
|
),
|
@@ -701,20 +857,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
701
857
|
if output_transcription := server_content.output_transcription:
|
702
858
|
text = output_transcription.text
|
703
859
|
if text:
|
704
|
-
current_gen.
|
860
|
+
current_gen.push_text(text)
|
705
861
|
|
706
|
-
if server_content.generation_complete:
|
707
|
-
# The only way we'd know that the transcription is complete is by when they are
|
708
|
-
# done with generation
|
709
|
-
if current_gen.input_transcription:
|
710
|
-
self.emit(
|
711
|
-
"input_audio_transcription_completed",
|
712
|
-
llm.InputTranscriptionCompleted(
|
713
|
-
item_id=current_gen.response_id,
|
714
|
-
transcript=current_gen.input_transcription,
|
715
|
-
is_final=True,
|
716
|
-
),
|
717
|
-
)
|
862
|
+
if server_content.generation_complete or server_content.turn_complete:
|
718
863
|
current_gen._completed_timestamp = time.time()
|
719
864
|
|
720
865
|
if server_content.interrupted:
|
@@ -724,10 +869,38 @@ class RealtimeSession(llm.RealtimeSession):
|
|
724
869
|
self._mark_current_generation_done()
|
725
870
|
|
726
871
|
def _mark_current_generation_done(self) -> None:
|
727
|
-
if not self._current_generation:
|
872
|
+
if not self._current_generation or self._current_generation._done:
|
728
873
|
return
|
729
874
|
|
730
875
|
gen = self._current_generation
|
876
|
+
|
877
|
+
# The only way we'd know that the transcription is complete is by when they are
|
878
|
+
# done with generation
|
879
|
+
if gen.input_transcription:
|
880
|
+
self.emit(
|
881
|
+
"input_audio_transcription_completed",
|
882
|
+
llm.InputTranscriptionCompleted(
|
883
|
+
item_id=gen.input_id,
|
884
|
+
transcript=gen.input_transcription,
|
885
|
+
is_final=True,
|
886
|
+
),
|
887
|
+
)
|
888
|
+
|
889
|
+
# since gemini doesn't give us a view of the chat history on the server side,
|
890
|
+
# we would handle it manually here
|
891
|
+
self._chat_ctx.add_message(
|
892
|
+
role="user",
|
893
|
+
content=gen.input_transcription,
|
894
|
+
id=gen.input_id,
|
895
|
+
)
|
896
|
+
|
897
|
+
if gen.output_text:
|
898
|
+
self._chat_ctx.add_message(
|
899
|
+
role="assistant",
|
900
|
+
content=gen.output_text,
|
901
|
+
id=gen.response_id,
|
902
|
+
)
|
903
|
+
|
731
904
|
if not gen.text_ch.closed:
|
732
905
|
gen.text_ch.close()
|
733
906
|
if not gen.audio_ch.closed:
|
@@ -737,36 +910,37 @@ class RealtimeSession(llm.RealtimeSession):
|
|
737
910
|
gen.message_ch.close()
|
738
911
|
gen._done = True
|
739
912
|
|
740
|
-
def _handle_input_speech_started(self):
|
913
|
+
def _handle_input_speech_started(self) -> None:
|
741
914
|
self.emit("input_speech_started", llm.InputSpeechStartedEvent())
|
742
915
|
|
743
|
-
def _handle_tool_calls(self, tool_call: LiveServerToolCall):
|
916
|
+
def _handle_tool_calls(self, tool_call: types.LiveServerToolCall) -> None:
|
744
917
|
if not self._current_generation:
|
745
918
|
logger.warning("received tool call but no active generation.")
|
746
919
|
return
|
747
920
|
|
748
921
|
gen = self._current_generation
|
749
|
-
for fnc_call in tool_call.function_calls:
|
922
|
+
for fnc_call in tool_call.function_calls or []:
|
750
923
|
arguments = json.dumps(fnc_call.args)
|
751
924
|
|
752
925
|
gen.function_ch.send_nowait(
|
753
926
|
llm.FunctionCall(
|
754
927
|
call_id=fnc_call.id or utils.shortuuid("fnc-call-"),
|
755
|
-
name=fnc_call.name,
|
928
|
+
name=fnc_call.name, # type: ignore
|
756
929
|
arguments=arguments,
|
757
930
|
)
|
758
931
|
)
|
932
|
+
self._on_final_input_audio_transcription()
|
759
933
|
self._mark_current_generation_done()
|
760
934
|
|
761
935
|
def _handle_tool_call_cancellation(
|
762
|
-
self, tool_call_cancellation: LiveServerToolCallCancellation
|
763
|
-
):
|
936
|
+
self, tool_call_cancellation: types.LiveServerToolCallCancellation
|
937
|
+
) -> None:
|
764
938
|
logger.warning(
|
765
939
|
"server cancelled tool calls",
|
766
940
|
extra={"function_call_ids": tool_call_cancellation.ids},
|
767
941
|
)
|
768
942
|
|
769
|
-
def _handle_usage_metadata(self, usage_metadata: UsageMetadata):
|
943
|
+
def _handle_usage_metadata(self, usage_metadata: types.UsageMetadata) -> None:
|
770
944
|
current_gen = self._current_generation
|
771
945
|
if not current_gen:
|
772
946
|
logger.warning("no active generation to report metrics for")
|
@@ -782,18 +956,21 @@ class RealtimeSession(llm.RealtimeSession):
|
|
782
956
|
) - current_gen._created_timestamp
|
783
957
|
|
784
958
|
def _token_details_map(
|
785
|
-
token_details: list[ModalityTokenCount] | None,
|
786
|
-
) -> dict[
|
959
|
+
token_details: list[types.ModalityTokenCount] | None,
|
960
|
+
) -> dict[str, int]:
|
787
961
|
token_details_map = {"audio_tokens": 0, "text_tokens": 0, "image_tokens": 0}
|
788
962
|
if not token_details:
|
789
963
|
return token_details_map
|
790
964
|
|
791
965
|
for token_detail in token_details:
|
792
|
-
if token_detail.
|
966
|
+
if not token_detail.token_count:
|
967
|
+
continue
|
968
|
+
|
969
|
+
if token_detail.modality == types.MediaModality.AUDIO:
|
793
970
|
token_details_map["audio_tokens"] += token_detail.token_count
|
794
|
-
elif token_detail.modality ==
|
971
|
+
elif token_detail.modality == types.MediaModality.TEXT:
|
795
972
|
token_details_map["text_tokens"] += token_detail.token_count
|
796
|
-
elif token_detail.modality ==
|
973
|
+
elif token_detail.modality == types.MediaModality.IMAGE:
|
797
974
|
token_details_map["image_tokens"] += token_detail.token_count
|
798
975
|
return token_details_map
|
799
976
|
|
@@ -807,7 +984,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
807
984
|
input_tokens=usage_metadata.prompt_token_count or 0,
|
808
985
|
output_tokens=usage_metadata.response_token_count or 0,
|
809
986
|
total_tokens=usage_metadata.total_token_count or 0,
|
810
|
-
tokens_per_second=(usage_metadata.response_token_count or 0) / duration
|
987
|
+
tokens_per_second=(usage_metadata.response_token_count or 0) / duration
|
988
|
+
if duration > 0
|
989
|
+
else 0,
|
811
990
|
input_token_details=RealtimeModelMetrics.InputTokenDetails(
|
812
991
|
**_token_details_map(usage_metadata.prompt_tokens_details),
|
813
992
|
cached_tokens=sum(
|
@@ -824,18 +1003,27 @@ class RealtimeSession(llm.RealtimeSession):
|
|
824
1003
|
)
|
825
1004
|
self.emit("metrics_collected", metrics)
|
826
1005
|
|
827
|
-
def _handle_go_away(self, go_away: LiveServerGoAway):
|
1006
|
+
def _handle_go_away(self, go_away: types.LiveServerGoAway) -> None:
|
828
1007
|
logger.warning(
|
829
1008
|
f"Gemini server indicates disconnection soon. Time left: {go_away.time_left}"
|
830
1009
|
)
|
831
1010
|
# TODO(dz): this isn't a seamless reconnection just yet
|
832
1011
|
self._session_should_close.set()
|
833
1012
|
|
1013
|
+
def _on_final_input_audio_transcription(self) -> None:
|
1014
|
+
if (gen := self._current_generation) and gen.input_transcription:
|
1015
|
+
self.emit(
|
1016
|
+
"input_audio_transcription_completed",
|
1017
|
+
llm.InputTranscriptionCompleted(
|
1018
|
+
item_id=gen.response_id, transcript=gen.input_transcription, is_final=True
|
1019
|
+
),
|
1020
|
+
)
|
1021
|
+
|
834
1022
|
def commit_audio(self) -> None:
|
835
1023
|
pass
|
836
1024
|
|
837
1025
|
def clear_audio(self) -> None:
|
838
|
-
|
1026
|
+
pass
|
839
1027
|
|
840
1028
|
def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
|
841
1029
|
if self._input_resampler:
|
@@ -858,3 +1046,14 @@ class RealtimeSession(llm.RealtimeSession):
|
|
858
1046
|
yield from self._input_resampler.push(frame)
|
859
1047
|
else:
|
860
1048
|
yield frame
|
1049
|
+
|
1050
|
+
def _emit_error(self, error: Exception, recoverable: bool) -> None:
|
1051
|
+
self.emit(
|
1052
|
+
"error",
|
1053
|
+
llm.RealtimeModelError(
|
1054
|
+
timestamp=time.time(),
|
1055
|
+
label=self._realtime_model._label,
|
1056
|
+
error=error,
|
1057
|
+
recoverable=recoverable,
|
1058
|
+
),
|
1059
|
+
)
|