livekit-plugins-aws 1.1.3__py3-none-any.whl → 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of livekit-plugins-aws might be problematic. Click here for more details.

@@ -0,0 +1,1208 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import json
6
+ import os
7
+ import time
8
+ import uuid
9
+ import weakref
10
+ from collections.abc import Iterator
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime
13
+ from typing import Any, Literal
14
+
15
+ import boto3
16
+ from aws_sdk_bedrock_runtime.client import (
17
+ BedrockRuntimeClient,
18
+ InvokeModelWithBidirectionalStreamOperationInput,
19
+ )
20
+ from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme
21
+ from aws_sdk_bedrock_runtime.models import (
22
+ BidirectionalInputPayloadPart,
23
+ InvokeModelWithBidirectionalStreamInputChunk,
24
+ ModelErrorException,
25
+ ModelNotReadyException,
26
+ ModelTimeoutException,
27
+ ThrottlingException,
28
+ ValidationException,
29
+ )
30
+ from smithy_aws_core.identity import AWSCredentialsIdentity
31
+ from smithy_core.aio.interfaces.identity import IdentityResolver
32
+
33
+ from livekit import rtc
34
+ from livekit.agents import (
35
+ APIStatusError,
36
+ ToolError,
37
+ llm,
38
+ utils,
39
+ )
40
+ from livekit.agents.llm.realtime import RealtimeSession
41
+ from livekit.agents.metrics import RealtimeModelMetrics
42
+ from livekit.agents.types import NOT_GIVEN, NotGivenOr
43
+ from livekit.agents.utils import is_given
44
+ from livekit.plugins.aws.experimental.realtime.turn_tracker import _TurnTracker
45
+
46
+ from ...log import logger
47
+ from .events import (
48
+ VOICE_ID,
49
+ SonicEventBuilder as seb,
50
+ Tool,
51
+ ToolConfiguration,
52
+ ToolInputSchema,
53
+ ToolSpec,
54
+ )
55
+ from .pretty_printer import AnsiColors, log_event_data, log_message
56
+
57
+ DEFAULT_INPUT_SAMPLE_RATE = 16000
58
+ DEFAULT_OUTPUT_SAMPLE_RATE = 24000
59
+ DEFAULT_SAMPLE_SIZE_BITS = 16
60
+ DEFAULT_CHANNELS = 1
61
+ DEFAULT_CHUNK_SIZE = 512
62
+ DEFAULT_TEMPERATURE = 0.7
63
+ DEFAULT_TOP_P = 0.9
64
+ DEFAULT_MAX_TOKENS = 1024
65
+ MAX_MESSAGE_SIZE = 1024
66
+ MAX_MESSAGES = 40
67
+ DEFAULT_MAX_SESSION_RESTART_ATTEMPTS = 3
68
+ DEFAULT_MAX_SESSION_RESTART_DELAY = 10
69
+ DEFAULT_SYSTEM_PROMPT = (
70
+ "Your name is Sonic. You are a friend and eagerly helpful assistant."
71
+ "The user and you will engage in a spoken dialog exchanging the transcripts of a natural real-time conversation." # noqa: E501
72
+ "Keep your responses short and concise unless the user asks you to elaborate or you are explicitly asked to be verbose and chatty." # noqa: E501
73
+ "Do not repeat yourself. Do not ask the user to repeat themselves."
74
+ "Do ask the user to confirm or clarify their response if you are not sure what they mean."
75
+ "If after asking the user for clarification you still do not understand, be honest and tell them that you do not understand." # noqa: E501
76
+ "Do not make up information or make assumptions. If you do not know the answer, tell the user that you do not know the answer." # noqa: E501
77
+ "If the user makes a request of you that you cannot fulfill, tell them why you cannot fulfill it." # noqa: E501
78
+ "When making tool calls, inform the user that you are using a tool to generate the response."
79
+ "Avoid formatted lists or numbering and keep your output as a spoken transcript to be acted out." # noqa: E501
80
+ "Be appropriately emotive when responding to the user. Use American English as the language for your responses." # noqa: E501
81
+ )
82
+
83
+ lk_bedrock_debug = int(os.getenv("LK_BEDROCK_DEBUG", 0))
84
+
85
+
86
+ @dataclass
87
+ class _RealtimeOptions:
88
+ """Configuration container for a Sonic realtime session.
89
+
90
+ Attributes:
91
+ voice (VOICE_ID): Voice identifier used for TTS output.
92
+ temperature (float): Sampling temperature controlling randomness; 1.0 is most deterministic.
93
+ top_p (float): Nucleus sampling parameter; 0.0 considers all tokens.
94
+ max_tokens (int): Maximum number of tokens the model may generate in a single response.
95
+ tool_choice (llm.ToolChoice | None): Strategy that dictates how the model should invoke tools.
96
+ region (str): AWS region hosting the Bedrock Sonic model endpoint.
97
+ """ # noqa: E501
98
+
99
+ voice: VOICE_ID
100
+ temperature: float
101
+ top_p: float
102
+ max_tokens: int
103
+ tool_choice: llm.ToolChoice | None
104
+ region: str
105
+
106
+
107
+ @dataclass
108
+ class _MessageGeneration:
109
+ """Grouping of streams that together represent one assistant message.
110
+
111
+ Attributes:
112
+ message_id (str): Unique identifier that ties together text and audio for a single assistant turn.
113
+ text_ch (utils.aio.Chan[str]): Channel that yields partial text tokens as they arrive.
114
+ audio_ch (utils.aio.Chan[rtc.AudioFrame]): Channel that yields audio frames for the same assistant turn.
115
+ """ # noqa: E501
116
+
117
+ message_id: str
118
+ text_ch: utils.aio.Chan[str]
119
+ audio_ch: utils.aio.Chan[rtc.AudioFrame]
120
+
121
+
122
+ @dataclass
123
+ class _ResponseGeneration:
124
+ """Book-keeping dataclass tracking the lifecycle of a Sonic turn.
125
+
126
+ This object is created whenever we receive a *completion_start* event from the model
127
+ and is disposed of once the assistant turn finishes (e.g. *END_TURN*).
128
+
129
+ Attributes:
130
+ message_ch (utils.aio.Chan[llm.MessageGeneration]): Multiplexed stream for all assistant messages.
131
+ function_ch (utils.aio.Chan[llm.FunctionCall]): Stream that emits function tool calls.
132
+ input_id (str): Synthetic message id for the user input of the current turn.
133
+ response_id (str): Synthetic message id for the assistant reply of the current turn.
134
+ messages (dict[str, _MessageGeneration]): Map of message_id -> per-message stream containers.
135
+ user_messages (dict[str, str]): Map Bedrock content_id -> input_id.
136
+ speculative_messages (dict[str, str]): Map Bedrock content_id -> response_id (assistant side).
137
+ tool_messages (dict[str, str]): Map Bedrock content_id -> response_id for tool calls.
138
+ output_text (str): Accumulated assistant text (only used for metrics / debugging).
139
+ _created_timestamp (str): ISO-8601 timestamp when the generation record was created.
140
+ _first_token_timestamp (float | None): Wall-clock time of first token emission.
141
+ _completed_timestamp (float | None): Wall-clock time when the turn fully completed.
142
+ """ # noqa: E501
143
+
144
+ message_ch: utils.aio.Chan[llm.MessageGeneration]
145
+ function_ch: utils.aio.Chan[llm.FunctionCall]
146
+ input_id: str # corresponds to user's portion of the turn
147
+ response_id: str # corresponds to agent's portion of the turn
148
+ messages: dict[str, _MessageGeneration] = field(default_factory=dict)
149
+ user_messages: dict[str, str] = field(default_factory=dict)
150
+ speculative_messages: dict[str, str] = field(default_factory=dict)
151
+ tool_messages: dict[str, str] = field(default_factory=dict)
152
+ output_text: str = "" # agent ASR text
153
+ _created_timestamp: str = field(default_factory=datetime.now().isoformat())
154
+ _first_token_timestamp: float | None = None
155
+ _completed_timestamp: float | None = None
156
+
157
+
158
+ class Boto3CredentialsResolver(IdentityResolver):
159
+ """IdentityResolver implementation that sources AWS credentials from boto3.
160
+
161
+ The resolver delegates to the default boto3.Session() credential chain which
162
+ checks environment variables, shared credentials files, EC2 instance profiles, etc.
163
+ The credentials are then wrapped in an AWSCredentialsIdentity so they can be
164
+ passed into Bedrock runtime clients.
165
+ """
166
+
167
+ def __init__(self):
168
+ self.session = boto3.Session()
169
+
170
+ async def get_identity(self, **kwargs):
171
+ """Asynchronously resolve AWS credentials.
172
+
173
+ This method is invoked by the Bedrock runtime client whenever a new request needs to be
174
+ signed. It converts the static or temporary credentials returned by boto3
175
+ into an AWSCredentialsIdentity instance.
176
+
177
+ Returns:
178
+ AWSCredentialsIdentity: Identity containing the
179
+ AWS access key, secret key and optional session token.
180
+
181
+ Raises:
182
+ ValueError: If no credentials could be found by boto3.
183
+ """
184
+ try:
185
+ logger.debug("Attempting to load AWS credentials")
186
+ credentials = self.session.get_credentials()
187
+ if not credentials:
188
+ logger.error("Unable to load AWS credentials")
189
+ raise ValueError("Unable to load AWS credentials")
190
+
191
+ creds = credentials.get_frozen_credentials()
192
+ logger.debug(
193
+ f"AWS credentials loaded successfully. AWS_ACCESS_KEY_ID: {creds.access_key[:4]}***"
194
+ )
195
+
196
+ identity = AWSCredentialsIdentity(
197
+ access_key_id=creds.access_key,
198
+ secret_access_key=creds.secret_key,
199
+ session_token=creds.token if creds.token else None,
200
+ expiration=None,
201
+ )
202
+ return identity
203
+ except Exception as e:
204
+ logger.error(f"Failed to load AWS credentials: {str(e)}")
205
+ raise ValueError(f"Failed to load AWS credentials: {str(e)}") # noqa: B904
206
+
207
+
208
+ class RealtimeModel(llm.RealtimeModel):
209
+ """High-level entry point that conforms to the LiveKit RealtimeModel interface.
210
+
211
+ The object is very light-weight-– it mainly stores default inference options and
212
+ spawns a RealtimeSession when session() is invoked.
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ *,
218
+ voice: NotGivenOr[VOICE_ID] = NOT_GIVEN,
219
+ temperature: NotGivenOr[float] = NOT_GIVEN,
220
+ top_p: NotGivenOr[float] = NOT_GIVEN,
221
+ max_tokens: NotGivenOr[int] = NOT_GIVEN,
222
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
223
+ region: NotGivenOr[str] = NOT_GIVEN,
224
+ ):
225
+ """Instantiate a new RealtimeModel.
226
+
227
+ Args:
228
+ voice (VOICE_ID | NotGiven): Preferred voice id for Sonic TTS output. Falls back to "tiffany".
229
+ temperature (float | NotGiven): Sampling temperature (0-1). Defaults to DEFAULT_TEMPERATURE.
230
+ top_p (float | NotGiven): Nucleus sampling probability mass. Defaults to DEFAULT_TOP_P.
231
+ max_tokens (int | NotGiven): Upper bound for tokens emitted by the model. Defaults to DEFAULT_MAX_TOKENS.
232
+ tool_choice (llm.ToolChoice | None | NotGiven): Strategy for tool invocation ("auto", "required", or explicit function).
233
+ region (str | NotGiven): AWS region of the Bedrock runtime endpoint.
234
+ """ # noqa: E501
235
+ super().__init__(
236
+ capabilities=llm.RealtimeCapabilities(
237
+ message_truncation=False,
238
+ turn_detection=True,
239
+ user_transcription=True,
240
+ auto_tool_reply_generation=True,
241
+ )
242
+ )
243
+ self.model_id = "amazon.nova-sonic-v1:0"
244
+ # note: temperature and top_p do not follow industry standards and are defined slightly differently for Sonic # noqa: E501
245
+ # temperature ranges from 0.0 to 1.0, where 0.0 is the most random and 1.0 is the most deterministic # noqa: E501
246
+ # top_p ranges from 0.0 to 1.0, where 0.0 is the most random and 1.0 is the most deterministic # noqa: E501
247
+ self.temperature = temperature
248
+ self.top_p = top_p
249
+ self._opts = _RealtimeOptions(
250
+ voice=voice if is_given(voice) else "tiffany",
251
+ temperature=temperature if is_given(temperature) else DEFAULT_TEMPERATURE,
252
+ top_p=top_p if is_given(top_p) else DEFAULT_TOP_P,
253
+ max_tokens=max_tokens if is_given(max_tokens) else DEFAULT_MAX_TOKENS,
254
+ tool_choice=tool_choice or None,
255
+ region=region if is_given(region) else "us-east-1",
256
+ )
257
+ self._sessions = weakref.WeakSet[RealtimeSession]()
258
+
259
+ def session(self) -> RealtimeSession:
260
+ """Return a new RealtimeSession bound to this model instance."""
261
+ sess = RealtimeSession(self)
262
+
263
+ # note: this is a hack to get the session to initialize itself
264
+ # TODO: change how RealtimeSession is initialized by creating a single task main_atask that spawns subtasks # noqa: E501
265
+ asyncio.create_task(sess.initialize_streams())
266
+ self._sessions.add(sess)
267
+ return sess
268
+
269
+ # stub b/c RealtimeSession.aclose() is invoked directly
270
+ async def aclose(self) -> None:
271
+ pass
272
+
273
+
274
+ class RealtimeSession( # noqa: F811
275
+ llm.RealtimeSession[Literal["bedrock_server_event_received", "bedrock_client_event_queued"]]
276
+ ):
277
+ """Bidirectional streaming session against the Nova Sonic Bedrock runtime.
278
+
279
+ The session owns two asynchronous tasks:
280
+
281
+ 1. _process_audio_input – pushes user mic audio and tool results to Bedrock.
282
+ 2. _process_responses – receives server events from Bedrock and converts them into
283
+ LiveKit abstractions such as llm.MessageGeneration.
284
+
285
+ A set of helper handlers (_handle_*) transform the low-level Bedrock
286
+ JSON payloads into higher-level application events and keep
287
+ _ResponseGeneration state in sync.
288
+ """
289
+
290
+ def __init__(self, realtime_model: RealtimeModel) -> None:
291
+ """Create and wire-up a new realtime session.
292
+
293
+ Args:
294
+ realtime_model (RealtimeModel): Parent model instance that stores static
295
+ inference options and the Smithy Bedrock client configuration.
296
+ """
297
+ super().__init__(realtime_model)
298
+ self._realtime_model = realtime_model
299
+ self._event_builder = seb(
300
+ prompt_name=str(uuid.uuid4()),
301
+ audio_content_name=str(uuid.uuid4()),
302
+ )
303
+ self._input_resampler: rtc.AudioResampler | None = None
304
+ self._bstream = utils.audio.AudioByteStream(
305
+ DEFAULT_INPUT_SAMPLE_RATE, DEFAULT_CHANNELS, samples_per_channel=DEFAULT_CHUNK_SIZE
306
+ )
307
+
308
+ self._response_task = None
309
+ self._audio_input_task = None
310
+ self._stream_response = None
311
+ self._bedrock_client = None
312
+ self._is_sess_active = asyncio.Event()
313
+ self._chat_ctx = llm.ChatContext.empty()
314
+ self._tools = llm.ToolContext.empty()
315
+ self._tool_type_map = {}
316
+ self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
317
+ self._tools_ready = asyncio.get_running_loop().create_future()
318
+ self._instructions_ready = asyncio.get_running_loop().create_future()
319
+ self._chat_ctx_ready = asyncio.get_running_loop().create_future()
320
+ self._instructions = DEFAULT_SYSTEM_PROMPT
321
+ self._audio_input_chan = utils.aio.Chan[bytes]()
322
+ self._current_generation: _ResponseGeneration | None = None
323
+
324
+ # note: currently tracks session restart attempts across all sessions
325
+ # TODO: track restart attempts per turn
326
+ self._session_restart_attempts = 0
327
+
328
+ self._event_handlers = {
329
+ "completion_start": self._handle_completion_start_event,
330
+ "audio_output_content_start": self._handle_audio_output_content_start_event,
331
+ "audio_output_content": self._handle_audio_output_content_event,
332
+ "audio_output_content_end": self._handle_audio_output_content_end_event,
333
+ "text_output_content_start": self._handle_text_output_content_start_event,
334
+ "text_output_content": self._handle_text_output_content_event,
335
+ "text_output_content_end": self._handle_text_output_content_end_event,
336
+ "tool_output_content_start": self._handle_tool_output_content_start_event,
337
+ "tool_output_content": self._handle_tool_output_content_event,
338
+ "tool_output_content_end": self._handle_tool_output_content_end_event,
339
+ "completion_end": self._handle_completion_end_event,
340
+ "usage": self._handle_usage_event,
341
+ "other_event": self._handle_other_event,
342
+ }
343
+ self._turn_tracker = _TurnTracker(
344
+ self.emit, streams_provider=self._current_generation_streams
345
+ )
346
+
347
+ def _current_generation_streams(
348
+ self,
349
+ ) -> tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]:
350
+ return (self._current_generation.message_ch, self._current_generation.function_ch)
351
+
352
+ @utils.log_exceptions(logger=logger)
353
+ def _initialize_client(self):
354
+ """Instantiate the Bedrock runtime client"""
355
+ config = Config(
356
+ endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com",
357
+ region=self._realtime_model._opts.region,
358
+ aws_credentials_identity_resolver=Boto3CredentialsResolver(),
359
+ http_auth_scheme_resolver=HTTPAuthSchemeResolver(),
360
+ http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()},
361
+ )
362
+ self._bedrock_client = BedrockRuntimeClient(config=config)
363
+
364
+ @utils.log_exceptions(logger=logger)
365
+ async def _send_raw_event(self, event_json):
366
+ """Low-level helper that serialises event_json and forwards it to the bidirectional stream.
367
+
368
+ Args:
369
+ event_json (dict | str): The JSON payload (already in Bedrock wire format) to queue.
370
+
371
+ Raises:
372
+ Exception: Propagates any failures returned by the Bedrock runtime client.
373
+ """
374
+ if not self._stream_response:
375
+ logger.warning("stream not initialized; dropping event (this should never occur)")
376
+ return
377
+
378
+ event = InvokeModelWithBidirectionalStreamInputChunk(
379
+ value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8"))
380
+ )
381
+
382
+ try:
383
+ await self._stream_response.input_stream.send(event)
384
+ except Exception as e:
385
+ logger.exception("Error sending event")
386
+ err_msg = getattr(e, "message", str(e))
387
+ request_id = None
388
+ try:
389
+ request_id = err_msg.split(" ")[0].split("=")[1]
390
+ except Exception:
391
+ pass
392
+
393
+ self.emit(
394
+ "error",
395
+ llm.RealtimeModelError(
396
+ timestamp=time.monotonic(),
397
+ label=self._realtime_model._label,
398
+ error=APIStatusError(
399
+ message=err_msg,
400
+ status_code=500,
401
+ request_id=request_id,
402
+ body=e,
403
+ retryable=False,
404
+ ),
405
+ recoverable=False,
406
+ ),
407
+ )
408
+ raise
409
+
410
+ def _serialize_tool_config(self) -> ToolConfiguration | None:
411
+ """Convert self.tools into the JSON structure expected by Sonic.
412
+
413
+ If any tools are registered, the method also harmonises temperature and
414
+ top_p defaults to Sonic's recommended greedy values (1.0).
415
+
416
+ Returns:
417
+ ToolConfiguration | None: None when no tools are present, otherwise a complete config block.
418
+ """ # noqa: E501
419
+ tool_cfg = None
420
+ if self.tools.function_tools:
421
+ tools = []
422
+ for name, f in self.tools.function_tools.items():
423
+ if llm.tool_context.is_function_tool(f):
424
+ description = llm.tool_context.get_function_info(f).description
425
+ input_schema = llm.utils.build_legacy_openai_schema(f, internally_tagged=True)[
426
+ "parameters"
427
+ ]
428
+ self._tool_type_map[name] = "FunctionTool"
429
+ else:
430
+ description = llm.tool_context.get_raw_function_info(f).raw_schema.get(
431
+ "description"
432
+ )
433
+ input_schema = llm.tool_context.get_raw_function_info(f).raw_schema[
434
+ "parameters"
435
+ ]
436
+ self._tool_type_map[name] = "RawFunctionTool"
437
+
438
+ tool = Tool(
439
+ toolSpec=ToolSpec(
440
+ name=name,
441
+ description=description,
442
+ inputSchema=ToolInputSchema(json_=json.dumps(input_schema)),
443
+ )
444
+ )
445
+ tools.append(tool)
446
+ tool_choice = self._tool_choice_adapter(self._realtime_model._opts.tool_choice)
447
+ logger.debug(f"TOOL CHOICE: {tool_choice}")
448
+ tool_cfg = ToolConfiguration(tools=tools, toolChoice=tool_choice)
449
+
450
+ # recommended to set greedy inference configs for tool calls
451
+ if not is_given(self._realtime_model.top_p):
452
+ self._realtime_model._opts.top_p = 1.0
453
+ if not is_given(self._realtime_model.temperature):
454
+ self._realtime_model._opts.temperature = 1.0
455
+ return tool_cfg
456
+
457
+ @utils.log_exceptions(logger=logger)
458
+ async def initialize_streams(self, is_restart: bool = False):
459
+ """Open the Bedrock bidirectional stream and spawn background worker tasks.
460
+
461
+ This coroutine is idempotent and can be invoked again when recoverable
462
+ errors (e.g. timeout, throttling) require a fresh session.
463
+
464
+ Args:
465
+ is_restart (bool, optional): Marks whether we are re-initialising an
466
+ existing session after an error. Defaults to False.
467
+ """
468
+ try:
469
+ if not self._bedrock_client:
470
+ logger.info("Creating Bedrock client")
471
+ self._initialize_client()
472
+
473
+ logger.info("Initializing Bedrock stream")
474
+ self._stream_response = (
475
+ await self._bedrock_client.invoke_model_with_bidirectional_stream(
476
+ InvokeModelWithBidirectionalStreamOperationInput(
477
+ model_id=self._realtime_model.model_id
478
+ )
479
+ )
480
+ )
481
+
482
+ if not is_restart:
483
+ pending_events: list[asyncio.Future] = []
484
+ if not self.tools.function_tools:
485
+ pending_events.append(self._tools_ready)
486
+ if not self._instructions_ready.done():
487
+ pending_events.append(self._instructions_ready)
488
+ if not self._chat_ctx_ready.done():
489
+ pending_events.append(self._chat_ctx_ready)
490
+
491
+ # note: can't know during sess init whether tools were not added
492
+ # or if they were added haven't yet been updated
493
+ # therefore in the case there are no tools, we wait the entire timeout
494
+ try:
495
+ if pending_events:
496
+ await asyncio.wait_for(asyncio.gather(*pending_events), timeout=0.5)
497
+ except asyncio.TimeoutError:
498
+ if not self._tools_ready.done():
499
+ logger.warning("Tools not ready after 500ms, continuing without them")
500
+
501
+ if not self._instructions_ready.done():
502
+ logger.warning(
503
+ "Instructions not received after 500ms, proceeding with default instructions" # noqa: E501
504
+ )
505
+ if not self._chat_ctx_ready.done():
506
+ logger.warning(
507
+ "Chat context not received after 500ms, proceeding with empty chat context" # noqa: E501
508
+ )
509
+
510
+ logger.info(
511
+ f"Initializing Bedrock session with realtime options: {self._realtime_model._opts}"
512
+ )
513
+ # there is a 40-message limit on the chat context
514
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
515
+ logger.warning(
516
+ f"Chat context has {len(self._chat_ctx.items)} messages, truncating to {MAX_MESSAGES}" # noqa: E501
517
+ )
518
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
519
+ init_events = self._event_builder.create_prompt_start_block(
520
+ voice_id=self._realtime_model._opts.voice,
521
+ sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
522
+ system_content=self._instructions,
523
+ chat_ctx=self.chat_ctx,
524
+ tool_configuration=self._serialize_tool_config(),
525
+ max_tokens=self._realtime_model._opts.max_tokens,
526
+ top_p=self._realtime_model._opts.top_p,
527
+ temperature=self._realtime_model._opts.temperature,
528
+ )
529
+
530
+ for event in init_events:
531
+ await self._send_raw_event(event)
532
+ logger.debug(f"Sent event: {event}")
533
+
534
+ if not is_restart:
535
+ self._audio_input_task = asyncio.create_task(
536
+ self._process_audio_input(), name="RealtimeSession._process_audio_input"
537
+ )
538
+
539
+ self._response_task = asyncio.create_task(
540
+ self._process_responses(), name="RealtimeSession._process_responses"
541
+ )
542
+ self._is_sess_active.set()
543
+ logger.debug("Stream initialized successfully")
544
+ except Exception as e:
545
+ self._is_sess_active.set_exception(e)
546
+ logger.debug(f"Failed to initialize stream: {str(e)}")
547
+ raise
548
+ return self
549
+
550
+ @utils.log_exceptions(logger=logger)
551
+ def _emit_generation_event(self) -> None:
552
+ """Publish a llm.GenerationCreatedEvent to external subscribers."""
553
+ logger.debug("Emitting generation event")
554
+ generation_ev = llm.GenerationCreatedEvent(
555
+ message_stream=self._current_generation.message_ch,
556
+ function_stream=self._current_generation.function_ch,
557
+ user_initiated=False,
558
+ )
559
+ self.emit("generation_created", generation_ev)
560
+
561
+ @utils.log_exceptions(logger=logger)
562
+ async def _handle_event(self, event_data: dict) -> None:
563
+ """Dispatch a raw Bedrock event to the corresponding _handle_* method."""
564
+ event_type = self._event_builder.get_event_type(event_data)
565
+ event_handler = self._event_handlers.get(event_type)
566
+ if event_handler:
567
+ await event_handler(event_data)
568
+ self._turn_tracker.feed(event_data)
569
+ else:
570
+ logger.warning(f"No event handler found for event type: {event_type}")
571
+
572
+ async def _handle_completion_start_event(self, event_data: dict) -> None:
573
+ log_event_data(event_data)
574
+ self._create_response_generation()
575
+
576
+ def _create_response_generation(self) -> None:
577
+ """Instantiate _ResponseGeneration and emit the GenerationCreated event."""
578
+ if self._current_generation is None:
579
+ self._current_generation = _ResponseGeneration(
580
+ message_ch=utils.aio.Chan(),
581
+ function_ch=utils.aio.Chan(),
582
+ input_id=str(uuid.uuid4()),
583
+ response_id=str(uuid.uuid4()),
584
+ messages={},
585
+ user_messages={},
586
+ speculative_messages={},
587
+ _created_timestamp=datetime.now().isoformat(),
588
+ )
589
+ msg_gen = _MessageGeneration(
590
+ message_id=self._current_generation.response_id,
591
+ text_ch=utils.aio.Chan(),
592
+ audio_ch=utils.aio.Chan(),
593
+ )
594
+ self._current_generation.message_ch.send_nowait(
595
+ llm.MessageGeneration(
596
+ message_id=msg_gen.message_id,
597
+ text_stream=msg_gen.text_ch,
598
+ audio_stream=msg_gen.audio_ch,
599
+ )
600
+ )
601
+ self._current_generation.messages[self._current_generation.response_id] = msg_gen
602
+
603
+ # will be completely ignoring post-ASR text events
604
+ async def _handle_text_output_content_start_event(self, event_data: dict) -> None:
605
+ """Handle text_output_content_start for both user and assistant roles."""
606
+ log_event_data(event_data)
607
+ role = event_data["event"]["contentStart"]["role"]
608
+
609
+ # note: does not work if you emit llm.GCE too early (for some reason)
610
+ if role == "USER":
611
+ self._create_response_generation()
612
+ content_id = event_data["event"]["contentStart"]["contentId"]
613
+ self._current_generation.user_messages[content_id] = self._current_generation.input_id
614
+
615
+ elif (
616
+ role == "ASSISTANT"
617
+ and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
618
+ ):
619
+ text_content_id = event_data["event"]["contentStart"]["contentId"]
620
+ self._current_generation.speculative_messages[text_content_id] = (
621
+ self._current_generation.response_id
622
+ )
623
+
624
+ async def _handle_text_output_content_event(self, event_data: dict) -> None:
625
+ """Stream partial text tokens into the current _MessageGeneration."""
626
+ log_event_data(event_data)
627
+ text_content_id = event_data["event"]["textOutput"]["contentId"]
628
+ text_content = f"{event_data['event']['textOutput']['content']}\n"
629
+
630
+ # currently only agent can be interrupted
631
+ if text_content == '{ "interrupted" : true }\n':
632
+ # the interrupted flag is not being set correctly in chat_ctx
633
+ # this is b/c audio playback is desynced from text transcription
634
+ # TODO: fix this; possibly via a playback timer
635
+ idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
636
+ logger.debug(
637
+ f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
638
+ )
639
+ self._chat_ctx.items[idx].interrupted = True
640
+ self._close_current_generation()
641
+ return
642
+
643
+ # ignore events until turn starts
644
+ if self._current_generation is not None:
645
+ # TODO: rename event to llm.InputTranscriptionUpdated
646
+ if (
647
+ self._current_generation.user_messages.get(text_content_id)
648
+ == self._current_generation.input_id
649
+ ):
650
+ logger.debug(f"INPUT TRANSCRIPTION UPDATED: {text_content}")
651
+ # note: user ASR text is slightly different than what is sent to LiveKit (newline vs whitespace) # noqa: E501
652
+ # TODO: fix this
653
+ self._update_chat_ctx(role="user", text_content=text_content)
654
+
655
+ elif (
656
+ self._current_generation.speculative_messages.get(text_content_id)
657
+ == self._current_generation.response_id
658
+ ):
659
+ curr_gen = self._current_generation.messages[self._current_generation.response_id]
660
+ curr_gen.text_ch.send_nowait(text_content)
661
+ # note: this update is per utterance, not per turn
662
+ self._update_chat_ctx(role="assistant", text_content=text_content)
663
+
664
+ def _update_chat_ctx(self, role: str, text_content: str) -> None:
665
+ """
666
+ Update the chat context with the latest ASR text while guarding against model limitations:
667
+ a) 40 total messages limit
668
+ b) 1kB message size limit
669
+ """
670
+ prev_utterance = self._chat_ctx.items[-1]
671
+ if prev_utterance.role == role:
672
+ if (
673
+ len(prev_utterance.content[0].encode("utf-8")) + len(text_content.encode("utf-8"))
674
+ < MAX_MESSAGE_SIZE
675
+ ):
676
+ prev_utterance.content[0] = "\n".join([prev_utterance.content[0], text_content])
677
+ else:
678
+ self._chat_ctx.add_message(role=role, content=text_content)
679
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
680
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
681
+ else:
682
+ self._chat_ctx.add_message(role=role, content=text_content)
683
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
684
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
685
+
686
+ # cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
687
+ async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
688
+ """Mark the assistant message closed when Bedrock signals END_TURN."""
689
+ stop_reason = event_data["event"]["contentEnd"]["stopReason"]
690
+ text_content_id = event_data["event"]["contentEnd"]["contentId"]
691
+ if (
692
+ self._current_generation
693
+ is not None # means that first utterance in the turn was an interrupt
694
+ and self._current_generation.speculative_messages.get(text_content_id)
695
+ == self._current_generation.response_id
696
+ and stop_reason == "END_TURN"
697
+ ):
698
+ log_event_data(event_data)
699
+ self._close_current_generation()
700
+
701
+ async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
702
+ """Track mapping content_id -> response_id for upcoming tool use."""
703
+ log_event_data(event_data)
704
+ tool_use_content_id = event_data["event"]["contentStart"]["contentId"]
705
+ self._current_generation.tool_messages[tool_use_content_id] = (
706
+ self._current_generation.response_id
707
+ )
708
+
709
+ # note: tool calls are synchronous for now
710
+ async def _handle_tool_output_content_event(self, event_data: dict) -> None:
711
+ """Execute the referenced tool locally and forward results back to Bedrock."""
712
+ log_event_data(event_data)
713
+ tool_use_content_id = event_data["event"]["toolUse"]["contentId"]
714
+ tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
715
+ tool_name = event_data["event"]["toolUse"]["toolName"]
716
+ if (
717
+ self._current_generation.tool_messages.get(tool_use_content_id)
718
+ == self._current_generation.response_id
719
+ ):
720
+ args = event_data["event"]["toolUse"]["content"]
721
+ self._current_generation.function_ch.send_nowait(
722
+ llm.FunctionCall(
723
+ call_id=tool_use_id,
724
+ name=tool_name,
725
+ arguments=args,
726
+ )
727
+ )
728
+
729
+ # note: may need to inject RunContext here...
730
+ tool_type = self._tool_type_map[tool_name]
731
+ if tool_type == "FunctionTool":
732
+ tool_result = await self.tools.function_tools[tool_name](**json.loads(args))
733
+ elif tool_type == "RawFunctionTool":
734
+ tool_result = await self.tools.function_tools[tool_name](json.loads(args))
735
+ else:
736
+ raise ValueError(f"Unknown tool type: {tool_type}")
737
+ logger.debug(f"TOOL ARGS: {args}\nTOOL RESULT: {tool_result}")
738
+
739
+ # Sonic only accepts Structured Output for tool results
740
+ # therefore, must JSON stringify ToolError
741
+ if isinstance(tool_result, ToolError):
742
+ logger.warning(f"TOOL ERROR: {tool_name} {tool_result.message}")
743
+ tool_result = {"error": tool_result.message}
744
+ self._tool_results_ch.send_nowait(
745
+ {
746
+ "tool_use_id": tool_use_id,
747
+ "tool_result": tool_result,
748
+ }
749
+ )
750
+
751
+ async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
752
+ log_event_data(event_data)
753
+
754
+ async def _handle_audio_output_content_start_event(self, event_data: dict) -> None:
755
+ """Associate the upcoming audio chunk with the active assistant message."""
756
+ if self._current_generation is not None:
757
+ log_event_data(event_data)
758
+ audio_content_id = event_data["event"]["contentStart"]["contentId"]
759
+ self._current_generation.speculative_messages[audio_content_id] = (
760
+ self._current_generation.response_id
761
+ )
762
+
763
+ async def _handle_audio_output_content_event(self, event_data: dict) -> None:
764
+ """Decode base64 audio from Bedrock and forward it to the audio stream."""
765
+ if (
766
+ self._current_generation is not None
767
+ and self._current_generation.speculative_messages.get(
768
+ event_data["event"]["audioOutput"]["contentId"]
769
+ )
770
+ == self._current_generation.response_id
771
+ ):
772
+ audio_content = event_data["event"]["audioOutput"]["content"]
773
+ audio_bytes = base64.b64decode(audio_content)
774
+ curr_gen = self._current_generation.messages[self._current_generation.response_id]
775
+ curr_gen.audio_ch.send_nowait(
776
+ rtc.AudioFrame(
777
+ data=audio_bytes,
778
+ sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
779
+ num_channels=DEFAULT_CHANNELS,
780
+ samples_per_channel=len(audio_bytes) // 2,
781
+ )
782
+ )
783
+
784
+ async def _handle_audio_output_content_end_event(self, event_data: dict) -> None:
785
+ """Close the assistant message streams once Bedrock finishes audio for the turn."""
786
+ if (
787
+ self._current_generation is not None
788
+ and event_data["event"]["contentEnd"]["stopReason"] == "END_TURN"
789
+ and self._current_generation.speculative_messages.get(
790
+ event_data["event"]["contentEnd"]["contentId"]
791
+ )
792
+ == self._current_generation.response_id
793
+ ):
794
+ log_event_data(event_data)
795
+ self._close_current_generation()
796
+
797
+ def _close_current_generation(self) -> None:
798
+ """Helper that closes all channels of the active _ResponseGeneration."""
799
+ if self._current_generation is not None:
800
+ if self._current_generation.response_id in self._current_generation.messages:
801
+ curr_gen = self._current_generation.messages[self._current_generation.response_id]
802
+ if not curr_gen.audio_ch.closed:
803
+ curr_gen.audio_ch.close()
804
+ if not curr_gen.text_ch.closed:
805
+ curr_gen.text_ch.close()
806
+
807
+ if not self._current_generation.message_ch.closed:
808
+ self._current_generation.message_ch.close()
809
+ if not self._current_generation.function_ch.closed:
810
+ self._current_generation.function_ch.close()
811
+
812
+ self._current_generation = None
813
+
814
+ async def _handle_completion_end_event(self, event_data: dict) -> None:
815
+ log_event_data(event_data)
816
+
817
+ async def _handle_other_event(self, event_data: dict) -> None:
818
+ log_event_data(event_data)
819
+
820
+ async def _handle_usage_event(self, event_data: dict) -> None:
821
+ # log_event_data(event_data)
822
+ # TODO: implement duration and ttft
823
+ input_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["input"]
824
+ output_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["output"]
825
+ # Q: should we be counting per turn or utterance?
826
+ metrics = RealtimeModelMetrics(
827
+ label=self._realtime_model._label,
828
+ # TODO: pass in the correct request_id
829
+ request_id=event_data["event"]["usageEvent"]["completionId"],
830
+ timestamp=time.monotonic(),
831
+ duration=0,
832
+ ttft=0,
833
+ cancelled=False,
834
+ input_tokens=input_tokens["speechTokens"] + input_tokens["textTokens"],
835
+ output_tokens=output_tokens["speechTokens"] + output_tokens["textTokens"],
836
+ total_tokens=input_tokens["speechTokens"]
837
+ + input_tokens["textTokens"]
838
+ + output_tokens["speechTokens"]
839
+ + output_tokens["textTokens"],
840
+ # need duration to calculate this
841
+ tokens_per_second=0,
842
+ input_token_details=RealtimeModelMetrics.InputTokenDetails(
843
+ text_tokens=input_tokens["textTokens"],
844
+ audio_tokens=input_tokens["speechTokens"],
845
+ image_tokens=0,
846
+ cached_tokens=0,
847
+ cached_tokens_details=None,
848
+ ),
849
+ output_token_details=RealtimeModelMetrics.OutputTokenDetails(
850
+ text_tokens=output_tokens["textTokens"],
851
+ audio_tokens=output_tokens["speechTokens"],
852
+ image_tokens=0,
853
+ ),
854
+ )
855
+ self.emit("metrics_collected", metrics)
856
+
857
+ @utils.log_exceptions(logger=logger)
858
+ async def _process_responses(self):
859
+ """Background task that drains Bedrock's output stream and feeds the event handlers."""
860
+ try:
861
+ await self._is_sess_active.wait()
862
+
863
+ # note: may need another signal here to block input task until bedrock is ready
864
+ # TODO: save this as a field so we're not re-awaiting it every time
865
+ _, output_stream = await self._stream_response.await_output()
866
+ while self._is_sess_active.is_set():
867
+ # and not self.stream_response.output_stream.closed:
868
+ try:
869
+ result = await output_stream.receive()
870
+ if result.value and result.value.bytes_:
871
+ try:
872
+ response_data = result.value.bytes_.decode("utf-8")
873
+ json_data = json.loads(response_data)
874
+ # logger.debug(f"Received event: {json_data}")
875
+ await self._handle_event(json_data)
876
+ except json.JSONDecodeError:
877
+ logger.warning(f"JSON decode error: {response_data}")
878
+ else:
879
+ logger.warning("No response received")
880
+ except asyncio.CancelledError:
881
+ logger.info("Response processing task cancelled")
882
+ self._close_current_generation()
883
+ raise
884
+ except ValidationException as ve:
885
+ # there is a 3min no-activity (e.g. silence) timeout on the stream, after which the stream is closed # noqa: E501
886
+ if (
887
+ "InternalErrorCode=531::RST_STREAM closed stream. HTTP/2 error code: NO_ERROR" # noqa: E501
888
+ in ve.message
889
+ ):
890
+ logger.warning(f"Validation error: {ve}\nAttempting to recover...")
891
+ await self._restart_session(ve)
892
+
893
+ else:
894
+ logger.error(f"Validation error: {ve}")
895
+ request_id = ve.split(" ")[0].split("=")[1]
896
+ self.emit(
897
+ "error",
898
+ llm.RealtimeModelError(
899
+ timestamp=time.monotonic(),
900
+ label=self._realtime_model._label,
901
+ error=APIStatusError(
902
+ message=ve.message,
903
+ status_code=400,
904
+ request_id=request_id,
905
+ body=ve,
906
+ retryable=False,
907
+ ),
908
+ recoverable=False,
909
+ ),
910
+ )
911
+ raise
912
+ except (ThrottlingException, ModelNotReadyException, ModelErrorException) as re:
913
+ logger.warning(f"Retryable error: {re}\nAttempting to recover...")
914
+ await self._restart_session(re)
915
+ break
916
+ except ModelTimeoutException as mte:
917
+ logger.warning(f"Model timeout error: {mte}\nAttempting to recover...")
918
+ await self._restart_session(mte)
919
+ break
920
+ except ValueError as val_err:
921
+ if "I/O operation on closed file." == val_err.args[0]:
922
+ logger.info("initiating graceful shutdown of session")
923
+ break
924
+ raise
925
+ except OSError:
926
+ logger.info("stream already closed, exiting")
927
+ break
928
+ except Exception as e:
929
+ err_msg = getattr(e, "message", str(e))
930
+ logger.error(f"Response processing error: {err_msg} (type: {type(e)})")
931
+ request_id = None
932
+ try:
933
+ request_id = err_msg.split(" ")[0].split("=")[1]
934
+ except Exception:
935
+ pass
936
+
937
+ self.emit(
938
+ "error",
939
+ llm.RealtimeModelError(
940
+ timestamp=time.monotonic(),
941
+ label=self._realtime_model._label,
942
+ error=APIStatusError(
943
+ message=e.message,
944
+ status_code=500,
945
+ request_id=request_id,
946
+ body=e,
947
+ retryable=False,
948
+ ),
949
+ recoverable=False,
950
+ ),
951
+ )
952
+ raise
953
+
954
+ finally:
955
+ logger.info("main output response stream processing task exiting")
956
+ self._is_sess_active.clear()
957
+
958
+ async def _restart_session(self, ex: Exception) -> None:
959
+ if self._session_restart_attempts >= DEFAULT_MAX_SESSION_RESTART_ATTEMPTS:
960
+ logger.error("Max session restart attempts reached, exiting")
961
+ err_msg = getattr(ex, "message", str(ex))
962
+ request_id = None
963
+ try:
964
+ request_id = err_msg.split(" ")[0].split("=")[1]
965
+ except Exception:
966
+ pass
967
+ self.emit(
968
+ "error",
969
+ llm.RealtimeModelError(
970
+ timestamp=time.monotonic(),
971
+ label=self._realtime_model._label,
972
+ error=APIStatusError(
973
+ message=f"Max restart attempts exceeded: {err_msg}",
974
+ status_code=500,
975
+ request_id=request_id,
976
+ body=ex,
977
+ retryable=False,
978
+ ),
979
+ recoverable=False,
980
+ ),
981
+ )
982
+ self._is_sess_active.clear()
983
+ return
984
+ self._session_restart_attempts += 1
985
+ self._is_sess_active.clear()
986
+ delay = 2 ** (self._session_restart_attempts - 1) - 1
987
+ await asyncio.sleep(min(delay, DEFAULT_MAX_SESSION_RESTART_DELAY))
988
+ await self.initialize_streams(is_restart=True)
989
+ logger.info(
990
+ f"Session restarted successfully ({self._session_restart_attempts}/{DEFAULT_MAX_SESSION_RESTART_ATTEMPTS})" # noqa: E501
991
+ )
992
+
993
+ @property
994
+ def chat_ctx(self) -> llm.ChatContext:
995
+ return self._chat_ctx.copy()
996
+
997
+ @property
998
+ def tools(self) -> llm.ToolContext:
999
+ return self._tools.copy()
1000
+
1001
+ async def update_instructions(self, instructions: str) -> None:
1002
+ """Injects the system prompt at the start of the session."""
1003
+ self._instructions = instructions
1004
+ self._instructions_ready.set_result(True)
1005
+ logger.debug(f"Instructions updated: {instructions}")
1006
+
1007
+ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
1008
+ """Inject an initial chat history once during the very first session startup."""
1009
+ # sometimes fires randomly
1010
+ # add a guard here to only allow chat_ctx to be updated on
1011
+ # the very first session initialization
1012
+ if not self._chat_ctx_ready.done():
1013
+ self._chat_ctx = chat_ctx.copy()
1014
+ logger.debug(f"Chat context updated: {self._chat_ctx.items}")
1015
+ self._chat_ctx_ready.set_result(True)
1016
+
1017
+ async def _send_tool_events(self, tool_use_id: str, tool_result: str) -> None:
1018
+ """Send tool_result back to Bedrock, grouped under tool_use_id."""
1019
+ tool_content_name = str(uuid.uuid4())
1020
+ tool_events = self._event_builder.create_tool_content_block(
1021
+ content_name=tool_content_name,
1022
+ tool_use_id=tool_use_id,
1023
+ content=tool_result,
1024
+ )
1025
+ for event in tool_events:
1026
+ await self._send_raw_event(event)
1027
+ # logger.debug(f"Sent tool event: {event}")
1028
+
1029
+ def _tool_choice_adapter(self, tool_choice: llm.ToolChoice) -> dict[str, dict[str, str]] | None:
1030
+ """Translate the LiveKit ToolChoice enum into Sonic's JSON schema."""
1031
+ if tool_choice == "auto":
1032
+ return {"auto": {}}
1033
+ elif tool_choice == "required":
1034
+ return {"any": {}}
1035
+ elif isinstance(tool_choice, dict) and tool_choice["type"] == "function":
1036
+ return {"tool": {"name": tool_choice["function"]["name"]}}
1037
+ else:
1038
+ return None
1039
+
1040
+ # note: return value from tool functions registered to Sonic must be Structured Output (a dict that is JSON serializable) # noqa: E501
1041
+ async def update_tools(self, tools: list[llm.FunctionTool | llm.RawFunctionTool | Any]) -> None:
1042
+ """Replace the active tool set with tools and notify Sonic if necessary."""
1043
+ logger.debug(f"Updating tools: {tools}")
1044
+ retained_tools: list[llm.FunctionTool | llm.RawFunctionTool] = []
1045
+
1046
+ for tool in tools:
1047
+ retained_tools.append(tool)
1048
+ self._tools = llm.ToolContext(retained_tools)
1049
+ if retained_tools:
1050
+ self._tools_ready.set_result(True)
1051
+ logger.debug("Tool list has been injected")
1052
+
1053
+ def update_options(self, *, tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN) -> None:
1054
+ """Live update of inference options is not supported by Sonic yet."""
1055
+ logger.warning(
1056
+ "updating inference configuration options is not yet supported by Nova Sonic's Realtime API" # noqa: E501
1057
+ )
1058
+
1059
+ @utils.log_exceptions(logger=logger)
1060
+ def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
1061
+ """Ensure mic audio matches Sonic's required sample rate & channels."""
1062
+ if self._input_resampler:
1063
+ if frame.sample_rate != self._input_resampler._input_rate:
1064
+ self._input_resampler = None
1065
+
1066
+ if self._input_resampler is None and (
1067
+ frame.sample_rate != DEFAULT_INPUT_SAMPLE_RATE or frame.num_channels != DEFAULT_CHANNELS
1068
+ ):
1069
+ self._input_resampler = rtc.AudioResampler(
1070
+ input_rate=frame.sample_rate,
1071
+ output_rate=DEFAULT_INPUT_SAMPLE_RATE,
1072
+ num_channels=DEFAULT_CHANNELS,
1073
+ )
1074
+
1075
+ if self._input_resampler:
1076
+ # flush the resampler when the input source is changed
1077
+ yield from self._input_resampler.push(frame)
1078
+ else:
1079
+ yield frame
1080
+
1081
+ @utils.log_exceptions(logger=logger)
1082
+ async def _process_audio_input(self):
1083
+ """Background task that feeds audio and tool results into the Bedrock stream."""
1084
+ await self._send_raw_event(self._event_builder.create_audio_content_start_event())
1085
+ logger.info("Starting audio input processing loop")
1086
+ while self._is_sess_active.is_set():
1087
+ try:
1088
+ # note: could potentially pull this out into a separate task
1089
+ try:
1090
+ val = self._tool_results_ch.recv_nowait()
1091
+ tool_result = val["tool_result"]
1092
+ tool_use_id = val["tool_use_id"]
1093
+ await self._send_tool_events(tool_use_id, tool_result)
1094
+
1095
+ except utils.aio.channel.ChanEmpty:
1096
+ pass
1097
+ except utils.aio.channel.ChanClosed:
1098
+ logger.warning(
1099
+ "tool results channel closed, exiting audio input processing loop"
1100
+ )
1101
+ break
1102
+
1103
+ try:
1104
+ audio_bytes = await self._audio_input_chan.recv()
1105
+ blob = base64.b64encode(audio_bytes)
1106
+ audio_event = self._event_builder.create_audio_input_event(
1107
+ audio_content=blob.decode("utf-8"),
1108
+ )
1109
+
1110
+ await self._send_raw_event(audio_event)
1111
+ except utils.aio.channel.ChanEmpty:
1112
+ pass
1113
+ except utils.aio.channel.ChanClosed:
1114
+ logger.warning(
1115
+ "audio input channel closed, exiting audio input processing loop"
1116
+ )
1117
+ break
1118
+
1119
+ except asyncio.CancelledError:
1120
+ logger.info("Audio processing loop cancelled")
1121
+ self._audio_input_chan.close()
1122
+ self._tool_results_ch.close()
1123
+ raise
1124
+ except Exception:
1125
+ logger.exception("Error processing audio")
1126
+
1127
+ # for debugging purposes only
1128
+ def _log_significant_audio(self, audio_bytes: bytes) -> None:
1129
+ """Utility that prints a debug message when the audio chunk has non-trivial RMS energy."""
1130
+ squared_sum = sum(sample**2 for sample in audio_bytes)
1131
+ if (squared_sum / len(audio_bytes)) ** 0.5 > 200:
1132
+ if lk_bedrock_debug:
1133
+ log_message("Enqueuing significant audio chunk", AnsiColors.BLUE)
1134
+
1135
+ @utils.log_exceptions(logger=logger)
1136
+ def push_audio(self, frame: rtc.AudioFrame) -> None:
1137
+ """Enqueue an incoming mic rtc.AudioFrame for transcription."""
1138
+ if not self._audio_input_chan.closed:
1139
+ # logger.debug(f"Raw audio received: samples={len(frame.data)} rate={frame.sample_rate} channels={frame.num_channels}") # noqa: E501
1140
+ for f in self._resample_audio(frame):
1141
+ # logger.debug(f"Resampled audio: samples={len(frame.data)} rate={frame.sample_rate} channels={frame.num_channels}") # noqa: E501
1142
+
1143
+ for nf in self._bstream.write(f.data.tobytes()):
1144
+ self._log_significant_audio(nf.data)
1145
+ self._audio_input_chan.send_nowait(nf.data)
1146
+ else:
1147
+ logger.warning("audio input channel closed, skipping audio")
1148
+
1149
+ def generate_reply(
1150
+ self,
1151
+ *,
1152
+ instructions: NotGivenOr[str] = NOT_GIVEN,
1153
+ ) -> asyncio.Future[llm.GenerationCreatedEvent]:
1154
+ logger.warning("unprompted generation is not supported by Nova Sonic's Realtime API")
1155
+
1156
+ def commit_audio(self) -> None:
1157
+ logger.warning("commit_audio is not supported by Nova Sonic's Realtime API")
1158
+
1159
+ def clear_audio(self) -> None:
1160
+ logger.warning("clear_audio is not supported by Nova Sonic's Realtime API")
1161
+
1162
+ def push_video(self, frame: rtc.VideoFrame) -> None:
1163
+ logger.warning("video is not supported by Nova Sonic's Realtime API")
1164
+
1165
+ def interrupt(self) -> None:
1166
+ logger.warning("interrupt is not supported by Nova Sonic's Realtime API")
1167
+
1168
+ def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
1169
+ logger.warning("truncate is not supported by Nova Sonic's Realtime API")
1170
+
1171
+ @utils.log_exceptions(logger=logger)
1172
+ async def aclose(self) -> None:
1173
+ """Gracefully shut down the realtime session and release network resources."""
1174
+ logger.info("attempting to shutdown agent session")
1175
+ if not self._is_sess_active.is_set():
1176
+ logger.info("agent session already inactive")
1177
+ return
1178
+
1179
+ for event in self._event_builder.create_prompt_end_block():
1180
+ await self._send_raw_event(event)
1181
+ # allow event loops to fall out naturally
1182
+ # otherwise, the smithy layer will raise an InvalidStateError during cancellation
1183
+ self._is_sess_active.clear()
1184
+
1185
+ if self._stream_response and not self._stream_response.output_stream.closed:
1186
+ await self._stream_response.output_stream.close()
1187
+
1188
+ # note: even after the self.is_active flag is flipped and the output stream is closed,
1189
+ # there is a future inside output_stream.receive() at the AWS-CRT C layer that blocks
1190
+ # resulting in an error after cancellation
1191
+ # however, it's mostly cosmetic-- the event loop will still exit
1192
+ # TODO: fix this nit
1193
+ if self._response_task:
1194
+ try:
1195
+ await asyncio.wait_for(self._response_task, timeout=1.0)
1196
+ except asyncio.TimeoutError:
1197
+ logger.warning("shutdown of output event loop timed out-- cancelling")
1198
+ self._response_task.cancel()
1199
+
1200
+ # must cancel the audio input task before closing the input stream
1201
+ if self._audio_input_task and not self._audio_input_task.done():
1202
+ self._audio_input_task.cancel()
1203
+ if self._stream_response and not self._stream_response.input_stream.closed:
1204
+ await self._stream_response.input_stream.close()
1205
+
1206
+ await asyncio.gather(self._response_task, self._audio_input_task, return_exceptions=True)
1207
+ logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
1208
+ logger.info("Session end")