liter_llm 1.0.0.pre.rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +239 -0
  3. data/ext/liter_llm_rb/extconf.rb +65 -0
  4. data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
  5. data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
  6. data/ext/liter_llm_rb/native/Cargo.toml +32 -0
  7. data/ext/liter_llm_rb/native/build.rs +15 -0
  8. data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
  9. data/lib/liter_llm.rb +8 -0
  10. data/sig/liter_llm.rbs +416 -0
  11. data/vendor/Cargo.toml +54 -0
  12. data/vendor/liter-llm/Cargo.toml +92 -0
  13. data/vendor/liter-llm/README.md +252 -0
  14. data/vendor/liter-llm/schemas/pricing.json +40 -0
  15. data/vendor/liter-llm/schemas/providers.json +1662 -0
  16. data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
  17. data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
  18. data/vendor/liter-llm/src/auth/mod.rs +68 -0
  19. data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
  20. data/vendor/liter-llm/src/client/config.rs +351 -0
  21. data/vendor/liter-llm/src/client/managed.rs +622 -0
  22. data/vendor/liter-llm/src/client/mod.rs +864 -0
  23. data/vendor/liter-llm/src/cost.rs +212 -0
  24. data/vendor/liter-llm/src/error.rs +190 -0
  25. data/vendor/liter-llm/src/http/eventstream.rs +860 -0
  26. data/vendor/liter-llm/src/http/mod.rs +12 -0
  27. data/vendor/liter-llm/src/http/request.rs +438 -0
  28. data/vendor/liter-llm/src/http/retry.rs +72 -0
  29. data/vendor/liter-llm/src/http/streaming.rs +289 -0
  30. data/vendor/liter-llm/src/lib.rs +37 -0
  31. data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
  32. data/vendor/liter-llm/src/provider/azure.rs +579 -0
  33. data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
  34. data/vendor/liter-llm/src/provider/cohere.rs +654 -0
  35. data/vendor/liter-llm/src/provider/custom.rs +404 -0
  36. data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
  37. data/vendor/liter-llm/src/provider/mistral.rs +188 -0
  38. data/vendor/liter-llm/src/provider/mod.rs +616 -0
  39. data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
  40. data/vendor/liter-llm/src/tests.rs +1425 -0
  41. data/vendor/liter-llm/src/tokenizer.rs +281 -0
  42. data/vendor/liter-llm/src/tower/budget.rs +599 -0
  43. data/vendor/liter-llm/src/tower/cache.rs +502 -0
  44. data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
  45. data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
  46. data/vendor/liter-llm/src/tower/cost.rs +404 -0
  47. data/vendor/liter-llm/src/tower/fallback.rs +121 -0
  48. data/vendor/liter-llm/src/tower/health.rs +219 -0
  49. data/vendor/liter-llm/src/tower/hooks.rs +369 -0
  50. data/vendor/liter-llm/src/tower/mod.rs +77 -0
  51. data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
  52. data/vendor/liter-llm/src/tower/router.rs +436 -0
  53. data/vendor/liter-llm/src/tower/service.rs +181 -0
  54. data/vendor/liter-llm/src/tower/tests.rs +539 -0
  55. data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
  56. data/vendor/liter-llm/src/tower/tracing.rs +209 -0
  57. data/vendor/liter-llm/src/tower/types.rs +170 -0
  58. data/vendor/liter-llm/src/types/audio.rs +52 -0
  59. data/vendor/liter-llm/src/types/batch.rs +77 -0
  60. data/vendor/liter-llm/src/types/chat.rs +214 -0
  61. data/vendor/liter-llm/src/types/common.rs +244 -0
  62. data/vendor/liter-llm/src/types/embedding.rs +84 -0
  63. data/vendor/liter-llm/src/types/files.rs +58 -0
  64. data/vendor/liter-llm/src/types/image.rs +40 -0
  65. data/vendor/liter-llm/src/types/mod.rs +27 -0
  66. data/vendor/liter-llm/src/types/models.rs +21 -0
  67. data/vendor/liter-llm/src/types/moderation.rs +80 -0
  68. data/vendor/liter-llm/src/types/ocr.rs +87 -0
  69. data/vendor/liter-llm/src/types/rerank.rs +46 -0
  70. data/vendor/liter-llm/src/types/responses.rs +55 -0
  71. data/vendor/liter-llm/src/types/search.rs +45 -0
  72. data/vendor/liter-llm/tests/contract.rs +332 -0
  73. data/vendor/liter-llm-ffi/Cargo.toml +30 -0
  74. data/vendor/liter-llm-ffi/build.rs +66 -0
  75. data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
  76. data/vendor/liter-llm-ffi/liter_llm.h +850 -0
  77. data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
  78. metadata +286 -0
