livekit-plugins-aws 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of livekit-plugins-aws might be problematic. Click here for more details.
- livekit/plugins/aws/__init__.py +30 -0
- livekit/plugins/aws/_utils.py +216 -0
- livekit/plugins/aws/llm.py +350 -0
- livekit/plugins/aws/log.py +3 -0
- livekit/plugins/aws/models.py +48 -0
- livekit/plugins/aws/py.typed +0 -0
- livekit/plugins/aws/stt.py +218 -0
- livekit/plugins/aws/tts.py +202 -0
- livekit/plugins/aws/version.py +15 -0
- livekit_plugins_aws-0.1.0.dist-info/METADATA +53 -0
- livekit_plugins_aws-0.1.0.dist-info/RECORD +13 -0
- livekit_plugins_aws-0.1.0.dist-info/WHEEL +5 -0
- livekit_plugins_aws-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright 2023 LiveKit, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .llm import LLM
|
|
16
|
+
from .stt import STT, SpeechStream
|
|
17
|
+
from .tts import TTS, ChunkedStream
|
|
18
|
+
from .version import __version__
|
|
19
|
+
|
|
20
|
+
__all__ = ["STT", "SpeechStream", "TTS", "ChunkedStream", "LLM", "__version__"]
|
|
21
|
+
|
|
22
|
+
from livekit.agents import Plugin
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AWSPlugin(Plugin):
|
|
26
|
+
def __init__(self) -> None:
|
|
27
|
+
super().__init__(__name__, __version__, __package__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
Plugin.register_plugin(AWSPlugin())
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import inspect
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any, Dict, List, Optional, Tuple, get_args, get_origin
|
|
8
|
+
|
|
9
|
+
import boto3
|
|
10
|
+
from livekit import rtc
|
|
11
|
+
from livekit.agents import llm, utils
|
|
12
|
+
from livekit.agents.llm.function_context import _is_optional_type
|
|
13
|
+
|
|
14
|
+
__all__ = ["_build_aws_ctx", "_build_tools", "_get_aws_credentials"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_aws_credentials(
|
|
18
|
+
api_key: Optional[str], api_secret: Optional[str], region: Optional[str]
|
|
19
|
+
):
|
|
20
|
+
region = region or os.environ.get("AWS_DEFAULT_REGION")
|
|
21
|
+
if not region:
|
|
22
|
+
raise ValueError(
|
|
23
|
+
"AWS_DEFAULT_REGION must be set using the argument or by setting the AWS_DEFAULT_REGION environment variable."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# If API key and secret are provided, create a session with them
|
|
27
|
+
if api_key and api_secret:
|
|
28
|
+
session = boto3.Session(
|
|
29
|
+
aws_access_key_id=api_key,
|
|
30
|
+
aws_secret_access_key=api_secret,
|
|
31
|
+
region_name=region,
|
|
32
|
+
)
|
|
33
|
+
else:
|
|
34
|
+
session = boto3.Session(region_name=region)
|
|
35
|
+
|
|
36
|
+
credentials = session.get_credentials()
|
|
37
|
+
if not credentials or not credentials.access_key or not credentials.secret_key:
|
|
38
|
+
raise ValueError("No valid AWS credentials found.")
|
|
39
|
+
return credentials.access_key, credentials.secret_key
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
JSON_SCHEMA_TYPE_MAP: Dict[type, str] = {
|
|
43
|
+
str: "string",
|
|
44
|
+
int: "integer",
|
|
45
|
+
float: "number",
|
|
46
|
+
bool: "boolean",
|
|
47
|
+
dict: "object",
|
|
48
|
+
list: "array",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _build_parameters(arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
53
|
+
properties: Dict[str, dict] = {}
|
|
54
|
+
required: List[str] = []
|
|
55
|
+
|
|
56
|
+
for arg_name, arg_info in arguments.items():
|
|
57
|
+
prop = {}
|
|
58
|
+
if hasattr(arg_info, "description") and arg_info.description:
|
|
59
|
+
prop["description"] = arg_info.description
|
|
60
|
+
|
|
61
|
+
_, py_type = _is_optional_type(arg_info.type)
|
|
62
|
+
origin = get_origin(py_type)
|
|
63
|
+
if origin is list:
|
|
64
|
+
item_type = get_args(py_type)[0]
|
|
65
|
+
if item_type not in JSON_SCHEMA_TYPE_MAP:
|
|
66
|
+
raise ValueError(f"Unsupported type: {item_type}")
|
|
67
|
+
prop["type"] = "array"
|
|
68
|
+
prop["items"] = {"type": JSON_SCHEMA_TYPE_MAP[item_type]}
|
|
69
|
+
|
|
70
|
+
if hasattr(arg_info, "choices") and arg_info.choices:
|
|
71
|
+
prop["items"]["enum"] = list(arg_info.choices)
|
|
72
|
+
else:
|
|
73
|
+
if py_type not in JSON_SCHEMA_TYPE_MAP:
|
|
74
|
+
raise ValueError(f"Unsupported type: {py_type}")
|
|
75
|
+
|
|
76
|
+
prop["type"] = JSON_SCHEMA_TYPE_MAP[py_type]
|
|
77
|
+
|
|
78
|
+
if arg_info.choices:
|
|
79
|
+
prop["enum"] = list(arg_info.choices)
|
|
80
|
+
|
|
81
|
+
properties[arg_name] = prop
|
|
82
|
+
|
|
83
|
+
if arg_info.default is inspect.Parameter.empty:
|
|
84
|
+
required.append(arg_name)
|
|
85
|
+
|
|
86
|
+
if properties:
|
|
87
|
+
parameters = {"json": {"type": "object", "properties": properties}}
|
|
88
|
+
if required:
|
|
89
|
+
parameters["json"]["required"] = required
|
|
90
|
+
|
|
91
|
+
return parameters
|
|
92
|
+
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _build_tools(fnc_ctx: Any) -> List[dict]:
|
|
97
|
+
tools: List[dict] = []
|
|
98
|
+
for fnc_info in fnc_ctx.ai_functions.values():
|
|
99
|
+
parameters = _build_parameters(fnc_info.arguments)
|
|
100
|
+
|
|
101
|
+
func_decl = {
|
|
102
|
+
"toolSpec": {
|
|
103
|
+
"name": fnc_info.name,
|
|
104
|
+
"description": fnc_info.description,
|
|
105
|
+
"inputSchema": parameters
|
|
106
|
+
if parameters
|
|
107
|
+
else {"json": {"type": "object", "properties": {}}},
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
tools.append(func_decl)
|
|
112
|
+
return tools
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _build_image(image: llm.ChatImage, cache_key: Any) -> dict:
|
|
116
|
+
if isinstance(image.image, str):
|
|
117
|
+
if image.image.startswith("data:image/jpeg;base64,"):
|
|
118
|
+
base64_data = image.image.split(",", 1)[1]
|
|
119
|
+
try:
|
|
120
|
+
image_bytes = base64.b64decode(base64_data)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
raise ValueError("Invalid base64 data in image URL") from e
|
|
123
|
+
|
|
124
|
+
return {"image": {"format": "jpeg", "source": {"bytes": image_bytes}}}
|
|
125
|
+
else:
|
|
126
|
+
return {"image": {"format": "jpeg", "source": {"uri": image.image}}}
|
|
127
|
+
|
|
128
|
+
elif isinstance(image.image, rtc.VideoFrame):
|
|
129
|
+
if cache_key not in image._cache:
|
|
130
|
+
opts = utils.images.EncodeOptions()
|
|
131
|
+
if image.inference_width and image.inference_height:
|
|
132
|
+
opts.resize_options = utils.images.ResizeOptions(
|
|
133
|
+
width=image.inference_width,
|
|
134
|
+
height=image.inference_height,
|
|
135
|
+
strategy="scale_aspect_fit",
|
|
136
|
+
)
|
|
137
|
+
image._cache[cache_key] = utils.images.encode(image.image, opts)
|
|
138
|
+
|
|
139
|
+
return {
|
|
140
|
+
"image": {
|
|
141
|
+
"format": "jpeg",
|
|
142
|
+
"source": {
|
|
143
|
+
"bytes": image._cache[cache_key],
|
|
144
|
+
},
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
raise ValueError(f"Unsupported image type: {type(image.image)}")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _build_aws_ctx(
|
|
151
|
+
chat_ctx: llm.ChatContext, cache_key: Any
|
|
152
|
+
) -> Tuple[List[dict], Optional[dict]]:
|
|
153
|
+
messages: List[dict] = []
|
|
154
|
+
system: Optional[dict] = None
|
|
155
|
+
current_role: Optional[str] = None
|
|
156
|
+
current_content: List[dict] = []
|
|
157
|
+
|
|
158
|
+
for msg in chat_ctx.messages:
|
|
159
|
+
if msg.role == "system":
|
|
160
|
+
if isinstance(msg.content, str):
|
|
161
|
+
system = {"text": msg.content}
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
if msg.role == "assistant":
|
|
165
|
+
role = "assistant"
|
|
166
|
+
else:
|
|
167
|
+
role = "user"
|
|
168
|
+
|
|
169
|
+
if role != current_role:
|
|
170
|
+
if current_role is not None and current_content:
|
|
171
|
+
messages.append({"role": current_role, "content": current_content})
|
|
172
|
+
current_role = role
|
|
173
|
+
current_content = []
|
|
174
|
+
|
|
175
|
+
if msg.tool_calls:
|
|
176
|
+
for fnc in msg.tool_calls:
|
|
177
|
+
current_content.append(
|
|
178
|
+
{
|
|
179
|
+
"toolUse": {
|
|
180
|
+
"toolUseId": fnc.tool_call_id,
|
|
181
|
+
"name": fnc.function_info.name,
|
|
182
|
+
"input": fnc.arguments,
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if msg.role == "tool":
|
|
188
|
+
tool_response: dict = {
|
|
189
|
+
"toolResult": {
|
|
190
|
+
"toolUseId": msg.tool_call_id,
|
|
191
|
+
"content": [],
|
|
192
|
+
"status": "success",
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
if isinstance(msg.content, dict):
|
|
196
|
+
tool_response["toolResult"]["content"].append({"json": msg.content})
|
|
197
|
+
elif isinstance(msg.content, str):
|
|
198
|
+
tool_response["toolResult"]["content"].append({"text": msg.content})
|
|
199
|
+
current_content.append(tool_response)
|
|
200
|
+
else:
|
|
201
|
+
if msg.content:
|
|
202
|
+
if isinstance(msg.content, str):
|
|
203
|
+
current_content.append({"text": msg.content})
|
|
204
|
+
elif isinstance(msg.content, dict):
|
|
205
|
+
current_content.append({"text": json.dumps(msg.content)})
|
|
206
|
+
elif isinstance(msg.content, list):
|
|
207
|
+
for item in msg.content:
|
|
208
|
+
if isinstance(item, str):
|
|
209
|
+
current_content.append({"text": item})
|
|
210
|
+
elif isinstance(item, llm.ChatImage):
|
|
211
|
+
current_content.append(_build_image(item, cache_key))
|
|
212
|
+
|
|
213
|
+
if current_role is not None and current_content:
|
|
214
|
+
messages.append({"role": current_role, "content": current_content})
|
|
215
|
+
|
|
216
|
+
return messages, system
|
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
# Copyright 2023 LiveKit, Inc.
|
|
2
|
+
#
|
|
3
|
+
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import os
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from typing import Any, Literal, MutableSet, Union
|
|
21
|
+
|
|
22
|
+
import boto3
|
|
23
|
+
from livekit.agents import (
|
|
24
|
+
APIConnectionError,
|
|
25
|
+
APIStatusError,
|
|
26
|
+
llm,
|
|
27
|
+
)
|
|
28
|
+
from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
|
|
29
|
+
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
|
|
30
|
+
|
|
31
|
+
from ._utils import _build_aws_ctx, _build_tools, _get_aws_credentials
|
|
32
|
+
from .log import logger
|
|
33
|
+
|
|
34
|
+
TEXT_MODEL = Literal["anthropic.claude-3-5-sonnet-20241022-v2:0"]
|
|
35
|
+
DEFAULT_REGION = "us-east-1"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class LLMOptions:
|
|
40
|
+
model: TEXT_MODEL | str
|
|
41
|
+
temperature: float | None
|
|
42
|
+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto"
|
|
43
|
+
max_output_tokens: int | None = None
|
|
44
|
+
top_p: float | None = None
|
|
45
|
+
additional_request_fields: dict[str, Any] | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LLM(llm.LLM):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
model: TEXT_MODEL | str = "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
53
|
+
api_key: str | None = None,
|
|
54
|
+
api_secret: str | None = None,
|
|
55
|
+
region: str = "us-east-1",
|
|
56
|
+
temperature: float = 0.8,
|
|
57
|
+
max_output_tokens: int | None = None,
|
|
58
|
+
top_p: float | None = None,
|
|
59
|
+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
|
|
60
|
+
additional_request_fields: dict[str, Any] | None = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""
|
|
63
|
+
Create a new instance of AWS Bedrock LLM.
|
|
64
|
+
|
|
65
|
+
``api_key`` and ``api_secret`` must be set to your AWS Access key id and secret access key, either using the argument or by setting the
|
|
66
|
+
``AWS_ACCESS_KEY_ID`` and ``AWS_SECRET_ACCESS_KEY`` environmental variables.
|
|
67
|
+
|
|
68
|
+
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html for more details on the the AWS Bedrock Runtime API.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
model (TEXT_MODEL, optional): model or inference profile arn to use(https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-use.html). Defaults to 'anthropic.claude-3-5-sonnet-20240620-v1:0'.
|
|
72
|
+
api_key(str, optional): AWS access key id.
|
|
73
|
+
api_secret(str, optional): AWS secret access key
|
|
74
|
+
region (str, optional): The region to use for AWS API requests. Defaults value is "us-east-1".
|
|
75
|
+
temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
|
|
76
|
+
max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
|
|
77
|
+
top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
|
|
78
|
+
tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
|
|
79
|
+
additional_request_fields (dict[str, Any], optional): Additional request fields to send to the AWS Bedrock Converse API. Defaults to None.
|
|
80
|
+
"""
|
|
81
|
+
super().__init__(
|
|
82
|
+
capabilities=LLMCapabilities(
|
|
83
|
+
supports_choices_on_int=True,
|
|
84
|
+
requires_persistent_functions=True,
|
|
85
|
+
)
|
|
86
|
+
)
|
|
87
|
+
self._api_key, self._api_secret = _get_aws_credentials(
|
|
88
|
+
api_key, api_secret, region
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
self._model = model or os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
|
|
92
|
+
if not self._model:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
"model or inference profile arn must be set using the argument or by setting the BEDROCK_INFERENCE_PROFILE_ARN environment variable."
|
|
95
|
+
)
|
|
96
|
+
self._opts = LLMOptions(
|
|
97
|
+
model=self._model,
|
|
98
|
+
temperature=temperature,
|
|
99
|
+
tool_choice=tool_choice,
|
|
100
|
+
max_output_tokens=max_output_tokens,
|
|
101
|
+
top_p=top_p,
|
|
102
|
+
additional_request_fields=additional_request_fields,
|
|
103
|
+
)
|
|
104
|
+
self._region = region
|
|
105
|
+
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
|
|
106
|
+
|
|
107
|
+
def chat(
|
|
108
|
+
self,
|
|
109
|
+
*,
|
|
110
|
+
chat_ctx: llm.ChatContext,
|
|
111
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
112
|
+
fnc_ctx: llm.FunctionContext | None = None,
|
|
113
|
+
temperature: float | None = None,
|
|
114
|
+
n: int | None = 1,
|
|
115
|
+
parallel_tool_calls: bool | None = None,
|
|
116
|
+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
|
|
117
|
+
| None = None,
|
|
118
|
+
) -> "LLMStream":
|
|
119
|
+
if tool_choice is None:
|
|
120
|
+
tool_choice = self._opts.tool_choice
|
|
121
|
+
|
|
122
|
+
if temperature is None:
|
|
123
|
+
temperature = self._opts.temperature
|
|
124
|
+
|
|
125
|
+
return LLMStream(
|
|
126
|
+
self,
|
|
127
|
+
model=self._opts.model,
|
|
128
|
+
aws_access_key_id=self._api_key,
|
|
129
|
+
aws_secret_access_key=self._api_secret,
|
|
130
|
+
region_name=self._region,
|
|
131
|
+
max_output_tokens=self._opts.max_output_tokens,
|
|
132
|
+
top_p=self._opts.top_p,
|
|
133
|
+
additional_request_fields=self._opts.additional_request_fields,
|
|
134
|
+
chat_ctx=chat_ctx,
|
|
135
|
+
fnc_ctx=fnc_ctx,
|
|
136
|
+
conn_options=conn_options,
|
|
137
|
+
temperature=temperature,
|
|
138
|
+
tool_choice=tool_choice,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class LLMStream(llm.LLMStream):
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
llm: LLM,
|
|
146
|
+
*,
|
|
147
|
+
model: str | TEXT_MODEL,
|
|
148
|
+
aws_access_key_id: str | None,
|
|
149
|
+
aws_secret_access_key: str | None,
|
|
150
|
+
region_name: str,
|
|
151
|
+
chat_ctx: llm.ChatContext,
|
|
152
|
+
conn_options: APIConnectOptions,
|
|
153
|
+
fnc_ctx: llm.FunctionContext | None,
|
|
154
|
+
temperature: float | None,
|
|
155
|
+
max_output_tokens: int | None,
|
|
156
|
+
top_p: float | None,
|
|
157
|
+
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
|
|
158
|
+
additional_request_fields: dict[str, Any] | None,
|
|
159
|
+
) -> None:
|
|
160
|
+
super().__init__(
|
|
161
|
+
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
|
|
162
|
+
)
|
|
163
|
+
self._client = boto3.client(
|
|
164
|
+
"bedrock-runtime",
|
|
165
|
+
region_name=region_name,
|
|
166
|
+
aws_access_key_id=aws_access_key_id,
|
|
167
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
168
|
+
)
|
|
169
|
+
self._model = model
|
|
170
|
+
self._llm: LLM = llm
|
|
171
|
+
self._max_output_tokens = max_output_tokens
|
|
172
|
+
self._top_p = top_p
|
|
173
|
+
self._temperature = temperature
|
|
174
|
+
self._tool_choice = tool_choice
|
|
175
|
+
self._additional_request_fields = additional_request_fields
|
|
176
|
+
|
|
177
|
+
async def _run(self) -> None:
|
|
178
|
+
self._tool_call_id: str | None = None
|
|
179
|
+
self._fnc_name: str | None = None
|
|
180
|
+
self._fnc_raw_arguments: str | None = None
|
|
181
|
+
self._text: str = ""
|
|
182
|
+
retryable = True
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
opts: dict[str, Any] = {}
|
|
186
|
+
messages, system_instruction = _build_aws_ctx(self._chat_ctx, id(self))
|
|
187
|
+
messages = _merge_messages(messages)
|
|
188
|
+
|
|
189
|
+
def _get_tool_config() -> dict[str, Any] | None:
|
|
190
|
+
if not (self._fnc_ctx and self._fnc_ctx.ai_functions):
|
|
191
|
+
return None
|
|
192
|
+
|
|
193
|
+
tools = _build_tools(self._fnc_ctx)
|
|
194
|
+
config: dict[str, Any] = {"tools": tools}
|
|
195
|
+
|
|
196
|
+
if isinstance(self._tool_choice, ToolChoice):
|
|
197
|
+
config["toolChoice"] = {"tool": {"name": self._tool_choice.name}}
|
|
198
|
+
elif self._tool_choice == "required":
|
|
199
|
+
config["toolChoice"] = {"any": {}}
|
|
200
|
+
elif self._tool_choice == "auto":
|
|
201
|
+
config["toolChoice"] = {"auto": {}}
|
|
202
|
+
else:
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
return config
|
|
206
|
+
|
|
207
|
+
tool_config = _get_tool_config()
|
|
208
|
+
if tool_config:
|
|
209
|
+
opts["toolConfig"] = tool_config
|
|
210
|
+
|
|
211
|
+
if self._additional_request_fields:
|
|
212
|
+
opts["additionalModelRequestFields"] = _strip_nones(
|
|
213
|
+
self._additional_request_fields
|
|
214
|
+
)
|
|
215
|
+
if system_instruction:
|
|
216
|
+
opts["system"] = [system_instruction]
|
|
217
|
+
|
|
218
|
+
inference_config = _strip_nones(
|
|
219
|
+
{
|
|
220
|
+
"maxTokens": self._max_output_tokens,
|
|
221
|
+
"temperature": self._temperature,
|
|
222
|
+
"topP": self._top_p,
|
|
223
|
+
}
|
|
224
|
+
)
|
|
225
|
+
response = self._client.converse_stream(
|
|
226
|
+
modelId=self._model,
|
|
227
|
+
messages=messages,
|
|
228
|
+
inferenceConfig=inference_config,
|
|
229
|
+
**_strip_nones(opts),
|
|
230
|
+
) # type: ignore
|
|
231
|
+
|
|
232
|
+
request_id = response["ResponseMetadata"]["RequestId"]
|
|
233
|
+
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
|
234
|
+
raise APIStatusError(
|
|
235
|
+
f"aws bedrock llm: error generating content: {response}",
|
|
236
|
+
retryable=False,
|
|
237
|
+
request_id=request_id,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
for chunk in response["stream"]:
|
|
241
|
+
chat_chunk = self._parse_chunk(request_id, chunk)
|
|
242
|
+
if chat_chunk is not None:
|
|
243
|
+
retryable = False
|
|
244
|
+
self._event_ch.send_nowait(chat_chunk)
|
|
245
|
+
|
|
246
|
+
# Let other coroutines run
|
|
247
|
+
await asyncio.sleep(0)
|
|
248
|
+
|
|
249
|
+
except Exception as e:
|
|
250
|
+
raise APIConnectionError(
|
|
251
|
+
f"aws bedrock llm: error generating content: {e}",
|
|
252
|
+
retryable=retryable,
|
|
253
|
+
) from e
|
|
254
|
+
|
|
255
|
+
def _parse_chunk(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
|
|
256
|
+
if "contentBlockStart" in chunk:
|
|
257
|
+
tool_use = chunk["contentBlockStart"]["start"]["toolUse"]
|
|
258
|
+
self._tool_call_id = tool_use["toolUseId"]
|
|
259
|
+
self._fnc_name = tool_use["name"]
|
|
260
|
+
self._fnc_raw_arguments = ""
|
|
261
|
+
elif "contentBlockDelta" in chunk:
|
|
262
|
+
delta = chunk["contentBlockDelta"]["delta"]
|
|
263
|
+
if "toolUse" in delta:
|
|
264
|
+
self._fnc_raw_arguments += delta["toolUse"]["input"]
|
|
265
|
+
elif "text" in delta:
|
|
266
|
+
self._text += delta["text"]
|
|
267
|
+
elif "contentBlockStop" in chunk:
|
|
268
|
+
if self._text:
|
|
269
|
+
chat_chunk = llm.ChatChunk(
|
|
270
|
+
request_id=request_id,
|
|
271
|
+
choices=[
|
|
272
|
+
llm.Choice(
|
|
273
|
+
delta=llm.ChoiceDelta(content=self._text, role="assistant"),
|
|
274
|
+
index=chunk["contentBlockStop"]["contentBlockIndex"],
|
|
275
|
+
)
|
|
276
|
+
],
|
|
277
|
+
)
|
|
278
|
+
self._text = ""
|
|
279
|
+
return chat_chunk
|
|
280
|
+
elif self._tool_call_id:
|
|
281
|
+
return self._try_build_function(request_id, chunk)
|
|
282
|
+
|
|
283
|
+
return None
|
|
284
|
+
|
|
285
|
+
def _try_build_function(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
|
|
286
|
+
if self._tool_call_id is None:
|
|
287
|
+
logger.warning("aws bedrock llm: no tool call id in the response")
|
|
288
|
+
return None
|
|
289
|
+
if self._fnc_name is None:
|
|
290
|
+
logger.warning("aws bedrock llm: no function name in the response")
|
|
291
|
+
return None
|
|
292
|
+
if self._fnc_raw_arguments is None:
|
|
293
|
+
logger.warning("aws bedrock llm: no function arguments in the response")
|
|
294
|
+
return None
|
|
295
|
+
if self._fnc_ctx is None:
|
|
296
|
+
logger.warning(
|
|
297
|
+
"aws bedrock llm: stream tried to run function without function context"
|
|
298
|
+
)
|
|
299
|
+
return None
|
|
300
|
+
|
|
301
|
+
fnc_info = _create_ai_function_info(
|
|
302
|
+
self._fnc_ctx,
|
|
303
|
+
self._tool_call_id,
|
|
304
|
+
self._fnc_name,
|
|
305
|
+
self._fnc_raw_arguments,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
|
|
309
|
+
self._function_calls_info.append(fnc_info)
|
|
310
|
+
|
|
311
|
+
return llm.ChatChunk(
|
|
312
|
+
request_id=request_id,
|
|
313
|
+
choices=[
|
|
314
|
+
llm.Choice(
|
|
315
|
+
delta=llm.ChoiceDelta(
|
|
316
|
+
role="assistant",
|
|
317
|
+
tool_calls=[fnc_info],
|
|
318
|
+
),
|
|
319
|
+
index=chunk["contentBlockStop"]["contentBlockIndex"],
|
|
320
|
+
)
|
|
321
|
+
],
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _merge_messages(
|
|
326
|
+
messages: list[dict],
|
|
327
|
+
) -> list[dict]:
|
|
328
|
+
# Anthropic enforces alternating messages
|
|
329
|
+
combined_messages: list[dict] = []
|
|
330
|
+
for m in messages:
|
|
331
|
+
if len(combined_messages) == 0 or m["role"] != combined_messages[-1]["role"]:
|
|
332
|
+
combined_messages.append(m)
|
|
333
|
+
continue
|
|
334
|
+
last_message = combined_messages[-1]
|
|
335
|
+
if not isinstance(last_message["content"], list) or not isinstance(
|
|
336
|
+
m["content"], list
|
|
337
|
+
):
|
|
338
|
+
logger.error("message content is not a list")
|
|
339
|
+
continue
|
|
340
|
+
|
|
341
|
+
last_message["content"].extend(m["content"])
|
|
342
|
+
|
|
343
|
+
if len(combined_messages) == 0 or combined_messages[0]["role"] != "user":
|
|
344
|
+
combined_messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})
|
|
345
|
+
|
|
346
|
+
return combined_messages
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
|
|
350
|
+
return {k: v for k, v in d.items() if v is not None}
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
TTS_SPEECH_ENGINE = Literal["standard", "neural", "long-form", "generative"]
|
|
4
|
+
TTS_LANGUAGE = Literal[
|
|
5
|
+
"arb",
|
|
6
|
+
"cmn-CN",
|
|
7
|
+
"cy-GB",
|
|
8
|
+
"da-DK",
|
|
9
|
+
"de-DE",
|
|
10
|
+
"en-AU",
|
|
11
|
+
"en-GB",
|
|
12
|
+
"en-GB-WLS",
|
|
13
|
+
"en-IN",
|
|
14
|
+
"en-US",
|
|
15
|
+
"es-ES",
|
|
16
|
+
"es-MX",
|
|
17
|
+
"es-US",
|
|
18
|
+
"fr-CA",
|
|
19
|
+
"fr-FR",
|
|
20
|
+
"is-IS",
|
|
21
|
+
"it-IT",
|
|
22
|
+
"ja-JP",
|
|
23
|
+
"hi-IN",
|
|
24
|
+
"ko-KR",
|
|
25
|
+
"nb-NO",
|
|
26
|
+
"nl-NL",
|
|
27
|
+
"pl-PL",
|
|
28
|
+
"pt-BR",
|
|
29
|
+
"pt-PT",
|
|
30
|
+
"ro-RO",
|
|
31
|
+
"ru-RU",
|
|
32
|
+
"sv-SE",
|
|
33
|
+
"tr-TR",
|
|
34
|
+
"en-NZ",
|
|
35
|
+
"en-ZA",
|
|
36
|
+
"ca-ES",
|
|
37
|
+
"de-AT",
|
|
38
|
+
"yue-CN",
|
|
39
|
+
"ar-AE",
|
|
40
|
+
"fi-FI",
|
|
41
|
+
"en-IE",
|
|
42
|
+
"nl-BE",
|
|
43
|
+
"fr-BE",
|
|
44
|
+
"cs-CZ",
|
|
45
|
+
"de-CH",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
TTS_OUTPUT_FORMAT = Literal["pcm", "mp3"]
|
|
File without changes
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
from amazon_transcribe.client import TranscribeStreamingClient
|
|
20
|
+
from amazon_transcribe.model import Result, TranscriptEvent
|
|
21
|
+
from livekit import rtc
|
|
22
|
+
from livekit.agents import (
|
|
23
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
|
24
|
+
APIConnectOptions,
|
|
25
|
+
stt,
|
|
26
|
+
utils,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
from ._utils import _get_aws_credentials
|
|
30
|
+
from .log import logger
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class STTOptions:
|
|
35
|
+
speech_region: str
|
|
36
|
+
sample_rate: int
|
|
37
|
+
language: str
|
|
38
|
+
encoding: str
|
|
39
|
+
vocabulary_name: Optional[str]
|
|
40
|
+
session_id: Optional[str]
|
|
41
|
+
vocab_filter_method: Optional[str]
|
|
42
|
+
vocab_filter_name: Optional[str]
|
|
43
|
+
show_speaker_label: Optional[bool]
|
|
44
|
+
enable_channel_identification: Optional[bool]
|
|
45
|
+
number_of_channels: Optional[int]
|
|
46
|
+
enable_partial_results_stabilization: Optional[bool]
|
|
47
|
+
partial_results_stability: Optional[str]
|
|
48
|
+
language_model_name: Optional[str]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class STT(stt.STT):
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
*,
|
|
55
|
+
speech_region: str = "us-east-1",
|
|
56
|
+
api_key: str | None = None,
|
|
57
|
+
api_secret: str | None = None,
|
|
58
|
+
sample_rate: int = 48000,
|
|
59
|
+
language: str = "en-US",
|
|
60
|
+
encoding: str = "pcm",
|
|
61
|
+
vocabulary_name: Optional[str] = None,
|
|
62
|
+
session_id: Optional[str] = None,
|
|
63
|
+
vocab_filter_method: Optional[str] = None,
|
|
64
|
+
vocab_filter_name: Optional[str] = None,
|
|
65
|
+
show_speaker_label: Optional[bool] = None,
|
|
66
|
+
enable_channel_identification: Optional[bool] = None,
|
|
67
|
+
number_of_channels: Optional[int] = None,
|
|
68
|
+
enable_partial_results_stabilization: Optional[bool] = None,
|
|
69
|
+
partial_results_stability: Optional[str] = None,
|
|
70
|
+
language_model_name: Optional[str] = None,
|
|
71
|
+
):
|
|
72
|
+
super().__init__(
|
|
73
|
+
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
self._api_key, self._api_secret = _get_aws_credentials(
|
|
77
|
+
api_key, api_secret, speech_region
|
|
78
|
+
)
|
|
79
|
+
self._config = STTOptions(
|
|
80
|
+
speech_region=speech_region,
|
|
81
|
+
language=language,
|
|
82
|
+
sample_rate=sample_rate,
|
|
83
|
+
encoding=encoding,
|
|
84
|
+
vocabulary_name=vocabulary_name,
|
|
85
|
+
session_id=session_id,
|
|
86
|
+
vocab_filter_method=vocab_filter_method,
|
|
87
|
+
vocab_filter_name=vocab_filter_name,
|
|
88
|
+
show_speaker_label=show_speaker_label,
|
|
89
|
+
enable_channel_identification=enable_channel_identification,
|
|
90
|
+
number_of_channels=number_of_channels,
|
|
91
|
+
enable_partial_results_stabilization=enable_partial_results_stabilization,
|
|
92
|
+
partial_results_stability=partial_results_stability,
|
|
93
|
+
language_model_name=language_model_name,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
async def _recognize_impl(
|
|
97
|
+
self,
|
|
98
|
+
buffer: utils.AudioBuffer,
|
|
99
|
+
*,
|
|
100
|
+
language: str | None,
|
|
101
|
+
conn_options: APIConnectOptions,
|
|
102
|
+
) -> stt.SpeechEvent:
|
|
103
|
+
raise NotImplementedError(
|
|
104
|
+
"Amazon Transcribe does not support single frame recognition"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def stream(
|
|
108
|
+
self,
|
|
109
|
+
*,
|
|
110
|
+
language: str | None = None,
|
|
111
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
112
|
+
) -> "SpeechStream":
|
|
113
|
+
return SpeechStream(
|
|
114
|
+
stt=self,
|
|
115
|
+
conn_options=conn_options,
|
|
116
|
+
opts=self._config,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class SpeechStream(stt.SpeechStream):
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
stt: STT,
|
|
124
|
+
opts: STTOptions,
|
|
125
|
+
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
126
|
+
) -> None:
|
|
127
|
+
super().__init__(
|
|
128
|
+
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
|
|
129
|
+
)
|
|
130
|
+
self._opts = opts
|
|
131
|
+
self._client = TranscribeStreamingClient(region=self._opts.speech_region)
|
|
132
|
+
|
|
133
|
+
async def _run(self) -> None:
|
|
134
|
+
stream = await self._client.start_stream_transcription(
|
|
135
|
+
language_code=self._opts.language,
|
|
136
|
+
media_sample_rate_hz=self._opts.sample_rate,
|
|
137
|
+
media_encoding=self._opts.encoding,
|
|
138
|
+
vocabulary_name=self._opts.vocabulary_name,
|
|
139
|
+
session_id=self._opts.session_id,
|
|
140
|
+
vocab_filter_method=self._opts.vocab_filter_method,
|
|
141
|
+
vocab_filter_name=self._opts.vocab_filter_name,
|
|
142
|
+
show_speaker_label=self._opts.show_speaker_label,
|
|
143
|
+
enable_channel_identification=self._opts.enable_channel_identification,
|
|
144
|
+
number_of_channels=self._opts.number_of_channels,
|
|
145
|
+
enable_partial_results_stabilization=self._opts.enable_partial_results_stabilization,
|
|
146
|
+
partial_results_stability=self._opts.partial_results_stability,
|
|
147
|
+
language_model_name=self._opts.language_model_name,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
@utils.log_exceptions(logger=logger)
|
|
151
|
+
async def input_generator():
|
|
152
|
+
async for frame in self._input_ch:
|
|
153
|
+
if isinstance(frame, rtc.AudioFrame):
|
|
154
|
+
await stream.input_stream.send_audio_event(
|
|
155
|
+
audio_chunk=frame.data.tobytes()
|
|
156
|
+
)
|
|
157
|
+
await stream.input_stream.end_stream()
|
|
158
|
+
|
|
159
|
+
@utils.log_exceptions(logger=logger)
|
|
160
|
+
async def handle_transcript_events():
|
|
161
|
+
async for event in stream.output_stream:
|
|
162
|
+
if isinstance(event, TranscriptEvent):
|
|
163
|
+
self._process_transcript_event(event)
|
|
164
|
+
|
|
165
|
+
tasks = [
|
|
166
|
+
asyncio.create_task(input_generator()),
|
|
167
|
+
asyncio.create_task(handle_transcript_events()),
|
|
168
|
+
]
|
|
169
|
+
try:
|
|
170
|
+
await asyncio.gather(*tasks)
|
|
171
|
+
finally:
|
|
172
|
+
await utils.aio.gracefully_cancel(*tasks)
|
|
173
|
+
|
|
174
|
+
def _process_transcript_event(self, transcript_event: TranscriptEvent):
|
|
175
|
+
stream = transcript_event.transcript.results
|
|
176
|
+
for resp in stream:
|
|
177
|
+
if resp.start_time and resp.start_time == 0.0:
|
|
178
|
+
self._event_ch.send_nowait(
|
|
179
|
+
stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH)
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if resp.end_time and resp.end_time > 0.0:
|
|
183
|
+
if resp.is_partial:
|
|
184
|
+
self._event_ch.send_nowait(
|
|
185
|
+
stt.SpeechEvent(
|
|
186
|
+
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
|
|
187
|
+
alternatives=[
|
|
188
|
+
_streaming_recognize_response_to_speech_data(resp)
|
|
189
|
+
],
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
self._event_ch.send_nowait(
|
|
195
|
+
stt.SpeechEvent(
|
|
196
|
+
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
197
|
+
alternatives=[
|
|
198
|
+
_streaming_recognize_response_to_speech_data(resp)
|
|
199
|
+
],
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
if not resp.is_partial:
|
|
204
|
+
self._event_ch.send_nowait(
|
|
205
|
+
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData:
|
|
210
|
+
data = stt.SpeechData(
|
|
211
|
+
language="en-US",
|
|
212
|
+
start_time=resp.start_time if resp.start_time else 0.0,
|
|
213
|
+
end_time=resp.end_time if resp.end_time else 0.0,
|
|
214
|
+
confidence=0.0,
|
|
215
|
+
text=resp.alternatives[0].transcript if resp.alternatives else "",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return data
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Any, Callable, Optional
|
|
18
|
+
|
|
19
|
+
import aiohttp
|
|
20
|
+
from aiobotocore.session import AioSession, get_session
|
|
21
|
+
from livekit import rtc
|
|
22
|
+
from livekit.agents import (
|
|
23
|
+
APIConnectionError,
|
|
24
|
+
APIConnectOptions,
|
|
25
|
+
APIStatusError,
|
|
26
|
+
APITimeoutError,
|
|
27
|
+
tts,
|
|
28
|
+
utils,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from ._utils import _get_aws_credentials
|
|
32
|
+
from .models import TTS_LANGUAGE, TTS_OUTPUT_FORMAT, TTS_SPEECH_ENGINE
|
|
33
|
+
|
|
34
|
+
TTS_NUM_CHANNELS: int = 1
|
|
35
|
+
DEFAULT_OUTPUT_FORMAT: TTS_OUTPUT_FORMAT = "pcm"
|
|
36
|
+
DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
|
|
37
|
+
DEFAULT_SPEECH_REGION = "us-east-1"
|
|
38
|
+
DEFAULT_VOICE = "Ruth"
|
|
39
|
+
DEFAULT_SAMPLE_RATE = 16000
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class _TTSOptions:
|
|
44
|
+
# https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
|
|
45
|
+
voice: str | None
|
|
46
|
+
output_format: TTS_OUTPUT_FORMAT
|
|
47
|
+
speech_engine: TTS_SPEECH_ENGINE
|
|
48
|
+
speech_region: str
|
|
49
|
+
sample_rate: int
|
|
50
|
+
language: TTS_LANGUAGE | str | None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TTS(tts.TTS):
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
*,
|
|
57
|
+
voice: str | None = DEFAULT_VOICE,
|
|
58
|
+
language: TTS_LANGUAGE | str | None = None,
|
|
59
|
+
output_format: TTS_OUTPUT_FORMAT = DEFAULT_OUTPUT_FORMAT,
|
|
60
|
+
speech_engine: TTS_SPEECH_ENGINE = DEFAULT_SPEECH_ENGINE,
|
|
61
|
+
sample_rate: int = DEFAULT_SAMPLE_RATE,
|
|
62
|
+
speech_region: str = DEFAULT_SPEECH_REGION,
|
|
63
|
+
api_key: str | None = None,
|
|
64
|
+
api_secret: str | None = None,
|
|
65
|
+
session: AioSession | None = None,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Create a new instance of AWS Polly TTS.
|
|
69
|
+
|
|
70
|
+
``api_key`` and ``api_secret`` must be set to your AWS Access key id and secret access key, either using the argument or by setting the
|
|
71
|
+
``AWS_ACCESS_KEY_ID`` and ``AWS_SECRET_ACCESS_KEY`` environmental variables.
|
|
72
|
+
|
|
73
|
+
See https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html for more details on the the AWS Polly TTS.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
Voice (TTSModels, optional): Voice ID to use for the synthesis. Defaults to "Ruth".
|
|
77
|
+
language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
|
|
78
|
+
output_format(TTS_OUTPUT_FORMAT, optional): The format in which the returned output will be encoded. Defaults to "pcm".
|
|
79
|
+
sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
|
|
80
|
+
speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
|
|
81
|
+
speech_region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
|
|
82
|
+
api_key(str, optional): AWS access key id.
|
|
83
|
+
api_secret(str, optional): AWS secret access key.
|
|
84
|
+
"""
|
|
85
|
+
super().__init__(
|
|
86
|
+
capabilities=tts.TTSCapabilities(
|
|
87
|
+
streaming=False,
|
|
88
|
+
),
|
|
89
|
+
sample_rate=sample_rate,
|
|
90
|
+
num_channels=TTS_NUM_CHANNELS,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self._api_key, self._api_secret = _get_aws_credentials(
|
|
94
|
+
api_key, api_secret, speech_region
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self._opts = _TTSOptions(
|
|
98
|
+
voice=voice,
|
|
99
|
+
output_format=output_format,
|
|
100
|
+
speech_engine=speech_engine,
|
|
101
|
+
speech_region=speech_region,
|
|
102
|
+
language=language,
|
|
103
|
+
sample_rate=sample_rate,
|
|
104
|
+
)
|
|
105
|
+
self._session = session or get_session()
|
|
106
|
+
|
|
107
|
+
def _get_client(self):
|
|
108
|
+
return self._session.create_client(
|
|
109
|
+
"polly",
|
|
110
|
+
region_name=self._opts.speech_region,
|
|
111
|
+
aws_access_key_id=self._api_key,
|
|
112
|
+
aws_secret_access_key=self._api_secret,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def synthesize(
|
|
116
|
+
self,
|
|
117
|
+
text: str,
|
|
118
|
+
*,
|
|
119
|
+
conn_options: Optional[APIConnectOptions] = None,
|
|
120
|
+
) -> "ChunkedStream":
|
|
121
|
+
return ChunkedStream(
|
|
122
|
+
tts=self,
|
|
123
|
+
text=text,
|
|
124
|
+
conn_options=conn_options,
|
|
125
|
+
opts=self._opts,
|
|
126
|
+
get_client=self._get_client,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class ChunkedStream(tts.ChunkedStream):
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
*,
|
|
134
|
+
tts: TTS,
|
|
135
|
+
text: str,
|
|
136
|
+
conn_options: Optional[APIConnectOptions] = None,
|
|
137
|
+
opts: _TTSOptions,
|
|
138
|
+
get_client: Callable[[], Any],
|
|
139
|
+
) -> None:
|
|
140
|
+
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
|
|
141
|
+
self._opts = opts
|
|
142
|
+
self._get_client = get_client
|
|
143
|
+
self._segment_id = utils.shortuuid()
|
|
144
|
+
|
|
145
|
+
async def _run(self):
|
|
146
|
+
request_id = utils.shortuuid()
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
async with self._get_client() as client:
|
|
150
|
+
params = {
|
|
151
|
+
"Text": self._input_text,
|
|
152
|
+
"OutputFormat": self._opts.output_format,
|
|
153
|
+
"Engine": self._opts.speech_engine,
|
|
154
|
+
"VoiceId": self._opts.voice,
|
|
155
|
+
"TextType": "text",
|
|
156
|
+
"SampleRate": str(self._opts.sample_rate),
|
|
157
|
+
"LanguageCode": self._opts.language,
|
|
158
|
+
}
|
|
159
|
+
response = await client.synthesize_speech(**_strip_nones(params))
|
|
160
|
+
if "AudioStream" in response:
|
|
161
|
+
decoder = utils.codecs.Mp3StreamDecoder()
|
|
162
|
+
async with response["AudioStream"] as resp:
|
|
163
|
+
async for data, _ in resp.content.iter_chunks():
|
|
164
|
+
if self._opts.output_format == "mp3":
|
|
165
|
+
frames = decoder.decode_chunk(data)
|
|
166
|
+
for frame in frames:
|
|
167
|
+
self._event_ch.send_nowait(
|
|
168
|
+
tts.SynthesizedAudio(
|
|
169
|
+
request_id=request_id,
|
|
170
|
+
segment_id=self._segment_id,
|
|
171
|
+
frame=frame,
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
self._event_ch.send_nowait(
|
|
176
|
+
tts.SynthesizedAudio(
|
|
177
|
+
request_id=request_id,
|
|
178
|
+
segment_id=self._segment_id,
|
|
179
|
+
frame=rtc.AudioFrame(
|
|
180
|
+
data=data,
|
|
181
|
+
sample_rate=self._opts.sample_rate,
|
|
182
|
+
num_channels=1,
|
|
183
|
+
samples_per_channel=len(data) // 2,
|
|
184
|
+
),
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
except asyncio.TimeoutError as e:
|
|
189
|
+
raise APITimeoutError() from e
|
|
190
|
+
except aiohttp.ClientResponseError as e:
|
|
191
|
+
raise APIStatusError(
|
|
192
|
+
message=e.message,
|
|
193
|
+
status_code=e.status,
|
|
194
|
+
request_id=request_id,
|
|
195
|
+
body=None,
|
|
196
|
+
) from e
|
|
197
|
+
except Exception as e:
|
|
198
|
+
raise APIConnectionError() from e
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
|
|
202
|
+
return {k: v for k, v in d.items() if v is not None}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2023 LiveKit, Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: livekit-plugins-aws
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: LiveKit Agents Plugin for services from AWS
|
|
5
|
+
Home-page: https://github.com/livekit/agents
|
|
6
|
+
License: Apache-2.0
|
|
7
|
+
Project-URL: Documentation, https://docs.livekit.io
|
|
8
|
+
Project-URL: Website, https://livekit.io/
|
|
9
|
+
Project-URL: Source, https://github.com/livekit/agents
|
|
10
|
+
Keywords: webrtc,realtime,audio,video,livekit,aws
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
13
|
+
Classifier: Topic :: Multimedia :: Sound/Audio
|
|
14
|
+
Classifier: Topic :: Multimedia :: Video
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
20
|
+
Requires-Python: >=3.9.0
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
Requires-Dist: livekit-agents>=0.12.0
|
|
23
|
+
Requires-Dist: aiobotocore==2.19.0
|
|
24
|
+
Requires-Dist: boto3==1.36.3
|
|
25
|
+
Requires-Dist: amazon-transcribe>=0.6.2
|
|
26
|
+
Dynamic: classifier
|
|
27
|
+
Dynamic: description
|
|
28
|
+
Dynamic: description-content-type
|
|
29
|
+
Dynamic: home-page
|
|
30
|
+
Dynamic: keywords
|
|
31
|
+
Dynamic: license
|
|
32
|
+
Dynamic: project-url
|
|
33
|
+
Dynamic: requires-dist
|
|
34
|
+
Dynamic: requires-python
|
|
35
|
+
Dynamic: summary
|
|
36
|
+
|
|
37
|
+
# LiveKit Plugins AWS
|
|
38
|
+
|
|
39
|
+
Agent Framework plugin for services from AWS.
|
|
40
|
+
|
|
41
|
+
- aws polly for tts
|
|
42
|
+
- aws transcribe for stt
|
|
43
|
+
- aws bedrock for llm
|
|
44
|
+
|
|
45
|
+
## Installation
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install livekit-plugins-aws
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
## Pre-requisites
|
|
52
|
+
|
|
53
|
+
You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
livekit/plugins/aws/__init__.py,sha256=Ea-hK7QdutnwdZvvs9K2fiR8RWJqz2JcONxXnV1kXF0,977
|
|
2
|
+
livekit/plugins/aws/_utils.py,sha256=iuDuQpPta4wLtgW1Wc2rHspZWoa7KZI76tujQIPY898,7411
|
|
3
|
+
livekit/plugins/aws/llm.py,sha256=yUAiBCtb2jRB1_S9BNrILTMmDffvKOpDod802kYnPVM,13527
|
|
4
|
+
livekit/plugins/aws/log.py,sha256=jFief0Xhv0n_F6sp6UFu9VKxs2bXNVGAfYGmEYfR_2Q,66
|
|
5
|
+
livekit/plugins/aws/models.py,sha256=wb7AfN-z7qgtKMZnUbQsELi6wN8ha5exI3DH8z6Gz3M,711
|
|
6
|
+
livekit/plugins/aws/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
+
livekit/plugins/aws/stt.py,sha256=eH7gKtdCjwki20Th6PrCsjjtH-zjXa8ZWu-cu_KaT80,7935
|
|
8
|
+
livekit/plugins/aws/tts.py,sha256=miUYrhstJ7tcLkvJ-8Cpv1UCQxRSdOqaSC2tvHBh9WI,7800
|
|
9
|
+
livekit/plugins/aws/version.py,sha256=vQH9cItKAVYAmrLbOntkbLqmxrUZrPiKb1TjkZ8jRKQ,600
|
|
10
|
+
livekit_plugins_aws-0.1.0.dist-info/METADATA,sha256=FUzLRO0YcUvcIidEEq_EK7Lbp6yPYKjzT_BkclYNGhM,1686
|
|
11
|
+
livekit_plugins_aws-0.1.0.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
|
12
|
+
livekit_plugins_aws-0.1.0.dist-info/top_level.txt,sha256=OoDok3xUmXbZRvOrfvvXB-Juu4DX79dlq188E19YHoo,8
|
|
13
|
+
livekit_plugins_aws-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
livekit
|