liter_llm 1.0.0.pre.rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +239 -0
  3. data/ext/liter_llm_rb/extconf.rb +65 -0
  4. data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
  5. data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
  6. data/ext/liter_llm_rb/native/Cargo.toml +32 -0
  7. data/ext/liter_llm_rb/native/build.rs +15 -0
  8. data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
  9. data/lib/liter_llm.rb +8 -0
  10. data/sig/liter_llm.rbs +416 -0
  11. data/vendor/Cargo.toml +54 -0
  12. data/vendor/liter-llm/Cargo.toml +92 -0
  13. data/vendor/liter-llm/README.md +252 -0
  14. data/vendor/liter-llm/schemas/pricing.json +40 -0
  15. data/vendor/liter-llm/schemas/providers.json +1662 -0
  16. data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
  17. data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
  18. data/vendor/liter-llm/src/auth/mod.rs +68 -0
  19. data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
  20. data/vendor/liter-llm/src/client/config.rs +351 -0
  21. data/vendor/liter-llm/src/client/managed.rs +622 -0
  22. data/vendor/liter-llm/src/client/mod.rs +864 -0
  23. data/vendor/liter-llm/src/cost.rs +212 -0
  24. data/vendor/liter-llm/src/error.rs +190 -0
  25. data/vendor/liter-llm/src/http/eventstream.rs +860 -0
  26. data/vendor/liter-llm/src/http/mod.rs +12 -0
  27. data/vendor/liter-llm/src/http/request.rs +438 -0
  28. data/vendor/liter-llm/src/http/retry.rs +72 -0
  29. data/vendor/liter-llm/src/http/streaming.rs +289 -0
  30. data/vendor/liter-llm/src/lib.rs +37 -0
  31. data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
  32. data/vendor/liter-llm/src/provider/azure.rs +579 -0
  33. data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
  34. data/vendor/liter-llm/src/provider/cohere.rs +654 -0
  35. data/vendor/liter-llm/src/provider/custom.rs +404 -0
  36. data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
  37. data/vendor/liter-llm/src/provider/mistral.rs +188 -0
  38. data/vendor/liter-llm/src/provider/mod.rs +616 -0
  39. data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
  40. data/vendor/liter-llm/src/tests.rs +1425 -0
  41. data/vendor/liter-llm/src/tokenizer.rs +281 -0
  42. data/vendor/liter-llm/src/tower/budget.rs +599 -0
  43. data/vendor/liter-llm/src/tower/cache.rs +502 -0
  44. data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
  45. data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
  46. data/vendor/liter-llm/src/tower/cost.rs +404 -0
  47. data/vendor/liter-llm/src/tower/fallback.rs +121 -0
  48. data/vendor/liter-llm/src/tower/health.rs +219 -0
  49. data/vendor/liter-llm/src/tower/hooks.rs +369 -0
  50. data/vendor/liter-llm/src/tower/mod.rs +77 -0
  51. data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
  52. data/vendor/liter-llm/src/tower/router.rs +436 -0
  53. data/vendor/liter-llm/src/tower/service.rs +181 -0
  54. data/vendor/liter-llm/src/tower/tests.rs +539 -0
  55. data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
  56. data/vendor/liter-llm/src/tower/tracing.rs +209 -0
  57. data/vendor/liter-llm/src/tower/types.rs +170 -0
  58. data/vendor/liter-llm/src/types/audio.rs +52 -0
  59. data/vendor/liter-llm/src/types/batch.rs +77 -0
  60. data/vendor/liter-llm/src/types/chat.rs +214 -0
  61. data/vendor/liter-llm/src/types/common.rs +244 -0
  62. data/vendor/liter-llm/src/types/embedding.rs +84 -0
  63. data/vendor/liter-llm/src/types/files.rs +58 -0
  64. data/vendor/liter-llm/src/types/image.rs +40 -0
  65. data/vendor/liter-llm/src/types/mod.rs +27 -0
  66. data/vendor/liter-llm/src/types/models.rs +21 -0
  67. data/vendor/liter-llm/src/types/moderation.rs +80 -0
  68. data/vendor/liter-llm/src/types/ocr.rs +87 -0
  69. data/vendor/liter-llm/src/types/rerank.rs +46 -0
  70. data/vendor/liter-llm/src/types/responses.rs +55 -0
  71. data/vendor/liter-llm/src/types/search.rs +45 -0
  72. data/vendor/liter-llm/tests/contract.rs +332 -0
  73. data/vendor/liter-llm-ffi/Cargo.toml +30 -0
  74. data/vendor/liter-llm-ffi/build.rs +66 -0
  75. data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
  76. data/vendor/liter-llm-ffi/liter_llm.h +850 -0
  77. data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
  78. metadata +286 -0
