videosdk-plugins-google 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- videosdk_plugins_google-0.0.1/.gitignore +7 -0
- videosdk_plugins_google-0.0.1/PKG-INFO +15 -0
- videosdk_plugins_google-0.0.1/README.md +0 -0
- videosdk_plugins_google-0.0.1/pyproject.toml +34 -0
- videosdk_plugins_google-0.0.1/videosdk/plugins/google/__init__.py +6 -0
- videosdk_plugins_google-0.0.1/videosdk/plugins/google/live_api.py +572 -0
- videosdk_plugins_google-0.0.1/videosdk/plugins/google/version.py +1 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: videosdk-plugins-google
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: VideoSDK Agent Framework plugin for google services
|
|
5
|
+
Author: videosdk
|
|
6
|
+
Keywords: ai,audio,google,video,videosdk
|
|
7
|
+
Classifier: Development Status :: 4 - Beta
|
|
8
|
+
Classifier: Intended Audience :: Developers
|
|
9
|
+
Classifier: Topic :: Communications :: Conferencing
|
|
10
|
+
Classifier: Topic :: Multimedia :: Sound/Audio
|
|
11
|
+
Classifier: Topic :: Multimedia :: Video
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
13
|
+
Requires-Python: >=3.11
|
|
14
|
+
Requires-Dist: google-genai>=1.14.0
|
|
15
|
+
Requires-Dist: videosdk-agents>=0.0.4
|
|
File without changes
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "videosdk-plugins-google"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = "VideoSDK Agent Framework plugin for google services"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
authors = [{ name = "videosdk"}]
|
|
12
|
+
keywords = ["video", "audio", "ai", "google", "videosdk"]
|
|
13
|
+
classifiers = [
|
|
14
|
+
"Intended Audience :: Developers",
|
|
15
|
+
"Development Status :: 4 - Beta",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Topic :: Communications :: Conferencing",
|
|
18
|
+
"Topic :: Multimedia :: Sound/Audio",
|
|
19
|
+
"Topic :: Multimedia :: Video",
|
|
20
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
21
|
+
]
|
|
22
|
+
dependencies = [
|
|
23
|
+
"videosdk-agents>=0.0.4",
|
|
24
|
+
"google-genai >= 1.14.0",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[tool.hatch.version]
|
|
28
|
+
path = "videosdk/plugins/google/version.py"
|
|
29
|
+
|
|
30
|
+
[tool.hatch.build.targets.wheel]
|
|
31
|
+
packages = ["videosdk"]
|
|
32
|
+
|
|
33
|
+
[tool.hatch.build.targets.sdist]
|
|
34
|
+
include = ["/videosdk"]
|
|
@@ -0,0 +1,572 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import os
|
|
5
|
+
import logging
|
|
6
|
+
import traceback
|
|
7
|
+
from typing import Any, Dict, Optional, Literal, List
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
import base64
|
|
10
|
+
import time
|
|
11
|
+
from dotenv import load_dotenv
|
|
12
|
+
from videosdk.agents import CustomAudioStreamTrack, RealtimeBaseModel, build_gemini_schema, is_function_tool, FunctionTool, get_tool_info
|
|
13
|
+
|
|
14
|
+
from google import genai
|
|
15
|
+
from google.genai.live import AsyncSession
|
|
16
|
+
from google.genai.types import (
|
|
17
|
+
Blob,
|
|
18
|
+
Content,
|
|
19
|
+
LiveConnectConfig,
|
|
20
|
+
Modality,
|
|
21
|
+
Part,
|
|
22
|
+
PrebuiltVoiceConfig,
|
|
23
|
+
SpeechConfig,
|
|
24
|
+
VoiceConfig,
|
|
25
|
+
FunctionResponse,
|
|
26
|
+
Tool,
|
|
27
|
+
GenerationConfig,
|
|
28
|
+
AudioTranscriptionConfig,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
load_dotenv()
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
AUDIO_SAMPLE_RATE = 24000 # Match audio sample rate expected by Gemini
|
|
36
|
+
|
|
37
|
+
# Supported event types
|
|
38
|
+
GeminiEventTypes = Literal[
|
|
39
|
+
"tools_updated",
|
|
40
|
+
"instructions_updated",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"]
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class GeminiLiveConfig:
|
|
47
|
+
"""Configuration for the Gemini Live API
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
voice: Voice ID for audio output. Options: 'Puck', 'Charon', 'Kore', 'Fenrir', 'Aoede'. Defaults to 'Puck'
|
|
51
|
+
language_code: Language code for speech synthesis. Defaults to 'en-US'
|
|
52
|
+
temperature: Controls randomness in response generation. Higher values (e.g. 0.8) make output more random,
|
|
53
|
+
lower values (e.g. 0.2) make it more focused. Defaults to None
|
|
54
|
+
top_p: Nucleus sampling parameter. Controls diversity via cumulative probability cutoff. Range 0-1. Defaults to None
|
|
55
|
+
top_k: Limits the number of tokens considered for each step of text generation. Defaults to None
|
|
56
|
+
candidate_count: Number of response candidates to generate. Defaults to 1
|
|
57
|
+
max_output_tokens: Maximum number of tokens allowed in model responses. Defaults to None
|
|
58
|
+
presence_penalty: Penalizes tokens based on their presence in the text so far. Range -2.0 to 2.0. Defaults to None
|
|
59
|
+
frequency_penalty: Penalizes tokens based on their frequency in the text so far. Range -2.0 to 2.0. Defaults to None
|
|
60
|
+
response_modalities: List of enabled response types. Options: ["TEXT", "AUDIO"]. Defaults to both
|
|
61
|
+
output_audio_transcription: Configuration for audio transcription features. Defaults to None
|
|
62
|
+
"""
|
|
63
|
+
voice: Voice | None = "Puck"
|
|
64
|
+
language_code: str | None = "en-US"
|
|
65
|
+
temperature: float | None = None
|
|
66
|
+
top_p: float | None = None
|
|
67
|
+
top_k: float | None = None
|
|
68
|
+
candidate_count: int | None = 1
|
|
69
|
+
max_output_tokens: int | None = None
|
|
70
|
+
presence_penalty: float | None = None
|
|
71
|
+
frequency_penalty: float | None = None
|
|
72
|
+
response_modalities: List[Modality] | None = field(default_factory=lambda: ["TEXT", "AUDIO"])
|
|
73
|
+
output_audio_transcription: AudioTranscriptionConfig | None = None
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class GeminiSession:
|
|
77
|
+
"""Represents a Gemini Live API session"""
|
|
78
|
+
session: AsyncSession
|
|
79
|
+
session_cm: Any
|
|
80
|
+
tasks: list[asyncio.Task]
|
|
81
|
+
|
|
82
|
+
class GeminiRealtime(RealtimeBaseModel[GeminiEventTypes]):
|
|
83
|
+
"""Gemini's realtime model for audio-only communication"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
*,
|
|
88
|
+
model: str,
|
|
89
|
+
config: GeminiLiveConfig | None = None,
|
|
90
|
+
api_key: str | None = None,
|
|
91
|
+
service_account_path: str | None = None,
|
|
92
|
+
) -> None:
|
|
93
|
+
"""
|
|
94
|
+
Initialize Gemini realtime model.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
model: The Gemini model identifier to use (e.g. 'gemini-pro', 'gemini-pro-vision')
|
|
98
|
+
config: Optional configuration object for customizing model behavior. Contains settings for:
|
|
99
|
+
- voice: Voice ID for audio output ('Puck', 'Charon', 'Kore', 'Fenrir', 'Aoede'). Defaults to 'Puck'
|
|
100
|
+
- language_code: Language code for speech synthesis. Defaults to 'en-US'
|
|
101
|
+
- temperature: Controls randomness in responses. Higher values (0.8) more random, lower (0.2) more focused
|
|
102
|
+
- top_p: Nucleus sampling parameter. Controls diversity via probability cutoff. Range 0-1
|
|
103
|
+
- top_k: Limits number of tokens considered for each generation step
|
|
104
|
+
- candidate_count: Number of response candidates to generate. Defaults to 1
|
|
105
|
+
- max_output_tokens: Maximum tokens allowed in model responses
|
|
106
|
+
- presence_penalty: Penalizes token presence in text. Range -2.0 to 2.0
|
|
107
|
+
- frequency_penalty: Penalizes token frequency in text. Range -2.0 to 2.0
|
|
108
|
+
- response_modalities: List of enabled response types ["TEXT", "AUDIO"]. Defaults to both
|
|
109
|
+
- output_audio_transcription: Configuration for audio transcription features
|
|
110
|
+
api_key: Gemini API key. If not provided, will attempt to read from GOOGLE_API_KEY env var
|
|
111
|
+
service_account_path: Path to Google service account JSON file.
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
ValueError: If neither api_key nor service_account_path is provided and no GOOGLE_API_KEY in env vars
|
|
115
|
+
"""
|
|
116
|
+
super().__init__()
|
|
117
|
+
|
|
118
|
+
# Core configuration
|
|
119
|
+
self.model = model
|
|
120
|
+
|
|
121
|
+
# Authentication setup
|
|
122
|
+
self._init_client(api_key, service_account_path)
|
|
123
|
+
|
|
124
|
+
# Initialize state
|
|
125
|
+
self._session: Optional[GeminiSession] = None
|
|
126
|
+
self._closing = False
|
|
127
|
+
self._session_should_close = asyncio.Event()
|
|
128
|
+
self._main_task = None
|
|
129
|
+
self.loop = None
|
|
130
|
+
self.audio_track = None
|
|
131
|
+
|
|
132
|
+
# Audio handling
|
|
133
|
+
self._buffered_audio = bytearray()
|
|
134
|
+
self._is_speaking = False
|
|
135
|
+
self._last_audio_time = 0.0
|
|
136
|
+
self._audio_processing_task = None
|
|
137
|
+
|
|
138
|
+
# Tools and instructions
|
|
139
|
+
self.tools = []
|
|
140
|
+
self._instructions : str = "You are a helpful voice assistant that can answer questions and help with tasks."
|
|
141
|
+
self.config: GeminiLiveConfig = config or GeminiLiveConfig()
|
|
142
|
+
|
|
143
|
+
self.on("tools_updated", self._handle_tools_updated)
|
|
144
|
+
self.on("instructions_updated", self._handle_instructions_updated)
|
|
145
|
+
|
|
146
|
+
def _init_client(self, api_key: str | None, service_account_path: str | None):
|
|
147
|
+
if service_account_path:
|
|
148
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = service_account_path
|
|
149
|
+
self.client = genai.Client(http_options={"api_version": "v1beta"})
|
|
150
|
+
else:
|
|
151
|
+
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
|
152
|
+
if not self.api_key:
|
|
153
|
+
raise ValueError("GOOGLE_API_KEY or service account required")
|
|
154
|
+
self.client = genai.Client(
|
|
155
|
+
api_key=self.api_key,
|
|
156
|
+
http_options={"api_version": "v1beta"}
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
async def connect(self) -> None:
|
|
160
|
+
"""Connect to the Gemini Live API"""
|
|
161
|
+
if self._session:
|
|
162
|
+
await self._cleanup_session(self._session)
|
|
163
|
+
self._session = None
|
|
164
|
+
|
|
165
|
+
self._closing = False
|
|
166
|
+
self._session_should_close.clear()
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
# Initialize audio track
|
|
170
|
+
if not self.audio_track and self.loop:
|
|
171
|
+
self.audio_track = CustomAudioStreamTrack(self.loop)
|
|
172
|
+
elif not self.loop:
|
|
173
|
+
raise RuntimeError("Event loop not initialized. Audio playback will not work.")
|
|
174
|
+
|
|
175
|
+
# Try to create an initial session
|
|
176
|
+
try:
|
|
177
|
+
initial_session = await self._create_session()
|
|
178
|
+
if initial_session:
|
|
179
|
+
self._session = initial_session
|
|
180
|
+
except Exception as e:
|
|
181
|
+
logger.error(f"Initial session creation failed, will retry: {e}")
|
|
182
|
+
|
|
183
|
+
# Start the main processing loop
|
|
184
|
+
if not self._main_task or self._main_task.done():
|
|
185
|
+
self._main_task = asyncio.create_task(self._session_loop(), name="gemini-main-loop")
|
|
186
|
+
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.error(f"Error connecting to Gemini Live API: {e}")
|
|
189
|
+
traceback.print_exc()
|
|
190
|
+
raise
|
|
191
|
+
|
|
192
|
+
async def _create_session(self) -> GeminiSession:
|
|
193
|
+
"""Create a new Gemini Live API session"""
|
|
194
|
+
config = LiveConnectConfig(
|
|
195
|
+
response_modalities=self.config.response_modalities,
|
|
196
|
+
generation_config=GenerationConfig(
|
|
197
|
+
candidate_count=self.config.candidate_count if self.config.candidate_count is not None else None,
|
|
198
|
+
temperature=self.config.temperature if self.config.temperature is not None else None,
|
|
199
|
+
top_p=self.config.top_p if self.config.top_p is not None else None,
|
|
200
|
+
top_k=self.config.top_k if self.config.top_k is not None else None,
|
|
201
|
+
max_output_tokens=self.config.max_output_tokens if self.config.max_output_tokens is not None else None,
|
|
202
|
+
presence_penalty=self.config.presence_penalty if self.config.presence_penalty is not None else None,
|
|
203
|
+
frequency_penalty=self.config.frequency_penalty if self.config.frequency_penalty is not None else None
|
|
204
|
+
),
|
|
205
|
+
system_instruction= self._instructions,
|
|
206
|
+
speech_config=SpeechConfig(
|
|
207
|
+
voice_config=VoiceConfig(
|
|
208
|
+
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=self.config.voice)
|
|
209
|
+
),
|
|
210
|
+
language_code=self.config.language_code
|
|
211
|
+
),
|
|
212
|
+
tools=self.formatted_tools or None,
|
|
213
|
+
output_audio_transcription=self.config.output_audio_transcription if self.config.output_audio_transcription else None
|
|
214
|
+
)
|
|
215
|
+
try:
|
|
216
|
+
|
|
217
|
+
session_cm = self.client.aio.live.connect(model=self.model, config=config)
|
|
218
|
+
session = await session_cm.__aenter__()
|
|
219
|
+
return GeminiSession(session=session, session_cm=session_cm, tasks=[])
|
|
220
|
+
except Exception as e:
|
|
221
|
+
logger.error(f"Connection error: {e}")
|
|
222
|
+
traceback.print_exc()
|
|
223
|
+
raise
|
|
224
|
+
|
|
225
|
+
async def _session_loop(self) -> None:
|
|
226
|
+
"""Main processing loop for Gemini sessions"""
|
|
227
|
+
reconnect_attempts = 0
|
|
228
|
+
max_reconnect_attempts = 5
|
|
229
|
+
reconnect_delay = 1
|
|
230
|
+
|
|
231
|
+
while not self._closing:
|
|
232
|
+
if not self._session:
|
|
233
|
+
try:
|
|
234
|
+
self._session = await self._create_session()
|
|
235
|
+
reconnect_attempts = 0
|
|
236
|
+
reconnect_delay = 1
|
|
237
|
+
except Exception as e:
|
|
238
|
+
reconnect_attempts += 1
|
|
239
|
+
reconnect_delay = min(30, reconnect_delay * 2)
|
|
240
|
+
logger.error(f"Session creation attempt {reconnect_attempts} failed: {e}")
|
|
241
|
+
if reconnect_attempts >= max_reconnect_attempts:
|
|
242
|
+
logger.error("Max reconnection attempts reached")
|
|
243
|
+
break
|
|
244
|
+
await asyncio.sleep(reconnect_delay)
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
session = self._session
|
|
248
|
+
|
|
249
|
+
# Start core tasks
|
|
250
|
+
recv_task = asyncio.create_task(self._receive_loop(session), name="gemini_receive")
|
|
251
|
+
keep_alive_task = asyncio.create_task(self._keep_alive(session), name="gemini_keepalive")
|
|
252
|
+
session.tasks.extend([recv_task, keep_alive_task])
|
|
253
|
+
|
|
254
|
+
# Wait for session close signal
|
|
255
|
+
try:
|
|
256
|
+
await self._session_should_close.wait()
|
|
257
|
+
finally:
|
|
258
|
+
for task in session.tasks:
|
|
259
|
+
if not task.done():
|
|
260
|
+
task.cancel()
|
|
261
|
+
try:
|
|
262
|
+
await asyncio.gather(*session.tasks, return_exceptions=True)
|
|
263
|
+
except Exception as e:
|
|
264
|
+
logger.error(f"Error during task cleanup: {e}")
|
|
265
|
+
|
|
266
|
+
# Handle reconnection
|
|
267
|
+
if not self._closing:
|
|
268
|
+
await self._cleanup_session(session)
|
|
269
|
+
self._session = None
|
|
270
|
+
await asyncio.sleep(reconnect_delay)
|
|
271
|
+
self._session_should_close.clear()
|
|
272
|
+
|
|
273
|
+
async def _handle_tool_calls(self, response, active_response_id: str) -> None:
|
|
274
|
+
"""Handle tool calls from Gemini"""
|
|
275
|
+
if not response.tool_call:
|
|
276
|
+
return
|
|
277
|
+
for tool_call in response.tool_call.function_calls:
|
|
278
|
+
|
|
279
|
+
# Find and execute the matching function
|
|
280
|
+
if self.tools:
|
|
281
|
+
for tool in self.tools:
|
|
282
|
+
if not is_function_tool(tool):
|
|
283
|
+
continue
|
|
284
|
+
tool_info = get_tool_info(tool)
|
|
285
|
+
if tool_info.name == tool_call.name:
|
|
286
|
+
try:
|
|
287
|
+
# Execute the function with the provided arguments
|
|
288
|
+
result = await tool(**tool_call.args)
|
|
289
|
+
# Send the response back to Gemini
|
|
290
|
+
await self.send_tool_response([
|
|
291
|
+
FunctionResponse(
|
|
292
|
+
id=tool_call.id,
|
|
293
|
+
name=tool_call.name,
|
|
294
|
+
response=result
|
|
295
|
+
)
|
|
296
|
+
])
|
|
297
|
+
except Exception as e:
|
|
298
|
+
logger.error(f"Error executing function {tool_call.name}: {e}")
|
|
299
|
+
traceback.print_exc()
|
|
300
|
+
break
|
|
301
|
+
|
|
302
|
+
async def _receive_loop(self, session: GeminiSession) -> None:
|
|
303
|
+
"""Process incoming messages from Gemini"""
|
|
304
|
+
try:
|
|
305
|
+
active_response_id = None
|
|
306
|
+
chunk_number = 0
|
|
307
|
+
|
|
308
|
+
while not self._closing:
|
|
309
|
+
try:
|
|
310
|
+
async for response in session.session.receive():
|
|
311
|
+
if self._closing:
|
|
312
|
+
break
|
|
313
|
+
|
|
314
|
+
if response.tool_call:
|
|
315
|
+
await self._handle_tool_calls(response, active_response_id)
|
|
316
|
+
|
|
317
|
+
# Handle server content with null checks
|
|
318
|
+
if (server_content := response.server_content):
|
|
319
|
+
try:
|
|
320
|
+
if (input_transcription := server_content.input_transcription):
|
|
321
|
+
if input_transcription.text:
|
|
322
|
+
self.emit("input_transcription", {
|
|
323
|
+
"text": input_transcription.text,
|
|
324
|
+
"is_final": False
|
|
325
|
+
})
|
|
326
|
+
|
|
327
|
+
# Output transcription handling
|
|
328
|
+
if (output_transcription := server_content.output_transcription):
|
|
329
|
+
if output_transcription.text:
|
|
330
|
+
self.emit("output_transcription", {
|
|
331
|
+
"text": output_transcription.text,
|
|
332
|
+
"is_final": False
|
|
333
|
+
})
|
|
334
|
+
|
|
335
|
+
except Exception as e:
|
|
336
|
+
logger.error(f"Transcription handling error: {e}")
|
|
337
|
+
traceback.print_exc()
|
|
338
|
+
|
|
339
|
+
# Process server content with audio responses
|
|
340
|
+
if server_content:
|
|
341
|
+
# Initialize response if needed
|
|
342
|
+
if not active_response_id:
|
|
343
|
+
active_response_id = f"response_{id(response)}"
|
|
344
|
+
chunk_number = 0
|
|
345
|
+
|
|
346
|
+
# Handle interruption
|
|
347
|
+
if server_content.interrupted:
|
|
348
|
+
if active_response_id:
|
|
349
|
+
active_response_id = None
|
|
350
|
+
# Clear audio buffer to stop playing interrupted audio
|
|
351
|
+
if self.audio_track:
|
|
352
|
+
self.audio_track.interrupt()
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
# Process audio content
|
|
356
|
+
if model_turn := server_content.model_turn:
|
|
357
|
+
# Emit output speech started when model starts responding
|
|
358
|
+
for part in model_turn.parts:
|
|
359
|
+
if hasattr(part, 'inline_data') and part.inline_data:
|
|
360
|
+
raw_audio = part.inline_data.data
|
|
361
|
+
# Skip empty chunks
|
|
362
|
+
if not raw_audio or len(raw_audio) < 2:
|
|
363
|
+
continue
|
|
364
|
+
|
|
365
|
+
# Process audio chunk
|
|
366
|
+
chunk_number += 1
|
|
367
|
+
|
|
368
|
+
# Send to audio track
|
|
369
|
+
if self.audio_track and self.loop:
|
|
370
|
+
# Ensure even length for 16-bit samples
|
|
371
|
+
if len(raw_audio) % 2 != 0:
|
|
372
|
+
raw_audio += b'\x00'
|
|
373
|
+
|
|
374
|
+
self.loop.create_task(
|
|
375
|
+
self.audio_track.add_new_bytes(raw_audio),
|
|
376
|
+
name=f"audio_chunk_{chunk_number}"
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Handle response completion
|
|
380
|
+
if server_content.turn_complete and active_response_id:
|
|
381
|
+
active_response_id = None
|
|
382
|
+
|
|
383
|
+
except Exception as e:
|
|
384
|
+
if "1000 (OK)" in str(e):
|
|
385
|
+
logger.info("Normal WebSocket closure")
|
|
386
|
+
else:
|
|
387
|
+
logger.error(f"Error in receive loop: {e}")
|
|
388
|
+
traceback.print_exc()
|
|
389
|
+
|
|
390
|
+
# Signal for reconnection
|
|
391
|
+
self._session_should_close.set()
|
|
392
|
+
break
|
|
393
|
+
|
|
394
|
+
await asyncio.sleep(0.1)
|
|
395
|
+
|
|
396
|
+
except asyncio.CancelledError:
|
|
397
|
+
logger.debug("Receive loop cancelled")
|
|
398
|
+
except Exception as e:
|
|
399
|
+
logger.error(f"Fatal error in receive loop: {e}")
|
|
400
|
+
traceback.print_exc()
|
|
401
|
+
self._session_should_close.set()
|
|
402
|
+
|
|
403
|
+
async def _keep_alive(self, session: GeminiSession) -> None:
|
|
404
|
+
"""Send periodic keep-alive messages"""
|
|
405
|
+
try:
|
|
406
|
+
while not self._closing:
|
|
407
|
+
await asyncio.sleep(10)
|
|
408
|
+
|
|
409
|
+
if self._closing:
|
|
410
|
+
break
|
|
411
|
+
|
|
412
|
+
try:
|
|
413
|
+
# Send minimal keep-alive message
|
|
414
|
+
await session.session.send_client_content(
|
|
415
|
+
turns=Content(parts=[Part(text=".")], role="user"),
|
|
416
|
+
turn_complete=False
|
|
417
|
+
)
|
|
418
|
+
except Exception as e:
|
|
419
|
+
if "closed" in str(e).lower():
|
|
420
|
+
self._session_should_close.set()
|
|
421
|
+
break
|
|
422
|
+
logger.error(f"Keep-alive error: {e}")
|
|
423
|
+
except asyncio.CancelledError:
|
|
424
|
+
pass
|
|
425
|
+
except Exception as e:
|
|
426
|
+
logger.error(f"Error in keep-alive: {e}")
|
|
427
|
+
self._session_should_close.set()
|
|
428
|
+
|
|
429
|
+
async def handle_audio_input(self, audio_data: bytes) -> None:
|
|
430
|
+
"""Handle incoming audio data from the user"""
|
|
431
|
+
if not self._session or self._closing:
|
|
432
|
+
return
|
|
433
|
+
|
|
434
|
+
await self._session.session.send_realtime_input(
|
|
435
|
+
audio=Blob(data=audio_data, mime_type=f"audio/pcm;rate={AUDIO_SAMPLE_RATE}")
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
async def interrupt(self) -> None:
|
|
439
|
+
"""Interrupt current response"""
|
|
440
|
+
if not self._session or self._closing:
|
|
441
|
+
return
|
|
442
|
+
|
|
443
|
+
try:
|
|
444
|
+
await self._session.session.send_client_content(
|
|
445
|
+
turns=Content(parts=[Part(text="stop")], role="user"),
|
|
446
|
+
turn_complete=True
|
|
447
|
+
)
|
|
448
|
+
if self.audio_track:
|
|
449
|
+
self.audio_track.interrupt()
|
|
450
|
+
except Exception as e:
|
|
451
|
+
logger.error(f"Interrupt error: {e}")
|
|
452
|
+
|
|
453
|
+
async def send_message(self, message: str) -> None:
|
|
454
|
+
"""Send a text message to get audio response"""
|
|
455
|
+
retry_count = 0
|
|
456
|
+
max_retries = 5
|
|
457
|
+
while not self._session or not self._session.session:
|
|
458
|
+
if retry_count >= max_retries:
|
|
459
|
+
raise RuntimeError("No active Gemini session after maximum retries")
|
|
460
|
+
logger.debug("No active session, waiting for connection...")
|
|
461
|
+
await asyncio.sleep(1)
|
|
462
|
+
retry_count += 1
|
|
463
|
+
|
|
464
|
+
try:
|
|
465
|
+
await self._session.session.send_client_content(
|
|
466
|
+
turns=[
|
|
467
|
+
Content(parts=[Part(text="Repeat the user's exact message back to them [DO NOT ADD ANYTHING ELSE]:" + message)], role="model"),
|
|
468
|
+
Content(parts=[Part(text=".")], role="user")
|
|
469
|
+
],
|
|
470
|
+
turn_complete=True
|
|
471
|
+
)
|
|
472
|
+
await asyncio.sleep(0.1)
|
|
473
|
+
except Exception as e:
|
|
474
|
+
logger.error(f"Error sending message: {e}")
|
|
475
|
+
self._session_should_close.set()
|
|
476
|
+
|
|
477
|
+
async def _cleanup_session(self, session: GeminiSession) -> None:
|
|
478
|
+
"""Clean up a session's resources"""
|
|
479
|
+
# Cancel all tasks
|
|
480
|
+
for task in session.tasks:
|
|
481
|
+
if not task.done():
|
|
482
|
+
task.cancel()
|
|
483
|
+
|
|
484
|
+
# Close the session
|
|
485
|
+
try:
|
|
486
|
+
await session.session_cm.__aexit__(None, None, None)
|
|
487
|
+
except Exception as e:
|
|
488
|
+
logger.error(f"Error closing session: {e}")
|
|
489
|
+
|
|
490
|
+
async def aclose(self) -> None:
|
|
491
|
+
"""Clean up all resources"""
|
|
492
|
+
if self._closing:
|
|
493
|
+
return
|
|
494
|
+
|
|
495
|
+
self._closing = True
|
|
496
|
+
self._session_should_close.set()
|
|
497
|
+
|
|
498
|
+
# Cancel audio processing task
|
|
499
|
+
if self._audio_processing_task and not self._audio_processing_task.done():
|
|
500
|
+
self._audio_processing_task.cancel()
|
|
501
|
+
try:
|
|
502
|
+
await asyncio.wait_for(self._audio_processing_task, timeout=1.0)
|
|
503
|
+
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
504
|
+
pass
|
|
505
|
+
|
|
506
|
+
# Cancel main task
|
|
507
|
+
if self._main_task and not self._main_task.done():
|
|
508
|
+
self._main_task.cancel()
|
|
509
|
+
try:
|
|
510
|
+
await asyncio.wait_for(self._main_task, timeout=2.0)
|
|
511
|
+
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
512
|
+
pass
|
|
513
|
+
|
|
514
|
+
# Clean up session
|
|
515
|
+
if self._session:
|
|
516
|
+
await self._cleanup_session(self._session)
|
|
517
|
+
self._session = None
|
|
518
|
+
|
|
519
|
+
# Clean up audio track
|
|
520
|
+
if hasattr(self.audio_track, 'cleanup') and self.audio_track:
|
|
521
|
+
try:
|
|
522
|
+
await self.audio_track.cleanup()
|
|
523
|
+
except Exception as e:
|
|
524
|
+
logger.error(f"Error cleaning up audio track: {e}")
|
|
525
|
+
|
|
526
|
+
# Clear audio buffers
|
|
527
|
+
self._buffered_audio = bytearray()
|
|
528
|
+
|
|
529
|
+
async def _reconnect(self) -> None:
|
|
530
|
+
if self._session:
|
|
531
|
+
await self._cleanup_session(self._session)
|
|
532
|
+
self._session = None
|
|
533
|
+
self._session = await self._create_session()
|
|
534
|
+
|
|
535
|
+
async def send_tool_response(self, function_responses: List[FunctionResponse]) -> None:
|
|
536
|
+
"""Send tool responses back to Gemini"""
|
|
537
|
+
if not self._session or not self._session.session:
|
|
538
|
+
return
|
|
539
|
+
|
|
540
|
+
try:
|
|
541
|
+
await self._session.session.send_tool_response(
|
|
542
|
+
function_responses=function_responses
|
|
543
|
+
)
|
|
544
|
+
except Exception as e:
|
|
545
|
+
logger.error(f"Error sending tool response: {e}")
|
|
546
|
+
self._session_should_close.set()
|
|
547
|
+
|
|
548
|
+
def _convert_tools_to_gemini_format(self, tools: List[FunctionTool]) -> List[Tool]:
|
|
549
|
+
"""Convert tool definitions to Gemini's Tool format"""
|
|
550
|
+
function_declarations = []
|
|
551
|
+
|
|
552
|
+
for tool in tools:
|
|
553
|
+
if not is_function_tool(tool):
|
|
554
|
+
continue
|
|
555
|
+
|
|
556
|
+
try:
|
|
557
|
+
function_declaration = build_gemini_schema(tool)
|
|
558
|
+
function_declarations.append(function_declaration)
|
|
559
|
+
except Exception as e:
|
|
560
|
+
logger.error(f"Failed to format tool {tool}: {e}")
|
|
561
|
+
continue
|
|
562
|
+
return [Tool(function_declarations=function_declarations)] if function_declarations else []
|
|
563
|
+
|
|
564
|
+
def _handle_tools_updated(self, data: Dict[str, Any]) -> None:
|
|
565
|
+
"""Handle tools updated event"""
|
|
566
|
+
tools = data.get("tools", [])
|
|
567
|
+
self.tools = tools
|
|
568
|
+
self.formatted_tools = self._convert_tools_to_gemini_format(tools)
|
|
569
|
+
|
|
570
|
+
def _handle_instructions_updated(self, data: Dict[str, Any]) -> None:
|
|
571
|
+
"""Handle instruction updated event"""
|
|
572
|
+
self._instructions = data.get("instructions", "")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.0.1"
|