livekit-plugins-google 0.7.3__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from . import beta
15
16
  from .stt import STT, SpeechStream
16
17
  from .tts import TTS
17
18
  from .version import __version__
18
19
 
19
- __all__ = ["STT", "TTS", "SpeechStream", "__version__"]
20
-
20
+ __all__ = ["STT", "TTS", "SpeechStream", "__version__", "beta"]
21
21
  from livekit.agents import Plugin
22
22
 
23
23
  from .log import logger
@@ -0,0 +1,3 @@
1
+ from . import realtime
2
+
3
+ __all__ = ["realtime"]
@@ -0,0 +1,15 @@
1
+ from .api_proto import (
2
+ ClientEvents,
3
+ LiveAPIModels,
4
+ ResponseModality,
5
+ Voice,
6
+ )
7
+ from .realtime_api import RealtimeModel
8
+
9
+ __all__ = [
10
+ "RealtimeModel",
11
+ "ClientEvents",
12
+ "LiveAPIModels",
13
+ "ResponseModality",
14
+ "Voice",
15
+ ]
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import Any, Dict, List, Literal, Sequence, Union
5
+
6
+ from google.genai import types # type: ignore
7
+
8
+ LiveAPIModels = Literal["gemini-2.0-flash-exp"]
9
+
10
+ Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"]
11
+ ResponseModality = Literal["AUDIO", "TEXT"]
12
+
13
+
14
+ ClientEvents = Union[
15
+ types.ContentListUnion,
16
+ types.ContentListUnionDict,
17
+ types.LiveClientContentOrDict,
18
+ types.LiveClientRealtimeInput,
19
+ types.LiveClientRealtimeInputOrDict,
20
+ types.LiveClientToolResponseOrDict,
21
+ types.FunctionResponseOrDict,
22
+ Sequence[types.FunctionResponseOrDict],
23
+ ]
24
+
25
+
26
+ JSON_SCHEMA_TYPE_MAP = {
27
+ str: "string",
28
+ int: "integer",
29
+ float: "number",
30
+ bool: "boolean",
31
+ dict: "object",
32
+ list: "array",
33
+ }
34
+
35
+
36
+ def _build_parameters(arguments: Dict[str, Any]) -> types.SchemaDict:
37
+ properties: Dict[str, types.SchemaDict] = {}
38
+ required: List[str] = []
39
+
40
+ for arg_name, arg_info in arguments.items():
41
+ py_type = arg_info.type
42
+ if py_type not in JSON_SCHEMA_TYPE_MAP:
43
+ raise ValueError(f"Unsupported type: {py_type}")
44
+
45
+ prop: types.SchemaDict = {
46
+ "type": JSON_SCHEMA_TYPE_MAP[py_type],
47
+ "description": arg_info.description,
48
+ }
49
+
50
+ if arg_info.choices:
51
+ prop["enum"] = arg_info.choices
52
+
53
+ properties[arg_name] = prop
54
+
55
+ if arg_info.default is inspect.Parameter.empty:
56
+ required.append(arg_name)
57
+
58
+ parameters: types.SchemaDict = {"type": "object", "properties": properties}
59
+
60
+ if required:
61
+ parameters["required"] = required
62
+
63
+ return parameters
64
+
65
+
66
+ def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclarationDict]:
67
+ function_declarations: List[types.FunctionDeclarationDict] = []
68
+ for fnc_info in fnc_ctx.ai_functions.values():
69
+ parameters = _build_parameters(fnc_info.arguments)
70
+
71
+ func_decl: types.FunctionDeclarationDict = {
72
+ "name": fnc_info.name,
73
+ "description": fnc_info.description,
74
+ "parameters": parameters,
75
+ }
76
+
77
+ function_declarations.append(func_decl)
78
+
79
+ return function_declarations
@@ -0,0 +1,424 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import json
6
+ import os
7
+ from dataclasses import dataclass
8
+ from typing import AsyncIterable, Literal
9
+
10
+ from livekit import rtc
11
+ from livekit.agents import llm, utils
12
+ from livekit.agents.llm.function_context import _create_ai_function_info
13
+
14
+ from google import genai # type: ignore
15
+ from google.genai.types import ( # type: ignore
16
+ FunctionResponse,
17
+ GenerationConfigDict,
18
+ LiveClientToolResponse,
19
+ LiveConnectConfigDict,
20
+ PrebuiltVoiceConfig,
21
+ SpeechConfig,
22
+ VoiceConfig,
23
+ )
24
+
25
+ from ...log import logger
26
+ from .api_proto import (
27
+ ClientEvents,
28
+ LiveAPIModels,
29
+ ResponseModality,
30
+ Voice,
31
+ _build_tools,
32
+ )
33
+
34
+ EventTypes = Literal[
35
+ "start_session",
36
+ "input_speech_started",
37
+ "response_content_added",
38
+ "response_content_done",
39
+ "function_calls_collected",
40
+ "function_calls_finished",
41
+ "function_calls_cancelled",
42
+ ]
43
+
44
+
45
+ @dataclass
46
+ class GeminiContent:
47
+ response_id: str
48
+ item_id: str
49
+ output_index: int
50
+ content_index: int
51
+ text: str
52
+ audio: list[rtc.AudioFrame]
53
+ text_stream: AsyncIterable[str]
54
+ audio_stream: AsyncIterable[rtc.AudioFrame]
55
+ content_type: Literal["text", "audio"]
56
+
57
+
58
+ @dataclass
59
+ class Capabilities:
60
+ supports_truncate: bool
61
+
62
+
63
+ @dataclass
64
+ class ModelOptions:
65
+ model: LiveAPIModels | str
66
+ api_key: str | None
67
+ voice: Voice | str
68
+ response_modalities: ResponseModality
69
+ vertexai: bool
70
+ project: str | None
71
+ location: str | None
72
+ candidate_count: int
73
+ temperature: float | None
74
+ max_output_tokens: int | None
75
+ top_p: float | None
76
+ top_k: int | None
77
+ presence_penalty: float | None
78
+ frequency_penalty: float | None
79
+ instructions: str
80
+
81
+
82
+ class RealtimeModel:
83
+ def __init__(
84
+ self,
85
+ *,
86
+ instructions: str = "",
87
+ model: LiveAPIModels | str = "gemini-2.0-flash-exp",
88
+ api_key: str | None = None,
89
+ voice: Voice | str = "Puck",
90
+ modalities: ResponseModality = "AUDIO",
91
+ vertexai: bool = False,
92
+ project: str | None = None,
93
+ location: str | None = None,
94
+ candidate_count: int = 1,
95
+ temperature: float | None = None,
96
+ max_output_tokens: int | None = None,
97
+ top_p: float | None = None,
98
+ top_k: int | None = None,
99
+ presence_penalty: float | None = None,
100
+ frequency_penalty: float | None = None,
101
+ loop: asyncio.AbstractEventLoop | None = None,
102
+ ):
103
+ """
104
+ Initializes a RealtimeModel instance for interacting with Google's Realtime API.
105
+
106
+ Args:
107
+ instructions (str, optional): Initial system instructions for the model. Defaults to "".
108
+ api_key (str or None, optional): OpenAI API key. If None, will attempt to read from the environment variable OPENAI_API_KEY
109
+ modalities (ResponseModality): Modalities to use, such as ["TEXT", "AUDIO"]. Defaults to ["AUDIO"].
110
+ model (str or None, optional): The name of the model to use. Defaults to "gemini-2.0-flash-exp".
111
+ voice (api_proto.Voice, optional): Voice setting for audio outputs. Defaults to "Puck".
112
+ temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
113
+ vertexai (bool, optional): Whether to use VertexAI for the API. Defaults to False.
114
+ project (str or None, optional): The project to use for the API. Defaults to None. (for vertexai)
115
+ location (str or None, optional): The location to use for the API. Defaults to None. (for vertexai)
116
+ candidate_count (int, optional): The number of candidate responses to generate. Defaults to 1.
117
+ top_p (float, optional): The top-p value for response generation
118
+ top_k (int, optional): The top-k value for response generation
119
+ presence_penalty (float, optional): The presence penalty for response generation
120
+ frequency_penalty (float, optional): The frequency penalty for response generation
121
+ loop (asyncio.AbstractEventLoop or None, optional): Event loop to use for async operations. If None, the current event loop is used.
122
+
123
+ Raises:
124
+ ValueError: If the API key is not provided and cannot be found in environment variables.
125
+ """
126
+ super().__init__()
127
+ self._capabilities = Capabilities(
128
+ supports_truncate=False,
129
+ )
130
+ self._model = model
131
+ self._loop = loop or asyncio.get_event_loop()
132
+ self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
133
+ self._vertexai = vertexai
134
+ self._project_id = project or os.environ.get("GOOGLE_PROJECT")
135
+ self._location = location or os.environ.get("GOOGLE_LOCATION")
136
+ if self._api_key is None and not self._vertexai:
137
+ raise ValueError("GOOGLE_API_KEY is not set")
138
+
139
+ self._rt_sessions: list[GeminiRealtimeSession] = []
140
+ self._opts = ModelOptions(
141
+ model=model,
142
+ api_key=api_key,
143
+ voice=voice,
144
+ response_modalities=modalities,
145
+ vertexai=vertexai,
146
+ project=project,
147
+ location=location,
148
+ candidate_count=candidate_count,
149
+ temperature=temperature,
150
+ max_output_tokens=max_output_tokens,
151
+ top_p=top_p,
152
+ top_k=top_k,
153
+ presence_penalty=presence_penalty,
154
+ frequency_penalty=frequency_penalty,
155
+ instructions=instructions,
156
+ )
157
+
158
+ @property
159
+ def sessions(self) -> list[GeminiRealtimeSession]:
160
+ return self._rt_sessions
161
+
162
+ @property
163
+ def capabilities(self) -> Capabilities:
164
+ return self._capabilities
165
+
166
+ def session(
167
+ self,
168
+ *,
169
+ chat_ctx: llm.ChatContext | None = None,
170
+ fnc_ctx: llm.FunctionContext | None = None,
171
+ ) -> GeminiRealtimeSession:
172
+ session = GeminiRealtimeSession(
173
+ opts=self._opts,
174
+ chat_ctx=chat_ctx or llm.ChatContext(),
175
+ fnc_ctx=fnc_ctx,
176
+ loop=self._loop,
177
+ )
178
+ self._rt_sessions.append(session)
179
+
180
+ return session
181
+
182
+ async def aclose(self) -> None:
183
+ for session in self._rt_sessions:
184
+ await session.aclose()
185
+
186
+
187
+ class GeminiRealtimeSession(utils.EventEmitter[EventTypes]):
188
+ def __init__(
189
+ self,
190
+ *,
191
+ opts: ModelOptions,
192
+ chat_ctx: llm.ChatContext,
193
+ fnc_ctx: llm.FunctionContext | None,
194
+ loop: asyncio.AbstractEventLoop,
195
+ ):
196
+ """
197
+ Initializes a GeminiRealtimeSession instance for interacting with Google's Realtime API.
198
+
199
+ Args:
200
+ opts (ModelOptions): The model options for the session.
201
+ chat_ctx (llm.ChatContext): The chat context for the session.
202
+ fnc_ctx (llm.FunctionContext or None): The function context for the session.
203
+ loop (asyncio.AbstractEventLoop): The event loop for the session.
204
+ """
205
+ super().__init__()
206
+ self._loop = loop
207
+ self._opts = opts
208
+ self._chat_ctx = chat_ctx
209
+ self._fnc_ctx = fnc_ctx
210
+ self._fnc_tasks = utils.aio.TaskSet()
211
+
212
+ tools = []
213
+ if self._fnc_ctx is not None:
214
+ functions = _build_tools(self._fnc_ctx)
215
+ tools.append({"function_declarations": functions})
216
+
217
+ self._config = LiveConnectConfigDict(
218
+ model=self._opts.model,
219
+ response_modalities=self._opts.response_modalities,
220
+ generation_config=GenerationConfigDict(
221
+ candidate_count=self._opts.candidate_count,
222
+ temperature=self._opts.temperature,
223
+ max_output_tokens=self._opts.max_output_tokens,
224
+ top_p=self._opts.top_p,
225
+ top_k=self._opts.top_k,
226
+ presence_penalty=self._opts.presence_penalty,
227
+ frequency_penalty=self._opts.frequency_penalty,
228
+ ),
229
+ system_instruction=self._opts.instructions,
230
+ speech_config=SpeechConfig(
231
+ voice_config=VoiceConfig(
232
+ prebuilt_voice_config=PrebuiltVoiceConfig(
233
+ voice_name=self._opts.voice
234
+ )
235
+ )
236
+ ),
237
+ tools=tools,
238
+ )
239
+ self._client = genai.Client(
240
+ http_options={"api_version": "v1alpha"},
241
+ api_key=self._opts.api_key,
242
+ vertexai=self._opts.vertexai,
243
+ project=self._opts.project,
244
+ location=self._opts.location,
245
+ )
246
+ self._main_atask = asyncio.create_task(
247
+ self._main_task(), name="gemini-realtime-session"
248
+ )
249
+ # dummy task to wait for the session to be initialized # TODO: sync chat ctx
250
+ self._init_sync_task = asyncio.create_task(
251
+ asyncio.sleep(0), name="gemini-realtime-session-init"
252
+ )
253
+ self._send_ch = utils.aio.Chan[ClientEvents]()
254
+ self._active_response_id = None
255
+
256
+ async def aclose(self) -> None:
257
+ if self._send_ch.closed:
258
+ return
259
+
260
+ self._send_ch.close()
261
+ await self._main_atask
262
+
263
+ @property
264
+ def fnc_ctx(self) -> llm.FunctionContext | None:
265
+ return self._fnc_ctx
266
+
267
+ @fnc_ctx.setter
268
+ def fnc_ctx(self, value: llm.FunctionContext | None) -> None:
269
+ self._fnc_ctx = value
270
+
271
+ def _push_audio(self, frame: rtc.AudioFrame) -> None:
272
+ data = base64.b64encode(frame.data).decode("utf-8")
273
+ self._queue_msg({"mime_type": "audio/pcm", "data": data})
274
+
275
+ def _queue_msg(self, msg: dict) -> None:
276
+ self._send_ch.send_nowait(msg)
277
+
278
+ def chat_ctx_copy(self) -> llm.ChatContext:
279
+ return self._chat_ctx.copy()
280
+
281
+ async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
282
+ self._chat_ctx = ctx.copy()
283
+
284
+ @utils.log_exceptions(logger=logger)
285
+ async def _main_task(self):
286
+ @utils.log_exceptions(logger=logger)
287
+ async def _send_task():
288
+ async for msg in self._send_ch:
289
+ await self._session.send(msg)
290
+
291
+ await self._session.send(".", end_of_turn=True)
292
+
293
+ @utils.log_exceptions(logger=logger)
294
+ async def _recv_task():
295
+ while True:
296
+ async for response in self._session.receive():
297
+ if self._active_response_id is None:
298
+ self._active_response_id = utils.shortuuid()
299
+ text_stream = utils.aio.Chan[str]()
300
+ audio_stream = utils.aio.Chan[rtc.AudioFrame]()
301
+ content = GeminiContent(
302
+ response_id=self._active_response_id,
303
+ item_id=self._active_response_id,
304
+ output_index=0,
305
+ content_index=0,
306
+ text="",
307
+ audio=[],
308
+ text_stream=text_stream,
309
+ audio_stream=audio_stream,
310
+ content_type=self._opts.response_modalities,
311
+ )
312
+ self.emit("response_content_added", content)
313
+
314
+ server_content = response.server_content
315
+ if server_content:
316
+ model_turn = server_content.model_turn
317
+ if model_turn:
318
+ for part in model_turn.parts:
319
+ if part.text:
320
+ content.text_stream.send_nowait(part.text)
321
+ if part.inline_data:
322
+ frame = rtc.AudioFrame(
323
+ data=part.inline_data.data,
324
+ sample_rate=24000,
325
+ num_channels=1,
326
+ samples_per_channel=len(part.inline_data.data)
327
+ // 2,
328
+ )
329
+ content.audio_stream.send_nowait(frame)
330
+
331
+ if server_content.interrupted or server_content.turn_complete:
332
+ for stream in (content.text_stream, content.audio_stream):
333
+ if isinstance(stream, utils.aio.Chan):
334
+ stream.close()
335
+
336
+ if server_content.interrupted:
337
+ self.emit("input_speech_started")
338
+ elif server_content.turn_complete:
339
+ self.emit("response_content_done", content)
340
+
341
+ self._active_response_id = None
342
+
343
+ if response.tool_call:
344
+ if self._fnc_ctx is None:
345
+ raise ValueError("Function context is not set")
346
+ fnc_calls = []
347
+ for fnc_call in response.tool_call.function_calls:
348
+ fnc_call_info = _create_ai_function_info(
349
+ self._fnc_ctx,
350
+ fnc_call.id,
351
+ fnc_call.name,
352
+ json.dumps(fnc_call.args),
353
+ )
354
+ fnc_calls.append(fnc_call_info)
355
+
356
+ self.emit("function_calls_collected", fnc_calls)
357
+
358
+ for fnc_call_info in fnc_calls:
359
+ self._fnc_tasks.create_task(
360
+ self._run_fnc_task(fnc_call_info, content.item_id)
361
+ )
362
+
363
+ # Handle function call cancellations
364
+ if response.tool_call_cancellation:
365
+ logger.warning(
366
+ "function call cancelled",
367
+ extra={
368
+ "function_call_ids": response.tool_call_cancellation.function_call_ids,
369
+ },
370
+ )
371
+ self.emit(
372
+ "function_calls_cancelled",
373
+ response.tool_call_cancellation.function_call_ids,
374
+ )
375
+
376
+ async with self._client.aio.live.connect(
377
+ model=self._opts.model, config=self._config
378
+ ) as session:
379
+ self._session = session
380
+ tasks = [
381
+ asyncio.create_task(_send_task(), name="gemini-realtime-send"),
382
+ asyncio.create_task(_recv_task(), name="gemini-realtime-recv"),
383
+ ]
384
+
385
+ try:
386
+ await asyncio.gather(*tasks)
387
+ finally:
388
+ await utils.aio.gracefully_cancel(*tasks)
389
+ await self._session.close()
390
+
391
+ @utils.log_exceptions(logger=logger)
392
+ async def _run_fnc_task(self, fnc_call_info: llm.FunctionCallInfo, item_id: str):
393
+ logger.debug(
394
+ "executing ai function",
395
+ extra={
396
+ "function": fnc_call_info.function_info.name,
397
+ },
398
+ )
399
+
400
+ called_fnc = fnc_call_info.execute()
401
+ try:
402
+ await called_fnc.task
403
+ except Exception as e:
404
+ logger.exception(
405
+ "error executing ai function",
406
+ extra={
407
+ "function": fnc_call_info.function_info.name,
408
+ },
409
+ exc_info=e,
410
+ )
411
+ tool_call = llm.ChatMessage.create_tool_from_called_function(called_fnc)
412
+ if tool_call.content is not None:
413
+ tool_response = LiveClientToolResponse(
414
+ function_responses=[
415
+ FunctionResponse(
416
+ name=tool_call.name,
417
+ id=tool_call.tool_call_id,
418
+ response={"result": tool_call.content},
419
+ )
420
+ ]
421
+ )
422
+ await self._session.send(tool_response)
423
+
424
+ self.emit("function_calls_finished", [called_fnc])
@@ -3,7 +3,13 @@ from typing import Literal
3
3
  # Speech to Text v2
