livekit-plugins-aws 1.0.0rc6__py3-none-any.whl → 1.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2106 @@
1
+ # mypy: disable-error-code=unused-ignore
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import asyncio
7
+ import base64
8
+ import concurrent.futures
9
+ import json
10
+ import os
11
+ import time
12
+ import uuid
13
+ import weakref
14
+ from collections.abc import AsyncIterator, Iterator
15
+ from dataclasses import dataclass, field
16
+ from typing import Any, Callable, Literal, cast
17
+
18
+ import boto3
19
+ from aws_sdk_bedrock_runtime.client import (
20
+ BedrockRuntimeClient,
21
+ InvokeModelWithBidirectionalStreamOperationInput,
22
+ )
23
+ from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme
24
+ from aws_sdk_bedrock_runtime.models import (
25
+ BidirectionalInputPayloadPart,
26
+ InvokeModelWithBidirectionalStreamInputChunk,
27
+ ModelErrorException,
28
+ ModelNotReadyException,
29
+ ModelStreamErrorException,
30
+ ModelTimeoutException,
31
+ ThrottlingException,
32
+ ValidationException,
33
+ )
34
+ from smithy_aws_core.identity import AWSCredentialsIdentity
35
+ from smithy_core.aio.interfaces.identity import IdentityResolver
36
+
37
+ from livekit import rtc
38
+ from livekit.agents import (
39
+ APIStatusError,
40
+ llm,
41
+ utils,
42
+ )
43
+ from livekit.agents.metrics import RealtimeModelMetrics
44
+ from livekit.agents.metrics.base import Metadata
45
+ from livekit.agents.types import NOT_GIVEN, NotGivenOr
46
+ from livekit.agents.utils import is_given
47
+ from livekit.plugins.aws.experimental.realtime.turn_tracker import _TurnTracker
48
+
49
+ from ...log import logger
50
+ from .events import (
51
+ SonicEventBuilder as seb,
52
+ Tool,
53
+ ToolConfiguration,
54
+ ToolInputSchema,
55
+ ToolSpec,
56
+ )
57
+ from .pretty_printer import AnsiColors, log_event_data, log_message
58
+ from .types import MODALITIES, REALTIME_MODELS, SONIC1_VOICES, SONIC2_VOICES, TURN_DETECTION
59
+
60
+ DEFAULT_INPUT_SAMPLE_RATE = 16000
61
+ DEFAULT_OUTPUT_SAMPLE_RATE = 24000
62
+ DEFAULT_SAMPLE_SIZE_BITS = 16
63
+ DEFAULT_CHANNELS = 1
64
+ DEFAULT_CHUNK_SIZE = 512
65
+ DEFAULT_TEMPERATURE = 0.7
66
+ DEFAULT_TOP_P = 0.9
67
+ DEFAULT_MAX_TOKENS = 1024
68
+ MAX_MESSAGE_SIZE = 1024
69
+ MAX_MESSAGES = 40
70
+ DEFAULT_MAX_SESSION_RESTART_ATTEMPTS = 3
71
+ DEFAULT_MAX_SESSION_RESTART_DELAY = 10
72
+ # Session recycling: restart before 8-min AWS limit or credential expiry
73
+ # Override with LK_SESSION_MAX_DURATION env var for testing (e.g., "60" for 1 minute)
74
+ MAX_SESSION_DURATION_SECONDS = int(os.getenv("LK_SESSION_MAX_DURATION", 6 * 60))
75
+ CREDENTIAL_EXPIRY_BUFFER_SECONDS = 3 * 60 # Restart 3 min before credential expiry
76
+ BARGE_IN_SIGNAL = '{ "interrupted" : true }\n' # Nova Sonic's barge-in detection signal
77
+ DEFAULT_SYSTEM_PROMPT = (
78
+ "Your name is Sonic, and you are a friendly and enthusiastic voice assistant. "
79
+ "You love helping people and having natural conversations. "
80
+ "Be warm, conversational, and engaging. "
81
+ "Keep your responses natural and concise for voice interaction. "
82
+ "Do not repeat yourself. "
83
+ "If you are not sure what the user means, ask them to confirm or clarify. "
84
+ "If after asking for clarification you still do not understand, be honest and tell them you do not understand. "
85
+ "Do not make up information or make assumptions. If you do not know the answer, say so. "
86
+ "When making tool calls, inform the user that you are using a tool to generate the response. "
87
+ "Avoid formatted lists or numbering and keep your output as a spoken transcript. "
88
+ "\n\n"
89
+ "CRITICAL LANGUAGE MIRRORING RULES:\n"
90
+ "- Always reply in the language the user speaks. DO NOT mix with English unless the user does.\n"
91
+ "- If the user talks in English, reply in English.\n"
92
+ "- Please respond in the language the user is talking to you in. If you have a question or suggestion, ask it in the language the user is talking in.\n"
93
+ "- Ensure that our communication remains in the same language as the user."
94
+ )
95
+
96
+ lk_bedrock_debug = int(os.getenv("LK_BEDROCK_DEBUG", 0))
97
+
98
+ # Shared credentials resolver instance to preserve cache across all sessions
99
+ _shared_credentials_resolver: Boto3CredentialsResolver | None = None
100
+
101
+
102
+ def _get_credentials_resolver() -> Boto3CredentialsResolver:
103
+ """Get or create the shared credentials resolver instance.
104
+
105
+ This ensures credential caching works across all RealtimeSession instances.
106
+ """
107
+ global _shared_credentials_resolver
108
+ if _shared_credentials_resolver is None:
109
+ _shared_credentials_resolver = Boto3CredentialsResolver()
110
+ return _shared_credentials_resolver
111
+
112
+
113
+ @dataclass
114
+ class _RealtimeOptions:
115
+ """Configuration container for a Sonic realtime session.
116
+
117
+ Attributes:
118
+ voice (str): Voice identifier used for TTS output.
119
+ temperature (float): Sampling temperature controlling randomness; 1.0 is most deterministic.
120
+ top_p (float): Nucleus sampling parameter; 0.0 considers all tokens.
121
+ max_tokens (int): Maximum number of tokens the model may generate in a single response.
122
+ tool_choice (llm.ToolChoice | None): Strategy that dictates how the model should invoke tools.
123
+ region (str): AWS region hosting the Bedrock Sonic model endpoint.
124
+ turn_detection (TURN_DETECTION): Turn-taking sensitivity - "HIGH", "MEDIUM" (default), or "LOW".
125
+ modalities (MODALITIES): Input/output mode - "audio" for audio-only, "mixed" for audio + text input.
126
+ """ # noqa: E501
127
+
128
+ voice: str
129
+ temperature: float
130
+ top_p: float
131
+ max_tokens: int
132
+ tool_choice: llm.ToolChoice | None
133
+ region: str
134
+ turn_detection: TURN_DETECTION
135
+ modalities: MODALITIES
136
+
137
+
138
+ @dataclass
139
+ class _MessageGeneration:
140
+ """Grouping of streams that together represent one assistant message.
141
+
142
+ Attributes:
143
+ message_id (str): Unique identifier that ties together text and audio for a single assistant turn.
144
+ text_ch (utils.aio.Chan[str]): Channel that yields partial text tokens as they arrive.
145
+ audio_ch (utils.aio.Chan[rtc.AudioFrame]): Channel that yields audio frames for the same assistant turn.
146
+ """ # noqa: E501
147
+
148
+ message_id: str
149
+ text_ch: utils.aio.Chan[str]
150
+ audio_ch: utils.aio.Chan[rtc.AudioFrame]
151
+
152
+
153
+ @dataclass
154
+ class _ResponseGeneration:
155
+ """Book-keeping dataclass tracking the lifecycle of a Nova Sonic completion.
156
+
157
+ Nova Sonic uses a completion model where one completionStart event begins a cycle
158
+ that may contain multiple content blocks (USER ASR, TOOL, ASSISTANT text/audio).
159
+ This generation stays open for the entire completion cycle.
160
+
161
+ Attributes:
162
+ completion_id (str): Nova Sonic's completionId that ties all events together.
163
+ message_ch (utils.aio.Chan[llm.MessageGeneration]): Stream for assistant messages.
164
+ function_ch (utils.aio.Chan[llm.FunctionCall]): Stream that emits function tool calls.
165
+ response_id (str): LiveKit response_id for the assistant's response.
166
+ message_gen (_MessageGeneration | None): Current message generation for assistant output.
167
+ content_id_map (dict[str, str]): Map Nova Sonic contentId -> type (USER/ASSISTANT/TOOL).
168
+ _created_timestamp (float): Wall-clock time when the generation record was created.
169
+ _first_token_timestamp (float | None): Wall-clock time of first token emission.
170
+ _completed_timestamp (float | None): Wall-clock time when the turn fully completed.
171
+ _restart_attempts (int): Number of restart attempts for this specific completion.
172
+ """ # noqa: E501
173
+
174
+ completion_id: str
175
+ message_ch: utils.aio.Chan[llm.MessageGeneration]
176
+ function_ch: utils.aio.Chan[llm.FunctionCall]
177
+ response_id: str
178
+ message_gen: _MessageGeneration | None = None
179
+ content_id_map: dict[str, str] = field(default_factory=dict)
180
+ _created_timestamp: float = field(default_factory=time.time)
181
+ _first_token_timestamp: float | None = None
182
+ _completed_timestamp: float | None = None
183
+ _restart_attempts: int = 0
184
+ _done_fut: asyncio.Future[None] | None = None # Resolved when generation completes
185
+ _emitted: bool = False # Track if generation_created event was emitted
186
+
187
+
188
+ class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc]
189
+ """IdentityResolver implementation that sources AWS credentials from boto3.
190
+
191
+ The resolver delegates to the default boto3.Session() credential chain which
192
+ checks environment variables, shared credentials files, EC2 instance profiles, etc.
193
+ The credentials are then wrapped in an AWSCredentialsIdentity so they can be
194
+ passed into Bedrock runtime clients.
195
+ """
196
+
197
+ def __init__(self) -> None:
198
+ self.session = boto3.Session() # type: ignore[attr-defined]
199
+ self._cached_identity: AWSCredentialsIdentity | None = None
200
+ self._cached_expiry: float | None = None
201
+
202
+ async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity:
203
+ """Asynchronously resolve AWS credentials.
204
+
205
+ This method is invoked by the Bedrock runtime client whenever a new request needs to be
206
+ signed. It converts the static or temporary credentials returned by boto3
207
+ into an AWSCredentialsIdentity instance.
208
+
209
+ Returns:
210
+ AWSCredentialsIdentity: Identity containing the
211
+ AWS access key, secret key and optional session token.
212
+
213
+ Raises:
214
+ ValueError: If no credentials could be found by boto3.
215
+ """
216
+ # Return cached credentials if available
217
+ # Session recycling will close the connection and get fresh credentials before these expire
218
+ if self._cached_identity:
219
+ return self._cached_identity
220
+
221
+ try:
222
+ logger.debug("[CREDS] Attempting to load AWS credentials")
223
+ credentials = self.session.get_credentials()
224
+ if not credentials:
225
+ logger.error("[CREDS] Unable to load AWS credentials")
226
+ raise ValueError("Unable to load AWS credentials")
227
+
228
+ creds = credentials.get_frozen_credentials()
229
+
230
+ # Ensure credentials are valid
231
+ if not creds.access_key or not creds.secret_key:
232
+ logger.error("AWS credentials are incomplete")
233
+ raise ValueError("AWS credentials are incomplete")
234
+
235
+ logger.debug(
236
+ f"[CREDS] AWS credentials loaded successfully. AWS_ACCESS_KEY_ID: {creds.access_key[:4]}***"
237
+ )
238
+
239
+ # Get expiration time if available (for temporary credentials)
240
+ expiry_time = getattr(credentials, "_expiry_time", None)
241
+
242
+ identity = AWSCredentialsIdentity(
243
+ access_key_id=creds.access_key,
244
+ secret_access_key=creds.secret_key,
245
+ session_token=creds.token if creds.token else None,
246
+ expiration=expiry_time,
247
+ )
248
+
249
+ # Cache the identity and expiry
250
+ self._cached_identity = identity
251
+ if expiry_time:
252
+ # Session will restart 3 minutes before expiration
253
+ self._cached_expiry = expiry_time.timestamp() - 180
254
+ logger.debug(
255
+ f"[CREDS] Cached credentials with expiry. "
256
+ f"expiry_time={expiry_time}, restart_before={self._cached_expiry}"
257
+ )
258
+ else:
259
+ # Static credentials don't have an inherent expiration attribute, cache indefinitely
260
+ self._cached_expiry = None
261
+ logger.debug("[CREDS] Cached static credentials (no expiry)")
262
+
263
+ return identity
264
+ except Exception as e:
265
+ logger.error(f"[CREDS] Failed to load AWS credentials: {str(e)}")
266
+ raise ValueError(f"Failed to load AWS credentials: {str(e)}") # noqa: B904
267
+
268
+ def get_credential_expiry_time(self) -> float | None:
269
+ """Get the credential expiry timestamp synchronously.
270
+
271
+ This loads credentials if not cached and returns the expiry time.
272
+ Used for calculating session duration before the async stream starts.
273
+
274
+ Returns:
275
+ float | None: Unix timestamp when credentials expire, or None for static credentials.
276
+ """
277
+ try:
278
+ session = boto3.Session() # type: ignore[attr-defined]
279
+ credentials = session.get_credentials()
280
+ if not credentials:
281
+ return None
282
+
283
+ expiry_time = getattr(credentials, "_expiry_time", None)
284
+ if expiry_time:
285
+ return float(expiry_time.timestamp())
286
+ return None
287
+ except Exception as e:
288
+ logger.warning(f"[CREDS] Failed to get credential expiry: {e}")
289
+ return None
290
+
291
+
292
+ class RealtimeModel(llm.RealtimeModel):
293
+ """High-level entry point that conforms to the LiveKit RealtimeModel interface.
294
+
295
+ The object is very light-weight-– it mainly stores default inference options and
296
+ spawns a RealtimeSession when session() is invoked.
297
+ """
298
+
299
+ def __init__(
300
+ self,
301
+ *,
302
+ model: REALTIME_MODELS | str = "amazon.nova-2-sonic-v1:0",
303
+ modalities: MODALITIES = "mixed",
304
+ voice: NotGivenOr[SONIC1_VOICES | SONIC2_VOICES | str] = NOT_GIVEN,
305
+ temperature: NotGivenOr[float] = NOT_GIVEN,
306
+ top_p: NotGivenOr[float] = NOT_GIVEN,
307
+ max_tokens: NotGivenOr[int] = NOT_GIVEN,
308
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
309
+ region: NotGivenOr[str] = NOT_GIVEN,
310
+ turn_detection: TURN_DETECTION = "MEDIUM",
311
+ generate_reply_timeout: float = 10.0,
312
+ ):
313
+ """Instantiate a new RealtimeModel.
314
+
315
+ Args:
316
+ model (REALTIME_MODELS | str): Bedrock model ID for realtime inference. Defaults to "amazon.nova-2-sonic-v1:0".
317
+ modalities (MODALITIES): Input/output mode. "audio" for audio-only (Sonic 1.0), "mixed" for audio + text input (Sonic 2.0). Defaults to "mixed".
318
+ voice (SONIC1_VOICES | SONIC2_VOICES | str | NotGiven): Voice id for TTS output. Defaults to "tiffany".
319
+ temperature (float | NotGiven): Sampling temperature (0-1). Defaults to DEFAULT_TEMPERATURE.
320
+ top_p (float | NotGiven): Nucleus sampling probability mass. Defaults to DEFAULT_TOP_P.
321
+ max_tokens (int | NotGiven): Upper bound for tokens emitted by the model. Defaults to DEFAULT_MAX_TOKENS.
322
+ tool_choice (llm.ToolChoice | None | NotGiven): Strategy for tool invocation ("auto", "required", or explicit function).
323
+ region (str | NotGiven): AWS region of the Bedrock runtime endpoint.
324
+ turn_detection (TURN_DETECTION): Turn-taking sensitivity. HIGH detects pauses quickly, LOW waits longer. Defaults to MEDIUM.
325
+ generate_reply_timeout (float): Timeout in seconds for generate_reply() calls. Defaults to 10.0.
326
+ """ # noqa: E501
327
+ super().__init__(
328
+ capabilities=llm.RealtimeCapabilities(
329
+ message_truncation=False,
330
+ turn_detection=True,
331
+ user_transcription=True,
332
+ auto_tool_reply_generation=True,
333
+ audio_output=True,
334
+ manual_function_calls=False,
335
+ )
336
+ )
337
+ self._model = model
338
+ self._generate_reply_timeout = generate_reply_timeout
339
+ # note: temperature and top_p do not follow industry standards and are defined slightly differently for Sonic # noqa: E501
340
+ # temperature ranges from 0.0 to 1.0, where 0.0 is the most random and 1.0 is the most deterministic # noqa: E501
341
+ # top_p ranges from 0.0 to 1.0, where 0.0 is the most random and 1.0 is the most deterministic # noqa: E501
342
+ self.temperature = temperature
343
+ self.top_p = top_p
344
+ self._opts = _RealtimeOptions(
345
+ voice=voice if is_given(voice) else "tiffany",
346
+ temperature=temperature if is_given(temperature) else DEFAULT_TEMPERATURE,
347
+ top_p=top_p if is_given(top_p) else DEFAULT_TOP_P,
348
+ max_tokens=max_tokens if is_given(max_tokens) else DEFAULT_MAX_TOKENS,
349
+ tool_choice=tool_choice or None,
350
+ region=region if is_given(region) else "us-east-1",
351
+ turn_detection=turn_detection,
352
+ modalities=modalities,
353
+ )
354
+ self._sessions = weakref.WeakSet[RealtimeSession]()
355
+
356
+ @classmethod
357
+ def with_nova_sonic_1(
358
+ cls,
359
+ *,
360
+ voice: NotGivenOr[SONIC1_VOICES | str] = NOT_GIVEN,
361
+ temperature: NotGivenOr[float] = NOT_GIVEN,
362
+ top_p: NotGivenOr[float] = NOT_GIVEN,
363
+ max_tokens: NotGivenOr[int] = NOT_GIVEN,
364
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
365
+ region: NotGivenOr[str] = NOT_GIVEN,
366
+ turn_detection: TURN_DETECTION = "MEDIUM",
367
+ generate_reply_timeout: float = 10.0,
368
+ ) -> RealtimeModel:
369
+ """Create a RealtimeModel configured for Nova Sonic 1.0 (audio-only).
370
+
371
+ Args:
372
+ voice (SONIC1_VOICES | str | NotGiven): Voice id for TTS output. Import SONIC1_VOICES from livekit.plugins.aws.experimental.realtime for supported values. Defaults to "tiffany".
373
+ temperature (float | NotGiven): Sampling temperature (0-1). Defaults to DEFAULT_TEMPERATURE.
374
+ top_p (float | NotGiven): Nucleus sampling probability mass. Defaults to DEFAULT_TOP_P.
375
+ max_tokens (int | NotGiven): Upper bound for tokens emitted. Defaults to DEFAULT_MAX_TOKENS.
376
+ tool_choice (llm.ToolChoice | None | NotGiven): Strategy for tool invocation.
377
+ region (str | NotGiven): AWS region. Defaults to "us-east-1".
378
+ turn_detection (TURN_DETECTION): Turn-taking sensitivity. Defaults to "MEDIUM".
379
+ generate_reply_timeout (float): Timeout for generate_reply() calls. Defaults to 10.0.
380
+
381
+ Returns:
382
+ RealtimeModel: Configured for Nova Sonic 1.0 with audio-only modalities.
383
+
384
+ Example:
385
+ model = RealtimeModel.with_nova_sonic_1(voice="matthew", tool_choice="auto")
386
+ """
387
+ return cls(
388
+ model="amazon.nova-sonic-v1:0",
389
+ modalities="audio",
390
+ voice=voice,
391
+ temperature=temperature,
392
+ top_p=top_p,
393
+ max_tokens=max_tokens,
394
+ tool_choice=tool_choice,
395
+ region=region,
396
+ turn_detection=turn_detection,
397
+ generate_reply_timeout=generate_reply_timeout,
398
+ )
399
+
400
+ @classmethod
401
+ def with_nova_sonic_2(
402
+ cls,
403
+ *,
404
+ voice: NotGivenOr[SONIC2_VOICES | str] = NOT_GIVEN,
405
+ temperature: NotGivenOr[float] = NOT_GIVEN,
406
+ top_p: NotGivenOr[float] = NOT_GIVEN,
407
+ max_tokens: NotGivenOr[int] = NOT_GIVEN,
408
+ tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN,
409
+ region: NotGivenOr[str] = NOT_GIVEN,
410
+ turn_detection: TURN_DETECTION = "MEDIUM",
411
+ generate_reply_timeout: float = 10.0,
412
+ ) -> RealtimeModel:
413
+ """Create a RealtimeModel configured for Nova Sonic 2.0 (audio + text input).
414
+
415
+ Args:
416
+ voice (SONIC2_VOICES | str | NotGiven): Voice id for TTS output. Import SONIC2_VOICES from livekit.plugins.aws.experimental.realtime for supported values. Defaults to "tiffany".
417
+ temperature (float | NotGiven): Sampling temperature (0-1). Defaults to DEFAULT_TEMPERATURE.
418
+ top_p (float | NotGiven): Nucleus sampling probability mass. Defaults to DEFAULT_TOP_P.
419
+ max_tokens (int | NotGiven): Upper bound for tokens emitted. Defaults to DEFAULT_MAX_TOKENS.
420
+ tool_choice (llm.ToolChoice | None | NotGiven): Strategy for tool invocation.
421
+ region (str | NotGiven): AWS region. Defaults to "us-east-1".
422
+ turn_detection (TURN_DETECTION): Turn-taking sensitivity. Defaults to "MEDIUM".
423
+ generate_reply_timeout (float): Timeout for generate_reply() calls. Defaults to 10.0.
424
+
425
+ Returns:
426
+ RealtimeModel: Configured for Nova Sonic 2.0 with mixed modalities (audio + text input).
427
+
428
+ Example:
429
+ model = RealtimeModel.with_nova_sonic_2(voice="tiffany", max_tokens=10_000)
430
+ """
431
+ return cls(
432
+ model="amazon.nova-2-sonic-v1:0",
433
+ modalities="mixed",
434
+ voice=voice,
435
+ temperature=temperature,
436
+ top_p=top_p,
437
+ max_tokens=max_tokens,
438
+ tool_choice=tool_choice,
439
+ region=region,
440
+ turn_detection=turn_detection,
441
+ generate_reply_timeout=generate_reply_timeout,
442
+ )
443
+
444
+ @property
445
+ def model(self) -> str:
446
+ return self._model
447
+
448
+ @property
449
+ def modalities(self) -> MODALITIES:
450
+ """Input/output mode: "audio" for audio-only, "mixed" for audio + text input."""
451
+ return self._opts.modalities
452
+
453
+ @property
454
+ def provider(self) -> str:
455
+ return "Amazon"
456
+
457
+ def session(self) -> RealtimeSession:
458
+ """Return a new RealtimeSession bound to this model instance."""
459
+ sess = RealtimeSession(self)
460
+ self._sessions.add(sess)
461
+ return sess
462
+
463
+ async def aclose(self) -> None:
464
+ """Close all active sessions."""
465
+ pass
466
+
467
+
468
+ class RealtimeSession( # noqa: F811
469
+ llm.RealtimeSession[Literal["bedrock_server_event_received", "bedrock_client_event_queued"]]
470
+ ):
471
+ """Bidirectional streaming session against the Nova Sonic Bedrock runtime.
472
+
473
+ The session owns two asynchronous tasks:
474
+
475
+ 1. _process_audio_input – pushes user mic audio and tool results to Bedrock.
476
+ 2. _process_responses – receives server events from Bedrock and converts them into
477
+ LiveKit abstractions such as llm.MessageGeneration.
478
+
479
+ A set of helper handlers (_handle_*) transform the low-level Bedrock
480
+ JSON payloads into higher-level application events and keep
481
+ _ResponseGeneration state in sync.
482
+ """
483
+
484
+ def __init__(self, realtime_model: RealtimeModel) -> None:
485
+ """Create and wire-up a new realtime session.
486
+
487
+ Args:
488
+ realtime_model (RealtimeModel): Parent model instance that stores static
489
+ inference options and the Smithy Bedrock client configuration.
490
+ """
491
+ super().__init__(realtime_model)
492
+ self._realtime_model: RealtimeModel = realtime_model
493
+ self._event_builder = seb(
494
+ prompt_name=str(uuid.uuid4()),
495
+ audio_content_name=str(uuid.uuid4()),
496
+ )
497
+ self._input_resampler: rtc.AudioResampler | None = None
498
+ self._bstream = utils.audio.AudioByteStream(
499
+ DEFAULT_INPUT_SAMPLE_RATE, DEFAULT_CHANNELS, samples_per_channel=DEFAULT_CHUNK_SIZE
500
+ )
501
+
502
+ self._response_task = None
503
+ self._audio_input_task = None
504
+ self._stream_response = None
505
+ self._bedrock_client = None
506
+ self._pending_tools: set[str] = set()
507
+ self._is_sess_active = asyncio.Event()
508
+ self._chat_ctx = llm.ChatContext.empty()
509
+ self._tools = llm.ToolContext.empty()
510
+ self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
511
+ # CRITICAL: Initialize futures as None for lazy creation
512
+ # Creating futures in __init__ causes race conditions during session restart.
513
+ # Futures are created in initialize_streams() when the event loop is guaranteed to exist.
514
+ self._tools_ready: asyncio.Future[bool] | None = None
515
+ self._instructions_ready: asyncio.Future[bool] | None = None
516
+ self._chat_ctx_ready: asyncio.Future[bool] | None = None
517
+ self._instructions = DEFAULT_SYSTEM_PROMPT
518
+ self._audio_input_chan = utils.aio.Chan[bytes]()
519
+ self._current_generation: _ResponseGeneration | None = None
520
+ # Session recycling: proactively restart before credential expiry or 8-min limit
521
+ self._session_start_time: float | None = None
522
+ self._session_recycle_task: asyncio.Task[None] | None = None
523
+ self._last_audio_output_time: float = 0.0 # Track when assistant last produced audio
524
+ self._audio_end_turn_received: bool = False # Track when assistant finishes speaking
525
+ self._pending_generation_fut: asyncio.Future[llm.GenerationCreatedEvent] | None = None
526
+ self._sent_message_ids: set[str] = set()
527
+ self._audio_message_ids: set[str] = set()
528
+
529
+ self._event_handlers = {
530
+ "completion_start": self._handle_completion_start_event,
531
+ "audio_output_content_start": self._handle_audio_output_content_start_event,
532
+ "audio_output_content": self._handle_audio_output_content_event,
533
+ "audio_output_content_end": self._handle_audio_output_content_end_event,
534
+ "text_output_content_start": self._handle_text_output_content_start_event,
535
+ "text_output_content": self._handle_text_output_content_event,
536
+ "text_output_content_end": self._handle_text_output_content_end_event,
537
+ "tool_output_content_start": self._handle_tool_output_content_start_event,
538
+ "tool_output_content": self._handle_tool_output_content_event,
539
+ "tool_output_content_end": self._handle_tool_output_content_end_event,
540
+ "completion_end": self._handle_completion_end_event,
541
+ "usage": self._handle_usage_event,
542
+ "other_event": self._handle_other_event,
543
+ }
544
+ self._turn_tracker = _TurnTracker(
545
+ cast(Callable[[str, Any], None], self.emit),
546
+ cast(Callable[[], None], self.emit_generation_event),
547
+ )
548
+
549
+ # Create main task to manage session lifecycle
550
+ self._main_atask = asyncio.create_task(
551
+ self.initialize_streams(), name="RealtimeSession.initialize_streams"
552
+ )
553
+
554
+ @utils.log_exceptions(logger=logger)
555
+ def _initialize_client(self) -> None:
556
+ """Instantiate the Bedrock runtime client"""
557
+ config = Config(
558
+ endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com",
559
+ region=self._realtime_model._opts.region,
560
+ aws_credentials_identity_resolver=_get_credentials_resolver(),
561
+ auth_scheme_resolver=HTTPAuthSchemeResolver(),
562
+ auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")},
563
+ user_agent_extra="x-client-framework:livekit-plugins-aws[realtime]",
564
+ )
565
+ self._bedrock_client = BedrockRuntimeClient(config=config)
566
+
567
+ def _calculate_session_duration(self) -> float:
568
+ """Calculate session duration based on credential expiry and AWS 8-min limit."""
569
+ resolver = _get_credentials_resolver()
570
+ credential_expiry = resolver.get_credential_expiry_time()
571
+
572
+ if credential_expiry is None:
573
+ # Static credentials - just use the max session duration
574
+ logger.info(
575
+ f"[SESSION] Static credentials, using max duration: {MAX_SESSION_DURATION_SECONDS}s"
576
+ )
577
+ return MAX_SESSION_DURATION_SECONDS
578
+
579
+ # Calculate time until we should restart (before credential expiry)
580
+ now = time.time()
581
+ time_until_cred_expiry = credential_expiry - now - CREDENTIAL_EXPIRY_BUFFER_SECONDS
582
+
583
+ # Use the minimum of session limit and credential expiry
584
+ duration = min(MAX_SESSION_DURATION_SECONDS, time_until_cred_expiry)
585
+
586
+ if duration < 30:
587
+ logger.warning(
588
+ f"[SESSION] Very short session duration: {duration:.0f}s. "
589
+ f"Credentials may expire soon."
590
+ )
591
+ duration = max(duration, 10) # At least 10 seconds
592
+
593
+ logger.info(
594
+ f"[SESSION] Session will recycle in {duration:.0f}s "
595
+ f"(max={MAX_SESSION_DURATION_SECONDS}s, time_until_cred_expiry={time_until_cred_expiry:.0f}s)"
596
+ )
597
+
598
+ return duration
599
+
600
+ def _start_session_recycle_timer(self) -> None:
601
+ """Start the session recycling timer."""
602
+ if self._session_recycle_task and not self._session_recycle_task.done():
603
+ self._session_recycle_task.cancel()
604
+
605
+ duration = self._calculate_session_duration()
606
+
607
+ self._session_recycle_task = asyncio.create_task(
608
+ self._session_recycle_timer(duration), name="RealtimeSession._session_recycle_timer"
609
+ )
610
+
611
+ async def _session_recycle_timer(self, duration: float) -> None:
612
+ """Background task that triggers session recycling after duration seconds."""
613
+ try:
614
+ logger.info(f"[SESSION] Recycle timer started, will fire in {duration:.0f}s")
615
+ await asyncio.sleep(duration)
616
+
617
+ if not self._is_sess_active.is_set():
618
+ logger.debug("[SESSION] Session no longer active, skipping recycle")
619
+ return
620
+
621
+ logger.info(
622
+ f"[SESSION] Session duration limit reached ({duration:.0f}s), initiating recycle"
623
+ )
624
+
625
+ # Step 1: Wait for assistant to finish speaking (AUDIO contentEnd with END_TURN)
626
+ if not self._audio_end_turn_received:
627
+ logger.info(
628
+ "[SESSION] Waiting for assistant to finish speaking (AUDIO END_TURN)..."
629
+ )
630
+ while not self._audio_end_turn_received:
631
+ await asyncio.sleep(0.1)
632
+ logger.debug("[SESSION] Assistant finished speaking")
633
+
634
+ # Step 2: Wait for audio to fully stop (no new audio for 1 second)
635
+ logger.debug("[SESSION] Waiting for audio to fully stop...")
636
+ last_audio_time = self._last_audio_output_time
637
+ while True:
638
+ await asyncio.sleep(0.1)
639
+ if self._last_audio_output_time == last_audio_time:
640
+ await asyncio.sleep(0.9)
641
+ if self._last_audio_output_time == last_audio_time:
642
+ logger.debug("[SESSION] No new audio for 1s, proceeding with recycle")
643
+ break
644
+ else:
645
+ logger.debug("[SESSION] New audio detected, continuing to wait...")
646
+ last_audio_time = self._last_audio_output_time
647
+
648
+ # Step 3: Send close events to trigger completionEnd from Nova Sonic
649
+ # This must happen BEFORE cancelling tasks so response task can receive completionEnd
650
+ logger.info("[SESSION] Sending close events to Nova Sonic...")
651
+ if self._stream_response:
652
+ for event in self._event_builder.create_prompt_end_block():
653
+ await self._send_raw_event(event)
654
+
655
+ # Step 4: Wait for completionEnd and let _done_fut resolve
656
+ if self._current_generation and self._current_generation._done_fut:
657
+ try:
658
+ await asyncio.wait_for(self._current_generation._done_fut, timeout=2.0)
659
+ logger.debug("[SESSION] Generation completed (completionEnd received)")
660
+ except asyncio.TimeoutError:
661
+ logger.warning("[SESSION] Timeout waiting for completionEnd, proceeding anyway")
662
+ self._close_current_generation()
663
+
664
+ await self._graceful_session_recycle()
665
+
666
+ except asyncio.CancelledError:
667
+ logger.debug("[SESSION] Recycle timer cancelled")
668
+ raise
669
+ except Exception as e:
670
+ logger.error(f"[SESSION] Error in recycle timer: {e}")
671
+
672
+ async def _graceful_session_recycle(self) -> None:
673
+ """Gracefully recycle the session, preserving conversation state."""
674
+ logger.info("[SESSION] Starting graceful session recycle")
675
+
676
+ # Step 1: Drain any pending tool results
677
+ logger.debug("[SESSION] Draining pending tool results...")
678
+ while True:
679
+ try:
680
+ tool_result = self._tool_results_ch.recv_nowait()
681
+ logger.debug(f"[TOOL] Draining pending result: {tool_result['tool_use_id']}")
682
+ await self._send_tool_events(tool_result["tool_use_id"], tool_result["tool_result"])
683
+ except utils.aio.channel.ChanEmpty:
684
+ logger.debug("[SESSION] No more pending tool results")
685
+ break
686
+ except Exception as e:
687
+ logger.warning(f"[SESSION] Error draining tool result: {e}")
688
+ break
689
+
690
+ # Step 2: Signal tasks to stop
691
+ self._is_sess_active.clear()
692
+
693
+ # Step 3: Wait for response task to exit naturally, then cancel if needed
694
+ if self._response_task and not self._response_task.done():
695
+ try:
696
+ # TODO: Even waiting for 30 seconds this never just happens.
697
+ # See if we can figure out how to make this more graceful
698
+ await asyncio.wait_for(self._response_task, timeout=1.0)
699
+ except asyncio.TimeoutError:
700
+ logger.debug("[SESSION] Response task timeout, cancelling...")
701
+ self._response_task.cancel()
702
+ try:
703
+ await self._response_task
704
+ except asyncio.CancelledError:
705
+ pass
706
+
707
+ # Step 4: Cancel audio input task (blocked on channel, won't exit naturally)
708
+ if self._audio_input_task and not self._audio_input_task.done():
709
+ self._audio_input_task.cancel()
710
+ try:
711
+ await self._audio_input_task
712
+ except asyncio.CancelledError:
713
+ pass
714
+
715
+ # Step 5: Close the stream (close events already sent in _session_recycle_timer)
716
+ if self._stream_response:
717
+ try:
718
+ if not self._stream_response.input_stream.closed:
719
+ await self._stream_response.input_stream.close()
720
+ except Exception as e:
721
+ logger.debug(f"[SESSION] Error closing stream (expected): {e}")
722
+
723
+ # Step 6: Reset state for new session
724
+ self._stream_response = None
725
+ self._bedrock_client = None
726
+ self._event_builder = seb(
727
+ prompt_name=str(uuid.uuid4()),
728
+ audio_content_name=str(uuid.uuid4()),
729
+ )
730
+ self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
731
+ logger.debug("[SESSION] Created fresh tool results channel")
732
+ self._audio_end_turn_received = False
733
+
734
+ # Step 7: Start new session with preserved state
735
+ await self.initialize_streams(is_restart=True)
736
+
737
+ logger.info("[SESSION] Session recycled successfully")
738
+
739
+ @utils.log_exceptions(logger=logger)
740
+ async def _send_raw_event(self, event_json: str) -> None:
741
+ """Low-level helper that serialises event_json and forwards it to the bidirectional stream.
742
+
743
+ Args:
744
+ event_json (str): The JSON payload (already in Bedrock wire format) to queue.
745
+
746
+ Raises:
747
+ Exception: Propagates any failures returned by the Bedrock runtime client.
748
+ """
749
+ if not self._stream_response:
750
+ logger.warning("stream not initialized; dropping event (this should never occur)")
751
+ return
752
+
753
+ # Log the full JSON being sent (skip audio events to avoid log spam)
754
+ if '"audioInput"' not in event_json:
755
+ logger.debug(f"[SEND] {event_json}")
756
+
757
+ event = InvokeModelWithBidirectionalStreamInputChunk(
758
+ value=BidirectionalInputPayloadPart(bytes_=event_json.encode("utf-8"))
759
+ )
760
+
761
+ try:
762
+ await self._stream_response.input_stream.send(event)
763
+ except Exception as e:
764
+ logger.exception("Error sending event")
765
+ err_msg = getattr(e, "message", str(e))
766
+ request_id = None
767
+ try:
768
+ request_id = err_msg.split(" ")[0].split("=")[1]
769
+ except Exception:
770
+ pass
771
+
772
+ self.emit(
773
+ "error",
774
+ llm.RealtimeModelError(
775
+ timestamp=time.monotonic(),
776
+ label=self._realtime_model._label,
777
+ error=APIStatusError(
778
+ message=err_msg,
779
+ status_code=500,
780
+ request_id=request_id,
781
+ body=e,
782
+ retryable=False,
783
+ ),
784
+ recoverable=False,
785
+ ),
786
+ )
787
+ raise
788
+
789
+ def _serialize_tool_config(self) -> ToolConfiguration | None:
790
+ """Convert self.tools into the JSON structure expected by Sonic.
791
+
792
+ If any tools are registered, the method also harmonises temperature and
793
+ top_p defaults to Sonic's recommended greedy values (1.0).
794
+
795
+ Returns:
796
+ ToolConfiguration | None: None when no tools are present, otherwise a complete config block.
797
+ """ # noqa: E501
798
+ tool_cfg = None
799
+ if self.tools.function_tools:
800
+ tools = []
801
+ for name, f in self.tools.function_tools.items():
802
+ if llm.tool_context.is_function_tool(f):
803
+ description = llm.tool_context.get_function_info(f).description
804
+ input_schema = llm.utils.build_legacy_openai_schema(f, internally_tagged=True)[
805
+ "parameters"
806
+ ]
807
+ elif llm.tool_context.is_raw_function_tool(f):
808
+ description = llm.tool_context.get_raw_function_info(f).raw_schema.get(
809
+ "description"
810
+ )
811
+ raw_schema = llm.tool_context.get_raw_function_info(f).raw_schema
812
+ # Safely access parameters with fallback
813
+ input_schema = raw_schema.get(
814
+ "parameters",
815
+ raw_schema.get("input_schema", {"type": "object", "properties": {}}),
816
+ )
817
+ else:
818
+ continue
819
+
820
+ tool = Tool(
821
+ toolSpec=ToolSpec(
822
+ name=name,
823
+ description=description or "No description provided",
824
+ inputSchema=ToolInputSchema(json_=json.dumps(input_schema)), # type: ignore
825
+ )
826
+ )
827
+ tools.append(tool)
828
+ tool_choice = self._tool_choice_adapter(self._realtime_model._opts.tool_choice)
829
+ logger.debug(f"TOOL CHOICE: {tool_choice}")
830
+ tool_cfg = ToolConfiguration(tools=tools, toolChoice=tool_choice)
831
+
832
+ # recommended to set greedy inference configs for tool calls
833
+ if not is_given(self._realtime_model.top_p):
834
+ self._realtime_model._opts.top_p = 1.0
835
+ if not is_given(self._realtime_model.temperature):
836
+ self._realtime_model._opts.temperature = 1.0
837
+ return tool_cfg
838
+
839
+ @utils.log_exceptions(logger=logger)
840
+ async def initialize_streams(self, is_restart: bool = False) -> None:
841
+ """Open the Bedrock bidirectional stream and spawn background worker tasks.
842
+
843
+ This coroutine is idempotent and can be invoked again when recoverable
844
+ errors (e.g. timeout, throttling) require a fresh session.
845
+
846
+ Args:
847
+ is_restart (bool, optional): Marks whether we are re-initialising an
848
+ existing session after an error. Defaults to False.
849
+ """
850
+ try:
851
+ if not self._bedrock_client:
852
+ logger.info("Creating Bedrock client")
853
+ self._initialize_client()
854
+ assert self._bedrock_client is not None, "bedrock_client is None"
855
+
856
+ logger.info("Initializing Bedrock stream")
857
+ self._stream_response = (
858
+ await self._bedrock_client.invoke_model_with_bidirectional_stream(
859
+ InvokeModelWithBidirectionalStreamOperationInput(
860
+ model_id=self._realtime_model.model
861
+ )
862
+ )
863
+ )
864
+
865
+ if not is_restart:
866
+ # Lazy-initialize futures if needed
867
+ if self._tools_ready is None:
868
+ self._tools_ready = asyncio.get_running_loop().create_future()
869
+ if self._instructions_ready is None:
870
+ self._instructions_ready = asyncio.get_running_loop().create_future()
871
+ if self._chat_ctx_ready is None:
872
+ self._chat_ctx_ready = asyncio.get_running_loop().create_future()
873
+
874
+ pending_events: list[asyncio.Future] = []
875
+ if not self.tools.function_tools:
876
+ pending_events.append(self._tools_ready)
877
+ if not self._instructions_ready.done():
878
+ pending_events.append(self._instructions_ready)
879
+ if not self._chat_ctx_ready.done():
880
+ pending_events.append(self._chat_ctx_ready)
881
+
882
+ # note: can't know during sess init whether tools were not added
883
+ # or if they were added haven't yet been updated
884
+ # therefore in the case there are no tools, we wait the entire timeout
885
+ try:
886
+ if pending_events:
887
+ await asyncio.wait_for(asyncio.gather(*pending_events), timeout=0.5)
888
+ except asyncio.TimeoutError:
889
+ if self._tools_ready and not self._tools_ready.done():
890
+ logger.warning("Tools not ready after 500ms, continuing without them")
891
+
892
+ if self._instructions_ready and not self._instructions_ready.done():
893
+ logger.warning(
894
+ "Instructions not received after 500ms, proceeding with default instructions" # noqa: E501
895
+ )
896
+ if self._chat_ctx_ready and not self._chat_ctx_ready.done():
897
+ logger.warning(
898
+ "Chat context not received after 500ms, proceeding with empty chat context" # noqa: E501
899
+ )
900
+
901
+ logger.info(
902
+ f"Initializing Bedrock session with realtime options: {self._realtime_model._opts}"
903
+ )
904
+ # there is a 40-message limit on the chat context
905
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
906
+ logger.warning(
907
+ f"Chat context has {len(self._chat_ctx.items)} messages, truncating to {MAX_MESSAGES}" # noqa: E501
908
+ )
909
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
910
+
911
+ # On restart, ensure chat history starts with USER (Nova Sonic requirement)
912
+ restart_ctx = self._chat_ctx
913
+ if is_restart and self._chat_ctx.items:
914
+ first_item = self._chat_ctx.items[0]
915
+ if first_item.type == "message" and first_item.role == "assistant":
916
+ restart_ctx = self._chat_ctx.copy()
917
+ dummy_msg = llm.ChatMessage(role="user", content=["[Resuming conversation]"])
918
+ restart_ctx.items.insert(0, dummy_msg)
919
+ logger.debug("[SESSION] Added dummy USER message to start of chat history")
920
+
921
+ init_events = self._event_builder.create_prompt_start_block(
922
+ voice_id=self._realtime_model._opts.voice,
923
+ sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE, # type: ignore
924
+ system_content=self._instructions,
925
+ chat_ctx=restart_ctx,
926
+ tool_configuration=self._serialize_tool_config(),
927
+ max_tokens=self._realtime_model._opts.max_tokens,
928
+ top_p=self._realtime_model._opts.top_p,
929
+ temperature=self._realtime_model._opts.temperature,
930
+ endpointing_sensitivity=self._realtime_model._opts.turn_detection,
931
+ )
932
+
933
+ for event in init_events:
934
+ await self._send_raw_event(event)
935
+ logger.debug(f"Sent event: {event}")
936
+
937
+ # Always create audio input task (even on restart)
938
+ self._audio_input_task = asyncio.create_task(
939
+ self._process_audio_input(), name="RealtimeSession._process_audio_input"
940
+ )
941
+
942
+ self._response_task = asyncio.create_task(
943
+ self._process_responses(), name="RealtimeSession._process_responses"
944
+ )
945
+ self._is_sess_active.set()
946
+
947
+ # Start session recycling timer
948
+ self._session_start_time = time.time()
949
+ self._start_session_recycle_timer()
950
+
951
+ logger.debug("Stream initialized successfully")
952
+ except Exception as e:
953
+ logger.debug(f"Failed to initialize stream: {str(e)}")
954
+ raise
955
+ return self
956
+
957
+ @utils.log_exceptions(logger=logger)
958
+ def emit_generation_event(self) -> None:
959
+ """Publish a llm.GenerationCreatedEvent to external subscribers.
960
+
961
+ This can be called multiple times for the same generation:
962
+ - Once from _create_response_generation() when a NEW generation is created
963
+ - Once from TurnTracker when TOOL_OUTPUT_CONTENT_START or ASSISTANT_SPEC_START arrives
964
+
965
+ The TurnTracker emission is critical for tool calls - it happens at the right moment
966
+ for the framework to start listening before the tool call is emitted.
967
+ """
968
+ if self._current_generation is None:
969
+ logger.debug("[GEN] emit_generation_event called but no generation exists - ignoring")
970
+ return
971
+
972
+ # Log whether this is first or re-emission for tool call
973
+ if self._current_generation._emitted:
974
+ logger.debug(
975
+ f"[GEN] EMITTING generation_created (re-emit for tool call) for response_id={self._current_generation.response_id}"
976
+ )
977
+ else:
978
+ logger.debug(
979
+ f"[GEN] EMITTING generation_created for response_id={self._current_generation.response_id}"
980
+ )
981
+
982
+ self._current_generation._emitted = True
983
+ generation_ev = llm.GenerationCreatedEvent(
984
+ message_stream=self._current_generation.message_ch,
985
+ function_stream=self._current_generation.function_ch,
986
+ user_initiated=False,
987
+ response_id=self._current_generation.response_id,
988
+ )
989
+ self.emit("generation_created", generation_ev)
990
+
991
+ # Resolve pending generate_reply future if exists
992
+ if self._pending_generation_fut and not self._pending_generation_fut.done():
993
+ self._pending_generation_fut.set_result(generation_ev)
994
+ self._pending_generation_fut = None
995
+
996
+ @utils.log_exceptions(logger=logger)
997
+ async def _handle_event(self, event_data: dict) -> None:
998
+ """Dispatch a raw Bedrock event to the corresponding _handle_* method."""
999
+ event_type = self._event_builder.get_event_type(event_data)
1000
+ event_handler = self._event_handlers.get(event_type)
1001
+ if event_handler:
1002
+ await event_handler(event_data)
1003
+ self._turn_tracker.feed(event_data)
1004
+ else:
1005
+ logger.warning(f"No event handler found for event type: {event_type}")
1006
+
1007
+ async def _handle_completion_start_event(self, event_data: dict) -> None:
1008
+ """Handle completionStart - create new generation for this completion cycle."""
1009
+ log_event_data(event_data)
1010
+ self._create_response_generation()
1011
+
1012
+ def _create_response_generation(self) -> None:
1013
+ """Instantiate _ResponseGeneration and emit the GenerationCreated event.
1014
+
1015
+ Can be called multiple times - will reuse existing generation but ensure
1016
+ message structure exists.
1017
+ """
1018
+ generation_created = False
1019
+ if self._current_generation is None:
1020
+ completion_id = "unknown" # Will be set from events
1021
+ response_id = str(uuid.uuid4())
1022
+
1023
+ logger.debug(f"[GEN] Creating NEW generation, response_id={response_id}")
1024
+ self._current_generation = _ResponseGeneration(
1025
+ completion_id=completion_id,
1026
+ message_ch=utils.aio.Chan(),
1027
+ function_ch=utils.aio.Chan(),
1028
+ response_id=response_id,
1029
+ _done_fut=asyncio.get_running_loop().create_future(),
1030
+ )
1031
+ generation_created = True
1032
+ else:
1033
+ logger.debug(
1034
+ f"[GEN] Generation already exists: response_id={self._current_generation.response_id}, emitted={self._current_generation._emitted}"
1035
+ )
1036
+
1037
+ # Always ensure message structure exists (even if generation already exists)
1038
+ if self._current_generation.message_gen is None:
1039
+ logger.debug(
1040
+ f"[GEN] Creating message structure for response_id={self._current_generation.response_id}"
1041
+ )
1042
+ msg_gen = _MessageGeneration(
1043
+ message_id=self._current_generation.response_id,
1044
+ text_ch=utils.aio.Chan(),
1045
+ audio_ch=utils.aio.Chan(),
1046
+ )
1047
+ msg_modalities = asyncio.Future[list[Literal["text", "audio"]]]()
1048
+ msg_modalities.set_result(
1049
+ ["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
1050
+ )
1051
+
1052
+ self._current_generation.message_gen = msg_gen
1053
+ self._current_generation.message_ch.send_nowait(
1054
+ llm.MessageGeneration(
1055
+ message_id=msg_gen.message_id,
1056
+ text_stream=msg_gen.text_ch,
1057
+ audio_stream=msg_gen.audio_ch,
1058
+ modalities=msg_modalities,
1059
+ )
1060
+ )
1061
+ else:
1062
+ logger.debug(
1063
+ f"[GEN] Message structure already exists for response_id={self._current_generation.response_id}"
1064
+ )
1065
+
1066
+ # Only emit generation event if we created a new generation
1067
+ if generation_created:
1068
+ logger.debug("[GEN] New generation created - calling emit_generation_event()")
1069
+ self.emit_generation_event()
1070
+
1071
+ # will be completely ignoring post-ASR text events
1072
+ async def _handle_text_output_content_start_event(self, event_data: dict) -> None:
1073
+ """Handle text_output_content_start - track content type."""
1074
+ log_event_data(event_data)
1075
+
1076
+ role = event_data["event"]["contentStart"]["role"]
1077
+
1078
+ # CRITICAL: Create NEW generation for each ASSISTANT SPECULATIVE response
1079
+ # Nova Sonic sends ASSISTANT SPECULATIVE for each new assistant turn, including after tool calls.
1080
+ # Without this, audio frames get routed to the wrong generation and don't play.
1081
+ if role == "ASSISTANT":
1082
+ additional_fields = event_data["event"]["contentStart"].get("additionalModelFields", "")
1083
+ if "SPECULATIVE" in additional_fields:
1084
+ # This is a new assistant response - close previous and create new
1085
+ logger.debug("[GEN] ASSISTANT SPECULATIVE text received")
1086
+ if self._current_generation is not None:
1087
+ logger.debug(
1088
+ f"[GEN] Closing previous generation (response_id={self._current_generation.response_id}) for new SPECULATIVE"
1089
+ )
1090
+ self._close_current_generation()
1091
+ self._create_response_generation()
1092
+ else:
1093
+ # For USER and FINAL, just ensure generation exists
1094
+ self._create_response_generation()
1095
+
1096
+ # CRITICAL: Check if generation exists before accessing
1097
+ # Barge-in can set _current_generation to None between the creation above and here.
1098
+ # Without this check, we crash on interruptions.
1099
+ if self._current_generation is None:
1100
+ logger.debug("No generation exists - ignoring content_start event")
1101
+ return
1102
+
1103
+ content_id = event_data["event"]["contentStart"]["contentId"]
1104
+
1105
+ # Track what type of content this is
1106
+ if role == "USER":
1107
+ self._current_generation.content_id_map[content_id] = "USER_ASR"
1108
+ elif role == "ASSISTANT":
1109
+ additional_fields = event_data["event"]["contentStart"].get("additionalModelFields", "")
1110
+ if "SPECULATIVE" in additional_fields:
1111
+ self._current_generation.content_id_map[content_id] = "ASSISTANT_TEXT"
1112
+ elif "FINAL" in additional_fields:
1113
+ self._current_generation.content_id_map[content_id] = "ASSISTANT_FINAL"
1114
+
1115
+ async def _handle_text_output_content_event(self, event_data: dict) -> None:
1116
+ """Stream partial text tokens into the current generation."""
1117
+ log_event_data(event_data)
1118
+
1119
+ if self._current_generation is None:
1120
+ logger.debug("No generation exists - ignoring text_output event")
1121
+ return
1122
+
1123
+ content_id = event_data["event"]["textOutput"]["contentId"]
1124
+ text_content = f"{event_data['event']['textOutput']['content']}\n"
1125
+
1126
+ # Nova Sonic's automatic barge-in detection
1127
+ if text_content == BARGE_IN_SIGNAL:
1128
+ idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
1129
+ if idx >= 0 and (item := self._chat_ctx.items[idx]).type == "message":
1130
+ item.interrupted = True
1131
+ logger.debug("Barge-in detected - marked message as interrupted")
1132
+
1133
+ # Close generation on barge-in unless tools are pending
1134
+ if not self._pending_tools:
1135
+ self._close_current_generation()
1136
+ else:
1137
+ logger.debug(f"Keeping generation open - {len(self._pending_tools)} pending tools")
1138
+ return
1139
+
1140
+ content_type = self._current_generation.content_id_map.get(content_id)
1141
+
1142
+ if content_type == "USER_ASR":
1143
+ logger.debug(f"INPUT TRANSCRIPTION UPDATED: {text_content}")
1144
+ self._update_chat_ctx(role="user", text_content=text_content)
1145
+
1146
+ elif content_type == "ASSISTANT_TEXT":
1147
+ # Set first token timestamp if not already set
1148
+ if self._current_generation._first_token_timestamp is None:
1149
+ self._current_generation._first_token_timestamp = time.time()
1150
+
1151
+ # Stream text to LiveKit
1152
+ if self._current_generation.message_gen:
1153
+ self._current_generation.message_gen.text_ch.send_nowait(text_content)
1154
+ self._update_chat_ctx(role="assistant", text_content=text_content)
1155
+
1156
+ def _update_chat_ctx(self, role: llm.ChatRole, text_content: str) -> None:
1157
+ """
1158
+ Update the chat context with the latest ASR text while guarding against model limitations:
1159
+ a) 40 total messages limit
1160
+ b) 1kB message size limit
1161
+ """
1162
+ logger.debug(f"Updating chat context with role: {role} and text_content: {text_content}")
1163
+ if len(self._chat_ctx.items) == 0:
1164
+ msg = self._chat_ctx.add_message(role=role, content=text_content)
1165
+ if role == "user":
1166
+ self._audio_message_ids.add(msg.id)
1167
+ else:
1168
+ prev_utterance = self._chat_ctx.items[-1]
1169
+ if prev_utterance.type == "message" and prev_utterance.role == role:
1170
+ if isinstance(prev_content := prev_utterance.content[0], str) and (
1171
+ len(prev_content.encode("utf-8")) + len(text_content.encode("utf-8"))
1172
+ < MAX_MESSAGE_SIZE
1173
+ ):
1174
+ prev_utterance.content[0] = "\n".join([prev_content, text_content])
1175
+ else:
1176
+ msg = self._chat_ctx.add_message(role=role, content=text_content)
1177
+ if role == "user":
1178
+ self._audio_message_ids.add(msg.id)
1179
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
1180
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
1181
+ else:
1182
+ msg = self._chat_ctx.add_message(role=role, content=text_content)
1183
+ if role == "user":
1184
+ self._audio_message_ids.add(msg.id)
1185
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
1186
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
1187
+
1188
+ # cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
1189
+ async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
1190
+ """Handle text content end - log but don't close generation yet."""
1191
+ # Nova Sonic sends multiple content blocks within one completion
1192
+ # Don't close generation here - wait for completionEnd or audio_output_content_end
1193
+ log_event_data(event_data)
1194
+
1195
+ async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
1196
+ """Track tool content start."""
1197
+ log_event_data(event_data)
1198
+
1199
+ # Ensure generation exists
1200
+ self._create_response_generation()
1201
+
1202
+ if self._current_generation is None:
1203
+ return
1204
+
1205
+ content_id = event_data["event"]["contentStart"]["contentId"]
1206
+ self._current_generation.content_id_map[content_id] = "TOOL"
1207
+
1208
+ async def _handle_tool_output_content_event(self, event_data: dict) -> None:
1209
+ """Execute the referenced tool locally and queue results."""
1210
+ log_event_data(event_data)
1211
+
1212
+ if self._current_generation is None:
1213
+ logger.warning("tool_output_content received without active generation")
1214
+ return
1215
+
1216
+ tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
1217
+ tool_name = event_data["event"]["toolUse"]["toolName"]
1218
+ args = event_data["event"]["toolUse"]["content"]
1219
+
1220
+ # Emit function call to LiveKit framework
1221
+ self._current_generation.function_ch.send_nowait(
1222
+ llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
1223
+ )
1224
+ self._pending_tools.add(tool_use_id)
1225
+ logger.debug(f"Tool call emitted: {tool_name} (id={tool_use_id})")
1226
+
1227
+ # CRITICAL: Close generation after tool call emission
1228
+ # The LiveKit framework expects the generation to close so it can call update_chat_ctx()
1229
+ # with the tool results. A new generation will be created when Nova Sonic sends the next
1230
+ # ASSISTANT SPECULATIVE text event with the tool response.
1231
+ logger.debug("Closing generation to allow tool result delivery")
1232
+ self._close_current_generation()
1233
+
1234
+ async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
1235
+ log_event_data(event_data)
1236
+
1237
+ async def _handle_audio_output_content_start_event(self, event_data: dict) -> None:
1238
+ """Track audio content start."""
1239
+ if self._current_generation is not None:
1240
+ log_event_data(event_data)
1241
+ content_id = event_data["event"]["contentStart"]["contentId"]
1242
+ self._current_generation.content_id_map[content_id] = "ASSISTANT_AUDIO"
1243
+
1244
+ async def _handle_audio_output_content_event(self, event_data: dict) -> None:
1245
+ """Decode base64 audio from Bedrock and forward it to the audio stream."""
1246
+ if self._current_generation is None or self._current_generation.message_gen is None:
1247
+ return
1248
+
1249
+ content_id = event_data["event"]["audioOutput"]["contentId"]
1250
+ content_type = self._current_generation.content_id_map.get(content_id)
1251
+
1252
+ if content_type == "ASSISTANT_AUDIO":
1253
+ audio_content = event_data["event"]["audioOutput"]["content"]
1254
+ audio_bytes = base64.b64decode(audio_content)
1255
+ self._current_generation.message_gen.audio_ch.send_nowait(
1256
+ rtc.AudioFrame(
1257
+ data=audio_bytes,
1258
+ sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
1259
+ num_channels=DEFAULT_CHANNELS,
1260
+ samples_per_channel=len(audio_bytes) // 2,
1261
+ )
1262
+ )
1263
+ # Track when we last received audio output (for session recycling)
1264
+ self._last_audio_output_time = time.time()
1265
+
1266
+ async def _handle_audio_output_content_end_event(self, event_data: dict) -> None:
1267
+ """Handle audio content end - track END_TURN for session recycling."""
1268
+ log_event_data(event_data)
1269
+
1270
+ # Check if this is END_TURN (assistant finished speaking)
1271
+ stop_reason = event_data.get("event", {}).get("contentEnd", {}).get("stopReason")
1272
+ if stop_reason == "END_TURN":
1273
+ self._audio_end_turn_received = True
1274
+ logger.debug("[SESSION] AUDIO END_TURN received - assistant finished speaking")
1275
+
1276
+ # Nova Sonic uses one completion for entire session
1277
+ # Don't close generation here - wait for new completionStart or session end
1278
+
1279
+ def _close_current_generation(self) -> None:
1280
+ """Helper that closes all channels of the active generation."""
1281
+ if self._current_generation is None:
1282
+ return
1283
+
1284
+ response_id = self._current_generation.response_id
1285
+ was_emitted = self._current_generation._emitted
1286
+
1287
+ # Set completed timestamp
1288
+ if self._current_generation._completed_timestamp is None:
1289
+ self._current_generation._completed_timestamp = time.time()
1290
+
1291
+ # Close message channels
1292
+ if self._current_generation.message_gen:
1293
+ if not self._current_generation.message_gen.audio_ch.closed:
1294
+ self._current_generation.message_gen.audio_ch.close()
1295
+ if not self._current_generation.message_gen.text_ch.closed:
1296
+ self._current_generation.message_gen.text_ch.close()
1297
+
1298
+ # Close generation channels
1299
+ if not self._current_generation.message_ch.closed:
1300
+ self._current_generation.message_ch.close()
1301
+ if not self._current_generation.function_ch.closed:
1302
+ self._current_generation.function_ch.close()
1303
+
1304
+ # Resolve _done_fut to signal generation is complete (for session recycling)
1305
+ if self._current_generation._done_fut and not self._current_generation._done_fut.done():
1306
+ self._current_generation._done_fut.set_result(None)
1307
+
1308
+ logger.debug(
1309
+ f"[GEN] CLOSED generation response_id={response_id}, was_emitted={was_emitted}"
1310
+ )
1311
+ self._current_generation = None
1312
+
1313
+ async def _handle_completion_end_event(self, event_data: dict) -> None:
1314
+ """Handle completionEnd - close the generation for this completion cycle."""
1315
+ log_event_data(event_data)
1316
+
1317
+ # Close generation if still open
1318
+ if self._current_generation:
1319
+ logger.debug("completionEnd received, closing generation")
1320
+ self._close_current_generation()
1321
+
1322
+ async def _handle_other_event(self, event_data: dict) -> None:
1323
+ log_event_data(event_data)
1324
+
1325
+ async def _handle_usage_event(self, event_data: dict) -> None:
1326
+ # log_event_data(event_data)
1327
+ input_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["input"]
1328
+ output_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["output"]
1329
+
1330
+ # Calculate metrics from timestamps
1331
+ duration = 0.0
1332
+ ttft = 0.0
1333
+ tokens_per_second = 0.0
1334
+
1335
+ if self._current_generation is not None:
1336
+ created_ts = self._current_generation._created_timestamp
1337
+ first_token_ts = self._current_generation._first_token_timestamp
1338
+ completed_ts = self._current_generation._completed_timestamp
1339
+
1340
+ # Calculate TTFT (time to first token)
1341
+ if first_token_ts is not None and isinstance(created_ts, (int, float)):
1342
+ ttft = first_token_ts - created_ts
1343
+
1344
+ # Calculate duration (total time from creation to completion)
1345
+ if completed_ts is not None and isinstance(created_ts, (int, float)):
1346
+ duration = completed_ts - created_ts
1347
+
1348
+ # Calculate tokens per second
1349
+ total_tokens = (
1350
+ input_tokens["speechTokens"]
1351
+ + input_tokens["textTokens"]
1352
+ + output_tokens["speechTokens"]
1353
+ + output_tokens["textTokens"]
1354
+ )
1355
+ if duration > 0:
1356
+ tokens_per_second = total_tokens / duration
1357
+
1358
+ metrics = RealtimeModelMetrics(
1359
+ label=self._realtime_model.label,
1360
+ request_id=event_data["event"]["usageEvent"]["completionId"],
1361
+ timestamp=time.monotonic(),
1362
+ duration=duration,
1363
+ ttft=ttft,
1364
+ cancelled=False,
1365
+ input_tokens=input_tokens["speechTokens"] + input_tokens["textTokens"],
1366
+ output_tokens=output_tokens["speechTokens"] + output_tokens["textTokens"],
1367
+ total_tokens=input_tokens["speechTokens"]
1368
+ + input_tokens["textTokens"]
1369
+ + output_tokens["speechTokens"]
1370
+ + output_tokens["textTokens"],
1371
+ tokens_per_second=tokens_per_second,
1372
+ input_token_details=RealtimeModelMetrics.InputTokenDetails(
1373
+ text_tokens=input_tokens["textTokens"],
1374
+ audio_tokens=input_tokens["speechTokens"],
1375
+ image_tokens=0,
1376
+ cached_tokens=0,
1377
+ cached_tokens_details=None,
1378
+ ),
1379
+ output_token_details=RealtimeModelMetrics.OutputTokenDetails(
1380
+ text_tokens=output_tokens["textTokens"],
1381
+ audio_tokens=output_tokens["speechTokens"],
1382
+ image_tokens=0,
1383
+ ),
1384
+ metadata=Metadata(
1385
+ model_name=self._realtime_model.model, model_provider=self._realtime_model.provider
1386
+ ),
1387
+ )
1388
+ self.emit("metrics_collected", metrics)
1389
+
1390
+ @utils.log_exceptions(logger=logger)
1391
+ async def _process_responses(self) -> None:
1392
+ """Background task that drains Bedrock's output stream and feeds the event handlers."""
1393
+ try:
1394
+ await self._is_sess_active.wait()
1395
+ assert self._stream_response is not None, "stream_response is None"
1396
+
1397
+ # note: may need another signal here to block input task until bedrock is ready
1398
+ # TODO: save this as a field so we're not re-awaiting it every time
1399
+ _, output_stream = await self._stream_response.await_output()
1400
+ while self._is_sess_active.is_set():
1401
+ # and not self.stream_response.output_stream.closed:
1402
+ try:
1403
+ result = await output_stream.receive()
1404
+ if result is None:
1405
+ # Stream closed, exit gracefully
1406
+ logger.debug("[SESSION] Stream returned None, exiting")
1407
+ break
1408
+ if result.value and result.value.bytes_:
1409
+ try:
1410
+ response_data = result.value.bytes_.decode("utf-8")
1411
+ json_data = json.loads(response_data)
1412
+ # logger.debug(f"Received event: {json_data}")
1413
+ await self._handle_event(json_data)
1414
+ except json.JSONDecodeError:
1415
+ logger.warning(f"JSON decode error: {response_data}")
1416
+ else:
1417
+ logger.warning("No response received")
1418
+ except concurrent.futures.InvalidStateError:
1419
+ # Future was cancelled during shutdown - expected when AWS CRT
1420
+ # tries to deliver data to cancelled futures
1421
+ logger.debug(
1422
+ "[SESSION] Future cancelled during receive (expected during shutdown)"
1423
+ )
1424
+ break
1425
+ except AttributeError as ae:
1426
+ # Result is None during shutdown
1427
+ if "'NoneType' object has no attribute" in str(ae):
1428
+ logger.debug(
1429
+ "[SESSION] Stream closed during receive (expected during shutdown)"
1430
+ )
1431
+ break
1432
+ raise
1433
+ except asyncio.CancelledError:
1434
+ logger.info("Response processing task cancelled")
1435
+ self._close_current_generation()
1436
+ raise
1437
+ except ValidationException as ve:
1438
+ # there is a 3min no-activity (e.g. silence) timeout on the stream, after which the stream is closed # noqa: E501
1439
+ if (
1440
+ "InternalErrorCode=531::RST_STREAM closed stream. HTTP/2 error code: NO_ERROR" # noqa: E501
1441
+ in ve.message
1442
+ ):
1443
+ logger.warning(f"Validation error: {ve}\nAttempting to recover...")
1444
+ await self._restart_session(ve)
1445
+ elif "Tool Response parsing error" in ve.message:
1446
+ # Tool parsing errors are recoverable - log and continue
1447
+ logger.warning(f"Tool response parsing error (recoverable): {ve}")
1448
+
1449
+ # Close current generation to unblock the model
1450
+ if self._current_generation:
1451
+ logger.debug("Closing generation due to tool parsing error")
1452
+ self._close_current_generation()
1453
+
1454
+ # Clear pending tools since they failed
1455
+ if self._pending_tools:
1456
+ logger.debug(f"Clearing {len(self._pending_tools)} pending tools")
1457
+ self._pending_tools.clear()
1458
+
1459
+ self.emit(
1460
+ "error",
1461
+ llm.RealtimeModelError(
1462
+ timestamp=time.monotonic(),
1463
+ label=self._realtime_model._label,
1464
+ error=APIStatusError(
1465
+ message=ve.message,
1466
+ status_code=400,
1467
+ request_id="",
1468
+ body=ve,
1469
+ retryable=False,
1470
+ ),
1471
+ recoverable=True,
1472
+ ),
1473
+ )
1474
+ # Don't raise - continue processing
1475
+ else:
1476
+ logger.error(f"Validation error: {ve}")
1477
+ self.emit(
1478
+ "error",
1479
+ llm.RealtimeModelError(
1480
+ timestamp=time.monotonic(),
1481
+ label=self._realtime_model._label,
1482
+ error=APIStatusError(
1483
+ message=ve.message,
1484
+ status_code=400,
1485
+ request_id="",
1486
+ body=ve,
1487
+ retryable=False,
1488
+ ),
1489
+ recoverable=False,
1490
+ ),
1491
+ )
1492
+ raise
1493
+ except (
1494
+ ThrottlingException,
1495
+ ModelNotReadyException,
1496
+ ModelErrorException,
1497
+ ModelStreamErrorException,
1498
+ ) as re:
1499
+ logger.warning(
1500
+ f"Retryable error: {re}\nAttempting to recover...", exc_info=True
1501
+ )
1502
+ await self._restart_session(re)
1503
+ break
1504
+ except ModelTimeoutException as mte:
1505
+ logger.warning(
1506
+ f"Model timeout error: {mte}\nAttempting to recover...", exc_info=True
1507
+ )
1508
+ await self._restart_session(mte)
1509
+ break
1510
+ except ValueError as val_err:
1511
+ if "I/O operation on closed file." == val_err.args[0]:
1512
+ logger.info("initiating graceful shutdown of session")
1513
+ break
1514
+ raise
1515
+ except OSError:
1516
+ logger.info("stream already closed, exiting")
1517
+ break
1518
+ except Exception as e:
1519
+ err_msg = getattr(e, "message", str(e))
1520
+ logger.error(f"Response processing error: {err_msg} (type: {type(e)})")
1521
+ request_id = None
1522
+ try:
1523
+ request_id = err_msg.split(" ")[0].split("=")[1]
1524
+ except Exception:
1525
+ pass
1526
+
1527
+ self.emit(
1528
+ "error",
1529
+ llm.RealtimeModelError(
1530
+ timestamp=time.monotonic(),
1531
+ label=self._realtime_model._label,
1532
+ error=APIStatusError(
1533
+ message=err_msg,
1534
+ status_code=500,
1535
+ request_id=request_id,
1536
+ body=e,
1537
+ retryable=False,
1538
+ ),
1539
+ recoverable=False,
1540
+ ),
1541
+ )
1542
+ raise
1543
+
1544
+ finally:
1545
+ logger.info("main output response stream processing task exiting")
1546
+ self._is_sess_active.clear()
1547
+
1548
+ async def _restart_session(self, ex: Exception) -> None:
1549
+ # Get restart attempts from current generation, or 0 if no generation
1550
+ restart_attempts = (
1551
+ self._current_generation._restart_attempts if self._current_generation else 0
1552
+ )
1553
+
1554
+ if restart_attempts >= DEFAULT_MAX_SESSION_RESTART_ATTEMPTS:
1555
+ logger.error("Max restart attempts reached for this turn, exiting")
1556
+ err_msg = getattr(ex, "message", str(ex))
1557
+ request_id = None
1558
+ try:
1559
+ request_id = err_msg.split(" ")[0].split("=")[1]
1560
+ except Exception:
1561
+ pass
1562
+ self.emit(
1563
+ "error",
1564
+ llm.RealtimeModelError(
1565
+ timestamp=time.monotonic(),
1566
+ label=self._realtime_model._label,
1567
+ error=APIStatusError(
1568
+ message=f"Max restart attempts exceeded: {err_msg}",
1569
+ status_code=500,
1570
+ request_id=request_id,
1571
+ body=ex,
1572
+ retryable=False,
1573
+ ),
1574
+ recoverable=False,
1575
+ ),
1576
+ )
1577
+ self._is_sess_active.clear()
1578
+ return
1579
+
1580
+ # Increment restart counter for current generation
1581
+ if self._current_generation:
1582
+ self._current_generation._restart_attempts += 1
1583
+ restart_attempts = self._current_generation._restart_attempts
1584
+ else:
1585
+ restart_attempts = 1
1586
+
1587
+ self._is_sess_active.clear()
1588
+ delay = 2 ** (restart_attempts - 1) - 1
1589
+ await asyncio.sleep(min(delay, DEFAULT_MAX_SESSION_RESTART_DELAY))
1590
+ await self.initialize_streams(is_restart=True)
1591
+ logger.info(
1592
+ f"Turn restarted successfully ({restart_attempts}/{DEFAULT_MAX_SESSION_RESTART_ATTEMPTS})"
1593
+ )
1594
+
1595
+ @property
1596
+ def chat_ctx(self) -> llm.ChatContext:
1597
+ return self._chat_ctx.copy()
1598
+
1599
+ @property
1600
+ def tools(self) -> llm.ToolContext:
1601
+ return self._tools.copy()
1602
+
1603
+ async def update_instructions(self, instructions: str) -> None:
1604
+ """Injects the system prompt at the start of the session."""
1605
+ self._instructions = instructions
1606
+ if self._instructions_ready is None:
1607
+ self._instructions_ready = asyncio.get_running_loop().create_future()
1608
+ if not self._instructions_ready.done():
1609
+ self._instructions_ready.set_result(True)
1610
+ logger.debug(f"Instructions updated: {instructions}")
1611
+
1612
+ async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
1613
+ """Inject chat history and handle incremental user messages."""
1614
+ if self._chat_ctx_ready is None:
1615
+ self._chat_ctx_ready = asyncio.get_running_loop().create_future()
1616
+
1617
+ chat_ctx = chat_ctx.copy(
1618
+ exclude_handoff=True, exclude_instructions=True, exclude_empty_message=True
1619
+ )
1620
+
1621
+ # Initial context setup (once)
1622
+ if not self._chat_ctx_ready.done():
1623
+ self._chat_ctx = chat_ctx.copy()
1624
+ logger.debug(f"Chat context updated: {self._chat_ctx.items}")
1625
+ self._chat_ctx_ready.set_result(True)
1626
+
1627
+ # Process items in context
1628
+ for item in chat_ctx.items:
1629
+ # Handle tool results
1630
+ if item.type == "function_call_output":
1631
+ if item.call_id not in self._pending_tools:
1632
+ continue
1633
+
1634
+ logger.debug(f"function call output: {item}")
1635
+ self._pending_tools.discard(item.call_id)
1636
+
1637
+ # Format tool result as proper JSON
1638
+ if item.is_error:
1639
+ tool_result = json.dumps({"error": str(item.output)})
1640
+ else:
1641
+ tool_result = item.output
1642
+
1643
+ self._tool_results_ch.send_nowait(
1644
+ {
1645
+ "tool_use_id": item.call_id,
1646
+ "tool_result": tool_result,
1647
+ }
1648
+ )
1649
+ continue
1650
+
1651
+ # Handle new user messages (Nova 2.0 text input)
1652
+ # Only send if it's NOT an audio transcription (audio messages are tracked in _audio_message_ids)
1653
+ if (
1654
+ item.type == "message"
1655
+ and item.role == "user"
1656
+ and item.id not in self._sent_message_ids
1657
+ ):
1658
+ # Check if this is an audio message (already transcribed by Nova)
1659
+ if item.id not in self._audio_message_ids:
1660
+ if item.text_content:
1661
+ logger.debug(
1662
+ f"Sending user message as interactive text: {item.text_content}"
1663
+ )
1664
+ # Send interactive text to Nova Sonic (triggers generation)
1665
+ # This is the flow for generate_reply(user_input=...) from the framework
1666
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
1667
+ self._pending_generation_fut = fut
1668
+
1669
+ text = item.text_content
1670
+
1671
+ async def _send_user_text(
1672
+ text: str = text, fut: asyncio.Future = fut
1673
+ ) -> None:
1674
+ try:
1675
+ # Wait for session to be fully initialized before sending
1676
+ await self._is_sess_active.wait()
1677
+ await self._send_text_message(text, interactive=True)
1678
+ except Exception as e:
1679
+ if not fut.done():
1680
+ fut.set_exception(e)
1681
+ if self._pending_generation_fut is fut:
1682
+ self._pending_generation_fut = None
1683
+
1684
+ asyncio.create_task(_send_user_text())
1685
+
1686
+ self._sent_message_ids.add(item.id)
1687
+ self._chat_ctx.items.append(item)
1688
+ else:
1689
+ logger.debug(
1690
+ f"Skipping user message (already in context from audio): {item.text_content}"
1691
+ )
1692
+ self._sent_message_ids.add(item.id)
1693
+
1694
+ async def _send_tool_events(self, tool_use_id: str, tool_result: str) -> None:
1695
+ """Send tool_result back to Bedrock, grouped under tool_use_id."""
1696
+ tool_content_name = str(uuid.uuid4())
1697
+ tool_events = self._event_builder.create_tool_content_block(
1698
+ content_name=tool_content_name,
1699
+ tool_use_id=tool_use_id,
1700
+ content=tool_result,
1701
+ )
1702
+ for event in tool_events:
1703
+ await self._send_raw_event(event)
1704
+ # logger.debug(f"Sent tool event: {event}")
1705
+
1706
+ def _tool_choice_adapter(
1707
+ self, tool_choice: llm.ToolChoice | None
1708
+ ) -> dict[str, dict[str, str]] | None:
1709
+ """Translate the LiveKit ToolChoice enum into Sonic's JSON schema."""
1710
+ if tool_choice == "auto":
1711
+ return {"auto": {}}
1712
+ elif tool_choice == "required":
1713
+ return {"any": {}}
1714
+ elif isinstance(tool_choice, dict) and tool_choice["type"] == "function":
1715
+ return {"tool": {"name": tool_choice["function"]["name"]}}
1716
+ else:
1717
+ return None
1718
+
1719
+ # note: return value from tool functions registered to Sonic must be Structured Output (a dict that is JSON serializable) # noqa: E501
1720
+ async def update_tools(self, tools: list[llm.FunctionTool | llm.RawFunctionTool | Any]) -> None:
1721
+ """Replace the active tool set with tools and notify Sonic if necessary."""
1722
+ logger.debug(f"Updating tools: {tools}")
1723
+ retained_tools: list[llm.FunctionTool | llm.RawFunctionTool] = []
1724
+
1725
+ for tool in tools:
1726
+ retained_tools.append(tool)
1727
+ self._tools = llm.ToolContext(retained_tools)
1728
+ if retained_tools:
1729
+ if self._tools_ready is None:
1730
+ self._tools_ready = asyncio.get_running_loop().create_future()
1731
+ if not self._tools_ready.done():
1732
+ self._tools_ready.set_result(True)
1733
+ logger.debug("Tool list has been injected")
1734
+
1735
+ def update_options(self, *, tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN) -> None:
1736
+ """Live update of inference options is not supported by Sonic yet."""
1737
+ logger.warning(
1738
+ "updating inference configuration options is not yet supported by Nova Sonic's Realtime API" # noqa: E501
1739
+ )
1740
+
1741
+ @utils.log_exceptions(logger=logger)
1742
+ def _resample_audio(self, frame: rtc.AudioFrame) -> Iterator[rtc.AudioFrame]:
1743
+ """Ensure mic audio matches Sonic's required sample rate & channels."""
1744
+ if self._input_resampler:
1745
+ if frame.sample_rate != self._input_resampler._input_rate:
1746
+ self._input_resampler = None
1747
+
1748
+ if self._input_resampler is None and (
1749
+ frame.sample_rate != DEFAULT_INPUT_SAMPLE_RATE or frame.num_channels != DEFAULT_CHANNELS
1750
+ ):
1751
+ self._input_resampler = rtc.AudioResampler(
1752
+ input_rate=frame.sample_rate,
1753
+ output_rate=DEFAULT_INPUT_SAMPLE_RATE,
1754
+ num_channels=DEFAULT_CHANNELS,
1755
+ )
1756
+
1757
+ if self._input_resampler:
1758
+ # flush the resampler when the input source is changed
1759
+ yield from self._input_resampler.push(frame)
1760
+ else:
1761
+ yield frame
1762
+
1763
+ @utils.log_exceptions(logger=logger)
1764
+ async def _process_audio_input(self) -> None:
1765
+ """Background task that feeds audio and tool results into the Bedrock stream."""
1766
+ await self._send_raw_event(self._event_builder.create_audio_content_start_event())
1767
+ logger.info("Starting audio input processing loop")
1768
+
1769
+ # Create tasks for both channels so we can wait on either
1770
+ audio_task = asyncio.create_task(self._audio_input_chan.recv())
1771
+ tool_task = asyncio.create_task(self._tool_results_ch.recv())
1772
+ pending = {audio_task, tool_task}
1773
+
1774
+ while self._is_sess_active.is_set():
1775
+ try:
1776
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
1777
+
1778
+ for task in done:
1779
+ if task == audio_task:
1780
+ try:
1781
+ audio_bytes = cast(bytes, task.result())
1782
+ blob = base64.b64encode(audio_bytes)
1783
+ audio_event = self._event_builder.create_audio_input_event(
1784
+ audio_content=blob.decode("utf-8"),
1785
+ )
1786
+ await self._send_raw_event(audio_event)
1787
+ # Create new task for next audio
1788
+ audio_task = asyncio.create_task(self._audio_input_chan.recv())
1789
+ pending.add(audio_task)
1790
+ except utils.aio.channel.ChanClosed:
1791
+ logger.warning("audio input channel closed")
1792
+ break
1793
+
1794
+ elif task == tool_task:
1795
+ try:
1796
+ val = cast(dict[str, str], task.result())
1797
+ tool_result = val["tool_result"]
1798
+ tool_use_id = val["tool_use_id"]
1799
+ if not isinstance(tool_result, str):
1800
+ tool_result = json.dumps(tool_result)
1801
+ else:
1802
+ try:
1803
+ json.loads(tool_result)
1804
+ except json.JSONDecodeError:
1805
+ try:
1806
+ tool_result = json.dumps(ast.literal_eval(tool_result))
1807
+ except Exception:
1808
+ pass
1809
+
1810
+ logger.debug(f"Sending tool result: {tool_result}")
1811
+ await self._send_tool_events(tool_use_id, tool_result)
1812
+ # Create new task for next tool result
1813
+ tool_task = asyncio.create_task(self._tool_results_ch.recv())
1814
+ pending.add(tool_task)
1815
+ except utils.aio.channel.ChanClosed:
1816
+ logger.warning("tool results channel closed")
1817
+ break
1818
+
1819
+ except asyncio.CancelledError:
1820
+ logger.info("Audio processing loop cancelled")
1821
+ # Cancel pending tasks
1822
+ for task in pending:
1823
+ task.cancel()
1824
+ self._audio_input_chan.close()
1825
+ self._tool_results_ch.close()
1826
+ raise
1827
+ except Exception:
1828
+ logger.exception("Error processing audio")
1829
+
1830
+ # for debugging purposes only
1831
+ def _log_significant_audio(self, audio_bytes: bytes) -> None:
1832
+ """Utility that prints a debug message when the audio chunk has non-trivial RMS energy."""
1833
+ squared_sum = sum(sample**2 for sample in audio_bytes)
1834
+ if (squared_sum / len(audio_bytes)) ** 0.5 > 200:
1835
+ if lk_bedrock_debug:
1836
+ log_message("Enqueuing significant audio chunk", AnsiColors.BLUE)
1837
+
1838
+ @utils.log_exceptions(logger=logger)
1839
+ def push_audio(self, frame: rtc.AudioFrame) -> None:
1840
+ """Enqueue an incoming mic rtc.AudioFrame for transcription."""
1841
+ if not self._audio_input_chan.closed:
1842
+ # logger.debug(f"Raw audio received: samples={len(frame.data)} rate={frame.sample_rate} channels={frame.num_channels}") # noqa: E501
1843
+ for f in self._resample_audio(frame):
1844
+ # logger.debug(f"Resampled audio: samples={len(frame.data)} rate={frame.sample_rate} channels={frame.num_channels}") # noqa: E501
1845
+
1846
+ for nf in self._bstream.write(f.data.tobytes()):
1847
+ self._log_significant_audio(nf.data)
1848
+ self._audio_input_chan.send_nowait(nf.data)
1849
+ else:
1850
+ logger.warning("audio input channel closed, skipping audio")
1851
+
1852
+ def generate_reply(
1853
+ self,
1854
+ *,
1855
+ instructions: NotGivenOr[str] = NOT_GIVEN,
1856
+ ) -> asyncio.Future[llm.GenerationCreatedEvent]:
1857
+ """Generate a reply from the model.
1858
+
1859
+ This method is called by the LiveKit framework's AgentSession.generate_reply() and
1860
+ AgentActivity._realtime_reply_task(). The framework handles user_input by adding it
1861
+ to the chat context via update_chat_ctx() before calling this method.
1862
+
1863
+ Flow for user_input:
1864
+ 1. Framework receives generate_reply with user_input parameter
1865
+ 2. Framework adds user message to chat context
1866
+ 3. Framework calls update_chat_ctx() (which sends the message to Nova Sonic)
1867
+ 4. Framework calls this method no parameters
1868
+ 5. This method trigger Nova Sonic's response based on the last context message add
1869
+
1870
+ Flow for instructions:
1871
+ 1. Framework receives generate_reply with instructions parameter
1872
+ 2. Framework calls this method instructions parameter
1873
+ 3. This method sends instructions as a prompt to Nova Sonic and triggers a response.
1874
+
1875
+ If both parameters are sent, the same flow will strip the user_input out of the initial call
1876
+ and send the instructions on to this method.
1877
+
1878
+ For Nova Sonic 2.0 and any supporting model:
1879
+ - Sends instructions as interactive text if provided
1880
+ - Triggers model response generation
1881
+
1882
+ For Nova Sonic 1.0:
1883
+ - Not supported (no text input capability)
1884
+ - Logs warning and returns empty future
1885
+
1886
+ Args:
1887
+ instructions (NotGivenOr[str]): Additional instructions to guide the response.
1888
+ These are sent as system-level prompts to influence how the model responds.
1889
+ User input should be added via update_chat_ctx(), not passed here.
1890
+
1891
+ Returns:
1892
+ asyncio.Future[llm.GenerationCreatedEvent]: Future that resolves when generation starts.
1893
+ Raises RealtimeError on timeout (default: 10s).
1894
+
1895
+ Note:
1896
+ User messages flow through AgentSession.generate_reply(user_input=...) →
1897
+ update_chat_ctx() which sends interactive text to Nova Sonic.
1898
+ This method handles the instructions parameter for system-level prompts.
1899
+ """
1900
+ # Check if generate_reply is supported (requires mixed modalities)
1901
+ if self._realtime_model.modalities != "mixed":
1902
+ logger.warning(
1903
+ "generate_reply() is not supported by this model (requires mixed modalities). "
1904
+ "Skipping generate_reply call. Use modalities='mixed' or Nova Sonic 2.0 "
1905
+ "to enable this feature."
1906
+ )
1907
+
1908
+ # Return a completed future with empty streams so the caller doesn't hang
1909
+ async def _empty_message_stream() -> AsyncIterator[llm.MessageGeneration]:
1910
+ return
1911
+ yield # Make it an async generator
1912
+
1913
+ async def _empty_function_stream() -> AsyncIterator[llm.FunctionCall]:
1914
+ return
1915
+ yield # Make it an async generator
1916
+
1917
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
1918
+ fut.set_result(
1919
+ llm.GenerationCreatedEvent(
1920
+ message_stream=_empty_message_stream(),
1921
+ function_stream=_empty_function_stream(),
1922
+ user_initiated=True,
1923
+ )
1924
+ )
1925
+ return fut
1926
+
1927
+ # Nova 2.0: Only send if instructions provided
1928
+ if is_given(instructions):
1929
+ logger.info(f"generate_reply: sending instructions='{instructions}'")
1930
+
1931
+ # Create future that will be resolved when generation starts
1932
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
1933
+ self._pending_generation_fut = fut
1934
+
1935
+ # Send text message asynchronously
1936
+ async def _send_text() -> None:
1937
+ try:
1938
+ # Wait for session to be fully initialized before sending
1939
+ await self._is_sess_active.wait()
1940
+ await self._send_text_message(instructions, interactive=True)
1941
+ except Exception as e:
1942
+ if not fut.done():
1943
+ fut.set_exception(e)
1944
+ if self._pending_generation_fut is fut:
1945
+ self._pending_generation_fut = None
1946
+
1947
+ asyncio.create_task(_send_text())
1948
+
1949
+ # Set timeout from model configuration
1950
+ def _on_timeout() -> None:
1951
+ if not fut.done():
1952
+ fut.set_exception(
1953
+ llm.RealtimeError("generate_reply timed out waiting for generation")
1954
+ )
1955
+ if self._pending_generation_fut is fut:
1956
+ self._pending_generation_fut = None
1957
+
1958
+ timeout_handle = asyncio.get_running_loop().call_later(
1959
+ self._realtime_model._generate_reply_timeout, _on_timeout
1960
+ )
1961
+ fut.add_done_callback(lambda _: timeout_handle.cancel())
1962
+
1963
+ return fut
1964
+
1965
+ # No instructions: Return pending generation if exists, otherwise create empty future that never resolves
1966
+ # (Framework will timeout naturally if no generation happens)
1967
+ if self._pending_generation_fut is not None:
1968
+ logger.debug("generate_reply: no instructions, returning existing pending generation")
1969
+ return self._pending_generation_fut
1970
+
1971
+ logger.debug(
1972
+ "generate_reply: no instructions and no pending generation, returning empty future"
1973
+ )
1974
+ return asyncio.Future[llm.GenerationCreatedEvent]()
1975
+
1976
+ async def _send_text_message(self, text: str, interactive: bool = True) -> None:
1977
+ """Internal method to send text message to Nova Sonic 2.0.
1978
+
1979
+ Args:
1980
+ text (str): The text message to send to the model.
1981
+ interactive (bool): If True, triggers generation. If False, adds to context only.
1982
+ """
1983
+ # Generate unique content_name for this message (required for multi-turn)
1984
+ content_name = str(uuid.uuid4())
1985
+
1986
+ # Choose appropriate event builder based on interactive flag
1987
+ if interactive:
1988
+ event = self._event_builder.create_text_content_start_event_interactive(
1989
+ content_name=content_name, role="USER"
1990
+ )
1991
+ else:
1992
+ event = self._event_builder.create_text_content_start_event(
1993
+ content_name=content_name, role="USER"
1994
+ )
1995
+
1996
+ # Send event sequence: contentStart → textInput → contentEnd
1997
+ await self._send_raw_event(event)
1998
+ await self._send_raw_event(
1999
+ self._event_builder.create_text_content_event(content_name, text)
2000
+ )
2001
+ await self._send_raw_event(self._event_builder.create_content_end_event(content_name))
2002
+ logger.info(
2003
+ f"Sent text message (interactive={interactive}): {text[:50]}{'...' if len(text) > 50 else ''}"
2004
+ )
2005
+
2006
+ def commit_audio(self) -> None:
2007
+ logger.warning("commit_audio is not supported by Nova Sonic's Realtime API")
2008
+
2009
+ def clear_audio(self) -> None:
2010
+ logger.warning("clear_audio is not supported by Nova Sonic's Realtime API")
2011
+
2012
+ def push_video(self, frame: rtc.VideoFrame) -> None:
2013
+ logger.warning("video is not supported by Nova Sonic's Realtime API")
2014
+
2015
+ def interrupt(self) -> None:
2016
+ """Nova Sonic handles interruption automatically via barge-in detection.
2017
+
2018
+ Unlike OpenAI's client-initiated interrupt, Nova Sonic automatically detects
2019
+ when the user starts speaking while the model is generating audio. When this
2020
+ happens, the model:
2021
+ 1. Immediately stops generating speech
2022
+ 2. Switches to listening mode
2023
+ 3. Sends a text event with content: { "interrupted" : true }
2024
+
2025
+ The plugin already handles this event (see _handle_text_output_content_event).
2026
+ No client action is needed - interruption works automatically.
2027
+
2028
+ See AWS docs: https://docs.aws.amazon.com/nova/latest/userguide/output-events.html
2029
+ """
2030
+ logger.info(
2031
+ "Nova Sonic handles interruption automatically via barge-in detection. "
2032
+ "The model detects when users start speaking and stops generation automatically."
2033
+ )
2034
+
2035
+ def truncate(
2036
+ self,
2037
+ *,
2038
+ message_id: str,
2039
+ modalities: list[Literal["text", "audio"]],
2040
+ audio_end_ms: int,
2041
+ audio_transcript: NotGivenOr[str] = NOT_GIVEN,
2042
+ ) -> None:
2043
+ logger.warning("truncate is not supported by Nova Sonic's Realtime API")
2044
+
2045
+ @utils.log_exceptions(logger=logger)
2046
+ async def aclose(self) -> None:
2047
+ """Gracefully shut down the realtime session and release network resources."""
2048
+ logger.info("attempting to shutdown agent session")
2049
+ if not self._is_sess_active.is_set():
2050
+ logger.info("agent session already inactive")
2051
+ return
2052
+
2053
+ # Cancel any pending generation futures
2054
+ if self._pending_generation_fut and not self._pending_generation_fut.done():
2055
+ self._pending_generation_fut.set_exception(
2056
+ llm.RealtimeError("Session closed while waiting for generation")
2057
+ )
2058
+ self._pending_generation_fut = None
2059
+
2060
+ for event in self._event_builder.create_prompt_end_block():
2061
+ await self._send_raw_event(event)
2062
+ # allow event loops to fall out naturally
2063
+ # otherwise, the smithy layer will raise an InvalidStateError during cancellation
2064
+ self._is_sess_active.clear()
2065
+
2066
+ if self._stream_response and not self._stream_response.output_stream.closed:
2067
+ await self._stream_response.output_stream.close()
2068
+
2069
+ # note: even after the self.is_active flag is flipped and the output stream is closed,
2070
+ # there is a future inside output_stream.receive() at the AWS-CRT C layer that blocks
2071
+ # resulting in an error after cancellation
2072
+ # however, it's mostly cosmetic-- the event loop will still exit
2073
+ # TODO: fix this nit
2074
+ tasks: list[asyncio.Task[Any]] = []
2075
+
2076
+ # Cancel session recycle timer
2077
+ if self._session_recycle_task and not self._session_recycle_task.done():
2078
+ self._session_recycle_task.cancel()
2079
+ try:
2080
+ await self._session_recycle_task
2081
+ except asyncio.CancelledError:
2082
+ pass
2083
+
2084
+ if self._response_task:
2085
+ try:
2086
+ await asyncio.wait_for(self._response_task, timeout=1.0)
2087
+ except asyncio.TimeoutError:
2088
+ logger.warning("shutdown of output event loop timed out-- cancelling")
2089
+ self._response_task.cancel()
2090
+ tasks.append(self._response_task)
2091
+
2092
+ # must cancel the audio input task before closing the input stream
2093
+ if self._audio_input_task and not self._audio_input_task.done():
2094
+ self._audio_input_task.cancel()
2095
+ tasks.append(self._audio_input_task)
2096
+ if self._stream_response and not self._stream_response.input_stream.closed:
2097
+ await self._stream_response.input_stream.close()
2098
+
2099
+ # cancel main task to prevent pending task warnings
2100
+ if self._main_atask and not self._main_atask.done():
2101
+ self._main_atask.cancel()
2102
+ tasks.append(self._main_atask)
2103
+
2104
+ await asyncio.gather(*tasks, return_exceptions=True)
2105
+ logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
2106
+ logger.info("Session end")