livekit-plugins-aws 1.1.5__tar.gz → 1.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of livekit-plugins-aws might be problematic. Click here for more details.
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/PKG-INFO +6 -2
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/README.md +4 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/events.py +20 -21
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/realtime_model.py +147 -76
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/turn_tracker.py +11 -20
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/llm.py +3 -1
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/version.py +1 -1
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/pyproject.toml +1 -1
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/.gitignore +0 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/__init__.py +0 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/__init__.py +0 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/experimental/realtime/pretty_printer.py +0 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/log.py +0 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/models.py +0 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/py.typed +0 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/stt.py +0 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/tts.py +0 -0
- {livekit_plugins_aws-1.1.5 → livekit_plugins_aws-1.1.7}/livekit/plugins/aws/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: livekit-plugins-aws
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.7
|
|
4
4
|
Summary: LiveKit Agents Plugin for services from AWS
|
|
5
5
|
Project-URL: Documentation, https://docs.livekit.io
|
|
6
6
|
Project-URL: Website, https://livekit.io/
|
|
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
20
20
|
Requires-Python: >=3.9.0
|
|
21
21
|
Requires-Dist: aioboto3>=14.1.0
|
|
22
22
|
Requires-Dist: amazon-transcribe>=0.6.2
|
|
23
|
-
Requires-Dist: livekit-agents>=1.1.
|
|
23
|
+
Requires-Dist: livekit-agents>=1.1.7
|
|
24
24
|
Provides-Extra: realtime
|
|
25
25
|
Requires-Dist: aws-sdk-bedrock-runtime==0.0.2; (python_version >= '3.12') and extra == 'realtime'
|
|
26
26
|
Requires-Dist: boto3>1.35.10; extra == 'realtime'
|
|
@@ -44,3 +44,7 @@ pip install livekit-plugins-aws[realtime]
|
|
|
44
44
|
## Pre-requisites
|
|
45
45
|
|
|
46
46
|
You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
|
|
47
|
+
|
|
48
|
+
## Example
|
|
49
|
+
|
|
50
|
+
For an example of the realtime STS model, Nova Sonic, see: https://github.com/livekit/agents/blob/main/examples/voice_agents/realtime_joke_teller.py
|
|
@@ -16,3 +16,7 @@ pip install livekit-plugins-aws[realtime]
|
|
|
16
16
|
## Pre-requisites
|
|
17
17
|
|
|
18
18
|
You'll need to specify an AWS Access Key and a Deployment Region. They can be set as environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY` and `AWS_DEFAULT_REGION`, respectively.
|
|
19
|
+
|
|
20
|
+
## Example
|
|
21
|
+
|
|
22
|
+
For an example of the realtime STS model, Nova Sonic, see: https://github.com/livekit/agents/blob/main/examples/voice_agents/realtime_joke_teller.py
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import uuid
|
|
3
|
-
from typing import Any, Literal, Optional, Union
|
|
3
|
+
from typing import Any, Literal, Optional, Union, cast
|
|
4
4
|
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field
|
|
5
|
+
from pydantic import BaseModel as _BaseModel, ConfigDict, Field
|
|
6
6
|
|
|
7
7
|
from livekit.agents import llm
|
|
8
8
|
|
|
@@ -20,7 +20,7 @@ SAMPLE_SIZE_BITS = Literal[16] # only supports 16-bit audio
|
|
|
20
20
|
CHANNEL_COUNT = Literal[1] # only supports monochannel audio
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
class BaseModel(
|
|
23
|
+
class BaseModel(_BaseModel):
|
|
24
24
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
|
25
25
|
|
|
26
26
|
|
|
@@ -91,7 +91,7 @@ class Tool(BaseModel):
|
|
|
91
91
|
|
|
92
92
|
|
|
93
93
|
class ToolConfiguration(BaseModel):
|
|
94
|
-
toolChoice: dict[str, dict[str, str]]
|
|
94
|
+
toolChoice: Optional[dict[str, dict[str, str]]] = None
|
|
95
95
|
tools: list[Tool]
|
|
96
96
|
|
|
97
97
|
|
|
@@ -260,6 +260,8 @@ class SonicEventBuilder:
|
|
|
260
260
|
else:
|
|
261
261
|
return "other_event"
|
|
262
262
|
|
|
263
|
+
raise ValueError(f"Unknown event type: {json_data}")
|
|
264
|
+
|
|
263
265
|
def create_text_content_block(
|
|
264
266
|
self,
|
|
265
267
|
content_name: str,
|
|
@@ -313,10 +315,18 @@ class SonicEventBuilder:
|
|
|
313
315
|
if chat_ctx.items:
|
|
314
316
|
logger.debug("initiating session with chat context")
|
|
315
317
|
for item in chat_ctx.items:
|
|
318
|
+
if item.type != "message":
|
|
319
|
+
continue
|
|
320
|
+
|
|
321
|
+
if (role := item.role.upper()) not in ["USER", "ASSISTANT", "SYSTEM"]:
|
|
322
|
+
continue
|
|
323
|
+
|
|
316
324
|
ctx_content_name = str(uuid.uuid4())
|
|
317
325
|
init_events.extend(
|
|
318
326
|
self.create_text_content_block(
|
|
319
|
-
ctx_content_name,
|
|
327
|
+
ctx_content_name,
|
|
328
|
+
cast(ROLE, role),
|
|
329
|
+
"".join(c for c in item.content if isinstance(c, str)),
|
|
320
330
|
)
|
|
321
331
|
)
|
|
322
332
|
|
|
@@ -481,26 +491,15 @@ class SonicEventBuilder:
|
|
|
481
491
|
sample_rate: SAMPLE_RATE_HERTZ,
|
|
482
492
|
tool_configuration: Optional[Union[ToolConfiguration, dict[str, Any], str]] = None,
|
|
483
493
|
) -> str:
|
|
484
|
-
tool_configuration = tool_configuration or ToolConfiguration(tools=[])
|
|
485
|
-
for tool in tool_configuration.tools:
|
|
486
|
-
logger.debug(f"TOOL JSON SCHEMA: {tool.toolSpec.inputSchema}")
|
|
487
|
-
tool_objects = [
|
|
488
|
-
Tool(
|
|
489
|
-
toolSpec=ToolSpec(
|
|
490
|
-
name=tool.toolSpec.name,
|
|
491
|
-
description=tool.toolSpec.description,
|
|
492
|
-
inputSchema=ToolInputSchema(json_=tool.toolSpec.inputSchema.json_),
|
|
493
|
-
)
|
|
494
|
-
)
|
|
495
|
-
for tool in tool_configuration.tools
|
|
496
|
-
]
|
|
497
|
-
|
|
498
494
|
if tool_configuration is None:
|
|
499
495
|
tool_configuration = ToolConfiguration(tools=[])
|
|
500
496
|
elif isinstance(tool_configuration, str):
|
|
501
|
-
tool_configuration = ToolConfiguration
|
|
497
|
+
tool_configuration = ToolConfiguration.model_validate_json(tool_configuration)
|
|
502
498
|
elif isinstance(tool_configuration, dict):
|
|
503
|
-
tool_configuration = ToolConfiguration(
|
|
499
|
+
tool_configuration = ToolConfiguration.model_validate(tool_configuration)
|
|
500
|
+
|
|
501
|
+
for tool in tool_configuration.tools:
|
|
502
|
+
logger.debug(f"TOOL JSON SCHEMA: {tool.toolSpec.inputSchema}")
|
|
504
503
|
|
|
505
504
|
tool_objects = list(tool_configuration.tools)
|
|
506
505
|
event = Event(
|
|
@@ -1,5 +1,8 @@
|
|
|
1
|
+
# mypy: disable-error-code=unused-ignore
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
5
|
+
import ast
|
|
3
6
|
import asyncio
|
|
4
7
|
import base64
|
|
5
8
|
import json
|
|
@@ -10,7 +13,7 @@ import weakref
|
|
|
10
13
|
from collections.abc import Iterator
|
|
11
14
|
from dataclasses import dataclass, field
|
|
12
15
|
from datetime import datetime
|
|
13
|
-
from typing import Any, Literal
|
|
16
|
+
from typing import Any, Callable, Literal, cast
|
|
14
17
|
|
|
15
18
|
import boto3
|
|
16
19
|
from aws_sdk_bedrock_runtime.client import (
|
|
@@ -33,11 +36,9 @@ from smithy_core.aio.interfaces.identity import IdentityResolver
|
|
|
33
36
|
from livekit import rtc
|
|
34
37
|
from livekit.agents import (
|
|
35
38
|
APIStatusError,
|
|
36
|
-
ToolError,
|
|
37
39
|
llm,
|
|
38
40
|
utils,
|
|
39
41
|
)
|
|
40
|
-
from livekit.agents.llm.realtime import RealtimeSession
|
|
41
42
|
from livekit.agents.metrics import RealtimeModelMetrics
|
|
42
43
|
from livekit.agents.types import NOT_GIVEN, NotGivenOr
|
|
43
44
|
from livekit.agents.utils import is_given
|
|
@@ -150,12 +151,12 @@ class _ResponseGeneration:
|
|
|
150
151
|
speculative_messages: dict[str, str] = field(default_factory=dict)
|
|
151
152
|
tool_messages: dict[str, str] = field(default_factory=dict)
|
|
152
153
|
output_text: str = "" # agent ASR text
|
|
153
|
-
_created_timestamp: str = field(default_factory=datetime.now().isoformat())
|
|
154
|
+
_created_timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
154
155
|
_first_token_timestamp: float | None = None
|
|
155
156
|
_completed_timestamp: float | None = None
|
|
156
157
|
|
|
157
158
|
|
|
158
|
-
class Boto3CredentialsResolver(IdentityResolver):
|
|
159
|
+
class Boto3CredentialsResolver(IdentityResolver): # type: ignore[misc]
|
|
159
160
|
"""IdentityResolver implementation that sources AWS credentials from boto3.
|
|
160
161
|
|
|
161
162
|
The resolver delegates to the default boto3.Session() credential chain which
|
|
@@ -164,10 +165,10 @@ class Boto3CredentialsResolver(IdentityResolver):
|
|
|
164
165
|
passed into Bedrock runtime clients.
|
|
165
166
|
"""
|
|
166
167
|
|
|
167
|
-
def __init__(self):
|
|
168
|
-
self.session = boto3.Session()
|
|
168
|
+
def __init__(self) -> None:
|
|
169
|
+
self.session = boto3.Session() # type: ignore[attr-defined]
|
|
169
170
|
|
|
170
|
-
async def get_identity(self, **kwargs):
|
|
171
|
+
async def get_identity(self, **kwargs: Any) -> AWSCredentialsIdentity:
|
|
171
172
|
"""Asynchronously resolve AWS credentials.
|
|
172
173
|
|
|
173
174
|
This method is invoked by the Bedrock runtime client whenever a new request needs to be
|
|
@@ -247,7 +248,7 @@ class RealtimeModel(llm.RealtimeModel):
|
|
|
247
248
|
self.temperature = temperature
|
|
248
249
|
self.top_p = top_p
|
|
249
250
|
self._opts = _RealtimeOptions(
|
|
250
|
-
voice=voice if is_given(voice) else "tiffany",
|
|
251
|
+
voice=cast(VOICE_ID, voice) if is_given(voice) else "tiffany",
|
|
251
252
|
temperature=temperature if is_given(temperature) else DEFAULT_TEMPERATURE,
|
|
252
253
|
top_p=top_p if is_given(top_p) else DEFAULT_TOP_P,
|
|
253
254
|
max_tokens=max_tokens if is_given(max_tokens) else DEFAULT_MAX_TOKENS,
|
|
@@ -295,7 +296,7 @@ class RealtimeSession( # noqa: F811
|
|
|
295
296
|
inference options and the Smithy Bedrock client configuration.
|
|
296
297
|
"""
|
|
297
298
|
super().__init__(realtime_model)
|
|
298
|
-
self._realtime_model = realtime_model
|
|
299
|
+
self._realtime_model: RealtimeModel = realtime_model
|
|
299
300
|
self._event_builder = seb(
|
|
300
301
|
prompt_name=str(uuid.uuid4()),
|
|
301
302
|
audio_content_name=str(uuid.uuid4()),
|
|
@@ -309,10 +310,10 @@ class RealtimeSession( # noqa: F811
|
|
|
309
310
|
self._audio_input_task = None
|
|
310
311
|
self._stream_response = None
|
|
311
312
|
self._bedrock_client = None
|
|
313
|
+
self._pending_tools: set[str] = set()
|
|
312
314
|
self._is_sess_active = asyncio.Event()
|
|
313
315
|
self._chat_ctx = llm.ChatContext.empty()
|
|
314
316
|
self._tools = llm.ToolContext.empty()
|
|
315
|
-
self._tool_type_map = {}
|
|
316
317
|
self._tool_results_ch = utils.aio.Chan[dict[str, str]]()
|
|
317
318
|
self._tools_ready = asyncio.get_running_loop().create_future()
|
|
318
319
|
self._instructions_ready = asyncio.get_running_loop().create_future()
|
|
@@ -341,16 +342,12 @@ class RealtimeSession( # noqa: F811
|
|
|
341
342
|
"other_event": self._handle_other_event,
|
|
342
343
|
}
|
|
343
344
|
self._turn_tracker = _TurnTracker(
|
|
344
|
-
self.emit,
|
|
345
|
+
cast(Callable[[str, Any], None], self.emit),
|
|
346
|
+
cast(Callable[[], None], self.emit_generation_event),
|
|
345
347
|
)
|
|
346
348
|
|
|
347
|
-
def _current_generation_streams(
|
|
348
|
-
self,
|
|
349
|
-
) -> tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]:
|
|
350
|
-
return (self._current_generation.message_ch, self._current_generation.function_ch)
|
|
351
|
-
|
|
352
349
|
@utils.log_exceptions(logger=logger)
|
|
353
|
-
def _initialize_client(self):
|
|
350
|
+
def _initialize_client(self) -> None:
|
|
354
351
|
"""Instantiate the Bedrock runtime client"""
|
|
355
352
|
config = Config(
|
|
356
353
|
endpoint_uri=f"https://bedrock-runtime.{self._realtime_model._opts.region}.amazonaws.com",
|
|
@@ -358,15 +355,16 @@ class RealtimeSession( # noqa: F811
|
|
|
358
355
|
aws_credentials_identity_resolver=Boto3CredentialsResolver(),
|
|
359
356
|
http_auth_scheme_resolver=HTTPAuthSchemeResolver(),
|
|
360
357
|
http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()},
|
|
358
|
+
user_agent_extra="x-client-framework:livekit-plugins-aws[realtime]",
|
|
361
359
|
)
|
|
362
360
|
self._bedrock_client = BedrockRuntimeClient(config=config)
|
|
363
361
|
|
|
364
362
|
@utils.log_exceptions(logger=logger)
|
|
365
|
-
async def _send_raw_event(self, event_json):
|
|
363
|
+
async def _send_raw_event(self, event_json: str) -> None:
|
|
366
364
|
"""Low-level helper that serialises event_json and forwards it to the bidirectional stream.
|
|
367
365
|
|
|
368
366
|
Args:
|
|
369
|
-
event_json (
|
|
367
|
+
event_json (str): The JSON payload (already in Bedrock wire format) to queue.
|
|
370
368
|
|
|
371
369
|
Raises:
|
|
372
370
|
Exception: Propagates any failures returned by the Bedrock runtime client.
|
|
@@ -425,21 +423,21 @@ class RealtimeSession( # noqa: F811
|
|
|
425
423
|
input_schema = llm.utils.build_legacy_openai_schema(f, internally_tagged=True)[
|
|
426
424
|
"parameters"
|
|
427
425
|
]
|
|
428
|
-
|
|
429
|
-
else:
|
|
426
|
+
elif llm.tool_context.is_raw_function_tool(f):
|
|
430
427
|
description = llm.tool_context.get_raw_function_info(f).raw_schema.get(
|
|
431
428
|
"description"
|
|
432
429
|
)
|
|
433
430
|
input_schema = llm.tool_context.get_raw_function_info(f).raw_schema[
|
|
434
431
|
"parameters"
|
|
435
432
|
]
|
|
436
|
-
|
|
433
|
+
else:
|
|
434
|
+
continue
|
|
437
435
|
|
|
438
436
|
tool = Tool(
|
|
439
437
|
toolSpec=ToolSpec(
|
|
440
438
|
name=name,
|
|
441
|
-
description=description,
|
|
442
|
-
inputSchema=ToolInputSchema(json_=json.dumps(input_schema)),
|
|
439
|
+
description=description or "No description provided",
|
|
440
|
+
inputSchema=ToolInputSchema(json_=json.dumps(input_schema)), # type: ignore
|
|
443
441
|
)
|
|
444
442
|
)
|
|
445
443
|
tools.append(tool)
|
|
@@ -455,7 +453,7 @@ class RealtimeSession( # noqa: F811
|
|
|
455
453
|
return tool_cfg
|
|
456
454
|
|
|
457
455
|
@utils.log_exceptions(logger=logger)
|
|
458
|
-
async def initialize_streams(self, is_restart: bool = False):
|
|
456
|
+
async def initialize_streams(self, is_restart: bool = False) -> None:
|
|
459
457
|
"""Open the Bedrock bidirectional stream and spawn background worker tasks.
|
|
460
458
|
|
|
461
459
|
This coroutine is idempotent and can be invoked again when recoverable
|
|
@@ -469,6 +467,7 @@ class RealtimeSession( # noqa: F811
|
|
|
469
467
|
if not self._bedrock_client:
|
|
470
468
|
logger.info("Creating Bedrock client")
|
|
471
469
|
self._initialize_client()
|
|
470
|
+
assert self._bedrock_client is not None, "bedrock_client is None"
|
|
472
471
|
|
|
473
472
|
logger.info("Initializing Bedrock stream")
|
|
474
473
|
self._stream_response = (
|
|
@@ -518,7 +517,7 @@ class RealtimeSession( # noqa: F811
|
|
|
518
517
|
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
519
518
|
init_events = self._event_builder.create_prompt_start_block(
|
|
520
519
|
voice_id=self._realtime_model._opts.voice,
|
|
521
|
-
sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE,
|
|
520
|
+
sample_rate=DEFAULT_OUTPUT_SAMPLE_RATE, # type: ignore
|
|
522
521
|
system_content=self._instructions,
|
|
523
522
|
chat_ctx=self.chat_ctx,
|
|
524
523
|
tool_configuration=self._serialize_tool_config(),
|
|
@@ -542,15 +541,16 @@ class RealtimeSession( # noqa: F811
|
|
|
542
541
|
self._is_sess_active.set()
|
|
543
542
|
logger.debug("Stream initialized successfully")
|
|
544
543
|
except Exception as e:
|
|
545
|
-
self._is_sess_active.set_exception(e)
|
|
546
544
|
logger.debug(f"Failed to initialize stream: {str(e)}")
|
|
547
545
|
raise
|
|
548
546
|
return self
|
|
549
547
|
|
|
550
548
|
@utils.log_exceptions(logger=logger)
|
|
551
|
-
def
|
|
549
|
+
def emit_generation_event(self) -> None:
|
|
552
550
|
"""Publish a llm.GenerationCreatedEvent to external subscribers."""
|
|
553
551
|
logger.debug("Emitting generation event")
|
|
552
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
553
|
+
|
|
554
554
|
generation_ev = llm.GenerationCreatedEvent(
|
|
555
555
|
message_stream=self._current_generation.message_ch,
|
|
556
556
|
function_stream=self._current_generation.function_ch,
|
|
@@ -605,10 +605,12 @@ class RealtimeSession( # noqa: F811
|
|
|
605
605
|
"""Handle text_output_content_start for both user and assistant roles."""
|
|
606
606
|
log_event_data(event_data)
|
|
607
607
|
role = event_data["event"]["contentStart"]["role"]
|
|
608
|
+
self._create_response_generation()
|
|
608
609
|
|
|
609
610
|
# note: does not work if you emit llm.GCE too early (for some reason)
|
|
610
611
|
if role == "USER":
|
|
611
|
-
self.
|
|
612
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
613
|
+
|
|
612
614
|
content_id = event_data["event"]["contentStart"]["contentId"]
|
|
613
615
|
self._current_generation.user_messages[content_id] = self._current_generation.input_id
|
|
614
616
|
|
|
@@ -616,6 +618,8 @@ class RealtimeSession( # noqa: F811
|
|
|
616
618
|
role == "ASSISTANT"
|
|
617
619
|
and "SPECULATIVE" in event_data["event"]["contentStart"]["additionalModelFields"]
|
|
618
620
|
):
|
|
621
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
622
|
+
|
|
619
623
|
text_content_id = event_data["event"]["contentStart"]["contentId"]
|
|
620
624
|
self._current_generation.speculative_messages[text_content_id] = (
|
|
621
625
|
self._current_generation.response_id
|
|
@@ -633,10 +637,15 @@ class RealtimeSession( # noqa: F811
|
|
|
633
637
|
# this is b/c audio playback is desynced from text transcription
|
|
634
638
|
# TODO: fix this; possibly via a playback timer
|
|
635
639
|
idx = self._chat_ctx.find_insertion_index(created_at=time.time()) - 1
|
|
640
|
+
if idx < 0:
|
|
641
|
+
logger.warning("Barge-in DETECTED but no previous message found")
|
|
642
|
+
return
|
|
643
|
+
|
|
636
644
|
logger.debug(
|
|
637
645
|
f"BARGE-IN DETECTED using idx: {idx} and chat_msg: {self._chat_ctx.items[idx]}"
|
|
638
646
|
)
|
|
639
|
-
self._chat_ctx.items[idx].
|
|
647
|
+
if (item := self._chat_ctx.items[idx]).type == "message":
|
|
648
|
+
item.interrupted = True
|
|
640
649
|
self._close_current_generation()
|
|
641
650
|
return
|
|
642
651
|
|
|
@@ -661,27 +670,31 @@ class RealtimeSession( # noqa: F811
|
|
|
661
670
|
# note: this update is per utterance, not per turn
|
|
662
671
|
self._update_chat_ctx(role="assistant", text_content=text_content)
|
|
663
672
|
|
|
664
|
-
def _update_chat_ctx(self, role:
|
|
673
|
+
def _update_chat_ctx(self, role: llm.ChatRole, text_content: str) -> None:
|
|
665
674
|
"""
|
|
666
675
|
Update the chat context with the latest ASR text while guarding against model limitations:
|
|
667
676
|
a) 40 total messages limit
|
|
668
677
|
b) 1kB message size limit
|
|
669
678
|
"""
|
|
670
|
-
|
|
671
|
-
if
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
679
|
+
logger.debug(f"Updating chat context with role: {role} and text_content: {text_content}")
|
|
680
|
+
if len(self._chat_ctx.items) == 0:
|
|
681
|
+
self._chat_ctx.add_message(role=role, content=text_content)
|
|
682
|
+
else:
|
|
683
|
+
prev_utterance = self._chat_ctx.items[-1]
|
|
684
|
+
if prev_utterance.type == "message" and prev_utterance.role == role:
|
|
685
|
+
if isinstance(prev_content := prev_utterance.content[0], str) and (
|
|
686
|
+
len(prev_content.encode("utf-8")) + len(text_content.encode("utf-8"))
|
|
687
|
+
< MAX_MESSAGE_SIZE
|
|
688
|
+
):
|
|
689
|
+
prev_utterance.content[0] = "\n".join([prev_content, text_content])
|
|
690
|
+
else:
|
|
691
|
+
self._chat_ctx.add_message(role=role, content=text_content)
|
|
692
|
+
if len(self._chat_ctx.items) > MAX_MESSAGES:
|
|
693
|
+
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
677
694
|
else:
|
|
678
695
|
self._chat_ctx.add_message(role=role, content=text_content)
|
|
679
696
|
if len(self._chat_ctx.items) > MAX_MESSAGES:
|
|
680
697
|
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
681
|
-
else:
|
|
682
|
-
self._chat_ctx.add_message(role=role, content=text_content)
|
|
683
|
-
if len(self._chat_ctx.items) > MAX_MESSAGES:
|
|
684
|
-
self._chat_ctx.truncate(max_items=MAX_MESSAGES)
|
|
685
698
|
|
|
686
699
|
# cannot rely on this event for user b/c stopReason=PARTIAL_TURN always for user
|
|
687
700
|
async def _handle_text_output_content_end_event(self, event_data: dict) -> None:
|
|
@@ -701,6 +714,8 @@ class RealtimeSession( # noqa: F811
|
|
|
701
714
|
async def _handle_tool_output_content_start_event(self, event_data: dict) -> None:
|
|
702
715
|
"""Track mapping content_id -> response_id for upcoming tool use."""
|
|
703
716
|
log_event_data(event_data)
|
|
717
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
718
|
+
|
|
704
719
|
tool_use_content_id = event_data["event"]["contentStart"]["contentId"]
|
|
705
720
|
self._current_generation.tool_messages[tool_use_content_id] = (
|
|
706
721
|
self._current_generation.response_id
|
|
@@ -710,6 +725,8 @@ class RealtimeSession( # noqa: F811
|
|
|
710
725
|
async def _handle_tool_output_content_event(self, event_data: dict) -> None:
|
|
711
726
|
"""Execute the referenced tool locally and forward results back to Bedrock."""
|
|
712
727
|
log_event_data(event_data)
|
|
728
|
+
assert self._current_generation is not None, "current_generation is None"
|
|
729
|
+
|
|
713
730
|
tool_use_content_id = event_data["event"]["toolUse"]["contentId"]
|
|
714
731
|
tool_use_id = event_data["event"]["toolUse"]["toolUseId"]
|
|
715
732
|
tool_name = event_data["event"]["toolUse"]["toolName"]
|
|
@@ -719,34 +736,38 @@ class RealtimeSession( # noqa: F811
|
|
|
719
736
|
):
|
|
720
737
|
args = event_data["event"]["toolUse"]["content"]
|
|
721
738
|
self._current_generation.function_ch.send_nowait(
|
|
722
|
-
llm.FunctionCall(
|
|
723
|
-
call_id=tool_use_id,
|
|
724
|
-
name=tool_name,
|
|
725
|
-
arguments=args,
|
|
726
|
-
)
|
|
739
|
+
llm.FunctionCall(call_id=tool_use_id, name=tool_name, arguments=args)
|
|
727
740
|
)
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
741
|
+
self._pending_tools.add(tool_use_id)
|
|
742
|
+
|
|
743
|
+
# performing these acrobatics in order to release the deadlock
|
|
744
|
+
# LiveKit will not accept a new generation until the previous one is closed
|
|
745
|
+
# the issue is that audio data cannot be generated until toolResult is received
|
|
746
|
+
# however, toolResults only arrive after update_chat_ctx() is invoked
|
|
747
|
+
# which will only occur after agent speech has completed
|
|
748
|
+
# therefore we introduce an artificial turn to trigger update_chat_ctx()
|
|
749
|
+
# TODO: this is messy-- investigate if there is a better way to handle this
|
|
750
|
+
curr_gen = self._current_generation.messages[self._current_generation.response_id]
|
|
751
|
+
curr_gen.audio_ch.close()
|
|
752
|
+
curr_gen.text_ch.close()
|
|
753
|
+
self._current_generation.message_ch.close()
|
|
754
|
+
self._current_generation.message_ch = utils.aio.Chan()
|
|
755
|
+
self._current_generation.function_ch.close()
|
|
756
|
+
self._current_generation.function_ch = utils.aio.Chan()
|
|
757
|
+
msg_gen = _MessageGeneration(
|
|
758
|
+
message_id=self._current_generation.response_id,
|
|
759
|
+
text_ch=utils.aio.Chan(),
|
|
760
|
+
audio_ch=utils.aio.Chan(),
|
|
761
|
+
)
|
|
762
|
+
self._current_generation.messages[self._current_generation.response_id] = msg_gen
|
|
763
|
+
self._current_generation.message_ch.send_nowait(
|
|
764
|
+
llm.MessageGeneration(
|
|
765
|
+
message_id=msg_gen.message_id,
|
|
766
|
+
text_stream=msg_gen.text_ch,
|
|
767
|
+
audio_stream=msg_gen.audio_ch,
|
|
768
|
+
)
|
|
749
769
|
)
|
|
770
|
+
self.emit_generation_event()
|
|
750
771
|
|
|
751
772
|
async def _handle_tool_output_content_end_event(self, event_data: dict) -> None:
|
|
752
773
|
log_event_data(event_data)
|
|
@@ -804,6 +825,14 @@ class RealtimeSession( # noqa: F811
|
|
|
804
825
|
if not curr_gen.text_ch.closed:
|
|
805
826
|
curr_gen.text_ch.close()
|
|
806
827
|
|
|
828
|
+
# TODO: seems not needed, tool_messages[id] is a str, function_ch is closed below?
|
|
829
|
+
# if self._current_generation.response_id in self._current_generation.tool_messages:
|
|
830
|
+
# curr_gen = self._current_generation.tool_messages[
|
|
831
|
+
# self._current_generation.response_id
|
|
832
|
+
# ]
|
|
833
|
+
# if not curr_gen.function_ch.closed:
|
|
834
|
+
# curr_gen.function_ch.close()
|
|
835
|
+
|
|
807
836
|
if not self._current_generation.message_ch.closed:
|
|
808
837
|
self._current_generation.message_ch.close()
|
|
809
838
|
if not self._current_generation.function_ch.closed:
|
|
@@ -855,10 +884,11 @@ class RealtimeSession( # noqa: F811
|
|
|
855
884
|
self.emit("metrics_collected", metrics)
|
|
856
885
|
|
|
857
886
|
@utils.log_exceptions(logger=logger)
|
|
858
|
-
async def _process_responses(self):
|
|
887
|
+
async def _process_responses(self) -> None:
|
|
859
888
|
"""Background task that drains Bedrock's output stream and feeds the event handlers."""
|
|
860
889
|
try:
|
|
861
890
|
await self._is_sess_active.wait()
|
|
891
|
+
assert self._stream_response is not None, "stream_response is None"
|
|
862
892
|
|
|
863
893
|
# note: may need another signal here to block input task until bedrock is ready
|
|
864
894
|
# TODO: save this as a field so we're not re-awaiting it every time
|
|
@@ -892,7 +922,6 @@ class RealtimeSession( # noqa: F811
|
|
|
892
922
|
|
|
893
923
|
else:
|
|
894
924
|
logger.error(f"Validation error: {ve}")
|
|
895
|
-
request_id = ve.split(" ")[0].split("=")[1]
|
|
896
925
|
self.emit(
|
|
897
926
|
"error",
|
|
898
927
|
llm.RealtimeModelError(
|
|
@@ -901,7 +930,7 @@ class RealtimeSession( # noqa: F811
|
|
|
901
930
|
error=APIStatusError(
|
|
902
931
|
message=ve.message,
|
|
903
932
|
status_code=400,
|
|
904
|
-
request_id=
|
|
933
|
+
request_id="",
|
|
905
934
|
body=ve,
|
|
906
935
|
retryable=False,
|
|
907
936
|
),
|
|
@@ -940,7 +969,7 @@ class RealtimeSession( # noqa: F811
|
|
|
940
969
|
timestamp=time.monotonic(),
|
|
941
970
|
label=self._realtime_model._label,
|
|
942
971
|
error=APIStatusError(
|
|
943
|
-
message=
|
|
972
|
+
message=err_msg,
|
|
944
973
|
status_code=500,
|
|
945
974
|
request_id=request_id,
|
|
946
975
|
body=e,
|
|
@@ -1014,6 +1043,25 @@ class RealtimeSession( # noqa: F811
|
|
|
1014
1043
|
logger.debug(f"Chat context updated: {self._chat_ctx.items}")
|
|
1015
1044
|
self._chat_ctx_ready.set_result(True)
|
|
1016
1045
|
|
|
1046
|
+
# for each function tool, send the result to aws
|
|
1047
|
+
for item in chat_ctx.items:
|
|
1048
|
+
if item.type != "function_call_output":
|
|
1049
|
+
continue
|
|
1050
|
+
|
|
1051
|
+
if item.call_id not in self._pending_tools:
|
|
1052
|
+
continue
|
|
1053
|
+
|
|
1054
|
+
logger.debug(f"function call output: {item}")
|
|
1055
|
+
self._pending_tools.discard(item.call_id)
|
|
1056
|
+
self._tool_results_ch.send_nowait(
|
|
1057
|
+
{
|
|
1058
|
+
"tool_use_id": item.call_id,
|
|
1059
|
+
"tool_result": item.output
|
|
1060
|
+
if not item.is_error
|
|
1061
|
+
else f"{{'error': '{item.output}'}}",
|
|
1062
|
+
}
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1017
1065
|
async def _send_tool_events(self, tool_use_id: str, tool_result: str) -> None:
|
|
1018
1066
|
"""Send tool_result back to Bedrock, grouped under tool_use_id."""
|
|
1019
1067
|
tool_content_name = str(uuid.uuid4())
|
|
@@ -1026,7 +1074,9 @@ class RealtimeSession( # noqa: F811
|
|
|
1026
1074
|
await self._send_raw_event(event)
|
|
1027
1075
|
# logger.debug(f"Sent tool event: {event}")
|
|
1028
1076
|
|
|
1029
|
-
def _tool_choice_adapter(
|
|
1077
|
+
def _tool_choice_adapter(
|
|
1078
|
+
self, tool_choice: llm.ToolChoice | None
|
|
1079
|
+
) -> dict[str, dict[str, str]] | None:
|
|
1030
1080
|
"""Translate the LiveKit ToolChoice enum into Sonic's JSON schema."""
|
|
1031
1081
|
if tool_choice == "auto":
|
|
1032
1082
|
return {"auto": {}}
|
|
@@ -1079,7 +1129,7 @@ class RealtimeSession( # noqa: F811
|
|
|
1079
1129
|
yield frame
|
|
1080
1130
|
|
|
1081
1131
|
@utils.log_exceptions(logger=logger)
|
|
1082
|
-
async def _process_audio_input(self):
|
|
1132
|
+
async def _process_audio_input(self) -> None:
|
|
1083
1133
|
"""Background task that feeds audio and tool results into the Bedrock stream."""
|
|
1084
1134
|
await self._send_raw_event(self._event_builder.create_audio_content_start_event())
|
|
1085
1135
|
logger.info("Starting audio input processing loop")
|
|
@@ -1090,6 +1140,19 @@ class RealtimeSession( # noqa: F811
|
|
|
1090
1140
|
val = self._tool_results_ch.recv_nowait()
|
|
1091
1141
|
tool_result = val["tool_result"]
|
|
1092
1142
|
tool_use_id = val["tool_use_id"]
|
|
1143
|
+
if not isinstance(tool_result, str):
|
|
1144
|
+
tool_result = json.dumps(tool_result)
|
|
1145
|
+
else:
|
|
1146
|
+
try:
|
|
1147
|
+
json.loads(tool_result)
|
|
1148
|
+
except json.JSONDecodeError:
|
|
1149
|
+
try:
|
|
1150
|
+
tool_result = json.dumps(ast.literal_eval(tool_result))
|
|
1151
|
+
except Exception:
|
|
1152
|
+
# return the original value
|
|
1153
|
+
pass
|
|
1154
|
+
|
|
1155
|
+
logger.debug(f"Sending tool result: {tool_result}")
|
|
1093
1156
|
await self._send_tool_events(tool_use_id, tool_result)
|
|
1094
1157
|
|
|
1095
1158
|
except utils.aio.channel.ChanEmpty:
|
|
@@ -1152,6 +1215,11 @@ class RealtimeSession( # noqa: F811
|
|
|
1152
1215
|
instructions: NotGivenOr[str] = NOT_GIVEN,
|
|
1153
1216
|
) -> asyncio.Future[llm.GenerationCreatedEvent]:
|
|
1154
1217
|
logger.warning("unprompted generation is not supported by Nova Sonic's Realtime API")
|
|
1218
|
+
fut = asyncio.Future[llm.GenerationCreatedEvent]()
|
|
1219
|
+
fut.set_exception(
|
|
1220
|
+
llm.RealtimeError("unprompted generation is not supported by Nova Sonic's Realtime API")
|
|
1221
|
+
)
|
|
1222
|
+
return fut
|
|
1155
1223
|
|
|
1156
1224
|
def commit_audio(self) -> None:
|
|
1157
1225
|
logger.warning("commit_audio is not supported by Nova Sonic's Realtime API")
|
|
@@ -1190,19 +1258,22 @@ class RealtimeSession( # noqa: F811
|
|
|
1190
1258
|
# resulting in an error after cancellation
|
|
1191
1259
|
# however, it's mostly cosmetic-- the event loop will still exit
|
|
1192
1260
|
# TODO: fix this nit
|
|
1261
|
+
tasks: list[asyncio.Task[Any]] = []
|
|
1193
1262
|
if self._response_task:
|
|
1194
1263
|
try:
|
|
1195
1264
|
await asyncio.wait_for(self._response_task, timeout=1.0)
|
|
1196
1265
|
except asyncio.TimeoutError:
|
|
1197
1266
|
logger.warning("shutdown of output event loop timed out-- cancelling")
|
|
1198
1267
|
self._response_task.cancel()
|
|
1268
|
+
tasks.append(self._response_task)
|
|
1199
1269
|
|
|
1200
1270
|
# must cancel the audio input task before closing the input stream
|
|
1201
1271
|
if self._audio_input_task and not self._audio_input_task.done():
|
|
1202
1272
|
self._audio_input_task.cancel()
|
|
1273
|
+
tasks.append(self._audio_input_task)
|
|
1203
1274
|
if self._stream_response and not self._stream_response.input_stream.closed:
|
|
1204
1275
|
await self._stream_response.input_stream.close()
|
|
1205
1276
|
|
|
1206
|
-
await asyncio.gather(
|
|
1277
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
1207
1278
|
logger.debug(f"CHAT CONTEXT: {self._chat_ctx.items}")
|
|
1208
1279
|
logger.info("Session end")
|
|
@@ -6,7 +6,7 @@ import uuid
|
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from typing import Any, Callable
|
|
8
8
|
|
|
9
|
-
from livekit.agents import llm
|
|
9
|
+
from livekit.agents import llm
|
|
10
10
|
|
|
11
11
|
from ...log import logger
|
|
12
12
|
|
|
@@ -34,7 +34,7 @@ class _Turn:
|
|
|
34
34
|
ev_trans_completed: bool = False
|
|
35
35
|
ev_generation_sent: bool = False
|
|
36
36
|
|
|
37
|
-
def add_partial_text(self, text: str):
|
|
37
|
+
def add_partial_text(self, text: str) -> None:
|
|
38
38
|
self.transcript.append(text)
|
|
39
39
|
|
|
40
40
|
@property
|
|
@@ -46,19 +46,17 @@ class _TurnTracker:
|
|
|
46
46
|
def __init__(
|
|
47
47
|
self,
|
|
48
48
|
emit_fn: Callable[[str, Any], None],
|
|
49
|
-
|
|
50
|
-
[], tuple[utils.aio.Chan[llm.MessageGeneration], utils.aio.Chan[llm.FunctionCall]]
|
|
51
|
-
],
|
|
49
|
+
emit_generation_fn: Callable[[], None],
|
|
52
50
|
):
|
|
53
51
|
self._emit = emit_fn
|
|
54
52
|
self._turn_idx = 0
|
|
55
53
|
self._curr_turn: _Turn | None = None
|
|
56
|
-
self.
|
|
54
|
+
self._emit_generation_fn = emit_generation_fn
|
|
57
55
|
|
|
58
56
|
# --------------------------------------------------------
|
|
59
57
|
# PUBLIC ENTRY POINT
|
|
60
58
|
# --------------------------------------------------------
|
|
61
|
-
def feed(self, event: dict):
|
|
59
|
+
def feed(self, event: dict) -> None:
|
|
62
60
|
turn = self._ensure_turn()
|
|
63
61
|
kind = _classify(event)
|
|
64
62
|
|
|
@@ -97,13 +95,13 @@ class _TurnTracker:
|
|
|
97
95
|
self._curr_turn = _Turn(turn_id=self._turn_idx)
|
|
98
96
|
return self._curr_turn
|
|
99
97
|
|
|
100
|
-
def _maybe_emit_input_started(self, turn: _Turn):
|
|
98
|
+
def _maybe_emit_input_started(self, turn: _Turn) -> None:
|
|
101
99
|
if not turn.ev_input_started:
|
|
102
100
|
turn.ev_input_started = True
|
|
103
101
|
self._emit("input_speech_started", llm.InputSpeechStartedEvent())
|
|
104
102
|
turn.phase = _Phase.USER_SPEAKING
|
|
105
103
|
|
|
106
|
-
def _maybe_emit_input_stopped(self, turn: _Turn):
|
|
104
|
+
def _maybe_emit_input_stopped(self, turn: _Turn) -> None:
|
|
107
105
|
if not turn.ev_input_stopped:
|
|
108
106
|
turn.ev_input_stopped = True
|
|
109
107
|
self._emit(
|
|
@@ -111,7 +109,7 @@ class _TurnTracker:
|
|
|
111
109
|
)
|
|
112
110
|
turn.phase = _Phase.USER_FINISHED
|
|
113
111
|
|
|
114
|
-
def _emit_transcript_updated(self, turn: _Turn):
|
|
112
|
+
def _emit_transcript_updated(self, turn: _Turn) -> None:
|
|
115
113
|
self._emit(
|
|
116
114
|
"input_audio_transcription_completed",
|
|
117
115
|
llm.InputTranscriptionCompleted(
|
|
@@ -121,7 +119,7 @@ class _TurnTracker:
|
|
|
121
119
|
),
|
|
122
120
|
)
|
|
123
121
|
|
|
124
|
-
def _maybe_emit_transcript_completed(self, turn: _Turn):
|
|
122
|
+
def _maybe_emit_transcript_completed(self, turn: _Turn) -> None:
|
|
125
123
|
if not turn.ev_trans_completed:
|
|
126
124
|
turn.ev_trans_completed = True
|
|
127
125
|
self._emit(
|
|
@@ -134,17 +132,10 @@ class _TurnTracker:
|
|
|
134
132
|
),
|
|
135
133
|
)
|
|
136
134
|
|
|
137
|
-
def _maybe_emit_generation_created(self, turn: _Turn):
|
|
135
|
+
def _maybe_emit_generation_created(self, turn: _Turn) -> None:
|
|
138
136
|
if not turn.ev_generation_sent:
|
|
139
137
|
turn.ev_generation_sent = True
|
|
140
|
-
|
|
141
|
-
logger.debug("Emitting generation event")
|
|
142
|
-
generation_ev = llm.GenerationCreatedEvent(
|
|
143
|
-
message_stream=msg_stream,
|
|
144
|
-
function_stream=fn_stream,
|
|
145
|
-
user_initiated=False,
|
|
146
|
-
)
|
|
147
|
-
self._emit("generation_created", generation_ev)
|
|
138
|
+
self._emit_generation_fn()
|
|
148
139
|
turn.phase = _Phase.ASSISTANT_RESPONDING
|
|
149
140
|
|
|
150
141
|
|
|
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
|
|
19
19
|
from typing import Any, cast
|
|
20
20
|
|
|
21
21
|
import aioboto3 # type: ignore
|
|
22
|
+
from botocore.config import Config
|
|
22
23
|
|
|
23
24
|
from livekit.agents import APIConnectionError, APIStatusError, llm
|
|
24
25
|
from livekit.agents.llm import (
|
|
@@ -205,7 +206,8 @@ class LLMStream(llm.LLMStream):
|
|
|
205
206
|
async def _run(self) -> None:
|
|
206
207
|
retryable = True
|
|
207
208
|
try:
|
|
208
|
-
|
|
209
|
+
config = Config(user_agent_extra="x-client-framework:livekit-plugins-aws")
|
|
210
|
+
async with self._session.client("bedrock-runtime", config=config) as client:
|
|
209
211
|
response = await client.converse_stream(**self._opts)
|
|
210
212
|
request_id = response["ResponseMetadata"]["RequestId"]
|
|
211
213
|
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|