livekit-plugins-google 0.11.2__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/beta/realtime/__init__.py +1 -5
- livekit/plugins/google/beta/realtime/api_proto.py +2 -4
- livekit/plugins/google/beta/realtime/realtime_api.py +407 -449
- livekit/plugins/google/llm.py +158 -220
- livekit/plugins/google/stt.py +80 -115
- livekit/plugins/google/tts.py +50 -55
- livekit/plugins/google/utils.py +251 -0
- livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-0.11.2.dist-info → livekit_plugins_google-1.0.0.dist-info}/METADATA +11 -21
- livekit_plugins_google-1.0.0.dist-info/RECORD +16 -0
- {livekit_plugins_google-0.11.2.dist-info → livekit_plugins_google-1.0.0.dist-info}/WHEEL +1 -2
- livekit/plugins/google/_utils.py +0 -199
- livekit/plugins/google/beta/realtime/transcriber.py +0 -270
- livekit_plugins_google-0.11.2.dist-info/RECORD +0 -18
- livekit_plugins_google-0.11.2.dist-info/top_level.txt +0 -1
@@ -3,25 +3,22 @@ from __future__ import annotations
|
|
3
3
|
import asyncio
|
4
4
|
import json
|
5
5
|
import os
|
6
|
+
import weakref
|
6
7
|
from dataclasses import dataclass
|
7
|
-
from typing import AsyncIterable, Literal
|
8
|
-
|
9
|
-
from livekit import rtc
|
10
|
-
from livekit.agents import llm, utils
|
11
|
-
from livekit.agents.llm.function_context import _create_ai_function_info
|
12
|
-
from livekit.agents.utils import images
|
13
8
|
|
14
9
|
from google import genai
|
10
|
+
from google.genai._api_client import HttpOptions
|
15
11
|
from google.genai.types import (
|
16
12
|
Blob,
|
17
13
|
Content,
|
18
|
-
|
14
|
+
FunctionDeclaration,
|
19
15
|
GenerationConfig,
|
20
|
-
HttpOptions,
|
21
16
|
LiveClientContent,
|
22
17
|
LiveClientRealtimeInput,
|
23
|
-
LiveClientToolResponse,
|
24
18
|
LiveConnectConfig,
|
19
|
+
LiveServerContent,
|
20
|
+
LiveServerToolCall,
|
21
|
+
LiveServerToolCallCancellation,
|
25
22
|
Modality,
|
26
23
|
Part,
|
27
24
|
PrebuiltVoiceConfig,
|
@@ -29,42 +26,18 @@ from google.genai.types import (
|
|
29
26
|
Tool,
|
30
27
|
VoiceConfig,
|
31
28
|
)
|
29
|
+
from livekit import rtc
|
30
|
+
from livekit.agents import llm, utils
|
31
|
+
from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
32
|
+
from livekit.agents.utils import is_given
|
32
33
|
|
33
34
|
from ...log import logger
|
34
|
-
from
|
35
|
-
|
36
|
-
LiveAPIModels,
|
37
|
-
Voice,
|
38
|
-
_build_gemini_ctx,
|
39
|
-
_build_tools,
|
40
|
-
)
|
41
|
-
from .transcriber import ModelTranscriber, TranscriberSession, TranscriptionContent
|
42
|
-
|
43
|
-
EventTypes = Literal[
|
44
|
-
"start_session",
|
45
|
-
"input_speech_started",
|
46
|
-
"response_content_added",
|
47
|
-
"response_content_done",
|
48
|
-
"function_calls_collected",
|
49
|
-
"function_calls_finished",
|
50
|
-
"function_calls_cancelled",
|
51
|
-
"input_speech_transcription_completed",
|
52
|
-
"agent_speech_transcription_completed",
|
53
|
-
"agent_speech_stopped",
|
54
|
-
]
|
35
|
+
from ...utils import _build_gemini_fnc, get_tool_results_for_realtime, to_chat_ctx
|
36
|
+
from .api_proto import ClientEvents, LiveAPIModels, Voice
|
55
37
|
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
response_id: str
|
60
|
-
item_id: str
|
61
|
-
output_index: int
|
62
|
-
content_index: int
|
63
|
-
text: str
|
64
|
-
audio: list[rtc.AudioFrame]
|
65
|
-
text_stream: AsyncIterable[str]
|
66
|
-
audio_stream: AsyncIterable[rtc.AudioFrame]
|
67
|
-
content_type: Literal["text", "audio"]
|
38
|
+
INPUT_AUDIO_SAMPLE_RATE = 16000
|
39
|
+
OUTPUT_AUDIO_SAMPLE_RATE = 24000
|
40
|
+
NUM_CHANNELS = 1
|
68
41
|
|
69
42
|
|
70
43
|
@dataclass
|
@@ -74,57 +47,59 @@ class InputTranscription:
|
|
74
47
|
|
75
48
|
|
76
49
|
@dataclass
|
77
|
-
class
|
78
|
-
supports_truncate: bool
|
79
|
-
input_audio_sample_rate: int | None = None
|
80
|
-
|
81
|
-
|
82
|
-
@dataclass
|
83
|
-
class ModelOptions:
|
50
|
+
class _RealtimeOptions:
|
84
51
|
model: LiveAPIModels | str
|
85
52
|
api_key: str | None
|
86
|
-
api_version: str
|
87
53
|
voice: Voice | str
|
88
|
-
response_modalities: list[Modality]
|
54
|
+
response_modalities: NotGivenOr[list[Modality]]
|
89
55
|
vertexai: bool
|
90
56
|
project: str | None
|
91
57
|
location: str | None
|
92
58
|
candidate_count: int
|
93
|
-
temperature: float
|
94
|
-
max_output_tokens: int
|
95
|
-
top_p: float
|
96
|
-
top_k: int
|
97
|
-
presence_penalty: float
|
98
|
-
frequency_penalty: float
|
99
|
-
instructions:
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
59
|
+
temperature: NotGivenOr[float]
|
60
|
+
max_output_tokens: NotGivenOr[int]
|
61
|
+
top_p: NotGivenOr[float]
|
62
|
+
top_k: NotGivenOr[int]
|
63
|
+
presence_penalty: NotGivenOr[float]
|
64
|
+
frequency_penalty: NotGivenOr[float]
|
65
|
+
instructions: NotGivenOr[str]
|
66
|
+
|
67
|
+
|
68
|
+
@dataclass
|
69
|
+
class _MessageGeneration:
|
70
|
+
message_id: str
|
71
|
+
text_ch: utils.aio.Chan[str]
|
72
|
+
audio_ch: utils.aio.Chan[rtc.AudioFrame]
|
73
|
+
|
74
|
+
|
75
|
+
@dataclass
|
76
|
+
class _ResponseGeneration:
|
77
|
+
message_ch: utils.aio.Chan[llm.MessageGeneration]
|
78
|
+
function_ch: utils.aio.Chan[llm.FunctionCall]
|
79
|
+
|
80
|
+
messages: dict[str, _MessageGeneration]
|
81
|
+
|
82
|
+
|
83
|
+
class RealtimeModel(llm.RealtimeModel):
|
105
84
|
def __init__(
|
106
85
|
self,
|
107
86
|
*,
|
108
|
-
instructions: str
|
87
|
+
instructions: NotGivenOr[str] = NOT_GIVEN,
|
109
88
|
model: LiveAPIModels | str = "gemini-2.0-flash-exp",
|
110
|
-
api_key: str
|
111
|
-
api_version: str = "v1alpha",
|
89
|
+
api_key: NotGivenOr[str] = NOT_GIVEN,
|
112
90
|
voice: Voice | str = "Puck",
|
113
|
-
modalities: list[Modality] =
|
114
|
-
enable_user_audio_transcription: bool = True,
|
115
|
-
enable_agent_audio_transcription: bool = True,
|
91
|
+
modalities: NotGivenOr[list[Modality]] = NOT_GIVEN,
|
116
92
|
vertexai: bool = False,
|
117
|
-
project: str
|
118
|
-
location: str
|
93
|
+
project: NotGivenOr[str] = NOT_GIVEN,
|
94
|
+
location: NotGivenOr[str] = NOT_GIVEN,
|
119
95
|
candidate_count: int = 1,
|
120
|
-
temperature: float
|
121
|
-
max_output_tokens: int
|
122
|
-
top_p: float
|
123
|
-
top_k: int
|
124
|
-
presence_penalty: float
|
125
|
-
frequency_penalty: float
|
126
|
-
|
127
|
-
):
|
96
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
97
|
+
max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
|
98
|
+
top_p: NotGivenOr[float] = NOT_GIVEN,
|
99
|
+
top_k: NotGivenOr[int] = NOT_GIVEN,
|
100
|
+
presence_penalty: NotGivenOr[float] = NOT_GIVEN,
|
101
|
+
frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
|
102
|
+
) -> None:
|
128
103
|
"""
|
129
104
|
Initializes a RealtimeModel instance for interacting with Google's Realtime API.
|
130
105
|
|
@@ -137,68 +112,57 @@ class RealtimeModel:
|
|
137
112
|
|
138
113
|
Args:
|
139
114
|
instructions (str, optional): Initial system instructions for the model. Defaults to "".
|
140
|
-
api_key (str
|
141
|
-
api_version (str, optional): The version of the API to use. Defaults to "v1alpha".
|
115
|
+
api_key (str, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
|
142
116
|
modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
|
143
|
-
model (str
|
117
|
+
model (str, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
|
144
118
|
voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
|
145
|
-
enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True
|
146
|
-
enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True
|
147
119
|
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
|
148
120
|
vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
|
149
|
-
project (str
|
150
|
-
location (str
|
121
|
+
project (str, optional): The project id to use for the API. Defaults to None. (for vertexai)
|
122
|
+
location (str, optional): The location to use for the API. Defaults to None. (for vertexai)
|
151
123
|
candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
|
152
124
|
top_p (float, optional): The top-p value for response generation
|
153
125
|
top_k (int, optional): The top-k value for response generation
|
154
126
|
presence_penalty (float, optional): The presence penalty for response generation
|
155
127
|
frequency_penalty (float, optional): The frequency penalty for response generation
|
156
|
-
loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used.
|
157
128
|
|
158
129
|
Raises:
|
159
|
-
ValueError: If the API key is
|
160
|
-
"""
|
161
|
-
super().__init__(
|
162
|
-
|
163
|
-
|
164
|
-
|
130
|
+
ValueError: If the API key is required but not found.
|
131
|
+
""" # noqa: E501
|
132
|
+
super().__init__(
|
133
|
+
capabilities=llm.RealtimeCapabilities(
|
134
|
+
message_truncation=False,
|
135
|
+
turn_detection=True,
|
136
|
+
user_transcription=False,
|
137
|
+
)
|
165
138
|
)
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION")
|
139
|
+
|
140
|
+
gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
|
141
|
+
gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
|
142
|
+
gcp_location = location if is_given(location) else os.environ.get("GOOGLE_CLOUD_LOCATION")
|
171
143
|
if vertexai:
|
172
|
-
if not
|
144
|
+
if not gcp_project or not gcp_location:
|
173
145
|
raise ValueError(
|
174
|
-
"Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables"
|
146
|
+
"Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables" # noqa: E501
|
175
147
|
)
|
176
|
-
|
148
|
+
gemini_api_key = None # VertexAI does not require an API key
|
177
149
|
|
178
150
|
else:
|
179
|
-
|
180
|
-
|
181
|
-
if not
|
151
|
+
gcp_project = None
|
152
|
+
gcp_location = None
|
153
|
+
if not gemini_api_key:
|
182
154
|
raise ValueError(
|
183
|
-
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
|
155
|
+
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" # noqa: E501
|
184
156
|
)
|
185
157
|
|
186
|
-
|
187
|
-
Content(parts=[Part(text=instructions)]) if instructions else None
|
188
|
-
)
|
189
|
-
|
190
|
-
self._rt_sessions: list[GeminiRealtimeSession] = []
|
191
|
-
self._opts = ModelOptions(
|
158
|
+
self._opts = _RealtimeOptions(
|
192
159
|
model=model,
|
193
|
-
|
194
|
-
api_key=self._api_key,
|
160
|
+
api_key=gemini_api_key,
|
195
161
|
voice=voice,
|
196
|
-
enable_user_audio_transcription=enable_user_audio_transcription,
|
197
|
-
enable_agent_audio_transcription=enable_agent_audio_transcription,
|
198
162
|
response_modalities=modalities,
|
199
163
|
vertexai=vertexai,
|
200
|
-
project=
|
201
|
-
location=
|
164
|
+
project=gcp_project,
|
165
|
+
location=gcp_location,
|
202
166
|
candidate_count=candidate_count,
|
203
167
|
temperature=temperature,
|
204
168
|
max_output_tokens=max_output_tokens,
|
@@ -206,387 +170,381 @@ class RealtimeModel:
|
|
206
170
|
top_k=top_k,
|
207
171
|
presence_penalty=presence_penalty,
|
208
172
|
frequency_penalty=frequency_penalty,
|
209
|
-
instructions=
|
173
|
+
instructions=instructions,
|
210
174
|
)
|
211
175
|
|
212
|
-
|
213
|
-
def sessions(self) -> list[GeminiRealtimeSession]:
|
214
|
-
return self._rt_sessions
|
176
|
+
self._sessions = weakref.WeakSet[RealtimeSession]()
|
215
177
|
|
216
|
-
|
217
|
-
|
218
|
-
|
178
|
+
def session(self) -> RealtimeSession:
|
179
|
+
sess = RealtimeSession(self)
|
180
|
+
self._sessions.add(sess)
|
181
|
+
return sess
|
219
182
|
|
220
|
-
def
|
221
|
-
self,
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
) -> GeminiRealtimeSession:
|
226
|
-
session = GeminiRealtimeSession(
|
227
|
-
opts=self._opts,
|
228
|
-
chat_ctx=chat_ctx or llm.ChatContext(),
|
229
|
-
fnc_ctx=fnc_ctx,
|
230
|
-
loop=self._loop,
|
231
|
-
)
|
232
|
-
self._rt_sessions.append(session)
|
183
|
+
def update_options(
|
184
|
+
self, *, voice: NotGivenOr[str] = NOT_GIVEN, temperature: NotGivenOr[float] = NOT_GIVEN
|
185
|
+
) -> None:
|
186
|
+
if is_given(voice):
|
187
|
+
self._opts.voice = voice
|
233
188
|
|
234
|
-
|
189
|
+
if is_given(temperature):
|
190
|
+
self._opts.temperature = temperature
|
235
191
|
|
236
|
-
|
237
|
-
|
238
|
-
await session.aclose()
|
192
|
+
for sess in self._sessions:
|
193
|
+
sess.update_options(voice=self._opts.voice, temperature=self._opts.temperature)
|
239
194
|
|
195
|
+
async def aclose(self) -> None: ...
|
240
196
|
|
241
|
-
class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
242
|
-
def __init__(
|
243
|
-
self,
|
244
|
-
*,
|
245
|
-
opts: ModelOptions,
|
246
|
-
chat_ctx: llm.ChatContext,
|
247
|
-
fnc_ctx: llm.FunctionContext | None,
|
248
|
-
loop: asyncio.AbstractEventLoop,
|
249
|
-
):
|
250
|
-
"""
|
251
|
-
Initializes a GeminiRealtimeSession instance for interacting with Google's Realtime API.
|
252
197
|
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
self.
|
261
|
-
self._opts = opts
|
262
|
-
self._chat_ctx = chat_ctx
|
263
|
-
self._fnc_ctx = fnc_ctx
|
264
|
-
self._fnc_tasks = utils.aio.TaskSet()
|
265
|
-
self._is_interrupted = False
|
266
|
-
self._playout_complete = asyncio.Event()
|
267
|
-
self._playout_complete.set()
|
268
|
-
|
269
|
-
tools = []
|
270
|
-
if self._fnc_ctx is not None:
|
271
|
-
functions = _build_tools(self._fnc_ctx)
|
272
|
-
tools.append(Tool(function_declarations=functions))
|
273
|
-
|
274
|
-
self._config = LiveConnectConfig(
|
275
|
-
response_modalities=self._opts.response_modalities,
|
276
|
-
generation_config=GenerationConfig(
|
277
|
-
candidate_count=self._opts.candidate_count,
|
278
|
-
temperature=self._opts.temperature,
|
279
|
-
max_output_tokens=self._opts.max_output_tokens,
|
280
|
-
top_p=self._opts.top_p,
|
281
|
-
top_k=self._opts.top_k,
|
282
|
-
presence_penalty=self._opts.presence_penalty,
|
283
|
-
frequency_penalty=self._opts.frequency_penalty,
|
284
|
-
),
|
285
|
-
system_instruction=self._opts.instructions,
|
286
|
-
speech_config=SpeechConfig(
|
287
|
-
voice_config=VoiceConfig(
|
288
|
-
prebuilt_voice_config=PrebuiltVoiceConfig(
|
289
|
-
voice_name=self._opts.voice
|
290
|
-
)
|
291
|
-
)
|
292
|
-
),
|
293
|
-
tools=tools,
|
294
|
-
)
|
198
|
+
class RealtimeSession(llm.RealtimeSession):
|
199
|
+
def __init__(self, realtime_model: RealtimeModel) -> None:
|
200
|
+
super().__init__(realtime_model)
|
201
|
+
self._opts = realtime_model._opts
|
202
|
+
self._tools = llm.ToolContext.empty()
|
203
|
+
self._chat_ctx = llm.ChatContext.empty()
|
204
|
+
self._msg_ch = utils.aio.Chan[ClientEvents]()
|
205
|
+
self._gemini_tools: list[Tool] = []
|
295
206
|
self._client = genai.Client(
|
296
|
-
http_options=HttpOptions(api_version=
|
207
|
+
http_options=HttpOptions(api_version="v1alpha"),
|
297
208
|
api_key=self._opts.api_key,
|
298
209
|
vertexai=self._opts.vertexai,
|
299
210
|
project=self._opts.project,
|
300
211
|
location=self._opts.location,
|
301
212
|
)
|
302
|
-
self._main_atask = asyncio.create_task(
|
303
|
-
self._main_task(), name="gemini-realtime-session"
|
304
|
-
)
|
305
|
-
if self._opts.enable_user_audio_transcription:
|
306
|
-
self._transcriber = TranscriberSession(
|
307
|
-
client=self._client, model=self._opts.model
|
308
|
-
)
|
309
|
-
self._transcriber.on("input_speech_done", self._on_input_speech_done)
|
310
|
-
if self._opts.enable_agent_audio_transcription:
|
311
|
-
self._agent_transcriber = ModelTranscriber(
|
312
|
-
client=self._client, model=self._opts.model
|
313
|
-
)
|
314
|
-
self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done)
|
315
|
-
# init dummy task
|
316
|
-
self._init_sync_task = asyncio.create_task(asyncio.sleep(0))
|
317
|
-
self._send_ch = utils.aio.Chan[ClientEvents]()
|
318
|
-
self._active_response_id = None
|
213
|
+
self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
|
319
214
|
|
320
|
-
|
321
|
-
if self._send_ch.closed:
|
322
|
-
return
|
215
|
+
self._current_generation: _ResponseGeneration | None = None
|
323
216
|
|
324
|
-
self.
|
325
|
-
|
217
|
+
self._is_interrupted = False
|
218
|
+
self._active_response_id = None
|
219
|
+
self._session = None
|
220
|
+
self._update_chat_ctx_lock = asyncio.Lock()
|
221
|
+
self._update_fnc_ctx_lock = asyncio.Lock()
|
222
|
+
self._response_created_futures: dict[str, asyncio.Future[llm.GenerationCreatedEvent]] = {}
|
223
|
+
self._pending_generation_event_id = None
|
224
|
+
|
225
|
+
self._reconnect_event = asyncio.Event()
|
226
|
+
self._session_lock = asyncio.Lock()
|
227
|
+
self._gemini_close_task: asyncio.Task | None = None
|
228
|
+
|
229
|
+
def _schedule_gemini_session_close(self) -> None:
|
230
|
+
if self._session is not None:
|
231
|
+
self._gemini_close_task = asyncio.create_task(self._close_gemini_session())
|
232
|
+
|
233
|
+
async def _close_gemini_session(self) -> None:
|
234
|
+
async with self._session_lock:
|
235
|
+
if self._session:
|
236
|
+
try:
|
237
|
+
await self._session.close()
|
238
|
+
finally:
|
239
|
+
self._session = None
|
240
|
+
|
241
|
+
def update_options(
|
242
|
+
self,
|
243
|
+
*,
|
244
|
+
voice: NotGivenOr[str] = NOT_GIVEN,
|
245
|
+
tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
|
246
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
247
|
+
) -> None:
|
248
|
+
if is_given(voice):
|
249
|
+
self._opts.voice = voice
|
250
|
+
|
251
|
+
if is_given(temperature):
|
252
|
+
self._opts.temperature = temperature
|
253
|
+
|
254
|
+
if self._session:
|
255
|
+
logger.warning("Updating options; triggering Gemini session reconnect.")
|
256
|
+
self._reconnect_event.set()
|
257
|
+
self._schedule_gemini_session_close()
|
258
|
+
|
259
|
+
async def update_instructions(self, instructions: str) -> None:
|
260
|
+
self._opts.instructions = instructions
|
261
|
+
if self._session:
|
262
|
+
logger.warning("Updating instructions; triggering Gemini session reconnect.")
|
263
|
+
self._reconnect_event.set()
|
264
|
+
self._schedule_gemini_session_close()
|
265
|
+
|
266
|
+
async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
|
267
|
+
async with self._update_chat_ctx_lock:
|
268
|
+
self._chat_ctx = chat_ctx
|
269
|
+
turns, _ = to_chat_ctx(self._chat_ctx, id(self), ignore_functions=True)
|
270
|
+
tool_results = get_tool_results_for_realtime(self._chat_ctx)
|
271
|
+
if turns:
|
272
|
+
self._msg_ch.send_nowait(LiveClientContent(turns=turns, turn_complete=False))
|
273
|
+
if tool_results:
|
274
|
+
self._msg_ch.send_nowait(tool_results)
|
275
|
+
|
276
|
+
async def update_tools(self, tools: list[llm.FunctionTool]) -> None:
|
277
|
+
async with self._update_fnc_ctx_lock:
|
278
|
+
retained_tools: list[llm.FunctionTool] = []
|
279
|
+
gemini_function_declarations: list[FunctionDeclaration] = []
|
280
|
+
|
281
|
+
for tool in tools:
|
282
|
+
gemini_function = _build_gemini_fnc(tool)
|
283
|
+
gemini_function_declarations.append(gemini_function)
|
284
|
+
retained_tools.append(tool)
|
285
|
+
|
286
|
+
self._tools = llm.ToolContext(retained_tools)
|
287
|
+
self._gemini_tools = [Tool(function_declarations=gemini_function_declarations)]
|
288
|
+
if self._session and gemini_function_declarations:
|
289
|
+
logger.warning("Updating tools; triggering Gemini session reconnect.")
|
290
|
+
self._reconnect_event.set()
|
291
|
+
self._schedule_gemini_session_close()
|
326
292
|
|
327
293
|
@property
|
328
|
-
def
|
329
|
-
return self.
|
294
|
+
def chat_ctx(self) -> llm.ChatContext:
|
295
|
+
return self._chat_ctx
|
330
296
|
|
331
297
|
@property
|
332
|
-
def
|
333
|
-
return self.
|
334
|
-
|
335
|
-
@fnc_ctx.setter
|
336
|
-
def fnc_ctx(self, value: llm.FunctionContext | None) -> None:
|
337
|
-
self._fnc_ctx = value
|
298
|
+
def tools(self) -> llm.ToolContext:
|
299
|
+
return self._tools
|
338
300
|
|
339
|
-
def
|
301
|
+
def push_audio(self, frame: rtc.AudioFrame) -> None:
|
340
302
|
realtime_input = LiveClientRealtimeInput(
|
341
|
-
media_chunks=[Blob(data=data, mime_type=
|
303
|
+
media_chunks=[Blob(data=frame.data.tobytes(), mime_type="audio/pcm")],
|
342
304
|
)
|
343
|
-
self.
|
305
|
+
self._msg_ch.send_nowait(realtime_input)
|
344
306
|
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
width=1024, height=1024, strategy="scale_aspect_fit"
|
350
|
-
),
|
351
|
-
)
|
307
|
+
def generate_reply(
|
308
|
+
self, *, instructions: NotGivenOr[str] = NOT_GIVEN
|
309
|
+
) -> asyncio.Future[llm.GenerationCreatedEvent]:
|
310
|
+
fut = asyncio.Future()
|
352
311
|
|
353
|
-
|
354
|
-
self
|
355
|
-
|
356
|
-
encode_options: images.EncodeOptions = DEFAULT_ENCODE_OPTIONS,
|
357
|
-
) -> None:
|
358
|
-
"""Push a video frame to the Gemini Multimodal Live session.
|
312
|
+
event_id = utils.shortuuid("gemini-response-")
|
313
|
+
self._response_created_futures[event_id] = fut
|
314
|
+
self._pending_generation_event_id = event_id
|
359
315
|
|
360
|
-
|
361
|
-
|
362
|
-
|
316
|
+
instructions_content = instructions if is_given(instructions) else "."
|
317
|
+
ctx = [Content(parts=[Part(text=instructions_content)], role="user")]
|
318
|
+
self._msg_ch.send_nowait(LiveClientContent(turns=ctx, turn_complete=True))
|
363
319
|
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
)
|
371
|
-
mime_type = (
|
372
|
-
"image/jpeg"
|
373
|
-
if encode_options.format == "JPEG"
|
374
|
-
else "image/png"
|
375
|
-
if encode_options.format == "PNG"
|
376
|
-
else "image/jpeg"
|
377
|
-
)
|
378
|
-
self._push_media_chunk(encoded_data, mime_type)
|
320
|
+
def _on_timeout() -> None:
|
321
|
+
if event_id in self._response_created_futures and not fut.done():
|
322
|
+
fut.set_exception(llm.RealtimeError("generate_reply timed out."))
|
323
|
+
self._response_created_futures.pop(event_id, None)
|
324
|
+
if self._pending_generation_event_id == event_id:
|
325
|
+
self._pending_generation_event_id = None
|
379
326
|
|
380
|
-
|
381
|
-
|
382
|
-
self._transcriber._push_audio(frame)
|
327
|
+
handle = asyncio.get_event_loop().call_later(5.0, _on_timeout)
|
328
|
+
fut.add_done_callback(lambda _: handle.cancel())
|
383
329
|
|
384
|
-
|
330
|
+
return fut
|
385
331
|
|
386
|
-
def
|
387
|
-
|
332
|
+
def interrupt(self) -> None:
|
333
|
+
logger.warning("interrupt() - no direct cancellation in Gemini")
|
388
334
|
|
389
|
-
def
|
390
|
-
|
335
|
+
def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
|
336
|
+
logger.warning(f"truncate(...) called for {message_id}, ignoring for Gemini")
|
391
337
|
|
392
|
-
async def
|
393
|
-
self.
|
338
|
+
async def aclose(self) -> None:
|
339
|
+
self._msg_ch.close()
|
394
340
|
|
395
|
-
|
396
|
-
|
341
|
+
for fut in self._response_created_futures.values():
|
342
|
+
if not fut.done():
|
343
|
+
fut.set_exception(llm.RealtimeError("Session closed"))
|
397
344
|
|
398
|
-
|
399
|
-
|
400
|
-
on_duplicate: Literal[
|
401
|
-
"cancel_existing", "cancel_new", "keep_both"
|
402
|
-
] = "keep_both",
|
403
|
-
) -> None:
|
404
|
-
turns, _ = _build_gemini_ctx(self._chat_ctx, id(self))
|
405
|
-
ctx = [self._opts.instructions] + turns if self._opts.instructions else turns
|
345
|
+
if self._main_atask:
|
346
|
+
await utils.aio.cancel_and_wait(self._main_atask)
|
406
347
|
|
407
|
-
if
|
408
|
-
|
409
|
-
|
348
|
+
if self._gemini_close_task:
|
349
|
+
await utils.aio.cancel_and_wait(self._gemini_close_task)
|
350
|
+
|
351
|
+
@utils.log_exceptions(logger=logger)
|
352
|
+
async def _main_task(self):
|
353
|
+
while True:
|
354
|
+
config = LiveConnectConfig(
|
355
|
+
response_modalities=self._opts.response_modalities
|
356
|
+
if is_given(self._opts.response_modalities)
|
357
|
+
else [Modality.AUDIO],
|
358
|
+
generation_config=GenerationConfig(
|
359
|
+
candidate_count=self._opts.candidate_count,
|
360
|
+
temperature=self._opts.temperature
|
361
|
+
if is_given(self._opts.temperature)
|
362
|
+
else None,
|
363
|
+
max_output_tokens=self._opts.max_output_tokens
|
364
|
+
if is_given(self._opts.max_output_tokens)
|
365
|
+
else None,
|
366
|
+
top_p=self._opts.top_p if is_given(self._opts.top_p) else None,
|
367
|
+
top_k=self._opts.top_k if is_given(self._opts.top_k) else None,
|
368
|
+
presence_penalty=self._opts.presence_penalty
|
369
|
+
if is_given(self._opts.presence_penalty)
|
370
|
+
else None,
|
371
|
+
frequency_penalty=self._opts.frequency_penalty
|
372
|
+
if is_given(self._opts.frequency_penalty)
|
373
|
+
else None,
|
374
|
+
),
|
375
|
+
system_instruction=Content(parts=[Part(text=self._opts.instructions)])
|
376
|
+
if is_given(self._opts.instructions)
|
377
|
+
else None,
|
378
|
+
speech_config=SpeechConfig(
|
379
|
+
voice_config=VoiceConfig(
|
380
|
+
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
|
381
|
+
)
|
382
|
+
),
|
383
|
+
tools=self._gemini_tools,
|
410
384
|
)
|
411
|
-
ctx = [Content(parts=[Part(text=".")])]
|
412
385
|
|
413
|
-
|
386
|
+
async with self._client.aio.live.connect(
|
387
|
+
model=self._opts.model, config=config
|
388
|
+
) as session:
|
389
|
+
async with self._session_lock:
|
390
|
+
self._session = session
|
391
|
+
|
392
|
+
@utils.log_exceptions(logger=logger)
|
393
|
+
async def _send_task():
|
394
|
+
async for msg in self._msg_ch:
|
395
|
+
if isinstance(msg, LiveClientContent):
|
396
|
+
await session.send(input=msg, end_of_turn=True)
|
397
|
+
|
398
|
+
await session.send(input=msg)
|
399
|
+
await session.send(input=".", end_of_turn=True)
|
400
|
+
|
401
|
+
@utils.log_exceptions(logger=logger)
|
402
|
+
async def _recv_task():
|
403
|
+
while True:
|
404
|
+
async for response in session.receive():
|
405
|
+
if self._active_response_id is None:
|
406
|
+
self._start_new_generation()
|
407
|
+
if response.server_content:
|
408
|
+
self._handle_server_content(response.server_content)
|
409
|
+
if response.tool_call:
|
410
|
+
self._handle_tool_calls(response.tool_call)
|
411
|
+
if response.tool_call_cancellation:
|
412
|
+
self._handle_tool_call_cancellation(response.tool_call_cancellation)
|
413
|
+
|
414
|
+
send_task = asyncio.create_task(_send_task(), name="gemini-realtime-send")
|
415
|
+
recv_task = asyncio.create_task(_recv_task(), name="gemini-realtime-recv")
|
416
|
+
reconnect_task = asyncio.create_task(
|
417
|
+
self._reconnect_event.wait(), name="reconnect-wait"
|
418
|
+
)
|
414
419
|
|
415
|
-
|
416
|
-
|
420
|
+
try:
|
421
|
+
done, _ = await asyncio.wait(
|
422
|
+
[send_task, recv_task, reconnect_task],
|
423
|
+
return_when=asyncio.FIRST_COMPLETED,
|
424
|
+
)
|
425
|
+
for task in done:
|
426
|
+
if task != reconnect_task:
|
427
|
+
task.result()
|
417
428
|
|
418
|
-
|
419
|
-
|
429
|
+
if reconnect_task not in done:
|
430
|
+
break
|
420
431
|
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
432
|
+
self._reconnect_event.clear()
|
433
|
+
finally:
|
434
|
+
await utils.aio.cancel_and_wait(send_task, recv_task, reconnect_task)
|
435
|
+
|
436
|
+
def _start_new_generation(self):
|
437
|
+
self._is_interrupted = False
|
438
|
+
self._active_response_id = utils.shortuuid("gemini-turn-")
|
439
|
+
self._current_generation = _ResponseGeneration(
|
440
|
+
message_ch=utils.aio.Chan[llm.MessageGeneration](),
|
441
|
+
function_ch=utils.aio.Chan[llm.FunctionCall](),
|
442
|
+
messages={},
|
443
|
+
)
|
444
|
+
|
445
|
+
# We'll assume each chunk belongs to a single message ID self._active_response_id
|
446
|
+
item_generation = _MessageGeneration(
|
447
|
+
message_id=self._active_response_id,
|
448
|
+
text_ch=utils.aio.Chan[str](),
|
449
|
+
audio_ch=utils.aio.Chan[rtc.AudioFrame](),
|
450
|
+
)
|
451
|
+
|
452
|
+
self._current_generation.message_ch.send_nowait(
|
453
|
+
llm.MessageGeneration(
|
454
|
+
message_id=self._active_response_id,
|
455
|
+
text_stream=item_generation.text_ch,
|
456
|
+
audio_stream=item_generation.audio_ch,
|
429
457
|
)
|
458
|
+
)
|
430
459
|
|
431
|
-
|
432
|
-
|
460
|
+
generation_event = llm.GenerationCreatedEvent(
|
461
|
+
message_stream=self._current_generation.message_ch,
|
462
|
+
function_stream=self._current_generation.function_ch,
|
463
|
+
user_initiated=False,
|
464
|
+
)
|
433
465
|
|
434
|
-
|
435
|
-
if
|
436
|
-
self.
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
466
|
+
# Resolve any pending future from generate_reply()
|
467
|
+
if self._pending_generation_event_id and (
|
468
|
+
fut := self._response_created_futures.pop(self._pending_generation_event_id, None)
|
469
|
+
):
|
470
|
+
fut.set_result(generation_event)
|
471
|
+
|
472
|
+
self._pending_generation_event_id = None
|
473
|
+
self.emit("generation_created", generation_event)
|
474
|
+
|
475
|
+
self._current_generation.messages[self._active_response_id] = item_generation
|
476
|
+
|
477
|
+
def _handle_server_content(self, server_content: LiveServerContent):
|
478
|
+
if not self._current_generation or not self._active_response_id:
|
479
|
+
logger.warning(
|
480
|
+
"gemini-realtime-session: No active response ID, skipping server content"
|
442
481
|
)
|
443
|
-
|
482
|
+
return
|
444
483
|
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
text_stream = utils.aio.Chan[str]()
|
462
|
-
audio_stream = utils.aio.Chan[rtc.AudioFrame]()
|
463
|
-
content = GeminiContent(
|
464
|
-
response_id=self._active_response_id,
|
465
|
-
item_id=self._active_response_id,
|
466
|
-
output_index=0,
|
467
|
-
content_index=0,
|
468
|
-
text="",
|
469
|
-
audio=[],
|
470
|
-
text_stream=text_stream,
|
471
|
-
audio_stream=audio_stream,
|
472
|
-
content_type="audio",
|
473
|
-
)
|
474
|
-
self.emit("response_content_added", content)
|
475
|
-
|
476
|
-
server_content = response.server_content
|
477
|
-
if server_content:
|
478
|
-
model_turn = server_content.model_turn
|
479
|
-
if model_turn:
|
480
|
-
for part in model_turn.parts:
|
481
|
-
if part.text:
|
482
|
-
content.text_stream.send_nowait(part.text)
|
483
|
-
if part.inline_data:
|
484
|
-
frame = rtc.AudioFrame(
|
485
|
-
data=part.inline_data.data,
|
486
|
-
sample_rate=24000,
|
487
|
-
num_channels=1,
|
488
|
-
samples_per_channel=len(part.inline_data.data)
|
489
|
-
// 2,
|
490
|
-
)
|
491
|
-
if self._opts.enable_agent_audio_transcription:
|
492
|
-
content.audio.append(frame)
|
493
|
-
content.audio_stream.send_nowait(frame)
|
494
|
-
|
495
|
-
if server_content.interrupted or server_content.turn_complete:
|
496
|
-
if self._opts.enable_agent_audio_transcription:
|
497
|
-
self._agent_transcriber._push_audio(content.audio)
|
498
|
-
for stream in (content.text_stream, content.audio_stream):
|
499
|
-
if isinstance(stream, utils.aio.Chan):
|
500
|
-
stream.close()
|
501
|
-
|
502
|
-
self.emit("agent_speech_stopped")
|
503
|
-
self._is_interrupted = True
|
504
|
-
|
505
|
-
self._active_response_id = None
|
506
|
-
|
507
|
-
if response.tool_call:
|
508
|
-
if self._fnc_ctx is None:
|
509
|
-
raise ValueError("Function context is not set")
|
510
|
-
fnc_calls = []
|
511
|
-
for fnc_call in response.tool_call.function_calls:
|
512
|
-
fnc_call_info = _create_ai_function_info(
|
513
|
-
self._fnc_ctx,
|
514
|
-
fnc_call.id,
|
515
|
-
fnc_call.name,
|
516
|
-
json.dumps(fnc_call.args),
|
517
|
-
)
|
518
|
-
fnc_calls.append(fnc_call_info)
|
519
|
-
|
520
|
-
self.emit("function_calls_collected", fnc_calls)
|
521
|
-
|
522
|
-
for fnc_call_info in fnc_calls:
|
523
|
-
self._fnc_tasks.create_task(
|
524
|
-
self._run_fnc_task(fnc_call_info, content.item_id)
|
525
|
-
)
|
526
|
-
|
527
|
-
# Handle function call cancellations
|
528
|
-
if response.tool_call_cancellation:
|
529
|
-
logger.warning(
|
530
|
-
"function call cancelled",
|
531
|
-
extra={
|
532
|
-
"function_call_ids": response.tool_call_cancellation.ids,
|
533
|
-
},
|
534
|
-
)
|
535
|
-
self.emit(
|
536
|
-
"function_calls_cancelled",
|
537
|
-
response.tool_call_cancellation.ids,
|
538
|
-
)
|
539
|
-
|
540
|
-
async with self._client.aio.live.connect(
|
541
|
-
model=self._opts.model, config=self._config
|
542
|
-
) as session:
|
543
|
-
self._session = session
|
544
|
-
tasks = [
|
545
|
-
asyncio.create_task(_send_task(), name="gemini-realtime-send"),
|
546
|
-
asyncio.create_task(_recv_task(), name="gemini-realtime-recv"),
|
547
|
-
]
|
548
|
-
|
549
|
-
try:
|
550
|
-
await asyncio.gather(*tasks)
|
551
|
-
finally:
|
552
|
-
await utils.aio.gracefully_cancel(*tasks)
|
553
|
-
await self._session.close()
|
554
|
-
if self._opts.enable_user_audio_transcription:
|
555
|
-
await self._transcriber.aclose()
|
556
|
-
if self._opts.enable_agent_audio_transcription:
|
557
|
-
await self._agent_transcriber.aclose()
|
484
|
+
item_generation = self._current_generation.messages[self._active_response_id]
|
485
|
+
|
486
|
+
model_turn = server_content.model_turn
|
487
|
+
if model_turn:
|
488
|
+
for part in model_turn.parts:
|
489
|
+
if part.text:
|
490
|
+
item_generation.text_ch.send_nowait(part.text)
|
491
|
+
if part.inline_data:
|
492
|
+
frame_data = part.inline_data.data
|
493
|
+
frame = rtc.AudioFrame(
|
494
|
+
data=frame_data,
|
495
|
+
sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
|
496
|
+
num_channels=NUM_CHANNELS,
|
497
|
+
samples_per_channel=len(frame_data) // 2,
|
498
|
+
)
|
499
|
+
item_generation.audio_ch.send_nowait(frame)
|
558
500
|
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
501
|
+
if server_content.interrupted or server_content.turn_complete:
|
502
|
+
self._finalize_response()
|
503
|
+
|
504
|
+
def _finalize_response(self) -> None:
|
505
|
+
if not self._current_generation:
|
506
|
+
return
|
507
|
+
|
508
|
+
for item_generation in self._current_generation.messages.values():
|
509
|
+
item_generation.text_ch.close()
|
510
|
+
item_generation.audio_ch.close()
|
511
|
+
|
512
|
+
self._current_generation.function_ch.close()
|
513
|
+
self._current_generation.message_ch.close()
|
514
|
+
self._current_generation = None
|
515
|
+
self._is_interrupted = True
|
516
|
+
self._active_response_id = None
|
517
|
+
self.emit("agent_speech_stopped")
|
518
|
+
|
519
|
+
def _handle_tool_calls(self, tool_call: LiveServerToolCall):
|
520
|
+
if not self._current_generation:
|
521
|
+
return
|
522
|
+
for fnc_call in tool_call.function_calls:
|
523
|
+
self._current_generation.function_ch.send_nowait(
|
524
|
+
llm.FunctionCall(
|
525
|
+
call_id=fnc_call.id,
|
526
|
+
name=fnc_call.name,
|
527
|
+
arguments=json.dumps(fnc_call.args),
|
528
|
+
)
|
529
|
+
)
|
530
|
+
self._finalize_response()
|
531
|
+
|
532
|
+
def _handle_tool_call_cancellation(
|
533
|
+
self, tool_call_cancellation: LiveServerToolCallCancellation
|
534
|
+
):
|
535
|
+
logger.warning(
|
536
|
+
"function call cancelled",
|
563
537
|
extra={
|
564
|
-
"
|
538
|
+
"function_call_ids": tool_call_cancellation.ids,
|
565
539
|
},
|
566
540
|
)
|
541
|
+
self.emit("function_calls_cancelled", tool_call_cancellation.ids)
|
567
542
|
|
568
|
-
|
569
|
-
|
570
|
-
await called_fnc.task
|
571
|
-
except Exception as e:
|
572
|
-
logger.exception(
|
573
|
-
"error executing ai function",
|
574
|
-
extra={
|
575
|
-
"function": fnc_call_info.function_info.name,
|
576
|
-
},
|
577
|
-
exc_info=e,
|
578
|
-
)
|
579
|
-
tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc)
|
580
|
-
if tool_call.content is not None:
|
581
|
-
tool_response = LiveClientToolResponse(
|
582
|
-
function_responses=[
|
583
|
-
FunctionResponse(
|
584
|
-
name=tool_call.name,
|
585
|
-
id=tool_call.tool_call_id,
|
586
|
-
response={"result": tool_call.content},
|
587
|
-
)
|
588
|
-
]
|
589
|
-
)
|
590
|
-
await self._session.send(input=tool_response)
|
543
|
+
def commit_audio(self) -> None:
|
544
|
+
raise NotImplementedError("commit_audio_buffer is not supported yet")
|
591
545
|
|
592
|
-
|
546
|
+
def clear_audio(self) -> None:
|
547
|
+
raise NotImplementedError("clear_audio is not supported yet")
|
548
|
+
|
549
|
+
def server_vad_enabled(self) -> bool:
|
550
|
+
return True
|