livekit-plugins-aws 0.1.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of livekit-plugins-aws might be problematic. Click here for more details.

@@ -14,70 +14,72 @@ from __future__ import annotations
14
14
 
15
15
  import asyncio
16
16
  from dataclasses import dataclass
17
- from typing import Optional
18
17
 
18
+ import aioboto3
19
+ from amazon_transcribe.auth import StaticCredentialResolver
19
20
  from amazon_transcribe.client import TranscribeStreamingClient
20
21
  from amazon_transcribe.model import Result, TranscriptEvent
22
+
21
23
  from livekit import rtc
22
- from livekit.agents import (
23
- DEFAULT_API_CONNECT_OPTIONS,
24
- APIConnectOptions,
25
- stt,
26
- utils,
27
- )
28
-
29
- from ._utils import _get_aws_credentials
24
+ from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions, stt, utils
25
+ from livekit.agents.types import NOT_GIVEN, NotGivenOr
26
+ from livekit.agents.utils import is_given
27
+
30
28
  from .log import logger
29
+ from .utils import DEFAULT_REGION, get_aws_async_session
30
+
31
+ REFRESH_INTERVAL = 1800
31
32
 
32
33
 
33
34
  @dataclass
34
35
  class STTOptions:
35
- speech_region: str
36
36
  sample_rate: int
37
37
  language: str
38
38
  encoding: str
39
- vocabulary_name: Optional[str]
40
- session_id: Optional[str]
41
- vocab_filter_method: Optional[str]
42
- vocab_filter_name: Optional[str]
43
- show_speaker_label: Optional[bool]
44
- enable_channel_identification: Optional[bool]
45
- number_of_channels: Optional[int]
46
- enable_partial_results_stabilization: Optional[bool]
47
- partial_results_stability: Optional[str]
48
- language_model_name: Optional[str]
39
+ vocabulary_name: NotGivenOr[str]
40
+ session_id: NotGivenOr[str]
41
+ vocab_filter_method: NotGivenOr[str]
42
+ vocab_filter_name: NotGivenOr[str]
43
+ show_speaker_label: NotGivenOr[bool]
44
+ enable_channel_identification: NotGivenOr[bool]
45
+ number_of_channels: NotGivenOr[int]
46
+ enable_partial_results_stabilization: NotGivenOr[bool]
47
+ partial_results_stability: NotGivenOr[str]
48
+ language_model_name: NotGivenOr[str]
49
49
 
50
50
 
51
51
  class STT(stt.STT):
52
52
  def __init__(
53
53
  self,
54
54
  *,
55
- speech_region: str = "us-east-1",
56
- api_key: str | None = None,
57
- api_secret: str | None = None,
55
+ region: NotGivenOr[str] = NOT_GIVEN,
56
+ api_key: NotGivenOr[str] = NOT_GIVEN,
57
+ api_secret: NotGivenOr[str] = NOT_GIVEN,
58
58
  sample_rate: int = 48000,
59
59
  language: str = "en-US",
60
60
  encoding: str = "pcm",
61
- vocabulary_name: Optional[str] = None,
62
- session_id: Optional[str] = None,
63
- vocab_filter_method: Optional[str] = None,
64
- vocab_filter_name: Optional[str] = None,
65
- show_speaker_label: Optional[bool] = None,
66
- enable_channel_identification: Optional[bool] = None,
67
- number_of_channels: Optional[int] = None,
68
- enable_partial_results_stabilization: Optional[bool] = None,
69
- partial_results_stability: Optional[str] = None,
70
- language_model_name: Optional[str] = None,
61
+ vocabulary_name: NotGivenOr[str] = NOT_GIVEN,
62
+ session_id: NotGivenOr[str] = NOT_GIVEN,
63
+ vocab_filter_method: NotGivenOr[str] = NOT_GIVEN,
64
+ vocab_filter_name: NotGivenOr[str] = NOT_GIVEN,
65
+ show_speaker_label: NotGivenOr[bool] = NOT_GIVEN,
66
+ enable_channel_identification: NotGivenOr[bool] = NOT_GIVEN,
67
+ number_of_channels: NotGivenOr[int] = NOT_GIVEN,
68
+ enable_partial_results_stabilization: NotGivenOr[bool] = NOT_GIVEN,
69
+ partial_results_stability: NotGivenOr[str] = NOT_GIVEN,
70
+ language_model_name: NotGivenOr[str] = NOT_GIVEN,
71
+ session: aioboto3.Session | None = None,
72
+ refresh_interval: NotGivenOr[int] = NOT_GIVEN,
71
73
  ):
