livekit-plugins-google 1.0.0rc8__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/google/beta/realtime/api_proto.py +0 -3
- livekit/plugins/google/beta/realtime/realtime_api.py +400 -419
- livekit/plugins/google/stt.py +6 -12
- livekit/plugins/google/tts.py +4 -7
- livekit/plugins/google/utils.py +21 -3
- livekit/plugins/google/version.py +1 -1
- {livekit_plugins_google-1.0.0rc8.dist-info → livekit_plugins_google-1.0.1.dist-info}/METADATA +2 -2
- livekit_plugins_google-1.0.1.dist-info/RECORD +16 -0
- livekit/plugins/google/beta/realtime/temp.py +0 -10
- livekit/plugins/google/beta/realtime/transcriber.py +0 -254
- livekit_plugins_google-1.0.0rc8.dist-info/RECORD +0 -18
- {livekit_plugins_google-1.0.0rc8.dist-info → livekit_plugins_google-1.0.1.dist-info}/WHEEL +0 -0
@@ -3,21 +3,22 @@ from __future__ import annotations
|
|
3
3
|
import asyncio
|
4
4
|
import json
|
5
5
|
import os
|
6
|
-
|
6
|
+
import weakref
|
7
7
|
from dataclasses import dataclass
|
8
|
-
from typing import Literal
|
9
8
|
|
10
9
|
from google import genai
|
11
10
|
from google.genai._api_client import HttpOptions
|
12
11
|
from google.genai.types import (
|
13
12
|
Blob,
|
14
13
|
Content,
|
15
|
-
|
14
|
+
FunctionDeclaration,
|
16
15
|
GenerationConfig,
|
17
16
|
LiveClientContent,
|
18
17
|
LiveClientRealtimeInput,
|
19
|
-
LiveClientToolResponse,
|
20
18
|
LiveConnectConfig,
|
19
|
+
LiveServerContent,
|
20
|
+
LiveServerToolCall,
|
21
|
+
LiveServerToolCallCancellation,
|
21
22
|
Modality,
|
22
23
|
Part,
|
23
24
|
PrebuiltVoiceConfig,
|
@@ -27,44 +28,16 @@ from google.genai.types import (
|
|
27
28
|
)
|
28
29
|
from livekit import rtc
|
29
30
|
from livekit.agents import llm, utils
|
30
|
-
from livekit.agents.
|
31
|
+
from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
32
|
+
from livekit.agents.utils import is_given
|
31
33
|
|
32
34
|
from ...log import logger
|
33
|
-
from
|
34
|
-
|
35
|
-
LiveAPIModels,
|
36
|
-
Voice,
|
37
|
-
)
|
38
|
-
|
39
|
-
# temporary placeholders
|
40
|
-
from .temp import _build_gemini_ctx, _build_tools, _create_ai_function_info
|
41
|
-
from .transcriber import ModelTranscriber, TranscriberSession, TranscriptionContent
|
35
|
+
from ...utils import _build_gemini_fnc, get_tool_results_for_realtime, to_chat_ctx
|
36
|
+
from .api_proto import ClientEvents, LiveAPIModels, Voice
|
42
37
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
"response_content_added",
|
47
|
-
"response_content_done",
|
48
|
-
"function_calls_collected",
|
49
|
-
"function_calls_finished",
|
50
|
-
"function_calls_cancelled",
|
51
|
-
"input_speech_transcription_completed",
|
52
|
-
"agent_speech_transcription_completed",
|
53
|
-
"agent_speech_stopped",
|
54
|
-
]
|
55
|
-
|
56
|
-
|
57
|
-
@dataclass
|
58
|
-
class GeminiContent:
|
59
|
-
response_id: str
|
60
|
-
item_id: str
|
61
|
-
output_index: int
|
62
|
-
content_index: int
|
63
|
-
text: str
|
64
|
-
audio: list[rtc.AudioFrame]
|
65
|
-
text_stream: AsyncIterable[str]
|
66
|
-
audio_stream: AsyncIterable[rtc.AudioFrame]
|
67
|
-
content_type: Literal["text", "audio"]
|
38
|
+
INPUT_AUDIO_SAMPLE_RATE = 16000
|
39
|
+
OUTPUT_AUDIO_SAMPLE_RATE = 24000
|
40
|
+
NUM_CHANNELS = 1
|
68
41
|
|
69
42
|
|
70
43
|
@dataclass
|
@@ -74,55 +47,59 @@ class InputTranscription:
|
|
74
47
|
|
75
48
|
|
76
49
|
@dataclass
|
77
|
-
class
|
78
|
-
supports_truncate: bool
|
79
|
-
input_audio_sample_rate: int | None = None
|
80
|
-
|
81
|
-
|
82
|
-
@dataclass
|
83
|
-
class ModelOptions:
|
50
|
+
class _RealtimeOptions:
|
84
51
|
model: LiveAPIModels | str
|
85
52
|
api_key: str | None
|
86
53
|
voice: Voice | str
|
87
|
-
response_modalities: list[Modality]
|
54
|
+
response_modalities: NotGivenOr[list[Modality]]
|
88
55
|
vertexai: bool
|
89
56
|
project: str | None
|
90
57
|
location: str | None
|
91
58
|
candidate_count: int
|
92
|
-
temperature: float
|
93
|
-
max_output_tokens: int
|
94
|
-
top_p: float
|
95
|
-
top_k: int
|
96
|
-
presence_penalty: float
|
97
|
-
frequency_penalty: float
|
98
|
-
instructions:
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
59
|
+
temperature: NotGivenOr[float]
|
60
|
+
max_output_tokens: NotGivenOr[int]
|
61
|
+
top_p: NotGivenOr[float]
|
62
|
+
top_k: NotGivenOr[int]
|
63
|
+
presence_penalty: NotGivenOr[float]
|
64
|
+
frequency_penalty: NotGivenOr[float]
|
65
|
+
instructions: NotGivenOr[str]
|
66
|
+
|
67
|
+
|
68
|
+
@dataclass
|
69
|
+
class _MessageGeneration:
|
70
|
+
message_id: str
|
71
|
+
text_ch: utils.aio.Chan[str]
|
72
|
+
audio_ch: utils.aio.Chan[rtc.AudioFrame]
|
73
|
+
|
74
|
+
|
75
|
+
@dataclass
|
76
|
+
class _ResponseGeneration:
|
77
|
+
message_ch: utils.aio.Chan[llm.MessageGeneration]
|
78
|
+
function_ch: utils.aio.Chan[llm.FunctionCall]
|
79
|
+
|
80
|
+
messages: dict[str, _MessageGeneration]
|
81
|
+
|
82
|
+
|
83
|
+
class RealtimeModel(llm.RealtimeModel):
|
104
84
|
def __init__(
|
105
85
|
self,
|
106
86
|
*,
|
107
|
-
instructions: str
|
87
|
+
instructions: NotGivenOr[str] = NOT_GIVEN,
|
108
88
|
model: LiveAPIModels | str = "gemini-2.0-flash-exp",
|
109
|
-
api_key: str
|
89
|
+
api_key: NotGivenOr[str] = NOT_GIVEN,
|
110
90
|
voice: Voice | str = "Puck",
|
111
|
-
modalities: list[Modality] =
|
112
|
-
enable_user_audio_transcription: bool = True,
|
113
|
-
enable_agent_audio_transcription: bool = True,
|
91
|
+
modalities: NotGivenOr[list[Modality]] = NOT_GIVEN,
|
114
92
|
vertexai: bool = False,
|
115
|
-
project: str
|
116
|
-
location: str
|
93
|
+
project: NotGivenOr[str] = NOT_GIVEN,
|
94
|
+
location: NotGivenOr[str] = NOT_GIVEN,
|
117
95
|
candidate_count: int = 1,
|
118
|
-
temperature: float
|
119
|
-
max_output_tokens: int
|
120
|
-
top_p: float
|
121
|
-
top_k: int
|
122
|
-
presence_penalty: float
|
123
|
-
frequency_penalty: float
|
124
|
-
|
125
|
-
):
|
96
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
97
|
+
max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
|
98
|
+
top_p: NotGivenOr[float] = NOT_GIVEN,
|
99
|
+
top_k: NotGivenOr[int] = NOT_GIVEN,
|
100
|
+
presence_penalty: NotGivenOr[float] = NOT_GIVEN,
|
101
|
+
frequency_penalty: NotGivenOr[float] = NOT_GIVEN,
|
102
|
+
) -> None:
|
126
103
|
"""
|
127
104
|
Initializes a RealtimeModel instance for interacting with Google's Realtime API.
|
128
105
|
|
@@ -135,66 +112,57 @@ class RealtimeModel:
|
|
135
112
|
|
136
113
|
Args:
|
137
114
|
instructions (str, optional): Initial system instructions for the model. Defaults to "".
|
138
|
-
api_key (str
|
115
|
+
api_key (str, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
|
139
116
|
modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
|
140
|
-
model (str
|
117
|
+
model (str, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
|
141
118
|
voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
|
142
|
-
enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True
|
143
|
-
enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True
|
144
119
|
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
|
145
120
|
vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
|
146
|
-
project (str
|
147
|
-
location (str
|
121
|
+
project (str, optional): The project id to use for the API. Defaults to None. (for vertexai)
|
122
|
+
location (str, optional): The location to use for the API. Defaults to None. (for vertexai)
|
148
123
|
candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
|
149
124
|
top_p (float, optional): The top-p value for response generation
|
150
125
|
top_k (int, optional): The top-k value for response generation
|
151
126
|
presence_penalty (float, optional): The presence penalty for response generation
|
152
127
|
frequency_penalty (float, optional): The frequency penalty for response generation
|
153
|
-
loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used.
|
154
128
|
|
155
129
|
Raises:
|
156
|
-
ValueError: If the API key is
|
130
|
+
ValueError: If the API key is required but not found.
|
157
131
|
""" # noqa: E501
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
132
|
+
super().__init__(
|
133
|
+
capabilities=llm.RealtimeCapabilities(
|
134
|
+
message_truncation=False,
|
135
|
+
turn_detection=True,
|
136
|
+
user_transcription=False,
|
137
|
+
)
|
164
138
|
)
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION")
|
139
|
+
|
140
|
+
gemini_api_key = api_key if is_given(api_key) else os.environ.get("GOOGLE_API_KEY")
|
141
|
+
gcp_project = project if is_given(project) else os.environ.get("GOOGLE_CLOUD_PROJECT")
|
142
|
+
gcp_location = location if is_given(location) else os.environ.get("GOOGLE_CLOUD_LOCATION")
|
170
143
|
if vertexai:
|
171
|
-
if not
|
144
|
+
if not gcp_project or not gcp_location:
|
172
145
|
raise ValueError(
|
173
146
|
"Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables" # noqa: E501
|
174
147
|
)
|
175
|
-
|
148
|
+
gemini_api_key = None # VertexAI does not require an API key
|
176
149
|
|
177
150
|
else:
|
178
|
-
|
179
|
-
|
180
|
-
if not
|
151
|
+
gcp_project = None
|
152
|
+
gcp_location = None
|
153
|
+
if not gemini_api_key:
|
181
154
|
raise ValueError(
|
182
155
|
"API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable" # noqa: E501
|
183
156
|
)
|
184
157
|
|
185
|
-
|
186
|
-
|
187
|
-
self._rt_sessions: list[GeminiRealtimeSession] = []
|
188
|
-
self._opts = ModelOptions(
|
158
|
+
self._opts = _RealtimeOptions(
|
189
159
|
model=model,
|
190
|
-
api_key=
|
160
|
+
api_key=gemini_api_key,
|
191
161
|
voice=voice,
|
192
|
-
enable_user_audio_transcription=enable_user_audio_transcription,
|
193
|
-
enable_agent_audio_transcription=enable_agent_audio_transcription,
|
194
162
|
response_modalities=modalities,
|
195
163
|
vertexai=vertexai,
|
196
|
-
project=
|
197
|
-
location=
|
164
|
+
project=gcp_project,
|
165
|
+
location=gcp_location,
|
198
166
|
candidate_count=candidate_count,
|
199
167
|
temperature=temperature,
|
200
168
|
max_output_tokens=max_output_tokens,
|
@@ -202,88 +170,39 @@ class RealtimeModel:
|
|
202
170
|
top_k=top_k,
|
203
171
|
presence_penalty=presence_penalty,
|
204
172
|
frequency_penalty=frequency_penalty,
|
205
|
-
instructions=
|
173
|
+
instructions=instructions,
|
206
174
|
)
|
207
175
|
|
208
|
-
|
209
|
-
def sessions(self) -> list[GeminiRealtimeSession]:
|
210
|
-
return self._rt_sessions
|
211
|
-
|
212
|
-
@property
|
213
|
-
def capabilities(self) -> Capabilities:
|
214
|
-
return self._capabilities
|
176
|
+
self._sessions = weakref.WeakSet[RealtimeSession]()
|
215
177
|
|
216
|
-
def session(
|
217
|
-
self
|
218
|
-
|
219
|
-
|
220
|
-
fnc_ctx: llm.FunctionContext | None = None,
|
221
|
-
) -> GeminiRealtimeSession:
|
222
|
-
session = GeminiRealtimeSession(
|
223
|
-
opts=self._opts,
|
224
|
-
chat_ctx=chat_ctx or llm.ChatContext(),
|
225
|
-
fnc_ctx=fnc_ctx,
|
226
|
-
loop=self._loop,
|
227
|
-
)
|
228
|
-
self._rt_sessions.append(session)
|
178
|
+
def session(self) -> RealtimeSession:
|
179
|
+
sess = RealtimeSession(self)
|
180
|
+
self._sessions.add(sess)
|
181
|
+
return sess
|
229
182
|
|
230
|
-
|
183
|
+
def update_options(
|
184
|
+
self, *, voice: NotGivenOr[str] = NOT_GIVEN, temperature: NotGivenOr[float] = NOT_GIVEN
|
185
|
+
) -> None:
|
186
|
+
if is_given(voice):
|
187
|
+
self._opts.voice = voice
|
231
188
|
|
232
|
-
|
233
|
-
|
234
|
-
await session.aclose()
|
189
|
+
if is_given(temperature):
|
190
|
+
self._opts.temperature = temperature
|
235
191
|
|
192
|
+
for sess in self._sessions:
|
193
|
+
sess.update_options(voice=self._opts.voice, temperature=self._opts.temperature)
|
236
194
|
|
237
|
-
|
238
|
-
def __init__(
|
239
|
-
self,
|
240
|
-
*,
|
241
|
-
opts: ModelOptions,
|
242
|
-
chat_ctx: llm.ChatContext,
|
243
|
-
fnc_ctx: llm.FunctionContext | None,
|
244
|
-
loop: asyncio.AbstractEventLoop,
|
245
|
-
):
|
246
|
-
"""
|
247
|
-
Initializes a GeminiRealtimeSession instance for interacting with Google's Realtime API.
|
195
|
+
async def aclose(self) -> None: ...
|
248
196
|
|
249
|
-
Args:
|
250
|
-
opts (ModelOptions): The model options for the session.
|
251
|
-
chat_ctx (llm.ChatContext): The chat context for the session.
|
252
|
-
fnc_ctx (llm.FunctionContext or None): The function context for the session.
|
253
|
-
loop (asyncio.AbstractEventLoop): The event loop for the session.
|
254
|
-
"""
|
255
|
-
super().__init__()
|
256
|
-
self._loop = loop
|
257
|
-
self._opts = opts
|
258
|
-
self._chat_ctx = chat_ctx
|
259
|
-
self._fnc_ctx = fnc_ctx
|
260
|
-
self._fnc_tasks = utils.aio.TaskSet()
|
261
|
-
self._is_interrupted = False
|
262
197
|
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
self.
|
269
|
-
|
270
|
-
|
271
|
-
candidate_count=self._opts.candidate_count,
|
272
|
-
temperature=self._opts.temperature,
|
273
|
-
max_output_tokens=self._opts.max_output_tokens,
|
274
|
-
top_p=self._opts.top_p,
|
275
|
-
top_k=self._opts.top_k,
|
276
|
-
presence_penalty=self._opts.presence_penalty,
|
277
|
-
frequency_penalty=self._opts.frequency_penalty,
|
278
|
-
),
|
279
|
-
system_instruction=self._opts.instructions,
|
280
|
-
speech_config=SpeechConfig(
|
281
|
-
voice_config=VoiceConfig(
|
282
|
-
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
|
283
|
-
)
|
284
|
-
),
|
285
|
-
tools=tools,
|
286
|
-
)
|
198
|
+
class RealtimeSession(llm.RealtimeSession):
|
199
|
+
def __init__(self, realtime_model: RealtimeModel) -> None:
|
200
|
+
super().__init__(realtime_model)
|
201
|
+
self._opts = realtime_model._opts
|
202
|
+
self._tools = llm.ToolContext.empty()
|
203
|
+
self._chat_ctx = llm.ChatContext.empty()
|
204
|
+
self._msg_ch = utils.aio.Chan[ClientEvents]()
|
205
|
+
self._gemini_tools: list[Tool] = []
|
287
206
|
self._client = genai.Client(
|
288
207
|
http_options=HttpOptions(api_version="v1alpha"),
|
289
208
|
api_key=self._opts.api_key,
|
@@ -292,278 +211,340 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
|
|
292
211
|
location=self._opts.location,
|
293
212
|
)
|
294
213
|
self._main_atask = asyncio.create_task(self._main_task(), name="gemini-realtime-session")
|
295
|
-
if self._opts.enable_user_audio_transcription:
|
296
|
-
self._transcriber = TranscriberSession(client=self._client, model=self._opts.model)
|
297
|
-
self._transcriber.on("input_speech_done", self._on_input_speech_done)
|
298
|
-
if self._opts.enable_agent_audio_transcription:
|
299
|
-
self._agent_transcriber = ModelTranscriber(client=self._client, model=self._opts.model)
|
300
|
-
self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done)
|
301
|
-
# init dummy task
|
302
|
-
self._init_sync_task = asyncio.create_task(asyncio.sleep(0))
|
303
|
-
self._send_ch = utils.aio.Chan[ClientEvents]()
|
304
|
-
self._active_response_id = None
|
305
214
|
|
306
|
-
|
307
|
-
if self._send_ch.closed:
|
308
|
-
return
|
215
|
+
self._current_generation: _ResponseGeneration | None = None
|
309
216
|
|
310
|
-
self.
|
311
|
-
|
217
|
+
self._is_interrupted = False
|
218
|
+
self._active_response_id = None
|
219
|
+
self._session = None
|
220
|
+
self._update_chat_ctx_lock = asyncio.Lock()
|
221
|
+
self._update_fnc_ctx_lock = asyncio.Lock()
|
222
|
+
self._response_created_futures: dict[str, asyncio.Future[llm.GenerationCreatedEvent]] = {}
|
223
|
+
self._pending_generation_event_id = None
|
224
|
+
|
225
|
+
self._reconnect_event = asyncio.Event()
|
226
|
+
self._session_lock = asyncio.Lock()
|
227
|
+
self._gemini_close_task: asyncio.Task | None = None
|
228
|
+
|
229
|
+
def _schedule_gemini_session_close(self) -> None:
|
230
|
+
if self._session is not None:
|
231
|
+
self._gemini_close_task = asyncio.create_task(self._close_gemini_session())
|
232
|
+
|
233
|
+
async def _close_gemini_session(self) -> None:
|
234
|
+
async with self._session_lock:
|
235
|
+
if self._session:
|
236
|
+
try:
|
237
|
+
await self._session.close()
|
238
|
+
finally:
|
239
|
+
self._session = None
|
240
|
+
|
241
|
+
def update_options(
|
242
|
+
self,
|
243
|
+
*,
|
244
|
+
voice: NotGivenOr[str] = NOT_GIVEN,
|
245
|
+
tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
|
246
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
247
|
+
) -> None:
|
248
|
+
if is_given(voice):
|
249
|
+
self._opts.voice = voice
|
250
|
+
|
251
|
+
if is_given(temperature):
|
252
|
+
self._opts.temperature = temperature
|
253
|
+
|
254
|
+
if self._session:
|
255
|
+
logger.warning("Updating options; triggering Gemini session reconnect.")
|
256
|
+
self._reconnect_event.set()
|
257
|
+
self._schedule_gemini_session_close()
|
258
|
+
|
259
|
+
async def update_instructions(self, instructions: str) -> None:
|
260
|
+
self._opts.instructions = instructions
|
261
|
+
if self._session:
|
262
|
+
logger.warning("Updating instructions; triggering Gemini session reconnect.")
|
263
|
+
self._reconnect_event.set()
|
264
|
+
self._schedule_gemini_session_close()
|
265
|
+
|
266
|
+
async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
|
267
|
+
async with self._update_chat_ctx_lock:
|
268
|
+
self._chat_ctx = chat_ctx
|
269
|
+
turns, _ = to_chat_ctx(self._chat_ctx, id(self), ignore_functions=True)
|
270
|
+
tool_results = get_tool_results_for_realtime(self._chat_ctx)
|
271
|
+
if turns:
|
272
|
+
self._msg_ch.send_nowait(LiveClientContent(turns=turns, turn_complete=False))
|
273
|
+
if tool_results:
|
274
|
+
self._msg_ch.send_nowait(tool_results)
|
275
|
+
|
276
|
+
async def update_tools(self, tools: list[llm.FunctionTool]) -> None:
|
277
|
+
async with self._update_fnc_ctx_lock:
|
278
|
+
retained_tools: list[llm.FunctionTool] = []
|
279
|
+
gemini_function_declarations: list[FunctionDeclaration] = []
|
280
|
+
|
281
|
+
for tool in tools:
|
282
|
+
gemini_function = _build_gemini_fnc(tool)
|
283
|
+
gemini_function_declarations.append(gemini_function)
|
284
|
+
retained_tools.append(tool)
|
285
|
+
|
286
|
+
self._tools = llm.ToolContext(retained_tools)
|
287
|
+
self._gemini_tools = [Tool(function_declarations=gemini_function_declarations)]
|
288
|
+
if self._session and gemini_function_declarations:
|
289
|
+
logger.warning("Updating tools; triggering Gemini session reconnect.")
|
290
|
+
self._reconnect_event.set()
|
291
|
+
self._schedule_gemini_session_close()
|
312
292
|
|
313
293
|
@property
|
314
|
-
def
|
315
|
-
return self.
|
294
|
+
def chat_ctx(self) -> llm.ChatContext:
|
295
|
+
return self._chat_ctx
|
316
296
|
|
317
|
-
@
|
318
|
-
def
|
319
|
-
self.
|
297
|
+
@property
|
298
|
+
def tools(self) -> llm.ToolContext:
|
299
|
+
return self._tools
|
320
300
|
|
321
|
-
def
|
301
|
+
def push_audio(self, frame: rtc.AudioFrame) -> None:
|
322
302
|
realtime_input = LiveClientRealtimeInput(
|
323
|
-
media_chunks=[Blob(data=data, mime_type=
|
303
|
+
media_chunks=[Blob(data=frame.data.tobytes(), mime_type="audio/pcm")],
|
324
304
|
)
|
325
|
-
self.
|
305
|
+
self._msg_ch.send_nowait(realtime_input)
|
326
306
|
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
)
|
307
|
+
def generate_reply(
|
308
|
+
self, *, instructions: NotGivenOr[str] = NOT_GIVEN
|
309
|
+
) -> asyncio.Future[llm.GenerationCreatedEvent]:
|
310
|
+
fut = asyncio.Future()
|
332
311
|
|
333
|
-
|
334
|
-
self
|
335
|
-
|
336
|
-
encode_options: images.EncodeOptions = DEFAULT_ENCODE_OPTIONS,
|
337
|
-
) -> None:
|
338
|
-
"""Push a video frame to the Gemini Multimodal Live session.
|
312
|
+
event_id = utils.shortuuid("gemini-response-")
|
313
|
+
self._response_created_futures[event_id] = fut
|
314
|
+
self._pending_generation_event_id = event_id
|
339
315
|
|
340
|
-
|
341
|
-
|
342
|
-
|
316
|
+
instructions_content = instructions if is_given(instructions) else "."
|
317
|
+
ctx = [Content(parts=[Part(text=instructions_content)], role="user")]
|
318
|
+
self._msg_ch.send_nowait(LiveClientContent(turns=ctx, turn_complete=True))
|
343
319
|
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
)
|
351
|
-
mime_type = (
|
352
|
-
"image/jpeg"
|
353
|
-
if encode_options.format == "JPEG"
|
354
|
-
else "image/png"
|
355
|
-
if encode_options.format == "PNG"
|
356
|
-
else "image/jpeg"
|
357
|
-
)
|
358
|
-
self._push_media_chunk(encoded_data, mime_type)
|
320
|
+
def _on_timeout() -> None:
|
321
|
+
if event_id in self._response_created_futures and not fut.done():
|
322
|
+
fut.set_exception(llm.RealtimeError("generate_reply timed out."))
|
323
|
+
self._response_created_futures.pop(event_id, None)
|
324
|
+
if self._pending_generation_event_id == event_id:
|
325
|
+
self._pending_generation_event_id = None
|
359
326
|
|
360
|
-
|
361
|
-
|
362
|
-
self._transcriber._push_audio(frame)
|
327
|
+
handle = asyncio.get_event_loop().call_later(5.0, _on_timeout)
|
328
|
+
fut.add_done_callback(lambda _: handle.cancel())
|
363
329
|
|
364
|
-
|
330
|
+
return fut
|
365
331
|
|
366
|
-
def
|
367
|
-
|
332
|
+
def interrupt(self) -> None:
|
333
|
+
logger.warning("interrupt() - no direct cancellation in Gemini")
|
368
334
|
|
369
|
-
def
|
370
|
-
|
335
|
+
def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
|
336
|
+
logger.warning(f"truncate(...) called for {message_id}, ignoring for Gemini")
|
371
337
|
|
372
|
-
async def
|
373
|
-
self.
|
338
|
+
async def aclose(self) -> None:
|
339
|
+
self._msg_ch.close()
|
374
340
|
|
375
|
-
|
376
|
-
|
341
|
+
for fut in self._response_created_futures.values():
|
342
|
+
if not fut.done():
|
343
|
+
fut.set_exception(llm.RealtimeError("Session closed"))
|
377
344
|
|
378
|
-
|
379
|
-
|
380
|
-
on_duplicate: Literal["cancel_existing", "cancel_new", "keep_both"] = "keep_both",
|
381
|
-
) -> None:
|
382
|
-
turns, _ = _build_gemini_ctx(self._chat_ctx, id(self))
|
383
|
-
ctx = [self._opts.instructions] + turns if self._opts.instructions else turns
|
345
|
+
if self._main_atask:
|
346
|
+
await utils.aio.cancel_and_wait(self._main_atask)
|
384
347
|
|
385
|
-
if
|
386
|
-
|
387
|
-
|
348
|
+
if self._gemini_close_task:
|
349
|
+
await utils.aio.cancel_and_wait(self._gemini_close_task)
|
350
|
+
|
351
|
+
@utils.log_exceptions(logger=logger)
|
352
|
+
async def _main_task(self):
|
353
|
+
while True:
|
354
|
+
config = LiveConnectConfig(
|
355
|
+
response_modalities=self._opts.response_modalities
|
356
|
+
if is_given(self._opts.response_modalities)
|
357
|
+
else [Modality.AUDIO],
|
358
|
+
generation_config=GenerationConfig(
|
359
|
+
candidate_count=self._opts.candidate_count,
|
360
|
+
temperature=self._opts.temperature
|
361
|
+
if is_given(self._opts.temperature)
|
362
|
+
else None,
|
363
|
+
max_output_tokens=self._opts.max_output_tokens
|
364
|
+
if is_given(self._opts.max_output_tokens)
|
365
|
+
else None,
|
366
|
+
top_p=self._opts.top_p if is_given(self._opts.top_p) else None,
|
367
|
+
top_k=self._opts.top_k if is_given(self._opts.top_k) else None,
|
368
|
+
presence_penalty=self._opts.presence_penalty
|
369
|
+
if is_given(self._opts.presence_penalty)
|
370
|
+
else None,
|
371
|
+
frequency_penalty=self._opts.frequency_penalty
|
372
|
+
if is_given(self._opts.frequency_penalty)
|
373
|
+
else None,
|
374
|
+
),
|
375
|
+
system_instruction=Content(parts=[Part(text=self._opts.instructions)])
|
376
|
+
if is_given(self._opts.instructions)
|
377
|
+
else None,
|
378
|
+
speech_config=SpeechConfig(
|
379
|
+
voice_config=VoiceConfig(
|
380
|
+
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self._opts.voice)
|
381
|
+
)
|
382
|
+
),
|
383
|
+
tools=self._gemini_tools,
|
388
384
|
)
|
389
|
-
ctx = [Content(parts=[Part(text=".")])]
|
390
385
|
|
391
|
-
|
386
|
+
async with self._client.aio.live.connect(
|
387
|
+
model=self._opts.model, config=config
|
388
|
+
) as session:
|
389
|
+
async with self._session_lock:
|
390
|
+
self._session = session
|
391
|
+
|
392
|
+
@utils.log_exceptions(logger=logger)
|
393
|
+
async def _send_task():
|
394
|
+
async for msg in self._msg_ch:
|
395
|
+
if isinstance(msg, LiveClientContent):
|
396
|
+
await session.send(input=msg, end_of_turn=True)
|
397
|
+
|
398
|
+
await session.send(input=msg)
|
399
|
+
await session.send(input=".", end_of_turn=True)
|
400
|
+
|
401
|
+
@utils.log_exceptions(logger=logger)
|
402
|
+
async def _recv_task():
|
403
|
+
while True:
|
404
|
+
async for response in session.receive():
|
405
|
+
if self._active_response_id is None:
|
406
|
+
self._start_new_generation()
|
407
|
+
if response.server_content:
|
408
|
+
self._handle_server_content(response.server_content)
|
409
|
+
if response.tool_call:
|
410
|
+
self._handle_tool_calls(response.tool_call)
|
411
|
+
if response.tool_call_cancellation:
|
412
|
+
self._handle_tool_call_cancellation(response.tool_call_cancellation)
|
413
|
+
|
414
|
+
send_task = asyncio.create_task(_send_task(), name="gemini-realtime-send")
|
415
|
+
recv_task = asyncio.create_task(_recv_task(), name="gemini-realtime-recv")
|
416
|
+
reconnect_task = asyncio.create_task(
|
417
|
+
self._reconnect_event.wait(), name="reconnect-wait"
|
418
|
+
)
|
392
419
|
|
393
|
-
|
394
|
-
|
420
|
+
try:
|
421
|
+
done, _ = await asyncio.wait(
|
422
|
+
[send_task, recv_task, reconnect_task],
|
423
|
+
return_when=asyncio.FIRST_COMPLETED,
|
424
|
+
)
|
425
|
+
for task in done:
|
426
|
+
if task != reconnect_task:
|
427
|
+
task.result()
|
395
428
|
|
396
|
-
|
397
|
-
|
429
|
+
if reconnect_task not in done:
|
430
|
+
break
|
398
431
|
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
432
|
+
self._reconnect_event.clear()
|
433
|
+
finally:
|
434
|
+
await utils.aio.cancel_and_wait(send_task, recv_task, reconnect_task)
|
435
|
+
|
436
|
+
def _start_new_generation(self):
|
437
|
+
self._is_interrupted = False
|
438
|
+
self._active_response_id = utils.shortuuid("gemini-turn-")
|
439
|
+
self._current_generation = _ResponseGeneration(
|
440
|
+
message_ch=utils.aio.Chan[llm.MessageGeneration](),
|
441
|
+
function_ch=utils.aio.Chan[llm.FunctionCall](),
|
442
|
+
messages={},
|
443
|
+
)
|
444
|
+
|
445
|
+
# We'll assume each chunk belongs to a single message ID self._active_response_id
|
446
|
+
item_generation = _MessageGeneration(
|
447
|
+
message_id=self._active_response_id,
|
448
|
+
text_ch=utils.aio.Chan[str](),
|
449
|
+
audio_ch=utils.aio.Chan[rtc.AudioFrame](),
|
450
|
+
)
|
451
|
+
|
452
|
+
self._current_generation.message_ch.send_nowait(
|
453
|
+
llm.MessageGeneration(
|
454
|
+
message_id=self._active_response_id,
|
455
|
+
text_stream=item_generation.text_ch,
|
456
|
+
audio_stream=item_generation.audio_ch,
|
407
457
|
)
|
458
|
+
)
|
408
459
|
|
409
|
-
|
410
|
-
|
460
|
+
generation_event = llm.GenerationCreatedEvent(
|
461
|
+
message_stream=self._current_generation.message_ch,
|
462
|
+
function_stream=self._current_generation.function_ch,
|
463
|
+
user_initiated=False,
|
464
|
+
)
|
411
465
|
|
412
|
-
|
413
|
-
if
|
414
|
-
self.
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
466
|
+
# Resolve any pending future from generate_reply()
|
467
|
+
if self._pending_generation_event_id and (
|
468
|
+
fut := self._response_created_futures.pop(self._pending_generation_event_id, None)
|
469
|
+
):
|
470
|
+
fut.set_result(generation_event)
|
471
|
+
|
472
|
+
self._pending_generation_event_id = None
|
473
|
+
self.emit("generation_created", generation_event)
|
474
|
+
|
475
|
+
self._current_generation.messages[self._active_response_id] = item_generation
|
476
|
+
|
477
|
+
def _handle_server_content(self, server_content: LiveServerContent):
|
478
|
+
if not self._current_generation or not self._active_response_id:
|
479
|
+
logger.warning(
|
480
|
+
"gemini-realtime-session: No active response ID, skipping server content"
|
420
481
|
)
|
421
|
-
|
482
|
+
return
|
422
483
|
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
text_stream = utils.aio.Chan[str]()
|
440
|
-
audio_stream = utils.aio.Chan[rtc.AudioFrame]()
|
441
|
-
content = GeminiContent(
|
442
|
-
response_id=self._active_response_id,
|
443
|
-
item_id=self._active_response_id,
|
444
|
-
output_index=0,
|
445
|
-
content_index=0,
|
446
|
-
text="",
|
447
|
-
audio=[],
|
448
|
-
text_stream=text_stream,
|
449
|
-
audio_stream=audio_stream,
|
450
|
-
content_type="audio",
|
451
|
-
)
|
452
|
-
self.emit("response_content_added", content)
|
453
|
-
|
454
|
-
server_content = response.server_content
|
455
|
-
if server_content:
|
456
|
-
model_turn = server_content.model_turn
|
457
|
-
if model_turn:
|
458
|
-
for part in model_turn.parts:
|
459
|
-
if part.text:
|
460
|
-
content.text_stream.send_nowait(part.text)
|
461
|
-
if part.inline_data:
|
462
|
-
frame = rtc.AudioFrame(
|
463
|
-
data=part.inline_data.data,
|
464
|
-
sample_rate=24000,
|
465
|
-
num_channels=1,
|
466
|
-
samples_per_channel=len(part.inline_data.data) // 2,
|
467
|
-
)
|
468
|
-
if self._opts.enable_agent_audio_transcription:
|
469
|
-
content.audio.append(frame)
|
470
|
-
content.audio_stream.send_nowait(frame)
|
471
|
-
|
472
|
-
if server_content.interrupted or server_content.turn_complete:
|
473
|
-
if self._opts.enable_agent_audio_transcription:
|
474
|
-
self._agent_transcriber._push_audio(content.audio)
|
475
|
-
for stream in (content.text_stream, content.audio_stream):
|
476
|
-
if isinstance(stream, utils.aio.Chan):
|
477
|
-
stream.close()
|
478
|
-
|
479
|
-
self.emit("agent_speech_stopped")
|
480
|
-
self._is_interrupted = True
|
481
|
-
|
482
|
-
self._active_response_id = None
|
483
|
-
|
484
|
-
if response.tool_call:
|
485
|
-
if self._fnc_ctx is None:
|
486
|
-
raise ValueError("Function context is not set")
|
487
|
-
fnc_calls = []
|
488
|
-
for fnc_call in response.tool_call.function_calls:
|
489
|
-
fnc_call_info = _create_ai_function_info(
|
490
|
-
self._fnc_ctx,
|
491
|
-
fnc_call.id,
|
492
|
-
fnc_call.name,
|
493
|
-
json.dumps(fnc_call.args),
|
494
|
-
)
|
495
|
-
fnc_calls.append(fnc_call_info)
|
496
|
-
|
497
|
-
self.emit("function_calls_collected", fnc_calls)
|
498
|
-
|
499
|
-
for fnc_call_info in fnc_calls:
|
500
|
-
self._fnc_tasks.create_task(
|
501
|
-
self._run_fnc_task(fnc_call_info, content.item_id)
|
502
|
-
)
|
503
|
-
|
504
|
-
# Handle function call cancellations
|
505
|
-
if response.tool_call_cancellation:
|
506
|
-
logger.warning(
|
507
|
-
"function call cancelled",
|
508
|
-
extra={
|
509
|
-
"function_call_ids": response.tool_call_cancellation.function_call_ids, # noqa: E501
|
510
|
-
},
|
511
|
-
)
|
512
|
-
self.emit(
|
513
|
-
"function_calls_cancelled",
|
514
|
-
response.tool_call_cancellation.function_call_ids,
|
515
|
-
)
|
516
|
-
|
517
|
-
async with self._client.aio.live.connect(
|
518
|
-
model=self._opts.model, config=self._config
|
519
|
-
) as session:
|
520
|
-
self._session = session
|
521
|
-
tasks = [
|
522
|
-
asyncio.create_task(_send_task(), name="gemini-realtime-send"),
|
523
|
-
asyncio.create_task(_recv_task(), name="gemini-realtime-recv"),
|
524
|
-
]
|
525
|
-
|
526
|
-
try:
|
527
|
-
await asyncio.gather(*tasks)
|
528
|
-
finally:
|
529
|
-
await utils.aio.gracefully_cancel(*tasks)
|
530
|
-
await self._session.close()
|
531
|
-
if self._opts.enable_user_audio_transcription:
|
532
|
-
await self._transcriber.aclose()
|
533
|
-
if self._opts.enable_agent_audio_transcription:
|
534
|
-
await self._agent_transcriber.aclose()
|
484
|
+
item_generation = self._current_generation.messages[self._active_response_id]
|
485
|
+
|
486
|
+
model_turn = server_content.model_turn
|
487
|
+
if model_turn:
|
488
|
+
for part in model_turn.parts:
|
489
|
+
if part.text:
|
490
|
+
item_generation.text_ch.send_nowait(part.text)
|
491
|
+
if part.inline_data:
|
492
|
+
frame_data = part.inline_data.data
|
493
|
+
frame = rtc.AudioFrame(
|
494
|
+
data=frame_data,
|
495
|
+
sample_rate=OUTPUT_AUDIO_SAMPLE_RATE,
|
496
|
+
num_channels=NUM_CHANNELS,
|
497
|
+
samples_per_channel=len(frame_data) // 2,
|
498
|
+
)
|
499
|
+
item_generation.audio_ch.send_nowait(frame)
|
535
500
|
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
501
|
+
if server_content.interrupted or server_content.turn_complete:
|
502
|
+
self._finalize_response()
|
503
|
+
|
504
|
+
def _finalize_response(self) -> None:
|
505
|
+
if not self._current_generation:
|
506
|
+
return
|
507
|
+
|
508
|
+
for item_generation in self._current_generation.messages.values():
|
509
|
+
item_generation.text_ch.close()
|
510
|
+
item_generation.audio_ch.close()
|
511
|
+
|
512
|
+
self._current_generation.function_ch.close()
|
513
|
+
self._current_generation.message_ch.close()
|
514
|
+
self._current_generation = None
|
515
|
+
self._is_interrupted = True
|
516
|
+
self._active_response_id = None
|
517
|
+
self.emit("agent_speech_stopped")
|
518
|
+
|
519
|
+
def _handle_tool_calls(self, tool_call: LiveServerToolCall):
|
520
|
+
if not self._current_generation:
|
521
|
+
return
|
522
|
+
for fnc_call in tool_call.function_calls:
|
523
|
+
self._current_generation.function_ch.send_nowait(
|
524
|
+
llm.FunctionCall(
|
525
|
+
call_id=fnc_call.id,
|
526
|
+
name=fnc_call.name,
|
527
|
+
arguments=json.dumps(fnc_call.args),
|
528
|
+
)
|
529
|
+
)
|
530
|
+
self._finalize_response()
|
531
|
+
|
532
|
+
def _handle_tool_call_cancellation(
|
533
|
+
self, tool_call_cancellation: LiveServerToolCallCancellation
|
534
|
+
):
|
535
|
+
logger.warning(
|
536
|
+
"function call cancelled",
|
540
537
|
extra={
|
541
|
-
"
|
538
|
+
"function_call_ids": tool_call_cancellation.ids,
|
542
539
|
},
|
543
540
|
)
|
541
|
+
self.emit("function_calls_cancelled", tool_call_cancellation.ids)
|
544
542
|
|
545
|
-
|
546
|
-
|
547
|
-
await called_fnc.task
|
548
|
-
except Exception as e:
|
549
|
-
logger.exception(
|
550
|
-
"error executing ai function",
|
551
|
-
extra={
|
552
|
-
"function": fnc_call_info.function_info.name,
|
553
|
-
},
|
554
|
-
exc_info=e,
|
555
|
-
)
|
556
|
-
tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc)
|
557
|
-
if tool_call.content is not None:
|
558
|
-
tool_response = LiveClientToolResponse(
|
559
|
-
function_responses=[
|
560
|
-
FunctionResponse(
|
561
|
-
name=tool_call.name,
|
562
|
-
id=tool_call.tool_call_id,
|
563
|
-
response={"result": tool_call.content},
|
564
|
-
)
|
565
|
-
]
|
566
|
-
)
|
567
|
-
await self._session.send(input=tool_response)
|
543
|
+
def commit_audio(self) -> None:
|
544
|
+
raise NotImplementedError("commit_audio_buffer is not supported yet")
|
568
545
|
|
569
|
-
|
546
|
+
def clear_audio(self) -> None:
|
547
|
+
raise NotImplementedError("clear_audio is not supported yet")
|
548
|
+
|
549
|
+
def server_vad_enabled(self) -> bool:
|
550
|
+
return True
|