audacity-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- audacity/__init__.py +20 -0
- audacity/_client.py +256 -0
- audacity/_mapping.py +346 -0
- audacity/_sse.py +280 -0
- audacity/_version.py +2 -0
- audacity/exceptions.py +349 -0
- audacity/py.typed +0 -0
- audacity_sdk-0.1.0.dist-info/METADATA +270 -0
- audacity_sdk-0.1.0.dist-info/RECORD +12 -0
- audacity_sdk-0.1.0.dist-info/WHEEL +5 -0
- audacity_sdk-0.1.0.dist-info/licenses/LICENSE +1 -0
- audacity_sdk-0.1.0.dist-info/top_level.txt +1 -0
audacity/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""
|
|
2
|
+
audacity — Audacity Investments AI SDK (Python).
|
|
3
|
+
|
|
4
|
+
Usage::
|
|
5
|
+
|
|
6
|
+
from audacity import Audacity
|
|
7
|
+
|
|
8
|
+
client = Audacity(api_key="audacity_api_…")
|
|
9
|
+
response = client.converse(
|
|
10
|
+
modelId="gpt-5.4-mini",
|
|
11
|
+
messages=[{"role": "user", "content": [{"text": "Hi"}]}],
|
|
12
|
+
)
|
|
13
|
+
print(response["output"]["message"]["content"][0]["text"])
|
|
14
|
+
"""
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from ._client import Audacity
|
|
18
|
+
from ._version import __version__
|
|
19
|
+
|
|
20
|
+
__all__ = ["Audacity", "__version__"]
|
audacity/_client.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Audacity SDK client.
|
|
3
|
+
|
|
4
|
+
Public surface (spec §2):
|
|
5
|
+
from audacity import Audacity
|
|
6
|
+
|
|
7
|
+
client = Audacity(api_key=None, base_url=None, timeout=120.0, max_retries=2)
|
|
8
|
+
response = client.converse(**kwargs) # returns dict
|
|
9
|
+
stream = client.converse_stream(**kwargs) # returns {"stream": generator}
|
|
10
|
+
"""
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import http.client
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
import random
|
|
17
|
+
import socket
|
|
18
|
+
import ssl
|
|
19
|
+
import time
|
|
20
|
+
import urllib.parse
|
|
21
|
+
from typing import Any, Dict, Optional, Tuple
|
|
22
|
+
|
|
23
|
+
from .exceptions import (
|
|
24
|
+
AudacityError,
|
|
25
|
+
MissingApiKeyError,
|
|
26
|
+
SdkError,
|
|
27
|
+
_ExceptionsNamespace,
|
|
28
|
+
parse_error_response,
|
|
29
|
+
)
|
|
30
|
+
from ._mapping import build_request_body, map_converse_response
|
|
31
|
+
from ._sse import iter_stream_events
|
|
32
|
+
from ._version import __version__ as _VERSION
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_DEFAULT_BASE_URL = "https://portal.audacityinvestments.com"
|
|
36
|
+
_DEFAULT_TIMEOUT = 120.0
|
|
37
|
+
_DEFAULT_MAX_RETRIES = 2
|
|
38
|
+
_USER_AGENT = "audacity-sdk-python/" + _VERSION
|
|
39
|
+
|
|
40
|
+
# HTTP status codes on which we retry (unless the specific error code overrides)
|
|
41
|
+
_RETRY_STATUSES = frozenset({408, 429, 500, 502, 503, 504})
|
|
42
|
+
|
|
43
|
+
# Network exception types we treat as transient
|
|
44
|
+
_NETWORK_ERRORS = (OSError, socket.timeout, http.client.HTTPException)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _should_retry(exc: BaseException) -> bool:
|
|
48
|
+
"""Return True if this exception warrants a retry attempt.
|
|
49
|
+
|
|
50
|
+
Every AudacityError carries a ``retryable`` attribute; for SdkError it is
|
|
51
|
+
per-instance (network failures retryable, decode failures not — spec §4).
|
|
52
|
+
"""
|
|
53
|
+
if isinstance(exc, AudacityError):
|
|
54
|
+
return bool(getattr(exc, "retryable", False))
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _backoff_seconds(attempt: int, exc: Optional[BaseException]) -> float:
|
|
59
|
+
"""
|
|
60
|
+
Compute sleep duration for retry *attempt* (1-based).
|
|
61
|
+
|
|
62
|
+
Formula: min(20, 0.5 * 2^attempt) with full jitter.
|
|
63
|
+
If the exception carries retry_after_seconds, use max(that, jittered).
|
|
64
|
+
"""
|
|
65
|
+
cap = min(20.0, 0.5 * (2 ** attempt))
|
|
66
|
+
jittered = random.uniform(0.0, cap)
|
|
67
|
+
|
|
68
|
+
retry_after: Optional[float] = getattr(exc, "retry_after_seconds", None)
|
|
69
|
+
if retry_after is not None:
|
|
70
|
+
return max(float(retry_after), jittered)
|
|
71
|
+
return jittered
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class Audacity:
|
|
75
|
+
"""
|
|
76
|
+
Audacity Investments LLM gateway client.
|
|
77
|
+
|
|
78
|
+
Constructor parameters
|
|
79
|
+
----------------------
|
|
80
|
+
api_key : str | None
|
|
81
|
+
API key (``audacity_api_…``). Falls back to ``AUDACITY_API_KEY`` env var.
|
|
82
|
+
base_url : str | None
|
|
83
|
+
Override gateway URL. Falls back to ``AUDACITY_BASE_URL`` env var, then
|
|
84
|
+
``https://portal.audacityinvestments.com``.
|
|
85
|
+
timeout : float
|
|
86
|
+
Per-socket-operation timeout in seconds (default 120): each connect,
|
|
87
|
+
write, and read must complete within this window. For ``converse``
|
|
88
|
+
this bounds every step of the request; for ``converse_stream`` it acts
|
|
89
|
+
as an idle timeout between stream reads and never bounds the total
|
|
90
|
+
stream duration (spec §1).
|
|
91
|
+
max_retries : int
|
|
92
|
+
Maximum number of retry attempts after a transient failure (default 2,
|
|
93
|
+
so up to 3 total attempts).
|
|
94
|
+
|
|
95
|
+
Attributes
|
|
96
|
+
----------
|
|
97
|
+
exceptions : _ExceptionsNamespace
|
|
98
|
+
boto3-parity namespace giving attribute access to all exception classes,
|
|
99
|
+
e.g. ``client.exceptions.ThrottlingException``.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
api_key: Optional[str] = None,
|
|
105
|
+
base_url: Optional[str] = None,
|
|
106
|
+
timeout: float = _DEFAULT_TIMEOUT,
|
|
107
|
+
max_retries: int = _DEFAULT_MAX_RETRIES,
|
|
108
|
+
) -> None:
|
|
109
|
+
resolved_key = api_key or os.environ.get("AUDACITY_API_KEY")
|
|
110
|
+
if not resolved_key:
|
|
111
|
+
raise MissingApiKeyError()
|
|
112
|
+
self._api_key = resolved_key
|
|
113
|
+
|
|
114
|
+
self._base_url = (
|
|
115
|
+
(base_url or os.environ.get("AUDACITY_BASE_URL") or _DEFAULT_BASE_URL)
|
|
116
|
+
.rstrip("/")
|
|
117
|
+
)
|
|
118
|
+
self._timeout = float(timeout)
|
|
119
|
+
self._max_retries = int(max_retries)
|
|
120
|
+
|
|
121
|
+
parsed = urllib.parse.urlparse(self._base_url)
|
|
122
|
+
self._scheme = parsed.scheme
|
|
123
|
+
self._host = parsed.netloc
|
|
124
|
+
# Path prefix for the request path
|
|
125
|
+
self._path_prefix = parsed.path.rstrip("/")
|
|
126
|
+
|
|
127
|
+
self._ssl_context: Optional[ssl.SSLContext] = (
|
|
128
|
+
ssl.create_default_context() if self._scheme == "https" else None
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
self.exceptions = _ExceptionsNamespace()
|
|
132
|
+
|
|
133
|
+
# ------------------------------------------------------------------
|
|
134
|
+
# Public API
|
|
135
|
+
# ------------------------------------------------------------------
|
|
136
|
+
|
|
137
|
+
def converse(self, **kwargs: Any) -> Dict[str, Any]:
|
|
138
|
+
"""
|
|
139
|
+
Non-streaming inference (Bedrock Converse parity).
|
|
140
|
+
|
|
141
|
+
Returns a dict with keys: output, stopReason, usage, metrics.
|
|
142
|
+
"""
|
|
143
|
+
body = build_request_body(kwargs, stream=False)
|
|
144
|
+
headers = self._make_headers(stream=False)
|
|
145
|
+
body_bytes = json.dumps(body).encode("utf-8")
|
|
146
|
+
|
|
147
|
+
raw, _, start_time = self._request_with_retries(body_bytes, headers, stream=False)
|
|
148
|
+
return map_converse_response(raw, start_time)
|
|
149
|
+
|
|
150
|
+
def converse_stream(self, **kwargs: Any) -> Dict[str, Any]:
|
|
151
|
+
"""
|
|
152
|
+
Streaming inference (Bedrock ConverseStream parity).
|
|
153
|
+
|
|
154
|
+
Returns ``{"stream": <generator of event dicts>}``.
|
|
155
|
+
Iterate ``response["stream"]`` to receive events.
|
|
156
|
+
|
|
157
|
+
Retries apply only until the HTTP 200 response is received.
|
|
158
|
+
Once the first SSE byte is consumed the generator owns the connection;
|
|
159
|
+
mid-stream failures surface as ``ModelStreamErrorException``.
|
|
160
|
+
"""
|
|
161
|
+
body = build_request_body(kwargs, stream=True)
|
|
162
|
+
headers = self._make_headers(stream=True)
|
|
163
|
+
body_bytes = json.dumps(body).encode("utf-8")
|
|
164
|
+
|
|
165
|
+
response, conn, start_time = self._request_with_retries(body_bytes, headers, stream=True)
|
|
166
|
+
# Ownership of conn + response passes to the generator
|
|
167
|
+
return {"stream": iter_stream_events(response, conn, start_time)}
|
|
168
|
+
|
|
169
|
+
# ------------------------------------------------------------------
|
|
170
|
+
# Internal helpers
|
|
171
|
+
# ------------------------------------------------------------------
|
|
172
|
+
|
|
173
|
+
def _request_with_retries(
|
|
174
|
+
self,
|
|
175
|
+
body_bytes: bytes,
|
|
176
|
+
headers: Dict[str, str],
|
|
177
|
+
stream: bool,
|
|
178
|
+
) -> Tuple[Any, Optional[http.client.HTTPConnection], float]:
|
|
179
|
+
"""
|
|
180
|
+
POST the request, applying the retry policy (spec §4): up to
|
|
181
|
+
``max_retries`` retries on network errors and retryable HTTP errors,
|
|
182
|
+
with capped exponential backoff, full jitter, and Retry-After support.
|
|
183
|
+
|
|
184
|
+
Returns a tuple for the successful (HTTP 200) attempt:
|
|
185
|
+
stream=False → ``(raw_body_bytes, None, start_time)`` — body fully
|
|
186
|
+
read, connection closed.
|
|
187
|
+
stream=True → ``(open_response, open_connection, start_time)`` —
|
|
188
|
+
no body bytes consumed; the caller takes ownership of both.
|
|
189
|
+
|
|
190
|
+
Raises the mapped AudacityError (or SdkError for network failures)
|
|
191
|
+
once retries are exhausted or the error is not retryable.
|
|
192
|
+
"""
|
|
193
|
+
path = self._path_prefix + "/v1/chat/completions"
|
|
194
|
+
|
|
195
|
+
last_exc: Optional[BaseException] = None
|
|
196
|
+
for attempt in range(self._max_retries + 1):
|
|
197
|
+
if attempt > 0:
|
|
198
|
+
time.sleep(_backoff_seconds(attempt, last_exc))
|
|
199
|
+
|
|
200
|
+
start_time = time.time()
|
|
201
|
+
conn = self._new_connection()
|
|
202
|
+
handed_off = False
|
|
203
|
+
try:
|
|
204
|
+
conn.request("POST", path, body_bytes, headers)
|
|
205
|
+
response = conn.getresponse()
|
|
206
|
+
status = response.status
|
|
207
|
+
if status == 200 and stream:
|
|
208
|
+
# Hand off ownership of conn + response to the caller
|
|
209
|
+
handed_off = True
|
|
210
|
+
return response, conn, start_time
|
|
211
|
+
raw = response.read()
|
|
212
|
+
except _NETWORK_ERRORS as e:
|
|
213
|
+
exc = SdkError(str(e), cause=e)
|
|
214
|
+
if attempt < self._max_retries:
|
|
215
|
+
last_exc = exc
|
|
216
|
+
continue
|
|
217
|
+
raise exc from e
|
|
218
|
+
finally:
|
|
219
|
+
if not handed_off:
|
|
220
|
+
conn.close()
|
|
221
|
+
|
|
222
|
+
if status == 200:
|
|
223
|
+
return raw, None, start_time
|
|
224
|
+
|
|
225
|
+
# Parse the error
|
|
226
|
+
retry_after_hdr = response.headers.get("Retry-After") if response.headers else None
|
|
227
|
+
exc = parse_error_response(status, raw, retry_after_hdr)
|
|
228
|
+
if attempt < self._max_retries and _should_retry(exc):
|
|
229
|
+
last_exc = exc
|
|
230
|
+
continue
|
|
231
|
+
raise exc
|
|
232
|
+
|
|
233
|
+
# Should not reach here, but make the type-checker happy
|
|
234
|
+
if last_exc is not None:
|
|
235
|
+
raise last_exc
|
|
236
|
+
raise SdkError("Exhausted retries") # pragma: no cover
|
|
237
|
+
|
|
238
|
+
def _new_connection(self) -> http.client.HTTPConnection:
|
|
239
|
+
if self._scheme == "https":
|
|
240
|
+
return http.client.HTTPSConnection(
|
|
241
|
+
self._host,
|
|
242
|
+
timeout=self._timeout,
|
|
243
|
+
context=self._ssl_context,
|
|
244
|
+
)
|
|
245
|
+
return http.client.HTTPConnection(
|
|
246
|
+
self._host,
|
|
247
|
+
timeout=self._timeout,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def _make_headers(self, stream: bool) -> Dict[str, str]:
|
|
251
|
+
return {
|
|
252
|
+
"Authorization": "Bearer " + self._api_key,
|
|
253
|
+
"Content-Type": "application/json",
|
|
254
|
+
"Accept": "text/event-stream" if stream else "application/json",
|
|
255
|
+
"User-Agent": _USER_AGENT,
|
|
256
|
+
}
|
audacity/_mapping.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Request/response mapping — spec §3 (normative).
|
|
3
|
+
|
|
4
|
+
Transforms between Bedrock-shaped dicts and OpenAI wire format.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import base64
|
|
9
|
+
import json
|
|
10
|
+
import time
|
|
11
|
+
from typing import Any, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
from .exceptions import SdkError, ValidationException
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# ---------------------------------------------------------------------------
|
|
17
|
+
# Helpers
|
|
18
|
+
# ---------------------------------------------------------------------------
|
|
19
|
+
|
|
20
|
+
def _map_finish_reason(finish_reason: Optional[str]) -> str:
|
|
21
|
+
"""Map OpenAI finish_reason to Bedrock stopReason."""
|
|
22
|
+
if finish_reason == "stop":
|
|
23
|
+
return "end_turn"
|
|
24
|
+
if finish_reason == "length":
|
|
25
|
+
return "max_tokens"
|
|
26
|
+
if finish_reason in ("tool_calls", "function_call"):
|
|
27
|
+
return "tool_use"
|
|
28
|
+
if finish_reason == "content_filter":
|
|
29
|
+
return "content_filtered"
|
|
30
|
+
return "end_turn"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _map_usage(usage: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
|
34
|
+
"""Map an OpenAI usage object to the Bedrock usage shape."""
|
|
35
|
+
usage = usage or {}
|
|
36
|
+
return {
|
|
37
|
+
"inputTokens": usage.get("prompt_tokens", 0),
|
|
38
|
+
"outputTokens": usage.get("completion_tokens", 0),
|
|
39
|
+
"totalTokens": usage.get("total_tokens", 0),
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _toolresult_content_text(content_list: List[Dict[str, Any]]) -> str:
|
|
44
|
+
"""Flatten toolResult content blocks into a single string."""
|
|
45
|
+
parts: List[str] = []
|
|
46
|
+
for block in content_list:
|
|
47
|
+
if "text" in block:
|
|
48
|
+
parts.append(block["text"])
|
|
49
|
+
elif "json" in block:
|
|
50
|
+
parts.append(json.dumps(block["json"]))
|
|
51
|
+
return "".join(parts)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _image_block_to_part(image: Dict[str, Any]) -> Dict[str, Any]:
|
|
55
|
+
"""
|
|
56
|
+
Map a Bedrock-style image block to an OpenAI image_url content part (spec §3).
|
|
57
|
+
|
|
58
|
+
source.url is passed through verbatim; source.bytes is base64-encoded into a
|
|
59
|
+
data URL. A str passed as bytes is treated as already-base64 (not re-encoded).
|
|
60
|
+
"""
|
|
61
|
+
fmt = image.get("format", "png")
|
|
62
|
+
source = image.get("source", {})
|
|
63
|
+
|
|
64
|
+
if "url" in source:
|
|
65
|
+
url = source["url"]
|
|
66
|
+
else:
|
|
67
|
+
data = source.get("bytes", b"")
|
|
68
|
+
if isinstance(data, str):
|
|
69
|
+
b64 = data
|
|
70
|
+
else:
|
|
71
|
+
b64 = base64.b64encode(bytes(data)).decode("ascii")
|
|
72
|
+
url = f"data:image/{fmt};base64,{b64}"
|
|
73
|
+
|
|
74
|
+
return {"type": "image_url", "image_url": {"url": url}}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# ---------------------------------------------------------------------------
|
|
78
|
+
# Converse input → OpenAI request body (spec §3)
|
|
79
|
+
# ---------------------------------------------------------------------------
|
|
80
|
+
|
|
81
|
+
def _build_openai_messages(
|
|
82
|
+
canonical_messages: List[Dict[str, Any]],
|
|
83
|
+
system: Optional[List[Dict[str, Any]]],
|
|
84
|
+
) -> List[Dict[str, Any]]:
|
|
85
|
+
"""
|
|
86
|
+
Convert Bedrock-shaped messages + system prompt to an OpenAI messages list.
|
|
87
|
+
|
|
88
|
+
Mapping rules (spec §3 steps 2–3):
|
|
89
|
+
- system blocks: join all .text with "\\n\\n" → single leading system message.
|
|
90
|
+
- user turn:
|
|
91
|
+
toolResult blocks → role:tool messages (emitted first)
|
|
92
|
+
text blocks → joined role:user message (emitted after tool messages)
|
|
93
|
+
- assistant turn:
|
|
94
|
+
content = joined text blocks (or null if none)
|
|
95
|
+
tool_calls = mapped toolUse blocks (omitted if none)
|
|
96
|
+
"""
|
|
97
|
+
result: List[Dict[str, Any]] = []
|
|
98
|
+
|
|
99
|
+
# Step 2: system
|
|
100
|
+
if system:
|
|
101
|
+
texts = [b.get("text", "") for b in system if "text" in b]
|
|
102
|
+
joined = "\n\n".join(texts)
|
|
103
|
+
if joined:
|
|
104
|
+
result.append({"role": "system", "content": joined})
|
|
105
|
+
|
|
106
|
+
# Step 3: messages
|
|
107
|
+
for msg in canonical_messages:
|
|
108
|
+
role = msg.get("role", "")
|
|
109
|
+
content_blocks: List[Dict[str, Any]] = msg.get("content", [])
|
|
110
|
+
|
|
111
|
+
if role == "user":
|
|
112
|
+
tool_msgs: List[Dict[str, Any]] = []
|
|
113
|
+
# Ordered (kind, payload) pairs for text/image blocks so the
|
|
114
|
+
# multimodal path preserves original block order (spec §3).
|
|
115
|
+
user_parts: List[Dict[str, Any]] = []
|
|
116
|
+
has_image = False
|
|
117
|
+
|
|
118
|
+
for block in content_blocks:
|
|
119
|
+
if "toolResult" in block:
|
|
120
|
+
tr = block["toolResult"]
|
|
121
|
+
text = _toolresult_content_text(tr.get("content", []))
|
|
122
|
+
tool_msgs.append({
|
|
123
|
+
"role": "tool",
|
|
124
|
+
"tool_call_id": tr.get("toolUseId", ""),
|
|
125
|
+
"content": text,
|
|
126
|
+
})
|
|
127
|
+
elif "text" in block:
|
|
128
|
+
user_parts.append({"type": "text", "text": block["text"]})
|
|
129
|
+
elif "image" in block:
|
|
130
|
+
has_image = True
|
|
131
|
+
user_parts.append(_image_block_to_part(block["image"]))
|
|
132
|
+
|
|
133
|
+
result.extend(tool_msgs)
|
|
134
|
+
if has_image:
|
|
135
|
+
result.append({"role": "user", "content": user_parts})
|
|
136
|
+
elif user_parts:
|
|
137
|
+
# Text-only turn keeps the plain-string form (spec §3).
|
|
138
|
+
result.append({
|
|
139
|
+
"role": "user",
|
|
140
|
+
"content": "\n".join(p["text"] for p in user_parts),
|
|
141
|
+
})
|
|
142
|
+
|
|
143
|
+
elif role == "assistant":
|
|
144
|
+
text_parts = []
|
|
145
|
+
tool_calls: List[Dict[str, Any]] = []
|
|
146
|
+
|
|
147
|
+
for block in content_blocks:
|
|
148
|
+
if "text" in block:
|
|
149
|
+
text_parts.append(block["text"])
|
|
150
|
+
elif "toolUse" in block:
|
|
151
|
+
tu = block["toolUse"]
|
|
152
|
+
tool_calls.append({
|
|
153
|
+
"id": tu.get("toolUseId", ""),
|
|
154
|
+
"type": "function",
|
|
155
|
+
"function": {
|
|
156
|
+
"name": tu.get("name", ""),
|
|
157
|
+
"arguments": json.dumps(tu.get("input", {})),
|
|
158
|
+
},
|
|
159
|
+
})
|
|
160
|
+
|
|
161
|
+
oai_msg: Dict[str, Any] = {"role": "assistant"}
|
|
162
|
+
oai_msg["content"] = "\n".join(text_parts) if text_parts else None
|
|
163
|
+
if tool_calls:
|
|
164
|
+
oai_msg["tool_calls"] = tool_calls
|
|
165
|
+
result.append(oai_msg)
|
|
166
|
+
|
|
167
|
+
return result
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _build_tool_list(
|
|
171
|
+
tools: List[Dict[str, Any]],
|
|
172
|
+
) -> List[Dict[str, Any]]:
|
|
173
|
+
"""Map toolConfig.tools to OpenAI tools array."""
|
|
174
|
+
out: List[Dict[str, Any]] = []
|
|
175
|
+
for t in tools:
|
|
176
|
+
spec = t.get("toolSpec", {})
|
|
177
|
+
func: Dict[str, Any] = {"name": spec.get("name", "")}
|
|
178
|
+
if spec.get("description"):
|
|
179
|
+
func["description"] = spec["description"]
|
|
180
|
+
input_schema = spec.get("inputSchema", {})
|
|
181
|
+
if "json" in input_schema:
|
|
182
|
+
func["parameters"] = input_schema["json"]
|
|
183
|
+
out.append({"type": "function", "function": func})
|
|
184
|
+
return out
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _build_tool_choice(tc: Dict[str, Any]) -> Any:
|
|
188
|
+
"""Map toolConfig.toolChoice to OpenAI tool_choice."""
|
|
189
|
+
if "auto" in tc:
|
|
190
|
+
return "auto"
|
|
191
|
+
if "any" in tc:
|
|
192
|
+
return "required"
|
|
193
|
+
if "tool" in tc:
|
|
194
|
+
return {"type": "function", "function": {"name": tc["tool"]["name"]}}
|
|
195
|
+
return "auto"
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def build_request_body(kwargs: Dict[str, Any], stream: bool = False) -> Dict[str, Any]:
|
|
199
|
+
"""
|
|
200
|
+
Construct the full OpenAI request body from Bedrock-shaped kwargs.
|
|
201
|
+
|
|
202
|
+
Spec §3 steps 1–7.
|
|
203
|
+
"""
|
|
204
|
+
body: Dict[str, Any] = {}
|
|
205
|
+
|
|
206
|
+
# Client-side validation (spec §2: modelId REQUIRED, messages min 1)
|
|
207
|
+
model_id = kwargs.get("modelId")
|
|
208
|
+
if not model_id:
|
|
209
|
+
raise ValidationException(
|
|
210
|
+
message="modelId is required.",
|
|
211
|
+
status_code=0,
|
|
212
|
+
)
|
|
213
|
+
messages = kwargs.get("messages")
|
|
214
|
+
if not messages:
|
|
215
|
+
raise ValidationException(
|
|
216
|
+
message="messages is required and must contain at least one message.",
|
|
217
|
+
status_code=0,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Step 1: modelId → model
|
|
221
|
+
body["model"] = model_id
|
|
222
|
+
|
|
223
|
+
# Steps 2–3: system + messages
|
|
224
|
+
body["messages"] = _build_openai_messages(
|
|
225
|
+
canonical_messages=messages,
|
|
226
|
+
system=kwargs.get("system"),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Step 4: inferenceConfig
|
|
230
|
+
ic = kwargs.get("inferenceConfig") or {}
|
|
231
|
+
if "maxTokens" in ic:
|
|
232
|
+
body["max_tokens"] = ic["maxTokens"]
|
|
233
|
+
if "temperature" in ic:
|
|
234
|
+
body["temperature"] = ic["temperature"]
|
|
235
|
+
if "topP" in ic:
|
|
236
|
+
body["top_p"] = ic["topP"]
|
|
237
|
+
if "stopSequences" in ic:
|
|
238
|
+
body["stop"] = ic["stopSequences"]
|
|
239
|
+
|
|
240
|
+
# Step 5: toolConfig
|
|
241
|
+
tc_cfg = kwargs.get("toolConfig") or {}
|
|
242
|
+
if "tools" in tc_cfg:
|
|
243
|
+
body["tools"] = _build_tool_list(tc_cfg["tools"])
|
|
244
|
+
if "toolChoice" in tc_cfg:
|
|
245
|
+
body["tool_choice"] = _build_tool_choice(tc_cfg["toolChoice"])
|
|
246
|
+
|
|
247
|
+
# Step 7: stream flag
|
|
248
|
+
if stream:
|
|
249
|
+
body["stream"] = True
|
|
250
|
+
|
|
251
|
+
# Step 6: additionalModelRequestFields (shallow-merged last)
|
|
252
|
+
amrf = kwargs.get("additionalModelRequestFields") or {}
|
|
253
|
+
for k, v in amrf.items():
|
|
254
|
+
body[k] = v
|
|
255
|
+
|
|
256
|
+
return body
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
# ---------------------------------------------------------------------------
|
|
260
|
+
# OpenAI response → Converse output (spec §3)
|
|
261
|
+
# ---------------------------------------------------------------------------
|
|
262
|
+
|
|
263
|
+
def map_converse_response(
|
|
264
|
+
raw_body: bytes,
|
|
265
|
+
start_time: float,
|
|
266
|
+
) -> Dict[str, Any]:
|
|
267
|
+
"""
|
|
268
|
+
Parse a non-streaming HTTP 200 response into the Bedrock Converse output shape.
|
|
269
|
+
|
|
270
|
+
Implements the defensive unwrap rule (spec §1):
|
|
271
|
+
if parsed 200 body has no 'choices' but has 'data.choices', use body['data'].
|
|
272
|
+
"""
|
|
273
|
+
try:
|
|
274
|
+
parsed = json.loads(raw_body)
|
|
275
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
276
|
+
raise SdkError(
|
|
277
|
+
f"Failed to decode 200 response body as JSON: {e}",
|
|
278
|
+
cause=e,
|
|
279
|
+
retryable=False,
|
|
280
|
+
) from e
|
|
281
|
+
|
|
282
|
+
if not isinstance(parsed, dict):
|
|
283
|
+
raise SdkError(
|
|
284
|
+
"Malformed 200 response: body is not a JSON object",
|
|
285
|
+
retryable=False,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Defensive unwrap
|
|
289
|
+
if "choices" not in parsed and isinstance(parsed.get("data"), dict):
|
|
290
|
+
if "choices" in parsed["data"]:
|
|
291
|
+
parsed = parsed["data"]
|
|
292
|
+
|
|
293
|
+
choices = parsed.get("choices")
|
|
294
|
+
if not isinstance(choices, list) or not choices:
|
|
295
|
+
raise SdkError(
|
|
296
|
+
"Malformed 200 response: missing or empty 'choices'",
|
|
297
|
+
retryable=False,
|
|
298
|
+
)
|
|
299
|
+
choice = choices[0]
|
|
300
|
+
msg = choice.get("message") if isinstance(choice, dict) else None
|
|
301
|
+
if not isinstance(msg, dict):
|
|
302
|
+
raise SdkError(
|
|
303
|
+
"Malformed 200 response: missing 'message' in choices[0]",
|
|
304
|
+
retryable=False,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
content_blocks: List[Dict[str, Any]] = []
|
|
308
|
+
|
|
309
|
+
# text content
|
|
310
|
+
if msg.get("content"):
|
|
311
|
+
content_blocks.append({"text": msg["content"]})
|
|
312
|
+
|
|
313
|
+
# tool_calls
|
|
314
|
+
for tc in msg.get("tool_calls") or []:
|
|
315
|
+
func = tc.get("function") or {}
|
|
316
|
+
raw_args = func.get("arguments") or "{}"
|
|
317
|
+
try:
|
|
318
|
+
input_val: Any = json.loads(raw_args)
|
|
319
|
+
except (json.JSONDecodeError, ValueError):
|
|
320
|
+
input_val = raw_args
|
|
321
|
+
content_blocks.append({
|
|
322
|
+
"toolUse": {
|
|
323
|
+
"toolUseId": tc.get("id", ""),
|
|
324
|
+
"name": func.get("name", ""),
|
|
325
|
+
"input": input_val,
|
|
326
|
+
}
|
|
327
|
+
})
|
|
328
|
+
|
|
329
|
+
finish_reason = choice.get("finish_reason")
|
|
330
|
+
stop_reason = _map_finish_reason(finish_reason)
|
|
331
|
+
|
|
332
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
333
|
+
|
|
334
|
+
return {
|
|
335
|
+
"output": {
|
|
336
|
+
"message": {
|
|
337
|
+
"role": "assistant",
|
|
338
|
+
"content": content_blocks,
|
|
339
|
+
}
|
|
340
|
+
},
|
|
341
|
+
"stopReason": stop_reason,
|
|
342
|
+
"usage": _map_usage(parsed.get("usage")),
|
|
343
|
+
"metrics": {
|
|
344
|
+
"latencyMs": latency_ms,
|
|
345
|
+
},
|
|
346
|
+
}
|