liter_llm 1.0.0.pre.rc.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +239 -0
- data/ext/liter_llm_rb/extconf.rb +65 -0
- data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
- data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
- data/ext/liter_llm_rb/native/Cargo.toml +32 -0
- data/ext/liter_llm_rb/native/build.rs +15 -0
- data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
- data/lib/liter_llm.rb +8 -0
- data/sig/liter_llm.rbs +416 -0
- data/vendor/Cargo.toml +54 -0
- data/vendor/liter-llm/Cargo.toml +92 -0
- data/vendor/liter-llm/README.md +252 -0
- data/vendor/liter-llm/schemas/pricing.json +40 -0
- data/vendor/liter-llm/schemas/providers.json +1662 -0
- data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
- data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
- data/vendor/liter-llm/src/auth/mod.rs +68 -0
- data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
- data/vendor/liter-llm/src/client/config.rs +351 -0
- data/vendor/liter-llm/src/client/managed.rs +622 -0
- data/vendor/liter-llm/src/client/mod.rs +864 -0
- data/vendor/liter-llm/src/cost.rs +212 -0
- data/vendor/liter-llm/src/error.rs +190 -0
- data/vendor/liter-llm/src/http/eventstream.rs +860 -0
- data/vendor/liter-llm/src/http/mod.rs +12 -0
- data/vendor/liter-llm/src/http/request.rs +438 -0
- data/vendor/liter-llm/src/http/retry.rs +72 -0
- data/vendor/liter-llm/src/http/streaming.rs +289 -0
- data/vendor/liter-llm/src/lib.rs +37 -0
- data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
- data/vendor/liter-llm/src/provider/azure.rs +579 -0
- data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
- data/vendor/liter-llm/src/provider/cohere.rs +654 -0
- data/vendor/liter-llm/src/provider/custom.rs +404 -0
- data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
- data/vendor/liter-llm/src/provider/mistral.rs +188 -0
- data/vendor/liter-llm/src/provider/mod.rs +616 -0
- data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
- data/vendor/liter-llm/src/tests.rs +1425 -0
- data/vendor/liter-llm/src/tokenizer.rs +281 -0
- data/vendor/liter-llm/src/tower/budget.rs +599 -0
- data/vendor/liter-llm/src/tower/cache.rs +502 -0
- data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
- data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
- data/vendor/liter-llm/src/tower/cost.rs +404 -0
- data/vendor/liter-llm/src/tower/fallback.rs +121 -0
- data/vendor/liter-llm/src/tower/health.rs +219 -0
- data/vendor/liter-llm/src/tower/hooks.rs +369 -0
- data/vendor/liter-llm/src/tower/mod.rs +77 -0
- data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
- data/vendor/liter-llm/src/tower/router.rs +436 -0
- data/vendor/liter-llm/src/tower/service.rs +181 -0
- data/vendor/liter-llm/src/tower/tests.rs +539 -0
- data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
- data/vendor/liter-llm/src/tower/tracing.rs +209 -0
- data/vendor/liter-llm/src/tower/types.rs +170 -0
- data/vendor/liter-llm/src/types/audio.rs +52 -0
- data/vendor/liter-llm/src/types/batch.rs +77 -0
- data/vendor/liter-llm/src/types/chat.rs +214 -0
- data/vendor/liter-llm/src/types/common.rs +244 -0
- data/vendor/liter-llm/src/types/embedding.rs +84 -0
- data/vendor/liter-llm/src/types/files.rs +58 -0
- data/vendor/liter-llm/src/types/image.rs +40 -0
- data/vendor/liter-llm/src/types/mod.rs +27 -0
- data/vendor/liter-llm/src/types/models.rs +21 -0
- data/vendor/liter-llm/src/types/moderation.rs +80 -0
- data/vendor/liter-llm/src/types/ocr.rs +87 -0
- data/vendor/liter-llm/src/types/rerank.rs +46 -0
- data/vendor/liter-llm/src/types/responses.rs +55 -0
- data/vendor/liter-llm/src/types/search.rs +45 -0
- data/vendor/liter-llm/tests/contract.rs +332 -0
- data/vendor/liter-llm-ffi/Cargo.toml +30 -0
- data/vendor/liter-llm-ffi/build.rs +66 -0
- data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
- data/vendor/liter-llm-ffi/liter_llm.h +850 -0
- data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
- metadata +286 -0
|
@@ -0,0 +1,404 @@
|
|
|
1
|
+
//! Tower middleware that records estimated cost as a tracing span attribute.
|
|
2
|
+
//!
|
|
3
|
+
//! [`CostTrackingLayer`] wraps any [`Service<LlmRequest>`] and, after each
|
|
4
|
+
//! successful response, calculates the USD cost from the embedded pricing
|
|
5
|
+
//! registry and records it as `gen_ai.usage.cost` on the current tracing span.
|
|
6
|
+
//!
|
|
7
|
+
//! The layer is a no-op (zero overhead) for models not present in the pricing
|
|
8
|
+
//! registry — the span attribute is simply not recorded.
|
|
9
|
+
//!
|
|
10
|
+
//! # Example
|
|
11
|
+
//!
|
|
12
|
+
//! ```rust,ignore
|
|
13
|
+
//! use liter_llm::tower::{CostTrackingLayer, LlmService, TracingLayer};
|
|
14
|
+
//! use tower::ServiceBuilder;
|
|
15
|
+
//!
|
|
16
|
+
//! let client = liter_llm::DefaultClient::new(config, None)?;
|
|
17
|
+
//! let service = ServiceBuilder::new()
|
|
18
|
+
//! .layer(TracingLayer)
|
|
19
|
+
//! .layer(CostTrackingLayer)
|
|
20
|
+
//! .service(LlmService::new(client));
|
|
21
|
+
//! ```
|
|
22
|
+
|
|
23
|
+
use std::task::{Context, Poll};
|
|
24
|
+
|
|
25
|
+
use tower::Layer;
|
|
26
|
+
use tower::Service;
|
|
27
|
+
|
|
28
|
+
use super::types::{LlmRequest, LlmResponse};
|
|
29
|
+
use crate::client::BoxFuture;
|
|
30
|
+
use crate::cost;
|
|
31
|
+
use crate::error::{LiterLlmError, Result};
|
|
32
|
+
|
|
33
|
+
// ─── Layer ────────────────────────────────────────────────────────────────────
|
|
34
|
+
|
|
35
|
+
/// Tower [`Layer`] that records estimated USD cost on the current tracing span.
|
|
36
|
+
///
|
|
37
|
+
/// After each successful response the layer calls [`cost::completion_cost`] and
|
|
38
|
+
/// records the result as `gen_ai.usage.cost` using
|
|
39
|
+
/// [`tracing::Span::record`]. If the model is not in the pricing registry the
|
|
40
|
+
/// attribute is simply omitted.
|
|
41
|
+
pub struct CostTrackingLayer;
|
|
42
|
+
|
|
43
|
+
impl<S> Layer<S> for CostTrackingLayer {
|
|
44
|
+
type Service = CostTrackingService<S>;
|
|
45
|
+
|
|
46
|
+
fn layer(&self, inner: S) -> Self::Service {
|
|
47
|
+
CostTrackingService { inner }
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// ─── Service ──────────────────────────────────────────────────────────────────
|
|
52
|
+
|
|
53
|
+
/// Tower service produced by [`CostTrackingLayer`].
|
|
54
|
+
pub struct CostTrackingService<S> {
|
|
55
|
+
inner: S,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
impl<S> Clone for CostTrackingService<S>
|
|
59
|
+
where
|
|
60
|
+
S: Clone,
|
|
61
|
+
{
|
|
62
|
+
fn clone(&self) -> Self {
|
|
63
|
+
Self {
|
|
64
|
+
inner: self.inner.clone(),
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
impl<S> Service<LlmRequest> for CostTrackingService<S>
|
|
70
|
+
where
|
|
71
|
+
S: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Send + 'static,
|
|
72
|
+
S::Future: Send + 'static,
|
|
73
|
+
{
|
|
74
|
+
type Response = LlmResponse;
|
|
75
|
+
type Error = LiterLlmError;
|
|
76
|
+
type Future = BoxFuture<'static, LlmResponse>;
|
|
77
|
+
|
|
78
|
+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
|
79
|
+
self.inner.poll_ready(cx)
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
fn call(&mut self, req: LlmRequest) -> Self::Future {
|
|
83
|
+
// Capture the model name before moving `req` into the inner call, so we
|
|
84
|
+
// can look up pricing after the response arrives.
|
|
85
|
+
let model = req.model().map(ToOwned::to_owned);
|
|
86
|
+
let fut = self.inner.call(req);
|
|
87
|
+
|
|
88
|
+
Box::pin(async move {
|
|
89
|
+
let resp = fut.await?;
|
|
90
|
+
record_cost(&model, &resp);
|
|
91
|
+
Ok(resp)
|
|
92
|
+
})
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
|
97
|
+
|
|
98
|
+
/// Extract usage from the response and record an estimated cost on the current
|
|
99
|
+
/// tracing span as `gen_ai.usage.cost`.
|
|
100
|
+
fn record_cost(model: &Option<String>, resp: &LlmResponse) {
|
|
101
|
+
let Some(model_name) = model else { return };
|
|
102
|
+
let Some(usage) = resp.usage() else { return };
|
|
103
|
+
|
|
104
|
+
if let Some(usd) = cost::completion_cost(model_name, usage.prompt_tokens, usage.completion_tokens) {
|
|
105
|
+
tracing::Span::current().record("gen_ai.usage.cost", usd);
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
// ─── Tests ────────────────────────────────────────────────────────────────────
|
|
110
|
+
|
|
111
|
+
#[cfg(test)]
|
|
112
|
+
mod tests {
|
|
113
|
+
use tower::Layer as _;
|
|
114
|
+
use tower::Service as _;
|
|
115
|
+
|
|
116
|
+
use crate::tower::service::LlmService;
|
|
117
|
+
use crate::tower::types::{LlmRequest, LlmResponse};
|
|
118
|
+
use crate::types::audio::{CreateSpeechRequest, CreateTranscriptionRequest, TranscriptionResponse};
|
|
119
|
+
use crate::types::image::{CreateImageRequest, ImagesResponse};
|
|
120
|
+
use crate::types::moderation::{ModerationRequest, ModerationResponse};
|
|
121
|
+
use crate::types::ocr::{OcrRequest, OcrResponse};
|
|
122
|
+
use crate::types::rerank::{RerankRequest, RerankResponse};
|
|
123
|
+
use crate::types::search::{SearchRequest, SearchResponse};
|
|
124
|
+
use crate::types::{
|
|
125
|
+
AssistantMessage, ChatCompletionRequest, ChatCompletionResponse, Choice, EmbeddingObject, EmbeddingRequest,
|
|
126
|
+
EmbeddingResponse, FinishReason, Message, ModelsListResponse, SystemMessage, Usage,
|
|
127
|
+
};
|
|
128
|
+
use crate::{
|
|
129
|
+
client::{BoxFuture, BoxStream, LlmClient},
|
|
130
|
+
error::{LiterLlmError, Result},
|
|
131
|
+
types::ChatCompletionChunk,
|
|
132
|
+
};
|
|
133
|
+
|
|
134
|
+
use std::pin::Pin;
|
|
135
|
+
use std::task::{Context, Poll};
|
|
136
|
+
|
|
137
|
+
use futures_core::Stream;
|
|
138
|
+
|
|
139
|
+
use super::CostTrackingLayer;
|
|
140
|
+
|
|
141
|
+
// ── Minimal mock ─────────────────────────────────────────────────────────
|
|
142
|
+
|
|
143
|
+
struct EmptyStream;
|
|
144
|
+
|
|
145
|
+
impl Stream for EmptyStream {
|
|
146
|
+
type Item = Result<ChatCompletionChunk>;
|
|
147
|
+
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
148
|
+
Poll::Ready(None)
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
#[derive(Clone)]
|
|
153
|
+
struct PricedMockClient {
|
|
154
|
+
#[allow(dead_code)]
|
|
155
|
+
model: String,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
impl LlmClient for PricedMockClient {
|
|
159
|
+
fn chat(&self, req: ChatCompletionRequest) -> BoxFuture<'_, ChatCompletionResponse> {
|
|
160
|
+
let model = req.model.clone();
|
|
161
|
+
let resp = ChatCompletionResponse {
|
|
162
|
+
id: "test".into(),
|
|
163
|
+
object: "chat.completion".into(),
|
|
164
|
+
created: 0,
|
|
165
|
+
model,
|
|
166
|
+
choices: vec![Choice {
|
|
167
|
+
index: 0,
|
|
168
|
+
message: AssistantMessage {
|
|
169
|
+
content: Some("hi".into()),
|
|
170
|
+
name: None,
|
|
171
|
+
tool_calls: None,
|
|
172
|
+
refusal: None,
|
|
173
|
+
function_call: None,
|
|
174
|
+
},
|
|
175
|
+
finish_reason: Some(FinishReason::Stop),
|
|
176
|
+
}],
|
|
177
|
+
usage: Some(Usage {
|
|
178
|
+
prompt_tokens: 100,
|
|
179
|
+
completion_tokens: 50,
|
|
180
|
+
total_tokens: 150,
|
|
181
|
+
}),
|
|
182
|
+
system_fingerprint: None,
|
|
183
|
+
service_tier: None,
|
|
184
|
+
};
|
|
185
|
+
Box::pin(async move { Ok(resp) })
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
fn chat_stream(&self, _req: ChatCompletionRequest) -> BoxFuture<'_, BoxStream<'_, ChatCompletionChunk>> {
|
|
189
|
+
Box::pin(async move {
|
|
190
|
+
let stream: BoxStream<'_, ChatCompletionChunk> = Box::pin(EmptyStream);
|
|
191
|
+
Ok(stream)
|
|
192
|
+
})
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
fn embed(&self, req: EmbeddingRequest) -> BoxFuture<'_, EmbeddingResponse> {
|
|
196
|
+
let model = req.model.clone();
|
|
197
|
+
let resp = EmbeddingResponse {
|
|
198
|
+
object: "list".into(),
|
|
199
|
+
data: vec![EmbeddingObject {
|
|
200
|
+
object: "embedding".into(),
|
|
201
|
+
embedding: vec![0.1],
|
|
202
|
+
index: 0,
|
|
203
|
+
}],
|
|
204
|
+
model,
|
|
205
|
+
usage: Some(Usage {
|
|
206
|
+
prompt_tokens: 10,
|
|
207
|
+
completion_tokens: 0,
|
|
208
|
+
total_tokens: 10,
|
|
209
|
+
}),
|
|
210
|
+
};
|
|
211
|
+
Box::pin(async move { Ok(resp) })
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
fn list_models(&self) -> BoxFuture<'_, ModelsListResponse> {
|
|
215
|
+
Box::pin(async move {
|
|
216
|
+
Ok(ModelsListResponse {
|
|
217
|
+
object: "list".into(),
|
|
218
|
+
data: vec![],
|
|
219
|
+
})
|
|
220
|
+
})
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
fn image_generate(&self, _req: CreateImageRequest) -> BoxFuture<'_, ImagesResponse> {
|
|
224
|
+
Box::pin(async move {
|
|
225
|
+
Ok(ImagesResponse {
|
|
226
|
+
created: 0,
|
|
227
|
+
data: vec![],
|
|
228
|
+
})
|
|
229
|
+
})
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
fn speech(&self, _req: CreateSpeechRequest) -> BoxFuture<'_, bytes::Bytes> {
|
|
233
|
+
Box::pin(async move { Ok(bytes::Bytes::new()) })
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
fn transcribe(&self, _req: CreateTranscriptionRequest) -> BoxFuture<'_, TranscriptionResponse> {
|
|
237
|
+
Box::pin(async move {
|
|
238
|
+
Ok(TranscriptionResponse {
|
|
239
|
+
text: String::new(),
|
|
240
|
+
language: None,
|
|
241
|
+
duration: None,
|
|
242
|
+
segments: None,
|
|
243
|
+
})
|
|
244
|
+
})
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
fn moderate(&self, _req: ModerationRequest) -> BoxFuture<'_, ModerationResponse> {
|
|
248
|
+
Box::pin(async move {
|
|
249
|
+
Ok(ModerationResponse {
|
|
250
|
+
id: String::new(),
|
|
251
|
+
model: String::new(),
|
|
252
|
+
results: vec![],
|
|
253
|
+
})
|
|
254
|
+
})
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
fn rerank(&self, _req: RerankRequest) -> BoxFuture<'_, RerankResponse> {
|
|
258
|
+
Box::pin(async move {
|
|
259
|
+
Ok(RerankResponse {
|
|
260
|
+
id: None,
|
|
261
|
+
results: vec![],
|
|
262
|
+
meta: None,
|
|
263
|
+
})
|
|
264
|
+
})
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
fn search(&self, _req: SearchRequest) -> BoxFuture<'_, SearchResponse> {
|
|
268
|
+
Box::pin(async {
|
|
269
|
+
Err(LiterLlmError::EndpointNotSupported {
|
|
270
|
+
endpoint: "search".into(),
|
|
271
|
+
provider: "mock".into(),
|
|
272
|
+
})
|
|
273
|
+
})
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
fn ocr(&self, _req: OcrRequest) -> BoxFuture<'_, OcrResponse> {
|
|
277
|
+
Box::pin(async {
|
|
278
|
+
Err(LiterLlmError::EndpointNotSupported {
|
|
279
|
+
endpoint: "ocr".into(),
|
|
280
|
+
provider: "mock".into(),
|
|
281
|
+
})
|
|
282
|
+
})
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
fn chat_req(model: &str) -> ChatCompletionRequest {
|
|
287
|
+
ChatCompletionRequest {
|
|
288
|
+
model: model.into(),
|
|
289
|
+
messages: vec![Message::System(SystemMessage {
|
|
290
|
+
content: "test".into(),
|
|
291
|
+
name: None,
|
|
292
|
+
})],
|
|
293
|
+
..Default::default()
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
// ── Tests ─────────────────────────────────────────────────────────────────
|
|
298
|
+
|
|
299
|
+
/// CostTrackingLayer passes through the response unchanged for a known model.
|
|
300
|
+
#[tokio::test]
|
|
301
|
+
async fn cost_tracking_passes_through_chat_response_for_known_model() {
|
|
302
|
+
let inner = LlmService::new(PricedMockClient { model: "gpt-4".into() });
|
|
303
|
+
let mut svc = CostTrackingLayer.layer(inner);
|
|
304
|
+
let resp = svc
|
|
305
|
+
.call(LlmRequest::Chat(chat_req("gpt-4")))
|
|
306
|
+
.await
|
|
307
|
+
.expect("should succeed");
|
|
308
|
+
// The response must still be a Chat variant with the correct model.
|
|
309
|
+
match resp {
|
|
310
|
+
LlmResponse::Chat(r) => {
|
|
311
|
+
assert_eq!(r.model, "gpt-4");
|
|
312
|
+
// estimated_cost should return Some for gpt-4.
|
|
313
|
+
let cost = r.estimated_cost().expect("gpt-4 must have pricing");
|
|
314
|
+
// 100 * 0.00003 + 50 * 0.00006 = 0.006
|
|
315
|
+
assert!((cost - 0.006).abs() < 1e-9, "unexpected cost: {cost}");
|
|
316
|
+
}
|
|
317
|
+
other => panic!("expected Chat response, got {:?}", std::mem::discriminant(&other)),
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
/// CostTrackingLayer is a no-op (does not panic) for unknown models.
|
|
322
|
+
#[tokio::test]
|
|
323
|
+
async fn cost_tracking_no_op_for_unknown_model() {
|
|
324
|
+
let inner = LlmService::new(PricedMockClient {
|
|
325
|
+
model: "unknown-model".into(),
|
|
326
|
+
});
|
|
327
|
+
let mut svc = CostTrackingLayer.layer(inner);
|
|
328
|
+
let resp = svc
|
|
329
|
+
.call(LlmRequest::Chat(chat_req("unknown-model")))
|
|
330
|
+
.await
|
|
331
|
+
.expect("should succeed without error");
|
|
332
|
+
// Response passes through; no panic even though model has no pricing.
|
|
333
|
+
assert!(matches!(resp, LlmResponse::Chat(_)));
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
/// CostTrackingLayer propagates errors from the inner service.
|
|
337
|
+
#[tokio::test]
|
|
338
|
+
async fn cost_tracking_propagates_inner_errors() {
|
|
339
|
+
use crate::client::{BoxFuture, BoxStream, LlmClient};
|
|
340
|
+
use crate::tower::service::LlmService;
|
|
341
|
+
|
|
342
|
+
#[derive(Clone)]
|
|
343
|
+
struct AlwaysErrorClient;
|
|
344
|
+
|
|
345
|
+
impl LlmClient for AlwaysErrorClient {
|
|
346
|
+
fn chat(&self, _req: ChatCompletionRequest) -> BoxFuture<'_, ChatCompletionResponse> {
|
|
347
|
+
Box::pin(async { Err(LiterLlmError::Timeout) })
|
|
348
|
+
}
|
|
349
|
+
fn chat_stream(&self, _req: ChatCompletionRequest) -> BoxFuture<'_, BoxStream<'_, ChatCompletionChunk>> {
|
|
350
|
+
Box::pin(async move {
|
|
351
|
+
let stream: BoxStream<'_, ChatCompletionChunk> = Box::pin(EmptyStream);
|
|
352
|
+
Ok(stream)
|
|
353
|
+
})
|
|
354
|
+
}
|
|
355
|
+
fn embed(&self, _req: EmbeddingRequest) -> BoxFuture<'_, EmbeddingResponse> {
|
|
356
|
+
Box::pin(async { Err(LiterLlmError::Timeout) })
|
|
357
|
+
}
|
|
358
|
+
fn list_models(&self) -> BoxFuture<'_, ModelsListResponse> {
|
|
359
|
+
Box::pin(async { Err(LiterLlmError::Timeout) })
|
|
360
|
+
}
|
|
361
|
+
fn image_generate(&self, _req: CreateImageRequest) -> BoxFuture<'_, ImagesResponse> {
|
|
362
|
+
Box::pin(async { Err(LiterLlmError::Timeout) })
|
|
363
|
+
}
|
|
364
|
+
fn speech(&self, _req: CreateSpeechRequest) -> BoxFuture<'_, bytes::Bytes> {
|
|
365
|
+
Box::pin(async { Err(LiterLlmError::Timeout) })
|
|
366
|
+
}
|
|
367
|
+
fn transcribe(&self, _req: CreateTranscriptionRequest) -> BoxFuture<'_, TranscriptionResponse> {
|
|
368
|
+
Box::pin(async { Err(LiterLlmError::Timeout) })
|
|
369
|
+
}
|
|
370
|
+
fn moderate(&self, _req: ModerationRequest) -> BoxFuture<'_, ModerationResponse> {
|
|
371
|
+
Box::pin(async { Err(LiterLlmError::Timeout) })
|
|
372
|
+
}
|
|
373
|
+
fn rerank(&self, _req: RerankRequest) -> BoxFuture<'_, RerankResponse> {
|
|
374
|
+
Box::pin(async { Err(LiterLlmError::Timeout) })
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
fn search(&self, _req: SearchRequest) -> BoxFuture<'_, SearchResponse> {
|
|
378
|
+
Box::pin(async {
|
|
379
|
+
Err(LiterLlmError::EndpointNotSupported {
|
|
380
|
+
endpoint: "search".into(),
|
|
381
|
+
provider: "mock".into(),
|
|
382
|
+
})
|
|
383
|
+
})
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
fn ocr(&self, _req: OcrRequest) -> BoxFuture<'_, OcrResponse> {
|
|
387
|
+
Box::pin(async {
|
|
388
|
+
Err(LiterLlmError::EndpointNotSupported {
|
|
389
|
+
endpoint: "ocr".into(),
|
|
390
|
+
provider: "mock".into(),
|
|
391
|
+
})
|
|
392
|
+
})
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
let inner = LlmService::new(AlwaysErrorClient);
|
|
397
|
+
let mut svc = CostTrackingLayer.layer(inner);
|
|
398
|
+
let err = svc
|
|
399
|
+
.call(LlmRequest::Chat(chat_req("gpt-4")))
|
|
400
|
+
.await
|
|
401
|
+
.expect_err("should propagate inner error");
|
|
402
|
+
assert!(matches!(err, LiterLlmError::Timeout));
|
|
403
|
+
}
|
|
404
|
+
}
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
use std::task::{Context, Poll};
|
|
2
|
+
|
|
3
|
+
use tower::Layer;
|
|
4
|
+
use tower::Service;
|
|
5
|
+
|
|
6
|
+
use super::types::{LlmRequest, LlmResponse};
|
|
7
|
+
use crate::client::BoxFuture;
|
|
8
|
+
use crate::error::{LiterLlmError, Result};
|
|
9
|
+
|
|
10
|
+
/// Tower [`Layer`] that routes to a fallback service when the primary service
|
|
11
|
+
/// returns an error.
|
|
12
|
+
///
|
|
13
|
+
/// Only transient errors trigger the fallback — specifically:
|
|
14
|
+
/// [`LiterLlmError::RateLimited`], [`LiterLlmError::ServiceUnavailable`],
|
|
15
|
+
/// [`LiterLlmError::Timeout`], and [`LiterLlmError::ServerError`].
|
|
16
|
+
/// Authentication or bad-request errors are propagated directly without
|
|
17
|
+
/// consulting the fallback because retrying on a different service would
|
|
18
|
+
/// produce the same result.
|
|
19
|
+
pub struct FallbackLayer<F> {
|
|
20
|
+
fallback: F,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
impl<F> FallbackLayer<F> {
|
|
24
|
+
/// Create a new fallback layer with the given fallback service.
|
|
25
|
+
#[must_use]
|
|
26
|
+
pub fn new(fallback: F) -> Self {
|
|
27
|
+
Self { fallback }
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
impl<S, F> Layer<S> for FallbackLayer<F>
|
|
32
|
+
where
|
|
33
|
+
F: Clone,
|
|
34
|
+
{
|
|
35
|
+
type Service = FallbackService<S, F>;
|
|
36
|
+
|
|
37
|
+
fn layer(&self, primary: S) -> Self::Service {
|
|
38
|
+
FallbackService {
|
|
39
|
+
primary,
|
|
40
|
+
// Clone the fallback so the produced service owns it independently.
|
|
41
|
+
fallback: self.fallback.clone(),
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
/// Tower service produced by [`FallbackLayer`].
|
|
47
|
+
pub struct FallbackService<S, F> {
|
|
48
|
+
primary: S,
|
|
49
|
+
fallback: F,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
impl<S, F> Clone for FallbackService<S, F>
|
|
53
|
+
where
|
|
54
|
+
S: Clone,
|
|
55
|
+
F: Clone,
|
|
56
|
+
{
|
|
57
|
+
fn clone(&self) -> Self {
|
|
58
|
+
Self {
|
|
59
|
+
primary: self.primary.clone(),
|
|
60
|
+
fallback: self.fallback.clone(),
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
impl<S, F> Service<LlmRequest> for FallbackService<S, F>
|
|
66
|
+
where
|
|
67
|
+
S: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Send + 'static,
|
|
68
|
+
S::Future: Send + 'static,
|
|
69
|
+
F: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Clone + Send + 'static,
|
|
70
|
+
F::Future: Send + 'static,
|
|
71
|
+
{
|
|
72
|
+
type Response = LlmResponse;
|
|
73
|
+
type Error = LiterLlmError;
|
|
74
|
+
type Future = BoxFuture<'static, LlmResponse>;
|
|
75
|
+
|
|
76
|
+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
|
77
|
+
// Tower contract: poll_ready should prepare exactly one service for a
|
|
78
|
+
// subsequent call. Ideally we would only poll the primary here and
|
|
79
|
+
// poll the fallback lazily in `call`. However, because `call` takes
|
|
80
|
+
// `&mut self` and must return a `'static` future (no reference to
|
|
81
|
+
// `self`), we cannot hold a mutable borrow across the await point.
|
|
82
|
+
// For our concrete use case (DefaultClient is always ready), polling
|
|
83
|
+
// both here is not harmful — neither service blocks and both remain
|
|
84
|
+
// ready until the next call. Callers that compose non-trivially-ready
|
|
85
|
+
// services should use a dedicated load-balancing layer instead.
|
|
86
|
+
match self.primary.poll_ready(cx) {
|
|
87
|
+
Poll::Pending => return Poll::Pending,
|
|
88
|
+
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
|
89
|
+
Poll::Ready(Ok(())) => {}
|
|
90
|
+
}
|
|
91
|
+
self.fallback.poll_ready(cx)
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
fn call(&mut self, req: LlmRequest) -> Self::Future {
|
|
95
|
+
// Clone the request so it can be replayed on the fallback if needed.
|
|
96
|
+
let fallback_req = req.clone();
|
|
97
|
+
let primary_fut = self.primary.call(req);
|
|
98
|
+
|
|
99
|
+
// `poll_ready` readied `self.fallback` for exactly one call.
|
|
100
|
+
// We move the readied service into the async block (so the future is
|
|
101
|
+
// 'static) and replace it with a fresh clone for the *next* call cycle.
|
|
102
|
+
// Tower's contract guarantees at most one `call` per `poll_ready`, so
|
|
103
|
+
// the fresh clone is not used until `poll_ready` runs again.
|
|
104
|
+
let fresh = self.fallback.clone();
|
|
105
|
+
let mut readied_fallback = std::mem::replace(&mut self.fallback, fresh);
|
|
106
|
+
|
|
107
|
+
Box::pin(async move {
|
|
108
|
+
match primary_fut.await {
|
|
109
|
+
Ok(resp) => Ok(resp),
|
|
110
|
+
Err(e) if e.is_transient() => {
|
|
111
|
+
tracing::warn!(
|
|
112
|
+
error = %e,
|
|
113
|
+
"primary service failed with transient error; trying fallback"
|
|
114
|
+
);
|
|
115
|
+
readied_fallback.call(fallback_req).await
|
|
116
|
+
}
|
|
117
|
+
Err(e) => Err(e),
|
|
118
|
+
}
|
|
119
|
+
})
|
|
120
|
+
}
|
|
121
|
+
}
|