livekit-plugins-aws 1.0.22__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of livekit-plugins-aws might be problematic. Click here for more details.

@@ -146,6 +146,9 @@ venv.bak/
146
146
  .dmypy.json
147
147
  dmypy.json
148
148
 
149
+ # trunk
150
+ .trunk/
151
+
149
152
  # Pyre type checker
150
153
  .pyre/
151
154
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-aws
3
- Version: 1.0.22
3
+ Version: 1.1.0
4
4
  Summary: LiveKit Agents Plugin for services from AWS
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9.0
21
21
  Requires-Dist: aioboto3>=14.1.0
22
22
  Requires-Dist: amazon-transcribe>=0.6.2
23
- Requires-Dist: livekit-agents>=1.0.22
23
+ Requires-Dist: livekit-agents>=1.1.0
24
24
  Description-Content-Type: text/markdown
25
25
 
26
26
  # AWS plugin for LiveKit Agents
@@ -16,12 +16,18 @@ from __future__ import annotations
16
16
 
17
17
  import os
18
18
  from dataclasses import dataclass
19
- from typing import Any, Literal
19
+ from typing import Any, cast
20
20
 
21
- import aioboto3
21
+ import aioboto3 # type: ignore
22
22
 
23
23
  from livekit.agents import APIConnectionError, APIStatusError, llm
