livekit-plugins-google 1.0.23__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/__init__.py +3 -2
- livekit/plugins/google/beta/realtime/realtime_api.py +296 -118
- livekit/plugins/google/llm.py +60 -27
- livekit/plugins/google/stt.py +19 -12
- livekit/plugins/google/tools.py +11 -0
- livekit/plugins/google/tts.py +109 -136
- livekit/plugins/google/utils.py +39 -88
- livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-1.0.23.dist-info → livekit_plugins_google-1.1.0.dist-info}/METADATA +2 -2
- livekit_plugins_google-1.1.0.dist-info/RECORD +17 -0
- livekit_plugins_google-1.0.23.dist-info/RECORD +0 -16
- {livekit_plugins_google-1.0.23.dist-info → livekit_plugins_google-1.1.0.dist-info}/WHEEL +0 -0
@@ -10,41 +10,23 @@ from collections.abc import Iterator
|
|
10
10
|
from dataclasses import dataclass, field
|
11
11
|
|
12
12
|
from google import genai
|
13
|
+
from google.genai import types
|
13
14
|
from google.genai.live import AsyncSession
|
14
|
-
from google.genai.types import (
|
15
|
-
AudioTranscriptionConfig,
|
16
|
-
Blob,
|
17
|
-
Content,
|
18
|
-
FunctionDeclaration,
|
19
|
-
GenerationConfig,
|
20
|
-
LiveClientContent,
|
21
|
-
LiveClientRealtimeInput,
|
22
|
-
LiveClientToolResponse,
|
23
|
-
LiveConnectConfig,
|
24
|
-
LiveServerContent,
|
25
|
-
LiveServerGoAway,
|
26
|
-
LiveServerToolCall,
|
27
|
-
LiveServerToolCallCancellation,
|
28
|
-
Modality,
|
29
|
-
ModalityTokenCount,
|
30
|
-
Part,
|
31
|
-
PrebuiltVoiceConfig,
|
32
|
-
RealtimeInputConfig,
|
33
|
-
SessionResumptionConfig,
|
34
|
-
SpeechConfig,
|
35
|
-
Tool,
|
36
|
-
UsageMetadata,
|
37
|
-
VoiceConfig,
|
38
|
-
)
|
39
15
|
from livekit import rtc
|
40
|
-
from livekit.agents import llm, utils
|
16
|
+
from livekit.agents import APIConnectionError, llm, utils
|
41
17
|
from livekit.agents.metrics import RealtimeModelMetrics
|
42
|
-
from livekit.agents.types import
|
18
|
+
from livekit.agents.types import (
|
19
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
20
|
+
NOT_GIVEN,
|
21
|
+
APIConnectOptions,
|
22
|
+
NotGivenOr,
|
23
|
+
)
|
43
24
|
from livekit.agents.utils import audio as audio_utils, images, is_given
|
44
25
|
from livekit.plugins.google.beta.realtime.api_proto import ClientEvents, LiveAPIModels, Voice
|
45
26
|
|
46
27
|
from ...log import logger
|
47
|
-
from ...
|
28
|
+
from ...tools import _LLMTool
|
29
|
+
from ...utils import create_tools_config, get_tool_results_for_realtime, to_fnc_ctx
|
48
30
|
|
49
31
|
INPUT_AUDIO_SAMPLE_RATE = 16000
|
50
32
|
INPUT_AUDIO_CHANNELS = 1
|
@@ -70,7 +52,7 @@ class _RealtimeOptions:
|
|
70
52
|
api_key: str | None
|
71
53
|
voice: Voice | str
|
72
54
|
language: NotGivenOr[str]
|
73
|
-
response_modalities: NotGivenOr[list[Modality]]
|
55
|
+
response_modalities: NotGivenOr[list[types.Modality]]
|
74
56
|
vertexai: bool
|
75
57
|
project: str | None
|
76
58
|
location: str | None
|
@@ -82,12 +64,16 @@ class _RealtimeOptions:
|
|
82
64
|
presence_penalty: NotGivenOr[float]
|
83
65
|
frequency_penalty: NotGivenOr[float]
|
84
66
|
instructions: NotGivenOr[str]
|
85
|
-
input_audio_transcription: AudioTranscriptionConfig | None
|
86
|
-
output_audio_transcription: AudioTranscriptionConfig | None
|
67
|
+
input_audio_transcription: types.AudioTranscriptionConfig | None
|
68
|
+
output_audio_transcription: types.AudioTranscriptionConfig | None
|
87
69
|
image_encode_options: NotGivenOr[images.EncodeOptions]
|
70
|
+
conn_options: APIConnectOptions
|
88
71
|
enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN
|
89
72
|
proactivity: NotGivenOr[bool] = NOT_GIVEN
|
90
|
-
realtime_input_config: NotGivenOr[RealtimeInputConfig] = NOT_GIVEN
|
73
|
+
realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN
|
74
|
+
context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN
|
75
|
+
api_version: NotGivenOr[str] = NOT_GIVEN
|
76
|
+
gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN
|
91
77
|
|
92
78
|
|
93
79
|
@dataclass
|
@@ -95,10 +81,13 @@ class _ResponseGeneration:
|
|
95
81
|
message_ch: utils.aio.Chan[llm.MessageGeneration]
|
96
82
|
function_ch: utils.aio.Chan[llm.FunctionCall]
|
97
83
|
|
84
|
+
input_id: str
|
98
85
|
response_id: str
|
99
86
|
text_ch: utils.aio.Chan[str]
|
100
87
|
audio_ch: utils.aio.Chan[rtc.AudioFrame]
|
88
|
+
|
101
89
|
input_transcription: str = ""
|
90
|
+
output_text: str = ""
|
102
91
|
|
103
92
|
_created_timestamp: float = field(default_factory=time.time)
|
104
93
|
"""The timestamp when the generation is created"""
|
@@ -109,6 +98,14 @@ class _ResponseGeneration:
|
|
109
98
|
_done: bool = False
|
110
99
|
"""Whether the generation is done (set when the turn is complete)"""
|
111
100
|
|
101
|
+
def push_text(self, text: str) -> None:
|
102
|
+
if self.output_text:
|
103
|
+
self.output_text += text
|
104
|
+
else:
|
105
|
+
self.output_text = text
|
106
|
+
|
107
|
+
self.text_ch.send_nowait(text)
|
108
|
+
|
112
109
|
|
113
110
|
class RealtimeModel(llm.RealtimeModel):
|
114
111
|
def __init__(
|
@@ -119,7 +116,7 @@ class RealtimeModel(llm.RealtimeModel):
|
|
119
116
|
api_key: NotGivenOr[str] = NOT_GIVEN,
|
120
117
|
voice: Voice | str = "Puck",
|
121
118
|
language: NotGivenOr[str] = NOT_GIVEN,
|
122
|
-
modalities: NotGivenOr[list[Modality]] = NOT_GIVEN,
|
119
|
+
modalities: NotGivenOr[list[types.Modality]] = NOT_GIVEN,
|
123
120
|
vertexai: NotGivenOr[bool] = NOT_GIVEN,
|
124
121
|
project: NotGivenOr[str] = NOT_GIVEN,
|
125
122
|
location: NotGivenOr[str] = NOT_GIVEN,
|
@@ -130,12 +127,16 @@ class RealtimeModel(llm.RealtimeModel):
|
|
130
127
|
top_k: NotGivenOr[int] = NOT_GIVEN,
|
131
128
|
presence_penalty: NotGivenOr[float] = NOT_GIVEN,
|
132
129
|
frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
|
133
|
-
input_audio_transcription: NotGivenOr[AudioTranscriptionConfig | None] = NOT_GIVEN,
|
134
|
-
output_audio_transcription: NotGivenOr[AudioTranscriptionConfig | None] = NOT_GIVEN,
|
130
|
+
input_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
|
131
|
+
output_audio_transcription: NotGivenOr[types.AudioTranscriptionConfig | None] = NOT_GIVEN,
|
135
132
|
image_encode_options: NotGivenOr[images.EncodeOptions] = NOT_GIVEN,
|
136
133
|
enable_affective_dialog: NotGivenOr[bool] = NOT_GIVEN,
|
137
134
|
proactivity: NotGivenOr[bool] = NOT_GIVEN,
|
138
|
-
realtime_input_config: NotGivenOr[RealtimeInputConfig] = NOT_GIVEN,
|
135
|
+
realtime_input_config: NotGivenOr[types.RealtimeInputConfig] = NOT_GIVEN,
|
136
|
+
context_window_compression: NotGivenOr[types.ContextWindowCompressionConfig] = NOT_GIVEN,
|
137
|
+
api_version: NotGivenOr[str] = NOT_GIVEN,
|
138
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
139
|
+
_gemini_tools: NotGivenOr[list[_LLMTool]] = NOT_GIVEN,
|
139
140
|
) -> None:
|
140
141
|
"""
|
141
142
|
Initializes a RealtimeModel instance for interacting with Google's Realtime API.
|
@@ -169,19 +170,30 @@ class RealtimeModel(llm.RealtimeModel):
|
|
169
170
|
enable_affective_dialog (bool, optional): Whether to enable affective dialog. Defaults to False.
|
170
171
|
proactivity (bool, optional): Whether to enable proactive audio. Defaults to False.
|
171
172
|
realtime_input_config (RealtimeInputConfig, optional): The configuration for realtime input. Defaults to None.
|
173
|
+
context_window_compression (ContextWindowCompressionConfig, optional): The configuration for context window compression. Defaults to None.
|
174
|
+
conn_options (APIConnectOptions, optional): The configuration for the API connection. Defaults to DEFAULT_API_CONNECT_OPTIONS.
|
175
|
+
_gemini_tools (list[LLMTool], optional): Gemini-specific tools to use for the session. This parameter is experimental and may change.
|
172
176
|
|
173
177
|
Raises:
|
174
178
|
ValueError: If the API key is required but not found.
|
175
179
|
""" # noqa: E501
|
176
180
|
if not is_given(input_audio_transcription):
|
177
|
-
input_audio_transcription = AudioTranscriptionConfig()
|
181
|
+
input_audio_transcription = types.AudioTranscriptionConfig()
|
178
182
|
if not is_given(output_audio_transcription):
|
179
|
-
output_audio_transcription = AudioTranscriptionConfig()
|
183
|
+
output_audio_transcription = types.AudioTranscriptionConfig()
|
184
|
+
|
185
|
+
server_turn_detection = True
|
186
|
+
if (
|
187
|
+
is_given(realtime_input_config)
|
188
|
+
and realtime_input_config.automatic_activity_detection
|
189
|
+
and realtime_input_config.automatic_activity_detection.disabled
|
190
|
+
):
|
191
|
+
server_turn_detection = False
|
180
192
|
|
181
193
|
super().__init__(
|
182
194
|
capabilities=llm.RealtimeCapabilities(
|
183
195
|
message_truncation=False,
|
184
|
-
turn_detection=
|
196
|
+
turn_detection=server_turn_detection,
|
185
197
|
user_transcription=input_audio_transcription is not None,
|
186
198
|
auto_tool_reply_generation=True,
|
187
199
|
)
|
@@ -195,7 +207,7 @@ class RealtimeModel(llm.RealtimeModel):
|
|
195
207
|
|
196
208
|
gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
|
197
209
|
gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
|
198
|
-
gcp_location = (
|
210
|
+
gcp_location: str | None = (
|
199
211
|
location
|
200
212
|
if is_given(location)
|
201
213
|
else os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
@@ -243,6 +255,10 @@ class RealtimeModel(llm.RealtimeModel):
|
|
243
255
|
enable_affective_dialog=enable_affective_dialog,
|
244
256
|
proactivity=proactivity,
|
245
257
|
realtime_input_config=realtime_input_config,
|
258
|
+
context_window_compression=context_window_compression,
|
259
|
+
api_version=api_version,
|
260
|
+
gemini_tools=_gemini_tools,
|
261
|
+
conn_options=conn_options,
|
246
262
|
)
|
247
263
|
|
248
264
|
self._sessions = weakref.WeakSet[RealtimeSession]()
|
@@ -253,8 +269,19 @@ class RealtimeModel(llm.RealtimeModel):
|
|
253
269
|
return sess
|
254
270
|
|
255
271
|
def update_options(
|
256
|
-
self,
|
272
|
+
self,
|
273
|
+
*,
|
274
|
+
voice: NotGivenOr[str] = NOT_GIVEN,
|
275
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
257
276
|
) -> None:
|
277
|
+
"""
|
278
|
+
Update the options for the RealtimeModel.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
voice (str, optional): The voice to use for the session.
|
282
|
+
temperature (float, optional): The temperature to use for the session.
|
283
|
+
tools (list[LLMTool], optional): The tools to use for the session.
|
284
|
+
"""
|
258
285
|
if is_given(voice):
|
259
286
|
self._opts.voice = voice
|
260
287
|
|
@@ -262,7 +289,10 @@ class RealtimeModel(llm.RealtimeModel):
|
|
262
289
|
self._opts.temperature = temperature
|
263
290
|
|
264
291
|
for sess in self._sessions:
|
265
|
-
sess.update_options(
|
292
|
+
sess.update_options(
|
293
|
+
voice=self._opts.voice,
|
294
|
+
temperature=self._opts.temperature,
|
295
|
+
)
|
266
296
|
|
267
297
|
async def aclose(self) -> None:
|
268
298
|
pass
|
@@ -273,7 +303,7 @@ class RealtimeSession(llm.RealtimeSession):
|
|
273
303
|
super().__init__(realtime_model)
|
274
304
|
self._opts = realtime_model._opts
|
275
305
|
self._tools = llm.ToolContext.empty()
|
276
|
-
self._gemini_declarations: list[FunctionDeclaration] = []
|
306
|
+
self._gemini_declarations: list[types.FunctionDeclaration] = []
|
277
307
|
self._chat_ctx = llm.ChatContext.empty()
|
278
308
|
self._msg_ch = utils.aio.Chan[ClientEvents]()
|
279
309
|
self._input_resampler: rtc.AudioResampler | None = None
|
@@ -285,11 +315,20 @@ class RealtimeSession(llm.RealtimeSession):
|
|
285
315
|
samples_per_channel=INPUT_AUDIO_SAMPLE_RATE // 20,
|
286
316
|
)
|
287
317
|
|
318
|
+
api_version = self._opts.api_version
|
319
|
+
if not api_version and (self._opts.enable_affective_dialog or self._opts.proactivity):
|
320
|
+
api_version = "v1alpha"
|
321
|
+
|
322
|
+
http_options = types.HttpOptions(timeout=int(self._opts.conn_options.timeout * 1000))
|
323
|
+
if api_version:
|
324
|
+
http_options.api_version = api_version
|
325
|
+
|
288
326
|
self._client = genai.Client(
|
289
327
|
api_key=self._opts.api_key,
|
290
328
|
vertexai=self._opts.vertexai,
|
291
329
|
project=self._opts.project,
|
292
330
|
location=self._opts.location,
|
331
|
+
http_options=http_options,
|
293
332
|
)
|
294
333
|
|
295
334
|
self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
|
@@ -302,8 +341,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
302
341
|
self._pending_generation_fut: asyncio.Future[llm.GenerationCreatedEvent] | None = None
|
303
342
|
|
304
343
|
self._session_resumption_handle: str | None = None
|
305
|
-
|
344
|
+
self._in_user_activity = False
|
306
345
|
self._session_lock = asyncio.Lock()
|
346
|
+
self._num_retries = 0
|
307
347
|
|
308
348
|
async def _close_active_session(self) -> None:
|
309
349
|
async with self._session_lock:
|
@@ -315,7 +355,7 @@ class RealtimeSession(llm.RealtimeSession):
|
|
315
355
|
finally:
|
316
356
|
self._active_session = None
|
317
357
|
|
318
|
-
def _mark_restart_needed(self):
|
358
|
+
def _mark_restart_needed(self) -> None:
|
319
359
|
if not self._session_should_close.is_set():
|
320
360
|
self._session_should_close.set()
|
321
361
|
# reset the msg_ch, do not send messages from previous session
|
@@ -346,6 +386,11 @@ class RealtimeSession(llm.RealtimeSession):
|
|
346
386
|
self._mark_restart_needed()
|
347
387
|
|
348
388
|
async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
|
389
|
+
async with self._session_lock:
|
390
|
+
if not self._active_session:
|
391
|
+
self._chat_ctx = chat_ctx.copy()
|
392
|
+
return
|
393
|
+
|
349
394
|
diff_ops = llm.utils.compute_chat_ctx_diff(self._chat_ctx, chat_ctx)
|
350
395
|
|
351
396
|
if diff_ops.to_remove:
|
@@ -358,15 +403,23 @@ class RealtimeSession(llm.RealtimeSession):
|
|
358
403
|
append_ctx.items.append(item)
|
359
404
|
|
360
405
|
if append_ctx.items:
|
361
|
-
|
406
|
+
turns_dict, _ = append_ctx.copy(
|
407
|
+
exclude_function_call=True,
|
408
|
+
).to_provider_format(format="google", inject_dummy_user_message=False)
|
409
|
+
# we are not generating, and do not need to inject
|
410
|
+
turns = [types.Content.model_validate(turn) for turn in turns_dict]
|
362
411
|
tool_results = get_tool_results_for_realtime(append_ctx, vertexai=self._opts.vertexai)
|
363
412
|
if turns:
|
364
|
-
self._send_client_event(LiveClientContent(turns=turns, turn_complete=False))
|
413
|
+
self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=False))
|
365
414
|
if tool_results:
|
366
415
|
self._send_client_event(tool_results)
|
367
416
|
|
368
|
-
|
369
|
-
|
417
|
+
# since we don't have a view of the history on the server side, we'll assume
|
418
|
+
# the current state is accurate. this isn't perfect because removals aren't done.
|
419
|
+
self._chat_ctx = chat_ctx.copy()
|
420
|
+
|
421
|
+
async def update_tools(self, tools: list[llm.FunctionTool | llm.RawFunctionTool]) -> None:
|
422
|
+
new_declarations: list[types.FunctionDeclaration] = to_fnc_ctx(tools)
|
370
423
|
current_tool_names = {f.name for f in self._gemini_declarations}
|
371
424
|
new_tool_names = {f.name for f in new_declarations}
|
372
425
|
|
@@ -383,11 +436,21 @@ class RealtimeSession(llm.RealtimeSession):
|
|
383
436
|
def tools(self) -> llm.ToolContext:
|
384
437
|
return self._tools.copy()
|
385
438
|
|
439
|
+
@property
|
440
|
+
def _manual_activity_detection(self) -> bool:
|
441
|
+
if (
|
442
|
+
is_given(self._opts.realtime_input_config)
|
443
|
+
and self._opts.realtime_input_config.automatic_activity_detection is not None
|
444
|
+
and self._opts.realtime_input_config.automatic_activity_detection.disabled
|
445
|
+
):
|
446
|
+
return True
|
447
|
+
return False
|
448
|
+
|
386
449
|
def push_audio(self, frame: rtc.AudioFrame) -> None:
|
387
450
|
for f in self._resample_audio(frame):
|
388
451
|
for nf in self._bstream.write(f.data.tobytes()):
|
389
|
-
realtime_input = LiveClientRealtimeInput(
|
390
|
-
media_chunks=[Blob(data=nf.data.tobytes(), mime_type="audio/pcm")]
|
452
|
+
realtime_input = types.LiveClientRealtimeInput(
|
453
|
+
media_chunks=[types.Blob(data=nf.data.tobytes(), mime_type="audio/pcm")]
|
391
454
|
)
|
392
455
|
self._send_client_event(realtime_input)
|
393
456
|
|
@@ -395,8 +458,8 @@ class RealtimeSession(llm.RealtimeSession):
|
|
395
458
|
encoded_data = images.encode(
|
396
459
|
frame, self._opts.image_encode_options or DEFAULT_IMAGE_ENCODE_OPTIONS
|
397
460
|
)
|
398
|
-
realtime_input = LiveClientRealtimeInput(
|
399
|
-
media_chunks=[Blob(data=encoded_data, mime_type="image/jpeg")]
|
461
|
+
realtime_input = types.LiveClientRealtimeInput(
|
462
|
+
media_chunks=[types.Blob(data=encoded_data, mime_type="image/jpeg")]
|
400
463
|
)
|
401
464
|
self._send_client_event(realtime_input)
|
402
465
|
|
@@ -413,16 +476,24 @@ class RealtimeSession(llm.RealtimeSession):
|
|
413
476
|
)
|
414
477
|
self._pending_generation_fut.cancel("Superseded by new generate_reply call")
|
415
478
|
|
416
|
-
fut = asyncio.Future()
|
479
|
+
fut = asyncio.Future[llm.GenerationCreatedEvent]()
|
417
480
|
self._pending_generation_fut = fut
|
418
481
|
|
482
|
+
if self._in_user_activity:
|
483
|
+
self._send_client_event(
|
484
|
+
types.LiveClientRealtimeInput(
|
485
|
+
activity_end=types.ActivityEnd(),
|
486
|
+
)
|
487
|
+
)
|
488
|
+
self._in_user_activity = False
|
489
|
+
|
419
490
|
# Gemini requires the last message to end with user's turn
|
420
491
|
# so we need to add a placeholder user turn in order to trigger a new generation
|
421
|
-
|
492
|
+
turns = []
|
422
493
|
if is_given(instructions):
|
423
|
-
|
424
|
-
|
425
|
-
self._send_client_event(
|
494
|
+
turns.append(types.Content(parts=[types.Part(text=instructions)], role="model"))
|
495
|
+
turns.append(types.Content(parts=[types.Part(text=".")], role="user"))
|
496
|
+
self._send_client_event(types.LiveClientContent(turns=turns, turn_complete=True))
|
426
497
|
|
427
498
|
def _on_timeout() -> None:
|
428
499
|
if not fut.done():
|
@@ -439,8 +510,28 @@ class RealtimeSession(llm.RealtimeSession):
|
|
439
510
|
|
440
511
|
return fut
|
441
512
|
|
513
|
+
def start_user_activity(self) -> None:
|
514
|
+
if not self._manual_activity_detection:
|
515
|
+
return
|
516
|
+
|
517
|
+
if not self._in_user_activity:
|
518
|
+
self._in_user_activity = True
|
519
|
+
self._send_client_event(
|
520
|
+
types.LiveClientRealtimeInput(
|
521
|
+
activity_start=types.ActivityStart(),
|
522
|
+
)
|
523
|
+
)
|
524
|
+
|
442
525
|
def interrupt(self) -> None:
|
443
|
-
|
526
|
+
# Gemini Live treats activity start as interruption, so we rely on start_user_activity
|
527
|
+
# notifications to handle it
|
528
|
+
if (
|
529
|
+
self._opts.realtime_input_config
|
530
|
+
and self._opts.realtime_input_config.activity_handling
|
531
|
+
== types.ActivityHandling.NO_INTERRUPTION
|
532
|
+
):
|
533
|
+
return
|
534
|
+
self.start_user_activity()
|
444
535
|
|
445
536
|
def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
|
446
537
|
logger.warning("truncate is not supported by the Google Realtime API.")
|
@@ -467,7 +558,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
467
558
|
self._mark_current_generation_done()
|
468
559
|
|
469
560
|
@utils.log_exceptions(logger=logger)
|
470
|
-
async def _main_task(self):
|
561
|
+
async def _main_task(self) -> None:
|
562
|
+
max_retries = self._opts.conn_options.max_retry
|
563
|
+
|
471
564
|
while not self._msg_ch.closed:
|
472
565
|
# previous session might not be closed yet, we'll do it here.
|
473
566
|
await self._close_active_session()
|
@@ -482,7 +575,15 @@ class RealtimeSession(llm.RealtimeSession):
|
|
482
575
|
) as session:
|
483
576
|
async with self._session_lock:
|
484
577
|
self._active_session = session
|
485
|
-
|
578
|
+
turns_dict, _ = self._chat_ctx.copy(
|
579
|
+
exclude_function_call=True,
|
580
|
+
).to_provider_format(format="google", inject_dummy_user_message=False)
|
581
|
+
if turns_dict:
|
582
|
+
turns = [types.Content.model_validate(turn) for turn in turns_dict]
|
583
|
+
await session.send_client_content(
|
584
|
+
turns=turns, # type: ignore
|
585
|
+
turn_complete=False,
|
586
|
+
)
|
486
587
|
# queue up existing chat context
|
487
588
|
send_task = asyncio.create_task(
|
488
589
|
self._send_task(session), name="gemini-realtime-send"
|
@@ -515,12 +616,30 @@ class RealtimeSession(llm.RealtimeSession):
|
|
515
616
|
except Exception as e:
|
516
617
|
logger.error(f"Gemini Realtime API error: {e}", exc_info=e)
|
517
618
|
if not self._msg_ch.closed:
|
518
|
-
|
519
|
-
|
619
|
+
# we shouldn't retry when it's not connected, usually this means incorrect
|
620
|
+
# parameters or setup
|
621
|
+
if not session or max_retries == 0:
|
622
|
+
self._emit_error(e, recoverable=False)
|
623
|
+
raise APIConnectionError(message="Failed to connect to Gemini Live") from e
|
624
|
+
|
625
|
+
if self._num_retries == max_retries:
|
626
|
+
self._emit_error(e, recoverable=False)
|
627
|
+
raise APIConnectionError(
|
628
|
+
message=f"Failed to connect to Gemini Live after {max_retries} attempts"
|
629
|
+
) from e
|
630
|
+
|
631
|
+
retry_interval = self._opts.conn_options._interval_for_retry(self._num_retries)
|
632
|
+
logger.warning(
|
633
|
+
f"Gemini Realtime API connection failed, retrying in {retry_interval}s",
|
634
|
+
exc_info=e,
|
635
|
+
extra={"attempt": self._num_retries, "max_retries": max_retries},
|
636
|
+
)
|
637
|
+
await asyncio.sleep(retry_interval)
|
638
|
+
self._num_retries += 1
|
520
639
|
finally:
|
521
640
|
await self._close_active_session()
|
522
641
|
|
523
|
-
async def _send_task(self, session: AsyncSession):
|
642
|
+
async def _send_task(self, session: AsyncSession) -> None:
|
524
643
|
try:
|
525
644
|
async for msg in self._msg_ch:
|
526
645
|
async with self._session_lock:
|
@@ -528,15 +647,21 @@ class RealtimeSession(llm.RealtimeSession):
|
|
528
647
|
not self._active_session or self._active_session != session
|
529
648
|
):
|
530
649
|
break
|
531
|
-
if isinstance(msg, LiveClientContent):
|
650
|
+
if isinstance(msg, types.LiveClientContent):
|
532
651
|
await session.send_client_content(
|
533
|
-
turns=msg.turns,
|
652
|
+
turns=msg.turns, # type: ignore
|
653
|
+
turn_complete=msg.turn_complete or True,
|
534
654
|
)
|
535
|
-
elif isinstance(msg, LiveClientToolResponse):
|
655
|
+
elif isinstance(msg, types.LiveClientToolResponse) and msg.function_responses:
|
536
656
|
await session.send_tool_response(function_responses=msg.function_responses)
|
537
|
-
elif isinstance(msg, LiveClientRealtimeInput):
|
538
|
-
|
539
|
-
|
657
|
+
elif isinstance(msg, types.LiveClientRealtimeInput):
|
658
|
+
if msg.media_chunks:
|
659
|
+
for media_chunk in msg.media_chunks:
|
660
|
+
await session.send_realtime_input(media=media_chunk)
|
661
|
+
elif msg.activity_start:
|
662
|
+
await session.send_realtime_input(activity_start=msg.activity_start)
|
663
|
+
elif msg.activity_end:
|
664
|
+
await session.send_realtime_input(activity_end=msg.activity_end)
|
540
665
|
else:
|
541
666
|
logger.warning(f"Warning: Received unhandled message type: {type(msg)}")
|
542
667
|
|
@@ -547,7 +672,7 @@ class RealtimeSession(llm.RealtimeSession):
|
|
547
672
|
finally:
|
548
673
|
logger.debug("send task finished.")
|
549
674
|
|
550
|
-
async def _recv_task(self, session: AsyncSession):
|
675
|
+
async def _recv_task(self, session: AsyncSession) -> None:
|
551
676
|
try:
|
552
677
|
while True:
|
553
678
|
async with self._session_lock:
|
@@ -583,6 +708,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
583
708
|
if response.go_away:
|
584
709
|
self._handle_go_away(response.go_away)
|
585
710
|
|
711
|
+
if self._num_retries > 0:
|
712
|
+
self._num_retries = 0 # reset the retry counter
|
713
|
+
|
586
714
|
# TODO(dz): a server-side turn is complete
|
587
715
|
except Exception as e:
|
588
716
|
if not self._session_should_close.is_set():
|
@@ -591,14 +719,18 @@ class RealtimeSession(llm.RealtimeSession):
|
|
591
719
|
finally:
|
592
720
|
self._mark_current_generation_done()
|
593
721
|
|
594
|
-
def _build_connect_config(self) -> LiveConnectConfig:
|
722
|
+
def _build_connect_config(self) -> types.LiveConnectConfig:
|
595
723
|
temp = self._opts.temperature if is_given(self._opts.temperature) else None
|
596
724
|
|
597
|
-
|
725
|
+
tools_config = create_tools_config(
|
726
|
+
function_tools=self._gemini_declarations,
|
727
|
+
gemini_tools=self._opts.gemini_tools if is_given(self._opts.gemini_tools) else None,
|
728
|
+
)
|
729
|
+
conf = types.LiveConnectConfig(
|
598
730
|
response_modalities=self._opts.response_modalities
|
599
731
|
if is_given(self._opts.response_modalities)
|
600
|
-
else [Modality.AUDIO],
|
601
|
-
generation_config=GenerationConfig(
|
732
|
+
else [types.Modality.AUDIO],
|
733
|
+
generation_config=types.GenerationConfig(
|
602
734
|
candidate_count=self._opts.candidate_count,
|
603
735
|
temperature=temp,
|
604
736
|
max_output_tokens=self._opts.max_output_tokens
|
@@ -613,41 +745,45 @@ class RealtimeSession(llm.RealtimeSession):
|
|
613
745
|
if is_given(self._opts.frequency_penalty)
|
614
746
|
else None,
|
615
747
|
),
|
616
|
-
system_instruction=Content(parts=[Part(text=self._opts.instructions)])
|
748
|
+
system_instruction=types.Content(parts=[types.Part(text=self._opts.instructions)])
|
617
749
|
if is_given(self._opts.instructions)
|
618
750
|
else None,
|
619
|
-
speech_config=SpeechConfig(
|
620
|
-
voice_config=VoiceConfig(
|
621
|
-
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
|
751
|
+
speech_config=types.SpeechConfig(
|
752
|
+
voice_config=types.VoiceConfig(
|
753
|
+
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=self._opts.voice)
|
622
754
|
),
|
623
755
|
language_code=self._opts.language if is_given(self._opts.language) else None,
|
624
756
|
),
|
625
|
-
tools=
|
757
|
+
tools=tools_config, # type: ignore
|
626
758
|
input_audio_transcription=self._opts.input_audio_transcription,
|
627
759
|
output_audio_transcription=self._opts.output_audio_transcription,
|
628
|
-
session_resumption=SessionResumptionConfig(
|
629
|
-
|
760
|
+
session_resumption=types.SessionResumptionConfig(
|
761
|
+
handle=self._session_resumption_handle
|
762
|
+
),
|
630
763
|
)
|
631
764
|
|
632
765
|
if is_given(self._opts.proactivity):
|
633
|
-
conf.proactivity =
|
766
|
+
conf.proactivity = types.ProactivityConfig(proactive_audio=self._opts.proactivity)
|
634
767
|
if is_given(self._opts.enable_affective_dialog):
|
635
768
|
conf.enable_affective_dialog = self._opts.enable_affective_dialog
|
636
769
|
if is_given(self._opts.realtime_input_config):
|
637
770
|
conf.realtime_input_config = self._opts.realtime_input_config
|
771
|
+
if is_given(self._opts.context_window_compression):
|
772
|
+
conf.context_window_compression = self._opts.context_window_compression
|
638
773
|
|
639
774
|
return conf
|
640
775
|
|
641
|
-
def _start_new_generation(self):
|
776
|
+
def _start_new_generation(self) -> None:
|
642
777
|
if self._current_generation and not self._current_generation._done:
|
643
778
|
logger.warning("starting new generation while another is active. Finalizing previous.")
|
644
779
|
self._mark_current_generation_done()
|
645
780
|
|
646
|
-
response_id = utils.shortuuid("
|
781
|
+
response_id = utils.shortuuid("GR_")
|
647
782
|
self._current_generation = _ResponseGeneration(
|
648
783
|
message_ch=utils.aio.Chan[llm.MessageGeneration](),
|
649
784
|
function_ch=utils.aio.Chan[llm.FunctionCall](),
|
650
785
|
response_id=response_id,
|
786
|
+
input_id=utils.shortuuid("GI_"),
|
651
787
|
text_ch=utils.aio.Chan[str](),
|
652
788
|
audio_ch=utils.aio.Chan[rtc.AudioFrame](),
|
653
789
|
_created_timestamp=time.time(),
|
@@ -674,21 +810,23 @@ class RealtimeSession(llm.RealtimeSession):
|
|
674
810
|
|
675
811
|
self.emit("generation_created", generation_event)
|
676
812
|
|
677
|
-
def _handle_server_content(self, server_content: LiveServerContent):
|
813
|
+
def _handle_server_content(self, server_content: types.LiveServerContent) -> None:
|
678
814
|
current_gen = self._current_generation
|
679
815
|
if not current_gen:
|
680
816
|
logger.warning("received server content but no active generation.")
|
681
817
|
return
|
682
818
|
|
683
819
|
if model_turn := server_content.model_turn:
|
684
|
-
for part in model_turn.parts:
|
820
|
+
for part in model_turn.parts or []:
|
685
821
|
if part.text:
|
686
|
-
current_gen.
|
822
|
+
current_gen.push_text(part.text)
|
687
823
|
if part.inline_data:
|
688
824
|
if not current_gen._first_token_timestamp:
|
689
825
|
current_gen._first_token_timestamp = time.time()
|
690
826
|
frame_data = part.inline_data.data
|
691
827
|
try:
|
828
|
+
if not isinstance(frame_data, bytes):
|
829
|
+
raise ValueError("frame_data is not bytes")
|
692
830
|
frame = rtc.AudioFrame(
|
693
831
|
data=frame_data,
|
694
832
|
sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
|
@@ -710,7 +848,7 @@ class RealtimeSession(llm.RealtimeSession):
|
|
710
848
|
self.emit(
|
711
849
|
"input_audio_transcription_completed",
|
712
850
|
llm.InputTranscriptionCompleted(
|
713
|
-
item_id=current_gen.
|
851
|
+
item_id=current_gen.input_id,
|
714
852
|
transcript=current_gen.input_transcription,
|
715
853
|
is_final=False,
|
716
854
|
),
|
@@ -719,20 +857,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
719
857
|
if output_transcription := server_content.output_transcription:
|
720
858
|
text = output_transcription.text
|
721
859
|
if text:
|
722
|
-
current_gen.
|
860
|
+
current_gen.push_text(text)
|
723
861
|
|
724
|
-
if server_content.generation_complete:
|
725
|
-
# The only way we'd know that the transcription is complete is by when they are
|
726
|
-
# done with generation
|
727
|
-
if current_gen.input_transcription:
|
728
|
-
self.emit(
|
729
|
-
"input_audio_transcription_completed",
|
730
|
-
llm.InputTranscriptionCompleted(
|
731
|
-
item_id=current_gen.response_id,
|
732
|
-
transcript=current_gen.input_transcription,
|
733
|
-
is_final=True,
|
734
|
-
),
|
735
|
-
)
|
862
|
+
if server_content.generation_complete or server_content.turn_complete:
|
736
863
|
current_gen._completed_timestamp = time.time()
|
737
864
|
|
738
865
|
if server_content.interrupted:
|
@@ -742,10 +869,38 @@ class RealtimeSession(llm.RealtimeSession):
|
|
742
869
|
self._mark_current_generation_done()
|
743
870
|
|
744
871
|
def _mark_current_generation_done(self) -> None:
|
745
|
-
if not self._current_generation:
|
872
|
+
if not self._current_generation or self._current_generation._done:
|
746
873
|
return
|
747
874
|
|
748
875
|
gen = self._current_generation
|
876
|
+
|
877
|
+
# The only way we'd know that the transcription is complete is by when they are
|
878
|
+
# done with generation
|
879
|
+
if gen.input_transcription:
|
880
|
+
self.emit(
|
881
|
+
"input_audio_transcription_completed",
|
882
|
+
llm.InputTranscriptionCompleted(
|
883
|
+
item_id=gen.input_id,
|
884
|
+
transcript=gen.input_transcription,
|
885
|
+
is_final=True,
|
886
|
+
),
|
887
|
+
)
|
888
|
+
|
889
|
+
# since gemini doesn't give us a view of the chat history on the server side,
|
890
|
+
# we would handle it manually here
|
891
|
+
self._chat_ctx.add_message(
|
892
|
+
role="user",
|
893
|
+
content=gen.input_transcription,
|
894
|
+
id=gen.input_id,
|
895
|
+
)
|
896
|
+
|
897
|
+
if gen.output_text:
|
898
|
+
self._chat_ctx.add_message(
|
899
|
+
role="assistant",
|
900
|
+
content=gen.output_text,
|
901
|
+
id=gen.response_id,
|
902
|
+
)
|
903
|
+
|
749
904
|
if not gen.text_ch.closed:
|
750
905
|
gen.text_ch.close()
|
751
906
|
if not gen.audio_ch.closed:
|
@@ -755,36 +910,37 @@ class RealtimeSession(llm.RealtimeSession):
|
|
755
910
|
gen.message_ch.close()
|
756
911
|
gen._done = True
|
757
912
|
|
758
|
-
def _handle_input_speech_started(self):
|
913
|
+
def _handle_input_speech_started(self) -> None:
|
759
914
|
self.emit("input_speech_started", llm.InputSpeechStartedEvent())
|
760
915
|
|
761
|
-
def _handle_tool_calls(self, tool_call: LiveServerToolCall):
|
916
|
+
def _handle_tool_calls(self, tool_call: types.LiveServerToolCall) -> None:
|
762
917
|
if not self._current_generation:
|
763
918
|
logger.warning("received tool call but no active generation.")
|
764
919
|
return
|
765
920
|
|
766
921
|
gen = self._current_generation
|
767
|
-
for fnc_call in tool_call.function_calls:
|
922
|
+
for fnc_call in tool_call.function_calls or []:
|
768
923
|
arguments = json.dumps(fnc_call.args)
|
769
924
|
|
770
925
|
gen.function_ch.send_nowait(
|
771
926
|
llm.FunctionCall(
|
772
927
|
call_id=fnc_call.id or utils.shortuuid("fnc-call-"),
|
773
|
-
name=fnc_call.name,
|
928
|
+
name=fnc_call.name, # type: ignore
|
774
929
|
arguments=arguments,
|
775
930
|
)
|
776
931
|
)
|
932
|
+
self._on_final_input_audio_transcription()
|
777
933
|
self._mark_current_generation_done()
|
778
934
|
|
779
935
|
def _handle_tool_call_cancellation(
|
780
|
-
self, tool_call_cancellation: LiveServerToolCallCancellation
|
781
|
-
):
|
936
|
+
self, tool_call_cancellation: types.LiveServerToolCallCancellation
|
937
|
+
) -> None:
|
782
938
|
logger.warning(
|
783
939
|
"server cancelled tool calls",
|
784
940
|
extra={"function_call_ids": tool_call_cancellation.ids},
|
785
941
|
)
|
786
942
|
|
787
|
-
def _handle_usage_metadata(self, usage_metadata: UsageMetadata):
|
943
|
+
def _handle_usage_metadata(self, usage_metadata: types.UsageMetadata) -> None:
|
788
944
|
current_gen = self._current_generation
|
789
945
|
if not current_gen:
|
790
946
|
logger.warning("no active generation to report metrics for")
|
@@ -800,8 +956,8 @@ class RealtimeSession(llm.RealtimeSession):
|
|
800
956
|
) - current_gen._created_timestamp
|
801
957
|
|
802
958
|
def _token_details_map(
|
803
|
-
token_details: list[ModalityTokenCount] | None,
|
804
|
-
) -> dict[
|
959
|
+
token_details: list[types.ModalityTokenCount] | None,
|
960
|
+
) -> dict[str, int]:
|
805
961
|
token_details_map = {"audio_tokens": 0, "text_tokens": 0, "image_tokens": 0}
|
806
962
|
if not token_details:
|
807
963
|
return token_details_map
|
@@ -810,11 +966,11 @@ class RealtimeSession(llm.RealtimeSession):
|
|
810
966
|
if not token_detail.token_count:
|
811
967
|
continue
|
812
968
|
|
813
|
-
if token_detail.modality ==
|
969
|
+
if token_detail.modality == types.MediaModality.AUDIO:
|
814
970
|
token_details_map["audio_tokens"] += token_detail.token_count
|
815
|
-
elif token_detail.modality ==
|
971
|
+
elif token_detail.modality == types.MediaModality.TEXT:
|
816
972
|
token_details_map["text_tokens"] += token_detail.token_count
|
817
|
-
elif token_detail.modality ==
|
973
|
+
elif token_detail.modality == types.MediaModality.IMAGE:
|
818
974
|
token_details_map["image_tokens"] += token_detail.token_count
|
819
975
|
return token_details_map
|
820
976
|
|
@@ -828,7 +984,9 @@ class RealtimeSession(llm.RealtimeSession):
|
|
828
984
|
input_tokens=usage_metadata.prompt_token_count or 0,
|
829
985
|
output_tokens=usage_metadata.response_token_count or 0,
|
830
986
|
total_tokens=usage_metadata.total_token_count or 0,
|
831
|
-
tokens_per_second=(usage_metadata.response_token_count or 0) / duration
|
987
|
+
tokens_per_second=(usage_metadata.response_token_count or 0) / duration
|
988
|
+
if duration > 0
|
989
|
+
else 0,
|
832
990
|
input_token_details=RealtimeModelMetrics.InputTokenDetails(
|
833
991
|
**_token_details_map(usage_metadata.prompt_tokens_details),
|
834
992
|
cached_tokens=sum(
|
@@ -845,18 +1003,27 @@ class RealtimeSession(llm.RealtimeSession):
|
|
845
1003
|
)
|
846
1004
|
self.emit("metrics_collected", metrics)
|
847
1005
|
|
848
|
-
def _handle_go_away(self, go_away: LiveServerGoAway):
|
1006
|
+
def _handle_go_away(self, go_away: types.LiveServerGoAway) -> None:
|
849
1007
|
logger.warning(
|
850
1008
|
f"Gemini server indicates disconnection soon. Time left: {go_away.time_left}"
|
851
1009
|
)
|
852
1010
|
# TODO(dz): this isn't a seamless reconnection just yet
|
853
1011
|
self._session_should_close.set()
|
854
1012
|
|
1013
|
+
def _on_final_input_audio_transcription(self) -> None:
|
1014
|
+
if (gen := self._current_generation) and gen.input_transcription:
|
1015
|
+
self.emit(
|
1016
|
+
"input_audio_transcription_completed",
|
1017
|
+
llm.InputTranscriptionCompleted(
|
1018
|
+
item_id=gen.response_id, transcript=gen.input_transcription, is_final=True
|
1019
|
+
),
|
1020
|
+
)
|
1021
|
+
|
855
1022
|
def commit_audio(self) -> None:
|
856
1023
|
pass
|
857
1024
|
|
858
1025
|
def clear_audio(self) -> None:
|
859
|
-
|
1026
|
+
pass
|
860
1027
|
|
861
1028
|
def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
|
862
1029
|
if self._input_resampler:
|
@@ -879,3 +1046,14 @@ class RealtimeSession(llm.RealtimeSession):
|
|
879
1046
|
yield from self._input_resampler.push(frame)
|
880
1047
|
else:
|
881
1048
|
yield frame
|
1049
|
+
|
1050
|
+
def _emit_error(self, error: Exception, recoverable: bool) -> None:
|
1051
|
+
self.emit(
|
1052
|
+
"error",
|
1053
|
+
llm.RealtimeModelError(
|
1054
|
+
timestamp=time.time(),
|
1055
|
+
label=self._realtime_model._label,
|
1056
|
+
error=error,
|
1057
|
+
recoverable=recoverable,
|
1058
|
+
),
|
1059
|
+
)
|