livekit-plugins-aws 1.2.16__py3-none-any.whl → 1.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of livekit-plugins-aws might be problematic. Click here for more details.

@@ -12,7 +12,6 @@ import uuid
12
12
  import weakref
13
13
  from collections.abc import Iterator
14
14
  from dataclasses import dataclass, field
15
- from datetime import datetime
16
15
  from typing import Any, Callable, Literal, cast
17
16
 
18
17
  import boto3
@@ -124,38 +123,35 @@ class _MessageGeneration:
124
123
 
125
124
  @dataclass
126
125
  class _ResponseGeneration:
127
- """Book-keeping dataclass tracking the lifecycle of a Sonic turn.
126
+ """Book-keeping dataclass tracking the lifecycle of a Nova Sonic completion.
128
127
 
129
- This object is created whenever we receive a *completion_start* event from the model
130
- and is disposed of once the assistant turn finishes (e.g. *END_TURN*).
128
+ Nova Sonic uses a completion model where one completionStart event begins a cycle
129
+ that may contain multiple content blocks (USER ASR, TOOL, ASSISTANT text/audio).
130
+ This generation stays open for the entire completion cycle.
131
131
 
132
132
  Attributes:
133
- message_ch (utils.aio.Chan[llm.MessageGeneration]): Multiplexed stream for all assistant messages.
133
+ completion_id (str): Nova Sonic's completionId that ties all events together.
134
+ message_ch (utils.aio.Chan[llm.MessageGeneration]): Stream for assistant messages.
134
135
  function_ch (utils.aio.Chan[llm.FunctionCall]): Stream that emits function tool calls.
135
- input_id (str): Synthetic message id for the user input of the current turn.
136
- response_id (str): Synthetic message id for the assistant reply of the current turn.
137
- messages (dict[str, _MessageGeneration]): Map of message_id -> per-message stream containers.
138
- user_messages (dict[str, str]): Map Bedrock content_id -> input_id.
139
- speculative_messages (dict[str, str]): Map Bedrock content_id -> response_id (assistant side).
140
- tool_messages (dict[str, str]): Map Bedrock content_id -> response_id for tool calls.
141
- output_text (str): Accumulated assistant text (only used for metrics / debugging).
142
- _created_timestamp (str): ISO-8601 timestamp when the generation record was created.
136
+ response_id (str): LiveKit response_id for the assistant's response.
137
+ message_gen (_MessageGeneration | None): Current message generation for assistant output.
138
+ content_id_map (dict[str, str]): Map Nova Sonic contentId -> type (USER/ASSISTANT/TOOL).
139
+ _created_timestamp (float): Wall-clock time when the generation record was created.
143
140
  _first_token_timestamp (float | None): Wall-clock time of first token emission.
144
141
  _completed_timestamp (float | None): Wall-clock time when the turn fully completed.