4
4
 
5
5
  SpeechModels = Literal[
6
- "long", "short", "telephony", "medical_dictation", "medical_conversation", "chirp"
6
+ "long",
7
+ "short",
8
+ "telephony",
9
+ "medical_dictation",
10
+ "medical_conversation",
11
+ "chirp",
12
+ "chirp_2",
7
13
  ]
8
14
 
9
15
  SpeechLanguages = Literal[
@@ -16,19 +16,23 @@ from __future__ import annotations
16
16
 
17
17
  import asyncio
18
18
  import dataclasses
19
+ import weakref
19
20
  from dataclasses import dataclass
20
- from typing import AsyncIterable, List, Union
21
+ from typing import List, Union
21
22
 
22
- from livekit import agents, rtc
23
+ from livekit import rtc
23
24
  from livekit.agents import (
25
+ DEFAULT_API_CONNECT_OPTIONS,
24
26
  APIConnectionError,
27
+ APIConnectOptions,
25
28
  APIStatusError,
26
29
  APITimeoutError,
27
30
  stt,
28
31
  utils,
29
32
  )
30
33
 
31
- from google.api_core.exceptions import Aborted, DeadlineExceeded, GoogleAPICallError
34
+ from google.api_core.client_options import ClientOptions
35
+ from google.api_core.exceptions import DeadlineExceeded, GoogleAPICallError
32
36
  from google.auth import default as gauth_default
33
37
  from google.auth.exceptions import DefaultCredentialsError
34
38
  from google.cloud.speech_v2 import SpeechAsyncClient
@@ -50,6 +54,7 @@ class STTOptions:
50
54
  punctuate: bool
51
55
  spoken_punctuation: bool
52
56
  model: SpeechModels
57
+ sample_rate: int
53
58
  keywords: List[tuple[str, float]] | None
54
59
 
55
60
  def build_adaptation(self) -> cloud_speech.SpeechAdaptation | None:
@@ -81,6 +86,8 @@ class STT(stt.STT):
81
86
  punctuate: bool = True,
82
87
  spoken_punctuation: bool = True,
83
88
  model: SpeechModels = "long",
89
+ location: str = "global",
90
+ sample_rate: int = 16000,
84
91
  credentials_info: dict | None = None,
85
92
  credentials_file: str | None = None,
86
93
  keywords: List[tuple[str, float]] | None = None,
@@ -97,6 +104,7 @@ class STT(stt.STT):
97
104
  )
