livekit-plugins-aws 1.1.5__py3-none-any.whl → 1.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of livekit-plugins-aws might be problematic. Click here for more details.

@@ -1,8 +1,8 @@
1
1
  import json
2
2
  import uuid
3
- from typing import Any, Literal, Optional, Union
3
+ from typing import Any, Literal, Optional, Union, cast
4
4
 
5
- from pydantic import BaseModel, ConfigDict, Field
5
+ from pydantic import BaseModel as _BaseModel, ConfigDict, Field
6
6
 
7
7
  from livekit.agents import llm
8
8
 
@@ -20,7 +20,7 @@ SAMPLE_SIZE_BITS = Literal[16] # only supports 16-bit audio
20
20
  CHANNEL_COUNT = Literal[1] # only supports monochannel audio
21
21
 
22
22
 
23
- class BaseModel(BaseModel):
23
+ class BaseModel(_BaseModel):
24
24
  model_config = ConfigDict(populate_by_name=True, extra="forbid")
25
25
 
26
26
 
@@ -91,7 +91,7 @@ class Tool(BaseModel):
91
91
 
92
92
 
93
93
  class ToolConfiguration(BaseModel):
94
- toolChoice: dict[str, dict[str, str]] | None = None
94
+ toolChoice: Optional[dict[str, dict[str, str]]] = None
95
95
  tools: list[Tool]
96
96
 
97
97
 
@@ -260,6 +260,8 @@ class SonicEventBuilder:
260
260
  else:
261
261
  return "other_event"
262
262
 
263
+ raise ValueError(f"Unknown event type: {json_data}")
264
+
263
265
  def create_text_content_block(
264
266
  self,
265
267
  content_name: str,
@@ -313,10 +315,18 @@ class SonicEventBuilder:
313
315
  if chat_ctx.items:
314
316
  logger.debug("initiating session with chat context")
315
317
  for item in chat_ctx.items:
318
+ if item.type != "message":
319
+ continue
320
+
321
+ if (role := item.role.upper()) not in ["USER", "ASSISTANT", "SYSTEM"]:
322
+ continue
323
+
316
324
  ctx_content_name = str(uuid.uuid4())
317
325
  init_events.extend(
318
326
  self.create_text_content_block(
319
- ctx_content_name, item.role.upper(), "".join(item.content)
327
+ ctx_content_name,
328
+ cast(ROLE, role),
329
+ "".join(c for c in item.content if isinstance(c, str)),
320
330
  )
321
331
  )
322
332
 
@@ -481,26 +491,15 @@ class SonicEventBuilder:
481
491
  sample_rate: SAMPLE_RATE_HERTZ,
482
492
  tool_configuration: Optional[Union[ToolConfiguration, dict[str, Any], str]] = None,
483
493
  ) -> str:
484
- tool_configuration = tool_configuration or ToolConfiguration(tools=[])
485
- for tool in tool_configuration.tools:
486
- logger.debug(f"TOOL JSON SCHEMA: {tool.toolSpec.inputSchema}")
487
- tool_objects = [
488
- Tool(
489
- toolSpec=ToolSpec(
490
- name=tool.toolSpec.name,
491
- description=tool.toolSpec.description,
492
- inputSchema=ToolInputSchema(json_=tool.toolSpec.inputSchema.json_),
493
- )
494
- )
495
- for tool in tool_configuration.tools
496
- ]
497
-
498
494
  if tool_configuration is None:
499
495
  tool_configuration = ToolConfiguration(tools=[])
500
496
  elif isinstance(tool_configuration, str):
501
- tool_configuration = ToolConfiguration(**json.loads(tool_configuration))
497
+ tool_configuration = ToolConfiguration.model_validate_json(tool_configuration)
502
498
  elif isinstance(tool_configuration, dict):
503
- tool_configuration = ToolConfiguration(**tool_configuration)
499
+ tool_configuration = ToolConfiguration.model_validate(tool_configuration)
500
+
501
+ for tool in tool_configuration.tools:
502
+ logger.debug(f"TOOL JSON SCHEMA: {tool.toolSpec.inputSchema}")
504
503
 
505
504
  tool_objects = list(tool_configuration.tools)
