livekit-plugins-aws 1.0.0rc6__py3-none-any.whl → 1.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/aws/__init__.py +47 -7
- livekit/plugins/aws/experimental/realtime/__init__.py +11 -0
- livekit/plugins/aws/experimental/realtime/events.py +545 -0
- livekit/plugins/aws/experimental/realtime/pretty_printer.py +49 -0
- livekit/plugins/aws/experimental/realtime/realtime_model.py +2106 -0
- livekit/plugins/aws/experimental/realtime/turn_tracker.py +171 -0
- livekit/plugins/aws/experimental/realtime/types.py +38 -0
- livekit/plugins/aws/llm.py +109 -71
- livekit/plugins/aws/log.py +4 -0
- livekit/plugins/aws/models.py +4 -3
- livekit/plugins/aws/stt.py +214 -71
- livekit/plugins/aws/tts.py +96 -116
- livekit/plugins/aws/utils.py +29 -125
- livekit/plugins/aws/version.py +1 -1
- livekit_plugins_aws-1.3.9.dist-info/METADATA +385 -0
- livekit_plugins_aws-1.3.9.dist-info/RECORD +18 -0
- {livekit_plugins_aws-1.0.0rc6.dist-info → livekit_plugins_aws-1.3.9.dist-info}/WHEEL +1 -1
- livekit_plugins_aws-1.0.0rc6.dist-info/METADATA +0 -43
- livekit_plugins_aws-1.0.0rc6.dist-info/RECORD +0 -12
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
import enum
|
|
5
|
+
import uuid
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
from livekit.agents import llm
|
|
10
|
+
|
|
11
|
+
from ...log import logger
|
|
12
|
+
|
|
13
|
+
# Nova Sonic's barge-in detection signal (raw content without newline)
|
|
14
|
+
BARGE_IN_CONTENT = '{ "interrupted" : true }'
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class _Phase(enum.Enum):
|
|
18
|
+
IDLE = 0 # waiting for the USER to begin speaking
|
|
19
|
+
USER_SPEAKING = 1 # still receiving USER text+audio blocks
|
|
20
|
+
USER_FINISHED = 2 # first ASSISTANT speculative block observed
|
|
21
|
+
ASSISTANT_RESPONDING = 3 # ASSISTANT audio/text streaming
|
|
22
|
+
DONE = 4 # assistant audio ended (END_TURN) or barge-in (INTERRUPTED)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# note: b/c user ASR text is transcribed server-side, a single turn constitutes
|
|
26
|
+
# both the user and agent's speech
|
|
27
|
+
@dataclass
|
|
28
|
+
class _Turn:
|
|
29
|
+
turn_id: int
|
|
30
|
+
input_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
31
|
+
created: datetime.datetime = field(default_factory=datetime.datetime.utcnow)
|
|
32
|
+
transcript: list[str] = field(default_factory=list)
|
|
33
|
+
|
|
34
|
+
phase: _Phase = _Phase.IDLE
|
|
35
|
+
ev_input_started: bool = False
|
|
36
|
+
ev_input_stopped: bool = False
|
|
37
|
+
ev_trans_completed: bool = False
|
|
38
|
+
ev_generation_sent: bool = False
|
|
39
|
+
|
|
40
|
+
def add_partial_text(self, text: str) -> None:
|
|
41
|
+
self.transcript.append(text)
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def curr_transcript(self) -> str:
|
|
45
|
+
return " ".join(self.transcript)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class _TurnTracker:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
emit_fn: Callable[[str, Any], None],
|
|
52
|
+
emit_generation_fn: Callable[[], None],
|
|
53
|
+
):
|
|
54
|
+
self._emit = emit_fn
|
|
55
|
+
self._turn_idx = 0
|
|
56
|
+
self._curr_turn: _Turn | None = None
|
|
57
|
+
self._emit_generation_fn = emit_generation_fn
|
|
58
|
+
|
|
59
|
+
# --------------------------------------------------------
|
|
60
|
+
# PUBLIC ENTRY POINT
|
|
61
|
+
# --------------------------------------------------------
|
|
62
|
+
def feed(self, event: dict) -> None:
|
|
63
|
+
turn = self._ensure_turn()
|
|
64
|
+
kind = _classify(event)
|
|
65
|
+
|
|
66
|
+
if kind == "USER_TEXT_PARTIAL":
|
|
67
|
+
turn.add_partial_text(event["event"]["textOutput"]["content"])
|
|
68
|
+
self._maybe_emit_input_started(turn)
|
|
69
|
+
self._emit_transcript_updated(turn)
|
|
70
|
+
# note: cannot invoke self._maybe_input_stopped() here
|
|
71
|
+
# b/c there is no way to know if the user is done speaking
|
|
72
|
+
|
|
73
|
+
# will always be correlated b/c generate_reply() is a stub
|
|
74
|
+
# user ASR text ends when agent's ASR speculative text begins
|
|
75
|
+
# corresponds to beginning of agent's turn
|
|
76
|
+
elif kind == "TOOL_OUTPUT_CONTENT_START" or kind == "ASSISTANT_SPEC_START":
|
|
77
|
+
# must be a maybe methods b/c agent can chain multiple tool calls
|
|
78
|
+
self._maybe_emit_input_stopped(turn)
|
|
79
|
+
self._maybe_emit_transcript_completed(turn)
|
|
80
|
+
self._maybe_emit_generation_created(turn)
|
|
81
|
+
|
|
82
|
+
elif kind == "BARGE_IN":
|
|
83
|
+
logger.debug(f"BARGE-IN DETECTED IN TURN TRACKER: {turn}")
|
|
84
|
+
# start new turn immediately to make interruptions snappier
|
|
85
|
+
self._emit("input_speech_started", llm.InputSpeechStartedEvent())
|
|
86
|
+
turn.phase = _Phase.DONE
|
|
87
|
+
|
|
88
|
+
elif kind == "ASSISTANT_AUDIO_END":
|
|
89
|
+
if event["event"]["contentEnd"]["stopReason"] == "END_TURN":
|
|
90
|
+
turn.phase = _Phase.DONE
|
|
91
|
+
|
|
92
|
+
if turn.phase is _Phase.DONE:
|
|
93
|
+
self._curr_turn = None
|
|
94
|
+
|
|
95
|
+
def _ensure_turn(self) -> _Turn:
|
|
96
|
+
if self._curr_turn is None:
|
|
97
|
+
self._turn_idx += 1
|
|
98
|
+
self._curr_turn = _Turn(turn_id=self._turn_idx)
|
|
99
|
+
return self._curr_turn
|
|
100
|
+
|
|
101
|
+
def _maybe_emit_input_started(self, turn: _Turn) -> None:
|
|
102
|
+
if not turn.ev_input_started:
|
|
103
|
+
turn.ev_input_started = True
|
|
104
|
+
self._emit("input_speech_started", llm.InputSpeechStartedEvent())
|
|
105
|
+
turn.phase = _Phase.USER_SPEAKING
|
|
106
|
+
|
|
107
|
+
def _maybe_emit_input_stopped(self, turn: _Turn) -> None:
|
|
108
|
+
if not turn.ev_input_stopped:
|
|
109
|
+
turn.ev_input_stopped = True
|
|
110
|
+
self._emit(
|
|
111
|
+
"input_speech_stopped", llm.InputSpeechStoppedEvent(user_transcription_enabled=True)
|
|
112
|
+
)
|
|
113
|
+
turn.phase = _Phase.USER_FINISHED
|
|
114
|
+
|
|
115
|
+
def _emit_transcript_updated(self, turn: _Turn) -> None:
|
|
116
|
+
self._emit(
|
|
117
|
+
"input_audio_transcription_completed",
|
|
118
|
+
llm.InputTranscriptionCompleted(
|
|
119
|
+
item_id=turn.input_id,
|
|
120
|
+
transcript=turn.curr_transcript,
|
|
121
|
+
is_final=False,
|
|
122
|
+
),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def _maybe_emit_transcript_completed(self, turn: _Turn) -> None:
|
|
126
|
+
if not turn.ev_trans_completed:
|
|
127
|
+
turn.ev_trans_completed = True
|
|
128
|
+
self._emit(
|
|
129
|
+
"input_audio_transcription_completed",
|
|
130
|
+
# Q: does input_id need to match /w the _ResponseGeneration.input_id?
|
|
131
|
+
llm.InputTranscriptionCompleted(
|
|
132
|
+
item_id=turn.input_id,
|
|
133
|
+
transcript=turn.curr_transcript,
|
|
134
|
+
is_final=True,
|
|
135
|
+
),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def _maybe_emit_generation_created(self, turn: _Turn) -> None:
|
|
139
|
+
if not turn.ev_generation_sent:
|
|
140
|
+
turn.ev_generation_sent = True
|
|
141
|
+
logger.debug(
|
|
142
|
+
f"[GEN] TurnTracker calling emit_generation_fn() for turn_id={turn.turn_id}"
|
|
143
|
+
)
|
|
144
|
+
self._emit_generation_fn()
|
|
145
|
+
turn.phase = _Phase.ASSISTANT_RESPONDING
|
|
146
|
+
else:
|
|
147
|
+
logger.debug(f"[GEN] TurnTracker SKIPPED - already sent for turn_id={turn.turn_id}")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _classify(ev: dict) -> str:
|
|
151
|
+
e = ev.get("event", {})
|
|
152
|
+
if "textOutput" in e and e["textOutput"]["role"] == "USER":
|
|
153
|
+
return "USER_TEXT_PARTIAL"
|
|
154
|
+
|
|
155
|
+
if "contentStart" in e and e["contentStart"]["type"] == "TOOL":
|
|
156
|
+
return "TOOL_OUTPUT_CONTENT_START"
|
|
157
|
+
|
|
158
|
+
if "contentStart" in e and e["contentStart"]["role"] == "ASSISTANT":
|
|
159
|
+
add = e["contentStart"].get("additionalModelFields", "")
|
|
160
|
+
if "SPECULATIVE" in add:
|
|
161
|
+
return "ASSISTANT_SPEC_START"
|
|
162
|
+
|
|
163
|
+
if "textOutput" in e and e["textOutput"]["content"] == BARGE_IN_CONTENT:
|
|
164
|
+
return "BARGE_IN"
|
|
165
|
+
|
|
166
|
+
# note: there cannot be any audio events for the user in the output event loop
|
|
167
|
+
# therefore, we know that the audio event must be for the assistant
|
|
168
|
+
if "contentEnd" in e and e["contentEnd"]["type"] == "AUDIO":
|
|
169
|
+
return "ASSISTANT_AUDIO_END"
|
|
170
|
+
|
|
171
|
+
return ""
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
TURN_DETECTION = Literal["HIGH", "MEDIUM", "LOW"]
|
|
4
|
+
MODALITIES = Literal["audio", "mixed"]
|
|
5
|
+
REALTIME_MODELS = Literal["amazon.nova-sonic-v1:0", "amazon.nova-2-sonic-v1:0"]
|
|
6
|
+
|
|
7
|
+
SONIC1_VOICES = Literal[
|
|
8
|
+
"matthew", # English (US) - Masculine
|
|
9
|
+
"tiffany", # English (US) - Feminine
|
|
10
|
+
"amy", # English (GB) - Feminine
|
|
11
|
+
"lupe", # Spanish - Feminine
|
|
12
|
+
"carlos", # Spanish - Masculine
|
|
13
|
+
"ambre", # French - Feminine
|
|
14
|
+
"florian", # French - Masculine
|
|
15
|
+
"greta", # German - Feminine
|
|
16
|
+
"lennart", # German - Masculine
|
|
17
|
+
"beatrice", # Italian - Feminine
|
|
18
|
+
"lorenzo", # Italian - Masculine
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
SONIC2_VOICES = Literal[
|
|
22
|
+
"matthew", # English (US) - Masculine - Polyglot
|
|
23
|
+
"tiffany", # English (US) - Feminine - Polyglot
|
|
24
|
+
"amy", # English (GB) - Feminine
|
|
25
|
+
"olivia", # English (US) - Feminine
|
|
26
|
+
"lupe", # Spanish - Feminine
|
|
27
|
+
"carlos", # Spanish - Masculine
|
|
28
|
+
"ambre", # French - Feminine
|
|
29
|
+
"florian", # French - Masculine
|
|
30
|
+
"tina", # German - Feminine
|
|
31
|
+
"lennart", # German - Masculine
|
|
32
|
+
"beatrice", # Italian - Feminine
|
|
33
|
+
"lorenzo", # Italian - Masculine
|
|
34
|
+
"carolina", # Portuguese (Brazilian) - Feminine
|
|
35
|
+
"leo", # Portuguese (Brazilian) - Masculine
|
|
36
|
+
"arjun", # Hindi - Masculine
|
|
37
|
+
"kiara", # Hindi - Feminine
|
|
38
|
+
]
|
livekit/plugins/aws/llm.py
CHANGED
|
@@ -14,15 +14,21 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
-
import asyncio
|
|
18
17
|
import os
|
|
19
18
|
from dataclasses import dataclass
|
|
20
|
-
from typing import Any,
|
|
19
|
+
from typing import Any, cast
|
|
21
20
|
|
|
22
|
-
import
|
|
21
|
+
import aioboto3 # type: ignore
|
|
22
|
+
from botocore.config import Config # type: ignore
|
|
23
23
|
|
|
24
24
|
from livekit.agents import APIConnectionError, APIStatusError, llm
|
|
25
|
-
from livekit.agents.llm import
|
|
25
|
+
from livekit.agents.llm import (
|
|
26
|
+
ChatContext,
|
|
27
|
+
FunctionTool,
|
|
28
|
+
FunctionToolCall,
|
|
29
|
+
RawFunctionTool,
|
|
30
|
+
ToolChoice,
|
|
31
|
+
)
|
|
26
32
|
from livekit.agents.types import (
|
|
27
33
|
DEFAULT_API_CONNECT_OPTIONS,
|
|
28
34
|
NOT_GIVEN,
|
|
@@ -32,34 +38,39 @@ from livekit.agents.types import (
|
|
|
32
38
|
from livekit.agents.utils import is_given
|
|
33
39
|
|
|
34
40
|
from .log import logger
|
|
35
|
-
from .utils import
|
|
41
|
+
from .utils import to_fnc_ctx
|
|
36
42
|
|
|
37
|
-
|
|
43
|
+
DEFAULT_TEXT_MODEL = "amazon.nova-2-lite-v1:0"
|
|
38
44
|
|
|
39
45
|
|
|
40
46
|
@dataclass
|
|
41
47
|
class _LLMOptions:
|
|
42
|
-
model: str
|
|
48
|
+
model: str
|
|
43
49
|
temperature: NotGivenOr[float]
|
|
44
50
|
tool_choice: NotGivenOr[ToolChoice]
|
|
45
51
|
max_output_tokens: NotGivenOr[int]
|
|
46
52
|
top_p: NotGivenOr[float]
|
|
47
53
|
additional_request_fields: NotGivenOr[dict[str, Any]]
|
|
54
|
+
cache_system: bool
|
|
55
|
+
cache_tools: bool
|
|
48
56
|
|
|
49
57
|
|
|
50
58
|
class LLM(llm.LLM):
|
|
51
59
|
def __init__(
|
|
52
60
|
self,
|
|
53
61
|
*,
|
|
54
|
-
model: NotGivenOr[str
|
|
62
|
+
model: NotGivenOr[str] = DEFAULT_TEXT_MODEL,
|
|
55
63
|
api_key: NotGivenOr[str] = NOT_GIVEN,
|
|
56
64
|
api_secret: NotGivenOr[str] = NOT_GIVEN,
|
|
57
|
-
region: NotGivenOr[str] =
|
|
65
|
+
region: NotGivenOr[str] = "us-east-1",
|
|
58
66
|
temperature: NotGivenOr[float] = NOT_GIVEN,
|
|
59
67
|
max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
|
|
60
68
|
top_p: NotGivenOr[float] = NOT_GIVEN,
|
|
61
69
|
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
|
|
62
70
|
additional_request_fields: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
|
71
|
+
cache_system: bool = False,
|
|
72
|
+
cache_tools: bool = False,
|
|
73
|
+
session: aioboto3.Session | None = None,
|
|
63
74
|
) -> None:
|
|
64
75
|
"""
|
|
65
76
|
Create a new instance of AWS Bedrock LLM.
|
|
@@ -70,7 +81,8 @@ class LLM(llm.LLM):
|
|
|
70
81
|
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html for more details on the AWS Bedrock Runtime API.
|
|
71
82
|
|
|
72
83
|
Args:
|
|
73
|
-
model (
|
|
84
|
+
model (str, optional): model or inference profile arn to use(https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-use.html).
|
|
85
|
+
Defaults to 'amazon.nova-2-lite-v1:0'.
|
|
74
86
|
api_key(str, optional): AWS access key id.
|
|
75
87
|
api_secret(str, optional): AWS secret access key
|
|
76
88
|
region (str, optional): The region to use for AWS API requests. Defaults value is "us-east-1".
|
|
@@ -79,36 +91,57 @@ class LLM(llm.LLM):
|
|
|
79
91
|
top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
|
|
80
92
|
tool_choice (ToolChoice, optional): Specifies whether to use tools during response generation. Defaults to "auto".
|
|
81
93
|
additional_request_fields (dict[str, Any], optional): Additional request fields to send to the AWS Bedrock Converse API. Defaults to None.
|
|
94
|
+
cache_system (bool, optional): Caches system messages to reduce token usage. Defaults to False.
|
|
95
|
+
cache_tools (bool, optional): Caches tool definitions to reduce token usage. Defaults to False.
|
|
96
|
+
session (aioboto3.Session, optional): Optional aioboto3 session to use.
|
|
82
97
|
""" # noqa: E501
|
|
83
98
|
super().__init__()
|
|
84
|
-
|
|
85
|
-
|
|
99
|
+
|
|
100
|
+
self._session = session or aioboto3.Session(
|
|
101
|
+
aws_access_key_id=api_key if is_given(api_key) else None,
|
|
102
|
+
aws_secret_access_key=api_secret if is_given(api_secret) else None,
|
|
103
|
+
region_name=region if is_given(region) else None,
|
|
86
104
|
)
|
|
87
105
|
|
|
88
|
-
|
|
89
|
-
|
|
106
|
+
bedrock_model = (
|
|
107
|
+
model if is_given(model) else os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
|
|
108
|
+
)
|
|
109
|
+
if not bedrock_model:
|
|
90
110
|
raise ValueError(
|
|
91
111
|
"model or inference profile arn must be set using the argument or by setting the BEDROCK_INFERENCE_PROFILE_ARN environment variable." # noqa: E501
|
|
92
112
|
)
|
|
93
113
|
self._opts = _LLMOptions(
|
|
94
|
-
model=
|
|
114
|
+
model=bedrock_model,
|
|
95
115
|
temperature=temperature,
|
|
96
116
|
tool_choice=tool_choice,
|
|
97
117
|
max_output_tokens=max_output_tokens,
|
|
98
118
|
top_p=top_p,
|
|
99
119
|
additional_request_fields=additional_request_fields,
|
|
120
|
+
cache_system=cache_system,
|
|
121
|
+
cache_tools=cache_tools,
|
|
100
122
|
)
|
|
101
123
|
|
|
124
|
+
@property
|
|
125
|
+
def model(self) -> str:
|
|
126
|
+
return self._opts.model
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def provider(self) -> str:
|
|
130
|
+
return "AWS Bedrock"
|
|
131
|
+
|
|
102
132
|
def chat(
|
|
103
133
|
self,
|
|
104
134
|
*,
|
|
105
135
|
chat_ctx: ChatContext,
|
|
106
|
-
tools: list[FunctionTool] | None = None,
|
|
136
|
+
tools: list[FunctionTool | RawFunctionTool] | None = None,
|
|
137
|
+
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
|
107
138
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
108
|
-
temperature: NotGivenOr[float] = NOT_GIVEN,
|
|
109
139
|
tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
|
|
140
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
|
141
|
+
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
|
110
142
|
) -> LLMStream:
|
|
111
|
-
opts = {}
|
|
143
|
+
opts: dict[str, Any] = {}
|
|
144
|
+
extra_kwargs = extra_kwargs if is_given(extra_kwargs) else {}
|
|
112
145
|
|
|
113
146
|
if is_given(self._opts.model):
|
|
114
147
|
opts["modelId"] = self._opts.model
|
|
@@ -119,8 +152,14 @@ class LLM(llm.LLM):
|
|
|
119
152
|
if not tools:
|
|
120
153
|
return None
|
|
121
154
|
|
|
122
|
-
|
|
123
|
-
|
|
155
|
+
tools_list = to_fnc_ctx(tools)
|
|
156
|
+
if self._opts.cache_tools:
|
|
157
|
+
tools_list.append({"cachePoint": {"type": "default"}})
|
|
158
|
+
|
|
159
|
+
tool_config: dict[str, Any] = {"tools": tools_list}
|
|
160
|
+
tool_choice = (
|
|
161
|
+
cast(ToolChoice, tool_choice) if is_given(tool_choice) else self._opts.tool_choice
|
|
162
|
+
)
|
|
124
163
|
if is_given(tool_choice):
|
|
125
164
|
if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
|
|
126
165
|
tool_config["toolChoice"] = {"tool": {"name": tool_choice["function"]["name"]}}
|
|
@@ -136,12 +175,17 @@ class LLM(llm.LLM):
|
|
|
136
175
|
tool_config = _get_tool_config()
|
|
137
176
|
if tool_config:
|
|
138
177
|
opts["toolConfig"] = tool_config
|
|
139
|
-
messages,
|
|
178
|
+
messages, extra_data = chat_ctx.to_provider_format(format="aws")
|
|
140
179
|
opts["messages"] = messages
|
|
141
|
-
if
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
180
|
+
if extra_data.system_messages:
|
|
181
|
+
system_messages: list[dict[str, str | dict]] = [
|
|
182
|
+
{"text": content} for content in extra_data.system_messages
|
|
183
|
+
]
|
|
184
|
+
if self._opts.cache_system:
|
|
185
|
+
system_messages.append({"cachePoint": {"type": "default"}})
|
|
186
|
+
opts["system"] = system_messages
|
|
187
|
+
|
|
188
|
+
inference_config: dict[str, Any] = {}
|
|
145
189
|
if is_given(self._opts.max_output_tokens):
|
|
146
190
|
inference_config["maxTokens"] = self._opts.max_output_tokens
|
|
147
191
|
temperature = temperature if is_given(temperature) else self._opts.temperature
|
|
@@ -156,11 +200,9 @@ class LLM(llm.LLM):
|
|
|
156
200
|
|
|
157
201
|
return LLMStream(
|
|
158
202
|
self,
|
|
159
|
-
aws_access_key_id=self._api_key,
|
|
160
|
-
aws_secret_access_key=self._api_secret,
|
|
161
|
-
region_name=self._region,
|
|
162
203
|
chat_ctx=chat_ctx,
|
|
163
|
-
tools=tools,
|
|
204
|
+
tools=tools or [],
|
|
205
|
+
session=self._session,
|
|
164
206
|
conn_options=conn_options,
|
|
165
207
|
extra_kwargs=opts,
|
|
166
208
|
)
|
|
@@ -171,24 +213,16 @@ class LLMStream(llm.LLMStream):
|
|
|
171
213
|
self,
|
|
172
214
|
llm: LLM,
|
|
173
215
|
*,
|
|
174
|
-
aws_access_key_id: str,
|
|
175
|
-
aws_secret_access_key: str,
|
|
176
|
-
region_name: str,
|
|
177
216
|
chat_ctx: ChatContext,
|
|
217
|
+
session: aioboto3.Session,
|
|
178
218
|
conn_options: APIConnectOptions,
|
|
179
|
-
tools: list[FunctionTool
|
|
219
|
+
tools: list[FunctionTool | RawFunctionTool],
|
|
180
220
|
extra_kwargs: dict[str, Any],
|
|
181
221
|
) -> None:
|
|
182
222
|
super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
|
183
|
-
self._client = boto3.client(
|
|
184
|
-
"bedrock-runtime",
|
|
185
|
-
region_name=region_name,
|
|
186
|
-
aws_access_key_id=aws_access_key_id,
|
|
187
|
-
aws_secret_access_key=aws_secret_access_key,
|
|
188
|
-
)
|
|
189
223
|
self._llm: LLM = llm
|
|
190
224
|
self._opts = extra_kwargs
|
|
191
|
-
|
|
225
|
+
self._session = session
|
|
192
226
|
self._tool_call_id: str | None = None
|
|
193
227
|
self._fnc_name: str | None = None
|
|
194
228
|
self._fnc_raw_arguments: str | None = None
|
|
@@ -197,23 +231,22 @@ class LLMStream(llm.LLMStream):
|
|
|
197
231
|
async def _run(self) -> None:
|
|
198
232
|
retryable = True
|
|
199
233
|
try:
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
await asyncio.sleep(0)
|
|
234
|
+
config = Config(user_agent_extra="x-client-framework:livekit-plugins-aws")
|
|
235
|
+
async with self._session.client("bedrock-runtime", config=config) as client:
|
|
236
|
+
response = await client.converse_stream(**self._opts)
|
|
237
|
+
request_id = response["ResponseMetadata"]["RequestId"]
|
|
238
|
+
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
|
239
|
+
raise APIStatusError(
|
|
240
|
+
f"aws bedrock llm: error generating content: {response}",
|
|
241
|
+
retryable=False,
|
|
242
|
+
request_id=request_id,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
async for chunk in response["stream"]:
|
|
246
|
+
chat_chunk = self._parse_chunk(request_id, chunk)
|
|
247
|
+
if chat_chunk is not None:
|
|
248
|
+
retryable = False
|
|
249
|
+
self._event_ch.send_nowait(chat_chunk)
|
|
217
250
|
|
|
218
251
|
except Exception as e:
|
|
219
252
|
raise APIConnectionError(
|
|
@@ -223,37 +256,42 @@ class LLMStream(llm.LLMStream):
|
|
|
223
256
|
|
|
224
257
|
def _parse_chunk(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
|
|
225
258
|
if "contentBlockStart" in chunk:
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
259
|
+
start = chunk["contentBlockStart"]["start"]
|
|
260
|
+
if "toolUse" in start:
|
|
261
|
+
tool_use = start["toolUse"]
|
|
262
|
+
self._tool_call_id = tool_use["toolUseId"]
|
|
263
|
+
self._fnc_name = tool_use["name"]
|
|
264
|
+
self._fnc_raw_arguments = ""
|
|
230
265
|
|
|
231
266
|
elif "contentBlockDelta" in chunk:
|
|
232
267
|
delta = chunk["contentBlockDelta"]["delta"]
|
|
233
268
|
if "toolUse" in delta:
|
|
234
269
|
self._fnc_raw_arguments += delta["toolUse"]["input"]
|
|
235
270
|
elif "text" in delta:
|
|
236
|
-
|
|
271
|
+
return llm.ChatChunk(
|
|
272
|
+
id=request_id,
|
|
273
|
+
delta=llm.ChoiceDelta(content=delta["text"], role="assistant"),
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
logger.warning(f"aws bedrock llm: unknown chunk type: {chunk}")
|
|
237
277
|
|
|
238
278
|
elif "metadata" in chunk:
|
|
239
279
|
metadata = chunk["metadata"]
|
|
240
280
|
return llm.ChatChunk(
|
|
241
|
-
|
|
281
|
+
id=request_id,
|
|
242
282
|
usage=llm.CompletionUsage(
|
|
243
283
|
completion_tokens=metadata["usage"]["outputTokens"],
|
|
244
284
|
prompt_tokens=metadata["usage"]["inputTokens"],
|
|
245
285
|
total_tokens=metadata["usage"]["totalTokens"],
|
|
286
|
+
prompt_cached_tokens=(
|
|
287
|
+
metadata["usage"]["cacheReadInputTokens"]
|
|
288
|
+
if "cacheReadInputTokens" in metadata["usage"]
|
|
289
|
+
else 0
|
|
290
|
+
),
|
|
246
291
|
),
|
|
247
292
|
)
|
|
248
293
|
elif "contentBlockStop" in chunk:
|
|
249
|
-
if self.
|
|
250
|
-
chat_chunk = llm.ChatChunk(
|
|
251
|
-
id=request_id,
|
|
252
|
-
delta=llm.ChoiceDelta(content=self._text, role="assistant"),
|
|
253
|
-
)
|
|
254
|
-
self._text = ""
|
|
255
|
-
return chat_chunk
|
|
256
|
-
elif self._tool_call_id:
|
|
294
|
+
if self._tool_call_id:
|
|
257
295
|
if self._tool_call_id is None:
|
|
258
296
|
logger.warning("aws bedrock llm: no tool call id in the response")
|
|
259
297
|
return None
|
livekit/plugins/aws/log.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
3
|
logger = logging.getLogger("livekit.plugins.aws")
|
|
4
|
+
smithy_logger = logging.getLogger("smithy_aws_event_stream.aio")
|
|
5
|
+
smithy_logger.setLevel(logging.INFO)
|
|
6
|
+
bedrock_client_logger = logging.getLogger("aws_sdk_bedrock_runtime.client")
|
|
7
|
+
bedrock_client_logger.setLevel(logging.INFO)
|
livekit/plugins/aws/models.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
TTSSpeechEngine = Literal["standard", "neural", "long-form", "generative"]
|
|
4
|
+
TTSLanguages = Literal[
|
|
5
5
|
"arb",
|
|
6
6
|
"cmn-CN",
|
|
7
7
|
"cy-GB",
|
|
@@ -45,4 +45,5 @@ TTS_LANGUAGE = Literal[
|
|
|
45
45
|
"de-CH",
|
|
46
46
|
]
|
|
47
47
|
|
|
48
|
-
|
|
48
|
+
TTSEncoding = Literal["mp3"]
|
|
49
|
+
TTSTextType = Literal["text", "ssml"]
|