98
105
 
99
106
  self._client: SpeechAsyncClient | None = None
107
+ self._location = location
100
108
  self._credentials_info = credentials_info
101
109
  self._credentials_file = credentials_file
102
110
 
@@ -120,8 +128,10 @@ class STT(stt.STT):
120
128
  punctuate=punctuate,
121
129
  spoken_punctuation=spoken_punctuation,
122
130
  model=model,
131
+ sample_rate=sample_rate,
123
132
  keywords=keywords,
124
133
  )
134
+ self._streams = weakref.WeakSet[SpeechStream]()
125
135
 
126
136
  def _ensure_client(self) -> SpeechAsyncClient:
127
137
  if self._credentials_info:
@@ -132,9 +142,16 @@ class STT(stt.STT):
132
142
  self._client = SpeechAsyncClient.from_service_account_file(
133
143
  self._credentials_file
134
144
  )
135
- else:
145
+ elif self._location == "global":
136
146
  self._client = SpeechAsyncClient()
137
-
147
+ else:
148
+ # Add support for passing a specific location that matches recognizer
149
+ # see: https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
150
+ self._client = SpeechAsyncClient(
151
+ client_options=ClientOptions(
152
+ api_endpoint=f"{self._location}-speech.googleapis.com"
153
+ )
154
+ )
138
155
  assert self._client is not None