@@ -0,0 +1,1425 @@
1
+ #[cfg(test)]
2
+ mod serde_tests {
3
+ use crate::types::*;
4
+
5
+ #[test]
6
+ fn chat_request_round_trip() {
7
+ let req = ChatCompletionRequest {
8
+ model: "gpt-4".into(),
9
+ messages: vec![
10
+ Message::System(SystemMessage {
11
+ content: "You are helpful.".into(),
12
+ name: None,
13
+ }),
14
+ Message::User(UserMessage {
15
+ content: UserContent::Text("Hello!".into()),
16
+ name: None,
17
+ }),
18
+ ],
19
+ temperature: Some(0.7),
20
+ max_tokens: Some(100),
21
+ ..Default::default()
22
+ };
23
+
24
+ let json = serde_json::to_string(&req).unwrap();
25
+ let parsed: ChatCompletionRequest = serde_json::from_str(&json).unwrap();
26
+ assert_eq!(parsed.model, "gpt-4");
27
+ assert_eq!(parsed.messages.len(), 2);
28
+ assert_eq!(parsed.temperature, Some(0.7));
29
+ assert_eq!(parsed.max_tokens, Some(100));
30
+ }
31
+
32
+ #[test]
33
+ fn chat_response_deserialize() {
34
+ let json = r#"{
35
+ "id": "chatcmpl-abc123",
36
+ "object": "chat.completion",
37
+ "created": 1700000000,
38
+ "model": "gpt-4",
39
+ "choices": [{
40
+ "index": 0,
41
+ "message": {
42
+ "role": "assistant",
43
+ "content": "Hello!"
44
+ },
45
+ "finish_reason": "stop"
46
+ }],
47
+ "usage": {
48
+ "prompt_tokens": 10,
49
+ "completion_tokens": 5,
50
+ "total_tokens": 15
51
+ }
52
+ }"#;
53
+
54
+ let resp: ChatCompletionResponse = serde_json::from_str(json).unwrap();
55
+ assert_eq!(resp.id, "chatcmpl-abc123");
56
+ assert_eq!(resp.choices.len(), 1);
57
+ assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello!"));
58
+ assert_eq!(resp.usage.as_ref().unwrap().total_tokens, 15);
59
+ }
60
+
61
+ #[test]
62
+ fn stream_chunk_deserialize() {
63
+ let json = r#"{
64
+ "id": "chatcmpl-abc123",
65
+ "object": "chat.completion.chunk",
66
+ "created": 1700000000,
67
+ "model": "gpt-4",
68
+ "choices": [{
69
+ "index": 0,
70
+ "delta": {
71
+ "content": "Hello"
72
+ },
73
+ "finish_reason": null
74
+ }]
75
+ }"#;
76
+
77
+ let chunk: ChatCompletionChunk = serde_json::from_str(json).unwrap();
78
+ assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
79
+ assert!(chunk.choices[0].finish_reason.is_none());
80
+ }
81
+
82
+ #[test]
83
+ fn tool_call_message_round_trip() {
84
+ let msg = Message::Assistant(AssistantMessage {
85
+ content: None,
86
+ name: None,
87
+ tool_calls: Some(vec![ToolCall {
88
+ id: "call_123".into(),
89
+ call_type: ToolType::Function,
90
+ function: FunctionCall {
91
+ name: "get_weather".into(),
92
+ arguments: r#"{"location": "NYC"}"#.into(),
93
+ },
94
+ }]),
95
+ refusal: None,
96
+ function_call: None,
97
+ });
98
+
99
+ let json = serde_json::to_string(&msg).unwrap();
100
+ let parsed: Message = serde_json::from_str(&json).unwrap();
101
+
102
+ if let Message::Assistant(a) = parsed {
103
+ let calls = a.tool_calls.unwrap();
104
+ assert_eq!(calls.len(), 1);
105
+ assert_eq!(calls[0].function.name, "get_weather");
106
+ } else {
107
+ panic!("expected assistant message");
108
+ }
109
+ }
110
+
111
+ #[test]
112
+ fn multipart_content_round_trip() {
113
+ let msg = Message::User(UserMessage {
114
+ content: UserContent::Parts(vec![
115
+ ContentPart::Text {
116
+ text: "What's in this image?".into(),
117
+ },
118
+ ContentPart::ImageUrl {
119
+ image_url: ImageUrl {
120
+ url: "https://example.com/image.png".into(),
121
+ detail: Some(ImageDetail::High),
122
+ },
123
+ },
124
+ ]),
125
+ name: None,
126
+ });
127
+
128
+ let json = serde_json::to_string(&msg).unwrap();
129
+ assert!(json.contains("image_url"));
130
+ let _: Message = serde_json::from_str(&json).unwrap();
131
+ }
132
+
133
+ #[test]
134
+ fn embedding_request_round_trip() {
135
+ let req = EmbeddingRequest {
136
+ model: "text-embedding-3-small".into(),
137
+ input: EmbeddingInput::Multiple(vec!["hello".into(), "world".into()]),
138
+ encoding_format: None,
139
+ dimensions: Some(256),
140
+ user: None,
141
+ };
142
+
143
+ let json = serde_json::to_string(&req).unwrap();
144
+ let parsed: EmbeddingRequest = serde_json::from_str(&json).unwrap();
145
+ assert_eq!(parsed.model, "text-embedding-3-small");
146
+ assert_eq!(parsed.dimensions, Some(256));
147
+ }
148
+
149
+ #[test]
150
+ fn embedding_response_deserialize() {
151
+ let json = r#"{
152
+ "object": "list",
153
+ "data": [{
154
+ "object": "embedding",
155
+ "embedding": [0.1, 0.2, 0.3],
156
+ "index": 0
157
+ }],
158
+ "model": "text-embedding-3-small",
159
+ "usage": {
160
+ "prompt_tokens": 5,
161
+ "completion_tokens": 0,
162
+ "total_tokens": 5
163
+ }
164
+ }"#;
165
+
166
+ let resp: EmbeddingResponse = serde_json::from_str(json).unwrap();
167
+ assert_eq!(resp.data.len(), 1);
168
+ assert_eq!(resp.data[0].embedding.len(), 3);
169
+ }
170
+
171
+ #[test]
172
+ fn developer_message_round_trip() {
173
+ let msg = Message::Developer(DeveloperMessage {
174
+ content: "You are a dev assistant.".into(),
175
+ name: Some("devbot".into()),
176
+ });
177
+ let json = serde_json::to_string(&msg).unwrap();
178
+ assert!(json.contains("\"role\":\"developer\""));
179
+ let parsed: Message = serde_json::from_str(&json).unwrap();
180
+ if let Message::Developer(d) = parsed {
181
+ assert_eq!(d.content, "You are a dev assistant.");
182
+ assert_eq!(d.name.as_deref(), Some("devbot"));
183
+ } else {
184
+ panic!("expected developer message");
185
+ }
186
+ }
187
+
188
+ #[test]
189
+ fn function_message_round_trip() {
190
+ let msg = Message::Function(FunctionMessage {
191
+ content: r#"{"temperature": 72}"#.into(),
192
+ name: "get_weather".into(),
193
+ });
194
+ let json = serde_json::to_string(&msg).unwrap();
195
+ assert!(json.contains("\"role\":\"function\""));
196
+ let parsed: Message = serde_json::from_str(&json).unwrap();
197
+ if let Message::Function(f) = parsed {
198
+ assert_eq!(f.name, "get_weather");
199
+ } else {
200
+ panic!("expected function message");
201
+ }
202
+ }
203
+
204
+ #[test]
205
+ fn assistant_message_with_refusal() {
206
+ let msg = Message::Assistant(AssistantMessage {
207
+ content: None,
208
+ name: None,
209
+ tool_calls: None,
210
+ refusal: Some("I cannot help with that.".into()),
211
+ function_call: None,
212
+ });
213
+ let json = serde_json::to_string(&msg).unwrap();
214
+ assert!(json.contains("refusal"));
215
+ let parsed: Message = serde_json::from_str(&json).unwrap();
216
+ if let Message::Assistant(a) = parsed {
217
+ assert_eq!(a.refusal.as_deref(), Some("I cannot help with that."));
218
+ } else {
219
+ panic!("expected assistant message");
220
+ }
221
+ }
222
+
223
+ #[test]
224
+ fn finish_reason_function_call_serde() {
225
+ let reason = FinishReason::FunctionCall;
226
+ let json = serde_json::to_string(&reason).unwrap();
227
+ assert_eq!(json, "\"function_call\"");
228
+ let parsed: FinishReason = serde_json::from_str(&json).unwrap();
229
+ assert_eq!(parsed, FinishReason::FunctionCall);
230
+ }
231
+
232
+ #[test]
233
+ fn service_tier_in_response() {
234
+ let json = r#"{
235
+ "id": "chatcmpl-abc",
236
+ "object": "chat.completion",
237
+ "created": 1700000000,
238
+ "model": "gpt-4",
239
+ "choices": [{
240
+ "index": 0,
241
+ "message": {"role": "assistant", "content": "Hi"},
242
+ "finish_reason": "stop"
243
+ }],
244
+ "service_tier": "default"
245
+ }"#;
246
+ let resp: ChatCompletionResponse = serde_json::from_str(json).unwrap();
247
+ assert_eq!(resp.service_tier.as_deref(), Some("default"));
248
+ }
249
+
250
+ #[test]
251
+ fn response_format_json_schema() {
252
+ let fmt = ResponseFormat::JsonSchema {
253
+ json_schema: JsonSchemaFormat {
254
+ name: "my_schema".into(),
255
+ description: None,
256
+ schema: serde_json::json!({"type": "object"}),
257
+ strict: Some(true),
258
+ },
259
+ };
260
+
261
+ let json = serde_json::to_string(&fmt).unwrap();
262
+ assert!(json.contains("json_schema"));
263
+ let _: ResponseFormat = serde_json::from_str(&json).unwrap();
264
+ }
265
+
266
+ #[test]
267
+ fn finish_reason_other_unknown_string() {
268
+ // Catch-all variant for unknown finish reasons from non-OpenAI providers
269
+ let json = r#""custom_stop_reason""#;
270
+ let reason: FinishReason = serde_json::from_str(json).unwrap();
271
+ assert_eq!(reason, FinishReason::Other);
272
+ }
273
+
274
+ #[test]
275
+ fn finish_reason_stop_serde() {
276
+ let reason = FinishReason::Stop;
277
+ let json = serde_json::to_string(&reason).unwrap();
278
+ assert_eq!(json, "\"stop\"");
279
+ let parsed: FinishReason = serde_json::from_str(&json).unwrap();
280
+ assert_eq!(parsed, FinishReason::Stop);
281
+ }
282
+
283
+ #[test]
284
+ fn finish_reason_length_serde() {
285
+ let reason = FinishReason::Length;
286
+ let json = serde_json::to_string(&reason).unwrap();
287
+ assert_eq!(json, "\"length\"");
288
+ let parsed: FinishReason = serde_json::from_str(&json).unwrap();
289
+ assert_eq!(parsed, FinishReason::Length);
290
+ }
291
+
292
+ #[test]
293
+ fn embedding_format_float_serde() {
294
+ let fmt = EmbeddingFormat::Float;
295
+ let json = serde_json::to_string(&fmt).unwrap();
296
+ assert_eq!(json, "\"float\"");
297
+ let parsed: EmbeddingFormat = serde_json::from_str(&json).unwrap();
298
+ assert_eq!(parsed, EmbeddingFormat::Float);
299
+ }
300
+
301
+ #[test]
302
+ fn embedding_format_base64_serde() {
303
+ let fmt = EmbeddingFormat::Base64;
304
+ let json = serde_json::to_string(&fmt).unwrap();
305
+ assert_eq!(json, "\"base64\"");
306
+ let parsed: EmbeddingFormat = serde_json::from_str(&json).unwrap();
307
+ assert_eq!(parsed, EmbeddingFormat::Base64);
308
+ }
309
+
310
+ #[test]
311
+ fn embedding_input_single_string() {
312
+ let input = EmbeddingInput::Single("hello world".into());
313
+ let json = serde_json::to_string(&input).unwrap();
314
+ assert_eq!(json, "\"hello world\"");
315
+ let parsed: EmbeddingInput = serde_json::from_str(&json).unwrap();
316
+ assert_eq!(parsed, input);
317
+ }
318
+
319
+ #[test]
320
+ fn embedding_input_multiple_strings() {
321
+ let input = EmbeddingInput::Multiple(vec!["hello".into(), "world".into(), "test".into()]);
322
+ let json = serde_json::to_string(&input).unwrap();
323
+ assert!(json.contains("hello"));
324
+ assert!(json.contains("world"));
325
+ let parsed: EmbeddingInput = serde_json::from_str(&json).unwrap();
326
+ assert_eq!(parsed, input);
327
+ }
328
+
329
+ #[test]
330
+ fn embedding_request_with_format_and_dimensions() {
331
+ let req = EmbeddingRequest {
332
+ model: "text-embedding-3-large".into(),
333
+ input: EmbeddingInput::Single("test".into()),
334
+ encoding_format: Some(EmbeddingFormat::Base64),
335
+ dimensions: Some(1024),
336
+ user: Some("user-123".into()),
337
+ };
338
+
339
+ let json = serde_json::to_string(&req).unwrap();
340
+ let parsed: EmbeddingRequest = serde_json::from_str(&json).unwrap();
341
+ assert_eq!(parsed.encoding_format, Some(EmbeddingFormat::Base64));
342
+ assert_eq!(parsed.dimensions, Some(1024));
343
+ assert_eq!(parsed.user, Some("user-123".into()));
344
+ }
345
+
346
+ #[test]
347
+ fn stop_sequence_single_serde() {
348
+ let stop = StopSequence::Single("\\n".into());
349
+ let json = serde_json::to_string(&stop).unwrap();
350
+ assert_eq!(json, "\"\\\\n\"");
351
+ let parsed: StopSequence = serde_json::from_str(&json).unwrap();
352
+ assert_eq!(parsed, stop);
353
+ }
354
+
355
+ #[test]
356
+ fn stop_sequence_multiple_serde() {
357
+ let stop = StopSequence::Multiple(vec!["\\n".into(), "\\n\\n".into(), "[END]".into()]);
358
+ let json = serde_json::to_string(&stop).unwrap();
359
+ assert!(json.contains("END"));
360
+ let parsed: StopSequence = serde_json::from_str(&json).unwrap();
361
+ assert_eq!(parsed, stop);
362
+ }
363
+
364
+ #[test]
365
+ fn tool_choice_mode_auto_serde() {
366
+ let choice = ToolChoice::Mode(ToolChoiceMode::Auto);
367
+ let json = serde_json::to_string(&choice).unwrap();
368
+ assert_eq!(json, "\"auto\"");
369
+ let parsed: ToolChoice = serde_json::from_str(&json).unwrap();
370
+ assert_eq!(parsed, choice);
371
+ }
372
+
373
+ #[test]
374
+ fn tool_choice_mode_required_serde() {
375
+ let choice = ToolChoice::Mode(ToolChoiceMode::Required);
376
+ let json = serde_json::to_string(&choice).unwrap();
377
+ assert_eq!(json, "\"required\"");
378
+ let parsed: ToolChoice = serde_json::from_str(&json).unwrap();
379
+ assert_eq!(parsed, choice);
380
+ }
381
+
382
+ #[test]
383
+ fn tool_choice_mode_none_serde() {
384
+ let choice = ToolChoice::Mode(ToolChoiceMode::None);
385
+ let json = serde_json::to_string(&choice).unwrap();
386
+ assert_eq!(json, "\"none\"");
387
+ let parsed: ToolChoice = serde_json::from_str(&json).unwrap();
388
+ assert_eq!(parsed, choice);
389
+ }
390
+
391
+ #[test]
392
+ fn tool_choice_specific_serde() {
393
+ let choice = ToolChoice::Specific(SpecificToolChoice {
394
+ choice_type: ToolType::Function,
395
+ function: SpecificFunction {
396
+ name: "get_weather".into(),
397
+ },
398
+ });
399
+ let json = serde_json::to_string(&choice).unwrap();
400
+ assert!(json.contains("get_weather"));
401
+ let parsed: ToolChoice = serde_json::from_str(&json).unwrap();
402
+ assert_eq!(parsed, choice);
403
+ }
404
+
405
+ #[test]
406
+ fn response_format_text_serde() {
407
+ let fmt = ResponseFormat::Text;
408
+ let json = serde_json::to_string(&fmt).unwrap();
409
+ assert_eq!(json, "{\"type\":\"text\"}");
410
+ let parsed: ResponseFormat = serde_json::from_str(&json).unwrap();
411
+ assert_eq!(parsed, fmt);
412
+ }
413
+
414
+ #[test]
415
+ fn response_format_json_object_serde() {
416
+ let fmt = ResponseFormat::JsonObject;
417
+ let json = serde_json::to_string(&fmt).unwrap();
418
+ assert_eq!(json, "{\"type\":\"json_object\"}");
419
+ let parsed: ResponseFormat = serde_json::from_str(&json).unwrap();
420
+ assert_eq!(parsed, fmt);
421
+ }
422
+
423
+ #[test]
424
+ fn chat_completion_request_default_round_trip() {
425
+ let req = ChatCompletionRequest::default();
426
+ let json = serde_json::to_string(&req).unwrap();
427
+ let parsed: ChatCompletionRequest = serde_json::from_str(&json).unwrap();
428
+ assert_eq!(parsed, req);
429
+ assert!(parsed.model.is_empty());
430
+ assert!(parsed.messages.is_empty());
431
+ }
432
+
433
+ #[test]
434
+ fn chat_completion_request_full_fields_partial_eq() {
435
+ let req1 = ChatCompletionRequest {
436
+ model: "gpt-4".into(),
437
+ messages: vec![Message::System(SystemMessage {
438
+ content: "You are helpful.".into(),
439
+ name: None,
440
+ })],
441
+ temperature: Some(0.7),
442
+ max_tokens: Some(100),
443
+ top_p: Some(0.95),
444
+ n: Some(1),
445
+ stream: Some(false),
446
+ stop: Some(StopSequence::Single("\\n".into())),
447
+ presence_penalty: Some(0.0),
448
+ frequency_penalty: Some(0.0),
449
+ logit_bias: None,
450
+ user: Some("user-1".into()),
451
+ tools: None,
452
+ tool_choice: None,
453
+ parallel_tool_calls: None,
454
+ response_format: None,
455
+ stream_options: None,
456
+ seed: None,
457
+ reasoning_effort: None,
458
+ extra_body: None,
459
+ };
460
+
461
+ let json = serde_json::to_string(&req1).unwrap();
462
+ let req2: ChatCompletionRequest = serde_json::from_str(&json).unwrap();
463
+ assert_eq!(req1, req2);
464
+ }
465
+
466
+ #[test]
467
+ fn message_variant_equality() {
468
+ // Test that Message variants support PartialEq correctly
469
+ let msg1 = Message::System(SystemMessage {
470
+ content: "test".into(),
471
+ name: None,
472
+ });
473
+ let msg2 = Message::System(SystemMessage {
474
+ content: "test".into(),
475
+ name: None,
476
+ });
477
+ assert_eq!(msg1, msg2);
478
+
479
+ let msg3 = Message::System(SystemMessage {
480
+ content: "different".into(),
481
+ name: None,
482
+ });
483
+ assert_ne!(msg1, msg3);
484
+ }
485
+
486
+ #[test]
487
+ fn message_assistant_round_trip_equality() {
488
+ let msg = Message::Assistant(AssistantMessage {
489
+ content: Some("Hello!".into()),
490
+ name: None,
491
+ tool_calls: None,
492
+ refusal: None,
493
+ function_call: None,
494
+ });
495
+ let json = serde_json::to_string(&msg).unwrap();
496
+ let parsed: Message = serde_json::from_str(&json).unwrap();
497
+ assert_eq!(msg, parsed);
498
+ }
499
+
500
+ #[test]
501
+ fn message_user_round_trip_equality() {
502
+ let msg = Message::User(UserMessage {
503
+ content: UserContent::Text("What's up?".into()),
504
+ name: None,
505
+ });
506
+ let json = serde_json::to_string(&msg).unwrap();
507
+ let parsed: Message = serde_json::from_str(&json).unwrap();
508
+ assert_eq!(msg, parsed);
509
+ }
510
+
511
+ #[test]
512
+ fn message_tool_round_trip() {
513
+ let msg = Message::Tool(ToolMessage {
514
+ content: r#"{"result": "sunny"}"#.into(),
515
+ tool_call_id: "call_456".into(),
516
+ name: Some("get_weather".into()),
517
+ });
518
+ let json = serde_json::to_string(&msg).unwrap();
519
+ let parsed: Message = serde_json::from_str(&json).unwrap();
520
+ if let Message::Tool(t) = parsed {
521
+ assert_eq!(t.tool_call_id, "call_456");
522
+ assert_eq!(t.name.as_deref(), Some("get_weather"));
523
+ } else {
524
+ panic!("expected tool message");
525
+ }
526
+ }
527
+
528
+ #[test]
529
+ fn user_content_parts_image_detail_low() {
530
+ let msg = Message::User(UserMessage {
531
+ content: UserContent::Parts(vec![
532
+ ContentPart::Text {
533
+ text: "Describe this".into(),
534
+ },
535
+ ContentPart::ImageUrl {
536
+ image_url: ImageUrl {
537
+ url: "https://example.com/img.png".into(),
538
+ detail: Some(ImageDetail::Low),
539
+ },
540
+ },
541
+ ]),
542
+ name: None,
543
+ });
544
+ let json = serde_json::to_string(&msg).unwrap();
545
+ let parsed: Message = serde_json::from_str(&json).unwrap();
546
+ assert_eq!(msg, parsed);
547
+ }
548
+
549
+ #[test]
550
+ fn user_content_parts_image_detail_auto() {
551
+ let msg = Message::User(UserMessage {
552
+ content: UserContent::Parts(vec![ContentPart::ImageUrl {
553
+ image_url: ImageUrl {
554
+ url: "https://example.com/img.png".into(),
555
+ detail: Some(ImageDetail::Auto),
556
+ },
557
+ }]),
558
+ name: None,
559
+ });
560
+ let json = serde_json::to_string(&msg).unwrap();
561
+ assert!(json.contains("\"detail\":\"auto\""));
562
+ let parsed: Message = serde_json::from_str(&json).unwrap();
563
+ assert_eq!(msg, parsed);
564
+ }
565
+
566
+ // -- Search type serialization tests --
567
+
568
+ #[test]
569
+ fn search_request_round_trip() {
570
+ use crate::types::search::SearchRequest;
571
+ let req = SearchRequest {
572
+ model: "brave/web-search".into(),
573
+ query: "What is Rust?".into(),
574
+ ..Default::default()
575
+ };
576
+ let json = serde_json::to_string(&req).unwrap();
577
+ let parsed: SearchRequest = serde_json::from_str(&json).unwrap();
578
+ assert_eq!(parsed.model, "brave/web-search");
579
+ assert_eq!(parsed.query, "What is Rust?");
580
+ }
581
+
582
+ #[test]
583
+ fn search_request_rejects_unknown_fields() {
584
+ use crate::types::search::SearchRequest;
585
+ let json = r#"{"model":"test","query":"q","unknown_field":true}"#;
586
+ assert!(
587
+ serde_json::from_str::<SearchRequest>(json).is_err(),
588
+ "SearchRequest with deny_unknown_fields should reject unknown keys"
589
+ );
590
+ }
591
+
592
+ #[test]
593
+ fn search_request_with_optional_fields() {
594
+ use crate::types::search::SearchRequest;
595
+ let req = SearchRequest {
596
+ model: "tavily/search".into(),
597
+ query: "Rust language".into(),
598
+ max_results: Some(5),
599
+ search_domain_filter: Some(vec!["rust-lang.org".into()]),
600
+ country: Some("US".into()),
601
+ };
602
+ let json = serde_json::to_string(&req).unwrap();
603
+ assert!(json.contains("max_results"));
604
+ assert!(json.contains("rust-lang.org"));
605
+ let parsed: SearchRequest = serde_json::from_str(&json).unwrap();
606
+ assert_eq!(parsed.max_results, Some(5));
607
+ assert_eq!(parsed.country, Some("US".into()));
608
+ }
609
+
610
+ #[test]
611
+ fn search_response_deserialize() {
612
+ use crate::types::search::SearchResponse;
613
+ let json = r#"{
614
+ "results": [
615
+ {"title": "Rust", "url": "https://www.rust-lang.org", "snippet": "A language empowering everyone."}
616
+ ],
617
+ "model": "brave/web-search"
618
+ }"#;
619
+ let resp: SearchResponse = serde_json::from_str(json).unwrap();
620
+ assert_eq!(resp.results.len(), 1);
621
+ assert_eq!(resp.model, "brave/web-search");
622
+ assert_eq!(resp.results[0].title, "Rust");
623
+ }
624
+
625
+ // -- OCR type serialization tests --
626
+
627
+ #[test]
628
+ fn ocr_request_url_variant() {
629
+ use crate::types::ocr::{OcrDocument, OcrRequest};
630
+ let req = OcrRequest {
631
+ model: "mistral/mistral-ocr-latest".into(),
632
+ document: OcrDocument::Url {
633
+ url: "https://example.com/doc.pdf".into(),
634
+ },
635
+ pages: None,
636
+ include_image_base64: None,
637
+ };
638
+ let json = serde_json::to_string(&req).unwrap();
639
+ let parsed: OcrRequest = serde_json::from_str(&json).unwrap();
640
+ assert_eq!(parsed.model, "mistral/mistral-ocr-latest");
641
+ }
642
+
643
+ #[test]
644
+ fn ocr_document_url_serializes_with_tag() {
645
+ use crate::types::ocr::OcrDocument;
646
+ let doc = OcrDocument::Url {
647
+ url: "https://example.com".into(),
648
+ };
649
+ let json = serde_json::to_string(&doc).unwrap();
650
+ assert!(
651
+ json.contains("document_url"),
652
+ "expected 'document_url' tag in serialized output, got: {json}"
653
+ );
654
+ }
655
+
656
+ #[test]
657
+ fn ocr_document_base64_serializes_with_tag() {
658
+ use crate::types::ocr::OcrDocument;
659
+ let doc = OcrDocument::Base64 {
660
+ data: "dGVzdA==".into(),
661
+ media_type: "application/pdf".into(),
662
+ };
663
+ let json = serde_json::to_string(&doc).unwrap();
664
+ assert!(
665
+ json.contains("\"type\":\"base64\""),
666
+ "expected 'base64' tag in serialized output, got: {json}"
667
+ );
668
+ let parsed: OcrDocument = serde_json::from_str(&json).unwrap();
669
+ if let OcrDocument::Base64 { data, media_type } = parsed {
670
+ assert_eq!(data, "dGVzdA==");
671
+ assert_eq!(media_type, "application/pdf");
672
+ } else {
673
+ panic!("expected OcrDocument::Base64 variant");
674
+ }
675
+ }
676
+
677
+ #[test]
678
+ fn ocr_request_rejects_unknown_fields() {
679
+ use crate::types::ocr::OcrRequest;
680
+ let json = r#"{"model":"m","document":{"type":"document_url","url":"u"},"bogus":1}"#;
681
+ assert!(
682
+ serde_json::from_str::<OcrRequest>(json).is_err(),
683
+ "OcrRequest with deny_unknown_fields should reject unknown keys"
684
+ );
685
+ }
686
+
687
+ #[test]
688
+ fn ocr_response_deserialize() {
689
+ use crate::types::ocr::OcrResponse;
690
+ let json = r##"{
691
+ "pages": [{"index": 0, "markdown": "# Title"}],
692
+ "model": "mistral/mistral-ocr-latest"
693
+ }"##;
694
+ let resp: OcrResponse = serde_json::from_str(json).unwrap();
695
+ assert_eq!(resp.pages.len(), 1);
696
+ assert_eq!(resp.pages[0].index, 0);
697
+ assert!(resp.pages[0].markdown.contains("Title"));
698
+ assert_eq!(resp.model, "mistral/mistral-ocr-latest");
699
+ }
700
+ }
701
+
702
+ #[cfg(test)]
703
+ mod provider_tests {
704
+ use std::collections::HashMap;
705
+
706
+ use crate::provider::{
707
+ AuthConfig, AuthType, ConfigDrivenProvider, OpenAiProvider, Provider, ProviderConfig, detect_provider,
708
+ };
709
+
710
+ #[test]
711
+ fn openai_matches() {
712
+ let p = OpenAiProvider;
713
+ assert!(p.matches_model("gpt-4"));
714
+ assert!(p.matches_model("gpt-4o-mini"));
715
+ assert!(p.matches_model("o1-preview"));
716
+ assert!(p.matches_model("o3-mini"));
717
+ assert!(p.matches_model("text-embedding-3-small"));
718
+ assert!(!p.matches_model("claude-3-opus"));
719
+ assert!(!p.matches_model("groq/llama3"));
720
+ }
721
+
722
+ #[test]
723
+ fn detect_openai() {
724
+ let p = detect_provider("gpt-4").unwrap();
725
+ assert_eq!(p.name(), "openai");
726
+ assert_eq!(p.base_url(), "https://api.openai.com/v1");
727
+ }
728
+
729
+ #[test]
730
+ fn detect_groq() {
731
+ let p = detect_provider("groq/llama3-70b").unwrap();
732
+ assert_eq!(p.name(), "groq");
733
+ assert_eq!(p.base_url(), "https://api.groq.com/openai/v1");
734
+ }
735
+
736
+ #[test]
737
+ fn detect_mistral_prefix() {
738
+ // The registry uses the "mistral/" routing prefix for Mistral models.
739
+ // Bare model names without a slash prefix return None; callers should
740
+ // use the prefixed form "mistral/mistral-large-latest".
741
+ let p = detect_provider("mistral/mistral-large-latest").unwrap();
742
+ assert_eq!(p.name(), "mistral");
743
+ assert_eq!(p.base_url(), "https://api.mistral.ai/v1");
744
+ }
745
+
746
+ #[test]
747
+ fn detect_ollama() {
748
+ let p = detect_provider("ollama/llama3").unwrap();
749
+ assert_eq!(p.name(), "ollama");
750
+ assert_eq!(p.base_url(), "http://localhost:11434/v1");
751
+ }
752
+
753
+ #[test]
754
+ fn detect_unknown_returns_none() {
755
+ assert!(detect_provider("some-random-model").is_none());
756
+ }
757
+
758
+ fn make_provider(auth_type: AuthType) -> ConfigDrivenProvider {
759
+ // Box::leak gives us a &'static reference, which ConfigDrivenProvider
760
+ // now requires. Leaking in tests is acceptable — each test process is
761
+ // short-lived and the total number of leaked configs is tiny.
762
+ let cfg: &'static ProviderConfig = Box::leak(Box::new(ProviderConfig {
763
+ name: "test-provider".into(),
764
+ display_name: None,
765
+ base_url: Some("https://api.example.com/v1".into()),
766
+ auth: Some(AuthConfig {
767
+ auth_type,
768
+ env_var: Some("TEST_API_KEY".into()),
769
+ }),
770
+ endpoints: None,
771
+ model_prefixes: None,
772
+ param_mappings: None,
773
+ }));
774
+ ConfigDrivenProvider::new(cfg)
775
+ }
776
+
777
+ #[test]
778
+ fn config_driven_bearer_auth() {
779
+ let provider = make_provider(AuthType::Bearer);
780
+ let header = provider.auth_header("my-secret-key");
781
+ assert!(header.is_some());
782
+ let (name, value) = header.unwrap();
783
+ assert_eq!(name, "Authorization");
784
+ assert_eq!(value, "Bearer my-secret-key");
785
+ }
786
+
787
+ #[test]
788
+ fn config_driven_api_key_auth() {
789
+ let provider = make_provider(AuthType::ApiKey);
790
+ let header = provider.auth_header("my-secret-key");
791
+ assert!(header.is_some());
792
+ let (name, value) = header.unwrap();
793
+ assert_eq!(name, "x-api-key");
794
+ assert_eq!(value, "my-secret-key");
795
+ }
796
+
797
+ #[test]
798
+ fn config_driven_no_auth() {
799
+ let provider = make_provider(AuthType::None);
800
+ let header = provider.auth_header("my-secret-key");
801
+ assert!(header.is_none(), "AuthType::None should return no auth header");
802
+ }
803
+
804
+ #[test]
805
+ fn detect_deepseek() {
806
+ let p = detect_provider("deepseek/deepseek-chat").unwrap();
807
+ assert_eq!(p.name(), "deepseek");
808
+ }
809
+
810
+ #[test]
811
+ fn detect_cerebras() {
812
+ let p = detect_provider("cerebras/llama-3.1-70b").unwrap();
813
+ assert_eq!(p.name(), "cerebras");
814
+ }
815
+
816
+ #[test]
817
+ fn detect_openrouter() {
818
+ let p = detect_provider("openrouter/auto").unwrap();
819
+ assert_eq!(p.name(), "openrouter");
820
+ }
821
+
822
+ #[test]
823
+ fn config_driven_unknown_auth_defaults_to_bearer() {
824
+ let provider = make_provider(AuthType::Unknown);
825
+ let header = provider.auth_header("my-secret-key");
826
+ assert!(header.is_some());
827
+ let (name, value) = header.unwrap();
828
+ assert_eq!(name, "Authorization");
829
+ assert!(value.contains("Bearer"));
830
+ }
831
+
832
+ #[test]
833
+ fn provider_base_url_contract() {
834
+ let provider = OpenAiProvider;
835
+ assert_eq!(provider.base_url(), "https://api.openai.com/v1");
836
+ }
837
+
838
+ #[test]
839
+ fn provider_strip_model_prefix_groq() {
840
+ let p = detect_provider("groq/llama3-70b").unwrap();
841
+ let stripped = p.strip_model_prefix("groq/llama3-70b");
842
+ assert_eq!(stripped, "llama3-70b");
843
+ }
844
+
845
+ #[test]
846
+ fn provider_strip_model_prefix_openai() {
847
+ let p = OpenAiProvider;
848
+ let stripped = p.strip_model_prefix("gpt-4");
849
+ assert_eq!(stripped, "gpt-4"); // OpenAI has no prefix
850
+ }
851
+
852
+ #[test]
853
+ fn detect_anthropic_by_claude_prefix() {
854
+ let p = detect_provider("claude-3-5-sonnet-20241022").unwrap();
855
+ assert_eq!(p.name(), "anthropic");
856
+ assert_eq!(p.base_url(), "https://api.anthropic.com/v1");
857
+ }
858
+
859
+ #[test]
860
+ fn detect_anthropic_by_slash_prefix() {
861
+ let p = detect_provider("anthropic/claude-3-5-sonnet-20241022").unwrap();
862
+ assert_eq!(p.name(), "anthropic");
863
+ }
864
+
865
+ #[test]
866
+ fn anthropic_auth_header_uses_x_api_key() {
867
+ let p = detect_provider("claude-3-5-sonnet-20241022").unwrap();
868
+ let header = p.auth_header("sk-ant-test");
869
+ assert!(header.is_some());
870
+ let (name, value) = header.unwrap();
871
+ assert_eq!(name, "x-api-key");
872
+ assert_eq!(value, "sk-ant-test");
873
+ }
874
+
875
+ #[test]
876
+ fn anthropic_extra_headers_contain_version() {
877
+ use crate::provider::anthropic::AnthropicProvider;
878
+ let p = AnthropicProvider;
879
+ let extras = p.extra_headers();
880
+ assert_eq!(extras.len(), 1);
881
+ assert_eq!(extras[0].0, "anthropic-version");
882
+ assert_eq!(extras[0].1, "2023-06-01");
883
+ }
884
+
885
+ #[test]
886
+ fn anthropic_strips_prefix() {
887
+ let p = detect_provider("anthropic/claude-3-5-sonnet-20241022").unwrap();
888
+ assert_eq!(
889
+ p.strip_model_prefix("anthropic/claude-3-5-sonnet-20241022"),
890
+ "claude-3-5-sonnet-20241022"
891
+ );
892
+ }
893
+
894
+ #[test]
895
+ fn anthropic_bare_model_not_stripped() {
896
+ use crate::provider::anthropic::AnthropicProvider;
897
+ let p = AnthropicProvider;
898
+ assert_eq!(
899
+ p.strip_model_prefix("claude-3-5-sonnet-20241022"),
900
+ "claude-3-5-sonnet-20241022"
901
+ );
902
+ }
903
+
904
+ #[test]
905
+ fn detect_azure_by_prefix() {
906
+ let p = detect_provider("azure/gpt-4").unwrap();
907
+ assert_eq!(p.name(), "azure");
908
+ }
909
+
910
+ #[test]
911
+ fn azure_auth_header_uses_api_key() {
912
+ let p = detect_provider("azure/gpt-4").unwrap();
913
+ let header = p.auth_header("my-azure-key");
914
+ assert!(header.is_some());
915
+ let (name, value) = header.unwrap();
916
+ assert_eq!(name, "api-key");
917
+ assert_eq!(value, "my-azure-key");
918
+ }
919
+
920
+ #[test]
921
+ fn azure_strips_prefix() {
922
+ let p = detect_provider("azure/gpt-4").unwrap();
923
+ assert_eq!(p.strip_model_prefix("azure/gpt-4"), "gpt-4");
924
+ }
925
+
926
+ #[test]
927
+ fn azure_extra_headers_are_empty() {
928
+ use crate::provider::azure::AzureProvider;
929
+ let p = AzureProvider::new();
930
+ assert!(p.extra_headers().is_empty());
931
+ }
932
+
933
+ #[test]
934
+ fn detect_vertex_ai_by_prefix() {
935
+ let p = detect_provider("vertex_ai/gemini-2.0-flash").unwrap();
936
+ assert_eq!(p.name(), "vertex_ai");
937
+ }
938
+
939
+ #[test]
940
+ fn vertex_ai_auth_header_uses_bearer() {
941
+ let p = detect_provider("vertex_ai/gemini-2.0-flash").unwrap();
942
+ let header = p.auth_header("ya29.my-access-token");
943
+ assert!(header.is_some(), "Expected an auth header");
944
+ let (name, value) = header.unwrap();
945
+ assert_eq!(name, "Authorization");
946
+ assert_eq!(value, "Bearer ya29.my-access-token");
947
+ }
948
+
949
+ #[test]
950
+ fn vertex_ai_strips_prefix() {
951
+ let p = detect_provider("vertex_ai/gemini-2.0-flash").unwrap();
952
+ assert_eq!(p.strip_model_prefix("vertex_ai/gemini-2.0-flash"), "gemini-2.0-flash");
953
+ }
954
+
955
+ #[test]
956
+ fn vertex_ai_bare_model_not_stripped() {
957
+ use crate::provider::vertex::VertexAiProvider;
958
+ let p = VertexAiProvider::new("test-project", "us-central1");
959
+ assert_eq!(p.strip_model_prefix("gemini-2.0-flash"), "gemini-2.0-flash");
960
+ }
961
+
962
+ #[test]
963
+ fn vertex_ai_does_not_match_gemini_unprefixed() {
964
+ // Unprefixed "gemini-*" models route to the Google AI Studio (gemini)
965
+ // provider, not Vertex AI. Vertex AI requires the explicit prefix.
966
+ let p = detect_provider("gemini-2.0-flash");
967
+ // May return Some (from the registry gemini provider) but must NOT be
968
+ // the vertex_ai provider.
969
+ if let Some(p) = p {
970
+ assert_ne!(p.name(), "vertex_ai");
971
+ }
972
+ }
973
+
974
+ #[test]
975
+ fn vertex_ai_extra_headers_are_empty() {
976
+ use crate::provider::vertex::VertexAiProvider;
977
+ let p = VertexAiProvider::new("test-project", "us-central1");
978
+ assert!(p.extra_headers().is_empty());
979
+ }
980
+
981
+ // ── AWS Bedrock ─────────────────────────────────────────────────────────
982
+
983
+ #[test]
984
+ fn detect_bedrock_by_prefix() {
985
+ let p = detect_provider("bedrock/anthropic.claude-3-sonnet-20240229-v1:0").unwrap();
986
+ assert_eq!(p.name(), "bedrock");
987
+ }
988
+
989
+ #[test]
990
+ fn bedrock_matches_only_prefixed_models() {
991
+ use crate::provider::bedrock::BedrockProvider;
992
+ let p = BedrockProvider::from_env();
993
+ assert!(p.matches_model("bedrock/anthropic.claude-3-sonnet-20240229-v1:0"));
994
+ assert!(p.matches_model("bedrock/amazon.nova-pro-v1:0"));
995
+ // Unprefixed model IDs must NOT match — they belong to other providers.
996
+ assert!(!p.matches_model("anthropic.claude-3-sonnet-20240229-v1:0"));
997
+ assert!(!p.matches_model("amazon.nova-pro-v1:0"));
998
+ }
999
+
1000
+ #[test]
1001
+ fn bedrock_strips_prefix() {
1002
+ let p = detect_provider("bedrock/anthropic.claude-3-sonnet-20240229-v1:0").unwrap();
1003
+ assert_eq!(
1004
+ p.strip_model_prefix("bedrock/anthropic.claude-3-sonnet-20240229-v1:0"),
1005
+ "anthropic.claude-3-sonnet-20240229-v1:0"
1006
+ );
1007
+ }
1008
+
1009
+ #[test]
1010
+ fn bedrock_bare_model_not_stripped() {
1011
+ use crate::provider::bedrock::BedrockProvider;
1012
+ let p = BedrockProvider::from_env();
1013
+ // Without the "bedrock/" prefix, the model name is returned unchanged.
1014
+ assert_eq!(
1015
+ p.strip_model_prefix("anthropic.claude-3-sonnet-20240229-v1:0"),
1016
+ "anthropic.claude-3-sonnet-20240229-v1:0"
1017
+ );
1018
+ }
1019
+
1020
+ #[test]
1021
+ fn bedrock_auth_header_returns_none() {
1022
+ // SigV4 uses computed headers, not a static auth header.
1023
+ let p = detect_provider("bedrock/anthropic.claude-3-sonnet-20240229-v1:0").unwrap();
1024
+ let header = p.auth_header("ignored-for-sigv4");
1025
+ assert!(header.is_none(), "BedrockProvider must return None for auth_header");
1026
+ }
1027
+
1028
+ #[test]
1029
+ fn bedrock_extra_headers_are_empty() {
1030
+ use crate::provider::bedrock::BedrockProvider;
1031
+ let p = BedrockProvider::from_env();
1032
+ assert!(p.extra_headers().is_empty(), "Bedrock has no static extra headers");
1033
+ }
1034
+
1035
+ #[test]
1036
+ fn bedrock_base_url_includes_region() {
1037
+ use crate::provider::bedrock::BedrockProvider;
1038
+ let p = BedrockProvider::new("eu-west-1");
1039
+ assert_eq!(p.base_url(), "https://bedrock-runtime.eu-west-1.amazonaws.com");
1040
+ }
1041
+
1042
+ #[test]
1043
+ fn bedrock_default_region_is_us_east_1() {
1044
+ use crate::provider::bedrock::BedrockProvider;
1045
+ // Ensure AWS_DEFAULT_REGION / AWS_REGION are not set in this test process.
1046
+ // We use a direct constructor to guarantee the region regardless of env.
1047
+ let p = BedrockProvider::new("us-east-1");
1048
+ assert!(p.base_url().contains("us-east-1"));
1049
+ }
1050
+
1051
+ #[test]
1052
+ fn bedrock_signing_headers_without_feature_returns_empty() {
1053
+ // When the `bedrock` feature is disabled, signing_headers() must return
1054
+ // an empty vec so requests work against override base-URLs (e.g. mock servers).
1055
+ // When the feature IS enabled but credentials are absent, it also returns
1056
+ // empty (the error is silently swallowed to avoid panicking).
1057
+ use crate::provider::bedrock::BedrockProvider;
1058
+ let p = BedrockProvider::new("us-east-1");
1059
+ let headers = p.signing_headers("POST", "http://localhost/chat/completions", b"{}");
1060
+ // In CI without real AWS credentials, headers will be empty either way.
1061
+ // The important invariant: it must not panic.
1062
+ let _ = headers;
1063
+ }
1064
+
1065
+ fn make_provider_with_mappings(mappings: HashMap<String, String>) -> ConfigDrivenProvider {
1066
+ let cfg: &'static ProviderConfig = Box::leak(Box::new(ProviderConfig {
1067
+ name: "test-mapped".into(),
1068
+ display_name: None,
1069
+ base_url: Some("https://api.example.com/v1".into()),
1070
+ auth: Some(AuthConfig {
1071
+ auth_type: AuthType::Bearer,
1072
+ env_var: Some("TEST_API_KEY".into()),
1073
+ }),
1074
+ endpoints: None,
1075
+ model_prefixes: None,
1076
+ param_mappings: Some(mappings),
1077
+ }));
1078
+ ConfigDrivenProvider::new(cfg)
1079
+ }
1080
+
1081
+ #[test]
1082
+ fn param_mappings_renames_field() {
1083
+ let mut mappings = HashMap::new();
1084
+ mappings.insert("max_completion_tokens".into(), "max_tokens".into());
1085
+
1086
+ let provider = make_provider_with_mappings(mappings);
1087
+ let mut body = serde_json::json!({
1088
+ "model": "test/model",
1089
+ "messages": [],
1090
+ "max_completion_tokens": 512
1091
+ });
1092
+
1093
+ provider.transform_request(&mut body).unwrap();
1094
+
1095
+ assert_eq!(body["max_tokens"], 512);
1096
+ assert!(body.get("max_completion_tokens").is_none());
1097
+ }
1098
+
1099
+ #[test]
1100
+ fn param_mappings_skips_absent_field() {
1101
+ let mut mappings = HashMap::new();
1102
+ mappings.insert("max_completion_tokens".into(), "max_tokens".into());
1103
+
1104
+ let provider = make_provider_with_mappings(mappings);
1105
+ let mut body = serde_json::json!({
1106
+ "model": "test/model",
1107
+ "messages": []
1108
+ });
1109
+
1110
+ provider.transform_request(&mut body).unwrap();
1111
+
1112
+ assert!(body.get("max_tokens").is_none());
1113
+ assert!(body.get("max_completion_tokens").is_none());
1114
+ }
1115
+
1116
+ #[test]
1117
+ fn param_mappings_none_is_noop() {
1118
+ let provider = make_provider(AuthType::Bearer);
1119
+ let mut body = serde_json::json!({
1120
+ "model": "test/model",
1121
+ "messages": [],
1122
+ "max_completion_tokens": 512
1123
+ });
1124
+
1125
+ provider.transform_request(&mut body).unwrap();
1126
+
1127
+ // Field untouched when no mappings configured
1128
+ assert_eq!(body["max_completion_tokens"], 512);
1129
+ }
1130
+
1131
+ #[test]
1132
+ fn param_mappings_multiple_fields() {
1133
+ let mut mappings = HashMap::new();
1134
+ mappings.insert("max_completion_tokens".into(), "max_tokens".into());
1135
+ mappings.insert("frequency_penalty".into(), "repetition_penalty".into());
1136
+
1137
+ let provider = make_provider_with_mappings(mappings);
1138
+ let mut body = serde_json::json!({
1139
+ "model": "test/model",
1140
+ "messages": [],
1141
+ "max_completion_tokens": 512,
1142
+ "frequency_penalty": 0.5
1143
+ });
1144
+
1145
+ provider.transform_request(&mut body).unwrap();
1146
+
1147
+ assert_eq!(body["max_tokens"], 512);
1148
+ assert_eq!(body["repetition_penalty"], 0.5);
1149
+ assert!(body.get("max_completion_tokens").is_none());
1150
+ assert!(body.get("frequency_penalty").is_none());
1151
+ }
1152
+
1153
+ #[test]
1154
+ fn real_provider_apertis_has_param_mappings() {
1155
+ let p = detect_provider("apertis/some-model").unwrap();
1156
+ let mut body = serde_json::json!({
1157
+ "model": "some-model",
1158
+ "messages": [{"role": "user", "content": "hi"}],
1159
+ "max_completion_tokens": 256
1160
+ });
1161
+
1162
+ p.transform_request(&mut body).unwrap();
1163
+
1164
+ assert_eq!(body["max_tokens"], 256);
1165
+ assert!(body.get("max_completion_tokens").is_none());
1166
+ }
1167
+ }
1168
+
1169
+ #[cfg(test)]
1170
+ mod error_tests {
1171
+ use crate::error::LiterLlmError;
1172
+
1173
+ #[test]
1174
+ fn error_from_401() {
1175
+ let err = LiterLlmError::from_status(
1176
+ 401,
1177
+ r#"{"error":{"message":"Invalid API key","type":"invalid_request_error"}}"#,
1178
+ None,
1179
+ );
1180
+ assert!(matches!(err, LiterLlmError::Authentication { .. }));
1181
+ }
1182
+
1183
+ #[test]
1184
+ fn error_from_429() {
1185
+ let err = LiterLlmError::from_status(
1186
+ 429,
1187
+ r#"{"error":{"message":"Rate limited","type":"rate_limit_error"}}"#,
1188
+ None,
1189
+ );
1190
+ assert!(matches!(err, LiterLlmError::RateLimited { .. }));
1191
+ }
1192
+
1193
+ #[test]
1194
+ fn error_from_context_window() {
1195
+ let err = LiterLlmError::from_status(
1196
+ 400,
1197
+ r#"{"error":{"message":"maximum context length exceeded","type":"invalid_request_error"}}"#,
1198
+ None,
1199
+ );
1200
+ assert!(matches!(err, LiterLlmError::ContextWindowExceeded { .. }));
1201
+ }
1202
+
1203
+ #[test]
1204
+ fn error_from_plain_text() {
1205
+ let err = LiterLlmError::from_status(500, "Internal Server Error", None);
1206
+ assert!(matches!(err, LiterLlmError::ServerError { .. }));
1207
+ }
1208
+
1209
+ #[test]
1210
+ fn error_from_503() {
1211
+ let err = LiterLlmError::from_status(503, "Service Unavailable", None);
1212
+ assert!(matches!(err, LiterLlmError::ServiceUnavailable { .. }));
1213
+ }
1214
+
1215
+ #[test]
1216
+ fn error_from_403_forbidden() {
1217
+ // 403 Forbidden indicates invalid credentials / insufficient
1218
+ // permissions — map to Authentication, not ServerError.
1219
+ let err = LiterLlmError::from_status(403, "Forbidden", None);
1220
+ assert!(matches!(err, LiterLlmError::Authentication { .. }));
1221
+ }
1222
+
1223
+ #[test]
1224
+ fn error_from_502_bad_gateway() {
1225
+ let err = LiterLlmError::from_status(502, "Bad Gateway", None);
1226
+ assert!(matches!(err, LiterLlmError::ServiceUnavailable { .. }));
1227
+ }
1228
+
1229
+ #[test]
1230
+ fn error_from_504_gateway_timeout() {
1231
+ let err = LiterLlmError::from_status(504, "Gateway Timeout", None);
1232
+ assert!(matches!(err, LiterLlmError::ServiceUnavailable { .. }));
1233
+ }
1234
+
1235
+ #[test]
1236
+ fn error_from_content_policy_filter() {
1237
+ let err = LiterLlmError::from_status(
1238
+ 400,
1239
+ r#"{"error":{"message":"Request violates content_filter policy","type":"invalid_request_error"}}"#,
1240
+ None,
1241
+ );
1242
+ assert!(matches!(err, LiterLlmError::ContentPolicy { .. }));
1243
+ }
1244
+
1245
+ #[test]
1246
+ fn error_from_content_policy_explicit() {
1247
+ let err = LiterLlmError::from_status(
1248
+ 400,
1249
+ r#"{"error":{"message":"content_policy violation detected","type":"invalid_request_error"}}"#,
1250
+ None,
1251
+ );
1252
+ assert!(matches!(err, LiterLlmError::ContentPolicy { .. }));
1253
+ }
1254
+
1255
+ #[test]
1256
+ fn error_message_preservation() {
1257
+ let msg = "Custom error message from provider";
1258
+ let err = LiterLlmError::from_status(
1259
+ 500,
1260
+ &format!(r#"{{"error":{{"message":"{}","type":"server_error"}}}}"#, msg),
1261
+ None,
1262
+ );
1263
+ if let LiterLlmError::ServerError { message } = err {
1264
+ assert_eq!(message, msg);
1265
+ } else {
1266
+ panic!("expected ServerError");
1267
+ }
1268
+ }
1269
+ }
1270
+
1271
+ #[cfg(test)]
1272
+ mod retry_tests {
1273
+ use crate::http::retry::{parse_retry_after, should_retry};
1274
+
1275
+ #[test]
1276
+ fn retry_on_429() {
1277
+ assert!(should_retry(429, 0, 3, None).is_some());
1278
+ assert!(should_retry(429, 1, 3, None).is_some());
1279
+ assert!(should_retry(429, 2, 3, None).is_some());
1280
+ assert!(should_retry(429, 3, 3, None).is_none()); // exhausted
1281
+ }
1282
+
1283
+ #[test]
1284
+ fn retry_on_500() {
1285
+ assert!(should_retry(500, 0, 3, None).is_some());
1286
+ assert!(should_retry(503, 0, 3, None).is_some());
1287
+ }
1288
+
1289
+ #[test]
1290
+ fn no_retry_on_400() {
1291
+ assert!(should_retry(400, 0, 3, None).is_none());
1292
+ assert!(should_retry(401, 0, 3, None).is_none());
1293
+ assert!(should_retry(404, 0, 3, None).is_none());
1294
+ }
1295
+
1296
+ #[test]
1297
+ fn no_retry_when_disabled() {
1298
+ assert!(should_retry(429, 0, 0, None).is_none());
1299
+ }
1300
+
1301
+ #[test]
1302
+ fn exponential_backoff() {
1303
+ use std::time::Duration;
1304
+ // With jitter in [0.5, 1.0] of the base delay, verify each attempt
1305
+ // stays within its expected range. Base delays: 1s, 2s, 4s.
1306
+ let d0 = should_retry(429, 0, 3, None).unwrap();
1307
+ let d1 = should_retry(429, 1, 3, None).unwrap();
1308
+ let d2 = should_retry(429, 2, 3, None).unwrap();
1309
+ assert!(d0 >= Duration::from_millis(500) && d0 <= Duration::from_secs(1));
1310
+ assert!(d1 >= Duration::from_secs(1) && d1 <= Duration::from_secs(2));
1311
+ assert!(d2 >= Duration::from_secs(2) && d2 <= Duration::from_secs(4));
1312
+ }
1313
+
1314
+ #[test]
1315
+ fn retry_after_header_respected_on_429() {
1316
+ use std::time::Duration;
1317
+ let server_delay = Duration::from_secs(42);
1318
+ let delay = should_retry(429, 0, 3, Some(server_delay)).unwrap();
1319
+ assert_eq!(delay, server_delay);
1320
+ }
1321
+
1322
+ #[test]
1323
+ fn retry_after_header_ignored_on_500() {
1324
+ use std::time::Duration;
1325
+ // Retry-After is only honoured for 429; on 500 we use exponential backoff.
1326
+ let server_delay = Duration::from_secs(42);
1327
+ let delay = should_retry(500, 0, 3, Some(server_delay)).unwrap();
1328
+ // Exponential backoff for attempt 0 = 1 s base, with jitter in [0.5, 1.0].
1329
+ assert!(delay >= Duration::from_millis(500) && delay <= Duration::from_secs(1));
1330
+ }
1331
+
1332
+ #[test]
1333
+ fn parse_retry_after_header() {
1334
+ use std::time::Duration;
1335
+ assert_eq!(parse_retry_after("30"), Some(Duration::from_secs(30)));
1336
+ assert_eq!(parse_retry_after(" 5 "), Some(Duration::from_secs(5)));
1337
+ assert_eq!(parse_retry_after("not-a-number"), None);
1338
+ }
1339
+
1340
+ #[test]
1341
+ fn retry_on_504() {
1342
+ use std::time::Duration;
1343
+ // Jitter scales base delay by [0.5, 1.0], so we check ranges.
1344
+ let d0 = should_retry(504, 0, 3, None).unwrap();
1345
+ assert!(d0 >= Duration::from_millis(500) && d0 <= Duration::from_secs(1));
1346
+ let d1 = should_retry(504, 1, 3, None).unwrap();
1347
+ assert!(d1 >= Duration::from_secs(1) && d1 <= Duration::from_secs(2));
1348
+ let d2 = should_retry(504, 2, 3, None).unwrap();
1349
+ assert!(d2 >= Duration::from_secs(2) && d2 <= Duration::from_secs(4));
1350
+ }
1351
+
1352
+ #[test]
1353
+ fn retry_on_502() {
1354
+ assert!(should_retry(502, 0, 3, None).is_some());
1355
+ assert!(should_retry(502, 1, 3, None).is_some());
1356
+ }
1357
+
1358
+ #[test]
1359
+ fn server_retry_after_capped_at_60s() {
1360
+ use std::time::Duration;
1361
+ let server_delay = Duration::from_secs(120);
1362
+ let delay = should_retry(429, 0, 3, Some(server_delay)).unwrap();
1363
+ // Should be capped at 60 seconds
1364
+ assert_eq!(delay, Duration::from_secs(60));
1365
+ }
1366
+
1367
+ #[test]
1368
+ fn server_retry_after_under_cap() {
1369
+ use std::time::Duration;
1370
+ let server_delay = Duration::from_secs(30);
1371
+ let delay = should_retry(429, 0, 3, Some(server_delay)).unwrap();
1372
+ // Under 60s cap, should use server value
1373
+ assert_eq!(delay, Duration::from_secs(30));
1374
+ }
1375
+
1376
+ #[test]
1377
+ fn exponential_backoff_caps_at_30s() {
1378
+ use std::time::Duration;
1379
+ // Attempt 5 would be 2^5 = 32 seconds, capped at 30, then jitter [0.5, 1.0].
1380
+ let delay = should_retry(500, 5, 10, None).unwrap();
1381
+ assert!(delay >= Duration::from_secs(15) && delay <= Duration::from_secs(30));
1382
+ }
1383
+ }
1384
+
1385
+ #[cfg(test)]
1386
+ mod sse_tests {
1387
+ use crate::http::streaming::parse_sse_line;
1388
+
1389
+ #[test]
1390
+ fn parse_valid_chunk() {
1391
+ let line = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}"#;
1392
+ let result = parse_sse_line(line).unwrap().unwrap();
1393
+ assert_eq!(result.choices[0].delta.content.as_deref(), Some("Hi"));
1394
+ }
1395
+
1396
+ #[test]
1397
+ fn parse_done_returns_none() {
1398
+ assert!(parse_sse_line("data: [DONE]").is_none());
1399
+ }
1400
+
1401
+ #[test]
1402
+ fn parse_non_data_returns_none() {
1403
+ assert!(parse_sse_line("event: ping").is_none());
1404
+ assert!(parse_sse_line(": comment").is_none());
1405
+ }
1406
+
1407
+ #[test]
1408
+ fn parse_invalid_json() {
1409
+ let result = parse_sse_line("data: {invalid}").unwrap();
1410
+ assert!(result.is_err());
1411
+ }
1412
+
1413
+ #[test]
1414
+ fn parse_data_without_space() {
1415
+ // Some implementations emit "data:{json}" without a trailing space.
1416
+ let line = r#"data:{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}"#;
1417
+ let result = parse_sse_line(line).unwrap().unwrap();
1418
+ assert_eq!(result.choices[0].delta.content.as_deref(), Some("Hi"));
1419
+ }
1420
+
1421
+ #[test]
1422
+ fn parse_done_without_space() {
1423
+ assert!(parse_sse_line("data:[DONE]").is_none());
1424
+ }
1425
+ }