livekit-plugins-aws 0.1.1__py3-none-any.whl → 1.0.0.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of livekit-plugins-aws might be problematic. Click here for more details.
- livekit/plugins/aws/llm.py +119 -199
- livekit/plugins/aws/stt.py +34 -53
- livekit/plugins/aws/tts.py +8 -7
- livekit/plugins/aws/utils.py +135 -0
- livekit/plugins/aws/version.py +1 -1
- {livekit_plugins_aws-0.1.1.dist-info → livekit_plugins_aws-1.0.0.dev4.dist-info}/METADATA +13 -23
- livekit_plugins_aws-1.0.0.dev4.dist-info/RECORD +12 -0
- {livekit_plugins_aws-0.1.1.dist-info → livekit_plugins_aws-1.0.0.dev4.dist-info}/WHEEL +1 -2
- livekit/plugins/aws/_utils.py +0 -216
- livekit_plugins_aws-0.1.1.dist-info/RECORD +0 -13
- livekit_plugins_aws-0.1.1.dist-info/top_level.txt +0 -1
livekit/plugins/aws/llm.py
CHANGED
|
@@ -17,47 +17,50 @@ from __future__ import annotations
|
|
|
17
17
|
import asyncio
|
|
18
18
|
import os
|
|
19
19
|
from dataclasses import dataclass
|
|
20
|
-
from typing import Any, Literal
|
|
20
|
+
from typing import Any, Literal
|
|
21
21
|
|
|
22
22
|
import boto3
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
23
|
+
|
|
24
|
+
from livekit.agents import APIConnectionError, APIStatusError, llm
|
|
25
|
+
from livekit.agents.llm import ChatContext, FunctionTool, FunctionToolCall, ToolChoice
|
|
26
|
+
from livekit.agents.types import (
|
|
27
|
+
DEFAULT_API_CONNECT_OPTIONS,
|
|
28
|
+
NOT_GIVEN,
|
|
29
|
+
APIConnectOptions,
|
|
30
|
+
NotGivenOr,
|
|
27
31
|
)
|
|
28
|
-
from livekit.agents.
|
|
29
|
-
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
|
|
32
|
+
from livekit.agents.utils import is_given
|
|
30
33
|
|
|
31
|
-
from ._utils import _build_aws_ctx, _build_tools, _get_aws_credentials
|
|
32
34
|
from .log import logger
|
|
35
|
+
from .utils import get_aws_credentials, to_chat_ctx, to_fnc_ctx
|
|
33
36
|
|
|
34
37
|
TEXT_MODEL = Literal["anthropic.claude-3-5-sonnet-20241022-v2:0"]
|
|
35
38
|
DEFAULT_REGION = "us-east-1"
|
|
36
39
|
|
|
37
40
|
|
|
38
41
|
@dataclass
|
|
39
|
-
class
|
|
40
|
-
model:
|
|
41
|
-
temperature: float
|
|
42
|
-
tool_choice:
|
|
43
|
-
max_output_tokens: int
|
|
44
|
-
top_p: float
|
|
45
|
-
additional_request_fields: dict[str, Any]
|
|
42
|
+
class _LLMOptions:
|
|
43
|
+
model: str | TEXT_MODEL
|
|
44
|
+
temperature: NotGivenOr[float]
|
|
45
|
+
tool_choice: NotGivenOr[ToolChoice | Literal["auto", "required", "none"]]
|
|
46
|
+
max_output_tokens: NotGivenOr[int]
|
|
47
|
+
top_p: NotGivenOr[float]
|
|
48
|
+
additional_request_fields: NotGivenOr[dict[str, Any]]
|
|
46
49
|
|
|
47
50
|
|
|
48
51
|
class LLM(llm.LLM):
|
|
49
52
|
def __init__(
|
|
50
53
|
self,
|
|
51
54
|
*,
|
|
52
|
-
model:
|
|
53
|
-
api_key: str
|
|
54
|
-
api_secret: str
|
|
55
|
-
region: str =
|
|
56
|
-
temperature: float =
|
|
57
|
-
max_output_tokens: int
|
|
58
|
-
top_p: float
|
|
59
|
-
tool_choice:
|
|
60
|
-
additional_request_fields: dict[str, Any]
|
|
55
|
+
model: NotGivenOr[str | TEXT_MODEL] = NOT_GIVEN,
|
|
56
|
+
api_key: NotGivenOr[str] = NOT_GIVEN,
|
|
57
|
+
api_secret: NotGivenOr[str] = NOT_GIVEN,
|
|
58
|
+
region: NotGivenOr[str] = NOT_GIVEN,
|
|
59
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
|
60
|
+
max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
|
|
61
|
+
top_p: NotGivenOr[float] = NOT_GIVEN,
|
|
62
|
+
tool_choice: NotGivenOr[ToolChoice | Literal["auto", "required", "none"]] = NOT_GIVEN,
|
|
63
|
+
additional_request_fields: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
|
61
64
|
) -> None:
|
|
62
65
|
"""
|
|
63
66
|
Create a new instance of AWS Bedrock LLM.
|
|
@@ -65,7 +68,7 @@ class LLM(llm.LLM):
|
|
|
65
68
|
``api_key`` and ``api_secret`` must be set to your AWS Access key id and secret access key, either using the argument or by setting the
|
|
66
69
|
``AWS_ACCESS_KEY_ID`` and ``AWS_SECRET_ACCESS_KEY`` environmental variables.
|
|
67
70
|
|
|
68
|
-
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html for more details on the
|
|
71
|
+
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html for more details on the AWS Bedrock Runtime API.
|
|
69
72
|
|
|
70
73
|
Args:
|
|
71
74
|
model (TEXT_MODEL, optional): model or inference profile arn to use(https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-use.html). Defaults to 'anthropic.claude-3-5-sonnet-20240620-v1:0'.
|
|
@@ -78,64 +81,89 @@ class LLM(llm.LLM):
|
|
|
78
81
|
tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
|
|
79
82
|
additional_request_fields (dict[str, Any], optional): Additional request fields to send to the AWS Bedrock Converse API. Defaults to None.
|
|
80
83
|
"""
|
|
81
|
-
super().__init__(
|
|
82
|
-
|
|
83
|
-
supports_choices_on_int=True,
|
|
84
|
-
requires_persistent_functions=True,
|
|
85
|
-
)
|
|
86
|
-
)
|
|
87
|
-
self._api_key, self._api_secret = _get_aws_credentials(
|
|
84
|
+
super().__init__()
|
|
85
|
+
self._api_key, self._api_secret, self._region = get_aws_credentials(
|
|
88
86
|
api_key, api_secret, region
|
|
89
87
|
)
|
|
90
88
|
|
|
91
|
-
|
|
92
|
-
if not
|
|
89
|
+
model = model or os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
|
|
90
|
+
if not is_given(model):
|
|
93
91
|
raise ValueError(
|
|
94
92
|
"model or inference profile arn must be set using the argument or by setting the BEDROCK_INFERENCE_PROFILE_ARN environment variable."
|
|
95
93
|
)
|
|
96
|
-
self._opts =
|
|
97
|
-
model=
|
|
94
|
+
self._opts = _LLMOptions(
|
|
95
|
+
model=model,
|
|
98
96
|
temperature=temperature,
|
|
99
97
|
tool_choice=tool_choice,
|
|
100
98
|
max_output_tokens=max_output_tokens,
|
|
101
99
|
top_p=top_p,
|
|
102
100
|
additional_request_fields=additional_request_fields,
|
|
103
101
|
)
|
|
104
|
-
self._region = region
|
|
105
|
-
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
|
|
106
102
|
|
|
107
103
|
def chat(
|
|
108
104
|
self,
|
|
109
105
|
*,
|
|
110
|
-
chat_ctx:
|
|
106
|
+
chat_ctx: ChatContext,
|
|
107
|
+
tools: list[FunctionTool] | None = None,
|
|
111
108
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
109
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
|
110
|
+
tool_choice: NotGivenOr[ToolChoice | Literal["auto", "required", "none"]] = NOT_GIVEN,
|
|
111
|
+
) -> LLMStream:
|
|
112
|
+
opts = {}
|
|
113
|
+
|
|
114
|
+
if is_given(self._opts.model):
|
|
115
|
+
opts["modelId"] = self._opts.model
|
|
116
|
+
|
|
117
|
+
def _get_tool_config() -> dict[str, Any] | None:
|
|
118
|
+
nonlocal tool_choice
|
|
119
|
+
|
|
120
|
+
if not tools:
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
tool_config: dict[str, Any] = {"tools": to_fnc_ctx(tools)}
|
|
124
|
+
tool_choice = tool_choice if is_given(tool_choice) else self._opts.tool_choice
|
|
125
|
+
if is_given(tool_choice):
|
|
126
|
+
if isinstance(tool_choice, ToolChoice):
|
|
127
|
+
tool_config["toolChoice"] = {"tool": {"name": tool_choice.name}}
|
|
128
|
+
elif tool_choice == "required":
|
|
129
|
+
tool_config["toolChoice"] = {"any": {}}
|
|
130
|
+
elif tool_choice == "auto":
|
|
131
|
+
tool_config["toolChoice"] = {"auto": {}}
|
|
132
|
+
else:
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
return tool_config
|
|
136
|
+
|
|
137
|
+
tool_config = _get_tool_config()
|
|
138
|
+
if tool_config:
|
|
139
|
+
opts["toolConfig"] = tool_config
|
|
140
|
+
messages, system_message = to_chat_ctx(chat_ctx, id(self))
|
|
141
|
+
opts["messages"] = messages
|
|
142
|
+
if system_message:
|
|
143
|
+
opts["system"] = [system_message]
|
|
144
|
+
|
|
145
|
+
inference_config = {}
|
|
146
|
+
if is_given(self._opts.max_output_tokens):
|
|
147
|
+
inference_config["maxTokens"] = self._opts.max_output_tokens
|
|
148
|
+
temperature = temperature if is_given(temperature) else self._opts.temperature
|
|
149
|
+
if is_given(temperature):
|
|
150
|
+
inference_config["temperature"] = temperature
|
|
151
|
+
if is_given(self._opts.top_p):
|
|
152
|
+
inference_config["topP"] = self._opts.top_p
|
|
153
|
+
|
|
154
|
+
opts["inferenceConfig"] = inference_config
|
|
155
|
+
if is_given(self._opts.additional_request_fields):
|
|
156
|
+
opts["additionalModelRequestFields"] = self._opts.additional_request_fields
|
|
124
157
|
|
|
125
158
|
return LLMStream(
|
|
126
159
|
self,
|
|
127
|
-
model=self._opts.model,
|
|
128
160
|
aws_access_key_id=self._api_key,
|
|
129
161
|
aws_secret_access_key=self._api_secret,
|
|
130
162
|
region_name=self._region,
|
|
131
|
-
max_output_tokens=self._opts.max_output_tokens,
|
|
132
|
-
top_p=self._opts.top_p,
|
|
133
|
-
additional_request_fields=self._opts.additional_request_fields,
|
|
134
163
|
chat_ctx=chat_ctx,
|
|
135
|
-
|
|
164
|
+
tools=tools,
|
|
136
165
|
conn_options=conn_options,
|
|
137
|
-
|
|
138
|
-
tool_choice=tool_choice,
|
|
166
|
+
extra_kwargs=opts,
|
|
139
167
|
)
|
|
140
168
|
|
|
141
169
|
|
|
@@ -144,91 +172,33 @@ class LLMStream(llm.LLMStream):
|
|
|
144
172
|
self,
|
|
145
173
|
llm: LLM,
|
|
146
174
|
*,
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
aws_secret_access_key: str | None,
|
|
175
|
+
aws_access_key_id: str,
|
|
176
|
+
aws_secret_access_key: str,
|
|
150
177
|
region_name: str,
|
|
151
|
-
chat_ctx:
|
|
178
|
+
chat_ctx: ChatContext,
|
|
152
179
|
conn_options: APIConnectOptions,
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
max_output_tokens: int | None,
|
|
156
|
-
top_p: float | None,
|
|
157
|
-
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
|
|
158
|
-
additional_request_fields: dict[str, Any] | None,
|
|
180
|
+
tools: list[FunctionTool] | None,
|
|
181
|
+
extra_kwargs: dict[str, Any],
|
|
159
182
|
) -> None:
|
|
160
|
-
super().__init__(
|
|
161
|
-
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
|
|
162
|
-
)
|
|
183
|
+
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
|
163
184
|
self._client = boto3.client(
|
|
164
185
|
"bedrock-runtime",
|
|
165
186
|
region_name=region_name,
|
|
166
187
|
aws_access_key_id=aws_access_key_id,
|
|
167
188
|
aws_secret_access_key=aws_secret_access_key,
|
|
168
189
|
)
|
|
169
|
-
self._model = model
|
|
170
190
|
self._llm: LLM = llm
|
|
171
|
-
self.
|
|
172
|
-
self._top_p = top_p
|
|
173
|
-
self._temperature = temperature
|
|
174
|
-
self._tool_choice = tool_choice
|
|
175
|
-
self._additional_request_fields = additional_request_fields
|
|
191
|
+
self._opts = extra_kwargs
|
|
176
192
|
|
|
177
|
-
async def _run(self) -> None:
|
|
178
193
|
self._tool_call_id: str | None = None
|
|
179
194
|
self._fnc_name: str | None = None
|
|
180
195
|
self._fnc_raw_arguments: str | None = None
|
|
181
196
|
self._text: str = ""
|
|
182
|
-
retryable = True
|
|
183
197
|
|
|
198
|
+
async def _run(self) -> None:
|
|
199
|
+
retryable = True
|
|
184
200
|
try:
|
|
185
|
-
|
|
186
|
-
messages, system_instruction = _build_aws_ctx(self._chat_ctx, id(self))
|
|
187
|
-
messages = _merge_messages(messages)
|
|
188
|
-
|
|
189
|
-
def _get_tool_config() -> dict[str, Any] | None:
|
|
190
|
-
if not (self._fnc_ctx and self._fnc_ctx.ai_functions):
|
|
191
|
-
return None
|
|
192
|
-
|
|
193
|
-
tools = _build_tools(self._fnc_ctx)
|
|
194
|
-
config: dict[str, Any] = {"tools": tools}
|
|
195
|
-
|
|
196
|
-
if isinstance(self._tool_choice, ToolChoice):
|
|
197
|
-
config["toolChoice"] = {"tool": {"name": self._tool_choice.name}}
|
|
198
|
-
elif self._tool_choice == "required":
|
|
199
|
-
config["toolChoice"] = {"any": {}}
|
|
200
|
-
elif self._tool_choice == "auto":
|
|
201
|
-
config["toolChoice"] = {"auto": {}}
|
|
202
|
-
else:
|
|
203
|
-
return None
|
|
204
|
-
|
|
205
|
-
return config
|
|
206
|
-
|
|
207
|
-
tool_config = _get_tool_config()
|
|
208
|
-
if tool_config:
|
|
209
|
-
opts["toolConfig"] = tool_config
|
|
210
|
-
|
|
211
|
-
if self._additional_request_fields:
|
|
212
|
-
opts["additionalModelRequestFields"] = _strip_nones(
|
|
213
|
-
self._additional_request_fields
|
|
214
|
-
)
|
|
215
|
-
if system_instruction:
|
|
216
|
-
opts["system"] = [system_instruction]
|
|
217
|
-
|
|
218
|
-
inference_config = _strip_nones(
|
|
219
|
-
{
|
|
220
|
-
"maxTokens": self._max_output_tokens,
|
|
221
|
-
"temperature": self._temperature,
|
|
222
|
-
"topP": self._top_p,
|
|
223
|
-
}
|
|
224
|
-
)
|
|
225
|
-
response = self._client.converse_stream(
|
|
226
|
-
modelId=self._model,
|
|
227
|
-
messages=messages,
|
|
228
|
-
inferenceConfig=inference_config,
|
|
229
|
-
**_strip_nones(opts),
|
|
230
|
-
) # type: ignore
|
|
231
|
-
|
|
201
|
+
response = self._client.converse_stream(**self._opts) # type: ignore
|
|
232
202
|
request_id = response["ResponseMetadata"]["RequestId"]
|
|
233
203
|
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
|
234
204
|
raise APIStatusError(
|
|
@@ -267,84 +237,34 @@ class LLMStream(llm.LLMStream):
|
|
|
267
237
|
elif "contentBlockStop" in chunk:
|
|
268
238
|
if self._text:
|
|
269
239
|
chat_chunk = llm.ChatChunk(
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
llm.Choice(
|
|
273
|
-
delta=llm.ChoiceDelta(content=self._text, role="assistant"),
|
|
274
|
-
index=chunk["contentBlockStop"]["contentBlockIndex"],
|
|
275
|
-
)
|
|
276
|
-
],
|
|
240
|
+
id=request_id,
|
|
241
|
+
delta=llm.ChoiceDelta(content=self._text, role="assistant"),
|
|
277
242
|
)
|
|
278
243
|
self._text = ""
|
|
279
244
|
return chat_chunk
|
|
280
245
|
elif self._tool_call_id:
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
if self._fnc_raw_arguments is None:
|
|
293
|
-
logger.warning("aws bedrock llm: no function arguments in the response")
|
|
294
|
-
return None
|
|
295
|
-
if self._fnc_ctx is None:
|
|
296
|
-
logger.warning(
|
|
297
|
-
"aws bedrock llm: stream tried to run function without function context"
|
|
298
|
-
)
|
|
299
|
-
return None
|
|
300
|
-
|
|
301
|
-
fnc_info = _create_ai_function_info(
|
|
302
|
-
self._fnc_ctx,
|
|
303
|
-
self._tool_call_id,
|
|
304
|
-
self._fnc_name,
|
|
305
|
-
self._fnc_raw_arguments,
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
|
|
309
|
-
self._function_calls_info.append(fnc_info)
|
|
310
|
-
|
|
311
|
-
return llm.ChatChunk(
|
|
312
|
-
request_id=request_id,
|
|
313
|
-
choices=[
|
|
314
|
-
llm.Choice(
|
|
246
|
+
if self._tool_call_id is None:
|
|
247
|
+
logger.warning("aws bedrock llm: no tool call id in the response")
|
|
248
|
+
return None
|
|
249
|
+
if self._fnc_name is None:
|
|
250
|
+
logger.warning("aws bedrock llm: no function name in the response")
|
|
251
|
+
return None
|
|
252
|
+
if self._fnc_raw_arguments is None:
|
|
253
|
+
logger.warning("aws bedrock llm: no function arguments in the response")
|
|
254
|
+
return None
|
|
255
|
+
chat_chunk = llm.ChatChunk(
|
|
256
|
+
id=request_id,
|
|
315
257
|
delta=llm.ChoiceDelta(
|
|
316
258
|
role="assistant",
|
|
317
|
-
tool_calls=[
|
|
259
|
+
tool_calls=[
|
|
260
|
+
FunctionToolCall(
|
|
261
|
+
arguments=self._fnc_raw_arguments,
|
|
262
|
+
name=self._fnc_name,
|
|
263
|
+
call_id=self._tool_call_id,
|
|
264
|
+
),
|
|
265
|
+
],
|
|
318
266
|
),
|
|
319
|
-
index=chunk["contentBlockStop"]["contentBlockIndex"],
|
|
320
267
|
)
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
def _merge_messages(
|
|
326
|
-
messages: list[dict],
|
|
327
|
-
) -> list[dict]:
|
|
328
|
-
# Anthropic enforces alternating messages
|
|
329
|
-
combined_messages: list[dict] = []
|
|
330
|
-
for m in messages:
|
|
331
|
-
if len(combined_messages) == 0 or m["role"] != combined_messages[-1]["role"]:
|
|
332
|
-
combined_messages.append(m)
|
|
333
|
-
continue
|
|
334
|
-
last_message = combined_messages[-1]
|
|
335
|
-
if not isinstance(last_message["content"], list) or not isinstance(
|
|
336
|
-
m["content"], list
|
|
337
|
-
):
|
|
338
|
-
logger.error("message content is not a list")
|
|
339
|
-
continue
|
|
340
|
-
|
|
341
|
-
last_message["content"].extend(m["content"])
|
|
342
|
-
|
|
343
|
-
if len(combined_messages) == 0 or combined_messages[0]["role"] != "user":
|
|
344
|
-
combined_messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})
|
|
345
|
-
|
|
346
|
-
return combined_messages
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
|
|
350
|
-
return {k: v for k, v in d.items() if v is not None}
|
|
268
|
+
self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
|
|
269
|
+
return chat_chunk
|
|
270
|
+
return None
|
livekit/plugins/aws/stt.py
CHANGED
|
@@ -14,20 +14,15 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
from typing import Optional
|
|
18
17
|
|
|
19
18
|
from amazon_transcribe.client import TranscribeStreamingClient
|
|
20
19
|
from amazon_transcribe.model import Result, TranscriptEvent
|
|
20
|
+
|
|
21
21
|
from livekit import rtc
|
|
22
|
-
from livekit.agents import
|
|
23
|
-
|
|
24
|
-
APIConnectOptions,
|
|
25
|
-
stt,
|
|
26
|
-
utils,
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
from ._utils import _get_aws_credentials
|
|
22
|
+
from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions, stt, utils
|
|
23
|
+
|
|
30
24
|
from .log import logger
|
|
25
|
+
from .utils import get_aws_credentials
|
|
31
26
|
|
|
32
27
|
|
|
33
28
|
@dataclass
|
|
@@ -36,16 +31,16 @@ class STTOptions:
|
|
|
36
31
|
sample_rate: int
|
|
37
32
|
language: str
|
|
38
33
|
encoding: str
|
|
39
|
-
vocabulary_name:
|
|
40
|
-
session_id:
|
|
41
|
-
vocab_filter_method:
|
|
42
|
-
vocab_filter_name:
|
|
43
|
-
show_speaker_label:
|
|
44
|
-
enable_channel_identification:
|
|
45
|
-
number_of_channels:
|
|
46
|
-
enable_partial_results_stabilization:
|
|
47
|
-
partial_results_stability:
|
|
48
|
-
language_model_name:
|
|
34
|
+
vocabulary_name: str | None
|
|
35
|
+
session_id: str | None
|
|
36
|
+
vocab_filter_method: str | None
|
|
37
|
+
vocab_filter_name: str | None
|
|
38
|
+
show_speaker_label: bool | None
|
|
39
|
+
enable_channel_identification: bool | None
|
|
40
|
+
number_of_channels: int | None
|
|
41
|
+
enable_partial_results_stabilization: bool | None
|
|
42
|
+
partial_results_stability: str | None
|
|
43
|
+
language_model_name: str | None
|
|
49
44
|
|
|
50
45
|
|
|
51
46
|
class STT(stt.STT):
|
|
@@ -58,26 +53,24 @@ class STT(stt.STT):
|
|
|
58
53
|
sample_rate: int = 48000,
|
|
59
54
|
language: str = "en-US",
|
|
60
55
|
encoding: str = "pcm",
|
|
61
|
-
vocabulary_name:
|
|
62
|
-
session_id:
|
|
63
|
-
vocab_filter_method:
|
|
64
|
-
vocab_filter_name:
|
|
65
|
-
show_speaker_label:
|
|
66
|
-
enable_channel_identification:
|
|
67
|
-
number_of_channels:
|
|
68
|
-
enable_partial_results_stabilization:
|
|
69
|
-
partial_results_stability:
|
|
70
|
-
language_model_name:
|
|
56
|
+
vocabulary_name: str | None = None,
|
|
57
|
+
session_id: str | None = None,
|
|
58
|
+
vocab_filter_method: str | None = None,
|
|
59
|
+
vocab_filter_name: str | None = None,
|
|
60
|
+
show_speaker_label: bool | None = None,
|
|
61
|
+
enable_channel_identification: bool | None = None,
|
|
62
|
+
number_of_channels: int | None = None,
|
|
63
|
+
enable_partial_results_stabilization: bool | None = None,
|
|
64
|
+
partial_results_stability: str | None = None,
|
|
65
|
+
language_model_name: str | None = None,
|
|
71
66
|
):
|
|
72
|
-
super().__init__(
|
|
73
|
-
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
|
|
74
|
-
)
|
|
67
|
+
super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
|
|
75
68
|
|
|
76
|
-
self._api_key, self._api_secret =
|
|
69
|
+
self._api_key, self._api_secret, self._speech_region = get_aws_credentials(
|
|
77
70
|
api_key, api_secret, speech_region
|
|
78
71
|
)
|
|
79
72
|
self._config = STTOptions(
|
|
80
|
-
speech_region=
|
|
73
|
+
speech_region=self._speech_region,
|
|
81
74
|
language=language,
|
|
82
75
|
sample_rate=sample_rate,
|
|
83
76
|
encoding=encoding,
|
|
@@ -100,16 +93,14 @@ class STT(stt.STT):
|
|
|
100
93
|
language: str | None,
|
|
101
94
|
conn_options: APIConnectOptions,
|
|
102
95
|
) -> stt.SpeechEvent:
|
|
103
|
-
raise NotImplementedError(
|
|
104
|
-
"Amazon Transcribe does not support single frame recognition"
|
|
105
|
-
)
|
|
96
|
+
raise NotImplementedError("Amazon Transcribe does not support single frame recognition")
|
|
106
97
|
|
|
107
98
|
def stream(
|
|
108
99
|
self,
|
|
109
100
|
*,
|
|
110
101
|
language: str | None = None,
|
|
111
102
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
112
|
-
) ->
|
|
103
|
+
) -> SpeechStream:
|
|
113
104
|
return SpeechStream(
|
|
114
105
|
stt=self,
|
|
115
106
|
conn_options=conn_options,
|
|
@@ -124,9 +115,7 @@ class SpeechStream(stt.SpeechStream):
|
|
|
124
115
|
opts: STTOptions,
|
|
125
116
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
126
117
|
) -> None:
|
|
127
|
-
super().__init__(
|
|
128
|
-
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
|
|
129
|
-
)
|
|
118
|
+
super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
|
|
130
119
|
self._opts = opts
|
|
131
120
|
self._client = TranscribeStreamingClient(region=self._opts.speech_region)
|
|
132
121
|
|
|
@@ -151,9 +140,7 @@ class SpeechStream(stt.SpeechStream):
|
|
|
151
140
|
async def input_generator():
|
|
152
141
|
async for frame in self._input_ch:
|
|
153
142
|
if isinstance(frame, rtc.AudioFrame):
|
|
154
|
-
await stream.input_stream.send_audio_event(
|
|
155
|
-
audio_chunk=frame.data.tobytes()
|
|
156
|
-
)
|
|
143
|
+
await stream.input_stream.send_audio_event(audio_chunk=frame.data.tobytes())
|
|
157
144
|
await stream.input_stream.end_stream()
|
|
158
145
|
|
|
159
146
|
@utils.log_exceptions(logger=logger)
|
|
@@ -184,9 +171,7 @@ class SpeechStream(stt.SpeechStream):
|
|
|
184
171
|
self._event_ch.send_nowait(
|
|
185
172
|
stt.SpeechEvent(
|
|
186
173
|
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
|
|
187
|
-
alternatives=[
|
|
188
|
-
_streaming_recognize_response_to_speech_data(resp)
|
|
189
|
-
],
|
|
174
|
+
alternatives=[_streaming_recognize_response_to_speech_data(resp)],
|
|
190
175
|
)
|
|
191
176
|
)
|
|
192
177
|
|
|
@@ -194,16 +179,12 @@ class SpeechStream(stt.SpeechStream):
|
|
|
194
179
|
self._event_ch.send_nowait(
|
|
195
180
|
stt.SpeechEvent(
|
|
196
181
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
197
|
-
alternatives=[
|
|
198
|
-
_streaming_recognize_response_to_speech_data(resp)
|
|
199
|
-
],
|
|
182
|
+
alternatives=[_streaming_recognize_response_to_speech_data(resp)],
|
|
200
183
|
)
|
|
201
184
|
)
|
|
202
185
|
|
|
203
186
|
if not resp.is_partial:
|
|
204
|
-
self._event_ch.send_nowait(
|
|
205
|
-
stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
|
|
206
|
-
)
|
|
187
|
+
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
|
|
207
188
|
|
|
208
189
|
|
|
209
190
|
def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData:
|
livekit/plugins/aws/tts.py
CHANGED
|
@@ -14,10 +14,11 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
16
|
from dataclasses import dataclass
|
|
17
|
-
from typing import Any, Callable
|
|
17
|
+
from typing import Any, Callable
|
|
18
18
|
|
|
19
19
|
import aiohttp
|
|
20
20
|
from aiobotocore.session import AioSession, get_session
|
|
21
|
+
|
|
21
22
|
from livekit.agents import (
|
|
22
23
|
APIConnectionError,
|
|
23
24
|
APIConnectOptions,
|
|
@@ -27,8 +28,8 @@ from livekit.agents import (
|
|
|
27
28
|
utils,
|
|
28
29
|
)
|
|
29
30
|
|
|
30
|
-
from ._utils import _get_aws_credentials
|
|
31
31
|
from .models import TTS_LANGUAGE, TTS_SPEECH_ENGINE
|
|
32
|
+
from .utils import get_aws_credentials
|
|
32
33
|
|
|
33
34
|
TTS_NUM_CHANNELS: int = 1
|
|
34
35
|
DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
|
|
@@ -85,14 +86,14 @@ class TTS(tts.TTS):
|
|
|
85
86
|
num_channels=TTS_NUM_CHANNELS,
|
|
86
87
|
)
|
|
87
88
|
|
|
88
|
-
self._api_key, self._api_secret =
|
|
89
|
+
self._api_key, self._api_secret, self._speech_region = get_aws_credentials(
|
|
89
90
|
api_key, api_secret, speech_region
|
|
90
91
|
)
|
|
91
92
|
|
|
92
93
|
self._opts = _TTSOptions(
|
|
93
94
|
voice=voice,
|
|
94
95
|
speech_engine=speech_engine,
|
|
95
|
-
speech_region=
|
|
96
|
+
speech_region=self._speech_region,
|
|
96
97
|
language=language,
|
|
97
98
|
sample_rate=sample_rate,
|
|
98
99
|
)
|
|
@@ -110,8 +111,8 @@ class TTS(tts.TTS):
|
|
|
110
111
|
self,
|
|
111
112
|
text: str,
|
|
112
113
|
*,
|
|
113
|
-
conn_options:
|
|
114
|
-
) ->
|
|
114
|
+
conn_options: APIConnectOptions | None = None,
|
|
115
|
+
) -> ChunkedStream:
|
|
115
116
|
return ChunkedStream(
|
|
116
117
|
tts=self,
|
|
117
118
|
text=text,
|
|
@@ -127,7 +128,7 @@ class ChunkedStream(tts.ChunkedStream):
|
|
|
127
128
|
*,
|
|
128
129
|
tts: TTS,
|
|
129
130
|
text: str,
|
|
130
|
-
conn_options:
|
|
131
|
+
conn_options: APIConnectOptions | None = None,
|
|
131
132
|
opts: _TTSOptions,
|
|
132
133
|
get_client: Callable[[], Any],
|
|
133
134
|
) -> None:
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, cast
|
|
6
|
+
|
|
7
|
+
import boto3
|
|
8
|
+
|
|
9
|
+
from livekit.agents import llm
|
|
10
|
+
from livekit.agents.llm import ChatContext, FunctionTool, ImageContent, utils
|
|
11
|
+
|
|
12
|
+
__all__ = ["to_fnc_ctx", "to_chat_ctx", "get_aws_credentials"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_aws_credentials(api_key: str | None, api_secret: str | None, region: str | None):
|
|
16
|
+
region = region or os.environ.get("AWS_DEFAULT_REGION")
|
|
17
|
+
if not region:
|
|
18
|
+
raise ValueError(
|
|
19
|
+
"AWS_DEFAULT_REGION must be set via argument or the AWS_DEFAULT_REGION environment variable."
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if api_key and api_secret:
|
|
23
|
+
session = boto3.Session(
|
|
24
|
+
aws_access_key_id=api_key,
|
|
25
|
+
aws_secret_access_key=api_secret,
|
|
26
|
+
region_name=region,
|
|
27
|
+
)
|
|
28
|
+
else:
|
|
29
|
+
session = boto3.Session(region_name=region)
|
|
30
|
+
|
|
31
|
+
credentials = session.get_credentials()
|
|
32
|
+
if not credentials or not credentials.access_key or not credentials.secret_key:
|
|
33
|
+
raise ValueError("No valid AWS credentials found.")
|
|
34
|
+
return cast(tuple[str, str, str], (credentials.access_key, credentials.secret_key, region))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def to_fnc_ctx(fncs: list[FunctionTool]) -> list[dict]:
|
|
38
|
+
return [_build_tool_spec(fnc) for fnc in fncs]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def to_chat_ctx(chat_ctx: ChatContext, cache_key: Any) -> tuple[list[dict], dict | None]:
|
|
42
|
+
messages: list[dict] = []
|
|
43
|
+
system_message: dict | None = None
|
|
44
|
+
current_role: str | None = None
|
|
45
|
+
current_content: list[dict] = []
|
|
46
|
+
|
|
47
|
+
for msg in chat_ctx.items:
|
|
48
|
+
if msg.type == "message" and msg.role == "system":
|
|
49
|
+
for content in msg.content:
|
|
50
|
+
if isinstance(content, str):
|
|
51
|
+
system_message = {"text": content}
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
if msg.type == "message":
|
|
55
|
+
role = "assistant" if msg.role == "assistant" else "user"
|
|
56
|
+
elif msg.type == "function_call":
|
|
57
|
+
role = "assistant"
|
|
58
|
+
elif msg.type == "function_call_output":
|
|
59
|
+
role = "user"
|
|
60
|
+
|
|
61
|
+
# if the effective role changed, finalize the previous turn.
|
|
62
|
+
if role != current_role:
|
|
63
|
+
if current_content and current_role is not None:
|
|
64
|
+
messages.append({"role": current_role, "content": current_content})
|
|
65
|
+
current_content = []
|
|
66
|
+
current_role = role
|
|
67
|
+
|
|
68
|
+
if msg.type == "message":
|
|
69
|
+
for content in msg.content:
|
|
70
|
+
if isinstance(content, str):
|
|
71
|
+
current_content.append({"text": content})
|
|
72
|
+
elif isinstance(content, ImageContent):
|
|
73
|
+
current_content.append(_build_image(content, cache_key))
|
|
74
|
+
elif msg.type == "function_call":
|
|
75
|
+
current_content.append(
|
|
76
|
+
{
|
|
77
|
+
"toolUse": {
|
|
78
|
+
"toolUseId": msg.call_id,
|
|
79
|
+
"name": msg.name,
|
|
80
|
+
"input": json.loads(msg.arguments or "{}"),
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
)
|
|
84
|
+
elif msg.type == "function_call_output":
|
|
85
|
+
tool_response = {
|
|
86
|
+
"toolResult": {
|
|
87
|
+
"toolUseId": msg.call_id,
|
|
88
|
+
"content": [],
|
|
89
|
+
"status": "success",
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
if isinstance(msg.output, dict):
|
|
93
|
+
tool_response["toolResult"]["content"].append({"json": msg.output})
|
|
94
|
+
elif isinstance(msg.output, str):
|
|
95
|
+
tool_response["toolResult"]["content"].append({"text": msg.output})
|
|
96
|
+
current_content.append(tool_response)
|
|
97
|
+
|
|
98
|
+
# Finalize the last message if there’s any content left
|
|
99
|
+
if current_role is not None and current_content:
|
|
100
|
+
messages.append({"role": current_role, "content": current_content})
|
|
101
|
+
|
|
102
|
+
# Ensure the message list starts with a "user" message
|
|
103
|
+
if not messages or messages[0]["role"] != "user":
|
|
104
|
+
messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})
|
|
105
|
+
|
|
106
|
+
return messages, system_message
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _build_tool_spec(fnc: FunctionTool) -> dict:
|
|
110
|
+
fnc = llm.utils.build_legacy_openai_schema(fnc, internally_tagged=True)
|
|
111
|
+
return {
|
|
112
|
+
"toolSpec": _strip_nones(
|
|
113
|
+
{
|
|
114
|
+
"name": fnc["name"],
|
|
115
|
+
"description": fnc["description"] if fnc["description"] else None,
|
|
116
|
+
"inputSchema": {"json": fnc["parameters"] if fnc["parameters"] else {}},
|
|
117
|
+
}
|
|
118
|
+
)
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _build_image(image: ImageContent, cache_key: Any) -> dict:
|
|
123
|
+
img = utils.serialize_image(image)
|
|
124
|
+
if cache_key not in image._cache:
|
|
125
|
+
image._cache[cache_key] = img.data_bytes
|
|
126
|
+
return {
|
|
127
|
+
"image": {
|
|
128
|
+
"format": "jpeg",
|
|
129
|
+
"source": {"bytes": image._cache[cache_key]},
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _strip_nones(d: dict) -> dict:
|
|
135
|
+
return {k: v for k, v in d.items() if v is not None}
|
livekit/plugins/aws/version.py
CHANGED
|
@@ -1,38 +1,28 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: livekit-plugins-aws
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0.dev4
|
|
4
4
|
Summary: LiveKit Agents Plugin for services from AWS
|
|
5
|
-
Home-page: https://github.com/livekit/agents
|
|
6
|
-
License: Apache-2.0
|
|
7
5
|
Project-URL: Documentation, https://docs.livekit.io
|
|
8
6
|
Project-URL: Website, https://livekit.io/
|
|
9
7
|
Project-URL: Source, https://github.com/livekit/agents
|
|
10
|
-
|
|
8
|
+
Author-email: LiveKit <support@livekit.io>
|
|
9
|
+
License-Expression: Apache-2.0
|
|
10
|
+
Keywords: audio,aws,livekit,realtime,video,webrtc
|
|
11
11
|
Classifier: Intended Audience :: Developers
|
|
12
12
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
13
|
-
Classifier: Topic :: Multimedia :: Sound/Audio
|
|
14
|
-
Classifier: Topic :: Multimedia :: Video
|
|
15
|
-
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
13
|
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
17
15
|
Classifier: Programming Language :: Python :: 3.9
|
|
18
16
|
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
-
Classifier:
|
|
17
|
+
Classifier: Topic :: Multimedia :: Sound/Audio
|
|
18
|
+
Classifier: Topic :: Multimedia :: Video
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
20
|
Requires-Python: >=3.9.0
|
|
21
|
-
Description-Content-Type: text/markdown
|
|
22
|
-
Requires-Dist: livekit-agents[codecs]<1.0.0,>=0.12.16
|
|
23
21
|
Requires-Dist: aiobotocore==2.19.0
|
|
24
|
-
Requires-Dist: boto3==1.36.3
|
|
25
22
|
Requires-Dist: amazon-transcribe>=0.6.2
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
Dynamic: home-page
|
|
30
|
-
Dynamic: keywords
|
|
31
|
-
Dynamic: license
|
|
32
|
-
Dynamic: project-url
|
|
33
|
-
Dynamic: requires-dist
|
|
34
|
-
Dynamic: requires-python
|
|
35
|
-
Dynamic: summary
|
|
23
|
+
Requires-Dist: boto3==1.36.3
|
|
24
|
+
Requires-Dist: livekit-agents>=1.0.0.dev4
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
36
26
|
|
|
37
27
|
# LiveKit Plugins AWS
|
|
38
28
|
|
|
@@ -50,4 +40,4 @@ pip install livekit-plugins-aws
|
|
|
50
40
|
|
|
51
41
|
## Pre-requisites
|
|
52
42
|
|
|
53
|
-
You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
|
|
43
|
+
You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
livekit/plugins/aws/__init__.py,sha256=Ea-hK7QdutnwdZvvs9K2fiR8RWJqz2JcONxXnV1kXF0,977
|
|
2
|
+
livekit/plugins/aws/llm.py,sha256=Mc910AREP7-FX1yEV1k_rViue_30Gy8qmp42VDAptSE,11011
|
|
3
|
+
livekit/plugins/aws/log.py,sha256=jFief0Xhv0n_F6sp6UFu9VKxs2bXNVGAfYGmEYfR_2Q,66
|
|
4
|
+
livekit/plugins/aws/models.py,sha256=Nf8RFmDulW7h03dG2lERTog3mgDK0TbLvW0eGOncuEE,704
|
|
5
|
+
livekit/plugins/aws/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
+
livekit/plugins/aws/stt.py,sha256=hRulbbMXtYqYPuqo359ARWE0fYDy1PzMdpT-h2m1UsY,7575
|
|
7
|
+
livekit/plugins/aws/tts.py,sha256=WA-KtEVF8dq4GZEbPWdY3azdHZRiHFyptesx7kh6Tio,7250
|
|
8
|
+
livekit/plugins/aws/utils.py,sha256=Q62NpoJs3bLerMBlhW22L9xiZHgmtxK3-js7KbL0bkQ,4790
|
|
9
|
+
livekit/plugins/aws/version.py,sha256=koM_bT4QbztrKQ60Gjg7V4oe99CuxgGcpuUtWMOEKqU,605
|
|
10
|
+
livekit_plugins_aws-1.0.0.dev4.dist-info/METADATA,sha256=2GdpNgK-u87T1YW20JWZzR0r8iadqvQLr4NsNnOLUEo,1488
|
|
11
|
+
livekit_plugins_aws-1.0.0.dev4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
12
|
+
livekit_plugins_aws-1.0.0.dev4.dist-info/RECORD,,
|
livekit/plugins/aws/_utils.py
DELETED
|
@@ -1,216 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import base64
|
|
4
|
-
import inspect
|
|
5
|
-
import json
|
|
6
|
-
import os
|
|
7
|
-
from typing import Any, Dict, List, Optional, Tuple, get_args, get_origin
|
|
8
|
-
|
|
9
|
-
import boto3
|
|
10
|
-
from livekit import rtc
|
|
11
|
-
from livekit.agents import llm, utils
|
|
12
|
-
from livekit.agents.llm.function_context import _is_optional_type
|
|
13
|
-
|
|
14
|
-
__all__ = ["_build_aws_ctx", "_build_tools", "_get_aws_credentials"]
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def _get_aws_credentials(
|
|
18
|
-
api_key: Optional[str], api_secret: Optional[str], region: Optional[str]
|
|
19
|
-
):
|
|
20
|
-
region = region or os.environ.get("AWS_DEFAULT_REGION")
|
|
21
|
-
if not region:
|
|
22
|
-
raise ValueError(
|
|
23
|
-
"AWS_DEFAULT_REGION must be set using the argument or by setting the AWS_DEFAULT_REGION environment variable."
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
# If API key and secret are provided, create a session with them
|
|
27
|
-
if api_key and api_secret:
|
|
28
|
-
session = boto3.Session(
|
|
29
|
-
aws_access_key_id=api_key,
|
|
30
|
-
aws_secret_access_key=api_secret,
|
|
31
|
-
region_name=region,
|
|
32
|
-
)
|
|
33
|
-
else:
|
|
34
|
-
session = boto3.Session(region_name=region)
|
|
35
|
-
|
|
36
|
-
credentials = session.get_credentials()
|
|
37
|
-
if not credentials or not credentials.access_key or not credentials.secret_key:
|
|
38
|
-
raise ValueError("No valid AWS credentials found.")
|
|
39
|
-
return credentials.access_key, credentials.secret_key
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
JSON_SCHEMA_TYPE_MAP: Dict[type, str] = {
|
|
43
|
-
str: "string",
|
|
44
|
-
int: "integer",
|
|
45
|
-
float: "number",
|
|
46
|
-
bool: "boolean",
|
|
47
|
-
dict: "object",
|
|
48
|
-
list: "array",
|
|
49
|
-
}
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def _build_parameters(arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
53
|
-
properties: Dict[str, dict] = {}
|
|
54
|
-
required: List[str] = []
|
|
55
|
-
|
|
56
|
-
for arg_name, arg_info in arguments.items():
|
|
57
|
-
prop = {}
|
|
58
|
-
if hasattr(arg_info, "description") and arg_info.description:
|
|
59
|
-
prop["description"] = arg_info.description
|
|
60
|
-
|
|
61
|
-
_, py_type = _is_optional_type(arg_info.type)
|
|
62
|
-
origin = get_origin(py_type)
|
|
63
|
-
if origin is list:
|
|
64
|
-
item_type = get_args(py_type)[0]
|
|
65
|
-
if item_type not in JSON_SCHEMA_TYPE_MAP:
|
|
66
|
-
raise ValueError(f"Unsupported type: {item_type}")
|
|
67
|
-
prop["type"] = "array"
|
|
68
|
-
prop["items"] = {"type": JSON_SCHEMA_TYPE_MAP[item_type]}
|
|
69
|
-
|
|
70
|
-
if hasattr(arg_info, "choices") and arg_info.choices:
|
|
71
|
-
prop["items"]["enum"] = list(arg_info.choices)
|
|
72
|
-
else:
|
|
73
|
-
if py_type not in JSON_SCHEMA_TYPE_MAP:
|
|
74
|
-
raise ValueError(f"Unsupported type: {py_type}")
|
|
75
|
-
|
|
76
|
-
prop["type"] = JSON_SCHEMA_TYPE_MAP[py_type]
|
|
77
|
-
|
|
78
|
-
if arg_info.choices:
|
|
79
|
-
prop["enum"] = list(arg_info.choices)
|
|
80
|
-
|
|
81
|
-
properties[arg_name] = prop
|
|
82
|
-
|
|
83
|
-
if arg_info.default is inspect.Parameter.empty:
|
|
84
|
-
required.append(arg_name)
|
|
85
|
-
|
|
86
|
-
if properties:
|
|
87
|
-
parameters = {"json": {"type": "object", "properties": properties}}
|
|
88
|
-
if required:
|
|
89
|
-
parameters["json"]["required"] = required
|
|
90
|
-
|
|
91
|
-
return parameters
|
|
92
|
-
|
|
93
|
-
return None
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def _build_tools(fnc_ctx: Any) -> List[dict]:
|
|
97
|
-
tools: List[dict] = []
|
|
98
|
-
for fnc_info in fnc_ctx.ai_functions.values():
|
|
99
|
-
parameters = _build_parameters(fnc_info.arguments)
|
|
100
|
-
|
|
101
|
-
func_decl = {
|
|
102
|
-
"toolSpec": {
|
|
103
|
-
"name": fnc_info.name,
|
|
104
|
-
"description": fnc_info.description,
|
|
105
|
-
"inputSchema": parameters
|
|
106
|
-
if parameters
|
|
107
|
-
else {"json": {"type": "object", "properties": {}}},
|
|
108
|
-
}
|
|
109
|
-
}
|
|
110
|
-
|
|
111
|
-
tools.append(func_decl)
|
|
112
|
-
return tools
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def _build_image(image: llm.ChatImage, cache_key: Any) -> dict:
|
|
116
|
-
if isinstance(image.image, str):
|
|
117
|
-
if image.image.startswith("data:image/jpeg;base64,"):
|
|
118
|
-
base64_data = image.image.split(",", 1)[1]
|
|
119
|
-
try:
|
|
120
|
-
image_bytes = base64.b64decode(base64_data)
|
|
121
|
-
except Exception as e:
|
|
122
|
-
raise ValueError("Invalid base64 data in image URL") from e
|
|
123
|
-
|
|
124
|
-
return {"image": {"format": "jpeg", "source": {"bytes": image_bytes}}}
|
|
125
|
-
else:
|
|
126
|
-
return {"image": {"format": "jpeg", "source": {"uri": image.image}}}
|
|
127
|
-
|
|
128
|
-
elif isinstance(image.image, rtc.VideoFrame):
|
|
129
|
-
if cache_key not in image._cache:
|
|
130
|
-
opts = utils.images.EncodeOptions()
|
|
131
|
-
if image.inference_width and image.inference_height:
|
|
132
|
-
opts.resize_options = utils.images.ResizeOptions(
|
|
133
|
-
width=image.inference_width,
|
|
134
|
-
height=image.inference_height,
|
|
135
|
-
strategy="scale_aspect_fit",
|
|
136
|
-
)
|
|
137
|
-
image._cache[cache_key] = utils.images.encode(image.image, opts)
|
|
138
|
-
|
|
139
|
-
return {
|
|
140
|
-
"image": {
|
|
141
|
-
"format": "jpeg",
|
|
142
|
-
"source": {
|
|
143
|
-
"bytes": image._cache[cache_key],
|
|
144
|
-
},
|
|
145
|
-
}
|
|
146
|
-
}
|
|
147
|
-
raise ValueError(f"Unsupported image type: {type(image.image)}")
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
def _build_aws_ctx(
|
|
151
|
-
chat_ctx: llm.ChatContext, cache_key: Any
|
|
152
|
-
) -> Tuple[List[dict], Optional[dict]]:
|
|
153
|
-
messages: List[dict] = []
|
|
154
|
-
system: Optional[dict] = None
|
|
155
|
-
current_role: Optional[str] = None
|
|
156
|
-
current_content: List[dict] = []
|
|
157
|
-
|
|
158
|
-
for msg in chat_ctx.messages:
|
|
159
|
-
if msg.role == "system":
|
|
160
|
-
if isinstance(msg.content, str):
|
|
161
|
-
system = {"text": msg.content}
|
|
162
|
-
continue
|
|
163
|
-
|
|
164
|
-
if msg.role == "assistant":
|
|
165
|
-
role = "assistant"
|
|
166
|
-
else:
|
|
167
|
-
role = "user"
|
|
168
|
-
|
|
169
|
-
if role != current_role:
|
|
170
|
-
if current_role is not None and current_content:
|
|
171
|
-
messages.append({"role": current_role, "content": current_content})
|
|
172
|
-
current_role = role
|
|
173
|
-
current_content = []
|
|
174
|
-
|
|
175
|
-
if msg.tool_calls:
|
|
176
|
-
for fnc in msg.tool_calls:
|
|
177
|
-
current_content.append(
|
|
178
|
-
{
|
|
179
|
-
"toolUse": {
|
|
180
|
-
"toolUseId": fnc.tool_call_id,
|
|
181
|
-
"name": fnc.function_info.name,
|
|
182
|
-
"input": fnc.arguments,
|
|
183
|
-
}
|
|
184
|
-
}
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
if msg.role == "tool":
|
|
188
|
-
tool_response: dict = {
|
|
189
|
-
"toolResult": {
|
|
190
|
-
"toolUseId": msg.tool_call_id,
|
|
191
|
-
"content": [],
|
|
192
|
-
"status": "success",
|
|
193
|
-
}
|
|
194
|
-
}
|
|
195
|
-
if isinstance(msg.content, dict):
|
|
196
|
-
tool_response["toolResult"]["content"].append({"json": msg.content})
|
|
197
|
-
elif isinstance(msg.content, str):
|
|
198
|
-
tool_response["toolResult"]["content"].append({"text": msg.content})
|
|
199
|
-
current_content.append(tool_response)
|
|
200
|
-
else:
|
|
201
|
-
if msg.content:
|
|
202
|
-
if isinstance(msg.content, str):
|
|
203
|
-
current_content.append({"text": msg.content})
|
|
204
|
-
elif isinstance(msg.content, dict):
|
|
205
|
-
current_content.append({"text": json.dumps(msg.content)})
|
|
206
|
-
elif isinstance(msg.content, list):
|
|
207
|
-
for item in msg.content:
|
|
208
|
-
if isinstance(item, str):
|
|
209
|
-
current_content.append({"text": item})
|
|
210
|
-
elif isinstance(item, llm.ChatImage):
|
|
211
|
-
current_content.append(_build_image(item, cache_key))
|
|
212
|
-
|
|
213
|
-
if current_role is not None and current_content:
|
|
214
|
-
messages.append({"role": current_role, "content": current_content})
|
|
215
|
-
|
|
216
|
-
return messages, system
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
livekit/plugins/aws/__init__.py,sha256=Ea-hK7QdutnwdZvvs9K2fiR8RWJqz2JcONxXnV1kXF0,977
|
|
2
|
-
livekit/plugins/aws/_utils.py,sha256=iuDuQpPta4wLtgW1Wc2rHspZWoa7KZI76tujQIPY898,7411
|
|
3
|
-
livekit/plugins/aws/llm.py,sha256=yUAiBCtb2jRB1_S9BNrILTMmDffvKOpDod802kYnPVM,13527
|
|
4
|
-
livekit/plugins/aws/log.py,sha256=jFief0Xhv0n_F6sp6UFu9VKxs2bXNVGAfYGmEYfR_2Q,66
|
|
5
|
-
livekit/plugins/aws/models.py,sha256=Nf8RFmDulW7h03dG2lERTog3mgDK0TbLvW0eGOncuEE,704
|
|
6
|
-
livekit/plugins/aws/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
-
livekit/plugins/aws/stt.py,sha256=eH7gKtdCjwki20Th6PrCsjjtH-zjXa8ZWu-cu_KaT80,7935
|
|
8
|
-
livekit/plugins/aws/tts.py,sha256=m2Z6VXyWsJebqzGTDqE39KvgkBgdQkZ731fuIjbszAY,7243
|
|
9
|
-
livekit/plugins/aws/version.py,sha256=3-nEcobvIJfZdV4yNIRuYpAGQ3svREnYIv2ivxoIZcQ,600
|
|
10
|
-
livekit_plugins_aws-0.1.1.dist-info/METADATA,sha256=9rnNMyDhecj1fQIbGcxvGo_I0cg7d_lI2xEf-tBMQfc,1702
|
|
11
|
-
livekit_plugins_aws-0.1.1.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
|
12
|
-
livekit_plugins_aws-0.1.1.dist-info/top_level.txt,sha256=OoDok3xUmXbZRvOrfvvXB-Juu4DX79dlq188E19YHoo,8
|
|
13
|
-
livekit_plugins_aws-0.1.1.dist-info/RECORD,,
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
livekit
|