livekit-plugins-aws 1.1.4__py3-none-any.whl → 1.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of livekit-plugins-aws might be problematic. Click here for more details.

@@ -0,0 +1,1276 @@
1
+ # mypy: disable-error-code=unused-ignore
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import asyncio
7
+ import base64
8
+ import json
9
+ import os
10
+ import time
11
+ import uuid
12
+ import weakref
13
+ from collections.abc import Iterator
14
+ from dataclasses import dataclass, field
15
+ from datetime import datetime
16
+ from typing import Any, Callable, Literal, cast
17
+
18
+ import boto3
19
+ from aws_sdk_bedrock_runtime.client import (
20
+ BedrockRuntimeClient,
21
+ InvokeModelWithBidirectionalStreamOperationInput,
22
+ )
23
+ from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme
24
+ from aws_sdk_bedrock_runtime.models import (
25
+ BidirectionalInputPayloadPart,
26
+ InvokeModelWithBidirectionalStreamInputChunk,
27
+ ModelErrorException,
28
+ ModelNotReadyException,
29
+ ModelTimeoutException,
30
+ ThrottlingException,
31
+ ValidationException,
32
+ )
33
+ from smithy_aws_core.identity import AWSCredentialsIdentity
34
+ from smithy_core.aio.interfaces.identity import IdentityResolver
35
+
36
+ from livekit import rtc
37
+ from livekit.agents import (
38
+ APIStatusError,
39
+ llm,
40
+ utils,
41
+ )
42
+ from livekit.agents.metrics import RealtimeModelMetrics
43
+ from livekit.agents.types import NOT_GIVEN, NotGivenOr
44
+ from livekit.agents.utils import is_given
45
+ from livekit.plugins.aws.experimental.realtime.turn_tracker import _TurnTracker
46
+
47
+ from ...log import logger
48
+ from .events import (
49
+ VOICE_ID,
50
+ SonicEventBuilder as seb,
51
+ Tool,
52
+ ToolConfiguration,
53
+ ToolInputSchema,
54
+ ToolSpec,
55
+ )
56
+ from .pretty_printer import AnsiColors, log_event_data, log_message
57
+
58
+ DEFAULT_INPUT_SAMPLE_RATE = 16000
59
+ DEFAULT_OUTPUT_SAMPLE_RATE = 24000
60
+ DEFAULT_SAMPLE_SIZE_BITS = 16
61
+ DEFAULT_CHANNELS = 1
62
+ DEFAULT_CHUNK_SIZE = 512
63
+ DEFAULT_TEMPERATURE = 0.7
64
+ DEFAULT_TOP_P = 0.9
65
+ DEFAULT_MAX_TOKENS = 1024
66
+ MAX_MESSAGE_SIZE = 1024
67
+ MAX_MESSAGES = 40
68
+ DEFAULT_MAX_SESSION_RESTART_ATTEMPTS = 3
69
+ DEFAULT_MAX_SESSION_RESTART_DELAY = 10
70
+ DEFAULT_SYSTEM_PROMPT = (
71
+ "Your name is Sonic. You are a friend and eagerly helpful assistant."
72
+ "The user and you will engage in a spoken dialog exchanging the transcripts of a natural real-time conversation." # noqa: E501
73
+ "Keep your responses short and concise unless the user asks you to elaborate or you are explicitly asked to be verbose and chatty." # noqa: E501
74
+ "Do not repeat yourself. Do not ask the user to repeat themselves."
75
+ "Do ask the user to confirm or clarify their response if you are not sure what they mean."
76
+ "If after asking the user for clarification you still do not understand, be honest and tell them that you do not understand." # noqa: E501
77
+ "Do not make up information or make assumptions. If you do not know the answer, tell the user that you do not know the answer." # noqa: E501
78
+ "If the user makes a request of you that you cannot fulfill, tell them why you cannot fulfill it." # noqa: E501
79
+ "When making tool calls, inform the user that you are using a tool to generate the response."
80
+ "Avoid formatted lists or numbering and keep your output as a spoken transcript to be acted out." # noqa: E501
81
+ "Be appropriately emotive when responding to the user. Use American English as the language for your responses." # noqa: E501
82
+ )
83
+
84
+ lk_bedrock_debug = int(os.getenv("LK_BEDROCK_DEBUG", 0))
85
+
86
+
87
+ @dataclass
88
+ class _RealtimeOptions:
89
+ """Configuration container for a Sonic realtime session.
90
+
91
+ Attributes:
92
+ voice (VOICE_ID): Voice identifier used for TTS output.
93
+ temperature (float): Sampling temperature controlling randomness; 1.0 is most deterministic.
94
+ top_p (float): Nucleus sampling parameter; 0.0 considers all tokens.
95
+ max_tokens (int): Maximum number of tokens the model may generate in a single response.
96
+ tool_choice (llm.ToolChoice | None): Strategy that dictates how the model should invoke tools.
97
+ region (str): AWS region hosting the Bedrock Sonic model endpoint.
98
+ """ # noqa: E501
99
+
100
+ voice: VOICE_ID
101
+ temperature: float
102
+ top_p: float
103
+ max_tokens: int
104
+ tool_choice: llm.ToolChoice | None
105
+ region: str
106
+
107
+
108
+ @dataclass
109
+ class _MessageGeneration:
110
+ """Grouping of streams that together represent one assistant message.
111
+
112
+ Attributes:
113
+ message_id (str): Unique identifier that ties together text and audio for a single assistant turn.
114
+ text_ch (utils.aio.Chan[str]): Channel that yields partial text tokens as they arrive.
115
+ audio_ch (utils.aio.Chan[rtc.AudioFrame]): Channel that yields audio frames for the same assistant turn.
116
+ """ # noqa: E501
117
+
118
+ message_id: str
119
+ text_ch: utils.aio.Chan[str]
120
+ audio_ch: utils.aio.Chan[rtc.AudioFrame]
121
+
122
+
123
+ @dataclass
124
+ class _ResponseGeneration:
125
+ """Book-keeping dataclass tracking the lifecycle of a Sonic turn.
126
+
127
+ This object is created whenever we receive a *completion_start* event from the model
128
+ and is disposed of once the assistant turn finishes (e.g. *END_TURN*).
129
+
130
+ Attributes:
131
+ message_ch (utils.aio.Chan[llm.MessageGeneration]): Multiplexed stream for all assistant messages.
132
+ function_ch (utils.aio.Chan[llm.FunctionCall]): Stream that emits function tool calls.
133
+ input_id (str): Synthetic message id for the user input of the current turn.
134
+ response_id (str): Synthetic message id for the assistant reply of the current turn.
135
+ messages (dict[str, _MessageGeneration]): Map of message_id -> per-message stream containers.
136
+ user_messages (dict[str, str]): Map Bedrock content_id -> input_id.
137
+ speculative_messages (dict[str, str]): Map Bedrock content_id -> response_id (assistant side).
138
+ tool_messages (dict[str, str]): Map Bedrock content_id -> response_id for tool calls.
139
+ output_text (str): Accumulated assistant text (only used for metrics / debugging).
140
+ _created_timestamp (str): ISO-8601 timestamp when the generation record was created.
141
+ _first_token_timestamp (float | None): Wall-clock time of first token emission.
142
+ _completed_timestamp (float | None): Wall-clock time when the turn fully completed.
143
+ """ # noqa: E501
144
+
145
+ message_ch: utils.aio.Chan[llm.MessageGeneration]
146
+ function_ch: utils.aio.Chan[llm.FunctionCall]
147
+ input_id: str # corresponds to user's portion of the turn
148
+ response_id: str # corresponds to agent's portion of the turn
149
+ messages: dict[str, _MessageGeneration] = field(default_factory=dict)
150
+ user_messages: dict[str, str] = field(default_factory=dict)
151
+ speculative_messages: dict[str, str] = field(default_factory=dict)
152
+ tool_messages: dict[str, str] = field(default_factory=dict)
153
+ output_text: str = "" # agent ASR text
154
+ _created_timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
155
+ _first_token_timestamp: float | None = None
156
+ _completed_timestamp: float | None = None
157
+
158
+
159
+ class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc]
160
+ """IdentityResolver implementation that sources AWS credentials from boto3.
161
+
162
+ The resolver delegates to the default boto3.Session() credential chain which
163
+ checks environment variables, shared credentials files, EC2 instance profiles, etc.
164
+ The credentials are then wrapped in an AWSCredentialsIdentity so they can be
165
+ passed into Bedrock runtime clients.
166
+ """
167
+
168
+ def __init__(self) -> None:
169
+ self.session = boto3.Session() # type: ignore[attr-defined]
170
+
171
+ async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity:
172
+ """Asynchronously resolve AWS credentials.
173
+
174
+ This method is invoked by the Bedrock runtime client whenever a new request needs to be
175
+ signed. It converts the static or temporary credentials returned by boto3
176
+ into an AWSCredentialsIdentity instance.
177
+
178
+ Returns:
179
+ AWSCredentialsIdentity: Identity containing the
180
+ AWS access key, secret key and optional session token.
181
+
182
+ Raises:
183
+ ValueError: If no credentials could be found by boto3.
184
+ """
185
+ try:
186
+ logger.debug("Attempting to load AWS credentials")
187
+ credentials = self.session.get_credentials()
188
+ if not credentials:
189
+ logger.error("Unable to load AWS credentials")
190
+ raise ValueError("Unable to load AWS credentials")
191
+
192
+ creds = credentials.get_frozen_credentials()
193
+ logger.debug(
194
+ f"AWS credentials loaded successfully. AWS_ACCESS_KEY_ID: {creds.access_key[:4]}***"
195
+ )
196
+
197
+ identity = AWSCredentialsIdentity(
198
+ access_key_id=creds.access_key,
199
+ secret_access_key=creds.secret_key,
200
+ session_token=creds.token if creds.token else None,
201
+ expiration=None,
202
+ )
203
+ return identity
204
+ except Exception as e:
205
+ logger.error(f"Failed to load AWS credentials: {str(e)}")
206
+ raise ValueError(f"Failed to load AWS credentials: {str(e)}") # noqa: B904
207
+
208
+
209
+ class RealtimeModel(llm.RealtimeModel):
210
+ """High-level entry point that conforms to the LiveKit RealtimeModel interface.
211
+
212
+ The object is very light-weight-– it mainly stores default inference options and
213
+ spawns a RealtimeSession when session() is invoked.
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ *,
219
+ voice: NotGivenOr[VOICE_ID] = NOT_GIVEN,
220
+ temperature: NotGivenOr[float] = NOT_GIVEN,
221
+ top_p: NotGivenOr[float] = NOT_GIVEN,
222
+ max_tokens: NotGivenOr[int] = NOT_GIVEN,
223
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
224
+ region: NotGivenOr[str] = NOT_GIVEN,
225
+ ):
226
+ """Instantiate a new RealtimeModel.
227
+
228
+ Args:
229
+ voice (VOICE_ID | NotGiven): Preferred voice id for Sonic TTS output. Falls back to "tiffany".
230
+ temperature (float | NotGiven): Sampling temperature (0-1). Defaults to DEFAULT_TEMPERATURE.
231
+ top_p (float | NotGiven): Nucleus sampling probability mass. Defaults to DEFAULT_TOP_P.
232
+ max_tokens (int | NotGiven): Upper bound for tokens emitted by the model. Defaults to DEFAULT_MAX_TOKENS.
233
+ tool_choice (llm.ToolChoice | None | NotGiven): Strategy for tool invocation ("auto", "required", or explicit function).
234
+ region (str | NotGiven): AWS region of the Bedrock runtime endpoint.
235
+ """ # noqa: E501
236
+ super().__init__(
237
+ capabilities=llm.RealtimeCapabilities(
238
+ message_truncation=False,
239
+ turn_detection=True,
240
+ user_transcription=True,
241
+ auto_tool_reply_generation=True,
242
+ )
243
+ )
244
+ self.model_id = "amazon.nova-sonic-v1:0"
245
+ # note: temperature and top_p do not follow industry standards and are defined slightly differently for Sonic # noqa: E501
246
+ # temperature ranges from 0.0 to 1.0, where 0.0 is the most random and 1.0 is the most deterministic # noqa: E501
247
+ # top_p ranges from 0.0 to 1.0, where 0.0 is the most random and 1.0 is the most deterministic # noqa: E501
248
+ self.temperature = temperature
249
+ self.top_p = top_p
250
+ self._opts = _RealtimeOptions(
251
+ voice=cast(VOICE_ID, voice) if is_given(voice) else "tiffany",
252
+ temperature=temperature if is_given(temperature) else DEFAULT_TEMPERATURE,
253
+ top_p=top_p if is_given(top_p) else DEFAULT_TOP_P,
254
+ max_tokens=max_tokens if is_given(max_tokens) else DEFAULT_MAX_TOKENS,
255
+ tool_choice=tool_choice or None,
256
+ region=region if is_given(region) else "us-east-1",
257
+ )
258
+ self._sessions = weakref.WeakSet[RealtimeSession]()
259
+
260
+ def session(self) -> RealtimeSession:
261
+ """Return a new RealtimeSession bound to this model instance."""
262
+ sess = RealtimeSession(self)
263
+
264
+ # note: this is a hack to get the session to initialize itself
265
+ # TODO: change how RealtimeSession is initialized by creating a single task main_atask that spawns subtasks # noqa: E501
266
+ asyncio.create_task(sess.initialize_streams())
267
+ self._sessions.add(sess)
268
+ return sess
269
+
270
+ # stub b/c RealtimeSession.aclose() is invoked directly
271
+ async def aclose(self) -> None:
272
+ pass
273
+
274
+
275
+ class RealtimeSession( # noqa: F811
276
+ llm.RealtimeSession[Literal["bedrock_server_event_received", "bedrock_client_event_queued"]]
277
+ ):
278
+ """Bidirectional streaming session against the Nova Sonic Bedrock runtime.
279
+
280
+ The session owns two asynchronous tasks:
281
+
282
+ 1. _process_audio_input – pushes user mic audio and tool results to Bedrock.
283
+ 2. _process_responses – receives server events from Bedrock and converts them into
284
+ LiveKit abstractions such as llm.MessageGeneration.
285
+
286
+ A set of helper handlers (_handle_*) transform the low-level Bedrock
287
+ JSON payloads into higher-level application events and keep
288
+ _ResponseGeneration state in sync.
289
+ """
290
+
291
+ def __init__(self, realtime_model: RealtimeModel) -> None:
292
+ """Create and wire-up a new realtime session.
293
+
294
+ Args:
295
+ realtime_model (RealtimeModel): Parent model instance that stores static
296
+ inference options and the Smithy Bedrock client configuration.
297
+ """
298
+ super().__init__(realtime_model)
299
+ self._realtime_model: RealtimeModel = realtime_model
300
+ self._event_builder = seb(
301
+ prompt_name=str(uuid.uuid4()),
302
+ audio_content_name=str(uuid.uuid4()),
303
+ )
304
+ self._input_resampler: rtc.AudioResampler | None = None
305
+ self._bstream = utils.audio.AudioByteStream(
306
+ DEFAULT_INPUT_SAMPLE_RATE, DEFAULT_CHANNELS, samples_per_channel=DEFAULT_CHUNK_SIZE
307
+ )
308
+
309
+ self._response_task = None
310
+ self._audio_input_task = None
311
+ self._stream_response = None
312
+ self._bedrock_client = None
313
+ self._pending_tools: set[str] = set()
314
+ self._is_sess_active = asyncio.Event()
315
+ self._chat_ctx = llm.ChatContext.empty()
316
+ self._tools = llm.ToolContext.empty()
317
+ self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
318
+ self._tools_ready = asyncio.get_running_loop().create_future()
319
+ self._instructions_ready = asyncio.get_running_loop().create_future()
320
+ self._chat_ctx_ready = asyncio.get_running_loop().create_future()
321
+ self._instructions = DEFAULT_SYSTEM_PROMPT
322
+ self._audio_input_chan = utils.aio.Chan[bytes]()
323
+ self._current_generation: _ResponseGeneration | None = None
324
+
325
+ # note: currently tracks session restart attempts across all sessions
326
+ # TODO: track restart attempts per turn
327
+ self._session_restart_attempts = 0
328
+
329
+ self._event_handlers = {
330
+ "completion_start": self._handle_completion_start_event,
331
+ "audio_output_content_start": self._handle_audio_output_content_start_event,
332
+ "audio_output_content": self._handle_audio_output_content_event,
333
+ "audio_output_content_end": self._handle_audio_output_content_end_event,
334
+ "text_output_content_start": self._handle_text_output_content_start_event,
335
+ "text_output_content": self._handle_text_output_content_event,
336
+ "text_output_content_end": self._handle_text_output_content_end_event,
337
+ "tool_output_content_start": self._handle_tool_output_content_start_event,
338
+ "tool_output_content": self._handle_tool_output_content_event,
339
+ "tool_output_content_end": self._handle_tool_output_content_end_event,
340
+ "completion_end": self._handle_completion_end_event,
341
+ "usage": self._handle_usage_event,
342
+ "other_event": self._handle_other_event,
343
+ }
344
+ self._turn_tracker = _TurnTracker(
345
+ cast(Callable[[str, Any], None], self.emit),
346
+ cast(Callable[[], None], self.emit_generation_event),
347
+ )
348
+
349
+ @utils.log_exceptions(logger=logger)
350
+ def _initialize_client(self) -> None:
351
+ """Instantiate the Bedrock runtime client"""
352
+ config = Config(
353
+ endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com",
354
+ region=self._realtime_model._opts.region,
355
+ aws_credentials_identity_resolver=Boto3CredentialsResolver(),
356
+ http_auth_scheme_resolver=HTTPAuthSchemeResolver(),
357
+ http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()},
358
+ )
359
+ self._bedrock_client = BedrockRuntimeClient(config=config)
360
+
361
+ @utils.log_exceptions(logger=logger)
362
+ async def _send_raw_event(self, event_json: str) -> None:
363
+ """Low-level helper that serialises event_json and forwards it to the bidirectional stream.
364
+
365
+ Args:
366
+ event_json (str): The JSON payload (already in Bedrock wire format) to queue.
367
+
368
+ Raises:
369
+ Exception: Propagates any failures returned by the Bedrock runtime client.
370
+ """
371
+ if not self._stream_response:
372
+ logger.warning("stream not initialized; dropping event (this should never occur)")
373
+ return
374
+
375
+ event = InvokeModelWithBidirectionalStreamInputChunk(
376
+ value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8"))
377
+ )
378
+
379
+ try:
380
+ await self._stream_response.input_stream.send(event)
381
+ except Exception as e:
382
+ logger.exception("Error sending event")
383
+ err_msg = getattr(e, "message", str(e))
384
+ request_id = None
385
+ try:
386
+ request_id = err_msg.split(" ")[0].split("=")[1]
387
+ except Exception:
388
+ pass
389
+
390
+ self.emit(
391
+ "error",
392
+ llm.RealtimeModelError(
393
+ timestamp=time.monotonic(),
394
+ label=self._realtime_model._label,
395
+ error=APIStatusError(
396
+ message=err_msg,
397
+ status_code=500,
398
+ request_id=request_id,
399
+ body=e,
400
+ retryable=False,
401
+ ),
402
+ recoverable=False,
403
+ ),
404
+ )
405
+ raise
406
+
407
+ def _serialize_tool_config(self) -> ToolConfiguration | None:
408
+ """Convert self.tools into the JSON structure expected by Sonic.
409
+
410
+ If any tools are registered, the method also harmonises temperature and
411
+ top_p defaults to Sonic's recommended greedy values (1.0).
412
+
413
+ Returns:
414
+ ToolConfiguration | None: None when no tools are present, otherwise a complete config block.
415
+ """ # noqa: E501
416
+ tool_cfg = None
417
+ if self.tools.function_tools:
418
+ tools = []
419
+ for name, f in self.tools.function_tools.items():
420
+ if llm.tool_context.is_function_tool(f):
421
+ description = llm.tool_context.get_function_info(f).description
422
+ input_schema = llm.utils.build_legacy_openai_schema(f, internally_tagged=True)[
423
+ "parameters"
424
+ ]
425
+ elif llm.tool_context.is_raw_function_tool(f):
426
+ description = llm.tool_context.get_raw_function_info(f).raw_schema.get(
427
+ "description"
428
+ )
429
+ input_schema = llm.tool_context.get_raw_function_info(f).raw_schema[
430
+ "parameters"
431
+ ]
432
+ else:
433
+ continue
434
+
435
+ tool = Tool(
436
+ toolSpec=ToolSpec(
437
+ name=name,
438
+ description=description or "No description provided",
439
+ inputSchema=ToolInputSchema(json_=json.dumps(input_schema)), # type: ignore
440
+ )
441
+ )
442
+ tools.append(tool)
443
+ tool_choice = self._tool_choice_adapter(self._realtime_model._opts.tool_choice)
444
+ logger.debug(f"TOOL CHOICE: {tool_choice}")
445
+ tool_cfg = ToolConfiguration(tools=tools, toolChoice=tool_choice)
446
+
447
+ # recommended to set greedy inference configs for tool calls
448
+ if not is_given(self._realtime_model.top_p):
449
+ self._realtime_model._opts.top_p = 1.0
450
+ if not is_given(self._realtime_model.temperature):
451
+ self._realtime_model._opts.temperature = 1.0
452
+ return tool_cfg
453
+
454
+ @utils.log_exceptions(logger=logger)
455
+ async def initialize_streams(self, is_restart: bool = False) -> None:
456
+ """Open the Bedrock bidirectional stream and spawn background worker tasks.
457
+
458
+ This coroutine is idempotent and can be invoked again when recoverable
459
+ errors (e.g. timeout, throttling) require a fresh session.
460
+
461
+ Args:
462
+ is_restart (bool, optional): Marks whether we are re-initialising an
463
+ existing session after an error. Defaults to False.
464
+ """
465
+ try:
466
+ if not self._bedrock_client:
467
+ logger.info("Creating Bedrock client")
468
+ self._initialize_client()
469
+ assert self._bedrock_client is not None, "bedrock_client is None"
470
+
471
+ logger.info("Initializing Bedrock stream")
472
+ self._stream_response = (
473
+ await self._bedrock_client.invoke_model_with_bidirectional_stream(
474
+ InvokeModelWithBidirectionalStreamOperationInput(
475
+ model_id=self._realtime_model.model_id
476
+ )
477
+ )
478
+ )
479
+
480
+ if not is_restart:
481
+ pending_events: list[asyncio.Future] = []
482
+ if not self.tools.function_tools:
483
+ pending_events.append(self._tools_ready)
484
+ if not self._instructions_ready.done():
485
+ pending_events.append(self._instructions_ready)
486
+ if not self._chat_ctx_ready.done():
487
+ pending_events.append(self._chat_ctx_ready)
488
+
489
+ # note: can't know during sess init whether tools were not added
490
+ # or if they were added haven't yet been updated
491
+ # therefore in the case there are no tools, we wait the entire timeout
492
+ try:
493
+ if pending_events:
494
+ await asyncio.wait_for(asyncio.gather(*pending_events), timeout=0.5)
495
+ except asyncio.TimeoutError:
496
+ if not self._tools_ready.done():
497
+ logger.warning("Tools not ready after 500ms, continuing without them")
498
+
499
+ if not self._instructions_ready.done():
500
+ logger.warning(
501
+ "Instructions not received after 500ms, proceeding with default instructions" # noqa: E501
502
+ )
503
+ if not self._chat_ctx_ready.done():
504
+ logger.warning(
505
+ "Chat context not received after 500ms, proceeding with empty chat context" # noqa: E501
506
+ )
507
+
508
+ logger.info(
509
+ f"Initializing Bedrock session with realtime options: {self._realtime_model._opts}"
510
+ )
511
+ # there is a 40-message limit on the chat context
512
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
513
+ logger.warning(
514
+ f"Chat context has {len(self._chat_ctx.items)} messages, truncating to {MAX_MESSAGES}" # noqa: E501
515
+ )
516
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
517
+ init_events = self._event_builder.create_prompt_start_block(
518
+ voice_id=self._realtime_model._opts.voice,
519
+ sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
520
+ system_content=self._instructions,
521
+ chat_ctx=self.chat_ctx,
522
+ tool_configuration=self._serialize_tool_config(),
523
+ max_tokens=self._realtime_model._opts.max_tokens,
524
+ top_p=self._realtime_model._opts.top_p,
525
+ temperature=self._realtime_model._opts.temperature,
526
+ )
527
+
528
+ for event in init_events:
529
+ await self._send_raw_event(event)
530
+ logger.debug(f"Sent event: {event}")
531
+
532
+ if not is_restart:
533
+ self._audio_input_task = asyncio.create_task(
534
+ self._process_audio_input(), name="RealtimeSession._process_audio_input"
535
+ )
536
+
537
+ self._response_task = asyncio.create_task(
538
+ self._process_responses(), name="RealtimeSession._process_responses"
539
+ )
540
+ self._is_sess_active.set()
541
+ logger.debug("Stream initialized successfully")
542
+ except Exception as e:
543
+ logger.debug(f"Failed to initialize stream: {str(e)}")
544
+ raise
545
+ return self
546
+
547
+ @utils.log_exceptions(logger=logger)
548
+ def emit_generation_event(self) -> None:
549
+ """Publish a llm.GenerationCreatedEvent to external subscribers."""
550
+ logger.debug("Emitting generation event")
551
+ assert self._current_generation is not None, "current_generation is None"
552
+
553
+ generation_ev = llm.GenerationCreatedEvent(
554
+ message_stream=self._current_generation.message_ch,
555
+ function_stream=self._current_generation.function_ch,
556
+ user_initiated=False,
557
+ )
558
+ self.emit("generation_created", generation_ev)
559
+
560
+ @utils.log_exceptions(logger=logger)
561
+ async def _handle_event(self, event_data: dict) -> None:
562
+ """Dispatch a raw Bedrock event to the corresponding _handle_* method."""
563
+ event_type = self._event_builder.get_event_type(event_data)
564
+ event_handler = self._event_handlers.get(event_type)
565
+ if event_handler:
566
+ await event_handler(event_data)
567
+ self._turn_tracker.feed(event_data)
568
+ else:
569
+ logger.warning(f"No event handler found for event type: {event_type}")
570
+
571
+ async def _handle_completion_start_event(self, event_data: dict) -> None:
572
+ log_event_data(event_data)
573
+ self._create_response_generation()
574
+
575
+ def _create_response_generation(self) -> None:
576
+ """Instantiate _ResponseGeneration and emit the GenerationCreated event."""
577
+ if self._current_generation is None:
578
+ self._current_generation = _ResponseGeneration(
579
+ message_ch=utils.aio.Chan(),
580
+ function_ch=utils.aio.Chan(),
581
+ input_id=str(uuid.uuid4()),
582
+ response_id=str(uuid.uuid4()),
583
+ messages={},
584
+ user_messages={},
585
+ speculative_messages={},
586
+ _created_timestamp=datetime.now().isoformat(),
587
+ )
588
+ msg_gen = _MessageGeneration(
589
+ message_id=self._current_generation.response_id,
590
+ text_ch=utils.aio.Chan(),
591
+ audio_ch=utils.aio.Chan(),
592
+ )
593
+ self._current_generation.message_ch.send_nowait(
594
+ llm.MessageGeneration(
595
+ message_id=msg_gen.message_id,
596
+ text_stream=msg_gen.text_ch,
597
+ audio_stream=msg_gen.audio_ch,
598
+ )
599
+ )
600
+ self._current_generation.messages[self._current_generation.response_id] = msg_gen
601
+
602
+ # will be completely ignoring post-ASR text events
603
+ async def _handle_text_output_content_start_event(self, event_data: dict) -> None:
604
+ """Handle text_output_content_start for both user and assistant roles."""
605
+ log_event_data(event_data)
606
+ role = event_data["event"]["contentStart"]["role"]
607
+ self._create_response_generation()
608
+
609
+ # note: does not work if you emit llm.GCE too early (for some reason)
610
+ if role == "USER":
611
+ assert self._current_generation is not None, "current_generation is None"
612
+
613
+ content_id = event_data["event"]["contentStart"]["contentId"]
614
+ self._current_generation.user_messages[content_id] = self._current_generation.input_id
615
+
616
+ elif (
617
+ role == "ASSISTANT"
618
+ and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
619
+ ):
620
+ assert self._current_generation is not None, "current_generation is None"
621
+
622
+ text_content_id = event_data["event"]["contentStart"]["contentId"]
623
+ self._current_generation.speculative_messages[text_content_id] = (
624
+ self._current_generation.response_id
625
+ )
626
+
627
+ async def _handle_text_output_content_event(self, event_data: dict) -> None:
628
+ """Stream partial text tokens into the current _MessageGeneration."""
629
+ log_event_data(event_data)
630
+ text_content_id = event_data["event"]["textOutput"]["contentId"]
631
+ text_content = f"{event_data['event']['textOutput']['content']}\n"
632
+
633
+ # currently only agent can be interrupted
634
+ if text_content == '{ "interrupted" : true }\n':
635
+ # the interrupted flag is not being set correctly in chat_ctx
636
+ # this is b/c audio playback is desynced from text transcription
637
+ # TODO: fix this; possibly via a playback timer
638
+ idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
639
+ if idx < 0:
640
+ logger.warning("Barge-in DETECTED but no previous message found")
641
+ return
642
+
643
+ logger.debug(
644
+ f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
645
+ )
646
+ if (item := self._chat_ctx.items[idx]).type == "message":
647
+ item.interrupted = True
648
+ self._close_current_generation()
649
+ return
650
+
651
+ # ignore events until turn starts
652
+ if self._current_generation is not None:
653
+ # TODO: rename event to llm.InputTranscriptionUpdated
654
+ if (
655
+ self._current_generation.user_messages.get(text_content_id)
656
+ == self._current_generation.input_id
657
+ ):
658
+ logger.debug(f"INPUT TRANSCRIPTION UPDATED: {text_content}")
659
+ # note: user ASR text is slightly different than what is sent to LiveKit (newline vs whitespace) # noqa: E501
660
+ # TODO: fix this
661
+ self._update_chat_ctx(role="user", text_content=text_content)
662
+
663
+ elif (
664
+ self._current_generation.speculative_messages.get(text_content_id)
665
+ == self._current_generation.response_id
666
+ ):
667
+ curr_gen = self._current_generation.messages[self._current_generation.response_id]
668
+ curr_gen.text_ch.send_nowait(text_content)
669
+ # note: this update is per utterance, not per turn
670
+ self._update_chat_ctx(role="assistant", text_content=text_content)
671
+
672
+ def _update_chat_ctx(self, role: llm.ChatRole, text_content: str) -> None:
673
+ """
674
+ Update the chat context with the latest ASR text while guarding against model limitations:
675
+ a) 40 total messages limit
676
+ b) 1kB message size limit
677
+ """
678
+ logger.debug(f"Updating chat context with role: {role} and text_content: {text_content}")
679
+ if len(self._chat_ctx.items) == 0:
680
+ self._chat_ctx.add_message(role=role, content=text_content)
681
+ else:
682
+ prev_utterance = self._chat_ctx.items[-1]
683
+ if prev_utterance.type == "message" and prev_utterance.role == role:
684
+ if isinstance(prev_content := prev_utterance.content[0], str) and (
685
+ len(prev_content.encode("utf-8")) + len(text_content.encode("utf-8"))
686
+ < MAX_MESSAGE_SIZE
687
+ ):
688
+ prev_utterance.content[0] = "\n".join([prev_content, text_content])
689
+ else:
690
+ self._chat_ctx.add_message(role=role, content=text_content)
691
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
692
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
693
+ else:
694
+ self._chat_ctx.add_message(role=role, content=text_content)
695
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
696
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
697
+
698
+ # cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
699
+ async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
700
+ """Mark the assistant message closed when Bedrock signals END_TURN."""
701
+ stop_reason = event_data["event"]["contentEnd"]["stopReason"]
702
+ text_content_id = event_data["event"]["contentEnd"]["contentId"]
703
+ if (
704
+ self._current_generation
705
+ is not None # means that first utterance in the turn was an interrupt
706
+ and self._current_generation.speculative_messages.get(text_content_id)
707
+ == self._current_generation.response_id
708
+ and stop_reason == "END_TURN"
709
+ ):
710
+ log_event_data(event_data)
711
+ self._close_current_generation()
712
+
713
+ async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
714
+ """Track mapping content_id -> response_id for upcoming tool use."""
715
+ log_event_data(event_data)
716
+ assert self._current_generation is not None, "current_generation is None"
717
+
718
+ tool_use_content_id = event_data["event"]["contentStart"]["contentId"]
719
+ self._current_generation.tool_messages[tool_use_content_id] = (
720
+ self._current_generation.response_id
721
+ )
722
+
723
+ # note: tool calls are synchronous for now
724
+ async def _handle_tool_output_content_event(self, event_data: dict) -> None:
725
+ """Execute the referenced tool locally and forward results back to Bedrock."""
726
+ log_event_data(event_data)
727
+ assert self._current_generation is not None, "current_generation is None"
728
+
729
+ tool_use_content_id = event_data["event"]["toolUse"]["contentId"]
730
+ tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
731
+ tool_name = event_data["event"]["toolUse"]["toolName"]
732
+ if (
733
+ self._current_generation.tool_messages.get(tool_use_content_id)
734
+ == self._current_generation.response_id
735
+ ):
736
+ args = event_data["event"]["toolUse"]["content"]
737
+ self._current_generation.function_ch.send_nowait(
738
+ llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
739
+ )
740
+ self._pending_tools.add(tool_use_id)
741
+
742
+ # performing these acrobatics in order to release the deadlock
743
+ # LiveKit will not accept a new generation until the previous one is closed
744
+ # the issue is that audio data cannot be generated until toolResult is received
745
+ # however, toolResults only arrive after update_chat_ctx() is invoked
746
+ # which will only occur after agent speech has completed
747
+ # therefore we introduce an artificial turn to trigger update_chat_ctx()
748
+ # TODO: this is messy-- investigate if there is a better way to handle this
749
+ curr_gen = self._current_generation.messages[self._current_generation.response_id]
750
+ curr_gen.audio_ch.close()
751
+ curr_gen.text_ch.close()
752
+ self._current_generation.message_ch.close()
753
+ self._current_generation.message_ch = utils.aio.Chan()
754
+ self._current_generation.function_ch.close()
755
+ self._current_generation.function_ch = utils.aio.Chan()
756
+ msg_gen = _MessageGeneration(
757
+ message_id=self._current_generation.response_id,
758
+ text_ch=utils.aio.Chan(),
759
+ audio_ch=utils.aio.Chan(),
760
+ )
761
+ self._current_generation.messages[self._current_generation.response_id] = msg_gen
762
+ self._current_generation.message_ch.send_nowait(
763
+ llm.MessageGeneration(
764
+ message_id=msg_gen.message_id,
765
+ text_stream=msg_gen.text_ch,
766
+ audio_stream=msg_gen.audio_ch,
767
+ )
768
+ )
769
+ self.emit_generation_event()
770
+
771
+ async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
772
+ log_event_data(event_data)
773
+
774
+ async def _handle_audio_output_content_start_event(self, event_data: dict) -> None:
775
+ """Associate the upcoming audio chunk with the active assistant message."""
776
+ if self._current_generation is not None:
777
+ log_event_data(event_data)
778
+ audio_content_id = event_data["event"]["contentStart"]["contentId"]
779
+ self._current_generation.speculative_messages[audio_content_id] = (
780
+ self._current_generation.response_id
781
+ )
782
+
783
+ async def _handle_audio_output_content_event(self, event_data: dict) -> None:
784
+ """Decode base64 audio from Bedrock and forward it to the audio stream."""
785
+ if (
786
+ self._current_generation is not None
787
+ and self._current_generation.speculative_messages.get(
788
+ event_data["event"]["audioOutput"]["contentId"]
789
+ )
790
+ == self._current_generation.response_id
791
+ ):
792
+ audio_content = event_data["event"]["audioOutput"]["content"]
793
+ audio_bytes = base64.b64decode(audio_content)
794
+ curr_gen = self._current_generation.messages[self._current_generation.response_id]
795
+ curr_gen.audio_ch.send_nowait(
796
+ rtc.AudioFrame(
797
+ data=audio_bytes,
798
+ sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
799
+ num_channels=DEFAULT_CHANNELS,
800
+ samples_per_channel=len(audio_bytes) // 2,
801
+ )
802
+ )
803
+
804
+ async def _handle_audio_output_content_end_event(self, event_data: dict) -> None:
805
+ """Close the assistant message streams once Bedrock finishes audio for the turn."""
806
+ if (
807
+ self._current_generation is not None
808
+ and event_data["event"]["contentEnd"]["stopReason"] == "END_TURN"
809
+ and self._current_generation.speculative_messages.get(
810
+ event_data["event"]["contentEnd"]["contentId"]
811
+ )
812
+ == self._current_generation.response_id
813
+ ):
814
+ log_event_data(event_data)
815
+ self._close_current_generation()
816
+
817
+ def _close_current_generation(self) -> None:
818
+ """Helper that closes all channels of the active _ResponseGeneration."""
819
+ if self._current_generation is not None:
820
+ if self._current_generation.response_id in self._current_generation.messages:
821
+ curr_gen = self._current_generation.messages[self._current_generation.response_id]
822
+ if not curr_gen.audio_ch.closed:
823
+ curr_gen.audio_ch.close()
824
+ if not curr_gen.text_ch.closed:
825
+ curr_gen.text_ch.close()
826
+ if self._current_generation.response_id in self._current_generation.tool_messages:
827
+ curr_gen = self._current_generation.tool_messages[
828
+ self._current_generation.response_id
829
+ ]
830
+ if not curr_gen.function_ch.closed:
831
+ curr_gen.function_ch.close()
832
+
833
+ if not self._current_generation.message_ch.closed:
834
+ self._current_generation.message_ch.close()
835
+ if not self._current_generation.function_ch.closed:
836
+ self._current_generation.function_ch.close()
837
+
838
+ self._current_generation = None
839
+
840
+ async def _handle_completion_end_event(self, event_data: dict) -> None:
841
+ log_event_data(event_data)
842
+
843
+ async def _handle_other_event(self, event_data: dict) -> None:
844
+ log_event_data(event_data)
845
+
846
+ async def _handle_usage_event(self, event_data: dict) -> None:
847
+ # log_event_data(event_data)
848
+ # TODO: implement duration and ttft
849
+ input_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["input"]
850
+ output_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["output"]
851
+ # Q: should we be counting per turn or utterance?
852
+ metrics = RealtimeModelMetrics(
853
+ label=self._realtime_model._label,
854
+ # TODO: pass in the correct request_id
855
+ request_id=event_data["event"]["usageEvent"]["completionId"],
856
+ timestamp=time.monotonic(),
857
+ duration=0,
858
+ ttft=0,
859
+ cancelled=False,
860
+ input_tokens=input_tokens["speechTokens"] + input_tokens["textTokens"],
861
+ output_tokens=output_tokens["speechTokens"] + output_tokens["textTokens"],
862
+ total_tokens=input_tokens["speechTokens"]
863
+ + input_tokens["textTokens"]
864
+ + output_tokens["speechTokens"]
865
+ + output_tokens["textTokens"],
866
+ # need duration to calculate this
867
+ tokens_per_second=0,
868
+ input_token_details=RealtimeModelMetrics.InputTokenDetails(
869
+ text_tokens=input_tokens["textTokens"],
870
+ audio_tokens=input_tokens["speechTokens"],
871
+ image_tokens=0,
872
+ cached_tokens=0,
873
+ cached_tokens_details=None,
874
+ ),
875
+ output_token_details=RealtimeModelMetrics.OutputTokenDetails(
876
+ text_tokens=output_tokens["textTokens"],
877
+ audio_tokens=output_tokens["speechTokens"],
878
+ image_tokens=0,
879
+ ),
880
+ )
881
+ self.emit("metrics_collected", metrics)
882
+
883
+ @utils.log_exceptions(logger=logger)
884
+ async def _process_responses(self) -> None:
885
+ """Background task that drains Bedrock's output stream and feeds the event handlers."""
886
+ try:
887
+ await self._is_sess_active.wait()
888
+ assert self._stream_response is not None, "stream_response is None"
889
+
890
+ # note: may need another signal here to block input task until bedrock is ready
891
+ # TODO: save this as a field so we're not re-awaiting it every time
892
+ _, output_stream = await self._stream_response.await_output()
893
+ while self._is_sess_active.is_set():
894
+ # and not self.stream_response.output_stream.closed:
895
+ try:
896
+ result = await output_stream.receive()
897
+ if result.value and result.value.bytes_:
898
+ try:
899
+ response_data = result.value.bytes_.decode("utf-8")
900
+ json_data = json.loads(response_data)
901
+ # logger.debug(f"Received event: {json_data}")
902
+ await self._handle_event(json_data)
903
+ except json.JSONDecodeError:
904
+ logger.warning(f"JSON decode error: {response_data}")
905
+ else:
906
+ logger.warning("No response received")
907
+ except asyncio.CancelledError:
908
+ logger.info("Response processing task cancelled")
909
+ self._close_current_generation()
910
+ raise
911
+ except ValidationException as ve:
912
+ # there is a 3min no-activity (e.g. silence) timeout on the stream, after which the stream is closed # noqa: E501
913
+ if (
914
+ "InternalErrorCode=531::RST_STREAM closed stream. HTTP/2 error code: NO_ERROR" # noqa: E501
915
+ in ve.message
916
+ ):
917
+ logger.warning(f"Validation error: {ve}\nAttempting to recover...")
918
+ await self._restart_session(ve)
919
+
920
+ else:
921
+ logger.error(f"Validation error: {ve}")
922
+ self.emit(
923
+ "error",
924
+ llm.RealtimeModelError(
925
+ timestamp=time.monotonic(),
926
+ label=self._realtime_model._label,
927
+ error=APIStatusError(
928
+ message=ve.message,
929
+ status_code=400,
930
+ request_id="",
931
+ body=ve,
932
+ retryable=False,
933
+ ),
934
+ recoverable=False,
935
+ ),
936
+ )
937
+ raise
938
+ except (ThrottlingException, ModelNotReadyException, ModelErrorException) as re:
939
+ logger.warning(f"Retryable error: {re}\nAttempting to recover...")
940
+ await self._restart_session(re)
941
+ break
942
+ except ModelTimeoutException as mte:
943
+ logger.warning(f"Model timeout error: {mte}\nAttempting to recover...")
944
+ await self._restart_session(mte)
945
+ break
946
+ except ValueError as val_err:
947
+ if "I/O operation on closed file." == val_err.args[0]:
948
+ logger.info("initiating graceful shutdown of session")
949
+ break
950
+ raise
951
+ except OSError:
952
+ logger.info("stream already closed, exiting")
953
+ break
954
+ except Exception as e:
955
+ err_msg = getattr(e, "message", str(e))
956
+ logger.error(f"Response processing error: {err_msg} (type: {type(e)})")
957
+ request_id = None
958
+ try:
959
+ request_id = err_msg.split(" ")[0].split("=")[1]
960
+ except Exception:
961
+ pass
962
+
963
+ self.emit(
964
+ "error",
965
+ llm.RealtimeModelError(
966
+ timestamp=time.monotonic(),
967
+ label=self._realtime_model._label,
968
+ error=APIStatusError(
969
+ message=err_msg,
970
+ status_code=500,
971
+ request_id=request_id,
972
+ body=e,
973
+ retryable=False,
974
+ ),
975
+ recoverable=False,
976
+ ),
977
+ )
978
+ raise
979
+
980
+ finally:
981
+ logger.info("main output response stream processing task exiting")
982
+ self._is_sess_active.clear()
983
+
984
+ async def _restart_session(self, ex: Exception) -> None:
985
+ if self._session_restart_attempts >= DEFAULT_MAX_SESSION_RESTART_ATTEMPTS:
986
+ logger.error("Max session restart attempts reached, exiting")
987
+ err_msg = getattr(ex, "message", str(ex))
988
+ request_id = None
989
+ try:
990
+ request_id = err_msg.split(" ")[0].split("=")[1]
991
+ except Exception:
992
+ pass
993
+ self.emit(
994
+ "error",
995
+ llm.RealtimeModelError(
996
+ timestamp=time.monotonic(),
997
+ label=self._realtime_model._label,
998
+ error=APIStatusError(
999
+ message=f"Max restart attempts exceeded: {err_msg}",
1000
+ status_code=500,
1001
+ request_id=request_id,
1002
+ body=ex,
1003
+ retryable=False,
1004
+ ),
1005
+ recoverable=False,
1006
+ ),
1007
+ )
1008
+ self._is_sess_active.clear()
1009
+ return
1010
+ self._session_restart_attempts += 1
1011
+ self._is_sess_active.clear()
1012
+ delay = 2 ** (self._session_restart_attempts - 1) - 1
1013
+ await asyncio.sleep(min(delay, DEFAULT_MAX_SESSION_RESTART_DELAY))
1014
+ await self.initialize_streams(is_restart=True)
1015
+ logger.info(
1016
+ f"Session restarted successfully ({self._session_restart_attempts}/{DEFAULT_MAX_SESSION_RESTART_ATTEMPTS})" # noqa: E501
1017
+ )
1018
+
1019
+ @property
1020
+ def chat_ctx(self) -> llm.ChatContext:
1021
+ return self._chat_ctx.copy()
1022
+
1023
+ @property
1024
+ def tools(self) -> llm.ToolContext:
1025
+ return self._tools.copy()
1026
+
1027
+ async def update_instructions(self, instructions: str) -> None:
1028
+ """Injects the system prompt at the start of the session."""
1029
+ self._instructions = instructions
1030
+ self._instructions_ready.set_result(True)
1031
+ logger.debug(f"Instructions updated: {instructions}")
1032
+
1033
+ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
1034
+ """Inject an initial chat history once during the very first session startup."""
1035
+ # sometimes fires randomly
1036
+ # add a guard here to only allow chat_ctx to be updated on
1037
+ # the very first session initialization
1038
+ if not self._chat_ctx_ready.done():
1039
+ self._chat_ctx = chat_ctx.copy()
1040
+ logger.debug(f"Chat context updated: {self._chat_ctx.items}")
1041
+ self._chat_ctx_ready.set_result(True)
1042
+
1043
+ # for each function tool, send the result to aws
1044
+ for item in chat_ctx.items:
1045
+ if item.type != "function_call_output":
1046
+ continue
1047
+
1048
+ if item.call_id not in self._pending_tools:
1049
+ continue
1050
+
1051
+ logger.debug(f"function call output: {item}")
1052
+ self._pending_tools.discard(item.call_id)
1053
+ self._tool_results_ch.send_nowait(
1054
+ {
1055
+ "tool_use_id": item.call_id,
1056
+ "tool_result": item.output
1057
+ if not item.is_error
1058
+ else f"{{'error': '{item.error}'}}",
1059
+ }
1060
+ )
1061
+
1062
+ async def _send_tool_events(self, tool_use_id: str, tool_result: str) -> None:
1063
+ """Send tool_result back to Bedrock, grouped under tool_use_id."""
1064
+ tool_content_name = str(uuid.uuid4())
1065
+ tool_events = self._event_builder.create_tool_content_block(
1066
+ content_name=tool_content_name,
1067
+ tool_use_id=tool_use_id,
1068
+ content=tool_result,
1069
+ )
1070
+ for event in tool_events:
1071
+ await self._send_raw_event(event)
1072
+ # logger.debug(f"Sent tool event: {event}")
1073
+
1074
+ def _tool_choice_adapter(
1075
+ self, tool_choice: llm.ToolChoice | None
1076
+ ) -> dict[str, dict[str, str]] | None:
1077
+ """Translate the LiveKit ToolChoice enum into Sonic's JSON schema."""
1078
+ if tool_choice == "auto":
1079
+ return {"auto": {}}
1080
+ elif tool_choice == "required":
1081
+ return {"any": {}}
1082
+ elif isinstance(tool_choice, dict) and tool_choice["type"] == "function":
1083
+ return {"tool": {"name": tool_choice["function"]["name"]}}
1084
+ else:
1085
+ return None
1086
+
1087
+ # note: return value from tool functions registered to Sonic must be Structured Output (a dict that is JSON serializable) # noqa: E501
1088
+ async def update_tools(self, tools: list[llm.FunctionTool | llm.RawFunctionTool | Any]) -> None:
1089
+ """Replace the active tool set with tools and notify Sonic if necessary."""
1090
+ logger.debug(f"Updating tools: {tools}")
1091
+ retained_tools: list[llm.FunctionTool | llm.RawFunctionTool] = []
1092
+
1093
+ for tool in tools:
1094
+ retained_tools.append(tool)
1095
+ self._tools = llm.ToolContext(retained_tools)
1096
+ if retained_tools:
1097
+ self._tools_ready.set_result(True)
1098
+ logger.debug("Tool list has been injected")
1099
+
1100
+ def update_options(self, *, tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN) -> None:
1101
+ """Live update of inference options is not supported by Sonic yet."""
1102
+ logger.warning(
1103
+ "updating inference configuration options is not yet supported by Nova Sonic's Realtime API" # noqa: E501
1104
+ )
1105
+
1106
+ @utils.log_exceptions(logger=logger)
1107
+ def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
1108
+ """Ensure mic audio matches Sonic's required sample rate & channels."""
1109
+ if self._input_resampler:
1110
+ if frame.sample_rate != self._input_resampler._input_rate:
1111
+ self._input_resampler = None
1112
+
1113
+ if self._input_resampler is None and (
1114
+ frame.sample_rate != DEFAULT_INPUT_SAMPLE_RATE or frame.num_channels != DEFAULT_CHANNELS
1115
+ ):
1116
+ self._input_resampler = rtc.AudioResampler(
1117
+ input_rate=frame.sample_rate,
1118
+ output_rate=DEFAULT_INPUT_SAMPLE_RATE,
1119
+ num_channels=DEFAULT_CHANNELS,
1120
+ )
1121
+
1122
+ if self._input_resampler:
1123
+ # flush the resampler when the input source is changed
1124
+ yield from self._input_resampler.push(frame)
1125
+ else:
1126
+ yield frame
1127
+
1128
+ @utils.log_exceptions(logger=logger)
1129
+ async def _process_audio_input(self) -> None:
1130
+ """Background task that feeds audio and tool results into the Bedrock stream."""
1131
+ await self._send_raw_event(self._event_builder.create_audio_content_start_event())
1132
+ logger.info("Starting audio input processing loop")
1133
+ while self._is_sess_active.is_set():
1134
+ try:
1135
+ # note: could potentially pull this out into a separate task
1136
+ try:
1137
+ val = self._tool_results_ch.recv_nowait()
1138
+ tool_result = val["tool_result"]
1139
+ tool_use_id = val["tool_use_id"]
1140
+ if not isinstance(tool_result, str):
1141
+ tool_result = json.dumps(tool_result)
1142
+ else:
1143
+ try:
1144
+ json.loads(tool_result)
1145
+ except json.JSONDecodeError:
1146
+ try:
1147
+ tool_result = json.dumps(ast.literal_eval(tool_result))
1148
+ except Exception:
1149
+ # return the original value
1150
+ pass
1151
+
1152
+ logger.debug(f"Sending tool result: {tool_result}")
1153
+ await self._send_tool_events(tool_use_id, tool_result)
1154
+
1155
+ except utils.aio.channel.ChanEmpty:
1156
+ pass
1157
+ except utils.aio.channel.ChanClosed:
1158
+ logger.warning(
1159
+ "tool results channel closed, exiting audio input processing loop"
1160
+ )
1161
+ break
1162
+
1163
+ try:
1164
+ audio_bytes = await self._audio_input_chan.recv()
1165
+ blob = base64.b64encode(audio_bytes)
1166
+ audio_event = self._event_builder.create_audio_input_event(
1167
+ audio_content=blob.decode("utf-8"),
1168
+ )
1169
+
1170
+ await self._send_raw_event(audio_event)
1171
+ except utils.aio.channel.ChanEmpty:
1172
+ pass
1173
+ except utils.aio.channel.ChanClosed:
1174
+ logger.warning(
1175
+ "audio input channel closed, exiting audio input processing loop"
1176
+ )
1177
+ break
1178
+
1179
+ except asyncio.CancelledError:
1180
+ logger.info("Audio processing loop cancelled")
1181
+ self._audio_input_chan.close()
1182
+ self._tool_results_ch.close()
1183
+ raise
1184
+ except Exception:
1185
+ logger.exception("Error processing audio")
1186
+
1187
+ # for debugging purposes only
1188
+ def _log_significant_audio(self, audio_bytes: bytes) -> None:
1189
+ """Utility that prints a debug message when the audio chunk has non-trivial RMS energy."""
1190
+ squared_sum = sum(sample**2 for sample in audio_bytes)
1191
+ if (squared_sum / len(audio_bytes)) ** 0.5 > 200:
1192
+ if lk_bedrock_debug:
1193
+ log_message("Enqueuing significant audio chunk", AnsiColors.BLUE)
1194
+
1195
+ @utils.log_exceptions(logger=logger)
1196
+ def push_audio(self, frame: rtc.AudioFrame) -> None:
1197
+ """Enqueue an incoming mic rtc.AudioFrame for transcription."""
1198
+ if not self._audio_input_chan.closed:
1199
+ # logger.debug(f"Raw audio received: samples={len(frame.data)} rate={frame.sample_rate} channels={frame.num_channels}") # noqa: E501
1200
+ for f in self._resample_audio(frame):
1201
+ # logger.debug(f"Resampled audio: samples={len(frame.data)} rate={frame.sample_rate} channels={frame.num_channels}") # noqa: E501
1202
+
1203
+ for nf in self._bstream.write(f.data.tobytes()):
1204
+ self._log_significant_audio(nf.data)
1205
+ self._audio_input_chan.send_nowait(nf.data)
1206
+ else:
1207
+ logger.warning("audio input channel closed, skipping audio")
1208
+
1209
+ def generate_reply(
1210
+ self,
1211
+ *,
1212
+ instructions: NotGivenOr[str] = NOT_GIVEN,
1213
+ ) -> asyncio.Future[llm.GenerationCreatedEvent]:
1214
+ logger.warning("unprompted generation is not supported by Nova Sonic's Realtime API")
1215
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
1216
+ fut.set_exception(
1217
+ llm.RealtimeError("unprompted generation is not supported by Nova Sonic's Realtime API")
1218
+ )
1219
+ return fut
1220
+
1221
+ def commit_audio(self) -> None:
1222
+ logger.warning("commit_audio is not supported by Nova Sonic's Realtime API")
1223
+
1224
+ def clear_audio(self) -> None:
1225
+ logger.warning("clear_audio is not supported by Nova Sonic's Realtime API")
1226
+
1227
+ def push_video(self, frame: rtc.VideoFrame) -> None:
1228
+ logger.warning("video is not supported by Nova Sonic's Realtime API")
1229
+
1230
+ def interrupt(self) -> None:
1231
+ logger.warning("interrupt is not supported by Nova Sonic's Realtime API")
1232
+
1233
+ def truncate(self, *, message_id: str, audio_end_ms: int) -> None:
1234
+ logger.warning("truncate is not supported by Nova Sonic's Realtime API")
1235
+
1236
+ @utils.log_exceptions(logger=logger)
1237
+ async def aclose(self) -> None:
1238
+ """Gracefully shut down the realtime session and release network resources."""
1239
+ logger.info("attempting to shutdown agent session")
1240
+ if not self._is_sess_active.is_set():
1241
+ logger.info("agent session already inactive")
1242
+ return
1243
+
1244
+ for event in self._event_builder.create_prompt_end_block():
1245
+ await self._send_raw_event(event)
1246
+ # allow event loops to fall out naturally
1247
+ # otherwise, the smithy layer will raise an InvalidStateError during cancellation
1248
+ self._is_sess_active.clear()
1249
+
1250
+ if self._stream_response and not self._stream_response.output_stream.closed:
1251
+ await self._stream_response.output_stream.close()
1252
+
1253
+ # note: even after the self.is_active flag is flipped and the output stream is closed,
1254
+ # there is a future inside output_stream.receive() at the AWS-CRT C layer that blocks
1255
+ # resulting in an error after cancellation
1256
+ # however, it's mostly cosmetic-- the event loop will still exit
1257
+ # TODO: fix this nit
1258
+ tasks: list[asyncio.Task[Any]] = []
1259
+ if self._response_task:
1260
+ try:
1261
+ await asyncio.wait_for(self._response_task, timeout=1.0)
1262
+ except asyncio.TimeoutError:
1263
+ logger.warning("shutdown of output event loop timed out-- cancelling")
1264
+ self._response_task.cancel()
1265
+ tasks.append(self._response_task)
1266
+
1267
+ # must cancel the audio input task before closing the input stream
1268
+ if self._audio_input_task and not self._audio_input_task.done():
1269
+ self._audio_input_task.cancel()
1270
+ tasks.append(self._audio_input_task)
1271
+ if self._stream_response and not self._stream_response.input_stream.closed:
1272
+ await self._stream_response.input_stream.close()
1273
+
1274
+ await asyncio.gather(*tasks, return_exceptions=True)
1275
+ logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
1276
+ logger.info("Session end")