24
- from livekit.agents.llm import ChatContext, FunctionTool, FunctionToolCall, ToolChoice
24
+ from livekit.agents.llm import (
25
+ ChatContext,
26
+ FunctionTool,
27
+ FunctionToolCall,
28
+ RawFunctionTool,
29
+ ToolChoice,
30
+ )
25
31
  from livekit.agents.types import (
26
32
  DEFAULT_API_CONNECT_OPTIONS,
27
33
  NOT_GIVEN,
@@ -31,14 +37,14 @@ from livekit.agents.types import (
31
37
  from livekit.agents.utils import is_given
32
38
 
33
39
  from .log import logger
34
- from .utils import to_chat_ctx, to_fnc_ctx
40
+ from .utils import to_fnc_ctx
35
41
 
36
- TEXT_MODEL = Literal["anthropic.claude-3-5-sonnet-20241022-v2:0"]
42
+ DEFAULT_TEXT_MODEL = "anthropic.claude-3-5-sonnet-20240620-v1:0"
37
43
 
38
44
 
39
45
  @dataclass
40
46
  class _LLMOptions:
41
- model: str | TEXT_MODEL
47
+ model: str
42
48
  temperature: NotGivenOr[float]
43
49
  tool_choice: NotGivenOr[ToolChoice]
44
50
  max_output_tokens: NotGivenOr[int]
@@ -50,10 +56,10 @@ class LLM(llm.LLM):
50
56
  def __init__(
51
57
  self,
52
58
  *,
53
- model: NotGivenOr[str | TEXT_MODEL] = NOT_GIVEN,
59
+ model: NotGivenOr[str] = DEFAULT_TEXT_MODEL,
54
60
  api_key: NotGivenOr[str] = NOT_GIVEN,
55
61
  api_secret: NotGivenOr[str] = NOT_GIVEN,
56
- region: NotGivenOr[str] = NOT_GIVEN,
62
+ region: NotGivenOr[str] = "us-east-1",
57
63
  temperature: NotGivenOr[float] = NOT_GIVEN,
58
64
  max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
59
65
  top_p: NotGivenOr[float] = NOT_GIVEN,
@@ -70,7 +76,8 @@ class LLM(llm.LLM):
70
76
  See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html for more details on the AWS Bedrock Runtime API.
71
77
 
72
78
  Args:
73
- model (TEXT_MODEL, optional): model or inference profile arn to use(https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-use.html). Defaults to 'anthropic.claude-3-5-sonnet-20240620-v1:0'.
79
+ model (str, optional): model or inference profile arn to use(https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-use.html).
80
+ Defaults to 'anthropic.claude-3-5-sonnet-20240620-v1:0'.
74
81
  api_key(str, optional): AWS access key id.
75
82
  api_secret(str, optional): AWS secret access key
76
83
  region (str, optional): The region to use for AWS API requests. Defaults value is "us-east-1".
@@ -89,13 +96,15 @@ class LLM(llm.LLM):
89
96
  region_name=region if is_given(region) else None,
90
97
  )
91
98
 
92
- model = model if is_given(model) else os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
93
- if not model:
99
+ bedrock_model = (
100
+ model if is_given(model) else os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
101
+ )
102
+ if not bedrock_model:
94
103
  raise ValueError(
95
104
  "model or inference profile arn must be set using the argument or by setting the BEDROCK_INFERENCE_PROFILE_ARN environment variable." # noqa: E501
96
105
  )
97
106
  self._opts = _LLMOptions(
98
- model=model,
107
+ model=bedrock_model,
99
108
  temperature=temperature,
100
109
  tool_choice=tool_choice,
101
110
  max_output_tokens=max_output_tokens,
@@ -107,12 +116,15 @@ class LLM(llm.LLM):
107
116
  self,
108
117
  *,
109
118
  chat_ctx: ChatContext,
110
- tools: list[FunctionTool] | None = None,
119
+ tools: list[FunctionTool | RawFunctionTool] | None = None,
120
+ parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
111
121
  conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
112
- temperature: NotGivenOr[float] = NOT_GIVEN,
113
122
  tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
123
+ temperature: NotGivenOr[float] = NOT_GIVEN,
124
+ extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
114
125
  ) -> LLMStream:
115
- opts = {}
126
+ opts: dict[str, Any] = {}
127
+ extra_kwargs = extra_kwargs if is_given(extra_kwargs) else {}
116
128
 
117
129
  if is_given(self._opts.model):
118
130
  opts["modelId"] = self._opts.model
@@ -124,7 +136,9 @@ class LLM(llm.LLM):
124
136
  return None
125
137
 
126
138
  tool_config: dict[str, Any] = {"tools": to_fnc_ctx(tools)}
127
- tool_choice = tool_choice if is_given(tool_choice) else self._opts.tool_choice
139
+ tool_choice = (
140
+ cast(ToolChoice, tool_choice) if is_given(tool_choice) else self._opts.tool_choice
141
+ )
128
142
  if is_given(tool_choice):
129
143
  if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
130
144
  tool_config["toolChoice"] = {"tool": {"name": tool_choice["function"]["name"]}}
@@ -140,12 +154,12 @@ class LLM(llm.LLM):
140
154
  tool_config = _get_tool_config()
141
155
  if tool_config:
142
156
  opts["toolConfig"] = tool_config
143
- messages, system_message = to_chat_ctx(chat_ctx, id(self))
157
+ messages, extra_data = chat_ctx.to_provider_format(format="aws")
144
158
  opts["messages"] = messages
145
- if system_message:
146
- opts["system"] = [system_message]
159
+ if extra_data.system_messages:
160
+ opts["system"] = [{"text": content} for content in extra_data.system_messages]
147
161
 
148
- inference_config = {}
162
+ inference_config: dict[str, Any] = {}
149
163
  if is_given(self._opts.max_output_tokens):
150
164
  inference_config["maxTokens"] = self._opts.max_output_tokens
151
165
  temperature = temperature if is_given(temperature) else self._opts.temperature
@@ -176,7 +190,7 @@ class LLMStream(llm.LLMStream):
176
190
  chat_ctx: ChatContext,
177
191
  session: aioboto3.Session,
178
192
  conn_options: APIConnectOptions,
179
- tools: list[FunctionTool],
193
+ tools: list[FunctionTool | RawFunctionTool],
180
194
  extra_kwargs: dict[str, Any],
181
195
  ) -> None:
182
196
  super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
@@ -192,7 +206,7 @@ class LLMStream(llm.LLMStream):
192
206
  retryable = True
193
207
  try:
194
208
  async with self._session.client("bedrock-runtime") as client:
195
- response = await client.converse_stream(**self._opts) # type: ignore
209
+ response = await client.converse_stream(**self._opts)
196
210
  request_id = response["ResponseMetadata"]["RequestId"]
197
211
  if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
198
212
  raise APIStatusError(
@@ -240,6 +254,11 @@ class LLMStream(llm.LLMStream):
240
254
  completion_tokens=metadata["usage"]["outputTokens"],
241
255
  prompt_tokens=metadata["usage"]["inputTokens"],
242
256
  total_tokens=metadata["usage"]["totalTokens"],
257
+ prompt_cached_tokens=(
258
+ metadata["usage"]["cacheReadInputTokens"]
259
+ if "cacheReadInputTokens" in metadata["usage"]
260
+ else 0
261
+ ),
243
262
  ),
244
263
  )
