livekit-plugins-aws 1.2.4__py3-none-any.whl → 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of livekit-plugins-aws might be problematic. Click here for more details.
- livekit/plugins/aws/experimental/realtime/realtime_model.py +385 -271
- livekit/plugins/aws/llm.py +23 -2
- livekit/plugins/aws/models.py +1 -0
- livekit/plugins/aws/stt.py +56 -14
- livekit/plugins/aws/tts.py +36 -5
- livekit/plugins/aws/version.py +1 -1
- {livekit_plugins_aws-1.2.4.dist-info → livekit_plugins_aws-1.3.2.dist-info}/METADATA +5 -4
- {livekit_plugins_aws-1.2.4.dist-info → livekit_plugins_aws-1.3.2.dist-info}/RECORD +9 -9
- {livekit_plugins_aws-1.2.4.dist-info → livekit_plugins_aws-1.3.2.dist-info}/WHEEL +0 -0
|
@@ -12,7 +12,6 @@ import uuid
|
|
|
12
12
|
import weakref
|
|
13
13
|
from collections.abc import Iterator
|
|
14
14
|
from dataclasses import dataclass, field
|
|
15
|
-
from datetime import datetime
|
|
16
15
|
from typing import Any, Callable, Literal, cast
|
|
17
16
|
|
|
18
17
|
import boto3
|
|
@@ -26,6 +25,7 @@ from aws_sdk_bedrock_runtime.models import (
|
|
|
26
25
|
InvokeModelWithBidirectionalStreamInputChunk,
|
|
27
26
|
ModelErrorException,
|
|
28
27
|
ModelNotReadyException,
|
|
28
|
+
ModelStreamErrorException,
|
|
29
29
|
ModelTimeoutException,
|
|
30
30
|
ThrottlingException,
|
|
31
31
|
ValidationException,
|
|
@@ -40,6 +40,7 @@ from livekit.agents import (
|
|
|
40
40
|
utils,
|
|
41
41
|
)
|
|
42
42
|
from livekit.agents.metrics import RealtimeModelMetrics
|
|
43
|
+
from livekit.agents.metrics.base import Metadata
|
|
43
44
|
from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
|
44
45
|
from livekit.agents.utils import is_given
|
|
45
46
|
from livekit.plugins.aws.experimental.realtime.turn_tracker import _TurnTracker
|
|
@@ -122,38 +123,35 @@ class _MessageGeneration:
|
|
|
122
123
|
|
|
123
124
|
@dataclass
|
|
124
125
|
class _ResponseGeneration:
|
|
125
|
-
"""Book-keeping dataclass tracking the lifecycle of a Sonic
|
|
126
|
+
"""Book-keeping dataclass tracking the lifecycle of a Nova Sonic completion.
|
|
126
127
|
|
|
127
|
-
|
|
128
|
-
|
|
128
|
+
Nova Sonic uses a completion model where one completionStart event begins a cycle
|
|
129
|
+
that may contain multiple content blocks (USER ASR, TOOL, ASSISTANT text/audio).
|
|
130
|
+
This generation stays open for the entire completion cycle.
|
|
129
131
|
|
|
130
132
|
Attributes:
|
|
131
|
-
|
|
133
|
+
completion_id (str): Nova Sonic's completionId that ties all events together.
|
|
134
|
+
message_ch (utils.aio.Chan[llm.MessageGeneration]): Stream for assistant messages.
|
|
132
135
|
function_ch (utils.aio.Chan[llm.FunctionCall]): Stream that emits function tool calls.
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
speculative_messages (dict[str, str]): Map Bedrock content_id -> response_id (assistant side).
|
|
138
|
-
tool_messages (dict[str, str]): Map Bedrock content_id -> response_id for tool calls.
|
|
139
|
-
output_text (str): Accumulated assistant text (only used for metrics / debugging).
|
|
140
|
-
_created_timestamp (str): ISO-8601 timestamp when the generation record was created.
|
|
136
|
+
response_id (str): LiveKit response_id for the assistant's response.
|
|
137
|
+
message_gen (_MessageGeneration | None): Current message generation for assistant output.
|
|
138
|
+
content_id_map (dict[str, str]): Map Nova Sonic contentId -> type (USER/ASSISTANT/TOOL).
|
|
139
|
+
_created_timestamp (float): Wall-clock time when the generation record was created.
|
|
141
140
|
_first_token_timestamp (float | None): Wall-clock time of first token emission.
|
|
142
141
|
_completed_timestamp (float | None): Wall-clock time when the turn fully completed.
|
|
142
|
+
_restart_attempts (int): Number of restart attempts for this specific completion.
|
|
143
143
|
""" # noqa: E501
|
|
144
144
|
|
|
145
|
+
completion_id: str
|
|
145
146
|
message_ch: utils.aio.Chan[llm.MessageGeneration]
|
|
146
147
|
function_ch: utils.aio.Chan[llm.FunctionCall]
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
speculative_messages: dict[str, str] = field(default_factory=dict)
|
|
152
|
-
tool_messages: dict[str, str] = field(default_factory=dict)
|
|
153
|
-
output_text: str = "" # agent ASR text
|
|
154
|
-
_created_timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
148
|
+
response_id: str
|
|
149
|
+
message_gen: _MessageGeneration | None = None
|
|
150
|
+
content_id_map: dict[str, str] = field(default_factory=dict)
|
|
151
|
+
_created_timestamp: float = field(default_factory=time.time)
|
|
155
152
|
_first_token_timestamp: float | None = None
|
|
156
153
|
_completed_timestamp: float | None = None
|
|
154
|
+
_restart_attempts: int = 0
|
|
157
155
|
|
|
158
156
|
|
|
159
157
|
class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc]
|
|
@@ -240,6 +238,7 @@ class RealtimeModel(llm.RealtimeModel):
|
|
|
240
238
|
user_transcription=True,
|
|
241
239
|
auto_tool_reply_generation=True,
|
|
242
240
|
audio_output=True,
|
|
241
|
+
manual_function_calls=False,
|
|
243
242
|
)
|
|
244
243
|
)
|
|
245
244
|
self.model_id = "amazon.nova-sonic-v1:0"
|
|
@@ -258,19 +257,23 @@ class RealtimeModel(llm.RealtimeModel):
|
|
|
258
257
|
)
|
|
259
258
|
self._sessions = weakref.WeakSet[RealtimeSession]()
|
|
260
259
|
|
|
260
|
+
@property
|
|
261
|
+
def model(self) -> str:
|
|
262
|
+
return self.model_id
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def provider(self) -> str:
|
|
266
|
+
return "Amazon"
|
|
267
|
+
|
|
261
268
|
def session(self) -> RealtimeSession:
|
|
262
269
|
"""Return a new RealtimeSession bound to this model instance."""
|
|
263
270
|
sess = RealtimeSession(self)
|
|
264
|
-
|
|
265
|
-
# note: this is a hack to get the session to initialize itself
|
|
266
|
-
# TODO: change how RealtimeSession is initialized by creating a single task main_atask that spawns subtasks # noqa: E501
|
|
267
|
-
asyncio.create_task(sess.initialize_streams())
|
|
268
271
|
self._sessions.add(sess)
|
|
269
272
|
return sess
|
|
270
273
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
+
async def aclose(self) -> None:
|
|
275
|
+
"""Close all active sessions."""
|
|
276
|
+
pass
|
|
274
277
|
|
|
275
278
|
|
|
276
279
|
class RealtimeSession( # noqa: F811
|
|
@@ -316,17 +319,16 @@ class RealtimeSession( # noqa: F811
|
|
|
316
319
|
self._chat_ctx = llm.ChatContext.empty()
|
|
317
320
|
self._tools = llm.ToolContext.empty()
|
|
318
321
|
self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
+
# CRITICAL: Initialize futures as None for lazy creation
|
|
323
|
+
# Creating futures in __init__ causes race conditions during session restart.
|
|
324
|
+
# Futures are created in initialize_streams() when the event loop is guaranteed to exist.
|
|
325
|
+
self._tools_ready: asyncio.Future[bool] | None = None
|
|
326
|
+
self._instructions_ready: asyncio.Future[bool] | None = None
|
|
327
|
+
self._chat_ctx_ready: asyncio.Future[bool] | None = None
|
|
322
328
|
self._instructions = DEFAULT_SYSTEM_PROMPT
|
|
323
329
|
self._audio_input_chan = utils.aio.Chan[bytes]()
|
|
324
330
|
self._current_generation: _ResponseGeneration | None = None
|
|
325
331
|
|
|
326
|
-
# note: currently tracks session restart attempts across all sessions
|
|
327
|
-
# TODO: track restart attempts per turn
|
|
328
|
-
self._session_restart_attempts = 0
|
|
329
|
-
|
|
330
332
|
self._event_handlers = {
|
|
331
333
|
"completion_start": self._handle_completion_start_event,
|
|
332
334
|
"audio_output_content_start": self._handle_audio_output_content_start_event,
|
|
@@ -347,6 +349,11 @@ class RealtimeSession( # noqa: F811
|
|
|
347
349
|
cast(Callable[[], None], self.emit_generation_event),
|
|
348
350
|
)
|
|
349
351
|
|
|
352
|
+
# Create main task to manage session lifecycle
|
|
353
|
+
self._main_atask = asyncio.create_task(
|
|
354
|
+
self.initialize_streams(), name="RealtimeSession.initialize_streams"
|
|
355
|
+
)
|
|
356
|
+
|
|
350
357
|
@utils.log_exceptions(logger=logger)
|
|
351
358
|
def _initialize_client(self) -> None:
|
|
352
359
|
"""Instantiate the Bedrock runtime client"""
|
|
@@ -480,6 +487,14 @@ class RealtimeSession( # noqa: F811
|
|
|
480
487
|
)
|
|
481
488
|
|
|
482
489
|
if not is_restart:
|
|
490
|
+
# Lazy-initialize futures if needed
|
|
491
|
+
if self._tools_ready is None:
|
|
492
|
+
self._tools_ready = asyncio.get_running_loop().create_future()
|
|
493
|
+
if self._instructions_ready is None:
|
|
494
|
+
self._instructions_ready = asyncio.get_running_loop().create_future()
|
|
495
|
+
if self._chat_ctx_ready is None:
|
|
496
|
+
self._chat_ctx_ready = asyncio.get_running_loop().create_future()
|
|
497
|
+
|
|
483
498
|
pending_events: list[asyncio.Future] = []
|
|
484
499
|
if not self.tools.function_tools:
|
|
485
500
|
pending_events.append(self._tools_ready)
|
|
@@ -495,14 +510,14 @@ class RealtimeSession( # noqa: F811
|
|
|
495
510
|
if pending_events:
|
|
496
511
|
await asyncio.wait_for(asyncio.gather(*pending_events), timeout=0.5)
|
|
497
512
|
except asyncio.TimeoutError:
|
|
498
|
-
if not self._tools_ready.done():
|
|
513
|
+
if self._tools_ready and not self._tools_ready.done():
|
|
499
514
|
logger.warning("Tools not ready after 500ms, continuing without them")
|
|
500
515
|
|
|
501
|
-
if not self._instructions_ready.done():
|
|
516
|
+
if self._instructions_ready and not self._instructions_ready.done():
|
|
502
517
|
logger.warning(
|
|
503
518
|
"Instructions not received after 500ms, proceeding with default instructions" # noqa: E501
|
|
504
519
|
)
|
|
505
|
-
if not self._chat_ctx_ready.done():
|
|
520
|
+
if self._chat_ctx_ready and not self._chat_ctx_ready.done():
|
|
506
521
|
logger.warning(
|
|
507
522
|
"Chat context not received after 500ms, proceeding with empty chat context" # noqa: E501
|
|
508
523
|
)
|
|
@@ -549,13 +564,16 @@ class RealtimeSession( # noqa: F811
|
|
|
549
564
|
@utils.log_exceptions(logger=logger)
|
|
550
565
|
def emit_generation_event(self) -> None:
|
|
551
566
|
"""Publish a llm.GenerationCreatedEvent to external subscribers."""
|
|
552
|
-
|
|
553
|
-
|
|
567
|
+
if self._current_generation is None:
|
|
568
|
+
logger.debug("emit_generation_event called but no generation exists - ignoring")
|
|
569
|
+
return
|
|
554
570
|
|
|
571
|
+
logger.debug("Emitting generation event")
|
|
555
572
|
generation_ev = llm.GenerationCreatedEvent(
|
|
556
573
|
message_stream=self._current_generation.message_ch,
|
|
557
574
|
function_stream=self._current_generation.function_ch,
|
|
558
575
|
user_initiated=False,
|
|
576
|
+
response_id=self._current_generation.response_id,
|
|
559
577
|
)
|
|
560
578
|
self.emit("generation_created", generation_ev)
|
|
561
579
|
|
|
@@ -571,21 +589,38 @@ class RealtimeSession( # noqa: F811
|
|
|
571
589
|
logger.warning(f"No event handler found for event type: {event_type}")
|
|
572
590
|
|
|
573
591
|
async def _handle_completion_start_event(self, event_data: dict) -> None:
|
|
592
|
+
"""Handle completionStart - create new generation for this completion cycle."""
|
|
574
593
|
log_event_data(event_data)
|
|
575
594
|
self._create_response_generation()
|
|
576
595
|
|
|
577
596
|
def _create_response_generation(self) -> None:
|
|
578
|
-
"""Instantiate _ResponseGeneration and emit the GenerationCreated event.
|
|
597
|
+
"""Instantiate _ResponseGeneration and emit the GenerationCreated event.
|
|
598
|
+
|
|
599
|
+
Can be called multiple times - will reuse existing generation but ensure
|
|
600
|
+
message structure exists.
|
|
601
|
+
"""
|
|
602
|
+
generation_created = False
|
|
579
603
|
if self._current_generation is None:
|
|
604
|
+
completion_id = "unknown" # Will be set from events
|
|
605
|
+
response_id = str(uuid.uuid4())
|
|
606
|
+
|
|
607
|
+
logger.debug(f"Creating new generation, response_id={response_id}")
|
|
580
608
|
self._current_generation = _ResponseGeneration(
|
|
609
|
+
completion_id=completion_id,
|
|
581
610
|
message_ch=utils.aio.Chan(),
|
|
582
611
|
function_ch=utils.aio.Chan(),
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
612
|
+
response_id=response_id,
|
|
613
|
+
)
|
|
614
|
+
generation_created = True
|
|
615
|
+
else:
|
|
616
|
+
logger.debug(
|
|
617
|
+
f"Generation already exists: response_id={self._current_generation.response_id}"
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# Always ensure message structure exists (even if generation already exists)
|
|
621
|
+
if self._current_generation.message_gen is None:
|
|
622
|
+
logger.debug(
|
|
623
|
+
f"Creating message structure for response_id={self._current_generation.response_id}"
|
|
589
624
|
)
|
|
590
625
|
msg_gen = _MessageGeneration(
|
|
591
626
|
message_id=self._current_generation.response_id,
|
|
@@ -596,6 +631,8 @@ class RealtimeSession( # noqa: F811
|
|
|
596
631
|
msg_modalities.set_result(
|
|
597
632
|
["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
|
|
598
633
|
)
|
|
634
|
+
|
|
635
|
+
self._current_generation.message_gen = msg_gen
|
|
599
636
|
self._current_generation.message_ch.send_nowait(
|
|
600
637
|
llm.MessageGeneration(
|
|
601
638
|
message_id=msg_gen.message_id,
|
|
@@ -604,77 +641,97 @@ class RealtimeSession( # noqa: F811
|
|
|
604
641
|
modalities=msg_modalities,
|
|
605
642
|
)
|
|
606
643
|
)
|
|
607
|
-
|
|
644
|
+
else:
|
|
645
|
+
logger.debug(
|
|
646
|
+
f"Message structure already exists for response_id={self._current_generation.response_id}"
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
# Only emit generation event if we created a new generation
|
|
650
|
+
if generation_created:
|
|
651
|
+
self.emit_generation_event()
|
|
608
652
|
|
|
609
653
|
# will be completely ignoring post-ASR text events
|
|
610
654
|
async def _handle_text_output_content_start_event(self, event_data: dict) -> None:
|
|
611
|
-
"""Handle text_output_content_start
|
|
655
|
+
"""Handle text_output_content_start - track content type."""
|
|
612
656
|
log_event_data(event_data)
|
|
657
|
+
|
|
613
658
|
role = event_data["event"]["contentStart"]["role"]
|
|
614
|
-
self._create_response_generation()
|
|
615
659
|
|
|
616
|
-
#
|
|
617
|
-
|
|
618
|
-
|
|
660
|
+
# CRITICAL: Create NEW generation for each ASSISTANT SPECULATIVE response
|
|
661
|
+
# Nova Sonic sends ASSISTANT SPECULATIVE for each new assistant turn, including after tool calls.
|
|
662
|
+
# Without this, audio frames get routed to the wrong generation and don't play.
|
|
663
|
+
if role == "ASSISTANT":
|
|
664
|
+
additional_fields = event_data["event"]["contentStart"].get("additionalModelFields", "")
|
|
665
|
+
if "SPECULATIVE" in additional_fields:
|
|
666
|
+
# This is a new assistant response - close previous and create new
|
|
667
|
+
logger.debug("ASSISTANT SPECULATIVE text - creating new generation")
|
|
668
|
+
if self._current_generation is not None:
|
|
669
|
+
logger.debug("Closing previous generation for new assistant response")
|
|
670
|
+
self._close_current_generation()
|
|
671
|
+
self._create_response_generation()
|
|
672
|
+
else:
|
|
673
|
+
# For USER and FINAL, just ensure generation exists
|
|
674
|
+
self._create_response_generation()
|
|
619
675
|
|
|
620
|
-
|
|
621
|
-
|
|
676
|
+
# CRITICAL: Check if generation exists before accessing
|
|
677
|
+
# Barge-in can set _current_generation to None between the creation above and here.
|
|
678
|
+
# Without this check, we crash on interruptions.
|
|
679
|
+
if self._current_generation is None:
|
|
680
|
+
logger.debug("No generation exists - ignoring content_start event")
|
|
681
|
+
return
|
|
622
682
|
|
|
623
|
-
|
|
624
|
-
role == "ASSISTANT"
|
|
625
|
-
and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
|
|
626
|
-
):
|
|
627
|
-
assert self._current_generation is not None, "current_generation is None"
|
|
683
|
+
content_id = event_data["event"]["contentStart"]["contentId"]
|
|
628
684
|
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
685
|
+
# Track what type of content this is
|
|
686
|
+
if role == "USER":
|
|
687
|
+
self._current_generation.content_id_map[content_id] = "USER_ASR"
|
|
688
|
+
elif role == "ASSISTANT":
|
|
689
|
+
additional_fields = event_data["event"]["contentStart"].get("additionalModelFields", "")
|
|
690
|
+
if "SPECULATIVE" in additional_fields:
|
|
691
|
+
self._current_generation.content_id_map[content_id] = "ASSISTANT_TEXT"
|
|
692
|
+
elif "FINAL" in additional_fields:
|
|
693
|
+
self._current_generation.content_id_map[content_id] = "ASSISTANT_FINAL"
|
|
633
694
|
|
|
634
695
|
async def _handle_text_output_content_event(self, event_data: dict) -> None:
|
|
635
|
-
"""Stream partial text tokens into the current
|
|
696
|
+
"""Stream partial text tokens into the current generation."""
|
|
636
697
|
log_event_data(event_data)
|
|
637
|
-
|
|
698
|
+
|
|
699
|
+
if self._current_generation is None:
|
|
700
|
+
logger.debug("No generation exists - ignoring text_output event")
|
|
701
|
+
return
|
|
702
|
+
|
|
703
|
+
content_id = event_data["event"]["textOutput"]["contentId"]
|
|
638
704
|
text_content = f"{event_data['event']['textOutput']['content']}\n"
|
|
639
705
|
|
|
640
|
-
#
|
|
706
|
+
# Nova Sonic's automatic barge-in detection
|
|
641
707
|
if text_content == '{ "interrupted" : true }\n':
|
|
642
|
-
# the interrupted flag is not being set correctly in chat_ctx
|
|
643
|
-
# this is b/c audio playback is desynced from text transcription
|
|
644
|
-
# TODO: fix this; possibly via a playback timer
|
|
645
708
|
idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
|
|
646
|
-
if idx
|
|
647
|
-
logger.warning("Barge-in DETECTED but no previous message found")
|
|
648
|
-
return
|
|
649
|
-
|
|
650
|
-
logger.debug(
|
|
651
|
-
f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
|
|
652
|
-
)
|
|
653
|
-
if (item := self._chat_ctx.items[idx]).type == "message":
|
|
709
|
+
if idx >= 0 and (item := self._chat_ctx.items[idx]).type == "message":
|
|
654
710
|
item.interrupted = True
|
|
655
|
-
|
|
711
|
+
logger.debug("Barge-in detected - marked message as interrupted")
|
|
712
|
+
|
|
713
|
+
# Close generation on barge-in unless tools are pending
|
|
714
|
+
if not self._pending_tools:
|
|
715
|
+
self._close_current_generation()
|
|
716
|
+
else:
|
|
717
|
+
logger.debug(f"Keeping generation open - {len(self._pending_tools)} pending tools")
|
|
656
718
|
return
|
|
657
719
|
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
self._current_generation.
|
|
672
|
-
|
|
673
|
-
):
|
|
674
|
-
curr_gen = self._current_generation.messages[self._current_generation.response_id]
|
|
675
|
-
curr_gen.text_ch.send_nowait(text_content)
|
|
676
|
-
# note: this update is per utterance, not per turn
|
|
677
|
-
self._update_chat_ctx(role="assistant", text_content=text_content)
|
|
720
|
+
content_type = self._current_generation.content_id_map.get(content_id)
|
|
721
|
+
|
|
722
|
+
if content_type == "USER_ASR":
|
|
723
|
+
logger.debug(f"INPUT TRANSCRIPTION UPDATED: {text_content}")
|
|
724
|
+
self._update_chat_ctx(role="user", text_content=text_content)
|
|
725
|
+
|
|
726
|
+
elif content_type == "ASSISTANT_TEXT":
|
|
727
|
+
# Set first token timestamp if not already set
|
|
728
|
+
if self._current_generation._first_token_timestamp is None:
|
|
729
|
+
self._current_generation._first_token_timestamp = time.time()
|
|
730
|
+
|
|
731
|
+
# Stream text to LiveKit
|
|
732
|
+
if self._current_generation.message_gen:
|
|
733
|
+
self._current_generation.message_gen.text_ch.send_nowait(text_content)
|
|
734
|
+
self._update_chat_ctx(role="assistant", text_content=text_content)
|
|
678
735
|
|
|
679
736
|
def _update_chat_ctx(self, role: llm.ChatRole, text_content: str) -> None:
|
|
680
737
|
"""
|
|
@@ -704,107 +761,72 @@ class RealtimeSession( # noqa: F811
|
|
|
704
761
|
|
|
705
762
|
# cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
|
|
706
763
|
async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
|
|
707
|
-
"""
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
self._current_generation
|
|
712
|
-
is not None # means that first utterance in the turn was an interrupt
|
|
713
|
-
and self._current_generation.speculative_messages.get(text_content_id)
|
|
714
|
-
== self._current_generation.response_id
|
|
715
|
-
and stop_reason == "END_TURN"
|
|
716
|
-
):
|
|
717
|
-
log_event_data(event_data)
|
|
718
|
-
self._close_current_generation()
|
|
764
|
+
"""Handle text content end - log but don't close generation yet."""
|
|
765
|
+
# Nova Sonic sends multiple content blocks within one completion
|
|
766
|
+
# Don't close generation here - wait for completionEnd or audio_output_content_end
|
|
767
|
+
log_event_data(event_data)
|
|
719
768
|
|
|
720
769
|
async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
|
|
721
|
-
"""Track
|
|
770
|
+
"""Track tool content start."""
|
|
722
771
|
log_event_data(event_data)
|
|
723
|
-
assert self._current_generation is not None, "current_generation is None"
|
|
724
772
|
|
|
725
|
-
|
|
726
|
-
self.
|
|
727
|
-
|
|
728
|
-
|
|
773
|
+
# Ensure generation exists
|
|
774
|
+
self._create_response_generation()
|
|
775
|
+
|
|
776
|
+
if self._current_generation is None:
|
|
777
|
+
return
|
|
778
|
+
|
|
779
|
+
content_id = event_data["event"]["contentStart"]["contentId"]
|
|
780
|
+
self._current_generation.content_id_map[content_id] = "TOOL"
|
|
729
781
|
|
|
730
|
-
# note: tool calls are synchronous for now
|
|
731
782
|
async def _handle_tool_output_content_event(self, event_data: dict) -> None:
|
|
732
|
-
"""Execute the referenced tool locally and
|
|
783
|
+
"""Execute the referenced tool locally and queue results."""
|
|
733
784
|
log_event_data(event_data)
|
|
734
|
-
assert self._current_generation is not None, "current_generation is None"
|
|
735
785
|
|
|
736
|
-
|
|
786
|
+
if self._current_generation is None:
|
|
787
|
+
logger.warning("tool_output_content received without active generation")
|
|
788
|
+
return
|
|
789
|
+
|
|
737
790
|
tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
|
|
738
791
|
tool_name = event_data["event"]["toolUse"]["toolName"]
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
# therefore we introduce an artificial turn to trigger update_chat_ctx()
|
|
755
|
-
# TODO: this is messy-- investigate if there is a better way to handle this
|
|
756
|
-
curr_gen = self._current_generation.messages[self._current_generation.response_id]
|
|
757
|
-
curr_gen.audio_ch.close()
|
|
758
|
-
curr_gen.text_ch.close()
|
|
759
|
-
self._current_generation.message_ch.close()
|
|
760
|
-
self._current_generation.message_ch = utils.aio.Chan()
|
|
761
|
-
self._current_generation.function_ch.close()
|
|
762
|
-
self._current_generation.function_ch = utils.aio.Chan()
|
|
763
|
-
msg_gen = _MessageGeneration(
|
|
764
|
-
message_id=self._current_generation.response_id,
|
|
765
|
-
text_ch=utils.aio.Chan(),
|
|
766
|
-
audio_ch=utils.aio.Chan(),
|
|
767
|
-
)
|
|
768
|
-
self._current_generation.messages[self._current_generation.response_id] = msg_gen
|
|
769
|
-
msg_modalities = asyncio.Future[list[Literal["text", "audio"]]]()
|
|
770
|
-
msg_modalities.set_result(
|
|
771
|
-
["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
|
|
772
|
-
)
|
|
773
|
-
self._current_generation.message_ch.send_nowait(
|
|
774
|
-
llm.MessageGeneration(
|
|
775
|
-
message_id=msg_gen.message_id,
|
|
776
|
-
text_stream=msg_gen.text_ch,
|
|
777
|
-
audio_stream=msg_gen.audio_ch,
|
|
778
|
-
modalities=msg_modalities,
|
|
779
|
-
)
|
|
780
|
-
)
|
|
781
|
-
self.emit_generation_event()
|
|
792
|
+
args = event_data["event"]["toolUse"]["content"]
|
|
793
|
+
|
|
794
|
+
# Emit function call to LiveKit framework
|
|
795
|
+
self._current_generation.function_ch.send_nowait(
|
|
796
|
+
llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
|
|
797
|
+
)
|
|
798
|
+
self._pending_tools.add(tool_use_id)
|
|
799
|
+
logger.debug(f"Tool call emitted: {tool_name} (id={tool_use_id})")
|
|
800
|
+
|
|
801
|
+
# CRITICAL: Close generation after tool call emission
|
|
802
|
+
# The LiveKit framework expects the generation to close so it can call update_chat_ctx()
|
|
803
|
+
# with the tool results. A new generation will be created when Nova Sonic sends the next
|
|
804
|
+
# ASSISTANT SPECULATIVE text event with the tool response.
|
|
805
|
+
logger.debug("Closing generation to allow tool result delivery")
|
|
806
|
+
self._close_current_generation()
|
|
782
807
|
|
|
783
808
|
async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
|
|
784
809
|
log_event_data(event_data)
|
|
785
810
|
|
|
786
811
|
async def _handle_audio_output_content_start_event(self, event_data: dict) -> None:
|
|
787
|
-
"""
|
|
812
|
+
"""Track audio content start."""
|
|
788
813
|
if self._current_generation is not None:
|
|
789
814
|
log_event_data(event_data)
|
|
790
|
-
|
|
791
|
-
self._current_generation.
|
|
792
|
-
self._current_generation.response_id
|
|
793
|
-
)
|
|
815
|
+
content_id = event_data["event"]["contentStart"]["contentId"]
|
|
816
|
+
self._current_generation.content_id_map[content_id] = "ASSISTANT_AUDIO"
|
|
794
817
|
|
|
795
818
|
async def _handle_audio_output_content_event(self, event_data: dict) -> None:
|
|
796
819
|
"""Decode base64 audio from Bedrock and forward it to the audio stream."""
|
|
797
|
-
if
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
820
|
+
if self._current_generation is None or self._current_generation.message_gen is None:
|
|
821
|
+
return
|
|
822
|
+
|
|
823
|
+
content_id = event_data["event"]["audioOutput"]["contentId"]
|
|
824
|
+
content_type = self._current_generation.content_id_map.get(content_id)
|
|
825
|
+
|
|
826
|
+
if content_type == "ASSISTANT_AUDIO":
|
|
804
827
|
audio_content = event_data["event"]["audioOutput"]["content"]
|
|
805
828
|
audio_bytes = base64.b64decode(audio_content)
|
|
806
|
-
|
|
807
|
-
curr_gen.audio_ch.send_nowait(
|
|
829
|
+
self._current_generation.message_gen.audio_ch.send_nowait(
|
|
808
830
|
rtc.AudioFrame(
|
|
809
831
|
data=audio_bytes,
|
|
810
832
|
sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
|
|
@@ -814,62 +836,89 @@ class RealtimeSession( # noqa: F811
|
|
|
814
836
|
)
|
|
815
837
|
|
|
816
838
|
async def _handle_audio_output_content_end_event(self, event_data: dict) -> None:
|
|
817
|
-
"""
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
and self._current_generation.speculative_messages.get(
|
|
822
|
-
event_data["event"]["contentEnd"]["contentId"]
|
|
823
|
-
)
|
|
824
|
-
== self._current_generation.response_id
|
|
825
|
-
):
|
|
826
|
-
log_event_data(event_data)
|
|
827
|
-
self._close_current_generation()
|
|
839
|
+
"""Handle audio content end - log but don't close generation."""
|
|
840
|
+
log_event_data(event_data)
|
|
841
|
+
# Nova Sonic uses one completion for entire session
|
|
842
|
+
# Don't close generation here - wait for new completionStart or session end
|
|
828
843
|
|
|
829
844
|
def _close_current_generation(self) -> None:
|
|
830
|
-
"""Helper that closes all channels of the active
|
|
831
|
-
if self._current_generation is
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
self._current_generation
|
|
845
|
+
"""Helper that closes all channels of the active generation."""
|
|
846
|
+
if self._current_generation is None:
|
|
847
|
+
return
|
|
848
|
+
|
|
849
|
+
# Set completed timestamp
|
|
850
|
+
if self._current_generation._completed_timestamp is None:
|
|
851
|
+
self._current_generation._completed_timestamp = time.time()
|
|
852
|
+
|
|
853
|
+
# Close message channels
|
|
854
|
+
if self._current_generation.message_gen:
|
|
855
|
+
if not self._current_generation.message_gen.audio_ch.closed:
|
|
856
|
+
self._current_generation.message_gen.audio_ch.close()
|
|
857
|
+
if not self._current_generation.message_gen.text_ch.closed:
|
|
858
|
+
self._current_generation.message_gen.text_ch.close()
|
|
859
|
+
|
|
860
|
+
# Close generation channels
|
|
861
|
+
if not self._current_generation.message_ch.closed:
|
|
862
|
+
self._current_generation.message_ch.close()
|
|
863
|
+
if not self._current_generation.function_ch.closed:
|
|
864
|
+
self._current_generation.function_ch.close()
|
|
865
|
+
|
|
866
|
+
logger.debug(
|
|
867
|
+
f"Closed generation for completion_id={self._current_generation.completion_id}"
|
|
868
|
+
)
|
|
869
|
+
self._current_generation = None
|
|
853
870
|
|
|
854
871
|
async def _handle_completion_end_event(self, event_data: dict) -> None:
|
|
872
|
+
"""Handle completionEnd - close the generation for this completion cycle."""
|
|
855
873
|
log_event_data(event_data)
|
|
856
874
|
|
|
875
|
+
# Close generation if still open
|
|
876
|
+
if self._current_generation:
|
|
877
|
+
logger.debug("completionEnd received, closing generation")
|
|
878
|
+
self._close_current_generation()
|
|
879
|
+
|
|
857
880
|
async def _handle_other_event(self, event_data: dict) -> None:
|
|
858
881
|
log_event_data(event_data)
|
|
859
882
|
|
|
860
883
|
async def _handle_usage_event(self, event_data: dict) -> None:
|
|
861
884
|
# log_event_data(event_data)
|
|
862
|
-
# TODO: implement duration and ttft
|
|
863
885
|
input_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["input"]
|
|
864
886
|
output_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["output"]
|
|
865
|
-
|
|
887
|
+
|
|
888
|
+
# Calculate metrics from timestamps
|
|
889
|
+
duration = 0.0
|
|
890
|
+
ttft = 0.0
|
|
891
|
+
tokens_per_second = 0.0
|
|
892
|
+
|
|
893
|
+
if self._current_generation is not None:
|
|
894
|
+
created_ts = self._current_generation._created_timestamp
|
|
895
|
+
first_token_ts = self._current_generation._first_token_timestamp
|
|
896
|
+
completed_ts = self._current_generation._completed_timestamp
|
|
897
|
+
|
|
898
|
+
# Calculate TTFT (time to first token)
|
|
899
|
+
if first_token_ts is not None and isinstance(created_ts, (int, float)):
|
|
900
|
+
ttft = first_token_ts - created_ts
|
|
901
|
+
|
|
902
|
+
# Calculate duration (total time from creation to completion)
|
|
903
|
+
if completed_ts is not None and isinstance(created_ts, (int, float)):
|
|
904
|
+
duration = completed_ts - created_ts
|
|
905
|
+
|
|
906
|
+
# Calculate tokens per second
|
|
907
|
+
total_tokens = (
|
|
908
|
+
input_tokens["speechTokens"]
|
|
909
|
+
+ input_tokens["textTokens"]
|
|
910
|
+
+ output_tokens["speechTokens"]
|
|
911
|
+
+ output_tokens["textTokens"]
|
|
912
|
+
)
|
|
913
|
+
if duration > 0:
|
|
914
|
+
tokens_per_second = total_tokens / duration
|
|
915
|
+
|
|
866
916
|
metrics = RealtimeModelMetrics(
|
|
867
|
-
label=self._realtime_model.
|
|
868
|
-
# TODO: pass in the correct request_id
|
|
917
|
+
label=self._realtime_model.label,
|
|
869
918
|
request_id=event_data["event"]["usageEvent"]["completionId"],
|
|
870
919
|
timestamp=time.monotonic(),
|
|
871
|
-
duration=
|
|
872
|
-
ttft=
|
|
920
|
+
duration=duration,
|
|
921
|
+
ttft=ttft,
|
|
873
922
|
cancelled=False,
|
|
874
923
|
input_tokens=input_tokens["speechTokens"] + input_tokens["textTokens"],
|
|
875
924
|
output_tokens=output_tokens["speechTokens"] + output_tokens["textTokens"],
|
|
@@ -877,8 +926,7 @@ class RealtimeSession( # noqa: F811
|
|
|
877
926
|
+ input_tokens["textTokens"]
|
|
878
927
|
+ output_tokens["speechTokens"]
|
|
879
928
|
+ output_tokens["textTokens"],
|
|
880
|
-
|
|
881
|
-
tokens_per_second=0,
|
|
929
|
+
tokens_per_second=tokens_per_second,
|
|
882
930
|
input_token_details=RealtimeModelMetrics.InputTokenDetails(
|
|
883
931
|
text_tokens=input_tokens["textTokens"],
|
|
884
932
|
audio_tokens=input_tokens["speechTokens"],
|
|
@@ -891,6 +939,9 @@ class RealtimeSession( # noqa: F811
|
|
|
891
939
|
audio_tokens=output_tokens["speechTokens"],
|
|
892
940
|
image_tokens=0,
|
|
893
941
|
),
|
|
942
|
+
metadata=Metadata(
|
|
943
|
+
model_name=self._realtime_model.model, model_provider=self._realtime_model.provider
|
|
944
|
+
),
|
|
894
945
|
)
|
|
895
946
|
self.emit("metrics_collected", metrics)
|
|
896
947
|
|
|
@@ -949,7 +1000,12 @@ class RealtimeSession( # noqa: F811
|
|
|
949
1000
|
),
|
|
950
1001
|
)
|
|
951
1002
|
raise
|
|
952
|
-
except (
|
|
1003
|
+
except (
|
|
1004
|
+
ThrottlingException,
|
|
1005
|
+
ModelNotReadyException,
|
|
1006
|
+
ModelErrorException,
|
|
1007
|
+
ModelStreamErrorException,
|
|
1008
|
+
) as re:
|
|
953
1009
|
logger.warning(f"Retryable error: {re}\nAttempting to recover...")
|
|
954
1010
|
await self._restart_session(re)
|
|
955
1011
|
break
|
|
@@ -996,8 +1052,13 @@ class RealtimeSession( # noqa: F811
|
|
|
996
1052
|
self._is_sess_active.clear()
|
|
997
1053
|
|
|
998
1054
|
async def _restart_session(self, ex: Exception) -> None:
|
|
999
|
-
if
|
|
1000
|
-
|
|
1055
|
+
# Get restart attempts from current generation, or 0 if no generation
|
|
1056
|
+
restart_attempts = (
|
|
1057
|
+
self._current_generation._restart_attempts if self._current_generation else 0
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
if restart_attempts >= DEFAULT_MAX_SESSION_RESTART_ATTEMPTS:
|
|
1061
|
+
logger.error("Max restart attempts reached for this turn, exiting")
|
|
1001
1062
|
err_msg = getattr(ex, "message", str(ex))
|
|
1002
1063
|
request_id = None
|
|
1003
1064
|
try:
|
|
@@ -1021,13 +1082,20 @@ class RealtimeSession( # noqa: F811
|
|
|
1021
1082
|
)
|
|
1022
1083
|
self._is_sess_active.clear()
|
|
1023
1084
|
return
|
|
1024
|
-
|
|
1085
|
+
|
|
1086
|
+
# Increment restart counter for current generation
|
|
1087
|
+
if self._current_generation:
|
|
1088
|
+
self._current_generation._restart_attempts += 1
|
|
1089
|
+
restart_attempts = self._current_generation._restart_attempts
|
|
1090
|
+
else:
|
|
1091
|
+
restart_attempts = 1
|
|
1092
|
+
|
|
1025
1093
|
self._is_sess_active.clear()
|
|
1026
|
-
delay = 2 ** (
|
|
1094
|
+
delay = 2 ** (restart_attempts - 1) - 1
|
|
1027
1095
|
await asyncio.sleep(min(delay, DEFAULT_MAX_SESSION_RESTART_DELAY))
|
|
1028
1096
|
await self.initialize_streams(is_restart=True)
|
|
1029
1097
|
logger.info(
|
|
1030
|
-
f"
|
|
1098
|
+
f"Turn restarted successfully ({restart_attempts}/{DEFAULT_MAX_SESSION_RESTART_ATTEMPTS})"
|
|
1031
1099
|
)
|
|
1032
1100
|
|
|
1033
1101
|
@property
|
|
@@ -1041,7 +1109,10 @@ class RealtimeSession( # noqa: F811
|
|
|
1041
1109
|
async def update_instructions(self, instructions: str) -> None:
|
|
1042
1110
|
"""Injects the system prompt at the start of the session."""
|
|
1043
1111
|
self._instructions = instructions
|
|
1044
|
-
self._instructions_ready
|
|
1112
|
+
if self._instructions_ready is None:
|
|
1113
|
+
self._instructions_ready = asyncio.get_running_loop().create_future()
|
|
1114
|
+
if not self._instructions_ready.done():
|
|
1115
|
+
self._instructions_ready.set_result(True)
|
|
1045
1116
|
logger.debug(f"Instructions updated: {instructions}")
|
|
1046
1117
|
|
|
1047
1118
|
async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
|
|
@@ -1049,16 +1120,26 @@ class RealtimeSession( # noqa: F811
|
|
|
1049
1120
|
# sometimes fires randomly
|
|
1050
1121
|
# add a guard here to only allow chat_ctx to be updated on
|
|
1051
1122
|
# the very first session initialization
|
|
1123
|
+
if self._chat_ctx_ready is None:
|
|
1124
|
+
self._chat_ctx_ready = asyncio.get_running_loop().create_future()
|
|
1125
|
+
|
|
1052
1126
|
if not self._chat_ctx_ready.done():
|
|
1053
1127
|
self._chat_ctx = chat_ctx.copy()
|
|
1054
1128
|
logger.debug(f"Chat context updated: {self._chat_ctx.items}")
|
|
1055
1129
|
self._chat_ctx_ready.set_result(True)
|
|
1056
1130
|
|
|
1057
1131
|
# for each function tool, send the result to aws
|
|
1132
|
+
logger.debug(
|
|
1133
|
+
f"update_chat_ctx called with {len(chat_ctx.items)} items, pending_tools: {self._pending_tools}"
|
|
1134
|
+
)
|
|
1058
1135
|
for item in chat_ctx.items:
|
|
1059
1136
|
if item.type != "function_call_output":
|
|
1060
1137
|
continue
|
|
1061
1138
|
|
|
1139
|
+
logger.debug(
|
|
1140
|
+
f"Found function_call_output: call_id={item.call_id}, in_pending={item.call_id in self._pending_tools}"
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1062
1143
|
if item.call_id not in self._pending_tools:
|
|
1063
1144
|
continue
|
|
1064
1145
|
|
|
@@ -1108,7 +1189,10 @@ class RealtimeSession( # noqa: F811
|
|
|
1108
1189
|
retained_tools.append(tool)
|
|
1109
1190
|
self._tools = llm.ToolContext(retained_tools)
|
|
1110
1191
|
if retained_tools:
|
|
1111
|
-
self._tools_ready
|
|
1192
|
+
if self._tools_ready is None:
|
|
1193
|
+
self._tools_ready = asyncio.get_running_loop().create_future()
|
|
1194
|
+
if not self._tools_ready.done():
|
|
1195
|
+
self._tools_ready.set_result(True)
|
|
1112
1196
|
logger.debug("Tool list has been injected")
|
|
1113
1197
|
|
|
1114
1198
|
def update_options(self, *, tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN) -> None:
|
|
@@ -1144,54 +1228,62 @@ class RealtimeSession( # noqa: F811
|
|
|
1144
1228
|
"""Background task that feeds audio and tool results into the Bedrock stream."""
|
|
1145
1229
|
await self._send_raw_event(self._event_builder.create_audio_content_start_event())
|
|
1146
1230
|
logger.info("Starting audio input processing loop")
|
|
1231
|
+
|
|
1232
|
+
# Create tasks for both channels so we can wait on either
|
|
1233
|
+
audio_task = asyncio.create_task(self._audio_input_chan.recv())
|
|
1234
|
+
tool_task = asyncio.create_task(self._tool_results_ch.recv())
|
|
1235
|
+
pending = {audio_task, tool_task}
|
|
1236
|
+
|
|
1147
1237
|
while self._is_sess_active.is_set():
|
|
1148
1238
|
try:
|
|
1149
|
-
|
|
1150
|
-
try:
|
|
1151
|
-
val = self._tool_results_ch.recv_nowait()
|
|
1152
|
-
tool_result = val["tool_result"]
|
|
1153
|
-
tool_use_id = val["tool_use_id"]
|
|
1154
|
-
if not isinstance(tool_result, str):
|
|
1155
|
-
tool_result = json.dumps(tool_result)
|
|
1156
|
-
else:
|
|
1157
|
-
try:
|
|
1158
|
-
json.loads(tool_result)
|
|
1159
|
-
except json.JSONDecodeError:
|
|
1160
|
-
try:
|
|
1161
|
-
tool_result = json.dumps(ast.literal_eval(tool_result))
|
|
1162
|
-
except Exception:
|
|
1163
|
-
# return the original value
|
|
1164
|
-
pass
|
|
1165
|
-
|
|
1166
|
-
logger.debug(f"Sending tool result: {tool_result}")
|
|
1167
|
-
await self._send_tool_events(tool_use_id, tool_result)
|
|
1168
|
-
|
|
1169
|
-
except utils.aio.channel.ChanEmpty:
|
|
1170
|
-
pass
|
|
1171
|
-
except utils.aio.channel.ChanClosed:
|
|
1172
|
-
logger.warning(
|
|
1173
|
-
"tool results channel closed, exiting audio input processing loop"
|
|
1174
|
-
)
|
|
1175
|
-
break
|
|
1239
|
+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
1176
1240
|
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1241
|
+
for task in done:
|
|
1242
|
+
if task == audio_task:
|
|
1243
|
+
try:
|
|
1244
|
+
audio_bytes = cast(bytes, task.result())
|
|
1245
|
+
blob = base64.b64encode(audio_bytes)
|
|
1246
|
+
audio_event = self._event_builder.create_audio_input_event(
|
|
1247
|
+
audio_content=blob.decode("utf-8"),
|
|
1248
|
+
)
|
|
1249
|
+
await self._send_raw_event(audio_event)
|
|
1250
|
+
# Create new task for next audio
|
|
1251
|
+
audio_task = asyncio.create_task(self._audio_input_chan.recv())
|
|
1252
|
+
pending.add(audio_task)
|
|
1253
|
+
except utils.aio.channel.ChanClosed:
|
|
1254
|
+
logger.warning("audio input channel closed")
|
|
1255
|
+
break
|
|
1256
|
+
|
|
1257
|
+
elif task == tool_task:
|
|
1258
|
+
try:
|
|
1259
|
+
val = cast(dict[str, str], task.result())
|
|
1260
|
+
tool_result = val["tool_result"]
|
|
1261
|
+
tool_use_id = val["tool_use_id"]
|
|
1262
|
+
if not isinstance(tool_result, str):
|
|
1263
|
+
tool_result = json.dumps(tool_result)
|
|
1264
|
+
else:
|
|
1265
|
+
try:
|
|
1266
|
+
json.loads(tool_result)
|
|
1267
|
+
except json.JSONDecodeError:
|
|
1268
|
+
try:
|
|
1269
|
+
tool_result = json.dumps(ast.literal_eval(tool_result))
|
|
1270
|
+
except Exception:
|
|
1271
|
+
pass
|
|
1272
|
+
|
|
1273
|
+
logger.debug(f"Sending tool result: {tool_result}")
|
|
1274
|
+
await self._send_tool_events(tool_use_id, tool_result)
|
|
1275
|
+
# Create new task for next tool result
|
|
1276
|
+
tool_task = asyncio.create_task(self._tool_results_ch.recv())
|
|
1277
|
+
pending.add(tool_task)
|
|
1278
|
+
except utils.aio.channel.ChanClosed:
|
|
1279
|
+
logger.warning("tool results channel closed")
|
|
1280
|
+
break
|
|
1192
1281
|
|
|
1193
1282
|
except asyncio.CancelledError:
|
|
1194
1283
|
logger.info("Audio processing loop cancelled")
|
|
1284
|
+
# Cancel pending tasks
|
|
1285
|
+
for task in pending:
|
|
1286
|
+
task.cancel()
|
|
1195
1287
|
self._audio_input_chan.close()
|
|
1196
1288
|
self._tool_results_ch.close()
|
|
1197
1289
|
raise
|
|
@@ -1242,7 +1334,24 @@ class RealtimeSession( # noqa: F811
|
|
|
1242
1334
|
logger.warning("video is not supported by Nova Sonic's Realtime API")
|
|
1243
1335
|
|
|
1244
1336
|
def interrupt(self) -> None:
|
|
1245
|
-
|
|
1337
|
+
"""Nova Sonic handles interruption automatically via barge-in detection.
|
|
1338
|
+
|
|
1339
|
+
Unlike OpenAI's client-initiated interrupt, Nova Sonic automatically detects
|
|
1340
|
+
when the user starts speaking while the model is generating audio. When this
|
|
1341
|
+
happens, the model:
|
|
1342
|
+
1. Immediately stops generating speech
|
|
1343
|
+
2. Switches to listening mode
|
|
1344
|
+
3. Sends a text event with content: { "interrupted" : true }
|
|
1345
|
+
|
|
1346
|
+
The plugin already handles this event (see _handle_text_output_content_event).
|
|
1347
|
+
No client action is needed - interruption works automatically.
|
|
1348
|
+
|
|
1349
|
+
See AWS docs: https://docs.aws.amazon.com/nova/latest/userguide/output-events.html
|
|
1350
|
+
"""
|
|
1351
|
+
logger.info(
|
|
1352
|
+
"Nova Sonic handles interruption automatically via barge-in detection. "
|
|
1353
|
+
"The model detects when users start speaking and stops generation automatically."
|
|
1354
|
+
)
|
|
1246
1355
|
|
|
1247
1356
|
def truncate(
|
|
1248
1357
|
self,
|
|
@@ -1292,6 +1401,11 @@ class RealtimeSession( # noqa: F811
|
|
|
1292
1401
|
if self._stream_response and not self._stream_response.input_stream.closed:
|
|
1293
1402
|
await self._stream_response.input_stream.close()
|
|
1294
1403
|
|
|
1404
|
+
# cancel main task to prevent pending task warnings
|
|
1405
|
+
if self._main_atask and not self._main_atask.done():
|
|
1406
|
+
self._main_atask.cancel()
|
|
1407
|
+
tasks.append(self._main_atask)
|
|
1408
|
+
|
|
1295
1409
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
1296
1410
|
logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
|
|
1297
1411
|
logger.info("Session end")
|