livekit-plugins-aws 1.1.3__py3-none-any.whl → 1.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of livekit-plugins-aws might be problematic. Click here for more details.
- livekit/plugins/aws/__init__.py +32 -8
- livekit/plugins/aws/experimental/realtime/__init__.py +15 -0
- livekit/plugins/aws/experimental/realtime/events.py +521 -0
- livekit/plugins/aws/experimental/realtime/pretty_printer.py +49 -0
- livekit/plugins/aws/experimental/realtime/realtime_model.py +1208 -0
- livekit/plugins/aws/experimental/realtime/turn_tracker.py +172 -0
- livekit/plugins/aws/log.py +4 -0
- livekit/plugins/aws/tts.py +0 -2
- livekit/plugins/aws/version.py +1 -1
- {livekit_plugins_aws-1.1.3.dist-info → livekit_plugins_aws-1.1.5.dist-info}/METADATA +11 -5
- livekit_plugins_aws-1.1.5.dist-info/RECORD +17 -0
- livekit_plugins_aws-1.1.3.dist-info/RECORD +0 -12
- {livekit_plugins_aws-1.1.3.dist-info → livekit_plugins_aws-1.1.5.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,1208 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
import uuid
|
|
9
|
+
import weakref
|
|
10
|
+
from collections.abc import Iterator
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from typing import Any, Literal
|
|
14
|
+
|
|
15
|
+
import boto3
|
|
16
|
+
from aws_sdk_bedrock_runtime.client import (
|
|
17
|
+
BedrockRuntimeClient,
|
|
18
|
+
InvokeModelWithBidirectionalStreamOperationInput,
|
|
19
|
+
)
|
|
20
|
+
from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme
|
|
21
|
+
from aws_sdk_bedrock_runtime.models import (
|
|
22
|
+
BidirectionalInputPayloadPart,
|
|
23
|
+
InvokeModelWithBidirectionalStreamInputChunk,
|
|
24
|
+
ModelErrorException,
|
|
25
|
+
ModelNotReadyException,
|
|
26
|
+
ModelTimeoutException,
|
|
27
|
+
ThrottlingException,
|
|
28
|
+
ValidationException,
|
|
29
|
+
)
|
|
30
|
+
from smithy_aws_core.identity import AWSCredentialsIdentity
|
|
31
|
+
from smithy_core.aio.interfaces.identity import IdentityResolver
|
|
32
|
+
|
|
33
|
+
from livekit import rtc
|
|
34
|
+
from livekit.agents import (
|
|
35
|
+
APIStatusError,
|
|
36
|
+
ToolError,
|
|
37
|
+
llm,
|
|
38
|
+
utils,
|
|
39
|
+
)
|
|
40
|
+
from livekit.agents.llm.realtime import RealtimeSession
|
|
41
|
+
from livekit.agents.metrics import RealtimeModelMetrics
|
|
42
|
+
from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
|
43
|
+
from livekit.agents.utils import is_given
|
|
44
|
+
from livekit.plugins.aws.experimental.realtime.turn_tracker import _TurnTracker
|
|
45
|
+
|
|
46
|
+
from ...log import logger
|
|
47
|
+
from .events import (
|
|
48
|
+
VOICE_ID,
|
|
49
|
+
SonicEventBuilder as seb,
|
|
50
|
+
Tool,
|
|
51
|
+
ToolConfiguration,
|
|
52
|
+
ToolInputSchema,
|
|
53
|
+
ToolSpec,
|
|
54
|
+
)
|
|
55
|
+
from .pretty_printer import AnsiColors, log_event_data, log_message
|
|
56
|
+
|
|
57
|
+
DEFAULT_INPUT_SAMPLE_RATE = 16000
|
|
58
|
+
DEFAULT_OUTPUT_SAMPLE_RATE = 24000
|
|
59
|
+
DEFAULT_SAMPLE_SIZE_BITS = 16
|
|
60
|
+
DEFAULT_CHANNELS = 1
|
|
61
|
+
DEFAULT_CHUNK_SIZE = 512
|
|
62
|
+
DEFAULT_TEMPERATURE = 0.7
|
|
63
|
+
DEFAULT_TOP_P = 0.9
|
|
64
|
+
DEFAULT_MAX_TOKENS = 1024
|
|
65
|
+
MAX_MESSAGE_SIZE = 1024
|
|
66
|
+
MAX_MESSAGES = 40
|
|
67
|
+
DEFAULT_MAX_SESSION_RESTART_ATTEMPTS = 3
|
|
68
|
+
DEFAULT_MAX_SESSION_RESTART_DELAY = 10
|
|
69
|
+
DEFAULT_SYSTEM_PROMPT = (
|
|
70
|
+
"Your name is Sonic. You are a friend and eagerly helpful assistant."
|
|
71
|
+
"The user and you will engage in a spoken dialog exchanging the transcripts of a natural real-time conversation." # noqa: E501
|
|
72
|
+
"Keep your responses short and concise unless the user asks you to elaborate or you are explicitly asked to be verbose and chatty." # noqa: E501
|
|
73
|
+
"Do not repeat yourself. Do not ask the user to repeat themselves."
|
|
74
|
+
"Do ask the user to confirm or clarify their response if you are not sure what they mean."
|
|
75
|
+
"If after asking the user for clarification you still do not understand, be honest and tell them that you do not understand." # noqa: E501
|
|
76
|
+
"Do not make up information or make assumptions. If you do not know the answer, tell the user that you do not know the answer." # noqa: E501
|
|
77
|
+
"If the user makes a request of you that you cannot fulfill, tell them why you cannot fulfill it." # noqa: E501
|
|
78
|
+
"When making tool calls, inform the user that you are using a tool to generate the response."
|
|
79
|
+
"Avoid formatted lists or numbering and keep your output as a spoken transcript to be acted out." # noqa: E501
|
|
80
|
+
"Be appropriately emotive when responding to the user. Use American English as the language for your responses." # noqa: E501
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
lk_bedrock_debug = int(os.getenv("LK_BEDROCK_DEBUG", 0))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class _RealtimeOptions:
|
|
88
|
+
"""Configuration container for a Sonic realtime session.
|
|
89
|
+
|
|
90
|
+
Attributes:
|
|
91
|
+
voice (VOICE_ID): Voice identifier used for TTS output.
|
|
92
|
+
temperature (float): Sampling temperature controlling randomness; 1.0 is most deterministic.
|
|
93
|
+
top_p (float): Nucleus sampling parameter; 0.0 considers all tokens.
|
|
94
|
+
max_tokens (int): Maximum number of tokens the model may generate in a single response.
|
|
95
|
+
tool_choice (llm.ToolChoice | None): Strategy that dictates how the model should invoke tools.
|
|
96
|
+
region (str): AWS region hosting the Bedrock Sonic model endpoint.
|
|
97
|
+
""" # noqa: E501
|
|
98
|
+
|
|
99
|
+
voice: VOICE_ID
|
|
100
|
+
temperature: float
|
|
101
|
+
top_p: float
|
|
102
|
+
max_tokens: int
|
|
103
|
+
tool_choice: llm.ToolChoice | None
|
|
104
|
+
region: str
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class _MessageGeneration:
|
|
109
|
+
"""Grouping of streams that together represent one assistant message.
|
|
110
|
+
|
|
111
|
+
Attributes:
|
|
112
|
+
message_id (str): Unique identifier that ties together text and audio for a single assistant turn.
|
|
113
|
+
text_ch (utils.aio.Chan[str]): Channel that yields partial text tokens as they arrive.
|
|
114
|
+
audio_ch (utils.aio.Chan[rtc.AudioFrame]): Channel that yields audio frames for the same assistant turn.
|
|
115
|
+
""" # noqa: E501
|
|
116
|
+
|
|
117
|
+
message_id: str
|
|
118
|
+
text_ch: utils.aio.Chan[str]
|
|
119
|
+
audio_ch: utils.aio.Chan[rtc.AudioFrame]
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class _ResponseGeneration:
|
|
124
|
+
"""Book-keeping dataclass tracking the lifecycle of a Sonic turn.
|
|
125
|
+
|
|
126
|
+
This object is created whenever we receive a *completion_start* event from the model
|
|
127
|
+
and is disposed of once the assistant turn finishes (e.g. *END_TURN*).
|
|
128
|
+
|
|
129
|
+
Attributes:
|
|
130
|
+
message_ch (utils.aio.Chan[llm.MessageGeneration]): Multiplexed stream for all assistant messages.
|
|
131
|
+
function_ch (utils.aio.Chan[llm.FunctionCall]): Stream that emits function tool calls.
|
|
132
|
+
input_id (str): Synthetic message id for the user input of the current turn.
|
|
133
|
+
response_id (str): Synthetic message id for the assistant reply of the current turn.
|
|
134
|
+
messages (dict[str, _MessageGeneration]): Map of message_id -> per-message stream containers.
|
|
135
|
+
user_messages (dict[str, str]): Map Bedrock content_id -> input_id.
|
|
136
|
+
speculative_messages (dict[str, str]): Map Bedrock content_id -> response_id (assistant side).
|
|
137
|
+
tool_messages (dict[str, str]): Map Bedrock content_id -> response_id for tool calls.
|
|
138
|
+
output_text (str): Accumulated assistant text (only used for metrics / debugging).
|
|
139
|
+
_created_timestamp (str): ISO-8601 timestamp when the generation record was created.
|
|
140
|
+
_first_token_timestamp (float | None): Wall-clock time of first token emission.
|
|
141
|
+
_completed_timestamp (float | None): Wall-clock time when the turn fully completed.
|
|
142
|
+
""" # noqa: E501
|
|
143
|
+
|
|
144
|
+
message_ch: utils.aio.Chan[llm.MessageGeneration]
|
|
145
|
+
function_ch: utils.aio.Chan[llm.FunctionCall]
|
|
146
|
+
input_id: str # corresponds to user's portion of the turn
|
|
147
|
+
response_id: str # corresponds to agent's portion of the turn
|
|
148
|
+
messages: dict[str, _MessageGeneration] = field(default_factory=dict)
|
|
149
|
+
user_messages: dict[str, str] = field(default_factory=dict)
|
|
150
|
+
speculative_messages: dict[str, str] = field(default_factory=dict)
|
|
151
|
+
tool_messages: dict[str, str] = field(default_factory=dict)
|
|
152
|
+
output_text: str = "" # agent ASR text
|
|
153
|
+
_created_timestamp: str = field(default_factory=datetime.now().isoformat())
|
|
154
|
+
_first_token_timestamp: float | None = None
|
|
155
|
+
_completed_timestamp: float | None = None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class Boto3CredentialsResolver(IdentityResolver):
|
|
159
|
+
"""IdentityResolver implementation that sources AWS credentials from boto3.
|
|
160
|
+
|
|
161
|
+
The resolver delegates to the default boto3.Session() credential chain which
|
|
162
|
+
checks environment variables, shared credentials files, EC2 instance profiles, etc.
|
|
163
|
+
The credentials are then wrapped in an AWSCredentialsIdentity so they can be
|
|
164
|
+
passed into Bedrock runtime clients.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
def __init__(self):
|
|
168
|
+
self.session = boto3.Session()
|
|
169
|
+
|
|
170
|
+
async def get_identity(self, **kwargs):
|
|
171
|
+
"""Asynchronously resolve AWS credentials.
|
|
172
|
+
|
|
173
|
+
This method is invoked by the Bedrock runtime client whenever a new request needs to be
|
|
174
|
+
signed. It converts the static or temporary credentials returned by boto3
|
|
175
|
+
into an AWSCredentialsIdentity instance.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
AWSCredentialsIdentity: Identity containing the
|
|
179
|
+
AWS access key, secret key and optional session token.
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ValueError: If no credentials could be found by boto3.
|
|
183
|
+
"""
|
|
184
|
+
try:
|
|
185
|
+
logger.debug("Attempting to load AWS credentials")
|
|
186
|
+
credentials = self.session.get_credentials()
|
|
187
|
+
if not credentials:
|
|
188
|
+
logger.error("Unable to load AWS credentials")
|
|
189
|
+
raise ValueError("Unable to load AWS credentials")
|
|
190
|
+
|
|
191
|
+
creds = credentials.get_frozen_credentials()
|
|
192
|
+
logger.debug(
|
|
193
|
+
f"AWS credentials loaded successfully. AWS_ACCESS_KEY_ID: {creds.access_key[:4]}***"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
identity = AWSCredentialsIdentity(
|
|
197
|
+
access_key_id=creds.access_key,
|
|
198
|
+
secret_access_key=creds.secret_key,
|
|
199
|
+
session_token=creds.token if creds.token else None,
|
|
200
|
+
expiration=None,
|
|
201
|
+
)
|
|
202
|
+
return identity
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.error(f"Failed to load AWS credentials: {str(e)}")
|
|
205
|
+
raise ValueError(f"Failed to load AWS credentials: {str(e)}") # noqa: B904
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class RealtimeModel(llm.RealtimeModel):
|
|
209
|
+
"""High-level entry point that conforms to the LiveKit RealtimeModel interface.
|
|
210
|
+
|
|
211
|
+
The object is very light-weight-– it mainly stores default inference options and
|
|
212
|
+
spawns a RealtimeSession when session() is invoked.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
*,
|
|
218
|
+
voice: NotGivenOr[VOICE_ID] = NOT_GIVEN,
|
|
219
|
+
temperature: NotGivenOr[float] = NOT_GIVEN,
|
|
220
|
+
top_p: NotGivenOr[float] = NOT_GIVEN,
|
|
221
|
+
max_tokens: NotGivenOr[int] = NOT_GIVEN,
|
|
222
|
+
tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
|
|
223
|
+
region: NotGivenOr[str] = NOT_GIVEN,
|
|
224
|
+
):
|
|
225
|
+
"""Instantiate a new RealtimeModel.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
voice (VOICE_ID | NotGiven): Preferred voice id for Sonic TTS output. Falls back to "tiffany".
|
|
229
|
+
temperature (float | NotGiven): Sampling temperature (0-1). Defaults to DEFAULT_TEMPERATURE.
|
|
230
|
+
top_p (float | NotGiven): Nucleus sampling probability mass. Defaults to DEFAULT_TOP_P.
|
|
231
|
+
max_tokens (int | NotGiven): Upper bound for tokens emitted by the model. Defaults to DEFAULT_MAX_TOKENS.
|
|
232
|
+
tool_choice (llm.ToolChoice | None | NotGiven): Strategy for tool invocation ("auto", "required", or explicit function).
|
|
233
|
+
region (str | NotGiven): AWS region of the Bedrock runtime endpoint.
|
|
234
|
+
""" # noqa: E501
|
|
235
|
+
super().__init__(
|
|
236
|
+
capabilities=llm.RealtimeCapabilities(
|
|
237
|
+
message_truncation=False,
|
|
238
|
+
turn_detection=True,
|
|
239
|
+
user_transcription=True,
|
|
240
|
+
auto_tool_reply_generation=True,
|
|
241
|
+
)
|
|
242
|
+
)
|
|
243
|
+
self.model_id = "amazon.nova-sonic-v1:0"
|
|
244
|
+
# note: temperature and top_p do not follow industry standards and are defined slightly differently for Sonic # noqa: E501
|
|
245
|
+
# temperature ranges from 0.0 to 1.0, where 0.0 is the most random and 1.0 is the most deterministic # noqa: E501
|
|
246
|
+
# top_p ranges from 0.0 to 1.0, where 0.0 is the most random and 1.0 is the most deterministic # noqa: E501
|
|
247
|
+
self.temperature = temperature
|
|
248
|
+
self.top_p = top_p
|
|
249
|
+
self._opts = _RealtimeOptions(
|
|
250
|
+
voice=voice if is_given(voice) else "tiffany",
|
|
251
|
+
temperature=temperature if is_given(temperature) else DEFAULT_TEMPERATURE,
|
|
252
|
+
top_p=top_p if is_given(top_p) else DEFAULT_TOP_P,
|
|
253
|
+
max_tokens=max_tokens if is_given(max_tokens) else DEFAULT_MAX_TOKENS,
|
|
254
|
+
tool_choice=tool_choice or None,
|
|
255
|
+
region=region if is_given(region) else "us-east-1",
|
|
256
|
+
)
|
|
257
|
+
self._sessions = weakref.WeakSet[RealtimeSession]()
|
|
258
|
+
|
|
259
|
+
def session(self) -> RealtimeSession:
|
|
260
|
+
"""Return a new RealtimeSession bound to this model instance."""
|
|
261
|
+
sess = RealtimeSession(self)
|
|
262
|
+
|
|
263
|
+
# note: this is a hack to get the session to initialize itself
|
|
264
|
+
# TODO: change how RealtimeSession is initialized by creating a single task main_atask that spawns subtasks # noqa: E501
|
|
265
|
+
asyncio.create_task(sess.initialize_streams())
|
|
266
|
+
self._sessions.add(sess)
|
|
267
|
+
return sess
|
|
268
|
+
|
|
269
|
+
# stub b/c RealtimeSession.aclose() is invoked directly
|
|
270
|
+
async def aclose(self) -> None:
|
|
271
|
+
pass
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class RealtimeSession( # noqa: F811
|
|
275
|
+
llm.RealtimeSession[Literal["bedrock_server_event_received", "bedrock_client_event_queued"]]
|
|
276
|
+
):
|
|
277
|
+
"""Bidirectional streaming session against the Nova Sonic Bedrock runtime.
|
|
278
|
+
|
|
279
|
+
The session owns two asynchronous tasks:
|
|
280
|
+
|
|
281
|
+
1. _process_audio_input – pushes user mic audio and tool results to Bedrock.
|
|
282
|
+
2. _process_responses – receives server events from Bedrock and converts them into
|
|
283
|
+
LiveKit abstractions such as llm.MessageGeneration.
|
|
284
|
+
|
|
285
|
+
A set of helper handlers (_handle_*) transform the low-level Bedrock
|
|
286
|
+
JSON payloads into higher-level application events and keep
|
|
287
|
+
_ResponseGeneration state in sync.
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
def __init__(self, realtime_model: RealtimeModel) -> None:
|
|
291
|
+
"""Create and wire-up a new realtime session.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
realtime_model (RealtimeModel): Parent model instance that stores static
|
|
295
|
+
inference options and the Smithy Bedrock client configuration.
|
|
296
|
+
"""
|
|
297
|
+
super().__init__(realtime_model)
|
|
298
|
+
self._realtime_model = realtime_model
|
|
299
|
+
self._event_builder = seb(
|
|
300
|
+
prompt_name=str(uuid.uuid4()),
|
|
301
|
+
audio_content_name=str(uuid.uuid4()),
|
|
302
|
+
)
|
|
303
|
+
self._input_resampler: rtc.AudioResampler | None = None
|
|
304
|
+
self._bstream = utils.audio.AudioByteStream(
|
|
305
|
+
DEFAULT_INPUT_SAMPLE_RATE, DEFAULT_CHANNELS, samples_per_channel=DEFAULT_CHUNK_SIZE
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
self._response_task = None
|
|
309
|
+
self._audio_input_task = None
|
|
310
|
+
self._stream_response = None
|
|
311
|
+
self._bedrock_client = None
|
|
312
|
+
self._is_sess_active = asyncio.Event()
|
|
313
|
+
self._chat_ctx = llm.ChatContext.empty()
|
|
314
|
+
self._tools = llm.ToolContext.empty()
|
|
315
|
+
self._tool_type_map = {}
|
|
316
|
+
self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
|
|
317
|
+
self._tools_ready = asyncio.get_running_loop().create_future()
|
|
318
|
+
self._instructions_ready = asyncio.get_running_loop().create_future()
|
|
319
|
+
self._chat_ctx_ready = asyncio.get_running_loop().create_future()
|
|
320
|
+
self._instructions = DEFAULT_SYSTEM_PROMPT
|
|
321
|
+
self._audio_input_chan = utils.aio.Chan[bytes]()
|
|
322
|
+
self._current_generation: _ResponseGeneration | None = None
|
|
323
|
+
|
|
324
|
+
# note: currently tracks session restart attempts across all sessions
|
|
325
|
+
# TODO: track restart attempts per turn
|
|
326
|
+
self._session_restart_attempts = 0
|
|
327
|
+
|
|
328
|
+
self._event_handlers = {
|
|
329
|
+
"completion_start": self._handle_completion_start_event,
|
|
330
|
+
"audio_output_content_start": self._handle_audio_output_content_start_event,
|
|
331
|
+
"audio_output_content": self._handle_audio_output_content_event,
|
|
332
|
+
"audio_output_content_end": self._handle_audio_output_content_end_event,
|
|
333
|
+
"text_output_content_start": self._handle_text_output_content_start_event,
|
|
334
|
+
"text_output_content": self._handle_text_output_content_event,
|
|
335
|
+
"text_output_content_end": self._handle_text_output_content_end_event,
|
|
336
|
+
"tool_output_content_start": self._handle_tool_output_content_start_event,
|
|
337
|
+
"tool_output_content": self._handle_tool_output_content_event,
|
|
338
|
+
"tool_output_content_end": self._handle_tool_output_content_end_event,
|
|
339
|
+
"completion_end": self._handle_completion_end_event,
|
|
340
|
+
"usage": self._handle_usage_event,
|
|
341
|
+
"other_event": self._handle_other_event,
|
|
342
|
+
}
|
|
343
|
+
self._turn_tracker = _TurnTracker(
|
|
344
|
+
self.emit, streams_provider=self._current_generation_streams
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def _current_generation_streams(
|
|
348
|
+
self,
|
|
349
|
+
) -> tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]:
|
|
350
|
+
return (self._current_generation.message_ch, self._current_generation.function_ch)
|
|
351
|
+
|
|
352
|
+
@utils.log_exceptions(logger=logger)
|
|
353
|
+
def _initialize_client(self):
|
|
354
|
+
"""Instantiate the Bedrock runtime client"""
|
|
355
|
+
config = Config(
|
|
356
|
+
endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com",
|
|
357
|
+
region=self._realtime_model._opts.region,
|
|
358
|
+
aws_credentials_identity_resolver=Boto3CredentialsResolver(),
|
|
359
|
+
http_auth_scheme_resolver=HTTPAuthSchemeResolver(),
|
|
360
|
+
http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()},
|
|
361
|
+
)
|
|
362
|
+
self._bedrock_client = BedrockRuntimeClient(config=config)
|
|
363
|
+
|
|
364
|
+
@utils.log_exceptions(logger=logger)
|
|
365
|
+
async def _send_raw_event(self, event_json):
|
|
366
|
+
"""Low-level helper that serialises event_json and forwards it to the bidirectional stream.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
event_json (dict | str): The JSON payload (already in Bedrock wire format) to queue.
|
|
370
|
+
|
|
371
|
+
Raises:
|
|
372
|
+
Exception: Propagates any failures returned by the Bedrock runtime client.
|
|
373
|
+
"""
|
|
374
|
+
if not self._stream_response:
|
|
375
|
+
logger.warning("stream not initialized; dropping event (this should never occur)")
|
|
376
|
+
return
|
|
377
|
+
|
|
378
|
+
event = InvokeModelWithBidirectionalStreamInputChunk(
|
|
379
|
+
value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8"))
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
try:
|
|
383
|
+
await self._stream_response.input_stream.send(event)
|
|
384
|
+
except Exception as e:
|
|
385
|
+
logger.exception("Error sending event")
|
|
386
|
+
err_msg = getattr(e, "message", str(e))
|
|
387
|
+
request_id = None
|
|
388
|
+
try:
|
|
389
|
+
request_id = err_msg.split(" ")[0].split("=")[1]
|
|
390
|
+
except Exception:
|
|
391
|
+
pass
|
|
392
|
+
|
|
393
|
+
self.emit(
|
|
394
|
+
"error",
|
|
395
|
+
llm.RealtimeModelError(
|
|
396
|
+
timestamp=time.monotonic(),
|
|
397
|
+
label=self._realtime_model._label,
|
|
398
|
+
error=APIStatusError(
|
|
399
|
+
message=err_msg,
|
|
400
|
+
status_code=500,
|
|
401
|
+
request_id=request_id,
|
|
402
|
+
body=e,
|
|
403
|
+
retryable=False,
|
|
404
|
+
),
|
|
405
|
+
recoverable=False,
|
|
406
|
+
),
|
|
407
|
+
)
|
|
408
|
+
raise
|
|
409
|
+
|
|
410
|
+
def _serialize_tool_config(self) -> ToolConfiguration | None:
|
|
411
|
+
"""Convert self.tools into the JSON structure expected by Sonic.
|
|
412
|
+
|
|
413
|
+
If any tools are registered, the method also harmonises temperature and
|
|
414
|
+
top_p defaults to Sonic's recommended greedy values (1.0).
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
ToolConfiguration | None: None when no tools are present, otherwise a complete config block.
|
|
418
|
+
""" # noqa: E501
|
|
419
|
+
tool_cfg = None
|
|
420
|
+
if self.tools.function_tools:
|
|
421
|
+
tools = []
|
|
422
|
+
for name, f in self.tools.function_tools.items():
|
|
423
|
+
if llm.tool_context.is_function_tool(f):
|
|
424
|
+
description = llm.tool_context.get_function_info(f).description
|
|
425
|
+
input_schema = llm.utils.build_legacy_openai_schema(f, internally_tagged=True)[
|
|
426
|
+
"parameters"
|
|
427
|
+
]
|
|
428
|
+
self._tool_type_map[name] = "FunctionTool"
|
|
429
|
+
else:
|
|
430
|
+
description = llm.tool_context.get_raw_function_info(f).raw_schema.get(
|
|
431
|
+
"description"
|
|
432
|
+
)
|
|
433
|
+
input_schema = llm.tool_context.get_raw_function_info(f).raw_schema[
|
|
434
|
+
"parameters"
|
|
435
|
+
]
|
|
436
|
+
self._tool_type_map[name] = "RawFunctionTool"
|
|
437
|
+
|
|
438
|
+
tool = Tool(
|
|
439
|
+
toolSpec=ToolSpec(
|
|
440
|
+
name=name,
|
|
441
|
+
description=description,
|
|
442
|
+
inputSchema=ToolInputSchema(json_=json.dumps(input_schema)),
|
|
443
|
+
)
|
|
444
|
+
)
|
|
445
|
+
tools.append(tool)
|
|
446
|
+
tool_choice = self._tool_choice_adapter(self._realtime_model._opts.tool_choice)
|
|
447
|
+
logger.debug(f"TOOL CHOICE: {tool_choice}")
|
|
448
|
+
tool_cfg = ToolConfiguration(tools=tools, toolChoice=tool_choice)
|
|
449
|
+
|
|
450
|
+
# recommended to set greedy inference configs for tool calls
|
|
451
|
+
if not is_given(self._realtime_model.top_p):
|
|
452
|
+
self._realtime_model._opts.top_p = 1.0
|
|
453
|
+
if not is_given(self._realtime_model.temperature):
|
|
454
|
+
self._realtime_model._opts.temperature = 1.0
|
|
455
|
+
return tool_cfg
|
|
456
|
+
|
|
457
|
+
@utils.log_exceptions(logger=logger)
|
|
458
|
+
async def initialize_streams(self, is_restart: bool = False):
|
|
459
|
+
"""Open the Bedrock bidirectional stream and spawn background worker tasks.
|
|
460
|
+
|
|
461
|
+
This coroutine is idempotent and can be invoked again when recoverable
|
|
462
|
+
errors (e.g. timeout, throttling) require a fresh session.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
is_restart (bool, optional): Marks whether we are re-initialising an
|
|
466
|
+
existing session after an error. Defaults to False.
|
|
467
|
+
"""
|
|
468
|
+
try:
|
|
469
|
+
if not self._bedrock_client:
|
|
470
|
+
logger.info("Creating Bedrock client")
|
|
471
|
+
self._initialize_client()
|
|
472
|
+
|
|
473
|
+
logger.info("Initializing Bedrock stream")
|
|
474
|
+
self._stream_response = (
|
|
475
|
+
await self._bedrock_client.invoke_model_with_bidirectional_stream(
|
|
476
|
+
InvokeModelWithBidirectionalStreamOperationInput(
|
|
477
|
+
model_id=self._realtime_model.model_id
|
|
478
|
+
)
|
|
479
|
+
)
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if not is_restart:
|
|
483
|
+
pending_events: list[asyncio.Future] = []
|
|
484
|
+
if not self.tools.function_tools:
|
|
485
|
+
pending_events.append(self._tools_ready)
|
|
486
|
+
if not self._instructions_ready.done():
|
|
487
|
+
pending_events.append(self._instructions_ready)
|
|
488
|
+
if not self._chat_ctx_ready.done():
|
|
489
|
+
pending_events.append(self._chat_ctx_ready)
|
|
490
|
+
|
|
491
|
+
# note: can't know during sess init whether tools were not added
|
|
492
|
+
# or if they were added haven't yet been updated
|
|
493
|
+
# therefore in the case there are no tools, we wait the entire timeout
|
|
494
|
+
try:
|
|
495
|
+
if pending_events:
|
|
496
|
+
await asyncio.wait_for(asyncio.gather(*pending_events), timeout=0.5)
|
|
497
|
+
except asyncio.TimeoutError:
|
|
498
|
+
if not self._tools_ready.done():
|
|
499
|
+
logger.warning("Tools not ready after 500ms, continuing without them")
|
|
500
|
+
|
|
501
|
+
if not self._instructions_ready.done():
|
|
502
|
+
logger.warning(
|
|
503
|
+
"Instructions not received after 500ms, proceeding with default instructions" # noqa: E501
|
|
504
|
+
)
|
|
505
|
+
if not self._chat_ctx_ready.done():
|
|
506
|
+
logger.warning(
|
|
507
|
+
"Chat context not received after 500ms, proceeding with empty chat context" # noqa: E501
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
logger.info(
|
|
511
|
+
f"Initializing Bedrock session with realtime options: {self._realtime_model._opts}"
|
|
512
|
+
)
|
|
513
|
+
# there is a 40-message limit on the chat context
|
|
514
|
+
if len(self._chat_ctx.items) > MAX_MESSAGES:
|
|
515
|
+
logger.warning(
|
|
516
|
+
f"Chat context has {len(self._chat_ctx.items)} messages, truncating to {MAX_MESSAGES}" # noqa: E501
|
|
517
|
+
)
|
|
518
|
+
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
519
|
+
init_events = self._event_builder.create_prompt_start_block(
|
|
520
|
+
voice_id=self._realtime_model._opts.voice,
|
|
521
|
+
sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
|
|
522
|
+
system_content=self._instructions,
|
|
523
|
+
chat_ctx=self.chat_ctx,
|
|
524
|
+
tool_configuration=self._serialize_tool_config(),
|
|
525
|
+
max_tokens=self._realtime_model._opts.max_tokens,
|
|
526
|
+
top_p=self._realtime_model._opts.top_p,
|
|
527
|
+
temperature=self._realtime_model._opts.temperature,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
for event in init_events:
|
|
531
|
+
await self._send_raw_event(event)
|
|
532
|
+
logger.debug(f"Sent event: {event}")
|
|
533
|
+
|
|
534
|
+
if not is_restart:
|
|
535
|
+
self._audio_input_task = asyncio.create_task(
|
|
536
|
+
self._process_audio_input(), name="RealtimeSession._process_audio_input"
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
self._response_task = asyncio.create_task(
|
|
540
|
+
self._process_responses(), name="RealtimeSession._process_responses"
|
|
541
|
+
)
|
|
542
|
+
self._is_sess_active.set()
|
|
543
|
+
logger.debug("Stream initialized successfully")
|
|
544
|
+
except Exception as e:
|
|
545
|
+
self._is_sess_active.set_exception(e)
|
|
546
|
+
logger.debug(f"Failed to initialize stream: {str(e)}")
|
|
547
|
+
raise
|
|
548
|
+
return self
|
|
549
|
+
|
|
550
|
+
@utils.log_exceptions(logger=logger)
|
|
551
|
+
def _emit_generation_event(self) -> None:
|
|
552
|
+
"""Publish a llm.GenerationCreatedEvent to external subscribers."""
|
|
553
|
+
logger.debug("Emitting generation event")
|
|
554
|
+
generation_ev = llm.GenerationCreatedEvent(
|
|
555
|
+
message_stream=self._current_generation.message_ch,
|
|
556
|
+
function_stream=self._current_generation.function_ch,
|
|
557
|
+
user_initiated=False,
|
|
558
|
+
)
|
|
559
|
+
self.emit("generation_created", generation_ev)
|
|
560
|
+
|
|
561
|
+
@utils.log_exceptions(logger=logger)
|
|
562
|
+
async def _handle_event(self, event_data: dict) -> None:
|
|
563
|
+
"""Dispatch a raw Bedrock event to the corresponding _handle_* method."""
|
|
564
|
+
event_type = self._event_builder.get_event_type(event_data)
|
|
565
|
+
event_handler = self._event_handlers.get(event_type)
|
|
566
|
+
if event_handler:
|
|
567
|
+
await event_handler(event_data)
|
|
568
|
+
self._turn_tracker.feed(event_data)
|
|
569
|
+
else:
|
|
570
|
+
logger.warning(f"No event handler found for event type: {event_type}")
|
|
571
|
+
|
|
572
|
+
async def _handle_completion_start_event(self, event_data: dict) -> None:
|
|
573
|
+
log_event_data(event_data)
|
|
574
|
+
self._create_response_generation()
|
|
575
|
+
|
|
576
|
+
def _create_response_generation(self) -> None:
|
|
577
|
+
"""Instantiate _ResponseGeneration and emit the GenerationCreated event."""
|
|
578
|
+
if self._current_generation is None:
|
|
579
|
+
self._current_generation = _ResponseGeneration(
|
|
580
|
+
message_ch=utils.aio.Chan(),
|
|
581
|
+
function_ch=utils.aio.Chan(),
|
|
582
|
+
input_id=str(uuid.uuid4()),
|
|
583
|
+
response_id=str(uuid.uuid4()),
|
|
584
|
+
messages={},
|
|
585
|
+
user_messages={},
|
|
586
|
+
speculative_messages={},
|
|
587
|
+
_created_timestamp=datetime.now().isoformat(),
|
|
588
|
+
)
|
|
589
|
+
msg_gen = _MessageGeneration(
|
|
590
|
+
message_id=self._current_generation.response_id,
|
|
591
|
+
text_ch=utils.aio.Chan(),
|
|
592
|
+
audio_ch=utils.aio.Chan(),
|
|
593
|
+
)
|
|
594
|
+
self._current_generation.message_ch.send_nowait(
|
|
595
|
+
llm.MessageGeneration(
|
|
596
|
+
message_id=msg_gen.message_id,
|
|
597
|
+
text_stream=msg_gen.text_ch,
|
|
598
|
+
audio_stream=msg_gen.audio_ch,
|
|
599
|
+
)
|
|
600
|
+
)
|
|
601
|
+
self._current_generation.messages[self._current_generation.response_id] = msg_gen
|
|
602
|
+
|
|
603
|
+
# will be completely ignoring post-ASR text events
|
|
604
|
+
async def _handle_text_output_content_start_event(self, event_data: dict) -> None:
|
|
605
|
+
"""Handle text_output_content_start for both user and assistant roles."""
|
|
606
|
+
log_event_data(event_data)
|
|
607
|
+
role = event_data["event"]["contentStart"]["role"]
|
|
608
|
+
|
|
609
|
+
# note: does not work if you emit llm.GCE too early (for some reason)
|
|
610
|
+
if role == "USER":
|
|
611
|
+
self._create_response_generation()
|
|
612
|
+
content_id = event_data["event"]["contentStart"]["contentId"]
|
|
613
|
+
self._current_generation.user_messages[content_id] = self._current_generation.input_id
|
|
614
|
+
|
|
615
|
+
elif (
|
|
616
|
+
role == "ASSISTANT"
|
|
617
|
+
and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
|
|
618
|
+
):
|
|
619
|
+
text_content_id = event_data["event"]["contentStart"]["contentId"]
|
|
620
|
+
self._current_generation.speculative_messages[text_content_id] = (
|
|
621
|
+
self._current_generation.response_id
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
async def _handle_text_output_content_event(self, event_data: dict) -> None:
|
|
625
|
+
"""Stream partial text tokens into the current _MessageGeneration."""
|
|
626
|
+
log_event_data(event_data)
|
|
627
|
+
text_content_id = event_data["event"]["textOutput"]["contentId"]
|
|
628
|
+
text_content = f"{event_data['event']['textOutput']['content']}\n"
|
|
629
|
+
|
|
630
|
+
# currently only agent can be interrupted
|
|
631
|
+
if text_content == '{ "interrupted" : true }\n':
|
|
632
|
+
# the interrupted flag is not being set correctly in chat_ctx
|
|
633
|
+
# this is b/c audio playback is desynced from text transcription
|
|
634
|
+
# TODO: fix this; possibly via a playback timer
|
|
635
|
+
idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
|
|
636
|
+
logger.debug(
|
|
637
|
+
f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
|
|
638
|
+
)
|
|
639
|
+
self._chat_ctx.items[idx].interrupted = True
|
|
640
|
+
self._close_current_generation()
|
|
641
|
+
return
|
|
642
|
+
|
|
643
|
+
# ignore events until turn starts
|
|
644
|
+
if self._current_generation is not None:
|
|
645
|
+
# TODO: rename event to llm.InputTranscriptionUpdated
|
|
646
|
+
if (
|
|
647
|
+
self._current_generation.user_messages.get(text_content_id)
|
|
648
|
+
== self._current_generation.input_id
|
|
649
|
+
):
|
|
650
|
+
logger.debug(f"INPUT TRANSCRIPTION UPDATED: {text_content}")
|
|
651
|
+
# note: user ASR text is slightly different than what is sent to LiveKit (newline vs whitespace) # noqa: E501
|
|
652
|
+
# TODO: fix this
|
|
653
|
+
self._update_chat_ctx(role="user", text_content=text_content)
|
|
654
|
+
|
|
655
|
+
elif (
|
|
656
|
+
self._current_generation.speculative_messages.get(text_content_id)
|
|
657
|
+
== self._current_generation.response_id
|
|
658
|
+
):
|
|
659
|
+
curr_gen = self._current_generation.messages[self._current_generation.response_id]
|
|
660
|
+
curr_gen.text_ch.send_nowait(text_content)
|
|
661
|
+
# note: this update is per utterance, not per turn
|
|
662
|
+
self._update_chat_ctx(role="assistant", text_content=text_content)
|
|
663
|
+
|
|
664
|
+
def _update_chat_ctx(self, role: str, text_content: str) -> None:
|
|
665
|
+
"""
|
|
666
|
+
Update the chat context with the latest ASR text while guarding against model limitations:
|
|
667
|
+
a) 40 total messages limit
|
|
668
|
+
b) 1kB message size limit
|
|
669
|
+
"""
|
|
670
|
+
prev_utterance = self._chat_ctx.items[-1]
|
|
671
|
+
if prev_utterance.role == role:
|
|
672
|
+
if (
|
|
673
|
+
len(prev_utterance.content[0].encode("utf-8")) + len(text_content.encode("utf-8"))
|
|
674
|
+
< MAX_MESSAGE_SIZE
|
|
675
|
+
):
|
|
676
|
+
prev_utterance.content[0] = "\n".join([prev_utterance.content[0], text_content])
|
|
677
|
+
else:
|
|
678
|
+
self._chat_ctx.add_message(role=role, content=text_content)
|
|
679
|
+
if len(self._chat_ctx.items) > MAX_MESSAGES:
|
|
680
|
+
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
681
|
+
else:
|
|
682
|
+
self._chat_ctx.add_message(role=role, content=text_content)
|
|
683
|
+
if len(self._chat_ctx.items) > MAX_MESSAGES:
|
|
684
|
+
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
685
|
+
|
|
686
|
+
# cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
|
|
687
|
+
async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
|
|
688
|
+
"""Mark the assistant message closed when Bedrock signals END_TURN."""
|
|
689
|
+
stop_reason = event_data["event"]["contentEnd"]["stopReason"]
|
|
690
|
+
text_content_id = event_data["event"]["contentEnd"]["contentId"]
|
|
691
|
+
if (
|
|
692
|
+
self._current_generation
|
|
693
|
+
is not None # means that first utterance in the turn was an interrupt
|
|
694
|
+
and self._current_generation.speculative_messages.get(text_content_id)
|
|
695
|
+
== self._current_generation.response_id
|
|
696
|
+
and stop_reason == "END_TURN"
|
|
697
|
+
):
|
|
698
|
+
log_event_data(event_data)
|
|
699
|
+
self._close_current_generation()
|
|
700
|
+
|
|
701
|
+
async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
|
|
702
|
+
"""Track mapping content_id -> response_id for upcoming tool use."""
|
|
703
|
+
log_event_data(event_data)
|
|
704
|
+
tool_use_content_id = event_data["event"]["contentStart"]["contentId"]
|
|
705
|
+
self._current_generation.tool_messages[tool_use_content_id] = (
|
|
706
|
+
self._current_generation.response_id
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
# note: tool calls are synchronous for now
|
|
710
|
+
async def _handle_tool_output_content_event(self, event_data: dict) -> None:
|
|
711
|
+
"""Execute the referenced tool locally and forward results back to Bedrock."""
|
|
712
|
+
log_event_data(event_data)
|
|
713
|
+
tool_use_content_id = event_data["event"]["toolUse"]["contentId"]
|
|
714
|
+
tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
|
|
715
|
+
tool_name = event_data["event"]["toolUse"]["toolName"]
|
|
716
|
+
if (
|
|
717
|
+
self._current_generation.tool_messages.get(tool_use_content_id)
|
|
718
|
+
== self._current_generation.response_id
|
|
719
|
+
):
|
|
720
|
+
args = event_data["event"]["toolUse"]["content"]
|
|
721
|
+
self._current_generation.function_ch.send_nowait(
|
|
722
|
+
llm.FunctionCall(
|
|
723
|
+
call_id=tool_use_id,
|
|
724
|
+
name=tool_name,
|
|
725
|
+
arguments=args,
|
|
726
|
+
)
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# note: may need to inject RunContext here...
|
|
730
|
+
tool_type = self._tool_type_map[tool_name]
|
|
731
|
+
if tool_type == "FunctionTool":
|
|
732
|
+
tool_result = await self.tools.function_tools[tool_name](**json.loads(args))
|
|
733
|
+
elif tool_type == "RawFunctionTool":
|
|
734
|
+
tool_result = await self.tools.function_tools[tool_name](json.loads(args))
|
|
735
|
+
else:
|
|
736
|
+
raise ValueError(f"Unknown tool type: {tool_type}")
|
|
737
|
+
logger.debug(f"TOOL ARGS: {args}\nTOOL RESULT: {tool_result}")
|
|
738
|
+
|
|
739
|
+
# Sonic only accepts Structured Output for tool results
|
|
740
|
+
# therefore, must JSON stringify ToolError
|
|
741
|
+
if isinstance(tool_result, ToolError):
|
|
742
|
+
logger.warning(f"TOOL ERROR: {tool_name} {tool_result.message}")
|
|
743
|
+
tool_result = {"error": tool_result.message}
|
|
744
|
+
self._tool_results_ch.send_nowait(
|
|
745
|
+
{
|
|
746
|
+
"tool_use_id": tool_use_id,
|
|
747
|
+
"tool_result": tool_result,
|
|
748
|
+
}
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
|
|
752
|
+
log_event_data(event_data)
|
|
753
|
+
|
|
754
|
+
async def _handle_audio_output_content_start_event(self, event_data: dict) -> None:
|
|
755
|
+
"""Associate the upcoming audio chunk with the active assistant message."""
|
|
756
|
+
if self._current_generation is not None:
|
|
757
|
+
log_event_data(event_data)
|
|
758
|
+
audio_content_id = event_data["event"]["contentStart"]["contentId"]
|
|
759
|
+
self._current_generation.speculative_messages[audio_content_id] = (
|
|
760
|
+
self._current_generation.response_id
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
async def _handle_audio_output_content_event(self, event_data: dict) -> None:
|
|
764
|
+
"""Decode base64 audio from Bedrock and forward it to the audio stream."""
|
|
765
|
+
if (
|
|
766
|
+
self._current_generation is not None
|
|
767
|
+
and self._current_generation.speculative_messages.get(
|
|
768
|
+
event_data["event"]["audioOutput"]["contentId"]
|
|
769
|
+
)
|
|
770
|
+
== self._current_generation.response_id
|
|
771
|
+
):
|
|
772
|
+
audio_content = event_data["event"]["audioOutput"]["content"]
|
|
773
|
+
audio_bytes = base64.b64decode(audio_content)
|
|
774
|
+
curr_gen = self._current_generation.messages[self._current_generation.response_id]
|
|
775
|
+
curr_gen.audio_ch.send_nowait(
|
|
776
|
+
rtc.AudioFrame(
|
|
777
|
+
data=audio_bytes,
|
|
778
|
+
sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
|
|
779
|
+
num_channels=DEFAULT_CHANNELS,
|
|
780
|
+
samples_per_channel=len(audio_bytes) // 2,
|
|
781
|
+
)
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
async def _handle_audio_output_content_end_event(self, event_data: dict) -> None:
|
|
785
|
+
"""Close the assistant message streams once Bedrock finishes audio for the turn."""
|
|
786
|
+
if (
|
|
787
|
+
self._current_generation is not None
|
|
788
|
+
and event_data["event"]["contentEnd"]["stopReason"] == "END_TURN"
|
|
789
|
+
and self._current_generation.speculative_messages.get(
|
|
790
|
+
event_data["event"]["contentEnd"]["contentId"]
|
|
791
|
+
)
|
|
792
|
+
== self._current_generation.response_id
|
|
793
|
+
):
|
|
794
|
+
log_event_data(event_data)
|
|
795
|
+
self._close_current_generation()
|
|
796
|
+
|
|
797
|
+
def _close_current_generation(self) -> None:
|
|
798
|
+
"""Helper that closes all channels of the active _ResponseGeneration."""
|
|
799
|
+
if self._current_generation is not None:
|
|
800
|
+
if self._current_generation.response_id in self._current_generation.messages:
|
|
801
|
+
curr_gen = self._current_generation.messages[self._current_generation.response_id]
|
|
802
|
+
if not curr_gen.audio_ch.closed:
|
|
803
|
+
curr_gen.audio_ch.close()
|
|
804
|
+
if not curr_gen.text_ch.closed:
|
|
805
|
+
curr_gen.text_ch.close()
|
|
806
|
+
|
|
807
|
+
if not self._current_generation.message_ch.closed:
|
|
808
|
+
self._current_generation.message_ch.close()
|
|
809
|
+
if not self._current_generation.function_ch.closed:
|
|
810
|
+
self._current_generation.function_ch.close()
|
|
811
|
+
|
|
812
|
+
self._current_generation = None
|
|
813
|
+
|
|
814
|
+
async def _handle_completion_end_event(self, event_data: dict) -> None:
|
|
815
|
+
log_event_data(event_data)
|
|
816
|
+
|
|
817
|
+
async def _handle_other_event(self, event_data: dict) -> None:
|
|
818
|
+
log_event_data(event_data)
|
|
819
|
+
|
|
820
|
+
async def _handle_usage_event(self, event_data: dict) -> None:
|
|
821
|
+
# log_event_data(event_data)
|
|
822
|
+
# TODO: implement duration and ttft
|
|
823
|
+
input_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["input"]
|
|
824
|
+
output_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["output"]
|
|
825
|
+
# Q: should we be counting per turn or utterance?
|
|
826
|
+
metrics = RealtimeModelMetrics(
|
|
827
|
+
label=self._realtime_model._label,
|
|
828
|
+
# TODO: pass in the correct request_id
|
|
829
|
+
request_id=event_data["event"]["usageEvent"]["completionId"],
|
|
830
|
+
timestamp=time.monotonic(),
|
|
831
|
+
duration=0,
|
|
832
|
+
ttft=0,
|
|
833
|
+
cancelled=False,
|
|
834
|
+
input_tokens=input_tokens["speechTokens"] + input_tokens["textTokens"],
|
|
835
|
+
output_tokens=output_tokens["speechTokens"] + output_tokens["textTokens"],
|
|
836
|
+
total_tokens=input_tokens["speechTokens"]
|
|
837
|
+
+ input_tokens["textTokens"]
|
|
838
|
+
+ output_tokens["speechTokens"]
|
|
839
|
+
+ output_tokens["textTokens"],
|
|
840
|
+
# need duration to calculate this
|
|
841
|
+
tokens_per_second=0,
|
|
842
|
+
input_token_details=RealtimeModelMetrics.InputTokenDetails(
|
|
843
|
+
text_tokens=input_tokens["textTokens"],
|
|
844
|
+
audio_tokens=input_tokens["speechTokens"],
|
|
845
|
+
image_tokens=0,
|
|
846
|
+
cached_tokens=0,
|
|
847
|
+
cached_tokens_details=None,
|
|
848
|
+
),
|
|
849
|
+
output_token_details=RealtimeModelMetrics.OutputTokenDetails(
|
|
850
|
+
text_tokens=output_tokens["textTokens"],
|
|
851
|
+
audio_tokens=output_tokens["speechTokens"],
|
|
852
|
+
image_tokens=0,
|
|
853
|
+
),
|
|
854
|
+
)
|
|
855
|
+
self.emit("metrics_collected", metrics)
|
|
856
|
+
|
|
857
|
+
@utils.log_exceptions(logger=logger)
|
|
858
|
+
async def _process_responses(self):
|
|
859
|
+
"""Background task that drains Bedrock's output stream and feeds the event handlers."""
|
|
860
|
+
try:
|
|
861
|
+
await self._is_sess_active.wait()
|
|
862
|
+
|
|
863
|
+
# note: may need another signal here to block input task until bedrock is ready
|
|
864
|
+
# TODO: save this as a field so we're not re-awaiting it every time
|
|
865
|
+
_, output_stream = await self._stream_response.await_output()
|
|
866
|
+
while self._is_sess_active.is_set():
|
|
867
|
+
# and not self.stream_response.output_stream.closed:
|
|
868
|
+
try:
|
|
869
|
+
result = await output_stream.receive()
|
|
870
|
+
if result.value and result.value.bytes_:
|
|
871
|
+
try:
|
|
872
|
+
response_data = result.value.bytes_.decode("utf-8")
|
|
873
|
+
json_data = json.loads(response_data)
|
|
874
|
+
# logger.debug(f"Received event: {json_data}")
|
|
875
|
+
await self._handle_event(json_data)
|
|
876
|
+
except json.JSONDecodeError:
|
|
877
|
+
logger.warning(f"JSON decode error: {response_data}")
|
|
878
|
+
else:
|
|
879
|
+
logger.warning("No response received")
|
|
880
|
+
except asyncio.CancelledError:
|
|
881
|
+
logger.info("Response processing task cancelled")
|
|
882
|
+
self._close_current_generation()
|
|
883
|
+
raise
|
|
884
|
+
except ValidationException as ve:
|
|
885
|
+
# there is a 3min no-activity (e.g. silence) timeout on the stream, after which the stream is closed # noqa: E501
|
|
886
|
+
if (
|
|
887
|
+
"InternalErrorCode=531::RST_STREAM closed stream. HTTP/2 error code: NO_ERROR" # noqa: E501
|
|
888
|
+
in ve.message
|
|
889
|
+
):
|
|
890
|
+
logger.warning(f"Validation error: {ve}\nAttempting to recover...")
|
|
891
|
+
await self._restart_session(ve)
|
|
892
|
+
|
|
893
|
+
else:
|
|
894
|
+
logger.error(f"Validation error: {ve}")
|
|
895
|
+
request_id = ve.split(" ")[0].split("=")[1]
|
|
896
|
+
self.emit(
|
|
897
|
+
"error",
|
|
898
|
+
llm.RealtimeModelError(
|
|
899
|
+
timestamp=time.monotonic(),
|
|
900
|
+
label=self._realtime_model._label,
|
|
901
|
+
error=APIStatusError(
|
|
902
|
+
message=ve.message,
|
|
903
|
+
status_code=400,
|
|
904
|
+
request_id=request_id,
|
|
905
|
+
body=ve,
|
|
906
|
+
retryable=False,
|
|
907
|
+
),
|
|
908
|
+
recoverable=False,
|
|
909
|
+
),
|
|
910
|
+
)
|
|
911
|
+
raise
|
|
912
|
+
except (ThrottlingException, ModelNotReadyException, ModelErrorException) as re:
|
|
913
|
+
logger.warning(f"Retryable error: {re}\nAttempting to recover...")
|
|
914
|
+
await self._restart_session(re)
|
|
915
|
+
break
|
|
916
|
+
except ModelTimeoutException as mte:
|
|
917
|
+
logger.warning(f"Model timeout error: {mte}\nAttempting to recover...")
|
|
918
|
+
await self._restart_session(mte)
|
|
919
|
+
break
|
|
920
|
+
except ValueError as val_err:
|
|
921
|
+
if "I/O operation on closed file." == val_err.args[0]:
|
|
922
|
+
logger.info("initiating graceful shutdown of session")
|
|
923
|
+
break
|
|
924
|
+
raise
|
|
925
|
+
except OSError:
|
|
926
|
+
logger.info("stream already closed, exiting")
|
|
927
|
+
break
|
|
928
|
+
except Exception as e:
|
|
929
|
+
err_msg = getattr(e, "message", str(e))
|
|
930
|
+
logger.error(f"Response processing error: {err_msg} (type: {type(e)})")
|
|
931
|
+
request_id = None
|
|
932
|
+
try:
|
|
933
|
+
request_id = err_msg.split(" ")[0].split("=")[1]
|
|
934
|
+
except Exception:
|
|
935
|
+
pass
|
|
936
|
+
|
|
937
|
+
self.emit(
|
|
938
|
+
"error",
|
|
939
|
+
llm.RealtimeModelError(
|
|
940
|
+
timestamp=time.monotonic(),
|
|
941
|
+
label=self._realtime_model._label,
|
|
942
|
+
error=APIStatusError(
|
|
943
|
+
message=e.message,
|
|
944
|
+
status_code=500,
|
|
945
|
+
request_id=request_id,
|
|
946
|
+
body=e,
|
|
947
|
+
retryable=False,
|
|
948
|
+
),
|
|
949
|
+
recoverable=False,
|
|
950
|
+
),
|
|
951
|
+
)
|
|
952
|
+
raise
|
|
953
|
+
|
|
954
|
+
finally:
|
|
955
|
+
logger.info("main output response stream processing task exiting")
|
|
956
|
+
self._is_sess_active.clear()
|
|
957
|
+
|
|
958
|
+
async def _restart_session(self, ex: Exception) -> None:
|
|
959
|
+
if self._session_restart_attempts >= DEFAULT_MAX_SESSION_RESTART_ATTEMPTS:
|
|
960
|
+
logger.error("Max session restart attempts reached, exiting")
|
|
961
|
+
err_msg = getattr(ex, "message", str(ex))
|
|
962
|
+
request_id = None
|
|
963
|
+
try:
|
|
964
|
+
request_id = err_msg.split(" ")[0].split("=")[1]
|
|
965
|
+
except Exception:
|
|
966
|
+
pass
|
|
967
|
+
self.emit(
|
|
968
|
+
"error",
|
|
969
|
+
llm.RealtimeModelError(
|
|
970
|
+
timestamp=time.monotonic(),
|
|
971
|
+
label=self._realtime_model._label,
|
|
972
|
+
error=APIStatusError(
|
|
973
|
+
message=f"Max restart attempts exceeded: {err_msg}",
|
|
974
|
+
status_code=500,
|
|
975
|
+
request_id=request_id,
|
|
976
|
+
body=ex,
|
|
977
|
+
retryable=False,
|
|
978
|
+
),
|
|
979
|
+
recoverable=False,
|
|
980
|
+
),
|
|
981
|
+
)
|
|
982
|
+
self._is_sess_active.clear()
|
|
983
|
+
return
|
|
984
|
+
self._session_restart_attempts += 1
|
|
985
|
+
self._is_sess_active.clear()
|
|
986
|
+
delay = 2 ** (self._session_restart_attempts - 1) - 1
|
|
987
|
+
await asyncio.sleep(min(delay, DEFAULT_MAX_SESSION_RESTART_DELAY))
|
|
988
|
+
await self.initialize_streams(is_restart=True)
|
|
989
|
+
logger.info(
|
|
990
|
+
f"Session restarted successfully ({self._session_restart_attempts}/{DEFAULT_MAX_SESSION_RESTART_ATTEMPTS})" # noqa: E501
|
|
991
|
+
)
|
|
992
|
+
|
|
993
|
+
@property
|
|
994
|
+
def chat_ctx(self) -> llm.ChatContext:
|
|
995
|
+
return self._chat_ctx.copy()
|
|
996
|
+
|
|
997
|
+
@property
|
|
998
|
+
def tools(self) -> llm.ToolContext:
|
|
999
|
+
return self._tools.copy()
|
|
1000
|
+
|
|
1001
|
+
async def update_instructions(self, instructions: str) -> None:
|
|
1002
|
+
"""Injects the system prompt at the start of the session."""
|
|
1003
|
+
self._instructions = instructions
|
|
1004
|
+
self._instructions_ready.set_result(True)
|
|
1005
|
+
logger.debug(f"Instructions updated: {instructions}")
|
|
1006
|
+
|
|
1007
|
+
async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
|
|
1008
|
+
"""Inject an initial chat history once during the very first session startup."""
|
|
1009
|
+
# sometimes fires randomly
|
|
1010
|
+
# add a guard here to only allow chat_ctx to be updated on
|
|
1011
|
+
# the very first session initialization
|
|
1012
|
+
if not self._chat_ctx_ready.done():
|
|
1013
|
+
self._chat_ctx = chat_ctx.copy()
|
|
1014
|
+
logger.debug(f"Chat context updated: {self._chat_ctx.items}")
|
|
1015
|
+
self._chat_ctx_ready.set_result(True)
|
|
1016
|
+
|
|
1017
|
+
async def _send_tool_events(self, tool_use_id: str, tool_result: str) -> None:
|
|
1018
|
+
"""Send tool_result back to Bedrock, grouped under tool_use_id."""
|
|
1019
|
+
tool_content_name = str(uuid.uuid4())
|
|
1020
|
+
tool_events = self._event_builder.create_tool_content_block(
|
|
1021
|
+
content_name=tool_content_name,
|
|
1022
|
+
tool_use_id=tool_use_id,
|
|
1023
|
+
content=tool_result,
|
|
1024
|
+
)
|
|
1025
|
+
for event in tool_events:
|
|
1026
|
+
await self._send_raw_event(event)
|
|
1027
|
+
# logger.debug(f"Sent tool event: {event}")
|
|
1028
|
+
|
|
1029
|
+
def _tool_choice_adapter(self, tool_choice: llm.ToolChoice) -> dict[str, dict[str, str]] | None:
|
|
1030
|
+
"""Translate the LiveKit ToolChoice enum into Sonic's JSON schema."""
|
|
1031
|
+
if tool_choice == "auto":
|
|
1032
|
+
return {"auto": {}}
|
|
1033
|
+
elif tool_choice == "required":
|
|
1034
|
+
return {"any": {}}
|
|
1035
|
+
elif isinstance(tool_choice, dict) and tool_choice["type"] == "function":
|
|
1036
|
+
return {"tool": {"name": tool_choice["function"]["name"]}}
|
|
1037
|
+
else:
|
|
1038
|
+
return None
|
|
1039
|
+
|
|
1040
|
+
# note: return value from tool functions registered to Sonic must be Structured Output (a dict that is JSON serializable) # noqa: E501
|
|
1041
|
+
async def update_tools(self, tools: list[llm.FunctionTool | llm.RawFunctionTool | Any]) -> None:
|
|
1042
|
+
"""Replace the active tool set with tools and notify Sonic if necessary."""
|
|
1043
|
+
logger.debug(f"Updating tools: {tools}")
|
|
1044
|
+
retained_tools: list[llm.FunctionTool | llm.RawFunctionTool] = []
|
|
1045
|
+
|
|
1046
|
+
for tool in tools:
|
|
1047
|
+
retained_tools.append(tool)
|
|
1048
|
+
self._tools = llm.ToolContext(retained_tools)
|
|
1049
|
+
if retained_tools:
|
|
1050
|
+
self._tools_ready.set_result(True)
|
|
1051
|
+
logger.debug("Tool list has been injected")
|
|
1052
|
+
|
|
1053
|
+
def update_options(self, *, tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN) -> None:
|
|
1054
|
+
"""Live update of inference options is not supported by Sonic yet."""
|
|
1055
|
+
logger.warning(
|
|
1056
|
+
"updating inference configuration options is not yet supported by Nova Sonic's Realtime API" # noqa: E501
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
@utils.log_exceptions(logger=logger)
|
|
1060
|
+
def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
|
|
1061
|
+
"""Ensure mic audio matches Sonic's required sample rate & channels."""
|
|
1062
|
+
if self._input_resampler:
|
|
1063
|
+
if frame.sample_rate != self._input_resampler._input_rate:
|
|
1064
|
+
self._input_resampler = None
|
|
1065
|
+
|
|
1066
|
+
if self._input_resampler is None and (
|
|
1067
|
+
frame.sample_rate != DEFAULT_INPUT_SAMPLE_RATE or frame.num_channels != DEFAULT_CHANNELS
|
|
1068
|
+
):
|
|
1069
|
+
self._input_resampler = rtc.AudioResampler(
|
|
1070
|
+
input_rate=frame.sample_rate,
|
|
1071
|
+
output_rate=DEFAULT_INPUT_SAMPLE_RATE,
|
|
1072
|
+
num_channels=DEFAULT_CHANNELS,
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
if self._input_resampler:
|
|
1076
|
+
# flush the resampler when the input source is changed
|
|
1077
|
+
yield from self._input_resampler.push(frame)
|
|
1078
|
+
else:
|
|
1079
|
+
yield frame
|
|
1080
|
+
|
|
1081
|
+
@utils.log_exceptions(logger=logger)
|
|
1082
|
+
async def _process_audio_input(self):
|
|
1083
|
+
"""Background task that feeds audio and tool results into the Bedrock stream."""
|
|
1084
|
+
await self._send_raw_event(self._event_builder.create_audio_content_start_event())
|
|
1085
|
+
logger.info("Starting audio input processing loop")
|
|
1086
|
+
while self._is_sess_active.is_set():
|
|
1087
|
+
try:
|
|
1088
|
+
# note: could potentially pull this out into a separate task
|
|
1089
|
+
try:
|
|
1090
|
+
val = self._tool_results_ch.recv_nowait()
|
|
1091
|
+
tool_result = val["tool_result"]
|
|
1092
|
+
tool_use_id = val["tool_use_id"]
|
|
1093
|
+
await self._send_tool_events(tool_use_id, tool_result)
|
|
1094
|
+
|
|
1095
|
+
except utils.aio.channel.ChanEmpty:
|
|
1096
|
+
pass
|
|
1097
|
+
except utils.aio.channel.ChanClosed:
|
|
1098
|
+
logger.warning(
|
|
1099
|
+
"tool results channel closed, exiting audio input processing loop"
|
|
1100
|
+
)
|
|
1101
|
+
break
|
|
1102
|
+
|
|
1103
|
+
try:
|
|
1104
|
+
audio_bytes = await self._audio_input_chan.recv()
|
|
1105
|
+
blob = base64.b64encode(audio_bytes)
|
|
1106
|
+
audio_event = self._event_builder.create_audio_input_event(
|
|
1107
|
+
audio_content=blob.decode("utf-8"),
|
|
1108
|
+
)
|
|
1109
|
+
|
|
1110
|
+
await self._send_raw_event(audio_event)
|
|
1111
|
+
except utils.aio.channel.ChanEmpty:
|
|
1112
|
+
pass
|
|
1113
|
+
except utils.aio.channel.ChanClosed:
|
|
1114
|
+
logger.warning(
|
|
1115
|
+
"audio input channel closed, exiting audio input processing loop"
|
|
1116
|
+
)
|
|
1117
|
+
break
|
|
1118
|
+
|
|
1119
|
+
except asyncio.CancelledError:
|
|
1120
|
+
logger.info("Audio processing loop cancelled")
|
|
1121
|
+
self._audio_input_chan.close()
|
|
1122
|
+
self._tool_results_ch.close()
|
|
1123
|
+
raise
|
|
1124
|
+
except Exception:
|
|
1125
|
+
logger.exception("Error processing audio")
|
|
1126
|
+
|
|
1127
|
+
# for debugging purposes only
|
|
1128
|
+
def _log_significant_audio(self, audio_bytes: bytes) -> None:
|
|
1129
|
+
"""Utility that prints a debug message when the audio chunk has non-trivial RMS energy."""
|
|
1130
|
+
squared_sum = sum(sample**2 for sample in audio_bytes)
|
|
1131
|
+
if (squared_sum / len(audio_bytes)) ** 0.5 > 200:
|
|
1132
|
+
if lk_bedrock_debug:
|
|
1133
|
+
log_message("Enqueuing significant audio chunk", AnsiColors.BLUE)
|
|
1134
|
+
|
|
1135
|
+
@utils.log_exceptions(logger=logger)
|
|
1136
|
+
def push_audio(self, frame: rtc.AudioFrame) -> None:
|
|
1137
|
+
"""Enqueue an incoming mic rtc.AudioFrame for transcription."""
|
|
1138
|
+
if not self._audio_input_chan.closed:
|
|
1139
|
+
# logger.debug(f"Raw audio received: samples={len(frame.data)} rate={frame.sample_rate} channels={frame.num_channels}") # noqa: E501
|
|
1140
|
+
for f in self._resample_audio(frame):
|
|
1141
|
+
# logger.debug(f"Resampled audio: samples={len(frame.data)} rate={frame.sample_rate} channels={frame.num_channels}") # noqa: E501
|
|
1142
|
+
|
|
1143
|
+
for nf in self._bstream.write(f.data.tobytes()):
|
|
1144
|
+
self._log_significant_audio(nf.data)
|
|
1145
|
+
self._audio_input_chan.send_nowait(nf.data)
|
|
1146
|
+
else:
|
|
1147
|
+
logger.warning("audio input channel closed, skipping audio")
|
|
1148
|
+
|
|
1149
|
+
def generate_reply(
|
|
1150
|
+
self,
|
|
1151
|
+
*,
|
|
1152
|
+
instructions: NotGivenOr[str] = NOT_GIVEN,
|
|
1153
|
+
) -> asyncio.Future[llm.GenerationCreatedEvent]:
|
|
1154
|
+
logger.warning("unprompted generation is not supported by Nova Sonic's Realtime API")
|
|
1155
|
+
|
|
1156
|
+
def commit_audio(self) -> None:
|
|
1157
|
+
logger.warning("commit_audio is not supported by Nova Sonic's Realtime API")
|
|
1158
|
+
|
|
1159
|
+
def clear_audio(self) -> None:
|
|
1160
|
+
logger.warning("clear_audio is not supported by Nova Sonic's Realtime API")
|
|
1161
|
+
|
|
1162
|
+
def push_video(self, frame: rtc.VideoFrame) -> None:
|
|
1163
|
+
logger.warning("video is not supported by Nova Sonic's Realtime API")
|
|
1164
|
+
|
|
1165
|
+
def interrupt(self) -> None:
|
|
1166
|
+
logger.warning("interrupt is not supported by Nova Sonic's Realtime API")
|
|
1167
|
+
|
|
1168
|
+
def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
|
|
1169
|
+
logger.warning("truncate is not supported by Nova Sonic's Realtime API")
|
|
1170
|
+
|
|
1171
|
+
@utils.log_exceptions(logger=logger)
|
|
1172
|
+
async def aclose(self) -> None:
|
|
1173
|
+
"""Gracefully shut down the realtime session and release network resources."""
|
|
1174
|
+
logger.info("attempting to shutdown agent session")
|
|
1175
|
+
if not self._is_sess_active.is_set():
|
|
1176
|
+
logger.info("agent session already inactive")
|
|
1177
|
+
return
|
|
1178
|
+
|
|
1179
|
+
for event in self._event_builder.create_prompt_end_block():
|
|
1180
|
+
await self._send_raw_event(event)
|
|
1181
|
+
# allow event loops to fall out naturally
|
|
1182
|
+
# otherwise, the smithy layer will raise an InvalidStateError during cancellation
|
|
1183
|
+
self._is_sess_active.clear()
|
|
1184
|
+
|
|
1185
|
+
if self._stream_response and not self._stream_response.output_stream.closed:
|
|
1186
|
+
await self._stream_response.output_stream.close()
|
|
1187
|
+
|
|
1188
|
+
# note: even after the self.is_active flag is flipped and the output stream is closed,
|
|
1189
|
+
# there is a future inside output_stream.receive() at the AWS-CRT C layer that blocks
|
|
1190
|
+
# resulting in an error after cancellation
|
|
1191
|
+
# however, it's mostly cosmetic-- the event loop will still exit
|
|
1192
|
+
# TODO: fix this nit
|
|
1193
|
+
if self._response_task:
|
|
1194
|
+
try:
|
|
1195
|
+
await asyncio.wait_for(self._response_task, timeout=1.0)
|
|
1196
|
+
except asyncio.TimeoutError:
|
|
1197
|
+
logger.warning("shutdown of output event loop timed out-- cancelling")
|
|
1198
|
+
self._response_task.cancel()
|
|
1199
|
+
|
|
1200
|
+
# must cancel the audio input task before closing the input stream
|
|
1201
|
+
if self._audio_input_task and not self._audio_input_task.done():
|
|
1202
|
+
self._audio_input_task.cancel()
|
|
1203
|
+
if self._stream_response and not self._stream_response.input_stream.closed:
|
|
1204
|
+
await self._stream_response.input_stream.close()
|
|
1205
|
+
|
|
1206
|
+
await asyncio.gather(self._response_task, self._audio_input_task, return_exceptions=True)
|
|
1207
|
+
logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
|
|
1208
|
+
logger.info("Session end")
|