@mmmbuto/anthmorph 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.lock +1935 -0
- package/Cargo.toml +61 -0
- package/LICENSE +21 -0
- package/README.md +184 -0
- package/bin/anthmorph +22 -0
- package/package.json +52 -0
- package/scripts/anthmorphctl +456 -0
- package/scripts/postinstall.js +23 -0
- package/scripts/smoke_test.sh +72 -0
- package/src/config.rs +39 -0
- package/src/error.rs +54 -0
- package/src/main.rs +120 -0
- package/src/models/anthropic.rs +274 -0
- package/src/models/mod.rs +2 -0
- package/src/models/openai.rs +230 -0
- package/src/proxy.rs +829 -0
- package/src/transform.rs +460 -0
- package/tests/real_backends.rs +213 -0
package/src/proxy.rs
ADDED
|
@@ -0,0 +1,829 @@
|
|
|
1
|
+
use crate::config::BackendProfile;
|
|
2
|
+
use crate::error::{ProxyError, ProxyResult};
|
|
3
|
+
use crate::models::{anthropic, openai};
|
|
4
|
+
use crate::transform::{self, generate_message_id};
|
|
5
|
+
use axum::{
|
|
6
|
+
body::Body,
|
|
7
|
+
http::{header, HeaderMap, HeaderName, HeaderValue},
|
|
8
|
+
response::{IntoResponse, Response},
|
|
9
|
+
Extension, Json,
|
|
10
|
+
};
|
|
11
|
+
use bytes::Bytes;
|
|
12
|
+
use futures::stream::{Stream, StreamExt};
|
|
13
|
+
use reqwest::Client;
|
|
14
|
+
use serde_json::json;
|
|
15
|
+
use std::collections::BTreeMap;
|
|
16
|
+
use std::fmt;
|
|
17
|
+
use std::sync::Arc;
|
|
18
|
+
use std::time::Duration;
|
|
19
|
+
use tokio::pin;
|
|
20
|
+
use tower_http::cors::{AllowOrigin, CorsLayer};
|
|
21
|
+
|
|
22
|
+
fn map_model(client_model: &str, config: &Config) -> String {
|
|
23
|
+
match client_model {
|
|
24
|
+
m if m.is_empty() || m == "default" => config.model.clone(),
|
|
25
|
+
m if m.starts_with("claude-") => config.model.clone(),
|
|
26
|
+
other => other.to_string(),
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
pub async fn proxy_handler(
|
|
31
|
+
headers: HeaderMap,
|
|
32
|
+
Extension(config): Extension<Arc<Config>>,
|
|
33
|
+
Extension(client): Extension<Client>,
|
|
34
|
+
Json(req): Json<anthropic::AnthropicRequest>,
|
|
35
|
+
) -> ProxyResult<Response> {
|
|
36
|
+
authorize_request(&headers, &config)?;
|
|
37
|
+
|
|
38
|
+
let is_streaming = req.stream.unwrap_or(false);
|
|
39
|
+
|
|
40
|
+
tracing::debug!("Received request for model: {}", req.model);
|
|
41
|
+
tracing::debug!("Messages count: {}", req.messages.len());
|
|
42
|
+
for (i, msg) in req.messages.iter().enumerate() {
|
|
43
|
+
let content_type = match &msg.content {
|
|
44
|
+
anthropic::MessageContent::Text(_) => "Text",
|
|
45
|
+
anthropic::MessageContent::Blocks(blocks) => {
|
|
46
|
+
if blocks.is_empty() {
|
|
47
|
+
"empty_blocks"
|
|
48
|
+
} else {
|
|
49
|
+
match &blocks[0] {
|
|
50
|
+
anthropic::ContentBlock::Text { .. } => "text_block",
|
|
51
|
+
anthropic::ContentBlock::Image { .. } => "image_block",
|
|
52
|
+
anthropic::ContentBlock::ToolUse { .. } => "tool_use_block",
|
|
53
|
+
anthropic::ContentBlock::ToolResult { .. } => "tool_result_block",
|
|
54
|
+
anthropic::ContentBlock::Thinking { .. } => "thinking_block",
|
|
55
|
+
anthropic::ContentBlock::Other => "unknown_block",
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
};
|
|
60
|
+
tracing::debug!("Message {}: role={}, content={}", i, msg.role, content_type);
|
|
61
|
+
}
|
|
62
|
+
tracing::debug!("Streaming: {}", is_streaming);
|
|
63
|
+
|
|
64
|
+
let model = if req
|
|
65
|
+
.extra
|
|
66
|
+
.get("thinking")
|
|
67
|
+
.and_then(|v| v.get("type"))
|
|
68
|
+
.is_some()
|
|
69
|
+
{
|
|
70
|
+
config
|
|
71
|
+
.reasoning_model
|
|
72
|
+
.clone()
|
|
73
|
+
.unwrap_or_else(|| config.model.clone())
|
|
74
|
+
} else {
|
|
75
|
+
map_model(&req.model, &config)
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
let openai_req = transform::anthropic_to_openai(req, &model, config.backend_profile)?;
|
|
79
|
+
|
|
80
|
+
if is_streaming {
|
|
81
|
+
handle_streaming(config, client, openai_req).await
|
|
82
|
+
} else {
|
|
83
|
+
handle_non_streaming(config, client, openai_req).await
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
async fn handle_non_streaming(
|
|
88
|
+
config: Arc<Config>,
|
|
89
|
+
client: Client,
|
|
90
|
+
openai_req: openai::OpenAIRequest,
|
|
91
|
+
) -> ProxyResult<Response> {
|
|
92
|
+
let url = config.chat_completions_url();
|
|
93
|
+
tracing::debug!("Sending non-streaming request to {}", url);
|
|
94
|
+
|
|
95
|
+
let mut req_builder = client
|
|
96
|
+
.post(&url)
|
|
97
|
+
.json(&openai_req)
|
|
98
|
+
.timeout(Duration::from_secs(300));
|
|
99
|
+
|
|
100
|
+
if let Some(api_key) = &config.api_key {
|
|
101
|
+
req_builder = req_builder.header("Authorization", format!("Bearer {}", api_key));
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
let response = req_builder.send().await.map_err(|err| {
|
|
105
|
+
tracing::error!("Failed to send request to {}: {:?}", url, err);
|
|
106
|
+
ProxyError::Http(err)
|
|
107
|
+
})?;
|
|
108
|
+
|
|
109
|
+
if !response.status().is_success() {
|
|
110
|
+
let status = response.status();
|
|
111
|
+
let error_text = response
|
|
112
|
+
.text()
|
|
113
|
+
.await
|
|
114
|
+
.unwrap_or_else(|_| "Unknown error".to_string());
|
|
115
|
+
tracing::error!("Upstream error ({}): {}", status, error_text);
|
|
116
|
+
return Err(ProxyError::Upstream(format!(
|
|
117
|
+
"Upstream returned {}: {}",
|
|
118
|
+
status, error_text
|
|
119
|
+
)));
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
let openai_resp: openai::OpenAIResponse = response.json().await?;
|
|
123
|
+
let anthropic_resp =
|
|
124
|
+
transform::openai_to_anthropic(openai_resp, &openai_req.model, config.backend_profile)?;
|
|
125
|
+
|
|
126
|
+
Ok(Json(anthropic_resp).into_response())
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
async fn handle_streaming(
|
|
130
|
+
config: Arc<Config>,
|
|
131
|
+
client: Client,
|
|
132
|
+
openai_req: openai::OpenAIRequest,
|
|
133
|
+
) -> ProxyResult<Response> {
|
|
134
|
+
let url = config.chat_completions_url();
|
|
135
|
+
tracing::debug!("Sending streaming request to {}", url);
|
|
136
|
+
|
|
137
|
+
let mut req_builder = client
|
|
138
|
+
.post(&url)
|
|
139
|
+
.json(&openai_req)
|
|
140
|
+
.timeout(Duration::from_secs(300));
|
|
141
|
+
|
|
142
|
+
if let Some(api_key) = &config.api_key {
|
|
143
|
+
req_builder = req_builder.header("Authorization", format!("Bearer {}", api_key));
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
let response = req_builder.send().await.map_err(|err| {
|
|
147
|
+
tracing::error!("Failed to send streaming request: {:?}", err);
|
|
148
|
+
ProxyError::Http(err)
|
|
149
|
+
})?;
|
|
150
|
+
|
|
151
|
+
if !response.status().is_success() {
|
|
152
|
+
let status = response.status();
|
|
153
|
+
let error_text = response
|
|
154
|
+
.text()
|
|
155
|
+
.await
|
|
156
|
+
.unwrap_or_else(|_| "Unknown error".to_string());
|
|
157
|
+
tracing::error!("Upstream streaming error ({}): {}", status, error_text);
|
|
158
|
+
return Err(ProxyError::Upstream(format!(
|
|
159
|
+
"Upstream returned {}: {}",
|
|
160
|
+
status, error_text
|
|
161
|
+
)));
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
let stream = response.bytes_stream();
|
|
165
|
+
let sse_stream = create_sse_stream(stream, openai_req.model.clone(), config.backend_profile);
|
|
166
|
+
|
|
167
|
+
let mut headers = HeaderMap::new();
|
|
168
|
+
headers.insert(
|
|
169
|
+
"Content-Type",
|
|
170
|
+
HeaderValue::from_static("text/event-stream"),
|
|
171
|
+
);
|
|
172
|
+
headers.insert("Cache-Control", HeaderValue::from_static("no-cache"));
|
|
173
|
+
headers.insert("Connection", HeaderValue::from_static("keep-alive"));
|
|
174
|
+
|
|
175
|
+
Ok((headers, Body::from_stream(sse_stream)).into_response())
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
fn create_sse_stream(
|
|
179
|
+
stream: impl Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static,
|
|
180
|
+
fallback_model: String,
|
|
181
|
+
profile: BackendProfile,
|
|
182
|
+
) -> impl Stream<Item = Result<Bytes, std::io::Error>> + Send {
|
|
183
|
+
async_stream::stream! {
|
|
184
|
+
let mut buffer = String::new();
|
|
185
|
+
let mut message_id = None;
|
|
186
|
+
let mut current_model = None;
|
|
187
|
+
let mut next_content_index = 0usize;
|
|
188
|
+
let mut has_sent_message_start = false;
|
|
189
|
+
let mut active_block: Option<ActiveBlock> = None;
|
|
190
|
+
let mut tool_states: BTreeMap<usize, ToolCallState> = BTreeMap::new();
|
|
191
|
+
|
|
192
|
+
pin!(stream);
|
|
193
|
+
|
|
194
|
+
let mut raw_buffer: Vec<u8> = Vec::new();
|
|
195
|
+
|
|
196
|
+
while let Some(chunk) = stream.next().await {
|
|
197
|
+
match chunk {
|
|
198
|
+
Ok(bytes) => {
|
|
199
|
+
raw_buffer.extend_from_slice(&bytes);
|
|
200
|
+
|
|
201
|
+
loop {
|
|
202
|
+
match std::str::from_utf8(&raw_buffer) {
|
|
203
|
+
Ok(text) => {
|
|
204
|
+
buffer.push_str(&text.replace("\r\n", "\n"));
|
|
205
|
+
raw_buffer.clear();
|
|
206
|
+
break;
|
|
207
|
+
}
|
|
208
|
+
Err(e) => {
|
|
209
|
+
let valid_up_to = e.valid_up_to();
|
|
210
|
+
if valid_up_to > 0 {
|
|
211
|
+
let partial = std::str::from_utf8(&raw_buffer[..valid_up_to]).unwrap();
|
|
212
|
+
buffer.push_str(&partial.replace("\r\n", "\n"));
|
|
213
|
+
raw_buffer = raw_buffer[valid_up_to..].to_vec();
|
|
214
|
+
}
|
|
215
|
+
if raw_buffer.is_empty() || valid_up_to == 0 {
|
|
216
|
+
break;
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
while let Some(pos) = buffer.find("\n\n") {
|
|
223
|
+
let event_block = buffer[..pos].to_string();
|
|
224
|
+
buffer = buffer[pos + 2..].to_string();
|
|
225
|
+
|
|
226
|
+
if event_block.trim().is_empty() {
|
|
227
|
+
continue;
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
let Some(data) = extract_sse_data(&event_block) else {
|
|
231
|
+
continue;
|
|
232
|
+
};
|
|
233
|
+
|
|
234
|
+
if data.trim() == "[DONE]" {
|
|
235
|
+
yield Ok(Bytes::from(message_stop_sse()));
|
|
236
|
+
continue;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
if let Ok(chunk) = serde_json::from_str::<openai::StreamChunk>(&data) {
|
|
240
|
+
if message_id.is_none() {
|
|
241
|
+
if let Some(id) = &chunk.id {
|
|
242
|
+
message_id = Some(id.clone());
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
if current_model.is_none() {
|
|
246
|
+
if let Some(model) = &chunk.model {
|
|
247
|
+
current_model = Some(model.clone());
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
if let Some(choice) = chunk.choices.first() {
|
|
252
|
+
if !has_sent_message_start {
|
|
253
|
+
let event = anthropic::StreamEvent::MessageStart {
|
|
254
|
+
message: anthropic::MessageStartData {
|
|
255
|
+
id: message_id.clone().unwrap_or_else(generate_message_id),
|
|
256
|
+
message_type: "message".to_string(),
|
|
257
|
+
role: "assistant".to_string(),
|
|
258
|
+
model: current_model
|
|
259
|
+
.clone()
|
|
260
|
+
.unwrap_or_else(|| fallback_model.clone()),
|
|
261
|
+
usage: anthropic::Usage {
|
|
262
|
+
input_tokens: 0,
|
|
263
|
+
output_tokens: 0,
|
|
264
|
+
},
|
|
265
|
+
},
|
|
266
|
+
};
|
|
267
|
+
yield Ok(Bytes::from(sse_event("message_start", &event)));
|
|
268
|
+
has_sent_message_start = true;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
if let Some(reasoning) = &choice.delta.reasoning {
|
|
272
|
+
if !reasoning.is_empty() {
|
|
273
|
+
if !profile.supports_reasoning() {
|
|
274
|
+
yield Ok(Bytes::from(stream_error_sse(
|
|
275
|
+
"reasoning deltas are not supported by the active backend profile",
|
|
276
|
+
)));
|
|
277
|
+
break;
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
let (idx, transitions) = transition_to_thinking(
|
|
281
|
+
&mut active_block,
|
|
282
|
+
&mut next_content_index,
|
|
283
|
+
);
|
|
284
|
+
for event in transitions {
|
|
285
|
+
yield Ok(Bytes::from(event));
|
|
286
|
+
}
|
|
287
|
+
yield Ok(Bytes::from(delta_block_sse(
|
|
288
|
+
idx,
|
|
289
|
+
anthropic::ContentBlockDeltaData::ThinkingDelta {
|
|
290
|
+
thinking: reasoning.clone(),
|
|
291
|
+
},
|
|
292
|
+
)));
|
|
293
|
+
}
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
if let Some(content) = &choice.delta.content {
|
|
297
|
+
if !content.is_empty() {
|
|
298
|
+
let (idx, transitions) = transition_to_text(
|
|
299
|
+
&mut active_block,
|
|
300
|
+
&mut next_content_index,
|
|
301
|
+
);
|
|
302
|
+
for event in transitions {
|
|
303
|
+
yield Ok(Bytes::from(event));
|
|
304
|
+
}
|
|
305
|
+
yield Ok(Bytes::from(delta_block_sse(
|
|
306
|
+
idx,
|
|
307
|
+
anthropic::ContentBlockDeltaData::TextDelta {
|
|
308
|
+
text: content.clone(),
|
|
309
|
+
},
|
|
310
|
+
)));
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
if let Some(tool_calls) = &choice.delta.tool_calls {
|
|
315
|
+
for tool_call in tool_calls {
|
|
316
|
+
let tool_index = tool_call.index.unwrap_or(0);
|
|
317
|
+
let state = tool_states.entry(tool_index).or_default();
|
|
318
|
+
|
|
319
|
+
if let Some(id) = &tool_call.id {
|
|
320
|
+
state.id = Some(id.clone());
|
|
321
|
+
}
|
|
322
|
+
if let Some(function) = &tool_call.function {
|
|
323
|
+
if let Some(name) = &function.name {
|
|
324
|
+
state.name = Some(name.clone());
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
if state.content_index.is_none() {
|
|
329
|
+
if let (Some(id), Some(name)) = (state.id.clone(), state.name.clone()) {
|
|
330
|
+
let (idx, transitions) = transition_to_tool(
|
|
331
|
+
&mut active_block,
|
|
332
|
+
&mut next_content_index,
|
|
333
|
+
tool_index,
|
|
334
|
+
id,
|
|
335
|
+
name,
|
|
336
|
+
);
|
|
337
|
+
state.content_index = Some(idx);
|
|
338
|
+
for event in transitions {
|
|
339
|
+
yield Ok(Bytes::from(event));
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
} else if active_block != Some(ActiveBlock::ToolUse(tool_index, state.content_index.unwrap())) {
|
|
343
|
+
yield Ok(Bytes::from(stream_error_sse(
|
|
344
|
+
"interleaved tool call deltas are not supported safely",
|
|
345
|
+
)));
|
|
346
|
+
break;
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
if let Some(function) = &tool_call.function {
|
|
350
|
+
if let Some(arguments) = &function.arguments {
|
|
351
|
+
if let Some(idx) = state.content_index {
|
|
352
|
+
yield Ok(Bytes::from(delta_block_sse(
|
|
353
|
+
idx,
|
|
354
|
+
anthropic::ContentBlockDeltaData::InputJsonDelta {
|
|
355
|
+
partial_json: arguments.clone(),
|
|
356
|
+
},
|
|
357
|
+
)));
|
|
358
|
+
}
|
|
359
|
+
}
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
if let Some(finish_reason) = &choice.finish_reason {
|
|
365
|
+
if let Some(previous) = active_block.take() {
|
|
366
|
+
yield Ok(Bytes::from(stop_block_sse(previous.index())));
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
let event = anthropic::StreamEvent::MessageDelta {
|
|
370
|
+
delta: anthropic::MessageDeltaData {
|
|
371
|
+
stop_reason: transform::map_stop_reason(Some(finish_reason)),
|
|
372
|
+
stop_sequence: (),
|
|
373
|
+
},
|
|
374
|
+
usage: chunk.usage.as_ref().and_then(|u| {
|
|
375
|
+
u.completion_tokens.map(|tokens| anthropic::MessageDeltaUsage {
|
|
376
|
+
output_tokens: tokens,
|
|
377
|
+
})
|
|
378
|
+
}),
|
|
379
|
+
};
|
|
380
|
+
yield Ok(Bytes::from(sse_event("message_delta", &event)));
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
Err(e) => {
|
|
387
|
+
tracing::error!("Stream error: {}", e);
|
|
388
|
+
yield Ok(Bytes::from(stream_error_sse(&format!("Stream error: {}", e))));
|
|
389
|
+
break;
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
}
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
pub struct Config {
|
|
397
|
+
pub backend_url: String,
|
|
398
|
+
pub backend_profile: BackendProfile,
|
|
399
|
+
pub model: String,
|
|
400
|
+
pub reasoning_model: Option<String>,
|
|
401
|
+
pub api_key: Option<String>,
|
|
402
|
+
pub ingress_api_key: Option<String>,
|
|
403
|
+
pub allow_origins: Vec<String>,
|
|
404
|
+
pub port: u16,
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
impl Config {
|
|
408
|
+
pub fn from_env() -> Self {
|
|
409
|
+
Self {
|
|
410
|
+
backend_url: std::env::var("ANTHMORPH_BACKEND_URL")
|
|
411
|
+
.unwrap_or_else(|_| "https://llm.chutes.ai/v1".to_string()),
|
|
412
|
+
backend_profile: std::env::var("ANTHMORPH_BACKEND_PROFILE")
|
|
413
|
+
.ok()
|
|
414
|
+
.and_then(|v| v.parse().ok())
|
|
415
|
+
.unwrap_or(BackendProfile::Chutes),
|
|
416
|
+
model: std::env::var("ANTHMORPH_MODEL")
|
|
417
|
+
.unwrap_or_else(|_| "Qwen/Qwen3-Coder-Next-TEE".to_string()),
|
|
418
|
+
reasoning_model: std::env::var("ANTHMORPH_REASONING_MODEL").ok(),
|
|
419
|
+
api_key: std::env::var("ANTHMORPH_API_KEY").ok(),
|
|
420
|
+
ingress_api_key: std::env::var("ANTHMORPH_INGRESS_API_KEY").ok(),
|
|
421
|
+
allow_origins: std::env::var("ANTHMORPH_ALLOWED_ORIGINS")
|
|
422
|
+
.ok()
|
|
423
|
+
.map(|v| {
|
|
424
|
+
v.split(',')
|
|
425
|
+
.map(str::trim)
|
|
426
|
+
.filter(|s| !s.is_empty())
|
|
427
|
+
.map(ToOwned::to_owned)
|
|
428
|
+
.collect()
|
|
429
|
+
})
|
|
430
|
+
.unwrap_or_default(),
|
|
431
|
+
port: std::env::var("PORT")
|
|
432
|
+
.unwrap_or_else(|_| "3000".to_string())
|
|
433
|
+
.parse()
|
|
434
|
+
.unwrap_or(3000),
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
pub fn chat_completions_url(&self) -> String {
|
|
439
|
+
format!(
|
|
440
|
+
"{}/chat/completions",
|
|
441
|
+
self.backend_url.trim_end_matches('/')
|
|
442
|
+
)
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
impl fmt::Debug for Config {
|
|
447
|
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
448
|
+
f.debug_struct("Config")
|
|
449
|
+
.field("backend_url", &self.backend_url)
|
|
450
|
+
.field("backend_profile", &self.backend_profile.as_str())
|
|
451
|
+
.field("model", &self.model)
|
|
452
|
+
.field("reasoning_model", &self.reasoning_model)
|
|
453
|
+
.field("api_key", &"<hidden>")
|
|
454
|
+
.field("ingress_api_key", &"<hidden>")
|
|
455
|
+
.field("allow_origins", &self.allow_origins)
|
|
456
|
+
.field("port", &self.port)
|
|
457
|
+
.finish()
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
462
|
+
enum ActiveBlock {
|
|
463
|
+
Thinking(usize),
|
|
464
|
+
Text(usize),
|
|
465
|
+
ToolUse(usize, usize),
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
impl ActiveBlock {
|
|
469
|
+
fn index(self) -> usize {
|
|
470
|
+
match self {
|
|
471
|
+
ActiveBlock::Thinking(index) | ActiveBlock::Text(index) => index,
|
|
472
|
+
ActiveBlock::ToolUse(_, index) => index,
|
|
473
|
+
}
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
#[derive(Debug, Default)]
|
|
478
|
+
struct ToolCallState {
|
|
479
|
+
id: Option<String>,
|
|
480
|
+
name: Option<String>,
|
|
481
|
+
content_index: Option<usize>,
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
fn transition_to_thinking(
|
|
485
|
+
active_block: &mut Option<ActiveBlock>,
|
|
486
|
+
next_content_index: &mut usize,
|
|
487
|
+
) -> (usize, Vec<String>) {
|
|
488
|
+
match active_block {
|
|
489
|
+
Some(ActiveBlock::Thinking(index)) => (*index, Vec::new()),
|
|
490
|
+
_ => {
|
|
491
|
+
let mut events = Vec::new();
|
|
492
|
+
if let Some(previous) = active_block.take() {
|
|
493
|
+
events.push(stop_block_sse(previous.index()));
|
|
494
|
+
*next_content_index += 1;
|
|
495
|
+
}
|
|
496
|
+
let index = *next_content_index;
|
|
497
|
+
*active_block = Some(ActiveBlock::Thinking(index));
|
|
498
|
+
events.push(start_block_sse(
|
|
499
|
+
index,
|
|
500
|
+
anthropic::ContentBlockStartData::Thinking {
|
|
501
|
+
thinking: String::new(),
|
|
502
|
+
},
|
|
503
|
+
));
|
|
504
|
+
(index, events)
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
fn transition_to_text(
|
|
510
|
+
active_block: &mut Option<ActiveBlock>,
|
|
511
|
+
next_content_index: &mut usize,
|
|
512
|
+
) -> (usize, Vec<String>) {
|
|
513
|
+
match active_block {
|
|
514
|
+
Some(ActiveBlock::Text(index)) => (*index, Vec::new()),
|
|
515
|
+
_ => {
|
|
516
|
+
let mut events = Vec::new();
|
|
517
|
+
if let Some(previous) = active_block.take() {
|
|
518
|
+
events.push(stop_block_sse(previous.index()));
|
|
519
|
+
*next_content_index += 1;
|
|
520
|
+
}
|
|
521
|
+
let index = *next_content_index;
|
|
522
|
+
*active_block = Some(ActiveBlock::Text(index));
|
|
523
|
+
events.push(start_block_sse(
|
|
524
|
+
index,
|
|
525
|
+
anthropic::ContentBlockStartData::Text {
|
|
526
|
+
text: String::new(),
|
|
527
|
+
},
|
|
528
|
+
));
|
|
529
|
+
(index, events)
|
|
530
|
+
}
|
|
531
|
+
}
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
fn transition_to_tool(
|
|
535
|
+
active_block: &mut Option<ActiveBlock>,
|
|
536
|
+
next_content_index: &mut usize,
|
|
537
|
+
tool_index: usize,
|
|
538
|
+
id: String,
|
|
539
|
+
name: String,
|
|
540
|
+
) -> (usize, Vec<String>) {
|
|
541
|
+
if let Some(ActiveBlock::ToolUse(active_tool_index, index)) = active_block {
|
|
542
|
+
if *active_tool_index == tool_index {
|
|
543
|
+
return (*index, Vec::new());
|
|
544
|
+
}
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
let mut events = Vec::new();
|
|
548
|
+
if let Some(previous) = active_block.take() {
|
|
549
|
+
events.push(stop_block_sse(previous.index()));
|
|
550
|
+
*next_content_index += 1;
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
let index = *next_content_index;
|
|
554
|
+
*active_block = Some(ActiveBlock::ToolUse(tool_index, index));
|
|
555
|
+
events.push(start_block_sse(
|
|
556
|
+
index,
|
|
557
|
+
anthropic::ContentBlockStartData::ToolUse { id, name },
|
|
558
|
+
));
|
|
559
|
+
(index, events)
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
fn extract_sse_data(event_block: &str) -> Option<String> {
|
|
563
|
+
let data_lines: Vec<_> = event_block
|
|
564
|
+
.lines()
|
|
565
|
+
.filter_map(|line| line.strip_prefix("data: "))
|
|
566
|
+
.collect();
|
|
567
|
+
|
|
568
|
+
if data_lines.is_empty() {
|
|
569
|
+
None
|
|
570
|
+
} else {
|
|
571
|
+
Some(data_lines.join("\n"))
|
|
572
|
+
}
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
fn sse_event<T: serde::Serialize>(name: &str, payload: &T) -> String {
|
|
576
|
+
format!(
|
|
577
|
+
"event: {name}\ndata: {}\n\n",
|
|
578
|
+
serde_json::to_string(payload).unwrap_or_default()
|
|
579
|
+
)
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
fn start_block_sse(index: usize, content_block: anthropic::ContentBlockStartData) -> String {
|
|
583
|
+
let event = anthropic::StreamEvent::ContentBlockStart {
|
|
584
|
+
index,
|
|
585
|
+
content_block,
|
|
586
|
+
};
|
|
587
|
+
sse_event("content_block_start", &event)
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
fn delta_block_sse(index: usize, delta: anthropic::ContentBlockDeltaData) -> String {
|
|
591
|
+
let event = anthropic::StreamEvent::ContentBlockDelta { index, delta };
|
|
592
|
+
sse_event("content_block_delta", &event)
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
fn stop_block_sse(index: usize) -> String {
|
|
596
|
+
let event = anthropic::StreamEvent::ContentBlockStop { index };
|
|
597
|
+
sse_event("content_block_stop", &event)
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
fn message_stop_sse() -> String {
|
|
601
|
+
let event = anthropic::StreamEvent::MessageStop;
|
|
602
|
+
sse_event("message_stop", &event)
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
fn stream_error_sse(message: &str) -> String {
|
|
606
|
+
let event = json!({
|
|
607
|
+
"type": "error",
|
|
608
|
+
"error": {
|
|
609
|
+
"type": "stream_error",
|
|
610
|
+
"message": message,
|
|
611
|
+
}
|
|
612
|
+
});
|
|
613
|
+
format!(
|
|
614
|
+
"event: error\ndata: {}\n\n",
|
|
615
|
+
serde_json::to_string(&event).unwrap_or_default()
|
|
616
|
+
)
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
fn authorize_request(headers: &HeaderMap, config: &Config) -> ProxyResult<()> {
|
|
620
|
+
let Some(expected) = &config.ingress_api_key else {
|
|
621
|
+
return Ok(());
|
|
622
|
+
};
|
|
623
|
+
|
|
624
|
+
let bearer = headers
|
|
625
|
+
.get(header::AUTHORIZATION)
|
|
626
|
+
.and_then(|v| v.to_str().ok())
|
|
627
|
+
.and_then(|value| value.strip_prefix("Bearer "));
|
|
628
|
+
let x_api_key = headers.get("x-api-key").and_then(|v| v.to_str().ok());
|
|
629
|
+
|
|
630
|
+
if bearer == Some(expected.as_str()) || x_api_key == Some(expected.as_str()) {
|
|
631
|
+
Ok(())
|
|
632
|
+
} else {
|
|
633
|
+
Err(ProxyError::Upstream(
|
|
634
|
+
"401 unauthorized ingress request".to_string(),
|
|
635
|
+
))
|
|
636
|
+
}
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
pub fn build_cors_layer(config: &Config) -> anyhow::Result<Option<CorsLayer>> {
|
|
640
|
+
if config.allow_origins.is_empty() {
|
|
641
|
+
return Ok(None);
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
let origins: Vec<HeaderValue> = config
|
|
645
|
+
.allow_origins
|
|
646
|
+
.iter()
|
|
647
|
+
.map(|origin| HeaderValue::from_str(origin))
|
|
648
|
+
.collect::<Result<_, _>>()?;
|
|
649
|
+
|
|
650
|
+
Ok(Some(
|
|
651
|
+
CorsLayer::new()
|
|
652
|
+
.allow_methods([axum::http::Method::POST, axum::http::Method::GET])
|
|
653
|
+
.allow_headers([
|
|
654
|
+
header::AUTHORIZATION,
|
|
655
|
+
HeaderName::from_static("x-api-key"),
|
|
656
|
+
header::CONTENT_TYPE,
|
|
657
|
+
])
|
|
658
|
+
.allow_origin(AllowOrigin::list(origins)),
|
|
659
|
+
))
|
|
660
|
+
}
|
|
661
|
+
|
|
662
|
+
#[cfg(test)]
|
|
663
|
+
mod tests {
|
|
664
|
+
use super::*;
|
|
665
|
+
use axum::{
|
|
666
|
+
body::Body,
|
|
667
|
+
http::{
|
|
668
|
+
header::ACCESS_CONTROL_REQUEST_METHOD, header::ORIGIN, Method, Request, StatusCode,
|
|
669
|
+
},
|
|
670
|
+
routing::get,
|
|
671
|
+
Router,
|
|
672
|
+
};
|
|
673
|
+
use futures::stream;
|
|
674
|
+
use tower::ServiceExt;
|
|
675
|
+
|
|
676
|
+
#[tokio::test]
|
|
677
|
+
async fn create_sse_stream_accumulates_fragmented_tool_calls() {
|
|
678
|
+
let first = serde_json::to_string(&json!({
|
|
679
|
+
"id": "abc",
|
|
680
|
+
"model": "qwen",
|
|
681
|
+
"choices": [{
|
|
682
|
+
"index": 0,
|
|
683
|
+
"delta": {
|
|
684
|
+
"tool_calls": [{
|
|
685
|
+
"index": 0,
|
|
686
|
+
"id": "call_1",
|
|
687
|
+
"function": {
|
|
688
|
+
"name": "weather",
|
|
689
|
+
"arguments": "{\"loc"
|
|
690
|
+
}
|
|
691
|
+
}]
|
|
692
|
+
},
|
|
693
|
+
"finish_reason": null
|
|
694
|
+
}],
|
|
695
|
+
"usage": null
|
|
696
|
+
}))
|
|
697
|
+
.unwrap();
|
|
698
|
+
let second = serde_json::to_string(&json!({
|
|
699
|
+
"choices": [{
|
|
700
|
+
"index": 0,
|
|
701
|
+
"delta": {
|
|
702
|
+
"tool_calls": [{
|
|
703
|
+
"index": 0,
|
|
704
|
+
"function": {
|
|
705
|
+
"arguments": "ation\":\"Rome\"}"
|
|
706
|
+
}
|
|
707
|
+
}]
|
|
708
|
+
},
|
|
709
|
+
"finish_reason": "tool_calls"
|
|
710
|
+
}],
|
|
711
|
+
"usage": {
|
|
712
|
+
"completion_tokens": 7
|
|
713
|
+
}
|
|
714
|
+
}))
|
|
715
|
+
.unwrap();
|
|
716
|
+
|
|
717
|
+
let chunks = vec![
|
|
718
|
+
Ok(Bytes::from(format!("data: {first}\n\n"))),
|
|
719
|
+
Ok(Bytes::from(format!("data: {second}\n\n"))),
|
|
720
|
+
Ok(Bytes::from("data: [DONE]\n\n")),
|
|
721
|
+
];
|
|
722
|
+
|
|
723
|
+
let mut output = Vec::new();
|
|
724
|
+
let sse = create_sse_stream(
|
|
725
|
+
stream::iter(chunks),
|
|
726
|
+
"fallback".to_string(),
|
|
727
|
+
BackendProfile::Chutes,
|
|
728
|
+
);
|
|
729
|
+
tokio::pin!(sse);
|
|
730
|
+
|
|
731
|
+
while let Some(item) = sse.next().await {
|
|
732
|
+
output.push(String::from_utf8(item.unwrap().to_vec()).unwrap());
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
let joined = output.join("");
|
|
736
|
+
assert!(joined.contains("\"type\":\"tool_use\""));
|
|
737
|
+
assert!(joined.contains("\"partial_json\":\"{\\\"loc\""));
|
|
738
|
+
assert!(joined.contains("\"partial_json\":\"ation"));
|
|
739
|
+
assert_eq!(joined.matches("event: content_block_start").count(), 1);
|
|
740
|
+
}
|
|
741
|
+
|
|
742
|
+
#[test]
|
|
743
|
+
fn extracts_multi_line_sse_data() {
|
|
744
|
+
let block = "event: message\ndata: first\ndata: second\n";
|
|
745
|
+
assert_eq!(extract_sse_data(block).as_deref(), Some("first\nsecond"));
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
#[test]
|
|
749
|
+
fn authorize_request_accepts_bearer_and_x_api_key() {
|
|
750
|
+
let config = Config {
|
|
751
|
+
backend_url: "https://example.com".to_string(),
|
|
752
|
+
backend_profile: BackendProfile::OpenaiGeneric,
|
|
753
|
+
model: "model".to_string(),
|
|
754
|
+
reasoning_model: None,
|
|
755
|
+
api_key: None,
|
|
756
|
+
ingress_api_key: Some("secret".to_string()),
|
|
757
|
+
allow_origins: Vec::new(),
|
|
758
|
+
port: 3000,
|
|
759
|
+
};
|
|
760
|
+
|
|
761
|
+
let mut bearer_headers = HeaderMap::new();
|
|
762
|
+
bearer_headers.insert(
|
|
763
|
+
header::AUTHORIZATION,
|
|
764
|
+
HeaderValue::from_static("Bearer secret"),
|
|
765
|
+
);
|
|
766
|
+
assert!(authorize_request(&bearer_headers, &config).is_ok());
|
|
767
|
+
|
|
768
|
+
let mut x_api_headers = HeaderMap::new();
|
|
769
|
+
x_api_headers.insert(
|
|
770
|
+
HeaderName::from_static("x-api-key"),
|
|
771
|
+
HeaderValue::from_static("secret"),
|
|
772
|
+
);
|
|
773
|
+
assert!(authorize_request(&x_api_headers, &config).is_ok());
|
|
774
|
+
}
|
|
775
|
+
|
|
776
|
+
#[test]
|
|
777
|
+
fn authorize_request_rejects_invalid_ingress_key() {
|
|
778
|
+
let config = Config {
|
|
779
|
+
backend_url: "https://example.com".to_string(),
|
|
780
|
+
backend_profile: BackendProfile::OpenaiGeneric,
|
|
781
|
+
model: "model".to_string(),
|
|
782
|
+
reasoning_model: None,
|
|
783
|
+
api_key: None,
|
|
784
|
+
ingress_api_key: Some("secret".to_string()),
|
|
785
|
+
allow_origins: Vec::new(),
|
|
786
|
+
port: 3000,
|
|
787
|
+
};
|
|
788
|
+
|
|
789
|
+
let headers = HeaderMap::new();
|
|
790
|
+
let err = authorize_request(&headers, &config).unwrap_err();
|
|
791
|
+
assert!(err.to_string().contains("unauthorized ingress request"));
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
#[tokio::test]
|
|
795
|
+
async fn build_cors_layer_allows_configured_origin() {
|
|
796
|
+
let config = Config {
|
|
797
|
+
backend_url: "https://example.com".to_string(),
|
|
798
|
+
backend_profile: BackendProfile::OpenaiGeneric,
|
|
799
|
+
model: "model".to_string(),
|
|
800
|
+
reasoning_model: None,
|
|
801
|
+
api_key: None,
|
|
802
|
+
ingress_api_key: None,
|
|
803
|
+
allow_origins: vec!["https://allowed.example".to_string()],
|
|
804
|
+
port: 3000,
|
|
805
|
+
};
|
|
806
|
+
|
|
807
|
+
let app = Router::new().route("/health", get(|| async { StatusCode::OK }));
|
|
808
|
+
let app = app.layer(build_cors_layer(&config).unwrap().expect("cors layer"));
|
|
809
|
+
|
|
810
|
+
let response = app
|
|
811
|
+
.oneshot(
|
|
812
|
+
Request::builder()
|
|
813
|
+
.method(Method::OPTIONS)
|
|
814
|
+
.uri("/health")
|
|
815
|
+
.header(ORIGIN, "https://allowed.example")
|
|
816
|
+
.header(ACCESS_CONTROL_REQUEST_METHOD, "GET")
|
|
817
|
+
.body(Body::empty())
|
|
818
|
+
.unwrap(),
|
|
819
|
+
)
|
|
820
|
+
.await
|
|
821
|
+
.unwrap();
|
|
822
|
+
|
|
823
|
+
assert_eq!(response.status(), StatusCode::OK);
|
|
824
|
+
assert_eq!(
|
|
825
|
+
response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN),
|
|
826
|
+
Some(&HeaderValue::from_static("https://allowed.example"))
|
|
827
|
+
);
|
|
828
|
+
}
|
|
829
|
+
}
|