liter_llm 1.0.0.pre.rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. checksums.yaml +7 -0
  2. data/README.md +239 -0
  3. data/ext/liter_llm_rb/extconf.rb +65 -0
  4. data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
  5. data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
  6. data/ext/liter_llm_rb/native/Cargo.toml +32 -0
  7. data/ext/liter_llm_rb/native/build.rs +15 -0
  8. data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
  9. data/lib/liter_llm.rb +8 -0
  10. data/sig/liter_llm.rbs +416 -0
  11. data/vendor/Cargo.toml +54 -0
  12. data/vendor/liter-llm/Cargo.toml +92 -0
  13. data/vendor/liter-llm/README.md +252 -0
  14. data/vendor/liter-llm/schemas/pricing.json +40 -0
  15. data/vendor/liter-llm/schemas/providers.json +1662 -0
  16. data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
  17. data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
  18. data/vendor/liter-llm/src/auth/mod.rs +68 -0
  19. data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
  20. data/vendor/liter-llm/src/client/config.rs +351 -0
  21. data/vendor/liter-llm/src/client/managed.rs +622 -0
  22. data/vendor/liter-llm/src/client/mod.rs +864 -0
  23. data/vendor/liter-llm/src/cost.rs +212 -0
  24. data/vendor/liter-llm/src/error.rs +190 -0
  25. data/vendor/liter-llm/src/http/eventstream.rs +860 -0
  26. data/vendor/liter-llm/src/http/mod.rs +12 -0
  27. data/vendor/liter-llm/src/http/request.rs +438 -0
  28. data/vendor/liter-llm/src/http/retry.rs +72 -0
  29. data/vendor/liter-llm/src/http/streaming.rs +289 -0
  30. data/vendor/liter-llm/src/lib.rs +37 -0
  31. data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
  32. data/vendor/liter-llm/src/provider/azure.rs +579 -0
  33. data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
  34. data/vendor/liter-llm/src/provider/cohere.rs +654 -0
  35. data/vendor/liter-llm/src/provider/custom.rs +404 -0
  36. data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
  37. data/vendor/liter-llm/src/provider/mistral.rs +188 -0
  38. data/vendor/liter-llm/src/provider/mod.rs +616 -0
  39. data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
  40. data/vendor/liter-llm/src/tests.rs +1425 -0
  41. data/vendor/liter-llm/src/tokenizer.rs +281 -0
  42. data/vendor/liter-llm/src/tower/budget.rs +599 -0
  43. data/vendor/liter-llm/src/tower/cache.rs +502 -0
  44. data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
  45. data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
  46. data/vendor/liter-llm/src/tower/cost.rs +404 -0
  47. data/vendor/liter-llm/src/tower/fallback.rs +121 -0
  48. data/vendor/liter-llm/src/tower/health.rs +219 -0
  49. data/vendor/liter-llm/src/tower/hooks.rs +369 -0
  50. data/vendor/liter-llm/src/tower/mod.rs +77 -0
  51. data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
  52. data/vendor/liter-llm/src/tower/router.rs +436 -0
  53. data/vendor/liter-llm/src/tower/service.rs +181 -0
  54. data/vendor/liter-llm/src/tower/tests.rs +539 -0
  55. data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
  56. data/vendor/liter-llm/src/tower/tracing.rs +209 -0
  57. data/vendor/liter-llm/src/tower/types.rs +170 -0
  58. data/vendor/liter-llm/src/types/audio.rs +52 -0
  59. data/vendor/liter-llm/src/types/batch.rs +77 -0
  60. data/vendor/liter-llm/src/types/chat.rs +214 -0
  61. data/vendor/liter-llm/src/types/common.rs +244 -0
  62. data/vendor/liter-llm/src/types/embedding.rs +84 -0
  63. data/vendor/liter-llm/src/types/files.rs +58 -0
  64. data/vendor/liter-llm/src/types/image.rs +40 -0
  65. data/vendor/liter-llm/src/types/mod.rs +27 -0
  66. data/vendor/liter-llm/src/types/models.rs +21 -0
  67. data/vendor/liter-llm/src/types/moderation.rs +80 -0
  68. data/vendor/liter-llm/src/types/ocr.rs +87 -0
  69. data/vendor/liter-llm/src/types/rerank.rs +46 -0
  70. data/vendor/liter-llm/src/types/responses.rs +55 -0
  71. data/vendor/liter-llm/src/types/search.rs +45 -0
  72. data/vendor/liter-llm/tests/contract.rs +332 -0
  73. data/vendor/liter-llm-ffi/Cargo.toml +30 -0
  74. data/vendor/liter-llm-ffi/build.rs +66 -0
  75. data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
  76. data/vendor/liter-llm-ffi/liter_llm.h +850 -0
  77. data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
  78. metadata +286 -0