142
+ _restart_attempts (int): Number of restart attempts for this specific completion.
145
143
  """ # noqa: E501
146
144
 
145
+ completion_id: str
147
146
  message_ch: utils.aio.Chan[llm.MessageGeneration]
148
147
  function_ch: utils.aio.Chan[llm.FunctionCall]
149
- input_id: str # corresponds to user's portion of the turn
150
- response_id: str # corresponds to agent's portion of the turn
151
- messages: dict[str, _MessageGeneration] = field(default_factory=dict)
152
- user_messages: dict[str, str] = field(default_factory=dict)
153
- speculative_messages: dict[str, str] = field(default_factory=dict)
154
- tool_messages: dict[str, str] = field(default_factory=dict)
155
- output_text: str = "" # agent ASR text
156
- _created_timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
148
+ response_id: str
149
+ message_gen: _MessageGeneration | None = None
150
+ content_id_map: dict[str, str] = field(default_factory=dict)
151
+ _created_timestamp: float = field(default_factory=time.time)
157
152
  _first_token_timestamp: float | None = None
158
153
  _completed_timestamp: float | None = None
154
+ _restart_attempts: int = 0
159
155
 
160
156
 
161
157
  class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc]
@@ -272,16 +268,12 @@ class RealtimeModel(llm.RealtimeModel):
272
268
  def session(self) -> RealtimeSession:
273
269
  """Return a new RealtimeSession bound to this model instance."""
274
270
  sess = RealtimeSession(self)
275
-
276
- # note: this is a hack to get the session to initialize itself
277
- # TODO: change how RealtimeSession is initialized by creating a single task main_atask that spawns subtasks # noqa: E501
278
- asyncio.create_task(sess.initialize_streams())
279
271
  self._sessions.add(sess)
280
272
  return sess
281
273
 
282
- # stub b/c RealtimeSession.aclose() is invoked directly
283
- async def aclose(self) -> None:
284
- pass
274
+ async def aclose(self) -> None:
275
+ """Close all active sessions."""
276
+ pass
285
277
 
286
278
 
287
279
  class RealtimeSession( # noqa: F811
@@ -327,17 +319,16 @@ class RealtimeSession( # noqa: F811
327
319
  self._chat_ctx = llm.ChatContext.empty()
328
320
  self._tools = llm.ToolContext.empty()
329
321
  self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
330
- self._tools_ready = asyncio.get_running_loop().create_future()
331
- self._instructions_ready = asyncio.get_running_loop().create_future()
332
- self._chat_ctx_ready = asyncio.get_running_loop().create_future()
322
+ # CRITICAL: Initialize futures as None for lazy creation
323
+ # Creating futures in __init__ causes race conditions during session restart.
324
+ # Futures are created in initialize_streams() when the event loop is guaranteed to exist.
325
+ self._tools_ready: asyncio.Future[bool] | None = None
326
+ self._instructions_ready: asyncio.Future[bool] | None = None
327
+ self._chat_ctx_ready: asyncio.Future[bool] | None = None
333
328
  self._instructions = DEFAULT_SYSTEM_PROMPT
334
329
  self._audio_input_chan = utils.aio.Chan[bytes]()
335
330
  self._current_generation: _ResponseGeneration | None = None
336
331
 
337
- # note: currently tracks session restart attempts across all sessions
338
- # TODO: track restart attempts per turn
339
- self._session_restart_attempts = 0
340
-
341
332
  self._event_handlers = {
342
333
  "completion_start": self._handle_completion_start_event,
343
334
  "audio_output_content_start": self._handle_audio_output_content_start_event,
@@ -358,6 +349,11 @@ class RealtimeSession( # noqa: F811
358
349
  cast(Callable[[], None], self.emit_generation_event),
359
350
  )
360
351
 
352
+ # Create main task to manage session lifecycle
353
+ self._main_atask = asyncio.create_task(
354
+ self.initialize_streams(), name="RealtimeSession.initialize_streams"
355
+ )
356
+
361
357
  @utils.log_exceptions(logger=logger)
362
358
  def _initialize_client(self) -> None:
363
359
  """Instantiate the Bedrock runtime client"""
@@ -491,6 +487,14 @@ class RealtimeSession( # noqa: F811
491
487
  )
492
488
 
493
489
  if not is_restart:
490
+ # Lazy-initialize futures if needed
491
+ if self._tools_ready is None:
492
+ self._tools_ready = asyncio.get_running_loop().create_future()
493
+ if self._instructions_ready is None:
494
+ self._instructions_ready = asyncio.get_running_loop().create_future()
495
+ if self._chat_ctx_ready is None:
496
+ self._chat_ctx_ready = asyncio.get_running_loop().create_future()
497
+
494
498
  pending_events: list[asyncio.Future] = []
495
499
  if not self.tools.function_tools:
496
500
  pending_events.append(self._tools_ready)
@@ -506,14 +510,14 @@ class RealtimeSession( # noqa: F811
506
510
  if pending_events:
507
511
  await asyncio.wait_for(asyncio.gather(*pending_events), timeout=0.5)
508
512
  except asyncio.TimeoutError:
509
- if not self._tools_ready.done():
513
+ if self._tools_ready and not self._tools_ready.done():
510
514
  logger.warning("Tools not ready after 500ms, continuing without them")
511
515
 
512
- if not self._instructions_ready.done():
516
+ if self._instructions_ready and not self._instructions_ready.done():
513
517
  logger.warning(
514
518
  "Instructions not received after 500ms, proceeding with default instructions" # noqa: E501
515
519
  )
516
- if not self._chat_ctx_ready.done():
520
+ if self._chat_ctx_ready and not self._chat_ctx_ready.done():
517
521
  logger.warning(
518
522
  "Chat context not received after 500ms, proceeding with empty chat context" # noqa: E501
519
523
  )
@@ -560,9 +564,11 @@ class RealtimeSession( # noqa: F811
560
564
  @utils.log_exceptions(logger=logger)
561
565
  def emit_generation_event(self) -> None:
562
566
  """Publish a llm.GenerationCreatedEvent to external subscribers."""
563
- logger.debug("Emitting generation event")
564
- assert self._current_generation is not None, "current_generation is None"
567
+ if self._current_generation is None:
568
+ logger.debug("emit_generation_event called but no generation exists - ignoring")
569
+ return
565
570
 
571
+ logger.debug("Emitting generation event")
566
572
  generation_ev = llm.GenerationCreatedEvent(
567
573
  message_stream=self._current_generation.message_ch,
568
574
  function_stream=self._current_generation.function_ch,
@@ -583,21 +589,38 @@ class RealtimeSession( # noqa: F811
583
589
  logger.warning(f"No event handler found for event type: {event_type}")
584
590
 
585
591
  async def _handle_completion_start_event(self, event_data: dict) -> None:
592
+ """Handle completionStart - create new generation for this completion cycle."""
586
593
  log_event_data(event_data)
587
594
  self._create_response_generation()
588
595
 
589
596
  def _create_response_generation(self) -> None:
590
- """Instantiate _ResponseGeneration and emit the GenerationCreated event."""
597
+ """Instantiate _ResponseGeneration and emit the GenerationCreated event.
598
+
599
+ Can be called multiple times - will reuse existing generation but ensure
600
+ message structure exists.
601
+ """
602
+ generation_created = False
591
603
  if self._current_generation is None:
604
+ completion_id = "unknown" # Will be set from events
605
+ response_id = str(uuid.uuid4())
606
+
607
+ logger.debug(f"Creating new generation, response_id={response_id}")
592
608
  self._current_generation = _ResponseGeneration(
609
+ completion_id=completion_id,
593
610
  message_ch=utils.aio.Chan(),
594
611
  function_ch=utils.aio.Chan(),
595
- input_id=str(uuid.uuid4()),
596
- response_id=str(uuid.uuid4()),
597
- messages={},
598
- user_messages={},
599
- speculative_messages={},
600
- _created_timestamp=datetime.now().isoformat(),
612
+ response_id=response_id,
613
+ )
614
+ generation_created = True
615
+ else:
616
+ logger.debug(
617
+ f"Generation already exists: response_id={self._current_generation.response_id}"
618
+ )
619
+
620
+ # Always ensure message structure exists (even if generation already exists)
621
+ if self._current_generation.message_gen is None:
622
+ logger.debug(
623
+ f"Creating message structure for response_id={self._current_generation.response_id}"
601
624
  )
602
625
  msg_gen = _MessageGeneration(
603
626
  message_id=self._current_generation.response_id,
@@ -608,6 +631,8 @@ class RealtimeSession( # noqa: F811
608
631
  msg_modalities.set_result(
609
632
  ["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
610
633
  )
634
+
635
+ self._current_generation.message_gen = msg_gen
611
636
  self._current_generation.message_ch.send_nowait(
612
637
  llm.MessageGeneration(
613
638
  message_id=msg_gen.message_id,
@@ -616,77 +641,97 @@ class RealtimeSession( # noqa: F811
616
641
  modalities=msg_modalities,
617
642
  )
618
643
  )
619
- self._current_generation.messages[self._current_generation.response_id] = msg_gen
644
+ else:
645
+ logger.debug(
646
+ f"Message structure already exists for response_id={self._current_generation.response_id}"
647
+ )
648
+
649
+ # Only emit generation event if we created a new generation
650
+ if generation_created:
651
+ self.emit_generation_event()
620
652
 
621
653
  # will be completely ignoring post-ASR text events
622
654
  async def _handle_text_output_content_start_event(self, event_data: dict) -> None:
623
- """Handle text_output_content_start for both user and assistant roles."""
655
+ """Handle text_output_content_start - track content type."""
624
656
  log_event_data(event_data)
657
+
625
658
  role = event_data["event"]["contentStart"]["role"]
626
- self._create_response_generation()
627
659
 
628
- # note: does not work if you emit llm.GCE too early (for some reason)
629
- if role == "USER":
630
- assert self._current_generation is not None, "current_generation is None"
660
+ # CRITICAL: Create NEW generation for each ASSISTANT SPECULATIVE response
661
+ # Nova Sonic sends ASSISTANT SPECULATIVE for each new assistant turn, including after tool calls.
662
+ # Without this, audio frames get routed to the wrong generation and don't play.
663
+ if role == "ASSISTANT":
664
+ additional_fields = event_data["event"]["contentStart"].get("additionalModelFields", "")
665
+ if "SPECULATIVE" in additional_fields:
666
+ # This is a new assistant response - close previous and create new
667
+ logger.debug("ASSISTANT SPECULATIVE text - creating new generation")
668
+ if self._current_generation is not None:
669
+ logger.debug("Closing previous generation for new assistant response")
670
+ self._close_current_generation()
671
+ self._create_response_generation()
672
+ else:
673
+ # For USER and FINAL, just ensure generation exists
674
+ self._create_response_generation()
631
675
 
632
- content_id = event_data["event"]["contentStart"]["contentId"]
633
- self._current_generation.user_messages[content_id] = self._current_generation.input_id
676
+ # CRITICAL: Check if generation exists before accessing
677
+ # Barge-in can set _current_generation to None between the creation above and here.
678
+ # Without this check, we crash on interruptions.
679
+ if self._current_generation is None:
680
+ logger.debug("No generation exists - ignoring content_start event")
681
+ return
634
682
 
635
- elif (
636
- role == "ASSISTANT"
637
- and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
638
- ):
639
- assert self._current_generation is not None, "current_generation is None"
683
+ content_id = event_data["event"]["contentStart"]["contentId"]
640
684
 
641
- text_content_id = event_data["event"]["contentStart"]["contentId"]
642
- self._current_generation.speculative_messages[text_content_id] = (
643
- self._current_generation.response_id
644
- )
685
+ # Track what type of content this is
686
+ if role == "USER":
687
+ self._current_generation.content_id_map[content_id] = "USER_ASR"
688
+ elif role == "ASSISTANT":
689
+ additional_fields = event_data["event"]["contentStart"].get("additionalModelFields", "")
690
+ if "SPECULATIVE" in additional_fields:
691
+ self._current_generation.content_id_map[content_id] = "ASSISTANT_TEXT"
692
+ elif "FINAL" in additional_fields:
693
+ self._current_generation.content_id_map[content_id] = "ASSISTANT_FINAL"
645
694
 
646
695
  async def _handle_text_output_content_event(self, event_data: dict) -> None:
647
- """Stream partial text tokens into the current _MessageGeneration."""
696
+ """Stream partial text tokens into the current generation."""
648
697
  log_event_data(event_data)
649
- text_content_id = event_data["event"]["textOutput"]["contentId"]
698
+
699
+ if self._current_generation is None:
700
+ logger.debug("No generation exists - ignoring text_output event")
701
+ return
702
+
703
+ content_id = event_data["event"]["textOutput"]["contentId"]
650
704
  text_content = f"{event_data['event']['textOutput']['content']}\n"
651
705
 
652
- # currently only agent can be interrupted
706
+ # Nova Sonic's automatic barge-in detection
653
707
  if text_content == '{ "interrupted" : true }\n':
654
- # the interrupted flag is not being set correctly in chat_ctx
655
- # this is b/c audio playback is desynced from text transcription
656
- # TODO: fix this; possibly via a playback timer
657
708
  idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
658
- if idx < 0:
659
- logger.warning("Barge-in DETECTED but no previous message found")
660
- return
661
-
662
- logger.debug(
663
- f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
664
- )
665
- if (item := self._chat_ctx.items[idx]).type == "message":
709
+ if idx >= 0 and (item := self._chat_ctx.items[idx]).type == "message":
666
710
  item.interrupted = True
667
- self._close_current_generation()
711
+ logger.debug("Barge-in detected - marked message as interrupted")
712
+
713
+ # Close generation on barge-in unless tools are pending
714
+ if not self._pending_tools:
715
+ self._close_current_generation()
716
+ else:
717
+ logger.debug(f"Keeping generation open - {len(self._pending_tools)} pending tools")
668
718
  return
669
719
 
670
- # ignore events until turn starts
671
- if self._current_generation is not None:
672
- # TODO: rename event to llm.InputTranscriptionUpdated
673
- if (
674
- self._current_generation.user_messages.get(text_content_id)
675
- == self._current_generation.input_id
676
- ):
677
- logger.debug(f"INPUT TRANSCRIPTION UPDATED: {text_content}")
678
- # note: user ASR text is slightly different than what is sent to LiveKit (newline vs whitespace) # noqa: E501
679
- # TODO: fix this
680
- self._update_chat_ctx(role="user", text_content=text_content)
681
-
682
- elif (
683
- self._current_generation.speculative_messages.get(text_content_id)
684
- == self._current_generation.response_id
685
- ):
686
- curr_gen = self._current_generation.messages[self._current_generation.response_id]
687
- curr_gen.text_ch.send_nowait(text_content)
688
- # note: this update is per utterance, not per turn
689
- self._update_chat_ctx(role="assistant", text_content=text_content)
720
+ content_type = self._current_generation.content_id_map.get(content_id)
721
+
722
+ if content_type == "USER_ASR":
723
+ logger.debug(f"INPUT TRANSCRIPTION UPDATED: {text_content}")
724
+ self._update_chat_ctx(role="user", text_content=text_content)
725
+
726
+ elif content_type == "ASSISTANT_TEXT":
727
+ # Set first token timestamp if not already set
728
+ if self._current_generation._first_token_timestamp is None:
729
+ self._current_generation._first_token_timestamp = time.time()
730
+
731
+ # Stream text to LiveKit
732
+ if self._current_generation.message_gen:
733
+ self._current_generation.message_gen.text_ch.send_nowait(text_content)
734
+ self._update_chat_ctx(role="assistant", text_content=text_content)
690
735
 
691
736
  def _update_chat_ctx(self, role: llm.ChatRole, text_content: str) -> None:
692
737
  """
