liter_llm 1.0.0.pre.rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +239 -0
  3. data/ext/liter_llm_rb/extconf.rb +65 -0
  4. data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
  5. data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
  6. data/ext/liter_llm_rb/native/Cargo.toml +32 -0
  7. data/ext/liter_llm_rb/native/build.rs +15 -0
  8. data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
  9. data/lib/liter_llm.rb +8 -0
  10. data/sig/liter_llm.rbs +416 -0
  11. data/vendor/Cargo.toml +54 -0
  12. data/vendor/liter-llm/Cargo.toml +92 -0
  13. data/vendor/liter-llm/README.md +252 -0
  14. data/vendor/liter-llm/schemas/pricing.json +40 -0
  15. data/vendor/liter-llm/schemas/providers.json +1662 -0
  16. data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
  17. data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
  18. data/vendor/liter-llm/src/auth/mod.rs +68 -0
  19. data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
  20. data/vendor/liter-llm/src/client/config.rs +351 -0
  21. data/vendor/liter-llm/src/client/managed.rs +622 -0
  22. data/vendor/liter-llm/src/client/mod.rs +864 -0
  23. data/vendor/liter-llm/src/cost.rs +212 -0
  24. data/vendor/liter-llm/src/error.rs +190 -0
  25. data/vendor/liter-llm/src/http/eventstream.rs +860 -0
  26. data/vendor/liter-llm/src/http/mod.rs +12 -0
  27. data/vendor/liter-llm/src/http/request.rs +438 -0
  28. data/vendor/liter-llm/src/http/retry.rs +72 -0
  29. data/vendor/liter-llm/src/http/streaming.rs +289 -0
  30. data/vendor/liter-llm/src/lib.rs +37 -0
  31. data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
  32. data/vendor/liter-llm/src/provider/azure.rs +579 -0
  33. data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
  34. data/vendor/liter-llm/src/provider/cohere.rs +654 -0
  35. data/vendor/liter-llm/src/provider/custom.rs +404 -0
  36. data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
  37. data/vendor/liter-llm/src/provider/mistral.rs +188 -0
  38. data/vendor/liter-llm/src/provider/mod.rs +616 -0
  39. data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
  40. data/vendor/liter-llm/src/tests.rs +1425 -0
  41. data/vendor/liter-llm/src/tokenizer.rs +281 -0
  42. data/vendor/liter-llm/src/tower/budget.rs +599 -0
  43. data/vendor/liter-llm/src/tower/cache.rs +502 -0
  44. data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
  45. data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
  46. data/vendor/liter-llm/src/tower/cost.rs +404 -0
  47. data/vendor/liter-llm/src/tower/fallback.rs +121 -0
  48. data/vendor/liter-llm/src/tower/health.rs +219 -0
  49. data/vendor/liter-llm/src/tower/hooks.rs +369 -0
  50. data/vendor/liter-llm/src/tower/mod.rs +77 -0
  51. data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
  52. data/vendor/liter-llm/src/tower/router.rs +436 -0
  53. data/vendor/liter-llm/src/tower/service.rs +181 -0
  54. data/vendor/liter-llm/src/tower/tests.rs +539 -0
  55. data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
  56. data/vendor/liter-llm/src/tower/tracing.rs +209 -0
  57. data/vendor/liter-llm/src/tower/types.rs +170 -0
  58. data/vendor/liter-llm/src/types/audio.rs +52 -0
  59. data/vendor/liter-llm/src/types/batch.rs +77 -0
  60. data/vendor/liter-llm/src/types/chat.rs +214 -0
  61. data/vendor/liter-llm/src/types/common.rs +244 -0
  62. data/vendor/liter-llm/src/types/embedding.rs +84 -0
  63. data/vendor/liter-llm/src/types/files.rs +58 -0
  64. data/vendor/liter-llm/src/types/image.rs +40 -0
  65. data/vendor/liter-llm/src/types/mod.rs +27 -0
  66. data/vendor/liter-llm/src/types/models.rs +21 -0
  67. data/vendor/liter-llm/src/types/moderation.rs +80 -0
  68. data/vendor/liter-llm/src/types/ocr.rs +87 -0
  69. data/vendor/liter-llm/src/types/rerank.rs +46 -0
  70. data/vendor/liter-llm/src/types/responses.rs +55 -0
  71. data/vendor/liter-llm/src/types/search.rs +45 -0
  72. data/vendor/liter-llm/tests/contract.rs +332 -0
  73. data/vendor/liter-llm-ffi/Cargo.toml +30 -0
  74. data/vendor/liter-llm-ffi/build.rs +66 -0
  75. data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
  76. data/vendor/liter-llm-ffi/liter_llm.h +850 -0
  77. data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
  78. metadata +286 -0