139
156
  return self._client
140
157
 
@@ -150,7 +167,7 @@ class STT(stt.STT):
150
167
  from google.auth import default as ga_default
151
168
 
152
169
  _, project_id = ga_default()
153
- return f"projects/{project_id}/locations/global/recognizers/_"
170
+ return f"projects/{project_id}/locations/{self._location}/recognizers/_"
154
171
 
155
172
  def _sanitize_options(self, *, language: str | None = None) -> STTOptions:
156
173
  config = dataclasses.replace(self._config)
@@ -173,10 +190,11 @@ class STT(stt.STT):
173
190
  self,
174
191
  buffer: utils.AudioBuffer,
175
192
  *,
176
- language: SpeechLanguages | str | None = None,
193
+ language: SpeechLanguages | str | None,
194
+ conn_options: APIConnectOptions,
177
195
  ) -> stt.SpeechEvent:
178
196
  config = self._sanitize_options(language=language)
179
- frame = agents.utils.merge_frames(buffer)
197
+ frame = rtc.combine_audio_frames(buffer)
180
198
 
181
199
  config = cloud_speech.RecognitionConfig(
182
200
  explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
@@ -200,7 +218,8 @@ class STT(stt.STT):
200
218
  recognizer=self._recognizer,
201
219
  config=config,
202
220
  content=frame.data.tobytes(),
203
- )
221
+ ),
222
+ timeout=conn_options.timeout,
204
223
  )
