liter_llm 1.0.0.pre.rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +239 -0
  3. data/ext/liter_llm_rb/extconf.rb +65 -0
  4. data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
  5. data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
  6. data/ext/liter_llm_rb/native/Cargo.toml +32 -0
  7. data/ext/liter_llm_rb/native/build.rs +15 -0
  8. data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
  9. data/lib/liter_llm.rb +8 -0
  10. data/sig/liter_llm.rbs +416 -0
  11. data/vendor/Cargo.toml +54 -0
  12. data/vendor/liter-llm/Cargo.toml +92 -0
  13. data/vendor/liter-llm/README.md +252 -0
  14. data/vendor/liter-llm/schemas/pricing.json +40 -0
  15. data/vendor/liter-llm/schemas/providers.json +1662 -0
  16. data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
  17. data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
  18. data/vendor/liter-llm/src/auth/mod.rs +68 -0
  19. data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
  20. data/vendor/liter-llm/src/client/config.rs +351 -0
  21. data/vendor/liter-llm/src/client/managed.rs +622 -0
  22. data/vendor/liter-llm/src/client/mod.rs +864 -0
  23. data/vendor/liter-llm/src/cost.rs +212 -0
  24. data/vendor/liter-llm/src/error.rs +190 -0
  25. data/vendor/liter-llm/src/http/eventstream.rs +860 -0
  26. data/vendor/liter-llm/src/http/mod.rs +12 -0
  27. data/vendor/liter-llm/src/http/request.rs +438 -0
  28. data/vendor/liter-llm/src/http/retry.rs +72 -0
  29. data/vendor/liter-llm/src/http/streaming.rs +289 -0
  30. data/vendor/liter-llm/src/lib.rs +37 -0
  31. data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
  32. data/vendor/liter-llm/src/provider/azure.rs +579 -0
  33. data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
  34. data/vendor/liter-llm/src/provider/cohere.rs +654 -0
  35. data/vendor/liter-llm/src/provider/custom.rs +404 -0
  36. data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
  37. data/vendor/liter-llm/src/provider/mistral.rs +188 -0
  38. data/vendor/liter-llm/src/provider/mod.rs +616 -0
  39. data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
  40. data/vendor/liter-llm/src/tests.rs +1425 -0
  41. data/vendor/liter-llm/src/tokenizer.rs +281 -0
  42. data/vendor/liter-llm/src/tower/budget.rs +599 -0
  43. data/vendor/liter-llm/src/tower/cache.rs +502 -0
  44. data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
  45. data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
  46. data/vendor/liter-llm/src/tower/cost.rs +404 -0
  47. data/vendor/liter-llm/src/tower/fallback.rs +121 -0
  48. data/vendor/liter-llm/src/tower/health.rs +219 -0
  49. data/vendor/liter-llm/src/tower/hooks.rs +369 -0
  50. data/vendor/liter-llm/src/tower/mod.rs +77 -0
  51. data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
  52. data/vendor/liter-llm/src/tower/router.rs +436 -0
  53. data/vendor/liter-llm/src/tower/service.rs +181 -0
  54. data/vendor/liter-llm/src/tower/tests.rs +539 -0
  55. data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
  56. data/vendor/liter-llm/src/tower/tracing.rs +209 -0
  57. data/vendor/liter-llm/src/tower/types.rs +170 -0
  58. data/vendor/liter-llm/src/types/audio.rs +52 -0
  59. data/vendor/liter-llm/src/types/batch.rs +77 -0
  60. data/vendor/liter-llm/src/types/chat.rs +214 -0
  61. data/vendor/liter-llm/src/types/common.rs +244 -0
  62. data/vendor/liter-llm/src/types/embedding.rs +84 -0
  63. data/vendor/liter-llm/src/types/files.rs +58 -0
  64. data/vendor/liter-llm/src/types/image.rs +40 -0
  65. data/vendor/liter-llm/src/types/mod.rs +27 -0
  66. data/vendor/liter-llm/src/types/models.rs +21 -0
  67. data/vendor/liter-llm/src/types/moderation.rs +80 -0
  68. data/vendor/liter-llm/src/types/ocr.rs +87 -0
  69. data/vendor/liter-llm/src/types/rerank.rs +46 -0
  70. data/vendor/liter-llm/src/types/responses.rs +55 -0
  71. data/vendor/liter-llm/src/types/search.rs +45 -0
  72. data/vendor/liter-llm/tests/contract.rs +332 -0
  73. data/vendor/liter-llm-ffi/Cargo.toml +30 -0
  74. data/vendor/liter-llm-ffi/build.rs +66 -0
  75. data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
  76. data/vendor/liter-llm-ffi/liter_llm.h +850 -0
  77. data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
  78. metadata +286 -0
