livekit-plugins-aws 1.1.5__py3-none-any.whl → 1.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of livekit-plugins-aws might be problematic. Click here for more details.
- livekit/plugins/aws/experimental/realtime/events.py +20 -21
- livekit/plugins/aws/experimental/realtime/realtime_model.py +143 -75
- livekit/plugins/aws/experimental/realtime/turn_tracker.py +11 -20
- livekit/plugins/aws/version.py +1 -1
- {livekit_plugins_aws-1.1.5.dist-info → livekit_plugins_aws-1.1.6.dist-info}/METADATA +6 -2
- {livekit_plugins_aws-1.1.5.dist-info → livekit_plugins_aws-1.1.6.dist-info}/RECORD +7 -7
- {livekit_plugins_aws-1.1.5.dist-info → livekit_plugins_aws-1.1.6.dist-info}/WHEEL +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import uuid
|
|
3
|
-
from typing import Any, Literal, Optional, Union
|
|
3
|
+
from typing import Any, Literal, Optional, Union, cast
|
|
4
4
|
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field
|
|
5
|
+
from pydantic import BaseModel as _BaseModel, ConfigDict, Field
|
|
6
6
|
|
|
7
7
|
from livekit.agents import llm
|
|
8
8
|
|
|
@@ -20,7 +20,7 @@ SAMPLE_SIZE_BITS = Literal[16] # only supports 16-bit audio
|
|
|
20
20
|
CHANNEL_COUNT = Literal[1] # only supports monochannel audio
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
class BaseModel(
|
|
23
|
+
class BaseModel(_BaseModel):
|
|
24
24
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
|
25
25
|
|
|
26
26
|
|
|
@@ -91,7 +91,7 @@ class Tool(BaseModel):
|
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
class ToolConfiguration(BaseModel):
|
|
94
|
-
toolChoice: dict[str, dict[str, str]]
|
|
94
|
+
toolChoice: Optional[dict[str, dict[str, str]]] = None
|
|
95
95
|
tools: list[Tool]
|
|
96
96
|
|
|
97
97
|
|
|
@@ -260,6 +260,8 @@ class SonicEventBuilder:
|
|
|
260
260
|
else:
|
|
261
261
|
return "other_event"
|
|
262
262
|
|
|
263
|
+
raise ValueError(f"Unknown event type: {json_data}")
|
|
264
|
+
|
|
263
265
|
def create_text_content_block(
|
|
264
266
|
self,
|
|
265
267
|
content_name: str,
|
|
@@ -313,10 +315,18 @@ class SonicEventBuilder:
|
|
|
313
315
|
if chat_ctx.items:
|
|
314
316
|
logger.debug("initiating session with chat context")
|
|
315
317
|
for item in chat_ctx.items:
|
|
318
|
+
if item.type != "message":
|
|
319
|
+
continue
|
|
320
|
+
|
|
321
|
+
if (role := item.role.upper()) not in ["USER", "ASSISTANT", "SYSTEM"]:
|
|
322
|
+
continue
|
|
323
|
+
|
|
316
324
|
ctx_content_name = str(uuid.uuid4())
|
|
317
325
|
init_events.extend(
|
|
318
326
|
self.create_text_content_block(
|
|
319
|
-
ctx_content_name,
|
|
327
|
+
ctx_content_name,
|
|
328
|
+
cast(ROLE, role),
|
|
329
|
+
"".join(c for c in item.content if isinstance(c, str)),
|
|
320
330
|
)
|
|
321
331
|
)
|
|
322
332
|
|
|
@@ -481,26 +491,15 @@ class SonicEventBuilder:
|
|
|
481
491
|
sample_rate: SAMPLE_RATE_HERTZ,
|
|
482
492
|
tool_configuration: Optional[Union[ToolConfiguration, dict[str, Any], str]] = None,
|
|
483
493
|
) -> str:
|
|
484
|
-
tool_configuration = tool_configuration or ToolConfiguration(tools=[])
|
|
485
|
-
for tool in tool_configuration.tools:
|
|
486
|
-
logger.debug(f"TOOL JSON SCHEMA: {tool.toolSpec.inputSchema}")
|
|
487
|
-
tool_objects = [
|
|
488
|
-
Tool(
|
|
489
|
-
toolSpec=ToolSpec(
|
|
490
|
-
name=tool.toolSpec.name,
|
|
491
|
-
description=tool.toolSpec.description,
|
|
492
|
-
inputSchema=ToolInputSchema(json_=tool.toolSpec.inputSchema.json_),
|
|
493
|
-
)
|
|
494
|
-
)
|
|
495
|
-
for tool in tool_configuration.tools
|
|
496
|
-
]
|
|
497
|
-
|
|
498
494
|
if tool_configuration is None:
|
|
499
495
|
tool_configuration = ToolConfiguration(tools=[])
|
|
500
496
|
elif isinstance(tool_configuration, str):
|
|
501
|
-
tool_configuration = ToolConfiguration
|
|
497
|
+
tool_configuration = ToolConfiguration.model_validate_json(tool_configuration)
|
|
502
498
|
elif isinstance(tool_configuration, dict):
|
|
503
|
-
tool_configuration = ToolConfiguration(
|
|
499
|
+
tool_configuration = ToolConfiguration.model_validate(tool_configuration)
|
|
500
|
+
|
|
501
|
+
for tool in tool_configuration.tools:
|
|
502
|
+
logger.debug(f"TOOL JSON SCHEMA: {tool.toolSpec.inputSchema}")
|
|
504
503
|
|
|
505
504
|
tool_objects = list(tool_configuration.tools)
|
|
506
505
|
event = Event(
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
# mypy: disable-error-code=unused-ignore
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
5
|
+
import ast
|
|
3
6
|
import asyncio
|
|
4
7
|
import base64
|
|
5
8
|
import json
|
|
@@ -10,7 +13,7 @@ import weakref
|
|
|
10
13
|
from collections.abc import Iterator
|
|
11
14
|
from dataclasses import dataclass, field
|
|
12
15
|
from datetime import datetime
|
|
13
|
-
from typing import Any, Literal
|
|
16
|
+
from typing import Any, Callable, Literal, cast
|
|
14
17
|
|
|
15
18
|
import boto3
|
|
16
19
|
from aws_sdk_bedrock_runtime.client import (
|
|
@@ -33,11 +36,9 @@ from smithy_core.aio.interfaces.identity import IdentityResolver
|
|
|
33
36
|
from livekit import rtc
|
|
34
37
|
from livekit.agents import (
|
|
35
38
|
APIStatusError,
|
|
36
|
-
ToolError,
|
|
37
39
|
llm,
|
|
38
40
|
utils,
|
|
39
41
|
)
|
|
40
|
-
from livekit.agents.llm.realtime import RealtimeSession
|
|
41
42
|
from livekit.agents.metrics import RealtimeModelMetrics
|
|
42
43
|
from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
|
43
44
|
from livekit.agents.utils import is_given
|
|
@@ -150,12 +151,12 @@ class _ResponseGeneration:
|
|
|
150
151
|
speculative_messages: dict[str, str] = field(default_factory=dict)
|
|
151
152
|
tool_messages: dict[str, str] = field(default_factory=dict)
|
|
152
153
|
output_text: str = "" # agent ASR text
|
|
153
|
-
_created_timestamp: str = field(default_factory=datetime.now().isoformat())
|
|
154
|
+
_created_timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
154
155
|
_first_token_timestamp: float | None = None
|
|
155
156
|
_completed_timestamp: float | None = None
|
|
156
157
|
|
|
157
158
|
|
|
158
|
-
class Boto3CredentialsResolver(IdentityResolver):
|
|
159
|
+
class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc]
|
|
159
160
|
"""IdentityResolver implementation that sources AWS credentials from boto3.
|
|
160
161
|
|
|
161
162
|
The resolver delegates to the default boto3.Session() credential chain which
|
|
@@ -164,10 +165,10 @@ class Boto3CredentialsResolver(IdentityResolver):
|
|
|
164
165
|
passed into Bedrock runtime clients.
|
|
165
166
|
"""
|
|
166
167
|
|
|
167
|
-
def __init__(self):
|
|
168
|
-
self.session = boto3.Session()
|
|
168
|
+
def __init__(self) -> None:
|
|
169
|
+
self.session = boto3.Session() # type: ignore[attr-defined]
|
|
169
170
|
|
|
170
|
-
async def get_identity(self, **kwargs):
|
|
171
|
+
async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity:
|
|
171
172
|
"""Asynchronously resolve AWS credentials.
|
|
172
173
|
|
|
173
174
|
This method is invoked by the Bedrock runtime client whenever a new request needs to be
|
|
@@ -247,7 +248,7 @@ class RealtimeModel(llm.RealtimeModel):
|
|
|
247
248
|
self.temperature = temperature
|
|
248
249
|
self.top_p = top_p
|
|
249
250
|
self._opts = _RealtimeOptions(
|
|
250
|
-
voice=voice if is_given(voice) else "tiffany",
|
|
251
|
+
voice=cast(VOICE_ID, voice) if is_given(voice) else "tiffany",
|
|
251
252
|
temperature=temperature if is_given(temperature) else DEFAULT_TEMPERATURE,
|
|
252
253
|
top_p=top_p if is_given(top_p) else DEFAULT_TOP_P,
|
|
253
254
|
max_tokens=max_tokens if is_given(max_tokens) else DEFAULT_MAX_TOKENS,
|
|
@@ -295,7 +296,7 @@ class RealtimeSession( # noqa: F811
|
|
|
295
296
|
inference options and the Smithy Bedrock client configuration.
|
|
296
297
|
"""
|
|
297
298
|
super().__init__(realtime_model)
|
|
298
|
-
self._realtime_model = realtime_model
|
|
299
|
+
self._realtime_model: RealtimeModel = realtime_model
|
|
299
300
|
self._event_builder = seb(
|
|
300
301
|
prompt_name=str(uuid.uuid4()),
|
|
301
302
|
audio_content_name=str(uuid.uuid4()),
|
|
@@ -309,10 +310,10 @@ class RealtimeSession( # noqa: F811
|
|
|
309
310
|
self._audio_input_task = None
|
|
310
311
|
self._stream_response = None
|
|
311
312
|
self._bedrock_client = None
|
|
313
|
+
self._pending_tools: set[str] = set()
|
|
312
314
|
self._is_sess_active = asyncio.Event()
|
|
313
315
|
self._chat_ctx = llm.ChatContext.empty()
|
|
314
316
|
self._tools = llm.ToolContext.empty()
|
|
315
|
-
self._tool_type_map = {}
|
|
316
317
|
self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
|
|
317
318
|
self._tools_ready = asyncio.get_running_loop().create_future()
|
|
318
319
|
self._instructions_ready = asyncio.get_running_loop().create_future()
|
|
@@ -341,16 +342,12 @@ class RealtimeSession( # noqa: F811
|
|
|
341
342
|
"other_event": self._handle_other_event,
|
|
342
343
|
}
|
|
343
344
|
self._turn_tracker = _TurnTracker(
|
|
344
|
-
self.emit,
|
|
345
|
+
cast(Callable[[str, Any], None], self.emit),
|
|
346
|
+
cast(Callable[[], None], self.emit_generation_event),
|
|
345
347
|
)
|
|
346
348
|
|
|
347
|
-
def _current_generation_streams(
|
|
348
|
-
self,
|
|
349
|
-
) -> tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]:
|
|
350
|
-
return (self._current_generation.message_ch, self._current_generation.function_ch)
|
|
351
|
-
|
|
352
349
|
@utils.log_exceptions(logger=logger)
|
|
353
|
-
def _initialize_client(self):
|
|
350
|
+
def _initialize_client(self) -> None:
|
|
354
351
|
"""Instantiate the Bedrock runtime client"""
|
|
355
352
|
config = Config(
|
|
356
353
|
endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com",
|
|
@@ -362,11 +359,11 @@ class RealtimeSession( # noqa: F811
|
|
|
362
359
|
self._bedrock_client = BedrockRuntimeClient(config=config)
|
|
363
360
|
|
|
364
361
|
@utils.log_exceptions(logger=logger)
|
|
365
|
-
async def _send_raw_event(self, event_json):
|
|
362
|
+
async def _send_raw_event(self, event_json: str) -> None:
|
|
366
363
|
"""Low-level helper that serialises event_json and forwards it to the bidirectional stream.
|
|
367
364
|
|
|
368
365
|
Args:
|
|
369
|
-
event_json (
|
|
366
|
+
event_json (str): The JSON payload (already in Bedrock wire format) to queue.
|
|
370
367
|
|
|
371
368
|
Raises:
|
|
372
369
|
Exception: Propagates any failures returned by the Bedrock runtime client.
|
|
@@ -425,21 +422,21 @@ class RealtimeSession( # noqa: F811
|
|
|
425
422
|
input_schema = llm.utils.build_legacy_openai_schema(f, internally_tagged=True)[
|
|
426
423
|
"parameters"
|
|
427
424
|
]
|
|
428
|
-
|
|
429
|
-
else:
|
|
425
|
+
elif llm.tool_context.is_raw_function_tool(f):
|
|
430
426
|
description = llm.tool_context.get_raw_function_info(f).raw_schema.get(
|
|
431
427
|
"description"
|
|
432
428
|
)
|
|
433
429
|
input_schema = llm.tool_context.get_raw_function_info(f).raw_schema[
|
|
434
430
|
"parameters"
|
|
435
431
|
]
|
|
436
|
-
|
|
432
|
+
else:
|
|
433
|
+
continue
|
|
437
434
|
|
|
438
435
|
tool = Tool(
|
|
439
436
|
toolSpec=ToolSpec(
|
|
440
437
|
name=name,
|
|
441
|
-
description=description,
|
|
442
|
-
inputSchema=ToolInputSchema(json_=json.dumps(input_schema)),
|
|
438
|
+
description=description or "No description provided",
|
|
439
|
+
inputSchema=ToolInputSchema(json_=json.dumps(input_schema)), # type: ignore
|
|
443
440
|
)
|
|
444
441
|
)
|
|
445
442
|
tools.append(tool)
|
|
@@ -455,7 +452,7 @@ class RealtimeSession( # noqa: F811
|
|
|
455
452
|
return tool_cfg
|
|
456
453
|
|
|
457
454
|
@utils.log_exceptions(logger=logger)
|
|
458
|
-
async def initialize_streams(self, is_restart: bool = False):
|
|
455
|
+
async def initialize_streams(self, is_restart: bool = False) -> None:
|
|
459
456
|
"""Open the Bedrock bidirectional stream and spawn background worker tasks.
|
|
460
457
|
|
|
461
458
|
This coroutine is idempotent and can be invoked again when recoverable
|
|
@@ -469,6 +466,7 @@ class RealtimeSession( # noqa: F811
|
|
|
469
466
|
if not self._bedrock_client:
|
|
470
467
|
logger.info("Creating Bedrock client")
|
|
471
468
|
self._initialize_client()
|
|
469
|
+
assert self._bedrock_client is not None, "bedrock_client is None"
|
|
472
470
|
|
|
473
471
|
logger.info("Initializing Bedrock stream")
|
|
474
472
|
self._stream_response = (
|
|
@@ -542,15 +540,16 @@ class RealtimeSession( # noqa: F811
|
|
|
542
540
|
self._is_sess_active.set()
|
|
543
541
|
logger.debug("Stream initialized successfully")
|
|
544
542
|
except Exception as e:
|
|
545
|
-
self._is_sess_active.set_exception(e)
|
|
546
543
|
logger.debug(f"Failed to initialize stream: {str(e)}")
|
|
547
544
|
raise
|
|
548
545
|
return self
|
|
549
546
|
|
|
550
547
|
@utils.log_exceptions(logger=logger)
|
|
551
|
-
def
|
|
548
|
+
def emit_generation_event(self) -> None:
|
|
552
549
|
"""Publish a llm.GenerationCreatedEvent to external subscribers."""
|
|
553
550
|
logger.debug("Emitting generation event")
|
|
551
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
552
|
+
|
|
554
553
|
generation_ev = llm.GenerationCreatedEvent(
|
|
555
554
|
message_stream=self._current_generation.message_ch,
|
|
556
555
|
function_stream=self._current_generation.function_ch,
|
|
@@ -605,10 +604,12 @@ class RealtimeSession( # noqa: F811
|
|
|
605
604
|
"""Handle text_output_content_start for both user and assistant roles."""
|
|
606
605
|
log_event_data(event_data)
|
|
607
606
|
role = event_data["event"]["contentStart"]["role"]
|
|
607
|
+
self._create_response_generation()
|
|
608
608
|
|
|
609
609
|
# note: does not work if you emit llm.GCE too early (for some reason)
|
|
610
610
|
if role == "USER":
|
|
611
|
-
self.
|
|
611
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
612
|
+
|
|
612
613
|
content_id = event_data["event"]["contentStart"]["contentId"]
|
|
613
614
|
self._current_generation.user_messages[content_id] = self._current_generation.input_id
|
|
614
615
|
|
|
@@ -616,6 +617,8 @@ class RealtimeSession( # noqa: F811
|
|
|
616
617
|
role == "ASSISTANT"
|
|
617
618
|
and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
|
|
618
619
|
):
|
|
620
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
621
|
+
|
|
619
622
|
text_content_id = event_data["event"]["contentStart"]["contentId"]
|
|
620
623
|
self._current_generation.speculative_messages[text_content_id] = (
|
|
621
624
|
self._current_generation.response_id
|
|
@@ -633,10 +636,15 @@ class RealtimeSession( # noqa: F811
|
|
|
633
636
|
# this is b/c audio playback is desynced from text transcription
|
|
634
637
|
# TODO: fix this; possibly via a playback timer
|
|
635
638
|
idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
|
|
639
|
+
if idx < 0:
|
|
640
|
+
logger.warning("Barge-in DETECTED but no previous message found")
|
|
641
|
+
return
|
|
642
|
+
|
|
636
643
|
logger.debug(
|
|
637
644
|
f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
|
|
638
645
|
)
|
|
639
|
-
self._chat_ctx.items[idx].
|
|
646
|
+
if (item := self._chat_ctx.items[idx]).type == "message":
|
|
647
|
+
item.interrupted = True
|
|
640
648
|
self._close_current_generation()
|
|
641
649
|
return
|
|
642
650
|
|
|
@@ -661,27 +669,31 @@ class RealtimeSession( # noqa: F811
|
|
|
661
669
|
# note: this update is per utterance, not per turn
|
|
662
670
|
self._update_chat_ctx(role="assistant", text_content=text_content)
|
|
663
671
|
|
|
664
|
-
def _update_chat_ctx(self, role:
|
|
672
|
+
def _update_chat_ctx(self, role: llm.ChatRole, text_content: str) -> None:
|
|
665
673
|
"""
|
|
666
674
|
Update the chat context with the latest ASR text while guarding against model limitations:
|
|
667
675
|
a) 40 total messages limit
|
|
668
676
|
b) 1kB message size limit
|
|
669
677
|
"""
|
|
670
|
-
|
|
671
|
-
if
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
678
|
+
logger.debug(f"Updating chat context with role: {role} and text_content: {text_content}")
|
|
679
|
+
if len(self._chat_ctx.items) == 0:
|
|
680
|
+
self._chat_ctx.add_message(role=role, content=text_content)
|
|
681
|
+
else:
|
|
682
|
+
prev_utterance = self._chat_ctx.items[-1]
|
|
683
|
+
if prev_utterance.type == "message" and prev_utterance.role == role:
|
|
684
|
+
if isinstance(prev_content := prev_utterance.content[0], str) and (
|
|
685
|
+
len(prev_content.encode("utf-8")) + len(text_content.encode("utf-8"))
|
|
686
|
+
< MAX_MESSAGE_SIZE
|
|
687
|
+
):
|
|
688
|
+
prev_utterance.content[0] = "\n".join([prev_content, text_content])
|
|
689
|
+
else:
|
|
690
|
+
self._chat_ctx.add_message(role=role, content=text_content)
|
|
691
|
+
if len(self._chat_ctx.items) > MAX_MESSAGES:
|
|
692
|
+
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
677
693
|
else:
|
|
678
694
|
self._chat_ctx.add_message(role=role, content=text_content)
|
|
679
695
|
if len(self._chat_ctx.items) > MAX_MESSAGES:
|
|
680
696
|
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
681
|
-
else:
|
|
682
|
-
self._chat_ctx.add_message(role=role, content=text_content)
|
|
683
|
-
if len(self._chat_ctx.items) > MAX_MESSAGES:
|
|
684
|
-
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
685
697
|
|
|
686
698
|
# cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
|
|
687
699
|
async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
|
|
@@ -701,6 +713,8 @@ class RealtimeSession( # noqa: F811
|
|
|
701
713
|
async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
|
|
702
714
|
"""Track mapping content_id -> response_id for upcoming tool use."""
|
|
703
715
|
log_event_data(event_data)
|
|
716
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
717
|
+
|
|
704
718
|
tool_use_content_id = event_data["event"]["contentStart"]["contentId"]
|
|
705
719
|
self._current_generation.tool_messages[tool_use_content_id] = (
|
|
706
720
|
self._current_generation.response_id
|
|
@@ -710,6 +724,8 @@ class RealtimeSession( # noqa: F811
|
|
|
710
724
|
async def _handle_tool_output_content_event(self, event_data: dict) -> None:
|
|
711
725
|
"""Execute the referenced tool locally and forward results back to Bedrock."""
|
|
712
726
|
log_event_data(event_data)
|
|
727
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
728
|
+
|
|
713
729
|
tool_use_content_id = event_data["event"]["toolUse"]["contentId"]
|
|
714
730
|
tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
|
|
715
731
|
tool_name = event_data["event"]["toolUse"]["toolName"]
|
|
@@ -719,34 +735,38 @@ class RealtimeSession( # noqa: F811
|
|
|
719
735
|
):
|
|
720
736
|
args = event_data["event"]["toolUse"]["content"]
|
|
721
737
|
self._current_generation.function_ch.send_nowait(
|
|
722
|
-
llm.FunctionCall(
|
|
723
|
-
call_id=tool_use_id,
|
|
724
|
-
name=tool_name,
|
|
725
|
-
arguments=args,
|
|
726
|
-
)
|
|
738
|
+
llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
|
|
727
739
|
)
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
740
|
+
self._pending_tools.add(tool_use_id)
|
|
741
|
+
|
|
742
|
+
# performing these acrobatics in order to release the deadlock
|
|
743
|
+
# LiveKit will not accept a new generation until the previous one is closed
|
|
744
|
+
# the issue is that audio data cannot be generated until toolResult is received
|
|
745
|
+
# however, toolResults only arrive after update_chat_ctx() is invoked
|
|
746
|
+
# which will only occur after agent speech has completed
|
|
747
|
+
# therefore we introduce an artificial turn to trigger update_chat_ctx()
|
|
748
|
+
# TODO: this is messy-- investigate if there is a better way to handle this
|
|
749
|
+
curr_gen = self._current_generation.messages[self._current_generation.response_id]
|
|
750
|
+
curr_gen.audio_ch.close()
|
|
751
|
+
curr_gen.text_ch.close()
|
|
752
|
+
self._current_generation.message_ch.close()
|
|
753
|
+
self._current_generation.message_ch = utils.aio.Chan()
|
|
754
|
+
self._current_generation.function_ch.close()
|
|
755
|
+
self._current_generation.function_ch = utils.aio.Chan()
|
|
756
|
+
msg_gen = _MessageGeneration(
|
|
757
|
+
message_id=self._current_generation.response_id,
|
|
758
|
+
text_ch=utils.aio.Chan(),
|
|
759
|
+
audio_ch=utils.aio.Chan(),
|
|
760
|
+
)
|
|
761
|
+
self._current_generation.messages[self._current_generation.response_id] = msg_gen
|
|
762
|
+
self._current_generation.message_ch.send_nowait(
|
|
763
|
+
llm.MessageGeneration(
|
|
764
|
+
message_id=msg_gen.message_id,
|
|
765
|
+
text_stream=msg_gen.text_ch,
|
|
766
|
+
audio_stream=msg_gen.audio_ch,
|
|
767
|
+
)
|
|
749
768
|
)
|
|
769
|
+
self.emit_generation_event()
|
|
750
770
|
|
|
751
771
|
async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
|
|
752
772
|
log_event_data(event_data)
|
|
@@ -803,6 +823,12 @@ class RealtimeSession( # noqa: F811
|
|
|
803
823
|
curr_gen.audio_ch.close()
|
|
804
824
|
if not curr_gen.text_ch.closed:
|
|
805
825
|
curr_gen.text_ch.close()
|
|
826
|
+
if self._current_generation.response_id in self._current_generation.tool_messages:
|
|
827
|
+
curr_gen = self._current_generation.tool_messages[
|
|
828
|
+
self._current_generation.response_id
|
|
829
|
+
]
|
|
830
|
+
if not curr_gen.function_ch.closed:
|
|
831
|
+
curr_gen.function_ch.close()
|
|
806
832
|
|
|
807
833
|
if not self._current_generation.message_ch.closed:
|
|
808
834
|
self._current_generation.message_ch.close()
|
|
@@ -855,10 +881,11 @@ class RealtimeSession( # noqa: F811
|
|
|
855
881
|
self.emit("metrics_collected", metrics)
|
|
856
882
|
|
|
857
883
|
@utils.log_exceptions(logger=logger)
|
|
858
|
-
async def _process_responses(self):
|
|
884
|
+
async def _process_responses(self) -> None:
|
|
859
885
|
"""Background task that drains Bedrock's output stream and feeds the event handlers."""
|
|
860
886
|
try:
|
|
861
887
|
await self._is_sess_active.wait()
|
|
888
|
+
assert self._stream_response is not None, "stream_response is None"
|
|
862
889
|
|
|
863
890
|
# note: may need another signal here to block input task until bedrock is ready
|
|
864
891
|
# TODO: save this as a field so we're not re-awaiting it every time
|
|
@@ -892,7 +919,6 @@ class RealtimeSession( # noqa: F811
|
|
|
892
919
|
|
|
893
920
|
else:
|
|
894
921
|
logger.error(f"Validation error: {ve}")
|
|
895
|
-
request_id = ve.split(" ")[0].split("=")[1]
|
|
896
922
|
self.emit(
|
|
897
923
|
"error",
|
|
898
924
|
llm.RealtimeModelError(
|
|
@@ -901,7 +927,7 @@ class RealtimeSession( # noqa: F811
|
|
|
901
927
|
error=APIStatusError(
|
|
902
928
|
message=ve.message,
|
|
903
929
|
status_code=400,
|
|
904
|
-
request_id=
|
|
930
|
+
request_id="",
|
|
905
931
|
body=ve,
|
|
906
932
|
retryable=False,
|
|
907
933
|
),
|
|
@@ -940,7 +966,7 @@ class RealtimeSession( # noqa: F811
|
|
|
940
966
|
timestamp=time.monotonic(),
|
|
941
967
|
label=self._realtime_model._label,
|
|
942
968
|
error=APIStatusError(
|
|
943
|
-
message=
|
|
969
|
+
message=err_msg,
|
|
944
970
|
status_code=500,
|
|
945
971
|
request_id=request_id,
|
|
946
972
|
body=e,
|
|
@@ -1014,6 +1040,25 @@ class RealtimeSession( # noqa: F811
|
|
|
1014
1040
|
logger.debug(f"Chat context updated: {self._chat_ctx.items}")
|
|
1015
1041
|
self._chat_ctx_ready.set_result(True)
|
|
1016
1042
|
|
|
1043
|
+
# for each function tool, send the result to aws
|
|
1044
|
+
for item in chat_ctx.items:
|
|
1045
|
+
if item.type != "function_call_output":
|
|
1046
|
+
continue
|
|
1047
|
+
|
|
1048
|
+
if item.call_id not in self._pending_tools:
|
|
1049
|
+
continue
|
|
1050
|
+
|
|
1051
|
+
logger.debug(f"function call output: {item}")
|
|
1052
|
+
self._pending_tools.discard(item.call_id)
|
|
1053
|
+
self._tool_results_ch.send_nowait(
|
|
1054
|
+
{
|
|
1055
|
+
"tool_use_id": item.call_id,
|
|
1056
|
+
"tool_result": item.output
|
|
1057
|
+
if not item.is_error
|
|
1058
|
+
else f"{{'error': '{item.error}'}}",
|
|
1059
|
+
}
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1017
1062
|
async def _send_tool_events(self, tool_use_id: str, tool_result: str) -> None:
|
|
1018
1063
|
"""Send tool_result back to Bedrock, grouped under tool_use_id."""
|
|
1019
1064
|
tool_content_name = str(uuid.uuid4())
|
|
@@ -1026,7 +1071,9 @@ class RealtimeSession( # noqa: F811
|
|
|
1026
1071
|
await self._send_raw_event(event)
|
|
1027
1072
|
# logger.debug(f"Sent tool event: {event}")
|
|
1028
1073
|
|
|
1029
|
-
def _tool_choice_adapter(
|
|
1074
|
+
def _tool_choice_adapter(
|
|
1075
|
+
self, tool_choice: llm.ToolChoice | None
|
|
1076
|
+
) -> dict[str, dict[str, str]] | None:
|
|
1030
1077
|
"""Translate the LiveKit ToolChoice enum into Sonic's JSON schema."""
|
|
1031
1078
|
if tool_choice == "auto":
|
|
1032
1079
|
return {"auto": {}}
|
|
@@ -1079,7 +1126,7 @@ class RealtimeSession( # noqa: F811
|
|
|
1079
1126
|
yield frame
|
|
1080
1127
|
|
|
1081
1128
|
@utils.log_exceptions(logger=logger)
|
|
1082
|
-
async def _process_audio_input(self):
|
|
1129
|
+
async def _process_audio_input(self) -> None:
|
|
1083
1130
|
"""Background task that feeds audio and tool results into the Bedrock stream."""
|
|
1084
1131
|
await self._send_raw_event(self._event_builder.create_audio_content_start_event())
|
|
1085
1132
|
logger.info("Starting audio input processing loop")
|
|
@@ -1090,6 +1137,19 @@ class RealtimeSession( # noqa: F811
|
|
|
1090
1137
|
val = self._tool_results_ch.recv_nowait()
|
|
1091
1138
|
tool_result = val["tool_result"]
|
|
1092
1139
|
tool_use_id = val["tool_use_id"]
|
|
1140
|
+
if not isinstance(tool_result, str):
|
|
1141
|
+
tool_result = json.dumps(tool_result)
|
|
1142
|
+
else:
|
|
1143
|
+
try:
|
|
1144
|
+
json.loads(tool_result)
|
|
1145
|
+
except json.JSONDecodeError:
|
|
1146
|
+
try:
|
|
1147
|
+
tool_result = json.dumps(ast.literal_eval(tool_result))
|
|
1148
|
+
except Exception:
|
|
1149
|
+
# return the original value
|
|
1150
|
+
pass
|
|
1151
|
+
|
|
1152
|
+
logger.debug(f"Sending tool result: {tool_result}")
|
|
1093
1153
|
await self._send_tool_events(tool_use_id, tool_result)
|
|
1094
1154
|
|
|
1095
1155
|
except utils.aio.channel.ChanEmpty:
|
|
@@ -1152,6 +1212,11 @@ class RealtimeSession( # noqa: F811
|
|
|
1152
1212
|
instructions: NotGivenOr[str] = NOT_GIVEN,
|
|
1153
1213
|
) -> asyncio.Future[llm.GenerationCreatedEvent]:
|
|
1154
1214
|
logger.warning("unprompted generation is not supported by Nova Sonic's Realtime API")
|
|
1215
|
+
fut = asyncio.Future[llm.GenerationCreatedEvent]()
|
|
1216
|
+
fut.set_exception(
|
|
1217
|
+
llm.RealtimeError("unprompted generation is not supported by Nova Sonic's Realtime API")
|
|
1218
|
+
)
|
|
1219
|
+
return fut
|
|
1155
1220
|
|
|
1156
1221
|
def commit_audio(self) -> None:
|
|
1157
1222
|
logger.warning("commit_audio is not supported by Nova Sonic's Realtime API")
|
|
@@ -1190,19 +1255,22 @@ class RealtimeSession( # noqa: F811
|
|
|
1190
1255
|
# resulting in an error after cancellation
|
|
1191
1256
|
# however, it's mostly cosmetic-- the event loop will still exit
|
|
1192
1257
|
# TODO: fix this nit
|
|
1258
|
+
tasks: list[asyncio.Task[Any]] = []
|
|
1193
1259
|
if self._response_task:
|
|
1194
1260
|
try:
|
|
1195
1261
|
await asyncio.wait_for(self._response_task, timeout=1.0)
|
|
1196
1262
|
except asyncio.TimeoutError:
|
|
1197
1263
|
logger.warning("shutdown of output event loop timed out-- cancelling")
|
|
1198
1264
|
self._response_task.cancel()
|
|
1265
|
+
tasks.append(self._response_task)
|
|
1199
1266
|
|
|
1200
1267
|
# must cancel the audio input task before closing the input stream
|
|
1201
1268
|
if self._audio_input_task and not self._audio_input_task.done():
|
|
1202
1269
|
self._audio_input_task.cancel()
|
|
1270
|
+
tasks.append(self._audio_input_task)
|
|
1203
1271
|
if self._stream_response and not self._stream_response.input_stream.closed:
|
|
1204
1272
|
await self._stream_response.input_stream.close()
|
|
1205
1273
|
|
|
1206
|
-
await asyncio.gather(
|
|
1274
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
1207
1275
|
logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
|
|
1208
1276
|
logger.info("Session end")
|
|
@@ -6,7 +6,7 @@ import uuid
|
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from typing import Any, Callable
|
|
8
8
|
|
|
9
|
-
from livekit.agents import llm
|
|
9
|
+
from livekit.agents import llm
|
|
10
10
|
|
|
11
11
|
from ...log import logger
|
|
12
12
|
|
|
@@ -34,7 +34,7 @@ class _Turn:
|
|
|
34
34
|
ev_trans_completed: bool = False
|
|
35
35
|
ev_generation_sent: bool = False
|
|
36
36
|
|
|
37
|
-
def add_partial_text(self, text: str):
|
|
37
|
+
def add_partial_text(self, text: str) -> None:
|
|
38
38
|
self.transcript.append(text)
|
|
39
39
|
|
|
40
40
|
@property
|
|
@@ -46,19 +46,17 @@ class _TurnTracker:
|
|
|
46
46
|
def __init__(
|
|
47
47
|
self,
|
|
48
48
|
emit_fn: Callable[[str, Any], None],
|
|
49
|
-
|
|
50
|
-
[], tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]
|
|
51
|
-
],
|
|
49
|
+
emit_generation_fn: Callable[[], None],
|
|
52
50
|
):
|
|
53
51
|
self._emit = emit_fn
|
|
54
52
|
self._turn_idx = 0
|
|
55
53
|
self._curr_turn: _Turn | None = None
|
|
56
|
-
self.
|
|
54
|
+
self._emit_generation_fn = emit_generation_fn
|
|
57
55
|
|
|
58
56
|
# --------------------------------------------------------
|
|
59
57
|
# PUBLIC ENTRY POINT
|
|
60
58
|
# --------------------------------------------------------
|
|
61
|
-
def feed(self, event: dict):
|
|
59
|
+
def feed(self, event: dict) -> None:
|
|
62
60
|
turn = self._ensure_turn()
|
|
63
61
|
kind = _classify(event)
|
|
64
62
|
|
|
@@ -97,13 +95,13 @@ class _TurnTracker:
|
|
|
97
95
|
self._curr_turn = _Turn(turn_id=self._turn_idx)
|
|
98
96
|
return self._curr_turn
|
|
99
97
|
|
|
100
|
-
def _maybe_emit_input_started(self, turn: _Turn):
|
|
98
|
+
def _maybe_emit_input_started(self, turn: _Turn) -> None:
|
|
101
99
|
if not turn.ev_input_started:
|
|
102
100
|
turn.ev_input_started = True
|
|
103
101
|
self._emit("input_speech_started", llm.InputSpeechStartedEvent())
|
|
104
102
|
turn.phase = _Phase.USER_SPEAKING
|
|
105
103
|
|
|
106
|
-
def _maybe_emit_input_stopped(self, turn: _Turn):
|
|
104
|
+
def _maybe_emit_input_stopped(self, turn: _Turn) -> None:
|
|
107
105
|
if not turn.ev_input_stopped:
|
|
108
106
|
turn.ev_input_stopped = True
|
|
109
107
|
self._emit(
|
|
@@ -111,7 +109,7 @@ class _TurnTracker:
|
|
|
111
109
|
)
|
|
112
110
|
turn.phase = _Phase.USER_FINISHED
|
|
113
111
|
|
|
114
|
-
def _emit_transcript_updated(self, turn: _Turn):
|
|
112
|
+
def _emit_transcript_updated(self, turn: _Turn) -> None:
|
|
115
113
|
self._emit(
|
|
116
114
|
"input_audio_transcription_completed",
|
|
117
115
|
llm.InputTranscriptionCompleted(
|
|
@@ -121,7 +119,7 @@ class _TurnTracker:
|
|
|
121
119
|
),
|
|
122
120
|
)
|
|
123
121
|
|
|
124
|
-
def _maybe_emit_transcript_completed(self, turn: _Turn):
|
|
122
|
+
def _maybe_emit_transcript_completed(self, turn: _Turn) -> None:
|
|
125
123
|
if not turn.ev_trans_completed:
|
|
126
124
|
turn.ev_trans_completed = True
|
|
127
125
|
self._emit(
|
|
@@ -134,17 +132,10 @@ class _TurnTracker:
|
|
|
134
132
|
),
|
|
135
133
|
)
|
|
136
134
|
|
|
137
|
-
def _maybe_emit_generation_created(self, turn: _Turn):
|
|
135
|
+
def _maybe_emit_generation_created(self, turn: _Turn) -> None:
|
|
138
136
|
if not turn.ev_generation_sent:
|
|
139
137
|
turn.ev_generation_sent = True
|
|
140
|
-
|
|
141
|
-
logger.debug("Emitting generation event")
|
|
142
|
-
generation_ev = llm.GenerationCreatedEvent(
|
|
143
|
-
message_stream=msg_stream,
|
|
144
|
-
function_stream=fn_stream,
|
|
145
|
-
user_initiated=False,
|
|
146
|
-
)
|
|
147
|
-
self._emit("generation_created", generation_ev)
|
|
138
|
+
self._emit_generation_fn()
|
|
148
139
|
turn.phase = _Phase.ASSISTANT_RESPONDING
|
|
149
140
|
|
|
150
141
|
|
livekit/plugins/aws/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: livekit-plugins-aws
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.6
|
|
4
4
|
Summary: LiveKit Agents Plugin for services from AWS
|
|
5
5
|
Project-URL: Documentation, https://docs.livekit.io
|
|
6
6
|
Project-URL: Website, https://livekit.io/
|
|
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Requires-Python: >=3.9.0
|
|
21
21
|
Requires-Dist: aioboto3>=14.1.0
|
|
22
22
|
Requires-Dist: amazon-transcribe>=0.6.2
|
|
23
|
-
Requires-Dist: livekit-agents>=1.1.
|
|
23
|
+
Requires-Dist: livekit-agents>=1.1.6
|
|
24
24
|
Provides-Extra: realtime
|
|
25
25
|
Requires-Dist: aws-sdk-bedrock-runtime==0.0.2; (python_version >= '3.12') and extra == 'realtime'
|
|
26
26
|
Requires-Dist: boto3>1.35.10; extra == 'realtime'
|
|
@@ -44,3 +44,7 @@ pip install livekit-plugins-aws[realtime]
|
|
|
44
44
|
## Pre-requisites
|
|
45
45
|
|
|
46
46
|
You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
|
|
47
|
+
|
|
48
|
+
## Example
|
|
49
|
+
|
|
50
|
+
For an example of the realtime STS model, Nova Sonic, see: https://github.com/livekit/agents/blob/main/examples/voice_agents/realtime_joke_teller.py
|
|
@@ -6,12 +6,12 @@ livekit/plugins/aws/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
|
6
6
|
livekit/plugins/aws/stt.py,sha256=PSR89aN28wm4i83yEdhkDJ9xzM0CsNIKrc3v3EbPndQ,9018
|
|
7
7
|
livekit/plugins/aws/tts.py,sha256=T5dVpTuIuzQimYNnkfXi5dRLmRldWySL4IcbkXjmJLM,6083
|
|
8
8
|
livekit/plugins/aws/utils.py,sha256=nA5Ua1f4T-25Loar6EvlrKTXI9N-zpTIH7cdQkwGyGI,1518
|
|
9
|
-
livekit/plugins/aws/version.py,sha256
|
|
9
|
+
livekit/plugins/aws/version.py,sha256=-bNd31cMcYCdhZCIKJ1-jtY4NgZvppVgKyzXAIzQtqM,600
|
|
10
10
|
livekit/plugins/aws/experimental/realtime/__init__.py,sha256=mm_TGZc9QAWSO-VOO3PdE8Y5R6xlWckXRZuiFUIHa-Q,287
|
|
11
|
-
livekit/plugins/aws/experimental/realtime/events.py,sha256
|
|
11
|
+
livekit/plugins/aws/experimental/realtime/events.py,sha256=-pJrwVrH5AZFxa1eDbX5nDdnJMz4BNucNZlYUYLsP-Y,15853
|
|
12
12
|
livekit/plugins/aws/experimental/realtime/pretty_printer.py,sha256=KN7KPrfQu8cU7ff34vFAtfrd1umUSTVNKXQU7D8AMiM,1442
|
|
13
|
-
livekit/plugins/aws/experimental/realtime/realtime_model.py,sha256=
|
|
14
|
-
livekit/plugins/aws/experimental/realtime/turn_tracker.py,sha256=
|
|
15
|
-
livekit_plugins_aws-1.1.
|
|
16
|
-
livekit_plugins_aws-1.1.
|
|
17
|
-
livekit_plugins_aws-1.1.
|
|
13
|
+
livekit/plugins/aws/experimental/realtime/realtime_model.py,sha256=em_3Fbp1qefF7cIIHc6ib1FLdD1MOGes2Lwq61o2wlk,59464
|
|
14
|
+
livekit/plugins/aws/experimental/realtime/turn_tracker.py,sha256=bcufaap-coeIYuK3ct1Is9W_UoefGYRmnJu7Mn5DCYU,6002
|
|
15
|
+
livekit_plugins_aws-1.1.6.dist-info/METADATA,sha256=ST8uYsoqQgHUVRCLC3BdkYdwALh3joYGRjblVKQgDrE,1989
|
|
16
|
+
livekit_plugins_aws-1.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
17
|
+
livekit_plugins_aws-1.1.6.dist-info/RECORD,,
|
|
File without changes
|