livekit-plugins-aws 1.0.0rc6__tar.gz → 1.0.0rc7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of livekit-plugins-aws might be problematic. Click here for more details.
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/PKG-INFO +5 -5
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/llm.py +35 -45
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/stt.py +79 -48
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/tts.py +16 -26
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/utils.py +29 -28
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/version.py +1 -1
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/pyproject.toml +4 -4
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/.gitignore +0 -0
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/README.md +0 -0
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/__init__.py +0 -0
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/log.py +0 -0
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/models.py +0 -0
- {livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: livekit-plugins-aws
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.0rc7
|
|
4
4
|
Summary: LiveKit Agents Plugin for services from AWS
|
|
5
5
|
Project-URL: Documentation, https://docs.livekit.io
|
|
6
6
|
Project-URL: Website, https://livekit.io/
|
|
@@ -18,10 +18,10 @@ Classifier: Topic :: Multimedia :: Sound/Audio
|
|
|
18
18
|
Classifier: Topic :: Multimedia :: Video
|
|
19
19
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
20
|
Requires-Python: >=3.9.0
|
|
21
|
-
Requires-Dist:
|
|
22
|
-
Requires-Dist: amazon-transcribe
|
|
23
|
-
Requires-Dist: boto3==1.
|
|
24
|
-
Requires-Dist: livekit-agents>=1.0.0.
|
|
21
|
+
Requires-Dist: aioboto3==14.1.0
|
|
22
|
+
Requires-Dist: amazon-transcribe==0.6.2
|
|
23
|
+
Requires-Dist: boto3==1.37.1
|
|
24
|
+
Requires-Dist: livekit-agents>=1.0.0.rc7
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
|
|
27
27
|
# LiveKit Plugins AWS
|
|
@@ -14,12 +14,11 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import os
|
|
19
18
|
from dataclasses import dataclass
|
|
20
19
|
from typing import Any, Literal
|
|
21
20
|
|
|
22
|
-
import
|
|
21
|
+
import aioboto3
|
|
23
22
|
|
|
24
23
|
from livekit.agents import APIConnectionError, APIStatusError, llm
|
|
25
24
|
from livekit.agents.llm import ChatContext, FunctionTool, FunctionToolCall, ToolChoice
|
|
@@ -32,7 +31,7 @@ from livekit.agents.types import (
|
|
|
32
31
|
from livekit.agents.utils import is_given
|
|
33
32
|
|
|
34
33
|
from .log import logger
|
|
35
|
-
from .utils import
|
|
34
|
+
from .utils import get_aws_async_session, to_chat_ctx, to_fnc_ctx
|
|
36
35
|
|
|
37
36
|
TEXT_MODEL = Literal["anthropic.claude-3-5-sonnet-20241022-v2:0"]
|
|
38
37
|
|
|
@@ -60,6 +59,7 @@ class LLM(llm.LLM):
|
|
|
60
59
|
top_p: NotGivenOr[float] = NOT_GIVEN,
|
|
61
60
|
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
|
|
62
61
|
additional_request_fields: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
|
62
|
+
session: aioboto3.Session | None = None,
|
|
63
63
|
) -> None:
|
|
64
64
|
"""
|
|
65
65
|
Create a new instance of AWS Bedrock LLM.
|
|
@@ -79,10 +79,14 @@ class LLM(llm.LLM):
|
|
|
79
79
|
top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
|
|
80
80
|
tool_choice (ToolChoice, optional): Specifies whether to use tools during response generation. Defaults to "auto".
|
|
81
81
|
additional_request_fields (dict[str, Any], optional): Additional request fields to send to the AWS Bedrock Converse API. Defaults to None.
|
|
82
|
+
session (aioboto3.Session, optional): Optional aioboto3 session to use.
|
|
82
83
|
""" # noqa: E501
|
|
83
84
|
super().__init__()
|
|
84
|
-
|
|
85
|
-
|
|
85
|
+
|
|
86
|
+
self._session = session or get_aws_async_session(
|
|
87
|
+
api_key=api_key if is_given(api_key) else None,
|
|
88
|
+
api_secret=api_secret if is_given(api_secret) else None,
|
|
89
|
+
region=region if is_given(region) else None,
|
|
86
90
|
)
|
|
87
91
|
|
|
88
92
|
model = model if is_given(model) else os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
|
|
@@ -156,11 +160,9 @@ class LLM(llm.LLM):
|
|
|
156
160
|
|
|
157
161
|
return LLMStream(
|
|
158
162
|
self,
|
|
159
|
-
aws_access_key_id=self._api_key,
|
|
160
|
-
aws_secret_access_key=self._api_secret,
|
|
161
|
-
region_name=self._region,
|
|
162
163
|
chat_ctx=chat_ctx,
|
|
163
164
|
tools=tools,
|
|
165
|
+
session=self._session,
|
|
164
166
|
conn_options=conn_options,
|
|
165
167
|
extra_kwargs=opts,
|
|
166
168
|
)
|
|
@@ -171,24 +173,16 @@ class LLMStream(llm.LLMStream):
|
|
|
171
173
|
self,
|
|
172
174
|
llm: LLM,
|
|
173
175
|
*,
|
|
174
|
-
aws_access_key_id: str,
|
|
175
|
-
aws_secret_access_key: str,
|
|
176
|
-
region_name: str,
|
|
177
176
|
chat_ctx: ChatContext,
|
|
177
|
+
session: aioboto3.Session,
|
|
178
178
|
conn_options: APIConnectOptions,
|
|
179
179
|
tools: list[FunctionTool] | None,
|
|
180
180
|
extra_kwargs: dict[str, Any],
|
|
181
181
|
) -> None:
|
|
182
182
|
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
|
183
|
-
self._client = boto3.client(
|
|
184
|
-
"bedrock-runtime",
|
|
185
|
-
region_name=region_name,
|
|
186
|
-
aws_access_key_id=aws_access_key_id,
|
|
187
|
-
aws_secret_access_key=aws_secret_access_key,
|
|
188
|
-
)
|
|
189
183
|
self._llm: LLM = llm
|
|
190
184
|
self._opts = extra_kwargs
|
|
191
|
-
|
|
185
|
+
self._session = session
|
|
192
186
|
self._tool_call_id: str | None = None
|
|
193
187
|
self._fnc_name: str | None = None
|
|
194
188
|
self._fnc_raw_arguments: str | None = None
|
|
@@ -197,23 +191,21 @@ class LLMStream(llm.LLMStream):
|
|
|
197
191
|
async def _run(self) -> None:
|
|
198
192
|
retryable = True
|
|
199
193
|
try:
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
# Let other coroutines run
|
|
216
|
-
await asyncio.sleep(0)
|
|
194
|
+
async with self._session.client("bedrock-runtime") as client:
|
|
195
|
+
response = await client.converse_stream(**self._opts) # type: ignore
|
|
196
|
+
request_id = response["ResponseMetadata"]["RequestId"]
|
|
197
|
+
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
|
198
|
+
raise APIStatusError(
|
|
199
|
+
f"aws bedrock llm: error generating content: {response}",
|
|
200
|
+
retryable=False,
|
|
201
|
+
request_id=request_id,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
async for chunk in response["stream"]:
|
|
205
|
+
chat_chunk = self._parse_chunk(request_id, chunk)
|
|
206
|
+
if chat_chunk is not None:
|
|
207
|
+
retryable = False
|
|
208
|
+
self._event_ch.send_nowait(chat_chunk)
|
|
217
209
|
|
|
218
210
|
except Exception as e:
|
|
219
211
|
raise APIConnectionError(
|
|
@@ -233,12 +225,17 @@ class LLMStream(llm.LLMStream):
|
|
|
233
225
|
if "toolUse" in delta:
|
|
234
226
|
self._fnc_raw_arguments += delta["toolUse"]["input"]
|
|
235
227
|
elif "text" in delta:
|
|
236
|
-
|
|
228
|
+
return llm.ChatChunk(
|
|
229
|
+
id=request_id,
|
|
230
|
+
delta=llm.ChoiceDelta(content=delta["text"], role="assistant"),
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
logger.warning(f"aws bedrock llm: unknown chunk type: {chunk}")
|
|
237
234
|
|
|
238
235
|
elif "metadata" in chunk:
|
|
239
236
|
metadata = chunk["metadata"]
|
|
240
237
|
return llm.ChatChunk(
|
|
241
|
-
|
|
238
|
+
id=request_id,
|
|
242
239
|
usage=llm.CompletionUsage(
|
|
243
240
|
completion_tokens=metadata["usage"]["outputTokens"],
|
|
244
241
|
prompt_tokens=metadata["usage"]["inputTokens"],
|
|
@@ -246,14 +243,7 @@ class LLMStream(llm.LLMStream):
|
|
|
246
243
|
),
|
|
247
244
|
)
|
|
248
245
|
elif "contentBlockStop" in chunk:
|
|
249
|
-
if self.
|
|
250
|
-
chat_chunk = llm.ChatChunk(
|
|
251
|
-
id=request_id,
|
|
252
|
-
delta=llm.ChoiceDelta(content=self._text, role="assistant"),
|
|
253
|
-
)
|
|
254
|
-
self._text = ""
|
|
255
|
-
return chat_chunk
|
|
256
|
-
elif self._tool_call_id:
|
|
246
|
+
if self._tool_call_id:
|
|
257
247
|
if self._tool_call_id is None:
|
|
258
248
|
logger.warning("aws bedrock llm: no tool call id in the response")
|
|
259
249
|
return None
|
|
@@ -15,6 +15,8 @@ from __future__ import annotations
|
|
|
15
15
|
import asyncio
|
|
16
16
|
from dataclasses import dataclass
|
|
17
17
|
|
|
18
|
+
import aioboto3
|
|
19
|
+
from amazon_transcribe.auth import StaticCredentialResolver
|
|
18
20
|
from amazon_transcribe.client import TranscribeStreamingClient
|
|
19
21
|
from amazon_transcribe.model import Result, TranscriptEvent
|
|
20
22
|
|
|
@@ -24,12 +26,13 @@ from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
|
|
24
26
|
from livekit.agents.utils import is_given
|
|
25
27
|
|
|
26
28
|
from .log import logger
|
|
27
|
-
from .utils import
|
|
29
|
+
from .utils import DEFAULT_REGION, get_aws_async_session
|
|
30
|
+
|
|
31
|
+
REFRESH_INTERVAL = 1800
|
|
28
32
|
|
|
29
33
|
|
|
30
34
|
@dataclass
|
|
31
35
|
class STTOptions:
|
|
32
|
-
speech_region: str
|
|
33
36
|
sample_rate: int
|
|
34
37
|
language: str
|
|
35
38
|
encoding: str
|
|
@@ -49,7 +52,7 @@ class STT(stt.STT):
|
|
|
49
52
|
def __init__(
|
|
50
53
|
self,
|
|
51
54
|
*,
|
|
52
|
-
|
|
55
|
+
region: NotGivenOr[str] = NOT_GIVEN,
|
|
53
56
|
api_key: NotGivenOr[str] = NOT_GIVEN,
|
|
54
57
|
api_secret: NotGivenOr[str] = NOT_GIVEN,
|
|
55
58
|
sample_rate: int = 48000,
|
|
@@ -65,14 +68,18 @@ class STT(stt.STT):
|
|
|
65
68
|
enable_partial_results_stabilization: NotGivenOr[bool] = NOT_GIVEN,
|
|
66
69
|
partial_results_stability: NotGivenOr[str] = NOT_GIVEN,
|
|
67
70
|
language_model_name: NotGivenOr[str] = NOT_GIVEN,
|
|
71
|
+
session: aioboto3.Session | None = None,
|
|
72
|
+
refresh_interval: NotGivenOr[int] = NOT_GIVEN,
|
|
68
73
|
):
|
|
69
74
|
super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
|
|
70
|
-
|
|
71
|
-
self.
|
|
72
|
-
api_key
|
|
75
|
+
self._region = region if is_given(region) else DEFAULT_REGION
|
|
76
|
+
self._session = session or get_aws_async_session(
|
|
77
|
+
api_key=api_key if is_given(api_key) else None,
|
|
78
|
+
api_secret=api_secret if is_given(api_secret) else None,
|
|
79
|
+
region=self._region,
|
|
73
80
|
)
|
|
81
|
+
|
|
74
82
|
self._config = STTOptions(
|
|
75
|
-
speech_region=self._speech_region,
|
|
76
83
|
language=language,
|
|
77
84
|
sample_rate=sample_rate,
|
|
78
85
|
encoding=encoding,
|
|
@@ -87,6 +94,28 @@ class STT(stt.STT):
|
|
|
87
94
|
partial_results_stability=partial_results_stability,
|
|
88
95
|
language_model_name=language_model_name,
|
|
89
96
|
)
|
|
97
|
+
self._pool = utils.ConnectionPool[TranscribeStreamingClient](
|
|
98
|
+
connect_cb=self._create_client,
|
|
99
|
+
max_session_duration=refresh_interval
|
|
100
|
+
if is_given(refresh_interval)
|
|
101
|
+
else REFRESH_INTERVAL,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
async def _create_client(self) -> TranscribeStreamingClient:
|
|
105
|
+
creds = await self._session.get_credentials()
|
|
106
|
+
frozen_credentials = await creds.get_frozen_credentials()
|
|
107
|
+
return TranscribeStreamingClient(
|
|
108
|
+
region=self._region,
|
|
109
|
+
credential_resolver=StaticCredentialResolver(
|
|
110
|
+
access_key_id=frozen_credentials.access_key,
|
|
111
|
+
secret_access_key=frozen_credentials.secret_key,
|
|
112
|
+
session_token=frozen_credentials.token,
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
async def aclose(self) -> None:
|
|
117
|
+
await self._pool.aclose()
|
|
118
|
+
await super().aclose()
|
|
90
119
|
|
|
91
120
|
async def _recognize_impl(
|
|
92
121
|
self,
|
|
@@ -105,6 +134,7 @@ class STT(stt.STT):
|
|
|
105
134
|
) -> SpeechStream:
|
|
106
135
|
return SpeechStream(
|
|
107
136
|
stt=self,
|
|
137
|
+
pool=self._pool,
|
|
108
138
|
conn_options=conn_options,
|
|
109
139
|
opts=self._config,
|
|
110
140
|
)
|
|
@@ -115,52 +145,54 @@ class SpeechStream(stt.SpeechStream):
|
|
|
115
145
|
self,
|
|
116
146
|
stt: STT,
|
|
117
147
|
opts: STTOptions,
|
|
148
|
+
pool: utils.ConnectionPool[TranscribeStreamingClient],
|
|
118
149
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
119
150
|
) -> None:
|
|
120
151
|
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
|
|
121
152
|
self._opts = opts
|
|
122
|
-
self.
|
|
153
|
+
self._pool = pool
|
|
123
154
|
|
|
124
155
|
async def _run(self) -> None:
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
async
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
async
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
156
|
+
async with self._pool.connection() as client:
|
|
157
|
+
live_config = {
|
|
158
|
+
"language_code": self._opts.language,
|
|
159
|
+
"media_sample_rate_hz": self._opts.sample_rate,
|
|
160
|
+
"media_encoding": self._opts.encoding,
|
|
161
|
+
"vocabulary_name": self._opts.vocabulary_name,
|
|
162
|
+
"session_id": self._opts.session_id,
|
|
163
|
+
"vocab_filter_method": self._opts.vocab_filter_method,
|
|
164
|
+
"vocab_filter_name": self._opts.vocab_filter_name,
|
|
165
|
+
"show_speaker_label": self._opts.show_speaker_label,
|
|
166
|
+
"enable_channel_identification": self._opts.enable_channel_identification,
|
|
167
|
+
"number_of_channels": self._opts.number_of_channels,
|
|
168
|
+
"enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization, # noqa: E501
|
|
169
|
+
"partial_results_stability": self._opts.partial_results_stability,
|
|
170
|
+
"language_model_name": self._opts.language_model_name,
|
|
171
|
+
}
|
|
172
|
+
filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
|
|
173
|
+
stream = await client.start_stream_transcription(**filtered_config)
|
|
174
|
+
|
|
175
|
+
@utils.log_exceptions(logger=logger)
|
|
176
|
+
async def input_generator():
|
|
177
|
+
async for frame in self._input_ch:
|
|
178
|
+
if isinstance(frame, rtc.AudioFrame):
|
|
179
|
+
await stream.input_stream.send_audio_event(audio_chunk=frame.data.tobytes())
|
|
180
|
+
await stream.input_stream.end_stream()
|
|
181
|
+
|
|
182
|
+
@utils.log_exceptions(logger=logger)
|
|
183
|
+
async def handle_transcript_events():
|
|
184
|
+
async for event in stream.output_stream:
|
|
185
|
+
if isinstance(event, TranscriptEvent):
|
|
186
|
+
self._process_transcript_event(event)
|
|
187
|
+
|
|
188
|
+
tasks = [
|
|
189
|
+
asyncio.create_task(input_generator()),
|
|
190
|
+
asyncio.create_task(handle_transcript_events()),
|
|
191
|
+
]
|
|
192
|
+
try:
|
|
193
|
+
await asyncio.gather(*tasks)
|
|
194
|
+
finally:
|
|
195
|
+
await utils.aio.gracefully_cancel(*tasks)
|
|
164
196
|
|
|
165
197
|
def _process_transcript_event(self, transcript_event: TranscriptEvent):
|
|
166
198
|
stream = transcript_event.transcript.results
|
|
@@ -196,7 +228,6 @@ def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData
|
|
|
196
228
|
language="en-US",
|
|
197
229
|
start_time=resp.start_time if resp.start_time else 0.0,
|
|
198
230
|
end_time=resp.end_time if resp.end_time else 0.0,
|
|
199
|
-
confidence=0.0,
|
|
200
231
|
text=resp.alternatives[0].transcript if resp.alternatives else "",
|
|
201
232
|
)
|
|
202
233
|
|
|
@@ -14,10 +14,9 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
from typing import Any, Callable
|
|
18
17
|
|
|
18
|
+
import aioboto3
|
|
19
19
|
import aiohttp
|
|
20
|
-
from aiobotocore.session import AioSession, get_session
|
|
21
20
|
|
|
22
21
|
from livekit.agents import (
|
|
23
22
|
APIConnectionError,
|
|
@@ -35,11 +34,10 @@ from livekit.agents.types import (
|
|
|
35
34
|
from livekit.agents.utils import is_given
|
|
36
35
|
|
|
37
36
|
from .models import TTS_LANGUAGE, TTS_SPEECH_ENGINE
|
|
38
|
-
from .utils import _strip_nones,
|
|
37
|
+
from .utils import _strip_nones, get_aws_async_session
|
|
39
38
|
|
|
40
39
|
TTS_NUM_CHANNELS: int = 1
|
|
41
40
|
DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
|
|
42
|
-
DEFAULT_SPEECH_REGION = "us-east-1"
|
|
43
41
|
DEFAULT_VOICE = "Ruth"
|
|
44
42
|
DEFAULT_SAMPLE_RATE = 16000
|
|
45
43
|
|
|
@@ -49,7 +47,7 @@ class _TTSOptions:
|
|
|
49
47
|
# https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
|
|
50
48
|
voice: NotGivenOr[str]
|
|
51
49
|
speech_engine: NotGivenOr[TTS_SPEECH_ENGINE]
|
|
52
|
-
|
|
50
|
+
region: str
|
|
53
51
|
sample_rate: int
|
|
54
52
|
language: NotGivenOr[TTS_LANGUAGE | str]
|
|
55
53
|
|
|
@@ -62,10 +60,10 @@ class TTS(tts.TTS):
|
|
|
62
60
|
language: NotGivenOr[TTS_LANGUAGE | str] = NOT_GIVEN,
|
|
63
61
|
speech_engine: NotGivenOr[TTS_SPEECH_ENGINE] = NOT_GIVEN,
|
|
64
62
|
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
|
65
|
-
|
|
63
|
+
region: NotGivenOr[str] = NOT_GIVEN,
|
|
66
64
|
api_key: NotGivenOr[str] = NOT_GIVEN,
|
|
67
65
|
api_secret: NotGivenOr[str] = NOT_GIVEN,
|
|
68
|
-
session:
|
|
66
|
+
session: aioboto3.Session | None = None,
|
|
69
67
|
) -> None:
|
|
70
68
|
"""
|
|
71
69
|
Create a new instance of AWS Polly TTS.
|
|
@@ -80,9 +78,10 @@ class TTS(tts.TTS):
|
|
|
80
78
|
language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
|
|
81
79
|
sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
|
|
82
80
|
speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
|
|
83
|
-
|
|
81
|
+
region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
|
|
84
82
|
api_key(str, optional): AWS access key id.
|
|
85
83
|
api_secret(str, optional): AWS secret access key.
|
|
84
|
+
session(aioboto3.Session, optional): Optional aioboto3 session to use.
|
|
86
85
|
""" # noqa: E501
|
|
87
86
|
super().__init__(
|
|
88
87
|
capabilities=tts.TTSCapabilities(
|
|
@@ -91,27 +90,18 @@ class TTS(tts.TTS):
|
|
|
91
90
|
sample_rate=sample_rate,
|
|
92
91
|
num_channels=TTS_NUM_CHANNELS,
|
|
93
92
|
)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
93
|
+
self._session = session or get_aws_async_session(
|
|
94
|
+
api_key=api_key if is_given(api_key) else None,
|
|
95
|
+
api_secret=api_secret if is_given(api_secret) else None,
|
|
96
|
+
region=region if is_given(region) else None,
|
|
97
97
|
)
|
|
98
|
-
|
|
99
98
|
self._opts = _TTSOptions(
|
|
100
99
|
voice=voice,
|
|
101
100
|
speech_engine=speech_engine,
|
|
102
|
-
|
|
101
|
+
region=region,
|
|
103
102
|
language=language,
|
|
104
103
|
sample_rate=sample_rate,
|
|
105
104
|
)
|
|
106
|
-
self._session = session or get_session()
|
|
107
|
-
|
|
108
|
-
def _get_client(self):
|
|
109
|
-
return self._session.create_client(
|
|
110
|
-
"polly",
|
|
111
|
-
region_name=self._opts.speech_region,
|
|
112
|
-
aws_access_key_id=self._api_key,
|
|
113
|
-
aws_secret_access_key=self._api_secret,
|
|
114
|
-
)
|
|
115
105
|
|
|
116
106
|
def synthesize(
|
|
117
107
|
self,
|
|
@@ -123,8 +113,8 @@ class TTS(tts.TTS):
|
|
|
123
113
|
tts=self,
|
|
124
114
|
text=text,
|
|
125
115
|
conn_options=conn_options,
|
|
116
|
+
session=self._session,
|
|
126
117
|
opts=self._opts,
|
|
127
|
-
get_client=self._get_client,
|
|
128
118
|
)
|
|
129
119
|
|
|
130
120
|
|
|
@@ -134,20 +124,20 @@ class ChunkedStream(tts.ChunkedStream):
|
|
|
134
124
|
*,
|
|
135
125
|
tts: TTS,
|
|
136
126
|
text: str,
|
|
127
|
+
session: aioboto3.Session,
|
|
137
128
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
138
129
|
opts: _TTSOptions,
|
|
139
|
-
get_client: Callable[[], Any],
|
|
140
130
|
) -> None:
|
|
141
131
|
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
|
|
142
132
|
self._opts = opts
|
|
143
|
-
self._get_client = get_client
|
|
144
133
|
self._segment_id = utils.shortuuid()
|
|
134
|
+
self._session = session
|
|
145
135
|
|
|
146
136
|
async def _run(self):
|
|
147
137
|
request_id = utils.shortuuid()
|
|
148
138
|
|
|
149
139
|
try:
|
|
150
|
-
async with self.
|
|
140
|
+
async with self._session.client("polly") as client:
|
|
151
141
|
params = {
|
|
152
142
|
"Text": self._input_text,
|
|
153
143
|
"OutputFormat": "mp3",
|
|
@@ -1,43 +1,44 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
import
|
|
5
|
-
from typing import Any, cast
|
|
4
|
+
from typing import Any
|
|
6
5
|
|
|
6
|
+
import aioboto3
|
|
7
7
|
import boto3
|
|
8
|
+
from botocore.exceptions import NoCredentialsError
|
|
8
9
|
|
|
9
10
|
from livekit.agents import llm
|
|
10
11
|
from livekit.agents.llm import ChatContext, FunctionTool, ImageContent, utils
|
|
11
|
-
from livekit.agents.types import NotGivenOr
|
|
12
|
-
from livekit.agents.utils import is_given
|
|
13
|
-
|
|
14
|
-
__all__ = ["to_fnc_ctx", "to_chat_ctx", "get_aws_credentials"]
|
|
15
12
|
|
|
13
|
+
__all__ = ["to_fnc_ctx", "to_chat_ctx", "get_aws_async_session"]
|
|
16
14
|
DEFAULT_REGION = "us-east-1"
|
|
17
15
|
|
|
18
16
|
|
|
19
|
-
def
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
):
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
17
|
+
def get_aws_async_session(
|
|
18
|
+
region: str | None = None,
|
|
19
|
+
api_key: str | None = None,
|
|
20
|
+
api_secret: str | None = None,
|
|
21
|
+
) -> aioboto3.Session:
|
|
22
|
+
_validate_aws_credentials(api_key, api_secret)
|
|
23
|
+
session = aioboto3.Session(
|
|
24
|
+
aws_access_key_id=api_key,
|
|
25
|
+
aws_secret_access_key=api_secret,
|
|
26
|
+
region_name=region or DEFAULT_REGION,
|
|
27
|
+
)
|
|
28
|
+
return session
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _validate_aws_credentials(
|
|
32
|
+
api_key: str | None = None,
|
|
33
|
+
api_secret: str | None = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
try:
|
|
36
|
+
session = boto3.Session(aws_access_key_id=api_key, aws_secret_access_key=api_secret)
|
|
37
|
+
creds = session.get_credentials()
|
|
38
|
+
if not creds:
|
|
39
|
+
raise ValueError("No credentials found")
|
|
40
|
+
except (NoCredentialsError, Exception) as e:
|
|
41
|
+
raise ValueError(f"Unable to locate valid AWS credentials: {str(e)}") from e
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
def to_fnc_ctx(fncs: list[FunctionTool]) -> list[dict]:
|
|
@@ -23,10 +23,10 @@ classifiers = [
|
|
|
23
23
|
"Programming Language :: Python :: 3 :: Only",
|
|
24
24
|
]
|
|
25
25
|
dependencies = [
|
|
26
|
-
"livekit-agents>=1.0.0.
|
|
27
|
-
"
|
|
28
|
-
"
|
|
29
|
-
"
|
|
26
|
+
"livekit-agents>=1.0.0.rc7",
|
|
27
|
+
"aioboto3==14.1.0",
|
|
28
|
+
"amazon-transcribe==0.6.2",
|
|
29
|
+
"boto3==1.37.1",
|
|
30
30
|
]
|
|
31
31
|
|
|
32
32
|
[project.urls]
|
|
File without changes
|
|
File without changes
|
{livekit_plugins_aws-1.0.0rc6 → livekit_plugins_aws-1.0.0rc7}/livekit/plugins/aws/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|