@@ -0,0 +1,860 @@
1
+ //! AWS EventStream binary protocol parser for Bedrock ConverseStream.
2
+ //!
3
+ //! The EventStream protocol frames each message as:
4
+ //!
5
+ //! ```text
6
+ //! [total_length:4][headers_length:4][prelude_crc:4][headers:N][payload:M][message_crc:4]
7
+ //! ```
8
+ //!
9
+ //! Where lengths are big-endian `u32` values. Each header is:
10
+ //!
11
+ //! ```text
12
+ //! [name_len:1][name:name_len][type:1][value_len:2][value:value_len]
13
+ //! ```
14
+ //!
15
+ //! This implementation focuses on correctness and zero-copy where possible.
16
+ //! CRC validation is performed to detect corrupted frames.
17
+
18
+ use std::pin::Pin;
19
+ use std::task::{Context, Poll};
20
+
21
+ use bytes::{Bytes, BytesMut};
22
+ use futures_core::Stream;
23
+ use pin_project_lite::pin_project;
24
+
25
+ use crate::error::{LiterLlmError, Result};
26
+ use crate::http::request::with_retry;
27
+ use crate::types::ChatCompletionChunk;
28
+
29
+ /// Minimum frame size: prelude (12) + message CRC (4) = 16 bytes.
30
+ const MIN_FRAME_SIZE: usize = 16;
31
+
32
+ /// Maximum frame size to prevent unbounded buffering (16 MiB).
33
+ const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
34
+
35
+ /// Header value type for UTF-8 strings.
36
+ const HEADER_TYPE_STRING: u8 = 7;
37
+
38
+ // ---------------------------------------------------------------------------
39
+ // Public entry point
40
+ // ---------------------------------------------------------------------------
41
+
42
+ /// Send a streaming POST request and return a stream of `ChatCompletionChunk`s
43
+ /// parsed from AWS EventStream binary frames.
44
+ ///
45
+ /// The `parse_event` function receives `(event_type, payload_json)` for each
46
+ /// event and returns a parsed chunk or `None` for terminal events.
47
+ #[cfg_attr(
48
+ feature = "tracing",
49
+ tracing::instrument(
50
+ skip_all,
51
+ fields(
52
+ http.method = "POST",
53
+ http.url = %url,
54
+ http.status_code = tracing::field::Empty,
55
+ http.retry_count = tracing::field::Empty,
56
+ )
57
+ )
58
+ )]
59
+ pub async fn post_eventstream<P>(
60
+ client: &reqwest::Client,
61
+ url: &str,
62
+ auth_header: Option<(&str, &str)>,
63
+ extra_headers: &[(&str, &str)],
64
+ body: Bytes,
65
+ max_retries: u32,
66
+ parse_event: P,
67
+ ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>>
68
+ where
69
+ P: Fn(&str, &str) -> Result<Option<ChatCompletionChunk>> + Send + 'static,
70
+ {
71
+ let mut retry_count = 0u32;
72
+
73
+ let resp = with_retry(max_retries, || {
74
+ // Clone is a zero-copy ref-count bump on `Bytes`.
75
+ let mut builder = client
76
+ .post(url)
77
+ .header(reqwest::header::CONTENT_TYPE, "application/json")
78
+ .body(body.clone());
79
+ if let Some((name, value)) = auth_header {
80
+ builder = builder.header(name, value);
81
+ }
82
+ for (name, value) in extra_headers {
83
+ builder = builder.header(*name, *value);
84
+ }
85
+ retry_count += 1;
86
+ builder.send()
87
+ })
88
+ .await?;
89
+
90
+ #[cfg(feature = "tracing")]
91
+ {
92
+ let span = tracing::Span::current();
93
+ span.record("http.status_code", resp.status().as_u16());
94
+ span.record("http.retry_count", retry_count.saturating_sub(1));
95
+ }
96
+
97
+ let byte_stream = resp.bytes_stream();
98
+ let stream = EventStreamParser::new(byte_stream, parse_event);
99
+ Ok(Box::pin(stream))
100
+ }
101
+
102
+ // ---------------------------------------------------------------------------
103
+ // EventStream frame parser
104
+ // ---------------------------------------------------------------------------
105
+
106
+ /// A parsed header from an EventStream frame.
107
+ struct EventHeader {
108
+ name: String,
109
+ value: String,
110
+ }
111
+
112
+ /// Parse headers from the header section of an EventStream frame.
113
+ ///
114
+ /// Only extracts string-typed headers (type 7); other types are skipped.
115
+ fn parse_headers(mut data: &[u8]) -> Result<Vec<EventHeader>> {
116
+ let mut headers = Vec::new();
117
+ while !data.is_empty() {
118
+ let name_len = data[0] as usize;
119
+ data = &data[1..];
120
+ if data.len() < name_len {
121
+ return Err(LiterLlmError::Streaming {
122
+ message: "EventStream header name truncated".into(),
123
+ });
124
+ }
125
+ let name = std::str::from_utf8(&data[..name_len])
126
+ .map_err(|_| LiterLlmError::Streaming {
127
+ message: "EventStream header name is not UTF-8".into(),
128
+ })?
129
+ .to_owned();
130
+ data = &data[name_len..];
131
+
132
+ if data.is_empty() {
133
+ return Err(LiterLlmError::Streaming {
134
+ message: "EventStream header type byte missing".into(),
135
+ });
136
+ }
137
+ let value_type = data[0];
138
+ data = &data[1..];
139
+
140
+ if value_type == HEADER_TYPE_STRING {
141
+ // String: 2-byte big-endian length + UTF-8 value.
142
+ if data.len() < 2 {
143
+ return Err(LiterLlmError::Streaming {
144
+ message: "EventStream string header length truncated".into(),
145
+ });
146
+ }
147
+ let value_len = u16::from_be_bytes([data[0], data[1]]) as usize;
148
+ data = &data[2..];
149
+ if data.len() < value_len {
150
+ return Err(LiterLlmError::Streaming {
151
+ message: "EventStream string header value truncated".into(),
152
+ });
153
+ }
154
+ let value = std::str::from_utf8(&data[..value_len])
155
+ .map_err(|_| LiterLlmError::Streaming {
156
+ message: "EventStream header value is not UTF-8".into(),
157
+ })?
158
+ .to_owned();
159
+ data = &data[value_len..];
160
+ headers.push(EventHeader { name, value });
161
+ } else {
162
+ // Skip non-string header types based on their wire sizes.
163
+ // AWS EventStream spec: bool types have no value bytes (the type
164
+ // byte itself encodes true/false).
165
+ let skip = match value_type {
166
+ 0 => 0, // bool_true — no value bytes
167
+ 1 => 0, // bool_false — no value bytes
168
+ 2 => 1, // byte
169
+ 3 => 2, // short
170
+ 4 => 4, // int
171
+ 5 => 8, // long
172
+ 6 => {
173
+ // bytes: 2-byte length prefix
174
+ if data.len() < 2 {
175
+ return Err(LiterLlmError::Streaming {
176
+ message: "EventStream bytes header length truncated".into(),
177
+ });
178
+ }
179
+ let len = u16::from_be_bytes([data[0], data[1]]) as usize;
180
+ 2 + len
181
+ }
182
+ 8 => 8, // timestamp
183
+ 9 => 16, // uuid
184
+ _ => {
185
+ return Err(LiterLlmError::Streaming {
186
+ message: format!("unknown EventStream header type: {value_type}"),
187
+ });
188
+ }
189
+ };
190
+ if data.len() < skip {
191
+ return Err(LiterLlmError::Streaming {
192
+ message: "EventStream header value data truncated".into(),
193
+ });
194
+ }
195
+ data = &data[skip..];
196
+ }
197
+ }
198
+ Ok(headers)
199
+ }
200
+
201
+ /// CRC32C (Castagnoli) implementation for EventStream frame validation.
202
+ ///
203
+ /// Uses a lookup table for the CRC32C polynomial (0x82F63B78).
204
+ /// This matches the checksum algorithm used by the AWS EventStream protocol.
205
+ fn crc32c(data: &[u8]) -> u32 {
206
+ // CRC32C lookup table (Castagnoli polynomial 0x82F63B78).
207
+ static TABLE: [u32; 256] = {
208
+ let mut table = [0u32; 256];
209
+ let mut i = 0;
210
+ while i < 256 {
211
+ let mut crc = i as u32;
212
+ let mut j = 0;
213
+ while j < 8 {
214
+ if crc & 1 != 0 {
215
+ crc = (crc >> 1) ^ 0x82F6_3B78;
216
+ } else {
217
+ crc >>= 1;
218
+ }
219
+ j += 1;
220
+ }
221
+ table[i] = crc;
222
+ i += 1;
223
+ }
224
+ table
225
+ };
226
+
227
+ let mut crc = 0xFFFF_FFFFu32;
228
+ for &byte in data {
229
+ crc = TABLE[((crc ^ u32::from(byte)) & 0xFF) as usize] ^ (crc >> 8);
230
+ }
231
+ crc ^ 0xFFFF_FFFF
232
+ }
233
+
234
+ // ---------------------------------------------------------------------------
235
+ // EventStream Stream adapter
236
+ // ---------------------------------------------------------------------------
237
+
238
+ pin_project! {
239
+ /// Wraps a `bytes::Bytes` stream and yields `ChatCompletionChunk`s parsed
240
+ /// from AWS EventStream binary frames.
241
+ ///
242
+ /// The `P` type parameter is the parse function that receives
243
+ /// `(event_type, payload_json)` for each event.
244
+ struct EventStreamParser<S, P> {
245
+ #[pin]
246
+ inner: S,
247
+ buffer: BytesMut,
248
+ done: bool,
249
+ parse_event: P,
250
+ }
251
+ }
252
+
253
+ impl<S, P> EventStreamParser<S, P> {
254
+ fn new(inner: S, parse_event: P) -> Self {
255
+ Self {
256
+ inner,
257
+ buffer: BytesMut::new(),
258
+ done: false,
259
+ parse_event,
260
+ }
261
+ }
262
+ }
263
+
264
+ impl<S, P> Stream for EventStreamParser<S, P>
265
+ where
266
+ S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send,
267
+ P: Fn(&str, &str) -> Result<Option<ChatCompletionChunk>>,
268
+ {
269
+ type Item = Result<ChatCompletionChunk>;
270
+
271
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
272
+ let mut this = self.project();
273
+
274
+ loop {
275
+ // --- Try to parse a complete frame from the buffer ---
276
+ if this.buffer.len() >= MIN_FRAME_SIZE {
277
+ let total_length =
278
+ u32::from_be_bytes([this.buffer[0], this.buffer[1], this.buffer[2], this.buffer[3]]) as usize;
279
+
280
+ // Sanity check on frame size.
281
+ if !(MIN_FRAME_SIZE..=MAX_FRAME_SIZE).contains(&total_length) {
282
+ return Poll::Ready(Some(Err(LiterLlmError::Streaming {
283
+ message: format!(
284
+ "EventStream frame size {total_length} is out of range [{MIN_FRAME_SIZE}, {MAX_FRAME_SIZE}]"
285
+ ),
286
+ })));
287
+ }
288
+
289
+ // Wait for the full frame to be buffered.
290
+ if this.buffer.len() < total_length {
291
+ // Need more data — fall through to read more bytes.
292
+ } else {
293
+ // We have a complete frame. Extract and advance.
294
+ let frame = this.buffer.split_to(total_length);
295
+
296
+ let headers_length = u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]) as usize;
297
+
298
+ // Validate prelude CRC (covers first 8 bytes).
299
+ let prelude_crc_expected = u32::from_be_bytes([frame[8], frame[9], frame[10], frame[11]]);
300
+ let prelude_crc_actual = crc32c(&frame[..8]);
301
+ if prelude_crc_expected != prelude_crc_actual {
302
+ return Poll::Ready(Some(Err(LiterLlmError::Streaming {
303
+ message: format!(
304
+ "EventStream prelude CRC mismatch: expected {prelude_crc_expected:#010X}, got {prelude_crc_actual:#010X}"
305
+ ),
306
+ })));
307
+ }
308
+
309
+ // Validate message CRC (covers everything except the last 4 bytes).
310
+ let message_crc_expected = u32::from_be_bytes([
311
+ frame[total_length - 4],
312
+ frame[total_length - 3],
313
+ frame[total_length - 2],
314
+ frame[total_length - 1],
315
+ ]);
316
+ let message_crc_actual = crc32c(&frame[..total_length - 4]);
317
+ if message_crc_expected != message_crc_actual {
318
+ return Poll::Ready(Some(Err(LiterLlmError::Streaming {
319
+ message: format!(
320
+ "EventStream message CRC mismatch: expected {message_crc_expected:#010X}, got {message_crc_actual:#010X}"
321
+ ),
322
+ })));
323
+ }
324
+
325
+ // Parse headers.
326
+ let headers_start = 12; // after prelude (total_len + headers_len + prelude_crc)
327
+ let headers_end = headers_start + headers_length;
328
+ if headers_end > total_length - 4 {
329
+ return Poll::Ready(Some(Err(LiterLlmError::Streaming {
330
+ message: "EventStream headers extend past frame boundary".into(),
331
+ })));
332
+ }
333
+
334
+ let headers = match parse_headers(&frame[headers_start..headers_end]) {
335
+ Ok(h) => h,
336
+ Err(e) => return Poll::Ready(Some(Err(e))),
337
+ };
338
+
339
+ // Extract event-type and message-type headers.
340
+ let mut event_type = "";
341
+ let mut message_type = "";
342
+ for h in &headers {
343
+ match h.name.as_str() {
344
+ ":event-type" => event_type = &h.value,
345
+ ":message-type" => message_type = &h.value,
346
+ _ => {}
347
+ }
348
+ }
349
+
350
+ // Handle exceptions (error events from the service).
351
+ if message_type == "exception" {
352
+ let payload = &frame[headers_end..total_length - 4];
353
+ let payload_str = std::str::from_utf8(payload).unwrap_or("<binary>");
354
+ return Poll::Ready(Some(Err(LiterLlmError::Streaming {
355
+ message: format!("Bedrock EventStream exception ({event_type}): {payload_str}"),
356
+ })));
357
+ }
358
+
359
+ // Only process "event" messages.
360
+ if message_type != "event" {
361
+ continue;
362
+ }
363
+
364
+ // Extract payload.
365
+ let payload = &frame[headers_end..total_length - 4];
366
+ let payload_str = match std::str::from_utf8(payload) {
367
+ Ok(s) => s,
368
+ Err(e) => {
369
+ return Poll::Ready(Some(Err(LiterLlmError::Streaming {
370
+ message: format!("EventStream payload is not UTF-8: {e}"),
371
+ })));
372
+ }
373
+ };
374
+
375
+ // Delegate to the provider-supplied parser.
376
+ match (this.parse_event)(event_type, payload_str) {
377
+ Ok(None) => {
378
+ // Terminal event (e.g. messageStop) — stream complete.
379
+ // But don't end yet; there may be a metadata event after.
380
+ // Only truly end when the inner stream is done.
381
+ continue;
382
+ }
383
+ Ok(Some(chunk)) => return Poll::Ready(Some(Ok(chunk))),
384
+ Err(e) => return Poll::Ready(Some(Err(e))),
385
+ }
386
+ }
387
+ }
388
+
389
+ // --- Need more bytes ---
390
+ if *this.done {
391
+ // If the buffer still has bytes, the stream ended mid-frame.
392
+ if !this.buffer.is_empty() {
393
+ let leftover = this.buffer.len();
394
+ this.buffer.clear();
395
+ return Poll::Ready(Some(Err(LiterLlmError::Streaming {
396
+ message: format!("EventStream ended with {leftover} bytes of incomplete frame data"),
397
+ })));
398
+ }
399
+ return Poll::Ready(None);
400
+ }
401
+
402
+ match this.inner.as_mut().poll_next(cx) {
403
+ Poll::Ready(Some(Ok(bytes))) => {
404
+ if this.buffer.len() + bytes.len() > MAX_FRAME_SIZE {
405
+ *this.done = true;
406
+ return Poll::Ready(Some(Err(LiterLlmError::Streaming {
407
+ message: format!("EventStream buffer exceeded {MAX_FRAME_SIZE} bytes"),
408
+ })));
409
+ }
410
+ this.buffer.extend_from_slice(&bytes);
411
+ }
412
+ Poll::Ready(Some(Err(e))) => {
413
+ return Poll::Ready(Some(Err(LiterLlmError::from(e))));
414
+ }
415
+ Poll::Ready(None) => {
416
+ *this.done = true;
417
+ if this.buffer.is_empty() {
418
+ return Poll::Ready(None);
419
+ }
420
+ // Loop once more to try parsing any buffered frames.
421
+ continue;
422
+ }
423
+ Poll::Pending => {
424
+ return Poll::Pending;
425
+ }
426
+ }
427
+ }
428
+ }
429
+ }
430
+
431
+ // ---------------------------------------------------------------------------
432
+ // Tests
433
+ // ---------------------------------------------------------------------------
434
+
435
+ #[cfg(test)]
436
+ mod tests {
437
+ use super::*;
438
+
439
+ /// Build a valid EventStream frame from headers and payload.
440
+ fn build_frame(headers: &[(&str, &str)], payload: &[u8]) -> Vec<u8> {
441
+ // Encode headers.
442
+ let mut header_bytes = Vec::new();
443
+ for (name, value) in headers {
444
+ header_bytes.push(name.len() as u8);
445
+ header_bytes.extend_from_slice(name.as_bytes());
446
+ header_bytes.push(HEADER_TYPE_STRING); // type 7 = string
447
+ let value_bytes = value.as_bytes();
448
+ header_bytes.extend_from_slice(&(value_bytes.len() as u16).to_be_bytes());
449
+ header_bytes.extend_from_slice(value_bytes);
450
+ }
451
+
452
+ let headers_length = header_bytes.len() as u32;
453
+ let total_length = 12 + header_bytes.len() + payload.len() + 4;
454
+
455
+ let mut frame = Vec::with_capacity(total_length);
456
+
457
+ // Prelude: total_length + headers_length.
458
+ frame.extend_from_slice(&(total_length as u32).to_be_bytes());
459
+ frame.extend_from_slice(&headers_length.to_be_bytes());
460
+
461
+ // Prelude CRC.
462
+ let prelude_crc = crc32c(&frame[..8]);
463
+ frame.extend_from_slice(&prelude_crc.to_be_bytes());
464
+
465
+ // Headers + payload.
466
+ frame.extend_from_slice(&header_bytes);
467
+ frame.extend_from_slice(payload);
468
+
469
+ // Message CRC.
470
+ let message_crc = crc32c(&frame);
471
+ frame.extend_from_slice(&message_crc.to_be_bytes());
472
+
473
+ frame
474
+ }
475
+
476
+ #[test]
477
+ fn crc32c_known_value() {
478
+ // CRC32C of empty string is 0x00000000.
479
+ assert_eq!(crc32c(b""), 0x0000_0000);
480
+ // CRC32C of "123456789" is 0xE3069283.
481
+ assert_eq!(crc32c(b"123456789"), 0xE306_9283);
482
+ }
483
+
484
+ #[test]
485
+ fn parse_headers_basic() {
486
+ let headers_data = {
487
+ let mut buf = Vec::new();
488
+ // Header: ":event-type" = "contentBlockDelta"
489
+ let name = b":event-type";
490
+ buf.push(name.len() as u8);
491
+ buf.extend_from_slice(name);
492
+ buf.push(HEADER_TYPE_STRING);
493
+ let value = b"contentBlockDelta";
494
+ buf.extend_from_slice(&(value.len() as u16).to_be_bytes());
495
+ buf.extend_from_slice(value);
496
+ buf
497
+ };
498
+
499
+ let headers = parse_headers(&headers_data).unwrap();
500
+ assert_eq!(headers.len(), 1);
501
+ assert_eq!(headers[0].name, ":event-type");
502
+ assert_eq!(headers[0].value, "contentBlockDelta");
503
+ }
504
+
505
+ #[test]
506
+ fn build_and_parse_frame() {
507
+ let payload = br#"{"delta":{"text":"hello"}}"#;
508
+ let frame = build_frame(
509
+ &[
510
+ (":message-type", "event"),
511
+ (":event-type", "contentBlockDelta"),
512
+ (":content-type", "application/json"),
513
+ ],
514
+ payload,
515
+ );
516
+
517
+ // Verify frame is parseable: check total_length.
518
+ let total_length = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
519
+ assert_eq!(total_length, frame.len());
520
+
521
+ // Verify CRCs.
522
+ let prelude_crc_stored = u32::from_be_bytes([frame[8], frame[9], frame[10], frame[11]]);
523
+ assert_eq!(crc32c(&frame[..8]), prelude_crc_stored);
524
+
525
+ let message_crc_stored = u32::from_be_bytes([
526
+ frame[total_length - 4],
527
+ frame[total_length - 3],
528
+ frame[total_length - 2],
529
+ frame[total_length - 1],
530
+ ]);
531
+ assert_eq!(crc32c(&frame[..total_length - 4]), message_crc_stored);
532
+ }
533
+
534
+ #[tokio::test]
535
+ async fn eventstream_parser_yields_chunks() {
536
+ use super::once_future_stream::once_future;
537
+ use std::pin::pin;
538
+ use std::task::{Context, Poll};
539
+
540
+ // Build a few EventStream frames.
541
+ let frame1 = build_frame(
542
+ &[(":message-type", "event"), (":event-type", "contentBlockDelta")],
543
+ br#"{"contentBlockIndex":0,"delta":{"text":"Hello"}}"#,
544
+ );
545
+ let frame2 = build_frame(
546
+ &[(":message-type", "event"), (":event-type", "messageStop")],
547
+ br#"{"stopReason":"end_turn"}"#,
548
+ );
549
+
550
+ // Concatenate frames into a single bytes chunk.
551
+ let mut all_bytes = Vec::new();
552
+ all_bytes.extend_from_slice(&frame1);
553
+ all_bytes.extend_from_slice(&frame2);
554
+
555
+ // Create a simple stream that yields the bytes in one chunk.
556
+ let byte_stream = once_future(async { Ok::<_, reqwest::Error>(bytes::Bytes::from(all_bytes)) });
557
+
558
+ // Parse function that creates a chunk for text deltas and returns None for stop.
559
+ let parse = |event_type: &str, payload: &str| -> Result<Option<ChatCompletionChunk>> {
560
+ match event_type {
561
+ "contentBlockDelta" => {
562
+ let v: serde_json::Value = serde_json::from_str(payload)
563
+ .map_err(|e| LiterLlmError::Streaming { message: e.to_string() })?;
564
+ let text = v.pointer("/delta/text").and_then(|t| t.as_str()).unwrap_or("");
565
+ // Build a minimal ChatCompletionChunk.
566
+ let chunk_json = serde_json::json!({
567
+ "id": "test",
568
+ "object": "chat.completion.chunk",
569
+ "created": 0,
570
+ "model": "test",
571
+ "choices": [{
572
+ "index": 0,
573
+ "delta": {"content": text},
574
+ "finish_reason": null
575
+ }]
576
+ });
577
+ let chunk: ChatCompletionChunk = serde_json::from_value(chunk_json)
578
+ .map_err(|e| LiterLlmError::Streaming { message: e.to_string() })?;
579
+ Ok(Some(chunk))
580
+ }
581
+ "messageStop" => Ok(None),
582
+ _ => Ok(None),
583
+ }
584
+ };
585
+
586
+ let parser = EventStreamParser::new(byte_stream, parse);
587
+ let mut pinned = pin!(parser);
588
+
589
+ // Poll once — should get the text delta chunk.
590
+ let waker = std::task::Waker::noop();
591
+ let mut cx = Context::from_waker(waker);
592
+
593
+ match pinned.as_mut().poll_next(&mut cx) {
594
+ Poll::Ready(Some(Ok(chunk))) => {
595
+ assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
596
+ }
597
+ other => panic!("expected Ready(Some(Ok(chunk))), got {other:?}"),
598
+ }
599
+ }
600
+
601
+ #[test]
602
+ fn exception_frame_yields_error() {
603
+ let frame = build_frame(
604
+ &[(":message-type", "exception"), (":event-type", "validationException")],
605
+ br#"{"message":"Invalid request"}"#,
606
+ );
607
+
608
+ let headers_length = u32::from_be_bytes([frame[4], frame[5], frame[6], frame[7]]) as usize;
609
+ let headers_start = 12;
610
+ let headers_end = headers_start + headers_length;
611
+ let total_length = frame.len();
612
+
613
+ let headers = parse_headers(&frame[headers_start..headers_end]).unwrap();
614
+ let message_type = headers.iter().find(|h| h.name == ":message-type").unwrap();
615
+ assert_eq!(message_type.value, "exception");
616
+
617
+ let payload = std::str::from_utf8(&frame[headers_end..total_length - 4]).unwrap();
618
+ assert!(payload.contains("Invalid request"));
619
+ }
620
+
621
+ #[test]
622
+ fn corrupt_prelude_crc_detected() {
623
+ let mut frame = build_frame(
624
+ &[(":message-type", "event"), (":event-type", "messageStop")],
625
+ br#"{"stopReason":"end_turn"}"#,
626
+ );
627
+ // Corrupt byte 9 (part of prelude CRC).
628
+ frame[9] ^= 0xFF;
629
+
630
+ let total_length = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
631
+ assert_eq!(total_length, frame.len());
632
+
633
+ let prelude_crc_stored = u32::from_be_bytes([frame[8], frame[9], frame[10], frame[11]]);
634
+ let prelude_crc_actual = crc32c(&frame[..8]);
635
+ assert_ne!(prelude_crc_stored, prelude_crc_actual);
636
+ }
637
+
638
+ #[test]
639
+ fn corrupt_message_crc_detected() {
640
+ let mut frame = build_frame(
641
+ &[(":message-type", "event"), (":event-type", "messageStop")],
642
+ br#"{"stopReason":"end_turn"}"#,
643
+ );
644
+ let len = frame.len();
645
+ // Corrupt a payload byte.
646
+ frame[len / 2] ^= 0xFF;
647
+
648
+ let message_crc_stored = u32::from_be_bytes([frame[len - 4], frame[len - 3], frame[len - 2], frame[len - 1]]);
649
+ let message_crc_actual = crc32c(&frame[..len - 4]);
650
+ assert_ne!(message_crc_stored, message_crc_actual);
651
+ }
652
+
653
+ #[test]
654
+ fn empty_payload_frame() {
655
+ let frame = build_frame(&[(":message-type", "event"), (":event-type", "contentBlockStop")], b"");
656
+ let total_length = u32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]) as usize;
657
+ assert_eq!(total_length, frame.len());
658
+ }
659
+
660
+ #[tokio::test]
661
+ async fn parser_handles_split_frames() {
662
+ use super::vec_stream::VecStream;
663
+ use std::pin::pin;
664
+ use std::task::{Context, Poll};
665
+
666
+ let frame = build_frame(
667
+ &[(":message-type", "event"), (":event-type", "contentBlockDelta")],
668
+ br#"{"contentBlockIndex":0,"delta":{"text":"split"}}"#,
669
+ );
670
+
671
+ // Split the frame in half to simulate TCP segmentation.
672
+ let mid = frame.len() / 2;
673
+ let chunk1 = bytes::Bytes::from(frame[..mid].to_vec());
674
+ let chunk2 = bytes::Bytes::from(frame[mid..].to_vec());
675
+
676
+ let byte_stream = VecStream::new(vec![Ok(chunk1), Ok(chunk2)]);
677
+ let parse = |event_type: &str, payload: &str| -> Result<Option<ChatCompletionChunk>> {
678
+ if event_type == "contentBlockDelta" {
679
+ let v: serde_json::Value = serde_json::from_str(payload).unwrap();
680
+ let text = v.pointer("/delta/text").and_then(|t| t.as_str()).unwrap_or("");
681
+ let chunk: ChatCompletionChunk = serde_json::from_value(serde_json::json!({
682
+ "id": "t", "object": "chat.completion.chunk", "created": 0, "model": "t",
683
+ "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": null}]
684
+ }))
685
+ .unwrap();
686
+ Ok(Some(chunk))
687
+ } else {
688
+ Ok(None)
689
+ }
690
+ };
691
+
692
+ let parser = EventStreamParser::new(byte_stream, parse);
693
+ let mut pinned = pin!(parser);
694
+
695
+ let waker = std::task::Waker::noop();
696
+ let mut cx = Context::from_waker(waker);
697
+
698
+ // First poll may return Pending (first chunk doesn't complete the frame).
699
+ // Keep polling until we get the result.
700
+ let mut result = None;
701
+ for _ in 0..10 {
702
+ match pinned.as_mut().poll_next(&mut cx) {
703
+ Poll::Ready(Some(Ok(chunk))) => {
704
+ result = Some(chunk);
705
+ break;
706
+ }
707
+ Poll::Ready(Some(Err(e))) => panic!("unexpected error: {e}"),
708
+ Poll::Ready(None) => panic!("unexpected stream end"),
709
+ Poll::Pending => continue,
710
+ }
711
+ }
712
+ let chunk = result.expect("should have parsed the split frame");
713
+ assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("split"));
714
+ }
715
+
716
+ #[tokio::test]
717
+ async fn parser_errors_on_truncated_stream() {
718
+ use super::vec_stream::VecStream;
719
+ use std::pin::pin;
720
+ use std::task::{Context, Poll};
721
+
722
+ let frame = build_frame(
723
+ &[(":message-type", "event"), (":event-type", "contentBlockDelta")],
724
+ br#"{"contentBlockIndex":0,"delta":{"text":"truncated"}}"#,
725
+ );
726
+
727
+ // Only deliver the first half — simulate connection drop.
728
+ let partial = bytes::Bytes::from(frame[..frame.len() / 2].to_vec());
729
+ let byte_stream = VecStream::new(vec![Ok(partial)]);
730
+
731
+ let parse = |_: &str, _: &str| -> Result<Option<ChatCompletionChunk>> { Ok(None) };
732
+
733
+ let parser = EventStreamParser::new(byte_stream, parse);
734
+ let mut pinned = pin!(parser);
735
+
736
+ let waker = std::task::Waker::noop();
737
+ let mut cx = Context::from_waker(waker);
738
+
739
+ // Poll until we get an error about truncated data.
740
+ let mut got_error = false;
741
+ for _ in 0..10 {
742
+ match pinned.as_mut().poll_next(&mut cx) {
743
+ Poll::Ready(Some(Err(e))) => {
744
+ let msg = e.to_string();
745
+ assert!(
746
+ msg.contains("incomplete frame"),
747
+ "expected truncation error, got: {msg}"
748
+ );
749
+ got_error = true;
750
+ break;
751
+ }
752
+ Poll::Ready(Some(Ok(_))) => panic!("unexpected success"),
753
+ Poll::Ready(None) => panic!("unexpected clean end"),
754
+ Poll::Pending => continue,
755
+ }
756
+ }
757
+ assert!(got_error, "should have received a truncation error");
758
+ }
759
+
760
+ #[tokio::test]
761
+ async fn parser_exception_frame_through_stream() {
762
+ use super::vec_stream::VecStream;
763
+ use std::pin::pin;
764
+ use std::task::{Context, Poll};
765
+
766
+ let frame = build_frame(
767
+ &[(":message-type", "exception"), (":event-type", "throttlingException")],
768
+ br#"{"message":"Rate exceeded"}"#,
769
+ );
770
+
771
+ let byte_stream = VecStream::new(vec![Ok(bytes::Bytes::from(frame))]);
772
+ let parse = |_: &str, _: &str| -> Result<Option<ChatCompletionChunk>> { Ok(None) };
773
+
774
+ let parser = EventStreamParser::new(byte_stream, parse);
775
+ let mut pinned = pin!(parser);
776
+
777
+ let waker = std::task::Waker::noop();
778
+ let mut cx = Context::from_waker(waker);
779
+
780
+ match pinned.as_mut().poll_next(&mut cx) {
781
+ Poll::Ready(Some(Err(e))) => {
782
+ let msg = e.to_string();
783
+ assert!(msg.contains("throttlingException"), "got: {msg}");
784
+ assert!(msg.contains("Rate exceeded"), "got: {msg}");
785
+ }
786
+ other => panic!("expected error, got {other:?}"),
787
+ }
788
+ }
789
+ }
790
+
791
+ /// Test helper: a stream that yields items from a Vec.
792
+ #[cfg(test)]
793
+ mod vec_stream {
794
+ use std::collections::VecDeque;
795
+ use std::pin::Pin;
796
+ use std::task::{Context, Poll};
797
+
798
+ use futures_core::Stream;
799
+
800
+ pub struct VecStream<T> {
801
+ items: VecDeque<T>,
802
+ }
803
+
804
+ impl<T> VecStream<T> {
805
+ pub fn new(items: Vec<T>) -> Self {
806
+ Self { items: items.into() }
807
+ }
808
+ }
809
+
810
+ impl<T: Unpin> Stream for VecStream<T> {
811
+ type Item = T;
812
+
813
+ fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
814
+ let this = self.get_mut();
815
+ match this.items.pop_front() {
816
+ Some(item) => Poll::Ready(Some(item)),
817
+ None => Poll::Ready(None),
818
+ }
819
+ }
820
+ }
821
+ }
822
+
823
+ /// Helper: create a stream that yields a single future's result.
824
+ #[cfg(test)]
825
+ mod once_future_stream {
826
+ use std::future::Future;
827
+ use std::pin::Pin;
828
+ use std::task::{Context, Poll};
829
+
830
+ use futures_core::Stream;
831
+
832
+ pin_project_lite::pin_project! {
833
+ pub struct OnceFuture<F> {
834
+ #[pin]
835
+ future: Option<F>,
836
+ }
837
+ }
838
+
839
+ impl<F: Future> Stream for OnceFuture<F> {
840
+ type Item = F::Output;
841
+
842
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
843
+ let mut this = self.project();
844
+ match this.future.as_mut().as_pin_mut() {
845
+ Some(f) => match f.poll(cx) {
846
+ Poll::Ready(val) => {
847
+ this.future.set(None);
848
+ Poll::Ready(Some(val))
849
+ }
850
+ Poll::Pending => Poll::Pending,
851
+ },
852
+ None => Poll::Ready(None),
853
+ }
854
+ }
855
+ }
856
+
857
+ pub fn once_future<F: Future>(f: F) -> OnceFuture<F> {
858
+ OnceFuture { future: Some(f) }
859
+ }
860
+ }