@@ -0,0 +1,654 @@
1
+ use std::borrow::Cow;
2
+
3
+ use serde_json::Value;
4
+
5
+ use crate::error::{LiterLlmError, Result};
6
+ use crate::provider::{Provider, unix_timestamp_secs};
7
+ use crate::types::{ChatCompletionChunk, FinishReason, StreamChoice, StreamDelta, StreamFunctionCall, StreamToolCall};
8
+
9
+ /// Cohere provider (Command model family).
10
+ ///
11
+ /// Differences from the OpenAI-compatible baseline:
12
+ /// - Chat endpoint is `/chat` instead of `/chat/completions`.
13
+ /// - Rerank endpoint is `/rerank` instead of the default path.
14
+ /// - `stream_options` is an OpenAI-specific field and must be stripped; `stream` is kept (Cohere v2 requires it).
15
+ /// - Finish reasons use Cohere-specific names (`COMPLETE`, `MAX_TOKENS`, `TOOL_CALL`).
16
+ /// - Usage is reported under `tokens.input_tokens` / `tokens.output_tokens`.
17
+ /// - Response may lack `object` and `created` fields.
18
+ pub struct CohereProvider;
19
+
20
+ impl Provider for CohereProvider {
21
+ fn name(&self) -> &str {
22
+ "cohere"
23
+ }
24
+
25
+ fn base_url(&self) -> &str {
26
+ "https://api.cohere.com/v2"
27
+ }
28
+
29
+ fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
30
+ Some((Cow::Borrowed("Authorization"), Cow::Owned(format!("Bearer {api_key}"))))
31
+ }
32
+
33
+ fn matches_model(&self, model: &str) -> bool {
34
+ model.starts_with("command-r") || model.starts_with("command-") || model.starts_with("cohere/")
35
+ }
36
+
37
+ fn strip_model_prefix<'m>(&self, model: &'m str) -> &'m str {
38
+ model.strip_prefix("cohere/").unwrap_or(model)
39
+ }
40
+
41
+ /// Cohere uses `/chat` instead of `/chat/completions`.
42
+ fn chat_completions_path(&self) -> &str {
43
+ "/chat"
44
+ }
45
+
46
+ /// Cohere uses `/rerank` at the v2 base.
47
+ fn rerank_path(&self) -> &str {
48
+ "/rerank"
49
+ }
50
+
51
+ /// Strip transport-level parameters that Cohere does not accept in the body.
52
+ ///
53
+ /// Note: Cohere v2 requires `stream` in the body, so only `stream_options`
54
+ /// (an OpenAI-specific field) is removed.
55
+ fn transform_request(&self, body: &mut Value) -> Result<()> {
56
+ if let Some(obj) = body.as_object_mut() {
57
+ obj.remove("stream_options");
58
+ }
59
+ Ok(())
60
+ }
61
+
62
+ /// Parse a Cohere v2 streaming SSE event into a `ChatCompletionChunk`.
63
+ ///
64
+ /// Cohere v2 streaming events use a `type` field to distinguish event kinds:
65
+ /// - `stream-start`: beginning of stream, emit role = assistant
66
+ /// - `content-delta`: text content token, extract from `delta.text`
67
+ /// - `tool-call-start`: start of a tool call with id and function name
68
+ /// - `tool-call-delta`: partial tool call arguments
69
+ /// - `tool-call-end`: end of a tool call (skipped)
70
+ /// - `stream-end`: end of stream with finish reason and usage
71
+ fn parse_stream_event(&self, event_data: &str) -> Result<Option<ChatCompletionChunk>> {
72
+ let v: Value = serde_json::from_str(event_data).map_err(|e| LiterLlmError::Streaming {
73
+ message: format!("failed to parse Cohere SSE event: {e}"),
74
+ })?;
75
+
76
+ let event_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("");
77
+
78
+ match event_type {
79
+ "stream-start" => {
80
+ let id = v.get("generation_id").and_then(|g| g.as_str()).unwrap_or("").to_owned();
81
+
82
+ Ok(Some(ChatCompletionChunk {
83
+ id,
84
+ object: "chat.completion.chunk".to_owned(),
85
+ created: unix_timestamp_secs(),
86
+ model: String::new(),
87
+ choices: vec![StreamChoice {
88
+ index: 0,
89
+ delta: StreamDelta {
90
+ role: Some("assistant".to_owned()),
91
+ content: None,
92
+ tool_calls: None,
93
+ function_call: None,
94
+ refusal: None,
95
+ },
96
+ finish_reason: None,
97
+ }],
98
+ usage: None,
99
+ system_fingerprint: None,
100
+ service_tier: None,
101
+ }))
102
+ }
103
+
104
+ "content-delta" => {
105
+ let text = v
106
+ .pointer("/delta/text")
107
+ .and_then(|t| t.as_str())
108
+ .unwrap_or("")
109
+ .to_owned();
110
+
111
+ Ok(Some(ChatCompletionChunk {
112
+ id: String::new(),
113
+ object: "chat.completion.chunk".to_owned(),
114
+ created: unix_timestamp_secs(),
115
+ model: String::new(),
116
+ choices: vec![StreamChoice {
117
+ index: 0,
118
+ delta: StreamDelta {
119
+ role: None,
120
+ content: Some(text),
121
+ tool_calls: None,
122
+ function_call: None,
123
+ refusal: None,
124
+ },
125
+ finish_reason: None,
126
+ }],
127
+ usage: None,
128
+ system_fingerprint: None,
129
+ service_tier: None,
130
+ }))
131
+ }
132
+
133
+ "tool-call-start" => {
134
+ let index = v.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as u32;
135
+ let tool_id = v.pointer("/delta/id").and_then(|i| i.as_str()).unwrap_or("").to_owned();
136
+ let tool_name = v
137
+ .pointer("/delta/function/name")
138
+ .and_then(|n| n.as_str())
139
+ .unwrap_or("")
140
+ .to_owned();
141
+
142
+ Ok(Some(ChatCompletionChunk {
143
+ id: String::new(),
144
+ object: "chat.completion.chunk".to_owned(),
145
+ created: unix_timestamp_secs(),
146
+ model: String::new(),
147
+ choices: vec![StreamChoice {
148
+ index: 0,
149
+ delta: StreamDelta {
150
+ role: None,
151
+ content: None,
152
+ tool_calls: Some(vec![StreamToolCall {
153
+ index,
154
+ id: Some(tool_id),
155
+ call_type: Some(crate::types::ToolType::Function),
156
+ function: Some(StreamFunctionCall {
157
+ name: Some(tool_name),
158
+ arguments: None,
159
+ }),
160
+ }]),
161
+ function_call: None,
162
+ refusal: None,
163
+ },
164
+ finish_reason: None,
165
+ }],
166
+ usage: None,
167
+ system_fingerprint: None,
168
+ service_tier: None,
169
+ }))
170
+ }
171
+
172
+ "tool-call-delta" => {
173
+ let index = v.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as u32;
174
+ let arguments = v
175
+ .pointer("/delta/function/arguments")
176
+ .and_then(|a| a.as_str())
177
+ .unwrap_or("")
178
+ .to_owned();
179
+
180
+ Ok(Some(ChatCompletionChunk {
181
+ id: String::new(),
182
+ object: "chat.completion.chunk".to_owned(),
183
+ created: unix_timestamp_secs(),
184
+ model: String::new(),
185
+ choices: vec![StreamChoice {
186
+ index: 0,
187
+ delta: StreamDelta {
188
+ role: None,
189
+ content: None,
190
+ tool_calls: Some(vec![StreamToolCall {
191
+ index,
192
+ id: None,
193
+ call_type: None,
194
+ function: Some(StreamFunctionCall {
195
+ name: None,
196
+ arguments: Some(arguments),
197
+ }),
198
+ }]),
199
+ function_call: None,
200
+ refusal: None,
201
+ },
202
+ finish_reason: None,
203
+ }],
204
+ usage: None,
205
+ system_fingerprint: None,
206
+ service_tier: None,
207
+ }))
208
+ }
209
+
210
+ "tool-call-end" => Ok(None),
211
+
212
+ "stream-end" => {
213
+ let finish_reason = v
214
+ .get("finish_reason")
215
+ .and_then(|r| r.as_str())
216
+ .map(map_cohere_finish_reason);
217
+
218
+ let usage = extract_cohere_stream_usage(&v);
219
+
220
+ Ok(Some(ChatCompletionChunk {
221
+ id: String::new(),
222
+ object: "chat.completion.chunk".to_owned(),
223
+ created: unix_timestamp_secs(),
224
+ model: String::new(),
225
+ choices: vec![StreamChoice {
226
+ index: 0,
227
+ delta: StreamDelta {
228
+ role: None,
229
+ content: None,
230
+ tool_calls: None,
231
+ function_call: None,
232
+ refusal: None,
233
+ },
234
+ finish_reason,
235
+ }],
236
+ usage,
237
+ system_fingerprint: None,
238
+ service_tier: None,
239
+ }))
240
+ }
241
+
242
+ // Unknown event types are silently skipped.
243
+ _ => Ok(None),
244
+ }
245
+ }
246
+
247
+ /// Normalize Cohere response format to OpenAI-compatible JSON.
248
+ ///
249
+ /// - Maps finish reasons: `COMPLETE` -> `stop`, `MAX_TOKENS` -> `length`,
250
+ /// `TOOL_CALL` -> `tool_calls`.
251
+ /// - Normalizes usage from `tokens.{input,output}_tokens` to
252
+ /// `usage.{prompt,completion,total}_tokens`.
253
+ /// - Ensures `object` and `created` fields are present.
254
+ fn transform_response(&self, body: &mut Value) -> Result<()> {
255
+ // Map finish reasons in choices.
256
+ if let Some(choices) = body.get_mut("choices").and_then(Value::as_array_mut) {
257
+ for choice in choices {
258
+ if let Some(reason) = choice.get("finish_reason").and_then(Value::as_str) {
259
+ let mapped = match reason {
260
+ "COMPLETE" => "stop",
261
+ "MAX_TOKENS" => "length",
262
+ "TOOL_CALL" => "tool_calls",
263
+ other => other,
264
+ };
265
+ choice["finish_reason"] = Value::String(mapped.to_owned());
266
+ }
267
+ }
268
+ }
269
+
270
+ // Normalize usage from Cohere's `tokens` format.
271
+ if body.get("usage").is_none()
272
+ && let Some(tokens) = body.get("tokens")
273
+ {
274
+ let input = tokens.get("input_tokens").and_then(Value::as_u64).unwrap_or(0);
275
+ let output = tokens.get("output_tokens").and_then(Value::as_u64).unwrap_or(0);
276
+ body["usage"] = serde_json::json!({
277
+ "prompt_tokens": input,
278
+ "completion_tokens": output,
279
+ "total_tokens": input + output,
280
+ });
281
+ }
282
+
283
+ // Ensure standard OpenAI fields are present.
284
+ if body.get("object").is_none() {
285
+ body["object"] = Value::String("chat.completion".to_owned());
286
+ }
287
+ if body.get("created").is_none() {
288
+ body["created"] = Value::Number(unix_timestamp_secs().into());
289
+ }
290
+
291
+ Ok(())
292
+ }
293
+ }
294
+
295
+ /// Map Cohere finish reason strings to OpenAI-compatible `FinishReason`.
296
+ fn map_cohere_finish_reason(reason: &str) -> FinishReason {
297
+ match reason {
298
+ "COMPLETE" => FinishReason::Stop,
299
+ "MAX_TOKENS" => FinishReason::Length,
300
+ "TOOL_CALL" => FinishReason::ToolCalls,
301
+ _ => FinishReason::Other,
302
+ }
303
+ }
304
+
305
+ /// Extract usage from a Cohere `stream-end` event.
306
+ ///
307
+ /// Cohere v2 reports usage under `usage.billed_units.{input_tokens, output_tokens}`.
308
+ fn extract_cohere_stream_usage(v: &Value) -> Option<crate::types::Usage> {
309
+ let billed = v.pointer("/usage/billed_units")?;
310
+ let input = billed.get("input_tokens").and_then(|t| t.as_u64()).unwrap_or(0);
311
+ let output = billed.get("output_tokens").and_then(|t| t.as_u64()).unwrap_or(0);
312
+
313
+ Some(crate::types::Usage {
314
+ prompt_tokens: input,
315
+ completion_tokens: output,
316
+ total_tokens: input + output,
317
+ })
318
+ }
319
+
320
+ #[cfg(test)]
321
+ mod tests {
322
+ use serde_json::json;
323
+
324
+ use super::*;
325
+
326
+ #[test]
327
+ fn test_cohere_name_and_base_url() {
328
+ let provider = CohereProvider;
329
+ assert_eq!(provider.name(), "cohere");
330
+ assert_eq!(provider.base_url(), "https://api.cohere.com/v2");
331
+ }
332
+
333
+ #[test]
334
+ fn test_cohere_auth_header() {
335
+ let provider = CohereProvider;
336
+ let (name, value) = provider.auth_header("test-key").expect("should return auth header");
337
+ assert_eq!(name, "Authorization");
338
+ assert_eq!(value, "Bearer test-key");
339
+ }
340
+
341
+ #[test]
342
+ fn test_cohere_matches_model() {
343
+ let provider = CohereProvider;
344
+ assert!(provider.matches_model("command-r-plus"));
345
+ assert!(provider.matches_model("command-r"));
346
+ assert!(provider.matches_model("command-light"));
347
+ assert!(provider.matches_model("cohere/command-r-plus"));
348
+ assert!(!provider.matches_model("gpt-4"));
349
+ assert!(!provider.matches_model("claude-3"));
350
+ }
351
+
352
+ #[test]
353
+ fn test_cohere_strip_prefix() {
354
+ let provider = CohereProvider;
355
+ assert_eq!(provider.strip_model_prefix("cohere/command-r"), "command-r");
356
+ assert_eq!(provider.strip_model_prefix("command-r"), "command-r");
357
+ }
358
+
359
+ #[test]
360
+ fn test_cohere_endpoints() {
361
+ let provider = CohereProvider;
362
+ assert_eq!(provider.chat_completions_path(), "/chat");
363
+ assert_eq!(provider.rerank_path(), "/rerank");
364
+ }
365
+
366
+ #[test]
367
+ fn test_cohere_transform_request_preserves_stream_strips_options() {
368
+ let provider = CohereProvider;
369
+ let mut body = json!({
370
+ "model": "command-r-plus",
371
+ "messages": [{"role": "user", "content": "hello"}],
372
+ "stream": true,
373
+ "stream_options": {"include_usage": true}
374
+ });
375
+ provider.transform_request(&mut body).expect("transform should succeed");
376
+ // Cohere v2 needs `stream` in the body — only `stream_options` is removed.
377
+ assert_eq!(body["stream"], true);
378
+ assert!(body.get("stream_options").is_none());
379
+ // Other fields preserved.
380
+ assert_eq!(body["model"], "command-r-plus");
381
+ }
382
+
383
+ #[test]
384
+ fn test_cohere_transform_response_finish_reasons() {
385
+ let provider = CohereProvider;
386
+ let mut body = json!({
387
+ "choices": [
388
+ {"finish_reason": "COMPLETE", "message": {"content": "hi"}},
389
+ {"finish_reason": "MAX_TOKENS", "message": {"content": "..."}},
390
+ {"finish_reason": "TOOL_CALL", "message": {"content": ""}}
391
+ ]
392
+ });
393
+ provider
394
+ .transform_response(&mut body)
395
+ .expect("transform should succeed");
396
+
397
+ let choices = body["choices"].as_array().expect("choices array");
398
+ assert_eq!(choices[0]["finish_reason"], "stop");
399
+ assert_eq!(choices[1]["finish_reason"], "length");
400
+ assert_eq!(choices[2]["finish_reason"], "tool_calls");
401
+ }
402
+
403
+ #[test]
404
+ fn test_cohere_transform_response_usage_normalization() {
405
+ let provider = CohereProvider;
406
+ let mut body = json!({
407
+ "choices": [{"finish_reason": "COMPLETE"}],
408
+ "tokens": {
409
+ "input_tokens": 10,
410
+ "output_tokens": 20
411
+ }
412
+ });
413
+ provider
414
+ .transform_response(&mut body)
415
+ .expect("transform should succeed");
416
+
417
+ let usage = &body["usage"];
418
+ assert_eq!(usage["prompt_tokens"], 10);
419
+ assert_eq!(usage["completion_tokens"], 20);
420
+ assert_eq!(usage["total_tokens"], 30);
421
+ }
422
+
423
+ #[test]
424
+ fn test_cohere_transform_response_adds_object_and_created() {
425
+ let provider = CohereProvider;
426
+ let mut body = json!({"choices": []});
427
+ provider
428
+ .transform_response(&mut body)
429
+ .expect("transform should succeed");
430
+
431
+ assert_eq!(body["object"], "chat.completion");
432
+ assert!(body["created"].as_u64().is_some());
433
+ }
434
+
435
+ #[test]
436
+ fn test_cohere_transform_response_preserves_existing_usage() {
437
+ let provider = CohereProvider;
438
+ let mut body = json!({
439
+ "choices": [],
440
+ "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
441
+ "tokens": {"input_tokens": 99, "output_tokens": 99}
442
+ });
443
+ provider
444
+ .transform_response(&mut body)
445
+ .expect("transform should succeed");
446
+
447
+ // Existing usage should not be overwritten.
448
+ assert_eq!(body["usage"]["prompt_tokens"], 5);
449
+ }
450
+
451
+ // ── Streaming SSE parser tests ───────────────────────────────────────────
452
+
453
+ #[test]
454
+ fn test_parse_stream_event_stream_start() {
455
+ let provider = CohereProvider;
456
+ let event = r#"{"type":"stream-start","generation_id":"gen-123"}"#;
457
+ let chunk = provider
458
+ .parse_stream_event(event)
459
+ .expect("should parse")
460
+ .expect("should return Some");
461
+
462
+ assert_eq!(chunk.id, "gen-123");
463
+ assert_eq!(chunk.object, "chat.completion.chunk");
464
+ assert_eq!(chunk.choices.len(), 1);
465
+ assert_eq!(chunk.choices[0].delta.role.as_deref(), Some("assistant"));
466
+ assert!(chunk.choices[0].delta.content.is_none());
467
+ assert!(chunk.choices[0].finish_reason.is_none());
468
+ assert!(chunk.usage.is_none());
469
+ }
470
+
471
+ #[test]
472
+ fn test_parse_stream_event_content_delta() {
473
+ let provider = CohereProvider;
474
+ let event = r#"{"type":"content-delta","delta":{"type":"text_content","text":"Hello"}}"#;
475
+ let chunk = provider
476
+ .parse_stream_event(event)
477
+ .expect("should parse")
478
+ .expect("should return Some");
479
+
480
+ assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
481
+ assert!(chunk.choices[0].delta.role.is_none());
482
+ assert!(chunk.choices[0].delta.tool_calls.is_none());
483
+ }
484
+
485
+ #[test]
486
+ fn test_parse_stream_event_content_delta_whitespace() {
487
+ let provider = CohereProvider;
488
+ let event = r#"{"type":"content-delta","delta":{"type":"text_content","text":" world"}}"#;
489
+ let chunk = provider
490
+ .parse_stream_event(event)
491
+ .expect("should parse")
492
+ .expect("should return Some");
493
+
494
+ assert_eq!(chunk.choices[0].delta.content.as_deref(), Some(" world"));
495
+ }
496
+
497
+ #[test]
498
+ fn test_parse_stream_event_tool_call_start() {
499
+ let provider = CohereProvider;
500
+ let event = r#"{"type":"tool-call-start","index":0,"delta":{"type":"tool_call","id":"tc-001","function":{"name":"get_weather","arguments":""}}}"#;
501
+ let chunk = provider
502
+ .parse_stream_event(event)
503
+ .expect("should parse")
504
+ .expect("should return Some");
505
+
506
+ let tool_calls = chunk.choices[0]
507
+ .delta
508
+ .tool_calls
509
+ .as_ref()
510
+ .expect("should have tool_calls");
511
+ assert_eq!(tool_calls.len(), 1);
512
+ assert_eq!(tool_calls[0].index, 0);
513
+ assert_eq!(tool_calls[0].id.as_deref(), Some("tc-001"));
514
+ let func = tool_calls[0].function.as_ref().expect("should have function");
515
+ assert_eq!(func.name.as_deref(), Some("get_weather"));
516
+ assert!(func.arguments.is_none());
517
+ }
518
+
519
+ #[test]
520
+ fn test_parse_stream_event_tool_call_delta() {
521
+ let provider = CohereProvider;
522
+ let event =
523
+ r#"{"type":"tool-call-delta","index":0,"delta":{"type":"tool_call","function":{"arguments":"{\"ci"}}}"#;
524
+ let chunk = provider
525
+ .parse_stream_event(event)
526
+ .expect("should parse")
527
+ .expect("should return Some");
528
+
529
+ let tool_calls = chunk.choices[0]
530
+ .delta
531
+ .tool_calls
532
+ .as_ref()
533
+ .expect("should have tool_calls");
534
+ assert_eq!(tool_calls.len(), 1);
535
+ assert_eq!(tool_calls[0].index, 0);
536
+ assert!(tool_calls[0].id.is_none());
537
+ let func = tool_calls[0].function.as_ref().expect("should have function");
538
+ assert!(func.name.is_none());
539
+ assert_eq!(func.arguments.as_deref(), Some("{\"ci"));
540
+ }
541
+
542
+ #[test]
543
+ fn test_parse_stream_event_tool_call_end_returns_none() {
544
+ let provider = CohereProvider;
545
+ let event = r#"{"type":"tool-call-end","index":0}"#;
546
+ let result = provider.parse_stream_event(event).expect("should parse");
547
+
548
+ assert!(result.is_none());
549
+ }
550
+
551
+ #[test]
552
+ fn test_parse_stream_event_stream_end_complete() {
553
+ let provider = CohereProvider;
554
+ let event = r#"{"type":"stream-end","finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":10,"output_tokens":5}}}"#;
555
+ let chunk = provider
556
+ .parse_stream_event(event)
557
+ .expect("should parse")
558
+ .expect("should return Some");
559
+
560
+ assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::Stop));
561
+ let usage = chunk.usage.as_ref().expect("should have usage");
562
+ assert_eq!(usage.prompt_tokens, 10);
563
+ assert_eq!(usage.completion_tokens, 5);
564
+ assert_eq!(usage.total_tokens, 15);
565
+ }
566
+
567
+ #[test]
568
+ fn test_parse_stream_event_stream_end_max_tokens() {
569
+ let provider = CohereProvider;
570
+ let event = r#"{"type":"stream-end","finish_reason":"MAX_TOKENS","usage":{"billed_units":{"input_tokens":20,"output_tokens":100}}}"#;
571
+ let chunk = provider
572
+ .parse_stream_event(event)
573
+ .expect("should parse")
574
+ .expect("should return Some");
575
+
576
+ assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::Length));
577
+ let usage = chunk.usage.as_ref().expect("should have usage");
578
+ assert_eq!(usage.prompt_tokens, 20);
579
+ assert_eq!(usage.completion_tokens, 100);
580
+ assert_eq!(usage.total_tokens, 120);
581
+ }
582
+
583
+ #[test]
584
+ fn test_parse_stream_event_stream_end_tool_call() {
585
+ let provider = CohereProvider;
586
+ let event = r#"{"type":"stream-end","finish_reason":"TOOL_CALL","usage":{"billed_units":{"input_tokens":15,"output_tokens":8}}}"#;
587
+ let chunk = provider
588
+ .parse_stream_event(event)
589
+ .expect("should parse")
590
+ .expect("should return Some");
591
+
592
+ assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::ToolCalls));
593
+ }
594
+
595
+ #[test]
596
+ fn test_parse_stream_event_stream_end_no_usage() {
597
+ let provider = CohereProvider;
598
+ let event = r#"{"type":"stream-end","finish_reason":"COMPLETE"}"#;
599
+ let chunk = provider
600
+ .parse_stream_event(event)
601
+ .expect("should parse")
602
+ .expect("should return Some");
603
+
604
+ assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::Stop));
605
+ assert!(chunk.usage.is_none());
606
+ }
607
+
608
+ #[test]
609
+ fn test_parse_stream_event_unknown_type_returns_none() {
610
+ let provider = CohereProvider;
611
+ let event = r#"{"type":"debug","message":"some debug info"}"#;
612
+ let result = provider.parse_stream_event(event).expect("should parse");
613
+
614
+ assert!(result.is_none());
615
+ }
616
+
617
+ #[test]
618
+ fn test_parse_stream_event_invalid_json_returns_err() {
619
+ let provider = CohereProvider;
620
+ let result = provider.parse_stream_event("not valid json");
621
+
622
+ assert!(result.is_err());
623
+ }
624
+
625
+ #[test]
626
+ fn test_parse_stream_event_tool_call_start_index_1() {
627
+ let provider = CohereProvider;
628
+ let event = r#"{"type":"tool-call-start","index":1,"delta":{"type":"tool_call","id":"tc-002","function":{"name":"search","arguments":""}}}"#;
629
+ let chunk = provider
630
+ .parse_stream_event(event)
631
+ .expect("should parse")
632
+ .expect("should return Some");
633
+
634
+ let tool_calls = chunk.choices[0]
635
+ .delta
636
+ .tool_calls
637
+ .as_ref()
638
+ .expect("should have tool_calls");
639
+ assert_eq!(tool_calls[0].index, 1);
640
+ assert_eq!(tool_calls[0].id.as_deref(), Some("tc-002"));
641
+ }
642
+
643
+ #[test]
644
+ fn test_parse_stream_event_stream_end_unknown_finish_reason() {
645
+ let provider = CohereProvider;
646
+ let event = r#"{"type":"stream-end","finish_reason":"ERROR"}"#;
647
+ let chunk = provider
648
+ .parse_stream_event(event)
649
+ .expect("should parse")
650
+ .expect("should return Some");
651
+
652
+ assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::Other));
653
+ }
654
+ }