72
- super().__init__(
73
- capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
74
+ super().__init__(capabilities=stt.STTCapabilities(streaming=True, interim_results=True))
75
+ self._region = region if is_given(region) else DEFAULT_REGION
76
+ self._session = session or get_aws_async_session(
77
+ api_key=api_key if is_given(api_key) else None,
78
+ api_secret=api_secret if is_given(api_secret) else None,
79
+ region=self._region,
74
80
  )
75
81
 
76
- self._api_key, self._api_secret = _get_aws_credentials(
77
- api_key, api_secret, speech_region
78
- )
79
82
  self._config = STTOptions(
80
- speech_region=speech_region,
81
83
  language=language,
82
84
  sample_rate=sample_rate,
83
85
  encoding=encoding,
@@ -92,26 +94,47 @@ class STT(stt.STT):
92
94
  partial_results_stability=partial_results_stability,
93
95
  language_model_name=language_model_name,
94
96
  )
97
+ self._pool = utils.ConnectionPool[TranscribeStreamingClient](
98
+ connect_cb=self._create_client,
99
+ max_session_duration=refresh_interval
100
+ if is_given(refresh_interval)
101
+ else REFRESH_INTERVAL,
102
+ )
103
+
104
+ async def _create_client(self) -> TranscribeStreamingClient:
105
+ creds = await self._session.get_credentials()
106
+ frozen_credentials = await creds.get_frozen_credentials()
107
+ return TranscribeStreamingClient(
108
+ region=self._region,
109
+ credential_resolver=StaticCredentialResolver(
110
+ access_key_id=frozen_credentials.access_key,
111
+ secret_access_key=frozen_credentials.secret_key,
112
+ session_token=frozen_credentials.token,
113
+ ),
114
+ )
115
+
116
+ async def aclose(self) -> None:
117
+ await self._pool.aclose()
118
+ await super().aclose()
95
119
 
96
120
  async def _recognize_impl(
97
121
  self,
98
122
  buffer: utils.AudioBuffer,
99
123
  *,
100
- language: str | None,
124
+ language: NotGivenOr[str] = NOT_GIVEN,
101
125
  conn_options: APIConnectOptions,
102
126
  ) -> stt.SpeechEvent:
103
- raise NotImplementedError(
104
- "Amazon Transcribe does not support single frame recognition"
105
- )
127
+ raise NotImplementedError("Amazon Transcribe does not support single frame recognition")
106
128
 
107
129
  def stream(
108
130
  self,
109
131
  *,
110
- language: str | None = None,
132
+ language: NotGivenOr[str] = NOT_GIVEN,
111
133
  conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
112
- ) -> "SpeechStream":
134
+ ) -> SpeechStream:
113
135
  return SpeechStream(
114
136
  stt=self,
137
+ pool=self._pool,
115
138
  conn_options=conn_options,
116
139
  opts=self._config,
117
140
  )
@@ -122,54 +145,54 @@ class SpeechStream(stt.SpeechStream):
122
145
  self,
123
146
  stt: STT,
124
147
  opts: STTOptions,
148
+ pool: utils.ConnectionPool[TranscribeStreamingClient],
125
149
  conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
126
150
  ) -> None:
127
- super().__init__(
128
- stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
129
- )
151
+ super().__init__(stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate)
130
152
  self._opts = opts
131
- self._client = TranscribeStreamingClient(region=self._opts.speech_region)
153
+ self._pool = pool
132
154
 
133
155
  async def _run(self) -> None:
