liter_llm 1.0.0.pre.rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +239 -0
  3. data/ext/liter_llm_rb/extconf.rb +65 -0
  4. data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
  5. data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
  6. data/ext/liter_llm_rb/native/Cargo.toml +32 -0
  7. data/ext/liter_llm_rb/native/build.rs +15 -0
  8. data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
  9. data/lib/liter_llm.rb +8 -0
  10. data/sig/liter_llm.rbs +416 -0
  11. data/vendor/Cargo.toml +54 -0
  12. data/vendor/liter-llm/Cargo.toml +92 -0
  13. data/vendor/liter-llm/README.md +252 -0
  14. data/vendor/liter-llm/schemas/pricing.json +40 -0
  15. data/vendor/liter-llm/schemas/providers.json +1662 -0
  16. data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
  17. data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
  18. data/vendor/liter-llm/src/auth/mod.rs +68 -0
  19. data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
  20. data/vendor/liter-llm/src/client/config.rs +351 -0
  21. data/vendor/liter-llm/src/client/managed.rs +622 -0
  22. data/vendor/liter-llm/src/client/mod.rs +864 -0
  23. data/vendor/liter-llm/src/cost.rs +212 -0
  24. data/vendor/liter-llm/src/error.rs +190 -0
  25. data/vendor/liter-llm/src/http/eventstream.rs +860 -0
  26. data/vendor/liter-llm/src/http/mod.rs +12 -0
  27. data/vendor/liter-llm/src/http/request.rs +438 -0
  28. data/vendor/liter-llm/src/http/retry.rs +72 -0
  29. data/vendor/liter-llm/src/http/streaming.rs +289 -0
  30. data/vendor/liter-llm/src/lib.rs +37 -0
  31. data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
  32. data/vendor/liter-llm/src/provider/azure.rs +579 -0
  33. data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
  34. data/vendor/liter-llm/src/provider/cohere.rs +654 -0
  35. data/vendor/liter-llm/src/provider/custom.rs +404 -0
  36. data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
  37. data/vendor/liter-llm/src/provider/mistral.rs +188 -0
  38. data/vendor/liter-llm/src/provider/mod.rs +616 -0
  39. data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
  40. data/vendor/liter-llm/src/tests.rs +1425 -0
  41. data/vendor/liter-llm/src/tokenizer.rs +281 -0
  42. data/vendor/liter-llm/src/tower/budget.rs +599 -0
  43. data/vendor/liter-llm/src/tower/cache.rs +502 -0
  44. data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
  45. data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
  46. data/vendor/liter-llm/src/tower/cost.rs +404 -0
  47. data/vendor/liter-llm/src/tower/fallback.rs +121 -0
  48. data/vendor/liter-llm/src/tower/health.rs +219 -0
  49. data/vendor/liter-llm/src/tower/hooks.rs +369 -0
  50. data/vendor/liter-llm/src/tower/mod.rs +77 -0
  51. data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
  52. data/vendor/liter-llm/src/tower/router.rs +436 -0
  53. data/vendor/liter-llm/src/tower/service.rs +181 -0
  54. data/vendor/liter-llm/src/tower/tests.rs +539 -0
  55. data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
  56. data/vendor/liter-llm/src/tower/tracing.rs +209 -0
  57. data/vendor/liter-llm/src/tower/types.rs +170 -0
  58. data/vendor/liter-llm/src/types/audio.rs +52 -0
  59. data/vendor/liter-llm/src/types/batch.rs +77 -0
  60. data/vendor/liter-llm/src/types/chat.rs +214 -0
  61. data/vendor/liter-llm/src/types/common.rs +244 -0
  62. data/vendor/liter-llm/src/types/embedding.rs +84 -0
  63. data/vendor/liter-llm/src/types/files.rs +58 -0
  64. data/vendor/liter-llm/src/types/image.rs +40 -0
  65. data/vendor/liter-llm/src/types/mod.rs +27 -0
  66. data/vendor/liter-llm/src/types/models.rs +21 -0
  67. data/vendor/liter-llm/src/types/moderation.rs +80 -0
  68. data/vendor/liter-llm/src/types/ocr.rs +87 -0
  69. data/vendor/liter-llm/src/types/rerank.rs +46 -0
  70. data/vendor/liter-llm/src/types/responses.rs +55 -0
  71. data/vendor/liter-llm/src/types/search.rs +45 -0
  72. data/vendor/liter-llm/tests/contract.rs +332 -0
  73. data/vendor/liter-llm-ffi/Cargo.toml +30 -0
  74. data/vendor/liter-llm-ffi/build.rs +66 -0
  75. data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
  76. data/vendor/liter-llm-ffi/liter_llm.h +850 -0
  77. data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
  78. metadata +286 -0
