livekit-plugins-aws 1.1.5__tar.gz → 1.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of livekit-plugins-aws might be problematic. Click here for more details.

Files changed (18) hide show
  1. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/PKG-INFO +6 -2
  2. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/README.md +4 -0
  3. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/events.py +20 -21
  4. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/realtime_model.py +147 -76
  5. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/turn_tracker.py +11 -20
  6. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/llm.py +3 -1
  7. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/version.py +1 -1
  8. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/pyproject.toml +1 -1
  9. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/.gitignore +0 -0
  10. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/__init__.py +0 -0
  11. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/__init__.py +0 -0
  12. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/pretty_printer.py +0 -0
  13. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/log.py +0 -0
  14. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/models.py +0 -0
  15. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/py.typed +0 -0
  16. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/stt.py +0 -0
  17. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/tts.py +0 -0
  18. {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-aws
3
- Version: 1.1.5
3
+ Version: 1.1.7
4
4
  Summary: LiveKit Agents Plugin for services from AWS
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9.0
21
21
  Requires-Dist: aioboto3>=14.1.0
22
22
  Requires-Dist: amazon-transcribe>=0.6.2
23
- Requires-Dist: livekit-agents>=1.1.5
23
+ Requires-Dist: livekit-agents>=1.1.7
24
24
  Provides-Extra: realtime
25
25
  Requires-Dist: aws-sdk-bedrock-runtime==0.0.2; (python_version >= '3.12') and extra == 'realtime'
26
26
  Requires-Dist: boto3>1.35.10; extra == 'realtime'
@@ -44,3 +44,7 @@ pip install livekit-plugins-aws[realtime]
44
44
  ## Pre-requisites
45
45
 
46
46
  You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
47
+
48
+ ## Example
49
+
50
+ For an example of the realtime STS model, Nova Sonic, see: https://github.com/livekit/agents/blob/main/examples/voice_agents/realtime_joke_teller.py
@@ -16,3 +16,7 @@ pip install livekit-plugins-aws[realtime]
16
16
  ## Pre-requisites
17
17
 
18
18
  You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
19
+
20
+ ## Example
21
+
22
+ For an example of the realtime STS model, Nova Sonic, see: https://github.com/livekit/agents/blob/main/examples/voice_agents/realtime_joke_teller.py
@@ -1,8 +1,8 @@
1
1
  import json
2
2
  import uuid
3
- from typing import Any, Literal, Optional, Union
3
+ from typing import Any, Literal, Optional, Union, cast
4
4
 
5
- from pydantic import BaseModel, ConfigDict, Field
5
+ from pydantic import BaseModel as _BaseModel, ConfigDict, Field
6
6
 
7
7
  from livekit.agents import llm
8
8
 
@@ -20,7 +20,7 @@ SAMPLE_SIZE_BITS = Literal[16] # only supports 16-bit audio
20
20
  CHANNEL_COUNT = Literal[1] # only supports monochannel audio
21
21
 
22
22
 
23
- class BaseModel(BaseModel):
23
+ class BaseModel(_BaseModel):
24
24
  model_config = ConfigDict(populate_by_name=True, extra="forbid")
25
25
 
26
26
 
@@ -91,7 +91,7 @@ class Tool(BaseModel):
91
91
 
92
92
 
93
93
  class ToolConfiguration(BaseModel):
94
- toolChoice: dict[str, dict[str, str]] | None = None
94
+ toolChoice: Optional[dict[str, dict[str, str]]] = None
95
95
  tools: list[Tool]
96
96
 
97
97
 
@@ -260,6 +260,8 @@ class SonicEventBuilder:
260
260
  else:
261
261
  return "other_event"
262
262
 
263
+ raise ValueError(f"Unknown event type: {json_data}")
264
+
263
265
  def create_text_content_block(
264
266
  self,
265
267
  content_name: str,
@@ -313,10 +315,18 @@ class SonicEventBuilder:
313
315
  if chat_ctx.items:
314
316
  logger.debug("initiating session with chat context")
315
317
  for item in chat_ctx.items:
318
+ if item.type != "message":
319
+ continue
320
+
321
+ if (role := item.role.upper()) not in ["USER", "ASSISTANT", "SYSTEM"]:
322
+ continue
323
+
316
324
  ctx_content_name = str(uuid.uuid4())
317
325
  init_events.extend(
318
326
  self.create_text_content_block(
319
- ctx_content_name, item.role.upper(), "".join(item.content)
327
+ ctx_content_name,
328
+ cast(ROLE, role),
329
+ "".join(c for c in item.content if isinstance(c, str)),
320
330
  )
321
331
  )
322
332
 
@@ -481,26 +491,15 @@ class SonicEventBuilder:
481
491
  sample_rate: SAMPLE_RATE_HERTZ,
482
492
  tool_configuration: Optional[Union[ToolConfiguration, dict[str, Any], str]] = None,
483
493
  ) -> str:
484
- tool_configuration = tool_configuration or ToolConfiguration(tools=[])
485
- for tool in tool_configuration.tools:
486
- logger.debug(f"TOOL JSON SCHEMA: {tool.toolSpec.inputSchema}")
487
- tool_objects = [
488
- Tool(
489
- toolSpec=ToolSpec(
490
- name=tool.toolSpec.name,
491
- description=tool.toolSpec.description,
492
- inputSchema=ToolInputSchema(json_=tool.toolSpec.inputSchema.json_),
493
- )
494
- )
495
- for tool in tool_configuration.tools
496
- ]
497
-
498
494
  if tool_configuration is None:
499
495
  tool_configuration = ToolConfiguration(tools=[])
500
496
  elif isinstance(tool_configuration, str):
501
- tool_configuration = ToolConfiguration(**json.loads(tool_configuration))
497
+ tool_configuration = ToolConfiguration.model_validate_json(tool_configuration)
502
498
  elif isinstance(tool_configuration, dict):
503
- tool_configuration = ToolConfiguration(**tool_configuration)
499
+ tool_configuration = ToolConfiguration.model_validate(tool_configuration)
500
+
501
+ for tool in tool_configuration.tools:
502
+ logger.debug(f"TOOL JSON SCHEMA: {tool.toolSpec.inputSchema}")
504
503
 
505
504
  tool_objects = list(tool_configuration.tools)
506
505
  event = Event(
@@ -1,5 +1,8 @@
1
+ # mypy: disable-error-code=unused-ignore
2
+
1
3
  from __future__ import annotations
2
4
 
5
+ import ast
3
6
  import asyncio
4
7
  import base64
5
8
  import json
@@ -10,7 +13,7 @@ import weakref
10
13
  from collections.abc import Iterator
11
14
  from dataclasses import dataclass, field
12
15
  from datetime import datetime
13
- from typing import Any, Literal
16
+ from typing import Any, Callable, Literal, cast
14
17
 
15
18
  import boto3
16
19
  from aws_sdk_bedrock_runtime.client import (
@@ -33,11 +36,9 @@ from smithy_core.aio.interfaces.identity import IdentityResolver
33
36
  from livekit import rtc
34
37
  from livekit.agents import (
35
38
  APIStatusError,
36
- ToolError,
37
39
  llm,
38
40
  utils,
39
41
  )
40
- from livekit.agents.llm.realtime import RealtimeSession
41
42
  from livekit.agents.metrics import RealtimeModelMetrics
42
43
  from livekit.agents.types import NOT_GIVEN, NotGivenOr
43
44
  from livekit.agents.utils import is_given
@@ -150,12 +151,12 @@ class _ResponseGeneration:
150
151
  speculative_messages: dict[str, str] = field(default_factory=dict)
151
152
  tool_messages: dict[str, str] = field(default_factory=dict)
152
153
  output_text: str = "" # agent ASR text
153
- _created_timestamp: str = field(default_factory=datetime.now().isoformat())
154
+ _created_timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
154
155
  _first_token_timestamp: float | None = None
155
156
  _completed_timestamp: float | None = None
156
157
 
157
158
 
158
- class Boto3CredentialsResolver(IdentityResolver):
159
+ class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc]
159
160
  """IdentityResolver implementation that sources AWS credentials from boto3.
160
161
 
161
162
  The resolver delegates to the default boto3.Session() credential chain which
@@ -164,10 +165,10 @@ class Boto3CredentialsResolver(IdentityResolver):
164
165
  passed into Bedrock runtime clients.
165
166
  """
166
167
 
167
- def __init__(self):
168
- self.session = boto3.Session()
168
+ def __init__(self) -> None:
169
+ self.session = boto3.Session() # type: ignore[attr-defined]
169
170
 
170
- async def get_identity(self, **kwargs):
171
+ async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity:
171
172
  """Asynchronously resolve AWS credentials.
172
173
 
173
174
  This method is invoked by the Bedrock runtime client whenever a new request needs to be
@@ -247,7 +248,7 @@ class RealtimeModel(llm.RealtimeModel):
247
248
  self.temperature = temperature
248
249
  self.top_p = top_p
249
250
  self._opts = _RealtimeOptions(
250
- voice=voice if is_given(voice) else "tiffany",
251
+ voice=cast(VOICE_ID, voice) if is_given(voice) else "tiffany",
251
252
  temperature=temperature if is_given(temperature) else DEFAULT_TEMPERATURE,
252
253
  top_p=top_p if is_given(top_p) else DEFAULT_TOP_P,
253
254
  max_tokens=max_tokens if is_given(max_tokens) else DEFAULT_MAX_TOKENS,
@@ -295,7 +296,7 @@ class RealtimeSession( # noqa: F811
295
296
  inference options and the Smithy Bedrock client configuration.
296
297
  """
297
298
  super().__init__(realtime_model)
298
- self._realtime_model = realtime_model
299
+ self._realtime_model: RealtimeModel = realtime_model
299
300
  self._event_builder = seb(
300
301
  prompt_name=str(uuid.uuid4()),
301
302
  audio_content_name=str(uuid.uuid4()),
@@ -309,10 +310,10 @@ class RealtimeSession( # noqa: F811
309
310
  self._audio_input_task = None
310
311
  self._stream_response = None
311
312
  self._bedrock_client = None
313
+ self._pending_tools: set[str] = set()
312
314
  self._is_sess_active = asyncio.Event()
313
315
  self._chat_ctx = llm.ChatContext.empty()
314
316
  self._tools = llm.ToolContext.empty()
315
- self._tool_type_map = {}
316
317
  self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
317
318
  self._tools_ready = asyncio.get_running_loop().create_future()
318
319
  self._instructions_ready = asyncio.get_running_loop().create_future()
@@ -341,16 +342,12 @@ class RealtimeSession( # noqa: F811
341
342
  "other_event": self._handle_other_event,
342
343
  }
343
344
  self._turn_tracker = _TurnTracker(
344
- self.emit, streams_provider=self._current_generation_streams
345
+ cast(Callable[[str, Any], None], self.emit),
346
+ cast(Callable[[], None], self.emit_generation_event),
345
347
  )
346
348
 
347
- def _current_generation_streams(
348
- self,
349
- ) -> tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]:
350
- return (self._current_generation.message_ch, self._current_generation.function_ch)
351
-
352
349
  @utils.log_exceptions(logger=logger)
353
- def _initialize_client(self):
350
+ def _initialize_client(self) -> None:
354
351
  """Instantiate the Bedrock runtime client"""
355
352
  config = Config(
356
353
  endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com",
@@ -358,15 +355,16 @@ class RealtimeSession( # noqa: F811
358
355
  aws_credentials_identity_resolver=Boto3CredentialsResolver(),
359
356
  http_auth_scheme_resolver=HTTPAuthSchemeResolver(),
360
357
  http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()},
358
+ user_agent_extra="x-client-framework:livekit-plugins-aws[realtime]",
361
359
  )
362
360
  self._bedrock_client = BedrockRuntimeClient(config=config)
363
361
 
364
362
  @utils.log_exceptions(logger=logger)
365
- async def _send_raw_event(self, event_json):
363
+ async def _send_raw_event(self, event_json: str) -> None:
366
364
  """Low-level helper that serialises event_json and forwards it to the bidirectional stream.
367
365
 
368
366
  Args:
369
- event_json (dict | str): The JSON payload (already in Bedrock wire format) to queue.
367
+ event_json (str): The JSON payload (already in Bedrock wire format) to queue.
370
368
 
371
369
  Raises:
372
370
  Exception: Propagates any failures returned by the Bedrock runtime client.
@@ -425,21 +423,21 @@ class RealtimeSession( # noqa: F811
425
423
  input_schema = llm.utils.build_legacy_openai_schema(f, internally_tagged=True)[
426
424
  "parameters"
427
425
  ]
428
- self._tool_type_map[name] = "FunctionTool"
429
- else:
426
+ elif llm.tool_context.is_raw_function_tool(f):
430
427
  description = llm.tool_context.get_raw_function_info(f).raw_schema.get(
431
428
  "description"
432
429
  )
433
430
  input_schema = llm.tool_context.get_raw_function_info(f).raw_schema[
434
431
  "parameters"
435
432
  ]
436
- self._tool_type_map[name] = "RawFunctionTool"
433
+ else:
434
+ continue
437
435
 
438
436
  tool = Tool(
439
437
  toolSpec=ToolSpec(
440
438
  name=name,
441
- description=description,
442
- inputSchema=ToolInputSchema(json_=json.dumps(input_schema)),
439
+ description=description or "No description provided",
440
+ inputSchema=ToolInputSchema(json_=json.dumps(input_schema)), # type: ignore
443
441
  )
444
442
  )
445
443
  tools.append(tool)
@@ -455,7 +453,7 @@ class RealtimeSession( # noqa: F811
455
453
  return tool_cfg
456
454
 
457
455
  @utils.log_exceptions(logger=logger)
458
- async def initialize_streams(self, is_restart: bool = False):
456
+ async def initialize_streams(self, is_restart: bool = False) -> None:
459
457
  """Open the Bedrock bidirectional stream and spawn background worker tasks.
460
458
 
461
459
  This coroutine is idempotent and can be invoked again when recoverable
@@ -469,6 +467,7 @@ class RealtimeSession( # noqa: F811
469
467
  if not self._bedrock_client:
470
468
  logger.info("Creating Bedrock client")
471
469
  self._initialize_client()
470
+ assert self._bedrock_client is not None, "bedrock_client is None"
472
471
 
473
472
  logger.info("Initializing Bedrock stream")
474
473
  self._stream_response = (
@@ -518,7 +517,7 @@ class RealtimeSession( # noqa: F811
518
517
  self._chat_ctx.truncate(max_items=MAX_MESSAGES)
519
518
  init_events = self._event_builder.create_prompt_start_block(
520
519
  voice_id=self._realtime_model._opts.voice,
521
- sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
520
+ sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE, # type: ignore
522
521
  system_content=self._instructions,
523
522
  chat_ctx=self.chat_ctx,
524
523
  tool_configuration=self._serialize_tool_config(),
@@ -542,15 +541,16 @@ class RealtimeSession( # noqa: F811
542
541
  self._is_sess_active.set()
543
542
  logger.debug("Stream initialized successfully")
544
543
  except Exception as e:
545
- self._is_sess_active.set_exception(e)
546
544
  logger.debug(f"Failed to initialize stream: {str(e)}")
547
545
  raise
548
546
  return self
549
547
 
550
548
  @utils.log_exceptions(logger=logger)
551
- def _emit_generation_event(self) -> None:
549
+ def emit_generation_event(self) -> None:
552
550
  """Publish a llm.GenerationCreatedEvent to external subscribers."""
553
551
  logger.debug("Emitting generation event")
552
+ assert self._current_generation is not None, "current_generation is None"
553
+
554
554
  generation_ev = llm.GenerationCreatedEvent(
555
555
  message_stream=self._current_generation.message_ch,
556
556
  function_stream=self._current_generation.function_ch,
@@ -605,10 +605,12 @@ class RealtimeSession( # noqa: F811
605
605
  """Handle text_output_content_start for both user and assistant roles."""
606
606
  log_event_data(event_data)
607
607
  role = event_data["event"]["contentStart"]["role"]
608
+ self._create_response_generation()
608
609
 
609
610
  # note: does not work if you emit llm.GCE too early (for some reason)
610
611
  if role == "USER":
611
- self._create_response_generation()
612
+ assert self._current_generation is not None, "current_generation is None"
613
+
612
614
  content_id = event_data["event"]["contentStart"]["contentId"]
613
615
  self._current_generation.user_messages[content_id] = self._current_generation.input_id
614
616
 
@@ -616,6 +618,8 @@ class RealtimeSession( # noqa: F811
616
618
  role == "ASSISTANT"
617
619
  and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
618
620
  ):
621
+ assert self._current_generation is not None, "current_generation is None"
622
+
619
623
  text_content_id = event_data["event"]["contentStart"]["contentId"]
620
624
  self._current_generation.speculative_messages[text_content_id] = (
621
625
  self._current_generation.response_id
@@ -633,10 +637,15 @@ class RealtimeSession( # noqa: F811
633
637
  # this is b/c audio playback is desynced from text transcription
634
638
  # TODO: fix this; possibly via a playback timer
635
639
  idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
640
+ if idx < 0:
641
+ logger.warning("Barge-in DETECTED but no previous message found")
642
+ return
643
+
636
644
  logger.debug(
637
645
  f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
638
646
  )
639
- self._chat_ctx.items[idx].interrupted = True
647
+ if (item := self._chat_ctx.items[idx]).type == "message":
648
+ item.interrupted = True
640
649
  self._close_current_generation()
641
650
  return
642
651
 
@@ -661,27 +670,31 @@ class RealtimeSession( # noqa: F811
661
670
  # note: this update is per utterance, not per turn
662
671
  self._update_chat_ctx(role="assistant", text_content=text_content)
663
672
 
664
- def _update_chat_ctx(self, role: str, text_content: str) -> None:
673
+ def _update_chat_ctx(self, role: llm.ChatRole, text_content: str) -> None:
665
674
  """
666
675
  Update the chat context with the latest ASR text while guarding against model limitations:
667
676
  a) 40 total messages limit
668
677
  b) 1kB message size limit
669
678
  """
670
- prev_utterance = self._chat_ctx.items[-1]
671
- if prev_utterance.role == role:
672
- if (
673
- len(prev_utterance.content[0].encode("utf-8")) + len(text_content.encode("utf-8"))
674
- < MAX_MESSAGE_SIZE
675
- ):
676
- prev_utterance.content[0] = "\n".join([prev_utterance.content[0], text_content])
679
+ logger.debug(f"Updating chat context with role: {role} and text_content: {text_content}")
680
+ if len(self._chat_ctx.items) == 0:
681
+ self._chat_ctx.add_message(role=role, content=text_content)
682
+ else:
683
+ prev_utterance = self._chat_ctx.items[-1]
684
+ if prev_utterance.type == "message" and prev_utterance.role == role:
685
+ if isinstance(prev_content := prev_utterance.content[0], str) and (
686
+ len(prev_content.encode("utf-8")) + len(text_content.encode("utf-8"))
687
+ < MAX_MESSAGE_SIZE
688
+ ):
689
+ prev_utterance.content[0] = "\n".join([prev_content, text_content])
690
+ else:
691
+ self._chat_ctx.add_message(role=role, content=text_content)
692
+ if len(self._chat_ctx.items) > MAX_MESSAGES:
693
+ self._chat_ctx.truncate(max_items=MAX_MESSAGES)
677
694
  else:
678
695
  self._chat_ctx.add_message(role=role, content=text_content)
679
696
  if len(self._chat_ctx.items) > MAX_MESSAGES:
680
697
  self._chat_ctx.truncate(max_items=MAX_MESSAGES)
681
- else:
682
- self._chat_ctx.add_message(role=role, content=text_content)
683
- if len(self._chat_ctx.items) > MAX_MESSAGES:
684
- self._chat_ctx.truncate(max_items=MAX_MESSAGES)
685
698
 
686
699
  # cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
687
700
  async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
@@ -701,6 +714,8 @@ class RealtimeSession( # noqa: F811
701
714
  async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
702
715
  """Track mapping content_id -> response_id for upcoming tool use."""
703
716
  log_event_data(event_data)
717
+ assert self._current_generation is not None, "current_generation is None"
718
+
704
719
  tool_use_content_id = event_data["event"]["contentStart"]["contentId"]
705
720
  self._current_generation.tool_messages[tool_use_content_id] = (
706
721
  self._current_generation.response_id
@@ -710,6 +725,8 @@ class RealtimeSession( # noqa: F811
710
725
  async def _handle_tool_output_content_event(self, event_data: dict) -> None:
711
726
  """Execute the referenced tool locally and forward results back to Bedrock."""
712
727
  log_event_data(event_data)
728
+ assert self._current_generation is not None, "current_generation is None"
729
+
713
730
  tool_use_content_id = event_data["event"]["toolUse"]["contentId"]
714
731
  tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
715
732
  tool_name = event_data["event"]["toolUse"]["toolName"]
@@ -719,34 +736,38 @@ class RealtimeSession( # noqa: F811
719
736
  ):
720
737
  args = event_data["event"]["toolUse"]["content"]
721
738
  self._current_generation.function_ch.send_nowait(
722
- llm.FunctionCall(
723
- call_id=tool_use_id,
724
- name=tool_name,
725
- arguments=args,
726
- )
739
+ llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
727
740
  )
728
-
729
- # note: may need to inject RunContext here...
730
- tool_type = self._tool_type_map[tool_name]
731
- if tool_type == "FunctionTool":
732
- tool_result = await self.tools.function_tools[tool_name](**json.loads(args))
733
- elif tool_type == "RawFunctionTool":
734
- tool_result = await self.tools.function_tools[tool_name](json.loads(args))
735
- else:
736
- raise ValueError(f"Unknown tool type: {tool_type}")
737
- logger.debug(f"TOOL ARGS: {args}\nTOOL RESULT: {tool_result}")
738
-
739
- # Sonic only accepts Structured Output for tool results
740
- # therefore, must JSON stringify ToolError
741
- if isinstance(tool_result, ToolError):
742
- logger.warning(f"TOOL ERROR: {tool_name} {tool_result.message}")
743
- tool_result = {"error": tool_result.message}
744
- self._tool_results_ch.send_nowait(
745
- {
746
- "tool_use_id": tool_use_id,
747
- "tool_result": tool_result,
748
- }
741
+ self._pending_tools.add(tool_use_id)
742
+
743
+ # performing these acrobatics in order to release the deadlock
744
+ # LiveKit will not accept a new generation until the previous one is closed
745
+ # the issue is that audio data cannot be generated until toolResult is received
746
+ # however, toolResults only arrive after update_chat_ctx() is invoked
747
+ # which will only occur after agent speech has completed
748
+ # therefore we introduce an artificial turn to trigger update_chat_ctx()
749
+ # TODO: this is messy-- investigate if there is a better way to handle this
750
+ curr_gen = self._current_generation.messages[self._current_generation.response_id]
751
+ curr_gen.audio_ch.close()
752
+ curr_gen.text_ch.close()
753
+ self._current_generation.message_ch.close()
754
+ self._current_generation.message_ch = utils.aio.Chan()
755
+ self._current_generation.function_ch.close()
756
+ self._current_generation.function_ch = utils.aio.Chan()
757
+ msg_gen = _MessageGeneration(
758
+ message_id=self._current_generation.response_id,
759
+ text_ch=utils.aio.Chan(),
760
+ audio_ch=utils.aio.Chan(),
761
+ )
762
+ self._current_generation.messages[self._current_generation.response_id] = msg_gen
763
+ self._current_generation.message_ch.send_nowait(
764
+ llm.MessageGeneration(
765
+ message_id=msg_gen.message_id,
766
+ text_stream=msg_gen.text_ch,
767
+ audio_stream=msg_gen.audio_ch,
768
+ )
749
769
  )
770
+ self.emit_generation_event()
750
771
 
751
772
  async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
752
773
  log_event_data(event_data)
@@ -804,6 +825,14 @@ class RealtimeSession( # noqa: F811
804
825
  if not curr_gen.text_ch.closed:
805
826
  curr_gen.text_ch.close()
806
827
 
828
+ # TODO: seems not needed, tool_messages[id] is a str, function_ch is closed below?
829
+ # if self._current_generation.response_id in self._current_generation.tool_messages:
830
+ # curr_gen = self._current_generation.tool_messages[
831
+ # self._current_generation.response_id
832
+ # ]
833
+ # if not curr_gen.function_ch.closed:
834
+ # curr_gen.function_ch.close()
835
+
807
836
  if not self._current_generation.message_ch.closed:
808
837
  self._current_generation.message_ch.close()
809
838
  if not self._current_generation.function_ch.closed:
@@ -855,10 +884,11 @@ class RealtimeSession( # noqa: F811
855
884
  self.emit("metrics_collected", metrics)
856
885
 
857
886
  @utils.log_exceptions(logger=logger)
858
- async def _process_responses(self):
887
+ async def _process_responses(self) -> None:
859
888
  """Background task that drains Bedrock's output stream and feeds the event handlers."""
860
889
  try:
861
890
  await self._is_sess_active.wait()
891
+ assert self._stream_response is not None, "stream_response is None"
862
892
 
863
893
  # note: may need another signal here to block input task until bedrock is ready
864
894
  # TODO: save this as a field so we're not re-awaiting it every time
@@ -892,7 +922,6 @@ class RealtimeSession( # noqa: F811
892
922
 
893
923
  else:
894
924
  logger.error(f"Validation error: {ve}")
895
- request_id = ve.split(" ")[0].split("=")[1]
896
925
  self.emit(
897
926
  "error",
898
927
  llm.RealtimeModelError(
@@ -901,7 +930,7 @@ class RealtimeSession( # noqa: F811
901
930
  error=APIStatusError(
902
931
  message=ve.message,
903
932
  status_code=400,
904
- request_id=request_id,
933
+ request_id="",
905
934
  body=ve,
906
935
  retryable=False,
907
936
  ),
@@ -940,7 +969,7 @@ class RealtimeSession( # noqa: F811
940
969
  timestamp=time.monotonic(),
941
970
  label=self._realtime_model._label,
942
971
  error=APIStatusError(
943
- message=e.message,
972
+ message=err_msg,
944
973
  status_code=500,
945
974
  request_id=request_id,
946
975
  body=e,
@@ -1014,6 +1043,25 @@ class RealtimeSession( # noqa: F811
1014
1043
  logger.debug(f"Chat context updated: {self._chat_ctx.items}")
1015
1044
  self._chat_ctx_ready.set_result(True)
1016
1045
 
1046
+ # for each function tool, send the result to aws
1047
+ for item in chat_ctx.items:
1048
+ if item.type != "function_call_output":
1049
+ continue
1050
+
1051
+ if item.call_id not in self._pending_tools:
1052
+ continue
1053
+
1054
+ logger.debug(f"function call output: {item}")
1055
+ self._pending_tools.discard(item.call_id)
1056
+ self._tool_results_ch.send_nowait(
1057
+ {
1058
+ "tool_use_id": item.call_id,
1059
+ "tool_result": item.output
1060
+ if not item.is_error
1061
+ else f"{{'error': '{item.output}'}}",
1062
+ }
1063
+ )
1064
+
1017
1065
  async def _send_tool_events(self, tool_use_id: str, tool_result: str) -> None:
1018
1066
  """Send tool_result back to Bedrock, grouped under tool_use_id."""
1019
1067
  tool_content_name = str(uuid.uuid4())
@@ -1026,7 +1074,9 @@ class RealtimeSession( # noqa: F811
1026
1074
  await self._send_raw_event(event)
1027
1075
  # logger.debug(f"Sent tool event: {event}")
1028
1076
 
1029
- def _tool_choice_adapter(self, tool_choice: llm.ToolChoice) -> dict[str, dict[str, str]] | None:
1077
+ def _tool_choice_adapter(
1078
+ self, tool_choice: llm.ToolChoice | None
1079
+ ) -> dict[str, dict[str, str]] | None:
1030
1080
  """Translate the LiveKit ToolChoice enum into Sonic's JSON schema."""
1031
1081
  if tool_choice == "auto":
1032
1082
  return {"auto": {}}
@@ -1079,7 +1129,7 @@ class RealtimeSession( # noqa: F811
1079
1129
  yield frame
1080
1130
 
1081
1131
  @utils.log_exceptions(logger=logger)
1082
- async def _process_audio_input(self):
1132
+ async def _process_audio_input(self) -> None:
1083
1133
  """Background task that feeds audio and tool results into the Bedrock stream."""
1084
1134
  await self._send_raw_event(self._event_builder.create_audio_content_start_event())
1085
1135
  logger.info("Starting audio input processing loop")
@@ -1090,6 +1140,19 @@ class RealtimeSession( # noqa: F811
1090
1140
  val = self._tool_results_ch.recv_nowait()
1091
1141
  tool_result = val["tool_result"]
1092
1142
  tool_use_id = val["tool_use_id"]
1143
+ if not isinstance(tool_result, str):
1144
+ tool_result = json.dumps(tool_result)
1145
+ else:
1146
+ try:
1147
+ json.loads(tool_result)
1148
+ except json.JSONDecodeError:
1149
+ try:
1150
+ tool_result = json.dumps(ast.literal_eval(tool_result))
1151
+ except Exception:
1152
+ # return the original value
1153
+ pass
1154
+
1155
+ logger.debug(f"Sending tool result: {tool_result}")
1093
1156
  await self._send_tool_events(tool_use_id, tool_result)
1094
1157
 
1095
1158
  except utils.aio.channel.ChanEmpty:
@@ -1152,6 +1215,11 @@ class RealtimeSession( # noqa: F811
1152
1215
  instructions: NotGivenOr[str] = NOT_GIVEN,
1153
1216
  ) -> asyncio.Future[llm.GenerationCreatedEvent]:
1154
1217
  logger.warning("unprompted generation is not supported by Nova Sonic's Realtime API")
1218
+ fut = asyncio.Future[llm.GenerationCreatedEvent]()
1219
+ fut.set_exception(
1220
+ llm.RealtimeError("unprompted generation is not supported by Nova Sonic's Realtime API")
1221
+ )
1222
+ return fut
1155
1223
 
1156
1224
  def commit_audio(self) -> None:
1157
1225
  logger.warning("commit_audio is not supported by Nova Sonic's Realtime API")
@@ -1190,19 +1258,22 @@ class RealtimeSession( # noqa: F811
1190
1258
  # resulting in an error after cancellation
1191
1259
  # however, it's mostly cosmetic-- the event loop will still exit
1192
1260
  # TODO: fix this nit
1261
+ tasks: list[asyncio.Task[Any]] = []
1193
1262
  if self._response_task:
1194
1263
  try:
1195
1264
  await asyncio.wait_for(self._response_task, timeout=1.0)
1196
1265
  except asyncio.TimeoutError:
1197
1266
  logger.warning("shutdown of output event loop timed out-- cancelling")
1198
1267
  self._response_task.cancel()
1268
+ tasks.append(self._response_task)
1199
1269
 
1200
1270
  # must cancel the audio input task before closing the input stream
1201
1271
  if self._audio_input_task and not self._audio_input_task.done():
1202
1272
  self._audio_input_task.cancel()
1273
+ tasks.append(self._audio_input_task)
1203
1274
  if self._stream_response and not self._stream_response.input_stream.closed:
1204
1275
  await self._stream_response.input_stream.close()
1205
1276
 
1206
- await asyncio.gather(self._response_task, self._audio_input_task, return_exceptions=True)
1277
+ await asyncio.gather(*tasks, return_exceptions=True)
1207
1278
  logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
1208
1279
  logger.info("Session end")
@@ -6,7 +6,7 @@ import uuid
6
6
  from dataclasses import dataclass, field
7
7
  from typing import Any, Callable
8
8
 
9
- from livekit.agents import llm, utils
9
+ from livekit.agents import llm
10
10
 
11
11
  from ...log import logger
12
12
 
@@ -34,7 +34,7 @@ class _Turn:
34
34
  ev_trans_completed: bool = False
35
35
  ev_generation_sent: bool = False
36
36
 
37
- def add_partial_text(self, text: str):
37
+ def add_partial_text(self, text: str) -> None:
38
38
  self.transcript.append(text)
39
39
 
40
40
  @property
@@ -46,19 +46,17 @@ class _TurnTracker:
46
46
  def __init__(
47
47
  self,
48
48
  emit_fn: Callable[[str, Any], None],
49
- streams_provider: Callable[
50
- [], tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]
51
- ],
49
+ emit_generation_fn: Callable[[], None],
52
50
  ):
53
51
  self._emit = emit_fn
54
52
  self._turn_idx = 0
55
53
  self._curr_turn: _Turn | None = None
56
- self._get_streams = streams_provider
54
+ self._emit_generation_fn = emit_generation_fn
57
55
 
58
56
  # --------------------------------------------------------
59
57
  # PUBLIC ENTRY POINT
60
58
  # --------------------------------------------------------
61
- def feed(self, event: dict):
59
+ def feed(self, event: dict) -> None:
62
60
  turn = self._ensure_turn()
63
61
  kind = _classify(event)
64
62
 
@@ -97,13 +95,13 @@ class _TurnTracker:
97
95
  self._curr_turn = _Turn(turn_id=self._turn_idx)
98
96
  return self._curr_turn
99
97
 
100
- def _maybe_emit_input_started(self, turn: _Turn):
98
+ def _maybe_emit_input_started(self, turn: _Turn) -> None:
101
99
  if not turn.ev_input_started:
102
100
  turn.ev_input_started = True
103
101
  self._emit("input_speech_started", llm.InputSpeechStartedEvent())
104
102
  turn.phase = _Phase.USER_SPEAKING
105
103
 
106
- def _maybe_emit_input_stopped(self, turn: _Turn):
104
+ def _maybe_emit_input_stopped(self, turn: _Turn) -> None:
107
105
  if not turn.ev_input_stopped:
108
106
  turn.ev_input_stopped = True
109
107
  self._emit(
@@ -111,7 +109,7 @@ class _TurnTracker:
111
109
  )
112
110
  turn.phase = _Phase.USER_FINISHED
113
111
 
114
- def _emit_transcript_updated(self, turn: _Turn):
112
+ def _emit_transcript_updated(self, turn: _Turn) -> None:
115
113
  self._emit(
116
114
  "input_audio_transcription_completed",
117
115
  llm.InputTranscriptionCompleted(
@@ -121,7 +119,7 @@ class _TurnTracker:
121
119
  ),
122
120
  )
123
121
 
124
- def _maybe_emit_transcript_completed(self, turn: _Turn):
122
+ def _maybe_emit_transcript_completed(self, turn: _Turn) -> None:
125
123
  if not turn.ev_trans_completed:
126
124
  turn.ev_trans_completed = True
127
125
  self._emit(
@@ -134,17 +132,10 @@ class _TurnTracker:
134
132
  ),
135
133
  )
136
134
 
137
- def _maybe_emit_generation_created(self, turn: _Turn):
135
+ def _maybe_emit_generation_created(self, turn: _Turn) -> None:
138
136
  if not turn.ev_generation_sent:
139
137
  turn.ev_generation_sent = True
140
- msg_stream, fn_stream = self._get_streams()
141
- logger.debug("Emitting generation event")
142
- generation_ev = llm.GenerationCreatedEvent(
143
- message_stream=msg_stream,
144
- function_stream=fn_stream,
145
- user_initiated=False,
146
- )
147
- self._emit("generation_created", generation_ev)
138
+ self._emit_generation_fn()
148
139
  turn.phase = _Phase.ASSISTANT_RESPONDING
149
140
 
150
141
 
@@ -19,6 +19,7 @@ from dataclasses import dataclass
19
19
  from typing import Any, cast
20
20
 
21
21
  import aioboto3 # type: ignore
22
+ from botocore.config import Config
22
23
 
23
24
  from livekit.agents import APIConnectionError, APIStatusError, llm
24
25
  from livekit.agents.llm import (
@@ -205,7 +206,8 @@ class LLMStream(llm.LLMStream):
205
206
  async def _run(self) -> None:
206
207
  retryable = True
207
208
  try:
208
- async with self._session.client("bedrock-runtime") as client:
209
+ config = Config(user_agent_extra="x-client-framework:livekit-plugins-aws")
210
+ async with self._session.client("bedrock-runtime", config=config) as client:
209
211
  response = await client.converse_stream(**self._opts)
210
212
  request_id = response["ResponseMetadata"]["RequestId"]
211
213
  if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.1.5"
15
+ __version__ = "1.1.7"
@@ -23,7 +23,7 @@ classifiers = [
23
23
  "Programming Language :: Python :: 3 :: Only",
24
24
  ]
25
25
  dependencies = [
26
- "livekit-agents>=1.1.5",
26
+ "livekit-agents>=1.1.7",
27
27
  "aioboto3>=14.1.0",
28
28
  "amazon-transcribe>=0.6.2",
29
29
  ]