134
- stream = await self._client.start_stream_transcription(
135
- language_code=self._opts.language,
136
- media_sample_rate_hz=self._opts.sample_rate,
137
- media_encoding=self._opts.encoding,
138
- vocabulary_name=self._opts.vocabulary_name,
139
- session_id=self._opts.session_id,
140
- vocab_filter_method=self._opts.vocab_filter_method,
141
- vocab_filter_name=self._opts.vocab_filter_name,
142
- show_speaker_label=self._opts.show_speaker_label,
143
- enable_channel_identification=self._opts.enable_channel_identification,
144
- number_of_channels=self._opts.number_of_channels,
145
- enable_partial_results_stabilization=self._opts.enable_partial_results_stabilization,
146
- partial_results_stability=self._opts.partial_results_stability,
147
- language_model_name=self._opts.language_model_name,
148
- )
149
-
150
- @utils.log_exceptions(logger=logger)
151
- async def input_generator():
152
- async for frame in self._input_ch:
153
- if isinstance(frame, rtc.AudioFrame):
154
- await stream.input_stream.send_audio_event(
155
- audio_chunk=frame.data.tobytes()
156
- )
157
- await stream.input_stream.end_stream()
158
-
159
- @utils.log_exceptions(logger=logger)
160
- async def handle_transcript_events():
161
- async for event in stream.output_stream:
162
- if isinstance(event, TranscriptEvent):
163
- self._process_transcript_event(event)
164
-
165
- tasks = [
166
- asyncio.create_task(input_generator()),
167
- asyncio.create_task(handle_transcript_events()),
168
- ]
169
- try:
170
- await asyncio.gather(*tasks)
171
- finally:
172
- await utils.aio.gracefully_cancel(*tasks)
156
+ async with self._pool.connection() as client:
157
+ live_config = {
158
+ "language_code": self._opts.language,
159
+ "media_sample_rate_hz": self._opts.sample_rate,
160
+ "media_encoding": self._opts.encoding,
161
+ "vocabulary_name": self._opts.vocabulary_name,
162
+ "session_id": self._opts.session_id,
163
+ "vocab_filter_method": self._opts.vocab_filter_method,
164
+ "vocab_filter_name": self._opts.vocab_filter_name,
165
+ "show_speaker_label": self._opts.show_speaker_label,
166
+ "enable_channel_identification": self._opts.enable_channel_identification,
167
+ "number_of_channels": self._opts.number_of_channels,
168
+ "enable_partial_results_stabilization": self._opts.enable_partial_results_stabilization, # noqa: E501
169
+ "partial_results_stability": self._opts.partial_results_stability,
170
+ "language_model_name": self._opts.language_model_name,
171
+ }
172
+ filtered_config = {k: v for k, v in live_config.items() if v and is_given(v)}
173
+ stream = await client.start_stream_transcription(**filtered_config)
174
+
175
+ @utils.log_exceptions(logger=logger)
176
+ async def input_generator():
177
+ async for frame in self._input_ch:
178
+ if isinstance(frame, rtc.AudioFrame):
179
+ await stream.input_stream.send_audio_event(audio_chunk=frame.data.tobytes())
180
+ await stream.input_stream.end_stream()
181
+
182
+ @utils.log_exceptions(logger=logger)
183
+ async def handle_transcript_events():
184
+ async for event in stream.output_stream:
185
+ if isinstance(event, TranscriptEvent):
186
+ self._process_transcript_event(event)
187
+
188
+ tasks = [
189
+ asyncio.create_task(input_generator()),
190
+ asyncio.create_task(handle_transcript_events()),
191
+ ]
192
+ try:
193
+ await asyncio.gather(*tasks)
194
+ finally:
195
+ await utils.aio.gracefully_cancel(*tasks)
173
196
 
174
197
  def _process_transcript_event(self, transcript_event: TranscriptEvent):
175
198
  stream = transcript_event.transcript.results
@@ -184,9 +207,7 @@ class SpeechStream(stt.SpeechStream):
184
207
  self._event_ch.send_nowait(
185
208
  stt.SpeechEvent(
186
209
  type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
187
- alternatives=[
188
- _streaming_recognize_response_to_speech_data(resp)
189
- ],
210
+ alternatives=[_streaming_recognize_response_to_speech_data(resp)],
190
211
  )
191
212
  )
192
213
 
@@ -194,16 +215,12 @@ class SpeechStream(stt.SpeechStream):
194
215
  self._event_ch.send_nowait(
195
216
  stt.SpeechEvent(
196
217
  type=stt.SpeechEventType.FINAL_TRANSCRIPT,
197
- alternatives=[
198
- _streaming_recognize_response_to_speech_data(resp)
199
- ],
218
+ alternatives=[_streaming_recognize_response_to_speech_data(resp)],
200
219
  )
201
220
  )
202
221
 
203
222
  if not resp.is_partial:
204
- self._event_ch.send_nowait(
205
- stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)
206
- )
223
+ self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
207
224
 
208
225
 
209
226
  def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData:
@@ -211,7 +228,6 @@ def _streaming_recognize_response_to_speech_data(resp: Result) -> stt.SpeechData
211
228
  language="en-US",
212
229
  start_time=resp.start_time if resp.start_time else 0.0,
213
230
  end_time=resp.end_time if resp.end_time else 0.0,
214
- confidence=0.0,
215
231
  text=resp.alternatives[0].transcript if resp.alternatives else "",
