livekit-plugins-aws 0.1.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of livekit-plugins-aws might be problematic. Click here for more details.
- livekit/plugins/aws/llm.py +160 -239
- livekit/plugins/aws/stt.py +114 -98
- livekit/plugins/aws/tts.py +40 -45
- livekit/plugins/aws/utils.py +144 -0
- livekit/plugins/aws/version.py +1 -1
- {livekit_plugins_aws-0.1.1.dist-info → livekit_plugins_aws-1.0.0.dist-info}/METADATA +14 -24
- livekit_plugins_aws-1.0.0.dist-info/RECORD +12 -0
- {livekit_plugins_aws-0.1.1.dist-info → livekit_plugins_aws-1.0.0.dist-info}/WHEEL +1 -2
- livekit/plugins/aws/_utils.py +0 -216
- livekit_plugins_aws-0.1.1.dist-info/RECORD +0 -13
- livekit_plugins_aws-0.1.1.dist-info/top_level.txt +0 -1
livekit/plugins/aws/stt.py
CHANGED
|
@@ -14,70 +14,72 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
from typing import Optional
|
|
18
17
|
|
|
18
|
+
import aioboto3
|
|
19
|
+
from amazon_transcribe.auth import StaticCredentialResolver
|
|
19
20
|
from amazon_transcribe.client import TranscribeStreamingClient
|
|
20
21
|
from amazon_transcribe.model import Result, TranscriptEvent
|
|
22
|
+
|
|
21
23
|
from livekit import rtc
|
|
22
|
-
from livekit.agents import
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
utils,
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
from ._utils import _get_aws_credentials
|
|
24
|
+
from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions, stt, utils
|
|
25
|
+
from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
|
26
|
+
from livekit.agents.utils import is_given
|
|
27
|
+
|
|
30
28
|
from .log import logger
|
|
29
|
+
from .utils import DEFAULT_REGION, get_aws_async_session
|
|
30
|
+
|
|
31
|
+
REFRESH_INTERVAL = 1800
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
@dataclass
|
|
34
35
|
class STTOptions:
|
|
35
|
-
speech_region: str
|
|
36
36
|
sample_rate: int
|
|
37
37
|
language: str
|
|
38
38
|
encoding: str
|
|
39
|
-
vocabulary_name:
|
|
40
|
-
session_id:
|
|
41
|
-
vocab_filter_method:
|
|
42
|
-
vocab_filter_name:
|
|
43
|
-
show_speaker_label:
|
|
44
|
-
enable_channel_identification:
|
|
45
|
-
number_of_channels:
|
|
46
|
-
enable_partial_results_stabilization:
|
|
47
|
-
partial_results_stability:
|
|
48
|
-
language_model_name:
|
|
39
|
+
vocabulary_name: NotGivenOr[str]
|
|
40
|
+
session_id: NotGivenOr[str]
|
|
41
|
+
vocab_filter_method: NotGivenOr[str]
|
|
42
|
+
vocab_filter_name: NotGivenOr[str]
|
|
43
|
+
show_speaker_label: NotGivenOr[bool]
|
|
44
|
+
enable_channel_identification: NotGivenOr[bool]
|
|
45
|
+
number_of_channels: NotGivenOr[int]
|
|
46
|
+
enable_partial_results_stabilization: NotGivenOr[bool]
|
|
47
|
+
partial_results_stability: NotGivenOr[str]
|
|
48
|
+
language_model_name: NotGivenOr[str]
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
class STT(stt.STT):
|
|
52
52
|
def __init__(
|
|
53
53
|
self,
|
|
54
54
|
*,
|
|
55
|
-
|
|
56
|
-
api_key: str
|
|
57
|
-
api_secret: str
|
|
55
|
+
region: NotGivenOr[str] = NOT_GIVEN,
|
|
56
|
+
api_key: NotGivenOr[str] = NOT_GIVEN,
|
|
57
|
+
api_secret: NotGivenOr[str] = NOT_GIVEN,
|
|
58
58
|
sample_rate: int = 48000,
|
|
59
59
|
language: str = "en-US",
|
|
60
60
|
encoding: str = "pcm",
|
|
61
|
-
vocabulary_name:
|
|
62
|
-
session_id:
|
|
63
|
-
vocab_filter_method:
|
|
64
|
-
vocab_filter_name:
|
|
65
|
-
show_speaker_label:
|
|
66
|
-
enable_channel_identification:
|
|
67
|
-
number_of_channels:
|
|
68
|
-
enable_partial_results_stabilization:
|
|
69
|
-
partial_results_stability:
|
|
70
|
-
language_model_name:
|
|
61
|
+
vocabulary_name: NotGivenOr[str] = NOT_GIVEN,
|
|
62
|
+
session_id: NotGivenOr[str] = NOT_GIVEN,
|
|
63
|
+
vocab_filter_method: NotGivenOr[str] = NOT_GIVEN,
|
|
64
|
+
vocab_filter_name: NotGivenOr[str] = NOT_GIVEN,
|
|
65
|
+
show_speaker_label: NotGivenOr[bool] = NOT_GIVEN,
|
|
66
|
+
enable_channel_identification: NotGivenOr[bool] = NOT_GIVEN,
|
|
67
|
+
number_of_channels: NotGivenOr[int] = NOT_GIVEN,
|
|
68
|
+
enable_partial_results_stabilization: NotGivenOr[bool] = NOT_GIVEN,
|
|
69
|
+
partial_results_stability: NotGivenOr[str] = NOT_GIVEN,
|
|
70
|
+
language_model_name: NotGivenOr[str] = NOT_GIVEN,
|
|
71
|
+
session: aioboto3.Session | None = None,
|
|
72
|
+
refresh_interval: NotGivenOr[int] = NOT_GIVEN,
|
|
71
73
|
):
|
|
72
|
-
super().__init__(
|
|
73
|
-
|
|
74
|
+
super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
|
|
75
|
+
self._region = region if is_given(region) else DEFAULT_REGION
|
|
76
|
+
self._session = session or get_aws_async_session(
|
|
77
|
+
api_key=api_key if is_given(api_key) else None,
|
|
78
|
+
api_secret=api_secret if is_given(api_secret) else None,
|
|
79
|
+
region=self._region,
|
|
74
80
|
)
|
|
75
81
|
|
|
76
|
-
self._api_key, self._api_secret = _get_aws_credentials(
|
|
77
|
-
api_key, api_secret, speech_region
|
|
78
|
-
)
|
|
79
82
|
self._config = STTOptions(
|
|
80
|
-
speech_region=speech_region,
|
|
81
83
|
language=language,
|
|
82
84
|
sample_rate=sample_rate,
|
|
83
85
|
encoding=encoding,
|
|
@@ -92,26 +94,47 @@ class STT(stt.STT):
|
|
|
92
94
|
partial_results_stability=partial_results_stability,
|
|
93
95
|
language_model_name=language_model_name,
|
|
94
96
|
)
|
|
97
|
+
self._pool = utils.ConnectionPool[TranscribeStreamingClient](
|
|
98
|
+
connect_cb=self._create_client,
|
|
99
|
+
max_session_duration=refresh_interval
|
|
100
|
+
if is_given(refresh_interval)
|
|
101
|
+
else REFRESH_INTERVAL,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
async def _create_client(self) -> TranscribeStreamingClient:
|
|
105
|
+
creds = await self._session.get_credentials()
|
|
106
|
+
frozen_credentials = await creds.get_frozen_credentials()
|
|
107
|
+
return TranscribeStreamingClient(
|
|
108
|
+
region=self._region,
|
|
109
|
+
credential_resolver=StaticCredentialResolver(
|
|
110
|
+
access_key_id=frozen_credentials.access_key,
|
|
111
|
+
secret_access_key=frozen_credentials.secret_key,
|
|
112
|
+
session_token=frozen_credentials.token,
|
|
113
|
+
),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
async def aclose(self) -> None:
|
|
117
|
+
await self._pool.aclose()
|
|
118
|
+
await super().aclose()
|
|
95
119
|
|
|
96
120
|
async def _recognize_impl(
|
|
97
121
|
self,
|
|
98
122
|
buffer: utils.AudioBuffer,
|
|
99
123
|
*,
|
|
100
|
-
language: str
|
|
124
|
+
language: NotGivenOr[str] = NOT_GIVEN,
|
|
101
125
|
conn_options: APIConnectOptions,
|
|
102
126
|
) -> stt.SpeechEvent:
|
|
103
|
-
raise NotImplementedError(
|
|
104
|
-
"Amazon Transcribe does not support single frame recognition"
|
|
105
|
-
)
|
|
127
|
+
raise NotImplementedError("Amazon Transcribe does not support single frame recognition")
|
|
106
128
|
|
|
107
129
|
def stream(
|
|
108
130
|
self,
|
|
109
131
|
*,
|
|
110
|
-
language: str
|
|
132
|
+
language: NotGivenOr[str] = NOT_GIVEN,
|
|
111
133
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
112
|
-
) ->
|
|
134
|
+
) -> SpeechStream:
|
|
113
135
|
return SpeechStream(
|
|
114
136
|
stt=self,
|
|
137
|
+
pool=self._pool,
|
|
115
138
|
conn_options=conn_options,
|
|
116
139
|
opts=self._config,
|
|
117
140
|
)
|
|
@@ -122,54 +145,54 @@ class SpeechStream(stt.SpeechStream):
|
|
|
122
145
|
self,
|
|
123
146
|
stt: STT,
|
|
124
147
|
opts: STTOptions,
|
|
148
|
+
pool: utils.ConnectionPool[TranscribeStreamingClient],
|
|
125
149
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
126
150
|
) -> None:
|
|
127
|
-
super().__init__(
|
|
128
|
-
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
|
|
129
|
-
)
|
|
151
|
+
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
|
|
130
152
|
self._opts = opts
|
|
131
|
-
self.
|
|
153
|
+
self._pool = pool
|
|
132
154
|
|
|
133
155
|
async def _run(self) -> None:
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
async
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
156
|
+
async with self._pool.connection() as client:
|
|
157
|
+
live_config = {
|
|
158
|
+
"language_code": self._opts.language,
|
|
159
|
+
"media_sample_rate_hz": self._opts.sample_rate,
|
|
160
|
+
"media_encoding": self._opts.encoding,
|
|
161
|
+
"vocabulary_name": self._opts.vocabulary_name,
|
|
162
|
+
"session_id": self._opts.session_id,
|
|
163
|
+
"vocab_filter_method": self._opts.vocab_filter_method,
|
|
164
|
+
"vocab_filter_name": self._opts.vocab_filter_name,
|
|
165
|
+
"show_speaker_label": self._opts.show_speaker_label,
|
|
166
|
+
"enable_channel_identification": self._opts.enable_channel_identification,
|
|
167
|
+
"number_of_channels": self._opts.number_of_channels,
|
|
168
|
+
"enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization, # noqa: E501
|
|
169
|
+
"partial_results_stability": self._opts.partial_results_stability,
|
|
170
|
+
"language_model_name": self._opts.language_model_name,
|
|
171
|
+
}
|
|
172
|
+
filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
|
|
173
|
+
stream = await client.start_stream_transcription(**filtered_config)
|
|
174
|
+
|
|
175
|
+
@utils.log_exceptions(logger=logger)
|
|
176
|
+
async def input_generator():
|
|
177
|
+
async for frame in self._input_ch:
|
|
178
|
+
if isinstance(frame, rtc.AudioFrame):
|
|
179
|
+
await stream.input_stream.send_audio_event(audio_chunk=frame.data.tobytes())
|
|
180
|
+
await stream.input_stream.end_stream()
|
|
181
|
+
|
|
182
|
+
@utils.log_exceptions(logger=logger)
|
|
183
|
+
async def handle_transcript_events():
|
|
184
|
+
async for event in stream.output_stream:
|
|
185
|
+
if isinstance(event, TranscriptEvent):
|
|
186
|
+
self._process_transcript_event(event)
|
|
187
|
+
|
|
188
|
+
tasks = [
|
|
189
|
+
asyncio.create_task(input_generator()),
|
|
190
|
+
asyncio.create_task(handle_transcript_events()),
|
|
191
|
+
]
|
|
192
|
+
try:
|
|
193
|
+
await asyncio.gather(*tasks)
|
|
194
|
+
finally:
|
|
195
|
+
await utils.aio.gracefully_cancel(*tasks)
|
|
173
196
|
|
|
174
197
|
def _process_transcript_event(self, transcript_event: TranscriptEvent):
|
|
175
198
|
stream = transcript_event.transcript.results
|
|
@@ -184,9 +207,7 @@ class SpeechStream(stt.SpeechStream):
|
|
|
184
207
|
self._event_ch.send_nowait(
|
|
185
208
|
stt.SpeechEvent(
|
|
186
209
|
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
|
|
187
|
-
alternatives=[
|
|
188
|
-
_streaming_recognize_response_to_speech_data(resp)
|
|
189
|
-
],
|
|
210
|
+
alternatives=[_streaming_recognize_response_to_speech_data(resp)],
|
|
190
211
|
)
|
|
191
212
|
)
|
|
192
213
|
|
|
@@ -194,16 +215,12 @@ class SpeechStream(stt.SpeechStream):
|
|
|
194
215
|
self._event_ch.send_nowait(
|
|
195
216
|
stt.SpeechEvent(
|
|
196
217
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
197
|
-
alternatives=[
|
|
198
|
-
_streaming_recognize_response_to_speech_data(resp)
|
|
199
|
-
],
|
|
218
|
+
alternatives=[_streaming_recognize_response_to_speech_data(resp)],
|
|
200
219
|
)
|
|
201
220
|
)
|
|
202
221
|
|
|
203
222
|
if not resp.is_partial:
|
|
204
|
-
self._event_ch.send_nowait(
|
|
205
|
-
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
|
206
|
-
)
|
|
223
|
+
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
|
|
207
224
|
|
|
208
225
|
|
|
209
226
|
def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData:
|
|
@@ -211,7 +228,6 @@ def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData
|
|
|
211
228
|
language="en-US",
|
|
212
229
|
start_time=resp.start_time if resp.start_time else 0.0,
|
|
213
230
|
end_time=resp.end_time if resp.end_time else 0.0,
|
|
214
|
-
confidence=0.0,
|
|
215
231
|
text=resp.alternatives[0].transcript if resp.alternatives else "",
|
|
216
232
|
)
|
|
217
233
|
|
livekit/plugins/aws/tts.py
CHANGED
|
@@ -14,10 +14,10 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
from typing import Any, Callable, Optional
|
|
18
17
|
|
|
18
|
+
import aioboto3
|
|
19
19
|
import aiohttp
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
from livekit.agents import (
|
|
22
22
|
APIConnectionError,
|
|
23
23
|
APIConnectOptions,
|
|
@@ -26,13 +26,18 @@ from livekit.agents import (
|
|
|
26
26
|
tts,
|
|
27
27
|
utils,
|
|
28
28
|
)
|
|
29
|
+
from livekit.agents.types import (
|
|
30
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
|
31
|
+
NOT_GIVEN,
|
|
32
|
+
NotGivenOr,
|
|
33
|
+
)
|
|
34
|
+
from livekit.agents.utils import is_given
|
|
29
35
|
|
|
30
|
-
from ._utils import _get_aws_credentials
|
|
31
36
|
from .models import TTS_LANGUAGE, TTS_SPEECH_ENGINE
|
|
37
|
+
from .utils import _strip_nones, get_aws_async_session
|
|
32
38
|
|
|
33
39
|
TTS_NUM_CHANNELS: int = 1
|
|
34
40
|
DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
|
|
35
|
-
DEFAULT_SPEECH_REGION = "us-east-1"
|
|
36
41
|
DEFAULT_VOICE = "Ruth"
|
|
37
42
|
DEFAULT_SAMPLE_RATE = 16000
|
|
38
43
|
|
|
@@ -40,25 +45,25 @@ DEFAULT_SAMPLE_RATE = 16000
|
|
|
40
45
|
@dataclass
|
|
41
46
|
class _TTSOptions:
|
|
42
47
|
# https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
|
|
43
|
-
voice: str
|
|
44
|
-
speech_engine: TTS_SPEECH_ENGINE
|
|
45
|
-
|
|
48
|
+
voice: NotGivenOr[str]
|
|
49
|
+
speech_engine: NotGivenOr[TTS_SPEECH_ENGINE]
|
|
50
|
+
region: str
|
|
46
51
|
sample_rate: int
|
|
47
|
-
language: TTS_LANGUAGE | str
|
|
52
|
+
language: NotGivenOr[TTS_LANGUAGE | str]
|
|
48
53
|
|
|
49
54
|
|
|
50
55
|
class TTS(tts.TTS):
|
|
51
56
|
def __init__(
|
|
52
57
|
self,
|
|
53
58
|
*,
|
|
54
|
-
voice: str
|
|
55
|
-
language: TTS_LANGUAGE | str
|
|
56
|
-
speech_engine: TTS_SPEECH_ENGINE =
|
|
59
|
+
voice: NotGivenOr[str] = NOT_GIVEN,
|
|
60
|
+
language: NotGivenOr[TTS_LANGUAGE | str] = NOT_GIVEN,
|
|
61
|
+
speech_engine: NotGivenOr[TTS_SPEECH_ENGINE] = NOT_GIVEN,
|
|
57
62
|
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
|
58
|
-
|
|
59
|
-
api_key: str
|
|
60
|
-
api_secret: str
|
|
61
|
-
session:
|
|
63
|
+
region: NotGivenOr[str] = NOT_GIVEN,
|
|
64
|
+
api_key: NotGivenOr[str] = NOT_GIVEN,
|
|
65
|
+
api_secret: NotGivenOr[str] = NOT_GIVEN,
|
|
66
|
+
session: aioboto3.Session | None = None,
|
|
62
67
|
) -> None:
|
|
63
68
|
"""
|
|
64
69
|
Create a new instance of AWS Polly TTS.
|
|
@@ -73,10 +78,11 @@ class TTS(tts.TTS):
|
|
|
73
78
|
language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
|
|
74
79
|
sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
|
|
75
80
|
speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
|
|
76
|
-
|
|
81
|
+
region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
|
|
77
82
|
api_key(str, optional): AWS access key id.
|
|
78
83
|
api_secret(str, optional): AWS secret access key.
|
|
79
|
-
|
|
84
|
+
session(aioboto3.Session, optional): Optional aioboto3 session to use.
|
|
85
|
+
""" # noqa: E501
|
|
80
86
|
super().__init__(
|
|
81
87
|
capabilities=tts.TTSCapabilities(
|
|
82
88
|
streaming=False,
|
|
@@ -84,40 +90,31 @@ class TTS(tts.TTS):
|
|
|
84
90
|
sample_rate=sample_rate,
|
|
85
91
|
num_channels=TTS_NUM_CHANNELS,
|
|
86
92
|
)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
93
|
+
self._session = session or get_aws_async_session(
|
|
94
|
+
api_key=api_key if is_given(api_key) else None,
|
|
95
|
+
api_secret=api_secret if is_given(api_secret) else None,
|
|
96
|
+
region=region if is_given(region) else None,
|
|
90
97
|
)
|
|
91
|
-
|
|
92
98
|
self._opts = _TTSOptions(
|
|
93
99
|
voice=voice,
|
|
94
100
|
speech_engine=speech_engine,
|
|
95
|
-
|
|
101
|
+
region=region,
|
|
96
102
|
language=language,
|
|
97
103
|
sample_rate=sample_rate,
|
|
98
104
|
)
|
|
99
|
-
self._session = session or get_session()
|
|
100
|
-
|
|
101
|
-
def _get_client(self):
|
|
102
|
-
return self._session.create_client(
|
|
103
|
-
"polly",
|
|
104
|
-
region_name=self._opts.speech_region,
|
|
105
|
-
aws_access_key_id=self._api_key,
|
|
106
|
-
aws_secret_access_key=self._api_secret,
|
|
107
|
-
)
|
|
108
105
|
|
|
109
106
|
def synthesize(
|
|
110
107
|
self,
|
|
111
108
|
text: str,
|
|
112
109
|
*,
|
|
113
|
-
conn_options:
|
|
114
|
-
) ->
|
|
110
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
111
|
+
) -> ChunkedStream:
|
|
115
112
|
return ChunkedStream(
|
|
116
113
|
tts=self,
|
|
117
114
|
text=text,
|
|
118
115
|
conn_options=conn_options,
|
|
116
|
+
session=self._session,
|
|
119
117
|
opts=self._opts,
|
|
120
|
-
get_client=self._get_client,
|
|
121
118
|
)
|
|
122
119
|
|
|
123
120
|
|
|
@@ -127,28 +124,30 @@ class ChunkedStream(tts.ChunkedStream):
|
|
|
127
124
|
*,
|
|
128
125
|
tts: TTS,
|
|
129
126
|
text: str,
|
|
130
|
-
|
|
127
|
+
session: aioboto3.Session,
|
|
128
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
131
129
|
opts: _TTSOptions,
|
|
132
|
-
get_client: Callable[[], Any],
|
|
133
130
|
) -> None:
|
|
134
131
|
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
|
|
135
132
|
self._opts = opts
|
|
136
|
-
self._get_client = get_client
|
|
137
133
|
self._segment_id = utils.shortuuid()
|
|
134
|
+
self._session = session
|
|
138
135
|
|
|
139
136
|
async def _run(self):
|
|
140
137
|
request_id = utils.shortuuid()
|
|
141
138
|
|
|
142
139
|
try:
|
|
143
|
-
async with self.
|
|
140
|
+
async with self._session.client("polly") as client:
|
|
144
141
|
params = {
|
|
145
142
|
"Text": self._input_text,
|
|
146
143
|
"OutputFormat": "mp3",
|
|
147
|
-
"Engine": self._opts.speech_engine
|
|
148
|
-
|
|
144
|
+
"Engine": self._opts.speech_engine
|
|
145
|
+
if is_given(self._opts.speech_engine)
|
|
146
|
+
else DEFAULT_SPEECH_ENGINE,
|
|
147
|
+
"VoiceId": self._opts.voice if is_given(self._opts.voice) else DEFAULT_VOICE,
|
|
149
148
|
"TextType": "text",
|
|
150
149
|
"SampleRate": str(self._opts.sample_rate),
|
|
151
|
-
"LanguageCode": self._opts.language,
|
|
150
|
+
"LanguageCode": self._opts.language if is_given(self._opts.language) else None,
|
|
152
151
|
}
|
|
153
152
|
response = await client.synthesize_speech(**_strip_nones(params))
|
|
154
153
|
if "AudioStream" in response:
|
|
@@ -194,7 +193,3 @@ class ChunkedStream(tts.ChunkedStream):
|
|
|
194
193
|
) from e
|
|
195
194
|
except Exception as e:
|
|
196
195
|
raise APIConnectionError() from e
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
|
|
200
|
-
return {k: v for k, v in d.items() if v is not None}
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import aioboto3
|
|
7
|
+
import boto3
|
|
8
|
+
from botocore.exceptions import NoCredentialsError
|
|
9
|
+
|
|
10
|
+
from livekit.agents import llm
|
|
11
|
+
from livekit.agents.llm import ChatContext, FunctionTool, ImageContent, utils
|
|
12
|
+
|
|
13
|
+
__all__ = ["to_fnc_ctx", "to_chat_ctx", "get_aws_async_session"]
|
|
14
|
+
DEFAULT_REGION = "us-east-1"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_aws_async_session(
|
|
18
|
+
region: str | None = None,
|
|
19
|
+
api_key: str | None = None,
|
|
20
|
+
api_secret: str | None = None,
|
|
21
|
+
) -> aioboto3.Session:
|
|
22
|
+
_validate_aws_credentials(api_key, api_secret)
|
|
23
|
+
session = aioboto3.Session(
|
|
24
|
+
aws_access_key_id=api_key,
|
|
25
|
+
aws_secret_access_key=api_secret,
|
|
26
|
+
region_name=region or DEFAULT_REGION,
|
|
27
|
+
)
|
|
28
|
+
return session
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _validate_aws_credentials(
|
|
32
|
+
api_key: str | None = None,
|
|
33
|
+
api_secret: str | None = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
try:
|
|
36
|
+
session = boto3.Session(aws_access_key_id=api_key, aws_secret_access_key=api_secret)
|
|
37
|
+
creds = session.get_credentials()
|
|
38
|
+
if not creds:
|
|
39
|
+
raise ValueError("No credentials found")
|
|
40
|
+
except (NoCredentialsError, Exception) as e:
|
|
41
|
+
raise ValueError(f"Unable to locate valid AWS credentials: {str(e)}") from e
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def to_fnc_ctx(fncs: list[FunctionTool]) -> list[dict]:
|
|
45
|
+
return [_build_tool_spec(fnc) for fnc in fncs]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def to_chat_ctx(chat_ctx: ChatContext, cache_key: Any) -> tuple[list[dict], dict | None]:
|
|
49
|
+
messages: list[dict] = []
|
|
50
|
+
system_message: dict | None = None
|
|
51
|
+
current_role: str | None = None
|
|
52
|
+
current_content: list[dict] = []
|
|
53
|
+
|
|
54
|
+
for msg in chat_ctx.items:
|
|
55
|
+
if msg.type == "message" and msg.role == "system":
|
|
56
|
+
for content in msg.content:
|
|
57
|
+
if content and isinstance(content, str):
|
|
58
|
+
system_message = {"text": content}
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
if msg.type == "message":
|
|
62
|
+
role = "assistant" if msg.role == "assistant" else "user"
|
|
63
|
+
elif msg.type == "function_call":
|
|
64
|
+
role = "assistant"
|
|
65
|
+
elif msg.type == "function_call_output":
|
|
66
|
+
role = "user"
|
|
67
|
+
|
|
68
|
+
# if the effective role changed, finalize the previous turn.
|
|
69
|
+
if role != current_role:
|
|
70
|
+
if current_content and current_role is not None:
|
|
71
|
+
messages.append({"role": current_role, "content": current_content})
|
|
72
|
+
current_content = []
|
|
73
|
+
current_role = role
|
|
74
|
+
|
|
75
|
+
if msg.type == "message":
|
|
76
|
+
for content in msg.content:
|
|
77
|
+
if content and isinstance(content, str):
|
|
78
|
+
current_content.append({"text": content})
|
|
79
|
+
elif isinstance(content, ImageContent):
|
|
80
|
+
current_content.append(_build_image(content, cache_key))
|
|
81
|
+
elif msg.type == "function_call":
|
|
82
|
+
current_content.append(
|
|
83
|
+
{
|
|
84
|
+
"toolUse": {
|
|
85
|
+
"toolUseId": msg.call_id,
|
|
86
|
+
"name": msg.name,
|
|
87
|
+
"input": json.loads(msg.arguments or "{}"),
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
)
|
|
91
|
+
elif msg.type == "function_call_output":
|
|
92
|
+
tool_response = {
|
|
93
|
+
"toolResult": {
|
|
94
|
+
"toolUseId": msg.call_id,
|
|
95
|
+
"content": [],
|
|
96
|
+
"status": "success",
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
if isinstance(msg.output, dict):
|
|
100
|
+
tool_response["toolResult"]["content"].append({"json": msg.output})
|
|
101
|
+
elif isinstance(msg.output, str):
|
|
102
|
+
tool_response["toolResult"]["content"].append({"text": msg.output})
|
|
103
|
+
current_content.append(tool_response)
|
|
104
|
+
|
|
105
|
+
# Finalize the last message if there’s any content left
|
|
106
|
+
if current_role is not None and current_content:
|
|
107
|
+
messages.append({"role": current_role, "content": current_content})
|
|
108
|
+
|
|
109
|
+
# Ensure the message list starts with a "user" message
|
|
110
|
+
if not messages or messages[0]["role"] != "user":
|
|
111
|
+
messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})
|
|
112
|
+
|
|
113
|
+
return messages, system_message
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _build_tool_spec(fnc: FunctionTool) -> dict:
|
|
117
|
+
fnc = llm.utils.build_legacy_openai_schema(fnc, internally_tagged=True)
|
|
118
|
+
return {
|
|
119
|
+
"toolSpec": _strip_nones(
|
|
120
|
+
{
|
|
121
|
+
"name": fnc["name"],
|
|
122
|
+
"description": fnc["description"] if fnc["description"] else None,
|
|
123
|
+
"inputSchema": {"json": fnc["parameters"] if fnc["parameters"] else {}},
|
|
124
|
+
}
|
|
125
|
+
)
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _build_image(image: ImageContent, cache_key: Any) -> dict:
|
|
130
|
+
img = utils.serialize_image(image)
|
|
131
|
+
if img.external_url:
|
|
132
|
+
raise ValueError("external_url is not supported by AWS Bedrock.")
|
|
133
|
+
if cache_key not in image._cache:
|
|
134
|
+
image._cache[cache_key] = img.data_bytes
|
|
135
|
+
return {
|
|
136
|
+
"image": {
|
|
137
|
+
"format": "jpeg",
|
|
138
|
+
"source": {"bytes": image._cache[cache_key]},
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _strip_nones(d: dict) -> dict:
|
|
144
|
+
return {k: v for k, v in d.items() if v is not None}
|
livekit/plugins/aws/version.py
CHANGED