205
224
 
206
225
  return _recognize_response_to_speech_event(raw)
@@ -217,154 +236,223 @@ class STT(stt.STT):
217
236
  raise APIConnectionError() from e
218
237
 
219
238
  def stream(
220
- self, *, language: SpeechLanguages | str | None = None
239
+ self,
240
+ *,
241
+ language: SpeechLanguages | str | None = None,
242
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
221
243
  ) -> "SpeechStream":
222
244
  config = self._sanitize_options(language=language)
223
- return SpeechStream(self, self._ensure_client(), self._recognizer, config)
245
+ stream = SpeechStream(
246
+ stt=self,
247
+ client=self._ensure_client(),
248
+ recognizer=self._recognizer,
249
+ config=config,
250
+ conn_options=conn_options,
251
+ )
252
+ self._streams.add(stream)
253
+ return stream
254
+
255
+ def update_options(
256
+ self,
257
+ *,
258
+ languages: LanguageCode | None = None,
259
+ detect_language: bool | None = None,
260
+ interim_results: bool | None = None,
261
+ punctuate: bool | None = None,
262
+ spoken_punctuation: bool | None = None,
263
+ model: SpeechModels | None = None,
264
+ location: str | None = None,
265
+ keywords: List[tuple[str, float]] | None = None,
266
+ ):
267
+ if languages is not None:
268
+ if isinstance(languages, str):
269
+ languages = [languages]
270
+ self._config.languages = languages
271
+ if detect_language is not None:
272
+ self._config.detect_language = detect_language
273
+ if interim_results is not None:
274
+ self._config.interim_results = interim_results
275
+ if punctuate is not None:
276
+ self._config.punctuate = punctuate
277
+ if spoken_punctuation is not None:
278
+ self._config.spoken_punctuation = spoken_punctuation
279
+ if model is not None:
280
+ self._config.model = model
281
+ if keywords is not None:
282
+ self._config.keywords = keywords
283
+
284
+ for stream in self._streams:
285
+ stream.update_options(
286
+ languages=languages,
287
+ detect_language=detect_language,
288
+ interim_results=interim_results,
289
+ punctuate=punctuate,
290
+ spoken_punctuation=spoken_punctuation,
291
+ model=model,
292
+ location=location,
293
+ keywords=keywords,
294
+ )
224
295
 
225
296
 
226
297
  class SpeechStream(stt.SpeechStream):
227
298
  def __init__(
228
299
  self,
300
+ *,
229
301
  stt: STT,
302
+ conn_options: APIConnectOptions,
230
303
  client: SpeechAsyncClient,
231
304
  recognizer: str,
232
305
  config: STTOptions,
233
- sample_rate: int = 48000,
234
- num_channels: int = 1,
235
- max_retry: int = 32,
236
306
  ) -> None:
237
- super().__init__(stt)
307
+ super().__init__(
308
+ stt=stt, conn_options=conn_options, sample_rate=config.sample_rate
309
+ )
238
310
 
239
311
  self._client = client
240
312
  self._recognizer = recognizer
241
313
  self._config = config
242
- self._sample_rate = sample_rate
243
- self._num_channels = num_channels
244
- self._max_retry = max_retry
245
-
246
- self._streaming_config = cloud_speech.StreamingRecognitionConfig(
247
- config=cloud_speech.RecognitionConfig(
248
- explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
249
- encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
250
- sample_rate_hertz=self._sample_rate,
251
- audio_channel_count=self._num_channels,
252
- ),
253
- adaptation=config.build_adaptation(),
254
- language_codes=self._config.languages,
255
- model=self._config.model,
256
- features=cloud_speech.RecognitionFeatures(
257
- enable_automatic_punctuation=self._config.punctuate,
258
- enable_word_time_offsets=True,
259
- ),
260
- ),
261
- streaming_features=cloud_speech.StreamingRecognitionFeatures(
262
- enable_voice_activity_events=True,
263
- interim_results=self._config.interim_results,
264
- ),
265
- )
314
+ self._reconnect_event = asyncio.Event()
266
315
 
