livekit-plugins-aws 1.2.15__py3-none-any.whl → 1.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of livekit-plugins-aws might be problematic. Click here for more details.

@@ -12,7 +12,6 @@ import uuid
12
12
  import weakref
13
13
  from collections.abc import Iterator
14
14
  from dataclasses import dataclass, field
15
- from datetime import datetime
16
15
  from typing import Any, Callable, Literal, cast
17
16
 
18
17
  import boto3
@@ -26,6 +25,7 @@ from aws_sdk_bedrock_runtime.models import (
26
25
  InvokeModelWithBidirectionalStreamInputChunk,
27
26
  ModelErrorException,
28
27
  ModelNotReadyException,
28
+ ModelStreamErrorException,
29
29
  ModelTimeoutException,
30
30
  ThrottlingException,
31
31
  ValidationException,
@@ -123,38 +123,35 @@ class _MessageGeneration:
123
123
 
124
124
  @dataclass
125
125
  class _ResponseGeneration:
126
- """Book-keeping dataclass tracking the lifecycle of a Sonic turn.
126
+ """Book-keeping dataclass tracking the lifecycle of a Nova Sonic completion.
127
127
 
128
- This object is created whenever we receive a *completion_start* event from the model
129
- and is disposed of once the assistant turn finishes (e.g. *END_TURN*).
128
+ Nova Sonic uses a completion model where one completionStart event begins a cycle
129
+ that may contain multiple content blocks (USER ASR, TOOL, ASSISTANT text/audio).
130
+ This generation stays open for the entire completion cycle.
130
131
 
131
132
  Attributes:
132
- message_ch (utils.aio.Chan[llm.MessageGeneration]): Multiplexed stream for all assistant messages.
133
+ completion_id (str): Nova Sonic's completionId that ties all events together.
134
+ message_ch (utils.aio.Chan[llm.MessageGeneration]): Stream for assistant messages.
133
135
  function_ch (utils.aio.Chan[llm.FunctionCall]): Stream that emits function tool calls.
134
- input_id (str): Synthetic message id for the user input of the current turn.
135
- response_id (str): Synthetic message id for the assistant reply of the current turn.
136
- messages (dict[str, _MessageGeneration]): Map of message_id -> per-message stream containers.
137
- user_messages (dict[str, str]): Map Bedrock content_id -> input_id.
138
- speculative_messages (dict[str, str]): Map Bedrock content_id -> response_id (assistant side).
139
- tool_messages (dict[str, str]): Map Bedrock content_id -> response_id for tool calls.
140
- output_text (str): Accumulated assistant text (only used for metrics / debugging).
141
- _created_timestamp (str): ISO-8601 timestamp when the generation record was created.
136
+ response_id (str): LiveKit response_id for the assistant's response.
137
+ message_gen (_MessageGeneration | None): Current message generation for assistant output.
138
+ content_id_map (dict[str, str]): Map Nova Sonic contentId -> type (USER/ASSISTANT/TOOL).
139
+ _created_timestamp (float): Wall-clock time when the generation record was created.
142
140
  _first_token_timestamp (float | None): Wall-clock time of first token emission.
143
141
  _completed_timestamp (float | None): Wall-clock time when the turn fully completed.
142
+ _restart_attempts (int): Number of restart attempts for this specific completion.
144
143
  """ # noqa: E501
145
144
 
145
+ completion_id: str
146
146
  message_ch: utils.aio.Chan[llm.MessageGeneration]
147
147
  function_ch: utils.aio.Chan[llm.FunctionCall]
148
- input_id: str # corresponds to user's portion of the turn
149
- response_id: str # corresponds to agent's portion of the turn
150
- messages: dict[str, _MessageGeneration] = field(default_factory=dict)
151
- user_messages: dict[str, str] = field(default_factory=dict)
152
- speculative_messages: dict[str, str] = field(default_factory=dict)
153
- tool_messages: dict[str, str] = field(default_factory=dict)
154
- output_text: str = "" # agent ASR text
155
- _created_timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
148
+ response_id: str
149
+ message_gen: _MessageGeneration | None = None
150
+ content_id_map: dict[str, str] = field(default_factory=dict)
151
+ _created_timestamp: float = field(default_factory=time.time)
156
152
  _first_token_timestamp: float | None = None
157
153
  _completed_timestamp: float | None = None
154
+ _restart_attempts: int = 0
158
155
 
159
156
 
160
157
  class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc]
@@ -271,16 +268,12 @@ class RealtimeModel(llm.RealtimeModel):
271
268
  def session(self) -> RealtimeSession:
272
269
  """Return a new RealtimeSession bound to this model instance."""
273
270
  sess = RealtimeSession(self)
274
-
275
- # note: this is a hack to get the session to initialize itself
276
- # TODO: change how RealtimeSession is initialized by creating a single task main_atask that spawns subtasks # noqa: E501
277
- asyncio.create_task(sess.initialize_streams())
278
271
  self._sessions.add(sess)
279
272
  return sess
280
273
 
281
- # stub b/c RealtimeSession.aclose() is invoked directly
282
- async def aclose(self) -> None:
283
- pass
274
+ async def aclose(self) -> None:
275
+ """Close all active sessions."""
276
+ pass
284
277
 
285
278
 
286
279
  class RealtimeSession( # noqa: F811
@@ -326,17 +319,16 @@ class RealtimeSession( # noqa: F811
326
319
  self._chat_ctx = llm.ChatContext.empty()
327
320
  self._tools = llm.ToolContext.empty()
328
321
  self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
329
- self._tools_ready = asyncio.get_running_loop().create_future()
330
- self._instructions_ready = asyncio.get_running_loop().create_future()
331
- self._chat_ctx_ready = asyncio.get_running_loop().create_future()
322
+ # CRITICAL: Initialize futures as None for lazy creation
323
+ # Creating futures in __init__ causes race conditions during session restart.
324
+ # Futures are created in initialize_streams() when the event loop is guaranteed to exist.
325
+ self._tools_ready: asyncio.Future[bool] | None = None
326
+ self._instructions_ready: asyncio.Future[bool] | None = None
327
+ self._chat_ctx_ready: asyncio.Future[bool] | None = None
332
328
  self._instructions = DEFAULT_SYSTEM_PROMPT
333
329
  self._audio_input_chan = utils.aio.Chan[bytes]()
334
330
  self._current_generation: _ResponseGeneration | None = None
335
331
 
336
- # note: currently tracks session restart attempts across all sessions
337
- # TODO: track restart attempts per turn
338
- self._session_restart_attempts = 0
339
-
340
332
  self._event_handlers = {
341
333
  "completion_start": self._handle_completion_start_event,
342
334
  "audio_output_content_start": self._handle_audio_output_content_start_event,
@@ -357,6 +349,11 @@ class RealtimeSession( # noqa: F811
357
349
  cast(Callable[[], None], self.emit_generation_event),
358
350
  )
359
351
 
352
+ # Create main task to manage session lifecycle
353
+ self._main_atask = asyncio.create_task(
354
+ self.initialize_streams(), name="RealtimeSession.initialize_streams"
355
+ )
356
+
360
357
  @utils.log_exceptions(logger=logger)
361
358
  def _initialize_client(self) -> None:
362
359
  """Instantiate the Bedrock runtime client"""
@@ -490,6 +487,14 @@ class RealtimeSession( # noqa: F811
490
487
  )
491
488
 
492
489
  if not is_restart:
490
+ # Lazy-initialize futures if needed
491
+ if self._tools_ready is None:
492
+ self._tools_ready = asyncio.get_running_loop().create_future()
493
+ if self._instructions_ready is None:
494
+ self._instructions_ready = asyncio.get_running_loop().create_future()
495
+ if self._chat_ctx_ready is None:
496
+ self._chat_ctx_ready = asyncio.get_running_loop().create_future()
497
+
493
498
  pending_events: list[asyncio.Future] = []
494
499
  if not self.tools.function_tools:
495
500
  pending_events.append(self._tools_ready)
@@ -505,14 +510,14 @@ class RealtimeSession( # noqa: F811
505
510
  if pending_events:
506
511
  await asyncio.wait_for(asyncio.gather(*pending_events), timeout=0.5)
507
512
  except asyncio.TimeoutError:
508
- if not self._tools_ready.done():
513
+ if self._tools_ready and not self._tools_ready.done():
509
514
  logger.warning("Tools not ready after 500ms, continuing without them")
510
515
 
511
- if not self._instructions_ready.done():
516
+ if self._instructions_ready and not self._instructions_ready.done():
512
517
  logger.warning(
513
518
  "Instructions not received after 500ms, proceeding with default instructions" # noqa: E501
514
519
  )
515
- if not self._chat_ctx_ready.done():
520
+ if self._chat_ctx_ready and not self._chat_ctx_ready.done():
516
521
  logger.warning(
517
522
  "Chat context not received after 500ms, proceeding with empty chat context" # noqa: E501
518
523
  )
@@ -559,9 +564,11 @@ class RealtimeSession( # noqa: F811
559
564
  @utils.log_exceptions(logger=logger)
560
565
  def emit_generation_event(self) -> None:
561
566
  """Publish a llm.GenerationCreatedEvent to external subscribers."""
562
- logger.debug("Emitting generation event")
563
- assert self._current_generation is not None, "current_generation is None"
567
+ if self._current_generation is None:
568
+ logger.debug("emit_generation_event called but no generation exists - ignoring")
569
+ return
564
570
 
571
+ logger.debug("Emitting generation event")
565
572
  generation_ev = llm.GenerationCreatedEvent(
566
573
  message_stream=self._current_generation.message_ch,
567
574
  function_stream=self._current_generation.function_ch,
@@ -582,21 +589,38 @@ class RealtimeSession( # noqa: F811
582
589
  logger.warning(f"No event handler found for event type: {event_type}")
583
590
 
584
591
  async def _handle_completion_start_event(self, event_data: dict) -> None:
592
+ """Handle completionStart - create new generation for this completion cycle."""
585
593
  log_event_data(event_data)
586
594
  self._create_response_generation()
587
595
 
588
596
  def _create_response_generation(self) -> None:
589
- """Instantiate _ResponseGeneration and emit the GenerationCreated event."""
597
+ """Instantiate _ResponseGeneration and emit the GenerationCreated event.
598
+
599
+ Can be called multiple times - will reuse existing generation but ensure
600
+ message structure exists.
601
+ """
602
+ generation_created = False
590
603
  if self._current_generation is None:
604
+ completion_id = "unknown" # Will be set from events
605
+ response_id = str(uuid.uuid4())
606
+
607
+ logger.debug(f"Creating new generation, response_id={response_id}")
591
608
  self._current_generation = _ResponseGeneration(
609
+ completion_id=completion_id,
592
610
  message_ch=utils.aio.Chan(),
593
611
  function_ch=utils.aio.Chan(),
594
- input_id=str(uuid.uuid4()),
595
- response_id=str(uuid.uuid4()),
596
- messages={},
597
- user_messages={},
598
- speculative_messages={},
599
- _created_timestamp=datetime.now().isoformat(),
612
+ response_id=response_id,
613
+ )
614
+ generation_created = True
615
+ else:
616
+ logger.debug(
617
+ f"Generation already exists: response_id={self._current_generation.response_id}"
618
+ )
619
+
620
+ # Always ensure message structure exists (even if generation already exists)
621
+ if self._current_generation.message_gen is None:
622
+ logger.debug(
623
+ f"Creating message structure for response_id={self._current_generation.response_id}"
600
624
  )
601
625
  msg_gen = _MessageGeneration(
602
626
  message_id=self._current_generation.response_id,
@@ -607,6 +631,8 @@ class RealtimeSession( # noqa: F811
607
631
  msg_modalities.set_result(
608
632
  ["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
609
633
  )
634
+
635
+ self._current_generation.message_gen = msg_gen
610
636
  self._current_generation.message_ch.send_nowait(
611
637
  llm.MessageGeneration(
612
638
  message_id=msg_gen.message_id,
@@ -615,77 +641,97 @@ class RealtimeSession( # noqa: F811
615
641
  modalities=msg_modalities,
616
642
  )
617
643
  )
618
- self._current_generation.messages[self._current_generation.response_id] = msg_gen
644
+ else:
645
+ logger.debug(
646
+ f"Message structure already exists for response_id={self._current_generation.response_id}"
647
+ )
648
+
649
+ # Only emit generation event if we created a new generation
650
+ if generation_created:
651
+ self.emit_generation_event()
619
652
 
620
653
  # will be completely ignoring post-ASR text events
621
654
  async def _handle_text_output_content_start_event(self, event_data: dict) -> None:
622
- """Handle text_output_content_start for both user and assistant roles."""
655
+ """Handle text_output_content_start - track content type."""
623
656
  log_event_data(event_data)
657
+
624
658
  role = event_data["event"]["contentStart"]["role"]
625
- self._create_response_generation()
626
659
 
627
- # note: does not work if you emit llm.GCE too early (for some reason)
628
- if role == "USER":
629
- assert self._current_generation is not None, "current_generation is None"
660
+ # CRITICAL: Create NEW generation for each ASSISTANT SPECULATIVE response
661
+ # Nova Sonic sends ASSISTANT SPECULATIVE for each new assistant turn, including after tool calls.
662
+ # Without this, audio frames get routed to the wrong generation and don't play.
663
+ if role == "ASSISTANT":
664
+ additional_fields = event_data["event"]["contentStart"].get("additionalModelFields", "")
665
+ if "SPECULATIVE" in additional_fields:
666
+ # This is a new assistant response - close previous and create new
667
+ logger.debug("ASSISTANT SPECULATIVE text - creating new generation")
668
+ if self._current_generation is not None:
669
+ logger.debug("Closing previous generation for new assistant response")
670
+ self._close_current_generation()
671
+ self._create_response_generation()
672
+ else:
673
+ # For USER and FINAL, just ensure generation exists
674
+ self._create_response_generation()
630
675
 
631
- content_id = event_data["event"]["contentStart"]["contentId"]
632
- self._current_generation.user_messages[content_id] = self._current_generation.input_id
676
+ # CRITICAL: Check if generation exists before accessing
677
+ # Barge-in can set _current_generation to None between the creation above and here.
678
+ # Without this check, we crash on interruptions.
679
+ if self._current_generation is None:
680
+ logger.debug("No generation exists - ignoring content_start event")
681
+ return
633
682
 
634
- elif (
635
- role == "ASSISTANT"
636
- and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
637
- ):
638
- assert self._current_generation is not None, "current_generation is None"
683
+ content_id = event_data["event"]["contentStart"]["contentId"]
639
684
 
640
- text_content_id = event_data["event"]["contentStart"]["contentId"]
641
- self._current_generation.speculative_messages[text_content_id] = (
642
- self._current_generation.response_id
643
- )
685
+ # Track what type of content this is
686
+ if role == "USER":
687
+ self._current_generation.content_id_map[content_id] = "USER_ASR"
688
+ elif role == "ASSISTANT":
689
+ additional_fields = event_data["event"]["contentStart"].get("additionalModelFields", "")
690
+ if "SPECULATIVE" in additional_fields:
691
+ self._current_generation.content_id_map[content_id] = "ASSISTANT_TEXT"
692
+ elif "FINAL" in additional_fields:
693
+ self._current_generation.content_id_map[content_id] = "ASSISTANT_FINAL"
644
694
 
645
695
  async def _handle_text_output_content_event(self, event_data: dict) -> None:
646
- """Stream partial text tokens into the current _MessageGeneration."""
696
+ """Stream partial text tokens into the current generation."""
647
697
  log_event_data(event_data)
648
- text_content_id = event_data["event"]["textOutput"]["contentId"]
698
+
699
+ if self._current_generation is None:
700
+ logger.debug("No generation exists - ignoring text_output event")
701
+ return
702
+
703
+ content_id = event_data["event"]["textOutput"]["contentId"]
649
704
  text_content = f"{event_data['event']['textOutput']['content']}\n"
650
705
 
651
- # currently only agent can be interrupted
706
+ # Nova Sonic's automatic barge-in detection
652
707
  if text_content == '{ "interrupted" : true }\n':
653
- # the interrupted flag is not being set correctly in chat_ctx
654
- # this is b/c audio playback is desynced from text transcription
655
- # TODO: fix this; possibly via a playback timer
656
708
  idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
657
- if idx < 0:
658
- logger.warning("Barge-in DETECTED but no previous message found")
659
- return
660
-
661
- logger.debug(
662
- f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
663
- )
664
- if (item := self._chat_ctx.items[idx]).type == "message":
709
+ if idx >= 0 and (item := self._chat_ctx.items[idx]).type == "message":
665
710
  item.interrupted = True
666
- self._close_current_generation()
711
+ logger.debug("Barge-in detected - marked message as interrupted")
712
+
713
+ # Close generation on barge-in unless tools are pending
714
+ if not self._pending_tools:
715
+ self._close_current_generation()
716
+ else:
717
+ logger.debug(f"Keeping generation open - {len(self._pending_tools)} pending tools")
667
718
  return
668
719
 
669
- # ignore events until turn starts
670
- if self._current_generation is not None:
671
- # TODO: rename event to llm.InputTranscriptionUpdated
672
- if (
673
- self._current_generation.user_messages.get(text_content_id)
674
- == self._current_generation.input_id
675
- ):
676
- logger.debug(f"INPUT TRANSCRIPTION UPDATED: {text_content}")
677
- # note: user ASR text is slightly different than what is sent to LiveKit (newline vs whitespace) # noqa: E501
678
- # TODO: fix this
679
- self._update_chat_ctx(role="user", text_content=text_content)
680
-
681
- elif (
682
- self._current_generation.speculative_messages.get(text_content_id)
683
- == self._current_generation.response_id
684
- ):
685
- curr_gen = self._current_generation.messages[self._current_generation.response_id]
686
- curr_gen.text_ch.send_nowait(text_content)
687
- # note: this update is per utterance, not per turn
688
- self._update_chat_ctx(role="assistant", text_content=text_content)
720
+ content_type = self._current_generation.content_id_map.get(content_id)
721
+
722
+ if content_type == "USER_ASR":
723
+ logger.debug(f"INPUT TRANSCRIPTION UPDATED: {text_content}")
724
+ self._update_chat_ctx(role="user", text_content=text_content)
725
+
726
+ elif content_type == "ASSISTANT_TEXT":
727
+ # Set first token timestamp if not already set
728
+ if self._current_generation._first_token_timestamp is None:
729
+ self._current_generation._first_token_timestamp = time.time()
730
+
731
+ # Stream text to LiveKit
732
+ if self._current_generation.message_gen:
733
+ self._current_generation.message_gen.text_ch.send_nowait(text_content)
734
+ self._update_chat_ctx(role="assistant", text_content=text_content)
689
735
 
690
736
  def _update_chat_ctx(self, role: llm.ChatRole, text_content: str) -> None:
691
737
  """
@@ -715,107 +761,72 @@ class RealtimeSession( # noqa: F811
715
761
 
716
762
  # cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
717
763
  async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
718
- """Mark the assistant message closed when Bedrock signals END_TURN."""
719
- stop_reason = event_data["event"]["contentEnd"]["stopReason"]
720
- text_content_id = event_data["event"]["contentEnd"]["contentId"]
721
- if (
722
- self._current_generation
723
- is not None # means that first utterance in the turn was an interrupt
724
- and self._current_generation.speculative_messages.get(text_content_id)
725
- == self._current_generation.response_id
726
- and stop_reason == "END_TURN"
727
- ):
728
- log_event_data(event_data)
729
- self._close_current_generation()
764
+ """Handle text content end - log but don't close generation yet."""
765
+ # Nova Sonic sends multiple content blocks within one completion
766
+ # Don't close generation here - wait for completionEnd or audio_output_content_end
767
+ log_event_data(event_data)
730
768
 
731
769
  async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
732
- """Track mapping content_id -> response_id for upcoming tool use."""
770
+ """Track tool content start."""
733
771
  log_event_data(event_data)
734
- assert self._current_generation is not None, "current_generation is None"
735
772
 
736
- tool_use_content_id = event_data["event"]["contentStart"]["contentId"]
737
- self._current_generation.tool_messages[tool_use_content_id] = (
738
- self._current_generation.response_id
739
- )
773
+ # Ensure generation exists
774
+ self._create_response_generation()
775
+
776
+ if self._current_generation is None:
777
+ return
778
+
779
+ content_id = event_data["event"]["contentStart"]["contentId"]
780
+ self._current_generation.content_id_map[content_id] = "TOOL"
740
781
 
741
- # note: tool calls are synchronous for now
742
782
  async def _handle_tool_output_content_event(self, event_data: dict) -> None:
743
- """Execute the referenced tool locally and forward results back to Bedrock."""
783
+ """Execute the referenced tool locally and queue results."""
744
784
  log_event_data(event_data)
745
- assert self._current_generation is not None, "current_generation is None"
746
785
 
747
- tool_use_content_id = event_data["event"]["toolUse"]["contentId"]
786
+ if self._current_generation is None:
787
+ logger.warning("tool_output_content received without active generation")
788
+ return
789
+
748
790
  tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
749
791
  tool_name = event_data["event"]["toolUse"]["toolName"]
750
- if (
751
- self._current_generation.tool_messages.get(tool_use_content_id)
752
- == self._current_generation.response_id
753
- ):
754
- args = event_data["event"]["toolUse"]["content"]
755
- self._current_generation.function_ch.send_nowait(
756
- llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
757
- )
758
- self._pending_tools.add(tool_use_id)
759
-
760
- # performing these acrobatics in order to release the deadlock
761
- # LiveKit will not accept a new generation until the previous one is closed
762
- # the issue is that audio data cannot be generated until toolResult is received
763
- # however, toolResults only arrive after update_chat_ctx() is invoked
764
- # which will only occur after agent speech has completed
765
- # therefore we introduce an artificial turn to trigger update_chat_ctx()
766
- # TODO: this is messy-- investigate if there is a better way to handle this
767
- curr_gen = self._current_generation.messages[self._current_generation.response_id]
768
- curr_gen.audio_ch.close()
769
- curr_gen.text_ch.close()
770
- self._current_generation.message_ch.close()
771
- self._current_generation.message_ch = utils.aio.Chan()
772
- self._current_generation.function_ch.close()
773
- self._current_generation.function_ch = utils.aio.Chan()
774
- msg_gen = _MessageGeneration(
775
- message_id=self._current_generation.response_id,
776
- text_ch=utils.aio.Chan(),
777
- audio_ch=utils.aio.Chan(),
778
- )
779
- self._current_generation.messages[self._current_generation.response_id] = msg_gen
780
- msg_modalities = asyncio.Future[list[Literal["text", "audio"]]]()
781
- msg_modalities.set_result(
782
- ["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
783
- )
784
- self._current_generation.message_ch.send_nowait(
785
- llm.MessageGeneration(
786
- message_id=msg_gen.message_id,
787
- text_stream=msg_gen.text_ch,
788
- audio_stream=msg_gen.audio_ch,
789
- modalities=msg_modalities,
790
- )
791
- )
792
- self.emit_generation_event()
792
+ args = event_data["event"]["toolUse"]["content"]
793
+
794
+ # Emit function call to LiveKit framework
795
+ self._current_generation.function_ch.send_nowait(
796
+ llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
797
+ )
798
+ self._pending_tools.add(tool_use_id)
799
+ logger.debug(f"Tool call emitted: {tool_name} (id={tool_use_id})")
800
+
801
+ # CRITICAL: Close generation after tool call emission
802
+ # The LiveKit framework expects the generation to close so it can call update_chat_ctx()
803
+ # with the tool results. A new generation will be created when Nova Sonic sends the next
804
+ # ASSISTANT SPECULATIVE text event with the tool response.
805
+ logger.debug("Closing generation to allow tool result delivery")
806
+ self._close_current_generation()
793
807
 
794
808
  async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
795
809
  log_event_data(event_data)
796
810
 
797
811
  async def _handle_audio_output_content_start_event(self, event_data: dict) -> None:
798
- """Associate the upcoming audio chunk with the active assistant message."""
812
+ """Track audio content start."""
799
813
  if self._current_generation is not None:
800
814
  log_event_data(event_data)
801
- audio_content_id = event_data["event"]["contentStart"]["contentId"]
802
- self._current_generation.speculative_messages[audio_content_id] = (
803
- self._current_generation.response_id
804
- )
815
+ content_id = event_data["event"]["contentStart"]["contentId"]
816
+ self._current_generation.content_id_map[content_id] = "ASSISTANT_AUDIO"
805
817
 
806
818
  async def _handle_audio_output_content_event(self, event_data: dict) -> None:
807
819
  """Decode base64 audio from Bedrock and forward it to the audio stream."""
808
- if (
809
- self._current_generation is not None
810
- and self._current_generation.speculative_messages.get(
811
- event_data["event"]["audioOutput"]["contentId"]
812
- )
813
- == self._current_generation.response_id
814
- ):
820
+ if self._current_generation is None or self._current_generation.message_gen is None:
821
+ return
822
+
823
+ content_id = event_data["event"]["audioOutput"]["contentId"]
824
+ content_type = self._current_generation.content_id_map.get(content_id)
825
+
826
+ if content_type == "ASSISTANT_AUDIO":
815
827
  audio_content = event_data["event"]["audioOutput"]["content"]
816
828
  audio_bytes = base64.b64decode(audio_content)
817
- curr_gen = self._current_generation.messages[self._current_generation.response_id]
818
- curr_gen.audio_ch.send_nowait(
829
+ self._current_generation.message_gen.audio_ch.send_nowait(
819
830
  rtc.AudioFrame(
820
831
  data=audio_bytes,
821
832
  sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
@@ -825,62 +836,89 @@ class RealtimeSession( # noqa: F811
825
836
  )
826
837
 
827
838
  async def _handle_audio_output_content_end_event(self, event_data: dict) -> None:
828
- """Close the assistant message streams once Bedrock finishes audio for the turn."""
829
- if (
830
- self._current_generation is not None
831
- and event_data["event"]["contentEnd"]["stopReason"] == "END_TURN"
832
- and self._current_generation.speculative_messages.get(
833
- event_data["event"]["contentEnd"]["contentId"]
834
- )
835
- == self._current_generation.response_id
836
- ):
837
- log_event_data(event_data)
838
- self._close_current_generation()
839
+ """Handle audio content end - log but don't close generation."""
840
+ log_event_data(event_data)
841
+ # Nova Sonic uses one completion for entire session
842
+ # Don't close generation here - wait for new completionStart or session end
839
843
 
840
844
  def _close_current_generation(self) -> None:
841
- """Helper that closes all channels of the active _ResponseGeneration."""
842
- if self._current_generation is not None:
843
- if self._current_generation.response_id in self._current_generation.messages:
844
- curr_gen = self._current_generation.messages[self._current_generation.response_id]
845
- if not curr_gen.audio_ch.closed:
846
- curr_gen.audio_ch.close()
847
- if not curr_gen.text_ch.closed:
848
- curr_gen.text_ch.close()
849
-
850
- # TODO: seems not needed, tool_messages[id] is a str, function_ch is closed below?
851
- # if self._current_generation.response_id in self._current_generation.tool_messages:
852
- # curr_gen = self._current_generation.tool_messages[
853
- # self._current_generation.response_id
854
- # ]
855
- # if not curr_gen.function_ch.closed:
856
- # curr_gen.function_ch.close()
857
-
858
- if not self._current_generation.message_ch.closed:
859
- self._current_generation.message_ch.close()
860
- if not self._current_generation.function_ch.closed:
861
- self._current_generation.function_ch.close()
862
-
863
- self._current_generation = None
845
+ """Helper that closes all channels of the active generation."""
846
+ if self._current_generation is None:
847
+ return
848
+
849
+ # Set completed timestamp
850
+ if self._current_generation._completed_timestamp is None:
851
+ self._current_generation._completed_timestamp = time.time()
852
+
853
+ # Close message channels
854
+ if self._current_generation.message_gen:
855
+ if not self._current_generation.message_gen.audio_ch.closed:
856
+ self._current_generation.message_gen.audio_ch.close()
857
+ if not self._current_generation.message_gen.text_ch.closed:
858
+ self._current_generation.message_gen.text_ch.close()
859
+
860
+ # Close generation channels
861
+ if not self._current_generation.message_ch.closed:
862
+ self._current_generation.message_ch.close()
863
+ if not self._current_generation.function_ch.closed:
864
+ self._current_generation.function_ch.close()
865
+
866
+ logger.debug(
867
+ f"Closed generation for completion_id={self._current_generation.completion_id}"
868
+ )
869
+ self._current_generation = None
864
870
 
865
871
  async def _handle_completion_end_event(self, event_data: dict) -> None:
872
+ """Handle completionEnd - close the generation for this completion cycle."""
866
873
  log_event_data(event_data)
867
874
 
875
+ # Close generation if still open
876
+ if self._current_generation:
877
+ logger.debug("completionEnd received, closing generation")
878
+ self._close_current_generation()
879
+
868
880
  async def _handle_other_event(self, event_data: dict) -> None:
869
881
  log_event_data(event_data)
870
882
 
871
883
  async def _handle_usage_event(self, event_data: dict) -> None:
872
884
  # log_event_data(event_data)
873
- # TODO: implement duration and ttft
874
885
  input_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["input"]
875
886
  output_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["output"]
876
- # Q: should we be counting per turn or utterance?
887
+
888
+ # Calculate metrics from timestamps
889
+ duration = 0.0
890
+ ttft = 0.0
891
+ tokens_per_second = 0.0
892
+
893
+ if self._current_generation is not None:
894
+ created_ts = self._current_generation._created_timestamp
895
+ first_token_ts = self._current_generation._first_token_timestamp
896
+ completed_ts = self._current_generation._completed_timestamp
897
+
898
+ # Calculate TTFT (time to first token)
899
+ if first_token_ts is not None and isinstance(created_ts, (int, float)):
900
+ ttft = first_token_ts - created_ts
901
+
902
+ # Calculate duration (total time from creation to completion)
903
+ if completed_ts is not None and isinstance(created_ts, (int, float)):
904
+ duration = completed_ts - created_ts
905
+
906
+ # Calculate tokens per second
907
+ total_tokens = (
908
+ input_tokens["speechTokens"]
909
+ + input_tokens["textTokens"]
910
+ + output_tokens["speechTokens"]
911
+ + output_tokens["textTokens"]
912
+ )
913
+ if duration > 0:
914
+ tokens_per_second = total_tokens / duration
915
+
877
916
  metrics = RealtimeModelMetrics(
878
917
  label=self._realtime_model.label,
879
- # TODO: pass in the correct request_id
880
918
  request_id=event_data["event"]["usageEvent"]["completionId"],
881
919
  timestamp=time.monotonic(),
882
- duration=0,
883
- ttft=0,
920
+ duration=duration,
921
+ ttft=ttft,
884
922
  cancelled=False,
885
923
  input_tokens=input_tokens["speechTokens"] + input_tokens["textTokens"],
886
924
  output_tokens=output_tokens["speechTokens"] + output_tokens["textTokens"],
@@ -888,8 +926,7 @@ class RealtimeSession( # noqa: F811
888
926
  + input_tokens["textTokens"]
889
927
  + output_tokens["speechTokens"]
890
928
  + output_tokens["textTokens"],
891
- # need duration to calculate this
892
- tokens_per_second=0,
929
+ tokens_per_second=tokens_per_second,
893
930
  input_token_details=RealtimeModelMetrics.InputTokenDetails(
894
931
  text_tokens=input_tokens["textTokens"],
895
932
  audio_tokens=input_tokens["speechTokens"],
@@ -963,7 +1000,12 @@ class RealtimeSession( # noqa: F811
963
1000
  ),
964
1001
  )
965
1002
  raise
966
- except (ThrottlingException, ModelNotReadyException, ModelErrorException) as re:
1003
+ except (
1004
+ ThrottlingException,
1005
+ ModelNotReadyException,
1006
+ ModelErrorException,
1007
+ ModelStreamErrorException,
1008
+ ) as re:
967
1009
  logger.warning(f"Retryable error: {re}\nAttempting to recover...")
968
1010
  await self._restart_session(re)
969
1011
  break
@@ -1010,8 +1052,13 @@ class RealtimeSession( # noqa: F811
1010
1052
  self._is_sess_active.clear()
1011
1053
 
1012
1054
  async def _restart_session(self, ex: Exception) -> None:
1013
- if self._session_restart_attempts >= DEFAULT_MAX_SESSION_RESTART_ATTEMPTS:
1014
- logger.error("Max session restart attempts reached, exiting")
1055
+ # Get restart attempts from current generation, or 0 if no generation
1056
+ restart_attempts = (
1057
+ self._current_generation._restart_attempts if self._current_generation else 0
1058
+ )
1059
+
1060
+ if restart_attempts >= DEFAULT_MAX_SESSION_RESTART_ATTEMPTS:
1061
+ logger.error("Max restart attempts reached for this turn, exiting")
1015
1062
  err_msg = getattr(ex, "message", str(ex))
1016
1063
  request_id = None
1017
1064
  try:
@@ -1035,13 +1082,20 @@ class RealtimeSession( # noqa: F811
1035
1082
  )
1036
1083
  self._is_sess_active.clear()
1037
1084
  return
1038
- self._session_restart_attempts += 1
1085
+
1086
+ # Increment restart counter for current generation
1087
+ if self._current_generation:
1088
+ self._current_generation._restart_attempts += 1
1089
+ restart_attempts = self._current_generation._restart_attempts
1090
+ else:
1091
+ restart_attempts = 1
1092
+
1039
1093
  self._is_sess_active.clear()
1040
- delay = 2 ** (self._session_restart_attempts - 1) - 1
1094
+ delay = 2 ** (restart_attempts - 1) - 1
1041
1095
  await asyncio.sleep(min(delay, DEFAULT_MAX_SESSION_RESTART_DELAY))
1042
1096
  await self.initialize_streams(is_restart=True)
1043
1097
  logger.info(
1044
- f"Session restarted successfully ({self._session_restart_attempts}/{DEFAULT_MAX_SESSION_RESTART_ATTEMPTS})" # noqa: E501
1098
+ f"Turn restarted successfully ({restart_attempts}/{DEFAULT_MAX_SESSION_RESTART_ATTEMPTS})"
1045
1099
  )
1046
1100
 
1047
1101
  @property
@@ -1055,7 +1109,10 @@ class RealtimeSession( # noqa: F811
1055
1109
  async def update_instructions(self, instructions: str) -> None:
1056
1110
  """Injects the system prompt at the start of the session."""
1057
1111
  self._instructions = instructions
1058
- self._instructions_ready.set_result(True)
1112
+ if self._instructions_ready is None:
1113
+ self._instructions_ready = asyncio.get_running_loop().create_future()
1114
+ if not self._instructions_ready.done():
1115
+ self._instructions_ready.set_result(True)
1059
1116
  logger.debug(f"Instructions updated: {instructions}")
1060
1117
 
1061
1118
  async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
@@ -1063,16 +1120,26 @@ class RealtimeSession( # noqa: F811
1063
1120
  # sometimes fires randomly
1064
1121
  # add a guard here to only allow chat_ctx to be updated on
1065
1122
  # the very first session initialization
1123
+ if self._chat_ctx_ready is None:
1124
+ self._chat_ctx_ready = asyncio.get_running_loop().create_future()
1125
+
1066
1126
  if not self._chat_ctx_ready.done():
1067
1127
  self._chat_ctx = chat_ctx.copy()
1068
1128
  logger.debug(f"Chat context updated: {self._chat_ctx.items}")
1069
1129
  self._chat_ctx_ready.set_result(True)
1070
1130
 
1071
1131
  # for each function tool, send the result to aws
1132
+ logger.debug(
1133
+ f"update_chat_ctx called with {len(chat_ctx.items)} items, pending_tools: {self._pending_tools}"
1134
+ )
1072
1135
  for item in chat_ctx.items:
1073
1136
  if item.type != "function_call_output":
1074
1137
  continue
1075
1138
 
1139
+ logger.debug(
1140
+ f"Found function_call_output: call_id={item.call_id}, in_pending={item.call_id in self._pending_tools}"
1141
+ )
1142
+
1076
1143
  if item.call_id not in self._pending_tools:
1077
1144
  continue
1078
1145
 
@@ -1122,7 +1189,10 @@ class RealtimeSession( # noqa: F811
1122
1189
  retained_tools.append(tool)
1123
1190
  self._tools = llm.ToolContext(retained_tools)
1124
1191
  if retained_tools:
1125
- self._tools_ready.set_result(True)
1192
+ if self._tools_ready is None:
1193
+ self._tools_ready = asyncio.get_running_loop().create_future()
1194
+ if not self._tools_ready.done():
1195
+ self._tools_ready.set_result(True)
1126
1196
  logger.debug("Tool list has been injected")
1127
1197
 
1128
1198
  def update_options(self, *, tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN) -> None:
@@ -1158,54 +1228,62 @@ class RealtimeSession( # noqa: F811
1158
1228
  """Background task that feeds audio and tool results into the Bedrock stream."""
1159
1229
  await self._send_raw_event(self._event_builder.create_audio_content_start_event())
1160
1230
  logger.info("Starting audio input processing loop")
1231
+
1232
+ # Create tasks for both channels so we can wait on either
1233
+ audio_task = asyncio.create_task(self._audio_input_chan.recv())
1234
+ tool_task = asyncio.create_task(self._tool_results_ch.recv())
1235
+ pending = {audio_task, tool_task}
1236
+
1161
1237
  while self._is_sess_active.is_set():
1162
1238
  try:
1163
- # note: could potentially pull this out into a separate task
1164
- try:
1165
- val = self._tool_results_ch.recv_nowait()
1166
- tool_result = val["tool_result"]
1167
- tool_use_id = val["tool_use_id"]
1168
- if not isinstance(tool_result, str):
1169
- tool_result = json.dumps(tool_result)
1170
- else:
1171
- try:
1172
- json.loads(tool_result)
1173
- except json.JSONDecodeError:
1174
- try:
1175
- tool_result = json.dumps(ast.literal_eval(tool_result))
1176
- except Exception:
1177
- # return the original value
1178
- pass
1179
-
1180
- logger.debug(f"Sending tool result: {tool_result}")
1181
- await self._send_tool_events(tool_use_id, tool_result)
1182
-
1183
- except utils.aio.channel.ChanEmpty:
1184
- pass
1185
- except utils.aio.channel.ChanClosed:
1186
- logger.warning(
1187
- "tool results channel closed, exiting audio input processing loop"
1188
- )
1189
- break
1239
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
1190
1240
 
1191
- try:
1192
- audio_bytes = await self._audio_input_chan.recv()
1193
- blob = base64.b64encode(audio_bytes)
1194
- audio_event = self._event_builder.create_audio_input_event(
1195
- audio_content=blob.decode("utf-8"),
1196
- )
1197
-
1198
- await self._send_raw_event(audio_event)
1199
- except utils.aio.channel.ChanEmpty:
1200
- pass
1201
- except utils.aio.channel.ChanClosed:
1202
- logger.warning(
1203
- "audio input channel closed, exiting audio input processing loop"
1204
- )
1205
- break
1241
+ for task in done:
1242
+ if task == audio_task:
1243
+ try:
1244
+ audio_bytes = cast(bytes, task.result())
1245
+ blob = base64.b64encode(audio_bytes)
1246
+ audio_event = self._event_builder.create_audio_input_event(
1247
+ audio_content=blob.decode("utf-8"),
1248
+ )
1249
+ await self._send_raw_event(audio_event)
1250
+ # Create new task for next audio
1251
+ audio_task = asyncio.create_task(self._audio_input_chan.recv())
1252
+ pending.add(audio_task)
1253
+ except utils.aio.channel.ChanClosed:
1254
+ logger.warning("audio input channel closed")
1255
+ break
1256
+
1257
+ elif task == tool_task:
1258
+ try:
1259
+ val = cast(dict[str, str], task.result())
1260
+ tool_result = val["tool_result"]
1261
+ tool_use_id = val["tool_use_id"]
1262
+ if not isinstance(tool_result, str):
1263
+ tool_result = json.dumps(tool_result)
1264
+ else:
1265
+ try:
1266
+ json.loads(tool_result)
1267
+ except json.JSONDecodeError:
1268
+ try:
1269
+ tool_result = json.dumps(ast.literal_eval(tool_result))
1270
+ except Exception:
1271
+ pass
1272
+
1273
+ logger.debug(f"Sending tool result: {tool_result}")
1274
+ await self._send_tool_events(tool_use_id, tool_result)
1275
+ # Create new task for next tool result
1276
+ tool_task = asyncio.create_task(self._tool_results_ch.recv())
1277
+ pending.add(tool_task)
1278
+ except utils.aio.channel.ChanClosed:
1279
+ logger.warning("tool results channel closed")
1280
+ break
1206
1281
 
1207
1282
  except asyncio.CancelledError:
1208
1283
  logger.info("Audio processing loop cancelled")
1284
+ # Cancel pending tasks
1285
+ for task in pending:
1286
+ task.cancel()
1209
1287
  self._audio_input_chan.close()
1210
1288
  self._tool_results_ch.close()
1211
1289
  raise
@@ -1256,7 +1334,24 @@ class RealtimeSession( # noqa: F811
1256
1334
  logger.warning("video is not supported by Nova Sonic's Realtime API")
1257
1335
 
1258
1336
  def interrupt(self) -> None:
1259
- logger.warning("interrupt is not supported by Nova Sonic's Realtime API")
1337
+ """Nova Sonic handles interruption automatically via barge-in detection.
1338
+
1339
+ Unlike OpenAI's client-initiated interrupt, Nova Sonic automatically detects
1340
+ when the user starts speaking while the model is generating audio. When this
1341
+ happens, the model:
1342
+ 1. Immediately stops generating speech
1343
+ 2. Switches to listening mode
1344
+ 3. Sends a text event with content: { "interrupted" : true }
1345
+
1346
+ The plugin already handles this event (see _handle_text_output_content_event).
1347
+ No client action is needed - interruption works automatically.
1348
+
1349
+ See AWS docs: https://docs.aws.amazon.com/nova/latest/userguide/output-events.html
1350
+ """
1351
+ logger.info(
1352
+ "Nova Sonic handles interruption automatically via barge-in detection. "
1353
+ "The model detects when users start speaking and stops generation automatically."
1354
+ )
1260
1355
 
1261
1356
  def truncate(
1262
1357
  self,
@@ -1306,6 +1401,11 @@ class RealtimeSession( # noqa: F811
1306
1401
  if self._stream_response and not self._stream_response.input_stream.closed:
1307
1402
  await self._stream_response.input_stream.close()
1308
1403
 
1404
+ # cancel main task to prevent pending task warnings
1405
+ if self._main_atask and not self._main_atask.done():
1406
+ self._main_atask.cancel()
1407
+ tasks.append(self._main_atask)
1408
+
1309
1409
  await asyncio.gather(*tasks, return_exceptions=True)
1310
1410
  logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
1311
1411
  logger.info("Session end")
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.2.15"
15
+ __version__ = "1.2.17"
@@ -1,13 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-aws
3
- Version: 1.2.15
3
+ Version: 1.2.17
4
4
  Summary: LiveKit Agents Plugin for services from AWS
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
7
7
  Project-URL: Source, https://github.com/livekit/agents
8
8
  Author-email: LiveKit <hello@livekit.io>
9
9
  License-Expression: Apache-2.0
10
- Keywords: audio,aws,livekit,nova,realtime,sonic,video,webrtc
10
+ Keywords: ai,audio,aws,livekit,nova,realtime,sonic,video,voice
11
11
  Classifier: Intended Audience :: Developers
12
12
  Classifier: License :: OSI Approved :: Apache Software License
13
13
  Classifier: Programming Language :: Python :: 3
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9.0
21
21
  Requires-Dist: aioboto3>=14.1.0
22
22
  Requires-Dist: amazon-transcribe>=0.6.4
23
- Requires-Dist: livekit-agents>=1.2.15
23
+ Requires-Dist: livekit-agents>=1.2.17
24
24
  Provides-Extra: realtime
25
25
  Requires-Dist: aws-sdk-bedrock-runtime==0.0.2; (python_version >= '3.12') and extra == 'realtime'
26
26
  Requires-Dist: aws-sdk-signers==0.0.3; (python_version >= '3.12') and extra == 'realtime'
@@ -6,12 +6,12 @@ livekit/plugins/aws/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
6
6
  livekit/plugins/aws/stt.py,sha256=OwHfGDFW-VtSY68my3u5dyMLRpYVZPlnVbtm-dNj4Q0,9438
7
7
  livekit/plugins/aws/tts.py,sha256=oav-XWf9ysVGCmERWej6BgACu8vsLbRo9vFGpo9N6Ec,7184
8
8
  livekit/plugins/aws/utils.py,sha256=nA5Ua1f4T-25Loar6EvlrKTXI9N-zpTIH7cdQkwGyGI,1518
9
- livekit/plugins/aws/version.py,sha256=R5FvTAJuFKBJlKNE37WH1vS6st7RUEFAUNaLi-rjprE,601
9
+ livekit/plugins/aws/version.py,sha256=ZlvvHSEyo4YT35z0OKDbZvGI4D1lVYqTvemcuSdOS8o,601
10
10
  livekit/plugins/aws/experimental/realtime/__init__.py,sha256=mm_TGZc9QAWSO-VOO3PdE8Y5R6xlWckXRZuiFUIHa-Q,287
11
11
  livekit/plugins/aws/experimental/realtime/events.py,sha256=ltdGEipE3ZOkjn7K6rKN6WSCUPJkVg-S88mUmQ_V00s,15981
12
12
  livekit/plugins/aws/experimental/realtime/pretty_printer.py,sha256=KN7KPrfQu8cU7ff34vFAtfrd1umUSTVNKXQU7D8AMiM,1442
13
- livekit/plugins/aws/experimental/realtime/realtime_model.py,sha256=vieey2nDEIDaDQsTkmZ7p-NgvBaE9VlDMeQiduCkedI,60846
13
+ livekit/plugins/aws/experimental/realtime/realtime_model.py,sha256=ksdw7X-wm5wiDoCur9srYTRV2eVadHOjAIIanNS9dUo,64568
14
14
  livekit/plugins/aws/experimental/realtime/turn_tracker.py,sha256=bcufaap-coeIYuK3ct1Is9W_UoefGYRmnJu7Mn5DCYU,6002
15
- livekit_plugins_aws-1.2.15.dist-info/METADATA,sha256=NPvPAcrOSPSZ7A6f7X3RdVz30vR7skXx3XGtxGwk06U,2081
16
- livekit_plugins_aws-1.2.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- livekit_plugins_aws-1.2.15.dist-info/RECORD,,
15
+ livekit_plugins_aws-1.2.17.dist-info/METADATA,sha256=53WRXByqiLoDG0B9RoWMYIMWqoxxuWl1XHvMIH3rXsE,2083
16
+ livekit_plugins_aws-1.2.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ livekit_plugins_aws-1.2.17.dist-info/RECORD,,