liter_llm 1.0.0.pre.rc.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +239 -0
- data/ext/liter_llm_rb/extconf.rb +65 -0
- data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
- data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
- data/ext/liter_llm_rb/native/Cargo.toml +32 -0
- data/ext/liter_llm_rb/native/build.rs +15 -0
- data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
- data/lib/liter_llm.rb +8 -0
- data/sig/liter_llm.rbs +416 -0
- data/vendor/Cargo.toml +54 -0
- data/vendor/liter-llm/Cargo.toml +92 -0
- data/vendor/liter-llm/README.md +252 -0
- data/vendor/liter-llm/schemas/pricing.json +40 -0
- data/vendor/liter-llm/schemas/providers.json +1662 -0
- data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
- data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
- data/vendor/liter-llm/src/auth/mod.rs +68 -0
- data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
- data/vendor/liter-llm/src/client/config.rs +351 -0
- data/vendor/liter-llm/src/client/managed.rs +622 -0
- data/vendor/liter-llm/src/client/mod.rs +864 -0
- data/vendor/liter-llm/src/cost.rs +212 -0
- data/vendor/liter-llm/src/error.rs +190 -0
- data/vendor/liter-llm/src/http/eventstream.rs +860 -0
- data/vendor/liter-llm/src/http/mod.rs +12 -0
- data/vendor/liter-llm/src/http/request.rs +438 -0
- data/vendor/liter-llm/src/http/retry.rs +72 -0
- data/vendor/liter-llm/src/http/streaming.rs +289 -0
- data/vendor/liter-llm/src/lib.rs +37 -0
- data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
- data/vendor/liter-llm/src/provider/azure.rs +579 -0
- data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
- data/vendor/liter-llm/src/provider/cohere.rs +654 -0
- data/vendor/liter-llm/src/provider/custom.rs +404 -0
- data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
- data/vendor/liter-llm/src/provider/mistral.rs +188 -0
- data/vendor/liter-llm/src/provider/mod.rs +616 -0
- data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
- data/vendor/liter-llm/src/tests.rs +1425 -0
- data/vendor/liter-llm/src/tokenizer.rs +281 -0
- data/vendor/liter-llm/src/tower/budget.rs +599 -0
- data/vendor/liter-llm/src/tower/cache.rs +502 -0
- data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
- data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
- data/vendor/liter-llm/src/tower/cost.rs +404 -0
- data/vendor/liter-llm/src/tower/fallback.rs +121 -0
- data/vendor/liter-llm/src/tower/health.rs +219 -0
- data/vendor/liter-llm/src/tower/hooks.rs +369 -0
- data/vendor/liter-llm/src/tower/mod.rs +77 -0
- data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
- data/vendor/liter-llm/src/tower/router.rs +436 -0
- data/vendor/liter-llm/src/tower/service.rs +181 -0
- data/vendor/liter-llm/src/tower/tests.rs +539 -0
- data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
- data/vendor/liter-llm/src/tower/tracing.rs +209 -0
- data/vendor/liter-llm/src/tower/types.rs +170 -0
- data/vendor/liter-llm/src/types/audio.rs +52 -0
- data/vendor/liter-llm/src/types/batch.rs +77 -0
- data/vendor/liter-llm/src/types/chat.rs +214 -0
- data/vendor/liter-llm/src/types/common.rs +244 -0
- data/vendor/liter-llm/src/types/embedding.rs +84 -0
- data/vendor/liter-llm/src/types/files.rs +58 -0
- data/vendor/liter-llm/src/types/image.rs +40 -0
- data/vendor/liter-llm/src/types/mod.rs +27 -0
- data/vendor/liter-llm/src/types/models.rs +21 -0
- data/vendor/liter-llm/src/types/moderation.rs +80 -0
- data/vendor/liter-llm/src/types/ocr.rs +87 -0
- data/vendor/liter-llm/src/types/rerank.rs +46 -0
- data/vendor/liter-llm/src/types/responses.rs +55 -0
- data/vendor/liter-llm/src/types/search.rs +45 -0
- data/vendor/liter-llm/tests/contract.rs +332 -0
- data/vendor/liter-llm-ffi/Cargo.toml +30 -0
- data/vendor/liter-llm-ffi/build.rs +66 -0
- data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
- data/vendor/liter-llm-ffi/liter_llm.h +850 -0
- data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
- metadata +286 -0
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
use std::sync::Arc;
|
|
2
|
+
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
3
|
+
use std::task::{Context, Poll};
|
|
4
|
+
use std::time::Instant;
|
|
5
|
+
|
|
6
|
+
use dashmap::DashMap;
|
|
7
|
+
use tower::Service;
|
|
8
|
+
|
|
9
|
+
use super::types::{LlmRequest, LlmResponse};
|
|
10
|
+
use crate::client::BoxFuture;
|
|
11
|
+
use crate::error::{LiterLlmError, Result};
|
|
12
|
+
|
|
13
|
+
// ---- Routing strategy ------------------------------------------------------
|
|
14
|
+
|
|
15
|
+
/// Routing strategy for selecting among multiple deployments.
|
|
16
|
+
#[derive(Debug, Clone)]
|
|
17
|
+
pub enum RoutingStrategy {
|
|
18
|
+
/// Round-robin across all deployments in order.
|
|
19
|
+
RoundRobin,
|
|
20
|
+
/// Try deployments in order; advance to the next on a transient error.
|
|
21
|
+
/// Propagates immediately on non-transient errors.
|
|
22
|
+
Fallback,
|
|
23
|
+
/// Route to the deployment with the lowest observed latency (exponential
|
|
24
|
+
/// moving average).
|
|
25
|
+
LatencyBased,
|
|
26
|
+
/// Route to the cheapest deployment for the requested model using the
|
|
27
|
+
/// embedded pricing registry.
|
|
28
|
+
CostBased,
|
|
29
|
+
/// Weighted random distribution across deployments. Weights are
|
|
30
|
+
/// normalised at construction time; higher values receive proportionally
|
|
31
|
+
/// more traffic.
|
|
32
|
+
WeightedRandom {
|
|
33
|
+
/// One weight per deployment (must have the same length as the
|
|
34
|
+
/// deployments vec).
|
|
35
|
+
weights: Vec<f64>,
|
|
36
|
+
},
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
// ---- Per-deployment metrics ------------------------------------------------
|
|
40
|
+
|
|
41
|
+
/// Tracks per-deployment latency using an exponential moving average.
|
|
42
|
+
#[derive(Debug)]
|
|
43
|
+
struct DeploymentMetrics {
|
|
44
|
+
/// Exponential moving average of latency in seconds.
|
|
45
|
+
latency_ema: f64,
|
|
46
|
+
/// Number of requests seen (used to seed the EMA).
|
|
47
|
+
request_count: u64,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
impl Default for DeploymentMetrics {
|
|
51
|
+
fn default() -> Self {
|
|
52
|
+
Self {
|
|
53
|
+
latency_ema: 0.0,
|
|
54
|
+
request_count: 0,
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
impl DeploymentMetrics {
|
|
60
|
+
/// Update the EMA with a new latency sample (in seconds).
|
|
61
|
+
fn record_latency(&mut self, latency_secs: f64) {
|
|
62
|
+
// Smoothing factor for EMA — higher values weight recent samples more.
|
|
63
|
+
const ALPHA: f64 = 0.3;
|
|
64
|
+
|
|
65
|
+
if self.request_count == 0 {
|
|
66
|
+
self.latency_ema = latency_secs;
|
|
67
|
+
} else {
|
|
68
|
+
self.latency_ema = ALPHA * latency_secs + (1.0 - ALPHA) * self.latency_ema;
|
|
69
|
+
}
|
|
70
|
+
self.request_count += 1;
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
/// Shared state tracking per-deployment metrics, keyed by deployment index.
|
|
75
|
+
pub struct RouterState {
|
|
76
|
+
metrics: Arc<DashMap<usize, DeploymentMetrics>>,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
impl RouterState {
|
|
80
|
+
fn new() -> Self {
|
|
81
|
+
Self {
|
|
82
|
+
metrics: Arc::new(DashMap::new()),
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
impl Clone for RouterState {
|
|
88
|
+
fn clone(&self) -> Self {
|
|
89
|
+
Self {
|
|
90
|
+
metrics: Arc::clone(&self.metrics),
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
// ---- Router ----------------------------------------------------------------
|
|
96
|
+
|
|
97
|
+
/// A router that distributes [`LlmRequest`]s across multiple service
|
|
98
|
+
/// instances according to a [`RoutingStrategy`].
|
|
99
|
+
///
|
|
100
|
+
/// The inner deployments must be `Clone` so the router can hand out
|
|
101
|
+
/// independent service handles per call. Use [`LlmService`] as the
|
|
102
|
+
/// deployment type when wrapping a [`crate::client::LlmClient`].
|
|
103
|
+
///
|
|
104
|
+
/// [`LlmService`]: super::service::LlmService
|
|
105
|
+
pub struct Router<S> {
|
|
106
|
+
deployments: Vec<S>,
|
|
107
|
+
strategy: RoutingStrategy,
|
|
108
|
+
/// Monotonically incrementing counter used by [`RoutingStrategy::RoundRobin`].
|
|
109
|
+
counter: Arc<AtomicUsize>,
|
|
110
|
+
/// Per-deployment metrics (latency tracking, etc.).
|
|
111
|
+
state: RouterState,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
impl<S> Router<S> {
|
|
115
|
+
/// Create a new router.
|
|
116
|
+
///
|
|
117
|
+
/// # Errors
|
|
118
|
+
///
|
|
119
|
+
/// Returns [`LiterLlmError::BadRequest`] if `deployments` is empty — a
|
|
120
|
+
/// router with no deployments cannot handle any request.
|
|
121
|
+
///
|
|
122
|
+
/// For [`RoutingStrategy::WeightedRandom`], returns an error if the
|
|
123
|
+
/// weights vector length does not match the number of deployments or
|
|
124
|
+
/// if all weights are zero.
|
|
125
|
+
pub fn new(deployments: Vec<S>, strategy: RoutingStrategy) -> Result<Self> {
|
|
126
|
+
if deployments.is_empty() {
|
|
127
|
+
return Err(LiterLlmError::BadRequest {
|
|
128
|
+
message: "Router requires at least one deployment".into(),
|
|
129
|
+
});
|
|
130
|
+
}
|
|
131
|
+
if let RoutingStrategy::WeightedRandom { ref weights } = strategy {
|
|
132
|
+
if weights.len() != deployments.len() {
|
|
133
|
+
return Err(LiterLlmError::BadRequest {
|
|
134
|
+
message: format!(
|
|
135
|
+
"WeightedRandom: weights length ({}) must match deployments length ({})",
|
|
136
|
+
weights.len(),
|
|
137
|
+
deployments.len()
|
|
138
|
+
),
|
|
139
|
+
});
|
|
140
|
+
}
|
|
141
|
+
let total: f64 = weights.iter().sum();
|
|
142
|
+
if total <= 0.0 {
|
|
143
|
+
return Err(LiterLlmError::BadRequest {
|
|
144
|
+
message: "WeightedRandom: total weight must be positive".into(),
|
|
145
|
+
});
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
Ok(Self {
|
|
149
|
+
deployments,
|
|
150
|
+
strategy,
|
|
151
|
+
counter: Arc::new(AtomicUsize::new(0)),
|
|
152
|
+
state: RouterState::new(),
|
|
153
|
+
})
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
impl<S: Clone> Clone for Router<S> {
|
|
158
|
+
fn clone(&self) -> Self {
|
|
159
|
+
Self {
|
|
160
|
+
deployments: self.deployments.clone(),
|
|
161
|
+
strategy: self.strategy.clone(),
|
|
162
|
+
counter: Arc::clone(&self.counter),
|
|
163
|
+
state: self.state.clone(),
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
impl<S> Service<LlmRequest> for Router<S>
|
|
169
|
+
where
|
|
170
|
+
S: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Clone + Send + 'static,
|
|
171
|
+
S::Future: Send + 'static,
|
|
172
|
+
{
|
|
173
|
+
type Response = LlmResponse;
|
|
174
|
+
type Error = LiterLlmError;
|
|
175
|
+
type Future = BoxFuture<'static, LlmResponse>;
|
|
176
|
+
|
|
177
|
+
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
|
178
|
+
// All inner services are cloned per-call, so there is no persistent
|
|
179
|
+
// readied slot to manage here. A more sophisticated implementation
|
|
180
|
+
// could poll each deployment's readiness and track the result, but
|
|
181
|
+
// for DefaultClient (which is always ready) this is unnecessary.
|
|
182
|
+
Poll::Ready(Ok(()))
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
fn call(&mut self, req: LlmRequest) -> Self::Future {
|
|
186
|
+
match &self.strategy {
|
|
187
|
+
RoutingStrategy::RoundRobin => {
|
|
188
|
+
let idx = self.counter.fetch_add(1, Ordering::Relaxed) % self.deployments.len();
|
|
189
|
+
let mut svc = self.deployments[idx].clone();
|
|
190
|
+
Box::pin(async move { svc.call(req).await })
|
|
191
|
+
}
|
|
192
|
+
RoutingStrategy::Fallback => {
|
|
193
|
+
let deployments = self.deployments.clone();
|
|
194
|
+
Box::pin(async move {
|
|
195
|
+
let mut last_err: Option<LiterLlmError> = None;
|
|
196
|
+
for mut svc in deployments {
|
|
197
|
+
match svc.call(req.clone()).await {
|
|
198
|
+
Ok(resp) => return Ok(resp),
|
|
199
|
+
Err(e) if e.is_transient() => {
|
|
200
|
+
tracing::warn!(
|
|
201
|
+
error = %e,
|
|
202
|
+
"deployment failed with transient error; trying next deployment"
|
|
203
|
+
);
|
|
204
|
+
last_err = Some(e);
|
|
205
|
+
}
|
|
206
|
+
Err(e) => return Err(e),
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
Err(last_err.unwrap_or(LiterLlmError::ServerError {
|
|
210
|
+
message: "all deployments failed".into(),
|
|
211
|
+
}))
|
|
212
|
+
})
|
|
213
|
+
}
|
|
214
|
+
RoutingStrategy::LatencyBased => {
|
|
215
|
+
let state = self.state.clone();
|
|
216
|
+
let n = self.deployments.len();
|
|
217
|
+
|
|
218
|
+
// Pick deployment with the lowest latency EMA.
|
|
219
|
+
// Deployments with no data default to EMA 0.0 (optimistic).
|
|
220
|
+
let mut best_idx = 0;
|
|
221
|
+
let mut best_ema = f64::MAX;
|
|
222
|
+
for i in 0..n {
|
|
223
|
+
let ema = state.metrics.get(&i).map_or(0.0, |m| m.latency_ema);
|
|
224
|
+
if ema < best_ema {
|
|
225
|
+
best_ema = ema;
|
|
226
|
+
best_idx = i;
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
let mut svc = self.deployments[best_idx].clone();
|
|
231
|
+
let idx = best_idx;
|
|
232
|
+
|
|
233
|
+
Box::pin(async move {
|
|
234
|
+
let start = Instant::now();
|
|
235
|
+
let result = svc.call(req).await;
|
|
236
|
+
let latency = start.elapsed().as_secs_f64();
|
|
237
|
+
|
|
238
|
+
state.metrics.entry(idx).or_default().record_latency(latency);
|
|
239
|
+
|
|
240
|
+
result
|
|
241
|
+
})
|
|
242
|
+
}
|
|
243
|
+
RoutingStrategy::CostBased => {
|
|
244
|
+
let model = req.model().map(ToOwned::to_owned);
|
|
245
|
+
let deployments = self.deployments.clone();
|
|
246
|
+
|
|
247
|
+
// For cost-based routing, we try to pick the cheapest deployment.
|
|
248
|
+
// Since all deployments serve the same model, cost is typically
|
|
249
|
+
// uniform. The differentiator is when deployments wrap different
|
|
250
|
+
// providers (e.g., OpenAI vs Azure) with different pricing.
|
|
251
|
+
//
|
|
252
|
+
// Without per-deployment provider metadata, we use a simple
|
|
253
|
+
// heuristic: try each deployment in order and return the first
|
|
254
|
+
// success. A future enhancement could attach provider metadata
|
|
255
|
+
// to each deployment.
|
|
256
|
+
//
|
|
257
|
+
// For now, CostBased routes identically to Fallback but logs the
|
|
258
|
+
// cost after success.
|
|
259
|
+
Box::pin(async move {
|
|
260
|
+
let mut last_err: Option<LiterLlmError> = None;
|
|
261
|
+
for mut svc in deployments {
|
|
262
|
+
match svc.call(req.clone()).await {
|
|
263
|
+
Ok(resp) => {
|
|
264
|
+
if let (Some(model_name), Some(usage)) = (&model, resp.usage())
|
|
265
|
+
&& let Some(cost) = crate::cost::completion_cost(
|
|
266
|
+
model_name,
|
|
267
|
+
usage.prompt_tokens,
|
|
268
|
+
usage.completion_tokens,
|
|
269
|
+
)
|
|
270
|
+
{
|
|
271
|
+
tracing::debug!(
|
|
272
|
+
model = %model_name,
|
|
273
|
+
cost_usd = cost,
|
|
274
|
+
"cost-based routing: estimated cost"
|
|
275
|
+
);
|
|
276
|
+
}
|
|
277
|
+
return Ok(resp);
|
|
278
|
+
}
|
|
279
|
+
Err(e) if e.is_transient() => {
|
|
280
|
+
last_err = Some(e);
|
|
281
|
+
}
|
|
282
|
+
Err(e) => return Err(e),
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
Err(last_err.unwrap_or(LiterLlmError::ServerError {
|
|
286
|
+
message: "all deployments failed".into(),
|
|
287
|
+
}))
|
|
288
|
+
})
|
|
289
|
+
}
|
|
290
|
+
RoutingStrategy::WeightedRandom { weights } => {
|
|
291
|
+
let idx = weighted_random_select(weights);
|
|
292
|
+
let mut svc = self.deployments[idx].clone();
|
|
293
|
+
Box::pin(async move { svc.call(req).await })
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
/// Select a deployment index using weighted random distribution.
|
|
300
|
+
///
|
|
301
|
+
/// Uses a simple linear scan with a random threshold. For small deployment
|
|
302
|
+
/// counts (typical: 2-5) this is fast enough; no binary search needed.
|
|
303
|
+
fn weighted_random_select(weights: &[f64]) -> usize {
|
|
304
|
+
let total: f64 = weights.iter().sum();
|
|
305
|
+
// Simple pseudo-random: use the lower bits of the current time.
|
|
306
|
+
// This avoids adding a `rand` dependency. For production use,
|
|
307
|
+
// callers who need better randomness can use the `rand` crate
|
|
308
|
+
// externally.
|
|
309
|
+
let nanos = std::time::SystemTime::now()
|
|
310
|
+
.duration_since(std::time::UNIX_EPOCH)
|
|
311
|
+
.unwrap_or_default()
|
|
312
|
+
.subsec_nanos();
|
|
313
|
+
let threshold = (f64::from(nanos) / 1_000_000_000.0) * total;
|
|
314
|
+
|
|
315
|
+
let mut cumulative = 0.0;
|
|
316
|
+
for (i, &w) in weights.iter().enumerate() {
|
|
317
|
+
cumulative += w;
|
|
318
|
+
if threshold < cumulative {
|
|
319
|
+
return i;
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
// Fallback to last deployment (rounding edge case).
|
|
323
|
+
weights.len() - 1
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
// ---- Tests -----------------------------------------------------------------
|
|
327
|
+
|
|
328
|
+
#[cfg(test)]
|
|
329
|
+
mod tests {
|
|
330
|
+
use super::*;
|
|
331
|
+
use crate::tower::service::LlmService;
|
|
332
|
+
use crate::tower::tests_common::{MockClient, chat_req};
|
|
333
|
+
use crate::tower::types::LlmRequest;
|
|
334
|
+
|
|
335
|
+
#[tokio::test]
|
|
336
|
+
async fn latency_based_routes_to_fastest() {
|
|
337
|
+
let deployments: Vec<LlmService<MockClient>> =
|
|
338
|
+
vec![LlmService::new(MockClient::ok()), LlmService::new(MockClient::ok())];
|
|
339
|
+
|
|
340
|
+
let mut router = Router::new(deployments, RoutingStrategy::LatencyBased).expect("non-empty deployments");
|
|
341
|
+
|
|
342
|
+
// First call goes to deployment 0 (both have EMA 0.0, picks first).
|
|
343
|
+
let resp = router.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
|
|
344
|
+
assert!(resp.is_ok());
|
|
345
|
+
|
|
346
|
+
// After the first call, deployment 0 has a non-zero EMA.
|
|
347
|
+
// The second call should go to deployment 1 (still at 0.0 EMA).
|
|
348
|
+
let resp = router.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
|
|
349
|
+
assert!(resp.is_ok());
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
#[tokio::test]
|
|
353
|
+
async fn cost_based_falls_through_on_transient_error() {
|
|
354
|
+
let deployments: Vec<LlmService<MockClient>> = vec![
|
|
355
|
+
LlmService::new(MockClient::failing_service_unavailable()),
|
|
356
|
+
LlmService::new(MockClient::ok()),
|
|
357
|
+
];
|
|
358
|
+
|
|
359
|
+
let mut router = Router::new(deployments, RoutingStrategy::CostBased).expect("non-empty deployments");
|
|
360
|
+
|
|
361
|
+
let resp = router.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
|
|
362
|
+
assert!(resp.is_ok(), "should fall through to second deployment");
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
#[tokio::test]
|
|
366
|
+
async fn weighted_random_selects_valid_deployment() {
|
|
367
|
+
let deployments: Vec<LlmService<MockClient>> = vec![
|
|
368
|
+
LlmService::new(MockClient::ok()),
|
|
369
|
+
LlmService::new(MockClient::ok()),
|
|
370
|
+
LlmService::new(MockClient::ok()),
|
|
371
|
+
];
|
|
372
|
+
|
|
373
|
+
let mut router = Router::new(
|
|
374
|
+
deployments,
|
|
375
|
+
RoutingStrategy::WeightedRandom {
|
|
376
|
+
weights: vec![1.0, 2.0, 3.0],
|
|
377
|
+
},
|
|
378
|
+
)
|
|
379
|
+
.expect("non-empty deployments");
|
|
380
|
+
|
|
381
|
+
// Run several requests — all should succeed regardless of which
|
|
382
|
+
// deployment is selected.
|
|
383
|
+
for _ in 0..20 {
|
|
384
|
+
let resp = router.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
|
|
385
|
+
assert!(resp.is_ok());
|
|
386
|
+
}
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
#[tokio::test]
|
|
390
|
+
async fn weighted_random_rejects_mismatched_weights() {
|
|
391
|
+
let deployments: Vec<LlmService<MockClient>> =
|
|
392
|
+
vec![LlmService::new(MockClient::ok()), LlmService::new(MockClient::ok())];
|
|
393
|
+
|
|
394
|
+
let result = Router::new(
|
|
395
|
+
deployments,
|
|
396
|
+
RoutingStrategy::WeightedRandom {
|
|
397
|
+
weights: vec![1.0], // Wrong length.
|
|
398
|
+
},
|
|
399
|
+
);
|
|
400
|
+
assert!(result.is_err());
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
#[tokio::test]
|
|
404
|
+
async fn weighted_random_rejects_zero_total_weight() {
|
|
405
|
+
let deployments: Vec<LlmService<MockClient>> = vec![LlmService::new(MockClient::ok())];
|
|
406
|
+
|
|
407
|
+
let result = Router::new(deployments, RoutingStrategy::WeightedRandom { weights: vec![0.0] });
|
|
408
|
+
assert!(result.is_err());
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
#[test]
|
|
412
|
+
fn weighted_random_select_returns_valid_index() {
|
|
413
|
+
let weights = vec![1.0, 2.0, 3.0];
|
|
414
|
+
for _ in 0..100 {
|
|
415
|
+
let idx = weighted_random_select(&weights);
|
|
416
|
+
assert!(idx < weights.len());
|
|
417
|
+
}
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
#[test]
|
|
421
|
+
fn deployment_metrics_ema_updates() {
|
|
422
|
+
let mut m = DeploymentMetrics::default();
|
|
423
|
+
m.record_latency(1.0);
|
|
424
|
+
assert!(
|
|
425
|
+
(m.latency_ema - 1.0).abs() < 1e-9,
|
|
426
|
+
"first sample should set EMA directly"
|
|
427
|
+
);
|
|
428
|
+
|
|
429
|
+
m.record_latency(0.0);
|
|
430
|
+
// EMA = 0.3 * 0.0 + 0.7 * 1.0 = 0.7
|
|
431
|
+
assert!(
|
|
432
|
+
(m.latency_ema - 0.7).abs() < 1e-9,
|
|
433
|
+
"EMA should be 0.7 after second sample"
|
|
434
|
+
);
|
|
435
|
+
}
|
|
436
|
+
}
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
use std::collections::VecDeque;
|
|
2
|
+
use std::pin::Pin;
|
|
3
|
+
use std::sync::Arc;
|
|
4
|
+
use std::task::{Context, Poll};
|
|
5
|
+
|
|
6
|
+
use futures_core::Stream;
|
|
7
|
+
use tower::Service;
|
|
8
|
+
|
|
9
|
+
use super::types::{LlmRequest, LlmResponse};
|
|
10
|
+
use crate::client::{BoxFuture, LlmClient};
|
|
11
|
+
use crate::error::{LiterLlmError, Result};
|
|
12
|
+
use crate::types::ChatCompletionChunk;
|
|
13
|
+
|
|
14
|
+
/// A thin tower [`Service`] wrapper around any [`LlmClient`] implementation.
|
|
15
|
+
///
|
|
16
|
+
/// Because [`LlmClient`] methods take `&self`, the inner client is stored
|
|
17
|
+
/// behind an [`Arc`] so the service can be cloned without owning a unique
|
|
18
|
+
/// reference. `tower::Service::call` takes `&mut self`, but the actual
|
|
19
|
+
/// async work is dispatched through the shared reference inside the arc.
|
|
20
|
+
///
|
|
21
|
+
/// # Streaming behaviour
|
|
22
|
+
///
|
|
23
|
+
/// **Important:** Streaming responses (`ChatStream`) are **fully buffered** in
|
|
24
|
+
/// memory before being yielded to the caller. This is a consequence of Tower's
|
|
25
|
+
/// `Service` trait requiring `'static` futures — the borrowed stream returned by
|
|
26
|
+
/// [`LlmClient::chat_stream`] cannot outlive the `call` future without unsafe
|
|
27
|
+
/// lifetime extension. All chunks are collected into a `VecDeque` and then
|
|
28
|
+
/// replayed through a `BoxStream<'static, ...>`.
|
|
29
|
+
///
|
|
30
|
+
/// If you need incremental, unbuffered streaming, use [`LlmClient`] directly
|
|
31
|
+
/// instead of wrapping it in `LlmService`.
|
|
32
|
+
pub struct LlmService<C> {
|
|
33
|
+
inner: Arc<C>,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
impl<C> LlmService<C> {
|
|
37
|
+
/// Wrap `client` in a tower-compatible service.
|
|
38
|
+
#[must_use]
|
|
39
|
+
pub fn new(client: C) -> Self {
|
|
40
|
+
Self {
|
|
41
|
+
inner: Arc::new(client),
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
/// Wrap a client that is already behind an `Arc`.
|
|
46
|
+
///
|
|
47
|
+
/// This avoids a redundant `Arc` layer when the caller (e.g.
|
|
48
|
+
/// [`ManagedClient`](crate::client::managed::ManagedClient)) already
|
|
49
|
+
/// owns an `Arc<C>`.
|
|
50
|
+
#[must_use]
|
|
51
|
+
pub fn new_from_arc(client: Arc<C>) -> Self {
|
|
52
|
+
Self { inner: client }
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
/// Return a reference to the inner client.
|
|
56
|
+
pub fn inner(&self) -> &C {
|
|
57
|
+
&self.inner
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
impl<C> Clone for LlmService<C> {
|
|
62
|
+
fn clone(&self) -> Self {
|
|
63
|
+
Self {
|
|
64
|
+
inner: Arc::clone(&self.inner),
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
impl<C> Service<LlmRequest> for LlmService<C>
|
|
70
|
+
where
|
|
71
|
+
C: LlmClient + Send + Sync + 'static,
|
|
72
|
+
{
|
|
73
|
+
type Response = LlmResponse;
|
|
74
|
+
type Error = LiterLlmError;
|
|
75
|
+
type Future = BoxFuture<'static, LlmResponse>;
|
|
76
|
+
|
|
77
|
+
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
|
78
|
+
Poll::Ready(Ok(()))
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
fn call(&mut self, req: LlmRequest) -> Self::Future {
|
|
82
|
+
let client = Arc::clone(&self.inner);
|
|
83
|
+
Box::pin(async move {
|
|
84
|
+
match req {
|
|
85
|
+
LlmRequest::Chat(r) => {
|
|
86
|
+
let resp = client.chat(r).await?;
|
|
87
|
+
Ok(LlmResponse::Chat(resp))
|
|
88
|
+
}
|
|
89
|
+
LlmRequest::ChatStream(r) => {
|
|
90
|
+
// Collect the stream into a Vec while the Arc-backed client is
|
|
91
|
+
// alive. This avoids the unsound transmute that would otherwise
|
|
92
|
+
// be needed to extend the stream's borrow lifetime to 'static.
|
|
93
|
+
// The cost is that streaming chunks are buffered before being
|
|
94
|
+
// yielded; this is acceptable because tower middleware cannot
|
|
95
|
+
// express borrowed lifetimes across the Service boundary.
|
|
96
|
+
let stream = client.chat_stream(r).await?;
|
|
97
|
+
let chunks = collect_stream(stream).await?;
|
|
98
|
+
let static_stream: crate::client::BoxStream<'static, ChatCompletionChunk> =
|
|
99
|
+
Box::pin(OwnedChunksStream { chunks });
|
|
100
|
+
Ok(LlmResponse::ChatStream(static_stream))
|
|
101
|
+
}
|
|
102
|
+
LlmRequest::Embed(r) => {
|
|
103
|
+
let resp = client.embed(r).await?;
|
|
104
|
+
Ok(LlmResponse::Embed(resp))
|
|
105
|
+
}
|
|
106
|
+
LlmRequest::ListModels => {
|
|
107
|
+
let resp = client.list_models().await?;
|
|
108
|
+
Ok(LlmResponse::ListModels(resp))
|
|
109
|
+
}
|
|
110
|
+
LlmRequest::ImageGenerate(r) => {
|
|
111
|
+
let resp = client.image_generate(r).await?;
|
|
112
|
+
Ok(LlmResponse::ImageGenerate(resp))
|
|
113
|
+
}
|
|
114
|
+
LlmRequest::Speech(r) => {
|
|
115
|
+
let resp = client.speech(r).await?;
|
|
116
|
+
Ok(LlmResponse::Speech(resp))
|
|
117
|
+
}
|
|
118
|
+
LlmRequest::Transcribe(r) => {
|
|
119
|
+
let resp = client.transcribe(r).await?;
|
|
120
|
+
Ok(LlmResponse::Transcribe(resp))
|
|
121
|
+
}
|
|
122
|
+
LlmRequest::Moderate(r) => {
|
|
123
|
+
let resp = client.moderate(r).await?;
|
|
124
|
+
Ok(LlmResponse::Moderate(resp))
|
|
125
|
+
}
|
|
126
|
+
LlmRequest::Rerank(r) => {
|
|
127
|
+
let resp = client.rerank(r).await?;
|
|
128
|
+
Ok(LlmResponse::Rerank(resp))
|
|
129
|
+
}
|
|
130
|
+
LlmRequest::Search(r) => {
|
|
131
|
+
let resp = client.search(r).await?;
|
|
132
|
+
Ok(LlmResponse::Search(resp))
|
|
133
|
+
}
|
|
134
|
+
LlmRequest::Ocr(r) => {
|
|
135
|
+
let resp = client.ocr(r).await?;
|
|
136
|
+
Ok(LlmResponse::Ocr(resp))
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
})
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
/// Collect all items from a stream into a `VecDeque`, stopping on the first error.
|
|
144
|
+
async fn collect_stream<'a>(
|
|
145
|
+
mut stream: crate::client::BoxStream<'a, ChatCompletionChunk>,
|
|
146
|
+
) -> Result<VecDeque<ChatCompletionChunk>> {
|
|
147
|
+
let mut chunks = VecDeque::new();
|
|
148
|
+
loop {
|
|
149
|
+
// Drive the stream by polling it inside a future::poll_fn.
|
|
150
|
+
let item = std::future::poll_fn(|cx| Pin::as_mut(&mut stream).poll_next(cx)).await;
|
|
151
|
+
match item {
|
|
152
|
+
Some(Ok(chunk)) => chunks.push_back(chunk),
|
|
153
|
+
Some(Err(e)) => return Err(e),
|
|
154
|
+
None => break,
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
Ok(chunks)
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
/// A `Stream` that yields items from an owned `VecDeque` in order.
|
|
161
|
+
///
|
|
162
|
+
/// Uses `pop_front` to avoid cloning — each chunk is moved out of the deque
|
|
163
|
+
/// and ownership is transferred to the caller without any copy.
|
|
164
|
+
///
|
|
165
|
+
/// Used to wrap collected streaming chunks so they can be returned as a
|
|
166
|
+
/// `BoxStream<'static, ...>` without any lifetime dependencies.
|
|
167
|
+
struct OwnedChunksStream {
|
|
168
|
+
chunks: VecDeque<ChatCompletionChunk>,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
impl Stream for OwnedChunksStream {
|
|
172
|
+
type Item = Result<ChatCompletionChunk>;
|
|
173
|
+
|
|
174
|
+
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
175
|
+
Poll::Ready(self.chunks.pop_front().map(Ok))
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
fn size_hint(&self) -> (usize, Option<usize>) {
|
|
179
|
+
(self.chunks.len(), Some(self.chunks.len()))
|
|
180
|
+
}
|
|
181
|
+
}
|