@@ -0,0 +1,1543 @@
1
+ use std::borrow::Cow;
2
+
3
+ #[cfg(feature = "bedrock")]
4
+ use crate::error::LiterLlmError;
5
+ use crate::error::Result;
6
+ use crate::provider::{Provider, StreamFormat};
7
+ use crate::types::ChatCompletionChunk;
8
+
9
+ /// Default AWS region for Bedrock when none is specified.
10
+ const DEFAULT_REGION: &str = "us-east-1";
11
+
12
+ /// Map reasoning effort levels to budget_tokens for Claude-on-Bedrock extended thinking.
13
+ fn reasoning_effort_to_budget_tokens(effort: &str) -> u64 {
14
+ match effort {
15
+ "low" => 1024,
16
+ "medium" => 4096,
17
+ "high" => 16384,
18
+ _ => 4096, // default to medium
19
+ }
20
+ }
21
+
22
+ /// Extract a document format from a MIME type string.
23
+ ///
24
+ /// E.g. `"application/pdf"` → `"pdf"`, `"text/csv"` → `"csv"`.
25
+ fn format_from_media_type(media_type: &str) -> &str {
26
+ // Use the subtype portion after the slash.
27
+ media_type.split('/').nth(1).unwrap_or("pdf")
28
+ }
29
+
30
+ /// Percent-encode a model ID for use in a URL path segment.
31
+ ///
32
+ /// Bedrock model IDs can contain colons and slashes that must be encoded.
33
+ fn percent_encode_model(model: &str) -> String {
34
+ let mut encoded = String::with_capacity(model.len());
35
+ for byte in model.bytes() {
36
+ match byte {
37
+ // Unreserved characters per RFC 3986 §2.3 — safe to pass through.
38
+ b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
39
+ encoded.push(byte as char);
40
+ }
41
+ other => {
42
+ encoded.push('%');
43
+ // RFC 3986 §2.1 requires uppercase hex digits.
44
+ let hi = char::from_digit(u32::from(other >> 4), 16).unwrap_or('0');
45
+ let lo = char::from_digit(u32::from(other & 0xf), 16).unwrap_or('0');
46
+ encoded.push(hi.to_ascii_uppercase());
47
+ encoded.push(lo.to_ascii_uppercase());
48
+ }
49
+ }
50
+ }
51
+ encoded
52
+ }
53
+
54
+ /// AWS Bedrock provider.
55
+ ///
56
+ /// Differences from the OpenAI-compatible baseline:
57
+ /// - Routes `bedrock/` prefixed model names to the Bedrock runtime endpoint.
58
+ /// - The model prefix is stripped before the model ID is sent in the request.
59
+ /// - When the `bedrock` feature is enabled, every request is signed with
60
+ /// AWS Signature Version 4 using credentials from the environment
61
+ /// (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN`).
62
+ /// - When the `bedrock` feature is disabled, the provider is usable with a
63
+ /// `base_url` override (e.g. in tests against a mock server) without any
64
+ /// signing.
65
+ ///
66
+ /// # Region resolution
67
+ ///
68
+ /// The region is resolved in priority order:
69
+ /// 1. Explicit value passed to [`BedrockProvider::with_region`].
70
+ /// 2. `AWS_DEFAULT_REGION` environment variable.
71
+ /// 3. `AWS_REGION` environment variable.
72
+ /// 4. Hard-coded default: `us-east-1`.
73
+ ///
74
+ /// # Configuration
75
+ ///
76
+ /// ```rust,ignore
77
+ /// let config = ClientConfigBuilder::new("unused-for-sigv4")
78
+ /// .build();
79
+ /// let client = DefaultClient::new(config, Some("bedrock/anthropic.claude-3-sonnet-20240229-v1:0"))?;
80
+ /// ```
81
+ pub struct BedrockProvider {
82
+ #[allow(dead_code)] // used by region() accessor and in sigv4_sign
83
+ region: String,
84
+ /// Cached base URL: `https://bedrock-runtime.{region}.amazonaws.com`.
85
+ base_url: String,
86
+ /// Cached cross-region prefix from `BEDROCK_CROSS_REGION` env var at
87
+ /// construction time (e.g. `Some("us.")`) so we avoid reading the
88
+ /// environment on every request.
89
+ cross_region_prefix: Option<String>,
90
+ }
91
+
92
+ impl BedrockProvider {
93
+ /// Construct with the given AWS region.
94
+ #[must_use]
95
+ pub fn new(region: impl Into<String>) -> Self {
96
+ let region = region.into();
97
+ let base_url = format!("https://bedrock-runtime.{region}.amazonaws.com");
98
+ let cross_region_prefix = std::env::var("BEDROCK_CROSS_REGION")
99
+ .ok()
100
+ .filter(|v| !v.is_empty())
101
+ .map(|v| format!("{v}."));
102
+ Self {
103
+ region,
104
+ base_url,
105
+ cross_region_prefix,
106
+ }
107
+ }
108
+
109
+ /// Construct using region from the environment, falling back to `us-east-1`.
110
+ ///
111
+ /// Reads `AWS_DEFAULT_REGION` then `AWS_REGION`.
112
+ #[must_use]
113
+ pub fn from_env() -> Self {
114
+ let region = std::env::var("AWS_DEFAULT_REGION")
115
+ .or_else(|_| std::env::var("AWS_REGION"))
116
+ .unwrap_or_else(|_| DEFAULT_REGION.to_owned());
117
+ Self::new(region)
118
+ }
119
+
120
+ /// Return the AWS region this provider is configured for.
121
+ #[must_use]
122
+ #[allow(dead_code)] // useful for consumers of the library
123
+ pub fn region(&self) -> &str {
124
+ &self.region
125
+ }
126
+ }
127
+
128
+ impl Provider for BedrockProvider {
129
+ fn name(&self) -> &str {
130
+ "bedrock"
131
+ }
132
+
133
+ /// Base URL for the Bedrock runtime service.
134
+ ///
135
+ /// When a `base_url` override is set in [`ClientConfig`] (as in tests),
136
+ /// the override takes precedence and this value is never used.
137
+ fn base_url(&self) -> &str {
138
+ &self.base_url
139
+ }
140
+
141
+ /// Bedrock uses SigV4 signing rather than a static authorization header.
142
+ ///
143
+ /// Returns `None` so the HTTP layer skips adding an `Authorization` header.
144
+ /// Actual signing headers are injected by [`BedrockProvider::signing_headers`]
145
+ /// when the `bedrock` feature is enabled.
146
+ fn auth_header<'a>(&'a self, _api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
147
+ None
148
+ }
149
+
150
+ fn matches_model(&self, model: &str) -> bool {
151
+ model.starts_with("bedrock/")
152
+ }
153
+
154
+ fn strip_model_prefix<'m>(&self, model: &'m str) -> &'m str {
155
+ model.strip_prefix("bedrock/").unwrap_or(model)
156
+ }
157
+
158
+ /// Validate that the provider is usable in the current environment.
159
+ ///
160
+ /// When the `bedrock` feature is enabled, checks that AWS credentials are
161
+ /// available in the environment (`AWS_ACCESS_KEY_ID` at minimum). Without
162
+ /// credentials, every real Bedrock request will be rejected with a 403.
163
+ ///
164
+ /// When the `bedrock` feature is disabled (e.g. in tests with `base_url`
165
+ /// override), validation is skipped so callers can connect to a mock server
166
+ /// without real AWS credentials.
167
+ fn validate(&self) -> Result<()> {
168
+ #[cfg(feature = "bedrock")]
169
+ {
170
+ if std::env::var("AWS_ACCESS_KEY_ID").is_err() {
171
+ return Err(LiterLlmError::BadRequest {
172
+ message: "AWS Bedrock requires AWS credentials. \
173
+ Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY (and optionally \
174
+ AWS_SESSION_TOKEN) in the environment."
175
+ .into(),
176
+ });
177
+ }
178
+ }
179
+ Ok(())
180
+ }
181
+
182
+ /// Bedrock uses AWS EventStream binary framing, not SSE.
183
+ fn stream_format(&self) -> StreamFormat {
184
+ StreamFormat::AwsEventStream
185
+ }
186
+
187
+ /// Build the full URL for a Bedrock Converse API request.
188
+ ///
189
+ /// Chat completions map to `/model/{encoded_model}/converse`.
190
+ /// Embeddings map to `/model/{encoded_model}/invoke`.
191
+ /// All other paths are passed through unchanged.
192
+ ///
193
+ /// When the `BEDROCK_CROSS_REGION` environment variable is set, the
194
+ /// cross-region inference profile prefix is prepended to the model ID.
195
+ /// For example, with `BEDROCK_CROSS_REGION=us`, model
196
+ /// `anthropic.claude-3-sonnet-20240229-v1:0` becomes
197
+ /// `us.anthropic.claude-3-sonnet-20240229-v1:0`.
198
+ fn build_url(&self, endpoint_path: &str, model: &str) -> String {
199
+ let base = self.base_url();
200
+ let effective_model = self.apply_cross_region_prefix(model);
201
+ let encoded_model = percent_encode_model(&effective_model);
202
+ if endpoint_path.contains("chat/completions") {
203
+ format!("{base}/model/{encoded_model}/converse")
204
+ } else if endpoint_path.contains("embeddings") {
205
+ format!("{base}/model/{encoded_model}/invoke")
206
+ } else {
207
+ format!("{base}{endpoint_path}")
208
+ }
209
+ }
210
+
211
+ /// Build the streaming URL: `/model/{id}/converse-stream`.
212
+ fn build_stream_url(&self, endpoint_path: &str, model: &str) -> String {
213
+ let base = self.base_url();
214
+ let effective_model = self.apply_cross_region_prefix(model);
215
+ let encoded_model = percent_encode_model(&effective_model);
216
+ if endpoint_path.contains("chat/completions") {
217
+ format!("{base}/model/{encoded_model}/converse-stream")
218
+ } else {
219
+ // Non-chat streaming falls back to the regular URL.
220
+ self.build_url(endpoint_path, model)
221
+ }
222
+ }
223
+
224
+ /// Convert an OpenAI-style chat request to Bedrock Converse API format.
225
+ ///
226
+ /// Key differences from the OpenAI format:
227
+ /// - System messages are extracted to a top-level `system` array.
228
+ /// - Messages use `content` arrays with typed blocks (`text`, `toolUse`, `toolResult`).
229
+ /// - Generation parameters live in `inferenceConfig`.
230
+ /// - Tools are described in `toolConfig.tools[].toolSpec`.
231
+ fn transform_request(&self, body: &mut serde_json::Value) -> Result<()> {
232
+ use serde_json::json;
233
+
234
+ // Take ownership of the messages array to avoid cloning.
235
+ let messages = body
236
+ .as_object_mut()
237
+ .and_then(|o| o.remove("messages"))
238
+ .and_then(|v| match v {
239
+ serde_json::Value::Array(arr) => Some(arr),
240
+ _ => None,
241
+ })
242
+ .unwrap_or_default();
243
+
244
+ let mut system_parts = vec![];
245
+ let mut converse_messages = vec![];
246
+
247
+ for msg in &messages {
248
+ let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("");
249
+ let content = msg.get("content");
250
+
251
+ match role {
252
+ "system" | "developer" => {
253
+ if let Some(text) = content.and_then(|c| c.as_str()) {
254
+ system_parts.push(json!({"text": text}));
255
+ } else if let Some(array) = content.and_then(|c| c.as_array()) {
256
+ // System content may be a list of typed blocks; extract text parts.
257
+ for part in array {
258
+ if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
259
+ system_parts.push(json!({"text": text}));
260
+ }
261
+ }
262
+ }
263
+ }
264
+ "user" => {
265
+ let parts = if let Some(text) = content.and_then(|c| c.as_str()) {
266
+ vec![json!({"text": text})]
267
+ } else if let Some(array) = content.and_then(|c| c.as_array()) {
268
+ // Multimodal content: iterate parts and translate each block.
269
+ array
270
+ .iter()
271
+ .filter_map(|part| {
272
+ let part_type = part.get("type").and_then(|t| t.as_str()).unwrap_or("");
273
+ match part_type {
274
+ "text" => {
275
+ let text = part.get("text").and_then(|t| t.as_str()).unwrap_or("");
276
+ Some(json!({"text": text}))
277
+ }
278
+ "image_url" => {
279
+ // Map OpenAI image_url to Bedrock image block.
280
+ let url = part.pointer("/image_url/url").and_then(|u| u.as_str()).unwrap_or("");
281
+ if let Some(data_part) = url.strip_prefix("data:") {
282
+ // data:{media_type};base64,{data}
283
+ let mut iter = data_part.splitn(2, ';');
284
+ let media_type = iter.next().unwrap_or("image/jpeg");
285
+ let b64 = iter.next().and_then(|s| s.strip_prefix("base64,")).unwrap_or("");
286
+ Some(json!({
287
+ "image": {
288
+ "format": media_type.split('/').nth(1).unwrap_or("jpeg"),
289
+ "source": {"bytes": b64}
290
+ }
291
+ }))
292
+ } else {
293
+ // Plain URL — not directly supported by Bedrock;
294
+ // include as text so the message is not silently dropped.
295
+ Some(json!({"text": url}))
296
+ }
297
+ }
298
+ "document" => {
299
+ // Map ContentPart::Document to Bedrock document block.
300
+ let data =
301
+ part.pointer("/document/data").and_then(|d| d.as_str()).unwrap_or("");
302
+ let media_type = part
303
+ .pointer("/document/media_type")
304
+ .and_then(|m| m.as_str())
305
+ .unwrap_or("application/pdf");
306
+ let format = format_from_media_type(media_type);
307
+ Some(json!({
308
+ "document": {
309
+ "name": "doc",
310
+ "format": format,
311
+ "source": {"bytes": data}
312
+ }
313
+ }))
314
+ }
315
+ _ => None,
316
+ }
317
+ })
318
+ .collect()
319
+ } else {
320
+ // Fallback: represent unknown content as an empty text block.
321
+ vec![json!({"text": ""})]
322
+ };
323
+ converse_messages.push(json!({"role": "user", "content": parts}));
324
+ }
325
+ "assistant" => {
326
+ let mut parts = vec![];
327
+ if let Some(text) = content.and_then(|c| c.as_str())
328
+ && !text.is_empty()
329
+ {
330
+ parts.push(json!({"text": text}));
331
+ }
332
+ // Convert OpenAI tool_calls to Bedrock toolUse blocks.
333
+ if let Some(tool_calls) = msg.get("tool_calls").and_then(|t| t.as_array()) {
334
+ for tc in tool_calls {
335
+ let input: serde_json::Value = tc
336
+ .pointer("/function/arguments")
337
+ .and_then(|a| a.as_str())
338
+ .and_then(|s| serde_json::from_str(s).ok())
339
+ .unwrap_or_else(|| json!({}));
340
+ parts.push(json!({
341
+ "toolUse": {
342
+ "toolUseId": tc.get("id"),
343
+ "name": tc.pointer("/function/name"),
344
+ "input": input
345
+ }
346
+ }));
347
+ }
348
+ }
349
+ if parts.is_empty() {
350
+ parts.push(json!({"text": ""}));
351
+ }
352
+ converse_messages.push(json!({"role": "assistant", "content": parts}));
353
+ }
354
+ "tool" => {
355
+ let tool_call_id = msg.get("tool_call_id").and_then(|t| t.as_str()).unwrap_or("");
356
+ let result_text = content.and_then(|c| c.as_str()).unwrap_or("");
357
+ // Determine status: treat explicit error markers as failures.
358
+ let is_error = msg.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false);
359
+ let status = if is_error { "error" } else { "success" };
360
+ converse_messages.push(json!({
361
+ "role": "user",
362
+ "content": [{
363
+ "toolResult": {
364
+ "toolUseId": tool_call_id,
365
+ "content": [{"text": result_text}],
366
+ "status": status
367
+ }
368
+ }]
369
+ }));
370
+ }
371
+ _ => {}
372
+ }
373
+ }
374
+
375
+ // Build inferenceConfig from OpenAI generation parameters.
376
+ let mut inference_config = json!({});
377
+ if let Some(max_tokens) = body.get("max_tokens").or_else(|| body.get("max_completion_tokens")) {
378
+ inference_config["maxTokens"] = max_tokens.clone();
379
+ }
380
+ if let Some(temp) = body.get("temperature") {
381
+ inference_config["temperature"] = temp.clone();
382
+ }
383
+ if let Some(top_p) = body.get("top_p") {
384
+ inference_config["topP"] = top_p.clone();
385
+ }
386
+ if let Some(stop) = body.get("stop") {
387
+ let sequences = if let Some(s) = stop.as_str() {
388
+ vec![json!(s)]
389
+ } else {
390
+ stop.as_array().cloned().unwrap_or_default()
391
+ };
392
+ inference_config["stopSequences"] = json!(sequences);
393
+ }
394
+
395
+ // Build toolConfig if tools are present.
396
+ let tool_config = body.get("tools").and_then(|tools| {
397
+ tools.as_array().map(|arr| {
398
+ let bedrock_tools: Vec<serde_json::Value> = arr
399
+ .iter()
400
+ .map(|t| {
401
+ let parameters = t
402
+ .pointer("/function/parameters")
403
+ .cloned()
404
+ .unwrap_or_else(|| json!({"type": "object"}));
405
+ json!({
406
+ "toolSpec": {
407
+ "name": t.pointer("/function/name"),
408
+ "description": t.pointer("/function/description"),
409
+ "inputSchema": {"json": parameters}
410
+ }
411
+ })
412
+ })
413
+ .collect();
414
+ json!({"tools": bedrock_tools})
415
+ })
416
+ });
417
+
418
+ // ── Extended thinking / reasoning effort ────────────────────────────
419
+ // When reasoning_effort is set, map to Bedrock's additionalModelRequestFields
420
+ // for Claude-on-Bedrock extended thinking.
421
+ let mut additional_model_fields: Option<serde_json::Value> = None;
422
+ if let Some(effort) = body.get("reasoning_effort").and_then(|e| e.as_str()) {
423
+ let budget_tokens = reasoning_effort_to_budget_tokens(effort);
424
+ additional_model_fields = Some(json!({
425
+ "thinking": {
426
+ "type": "enabled",
427
+ "budget_tokens": budget_tokens
428
+ }
429
+ }));
430
+ }
431
+
432
+ // ── Response format → system instruction ────────────────────────────
433
+ // Bedrock/Claude doesn't have native JSON mode. When response_format is
434
+ // json_schema or json_object, add a system instruction for JSON output.
435
+ if let Some(response_format) = body.get("response_format") {
436
+ let rf_type = response_format.get("type").and_then(|t| t.as_str()).unwrap_or("");
437
+ match rf_type {
438
+ "json_schema" => {
439
+ let schema = response_format.get("json_schema").and_then(|js| js.get("schema"));
440
+ let schema_str = schema
441
+ .map(|s| serde_json::to_string_pretty(s).unwrap_or_default())
442
+ .unwrap_or_default();
443
+ let instruction = if schema_str.is_empty() {
444
+ "You MUST respond with valid JSON only. No other text.".to_owned()
445
+ } else {
446
+ format!(
447
+ "You MUST respond with valid JSON only that conforms to this schema:\n```json\n{schema_str}\n```\nNo other text outside the JSON."
448
+ )
449
+ };
450
+ system_parts.push(json!({"text": instruction}));
451
+ }
452
+ "json_object" => {
453
+ system_parts.push(json!({"text": "You MUST respond with valid JSON only. No other text."}));
454
+ }
455
+ _ => {}
456
+ }
457
+ }
458
+
459
+ // ── Guardrails ──────────────────────────────────────────────────────
460
+ // Extract guardrailConfig from extra_body if present.
461
+ let guardrail_config = body.get("extra_body").and_then(|eb| eb.get("guardrailConfig")).cloned();
462
+
463
+ // Assemble the Bedrock Converse request body.
464
+ let mut new_body = json!({
465
+ "messages": converse_messages,
466
+ });
467
+ if !system_parts.is_empty() {
468
+ new_body["system"] = json!(system_parts);
469
+ }
470
+ if let Some(obj) = inference_config.as_object()
471
+ && !obj.is_empty()
472
+ {
473
+ new_body["inferenceConfig"] = inference_config;
474
+ }
475
+ if let Some(tc) = tool_config {
476
+ new_body["toolConfig"] = tc;
477
+ }
478
+ if let Some(amf) = additional_model_fields {
479
+ new_body["additionalModelRequestFields"] = amf;
480
+ }
481
+ if let Some(gc) = guardrail_config {
482
+ new_body["guardrailConfig"] = gc;
483
+ }
484
+
485
+ *body = new_body;
486
+ Ok(())
487
+ }
488
+
489
+ /// Normalize a Bedrock Converse API response to OpenAI chat completion format.
490
+ ///
491
+ /// Bedrock wraps the assistant's message in `output.message.content[]` blocks.
492
+ /// Stop reasons use Bedrock terminology (`end_turn`, `tool_use`, etc.) and are
493
+ /// mapped to the OpenAI `finish_reason` set.
494
+ ///
495
+ /// **Known limitation:** The `model` field in the normalized response is
496
+ /// always `""`. Bedrock does not include the model name in its response
497
+ /// body — the model is only present in the request URL path. Threading
498
+ /// the model through would require a signature change to `transform_response`.
499
+ fn transform_response(&self, body: &mut serde_json::Value) -> Result<()> {
500
+ use serde_json::json;
501
+
502
+ let stop_reason = body.get("stopReason").and_then(|s| s.as_str()).unwrap_or("end_turn");
503
+ let usage = body.get("usage").cloned();
504
+
505
+ // Content blocks live under output.message.content[].
506
+ let content_blocks = body
507
+ .pointer("/output/message/content")
508
+ .and_then(|c| c.as_array())
509
+ .cloned()
510
+ .unwrap_or_default();
511
+
512
+ // Collect text and toolUse blocks separately.
513
+ let text: String = content_blocks
514
+ .iter()
515
+ .filter_map(|b| b.get("text").and_then(|t| t.as_str()))
516
+ .collect::<Vec<_>>()
517
+ .join("");
518
+
519
+ let tool_calls: Vec<serde_json::Value> = content_blocks
520
+ .iter()
521
+ .filter_map(|b| {
522
+ b.get("toolUse").map(|tu| {
523
+ let arguments = serde_json::to_string(tu.get("input").unwrap_or(&json!({}))).unwrap_or_default();
524
+ json!({
525
+ "id": tu.get("toolUseId"),
526
+ "type": "function",
527
+ "function": {
528
+ "name": tu.get("name"),
529
+ "arguments": arguments
530
+ }
531
+ })
532
+ })
533
+ })
534
+ .collect();
535
+
536
+ let finish_reason = match stop_reason {
537
+ "end_turn" => "stop",
538
+ "tool_use" => "tool_calls",
539
+ "max_tokens" => "length",
540
+ "stop_sequence" => "stop",
541
+ "content_filtered" | "guardrail_intervened" => "content_filter",
542
+ _ => "stop",
543
+ };
544
+
545
+ let input_tokens = usage
546
+ .as_ref()
547
+ .and_then(|u| u.get("inputTokens"))
548
+ .and_then(|v| v.as_u64())
549
+ .unwrap_or(0);
550
+ let output_tokens = usage
551
+ .as_ref()
552
+ .and_then(|u| u.get("outputTokens"))
553
+ .and_then(|v| v.as_u64())
554
+ .unwrap_or(0);
555
+
556
+ let response_id = body
557
+ .get("requestId")
558
+ .or_else(|| body.get("conversationId"))
559
+ .cloned()
560
+ .unwrap_or_else(|| json!("bedrock-resp"));
561
+
562
+ let content_value: serde_json::Value = if text.is_empty() { json!(null) } else { json!(text) };
563
+
564
+ let mut message = json!({"role": "assistant", "content": content_value});
565
+ if !tool_calls.is_empty() {
566
+ message["tool_calls"] = json!(tool_calls);
567
+ }
568
+
569
+ *body = json!({
570
+ "id": response_id,
571
+ "object": "chat.completion",
572
+ "created": super::unix_timestamp_secs(),
573
+ "model": "",
574
+ "choices": [{
575
+ "index": 0,
576
+ "message": message,
577
+ "finish_reason": finish_reason
578
+ }],
579
+ "usage": {
580
+ "prompt_tokens": input_tokens,
581
+ "completion_tokens": output_tokens,
582
+ "total_tokens": input_tokens + output_tokens
583
+ }
584
+ });
585
+
586
+ Ok(())
587
+ }
588
+
589
+ /// Compute AWS SigV4 signing headers for the request.
590
+ ///
591
+ /// When the `bedrock` feature is enabled, derives the `Authorization`,
592
+ /// `x-amz-date`, and (when a session token is present) `x-amz-security-token`
593
+ /// headers from the current request parameters and AWS credentials.
594
+ ///
595
+ /// When the `bedrock` feature is disabled, returns an empty vector so
596
+ /// requests work against override base-URLs (e.g. mock servers in tests).
597
+ fn signing_headers(&self, method: &str, url: &str, body: &[u8]) -> Vec<(String, String)> {
598
+ #[cfg(feature = "bedrock")]
599
+ {
600
+ sigv4_sign(method, url, body, &self.region).unwrap_or_default()
601
+ }
602
+
603
+ #[cfg(not(feature = "bedrock"))]
604
+ {
605
+ let _ = (method, url, body);
606
+ vec![]
607
+ }
608
+ }
609
+ }
610
+
611
+ /// Parse a Bedrock ConverseStream EventStream event into a `ChatCompletionChunk`.
612
+ ///
613
+ /// Bedrock ConverseStream events:
614
+ /// - `messageStart` → role delta
615
+ /// - `contentBlockStart` → tool_use start (with toolUseId and name)
616
+ /// - `contentBlockDelta` → text delta or tool_use input delta
617
+ /// - `contentBlockStop` → (ignored)
618
+ /// - `messageStop` → finish_reason
619
+ /// - `metadata` → usage (emitted as a final chunk with empty delta)
620
+ ///
621
+ /// Returns `Ok(None)` for events that don't map to a chunk (e.g. `contentBlockStop`).
622
+ ///
623
+ /// **Known limitation:** The `id` field is hardcoded to `"bedrock-stream"` and
624
+ /// `model` is always `""` on every chunk. Bedrock's ConverseStream protocol does
625
+ /// not include a request/response ID or model name in its event payloads, and
626
+ /// this parser is stateless so it cannot carry forward values from the original
627
+ /// request. This differs from the OpenAI format where every chunk includes the
628
+ /// real `id` and `model`.
629
+ pub(crate) fn parse_bedrock_stream_event(event_type: &str, payload: &str) -> Result<Option<ChatCompletionChunk>> {
630
+ use crate::error::LiterLlmError;
631
+ use serde_json::json;
632
+
633
+ let v: serde_json::Value = serde_json::from_str(payload).map_err(|e| LiterLlmError::Streaming {
634
+ message: format!("Bedrock stream event parse error: {e}"),
635
+ })?;
636
+
637
+ let chunk_from_json = |chunk_json: serde_json::Value| -> Result<ChatCompletionChunk> {
638
+ serde_json::from_value(chunk_json).map_err(|e| LiterLlmError::Streaming {
639
+ message: format!("Bedrock chunk deserialization error: {e}"),
640
+ })
641
+ };
642
+
643
+ match event_type {
644
+ "messageStart" => {
645
+ let role = v.get("role").and_then(|r| r.as_str()).unwrap_or("assistant");
646
+ chunk_from_json(json!({
647
+ "id": "bedrock-stream",
648
+ "object": "chat.completion.chunk",
649
+ "created": 0,
650
+ "model": "",
651
+ "choices": [{
652
+ "index": 0,
653
+ "delta": {"role": role},
654
+ "finish_reason": null
655
+ }]
656
+ }))
657
+ .map(Some)
658
+ }
659
+ "contentBlockStart" => {
660
+ let index = v.get("contentBlockIndex").and_then(|i| i.as_u64()).unwrap_or(0);
661
+ // Check if this is a tool_use start.
662
+ if let Some(tool_use) = v.pointer("/start/toolUse") {
663
+ let tool_use_id = tool_use.get("toolUseId").and_then(|t| t.as_str()).unwrap_or("");
664
+ let name = tool_use.get("name").and_then(|n| n.as_str()).unwrap_or("");
665
+ chunk_from_json(json!({
666
+ "id": "bedrock-stream",
667
+ "object": "chat.completion.chunk",
668
+ "created": 0,
669
+ "model": "",
670
+ "choices": [{
671
+ "index": 0,
672
+ "delta": {
673
+ "tool_calls": [{
674
+ "index": index,
675
+ "id": tool_use_id,
676
+ "type": "function",
677
+ "function": {"name": name, "arguments": ""}
678
+ }]
679
+ },
680
+ "finish_reason": null
681
+ }]
682
+ }))
683
+ .map(Some)
684
+ } else {
685
+ // Text content block start — no delta content yet.
686
+ Ok(None)
687
+ }
688
+ }
689
+ "contentBlockDelta" => {
690
+ let index = v.get("contentBlockIndex").and_then(|i| i.as_u64()).unwrap_or(0);
691
+
692
+ // Text delta.
693
+ if let Some(text) = v.pointer("/delta/text").and_then(|t| t.as_str()) {
694
+ return chunk_from_json(json!({
695
+ "id": "bedrock-stream",
696
+ "object": "chat.completion.chunk",
697
+ "created": 0,
698
+ "model": "",
699
+ "choices": [{
700
+ "index": 0,
701
+ "delta": {"content": text},
702
+ "finish_reason": null
703
+ }]
704
+ }))
705
+ .map(Some);
706
+ }
707
+
708
+ // Tool use input delta.
709
+ if let Some(input_json) = v.pointer("/delta/toolUse/input").and_then(|i| i.as_str()) {
710
+ return chunk_from_json(json!({
711
+ "id": "bedrock-stream",
712
+ "object": "chat.completion.chunk",
713
+ "created": 0,
714
+ "model": "",
715
+ "choices": [{
716
+ "index": 0,
717
+ "delta": {
718
+ "tool_calls": [{
719
+ "index": index,
720
+ "function": {"arguments": input_json}
721
+ }]
722
+ },
723
+ "finish_reason": null
724
+ }]
725
+ }))
726
+ .map(Some);
727
+ }
728
+
729
+ // Unrecognized delta shape — log so callers know data was skipped.
730
+ #[cfg(feature = "tracing")]
731
+ tracing::debug!(
732
+ content_block_index = index,
733
+ "Bedrock contentBlockDelta with unrecognized delta shape; skipping"
734
+ );
735
+
736
+ Ok(None)
737
+ }
738
+ "contentBlockStop" => Ok(None),
739
+ "messageStop" => {
740
+ let stop_reason = v.get("stopReason").and_then(|s| s.as_str()).unwrap_or("end_turn");
741
+ let finish_reason = match stop_reason {
742
+ "end_turn" => "stop",
743
+ "tool_use" => "tool_calls",
744
+ "max_tokens" => "length",
745
+ "stop_sequence" => "stop",
746
+ "content_filtered" | "guardrail_intervened" => "content_filter",
747
+ _ => "stop",
748
+ };
749
+ chunk_from_json(json!({
750
+ "id": "bedrock-stream",
751
+ "object": "chat.completion.chunk",
752
+ "created": 0,
753
+ "model": "",
754
+ "choices": [{
755
+ "index": 0,
756
+ "delta": {},
757
+ "finish_reason": finish_reason
758
+ }]
759
+ }))
760
+ .map(Some)
761
+ }
762
+ "metadata" => {
763
+ // Emit usage as a final chunk with empty choices.
764
+ let input_tokens = v.pointer("/usage/inputTokens").and_then(|t| t.as_u64()).unwrap_or(0);
765
+ let output_tokens = v.pointer("/usage/outputTokens").and_then(|t| t.as_u64()).unwrap_or(0);
766
+ chunk_from_json(json!({
767
+ "id": "bedrock-stream",
768
+ "object": "chat.completion.chunk",
769
+ "created": 0,
770
+ "model": "",
771
+ "choices": [],
772
+ "usage": {
773
+ "prompt_tokens": input_tokens,
774
+ "completion_tokens": output_tokens,
775
+ "total_tokens": input_tokens + output_tokens
776
+ }
777
+ }))
778
+ .map(Some)
779
+ }
780
+ _ => {
781
+ // Unknown event type — skip silently.
782
+ Ok(None)
783
+ }
784
+ }
785
+ }
786
+
787
+ /// Apply the cross-region inference profile prefix using the value cached at
788
+ /// construction time from the `BEDROCK_CROSS_REGION` environment variable.
789
+ ///
790
+ /// When the prefix is set (e.g. `"us."`), the model ID
791
+ /// `anthropic.claude-3-sonnet-20240229-v1:0` becomes
792
+ /// `us.anthropic.claude-3-sonnet-20240229-v1:0`.
793
+ ///
794
+ /// If the model already starts with the cross-region prefix, it is returned
795
+ /// unchanged to avoid double-prefixing.
796
+ impl BedrockProvider {
797
+ fn apply_cross_region_prefix(&self, model: &str) -> String {
798
+ match &self.cross_region_prefix {
799
+ Some(prefix) => {
800
+ if model.starts_with(prefix.as_str()) {
801
+ model.to_owned()
802
+ } else {
803
+ format!("{prefix}{model}")
804
+ }
805
+ }
806
+ None => model.to_owned(),
807
+ }
808
+ }
809
+ }
810
+
811
+ /// Legacy free function kept for existing tests. Reads the env var directly.
812
+ ///
813
+ /// Production code uses [`BedrockProvider::apply_cross_region_prefix`] which
814
+ /// reads the env var once at construction time.
815
+ #[cfg(test)]
816
+ fn apply_cross_region_prefix(model: &str) -> String {
817
+ match std::env::var("BEDROCK_CROSS_REGION") {
818
+ Ok(region) if !region.is_empty() => {
819
+ let prefix = format!("{region}.");
820
+ if model.starts_with(&prefix) {
821
+ model.to_owned()
822
+ } else {
823
+ format!("{prefix}{model}")
824
+ }
825
+ }
826
+ _ => model.to_owned(),
827
+ }
828
+ }
829
+
830
+ /// Compute AWS SigV4 signing headers using the `aws-sigv4` crate.
831
+ ///
832
+ /// Reads credentials from the standard AWS environment variables:
833
+ /// - `AWS_ACCESS_KEY_ID` (required)
834
+ /// - `AWS_SECRET_ACCESS_KEY` (required)
835
+ /// - `AWS_SESSION_TOKEN` (optional, for temporary credentials)
836
+ ///
837
+ /// Returns a vector of `(header-name, header-value)` pairs to inject into the
838
+ /// outgoing HTTP request.
839
+ #[cfg(feature = "bedrock")]
840
+ fn sigv4_sign(method: &str, url: &str, body: &[u8], region: &str) -> Result<Vec<(String, String)>> {
841
+ use aws_credential_types::Credentials;
842
+ use aws_sigv4::http_request::{SignableBody, SignableRequest, SigningSettings, sign};
843
+ use aws_sigv4::sign::v4::SigningParams;
844
+
845
+ let access_key = std::env::var("AWS_ACCESS_KEY_ID").map_err(|_| LiterLlmError::BadRequest {
846
+ message: "AWS_ACCESS_KEY_ID environment variable is required for Bedrock requests".into(),
847
+ })?;
848
+ let secret_key = std::env::var("AWS_SECRET_ACCESS_KEY").map_err(|_| LiterLlmError::BadRequest {
849
+ message: "AWS_SECRET_ACCESS_KEY environment variable is required for Bedrock requests".into(),
850
+ })?;
851
+ let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
852
+
853
+ let credentials = Credentials::new(
854
+ access_key,
855
+ secret_key,
856
+ session_token,
857
+ None, // expiry
858
+ "env",
859
+ );
860
+
861
+ let identity = credentials.into();
862
+
863
+ let signing_settings = SigningSettings::default();
864
+ let now = std::time::SystemTime::now();
865
+
866
+ let params = SigningParams::builder()
867
+ .identity(&identity)
868
+ .region(region)
869
+ .name("bedrock")
870
+ .time(now)
871
+ .settings(signing_settings)
872
+ .build()
873
+ .map_err(|e| LiterLlmError::BadRequest {
874
+ message: format!("failed to build SigV4 signing params: {e}"),
875
+ })?;
876
+
877
+ // Build a signable request from the method, URL, and body.
878
+ let signable = SignableRequest::new(
879
+ method,
880
+ url,
881
+ std::iter::empty::<(&str, &str)>(),
882
+ SignableBody::Bytes(body),
883
+ )
884
+ .map_err(|e| LiterLlmError::BadRequest {
885
+ message: format!("failed to create signable request: {e}"),
886
+ })?;
887
+
888
+ let signing_output = sign(signable, &params.into()).map_err(|e| LiterLlmError::BadRequest {
889
+ message: format!("SigV4 signing failed: {e}"),
890
+ })?;
891
+
892
+ let instructions = signing_output.output();
893
+ let signed_headers: Vec<(String, String)> = instructions
894
+ .headers()
895
+ .map(|(name, value)| (name.to_owned(), value.to_owned()))
896
+ .collect();
897
+
898
+ Ok(signed_headers)
899
+ }
900
+
901
+ // ── Unit tests ────────────────────────────────────────────────────────────────
902
+
903
+ #[cfg(test)]
904
+ mod tests {
905
+ use serde_json::json;
906
+
907
+ use serial_test::serial;
908
+
909
+ use super::*;
910
+ use crate::provider::Provider;
911
+ use crate::types::chat::FinishReason;
912
+
913
+ fn provider() -> BedrockProvider {
914
+ BedrockProvider::new("us-east-1")
915
+ }
916
+
917
+ // ── build_url ─────────────────────────────────────────────────────────────
918
+
919
+ #[test]
920
+ fn build_url_chat_completions() {
921
+ let p = provider();
922
+ let url = p.build_url("/chat/completions", "anthropic.claude-3-sonnet-20240229-v1:0");
923
+ // Colon must be uppercase-encoded per RFC 3986 §2.1.
924
+ assert_eq!(
925
+ url,
926
+ "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1%3A0/converse"
927
+ );
928
+ }
929
+
930
+ #[test]
931
+ fn build_url_embeddings() {
932
+ let p = provider();
933
+ let url = p.build_url("/embeddings", "amazon.titan-embed-text-v1");
934
+ assert_eq!(
935
+ url,
936
+ "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v1/invoke"
937
+ );
938
+ }
939
+
940
+ #[test]
941
+ fn build_url_other_path() {
942
+ let p = provider();
943
+ let url = p.build_url("/models", "any-model");
944
+ assert_eq!(url, "https://bedrock-runtime.us-east-1.amazonaws.com/models");
945
+ }
946
+
947
+ // ── percent_encode_model ──────────────────────────────────────────────────
948
+
949
+ #[test]
950
+ fn percent_encode_model_colon() {
951
+ let encoded = percent_encode_model("anthropic.claude-3-sonnet-20240229-v1:0");
952
+ // RFC 3986 §2.1 requires uppercase hex digits.
953
+ assert!(
954
+ encoded.contains("%3A"),
955
+ "colon should be percent-encoded with uppercase hex: {encoded}"
956
+ );
957
+ assert!(!encoded.contains("%3a"), "lowercase hex must not appear: {encoded}");
958
+ assert!(!encoded.contains(':'), "raw colon should not remain: {encoded}");
959
+ }
960
+
961
+ #[test]
962
+ fn percent_encode_model_safe_chars() {
963
+ let encoded = percent_encode_model("amazon.titan-embed-text-v1");
964
+ assert_eq!(encoded, "amazon.titan-embed-text-v1");
965
+ }
966
+
967
+ // ── transform_request ─────────────────────────────────────────────────────
968
+
969
+ #[test]
970
+ fn transform_request_basic_chat() {
971
+ let p = provider();
972
+ let mut body = json!({
973
+ "model": "anthropic.claude-3-sonnet",
974
+ "messages": [
975
+ {"role": "system", "content": "You are helpful."},
976
+ {"role": "user", "content": "Hello!"}
977
+ ],
978
+ "max_tokens": 100,
979
+ "temperature": 0.7
980
+ });
981
+
982
+ p.transform_request(&mut body).unwrap();
983
+
984
+ // System messages extracted to top-level array.
985
+ assert_eq!(body["system"][0]["text"], "You are helpful.");
986
+
987
+ // User message converted to content blocks.
988
+ assert_eq!(body["messages"][0]["role"], "user");
989
+ assert_eq!(body["messages"][0]["content"][0]["text"], "Hello!");
990
+
991
+ // Generation params in inferenceConfig.
992
+ assert_eq!(body["inferenceConfig"]["maxTokens"], 100);
993
+ assert_eq!(body["inferenceConfig"]["temperature"], 0.7);
994
+ }
995
+
996
+ #[test]
997
+ fn transform_request_with_tool_calls() {
998
+ let p = provider();
999
+ let mut body = json!({
1000
+ "messages": [
1001
+ {"role": "user", "content": "What is the weather?"},
1002
+ {
1003
+ "role": "assistant",
1004
+ "content": null,
1005
+ "tool_calls": [{
1006
+ "id": "call_abc",
1007
+ "type": "function",
1008
+ "function": {"name": "get_weather", "arguments": "{\"city\":\"Berlin\"}"}
1009
+ }]
1010
+ },
1011
+ {
1012
+ "role": "tool",
1013
+ "tool_call_id": "call_abc",
1014
+ "content": "Sunny, 22°C"
1015
+ }
1016
+ ]
1017
+ });
1018
+
1019
+ p.transform_request(&mut body).unwrap();
1020
+
1021
+ let messages = body["messages"].as_array().unwrap();
1022
+ assert_eq!(messages.len(), 3);
1023
+
1024
+ // Assistant message has toolUse block.
1025
+ let assistant = &messages[1];
1026
+ assert_eq!(assistant["role"], "assistant");
1027
+ let tool_use = &assistant["content"][0]["toolUse"];
1028
+ assert_eq!(tool_use["toolUseId"], "call_abc");
1029
+ assert_eq!(tool_use["name"], "get_weather");
1030
+ assert_eq!(tool_use["input"]["city"], "Berlin");
1031
+
1032
+ // Tool result converted to user message with toolResult block.
1033
+ let tool_result_msg = &messages[2];
1034
+ assert_eq!(tool_result_msg["role"], "user");
1035
+ let tool_result = &tool_result_msg["content"][0]["toolResult"];
1036
+ assert_eq!(tool_result["toolUseId"], "call_abc");
1037
+ assert_eq!(tool_result["status"], "success");
1038
+ }
1039
+
1040
+ #[test]
1041
+ fn transform_request_tools_schema() {
1042
+ let p = provider();
1043
+ let mut body = json!({
1044
+ "messages": [{"role": "user", "content": "hi"}],
1045
+ "tools": [{
1046
+ "type": "function",
1047
+ "function": {
1048
+ "name": "search",
1049
+ "description": "Search the web",
1050
+ "parameters": {"type": "object", "properties": {"query": {"type": "string"}}}
1051
+ }
1052
+ }]
1053
+ });
1054
+
1055
+ p.transform_request(&mut body).unwrap();
1056
+
1057
+ let tools = body["toolConfig"]["tools"].as_array().unwrap();
1058
+ assert_eq!(tools.len(), 1);
1059
+ let spec = &tools[0]["toolSpec"];
1060
+ assert_eq!(spec["name"], "search");
1061
+ assert_eq!(spec["description"], "Search the web");
1062
+ assert_eq!(spec["inputSchema"]["json"]["type"], "object");
1063
+ }
1064
+
1065
+ // ── transform_response ────────────────────────────────────────────────────
1066
+
1067
+ #[test]
1068
+ fn transform_response_basic() {
1069
+ let p = provider();
1070
+ let mut body = json!({
1071
+ "requestId": "req-123",
1072
+ "stopReason": "end_turn",
1073
+ "output": {
1074
+ "message": {
1075
+ "role": "assistant",
1076
+ "content": [{"text": "Hello, world!"}]
1077
+ }
1078
+ },
1079
+ "usage": {
1080
+ "inputTokens": 10,
1081
+ "outputTokens": 5
1082
+ }
1083
+ });
1084
+
1085
+ p.transform_response(&mut body).unwrap();
1086
+
1087
+ assert_eq!(body["object"], "chat.completion");
1088
+ assert_eq!(body["id"], "req-123");
1089
+ assert_eq!(body["choices"][0]["message"]["content"], "Hello, world!");
1090
+ assert_eq!(body["choices"][0]["finish_reason"], "stop");
1091
+ assert_eq!(body["usage"]["prompt_tokens"], 10);
1092
+ assert_eq!(body["usage"]["completion_tokens"], 5);
1093
+ assert_eq!(body["usage"]["total_tokens"], 15);
1094
+ }
1095
+
1096
+ #[test]
1097
+ fn transform_response_tool_calls() {
1098
+ let p = provider();
1099
+ let mut body = json!({
1100
+ "stopReason": "tool_use",
1101
+ "output": {
1102
+ "message": {
1103
+ "role": "assistant",
1104
+ "content": [
1105
+ {"toolUse": {
1106
+ "toolUseId": "call_xyz",
1107
+ "name": "get_weather",
1108
+ "input": {"city": "Berlin"}
1109
+ }}
1110
+ ]
1111
+ }
1112
+ },
1113
+ "usage": {"inputTokens": 20, "outputTokens": 10}
1114
+ });
1115
+
1116
+ p.transform_response(&mut body).unwrap();
1117
+
1118
+ assert_eq!(body["choices"][0]["finish_reason"], "tool_calls");
1119
+ let tool_calls = body["choices"][0]["message"]["tool_calls"].as_array().unwrap();
1120
+ assert_eq!(tool_calls.len(), 1);
1121
+ assert_eq!(tool_calls[0]["id"], "call_xyz");
1122
+ assert_eq!(tool_calls[0]["function"]["name"], "get_weather");
1123
+ let args: serde_json::Value =
1124
+ serde_json::from_str(tool_calls[0]["function"]["arguments"].as_str().unwrap()).unwrap();
1125
+ assert_eq!(args["city"], "Berlin");
1126
+ }
1127
+
1128
+ #[test]
1129
+ fn transform_response_finish_reason_mapping() {
1130
+ let p = provider();
1131
+
1132
+ for (bedrock_reason, expected_oai_reason) in [
1133
+ ("end_turn", "stop"),
1134
+ ("tool_use", "tool_calls"),
1135
+ ("max_tokens", "length"),
1136
+ ("stop_sequence", "stop"),
1137
+ ("content_filtered", "content_filter"),
1138
+ ("guardrail_intervened", "content_filter"),
1139
+ ("unknown_future_reason", "stop"),
1140
+ ] {
1141
+ let mut body = json!({
1142
+ "stopReason": bedrock_reason,
1143
+ "output": {"message": {"role": "assistant", "content": [{"text": ""}]}},
1144
+ "usage": {"inputTokens": 0, "outputTokens": 0}
1145
+ });
1146
+ p.transform_response(&mut body).unwrap();
1147
+ assert_eq!(
1148
+ body["choices"][0]["finish_reason"], expected_oai_reason,
1149
+ "bedrock stopReason '{bedrock_reason}' should map to '{expected_oai_reason}'"
1150
+ );
1151
+ }
1152
+ }
1153
+
1154
+ // ── model prefix / matching ───────────────────────────────────────────────
1155
+
1156
+ #[test]
1157
+ fn strip_model_prefix() {
1158
+ let p = provider();
1159
+ assert_eq!(p.strip_model_prefix("bedrock/anthropic.claude-3"), "anthropic.claude-3");
1160
+ assert_eq!(p.strip_model_prefix("anthropic.claude-3"), "anthropic.claude-3");
1161
+ }
1162
+
1163
+ #[test]
1164
+ fn matches_model() {
1165
+ let p = provider();
1166
+ assert!(p.matches_model("bedrock/anthropic.claude-3"));
1167
+ assert!(!p.matches_model("anthropic.claude-3"));
1168
+ assert!(!p.matches_model("gpt-4"));
1169
+ }
1170
+
1171
+ // ── stream_format ─────────────────────────────────────────────────────────
1172
+
1173
+ #[test]
1174
+ fn stream_format_is_eventstream() {
1175
+ let p = provider();
1176
+ assert_eq!(p.stream_format(), StreamFormat::AwsEventStream);
1177
+ }
1178
+
1179
+ // ── build_stream_url ──────────────────────────────────────────────────────
1180
+
1181
+ #[test]
1182
+ fn build_stream_url_chat_completions() {
1183
+ let p = provider();
1184
+ let url = p.build_stream_url("/chat/completions", "anthropic.claude-3-sonnet-20240229-v1:0");
1185
+ assert_eq!(
1186
+ url,
1187
+ "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1%3A0/converse-stream"
1188
+ );
1189
+ }
1190
+
1191
+ #[test]
1192
+ fn build_stream_url_non_chat_falls_back() {
1193
+ let p = provider();
1194
+ let url = p.build_stream_url("/embeddings", "amazon.titan-embed-text-v1");
1195
+ assert_eq!(
1196
+ url,
1197
+ "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v1/invoke"
1198
+ );
1199
+ }
1200
+
1201
+ // ── parse_bedrock_stream_event ────────────────────────────────────────────
1202
+
1203
+ #[test]
1204
+ fn parse_stream_event_message_start() {
1205
+ let chunk = parse_bedrock_stream_event("messageStart", r#"{"role":"assistant"}"#)
1206
+ .unwrap()
1207
+ .unwrap();
1208
+ assert_eq!(chunk.choices[0].delta.role.as_deref(), Some("assistant"));
1209
+ }
1210
+
1211
+ #[test]
1212
+ fn parse_stream_event_text_delta() {
1213
+ let chunk = parse_bedrock_stream_event(
1214
+ "contentBlockDelta",
1215
+ r#"{"contentBlockIndex":0,"delta":{"text":"Hello world"}}"#,
1216
+ )
1217
+ .unwrap()
1218
+ .unwrap();
1219
+ assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello world"));
1220
+ }
1221
+
1222
+ #[test]
1223
+ fn parse_stream_event_tool_use_start() {
1224
+ let chunk = parse_bedrock_stream_event(
1225
+ "contentBlockStart",
1226
+ r#"{"contentBlockIndex":0,"start":{"toolUse":{"toolUseId":"call_123","name":"get_weather"}}}"#,
1227
+ )
1228
+ .unwrap()
1229
+ .unwrap();
1230
+ let tc = &chunk.choices[0].delta.tool_calls.as_ref().unwrap()[0];
1231
+ assert_eq!(tc.id.as_deref(), Some("call_123"));
1232
+ assert_eq!(tc.function.as_ref().unwrap().name.as_deref(), Some("get_weather"));
1233
+ }
1234
+
1235
+ #[test]
1236
+ fn parse_stream_event_tool_use_input_delta() {
1237
+ let chunk = parse_bedrock_stream_event(
1238
+ "contentBlockDelta",
1239
+ r#"{"contentBlockIndex":0,"delta":{"toolUse":{"input":"{\"city\":\"Berlin\"}"}}}"#,
1240
+ )
1241
+ .unwrap()
1242
+ .unwrap();
1243
+ let tc = &chunk.choices[0].delta.tool_calls.as_ref().unwrap()[0];
1244
+ assert_eq!(
1245
+ tc.function.as_ref().unwrap().arguments.as_deref(),
1246
+ Some("{\"city\":\"Berlin\"}")
1247
+ );
1248
+ }
1249
+
1250
+ #[test]
1251
+ fn parse_stream_event_message_stop() {
1252
+ let chunk = parse_bedrock_stream_event("messageStop", r#"{"stopReason":"end_turn"}"#)
1253
+ .unwrap()
1254
+ .unwrap();
1255
+ assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::Stop));
1256
+ }
1257
+
1258
+ #[test]
1259
+ fn parse_stream_event_metadata_usage() {
1260
+ let chunk = parse_bedrock_stream_event("metadata", r#"{"usage":{"inputTokens":42,"outputTokens":10}}"#)
1261
+ .unwrap()
1262
+ .unwrap();
1263
+ let usage = chunk.usage.unwrap();
1264
+ assert_eq!(usage.prompt_tokens, 42);
1265
+ assert_eq!(usage.completion_tokens, 10);
1266
+ }
1267
+
1268
+ #[test]
1269
+ fn parse_stream_event_content_block_stop_returns_none() {
1270
+ let result = parse_bedrock_stream_event("contentBlockStop", r#"{"contentBlockIndex":0}"#).unwrap();
1271
+ assert!(result.is_none());
1272
+ }
1273
+
1274
+ #[test]
1275
+ fn parse_stream_event_unknown_returns_none() {
1276
+ let result = parse_bedrock_stream_event("futureEventType", r#"{}"#).unwrap();
1277
+ assert!(result.is_none());
1278
+ }
1279
+
1280
+ // ── Extended thinking / reasoning effort ─────────────────────────────────
1281
+
1282
+ #[test]
1283
+ fn transform_request_reasoning_effort_low() {
1284
+ let p = provider();
1285
+ let mut body = json!({
1286
+ "messages": [{"role": "user", "content": "Think step by step."}],
1287
+ "reasoning_effort": "low",
1288
+ "max_tokens": 1000
1289
+ });
1290
+ p.transform_request(&mut body).unwrap();
1291
+
1292
+ let amf = &body["additionalModelRequestFields"];
1293
+ assert_eq!(amf["thinking"]["type"], "enabled");
1294
+ assert_eq!(amf["thinking"]["budget_tokens"], 1024);
1295
+ }
1296
+
1297
+ #[test]
1298
+ fn transform_request_reasoning_effort_medium() {
1299
+ let p = provider();
1300
+ let mut body = json!({
1301
+ "messages": [{"role": "user", "content": "Think."}],
1302
+ "reasoning_effort": "medium"
1303
+ });
1304
+ p.transform_request(&mut body).unwrap();
1305
+
1306
+ assert_eq!(body["additionalModelRequestFields"]["thinking"]["budget_tokens"], 4096);
1307
+ }
1308
+
1309
+ #[test]
1310
+ fn transform_request_reasoning_effort_high() {
1311
+ let p = provider();
1312
+ let mut body = json!({
1313
+ "messages": [{"role": "user", "content": "Think hard."}],
1314
+ "reasoning_effort": "high"
1315
+ });
1316
+ p.transform_request(&mut body).unwrap();
1317
+
1318
+ assert_eq!(body["additionalModelRequestFields"]["thinking"]["budget_tokens"], 16384);
1319
+ }
1320
+
1321
+ #[test]
1322
+ fn transform_request_no_reasoning_effort_omits_amf() {
1323
+ let p = provider();
1324
+ let mut body = json!({
1325
+ "messages": [{"role": "user", "content": "hi"}]
1326
+ });
1327
+ p.transform_request(&mut body).unwrap();
1328
+
1329
+ assert!(body.get("additionalModelRequestFields").is_none());
1330
+ }
1331
+
1332
+ // ── Document handling ────────────────────────────────────────────────────
1333
+
1334
+ #[test]
1335
+ fn transform_request_document_content_part() {
1336
+ let p = provider();
1337
+ let mut body = json!({
1338
+ "messages": [{
1339
+ "role": "user",
1340
+ "content": [
1341
+ {"type": "text", "text": "Summarize this document."},
1342
+ {
1343
+ "type": "document",
1344
+ "document": {
1345
+ "data": "JVBERi0xLjQ=",
1346
+ "media_type": "application/pdf"
1347
+ }
1348
+ }
1349
+ ]
1350
+ }]
1351
+ });
1352
+ p.transform_request(&mut body).unwrap();
1353
+
1354
+ let content = body["messages"][0]["content"].as_array().unwrap();
1355
+ assert_eq!(content.len(), 2);
1356
+
1357
+ // First part: text
1358
+ assert_eq!(content[0]["text"], "Summarize this document.");
1359
+
1360
+ // Second part: document
1361
+ let doc = &content[1]["document"];
1362
+ assert_eq!(doc["name"], "doc");
1363
+ assert_eq!(doc["format"], "pdf");
1364
+ assert_eq!(doc["source"]["bytes"], "JVBERi0xLjQ=");
1365
+ }
1366
+
1367
+ #[test]
1368
+ fn transform_request_document_csv_format() {
1369
+ let p = provider();
1370
+ let mut body = json!({
1371
+ "messages": [{
1372
+ "role": "user",
1373
+ "content": [
1374
+ {
1375
+ "type": "document",
1376
+ "document": {
1377
+ "data": "Y29sMSxjb2wy",
1378
+ "media_type": "text/csv"
1379
+ }
1380
+ }
1381
+ ]
1382
+ }]
1383
+ });
1384
+ p.transform_request(&mut body).unwrap();
1385
+
1386
+ let doc = &body["messages"][0]["content"][0]["document"];
1387
+ assert_eq!(doc["format"], "csv");
1388
+ }
1389
+
1390
+ // ── Guardrails ───────────────────────────────────────────────────────────
1391
+
1392
+ #[test]
1393
+ fn transform_request_guardrails() {
1394
+ let p = provider();
1395
+ let mut body = json!({
1396
+ "messages": [{"role": "user", "content": "hello"}],
1397
+ "extra_body": {
1398
+ "guardrailConfig": {
1399
+ "guardrailIdentifier": "my-guardrail-id",
1400
+ "guardrailVersion": "DRAFT",
1401
+ "trace": "enabled"
1402
+ }
1403
+ }
1404
+ });
1405
+ p.transform_request(&mut body).unwrap();
1406
+
1407
+ let gc = &body["guardrailConfig"];
1408
+ assert_eq!(gc["guardrailIdentifier"], "my-guardrail-id");
1409
+ assert_eq!(gc["guardrailVersion"], "DRAFT");
1410
+ assert_eq!(gc["trace"], "enabled");
1411
+ }
1412
+
1413
+ #[test]
1414
+ fn transform_request_no_guardrails_omits_config() {
1415
+ let p = provider();
1416
+ let mut body = json!({
1417
+ "messages": [{"role": "user", "content": "hello"}]
1418
+ });
1419
+ p.transform_request(&mut body).unwrap();
1420
+
1421
+ assert!(body.get("guardrailConfig").is_none());
1422
+ }
1423
+
1424
+ // ── Response format / structured output ──────────────────────────────────
1425
+
1426
+ #[test]
1427
+ fn transform_request_json_object_response_format() {
1428
+ let p = provider();
1429
+ let mut body = json!({
1430
+ "messages": [{"role": "user", "content": "Give me JSON."}],
1431
+ "response_format": {"type": "json_object"}
1432
+ });
1433
+ p.transform_request(&mut body).unwrap();
1434
+
1435
+ // Should have a system instruction for JSON output.
1436
+ let system = body["system"].as_array().unwrap();
1437
+ let has_json_instruction = system
1438
+ .iter()
1439
+ .any(|s| s["text"].as_str().unwrap_or("").contains("valid JSON"));
1440
+ assert!(has_json_instruction, "should inject JSON instruction in system");
1441
+ }
1442
+
1443
+ #[test]
1444
+ fn transform_request_json_schema_response_format() {
1445
+ let p = provider();
1446
+ let mut body = json!({
1447
+ "messages": [{"role": "user", "content": "Give me structured data."}],
1448
+ "response_format": {
1449
+ "type": "json_schema",
1450
+ "json_schema": {
1451
+ "name": "my_schema",
1452
+ "schema": {
1453
+ "type": "object",
1454
+ "properties": {
1455
+ "name": {"type": "string"}
1456
+ }
1457
+ }
1458
+ }
1459
+ }
1460
+ });
1461
+ p.transform_request(&mut body).unwrap();
1462
+
1463
+ let system = body["system"].as_array().unwrap();
1464
+ let json_instruction = system
1465
+ .iter()
1466
+ .find(|s| s["text"].as_str().unwrap_or("").contains("valid JSON"))
1467
+ .unwrap();
1468
+ let text = json_instruction["text"].as_str().unwrap();
1469
+ assert!(
1470
+ text.contains("conforms to this schema"),
1471
+ "should include schema reference: {text}"
1472
+ );
1473
+ assert!(text.contains("\"name\""), "should include the schema content: {text}");
1474
+ }
1475
+
1476
+ #[test]
1477
+ fn transform_request_text_response_format_no_injection() {
1478
+ let p = provider();
1479
+ let mut body = json!({
1480
+ "messages": [{"role": "user", "content": "hello"}],
1481
+ "response_format": {"type": "text"}
1482
+ });
1483
+ p.transform_request(&mut body).unwrap();
1484
+
1485
+ // No system instruction should be added for plain text format.
1486
+ assert!(body.get("system").is_none());
1487
+ }
1488
+
1489
+ // ── Cross-region inference ───────────────────────────────────────────────
1490
+
1491
+ #[test]
1492
+ #[serial]
1493
+ fn apply_cross_region_prefix_when_set() {
1494
+ // SAFETY: env vars are process-global; `#[serial]` ensures no parallel mutation.
1495
+ unsafe { std::env::set_var("BEDROCK_CROSS_REGION", "us") };
1496
+ let result = super::apply_cross_region_prefix("anthropic.claude-3-sonnet-20240229-v1:0");
1497
+ assert_eq!(result, "us.anthropic.claude-3-sonnet-20240229-v1:0");
1498
+ unsafe { std::env::remove_var("BEDROCK_CROSS_REGION") };
1499
+ }
1500
+
1501
+ #[test]
1502
+ #[serial]
1503
+ fn apply_cross_region_prefix_no_double_prefix() {
1504
+ // SAFETY: env vars are process-global; `#[serial]` ensures no parallel mutation.
1505
+ unsafe { std::env::set_var("BEDROCK_CROSS_REGION", "eu") };
1506
+ let result = super::apply_cross_region_prefix("eu.anthropic.claude-3-sonnet-20240229-v1:0");
1507
+ assert_eq!(
1508
+ result, "eu.anthropic.claude-3-sonnet-20240229-v1:0",
1509
+ "should not double-prefix"
1510
+ );
1511
+ unsafe { std::env::remove_var("BEDROCK_CROSS_REGION") };
1512
+ }
1513
+
1514
+ #[test]
1515
+ #[serial]
1516
+ fn apply_cross_region_prefix_unset() {
1517
+ // SAFETY: env vars are process-global; `#[serial]` ensures no parallel mutation.
1518
+ unsafe { std::env::remove_var("BEDROCK_CROSS_REGION") };
1519
+ let result = super::apply_cross_region_prefix("anthropic.claude-3-sonnet-20240229-v1:0");
1520
+ assert_eq!(result, "anthropic.claude-3-sonnet-20240229-v1:0");
1521
+ }
1522
+
1523
+ // ── Helper function tests ────────────────────────────────────────────────
1524
+
1525
+ #[test]
1526
+ fn reasoning_effort_budget_tokens() {
1527
+ assert_eq!(super::reasoning_effort_to_budget_tokens("low"), 1024);
1528
+ assert_eq!(super::reasoning_effort_to_budget_tokens("medium"), 4096);
1529
+ assert_eq!(super::reasoning_effort_to_budget_tokens("high"), 16384);
1530
+ assert_eq!(super::reasoning_effort_to_budget_tokens("unknown"), 4096);
1531
+ }
1532
+
1533
+ #[test]
1534
+ fn format_from_media_type_extraction() {
1535
+ assert_eq!(super::format_from_media_type("application/pdf"), "pdf");
1536
+ assert_eq!(super::format_from_media_type("text/csv"), "csv");
1537
+ assert_eq!(
1538
+ super::format_from_media_type("application/vnd.openxmlformats-officedocument.wordprocessingml.document"),
1539
+ "vnd.openxmlformats-officedocument.wordprocessingml.document"
1540
+ );
1541
+ assert_eq!(super::format_from_media_type("pdf"), "pdf"); // fallback
1542
+ }
1543
+ }