@@ -0,0 +1,539 @@
1
+ //! Unit tests for the tower middleware integration.
2
+ //!
3
+ //! These tests use a mock [`LlmClient`] to avoid real HTTP calls.
4
+
5
+ use std::sync::Arc;
6
+ use std::sync::atomic::{AtomicUsize, Ordering};
7
+
8
+ use tower::Service;
9
+
10
+ use std::pin::Pin;
11
+ use std::task::{Context, Poll};
12
+
13
+ use futures_core::Stream;
14
+
15
+ use crate::client::{BoxFuture, BoxStream, LlmClient};
16
+ use crate::error::{LiterLlmError, Result};
17
+ use crate::tower::fallback::FallbackLayer;
18
+ use crate::tower::service::LlmService;
19
+ use crate::tower::tracing::TracingLayer;
20
+ use crate::tower::types::{LlmRequest, LlmResponse};
21
+ use crate::types::audio::{CreateSpeechRequest, CreateTranscriptionRequest, TranscriptionResponse};
22
+ use crate::types::image::{CreateImageRequest, ImagesResponse};
23
+ use crate::types::moderation::{ModerationRequest, ModerationResponse};
24
+ use crate::types::ocr::{OcrRequest, OcrResponse};
25
+ use crate::types::rerank::{RerankRequest, RerankResponse};
26
+ use crate::types::search::{SearchRequest, SearchResponse};
27
+ use crate::types::{
28
+ AssistantMessage, ChatCompletionRequest, ChatCompletionResponse, Choice, EmbeddingInput, EmbeddingObject,
29
+ EmbeddingRequest, EmbeddingResponse, FinishReason, Message, ModelsListResponse, SystemMessage, Usage,
30
+ };
31
+ use tower::Layer;
32
+
33
+ // ─── Helpers ─────────────────────────────────────────────────────────────────
34
+
35
+ /// A stream that yields no items.
36
+ struct EmptyStream;
37
+
38
+ impl Stream for EmptyStream {
39
+ type Item = Result<crate::types::ChatCompletionChunk>;
40
+ fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
41
+ Poll::Ready(None)
42
+ }
43
+ }
44
+
45
+ // ─── Mock client ─────────────────────────────────────────────────────────────
46
+
47
+ /// A synchronous mock client. All methods return configurable canned
48
+ /// responses or errors.
49
+ ///
50
+ /// The inner state is wrapped in `Arc` so the struct can be cheaply cloned
51
+ /// (required by [`FallbackLayer`] which requires `F: Clone`).
52
+ #[derive(Clone)]
53
+ struct MockClient {
54
+ /// Shared inner state.
55
+ inner: Arc<MockClientInner>,
56
+ }
57
+
58
+ struct MockClientInner {
59
+ /// When set, `chat` returns this error instead of the canned response.
60
+ chat_error: Option<LiterLlmErrorKind>,
61
+ /// Number of times `chat` has been called.
62
+ call_count: AtomicUsize,
63
+ }
64
+
65
+ /// A serializable subset of [`LiterLlmError`] variants used in tests.
66
+ /// `LiterLlmError` is not `Clone`, so we store an enum of the variants we care about.
67
+ enum LiterLlmErrorKind {
68
+ RateLimited { message: String },
69
+ ServiceUnavailable { message: String },
70
+ Timeout,
71
+ Authentication { message: String },
72
+ }
73
+
74
+ impl LiterLlmErrorKind {
75
+ fn to_error(&self) -> LiterLlmError {
76
+ match self {
77
+ Self::RateLimited { message } => LiterLlmError::RateLimited {
78
+ message: message.clone(),
79
+ retry_after: None,
80
+ },
81
+ Self::ServiceUnavailable { message } => LiterLlmError::ServiceUnavailable {
82
+ message: message.clone(),
83
+ },
84
+ Self::Timeout => LiterLlmError::Timeout,
85
+ // MockClient maps auth error to BadRequest (not Authentication) because
86
+ // the mock's chat() doesn't distinguish — see MockClient::ok().
87
+ Self::Authentication { message } => LiterLlmError::BadRequest {
88
+ message: message.clone(),
89
+ },
90
+ }
91
+ }
92
+ }
93
+
94
+ impl MockClient {
95
+ fn ok() -> Self {
96
+ Self {
97
+ inner: Arc::new(MockClientInner {
98
+ chat_error: None,
99
+ call_count: AtomicUsize::new(0),
100
+ }),
101
+ }
102
+ }
103
+
104
+ fn failing_rate_limited() -> Self {
105
+ Self {
106
+ inner: Arc::new(MockClientInner {
107
+ chat_error: Some(LiterLlmErrorKind::RateLimited {
108
+ message: "too many requests".into(),
109
+ }),
110
+ call_count: AtomicUsize::new(0),
111
+ }),
112
+ }
113
+ }
114
+
115
+ fn failing_service_unavailable() -> Self {
116
+ Self {
117
+ inner: Arc::new(MockClientInner {
118
+ chat_error: Some(LiterLlmErrorKind::ServiceUnavailable { message: "503".into() }),
119
+ call_count: AtomicUsize::new(0),
120
+ }),
121
+ }
122
+ }
123
+
124
+ fn failing_auth() -> Self {
125
+ Self {
126
+ inner: Arc::new(MockClientInner {
127
+ chat_error: Some(LiterLlmErrorKind::Authentication {
128
+ message: "invalid key".into(),
129
+ }),
130
+ call_count: AtomicUsize::new(0),
131
+ }),
132
+ }
133
+ }
134
+
135
+ fn failing_timeout() -> Self {
136
+ Self {
137
+ inner: Arc::new(MockClientInner {
138
+ chat_error: Some(LiterLlmErrorKind::Timeout),
139
+ call_count: AtomicUsize::new(0),
140
+ }),
141
+ }
142
+ }
143
+ }
144
+
145
+ fn make_chat_response(model: &str) -> ChatCompletionResponse {
146
+ ChatCompletionResponse {
147
+ id: "test-id".into(),
148
+ object: "chat.completion".into(),
149
+ created: 0,
150
+ model: model.into(),
151
+ choices: vec![Choice {
152
+ index: 0,
153
+ message: AssistantMessage {
154
+ content: Some("Hello!".into()),
155
+ name: None,
156
+ tool_calls: None,
157
+ refusal: None,
158
+ function_call: None,
159
+ },
160
+ finish_reason: Some(FinishReason::Stop),
161
+ }],
162
+ usage: Some(Usage {
163
+ prompt_tokens: 10,
164
+ completion_tokens: 5,
165
+ total_tokens: 15,
166
+ }),
167
+ system_fingerprint: None,
168
+ service_tier: None,
169
+ }
170
+ }
171
+
172
+ impl LlmClient for MockClient {
173
+ fn chat(&self, req: ChatCompletionRequest) -> BoxFuture<'_, ChatCompletionResponse> {
174
+ self.inner.call_count.fetch_add(1, Ordering::SeqCst);
175
+ let result = match &self.inner.chat_error {
176
+ Some(kind) => Err(kind.to_error()),
177
+ None => Ok(make_chat_response(&req.model)),
178
+ };
179
+ Box::pin(async move { result })
180
+ }
181
+
182
+ fn chat_stream(
183
+ &self,
184
+ _req: ChatCompletionRequest,
185
+ ) -> BoxFuture<'_, BoxStream<'_, crate::types::ChatCompletionChunk>> {
186
+ Box::pin(async move {
187
+ // Return an immediately-finished stream.
188
+ let stream: BoxStream<'_, crate::types::ChatCompletionChunk> = Box::pin(EmptyStream);
189
+ Ok(stream)
190
+ })
191
+ }
192
+
193
+ fn embed(&self, req: EmbeddingRequest) -> BoxFuture<'_, EmbeddingResponse> {
194
+ let resp = EmbeddingResponse {
195
+ object: "list".into(),
196
+ data: vec![EmbeddingObject {
197
+ object: "embedding".into(),
198
+ embedding: vec![0.1, 0.2, 0.3],
199
+ index: 0,
200
+ }],
201
+ model: req.model.clone(),
202
+ usage: Some(Usage {
203
+ prompt_tokens: 4,
204
+ completion_tokens: 0,
205
+ total_tokens: 4,
206
+ }),
207
+ };
208
+ Box::pin(async move { Ok(resp) })
209
+ }
210
+
211
+ fn list_models(&self) -> BoxFuture<'_, ModelsListResponse> {
212
+ Box::pin(async move {
213
+ Ok(ModelsListResponse {
214
+ object: "list".into(),
215
+ data: vec![],
216
+ })
217
+ })
218
+ }
219
+
220
+ fn image_generate(&self, _req: CreateImageRequest) -> BoxFuture<'_, ImagesResponse> {
221
+ Box::pin(async move {
222
+ Ok(ImagesResponse {
223
+ created: 0,
224
+ data: vec![],
225
+ })
226
+ })
227
+ }
228
+
229
+ fn speech(&self, _req: CreateSpeechRequest) -> BoxFuture<'_, bytes::Bytes> {
230
+ Box::pin(async move { Ok(bytes::Bytes::new()) })
231
+ }
232
+
233
+ fn transcribe(&self, _req: CreateTranscriptionRequest) -> BoxFuture<'_, TranscriptionResponse> {
234
+ Box::pin(async move {
235
+ Ok(TranscriptionResponse {
236
+ text: String::new(),
237
+ language: None,
238
+ duration: None,
239
+ segments: None,
240
+ })
241
+ })
242
+ }
243
+
244
+ fn moderate(&self, _req: ModerationRequest) -> BoxFuture<'_, ModerationResponse> {
245
+ Box::pin(async move {
246
+ Ok(ModerationResponse {
247
+ id: String::new(),
248
+ model: String::new(),
249
+ results: vec![],
250
+ })
251
+ })
252
+ }
253
+
254
+ fn rerank(&self, _req: RerankRequest) -> BoxFuture<'_, RerankResponse> {
255
+ Box::pin(async move {
256
+ Ok(RerankResponse {
257
+ id: None,
258
+ results: vec![],
259
+ meta: None,
260
+ })
261
+ })
262
+ }
263
+
264
+ fn search(&self, _req: SearchRequest) -> BoxFuture<'_, SearchResponse> {
265
+ Box::pin(async {
266
+ Err(LiterLlmError::EndpointNotSupported {
267
+ endpoint: "search".into(),
268
+ provider: "mock".into(),
269
+ })
270
+ })
271
+ }
272
+
273
+ fn ocr(&self, _req: OcrRequest) -> BoxFuture<'_, OcrResponse> {
274
+ Box::pin(async {
275
+ Err(LiterLlmError::EndpointNotSupported {
276
+ endpoint: "ocr".into(),
277
+ provider: "mock".into(),
278
+ })
279
+ })
280
+ }
281
+ }
282
+
283
+ // ─── Helpers ─────────────────────────────────────────────────────────────────
284
+
285
+ fn chat_req(model: &str) -> ChatCompletionRequest {
286
+ ChatCompletionRequest {
287
+ model: model.into(),
288
+ messages: vec![Message::System(SystemMessage {
289
+ content: "test".into(),
290
+ name: None,
291
+ })],
292
+ ..Default::default()
293
+ }
294
+ }
295
+
296
+ // ─── LlmService tests ────────────────────────────────────────────────────────
297
+
298
+ #[tokio::test]
299
+ async fn service_chat_returns_correct_response() {
300
+ let mut svc = LlmService::new(MockClient::ok());
301
+ let resp = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
302
+ match resp {
303
+ LlmResponse::Chat(r) => assert_eq!(r.model, "gpt-4"),
304
+ other => panic!("expected Chat response, got {:?}", std::mem::discriminant(&other)),
305
+ }
306
+ }
307
+
308
+ #[tokio::test]
309
+ async fn service_embed_returns_embedding_response() {
310
+ let mut svc = LlmService::new(MockClient::ok());
311
+ let req = EmbeddingRequest {
312
+ model: "text-embedding-3-small".into(),
313
+ input: EmbeddingInput::Single("hello world".into()),
314
+ encoding_format: None,
315
+ dimensions: None,
316
+ user: None,
317
+ };
318
+ let resp = svc.call(LlmRequest::Embed(req)).await.unwrap();
319
+ match resp {
320
+ LlmResponse::Embed(r) => assert_eq!(r.model, "text-embedding-3-small"),
321
+ other => panic!("expected Embed response, got {:?}", std::mem::discriminant(&other)),
322
+ }
323
+ }
324
+
325
+ #[tokio::test]
326
+ async fn service_list_models_returns_model_list() {
327
+ let mut svc = LlmService::new(MockClient::ok());
328
+ let resp = svc.call(LlmRequest::ListModels).await.unwrap();
329
+ assert!(matches!(resp, LlmResponse::ListModels(_)));
330
+ }
331
+
332
+ #[tokio::test]
333
+ async fn service_propagates_client_error() {
334
+ let mut svc = LlmService::new(MockClient::failing_auth());
335
+ let err = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap_err();
336
+ assert!(matches!(
337
+ err,
338
+ LiterLlmError::BadRequest { .. } | LiterLlmError::Authentication { .. }
339
+ ));
340
+ }
341
+
342
+ // ─── TracingLayer tests ───────────────────────────────────────────────────────
343
+
344
+ #[tokio::test]
345
+ async fn tracing_layer_passes_through_success() {
346
+ let inner = LlmService::new(MockClient::ok());
347
+ let mut svc = TracingLayer.layer(inner);
348
+ let resp = svc.call(LlmRequest::Chat(chat_req("gpt-4o"))).await.unwrap();
349
+ assert!(matches!(resp, LlmResponse::Chat(_)));
350
+ }
351
+
352
+ #[tokio::test]
353
+ async fn tracing_layer_propagates_error() {
354
+ let inner = LlmService::new(MockClient::failing_timeout());
355
+ let mut svc = TracingLayer.layer(inner);
356
+ let err = svc.call(LlmRequest::Chat(chat_req("gpt-4o"))).await.unwrap_err();
357
+ assert!(matches!(err, LiterLlmError::Timeout));
358
+ }
359
+
360
+ // ─── FallbackLayer tests ──────────────────────────────────────────────────────
361
+
362
+ #[tokio::test]
363
+ async fn fallback_not_triggered_on_success() {
364
+ let primary = LlmService::new(MockClient::ok());
365
+ let fallback = LlmService::new(MockClient::ok());
366
+
367
+ let mut svc = FallbackLayer::new(fallback).layer(primary);
368
+ let resp = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
369
+ // The response is Chat — confirming primary was called and succeeded.
370
+ assert!(matches!(resp, LlmResponse::Chat(_)));
371
+ }
372
+
373
+ #[tokio::test]
374
+ async fn fallback_triggered_on_rate_limit() {
375
+ let primary = LlmService::new(MockClient::failing_rate_limited());
376
+ let fallback = LlmService::new(MockClient::ok());
377
+
378
+ let mut svc = FallbackLayer::new(fallback).layer(primary);
379
+ let resp = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
380
+ assert!(matches!(resp, LlmResponse::Chat(_)));
381
+ }
382
+
383
+ #[tokio::test]
384
+ async fn fallback_triggered_on_service_unavailable() {
385
+ let primary = LlmService::new(MockClient::failing_service_unavailable());
386
+ let fallback = LlmService::new(MockClient::ok());
387
+
388
+ let mut svc = FallbackLayer::new(fallback).layer(primary);
389
+ let resp = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
390
+ assert!(matches!(resp, LlmResponse::Chat(_)));
391
+ }
392
+
393
+ #[tokio::test]
394
+ async fn fallback_not_triggered_on_auth_error() {
395
+ let primary = LlmService::new(MockClient::failing_auth());
396
+ let fallback = LlmService::new(MockClient::ok());
397
+
398
+ let mut svc = FallbackLayer::new(fallback).layer(primary);
399
+ // Authentication errors are not transient; fallback should NOT be tried.
400
+ // MockClient::failing_auth maps the error to BadRequest (non-transient),
401
+ // so an error should propagate rather than the fallback succeeding.
402
+ let result = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
403
+ assert!(result.is_err(), "expected auth error to propagate, not fall back");
404
+ }
405
+
406
+ // ─── LlmRequest helpers ───────────────────────────────────────────────────────
407
+
408
+ #[test]
409
+ fn request_type_labels() {
410
+ assert_eq!(LlmRequest::Chat(chat_req("m")).request_type(), "chat");
411
+ assert_eq!(LlmRequest::ChatStream(chat_req("m")).request_type(), "chat_stream");
412
+ assert_eq!(
413
+ LlmRequest::Embed(EmbeddingRequest {
414
+ model: "e".into(),
415
+ input: EmbeddingInput::Single("x".into()),
416
+ encoding_format: None,
417
+ dimensions: None,
418
+ user: None,
419
+ })
420
+ .request_type(),
421
+ "embeddings"
422
+ );
423
+ assert_eq!(LlmRequest::ListModels.request_type(), "list_models");
424
+ }
425
+
426
+ #[test]
427
+ fn operation_name_labels() {
428
+ assert_eq!(LlmRequest::Chat(chat_req("m")).operation_name(), "chat");
429
+ assert_eq!(LlmRequest::ChatStream(chat_req("m")).operation_name(), "chat");
430
+ assert_eq!(
431
+ LlmRequest::Embed(EmbeddingRequest {
432
+ model: "e".into(),
433
+ input: EmbeddingInput::Single("x".into()),
434
+ encoding_format: None,
435
+ dimensions: None,
436
+ user: None,
437
+ })
438
+ .operation_name(),
439
+ "embeddings"
440
+ );
441
+ assert_eq!(LlmRequest::ListModels.operation_name(), "list_models");
442
+ }
443
+
444
+ #[test]
445
+ fn request_model_returns_none_for_list_models() {
446
+ assert!(LlmRequest::ListModels.model().is_none());
447
+ }
448
+
449
+ // ─── Router tests ─────────────────────────────────────────────────────────────
450
+
451
+ use crate::tower::router::{Router, RoutingStrategy};
452
+
453
+ /// Build a `Router` over three mock `LlmService` instances and return
454
+ /// a handle to each service's shared call counter so tests can inspect
455
+ /// how many times each was called.
456
+ fn make_round_robin_router() -> (Router<LlmService<MockClient>>, [Arc<MockClientInner>; 3]) {
457
+ let clients = [MockClient::ok(), MockClient::ok(), MockClient::ok()];
458
+ let counters = [
459
+ Arc::clone(&clients[0].inner),
460
+ Arc::clone(&clients[1].inner),
461
+ Arc::clone(&clients[2].inner),
462
+ ];
463
+ let deployments: Vec<_> = clients.into_iter().map(LlmService::new).collect();
464
+ (
465
+ Router::new(deployments, RoutingStrategy::RoundRobin).expect("non-empty deployments"),
466
+ counters,
467
+ )
468
+ }
469
+
470
+ #[tokio::test]
471
+ async fn router_round_robin_distributes_across_deployments() {
472
+ let (mut router, counters) = make_round_robin_router();
473
+
474
+ // Fire 6 requests — each deployment should receive exactly 2.
475
+ for _ in 0..6 {
476
+ router.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
477
+ }
478
+
479
+ assert_eq!(
480
+ counters[0].call_count.load(Ordering::SeqCst),
481
+ 2,
482
+ "deployment 0 should receive 2 calls"
483
+ );
484
+ assert_eq!(
485
+ counters[1].call_count.load(Ordering::SeqCst),
486
+ 2,
487
+ "deployment 1 should receive 2 calls"
488
+ );
489
+ assert_eq!(
490
+ counters[2].call_count.load(Ordering::SeqCst),
491
+ 2,
492
+ "deployment 2 should receive 2 calls"
493
+ );
494
+ }
495
+
496
+ #[tokio::test]
497
+ async fn router_fallback_tries_next_on_transient_error_then_succeeds() {
498
+ // First two deployments fail with a transient error; the third succeeds.
499
+ let deployments: Vec<LlmService<MockClient>> = vec![
500
+ LlmService::new(MockClient::failing_rate_limited()),
501
+ LlmService::new(MockClient::failing_service_unavailable()),
502
+ LlmService::new(MockClient::ok()),
503
+ ];
504
+ let third_counter = Arc::clone(&deployments[2].inner().inner);
505
+
506
+ let mut router = Router::new(deployments, RoutingStrategy::Fallback).expect("non-empty deployments");
507
+ let resp = router.call(LlmRequest::Chat(chat_req("gpt-4"))).await.unwrap();
508
+
509
+ assert!(
510
+ matches!(resp, LlmResponse::Chat(_)),
511
+ "expected successful Chat response from third deployment"
512
+ );
513
+ assert_eq!(
514
+ third_counter.call_count.load(Ordering::SeqCst),
515
+ 1,
516
+ "third deployment should have been called exactly once"
517
+ );
518
+ }
519
+
520
+ #[tokio::test]
521
+ async fn router_fallback_stops_on_non_transient_error() {
522
+ // First deployment fails with a non-transient (auth) error; the second
523
+ // would succeed but must NOT be reached.
524
+ let ok_client = MockClient::ok();
525
+ let ok_counter = Arc::clone(&ok_client.inner);
526
+
527
+ let deployments: Vec<LlmService<MockClient>> =
528
+ vec![LlmService::new(MockClient::failing_auth()), LlmService::new(ok_client)];
529
+
530
+ let mut router = Router::new(deployments, RoutingStrategy::Fallback).expect("non-empty deployments");
531
+ let result = router.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
532
+
533
+ assert!(result.is_err(), "expected non-transient error to propagate immediately");
534
+ assert_eq!(
535
+ ok_counter.call_count.load(Ordering::SeqCst),
536
+ 0,
537
+ "second deployment must not be called after a non-transient error"
538
+ );
539
+ }