@@ -716,107 +761,72 @@ class RealtimeSession( # noqa: F811
716
761
 
717
762
  # cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
718
763
  async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
719
- """Mark the assistant message closed when Bedrock signals END_TURN."""
720
- stop_reason = event_data["event"]["contentEnd"]["stopReason"]
721
- text_content_id = event_data["event"]["contentEnd"]["contentId"]
722
- if (
723
- self._current_generation
724
- is not None # means that first utterance in the turn was an interrupt
725
- and self._current_generation.speculative_messages.get(text_content_id)
726
- == self._current_generation.response_id
727
- and stop_reason == "END_TURN"
728
- ):
729
- log_event_data(event_data)
730
- self._close_current_generation()
764
+ """Handle text content end - log but don't close generation yet."""
765
+ # Nova Sonic sends multiple content blocks within one completion
766
+ # Don't close generation here - wait for completionEnd or audio_output_content_end
767
+ log_event_data(event_data)
731
768
 
732
769
  async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
733
- """Track mapping content_id -> response_id for upcoming tool use."""
770
+ """Track tool content start."""
734
771
  log_event_data(event_data)
735
- assert self._current_generation is not None, "current_generation is None"
736
772
 
737
- tool_use_content_id = event_data["event"]["contentStart"]["contentId"]
738
- self._current_generation.tool_messages[tool_use_content_id] = (
739
- self._current_generation.response_id
740
- )
773
+ # Ensure generation exists
774
+ self._create_response_generation()
775
+
776
+ if self._current_generation is None:
777
+ return
778
+
779
+ content_id = event_data["event"]["contentStart"]["contentId"]
780
+ self._current_generation.content_id_map[content_id] = "TOOL"
741
781
 
742
- # note: tool calls are synchronous for now
743
782
  async def _handle_tool_output_content_event(self, event_data: dict) -> None:
744
- """Execute the referenced tool locally and forward results back to Bedrock."""
783
+ """Execute the referenced tool locally and queue results."""
745
784
  log_event_data(event_data)
746
- assert self._current_generation is not None, "current_generation is None"
747
785
 
748
- tool_use_content_id = event_data["event"]["toolUse"]["contentId"]
786
+ if self._current_generation is None:
787
+ logger.warning("tool_output_content received without active generation")
788
+ return
789
+
749
790
  tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
750
791
  tool_name = event_data["event"]["toolUse"]["toolName"]
751
- if (
752
- self._current_generation.tool_messages.get(tool_use_content_id)
753
- == self._current_generation.response_id
754
- ):
755
- args = event_data["event"]["toolUse"]["content"]
756
- self._current_generation.function_ch.send_nowait(
757
- llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
758
- )
759
- self._pending_tools.add(tool_use_id)
760
-
761
- # performing these acrobatics in order to release the deadlock
762
- # LiveKit will not accept a new generation until the previous one is closed
763
- # the issue is that audio data cannot be generated until toolResult is received
764
- # however, toolResults only arrive after update_chat_ctx() is invoked
765
- # which will only occur after agent speech has completed
766
- # therefore we introduce an artificial turn to trigger update_chat_ctx()
767
- # TODO: this is messy-- investigate if there is a better way to handle this
768
- curr_gen = self._current_generation.messages[self._current_generation.response_id]
769
- curr_gen.audio_ch.close()
770
- curr_gen.text_ch.close()
771
- self._current_generation.message_ch.close()
772
- self._current_generation.message_ch = utils.aio.Chan()
773
- self._current_generation.function_ch.close()
774
- self._current_generation.function_ch = utils.aio.Chan()
775
- msg_gen = _MessageGeneration(
776
- message_id=self._current_generation.response_id,
777
- text_ch=utils.aio.Chan(),
778
- audio_ch=utils.aio.Chan(),
779
- )
780
- self._current_generation.messages[self._current_generation.response_id] = msg_gen
781
- msg_modalities = asyncio.Future[list[Literal["text", "audio"]]]()
782
- msg_modalities.set_result(
783
- ["audio", "text"] if self._realtime_model.capabilities.audio_output else ["text"]
784
- )
785
- self._current_generation.message_ch.send_nowait(
786
- llm.MessageGeneration(
787
- message_id=msg_gen.message_id,
788
- text_stream=msg_gen.text_ch,
789
- audio_stream=msg_gen.audio_ch,
790
- modalities=msg_modalities,
791
- )
792
- )
793
- self.emit_generation_event()
792
+ args = event_data["event"]["toolUse"]["content"]
793
+
794
+ # Emit function call to LiveKit framework
795
+ self._current_generation.function_ch.send_nowait(
796
+ llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
797
+ )
798
+ self._pending_tools.add(tool_use_id)
799
+ logger.debug(f"Tool call emitted: {tool_name} (id={tool_use_id})")
800
+
801
+ # CRITICAL: Close generation after tool call emission
802
+ # The LiveKit framework expects the generation to close so it can call update_chat_ctx()
803
+ # with the tool results. A new generation will be created when Nova Sonic sends the next
804
+ # ASSISTANT SPECULATIVE text event with the tool response.
805
+ logger.debug("Closing generation to allow tool result delivery")
806
+ self._close_current_generation()
794
807
 
795
808
  async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
796
809
  log_event_data(event_data)
797
810
 
798
811
  async def _handle_audio_output_content_start_event(self, event_data: dict) -> None:
799
- """Associate the upcoming audio chunk with the active assistant message."""
812
+ """Track audio content start."""
800
813
  if self._current_generation is not None:
801
814
  log_event_data(event_data)
802
- audio_content_id = event_data["event"]["contentStart"]["contentId"]
803
- self._current_generation.speculative_messages[audio_content_id] = (
804
- self._current_generation.response_id
805
- )
815
+ content_id = event_data["event"]["contentStart"]["contentId"]
816
+ self._current_generation.content_id_map[content_id] = "ASSISTANT_AUDIO"
806
817
 
807
818
  async def _handle_audio_output_content_event(self, event_data: dict) -> None:
808
819
  """Decode base64 audio from Bedrock and forward it to the audio stream."""
809
- if (
810
- self._current_generation is not None
811
- and self._current_generation.speculative_messages.get(
812
- event_data["event"]["audioOutput"]["contentId"]
813
- )
814
- == self._current_generation.response_id
815
- ):
820
+ if self._current_generation is None or self._current_generation.message_gen is None:
821
+ return
822
+
823
+ content_id = event_data["event"]["audioOutput"]["contentId"]
824
+ content_type = self._current_generation.content_id_map.get(content_id)
825
+
826
+ if content_type == "ASSISTANT_AUDIO":
816
827
  audio_content = event_data["event"]["audioOutput"]["content"]
817
828
  audio_bytes = base64.b64decode(audio_content)
818
- curr_gen = self._current_generation.messages[self._current_generation.response_id]
819
- curr_gen.audio_ch.send_nowait(
829
+ self._current_generation.message_gen.audio_ch.send_nowait(
820
830
  rtc.AudioFrame(
821
831
  data=audio_bytes,
822
832
  sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
@@ -826,62 +836,89 @@ class RealtimeSession( # noqa: F811
826
836
  )
827
837
 
828
838
  async def _handle_audio_output_content_end_event(self, event_data: dict) -> None:
829
- """Close the assistant message streams once Bedrock finishes audio for the turn."""
830
- if (
831
- self._current_generation is not None
832
- and event_data["event"]["contentEnd"]["stopReason"] == "END_TURN"
833
- and self._current_generation.speculative_messages.get(
834
- event_data["event"]["contentEnd"]["contentId"]
835
- )
836
- == self._current_generation.response_id
837
- ):
838
- log_event_data(event_data)
839
- self._close_current_generation()
839
+ """Handle audio content end - log but don't close generation."""
840
+ log_event_data(event_data)
841
+ # Nova Sonic uses one completion for entire session
842
+ # Don't close generation here - wait for new completionStart or session end
840
843
 
841
844
  def _close_current_generation(self) -> None:
842
- """Helper that closes all channels of the active _ResponseGeneration."""
843
- if self._current_generation is not None:
844
- if self._current_generation.response_id in self._current_generation.messages:
845
- curr_gen = self._current_generation.messages[self._current_generation.response_id]
846
- if not curr_gen.audio_ch.closed:
847
- curr_gen.audio_ch.close()
848
- if not curr_gen.text_ch.closed:
849
- curr_gen.text_ch.close()
850
-
851
- # TODO: seems not needed, tool_messages[id] is a str, function_ch is closed below?
852
- # if self._current_generation.response_id in self._current_generation.tool_messages:
853
- # curr_gen = self._current_generation.tool_messages[
854
- # self._current_generation.response_id
855
- # ]
856
- # if not curr_gen.function_ch.closed:
857
- # curr_gen.function_ch.close()
858
-
859
- if not self._current_generation.message_ch.closed:
860
- self._current_generation.message_ch.close()
861
- if not self._current_generation.function_ch.closed:
862
- self._current_generation.function_ch.close()
863
-
864
- self._current_generation = None
845
+ """Helper that closes all channels of the active generation."""
846
+ if self._current_generation is None:
847
+ return
848
+
849
+ # Set completed timestamp
850
+ if self._current_generation._completed_timestamp is None:
851
+ self._current_generation._completed_timestamp = time.time()
852
+
853
+ # Close message channels
854
+ if self._current_generation.message_gen:
855
+ if not self._current_generation.message_gen.audio_ch.closed:
856
+ self._current_generation.message_gen.audio_ch.close()
857
+ if not self._current_generation.message_gen.text_ch.closed:
858
+ self._current_generation.message_gen.text_ch.close()
859
+
860
+ # Close generation channels
861
+ if not self._current_generation.message_ch.closed:
862
+ self._current_generation.message_ch.close()
863
+ if not self._current_generation.function_ch.closed:
864
+ self._current_generation.function_ch.close()
865
+
866
+ logger.debug(
867
+ f"Closed generation for completion_id={self._current_generation.completion_id}"
868
+ )
869
+ self._current_generation = None
865
870
 
866
871
  async def _handle_completion_end_event(self, event_data: dict) -> None:
872
+ """Handle completionEnd - close the generation for this completion cycle."""
867
873
  log_event_data(event_data)
868
874
 
875
+ # Close generation if still open
876
+ if self._current_generation:
877
+ logger.debug("completionEnd received, closing generation")
878
+ self._close_current_generation()
879
+
869
880
  async def _handle_other_event(self, event_data: dict) -> None:
870
881
  log_event_data(event_data)
871
882
 
872
883
  async def _handle_usage_event(self, event_data: dict) -> None:
873
884
  # log_event_data(event_data)
874
- # TODO: implement duration and ttft
875
885
  input_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["input"]
876
886
  output_tokens = event_data["event"]["usageEvent"]["details"]["delta"]["output"]
877
- # Q: should we be counting per turn or utterance?
887
+
888
+ # Calculate metrics from timestamps
889
+ duration = 0.0
890
+ ttft = 0.0
891
+ tokens_per_second = 0.0
892
+
893
+ if self._current_generation is not None:
894
+ created_ts = self._current_generation._created_timestamp
895
+ first_token_ts = self._current_generation._first_token_timestamp
896
+ completed_ts = self._current_generation._completed_timestamp
897
+
898
+ # Calculate TTFT (time to first token)
899
+ if first_token_ts is not None and isinstance(created_ts, (int, float)):
900
+ ttft = first_token_ts - created_ts
901
+
902
+ # Calculate duration (total time from creation to completion)
903
+ if completed_ts is not None and isinstance(created_ts, (int, float)):
904
+ duration = completed_ts - created_ts
905
+
906
+ # Calculate tokens per second
907
+ total_tokens = (
908
+ input_tokens["speechTokens"]
909
+ + input_tokens["textTokens"]
910
+ + output_tokens["speechTokens"]
911
+ + output_tokens["textTokens"]
912
+ )
913
+ if duration > 0:
914
+ tokens_per_second = total_tokens / duration
915
+
878
916
  metrics = RealtimeModelMetrics(
879
917
  label=self._realtime_model.label,
880
- # TODO: pass in the correct request_id
881
918
  request_id=event_data["event"]["usageEvent"]["completionId"],
882
919
  timestamp=time.monotonic(),
883
- duration=0,
884
- ttft=0,
920
+ duration=duration,
921
+ ttft=ttft,
885
922
  cancelled=False,
886
923
  input_tokens=input_tokens["speechTokens"] + input_tokens["textTokens"],
887
924
  output_tokens=output_tokens["speechTokens"] + output_tokens["textTokens"],
@@ -889,8 +926,7 @@ class RealtimeSession( # noqa: F811
889
926
  + input_tokens["textTokens"]
890
927
  + output_tokens["speechTokens"]
891
928
  + output_tokens["textTokens"],
892
- # need duration to calculate this
893
- tokens_per_second=0,
929
+ tokens_per_second=tokens_per_second,
894
930
  input_token_details=RealtimeModelMetrics.InputTokenDetails(
895
931
  text_tokens=input_tokens["textTokens"],
896
932
  audio_tokens=input_tokens["speechTokens"],
@@ -1016,8 +1052,13 @@ class RealtimeSession( # noqa: F811
1016
1052
  self._is_sess_active.clear()
1017
1053
 
1018
1054
  async def _restart_session(self, ex: Exception) -> None:
1019
- if self._session_restart_attempts >= DEFAULT_MAX_SESSION_RESTART_ATTEMPTS:
1020
- logger.error("Max session restart attempts reached, exiting")
1055
+ # Get restart attempts from current generation, or 0 if no generation
1056
+ restart_attempts = (
1057
+ self._current_generation._restart_attempts if self._current_generation else 0
1058
+ )
1059
+
1060
+ if restart_attempts >= DEFAULT_MAX_SESSION_RESTART_ATTEMPTS:
1061
+ logger.error("Max restart attempts reached for this turn, exiting")
1021
1062
  err_msg = getattr(ex, "message", str(ex))
1022
1063
  request_id = None
1023
1064
  try:
@@ -1041,13 +1082,20 @@ class RealtimeSession( # noqa: F811
1041
1082
  )
1042
1083
  self._is_sess_active.clear()
1043
1084
  return
1044
- self._session_restart_attempts += 1
1085
+
1086
+ # Increment restart counter for current generation
1087
+ if self._current_generation:
1088
+ self._current_generation._restart_attempts += 1
1089
+ restart_attempts = self._current_generation._restart_attempts
1090
+ else:
1091
+ restart_attempts = 1
1092
+
1045
1093
  self._is_sess_active.clear()
1046
- delay = 2 ** (self._session_restart_attempts - 1) - 1
1094
+ delay = 2 ** (restart_attempts - 1) - 1
1047
1095
  await asyncio.sleep(min(delay, DEFAULT_MAX_SESSION_RESTART_DELAY))
1048
1096
  await self.initialize_streams(is_restart=True)
1049
1097
  logger.info(
1050
- f"Session restarted successfully ({self._session_restart_attempts}/{DEFAULT_MAX_SESSION_RESTART_ATTEMPTS})" # noqa: E501
1098
+ f"Turn restarted successfully ({restart_attempts}/{DEFAULT_MAX_SESSION_RESTART_ATTEMPTS})"
1051
1099
  )
1052
1100
 
1053
1101
  @property
@@ -1061,7 +1109,10 @@ class RealtimeSession( # noqa: F811
1061
1109
  async def update_instructions(self, instructions: str) -> None:
1062
1110
  """Injects the system prompt at the start of the session."""
1063
1111
  self._instructions = instructions
1064
- self._instructions_ready.set_result(True)
1112
+ if self._instructions_ready is None:
1113
+ self._instructions_ready = asyncio.get_running_loop().create_future()
1114
+ if not self._instructions_ready.done():
1115
+ self._instructions_ready.set_result(True)
1065
1116
  logger.debug(f"Instructions updated: {instructions}")
1066
1117
 
1067
1118
  async def update_chat_ctx(self, chat_ctx: llm.ChatContext) -> None:
@@ -1069,16 +1120,26 @@ class RealtimeSession( # noqa: F811
1069
1120
  # sometimes fires randomly
1070
1121
  # add a guard here to only allow chat_ctx to be updated on
1071
1122
  # the very first session initialization
1123
+ if self._chat_ctx_ready is None:
1124
+ self._chat_ctx_ready = asyncio.get_running_loop().create_future()
1125
+
1072
1126
  if not self._chat_ctx_ready.done():
1073
1127
  self._chat_ctx = chat_ctx.copy()
1074
1128
  logger.debug(f"Chat context updated: {self._chat_ctx.items}")
1075
1129
  self._chat_ctx_ready.set_result(True)
1076
1130
 
1077
1131
  # for each function tool, send the result to aws
1132
+ logger.debug(
1133
+ f"update_chat_ctx called with {len(chat_ctx.items)} items, pending_tools: {self._pending_tools}"
1134
+ )
1078
1135
  for item in chat_ctx.items:
1079
1136
  if item.type != "function_call_output":
1080
1137
  continue
1081
1138
 
1139
+ logger.debug(
1140
+ f"Found function_call_output: call_id={item.call_id}, in_pending={item.call_id in self._pending_tools}"
1141
+ )
1142
+
1082
1143
  if item.call_id not in self._pending_tools:
1083
1144
  continue
1084
1145
 
@@ -1128,7 +1189,10 @@ class RealtimeSession( # noqa: F811
1128
1189
  retained_tools.append(tool)
1129
1190
  self._tools = llm.ToolContext(retained_tools)
1130
1191
  if retained_tools:
1131
- self._tools_ready.set_result(True)
1192
+ if self._tools_ready is None:
1193
+ self._tools_ready = asyncio.get_running_loop().create_future()
1194
+ if not self._tools_ready.done():
1195
+ self._tools_ready.set_result(True)
1132
1196
  logger.debug("Tool list has been injected")
1133
1197
 
1134
1198
  def update_options(self, *, tool_choice: NotGivenOr[llm.ToolChoice | None] = NOT_GIVEN) -> None:
@@ -1164,54 +1228,62 @@ class RealtimeSession( # noqa: F811
1164
1228
  """Background task that feeds audio and tool results into the Bedrock stream."""
1165
1229
  await self._send_raw_event(self._event_builder.create_audio_content_start_event())
1166
1230
  logger.info("Starting audio input processing loop")
1231
+
1232
+ # Create tasks for both channels so we can wait on either
1233
+ audio_task = asyncio.create_task(self._audio_input_chan.recv())
1234
+ tool_task = asyncio.create_task(self._tool_results_ch.recv())
1235
+ pending = {audio_task, tool_task}
1236
+
1167
1237
  while self._is_sess_active.is_set():
1168
1238
  try:
1169
- # note: could potentially pull this out into a separate task
1170
- try:
1171
- val = self._tool_results_ch.recv_nowait()
1172
- tool_result = val["tool_result"]
1173
- tool_use_id = val["tool_use_id"]
1174
- if not isinstance(tool_result, str):
1175
- tool_result = json.dumps(tool_result)
1176
- else:
1177
- try:
1178
- json.loads(tool_result)
1179
- except json.JSONDecodeError:
1180
- try:
1181
- tool_result = json.dumps(ast.literal_eval(tool_result))
1182
- except Exception:
1183
- # return the original value
1184
- pass
1185
-
1186
- logger.debug(f"Sending tool result: {tool_result}")
1187
- await self._send_tool_events(tool_use_id, tool_result)
1188
-
1189
- except utils.aio.channel.ChanEmpty:
1190
- pass
1191
- except utils.aio.channel.ChanClosed:
1192
- logger.warning(
1193
- "tool results channel closed, exiting audio input processing loop"
1194
- )
1195
- break
1239
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
1196
1240
 
1197
- try:
1198
- audio_bytes = await self._audio_input_chan.recv()
1199
- blob = base64.b64encode(audio_bytes)
1200
- audio_event = self._event_builder.create_audio_input_event(
1201
- audio_content=blob.decode("utf-8"),
1202
- )
1203
-
1204
- await self._send_raw_event(audio_event)
1205
- except utils.aio.channel.ChanEmpty:
1206
- pass
1207
- except utils.aio.channel.ChanClosed:
1208
- logger.warning(
1209
- "audio input channel closed, exiting audio input processing loop"
1210
- )
1211
- break
1241
+ for task in done:
1242
+ if task == audio_task:
1243
+ try:
1244
+ audio_bytes = cast(bytes, task.result())
1245
+ blob = base64.b64encode(audio_bytes)
1246
+ audio_event = self._event_builder.create_audio_input_event(
1247
+ audio_content=blob.decode("utf-8"),
1248
+ )
1249
+ await self._send_raw_event(audio_event)
1250
+ # Create new task for next audio
1251
+ audio_task = asyncio.create_task(self._audio_input_chan.recv())
1252
+ pending.add(audio_task)
1253
+ except utils.aio.channel.ChanClosed:
1254
+ logger.warning("audio input channel closed")
1255
+ break
1256
+
1257
+ elif task == tool_task:
1258
+ try:
1259
+ val = cast(dict[str, str], task.result())
1260
+ tool_result = val["tool_result"]
1261
+ tool_use_id = val["tool_use_id"]
1262
+ if not isinstance(tool_result, str):
1263
+ tool_result = json.dumps(tool_result)
1264
+ else:
1265
+ try:
1266
+ json.loads(tool_result)
1267
+ except json.JSONDecodeError:
1268
+ try:
1269
+ tool_result = json.dumps(ast.literal_eval(tool_result))
1270
+ except Exception:
1271
+ pass
1272
+
1273
+ logger.debug(f"Sending tool result: {tool_result}")
1274
+ await self._send_tool_events(tool_use_id, tool_result)
1275
+ # Create new task for next tool result
1276
+ tool_task = asyncio.create_task(self._tool_results_ch.recv())
1277
+ pending.add(tool_task)
1278
+ except utils.aio.channel.ChanClosed:
1279
+ logger.warning("tool results channel closed")
1280
+ break
1212
1281
 
1213
1282
  except asyncio.CancelledError:
1214
1283
  logger.info("Audio processing loop cancelled")
1284
+ # Cancel pending tasks
1285
+ for task in pending:
1286
+ task.cancel()
1215
1287
  self._audio_input_chan.close()
1216
1288
  self._tool_results_ch.close()
1217
1289
  raise
@@ -1262,7 +1334,24 @@ class RealtimeSession( # noqa: F811
1262
1334
  logger.warning("video is not supported by Nova Sonic's Realtime API")
1263
1335
 
1264
1336
  def interrupt(self) -> None:
1265
- logger.warning("interrupt is not supported by Nova Sonic's Realtime API")
1337
+ """Nova Sonic handles interruption automatically via barge-in detection.
1338
+
1339
+ Unlike OpenAI's client-initiated interrupt, Nova Sonic automatically detects
1340
+ when the user starts speaking while the model is generating audio. When this
1341
+ happens, the model:
1342
+ 1. Immediately stops generating speech
1343
+ 2. Switches to listening mode
1344
+ 3. Sends a text event with content: { "interrupted" : true }
1345
+
1346
+ The plugin already handles this event (see _handle_text_output_content_event).
1347
+ No client action is needed - interruption works automatically.
1348
+
1349
+ See AWS docs: https://docs.aws.amazon.com/nova/latest/userguide/output-events.html
1350
+ """
1351
+ logger.info(
1352
+ "Nova Sonic handles interruption automatically via barge-in detection. "
1353
+ "The model detects when users start speaking and stops generation automatically."
1354
+ )
1266
1355
 
1267
1356
  def truncate(
1268
1357
  self,
@@ -1312,6 +1401,11 @@ class RealtimeSession( # noqa: F811
1312
1401
  if self._stream_response and not self._stream_response.input_stream.closed:
1313
1402
  await self._stream_response.input_stream.close()
1314
1403
 
1404
+ # cancel main task to prevent pending task warnings
1405
+ if self._main_atask and not self._main_atask.done():
1406
+ self._main_atask.cancel()
1407
+ tasks.append(self._main_atask)
1408
+
1315
1409
  await asyncio.gather(*tasks, return_exceptions=True)
1316
1410
  logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
1317
1411
  logger.info("Session end")
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.2.16"
15
+ __version__ = "1.2.17"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-aws
3
- Version: 1.2.16
3
+ Version: 1.2.17
4
4
  Summary: LiveKit Agents Plugin for services from AWS
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9.0
21
21
  Requires-Dist: aioboto3>=14.1.0
22
22
  Requires-Dist: amazon-transcribe>=0.6.4
23
- Requires-Dist: livekit-agents>=1.2.16
23
+ Requires-Dist: livekit-agents>=1.2.17
24
24
  Provides-Extra: realtime
25
25
  Requires-Dist: aws-sdk-bedrock-runtime==0.0.2; (python_version >= '3.12') and extra == 'realtime'
26
26
  Requires-Dist: aws-sdk-signers==0.0.3; (python_version >= '3.12') and extra == 'realtime'
@@ -6,12 +6,12 @@ livekit/plugins/aws/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
6
6
  livekit/plugins/aws/stt.py,sha256=OwHfGDFW-VtSY68my3u5dyMLRpYVZPlnVbtm-dNj4Q0,9438
7
7
  livekit/plugins/aws/tts.py,sha256=oav-XWf9ysVGCmERWej6BgACu8vsLbRo9vFGpo9N6Ec,7184
8
8
  livekit/plugins/aws/utils.py,sha256=nA5Ua1f4T-25Loar6EvlrKTXI9N-zpTIH7cdQkwGyGI,1518
9
- livekit/plugins/aws/version.py,sha256=6RxW2Q7KoSNRlDtulIUp5F0_o0atksX-Xpp45NaSCaI,601
9
+ livekit/plugins/aws/version.py,sha256=ZlvvHSEyo4YT35z0OKDbZvGI4D1lVYqTvemcuSdOS8o,601
10
10
  livekit/plugins/aws/experimental/realtime/__init__.py,sha256=mm_TGZc9QAWSO-VOO3PdE8Y5R6xlWckXRZuiFUIHa-Q,287
11
11
  livekit/plugins/aws/experimental/realtime/events.py,sha256=ltdGEipE3ZOkjn7K6rKN6WSCUPJkVg-S88mUmQ_V00s,15981
12
12
  livekit/plugins/aws/experimental/realtime/pretty_printer.py,sha256=KN7KPrfQu8cU7ff34vFAtfrd1umUSTVNKXQU7D8AMiM,1442
13
- livekit/plugins/aws/experimental/realtime/realtime_model.py,sha256=0Eyz3FNOa-xYNFYjwp7NMgftZRr7DAc4S6DSQ4dVcog,61003
13
+ livekit/plugins/aws/experimental/realtime/realtime_model.py,sha256=ksdw7X-wm5wiDoCur9srYTRV2eVadHOjAIIanNS9dUo,64568
14
14
  livekit/plugins/aws/experimental/realtime/turn_tracker.py,sha256=bcufaap-coeIYuK3ct1Is9W_UoefGYRmnJu7Mn5DCYU,6002
15
- livekit_plugins_aws-1.2.16.dist-info/METADATA,sha256=d2UpArja3hG8OgvFLfTiVp36YbYWkbu-yu_k33ENyy4,2083
16
- livekit_plugins_aws-1.2.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- livekit_plugins_aws-1.2.16.dist-info/RECORD,,
15
+ livekit_plugins_aws-1.2.17.dist-info/METADATA,sha256=53WRXByqiLoDG0B9RoWMYIMWqoxxuWl1XHvMIH3rXsE,2083
16
+ livekit_plugins_aws-1.2.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ livekit_plugins_aws-1.2.17.dist-info/RECORD,,