267
- @utils.log_exceptions(logger=logger)
268
- async def _main_task(self) -> None:
269
- await self._run(self._max_retry)
270
-
271
- async def _run(self, max_retry: int) -> None:
272
- retry_count = 0
273
- while self._input_ch.qsize() or not self._input_ch.closed:
316
+ def update_options(
317
+ self,
318
+ *,
319
+ languages: LanguageCode | None = None,
320
+ detect_language: bool | None = None,
321
+ interim_results: bool | None = None,
322
+ punctuate: bool | None = None,
323
+ spoken_punctuation: bool | None = None,
324
+ model: SpeechModels | None = None,
325
+ location: str | None = None,
326
+ keywords: List[tuple[str, float]] | None = None,
327
+ ):
328
+ if languages is not None:
329
+ if isinstance(languages, str):
330
+ languages = [languages]
331
+ self._config.languages = languages
332
+ if detect_language is not None:
333
+ self._config.detect_language = detect_language
334
+ if interim_results is not None:
335
+ self._config.interim_results = interim_results
336
+ if punctuate is not None:
337
+ self._config.punctuate = punctuate
338
+ if spoken_punctuation is not None:
339
+ self._config.spoken_punctuation = spoken_punctuation
340
+ if model is not None:
341
+ self._config.model = model
342
+ if keywords is not None:
343
+ self._config.keywords = keywords
344
+
345
+ self._reconnect_event.set()
346
+
347
+ async def _run(self) -> None:
348
+ # google requires a async generator when calling streaming_recognize
349
+ # this function basically convert the queue into a async generator
350
+ async def input_generator():
274
351
  try:
275
- # google requires a async generator when calling streaming_recognize
276
- # this function basically convert the queue into a async generator
277
- async def input_generator():
278
- try:
279
- # first request should contain the config
352
+ # first request should contain the config
353
+ yield cloud_speech.StreamingRecognizeRequest(
354
+ recognizer=self._recognizer,
355
+ streaming_config=self._streaming_config,
356
+ )
357
+
358
+ async for frame in self._input_ch:
359
+ if isinstance(frame, rtc.AudioFrame):
280
360
  yield cloud_speech.StreamingRecognizeRequest(
281
- recognizer=self._recognizer,
282
- streaming_config=self._streaming_config,
361
+ audio=frame.data.tobytes()
283
362
  )
284
363
 
285
- async for frame in self._input_ch:
286
- if isinstance(frame, rtc.AudioFrame):
287
- frame = frame.remix_and_resample(
288
- self._sample_rate, self._num_channels
289
- )
290
- yield cloud_speech.StreamingRecognizeRequest(
291
- audio=frame.data.tobytes()
292
- )
364
+ except Exception:
365
+ logger.exception(
366
+ "an error occurred while streaming input to google STT"
367
+ )
368
+
369
+ async def process_stream(stream):
370
+ async for resp in stream:
371
+ if (
372
+ resp.speech_event_type
373
+ == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
374
+ ):
375
+ self._event_ch.send_nowait(
376
+ stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
377
+ )
293
378
 
