anveesa 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Cargo.lock +1791 -0
- package/Cargo.toml +17 -0
- package/README.md +253 -0
- package/bin/anveesa.js +64 -0
- package/package.json +32 -0
- package/scripts/install.js +55 -0
- package/src/cli.rs +94 -0
- package/src/config.rs +690 -0
- package/src/lib.rs +1540 -0
- package/src/main.rs +4 -0
- package/src/provider/command.rs +301 -0
- package/src/provider/mod.rs +194 -0
- package/src/provider/openai_compatible.rs +939 -0
- package/src/tools.rs +992 -0
|
@@ -0,0 +1,939 @@
|
|
|
1
|
+
use std::time::Duration;
|
|
2
|
+
|
|
3
|
+
use anyhow::{Context, Result, bail};
|
|
4
|
+
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
|
|
5
|
+
use serde_json::{Value, json};
|
|
6
|
+
use tokio::sync::{mpsc::UnboundedSender, oneshot};
|
|
7
|
+
|
|
8
|
+
use crate::{
|
|
9
|
+
config::OpenAiCompatibleProviderConfig,
|
|
10
|
+
provider::{
|
|
11
|
+
ApprovalDecision, ApprovalPolicy, ChatRole, DiffKind, DiffLine, PromptRequest, StreamEvent,
|
|
12
|
+
ToolConfirmPreview, TurnResult, Usage,
|
|
13
|
+
},
|
|
14
|
+
tools,
|
|
15
|
+
};
|
|
16
|
+
|
|
17
|
+
const DEFAULT_MAX_TOOL_ROUNDS: usize = 32;
|
|
18
|
+
const HARD_MAX_TOOL_ROUNDS: usize = 256;
|
|
19
|
+
const MAX_TOOL_ROUNDS_ENV: &str = "ANVEESA_MAX_TOOL_ROUNDS";
|
|
20
|
+
const MAX_RETRIES: usize = 2;
|
|
21
|
+
const CONNECT_TIMEOUT: Duration = Duration::from_secs(15);
|
|
22
|
+
/// How many times the model may call the exact same (tool, arguments) pair before we refuse.
|
|
23
|
+
const MAX_IDENTICAL_CALLS: usize = 3;
|
|
24
|
+
|
|
25
|
+
pub async fn ask(
|
|
26
|
+
provider_name: &str,
|
|
27
|
+
config: &OpenAiCompatibleProviderConfig,
|
|
28
|
+
request: PromptRequest,
|
|
29
|
+
policy: ApprovalPolicy,
|
|
30
|
+
events: &UnboundedSender<StreamEvent>,
|
|
31
|
+
) -> Result<TurnResult> {
|
|
32
|
+
let model = request
|
|
33
|
+
.model
|
|
34
|
+
.clone()
|
|
35
|
+
.or_else(|| config.default_model.clone())
|
|
36
|
+
.with_context(|| {
|
|
37
|
+
format!("provider '{provider_name}' requires --model or default_model in config")
|
|
38
|
+
})?;
|
|
39
|
+
|
|
40
|
+
// Use the explicit config flag if set; otherwise auto-enable for Anthropic endpoints.
|
|
41
|
+
let prompt_cache = config
|
|
42
|
+
.prompt_cache
|
|
43
|
+
.unwrap_or_else(|| is_anthropic_url(&config.base_url));
|
|
44
|
+
let headers = build_headers(config, prompt_cache)?;
|
|
45
|
+
let mut messages = build_messages(&request, policy, prompt_cache);
|
|
46
|
+
|
|
47
|
+
let client = reqwest::Client::builder()
|
|
48
|
+
.connect_timeout(CONNECT_TIMEOUT)
|
|
49
|
+
.build()
|
|
50
|
+
.context("failed to build HTTP client")?;
|
|
51
|
+
let url = format!("{}/chat/completions", config.base_url.trim_end_matches('/'));
|
|
52
|
+
|
|
53
|
+
let mut tools_enabled = true;
|
|
54
|
+
let mut usage_requested = true;
|
|
55
|
+
let mut tool_rounds = 0usize;
|
|
56
|
+
let max_tool_rounds = max_tool_rounds();
|
|
57
|
+
let mut approval_state = ToolApprovalState::default();
|
|
58
|
+
let mut full_text = String::new();
|
|
59
|
+
let mut last_usage: Option<Usage> = None;
|
|
60
|
+
|
|
61
|
+
loop {
|
|
62
|
+
let mut body = json!({
|
|
63
|
+
"model": model,
|
|
64
|
+
"messages": messages,
|
|
65
|
+
"stream": true,
|
|
66
|
+
});
|
|
67
|
+
if usage_requested {
|
|
68
|
+
body["stream_options"] = json!({ "include_usage": true });
|
|
69
|
+
}
|
|
70
|
+
if tools_enabled {
|
|
71
|
+
body["tools"] = json!(tools::definitions(policy.allows_write_tools()));
|
|
72
|
+
body["tool_choice"] = json!("auto");
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
let response = send_with_retry(&client, &url, &headers, &body)
|
|
76
|
+
.await
|
|
77
|
+
.with_context(|| format!("request to provider '{provider_name}' failed"))?;
|
|
78
|
+
|
|
79
|
+
let status = response.status();
|
|
80
|
+
if !status.is_success() {
|
|
81
|
+
let response_body = response.text().await.unwrap_or_default();
|
|
82
|
+
if tools_enabled && is_tool_parameter_error(&response_body) {
|
|
83
|
+
tools_enabled = false;
|
|
84
|
+
continue;
|
|
85
|
+
}
|
|
86
|
+
if usage_requested && is_stream_options_error(&response_body) {
|
|
87
|
+
usage_requested = false;
|
|
88
|
+
continue;
|
|
89
|
+
}
|
|
90
|
+
bail!("provider '{provider_name}' returned HTTP {status}: {response_body}");
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
let mut state = StreamState::default();
|
|
94
|
+
stream_response(response, &mut state, events).await?;
|
|
95
|
+
|
|
96
|
+
full_text.push_str(&state.content);
|
|
97
|
+
if let Some(usage) = state.usage {
|
|
98
|
+
last_usage = Some(usage);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if state.tool_calls.is_empty() {
|
|
102
|
+
break;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
if !tools_enabled {
|
|
106
|
+
break;
|
|
107
|
+
}
|
|
108
|
+
tool_rounds += 1;
|
|
109
|
+
|
|
110
|
+
messages.push(assistant_tool_message(&state));
|
|
111
|
+
for call in &state.tool_calls {
|
|
112
|
+
let content = dispatch_tool(call, policy, &mut approval_state, events).await;
|
|
113
|
+
messages.push(json!({
|
|
114
|
+
"role": "tool",
|
|
115
|
+
"tool_call_id": call.id,
|
|
116
|
+
"name": call.name,
|
|
117
|
+
"content": content,
|
|
118
|
+
}));
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
if tool_rounds >= max_tool_rounds {
|
|
122
|
+
tools_enabled = false;
|
|
123
|
+
messages.push(tool_limit_message(max_tool_rounds));
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
if let Some(usage) = last_usage {
|
|
128
|
+
let _ = events.send(StreamEvent::Usage(usage));
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
Ok(TurnResult {
|
|
132
|
+
text: full_text,
|
|
133
|
+
usage: last_usage,
|
|
134
|
+
})
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
#[derive(Debug, Default)]
|
|
138
|
+
struct ToolApprovalState {
|
|
139
|
+
allow_for_turn: bool,
|
|
140
|
+
/// Tracks how many times each identical (name, arguments) pair has been called this turn.
|
|
141
|
+
call_counts: std::collections::HashMap<(String, String), usize>,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
async fn dispatch_tool(
|
|
145
|
+
call: &PartialToolCall,
|
|
146
|
+
policy: ApprovalPolicy,
|
|
147
|
+
approval_state: &mut ToolApprovalState,
|
|
148
|
+
events: &UnboundedSender<StreamEvent>,
|
|
149
|
+
) -> String {
|
|
150
|
+
// Plan tools — display only, no approval or filesystem access needed.
|
|
151
|
+
if call.name == "set_plan" {
|
|
152
|
+
if let Ok(args) = serde_json::from_str::<serde_json::Value>(&call.arguments) {
|
|
153
|
+
let tasks = args["steps"]
|
|
154
|
+
.as_array()
|
|
155
|
+
.map(|arr| {
|
|
156
|
+
arr.iter()
|
|
157
|
+
.filter_map(|v| v.as_str().map(str::to_string))
|
|
158
|
+
.collect()
|
|
159
|
+
})
|
|
160
|
+
.unwrap_or_default();
|
|
161
|
+
let _ = events.send(StreamEvent::PlanSet { tasks });
|
|
162
|
+
}
|
|
163
|
+
return json!({"ok": true}).to_string();
|
|
164
|
+
}
|
|
165
|
+
if call.name == "complete_task" {
|
|
166
|
+
if let Ok(args) = serde_json::from_str::<serde_json::Value>(&call.arguments) {
|
|
167
|
+
if let Some(index) = args["index"].as_u64() {
|
|
168
|
+
let _ = events.send(StreamEvent::PlanTaskDone {
|
|
169
|
+
index: index as usize,
|
|
170
|
+
});
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
return json!({"ok": true}).to_string();
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
// Anti-loop guard: refuse if the model is calling the exact same (tool, args) repeatedly.
|
|
177
|
+
{
|
|
178
|
+
let key = (call.name.clone(), call.arguments.clone());
|
|
179
|
+
let count = approval_state.call_counts.entry(key).or_insert(0);
|
|
180
|
+
*count += 1;
|
|
181
|
+
if *count > MAX_IDENTICAL_CALLS {
|
|
182
|
+
return json!({
|
|
183
|
+
"ok": false,
|
|
184
|
+
"error": format!(
|
|
185
|
+
"Refusing to run '{}' again: this identical call has already been made {} time(s) \
|
|
186
|
+
this turn. Do NOT retry — stop and report the failure to the user.",
|
|
187
|
+
call.name, *count - 1
|
|
188
|
+
)
|
|
189
|
+
})
|
|
190
|
+
.to_string();
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
if tools::is_write_tool(&call.name) {
|
|
195
|
+
if !policy.allows_write_tools() {
|
|
196
|
+
return denied_message("write tools are disabled (pass --yes or run interactively)");
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
// Snapshot BEFORE the tool runs — needed both for preview and for post-run diff.
|
|
201
|
+
let file_op_snapshot = capture_file_op_snapshot(&call.name, &call.arguments);
|
|
202
|
+
|
|
203
|
+
let mut preview_was_shown = false;
|
|
204
|
+
|
|
205
|
+
if tools::is_write_tool(&call.name)
|
|
206
|
+
&& policy == ApprovalPolicy::Prompt
|
|
207
|
+
&& !approval_state.allow_for_turn
|
|
208
|
+
{
|
|
209
|
+
let preview = build_confirm_preview(&call.name, &call.arguments, &file_op_snapshot);
|
|
210
|
+
preview_was_shown = true;
|
|
211
|
+
match request_approval_with_preview(preview, events).await {
|
|
212
|
+
ApprovalDecision::AllowOnce => {}
|
|
213
|
+
ApprovalDecision::AllowForTurn => approval_state.allow_for_turn = true,
|
|
214
|
+
ApprovalDecision::Deny => return denied_message("user declined this action"),
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
let result = tools::run(&call.name, &call.arguments).await;
|
|
219
|
+
|
|
220
|
+
// When the user already reviewed the diff in the approval preview, skip the
|
|
221
|
+
// post-run FileOp so the same diff isn't printed twice.
|
|
222
|
+
if !preview_was_shown {
|
|
223
|
+
if let Some(snapshot) = file_op_snapshot {
|
|
224
|
+
if let Ok(result_json) = serde_json::from_str::<serde_json::Value>(&result) {
|
|
225
|
+
if result_json["ok"].as_bool().unwrap_or(false) {
|
|
226
|
+
emit_file_op_event(snapshot, &result_json, events);
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
result
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
// ── File-op diff helpers ──────────────────────────────────────────────────────
|
|
236
|
+
|
|
237
|
+
enum FileOpSnapshot {
|
|
238
|
+
Write {
|
|
239
|
+
path: String,
|
|
240
|
+
lines: Vec<String>,
|
|
241
|
+
total: usize,
|
|
242
|
+
},
|
|
243
|
+
Edit {
|
|
244
|
+
path: String,
|
|
245
|
+
start_line: usize,
|
|
246
|
+
old_lines: Vec<String>,
|
|
247
|
+
new_lines: Vec<String>,
|
|
248
|
+
},
|
|
249
|
+
CreateDir {
|
|
250
|
+
path: String,
|
|
251
|
+
},
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
fn capture_file_op_snapshot(tool_name: &str, arguments: &str) -> Option<FileOpSnapshot> {
|
|
255
|
+
let args: serde_json::Value = serde_json::from_str(arguments).ok()?;
|
|
256
|
+
match tool_name {
|
|
257
|
+
"write_file" => {
|
|
258
|
+
let path = args["path"].as_str()?.to_string();
|
|
259
|
+
let content = args["content"].as_str().unwrap_or("");
|
|
260
|
+
let all: Vec<String> = content.lines().map(str::to_string).collect();
|
|
261
|
+
let total = all.len();
|
|
262
|
+
Some(FileOpSnapshot::Write {
|
|
263
|
+
path,
|
|
264
|
+
lines: all.into_iter().take(20).collect(),
|
|
265
|
+
total,
|
|
266
|
+
})
|
|
267
|
+
}
|
|
268
|
+
"edit_file" => {
|
|
269
|
+
let path = args["path"].as_str()?.to_string();
|
|
270
|
+
let old = args["old_string"].as_str().unwrap_or("");
|
|
271
|
+
let new = args["new_string"].as_str().unwrap_or("");
|
|
272
|
+
let start_line = std::fs::read_to_string(&path)
|
|
273
|
+
.ok()
|
|
274
|
+
.and_then(|content| {
|
|
275
|
+
let pos = content.find(old)?;
|
|
276
|
+
Some(content[..pos].lines().count() + 1)
|
|
277
|
+
})
|
|
278
|
+
.unwrap_or(1);
|
|
279
|
+
Some(FileOpSnapshot::Edit {
|
|
280
|
+
path,
|
|
281
|
+
start_line,
|
|
282
|
+
old_lines: old.lines().map(str::to_string).collect(),
|
|
283
|
+
new_lines: new.lines().map(str::to_string).collect(),
|
|
284
|
+
})
|
|
285
|
+
}
|
|
286
|
+
"create_dir" => {
|
|
287
|
+
let path = args["path"].as_str()?.to_string();
|
|
288
|
+
Some(FileOpSnapshot::CreateDir { path })
|
|
289
|
+
}
|
|
290
|
+
_ => None,
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
fn emit_file_op_event(
|
|
295
|
+
snapshot: FileOpSnapshot,
|
|
296
|
+
result: &serde_json::Value,
|
|
297
|
+
events: &tokio::sync::mpsc::UnboundedSender<StreamEvent>,
|
|
298
|
+
) {
|
|
299
|
+
const MAX_PREVIEW: usize = 20;
|
|
300
|
+
|
|
301
|
+
let event = match snapshot {
|
|
302
|
+
FileOpSnapshot::Write { path, lines, total } => {
|
|
303
|
+
let verb = if result["created"].as_bool().unwrap_or(true) {
|
|
304
|
+
"Create"
|
|
305
|
+
} else {
|
|
306
|
+
"Update"
|
|
307
|
+
}
|
|
308
|
+
.to_string();
|
|
309
|
+
let truncated = total > MAX_PREVIEW;
|
|
310
|
+
let preview = lines
|
|
311
|
+
.into_iter()
|
|
312
|
+
.enumerate()
|
|
313
|
+
.map(|(i, text)| DiffLine {
|
|
314
|
+
kind: DiffKind::Add,
|
|
315
|
+
line_no: i + 1,
|
|
316
|
+
text,
|
|
317
|
+
})
|
|
318
|
+
.collect();
|
|
319
|
+
StreamEvent::FileOp {
|
|
320
|
+
verb,
|
|
321
|
+
path,
|
|
322
|
+
added: total,
|
|
323
|
+
removed: 0,
|
|
324
|
+
preview,
|
|
325
|
+
truncated,
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
FileOpSnapshot::Edit {
|
|
329
|
+
path,
|
|
330
|
+
start_line,
|
|
331
|
+
old_lines,
|
|
332
|
+
new_lines,
|
|
333
|
+
} => {
|
|
334
|
+
let added = new_lines.len();
|
|
335
|
+
let removed = old_lines.len();
|
|
336
|
+
// Cap: show at most MAX_PREVIEW removed + MAX_PREVIEW added
|
|
337
|
+
let cap = MAX_PREVIEW;
|
|
338
|
+
let truncated = old_lines.len() > cap || new_lines.len() > cap;
|
|
339
|
+
let mut preview: Vec<DiffLine> = old_lines
|
|
340
|
+
.into_iter()
|
|
341
|
+
.take(cap)
|
|
342
|
+
.enumerate()
|
|
343
|
+
.map(|(i, text)| DiffLine {
|
|
344
|
+
kind: DiffKind::Remove,
|
|
345
|
+
line_no: start_line + i,
|
|
346
|
+
text,
|
|
347
|
+
})
|
|
348
|
+
.collect();
|
|
349
|
+
for (i, text) in new_lines.into_iter().take(cap).enumerate() {
|
|
350
|
+
preview.push(DiffLine {
|
|
351
|
+
kind: DiffKind::Add,
|
|
352
|
+
line_no: start_line + i,
|
|
353
|
+
text,
|
|
354
|
+
});
|
|
355
|
+
}
|
|
356
|
+
StreamEvent::FileOp {
|
|
357
|
+
verb: "Update".to_string(),
|
|
358
|
+
path,
|
|
359
|
+
added,
|
|
360
|
+
removed,
|
|
361
|
+
preview,
|
|
362
|
+
truncated,
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
FileOpSnapshot::CreateDir { path } => StreamEvent::FileOp {
|
|
366
|
+
verb: "Create dir".to_string(),
|
|
367
|
+
path,
|
|
368
|
+
added: 0,
|
|
369
|
+
removed: 0,
|
|
370
|
+
preview: vec![],
|
|
371
|
+
truncated: false,
|
|
372
|
+
},
|
|
373
|
+
};
|
|
374
|
+
|
|
375
|
+
let _ = events.send(event);
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
fn max_tool_rounds() -> usize {
|
|
379
|
+
parse_tool_round_limit(std::env::var(MAX_TOOL_ROUNDS_ENV).ok().as_deref())
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
fn parse_tool_round_limit(value: Option<&str>) -> usize {
|
|
383
|
+
value
|
|
384
|
+
.and_then(|value| value.trim().parse::<usize>().ok())
|
|
385
|
+
.filter(|value| *value > 0)
|
|
386
|
+
.unwrap_or(DEFAULT_MAX_TOOL_ROUNDS)
|
|
387
|
+
.min(HARD_MAX_TOOL_ROUNDS)
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
fn tool_limit_message(max_tool_rounds: usize) -> Value {
|
|
391
|
+
json!({
|
|
392
|
+
"role": "system",
|
|
393
|
+
"content": format!(
|
|
394
|
+
"Anveesa has already run {max_tool_rounds} tool rounds for this answer. Do not call tools again. Use the tool results already provided to produce the best final answer. If the requested work is not complete, say exactly what remains."
|
|
395
|
+
)
|
|
396
|
+
})
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
fn denied_message(reason: &str) -> String {
|
|
400
|
+
json!({ "ok": false, "error": reason }).to_string()
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
fn build_confirm_preview(
|
|
404
|
+
tool_name: &str,
|
|
405
|
+
arguments: &str,
|
|
406
|
+
snapshot: &Option<FileOpSnapshot>,
|
|
407
|
+
) -> ToolConfirmPreview {
|
|
408
|
+
const CAP: usize = 20;
|
|
409
|
+
match snapshot {
|
|
410
|
+
Some(FileOpSnapshot::Write { path, lines, total }) => {
|
|
411
|
+
let verb = if std::path::Path::new(path).exists() {
|
|
412
|
+
"Update"
|
|
413
|
+
} else {
|
|
414
|
+
"Create"
|
|
415
|
+
};
|
|
416
|
+
let truncated = *total > CAP;
|
|
417
|
+
let diff = lines
|
|
418
|
+
.iter()
|
|
419
|
+
.enumerate()
|
|
420
|
+
.map(|(i, text)| DiffLine {
|
|
421
|
+
kind: DiffKind::Add,
|
|
422
|
+
line_no: i + 1,
|
|
423
|
+
text: text.clone(),
|
|
424
|
+
})
|
|
425
|
+
.collect();
|
|
426
|
+
ToolConfirmPreview::FileOp {
|
|
427
|
+
verb: verb.to_string(),
|
|
428
|
+
path: path.clone(),
|
|
429
|
+
added: *total,
|
|
430
|
+
removed: 0,
|
|
431
|
+
diff,
|
|
432
|
+
truncated,
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
Some(FileOpSnapshot::Edit {
|
|
436
|
+
path,
|
|
437
|
+
start_line,
|
|
438
|
+
old_lines,
|
|
439
|
+
new_lines,
|
|
440
|
+
}) => {
|
|
441
|
+
let truncated = old_lines.len() > CAP || new_lines.len() > CAP;
|
|
442
|
+
let mut diff: Vec<DiffLine> = old_lines
|
|
443
|
+
.iter()
|
|
444
|
+
.take(CAP)
|
|
445
|
+
.enumerate()
|
|
446
|
+
.map(|(i, text)| DiffLine {
|
|
447
|
+
kind: DiffKind::Remove,
|
|
448
|
+
line_no: start_line + i,
|
|
449
|
+
text: text.clone(),
|
|
450
|
+
})
|
|
451
|
+
.collect();
|
|
452
|
+
for (i, text) in new_lines.iter().take(CAP).enumerate() {
|
|
453
|
+
diff.push(DiffLine {
|
|
454
|
+
kind: DiffKind::Add,
|
|
455
|
+
line_no: start_line + i,
|
|
456
|
+
text: text.clone(),
|
|
457
|
+
});
|
|
458
|
+
}
|
|
459
|
+
ToolConfirmPreview::FileOp {
|
|
460
|
+
verb: "Update".to_string(),
|
|
461
|
+
path: path.clone(),
|
|
462
|
+
added: new_lines.len(),
|
|
463
|
+
removed: old_lines.len(),
|
|
464
|
+
diff,
|
|
465
|
+
truncated,
|
|
466
|
+
}
|
|
467
|
+
}
|
|
468
|
+
Some(FileOpSnapshot::CreateDir { path }) => {
|
|
469
|
+
ToolConfirmPreview::CreateDir { path: path.clone() }
|
|
470
|
+
}
|
|
471
|
+
None => ToolConfirmPreview::Generic {
|
|
472
|
+
summary: tools::describe_call(tool_name, arguments),
|
|
473
|
+
},
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
async fn request_approval_with_preview(
|
|
478
|
+
preview: ToolConfirmPreview,
|
|
479
|
+
events: &UnboundedSender<StreamEvent>,
|
|
480
|
+
) -> ApprovalDecision {
|
|
481
|
+
let (reply, answer) = oneshot::channel();
|
|
482
|
+
if events
|
|
483
|
+
.send(StreamEvent::Confirm { preview, reply })
|
|
484
|
+
.is_err()
|
|
485
|
+
{
|
|
486
|
+
return ApprovalDecision::Deny;
|
|
487
|
+
}
|
|
488
|
+
answer.await.unwrap_or(ApprovalDecision::Deny)
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
fn is_anthropic_url(base_url: &str) -> bool {
|
|
492
|
+
base_url.contains("anthropic.com")
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
fn build_headers(config: &OpenAiCompatibleProviderConfig, prompt_cache: bool) -> Result<HeaderMap> {
|
|
496
|
+
let mut headers = HeaderMap::new();
|
|
497
|
+
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
|
|
498
|
+
|
|
499
|
+
if let Some(api_key_env) = &config.api_key_env {
|
|
500
|
+
let api_key = std::env::var(api_key_env)
|
|
501
|
+
.with_context(|| format!("environment variable {api_key_env} is required"))?;
|
|
502
|
+
headers.insert(
|
|
503
|
+
AUTHORIZATION,
|
|
504
|
+
HeaderValue::from_str(&format!("Bearer {api_key}"))
|
|
505
|
+
.context("failed to build authorization header")?,
|
|
506
|
+
);
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
for (name, value) in &config.headers {
|
|
510
|
+
headers.insert(
|
|
511
|
+
HeaderName::from_bytes(name.as_bytes())
|
|
512
|
+
.with_context(|| format!("invalid header name '{name}'"))?,
|
|
513
|
+
HeaderValue::from_str(value)
|
|
514
|
+
.with_context(|| format!("invalid header value for '{name}'"))?,
|
|
515
|
+
);
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
if prompt_cache {
|
|
519
|
+
headers.insert(
|
|
520
|
+
HeaderName::from_static("anthropic-beta"),
|
|
521
|
+
HeaderValue::from_static("prompt-caching-2024-07-31"),
|
|
522
|
+
);
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
Ok(headers)
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
fn build_messages(
|
|
529
|
+
request: &PromptRequest,
|
|
530
|
+
policy: ApprovalPolicy,
|
|
531
|
+
prompt_cache: bool,
|
|
532
|
+
) -> Vec<Value> {
|
|
533
|
+
let mut messages = Vec::new();
|
|
534
|
+
if let Some(system) = &request.system {
|
|
535
|
+
messages.push(json!({ "role": "system", "content": system }));
|
|
536
|
+
}
|
|
537
|
+
if let Some(workspace_context) = &request.workspace_context {
|
|
538
|
+
messages.push(json!({ "role": "system", "content": workspace_context }));
|
|
539
|
+
}
|
|
540
|
+
messages
|
|
541
|
+
.push(json!({ "role": "system", "content": tools::guidance(policy.allows_write_tools()) }));
|
|
542
|
+
for message in &request.history {
|
|
543
|
+
let role = match message.role {
|
|
544
|
+
ChatRole::User => "user",
|
|
545
|
+
ChatRole::Assistant => "assistant",
|
|
546
|
+
};
|
|
547
|
+
messages.push(json!({ "role": role, "content": message.content }));
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
// Current user turn — multimodal when a clipboard image is attached.
|
|
551
|
+
let user_content = match &request.image {
|
|
552
|
+
Some(img) => json!([
|
|
553
|
+
{ "type": "text", "text": &request.prompt },
|
|
554
|
+
{ "type": "image_url", "image_url": { "url": format!("data:{};base64,{}", img.mime, img.data) } }
|
|
555
|
+
]),
|
|
556
|
+
None => json!(&request.prompt),
|
|
557
|
+
};
|
|
558
|
+
messages.push(json!({ "role": "user", "content": user_content }));
|
|
559
|
+
|
|
560
|
+
if prompt_cache {
|
|
561
|
+
apply_cache_breakpoints(&mut messages);
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
messages
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
/// Add `cache_control: {type: "ephemeral"}` to two breakpoints in the message list:
|
|
568
|
+
/// 1. The last system message (tools guidance — large, fully static).
|
|
569
|
+
/// 2. The last history message before the current user turn (grows each turn).
|
|
570
|
+
///
|
|
571
|
+
/// Everything up to each breakpoint is served from cache on subsequent turns.
|
|
572
|
+
fn apply_cache_breakpoints(messages: &mut Vec<Value>) {
|
|
573
|
+
let current_turn_idx = messages.len() - 1;
|
|
574
|
+
|
|
575
|
+
// Breakpoint 1: last system message
|
|
576
|
+
if let Some(idx) = messages[..current_turn_idx]
|
|
577
|
+
.iter()
|
|
578
|
+
.rposition(|m| m["role"] == "system")
|
|
579
|
+
{
|
|
580
|
+
add_cache_control(&mut messages[idx]);
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
// Breakpoint 2: last history message (user or assistant before the current turn)
|
|
584
|
+
if let Some(idx) = messages[..current_turn_idx]
|
|
585
|
+
.iter()
|
|
586
|
+
.rposition(|m| m["role"] != "system")
|
|
587
|
+
{
|
|
588
|
+
add_cache_control(&mut messages[idx]);
|
|
589
|
+
}
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
/// Convert a message's `content` to the array form required by Anthropic caching,
|
|
593
|
+
/// then inject `cache_control: {type: "ephemeral"}` on the last content block.
|
|
594
|
+
fn add_cache_control(message: &mut Value) {
|
|
595
|
+
let content = match message.get("content").cloned() {
|
|
596
|
+
Some(c) => c,
|
|
597
|
+
None => return,
|
|
598
|
+
};
|
|
599
|
+
|
|
600
|
+
let cached = match content {
|
|
601
|
+
Value::String(s) => json!([{
|
|
602
|
+
"type": "text",
|
|
603
|
+
"text": s,
|
|
604
|
+
"cache_control": { "type": "ephemeral" }
|
|
605
|
+
}]),
|
|
606
|
+
Value::Array(mut arr) => {
|
|
607
|
+
if let Some(last) = arr.last_mut().and_then(Value::as_object_mut) {
|
|
608
|
+
last.insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
|
|
609
|
+
}
|
|
610
|
+
Value::Array(arr)
|
|
611
|
+
}
|
|
612
|
+
other => other,
|
|
613
|
+
};
|
|
614
|
+
|
|
615
|
+
message["content"] = cached;
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
async fn send_with_retry(
|
|
619
|
+
client: &reqwest::Client,
|
|
620
|
+
url: &str,
|
|
621
|
+
headers: &HeaderMap,
|
|
622
|
+
body: &Value,
|
|
623
|
+
) -> Result<reqwest::Response> {
|
|
624
|
+
let mut attempt = 0usize;
|
|
625
|
+
loop {
|
|
626
|
+
match client
|
|
627
|
+
.post(url)
|
|
628
|
+
.headers(headers.clone())
|
|
629
|
+
.json(body)
|
|
630
|
+
.send()
|
|
631
|
+
.await
|
|
632
|
+
{
|
|
633
|
+
Ok(response) => {
|
|
634
|
+
if response.status().is_server_error() && attempt < MAX_RETRIES {
|
|
635
|
+
attempt += 1;
|
|
636
|
+
backoff(attempt).await;
|
|
637
|
+
continue;
|
|
638
|
+
}
|
|
639
|
+
return Ok(response);
|
|
640
|
+
}
|
|
641
|
+
Err(error) => {
|
|
642
|
+
let retryable = error.is_connect() || error.is_timeout();
|
|
643
|
+
if retryable && attempt < MAX_RETRIES {
|
|
644
|
+
attempt += 1;
|
|
645
|
+
backoff(attempt).await;
|
|
646
|
+
continue;
|
|
647
|
+
}
|
|
648
|
+
return Err(error.into());
|
|
649
|
+
}
|
|
650
|
+
}
|
|
651
|
+
}
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
async fn backoff(attempt: usize) {
|
|
655
|
+
let millis = 250u64 * (1u64 << (attempt - 1));
|
|
656
|
+
tokio::time::sleep(Duration::from_millis(millis)).await;
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
async fn stream_response(
|
|
660
|
+
mut response: reqwest::Response,
|
|
661
|
+
state: &mut StreamState,
|
|
662
|
+
events: &UnboundedSender<StreamEvent>,
|
|
663
|
+
) -> Result<()> {
|
|
664
|
+
let mut buffer = String::new();
|
|
665
|
+
|
|
666
|
+
while let Some(chunk) = response
|
|
667
|
+
.chunk()
|
|
668
|
+
.await
|
|
669
|
+
.context("failed to read streamed response chunk")?
|
|
670
|
+
{
|
|
671
|
+
buffer.push_str(&String::from_utf8_lossy(&chunk));
|
|
672
|
+
|
|
673
|
+
while let Some(newline) = buffer.find('\n') {
|
|
674
|
+
let line: String = buffer.drain(..=newline).collect();
|
|
675
|
+
if let Some(token) = state.ingest_line(line.trim_end_matches(['\r', '\n'])) {
|
|
676
|
+
let _ = events.send(StreamEvent::Token(token));
|
|
677
|
+
}
|
|
678
|
+
}
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
if !buffer.is_empty()
|
|
682
|
+
&& let Some(token) = state.ingest_line(buffer.trim())
|
|
683
|
+
{
|
|
684
|
+
let _ = events.send(StreamEvent::Token(token));
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
Ok(())
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
#[derive(Debug, Default)]
|
|
691
|
+
struct StreamState {
|
|
692
|
+
content: String,
|
|
693
|
+
tool_calls: Vec<PartialToolCall>,
|
|
694
|
+
usage: Option<Usage>,
|
|
695
|
+
done: bool,
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
#[derive(Debug, Default, Clone)]
|
|
699
|
+
struct PartialToolCall {
|
|
700
|
+
id: String,
|
|
701
|
+
name: String,
|
|
702
|
+
arguments: String,
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
impl StreamState {
|
|
706
|
+
/// Process a single SSE line, returning any new assistant text to display.
|
|
707
|
+
fn ingest_line(&mut self, line: &str) -> Option<String> {
|
|
708
|
+
if self.done {
|
|
709
|
+
return None;
|
|
710
|
+
}
|
|
711
|
+
let data = line.strip_prefix("data:")?.trim();
|
|
712
|
+
if data.is_empty() {
|
|
713
|
+
return None;
|
|
714
|
+
}
|
|
715
|
+
if data == "[DONE]" {
|
|
716
|
+
self.done = true;
|
|
717
|
+
return None;
|
|
718
|
+
}
|
|
719
|
+
let chunk: Value = serde_json::from_str(data).ok()?;
|
|
720
|
+
self.apply_chunk(&chunk)
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
fn apply_chunk(&mut self, chunk: &Value) -> Option<String> {
|
|
724
|
+
if let Some(usage) = chunk.get("usage").filter(|value| value.is_object()) {
|
|
725
|
+
self.usage = parse_usage(usage);
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
let delta = chunk.get("choices")?.get(0)?.get("delta")?;
|
|
729
|
+
|
|
730
|
+
if let Some(tool_calls) = delta.get("tool_calls").and_then(Value::as_array) {
|
|
731
|
+
for call in tool_calls {
|
|
732
|
+
self.apply_tool_call_delta(call);
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
let text = delta.get("content").and_then(Value::as_str)?;
|
|
737
|
+
if text.is_empty() {
|
|
738
|
+
return None;
|
|
739
|
+
}
|
|
740
|
+
self.content.push_str(text);
|
|
741
|
+
Some(text.to_string())
|
|
742
|
+
}
|
|
743
|
+
|
|
744
|
+
fn apply_tool_call_delta(&mut self, call: &Value) {
|
|
745
|
+
let index = call.get("index").and_then(Value::as_u64).unwrap_or(0) as usize;
|
|
746
|
+
while self.tool_calls.len() <= index {
|
|
747
|
+
self.tool_calls.push(PartialToolCall::default());
|
|
748
|
+
}
|
|
749
|
+
let slot = &mut self.tool_calls[index];
|
|
750
|
+
|
|
751
|
+
if let Some(id) = call.get("id").and_then(Value::as_str)
|
|
752
|
+
&& !id.is_empty()
|
|
753
|
+
{
|
|
754
|
+
slot.id = id.to_string();
|
|
755
|
+
}
|
|
756
|
+
if let Some(function) = call.get("function") {
|
|
757
|
+
if let Some(name) = function.get("name").and_then(Value::as_str) {
|
|
758
|
+
slot.name.push_str(name);
|
|
759
|
+
}
|
|
760
|
+
if let Some(arguments) = function.get("arguments").and_then(Value::as_str) {
|
|
761
|
+
slot.arguments.push_str(arguments);
|
|
762
|
+
}
|
|
763
|
+
}
|
|
764
|
+
}
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
fn parse_usage(value: &Value) -> Option<Usage> {
|
|
768
|
+
let object = value.as_object()?;
|
|
769
|
+
|
|
770
|
+
// Anthropic uses top-level cache fields; OpenAI nests them under prompt_tokens_details.
|
|
771
|
+
let details = object.get("prompt_tokens_details");
|
|
772
|
+
let cache_read_tokens = object
|
|
773
|
+
.get("cache_read_input_tokens")
|
|
774
|
+
.and_then(Value::as_u64)
|
|
775
|
+
.or_else(|| details?.get("cached_tokens").and_then(Value::as_u64))
|
|
776
|
+
.unwrap_or(0);
|
|
777
|
+
let cache_write_tokens = object
|
|
778
|
+
.get("cache_creation_input_tokens")
|
|
779
|
+
.and_then(Value::as_u64)
|
|
780
|
+
.unwrap_or(0);
|
|
781
|
+
|
|
782
|
+
Some(Usage {
|
|
783
|
+
prompt_tokens: object
|
|
784
|
+
.get("prompt_tokens")
|
|
785
|
+
.and_then(Value::as_u64)
|
|
786
|
+
.unwrap_or(0),
|
|
787
|
+
completion_tokens: object
|
|
788
|
+
.get("completion_tokens")
|
|
789
|
+
.and_then(Value::as_u64)
|
|
790
|
+
.unwrap_or(0),
|
|
791
|
+
total_tokens: object
|
|
792
|
+
.get("total_tokens")
|
|
793
|
+
.and_then(Value::as_u64)
|
|
794
|
+
.unwrap_or(0),
|
|
795
|
+
cache_read_tokens,
|
|
796
|
+
cache_write_tokens,
|
|
797
|
+
})
|
|
798
|
+
}
|
|
799
|
+
|
|
800
|
+
fn assistant_tool_message(state: &StreamState) -> Value {
|
|
801
|
+
let tool_calls = state
|
|
802
|
+
.tool_calls
|
|
803
|
+
.iter()
|
|
804
|
+
.filter(|call| !call.name.is_empty())
|
|
805
|
+
.map(|call| {
|
|
806
|
+
json!({
|
|
807
|
+
"id": call.id,
|
|
808
|
+
"type": "function",
|
|
809
|
+
"function": {
|
|
810
|
+
"name": call.name,
|
|
811
|
+
"arguments": call.arguments,
|
|
812
|
+
}
|
|
813
|
+
})
|
|
814
|
+
})
|
|
815
|
+
.collect::<Vec<_>>();
|
|
816
|
+
|
|
817
|
+
json!({
|
|
818
|
+
"role": "assistant",
|
|
819
|
+
"content": state.content,
|
|
820
|
+
"tool_calls": tool_calls,
|
|
821
|
+
})
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
fn is_tool_parameter_error(body: &str) -> bool {
|
|
825
|
+
let lower = body.to_lowercase();
|
|
826
|
+
lower.contains("tool") || lower.contains("function call")
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
fn is_stream_options_error(body: &str) -> bool {
|
|
830
|
+
let lower = body.to_lowercase();
|
|
831
|
+
lower.contains("stream_options") || lower.contains("include_usage")
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
#[cfg(test)]
|
|
835
|
+
mod tests {
|
|
836
|
+
use super::*;
|
|
837
|
+
|
|
838
|
+
fn chunk(content: &str) -> Value {
|
|
839
|
+
json!({ "choices": [{ "delta": { "content": content } }] })
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
#[test]
|
|
843
|
+
fn accumulates_content_tokens() {
|
|
844
|
+
let mut state = StreamState::default();
|
|
845
|
+
assert_eq!(state.apply_chunk(&chunk("Hel")), Some("Hel".to_string()));
|
|
846
|
+
assert_eq!(state.apply_chunk(&chunk("lo")), Some("lo".to_string()));
|
|
847
|
+
assert_eq!(state.content, "Hello");
|
|
848
|
+
assert!(state.tool_calls.is_empty());
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
#[test]
|
|
852
|
+
fn ignores_empty_content_delta() {
|
|
853
|
+
let mut state = StreamState::default();
|
|
854
|
+
assert_eq!(state.apply_chunk(&chunk("")), None);
|
|
855
|
+
assert_eq!(
|
|
856
|
+
state.apply_chunk(&json!({ "choices": [{ "delta": {} }] })),
|
|
857
|
+
None
|
|
858
|
+
);
|
|
859
|
+
}
|
|
860
|
+
|
|
861
|
+
#[test]
|
|
862
|
+
fn merges_fragmented_tool_call_deltas() {
|
|
863
|
+
let mut state = StreamState::default();
|
|
864
|
+
state.apply_chunk(&json!({
|
|
865
|
+
"choices": [{ "delta": { "tool_calls": [{
|
|
866
|
+
"index": 0, "id": "call_1", "function": { "name": "read_file", "arguments": "{\"pa" }
|
|
867
|
+
}] } }]
|
|
868
|
+
}));
|
|
869
|
+
state.apply_chunk(&json!({
|
|
870
|
+
"choices": [{ "delta": { "tool_calls": [{
|
|
871
|
+
"index": 0, "function": { "arguments": "th\":\"x\"}" }
|
|
872
|
+
}] } }]
|
|
873
|
+
}));
|
|
874
|
+
assert_eq!(state.tool_calls.len(), 1);
|
|
875
|
+
assert_eq!(state.tool_calls[0].id, "call_1");
|
|
876
|
+
assert_eq!(state.tool_calls[0].name, "read_file");
|
|
877
|
+
assert_eq!(state.tool_calls[0].arguments, "{\"path\":\"x\"}");
|
|
878
|
+
}
|
|
879
|
+
|
|
880
|
+
#[test]
|
|
881
|
+
fn parses_usage_chunk() {
|
|
882
|
+
let mut state = StreamState::default();
|
|
883
|
+
state.apply_chunk(&json!({
|
|
884
|
+
"choices": [],
|
|
885
|
+
"usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
|
|
886
|
+
}));
|
|
887
|
+
let usage = state.usage.expect("usage parsed");
|
|
888
|
+
assert_eq!(usage.prompt_tokens, 10);
|
|
889
|
+
assert_eq!(usage.completion_tokens, 5);
|
|
890
|
+
assert_eq!(usage.total_tokens, 15);
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
#[test]
|
|
894
|
+
fn ingest_line_handles_sse_framing() {
|
|
895
|
+
let mut state = StreamState::default();
|
|
896
|
+
assert_eq!(
|
|
897
|
+
state.ingest_line("data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}"),
|
|
898
|
+
Some("hi".to_string())
|
|
899
|
+
);
|
|
900
|
+
assert_eq!(state.ingest_line(""), None);
|
|
901
|
+
assert_eq!(state.ingest_line("data: [DONE]"), None);
|
|
902
|
+
assert!(state.done);
|
|
903
|
+
assert_eq!(
|
|
904
|
+
state.ingest_line("data: {\"choices\":[{\"delta\":{\"content\":\"x\"}}]}"),
|
|
905
|
+
None
|
|
906
|
+
);
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
#[test]
|
|
910
|
+
fn detects_parameter_errors() {
|
|
911
|
+
assert!(is_tool_parameter_error("This model does not support tools"));
|
|
912
|
+
assert!(is_stream_options_error("Unknown field stream_options"));
|
|
913
|
+
assert!(!is_stream_options_error("rate limit exceeded"));
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
#[test]
|
|
917
|
+
fn parses_tool_round_limit() {
|
|
918
|
+
assert_eq!(parse_tool_round_limit(None), DEFAULT_MAX_TOOL_ROUNDS);
|
|
919
|
+
assert_eq!(parse_tool_round_limit(Some("12")), 12);
|
|
920
|
+
assert_eq!(parse_tool_round_limit(Some("0")), DEFAULT_MAX_TOOL_ROUNDS);
|
|
921
|
+
assert_eq!(
|
|
922
|
+
parse_tool_round_limit(Some("nope")),
|
|
923
|
+
DEFAULT_MAX_TOOL_ROUNDS
|
|
924
|
+
);
|
|
925
|
+
assert_eq!(parse_tool_round_limit(Some("999")), HARD_MAX_TOOL_ROUNDS);
|
|
926
|
+
}
|
|
927
|
+
|
|
928
|
+
#[test]
|
|
929
|
+
fn tool_limit_message_forces_final_answer() {
|
|
930
|
+
let message = tool_limit_message(3);
|
|
931
|
+
assert_eq!(message["role"], json!("system"));
|
|
932
|
+
assert!(
|
|
933
|
+
message["content"]
|
|
934
|
+
.as_str()
|
|
935
|
+
.unwrap()
|
|
936
|
+
.contains("Do not call tools again")
|
|
937
|
+
);
|
|
938
|
+
}
|
|
939
|
+
}
|