liter_llm 1.0.0.pre.rc.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +239 -0
- data/ext/liter_llm_rb/extconf.rb +65 -0
- data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
- data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
- data/ext/liter_llm_rb/native/Cargo.toml +32 -0
- data/ext/liter_llm_rb/native/build.rs +15 -0
- data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
- data/lib/liter_llm.rb +8 -0
- data/sig/liter_llm.rbs +416 -0
- data/vendor/Cargo.toml +54 -0
- data/vendor/liter-llm/Cargo.toml +92 -0
- data/vendor/liter-llm/README.md +252 -0
- data/vendor/liter-llm/schemas/pricing.json +40 -0
- data/vendor/liter-llm/schemas/providers.json +1662 -0
- data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
- data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
- data/vendor/liter-llm/src/auth/mod.rs +68 -0
- data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
- data/vendor/liter-llm/src/client/config.rs +351 -0
- data/vendor/liter-llm/src/client/managed.rs +622 -0
- data/vendor/liter-llm/src/client/mod.rs +864 -0
- data/vendor/liter-llm/src/cost.rs +212 -0
- data/vendor/liter-llm/src/error.rs +190 -0
- data/vendor/liter-llm/src/http/eventstream.rs +860 -0
- data/vendor/liter-llm/src/http/mod.rs +12 -0
- data/vendor/liter-llm/src/http/request.rs +438 -0
- data/vendor/liter-llm/src/http/retry.rs +72 -0
- data/vendor/liter-llm/src/http/streaming.rs +289 -0
- data/vendor/liter-llm/src/lib.rs +37 -0
- data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
- data/vendor/liter-llm/src/provider/azure.rs +579 -0
- data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
- data/vendor/liter-llm/src/provider/cohere.rs +654 -0
- data/vendor/liter-llm/src/provider/custom.rs +404 -0
- data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
- data/vendor/liter-llm/src/provider/mistral.rs +188 -0
- data/vendor/liter-llm/src/provider/mod.rs +616 -0
- data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
- data/vendor/liter-llm/src/tests.rs +1425 -0
- data/vendor/liter-llm/src/tokenizer.rs +281 -0
- data/vendor/liter-llm/src/tower/budget.rs +599 -0
- data/vendor/liter-llm/src/tower/cache.rs +502 -0
- data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
- data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
- data/vendor/liter-llm/src/tower/cost.rs +404 -0
- data/vendor/liter-llm/src/tower/fallback.rs +121 -0
- data/vendor/liter-llm/src/tower/health.rs +219 -0
- data/vendor/liter-llm/src/tower/hooks.rs +369 -0
- data/vendor/liter-llm/src/tower/mod.rs +77 -0
- data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
- data/vendor/liter-llm/src/tower/router.rs +436 -0
- data/vendor/liter-llm/src/tower/service.rs +181 -0
- data/vendor/liter-llm/src/tower/tests.rs +539 -0
- data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
- data/vendor/liter-llm/src/tower/tracing.rs +209 -0
- data/vendor/liter-llm/src/tower/types.rs +170 -0
- data/vendor/liter-llm/src/types/audio.rs +52 -0
- data/vendor/liter-llm/src/types/batch.rs +77 -0
- data/vendor/liter-llm/src/types/chat.rs +214 -0
- data/vendor/liter-llm/src/types/common.rs +244 -0
- data/vendor/liter-llm/src/types/embedding.rs +84 -0
- data/vendor/liter-llm/src/types/files.rs +58 -0
- data/vendor/liter-llm/src/types/image.rs +40 -0
- data/vendor/liter-llm/src/types/mod.rs +27 -0
- data/vendor/liter-llm/src/types/models.rs +21 -0
- data/vendor/liter-llm/src/types/moderation.rs +80 -0
- data/vendor/liter-llm/src/types/ocr.rs +87 -0
- data/vendor/liter-llm/src/types/rerank.rs +46 -0
- data/vendor/liter-llm/src/types/responses.rs +55 -0
- data/vendor/liter-llm/src/types/search.rs +45 -0
- data/vendor/liter-llm/tests/contract.rs +332 -0
- data/vendor/liter-llm-ffi/Cargo.toml +30 -0
- data/vendor/liter-llm-ffi/build.rs +66 -0
- data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
- data/vendor/liter-llm-ffi/liter_llm.h +850 -0
- data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
- metadata +286 -0
|
@@ -0,0 +1,654 @@
|
|
|
1
|
+
use std::borrow::Cow;
|
|
2
|
+
|
|
3
|
+
use serde_json::Value;
|
|
4
|
+
|
|
5
|
+
use crate::error::{LiterLlmError, Result};
|
|
6
|
+
use crate::provider::{Provider, unix_timestamp_secs};
|
|
7
|
+
use crate::types::{ChatCompletionChunk, FinishReason, StreamChoice, StreamDelta, StreamFunctionCall, StreamToolCall};
|
|
8
|
+
|
|
9
|
+
/// Cohere provider (Command model family).
|
|
10
|
+
///
|
|
11
|
+
/// Differences from the OpenAI-compatible baseline:
|
|
12
|
+
/// - Chat endpoint is `/chat` instead of `/chat/completions`.
|
|
13
|
+
/// - Rerank endpoint is `/rerank` instead of the default path.
|
|
14
|
+
/// - `stream_options` is an OpenAI-specific field and must be stripped; `stream` is kept (Cohere v2 requires it).
|
|
15
|
+
/// - Finish reasons use Cohere-specific names (`COMPLETE`, `MAX_TOKENS`, `TOOL_CALL`).
|
|
16
|
+
/// - Usage is reported under `tokens.input_tokens` / `tokens.output_tokens`.
|
|
17
|
+
/// - Response may lack `object` and `created` fields.
|
|
18
|
+
pub struct CohereProvider;
|
|
19
|
+
|
|
20
|
+
impl Provider for CohereProvider {
|
|
21
|
+
fn name(&self) -> &str {
|
|
22
|
+
"cohere"
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
fn base_url(&self) -> &str {
|
|
26
|
+
"https://api.cohere.com/v2"
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
fn auth_header<'a>(&'a self, api_key: &'a str) -> Option<(Cow<'static, str>, Cow<'a, str>)> {
|
|
30
|
+
Some((Cow::Borrowed("Authorization"), Cow::Owned(format!("Bearer {api_key}"))))
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
fn matches_model(&self, model: &str) -> bool {
|
|
34
|
+
model.starts_with("command-r") || model.starts_with("command-") || model.starts_with("cohere/")
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
fn strip_model_prefix<'m>(&self, model: &'m str) -> &'m str {
|
|
38
|
+
model.strip_prefix("cohere/").unwrap_or(model)
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
/// Cohere uses `/chat` instead of `/chat/completions`.
|
|
42
|
+
fn chat_completions_path(&self) -> &str {
|
|
43
|
+
"/chat"
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
/// Cohere uses `/rerank` at the v2 base.
|
|
47
|
+
fn rerank_path(&self) -> &str {
|
|
48
|
+
"/rerank"
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
/// Strip transport-level parameters that Cohere does not accept in the body.
|
|
52
|
+
///
|
|
53
|
+
/// Note: Cohere v2 requires `stream` in the body, so only `stream_options`
|
|
54
|
+
/// (an OpenAI-specific field) is removed.
|
|
55
|
+
fn transform_request(&self, body: &mut Value) -> Result<()> {
|
|
56
|
+
if let Some(obj) = body.as_object_mut() {
|
|
57
|
+
obj.remove("stream_options");
|
|
58
|
+
}
|
|
59
|
+
Ok(())
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
/// Parse a Cohere v2 streaming SSE event into a `ChatCompletionChunk`.
|
|
63
|
+
///
|
|
64
|
+
/// Cohere v2 streaming events use a `type` field to distinguish event kinds:
|
|
65
|
+
/// - `stream-start`: beginning of stream, emit role = assistant
|
|
66
|
+
/// - `content-delta`: text content token, extract from `delta.text`
|
|
67
|
+
/// - `tool-call-start`: start of a tool call with id and function name
|
|
68
|
+
/// - `tool-call-delta`: partial tool call arguments
|
|
69
|
+
/// - `tool-call-end`: end of a tool call (skipped)
|
|
70
|
+
/// - `stream-end`: end of stream with finish reason and usage
|
|
71
|
+
fn parse_stream_event(&self, event_data: &str) -> Result<Option<ChatCompletionChunk>> {
|
|
72
|
+
let v: Value = serde_json::from_str(event_data).map_err(|e| LiterLlmError::Streaming {
|
|
73
|
+
message: format!("failed to parse Cohere SSE event: {e}"),
|
|
74
|
+
})?;
|
|
75
|
+
|
|
76
|
+
let event_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("");
|
|
77
|
+
|
|
78
|
+
match event_type {
|
|
79
|
+
"stream-start" => {
|
|
80
|
+
let id = v.get("generation_id").and_then(|g| g.as_str()).unwrap_or("").to_owned();
|
|
81
|
+
|
|
82
|
+
Ok(Some(ChatCompletionChunk {
|
|
83
|
+
id,
|
|
84
|
+
object: "chat.completion.chunk".to_owned(),
|
|
85
|
+
created: unix_timestamp_secs(),
|
|
86
|
+
model: String::new(),
|
|
87
|
+
choices: vec![StreamChoice {
|
|
88
|
+
index: 0,
|
|
89
|
+
delta: StreamDelta {
|
|
90
|
+
role: Some("assistant".to_owned()),
|
|
91
|
+
content: None,
|
|
92
|
+
tool_calls: None,
|
|
93
|
+
function_call: None,
|
|
94
|
+
refusal: None,
|
|
95
|
+
},
|
|
96
|
+
finish_reason: None,
|
|
97
|
+
}],
|
|
98
|
+
usage: None,
|
|
99
|
+
system_fingerprint: None,
|
|
100
|
+
service_tier: None,
|
|
101
|
+
}))
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
"content-delta" => {
|
|
105
|
+
let text = v
|
|
106
|
+
.pointer("/delta/text")
|
|
107
|
+
.and_then(|t| t.as_str())
|
|
108
|
+
.unwrap_or("")
|
|
109
|
+
.to_owned();
|
|
110
|
+
|
|
111
|
+
Ok(Some(ChatCompletionChunk {
|
|
112
|
+
id: String::new(),
|
|
113
|
+
object: "chat.completion.chunk".to_owned(),
|
|
114
|
+
created: unix_timestamp_secs(),
|
|
115
|
+
model: String::new(),
|
|
116
|
+
choices: vec![StreamChoice {
|
|
117
|
+
index: 0,
|
|
118
|
+
delta: StreamDelta {
|
|
119
|
+
role: None,
|
|
120
|
+
content: Some(text),
|
|
121
|
+
tool_calls: None,
|
|
122
|
+
function_call: None,
|
|
123
|
+
refusal: None,
|
|
124
|
+
},
|
|
125
|
+
finish_reason: None,
|
|
126
|
+
}],
|
|
127
|
+
usage: None,
|
|
128
|
+
system_fingerprint: None,
|
|
129
|
+
service_tier: None,
|
|
130
|
+
}))
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
"tool-call-start" => {
|
|
134
|
+
let index = v.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as u32;
|
|
135
|
+
let tool_id = v.pointer("/delta/id").and_then(|i| i.as_str()).unwrap_or("").to_owned();
|
|
136
|
+
let tool_name = v
|
|
137
|
+
.pointer("/delta/function/name")
|
|
138
|
+
.and_then(|n| n.as_str())
|
|
139
|
+
.unwrap_or("")
|
|
140
|
+
.to_owned();
|
|
141
|
+
|
|
142
|
+
Ok(Some(ChatCompletionChunk {
|
|
143
|
+
id: String::new(),
|
|
144
|
+
object: "chat.completion.chunk".to_owned(),
|
|
145
|
+
created: unix_timestamp_secs(),
|
|
146
|
+
model: String::new(),
|
|
147
|
+
choices: vec![StreamChoice {
|
|
148
|
+
index: 0,
|
|
149
|
+
delta: StreamDelta {
|
|
150
|
+
role: None,
|
|
151
|
+
content: None,
|
|
152
|
+
tool_calls: Some(vec![StreamToolCall {
|
|
153
|
+
index,
|
|
154
|
+
id: Some(tool_id),
|
|
155
|
+
call_type: Some(crate::types::ToolType::Function),
|
|
156
|
+
function: Some(StreamFunctionCall {
|
|
157
|
+
name: Some(tool_name),
|
|
158
|
+
arguments: None,
|
|
159
|
+
}),
|
|
160
|
+
}]),
|
|
161
|
+
function_call: None,
|
|
162
|
+
refusal: None,
|
|
163
|
+
},
|
|
164
|
+
finish_reason: None,
|
|
165
|
+
}],
|
|
166
|
+
usage: None,
|
|
167
|
+
system_fingerprint: None,
|
|
168
|
+
service_tier: None,
|
|
169
|
+
}))
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
"tool-call-delta" => {
|
|
173
|
+
let index = v.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as u32;
|
|
174
|
+
let arguments = v
|
|
175
|
+
.pointer("/delta/function/arguments")
|
|
176
|
+
.and_then(|a| a.as_str())
|
|
177
|
+
.unwrap_or("")
|
|
178
|
+
.to_owned();
|
|
179
|
+
|
|
180
|
+
Ok(Some(ChatCompletionChunk {
|
|
181
|
+
id: String::new(),
|
|
182
|
+
object: "chat.completion.chunk".to_owned(),
|
|
183
|
+
created: unix_timestamp_secs(),
|
|
184
|
+
model: String::new(),
|
|
185
|
+
choices: vec![StreamChoice {
|
|
186
|
+
index: 0,
|
|
187
|
+
delta: StreamDelta {
|
|
188
|
+
role: None,
|
|
189
|
+
content: None,
|
|
190
|
+
tool_calls: Some(vec![StreamToolCall {
|
|
191
|
+
index,
|
|
192
|
+
id: None,
|
|
193
|
+
call_type: None,
|
|
194
|
+
function: Some(StreamFunctionCall {
|
|
195
|
+
name: None,
|
|
196
|
+
arguments: Some(arguments),
|
|
197
|
+
}),
|
|
198
|
+
}]),
|
|
199
|
+
function_call: None,
|
|
200
|
+
refusal: None,
|
|
201
|
+
},
|
|
202
|
+
finish_reason: None,
|
|
203
|
+
}],
|
|
204
|
+
usage: None,
|
|
205
|
+
system_fingerprint: None,
|
|
206
|
+
service_tier: None,
|
|
207
|
+
}))
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
"tool-call-end" => Ok(None),
|
|
211
|
+
|
|
212
|
+
"stream-end" => {
|
|
213
|
+
let finish_reason = v
|
|
214
|
+
.get("finish_reason")
|
|
215
|
+
.and_then(|r| r.as_str())
|
|
216
|
+
.map(map_cohere_finish_reason);
|
|
217
|
+
|
|
218
|
+
let usage = extract_cohere_stream_usage(&v);
|
|
219
|
+
|
|
220
|
+
Ok(Some(ChatCompletionChunk {
|
|
221
|
+
id: String::new(),
|
|
222
|
+
object: "chat.completion.chunk".to_owned(),
|
|
223
|
+
created: unix_timestamp_secs(),
|
|
224
|
+
model: String::new(),
|
|
225
|
+
choices: vec![StreamChoice {
|
|
226
|
+
index: 0,
|
|
227
|
+
delta: StreamDelta {
|
|
228
|
+
role: None,
|
|
229
|
+
content: None,
|
|
230
|
+
tool_calls: None,
|
|
231
|
+
function_call: None,
|
|
232
|
+
refusal: None,
|
|
233
|
+
},
|
|
234
|
+
finish_reason,
|
|
235
|
+
}],
|
|
236
|
+
usage,
|
|
237
|
+
system_fingerprint: None,
|
|
238
|
+
service_tier: None,
|
|
239
|
+
}))
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
// Unknown event types are silently skipped.
|
|
243
|
+
_ => Ok(None),
|
|
244
|
+
}
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
/// Normalize Cohere response format to OpenAI-compatible JSON.
|
|
248
|
+
///
|
|
249
|
+
/// - Maps finish reasons: `COMPLETE` -> `stop`, `MAX_TOKENS` -> `length`,
|
|
250
|
+
/// `TOOL_CALL` -> `tool_calls`.
|
|
251
|
+
/// - Normalizes usage from `tokens.{input,output}_tokens` to
|
|
252
|
+
/// `usage.{prompt,completion,total}_tokens`.
|
|
253
|
+
/// - Ensures `object` and `created` fields are present.
|
|
254
|
+
fn transform_response(&self, body: &mut Value) -> Result<()> {
|
|
255
|
+
// Map finish reasons in choices.
|
|
256
|
+
if let Some(choices) = body.get_mut("choices").and_then(Value::as_array_mut) {
|
|
257
|
+
for choice in choices {
|
|
258
|
+
if let Some(reason) = choice.get("finish_reason").and_then(Value::as_str) {
|
|
259
|
+
let mapped = match reason {
|
|
260
|
+
"COMPLETE" => "stop",
|
|
261
|
+
"MAX_TOKENS" => "length",
|
|
262
|
+
"TOOL_CALL" => "tool_calls",
|
|
263
|
+
other => other,
|
|
264
|
+
};
|
|
265
|
+
choice["finish_reason"] = Value::String(mapped.to_owned());
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Normalize usage from Cohere's `tokens` format.
|
|
271
|
+
if body.get("usage").is_none()
|
|
272
|
+
&& let Some(tokens) = body.get("tokens")
|
|
273
|
+
{
|
|
274
|
+
let input = tokens.get("input_tokens").and_then(Value::as_u64).unwrap_or(0);
|
|
275
|
+
let output = tokens.get("output_tokens").and_then(Value::as_u64).unwrap_or(0);
|
|
276
|
+
body["usage"] = serde_json::json!({
|
|
277
|
+
"prompt_tokens": input,
|
|
278
|
+
"completion_tokens": output,
|
|
279
|
+
"total_tokens": input + output,
|
|
280
|
+
});
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
// Ensure standard OpenAI fields are present.
|
|
284
|
+
if body.get("object").is_none() {
|
|
285
|
+
body["object"] = Value::String("chat.completion".to_owned());
|
|
286
|
+
}
|
|
287
|
+
if body.get("created").is_none() {
|
|
288
|
+
body["created"] = Value::Number(unix_timestamp_secs().into());
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
Ok(())
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
/// Map Cohere finish reason strings to OpenAI-compatible `FinishReason`.
|
|
296
|
+
fn map_cohere_finish_reason(reason: &str) -> FinishReason {
|
|
297
|
+
match reason {
|
|
298
|
+
"COMPLETE" => FinishReason::Stop,
|
|
299
|
+
"MAX_TOKENS" => FinishReason::Length,
|
|
300
|
+
"TOOL_CALL" => FinishReason::ToolCalls,
|
|
301
|
+
_ => FinishReason::Other,
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
/// Extract usage from a Cohere `stream-end` event.
|
|
306
|
+
///
|
|
307
|
+
/// Cohere v2 reports usage under `usage.billed_units.{input_tokens, output_tokens}`.
|
|
308
|
+
fn extract_cohere_stream_usage(v: &Value) -> Option<crate::types::Usage> {
|
|
309
|
+
let billed = v.pointer("/usage/billed_units")?;
|
|
310
|
+
let input = billed.get("input_tokens").and_then(|t| t.as_u64()).unwrap_or(0);
|
|
311
|
+
let output = billed.get("output_tokens").and_then(|t| t.as_u64()).unwrap_or(0);
|
|
312
|
+
|
|
313
|
+
Some(crate::types::Usage {
|
|
314
|
+
prompt_tokens: input,
|
|
315
|
+
completion_tokens: output,
|
|
316
|
+
total_tokens: input + output,
|
|
317
|
+
})
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
#[cfg(test)]
|
|
321
|
+
mod tests {
|
|
322
|
+
use serde_json::json;
|
|
323
|
+
|
|
324
|
+
use super::*;
|
|
325
|
+
|
|
326
|
+
#[test]
|
|
327
|
+
fn test_cohere_name_and_base_url() {
|
|
328
|
+
let provider = CohereProvider;
|
|
329
|
+
assert_eq!(provider.name(), "cohere");
|
|
330
|
+
assert_eq!(provider.base_url(), "https://api.cohere.com/v2");
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
#[test]
|
|
334
|
+
fn test_cohere_auth_header() {
|
|
335
|
+
let provider = CohereProvider;
|
|
336
|
+
let (name, value) = provider.auth_header("test-key").expect("should return auth header");
|
|
337
|
+
assert_eq!(name, "Authorization");
|
|
338
|
+
assert_eq!(value, "Bearer test-key");
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
#[test]
|
|
342
|
+
fn test_cohere_matches_model() {
|
|
343
|
+
let provider = CohereProvider;
|
|
344
|
+
assert!(provider.matches_model("command-r-plus"));
|
|
345
|
+
assert!(provider.matches_model("command-r"));
|
|
346
|
+
assert!(provider.matches_model("command-light"));
|
|
347
|
+
assert!(provider.matches_model("cohere/command-r-plus"));
|
|
348
|
+
assert!(!provider.matches_model("gpt-4"));
|
|
349
|
+
assert!(!provider.matches_model("claude-3"));
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
#[test]
|
|
353
|
+
fn test_cohere_strip_prefix() {
|
|
354
|
+
let provider = CohereProvider;
|
|
355
|
+
assert_eq!(provider.strip_model_prefix("cohere/command-r"), "command-r");
|
|
356
|
+
assert_eq!(provider.strip_model_prefix("command-r"), "command-r");
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
#[test]
|
|
360
|
+
fn test_cohere_endpoints() {
|
|
361
|
+
let provider = CohereProvider;
|
|
362
|
+
assert_eq!(provider.chat_completions_path(), "/chat");
|
|
363
|
+
assert_eq!(provider.rerank_path(), "/rerank");
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
#[test]
|
|
367
|
+
fn test_cohere_transform_request_preserves_stream_strips_options() {
|
|
368
|
+
let provider = CohereProvider;
|
|
369
|
+
let mut body = json!({
|
|
370
|
+
"model": "command-r-plus",
|
|
371
|
+
"messages": [{"role": "user", "content": "hello"}],
|
|
372
|
+
"stream": true,
|
|
373
|
+
"stream_options": {"include_usage": true}
|
|
374
|
+
});
|
|
375
|
+
provider.transform_request(&mut body).expect("transform should succeed");
|
|
376
|
+
// Cohere v2 needs `stream` in the body — only `stream_options` is removed.
|
|
377
|
+
assert_eq!(body["stream"], true);
|
|
378
|
+
assert!(body.get("stream_options").is_none());
|
|
379
|
+
// Other fields preserved.
|
|
380
|
+
assert_eq!(body["model"], "command-r-plus");
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
#[test]
|
|
384
|
+
fn test_cohere_transform_response_finish_reasons() {
|
|
385
|
+
let provider = CohereProvider;
|
|
386
|
+
let mut body = json!({
|
|
387
|
+
"choices": [
|
|
388
|
+
{"finish_reason": "COMPLETE", "message": {"content": "hi"}},
|
|
389
|
+
{"finish_reason": "MAX_TOKENS", "message": {"content": "..."}},
|
|
390
|
+
{"finish_reason": "TOOL_CALL", "message": {"content": ""}}
|
|
391
|
+
]
|
|
392
|
+
});
|
|
393
|
+
provider
|
|
394
|
+
.transform_response(&mut body)
|
|
395
|
+
.expect("transform should succeed");
|
|
396
|
+
|
|
397
|
+
let choices = body["choices"].as_array().expect("choices array");
|
|
398
|
+
assert_eq!(choices[0]["finish_reason"], "stop");
|
|
399
|
+
assert_eq!(choices[1]["finish_reason"], "length");
|
|
400
|
+
assert_eq!(choices[2]["finish_reason"], "tool_calls");
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
#[test]
|
|
404
|
+
fn test_cohere_transform_response_usage_normalization() {
|
|
405
|
+
let provider = CohereProvider;
|
|
406
|
+
let mut body = json!({
|
|
407
|
+
"choices": [{"finish_reason": "COMPLETE"}],
|
|
408
|
+
"tokens": {
|
|
409
|
+
"input_tokens": 10,
|
|
410
|
+
"output_tokens": 20
|
|
411
|
+
}
|
|
412
|
+
});
|
|
413
|
+
provider
|
|
414
|
+
.transform_response(&mut body)
|
|
415
|
+
.expect("transform should succeed");
|
|
416
|
+
|
|
417
|
+
let usage = &body["usage"];
|
|
418
|
+
assert_eq!(usage["prompt_tokens"], 10);
|
|
419
|
+
assert_eq!(usage["completion_tokens"], 20);
|
|
420
|
+
assert_eq!(usage["total_tokens"], 30);
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
#[test]
|
|
424
|
+
fn test_cohere_transform_response_adds_object_and_created() {
|
|
425
|
+
let provider = CohereProvider;
|
|
426
|
+
let mut body = json!({"choices": []});
|
|
427
|
+
provider
|
|
428
|
+
.transform_response(&mut body)
|
|
429
|
+
.expect("transform should succeed");
|
|
430
|
+
|
|
431
|
+
assert_eq!(body["object"], "chat.completion");
|
|
432
|
+
assert!(body["created"].as_u64().is_some());
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
#[test]
|
|
436
|
+
fn test_cohere_transform_response_preserves_existing_usage() {
|
|
437
|
+
let provider = CohereProvider;
|
|
438
|
+
let mut body = json!({
|
|
439
|
+
"choices": [],
|
|
440
|
+
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15},
|
|
441
|
+
"tokens": {"input_tokens": 99, "output_tokens": 99}
|
|
442
|
+
});
|
|
443
|
+
provider
|
|
444
|
+
.transform_response(&mut body)
|
|
445
|
+
.expect("transform should succeed");
|
|
446
|
+
|
|
447
|
+
// Existing usage should not be overwritten.
|
|
448
|
+
assert_eq!(body["usage"]["prompt_tokens"], 5);
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
// ── Streaming SSE parser tests ───────────────────────────────────────────
|
|
452
|
+
|
|
453
|
+
#[test]
|
|
454
|
+
fn test_parse_stream_event_stream_start() {
|
|
455
|
+
let provider = CohereProvider;
|
|
456
|
+
let event = r#"{"type":"stream-start","generation_id":"gen-123"}"#;
|
|
457
|
+
let chunk = provider
|
|
458
|
+
.parse_stream_event(event)
|
|
459
|
+
.expect("should parse")
|
|
460
|
+
.expect("should return Some");
|
|
461
|
+
|
|
462
|
+
assert_eq!(chunk.id, "gen-123");
|
|
463
|
+
assert_eq!(chunk.object, "chat.completion.chunk");
|
|
464
|
+
assert_eq!(chunk.choices.len(), 1);
|
|
465
|
+
assert_eq!(chunk.choices[0].delta.role.as_deref(), Some("assistant"));
|
|
466
|
+
assert!(chunk.choices[0].delta.content.is_none());
|
|
467
|
+
assert!(chunk.choices[0].finish_reason.is_none());
|
|
468
|
+
assert!(chunk.usage.is_none());
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
#[test]
|
|
472
|
+
fn test_parse_stream_event_content_delta() {
|
|
473
|
+
let provider = CohereProvider;
|
|
474
|
+
let event = r#"{"type":"content-delta","delta":{"type":"text_content","text":"Hello"}}"#;
|
|
475
|
+
let chunk = provider
|
|
476
|
+
.parse_stream_event(event)
|
|
477
|
+
.expect("should parse")
|
|
478
|
+
.expect("should return Some");
|
|
479
|
+
|
|
480
|
+
assert_eq!(chunk.choices[0].delta.content.as_deref(), Some("Hello"));
|
|
481
|
+
assert!(chunk.choices[0].delta.role.is_none());
|
|
482
|
+
assert!(chunk.choices[0].delta.tool_calls.is_none());
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
#[test]
|
|
486
|
+
fn test_parse_stream_event_content_delta_whitespace() {
|
|
487
|
+
let provider = CohereProvider;
|
|
488
|
+
let event = r#"{"type":"content-delta","delta":{"type":"text_content","text":" world"}}"#;
|
|
489
|
+
let chunk = provider
|
|
490
|
+
.parse_stream_event(event)
|
|
491
|
+
.expect("should parse")
|
|
492
|
+
.expect("should return Some");
|
|
493
|
+
|
|
494
|
+
assert_eq!(chunk.choices[0].delta.content.as_deref(), Some(" world"));
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
#[test]
|
|
498
|
+
fn test_parse_stream_event_tool_call_start() {
|
|
499
|
+
let provider = CohereProvider;
|
|
500
|
+
let event = r#"{"type":"tool-call-start","index":0,"delta":{"type":"tool_call","id":"tc-001","function":{"name":"get_weather","arguments":""}}}"#;
|
|
501
|
+
let chunk = provider
|
|
502
|
+
.parse_stream_event(event)
|
|
503
|
+
.expect("should parse")
|
|
504
|
+
.expect("should return Some");
|
|
505
|
+
|
|
506
|
+
let tool_calls = chunk.choices[0]
|
|
507
|
+
.delta
|
|
508
|
+
.tool_calls
|
|
509
|
+
.as_ref()
|
|
510
|
+
.expect("should have tool_calls");
|
|
511
|
+
assert_eq!(tool_calls.len(), 1);
|
|
512
|
+
assert_eq!(tool_calls[0].index, 0);
|
|
513
|
+
assert_eq!(tool_calls[0].id.as_deref(), Some("tc-001"));
|
|
514
|
+
let func = tool_calls[0].function.as_ref().expect("should have function");
|
|
515
|
+
assert_eq!(func.name.as_deref(), Some("get_weather"));
|
|
516
|
+
assert!(func.arguments.is_none());
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
#[test]
|
|
520
|
+
fn test_parse_stream_event_tool_call_delta() {
|
|
521
|
+
let provider = CohereProvider;
|
|
522
|
+
let event =
|
|
523
|
+
r#"{"type":"tool-call-delta","index":0,"delta":{"type":"tool_call","function":{"arguments":"{\"ci"}}}"#;
|
|
524
|
+
let chunk = provider
|
|
525
|
+
.parse_stream_event(event)
|
|
526
|
+
.expect("should parse")
|
|
527
|
+
.expect("should return Some");
|
|
528
|
+
|
|
529
|
+
let tool_calls = chunk.choices[0]
|
|
530
|
+
.delta
|
|
531
|
+
.tool_calls
|
|
532
|
+
.as_ref()
|
|
533
|
+
.expect("should have tool_calls");
|
|
534
|
+
assert_eq!(tool_calls.len(), 1);
|
|
535
|
+
assert_eq!(tool_calls[0].index, 0);
|
|
536
|
+
assert!(tool_calls[0].id.is_none());
|
|
537
|
+
let func = tool_calls[0].function.as_ref().expect("should have function");
|
|
538
|
+
assert!(func.name.is_none());
|
|
539
|
+
assert_eq!(func.arguments.as_deref(), Some("{\"ci"));
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
#[test]
|
|
543
|
+
fn test_parse_stream_event_tool_call_end_returns_none() {
|
|
544
|
+
let provider = CohereProvider;
|
|
545
|
+
let event = r#"{"type":"tool-call-end","index":0}"#;
|
|
546
|
+
let result = provider.parse_stream_event(event).expect("should parse");
|
|
547
|
+
|
|
548
|
+
assert!(result.is_none());
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
#[test]
|
|
552
|
+
fn test_parse_stream_event_stream_end_complete() {
|
|
553
|
+
let provider = CohereProvider;
|
|
554
|
+
let event = r#"{"type":"stream-end","finish_reason":"COMPLETE","usage":{"billed_units":{"input_tokens":10,"output_tokens":5}}}"#;
|
|
555
|
+
let chunk = provider
|
|
556
|
+
.parse_stream_event(event)
|
|
557
|
+
.expect("should parse")
|
|
558
|
+
.expect("should return Some");
|
|
559
|
+
|
|
560
|
+
assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::Stop));
|
|
561
|
+
let usage = chunk.usage.as_ref().expect("should have usage");
|
|
562
|
+
assert_eq!(usage.prompt_tokens, 10);
|
|
563
|
+
assert_eq!(usage.completion_tokens, 5);
|
|
564
|
+
assert_eq!(usage.total_tokens, 15);
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
#[test]
|
|
568
|
+
fn test_parse_stream_event_stream_end_max_tokens() {
|
|
569
|
+
let provider = CohereProvider;
|
|
570
|
+
let event = r#"{"type":"stream-end","finish_reason":"MAX_TOKENS","usage":{"billed_units":{"input_tokens":20,"output_tokens":100}}}"#;
|
|
571
|
+
let chunk = provider
|
|
572
|
+
.parse_stream_event(event)
|
|
573
|
+
.expect("should parse")
|
|
574
|
+
.expect("should return Some");
|
|
575
|
+
|
|
576
|
+
assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::Length));
|
|
577
|
+
let usage = chunk.usage.as_ref().expect("should have usage");
|
|
578
|
+
assert_eq!(usage.prompt_tokens, 20);
|
|
579
|
+
assert_eq!(usage.completion_tokens, 100);
|
|
580
|
+
assert_eq!(usage.total_tokens, 120);
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
#[test]
|
|
584
|
+
fn test_parse_stream_event_stream_end_tool_call() {
|
|
585
|
+
let provider = CohereProvider;
|
|
586
|
+
let event = r#"{"type":"stream-end","finish_reason":"TOOL_CALL","usage":{"billed_units":{"input_tokens":15,"output_tokens":8}}}"#;
|
|
587
|
+
let chunk = provider
|
|
588
|
+
.parse_stream_event(event)
|
|
589
|
+
.expect("should parse")
|
|
590
|
+
.expect("should return Some");
|
|
591
|
+
|
|
592
|
+
assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::ToolCalls));
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
#[test]
|
|
596
|
+
fn test_parse_stream_event_stream_end_no_usage() {
|
|
597
|
+
let provider = CohereProvider;
|
|
598
|
+
let event = r#"{"type":"stream-end","finish_reason":"COMPLETE"}"#;
|
|
599
|
+
let chunk = provider
|
|
600
|
+
.parse_stream_event(event)
|
|
601
|
+
.expect("should parse")
|
|
602
|
+
.expect("should return Some");
|
|
603
|
+
|
|
604
|
+
assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::Stop));
|
|
605
|
+
assert!(chunk.usage.is_none());
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
#[test]
|
|
609
|
+
fn test_parse_stream_event_unknown_type_returns_none() {
|
|
610
|
+
let provider = CohereProvider;
|
|
611
|
+
let event = r#"{"type":"debug","message":"some debug info"}"#;
|
|
612
|
+
let result = provider.parse_stream_event(event).expect("should parse");
|
|
613
|
+
|
|
614
|
+
assert!(result.is_none());
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
#[test]
|
|
618
|
+
fn test_parse_stream_event_invalid_json_returns_err() {
|
|
619
|
+
let provider = CohereProvider;
|
|
620
|
+
let result = provider.parse_stream_event("not valid json");
|
|
621
|
+
|
|
622
|
+
assert!(result.is_err());
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
#[test]
|
|
626
|
+
fn test_parse_stream_event_tool_call_start_index_1() {
|
|
627
|
+
let provider = CohereProvider;
|
|
628
|
+
let event = r#"{"type":"tool-call-start","index":1,"delta":{"type":"tool_call","id":"tc-002","function":{"name":"search","arguments":""}}}"#;
|
|
629
|
+
let chunk = provider
|
|
630
|
+
.parse_stream_event(event)
|
|
631
|
+
.expect("should parse")
|
|
632
|
+
.expect("should return Some");
|
|
633
|
+
|
|
634
|
+
let tool_calls = chunk.choices[0]
|
|
635
|
+
.delta
|
|
636
|
+
.tool_calls
|
|
637
|
+
.as_ref()
|
|
638
|
+
.expect("should have tool_calls");
|
|
639
|
+
assert_eq!(tool_calls[0].index, 1);
|
|
640
|
+
assert_eq!(tool_calls[0].id.as_deref(), Some("tc-002"));
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
#[test]
|
|
644
|
+
fn test_parse_stream_event_stream_end_unknown_finish_reason() {
|
|
645
|
+
let provider = CohereProvider;
|
|
646
|
+
let event = r#"{"type":"stream-end","finish_reason":"ERROR"}"#;
|
|
647
|
+
let chunk = provider
|
|
648
|
+
.parse_stream_event(event)
|
|
649
|
+
.expect("should parse")
|
|
650
|
+
.expect("should return Some");
|
|
651
|
+
|
|
652
|
+
assert_eq!(chunk.choices[0].finish_reason, Some(FinishReason::Other));
|
|
653
|
+
}
|
|
654
|
+
}
|