@@ -0,0 +1,219 @@
1
+ //! Health check middleware.
2
+ //!
3
+ //! [`HealthCheckLayer`] wraps a service and spawns a background task that
4
+ //! periodically probes the service by sending a [`LlmRequest::ListModels`]
5
+ //! request. If the probe fails, the service is marked unhealthy and incoming
6
+ //! requests are immediately rejected with [`LiterLlmError::ServiceUnavailable`].
7
+ //!
8
+ //! The health flag is an [`AtomicBool`] shared between the background probe
9
+ //! task and the request path, so checking health adds minimal overhead (a
10
+ //! single atomic load).
11
+
12
+ use std::sync::Arc;
13
+ use std::sync::atomic::{AtomicBool, Ordering};
14
+ use std::task::{Context, Poll};
15
+ use std::time::Duration;
16
+
17
+ use tower::{Layer, Service};
18
+
19
+ use super::types::{LlmRequest, LlmResponse};
20
+ use crate::client::BoxFuture;
21
+ use crate::error::{LiterLlmError, Result};
22
+
23
+ // ---- Layer -----------------------------------------------------------------
24
+
25
+ /// Tower [`Layer`] that monitors service health via periodic probes.
26
+ ///
27
+ /// The background health-check task is spawned when the layer wraps a service
28
+ /// (i.e. when [`Layer::layer`] is called). The task runs until the
29
+ /// [`HealthCheckService`] (and all its clones) are dropped.
30
+ pub struct HealthCheckLayer {
31
+ interval: Duration,
32
+ }
33
+
34
+ impl HealthCheckLayer {
35
+ /// Create a new health-check layer that probes every `interval`.
36
+ #[must_use]
37
+ pub fn new(interval: Duration) -> Self {
38
+ Self { interval }
39
+ }
40
+ }
41
+
42
+ impl<S> Layer<S> for HealthCheckLayer
43
+ where
44
+ S: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Clone + Send + 'static,
45
+ S::Future: Send + 'static,
46
+ {
47
+ type Service = HealthCheckService<S>;
48
+
49
+ fn layer(&self, inner: S) -> Self::Service {
50
+ let healthy = Arc::new(AtomicBool::new(true));
51
+
52
+ // Spawn the background probe task.
53
+ let probe_svc = inner.clone();
54
+ let probe_healthy = Arc::clone(&healthy);
55
+ let interval = self.interval;
56
+
57
+ tokio::spawn(async move {
58
+ run_health_probe(probe_svc, probe_healthy, interval).await;
59
+ });
60
+
61
+ HealthCheckService { inner, healthy }
62
+ }
63
+ }
64
+
65
+ // ---- Background probe ------------------------------------------------------
66
+
67
+ async fn run_health_probe<S>(mut svc: S, healthy: Arc<AtomicBool>, interval: Duration)
68
+ where
69
+ S: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Send + 'static,
70
+ S::Future: Send + 'static,
71
+ {
72
+ loop {
73
+ tokio::time::sleep(interval).await;
74
+
75
+ // If the Arc is held only by us, all service clones have been dropped
76
+ // and we should stop probing.
77
+ if Arc::strong_count(&healthy) <= 1 {
78
+ break;
79
+ }
80
+
81
+ let result = svc.call(LlmRequest::ListModels).await;
82
+ let is_healthy = result.is_ok();
83
+ healthy.store(is_healthy, Ordering::Release);
84
+
85
+ if !is_healthy {
86
+ tracing::warn!("health check failed; marking service as unhealthy");
87
+ }
88
+ }
89
+ }
90
+
91
+ // ---- Service ---------------------------------------------------------------
92
+
93
+ /// Tower service produced by [`HealthCheckLayer`].
94
+ pub struct HealthCheckService<S> {
95
+ inner: S,
96
+ healthy: Arc<AtomicBool>,
97
+ }
98
+
99
+ impl<S: Clone> Clone for HealthCheckService<S> {
100
+ fn clone(&self) -> Self {
101
+ Self {
102
+ inner: self.inner.clone(),
103
+ healthy: Arc::clone(&self.healthy),
104
+ }
105
+ }
106
+ }
107
+
108
+ impl<S> HealthCheckService<S> {
109
+ /// Returns `true` if the last health probe succeeded.
110
+ #[must_use]
111
+ pub fn is_healthy(&self) -> bool {
112
+ self.healthy.load(Ordering::Acquire)
113
+ }
114
+ }
115
+
116
+ impl<S> Service<LlmRequest> for HealthCheckService<S>
117
+ where
118
+ S: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Send + 'static,
119
+ S::Future: Send + 'static,
120
+ {
121
+ type Response = LlmResponse;
122
+ type Error = LiterLlmError;
123
+ type Future = BoxFuture<'static, LlmResponse>;
124
+
125
+ fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
126
+ if !self.healthy.load(Ordering::Acquire) {
127
+ return Poll::Ready(Err(LiterLlmError::ServiceUnavailable {
128
+ message: "service is unhealthy (health check failed)".into(),
129
+ }));
130
+ }
131
+ self.inner.poll_ready(cx)
132
+ }
133
+
134
+ fn call(&mut self, req: LlmRequest) -> Self::Future {
135
+ if !self.healthy.load(Ordering::Acquire) {
136
+ return Box::pin(async {
137
+ Err(LiterLlmError::ServiceUnavailable {
138
+ message: "service is unhealthy (health check failed)".into(),
139
+ })
140
+ });
141
+ }
142
+ let fut = self.inner.call(req);
143
+ Box::pin(fut)
144
+ }
145
+ }
146
+
147
+ // ---- Tests -----------------------------------------------------------------
148
+
149
+ #[cfg(test)]
150
+ mod tests {
151
+ use std::sync::atomic::Ordering;
152
+
153
+ use tower::Service as _;
154
+
155
+ use super::*;
156
+ use crate::tower::service::LlmService;
157
+ use crate::tower::tests_common::{MockClient, chat_req};
158
+ use crate::tower::types::LlmRequest;
159
+
160
+ #[tokio::test]
161
+ async fn healthy_service_passes_through() {
162
+ let inner = LlmService::new(MockClient::ok());
163
+ let healthy = Arc::new(AtomicBool::new(true));
164
+ let mut svc = HealthCheckService {
165
+ inner,
166
+ healthy: Arc::clone(&healthy),
167
+ };
168
+
169
+ let resp = svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await;
170
+ assert!(resp.is_ok());
171
+ }
172
+
173
+ #[tokio::test]
174
+ async fn unhealthy_service_rejects_requests() {
175
+ let inner = LlmService::new(MockClient::ok());
176
+ let healthy = Arc::new(AtomicBool::new(false));
177
+ let mut svc = HealthCheckService {
178
+ inner,
179
+ healthy: Arc::clone(&healthy),
180
+ };
181
+
182
+ let err = svc
183
+ .call(LlmRequest::Chat(chat_req("gpt-4")))
184
+ .await
185
+ .expect_err("unhealthy service should reject");
186
+ assert!(matches!(err, LiterLlmError::ServiceUnavailable { .. }));
187
+ }
188
+
189
+ #[tokio::test]
190
+ async fn is_healthy_reflects_flag() {
191
+ let inner = LlmService::new(MockClient::ok());
192
+ let healthy = Arc::new(AtomicBool::new(true));
193
+ let svc = HealthCheckService {
194
+ inner,
195
+ healthy: Arc::clone(&healthy),
196
+ };
197
+
198
+ assert!(svc.is_healthy());
199
+ healthy.store(false, Ordering::Release);
200
+ assert!(!svc.is_healthy());
201
+ }
202
+
203
+ #[tokio::test]
204
+ async fn recovery_after_becoming_healthy_again() {
205
+ let inner = LlmService::new(MockClient::ok());
206
+ let healthy = Arc::new(AtomicBool::new(false));
207
+ let mut svc = HealthCheckService {
208
+ inner,
209
+ healthy: Arc::clone(&healthy),
210
+ };
211
+
212
+ // Unhealthy — should reject.
213
+ assert!(svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.is_err());
214
+
215
+ // Mark as healthy again.
216
+ healthy.store(true, Ordering::Release);
217
+ assert!(svc.call(LlmRequest::Chat(chat_req("gpt-4"))).await.is_ok());
218
+ }
219
+ }
@@ -0,0 +1,369 @@
1
+ //! Tower middleware that invokes user-defined hooks before and after requests.
2
+ //!
3
+ //! [`HooksLayer`] wraps any [`Service<LlmRequest>`] and calls registered
4
+ //! [`LlmHook`] implementations at three lifecycle points:
5
+ //!
6
+ //! - **`on_request`** — before the request is forwarded to the inner service.
7
+ //! Returning `Err` from any hook short-circuits the chain (guardrail
8
+ //! rejection).
9
+ //! - **`on_response`** — after a successful response from the inner service.
10
+ //! - **`on_error`** — when the inner service returns an error.
11
+ //!
12
+ //! Hooks are invoked sequentially in registration order.
13
+ //!
14
+ //! # Example
15
+ //!
16
+ //! ```rust,ignore
17
+ //! use std::sync::Arc;
18
+ //! use liter_llm::tower::{HooksLayer, LlmHook, LlmService, TracingLayer};
19
+ //! use tower::ServiceBuilder;
20
+ //!
21
+ //! let hook: Arc<dyn LlmHook> = Arc::new(MyAuditHook);
22
+ //! let service = ServiceBuilder::new()
23
+ //! .layer(HooksLayer::single(hook))
24
+ //! .service(LlmService::new(client));
25
+ //! ```
26
+
27
+ use std::future::Future;
28
+ use std::panic::AssertUnwindSafe;
29
+ use std::pin::Pin;
30
+ use std::sync::Arc;
31
+ use std::task::{Context, Poll};
32
+
33
+ use futures_util::FutureExt as _;
34
+ use tower::Layer;
35
+ use tower::Service;
36
+
37
+ use super::types::{LlmRequest, LlmResponse};
38
+ use crate::client::BoxFuture;
39
+ use crate::error::{LiterLlmError, Result};
40
+
41
+ // ─── Hook Trait ──────────────────────────────────────────────────────────────
42
+
43
+ /// Callback trait for observing and guarding LLM requests.
44
+ ///
45
+ /// All methods have default no-op implementations, so consumers only need to
46
+ /// override the lifecycle points they care about.
47
+ pub trait LlmHook: Send + Sync + 'static {
48
+ /// Called before the request is sent to the inner service.
49
+ ///
50
+ /// Return `Err` to short-circuit the entire service chain — this enables
51
+ /// guardrail patterns such as content filtering or budget enforcement.
52
+ fn on_request(&self, _req: &LlmRequest) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
53
+ Box::pin(async { Ok(()) })
54
+ }
55
+
56
+ /// Called after the inner service returns a successful response.
57
+ fn on_response(&self, _req: &LlmRequest, _resp: &LlmResponse) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
58
+ Box::pin(async {})
59
+ }
60
+
61
+ /// Called when the inner service returns an error.
62
+ fn on_error(&self, _req: &LlmRequest, _err: &LiterLlmError) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
63
+ Box::pin(async {})
64
+ }
65
+ }
66
+
67
+ // ─── Layer ───────────────────────────────────────────────────────────────────
68
+
69
+ /// Tower [`Layer`] that attaches [`LlmHook`] callbacks to a service.
70
+ ///
71
+ /// Hooks are stored behind `Arc` so that the layer and all services it
72
+ /// produces share the same hook instances without cloning them.
73
+ #[derive(Clone)]
74
+ pub struct HooksLayer {
75
+ hooks: Arc<Vec<Arc<dyn LlmHook>>>,
76
+ }
77
+
78
+ impl HooksLayer {
79
+ /// Create a new layer with the given list of hooks.
80
+ ///
81
+ /// Hooks are invoked sequentially in the order they appear in the vector.
82
+ #[must_use]
83
+ pub fn new(hooks: Vec<Arc<dyn LlmHook>>) -> Self {
84
+ Self { hooks: Arc::new(hooks) }
85
+ }
86
+
87
+ /// Convenience constructor for a single hook.
88
+ #[must_use]
89
+ pub fn single(hook: Arc<dyn LlmHook>) -> Self {
90
+ Self::new(vec![hook])
91
+ }
92
+ }
93
+
94
+ impl<S> Layer<S> for HooksLayer {
95
+ type Service = HooksService<S>;
96
+
97
+ fn layer(&self, inner: S) -> Self::Service {
98
+ HooksService {
99
+ inner,
100
+ hooks: Arc::clone(&self.hooks),
101
+ }
102
+ }
103
+ }
104
+
105
+ // ─── Service ─────────────────────────────────────────────────────────────────
106
+
107
+ /// Tower service produced by [`HooksLayer`].
108
+ pub struct HooksService<S> {
109
+ inner: S,
110
+ hooks: Arc<Vec<Arc<dyn LlmHook>>>,
111
+ }
112
+
113
+ impl<S: Clone> Clone for HooksService<S> {
114
+ fn clone(&self) -> Self {
115
+ Self {
116
+ inner: self.inner.clone(),
117
+ hooks: Arc::clone(&self.hooks),
118
+ }
119
+ }
120
+ }
121
+
122
+ impl<S> Service<LlmRequest> for HooksService<S>
123
+ where
124
+ S: Service<LlmRequest, Response = LlmResponse, Error = LiterLlmError> + Send + 'static,
125
+ S::Future: Send + 'static,
126
+ {
127
+ type Response = LlmResponse;
128
+ type Error = LiterLlmError;
129
+ type Future = BoxFuture<'static, LlmResponse>;
130
+
131
+ fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
132
+ self.inner.poll_ready(cx)
133
+ }
134
+
135
+ fn call(&mut self, req: LlmRequest) -> Self::Future {
136
+ let hooks = Arc::clone(&self.hooks);
137
+ // Clone the request so we can pass it to post-hooks after the inner
138
+ // service consumes the original.
139
+ let req_clone = req.clone();
140
+ let fut = self.inner.call(req);
141
+
142
+ Box::pin(async move {
143
+ // Pre-hooks: run sequentially; short-circuit on first Err or panic.
144
+ for hook in hooks.iter() {
145
+ let result = AssertUnwindSafe(hook.on_request(&req_clone)).catch_unwind().await;
146
+ match result {
147
+ Ok(Ok(())) => {}
148
+ Ok(Err(e)) => return Err(e),
149
+ Err(_panic) => {
150
+ tracing::error!("hook panicked during on_request");
151
+ return Err(LiterLlmError::HookRejected {
152
+ message: "hook panicked".into(),
153
+ });
154
+ }
155
+ }
156
+ }
157
+
158
+ match fut.await {
159
+ Ok(resp) => {
160
+ // Post-hooks (success path) — panics are logged but do not
161
+ // propagate so the caller still receives the response.
162
+ for hook in hooks.iter() {
163
+ if AssertUnwindSafe(hook.on_response(&req_clone, &resp))
164
+ .catch_unwind()
165
+ .await
166
+ .is_err()
167
+ {
168
+ tracing::error!("hook panicked during on_response");
169
+ }
170
+ }
171
+ Ok(resp)
172
+ }
173
+ Err(err) => {
174
+ // Post-hooks (error path) — panics are logged but do not
175
+ // replace the original error.
176
+ for hook in hooks.iter() {
177
+ if AssertUnwindSafe(hook.on_error(&req_clone, &err))
178
+ .catch_unwind()
179
+ .await
180
+ .is_err()
181
+ {
182
+ tracing::error!("hook panicked during on_error");
183
+ }
184
+ }
185
+ Err(err)
186
+ }
187
+ }
188
+ })
189
+ }
190
+ }
191
+
192
+ // ─── Tests ───────────────────────────────────────────────────────────────────
193
+
194
+ #[cfg(test)]
195
+ mod tests {
196
+ use std::sync::Arc;
197
+ use std::sync::atomic::{AtomicUsize, Ordering};
198
+
199
+ use tower::Layer as _;
200
+ use tower::Service as _;
201
+
202
+ use super::*;
203
+ use crate::tower::service::LlmService;
204
+ use crate::tower::tests_common::{MockClient, chat_req};
205
+ use crate::tower::types::{LlmRequest, LlmResponse};
206
+
207
+ // ── Test hook implementations ────────────────────────────────────────────
208
+
209
+ /// A hook that records how many times each callback was invoked.
210
+ struct CountingHook {
211
+ on_request_count: AtomicUsize,
212
+ on_response_count: AtomicUsize,
213
+ on_error_count: AtomicUsize,
214
+ }
215
+
216
+ impl CountingHook {
217
+ fn new() -> Self {
218
+ Self {
219
+ on_request_count: AtomicUsize::new(0),
220
+ on_response_count: AtomicUsize::new(0),
221
+ on_error_count: AtomicUsize::new(0),
222
+ }
223
+ }
224
+ }
225
+
226
+ impl LlmHook for CountingHook {
227
+ fn on_request(&self, _req: &LlmRequest) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
228
+ self.on_request_count.fetch_add(1, Ordering::SeqCst);
229
+ Box::pin(async { Ok(()) })
230
+ }
231
+
232
+ fn on_response(&self, _req: &LlmRequest, _resp: &LlmResponse) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
233
+ self.on_response_count.fetch_add(1, Ordering::SeqCst);
234
+ Box::pin(async {})
235
+ }
236
+
237
+ fn on_error(&self, _req: &LlmRequest, _err: &LiterLlmError) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
238
+ self.on_error_count.fetch_add(1, Ordering::SeqCst);
239
+ Box::pin(async {})
240
+ }
241
+ }
242
+
243
+ /// A hook that rejects all requests (guardrail).
244
+ struct RejectAllHook;
245
+
246
+ impl LlmHook for RejectAllHook {
247
+ fn on_request(&self, _req: &LlmRequest) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
248
+ Box::pin(async {
249
+ Err(LiterLlmError::HookRejected {
250
+ message: "rejected by guardrail".into(),
251
+ })
252
+ })
253
+ }
254
+ }
255
+
256
+ /// A hook that records its invocation order into a shared vector.
257
+ struct OrderTrackingHook {
258
+ id: usize,
259
+ order: Arc<std::sync::Mutex<Vec<usize>>>,
260
+ }
261
+
262
+ impl LlmHook for OrderTrackingHook {
263
+ fn on_request(&self, _req: &LlmRequest) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
264
+ self.order.lock().expect("lock poisoned").push(self.id);
265
+ Box::pin(async { Ok(()) })
266
+ }
267
+
268
+ fn on_response(&self, _req: &LlmRequest, _resp: &LlmResponse) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
269
+ self.order.lock().expect("lock poisoned").push(self.id + 100);
270
+ Box::pin(async {})
271
+ }
272
+ }
273
+
274
+ // ── Tests ────────────────────────────────────────────────────────────────
275
+
276
+ #[tokio::test]
277
+ async fn on_request_hook_is_called() {
278
+ let hook = Arc::new(CountingHook::new());
279
+ let inner = LlmService::new(MockClient::ok());
280
+ let mut svc = HooksLayer::single(Arc::clone(&hook) as Arc<dyn LlmHook>).layer(inner);
281
+
282
+ let _resp = svc
283
+ .call(LlmRequest::Chat(chat_req("gpt-4")))
284
+ .await
285
+ .expect("should succeed");
286
+
287
+ assert_eq!(hook.on_request_count.load(Ordering::SeqCst), 1);
288
+ }
289
+
290
+ #[tokio::test]
291
+ async fn on_response_hook_is_called_on_success() {
292
+ let hook = Arc::new(CountingHook::new());
293
+ let inner = LlmService::new(MockClient::ok());
294
+ let mut svc = HooksLayer::single(Arc::clone(&hook) as Arc<dyn LlmHook>).layer(inner);
295
+
296
+ let _resp = svc
297
+ .call(LlmRequest::Chat(chat_req("gpt-4")))
298
+ .await
299
+ .expect("should succeed");
300
+
301
+ assert_eq!(hook.on_response_count.load(Ordering::SeqCst), 1);
302
+ assert_eq!(hook.on_error_count.load(Ordering::SeqCst), 0);
303
+ }
304
+
305
+ #[tokio::test]
306
+ async fn on_error_hook_is_called_on_failure() {
307
+ let hook = Arc::new(CountingHook::new());
308
+ let inner = LlmService::new(MockClient::failing_timeout());
309
+ let mut svc = HooksLayer::single(Arc::clone(&hook) as Arc<dyn LlmHook>).layer(inner);
310
+
311
+ let err = svc
312
+ .call(LlmRequest::Chat(chat_req("gpt-4")))
313
+ .await
314
+ .expect_err("should fail");
315
+
316
+ assert!(matches!(err, LiterLlmError::Timeout));
317
+ assert_eq!(hook.on_error_count.load(Ordering::SeqCst), 1);
318
+ assert_eq!(hook.on_response_count.load(Ordering::SeqCst), 0);
319
+ }
320
+
321
+ #[tokio::test]
322
+ async fn guardrail_rejection_short_circuits_inner_service() {
323
+ let mock = MockClient::ok();
324
+ let call_count = Arc::clone(&mock.call_count);
325
+ let inner = LlmService::new(mock);
326
+ let mut svc = HooksLayer::single(Arc::new(RejectAllHook) as Arc<dyn LlmHook>).layer(inner);
327
+
328
+ let err = svc
329
+ .call(LlmRequest::Chat(chat_req("gpt-4")))
330
+ .await
331
+ .expect_err("should be rejected by guardrail");
332
+
333
+ assert!(matches!(err, LiterLlmError::HookRejected { .. }));
334
+ // The inner service must NOT have been called.
335
+ assert_eq!(call_count.load(Ordering::SeqCst), 0);
336
+ }
337
+
338
+ #[tokio::test]
339
+ async fn multiple_hooks_called_in_registration_order() {
340
+ let order = Arc::new(std::sync::Mutex::new(Vec::new()));
341
+
342
+ let hooks: Vec<Arc<dyn LlmHook>> = vec![
343
+ Arc::new(OrderTrackingHook {
344
+ id: 1,
345
+ order: Arc::clone(&order),
346
+ }),
347
+ Arc::new(OrderTrackingHook {
348
+ id: 2,
349
+ order: Arc::clone(&order),
350
+ }),
351
+ Arc::new(OrderTrackingHook {
352
+ id: 3,
353
+ order: Arc::clone(&order),
354
+ }),
355
+ ];
356
+
357
+ let inner = LlmService::new(MockClient::ok());
358
+ let mut svc = HooksLayer::new(hooks).layer(inner);
359
+
360
+ let _resp = svc
361
+ .call(LlmRequest::Chat(chat_req("gpt-4")))
362
+ .await
363
+ .expect("should succeed");
364
+
365
+ let recorded = order.lock().expect("lock poisoned").clone();
366
+ // Pre-hooks: 1, 2, 3 then post-hooks: 101, 102, 103
367
+ assert_eq!(recorded, vec![1, 2, 3, 101, 102, 103]);
368
+ }
369
+ }
@@ -0,0 +1,77 @@
1
+ //! Tower middleware integration for [`crate::client::LlmClient`].
2
+ //!
3
+ //! This module is only compiled when the `tower` feature is enabled. It
4
+ //! provides:
5
+ //!
6
+ //! - [`types::LlmRequest`] / [`types::LlmResponse`] — the request/response
7
+ //! enums that cross the tower `Service` boundary.
8
+ //! - [`service::LlmService`] — a thin `tower::Service` wrapper around any
9
+ //! [`crate::client::LlmClient`].
10
+ //! - [`tracing::TracingLayer`] / [`tracing::TracingService`] — OTEL-compatible
11
+ //! tracing middleware.
12
+ //! - [`fallback::FallbackLayer`] / [`fallback::FallbackService`] — route to a
13
+ //! backup service on transient errors.
14
+ //! - [`cost::CostTrackingLayer`] / [`cost::CostTrackingService`] — emit
15
+ //! `gen_ai.usage.cost` tracing span attribute from embedded pricing data.
16
+ //! - [`rate_limit::ModelRateLimitLayer`] / [`rate_limit::ModelRateLimitService`]
17
+ //! — per-model RPM / TPM rate limiting.
18
+ //! - [`cache::CacheLayer`] / [`cache::CacheService`] — in-memory response
19
+ //! caching for non-streaming requests.
20
+ //! - [`cooldown::CooldownLayer`] / [`cooldown::CooldownService`] — deployment
21
+ //! cooldowns after transient errors.
22
+ //! - [`health::HealthCheckLayer`] / [`health::HealthCheckService`] — periodic
23
+ //! health probes with automatic request rejection on failure.
24
+ //! - [`budget::BudgetLayer`] / [`budget::BudgetService`] — global and per-model
25
+ //! spending budget enforcement (hard reject or soft warn).
26
+ //! - [`hooks::HooksLayer`] / [`hooks::HooksService`] — user-defined pre/post
27
+ //! request hooks for guardrails, logging, and auditing.
28
+ //!
29
+ //! # Example
30
+ //!
31
+ //! ```rust,ignore
32
+ //! use liter_llm::tower::{CostTrackingLayer, LlmService, TracingLayer};
33
+ //! use tower::ServiceBuilder;
34
+ //!
35
+ //! let client = liter_llm::DefaultClient::new(config, None)?;
36
+ //! let service = ServiceBuilder::new()
37
+ //! .layer(TracingLayer)
38
+ //! .layer(CostTrackingLayer)
39
+ //! .service(LlmService::new(client));
40
+ //! ```
41
+
42
+ pub mod budget;
43
+ pub mod cache;
44
+ #[cfg(feature = "opendal-cache")]
45
+ pub mod cache_opendal;
46
+ pub mod cooldown;
47
+ pub mod cost;
48
+ pub mod fallback;
49
+ pub mod health;
50
+ pub mod hooks;
51
+ pub mod rate_limit;
52
+ pub mod router;
53
+ pub mod service;
54
+ #[cfg(test)]
55
+ mod tests;
56
+ #[cfg(test)]
57
+ pub(crate) mod tests_common;
58
+ pub mod tracing;
59
+ pub mod types;
60
+
61
+ // Re-export tower core types for convenient access
62
+ pub use tower::ServiceExt;
63
+
64
+ pub use budget::{BudgetConfig, BudgetLayer, BudgetService, BudgetState, Enforcement};
65
+ pub use cache::{CacheBackend, CacheConfig, CacheLayer, CacheService, CacheStore, CachedResponse, InMemoryStore};
66
+ #[cfg(feature = "opendal-cache")]
67
+ pub use cache_opendal::OpenDalCacheStore;
68
+ pub use cooldown::{CooldownLayer, CooldownService};
69
+ pub use cost::{CostTrackingLayer, CostTrackingService};
70
+ pub use fallback::{FallbackLayer, FallbackService};
71
+ pub use health::{HealthCheckLayer, HealthCheckService};
72
+ pub use hooks::{HooksLayer, HooksService, LlmHook};
73
+ pub use rate_limit::{ModelRateLimitLayer, ModelRateLimitService, RateLimitConfig};
74
+ pub use router::{Router, RoutingStrategy};
75
+ pub use service::LlmService;
76
+ pub use tracing::{TracingLayer, TracingService};
77
+ pub use types::{LlmRequest, LlmResponse};