livekit-plugins-aws 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of livekit-plugins-aws might be problematic. Click here for more details.

@@ -14,50 +14,52 @@
14
14
  # limitations under the License.
15
15
  from __future__ import annotations
16
16
 
17
- import asyncio
18
17
  import os
19
18
  from dataclasses import dataclass
20
- from typing import Any, Literal, MutableSet, Union
19
+ from typing import Any, Literal
21
20
 
22
- import boto3
23
- from livekit.agents import (
24
- APIConnectionError,
25
- APIStatusError,
26
- llm,
21
+ import aioboto3
22
+
23
+ from livekit.agents import APIConnectionError, APIStatusError, llm
24
+ from livekit.agents.llm import ChatContext, FunctionTool, FunctionToolCall, ToolChoice
25
+ from livekit.agents.types import (
26
+ DEFAULT_API_CONNECT_OPTIONS,
27
+ NOT_GIVEN,
28
+ APIConnectOptions,
29
+ NotGivenOr,
27
30
  )
28
- from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
29
- from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
31
+ from livekit.agents.utils import is_given
30
32
 
31
- from ._utils import _build_aws_ctx, _build_tools, _get_aws_credentials
32
33
  from .log import logger
34
+ from .utils import get_aws_async_session, to_chat_ctx, to_fnc_ctx
33
35
 
34
36
  TEXT_MODEL = Literal["anthropic.claude-3-5-sonnet-20241022-v2:0"]
35
- DEFAULT_REGION = "us-east-1"
36
37
 
37
38
 
38
39
  @dataclass
39
- class LLMOptions:
40
- model: TEXT_MODEL | str
41
- temperature: float | None
42
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto"
43
- max_output_tokens: int | None = None
44
- top_p: float | None = None
45
- additional_request_fields: dict[str, Any] | None = None
40
+ class _LLMOptions:
41
+ model: str | TEXT_MODEL
42
+ temperature: NotGivenOr[float]
43
+ tool_choice: NotGivenOr[ToolChoice]
44
+ max_output_tokens: NotGivenOr[int]
45
+ top_p: NotGivenOr[float]
46
+ additional_request_fields: NotGivenOr[dict[str, Any]]
46
47
 
47
48
 
48
49
  class LLM(llm.LLM):
49
50
  def __init__(
50
51
  self,
51
52
  *,
52
- model: TEXT_MODEL | str = "anthropic.claude-3-5-sonnet-20240620-v1:0",
53
- api_key: str | None = None,
54
- api_secret: str | None = None,
55
- region: str = "us-east-1",
56
- temperature: float = 0.8,
57
- max_output_tokens: int | None = None,
58
- top_p: float | None = None,
59
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
60
- additional_request_fields: dict[str, Any] | None = None,
53
+ model: NotGivenOr[str | TEXT_MODEL] = NOT_GIVEN,
54
+ api_key: NotGivenOr[str] = NOT_GIVEN,
55
+ api_secret: NotGivenOr[str] = NOT_GIVEN,
56
+ region: NotGivenOr[str] = NOT_GIVEN,
57
+ temperature: NotGivenOr[float] = NOT_GIVEN,
58
+ max_output_tokens: NotGivenOr[int] = NOT_GIVEN,
59
+ top_p: NotGivenOr[float] = NOT_GIVEN,
60
+ tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
61
+ additional_request_fields: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
62
+ session: aioboto3.Session | None = None,
61
63
  ) -> None:
62
64
  """
63
65
  Create a new instance of AWS Bedrock LLM.
@@ -65,7 +67,7 @@ class LLM(llm.LLM):
65
67
  ``api_key`` and ``api_secret`` must be set to your AWS Access key id and secret access key, either using the argument or by setting the
66
68
  ``AWS_ACCESS_KEY_ID`` and ``AWS_SECRET_ACCESS_KEY`` environmental variables.
67
69
 
68
- See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html for more details on the the AWS Bedrock Runtime API.
70
+ See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html for more details on the AWS Bedrock Runtime API.
69
71
 
70
72
  Args:
71
73
  model (TEXT_MODEL, optional): model or inference profile arn to use(https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-use.html). Defaults to 'anthropic.claude-3-5-sonnet-20240620-v1:0'.
@@ -75,67 +77,94 @@ class LLM(llm.LLM):
75
77
  temperature (float, optional): Sampling temperature for response generation. Defaults to 0.8.
76
78
  max_output_tokens (int, optional): Maximum number of tokens to generate in the output. Defaults to None.
77
79
  top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
78
- tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
80
+ tool_choice (ToolChoice, optional): Specifies whether to use tools during response generation. Defaults to "auto".
79
81
  additional_request_fields (dict[str, Any], optional): Additional request fields to send to the AWS Bedrock Converse API. Defaults to None.
80
- """
81
- super().__init__(
82
- capabilities=LLMCapabilities(
83
- supports_choices_on_int=True,
84
- requires_persistent_functions=True,
85
- )
86
- )
87
- self._api_key, self._api_secret = _get_aws_credentials(
88
- api_key, api_secret, region
82
+ session (aioboto3.Session, optional): Optional aioboto3 session to use.
83
+ """ # noqa: E501
84
+ super().__init__()
85
+
86
+ self._session = session or get_aws_async_session(
87
+ api_key=api_key if is_given(api_key) else None,
88
+ api_secret=api_secret if is_given(api_secret) else None,
89
+ region=region if is_given(region) else None,
89
90
  )
90
91
 
91
- self._model = model or os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
92
- if not self._model:
92
+ model = model if is_given(model) else os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
93
+ if not model:
93
94
  raise ValueError(
94
- "model or inference profile arn must be set using the argument or by setting the BEDROCK_INFERENCE_PROFILE_ARN environment variable."
95
+ "model or inference profile arn must be set using the argument or by setting the BEDROCK_INFERENCE_PROFILE_ARN environment variable." # noqa: E501
95
96
  )
96
- self._opts = LLMOptions(
97
- model=self._model,
97
+ self._opts = _LLMOptions(
98
+ model=model,
98
99
  temperature=temperature,
99
100
  tool_choice=tool_choice,
100
101
  max_output_tokens=max_output_tokens,
101
102
  top_p=top_p,
102
103
  additional_request_fields=additional_request_fields,
103
104
  )
104
- self._region = region
105
- self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
106
105
 
107
106
  def chat(
108
107
  self,
109
108
  *,
110
- chat_ctx: llm.ChatContext,
109
+ chat_ctx: ChatContext,
110
+ tools: list[FunctionTool] | None = None,
111
111
  conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
112
- fnc_ctx: llm.FunctionContext | None = None,
113
- temperature: float | None = None,
114
- n: int | None = 1,
115
- parallel_tool_calls: bool | None = None,
116
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]]
117
- | None = None,
118
- ) -> "LLMStream":
119
- if tool_choice is None:
120
- tool_choice = self._opts.tool_choice
121
-
122
- if temperature is None:
123
- temperature = self._opts.temperature
112
+ temperature: NotGivenOr[float] = NOT_GIVEN,
113
+ tool_choice: NotGivenOr[ToolChoice] = NOT_GIVEN,
114
+ ) -> LLMStream:
115
+ opts = {}
116
+
117
+ if is_given(self._opts.model):
118
+ opts["modelId"] = self._opts.model
119
+
120
+ def _get_tool_config() -> dict[str, Any] | None:
121
+ nonlocal tool_choice
122
+
123
+ if not tools:
124
+ return None
125
+
126
+ tool_config: dict[str, Any] = {"tools": to_fnc_ctx(tools)}
127
+ tool_choice = tool_choice if is_given(tool_choice) else self._opts.tool_choice
128
+ if is_given(tool_choice):
129
+ if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
130
+ tool_config["toolChoice"] = {"tool": {"name": tool_choice["function"]["name"]}}
131
+ elif tool_choice == "required":
132
+ tool_config["toolChoice"] = {"any": {}}
133
+ elif tool_choice == "auto":
134
+ tool_config["toolChoice"] = {"auto": {}}
135
+ else:
136
+ return None
137
+
138
+ return tool_config
139
+
140
+ tool_config = _get_tool_config()
141
+ if tool_config:
142
+ opts["toolConfig"] = tool_config
143
+ messages, system_message = to_chat_ctx(chat_ctx, id(self))
144
+ opts["messages"] = messages
145
+ if system_message:
146
+ opts["system"] = [system_message]
147
+
148
+ inference_config = {}
149
+ if is_given(self._opts.max_output_tokens):
150
+ inference_config["maxTokens"] = self._opts.max_output_tokens
151
+ temperature = temperature if is_given(temperature) else self._opts.temperature
152
+ if is_given(temperature):
153
+ inference_config["temperature"] = temperature
154
+ if is_given(self._opts.top_p):
155
+ inference_config["topP"] = self._opts.top_p
156
+
157
+ opts["inferenceConfig"] = inference_config
158
+ if is_given(self._opts.additional_request_fields):
159
+ opts["additionalModelRequestFields"] = self._opts.additional_request_fields
124
160
 
125
161
  return LLMStream(
126
162
  self,
127
- model=self._opts.model,
128
- aws_access_key_id=self._api_key,
129
- aws_secret_access_key=self._api_secret,
130
- region_name=self._region,
131
- max_output_tokens=self._opts.max_output_tokens,
132
- top_p=self._opts.top_p,
133
- additional_request_fields=self._opts.additional_request_fields,
134
163
  chat_ctx=chat_ctx,
135
- fnc_ctx=fnc_ctx,
164
+ tools=tools,
165
+ session=self._session,
136
166
  conn_options=conn_options,
137
- temperature=temperature,
138
- tool_choice=tool_choice,
167
+ extra_kwargs=opts,
139
168
  )
140
169
 
141
170
 
@@ -144,107 +173,39 @@ class LLMStream(llm.LLMStream):
144
173
  self,
145
174
  llm: LLM,
146
175
  *,
147
- model: str | TEXT_MODEL,
148
- aws_access_key_id: str | None,
149
- aws_secret_access_key: str | None,
150
- region_name: str,
151
- chat_ctx: llm.ChatContext,
176
+ chat_ctx: ChatContext,
177
+ session: aioboto3.Session,
152
178
  conn_options: APIConnectOptions,
153
- fnc_ctx: llm.FunctionContext | None,
154
- temperature: float | None,
155
- max_output_tokens: int | None,
156
- top_p: float | None,
157
- tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]],
158
- additional_request_fields: dict[str, Any] | None,
179
+ tools: list[FunctionTool] | None,
180
+ extra_kwargs: dict[str, Any],
159
181
  ) -> None:
160
- super().__init__(
161
- llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
162
- )
163
- self._client = boto3.client(
164
- "bedrock-runtime",
165
- region_name=region_name,
166
- aws_access_key_id=aws_access_key_id,
167
- aws_secret_access_key=aws_secret_access_key,
168
- )
169
- self._model = model
182
+ super().__init__(llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
170
183
  self._llm: LLM = llm
171
- self._max_output_tokens = max_output_tokens
172
- self._top_p = top_p
173
- self._temperature = temperature
174
- self._tool_choice = tool_choice
175
- self._additional_request_fields = additional_request_fields
176
-
177
- async def _run(self) -> None:
184
+ self._opts = extra_kwargs
185
+ self._session = session
178
186
  self._tool_call_id: str | None = None
179
187
  self._fnc_name: str | None = None
180
188
  self._fnc_raw_arguments: str | None = None
181
189
  self._text: str = ""
182
- retryable = True
183
190
 
191
+ async def _run(self) -> None:
192
+ retryable = True
184
193
  try:
185
- opts: dict[str, Any] = {}
186
- messages, system_instruction = _build_aws_ctx(self._chat_ctx, id(self))
187
- messages = _merge_messages(messages)
188
-
189
- def _get_tool_config() -> dict[str, Any] | None:
190
- if not (self._fnc_ctx and self._fnc_ctx.ai_functions):
191
- return None
192
-
193
- tools = _build_tools(self._fnc_ctx)
194
- config: dict[str, Any] = {"tools": tools}
195
-
196
- if isinstance(self._tool_choice, ToolChoice):
197
- config["toolChoice"] = {"tool": {"name": self._tool_choice.name}}
198
- elif self._tool_choice == "required":
199
- config["toolChoice"] = {"any": {}}
200
- elif self._tool_choice == "auto":
201
- config["toolChoice"] = {"auto": {}}
202
- else:
203
- return None
204
-
205
- return config
206
-
207
- tool_config = _get_tool_config()
208
- if tool_config:
209
- opts["toolConfig"] = tool_config
210
-
211
- if self._additional_request_fields:
212
- opts["additionalModelRequestFields"] = _strip_nones(
213
- self._additional_request_fields
214
- )
215
- if system_instruction:
216
- opts["system"] = [system_instruction]
217
-
218
- inference_config = _strip_nones(
219
- {
220
- "maxTokens": self._max_output_tokens,
221
- "temperature": self._temperature,
222
- "topP": self._top_p,
223
- }
224
- )
225
- response = self._client.converse_stream(
226
- modelId=self._model,
227
- messages=messages,
228
- inferenceConfig=inference_config,
229
- **_strip_nones(opts),
230
- ) # type: ignore
231
-
232
- request_id = response["ResponseMetadata"]["RequestId"]
233
- if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
234
- raise APIStatusError(
235
- f"aws bedrock llm: error generating content: {response}",
236
- retryable=False,
237
- request_id=request_id,
238
- )
239
-
240
- for chunk in response["stream"]:
241
- chat_chunk = self._parse_chunk(request_id, chunk)
242
- if chat_chunk is not None:
243
- retryable = False
244
- self._event_ch.send_nowait(chat_chunk)
245
-
246
- # Let other coroutines run
247
- await asyncio.sleep(0)
194
+ async with self._session.client("bedrock-runtime") as client:
195
+ response = await client.converse_stream(**self._opts) # type: ignore
196
+ request_id = response["ResponseMetadata"]["RequestId"]
197
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
198
+ raise APIStatusError(
199
+ f"aws bedrock llm: error generating content: {response}",
200
+ retryable=False,
201
+ request_id=request_id,
202
+ )
203
+
204
+ async for chunk in response["stream"]:
205
+ chat_chunk = self._parse_chunk(request_id, chunk)
206
+ if chat_chunk is not None:
207
+ retryable = False
208
+ self._event_ch.send_nowait(chat_chunk)
248
209
 
249
210
  except Exception as e:
250
211
  raise APIConnectionError(
@@ -258,93 +219,53 @@ class LLMStream(llm.LLMStream):
258
219
  self._tool_call_id = tool_use["toolUseId"]
259
220
  self._fnc_name = tool_use["name"]
260
221
  self._fnc_raw_arguments = ""
222
+
261
223
  elif "contentBlockDelta" in chunk:
262
224
  delta = chunk["contentBlockDelta"]["delta"]
263
225
  if "toolUse" in delta:
264
226
  self._fnc_raw_arguments += delta["toolUse"]["input"]
265
227
  elif "text" in delta:
266
- self._text += delta["text"]
267
- elif "contentBlockStop" in chunk:
268
- if self._text:
269
- chat_chunk = llm.ChatChunk(
270
- request_id=request_id,
271
- choices=[
272
- llm.Choice(
273
- delta=llm.ChoiceDelta(content=self._text, role="assistant"),
274
- index=chunk["contentBlockStop"]["contentBlockIndex"],
275
- )
276
- ],
228
+ return llm.ChatChunk(
229
+ id=request_id,
230
+ delta=llm.ChoiceDelta(content=delta["text"], role="assistant"),
277
231
  )
278
- self._text = ""
279
- return chat_chunk
280
- elif self._tool_call_id:
281
- return self._try_build_function(request_id, chunk)
282
-
283
- return None
284
-
285
- def _try_build_function(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
286
- if self._tool_call_id is None:
287
- logger.warning("aws bedrock llm: no tool call id in the response")
288
- return None
289
- if self._fnc_name is None:
290
- logger.warning("aws bedrock llm: no function name in the response")
291
- return None
292
- if self._fnc_raw_arguments is None:
293
- logger.warning("aws bedrock llm: no function arguments in the response")
294
- return None
295
- if self._fnc_ctx is None:
296
- logger.warning(
297
- "aws bedrock llm: stream tried to run function without function context"
232
+ else:
233
+ logger.warning(f"aws bedrock llm: unknown chunk type: {chunk}")
234
+
235
+ elif "metadata" in chunk:
236
+ metadata = chunk["metadata"]
237
+ return llm.ChatChunk(
238
+ id=request_id,
239
+ usage=llm.CompletionUsage(
240
+ completion_tokens=metadata["usage"]["outputTokens"],
241
+ prompt_tokens=metadata["usage"]["inputTokens"],
242
+ total_tokens=metadata["usage"]["totalTokens"],
243
+ ),
298
244
  )
299
- return None
300
-
301
- fnc_info = _create_ai_function_info(
302
- self._fnc_ctx,
303
- self._tool_call_id,
304
- self._fnc_name,
305
- self._fnc_raw_arguments,
306
- )
307
-
308
- self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
309
- self._function_calls_info.append(fnc_info)
310
-
311
- return llm.ChatChunk(
312
- request_id=request_id,
313
- choices=[
314
- llm.Choice(
245
+ elif "contentBlockStop" in chunk:
246
+ if self._tool_call_id:
247
+ if self._tool_call_id is None:
248
+ logger.warning("aws bedrock llm: no tool call id in the response")
249
+ return None
250
+ if self._fnc_name is None:
251
+ logger.warning("aws bedrock llm: no function name in the response")
252
+ return None
253
+ if self._fnc_raw_arguments is None:
254
+ logger.warning("aws bedrock llm: no function arguments in the response")
255
+ return None
256
+ chat_chunk = llm.ChatChunk(
257
+ id=request_id,
315
258
  delta=llm.ChoiceDelta(
316
259
  role="assistant",
317
- tool_calls=[fnc_info],
260
+ tool_calls=[
261
+ FunctionToolCall(
262
+ arguments=self._fnc_raw_arguments,
263
+ name=self._fnc_name,
264
+ call_id=self._tool_call_id,
265
+ ),
266
+ ],
318
267
  ),
319
- index=chunk["contentBlockStop"]["contentBlockIndex"],
320
268
  )
321
- ],
322
- )
323
-
324
-
325
- def _merge_messages(
326
- messages: list[dict],
327
- ) -> list[dict]:
328
- # Anthropic enforces alternating messages
329
- combined_messages: list[dict] = []
330
- for m in messages:
331
- if len(combined_messages) == 0 or m["role"] != combined_messages[-1]["role"]:
332
- combined_messages.append(m)
333
- continue
334
- last_message = combined_messages[-1]
335
- if not isinstance(last_message["content"], list) or not isinstance(
336
- m["content"], list
337
- ):
338
- logger.error("message content is not a list")
339
- continue
340
-
341
- last_message["content"].extend(m["content"])
342
-
343
- if len(combined_messages) == 0 or combined_messages[0]["role"] != "user":
344
- combined_messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})
345
-
346
- return combined_messages
347
-
348
-
349
- def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
350
- return {k: v for k, v in d.items() if v is not None}
269
+ self._tool_call_id = self._fnc_name = self._fnc_raw_arguments = None
270
+ return chat_chunk
271
+ return None
@@ -45,4 +45,4 @@ TTS_LANGUAGE = Literal[
45
45
  "de-CH",
46
46
  ]
47
47
 
48
- TTS_OUTPUT_FORMAT = Literal["pcm", "mp3"]
48
+ TTS_OUTPUT_FORMAT = Literal["mp3"]