216
232
  )
217
233
 
@@ -14,10 +14,10 @@ from __future__ import annotations
14
14
 
15
15
  import asyncio
16
16
  from dataclasses import dataclass
17
- from typing import Any, Callable, Optional
18
17
 
18
+ import aioboto3
19
19
  import aiohttp
20
- from aiobotocore.session import AioSession, get_session
20
+
21
21
  from livekit.agents import (
22
22
  APIConnectionError,
23
23
  APIConnectOptions,
@@ -26,13 +26,18 @@ from livekit.agents import (
26
26
  tts,
27
27
  utils,
28
28
  )
29
+ from livekit.agents.types import (
30
+ DEFAULT_API_CONNECT_OPTIONS,
31
+ NOT_GIVEN,
32
+ NotGivenOr,
33
+ )
34
+ from livekit.agents.utils import is_given
29
35
 
30
- from ._utils import _get_aws_credentials
31
36
  from .models import TTS_LANGUAGE, TTS_SPEECH_ENGINE
37
+ from .utils import _strip_nones, get_aws_async_session
32
38
 
33
39
  TTS_NUM_CHANNELS: int = 1
34
40
  DEFAULT_SPEECH_ENGINE: TTS_SPEECH_ENGINE = "generative"
35
- DEFAULT_SPEECH_REGION = "us-east-1"
36
41
  DEFAULT_VOICE = "Ruth"
37
42
  DEFAULT_SAMPLE_RATE = 16000
38
43
 
@@ -40,25 +45,25 @@ DEFAULT_SAMPLE_RATE = 16000
40
45
  @dataclass
41
46
  class _TTSOptions:
42
47
  # https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
43
- voice: str | None
44
- speech_engine: TTS_SPEECH_ENGINE
45
- speech_region: str
48
+ voice: NotGivenOr[str]
49
+ speech_engine: NotGivenOr[TTS_SPEECH_ENGINE]
50
+ region: str
46
51
  sample_rate: int
47
- language: TTS_LANGUAGE | str | None
52
+ language: NotGivenOr[TTS_LANGUAGE | str]
48
53
 
49
54
 
50
55
  class TTS(tts.TTS):
51
56
  def __init__(
52
57
  self,
53
58
  *,
54
- voice: str | None = DEFAULT_VOICE,
55
- language: TTS_LANGUAGE | str | None = None,
56
- speech_engine: TTS_SPEECH_ENGINE = DEFAULT_SPEECH_ENGINE,
59
+ voice: NotGivenOr[str] = NOT_GIVEN,
60
+ language: NotGivenOr[TTS_LANGUAGE | str] = NOT_GIVEN,
61
+ speech_engine: NotGivenOr[TTS_SPEECH_ENGINE] = NOT_GIVEN,
57
62
  sample_rate: int = DEFAULT_SAMPLE_RATE,
58
- speech_region: str = DEFAULT_SPEECH_REGION,
59
- api_key: str | None = None,
60
- api_secret: str | None = None,
61
- session: AioSession | None = None,
63
+ region: NotGivenOr[str] = NOT_GIVEN,
64
+ api_key: NotGivenOr[str] = NOT_GIVEN,
65
+ api_secret: NotGivenOr[str] = NOT_GIVEN,
66
+ session: aioboto3.Session | None = None,
62
67
  ) -> None:
63
68
  """
64
69
  Create a new instance of AWS Polly TTS.
@@ -73,10 +78,11 @@ class TTS(tts.TTS):
73
78
  language (TTS_LANGUAGE, optional): language code for the Synthesize Speech request. This is only necessary if using a bilingual voice, such as Aditi, which can be used for either Indian English (en-IN) or Hindi (hi-IN).
74
79
  sample_rate(int, optional): The audio frequency specified in Hz. Defaults to 16000.
75
80
  speech_engine(TTS_SPEECH_ENGINE, optional): The engine to use for the synthesis. Defaults to "generative".
76
- speech_region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
81
+ region(str, optional): The region to use for the synthesis. Defaults to "us-east-1".
77
82
  api_key(str, optional): AWS access key id.
78
83
  api_secret(str, optional): AWS secret access key.
79
- """
84
+ session(aioboto3.Session, optional): Optional aioboto3 session to use.
85
+ """ # noqa: E501
80
86
  super().__init__(
81
87
  capabilities=tts.TTSCapabilities(
82
88
  streaming=False,
@@ -84,40 +90,31 @@ class TTS(tts.TTS):
84
90
  sample_rate=sample_rate,
85
91
  num_channels=TTS_NUM_CHANNELS,
86
92
  )
87
-
88
- self._api_key, self._api_secret = _get_aws_credentials(
89
- api_key, api_secret, speech_region
93
+ self._session = session or get_aws_async_session(
94
+ api_key=api_key if is_given(api_key) else None,
95
+ api_secret=api_secret if is_given(api_secret) else None,
96
+ region=region if is_given(region) else None,
90
97
  )
91
-
92
98
  self._opts = _TTSOptions(
93
99
  voice=voice,
94
100
  speech_engine=speech_engine,
95
- speech_region=speech_region,
101
+ region=region,
96
102
  language=language,
97
103
  sample_rate=sample_rate,
98
104
  )
99
- self._session = session or get_session()
100
-
101
- def _get_client(self):
102
- return self._session.create_client(
103
- "polly",
104
- region_name=self._opts.speech_region,
105
- aws_access_key_id=self._api_key,
106
- aws_secret_access_key=self._api_secret,
107
- )
108
105
 
109
106
  def synthesize(
110
107
  self,
111
108
  text: str,
112
109
  *,
113
- conn_options: Optional[APIConnectOptions] = None,
114
- ) -> "ChunkedStream":
110
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
111
+ ) -> ChunkedStream:
115
112
  return ChunkedStream(
116
113
  tts=self,
117
114
  text=text,
118
115
  conn_options=conn_options,
116
+ session=self._session,
119
117
  opts=self._opts,
120
- get_client=self._get_client,
121
118
  )
122
119
 
123
120
 
@@ -127,28 +124,30 @@ class ChunkedStream(tts.ChunkedStream):
127
124
  *,
128
125
  tts: TTS,
129
126
  text: str,
130
- conn_options: Optional[APIConnectOptions] = None,
127
+ session: aioboto3.Session,
128
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
131
129
  opts: _TTSOptions,
132
- get_client: Callable[[], Any],
133
130
  ) -> None:
134
131
  super().__init__(tts=tts, input_text=text, conn_options=conn_options)
135
132
  self._opts = opts
136
- self._get_client = get_client
137
133
  self._segment_id = utils.shortuuid()
134
+ self._session = session
138
135
 
139
136
  async def _run(self):
140
137
  request_id = utils.shortuuid()
141
138
 
142
139
  try:
143
- async with self._get_client() as client:
140
+ async with self._session.client("polly") as client:
144
141
  params = {
145
142
  "Text": self._input_text,
146
143
  "OutputFormat": "mp3",
147
- "Engine": self._opts.speech_engine,
148
- "VoiceId": self._opts.voice,
144
+ "Engine": self._opts.speech_engine
145
+ if is_given(self._opts.speech_engine)
146
+ else DEFAULT_SPEECH_ENGINE,
147
+ "VoiceId": self._opts.voice if is_given(self._opts.voice) else DEFAULT_VOICE,
149
148
  "TextType": "text",
150
149
  "SampleRate": str(self._opts.sample_rate),
151
- "LanguageCode": self._opts.language,
150
+ "LanguageCode": self._opts.language if is_given(self._opts.language) else None,
152
151
  }
153
152
  response = await client.synthesize_speech(**_strip_nones(params))
154
153
  if "AudioStream" in response:
@@ -194,7 +193,3 @@ class ChunkedStream(tts.ChunkedStream):
194
193
  ) from e
195
194
  except Exception as e:
196
195
  raise APIConnectionError() from e
197
-
198
-
199
- def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
200
- return {k: v for k, v in d.items() if v is not None}
@@ -0,0 +1,144 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ import aioboto3
7
+ import boto3
8
+ from botocore.exceptions import NoCredentialsError
9
+
10
+ from livekit.agents import llm
11
+ from livekit.agents.llm import ChatContext, FunctionTool, ImageContent, utils
12
+
13
+ __all__ = ["to_fnc_ctx", "to_chat_ctx", "get_aws_async_session"]
14
+ DEFAULT_REGION = "us-east-1"
15
+
16
+
17
+ def get_aws_async_session(
18
+ region: str | None = None,
19
+ api_key: str | None = None,
20
+ api_secret: str | None = None,
21
+ ) -> aioboto3.Session:
22
+ _validate_aws_credentials(api_key, api_secret)
23
+ session = aioboto3.Session(
24
+ aws_access_key_id=api_key,
25
+ aws_secret_access_key=api_secret,
26
+ region_name=region or DEFAULT_REGION,
27
+ )
28
+ return session
29
+
30
+
31
+ def _validate_aws_credentials(
32
+ api_key: str | None = None,
33
+ api_secret: str | None = None,
34
+ ) -> None:
35
+ try:
36
+ session = boto3.Session(aws_access_key_id=api_key, aws_secret_access_key=api_secret)
37
+ creds = session.get_credentials()
38
+ if not creds:
39
+ raise ValueError("No credentials found")
40
+ except (NoCredentialsError, Exception) as e:
41
+ raise ValueError(f"Unable to locate valid AWS credentials: {str(e)}") from e
42
+
43
+
44
+ def to_fnc_ctx(fncs: list[FunctionTool]) -> list[dict]:
45
+ return [_build_tool_spec(fnc) for fnc in fncs]
46
+
47
+
48
+ def to_chat_ctx(chat_ctx: ChatContext, cache_key: Any) -> tuple[list[dict], dict | None]:
49
+ messages: list[dict] = []
50
+ system_message: dict | None = None
51
+ current_role: str | None = None
52
+ current_content: list[dict] = []
53
+
54
+ for msg in chat_ctx.items:
55
+ if msg.type == "message" and msg.role == "system":
56
+ for content in msg.content:
57
+ if content and isinstance(content, str):
58
+ system_message = {"text": content}
59
+ continue
60
+
61
+ if msg.type == "message":
62
+ role = "assistant" if msg.role == "assistant" else "user"
63
+ elif msg.type == "function_call":
64
+ role = "assistant"
65
+ elif msg.type == "function_call_output":
66
+ role = "user"
67
+
68
+ # if the effective role changed, finalize the previous turn.
69
+ if role != current_role:
70
+ if current_content and current_role is not None:
71
+ messages.append({"role": current_role, "content": current_content})
72
+ current_content = []
73
+ current_role = role
74
+
75
+ if msg.type == "message":
76
+ for content in msg.content:
77
+ if content and isinstance(content, str):
78
+ current_content.append({"text": content})
79
+ elif isinstance(content, ImageContent):
80
+ current_content.append(_build_image(content, cache_key))
81
+ elif msg.type == "function_call":
82
+ current_content.append(
83
+ {
84
+ "toolUse": {
85
+ "toolUseId": msg.call_id,
86
+ "name": msg.name,
87
+ "input": json.loads(msg.arguments or "{}"),
88
+ }
89
+ }
90
+ )
91
+ elif msg.type == "function_call_output":
92
+ tool_response = {
93
+ "toolResult": {
94
+ "toolUseId": msg.call_id,
95
+ "content": [],
96
+ "status": "success",
97
+ }
98
+ }
99
+ if isinstance(msg.output, dict):
100
+ tool_response["toolResult"]["content"].append({"json": msg.output})
101
+ elif isinstance(msg.output, str):
102
+ tool_response["toolResult"]["content"].append({"text": msg.output})
103
+ current_content.append(tool_response)
104
+
105
+ # Finalize the last message if there’s any content left
106
+ if current_role is not None and current_content:
107
+ messages.append({"role": current_role, "content": current_content})
108
+
109
+ # Ensure the message list starts with a "user" message
110
+ if not messages or messages[0]["role"] != "user":
111
+ messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})
112
+
113
+ return messages, system_message
114
+
115
+
116
+ def _build_tool_spec(fnc: FunctionTool) -> dict:
117
+ fnc = llm.utils.build_legacy_openai_schema(fnc, internally_tagged=True)
118
+ return {
119
+ "toolSpec": _strip_nones(
120
+ {
121
+ "name": fnc["name"],
122
+ "description": fnc["description"] if fnc["description"] else None,
123
+ "inputSchema": {"json": fnc["parameters"] if fnc["parameters"] else {}},
124
+ }
125
+ )
126
+ }
127
+
128
+
129
+ def _build_image(image: ImageContent, cache_key: Any) -> dict:
130
+ img = utils.serialize_image(image)
131
+ if img.external_url:
132
+ raise ValueError("external_url is not supported by AWS Bedrock.")
133
+ if cache_key not in image._cache:
134
+ image._cache[cache_key] = img.data_bytes
135
+ return {
136
+ "image": {
137
+ "format": "jpeg",
138
+ "source": {"bytes": image._cache[cache_key]},
139
+ }
140
+ }
141
+
142
+
143
+ def _strip_nones(d: dict) -> dict:
144
+ return {k: v for k, v in d.items() if v is not None}
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "0.1.1"
15
+ __version__ = "1.0.0"