294
- except Exception:
295
- logger.exception(
296
- "an error occurred while streaming input to google STT"
379
+ if (
380
+ resp.speech_event_type
381
+ == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
382
+ ):
383
+ result = resp.results[0]
384
+ speech_data = _streaming_recognize_response_to_speech_data(resp)
385
+ if speech_data is None:
386
+ continue
387
+
388
+ if not result.is_final:
389
+ self._event_ch.send_nowait(
390
+ stt.SpeechEvent(
391
+ type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
392
+ alternatives=[speech_data],
393
+ )
394
+ )
395
+ else:
396
+ self._event_ch.send_nowait(
397
+ stt.SpeechEvent(
398
+ type=stt.SpeechEventType.FINAL_TRANSCRIPT,
399
+ alternatives=[speech_data],
400
+ )
297
401
  )
298
402
 
299
- # try to connect
300
- stream = await self._client.streaming_recognize(
301
- requests=input_generator()
302
- )
303
- retry_count = 0 # connection successful, reset retry count
304
-
305
- await self._run_stream(stream)
306
- except Aborted:
307
- logger.error("google stt connection aborted")
308
- break
309
- except Exception as e:
310
- if retry_count >= max_retry:
311
- logger.error(
312
- f"failed to connect to google stt after {max_retry} tries",
313
- exc_info=e,
403
+ if (
404
+ resp.speech_event_type
405
+ == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
406
+ ):
407
+ self._event_ch.send_nowait(
408
+ stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
314
409
  )
315
- break
316
410
 
317
- retry_delay = min(retry_count * 2, 5) # max 5s
318
- retry_count += 1
319
- logger.warning(
320
- f"google stt connection failed, retrying in {retry_delay}s",
321
- exc_info=e,
411
+ while True:
412
+ try:
413
+ self._streaming_config = cloud_speech.StreamingRecognitionConfig(
414
+ config=cloud_speech.RecognitionConfig(
415
+ explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
416
+ encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
417
+ sample_rate_hertz=self._config.sample_rate,
418
+ audio_channel_count=1,
419
+ ),
420
+ adaptation=self._config.build_adaptation(),
421
+ language_codes=self._config.languages,
422
+ model=self._config.model,
423
+ features=cloud_speech.RecognitionFeatures(
424
+ enable_automatic_punctuation=self._config.punctuate,
425
+ enable_word_time_offsets=True,
426
+ ),
427
+ ),
428
+ streaming_features=cloud_speech.StreamingRecognitionFeatures(
429
+ enable_voice_activity_events=True,
430
+ interim_results=self._config.interim_results,
431
+ ),
322
432
  )
323
- await asyncio.sleep(retry_delay)
324
433
 
325
- async def _run_stream(
326
- self, stream: AsyncIterable[cloud_speech.StreamingRecognizeResponse]
327
- ):
328
- async for resp in stream:
329
- if (
330
- resp.speech_event_type
331
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_BEGIN
332
- ):
333
- self._event_ch.send_nowait(
334
- stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
434
+ stream = await self._client.streaming_recognize(
435
+ requests=input_generator(),
335
436
  )
336
437
 
337
- if (
338
- resp.speech_event_type
339
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_EVENT_TYPE_UNSPECIFIED
340
- ):
341
- result = resp.results[0]
342
- speech_data = _streaming_recognize_response_to_speech_data(resp)
343
- if speech_data is None:
344
- continue
345
-
346
- if not result.is_final:
347
- self._event_ch.send_nowait(
348
- stt.SpeechEvent(
349
- type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
350
- alternatives=[speech_data],
351
- )
438
+ process_stream_task = asyncio.create_task(process_stream(stream))
439
+ wait_reconnect_task = asyncio.create_task(self._reconnect_event.wait())
440
+ try:
441
+ done, _ = await asyncio.wait(
442
+ [process_stream_task, wait_reconnect_task],
443
+ return_when=asyncio.FIRST_COMPLETED,
352
444
  )
353
- else:
354
- self._event_ch.send_nowait(
355
- stt.SpeechEvent(
356
- type=stt.SpeechEventType.FINAL_TRANSCRIPT,
357
- alternatives=[speech_data],
358
- )
445
+ for task in done:
446
+ if task != wait_reconnect_task:
447
+ task.result()
448
+ finally:
449
+ await utils.aio.gracefully_cancel(
450
+ process_stream_task, wait_reconnect_task
359
451
  )
360
-
361
- if (
362
- resp.speech_event_type
363
- == cloud_speech.StreamingRecognizeResponse.SpeechEventType.SPEECH_ACTIVITY_END
364
- ):
365
- self._event_ch.send_nowait(
366
- stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
367
- )
452
+ finally:
453
+ if not self._reconnect_event.is_set():
454
+ break
455
+ self._reconnect_event.clear()
368
456
 
369
457
 
370
458
  def _recognize_response_to_speech_event(
@@ -18,7 +18,9 @@ from dataclasses import dataclass
18
18
 
19
19
  from livekit import rtc
20
20
  from livekit.agents import (
21
+ DEFAULT_API_CONNECT_OPTIONS,
21
22
  APIConnectionError,
23
+ APIConnectOptions,
22
24
  APIStatusError,
23
25
  APITimeoutError,
24
26
  tts,
@@ -134,7 +136,7 @@ class TTS(tts.TTS):
134
136
  self._opts.audio_config.speaking_rate = speaking_rate
135
137
 
136
138
  def _ensure_client(self) -> texttospeech.TextToSpeechAsyncClient:
137
- if not self._client:
139
+ if self._client is None:
138
140
  if self._credentials_info:
139
141
  self._client = (
140
142
  texttospeech.TextToSpeechAsyncClient.from_service_account_info(
@@ -154,22 +156,35 @@ class TTS(tts.TTS):
154
156
  assert self._client is not None
155
157
  return self._client
156
158
 
157
- def synthesize(self, text: str) -> "ChunkedStream":
158
- return ChunkedStream(self, text, self._opts, self._ensure_client())
159
+ def synthesize(
160
+ self,
161
+ text: str,
162
+ *,
163
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
164
+ ) -> "ChunkedStream":
165
+ return ChunkedStream(
166
+ tts=self,
167
+ input_text=text,
168
+ conn_options=conn_options,
169
+ opts=self._opts,
170
+ client=self._ensure_client(),
171
+ )
159
172
 
160
173
 
161
174
  class ChunkedStream(tts.ChunkedStream):
162
175
  def __init__(
163
176
  self,
177
+ *,
164
178
  tts: TTS,
165
- text: str,
179
+ input_text: str,
180
+ conn_options: APIConnectOptions,
166
181
  opts: _TTSOptions,
167
182
  client: texttospeech.TextToSpeechAsyncClient,
168
183
  ) -> None:
169
- super().__init__(tts, text)
184
+ super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
170
185
  self._opts, self._client = opts, client
171
186
 
172
- async def _main_task(self) -> None:
187
+ async def _run(self) -> None:
173
188
  request_id = utils.shortuuid()
174
189
 
175
190
  try:
@@ -177,16 +192,16 @@ class ChunkedStream(tts.ChunkedStream):
177
192
  input=texttospeech.SynthesisInput(text=self._input_text),
178
193
  voice=self._opts.voice,
179
194
  audio_config=self._opts.audio_config,
195
+ timeout=self._conn_options.timeout,
180
196
  )
181
197
 
182
- data = response.audio_content
183
198
  if self._opts.audio_config.audio_encoding == "mp3":
184
199
  decoder = utils.codecs.Mp3StreamDecoder()
185
200
  bstream = utils.audio.AudioByteStream(
186
201
  sample_rate=self._opts.audio_config.sample_rate_hertz,
187
202
  num_channels=1,
188
203
  )
189
- for frame in decoder.decode_chunk(data):
204
+ for frame in decoder.decode_chunk(response.audio_content):
190
205
  for frame in bstream.write(frame.data.tobytes()):
191
206
  self._event_ch.send_nowait(
192
207
  tts.SynthesizedAudio(request_id=request_id, frame=frame)
@@ -197,7 +212,7 @@ class ChunkedStream(tts.ChunkedStream):
197
212
  tts.SynthesizedAudio(request_id=request_id, frame=frame)
198
213
  )
199
214
  else:
200
- data = data[44:] # skip WAV header
215
+ data = response.audio_content[44:] # skip WAV header
201
216
  self._event_ch.send_nowait(
202
217
  tts.SynthesizedAudio(
203
218
  request_id=request_id,
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "0.7.3"
15
+ __version__ = "0.9.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: livekit-plugins-google
3
- Version: 0.7.3
3
+ Version: 0.9.0
4
4
  Summary: Agent Framework plugin for services from Google Cloud
5
5
  Home-page: https://github.com/livekit/agents
6
6
  License: Apache-2.0
@@ -19,10 +19,11 @@ Classifier: Programming Language :: Python :: 3.10
19
19
  Classifier: Programming Language :: Python :: 3 :: Only
20
20
  Requires-Python: >=3.9.0
21
21
  Description-Content-Type: text/markdown
22
- Requires-Dist: google-auth <3,>=2
23
- Requires-Dist: google-cloud-speech <3,>=2
24
- Requires-Dist: google-cloud-texttospeech <3,>=2
25
- Requires-Dist: livekit-agents >=0.11
22
+ Requires-Dist: google-auth<3,>=2
23
+ Requires-Dist: google-cloud-speech<3,>=2
24
+ Requires-Dist: google-cloud-texttospeech<3,>=2
25
+ Requires-Dist: google-genai>=0.3.0
26
+ Requires-Dist: livekit-agents>=0.12.3
26
27
 
27
28
  # LiveKit Plugins Google
28
29
 
@@ -37,3 +38,8 @@ pip install livekit-plugins-google
37
38
  ## Pre-requisites
38
39
 
39
40
  For credentials, you'll need a Google Cloud account and obtain the correct credentials. Credentials can be passed directly or via Application Default Credentials as specified in [How Application Default Credentials works](https://cloud.google.com/docs/authentication/application-default-credentials).
41
+
42
+ To use the STT and TTS API, you'll need to enable the respective services for your Google Cloud project.
43
+
44
+ - Cloud Speech-to-Text API
45
+ - Cloud Text-to-Speech API
@@ -0,0 +1,15 @@
1
+ livekit/plugins/google/__init__.py,sha256=TY-5FwEX4Vs7GLO1wSegIxC5W4UPkHBthlr-__yuE4w,1143
2
+ livekit/plugins/google/log.py,sha256=GI3YWN5YzrafnUccljzPRS_ZALkMNk1i21IRnTl2vNA,69
3
+ livekit/plugins/google/models.py,sha256=cBXhZGY9bFaSCyL9VeSng9wsxhf3peJi3AUYBKV-8GQ,1343
4
+ livekit/plugins/google/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ livekit/plugins/google/stt.py,sha256=SfmKgQotIVzk9-Hipo1X5cnLQG4uXLniTUoyM3IynwA,18712
6
+ livekit/plugins/google/tts.py,sha256=95qXCigVQYWNbcN3pIKBpIah4b31U_MWtXv5Ji0AMc4,9229
7
+ livekit/plugins/google/version.py,sha256=onRKrcQ35NZG4oEg_95WGeTytHh_6VVAlQKAZhwiEe4,600
8
+ livekit/plugins/google/beta/__init__.py,sha256=AxRYc7NGG62Tv1MmcZVCDHNvlhbC86hM-_yP01Qb28k,47
9
+ livekit/plugins/google/beta/realtime/__init__.py,sha256=XnJpNIN6NRm7Y4hH2RNA8Xt-tTmkZEKCs_zzU3_koBI,251
10
+ livekit/plugins/google/beta/realtime/api_proto.py,sha256=IHYBryuzpfGQD86Twlfq6qxrBhFHptf_IvOk36Wxo1M,2156
11
+ livekit/plugins/google/beta/realtime/realtime_api.py,sha256=OxrbWnUOT_oFdrMruvLPHgEoXlOr6M5oGym9b2Iqz48,15958
12
+ livekit_plugins_google-0.9.0.dist-info/METADATA,sha256=tB70OQMa7JtWLqRi1TMDUpv4y0TZEk0L609BN6y0x48,1841
13
+ livekit_plugins_google-0.9.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
14
+ livekit_plugins_google-0.9.0.dist-info/top_level.txt,sha256=OoDok3xUmXbZRvOrfvvXB-Juu4DX79dlq188E19YHoo,8
15
+ livekit_plugins_google-0.9.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.5.0)
2
+ Generator: setuptools (75.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,11 +0,0 @@
1
- livekit/plugins/google/__init__.py,sha256=rqV6C5mFNDFlrA2IcGJrsebr2VxQwMzoDUjY1JhMBZM,1117
2
- livekit/plugins/google/log.py,sha256=GI3YWN5YzrafnUccljzPRS_ZALkMNk1i21IRnTl2vNA,69
3
- livekit/plugins/google/models.py,sha256=n8pgTJ7xyJpPCZJ_y0GzaQq6LqYknL6K6trpi07-AxM,1307
4
- livekit/plugins/google/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- livekit/plugins/google/stt.py,sha256=WjeqYsunW8jY-WHlnNeks7gR-TiojMRR7LYdAVdCxqY,15268
6
- livekit/plugins/google/tts.py,sha256=hRN8ul1lDXU8LPVEfbTszgBiRYsifZXCPMwk-Pv2KeA,8793
7
- livekit/plugins/google/version.py,sha256=yJeG0VwiekDJAk7GHcIAe43ebagJgloe-ZsqEGZnqzE,600
8
- livekit_plugins_google-0.7.3.dist-info/METADATA,sha256=8UvORpoVunOTq0xKxHEk8M3sexKFnBnu66DkEJCnrRY,1647
9
- livekit_plugins_google-0.7.3.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
10
- livekit_plugins_google-0.7.3.dist-info/top_level.txt,sha256=OoDok3xUmXbZRvOrfvvXB-Juu4DX79dlq188E19YHoo,8
11
- livekit_plugins_google-0.7.3.dist-info/RECORD,,