506
505
  event = Event(
@@ -1,5 +1,8 @@
1
+ # mypy: disable-error-code=unused-ignore
2
+
1
3
  from __future__ import annotations
2
4
 
5
+ import ast
3
6
  import asyncio
4
7
  import base64
5
8
  import json
@@ -10,7 +13,7 @@ import weakref
10
13
  from collections.abc import Iterator
11
14
  from dataclasses import dataclass, field
12
15
  from datetime import datetime
13
- from typing import Any, Literal
16
+ from typing import Any, Callable, Literal, cast
14
17
 
15
18
  import boto3
16
19
  from aws_sdk_bedrock_runtime.client import (
@@ -33,11 +36,9 @@ from smithy_core.aio.interfaces.identity import IdentityResolver
33
36
  from livekit import rtc
34
37
  from livekit.agents import (
35
38
  APIStatusError,
36
- ToolError,
37
39
  llm,
38
40
  utils,
39
41
  )
40
- from livekit.agents.llm.realtime import RealtimeSession
41
42
  from livekit.agents.metrics import RealtimeModelMetrics
42
43
  from livekit.agents.types import NOT_GIVEN, NotGivenOr
43
44
  from livekit.agents.utils import is_given
@@ -150,12 +151,12 @@ class _ResponseGeneration:
150
151
  speculative_messages: dict[str, str] = field(default_factory=dict)
151
152
  tool_messages: dict[str, str] = field(default_factory=dict)
152
153
  output_text: str = "" # agent ASR text
153
- _created_timestamp: str = field(default_factory=datetime.now().isoformat())
154
+ _created_timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
154
155
  _first_token_timestamp: float | None = None
155
156
  _completed_timestamp: float | None = None
156
157
 
157
158
 
158
- class Boto3CredentialsResolver(IdentityResolver):
159
+ class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc]
159
160
  """IdentityResolver implementation that sources AWS credentials from boto3.
160
161
 
161
162
  The resolver delegates to the default boto3.Session() credential chain which
@@ -164,10 +165,10 @@ class Boto3CredentialsResolver(IdentityResolver):
164
165
  passed into Bedrock runtime clients.
165
166
  """
166
167
 
167
- def __init__(self):
168
- self.session = boto3.Session()
168
+ def __init__(self) -> None:
169
+ self.session = boto3.Session() # type: ignore[attr-defined]
169
170
 
170
- async def get_identity(self, **kwargs):
171
+ async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity:
171
172
  """Asynchronously resolve AWS credentials.
172
173
 
173
174
  This method is invoked by the Bedrock runtime client whenever a new request needs to be
@@ -247,7 +248,7 @@ class RealtimeModel(llm.RealtimeModel):
247
248
  self.temperature = temperature
248
249
  self.top_p = top_p
249
250
  self._opts = _RealtimeOptions(
250
- voice=voice if is_given(voice) else "tiffany",
251
+ voice=cast(VOICE_ID, voice) if is_given(voice) else "tiffany",
251
252
  temperature=temperature if is_given(temperature) else DEFAULT_TEMPERATURE,
252
253
  top_p=top_p if is_given(top_p) else DEFAULT_TOP_P,
253
254
  max_tokens=max_tokens if is_given(max_tokens) else DEFAULT_MAX_TOKENS,
@@ -295,7 +296,7 @@ class RealtimeSession( # noqa: F811
295
296
  inference options and the Smithy Bedrock client configuration.
296
297
  """
297
298
  super().__init__(realtime_model)
298
- self._realtime_model = realtime_model
299
+ self._realtime_model: RealtimeModel = realtime_model
299
300
  self._event_builder = seb(
300
301
  prompt_name=str(uuid.uuid4()),
301
302
  audio_content_name=str(uuid.uuid4()),
@@ -309,10 +310,10 @@ class RealtimeSession( # noqa: F811
309
310
  self._audio_input_task = None
310
311
  self._stream_response = None
311
312
  self._bedrock_client = None
313
+ self._pending_tools: set[str] = set()
312
314
  self._is_sess_active = asyncio.Event()
313
315
  self._chat_ctx = llm.ChatContext.empty()
314
316
  self._tools = llm.ToolContext.empty()
315
- self._tool_type_map = {}
316
317
  self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
317
318
  self._tools_ready = asyncio.get_running_loop().create_future()
318
319
  self._instructions_ready = asyncio.get_running_loop().create_future()
@@ -341,16 +342,12 @@ class RealtimeSession( # noqa: F811
341
342
  "other_event": self._handle_other_event,
342
343
  }
343
344
  self._turn_tracker = _TurnTracker(
344
- self.emit, streams_provider=self._current_generation_streams
345
+ cast(Callable[[str, Any], None], self.emit),
346
+ cast(Callable[[], None], self.emit_generation_event),
345
347
  )
346
348
 
347
- def _current_generation_streams(
348
- self,
349
- ) -> tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]:
350
- return (self._current_generation.message_ch, self._current_generation.function_ch)
351
-
352
349
  @utils.log_exceptions(logger=logger)
353
- def _initialize_client(self):
350
+ def _initialize_client(self) -> None:
354
351
  """Instantiate the Bedrock runtime client"""
355
352
  config = Config(
356
353
  endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com",
@@ -362,11 +359,11 @@ class RealtimeSession( # noqa: F811
362
359
  self._bedrock_client = BedrockRuntimeClient(config=config)
363
360
 
364
361
  @utils.log_exceptions(logger=logger)
365
- async def _send_raw_event(self, event_json):
362
+ async def _send_raw_event(self, event_json: str) -> None:
366
363
  """Low-level helper that serialises event_json and forwards it to the bidirectional stream.
367
364
 
368
365
  Args:
369
- event_json (dict | str): The JSON payload (already in Bedrock wire format) to queue.
366
+ event_json (str): The JSON payload (already in Bedrock wire format) to queue.
370
367
 
371
368
  Raises:
372
369
  Exception: Propagates any failures returned by the Bedrock runtime client.
@@ -425,21 +422,21 @@ class RealtimeSession( # noqa: F811
425
422
  input_schema = llm.utils.build_legacy_openai_schema(f, internally_tagged=True)[
426
423
  "parameters"
427
424
  ]
428
- self._tool_type_map[name] = "FunctionTool"
429
- else:
425
+ elif llm.tool_context.is_raw_function_tool(f):
430
426
  description = llm.tool_context.get_raw_function_info(f).raw_schema.get(
431
427
  "description"
432
428
  )
433
429
  input_schema = llm.tool_context.get_raw_function_info(f).raw_schema[
434
430
  "parameters"
435
431
  ]
436
- self._tool_type_map[name] = "RawFunctionTool"
432
+ else:
433
+ continue
437
434
 
438
435
  tool = Tool(
439
436
  toolSpec=ToolSpec(
440
437
  name=name,
441
- description=description,
442
- inputSchema=ToolInputSchema(json_=json.dumps(input_schema)),
438
+ description=description or "No description provided",
439
+ inputSchema=ToolInputSchema(json_=json.dumps(input_schema)), # type: ignore
443
440
  )
444
441
  )
445
442
  tools.append(tool)
@@ -455,7 +452,7 @@ class RealtimeSession( # noqa: F811
455
452
  return tool_cfg
456
453
 
457
454
  @utils.log_exceptions(logger=logger)
458
- async def initialize_streams(self, is_restart: bool = False):
455
+ async def initialize_streams(self, is_restart: bool = False) -> None:
459
456
  """Open the Bedrock bidirectional stream and spawn background worker tasks.
460
457
 
461
458
  This coroutine is idempotent and can be invoked again when recoverable
@@ -469,6 +466,7 @@ class RealtimeSession( # noqa: F811
469
466
  if not self._bedrock_client:
470
467
  logger.info("Creating Bedrock client")
471
468
  self._initialize_client()
469
+ assert self._bedrock_client is not None, "bedrock_client is None"
472
470
 
473
471
  logger.info("Initializing Bedrock stream")
474
472
  self._stream_response = (
@@ -542,15 +540,16 @@ class RealtimeSession( # noqa: F811
542
540
  self._is_sess_active.set()
543
541
  logger.debug("Stream initialized successfully")
544
542
  except Exception as e:
545
- self._is_sess_active.set_exception(e)
546
543
  logger.debug(f"Failed to initialize stream: {str(e)}")
547
544
  raise
548
545
  return self
549
546
 
550
547
  @utils.log_exceptions(logger=logger)
551
- def _emit_generation_event(self) -> None:
548
+ def emit_generation_event(self) -> None:
552
549
  """Publish a llm.GenerationCreatedEvent to external subscribers."""
553
550
  logger.debug("Emitting generation event")
551
+ assert self._current_generation is not None, "current_generation is None"
552
+
554
553
  generation_ev = llm.GenerationCreatedEvent(
555
554
  message_stream=self._current_generation.message_ch,
556
555
  function_stream=self._current_generation.function_ch,
@@ -605,10 +604,12 @@ class RealtimeSession( # noqa: F811
605
604
  """Handle text_output_content_start for both user and assistant roles."""
606
605
  log_event_data(event_data)
607
606
  role = event_data["event"]["contentStart"]["role"]
607
+ self._create_response_generation()
608
608
 
609
609
  # note: does not work if you emit llm.GCE too early (for some reason)
610
610
  if role == "USER":
611
- self._create_response_generation()
611
+ assert self._current_generation is not None, "current_generation is None"
612
+
612
613
  content_id = event_data["event"]["contentStart"]["contentId"]
613
614
  self._current_generation.user_messages[content_id] = self._current_generation.input_id
614
615
 
@@ -616,6 +617,8 @@ class RealtimeSession( # noqa: F811
616
617
  role == "ASSISTANT"
617
618
  and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
618
619
  ):
620
+ assert self._current_generation is not None, "current_generation is None"
621
+
619
622
  text_content_id = event_data["event"]["contentStart"]["contentId"]
620
623
  self._current_generation.speculative_messages[text_content_id] = (
621
624
  self._current_generation.response_id
@@ -633,10 +636,15 @@ class RealtimeSession( # noqa: F811
633
636
  # this is b/c audio playback is desynced from text transcription
634
637
  # TODO: fix this; possibly via a playback timer
635
638
  idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
639
+ if idx < 0:
640
+ logger.warning("Barge-in DETECTED but no previous message found")
641
+ return
642
+
636
643
  logger.debug(
637
644
  f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
638
645
  )
639
- self._chat_ctx.items[idx].interrupted = True
646
+ if (item := self._chat_ctx.items[idx]).type == "message":
647
+ item.interrupted = True
640
648
  self._close_current_generation()
641
649
  return
642
650
 
@@ -661,27 +669,31 @@ class RealtimeSession( # noqa: F811
661
669
  # note: this update is per utterance, not per turn
662
670
  self._update_chat_ctx(role="assistant", text_content=text_content)
663
671
 
664
- def _update_chat_ctx(self, role: str, text_content: str) -> None:
672
+ def _update_chat_ctx(self, role: llm.ChatRole, text_content: str) -> None:
665
673
  """
666
674
  Update the chat context with the latest ASR text while guarding against model limitations:
667
675
  a) 40 total messages limit
668
676
  b) 1kB message size limit
669
677
  """
670
- prev_utterance = self._chat_ctx.items[-1]
671
- if prev_utterance.role == role:
672
- if (
673
- len(prev_utterance.content[0].encode("utf-8")) + len(text_content.encode("utf-8"))
674
- < MAX_MESSAGE_SIZE
675
- ):
676
- prev_utterance.content[0] = "\n".join([prev_utterance.content[0], text_content])
678
+ logger.debug(f"Updating chat context with role: {role} and text_content: {text_content}")
679
+ if len(self._chat_ctx.items) == 0:
680
+ self._chat_ctx.add_message(role=role, content=text_content)
681
+ else:
682
+ prev_utterance = self._chat_ctx.items[-1]
683
+ if prev_utterance.type == "message" and prev_utterance.role == role:
684
+ if isinstance(prev_content := prev_utterance.content[0], str) and (
685
+ len(prev_content.encode("utf-8")) + len(text_content.encode("utf-8"))
686
+ < MAX_MESSAGE_SIZE
687
+ ):
688
+ prev_utterance.content[0] = "\n".join([prev_content, text_content])
689
+ else:
690
+ self._chat_ctx.add_message(role=role, content=text_content)
691
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
692
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
677
693
  else:
678
694
  self._chat_ctx.add_message(role=role, content=text_content)
679
695
  if len(self._chat_ctx.items) > MAX_MESSAGES:
680
696
  self._chat_ctx.truncate(max_items=MAX_MESSAGES)
681
- else:
682
- self._chat_ctx.add_message(role=role, content=text_content)
683
- if len(self._chat_ctx.items) > MAX_MESSAGES:
684
- self._chat_ctx.truncate(max_items=MAX_MESSAGES)
685
697
 
686
698
  # cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
687
699
  async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
@@ -701,6 +713,8 @@ class RealtimeSession( # noqa: F811
701
713
  async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
702
714
  """Track mapping content_id -> response_id for upcoming tool use."""
703
715
  log_event_data(event_data)
716
+ assert self._current_generation is not None, "current_generation is None"
717
+
704
718
  tool_use_content_id = event_data["event"]["contentStart"]["contentId"]
705
719
  self._current_generation.tool_messages[tool_use_content_id] = (
706
720
  self._current_generation.response_id
@@ -710,6 +724,8 @@ class RealtimeSession( # noqa: F811
710
724
  async def _handle_tool_output_content_event(self, event_data: dict) -> None:
711
725
  """Execute the referenced tool locally and forward results back to Bedrock."""
712
726
  log_event_data(event_data)
727
+ assert self._current_generation is not None, "current_generation is None"
728
+
713
729
  tool_use_content_id = event_data["event"]["toolUse"]["contentId"]
714
730
  tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
715
731
  tool_name = event_data["event"]["toolUse"]["toolName"]
@@ -719,34 +735,38 @@ class RealtimeSession( # noqa: F811
719
735
  ):
720
736
  args = event_data["event"]["toolUse"]["content"]
721
737
  self._current_generation.function_ch.send_nowait(
722
- llm.FunctionCall(
723
- call_id=tool_use_id,
724
- name=tool_name,
725
- arguments=args,
726
- )
738
+ llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
727
739
  )
728
-
729
- # note: may need to inject RunContext here...
730
- tool_type = self._tool_type_map[tool_name]
731
- if tool_type == "FunctionTool":
732
- tool_result = await self.tools.function_tools[tool_name](**json.loads(args))
733
- elif tool_type == "RawFunctionTool":
734
- tool_result = await self.tools.function_tools[tool_name](json.loads(args))
735
- else:
736
- raise ValueError(f"Unknown tool type: {tool_type}")
737
- logger.debug(f"TOOL ARGS: {args}\nTOOL RESULT: {tool_result}")
738
-
739
- # Sonic only accepts Structured Output for tool results
740
- # therefore, must JSON stringify ToolError
741
- if isinstance(tool_result, ToolError):
742
- logger.warning(f"TOOL ERROR: {tool_name} {tool_result.message}")
743
- tool_result = {"error": tool_result.message}
744
- self._tool_results_ch.send_nowait(
745
- {
746
- "tool_use_id": tool_use_id,
747
- "tool_result": tool_result,
748
- }
740
+ self._pending_tools.add(tool_use_id)
741
+
742
+ # performing these acrobatics in order to release the deadlock
743
+ # LiveKit will not accept a new generation until the previous one is closed
744
+ # the issue is that audio data cannot be generated until toolResult is received
745
+ # however, toolResults only arrive after update_chat_ctx() is invoked
746
+ # which will only occur after agent speech has completed
747
+ # therefore we introduce an artificial turn to trigger update_chat_ctx()
748
+ # TODO: this is messy-- investigate if there is a better way to handle this
749
+ curr_gen = self._current_generation.messages[self._current_generation.response_id]
750
+ curr_gen.audio_ch.close()
751
+ curr_gen.text_ch.close()
752
+ self._current_generation.message_ch.close()
753
+ self._current_generation.message_ch = utils.aio.Chan()
754
+ self._current_generation.function_ch.close()
755
+ self._current_generation.function_ch = utils.aio.Chan()
756
+ msg_gen = _MessageGeneration(
757
+ message_id=self._current_generation.response_id,
758
+ text_ch=utils.aio.Chan(),
759
+ audio_ch=utils.aio.Chan(),
760
+ )
761
+ self._current_generation.messages[self._current_generation.response_id] = msg_gen
762
+ self._current_generation.message_ch.send_nowait(
763
+ llm.MessageGeneration(
764
+ message_id=msg_gen.message_id,
765
+ text_stream=msg_gen.text_ch,
766
+ audio_stream=msg_gen.audio_ch,
767
+ )
749
768
  )
769
+ self.emit_generation_event()
750
770
 
751
771
  async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
752
772
  log_event_data(event_data)
@@ -803,6 +823,12 @@ class RealtimeSession( # noqa: F811
803
823
  curr_gen.audio_ch.close()
804
824
  if not curr_gen.text_ch.closed:
805
825
  curr_gen.text_ch.close()
826
+ if self._current_generation.response_id in self._current_generation.tool_messages:
827
+ curr_gen = self._current_generation.tool_messages[
828
+ self._current_generation.response_id
829
+ ]
830
+ if not curr_gen.function_ch.closed:
831
+ curr_gen.function_ch.close()
806
832
 
807
833
  if not self._current_generation.message_ch.closed:
808
834
  self._current_generation.message_ch.close()
@@ -855,10 +881,11 @@ class RealtimeSession( # noqa: F811
855
881
  self.emit("metrics_collected", metrics)
856
882
 
857
883
  @utils.log_exceptions(logger=logger)
858
- async def _process_responses(self):
884
+ async def _process_responses(self) -> None:
859
885
  """Background task that drains Bedrock's output stream and feeds the event handlers."""
860
886
  try:
861
887
  await self._is_sess_active.wait()
888
+ assert self._stream_response is not None, "stream_response is None"
862
889
 
863
890
  # note: may need another signal here to block input task until bedrock is ready
864
891
  # TODO: save this as a field so we're not re-awaiting it every time
@@ -892,7 +919,6 @@ class RealtimeSession( # noqa: F811
892
919
 
893
920
  else:
894
921
  logger.error(f"Validation error: {ve}")
895
- request_id = ve.split(" ")[0].split("=")[1]
896
922
  self.emit(
897
923
  "error",
898
924
  llm.RealtimeModelError(
@@ -901,7 +927,7 @@ class RealtimeSession( # noqa: F811
901
927
  error=APIStatusError(
902
928
  message=ve.message,
903
929
  status_code=400,
904
- request_id=request_id,
930
+ request_id="",
905
931
  body=ve,
906
932
  retryable=False,
907
933
  ),
@@ -940,7 +966,7 @@ class RealtimeSession( # noqa: F811
940
966
  timestamp=time.monotonic(),
941
967
  label=self._realtime_model._label,
942
968
  error=APIStatusError(
943
- message=e.message,
969
+ message=err_msg,
944
970
  status_code=500,
945
971
  request_id=request_id,
946
972
  body=e,
@@ -1014,6 +1040,25 @@ class RealtimeSession( # noqa: F811
1014
1040
  logger.debug(f"Chat context updated: {self._chat_ctx.items}")
1015
1041
  self._chat_ctx_ready.set_result(True)
1016
1042
 
1043
+ # for each function tool, send the result to aws
1044
+ for item in chat_ctx.items:
1045
+ if item.type != "function_call_output":
1046
+ continue
1047
+
1048
+ if item.call_id not in self._pending_tools:
1049
+ continue
1050
+
1051
+ logger.debug(f"function call output: {item}")
1052
+ self._pending_tools.discard(item.call_id)
1053
+ self._tool_results_ch.send_nowait(
1054
+ {
1055
+ "tool_use_id": item.call_id,
1056
+ "tool_result": item.output
1057
+ if not item.is_error
1058
+ else f"{{'error': '{item.error}'}}",
1059
+ }
1060
+ )
1061
+
1017
1062
  async def _send_tool_events(self, tool_use_id: str, tool_result: str) -> None:
1018
1063
  """Send tool_result back to Bedrock, grouped under tool_use_id."""
1019
1064
  tool_content_name = str(uuid.uuid4())
@@ -1026,7 +1071,9 @@ class RealtimeSession( # noqa: F811
1026
1071
  await self._send_raw_event(event)
1027
1072
  # logger.debug(f"Sent tool event: {event}")
1028
1073
 
1029
- def _tool_choice_adapter(self, tool_choice: llm.ToolChoice) -> dict[str, dict[str, str]] | None:
1074
+ def _tool_choice_adapter(
1075
+ self, tool_choice: llm.ToolChoice | None
1076
+ ) -> dict[str, dict[str, str]] | None:
1030
1077
  """Translate the LiveKit ToolChoice enum into Sonic's JSON schema."""
1031
1078
  if tool_choice == "auto":
1032
1079
  return {"auto": {}}
@@ -1079,7 +1126,7 @@ class RealtimeSession( # noqa: F811
1079
1126
  yield frame
1080
1127
 
1081
1128
  @utils.log_exceptions(logger=logger)
1082
- async def _process_audio_input(self):
1129
+ async def _process_audio_input(self) -> None:
1083
1130
  """Background task that feeds audio and tool results into the Bedrock stream."""
1084
1131
  await self._send_raw_event(self._event_builder.create_audio_content_start_event())
1085
1132
  logger.info("Starting audio input processing loop")
@@ -1090,6 +1137,19 @@ class RealtimeSession( # noqa: F811
1090
1137
  val = self._tool_results_ch.recv_nowait()
1091
1138
  tool_result = val["tool_result"]
1092
1139
  tool_use_id = val["tool_use_id"]
1140
+ if not isinstance(tool_result, str):
1141
+ tool_result = json.dumps(tool_result)
1142
+ else:
1143
+ try:
1144
+ json.loads(tool_result)
1145
+ except json.JSONDecodeError:
1146
+ try:
1147
+ tool_result = json.dumps(ast.literal_eval(tool_result))
1148
+ except Exception:
1149
+ # return the original value
1150
+ pass
1151
+
1152
+ logger.debug(f"Sending tool result: {tool_result}")
1093
1153
  await self._send_tool_events(tool_use_id, tool_result)
1094
1154
 
1095
1155
  except utils.aio.channel.ChanEmpty:
@@ -1152,6 +1212,11 @@ class RealtimeSession( # noqa: F811
1152
1212
  instructions: NotGivenOr[str] = NOT_GIVEN,
1153
1213
  ) -> asyncio.Future[llm.GenerationCreatedEvent]:
1154
1214
  logger.warning("unprompted generation is not supported by Nova Sonic's Realtime API")
1215
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
1216
+ fut.set_exception(
1217
+ llm.RealtimeError("unprompted generation is not supported by Nova Sonic's Realtime API")
1218
+ )
1219
+ return fut
1155
1220
 
1156
1221
  def commit_audio(self) -> None:
1157
1222
  logger.warning("commit_audio is not supported by Nova Sonic's Realtime API")
@@ -1190,19 +1255,22 @@ class RealtimeSession( # noqa: F811
1190
1255
  # resulting in an error after cancellation
1191
1256
  # however, it's mostly cosmetic-- the event loop will still exit
1192
1257
  # TODO: fix this nit
1258
+ tasks: list[asyncio.Task[Any]] = []
1193
1259
  if self._response_task:
1194
1260
  try:
1195
1261
  await asyncio.wait_for(self._response_task, timeout=1.0)
1196
1262
  except asyncio.TimeoutError:
1197
1263
  logger.warning("shutdown of output event loop timed out-- cancelling")
1198
1264
  self._response_task.cancel()
1265
+ tasks.append(self._response_task)
1199
1266
 
1200
1267
  # must cancel the audio input task before closing the input stream
1201
1268
  if self._audio_input_task and not self._audio_input_task.done():
1202
1269
  self._audio_input_task.cancel()
1270
+ tasks.append(self._audio_input_task)
1203
1271
  if self._stream_response and not self._stream_response.input_stream.closed:
1204
1272
  await self._stream_response.input_stream.close()
1205
1273
 
1206
- await asyncio.gather(self._response_task, self._audio_input_task, return_exceptions=True)
1274
+ await asyncio.gather(*tasks, return_exceptions=True)
1207
1275
  logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
1208
1276
  logger.info("Session end")
@@ -6,7 +6,7 @@ import uuid
6
6
  from dataclasses import dataclass, field
7
7
  from typing import Any, Callable
8
8
 
9
- from livekit.agents import llm, utils
9
+ from livekit.agents import llm
10
10
 
11
11
  from ...log import logger
12
12
 
@@ -34,7 +34,7 @@ class _Turn:
34
34
  ev_trans_completed: bool = False
35
35
  ev_generation_sent: bool = False
36
36
 
37
- def add_partial_text(self, text: str):
37
+ def add_partial_text(self, text: str) -> None:
38
38
  self.transcript.append(text)
39
39
 
40
40
  @property
@@ -46,19 +46,17 @@ class _TurnTracker:
46
46
  def __init__(
47
47
  self,
48
48
  emit_fn: Callable[[str, Any], None],
49
- streams_provider: Callable[
50
- [], tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]
51
- ],
49
+ emit_generation_fn: Callable[[], None],
52
50
  ):
53
51
  self._emit = emit_fn
54
52
  self._turn_idx = 0
55
53
  self._curr_turn: _Turn | None = None
56
- self._get_streams = streams_provider
54
+ self._emit_generation_fn = emit_generation_fn
57
55
 
58
56
  # --------------------------------------------------------
59
57
  # PUBLIC ENTRY POINT
60
58
  # --------------------------------------------------------
61
- def feed(self, event: dict):
59
+ def feed(self, event: dict) -> None:
62
60
  turn = self._ensure_turn()
63
61
  kind = _classify(event)
64
62
 
@@ -97,13 +95,13 @@ class _TurnTracker:
97
95
  self._curr_turn = _Turn(turn_id=self._turn_idx)
98
96
  return self._curr_turn
99
97
 
100
- def _maybe_emit_input_started(self, turn: _Turn):
98
+ def _maybe_emit_input_started(self, turn: _Turn) -> None:
101
99
  if not turn.ev_input_started:
102
100
  turn.ev_input_started = True
103
101
  self._emit("input_speech_started", llm.InputSpeechStartedEvent())
104
102
  turn.phase = _Phase.USER_SPEAKING
105
103
 
106
- def _maybe_emit_input_stopped(self, turn: _Turn):
104
+ def _maybe_emit_input_stopped(self, turn: _Turn) -> None:
107
105
  if not turn.ev_input_stopped:
108
106
  turn.ev_input_stopped = True
109
107
  self._emit(
@@ -111,7 +109,7 @@ class _TurnTracker:
111
109
  )
112
110
  turn.phase = _Phase.USER_FINISHED
113
111
 
114
- def _emit_transcript_updated(self, turn: _Turn):
112
+ def _emit_transcript_updated(self, turn: _Turn) -> None:
115
113
  self._emit(
116
114
  "input_audio_transcription_completed",
117
115
  llm.InputTranscriptionCompleted(
@@ -121,7 +119,7 @@ class _TurnTracker:
121
119
  ),
122
120
  )
123
121
 
124
- def _maybe_emit_transcript_completed(self, turn: _Turn):
122
+ def _maybe_emit_transcript_completed(self, turn: _Turn) -> None:
125
123
  if not turn.ev_trans_completed:
126
124
  turn.ev_trans_completed = True
127
125
  self._emit(
@@ -134,17 +132,10 @@ class _TurnTracker:
134
132
  ),
135
133
  )
136
134
 
137
- def _maybe_emit_generation_created(self, turn: _Turn):
135
+ def _maybe_emit_generation_created(self, turn: _Turn) -> None:
138
136
  if not turn.ev_generation_sent:
139
137
  turn.ev_generation_sent = True
140
- msg_stream, fn_stream = self._get_streams()
141
- logger.debug("Emitting generation event")
142
- generation_ev = llm.GenerationCreatedEvent(
143
- message_stream=msg_stream,
144
- function_stream=fn_stream,
145
- user_initiated=False,
146
- )
147
- self._emit("generation_created", generation_ev)
138
+ self._emit_generation_fn()
148
139
  turn.phase = _Phase.ASSISTANT_RESPONDING
149
140
 
150
141
 
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.1.5"
15
+ __version__ = "1.1.6"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-aws
3
- Version: 1.1.5
3
+ Version: 1.1.6
4
4
  Summary: LiveKit Agents Plugin for services from AWS
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9.0
21
21
  Requires-Dist: aioboto3>=14.1.0
22
22
  Requires-Dist: amazon-transcribe>=0.6.2
23
- Requires-Dist: livekit-agents>=1.1.5
23
+ Requires-Dist: livekit-agents>=1.1.6
24
24
  Provides-Extra: realtime
25
25
  Requires-Dist: aws-sdk-bedrock-runtime==0.0.2; (python_version >= '3.12') and extra == 'realtime'
26
26
  Requires-Dist: boto3>1.35.10; extra == 'realtime'
@@ -44,3 +44,7 @@ pip install livekit-plugins-aws[realtime]
44
44
  ## Pre-requisites
45
45
 
46
46
  You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
47
+
48
+ ## Example
49
+
50
+ For an example of the realtime STS model, Nova Sonic, see: https://github.com/livekit/agents/blob/main/examples/voice_agents/realtime_joke_teller.py
@@ -6,12 +6,12 @@ livekit/plugins/aws/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
6
6
  livekit/plugins/aws/stt.py,sha256=PSR89aN28wm4i83yEdhkDJ9xzM0CsNIKrc3v3EbPndQ,9018
7
7
  livekit/plugins/aws/tts.py,sha256=T5dVpTuIuzQimYNnkfXi5dRLmRldWySL4IcbkXjmJLM,6083
8
8
  livekit/plugins/aws/utils.py,sha256=nA5Ua1f4T-25Loar6EvlrKTXI9N-zpTIH7cdQkwGyGI,1518
9
- livekit/plugins/aws/version.py,sha256=OKtayGMVDYKyoKBO2yNM4kfRbH-PODJqECIiYhUzNWg,600
9
+ livekit/plugins/aws/version.py,sha256=-bNd31cMcYCdhZCIKJ1-jtY4NgZvppVgKyzXAIzQtqM,600
10
10
  livekit/plugins/aws/experimental/realtime/__init__.py,sha256=mm_TGZc9QAWSO-VOO3PdE8Y5R6xlWckXRZuiFUIHa-Q,287
11
- livekit/plugins/aws/experimental/realtime/events.py,sha256=ViWr4_RLY0VDGTF-dDL0b_-7GFlF08Lw5_x6q3EJ5eM,15917
11
+ livekit/plugins/aws/experimental/realtime/events.py,sha256=-pJrwVrH5AZFxa1eDbX5nDdnJMz4BNucNZlYUYLsP-Y,15853
12
12
  livekit/plugins/aws/experimental/realtime/pretty_printer.py,sha256=KN7KPrfQu8cU7ff34vFAtfrd1umUSTVNKXQU7D8AMiM,1442
13
- livekit/plugins/aws/experimental/realtime/realtime_model.py,sha256=1FyGB7VkUHWHsgkzNEELc0-qOz3tbDEuT2PWlqI-2GU,55978
14
- livekit/plugins/aws/experimental/realtime/turn_tracker.py,sha256=ER1Inu9D7X4EZ_wqpKeidrx52JXfnmnQHmxOielbjvc,6363
15
- livekit_plugins_aws-1.1.5.dist-info/METADATA,sha256=EU-x14QER4ma3vx2b9vALcZMYYc1fv1f9fB5lX01E7Y,1827
16
- livekit_plugins_aws-1.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- livekit_plugins_aws-1.1.5.dist-info/RECORD,,
13
+ livekit/plugins/aws/experimental/realtime/realtime_model.py,sha256=em_3Fbp1qefF7cIIHc6ib1FLdD1MOGes2Lwq61o2wlk,59464
14
+ livekit/plugins/aws/experimental/realtime/turn_tracker.py,sha256=bcufaap-coeIYuK3ct1Is9W_UoefGYRmnJu7Mn5DCYU,6002
15
+ livekit_plugins_aws-1.1.6.dist-info/METADATA,sha256=ST8uYsoqQgHUVRCLC3BdkYdwALh3joYGRjblVKQgDrE,1989
16
+ livekit_plugins_aws-1.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ livekit_plugins_aws-1.1.6.dist-info/RECORD,,