245
264
  elif "contentBlockStop" in chunk:
@@ -1,7 +1,7 @@
1
1
  from typing import Literal
2
2
 
3
- TTS_SPEECH_ENGINE = Literal["standard", "neural", "long-form", "generative"]
4
- TTS_LANGUAGE = Literal[
3
+ TTSSpeechEngine = Literal["standard", "neural", "long-form", "generative"]
4
+ TTSLanguages = Literal[
5
5
  "arb",
6
6
  "cmn-CN",
7
7
  "cy-GB",
@@ -45,4 +45,4 @@ TTS_LANGUAGE = Literal[
45
45
  "de-CH",
46
46
  ]
47
47
 
48
- TTS_OUTPUT_FORMAT = Literal["mp3"]
48
+ TTSEncoding = Literal["mp3"]
@@ -78,7 +78,7 @@ class STT(stt.STT):
78
78
  self._region = region
79
79
  self._client = TranscribeStreamingClient(
80
80
  region=self._region,
81
- credential_resolver=AwsCrtCredentialResolver(None),
81
+ credential_resolver=AwsCrtCredentialResolver(None), # type: ignore
82
82
  )
83
83
 
84
84
  self._config = STTOptions(
@@ -153,15 +153,15 @@ class SpeechStream(stt.SpeechStream):
153
153
  "language_model_name": self._opts.language_model_name,
154
154
  }
155
155
  filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
156
- stream = await self._client.start_stream_transcription(**filtered_config)
156
+ stream = await self._client.start_stream_transcription(**filtered_config) # type: ignore
157
157
 
158
- async def input_generator(stream: StartStreamTranscriptionEventStream):
158
+ async def input_generator(stream: StartStreamTranscriptionEventStream) -> None:
159
159
  async for frame in self._input_ch:
160
160
  if isinstance(frame, rtc.AudioFrame):
161
161
  await stream.input_stream.send_audio_event(audio_chunk=frame.data.tobytes())
162
- await stream.input_stream.end_stream()
162
+ await stream.input_stream.end_stream() # type: ignore
163
163
 
164
- async def handle_transcript_events(stream: StartStreamTranscriptionEventStream):
164
+ async def handle_transcript_events(stream: StartStreamTranscriptionEventStream) -> None:
165
165
  async for event in stream.output_stream:
166
166
  if isinstance(event, TranscriptEvent):
167
167
  self._process_transcript_event(event)
@@ -184,7 +184,7 @@ class SpeechStream(stt.SpeechStream):
184
184
  finally:
185
185
  await utils.aio.gracefully_cancel(*tasks)
186
186
 
187
- def _process_transcript_event(self, transcript_event: TranscriptEvent):
187
+ def _process_transcript_event(self, transcript_event: TranscriptEvent) -> None:
188
188
  stream = transcript_event.transcript.results
189
189
  for resp in stream:
190
190
  if resp.start_time and resp.start_time == 0.0:
@@ -0,0 +1,156 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from __future__ import annotations
14
+
15
+ from dataclasses import dataclass, replace
16
+
17
+ import aioboto3 # type: ignore
18
+ import botocore # type: ignore
19
+ import botocore.exceptions # type: ignore
20
+ from aiobotocore.config import AioConfig # type: ignore
21
+
22
+ from livekit.agents import (
23
+ APIConnectionError,
24
+ APIConnectOptions,
25
+ APITimeoutError,
26
+ tts,
27
+ )
28
+ from livekit.agents.types import (
29
+ DEFAULT_API_CONNECT_OPTIONS,
30
+ NOT_GIVEN,
31
+ NotGivenOr,
32
+ )
33
+ from livekit.agents.utils import is_given
34
+
35
+ from .models import TTSLanguages, TTSSpeechEngine
36
+ from .utils import _strip_nones
37
+
38
+ DEFAULT_SPEECH_ENGINE: TTSSpeechEngine = "generative"
39
+ DEFAULT_VOICE = "Ruth"
40
+
41
+
42
+ @dataclass
43
+ class _TTSOptions:
44
+ # https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
45
+ voice: str
46
+ speech_engine: TTSSpeechEngine
47
+ region: str | None
48
+ sample_rate: int
49
+ language: TTSLanguages | str | None
50
+
51
+
52
+ class TTS(tts.TTS):
53
+ def __init__(
54
+ self,
55
+ *,
56
+ voice: str = "Ruth",
57
+ language: NotGivenOr[TTSLanguages | str] = NOT_GIVEN,
58
+ speech_engine: TTSSpeechEngine = "generative",
59
+ sample_rate: int = 16000,
60
+ region: str | None = None,
61
+ api_key: str | None = None,
62
+ api_secret: str | None = None,
63
+ session: aioboto3.Session | None = None,
64
+ ) -> None:
65
+ """
66
+ Create a new instance of AWS Polly TTS.
67
+
68
+ ``api_key`` and ``api_secret`` must be set to your AWS Access key id and secret access key, either using the argument or by setting the
69
+ ``AWS_ACCESS_KEY_ID`` and ``AWS_SECRET_ACCESS_KEY`` environmental variables.
70
+
71
+ See https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html for more details on the the AWS Polly TTS.
72
+
73
+ Args:
74
+ Voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
75
+ language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
76
+ sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
77
+ speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
78
+ region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
79
+ api_key(str, optional): AWS access key id.
80
+ api_secret(str, optional): AWS secret access key.
81
+ session(aioboto3.Session, optional): Optional aioboto3 session to use.
82
+ """ # noqa: E501
83
+ super().__init__(
84
+ capabilities=tts.TTSCapabilities(
85
+ streaming=False,
86
+ ),
87
+ sample_rate=sample_rate,
88
+ num_channels=1,
89
+ )
90
+ self._session = session or aioboto3.Session(
91
+ aws_access_key_id=api_key if is_given(api_key) else None,
92
+ aws_secret_access_key=api_secret if is_given(api_secret) else None,
93
+ region_name=region if is_given(region) else None,
94
+ )
95
+
96
+ self._opts = _TTSOptions(
97
+ voice=voice,
98
+ speech_engine=speech_engine,
99
+ region=region or None,
100
+ language=language or None,
101
+ sample_rate=sample_rate,
102
+ )
103
+
104
+ def synthesize(
105
+ self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
106
+ ) -> ChunkedStream:
107
+ return ChunkedStream(tts=self, text=text, conn_options=conn_options)
108
+
109
+
110
+ class ChunkedStream(tts.ChunkedStream):
111
+ def __init__(
112
+ self, *, tts: TTS, text: str, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
113
+ ) -> None:
114
+ super().__init__(tts=tts, input_text=text, conn_options=conn_options)
115
+ self._tts = tts
116
+ self._opts = replace(tts._opts)
117
+
118
+ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
119
+ try:
120
+ config = AioConfig(
121
+ connect_timeout=self._conn_options.timeout,
122
+ read_timeout=10,
123
+ retries={"mode": "standard", "total_max_attempts": 1},
124
+ )
125
+ async with self._tts._session.client("polly", config=config) as client: # type: ignore
126
+ response = await client.synthesize_speech(
127
+ **_strip_nones(
128
+ {
129
+ "Text": self._input_text,
130
+ "OutputFormat": "mp3",
131
+ "Engine": self._opts.speech_engine,
132
+ "VoiceId": self._opts.voice,
133
+ "TextType": "text",
134
+ "SampleRate": str(self._opts.sample_rate),
135
+ "LanguageCode": self._opts.language,
136
+ }
137
+ )
138
+ )
139
+
140
+ if "AudioStream" in response:
141
+ output_emitter.initialize(
142
+ request_id=response["ResponseMetadata"]["RequestId"],
143
+ sample_rate=self._opts.sample_rate,
144
+ num_channels=1,
145
+ mime_type="audio/mp3",
146
+ )
147
+
148
+ async with response["AudioStream"] as resp:
149
+ async for data, _ in resp.content.iter_chunks():
150
+ output_emitter.push(data)
151
+
152
+ output_emitter.flush()
153
+ except botocore.exceptions.ConnectTimeoutError:
154
+ raise APITimeoutError() from None
155
+ except Exception as e:
156
+ raise APIConnectionError() from e
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from livekit.agents import llm
4
+ from livekit.agents.llm import FunctionTool, RawFunctionTool
5
+ from livekit.agents.llm.tool_context import (
6
+ get_raw_function_info,
7
+ is_function_tool,
8
+ is_raw_function_tool,
9
+ )
10
+
11
+ __all__ = ["to_fnc_ctx"]
12
+ DEFAULT_REGION = "us-east-1"
13
+
14
+
15
+ def to_fnc_ctx(fncs: list[FunctionTool | RawFunctionTool]) -> list[dict]:
16
+ return [_build_tool_spec(fnc) for fnc in fncs]
17
+
18
+
19
+ def _build_tool_spec(function: FunctionTool | RawFunctionTool) -> dict:
20
+ if is_function_tool(function):
21
+ fnc = llm.utils.build_legacy_openai_schema(function, internally_tagged=True)
22
+ return {
23
+ "toolSpec": _strip_nones(
24
+ {
25
+ "name": fnc["name"],
26
+ "description": fnc["description"] if fnc["description"] else None,
27
+ "inputSchema": {"json": fnc["parameters"] if fnc["parameters"] else {}},
28
+ }
29
+ )
30
+ }
31
+ elif is_raw_function_tool(function):
32
+ info = get_raw_function_info(function)
33
+ return {
34
+ "toolSpec": _strip_nones(
35
+ {
36
+ "name": info.name,
37
+ "description": info.raw_schema.get("description", ""),
38
+ "inputSchema": {"json": info.raw_schema.get("parameters", {})},
39
+ }
40
+ )
41
+ }
42
+ else:
43
+ raise ValueError("Invalid function tool")
44
+
45
+
46
+ def _strip_nones(d: dict) -> dict:
47
+ return {k: v for k, v in d.items() if v is not None}
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.0.22"
15
+ __version__ = "1.1.0"
@@ -23,7 +23,7 @@ classifiers = [
23
23
  "Programming Language :: Python :: 3 :: Only",
24
24
  ]
25
25
  dependencies = [
26
- "livekit-agents>=1.0.22",
26
+ "livekit-agents>=1.1.0",
27
27
  "aioboto3>=14.1.0",
28
28
  "amazon-transcribe>=0.6.2",
29
29
  ]
@@ -1,195 +0,0 @@
1
- # Licensed under the Apache License, Version 2.0 (the "License");
2
- # you may not use this file except in compliance with the License.
3
- # You may obtain a copy of the License at
4
- #
5
- # http://www.apache.org/licenses/LICENSE-2.0
6
- #
7
- # Unless required by applicable law or agreed to in writing, software
8
- # distributed under the License is distributed on an "AS IS" BASIS,
9
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- # See the License for the specific language governing permissions and
11
- # limitations under the License.
12
-
13
- from __future__ import annotations
14
-
15
- import asyncio
16
- from dataclasses import dataclass
17
-
18
- import aioboto3
19
- import aiohttp
20
-
21
- from livekit.agents import (
22
- APIConnectionError,
23
- APIConnectOptions,
24
- APIStatusError,
25
- APITimeoutError,
26
- tts,
27
- utils,
28
- )
29
- from livekit.agents.types import (
30
- DEFAULT_API_CONNECT_OPTIONS,
31
- NOT_GIVEN,
32
- NotGivenOr,
33
- )
34
- from livekit.agents.utils import is_given
35
-
36
- from .models import TTS_LANGUAGE, TTS_SPEECH_ENGINE
37
- from .utils import _strip_nones
38
-
39
- TTS_NUM_CHANNELS: int = 1
40
- DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
41
- DEFAULT_VOICE = "Ruth"
42
- DEFAULT_SAMPLE_RATE = 16000
43
-
44
-
45
- @dataclass
46
- class _TTSOptions:
47
- # https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
48
- voice: NotGivenOr[str]
49
- speech_engine: NotGivenOr[TTS_SPEECH_ENGINE]
50
- region: str
51
- sample_rate: int
52
- language: NotGivenOr[TTS_LANGUAGE | str]
53
-
54
-
55
- class TTS(tts.TTS):
56
- def __init__(
57
- self,
58
- *,
59
- voice: NotGivenOr[str] = NOT_GIVEN,
60
- language: NotGivenOr[TTS_LANGUAGE | str] = NOT_GIVEN,
61
- speech_engine: NotGivenOr[TTS_SPEECH_ENGINE] = NOT_GIVEN,
62
- sample_rate: int = DEFAULT_SAMPLE_RATE,
63
- region: NotGivenOr[str] = NOT_GIVEN,
64
- api_key: NotGivenOr[str] = NOT_GIVEN,
65
- api_secret: NotGivenOr[str] = NOT_GIVEN,
66
- session: aioboto3.Session | None = None,
67
- ) -> None:
68
- """
69
- Create a new instance of AWS Polly TTS.
70
-
71
- ``api_key`` and ``api_secret`` must be set to your AWS Access key id and secret access key, either using the argument or by setting the
72
- ``AWS_ACCESS_KEY_ID`` and ``AWS_SECRET_ACCESS_KEY`` environmental variables.
73
-
74
- See https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html for more details on the the AWS Polly TTS.
75
-
76
- Args:
77
- Voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
78
- language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
79
- sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
80
- speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
81
- region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
82
- api_key(str, optional): AWS access key id.
83
- api_secret(str, optional): AWS secret access key.
84
- session(aioboto3.Session, optional): Optional aioboto3 session to use.
85
- """ # noqa: E501
86
- super().__init__(
87
- capabilities=tts.TTSCapabilities(
88
- streaming=False,
89
- ),
90
- sample_rate=sample_rate,
91
- num_channels=TTS_NUM_CHANNELS,
92
- )
93
- self._session = session or aioboto3.Session(
94
- aws_access_key_id=api_key if is_given(api_key) else None,
95
- aws_secret_access_key=api_secret if is_given(api_secret) else None,
96
- region_name=region if is_given(region) else None,
97
- )
98
- self._opts = _TTSOptions(
99
- voice=voice,
100
- speech_engine=speech_engine,
101
- region=region,
102
- language=language,
103
- sample_rate=sample_rate,
104
- )
105
-
106
- def synthesize(
107
- self,
108
- text: str,
109
- *,
110
- conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
111
- ) -> ChunkedStream:
112
- return ChunkedStream(
113
- tts=self,
114
- text=text,
115
- conn_options=conn_options,
116
- session=self._session,
117
- opts=self._opts,
118
- )
119
-
120
-
121
- class ChunkedStream(tts.ChunkedStream):
122
- def __init__(
123
- self,
124
- *,
125
- tts: TTS,
126
- text: str,
127
- session: aioboto3.Session,
128
- conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
129
- opts: _TTSOptions,
130
- ) -> None:
131
- super().__init__(tts=tts, input_text=text, conn_options=conn_options)
132
- self._opts = opts
133
- self._segment_id = utils.shortuuid()
134
- self._session = session
135
-
136
- async def _run(self):
137
- request_id = utils.shortuuid()
138
-
139
- try:
140
- async with self._session.client("polly") as client:
141
- params = {
142
- "Text": self._input_text,
143
- "OutputFormat": "mp3",
144
- "Engine": self._opts.speech_engine
145
- if is_given(self._opts.speech_engine)
146
- else DEFAULT_SPEECH_ENGINE,
147
- "VoiceId": self._opts.voice if is_given(self._opts.voice) else DEFAULT_VOICE,
148
- "TextType": "text",
149
- "SampleRate": str(self._opts.sample_rate),
150
- "LanguageCode": self._opts.language if is_given(self._opts.language) else None,
151
- }
152
- response = await client.synthesize_speech(**_strip_nones(params))
153
- if "AudioStream" in response:
154
- decoder = utils.codecs.AudioStreamDecoder(
155
- sample_rate=self._opts.sample_rate,
156
- num_channels=1,
157
- )
158
-
159
- # Create a task to push data to the decoder
160
- async def push_data():
161
- try:
162
- async with response["AudioStream"] as resp:
163
- async for data, _ in resp.content.iter_chunks():
164
- decoder.push(data)
165
- finally:
166
- decoder.end_input()
167
-
168
- # Start pushing data to the decoder
169
- push_task = asyncio.create_task(push_data())
170
-
171
- try:
172
- # Create emitter and process decoded frames
173
- emitter = tts.SynthesizedAudioEmitter(
174
- event_ch=self._event_ch,
175
- request_id=request_id,
176
- segment_id=self._segment_id,
177
- )
178
- async for frame in decoder:
179
- emitter.push(frame)
180
- emitter.flush()
181
- await push_task
182
- finally:
183
- await utils.aio.gracefully_cancel(push_task)
184
-
185
- except asyncio.TimeoutError:
186
- raise APITimeoutError() from None
187
- except aiohttp.ClientResponseError as e:
188
- raise APIStatusError(
189
- message=e.message,
190
- status_code=e.status,
191
- request_id=request_id,
192
- body=None,
193
- ) from None
194
- except Exception as e:
195
- raise APIConnectionError() from e
@@ -1,113 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- from typing import Any
5
-
6
- from livekit.agents import llm
7
- from livekit.agents.llm import ChatContext, FunctionTool, ImageContent, utils
8
-
9
- __all__ = ["to_fnc_ctx", "to_chat_ctx"]
10
- DEFAULT_REGION = "us-east-1"
11
-
12
-
13
- def to_fnc_ctx(fncs: list[FunctionTool]) -> list[dict]:
14
- return [_build_tool_spec(fnc) for fnc in fncs]
15
-
16
-
17
- def to_chat_ctx(chat_ctx: ChatContext, cache_key: Any) -> tuple[list[dict], dict | None]:
18
- messages: list[dict] = []
19
- system_message: dict | None = None
20
- current_role: str | None = None
21
- current_content: list[dict] = []
22
-
23
- for msg in chat_ctx.items:
24
- if msg.type == "message" and msg.role == "system":
25
- for content in msg.content:
26
- if content and isinstance(content, str):
27
- system_message = {"text": content}
28
- continue
29
-
30
- if msg.type == "message":
31
- role = "assistant" if msg.role == "assistant" else "user"
32
- elif msg.type == "function_call":
33
- role = "assistant"
34
- elif msg.type == "function_call_output":
35
- role = "user"
36
-
37
- # if the effective role changed, finalize the previous turn.
38
- if role != current_role:
39
- if current_content and current_role is not None:
40
- messages.append({"role": current_role, "content": current_content})
41
- current_content = []
42
- current_role = role
43
-
44
- if msg.type == "message":
45
- for content in msg.content:
46
- if content and isinstance(content, str):
47
- current_content.append({"text": content})
48
- elif isinstance(content, ImageContent):
49
- current_content.append(_build_image(content, cache_key))
50
- elif msg.type == "function_call":
51
- current_content.append(
52
- {
53
- "toolUse": {
54
- "toolUseId": msg.call_id,
55
- "name": msg.name,
56
- "input": json.loads(msg.arguments or "{}"),
57
- }
58
- }
59
- )
60
- elif msg.type == "function_call_output":
61
- tool_response = {
62
- "toolResult": {
63
- "toolUseId": msg.call_id,
64
- "content": [],
65
- "status": "success",
66
- }
67
- }
68
- if isinstance(msg.output, dict):
69
- tool_response["toolResult"]["content"].append({"json": msg.output})
70
- elif isinstance(msg.output, str):
71
- tool_response["toolResult"]["content"].append({"text": msg.output})
72
- current_content.append(tool_response)
73
-
74
- # Finalize the last message if there’s any content left
75
- if current_role is not None and current_content:
76
- messages.append({"role": current_role, "content": current_content})
77
-
78
- # Ensure the message list starts with a "user" message
79
- if not messages or messages[0]["role"] != "user":
80
- messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})
81
-
82
- return messages, system_message
83
-
84
-
85
- def _build_tool_spec(fnc: FunctionTool) -> dict:
86
- fnc = llm.utils.build_legacy_openai_schema(fnc, internally_tagged=True)
87
- return {
88
- "toolSpec": _strip_nones(
89
- {
90
- "name": fnc["name"],
91
- "description": fnc["description"] if fnc["description"] else None,
92
- "inputSchema": {"json": fnc["parameters"] if fnc["parameters"] else {}},
93
- }
94
- )
95
- }
96
-
97
-
98
- def _build_image(image: ImageContent, cache_key: Any) -> dict:
99
- img = utils.serialize_image(image)
100
- if img.external_url:
101
- raise ValueError("external_url is not supported by AWS Bedrock.")
102
- if cache_key not in image._cache:
103
- image._cache[cache_key] = img.data_bytes
104
- return {
105
- "image": {
106
- "format": "jpeg",
107
- "source": {"bytes": image._cache[cache_key]},
108
- }
109
- }
110
-
111
-
112
- def _strip_nones(d: dict) -> dict:
113
- return {k: v for k, v in d.items() if v is not None}