livekit-plugins-google 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
- import base64
5
4
  import json
6
5
  import os
7
6
  from dataclasses import dataclass
@@ -11,14 +10,22 @@ from livekit import rtc
11
10
  from livekit.agents import llm, utils
12
11
  from livekit.agents.llm.function_context import _create_ai_function_info
13
12
 
14
- from google import genai # type: ignore
15
- from google.genai.types import ( # type: ignore
13
+ from google import genai
14
+ from google.genai._api_client import HttpOptions
15
+ from google.genai.types import (
16
+ Blob,
17
+ Content,
16
18
  FunctionResponse,
17
- GenerationConfigDict,
19
+ GenerationConfig,
20
+ LiveClientContent,
21
+ LiveClientRealtimeInput,
18
22
  LiveClientToolResponse,
19
- LiveConnectConfigDict,
23
+ LiveConnectConfig,
24
+ Modality,
25
+ Part,
20
26
  PrebuiltVoiceConfig,
21
27
  SpeechConfig,
28
+ Tool,
22
29
  VoiceConfig,
23
30
  )
24
31
 
@@ -26,10 +33,11 @@ from ...log import logger
26
33
  from .api_proto import (
27
34
  ClientEvents,
28
35
  LiveAPIModels,
29
- ResponseModality,
30
36
  Voice,
37
+ _build_gemini_ctx,
31
38
  _build_tools,
32
39
  )
40
+ from .transcriber import TranscriberSession, TranscriptionContent
33
41
 
34
42
  EventTypes = Literal[
35
43
  "start_session",
@@ -39,6 +47,9 @@ EventTypes = Literal[
39
47
  "function_calls_collected",
40
48
  "function_calls_finished",
41
49
  "function_calls_cancelled",
50
+ "input_speech_transcription_completed",
51
+ "agent_speech_transcription_completed",
52
+ "agent_speech_stopped",
42
53
  ]
43
54
 
44
55
 
@@ -55,6 +66,12 @@ class GeminiContent:
55
66
  content_type: Literal["text", "audio"]
56
67
 
57
68
 
69
+ @dataclass
70
+ class InputTranscription:
71
+ item_id: str
72
+ transcript: str
73
+
74
+
58
75
  @dataclass
59
76
  class Capabilities:
60
77
  supports_truncate: bool
@@ -65,7 +82,7 @@ class ModelOptions:
65
82
  model: LiveAPIModels | str
66
83
  api_key: str | None
67
84
  voice: Voice | str
68
- response_modalities: ResponseModality
85
+ response_modalities: list[Modality] | None
69
86
  vertexai: bool
70
87
  project: str | None
71
88
  location: str | None
@@ -76,18 +93,22 @@ class ModelOptions:
76
93
  top_k: int | None
77
94
  presence_penalty: float | None
78
95
  frequency_penalty: float | None
79
- instructions: str
96
+ instructions: Content | None
97
+ enable_user_audio_transcription: bool
98
+ enable_agent_audio_transcription: bool
80
99
 
81
100
 
82
101
  class RealtimeModel:
83
102
  def __init__(
84
103
  self,
85
104
  *,
86
- instructions: str = "",
105
+ instructions: str | None = None,
87
106
  model: LiveAPIModels | str = "gemini-2.0-flash-exp",
88
107
  api_key: str | None = None,
89
108
  voice: Voice | str = "Puck",
90
- modalities: ResponseModality = "AUDIO",
109
+ modalities: list[Modality] = ["AUDIO"],
110
+ enable_user_audio_transcription: bool = True,
111
+ enable_agent_audio_transcription: bool = True,
91
112
  vertexai: bool = False,
92
113
  project: str | None = None,
93
114
  location: str | None = None,
@@ -103,15 +124,24 @@ class RealtimeModel:
103
124
  """
104
125
  Initializes a RealtimeModel instance for interacting with Google's Realtime API.
105
126
 
127
+ Environment Requirements:
128
+ - For VertexAI: Set the `GOOGLE_APPLICATION_CREDENTIALS` environment variable to the path of the service account key file.
129
+ The Google Cloud project and location can be set via `project` and `location` arguments or the environment variables
130
+ `GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION`. By default, the project is inferred from the service account key file,
131
+ and the location defaults to "us-central1".
132
+ - For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable.
133
+
106
134
  Args:
107
135
  instructions (str, optional): Initial system instructions for the model. Defaults to "".
108
136
  api_key (str or None, optional): Google Gemini API key. If None, will attempt to read from the environment variable GOOGLE_API_KEY.
109
- modalities (ResponseModality): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
137
+ modalities (list[Modality], optional): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
110
138
  model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
111
139
  voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
140
+ enable_user_audio_transcription (bool, optional): Whether to enable user audio transcription. Defaults to True
141
+ enable_agent_audio_transcription (bool, optional): Whether to enable agent audio transcription. Defaults to True
112
142
  temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
113
143
  vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
114
- project (str or None, optional): The project to use for the API. Defaults to None. (for vertexai)
144
+ project (str or None, optional): The project id to use for the API. Defaults to None. (for vertexai)
115
145
  location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai)
116
146
  candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
117
147
  top_p (float, optional): The top-p value for response generation
@@ -130,21 +160,38 @@ class RealtimeModel:
130
160
  self._model = model
131
161
  self._loop = loop or asyncio.get_event_loop()
132
162
  self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
133
- self._vertexai = vertexai
134
- self._project_id = project or os.environ.get("GOOGLE_PROJECT")
135
- self._location = location or os.environ.get("GOOGLE_LOCATION")
136
- if self._api_key is None and not self._vertexai:
137
- raise ValueError("GOOGLE_API_KEY is not set")
163
+ self._project = project or os.environ.get("GOOGLE_CLOUD_PROJECT")
164
+ self._location = location or os.environ.get("GOOGLE_CLOUD_LOCATION")
165
+ if vertexai:
166
+ if not self._project or not self._location:
167
+ raise ValueError(
168
+ "Project and location are required for VertexAI either via project and location or GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION environment variables"
169
+ )
170
+ self._api_key = None # VertexAI does not require an API key
171
+
172
+ else:
173
+ self._project = None
174
+ self._location = None
175
+ if not self._api_key:
176
+ raise ValueError(
177
+ "API key is required for Google API either via api_key or GOOGLE_API_KEY environment variable"
178
+ )
179
+
180
+ instructions_content = (
181
+ Content(parts=[Part(text=instructions)]) if instructions else None
182
+ )
138
183
 
139
184
  self._rt_sessions: list[GeminiRealtimeSession] = []
140
185
  self._opts = ModelOptions(
141
186
  model=model,
142
- api_key=api_key,
187
+ api_key=self._api_key,
143
188
  voice=voice,
189
+ enable_user_audio_transcription=enable_user_audio_transcription,
190
+ enable_agent_audio_transcription=enable_agent_audio_transcription,
144
191
  response_modalities=modalities,
145
192
  vertexai=vertexai,
146
- project=project,
147
- location=location,
193
+ project=self._project,
194
+ location=self._location,
148
195
  candidate_count=candidate_count,
149
196
  temperature=temperature,
150
197
  max_output_tokens=max_output_tokens,
@@ -152,7 +199,7 @@ class RealtimeModel:
152
199
  top_k=top_k,
153
200
  presence_penalty=presence_penalty,
154
201
  frequency_penalty=frequency_penalty,
155
- instructions=instructions,
202
+ instructions=instructions_content,
156
203
  )
157
204
 
158
205
  @property
@@ -208,16 +255,16 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
208
255
  self._chat_ctx = chat_ctx
209
256
  self._fnc_ctx = fnc_ctx
210
257
  self._fnc_tasks = utils.aio.TaskSet()
258
+ self._is_interrupted = False
211
259
 
212
260
  tools = []
213
261
  if self._fnc_ctx is not None:
214
262
  functions = _build_tools(self._fnc_ctx)
215
- tools.append({"function_declarations": functions})
263
+ tools.append(Tool(function_declarations=functions))
216
264
 
217
- self._config = LiveConnectConfigDict(
218
- model=self._opts.model,
265
+ self._config = LiveConnectConfig(
219
266
  response_modalities=self._opts.response_modalities,
220
- generation_config=GenerationConfigDict(
267
+ generation_config=GenerationConfig(
221
268
  candidate_count=self._opts.candidate_count,
222
269
  temperature=self._opts.temperature,
223
270
  max_output_tokens=self._opts.max_output_tokens,
@@ -237,7 +284,7 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
237
284
  tools=tools,
238
285
  )
239
286
  self._client = genai.Client(
240
- http_options={"api_version": "v1alpha"},
287
+ http_options=HttpOptions(api_version="v1alpha"),
241
288
  api_key=self._opts.api_key,
242
289
  vertexai=self._opts.vertexai,
243
290
  project=self._opts.project,
@@ -246,12 +293,22 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
246
293
  self._main_atask = asyncio.create_task(
247
294
  self._main_task(), name="gemini-realtime-session"
248
295
  )
249
- # dummy task to wait for the session to be initialized # TODO: sync chat ctx
250
- self._init_sync_task = asyncio.create_task(
251
- asyncio.sleep(0), name="gemini-realtime-session-init"
252
- )
296
+ if self._opts.enable_user_audio_transcription:
297
+ self._transcriber = TranscriberSession(
298
+ client=self._client, model=self._opts.model
299
+ )
300
+ self._transcriber.on("input_speech_done", self._on_input_speech_done)
301
+ if self._opts.enable_agent_audio_transcription:
302
+ self._agent_transcriber = TranscriberSession(
303
+ client=self._client, model=self._opts.model
304
+ )
305
+ self._agent_transcriber.on("input_speech_done", self._on_agent_speech_done)
306
+ # init dummy task
307
+ self._init_sync_task = asyncio.create_task(asyncio.sleep(0))
253
308
  self._send_ch = utils.aio.Chan[ClientEvents]()
254
309
  self._active_response_id = None
310
+ if chat_ctx:
311
+ self.generate_reply(chat_ctx)
255
312
 
256
313
  async def aclose(self) -> None:
257
314
  if self._send_ch.closed:
@@ -269,32 +326,97 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
269
326
  self._fnc_ctx = value
270
327
 
271
328
  def _push_audio(self, frame: rtc.AudioFrame) -> None:
272
- data = base64.b64encode(frame.data).decode("utf-8")
273
- self._queue_msg({"mime_type": "audio/pcm", "data": data})
329
+ if self._opts.enable_user_audio_transcription:
330
+ self._transcriber._push_audio(frame)
331
+ realtime_input = LiveClientRealtimeInput(
332
+ media_chunks=[Blob(data=frame.data.tobytes(), mime_type="audio/pcm")],
333
+ )
334
+ self._queue_msg(realtime_input)
274
335
 
275
- def _queue_msg(self, msg: dict) -> None:
336
+ def _queue_msg(self, msg: ClientEvents) -> None:
276
337
  self._send_ch.send_nowait(msg)
277
338
 
339
+ def generate_reply(
340
+ self,
341
+ ctx: llm.ChatContext | llm.ChatMessage,
342
+ turn_complete: bool = True,
343
+ ) -> None:
344
+ if isinstance(ctx, llm.ChatMessage) and isinstance(ctx.content, str):
345
+ new_chat_ctx = llm.ChatContext()
346
+ new_chat_ctx.append(text=ctx.content, role=ctx.role)
347
+ elif isinstance(ctx, llm.ChatContext):
348
+ new_chat_ctx = ctx
349
+ else:
350
+ raise ValueError("Invalid chat context")
351
+ turns, _ = _build_gemini_ctx(new_chat_ctx, id(self))
352
+ client_content = LiveClientContent(
353
+ turn_complete=turn_complete,
354
+ turns=turns,
355
+ )
356
+ self._queue_msg(client_content)
357
+
278
358
  def chat_ctx_copy(self) -> llm.ChatContext:
279
359
  return self._chat_ctx.copy()
280
360
 
281
361
  async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
282
362
  self._chat_ctx = ctx.copy()
283
363
 
364
+ def cancel_response(self) -> None:
365
+ raise NotImplementedError("cancel_response is not supported yet")
366
+
367
+ def create_response(
368
+ self,
369
+ on_duplicate: Literal[
370
+ "cancel_existing", "cancel_new", "keep_both"
371
+ ] = "keep_both",
372
+ ) -> None:
373
+ raise NotImplementedError("create_response is not supported yet")
374
+
375
+ def commit_audio_buffer(self) -> None:
376
+ raise NotImplementedError("commit_audio_buffer is not supported yet")
377
+
378
+ def server_vad_enabled(self) -> bool:
379
+ return True
380
+
381
+ def _on_input_speech_done(self, content: TranscriptionContent) -> None:
382
+ if content.response_id and content.text:
383
+ self.emit(
384
+ "input_speech_transcription_completed",
385
+ InputTranscription(
386
+ item_id=content.response_id,
387
+ transcript=content.text,
388
+ ),
389
+ )
390
+
391
+ # self._chat_ctx.append(text=content.text, role="user")
392
+ # TODO: implement sync mechanism to make sure the transcribed user speech is inside the chat_ctx and always before the generated agent speech
393
+
394
+ def _on_agent_speech_done(self, content: TranscriptionContent) -> None:
395
+ if not self._is_interrupted and content.response_id and content.text:
396
+ self.emit(
397
+ "agent_speech_transcription_completed",
398
+ InputTranscription(
399
+ item_id=content.response_id,
400
+ transcript=content.text,
401
+ ),
402
+ )
403
+ # self._chat_ctx.append(text=content.text, role="assistant")
404
+
284
405
  @utils.log_exceptions(logger=logger)
285
406
  async def _main_task(self):
286
407
  @utils.log_exceptions(logger=logger)
287
408
  async def _send_task():
288
409
  async for msg in self._send_ch:
289
- await self._session.send(msg)
410
+ await self._session.send(input=msg)
290
411
 
291
- await self._session.send(".", end_of_turn=True)
412
+ await self._session.send(input=".", end_of_turn=True)
292
413
 
293
414
  @utils.log_exceptions(logger=logger)
294
415
  async def _recv_task():
295
416
  while True:
296
417
  async for response in self._session.receive():
297
418
  if self._active_response_id is None:
419
+ self._is_interrupted = False
298
420
  self._active_response_id = utils.shortuuid()
299
421
  text_stream = utils.aio.Chan[str]()
300
422
  audio_stream = utils.aio.Chan[rtc.AudioFrame]()
@@ -307,7 +429,7 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
307
429
  audio=[],
308
430
  text_stream=text_stream,
309
431
  audio_stream=audio_stream,
310
- content_type=self._opts.response_modalities,
432
+ content_type="audio",
311
433
  )
312
434
  self.emit("response_content_added", content)
313
435
 
@@ -326,6 +448,8 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
326
448
  samples_per_channel=len(part.inline_data.data)
327
449
  // 2,
328
450
  )
451
+ if self._opts.enable_agent_audio_transcription:
452
+ self._agent_transcriber._push_audio(frame)
329
453
  content.audio_stream.send_nowait(frame)
330
454
 
331
455
  if server_content.interrupted or server_content.turn_complete:
@@ -333,10 +457,8 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
333
457
  if isinstance(stream, utils.aio.Chan):
334
458
  stream.close()
335
459
 
336
- if server_content.interrupted:
337
- self.emit("input_speech_started")
338
- elif server_content.turn_complete:
339
- self.emit("response_content_done", content)
460
+ self.emit("agent_speech_stopped")
461
+ self._is_interrupted = True
340
462
 
341
463
  self._active_response_id = None
342
464
 
@@ -387,6 +509,10 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
387
509
  finally:
388
510
  await utils.aio.gracefully_cancel(*tasks)
389
511
  await self._session.close()
512
+ if self._opts.enable_user_audio_transcription:
513
+ await self._transcriber.aclose()
514
+ if self._opts.enable_agent_audio_transcription:
515
+ await self._agent_transcriber.aclose()
390
516
 
391
517
  @utils.log_exceptions(logger=logger)
392
518
  async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str):
@@ -419,6 +545,6 @@ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
419
545
  )
420
546
  ]
421
547
  )
422
- await self._session.send(tool_response)
548
+ await self._session.send(input=tool_response)
423
549
 
424
550
  self.emit("function_calls_finished", [called_fnc])
@@ -0,0 +1,173 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Literal
7
+
8
+ import websockets
9
+ from livekit import rtc
10
+ from livekit.agents import utils
11
+
12
+ from google import genai
13
+ from google.genai import types
14
+
15
+ from ...log import logger
16
+ from .api_proto import ClientEvents, LiveAPIModels
17
+
18
+ EventTypes = Literal[
19
+ "input_speech_started",
20
+ "input_speech_done",
21
+ ]
22
+
23
+ DEFAULT_LANGUAGE = "English"
24
+
25
+ SYSTEM_INSTRUCTIONS = f"""
26
+ You are an **Audio Transcriber**. Your task is to convert audio content into accurate and precise text.
27
+
28
+ - Transcribe verbatim; exclude non-speech sounds.
29
+ - Provide only transcription; no extra text or explanations.
30
+ - If audio is unclear, respond with: `...`
31
+ - Ensure error-free transcription, preserving meaning and context.
32
+ - Use proper punctuation and formatting.
33
+ - Do not add explanations, comments, or extra information.
34
+ - Do not include timestamps, speaker labels, or annotations unless specified.
35
+
36
+ - Audio Language: {DEFAULT_LANGUAGE}
37
+ """
38
+
39
+
40
+ @dataclass
41
+ class TranscriptionContent:
42
+ response_id: str
43
+ text: str
44
+
45
+
46
+ class TranscriberSession(utils.EventEmitter[EventTypes]):
47
+ def __init__(
48
+ self,
49
+ *,
50
+ client: genai.Client,
51
+ model: LiveAPIModels | str,
52
+ ):
53
+ """
54
+ Initializes a TranscriberSession instance for interacting with Google's Realtime API.
55
+ """
56
+ super().__init__()
57
+ self._client = client
58
+ self._model = model
59
+ self._closed = False
60
+ system_instructions = types.Content(
61
+ parts=[types.Part(text=SYSTEM_INSTRUCTIONS)]
62
+ )
63
+
64
+ self._config = types.LiveConnectConfig(
65
+ response_modalities=["TEXT"],
66
+ system_instruction=system_instructions,
67
+ generation_config=types.GenerationConfig(
68
+ temperature=0.0,
69
+ ),
70
+ )
71
+ self._main_atask = asyncio.create_task(
72
+ self._main_task(), name="gemini-realtime-transcriber"
73
+ )
74
+ self._send_ch = utils.aio.Chan[ClientEvents]()
75
+ self._active_response_id = None
76
+
77
+ def _push_audio(self, frame: rtc.AudioFrame) -> None:
78
+ if self._closed:
79
+ return
80
+ self._queue_msg(
81
+ types.LiveClientRealtimeInput(
82
+ media_chunks=[
83
+ types.Blob(data=frame.data.tobytes(), mime_type="audio/pcm")
84
+ ]
85
+ )
86
+ )
87
+
88
+ def _queue_msg(self, msg: ClientEvents) -> None:
89
+ if not self._closed:
90
+ self._send_ch.send_nowait(msg)
91
+
92
+ async def aclose(self) -> None:
93
+ if self._send_ch.closed:
94
+ return
95
+ self._closed = True
96
+ self._send_ch.close()
97
+ await self._main_atask
98
+
99
+ @utils.log_exceptions(logger=logger)
100
+ async def _main_task(self):
101
+ @utils.log_exceptions(logger=logger)
102
+ async def _send_task():
103
+ try:
104
+ async for msg in self._send_ch:
105
+ if self._closed:
106
+ break
107
+ await self._session.send(input=msg)
108
+ except websockets.exceptions.ConnectionClosedError as e:
109
+ logger.exception(f"Transcriber session closed in _send_task: {e}")
110
+ self._closed = True
111
+ except Exception as e:
112
+ logger.exception(f"Uncaught error in transcriber _send_task: {e}")
113
+ self._closed = True
114
+
115
+ @utils.log_exceptions(logger=logger)
116
+ async def _recv_task():
117
+ try:
118
+ while not self._closed:
119
+ async for response in self._session.receive():
120
+ if self._closed:
121
+ break
122
+ if self._active_response_id is None:
123
+ self._active_response_id = utils.shortuuid()
124
+ content = TranscriptionContent(
125
+ response_id=self._active_response_id,
126
+ text="",
127
+ )
128
+ self.emit("input_speech_started", content)
129
+
130
+ server_content = response.server_content
131
+ if server_content:
132
+ model_turn = server_content.model_turn
133
+ if model_turn:
134
+ for part in model_turn.parts:
135
+ if part.text:
136
+ content.text += part.text
137
+
138
+ if server_content.turn_complete:
139
+ content.text = clean_transcription(content.text)
140
+ self.emit("input_speech_done", content)
141
+ self._active_response_id = None
142
+
143
+ except websockets.exceptions.ConnectionClosedError as e:
144
+ logger.exception(f"Transcriber session closed in _recv_task: {e}")
145
+ self._closed = True
146
+ except Exception as e:
147
+ logger.exception(f"Uncaught error in transcriber _recv_task: {e}")
148
+ self._closed = True
149
+
150
+ async with self._client.aio.live.connect(
151
+ model=self._model, config=self._config
152
+ ) as session:
153
+ self._session = session
154
+ tasks = [
155
+ asyncio.create_task(
156
+ _send_task(), name="gemini-realtime-transcriber-send"
157
+ ),
158
+ asyncio.create_task(
159
+ _recv_task(), name="gemini-realtime-transcriber-recv"
160
+ ),
161
+ ]
162
+
163
+ try:
164
+ await asyncio.gather(*tasks)
165
+ finally:
166
+ await utils.aio.gracefully_cancel(*tasks)
167
+ await self._session.close()
168
+
169
+
170
+ def clean_transcription(text: str) -> str:
171
+ text = text.replace("\n", " ")
172
+ text = re.sub(r"\s+", " ", text)
173
+ return text.strip()