anveesa 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,939 @@
1
+ use std::time::Duration;
2
+
3
+ use anyhow::{Context, Result, bail};
4
+ use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
5
+ use serde_json::{Value, json};
6
+ use tokio::sync::{mpsc::UnboundedSender, oneshot};
7
+
8
+ use crate::{
9
+ config::OpenAiCompatibleProviderConfig,
10
+ provider::{
11
+ ApprovalDecision, ApprovalPolicy, ChatRole, DiffKind, DiffLine, PromptRequest, StreamEvent,
12
+ ToolConfirmPreview, TurnResult, Usage,
13
+ },
14
+ tools,
15
+ };
16
+
17
+ const DEFAULT_MAX_TOOL_ROUNDS: usize = 32;
18
+ const HARD_MAX_TOOL_ROUNDS: usize = 256;
19
+ const MAX_TOOL_ROUNDS_ENV: &str = "ANVEESA_MAX_TOOL_ROUNDS";
20
+ const MAX_RETRIES: usize = 2;
21
+ const CONNECT_TIMEOUT: Duration = Duration::from_secs(15);
22
+ /// How many times the model may call the exact same (tool, arguments) pair before we refuse.
23
+ const MAX_IDENTICAL_CALLS: usize = 3;
24
+
25
+ pub async fn ask(
26
+ provider_name: &str,
27
+ config: &OpenAiCompatibleProviderConfig,
28
+ request: PromptRequest,
29
+ policy: ApprovalPolicy,
30
+ events: &UnboundedSender<StreamEvent>,
31
+ ) -> Result<TurnResult> {
32
+ let model = request
33
+ .model
34
+ .clone()
35
+ .or_else(|| config.default_model.clone())
36
+ .with_context(|| {
37
+ format!("provider '{provider_name}' requires --model or default_model in config")
38
+ })?;
39
+
40
+ // Use the explicit config flag if set; otherwise auto-enable for Anthropic endpoints.
41
+ let prompt_cache = config
42
+ .prompt_cache
43
+ .unwrap_or_else(|| is_anthropic_url(&config.base_url));
44
+ let headers = build_headers(config, prompt_cache)?;
45
+ let mut messages = build_messages(&request, policy, prompt_cache);
46
+
47
+ let client = reqwest::Client::builder()
48
+ .connect_timeout(CONNECT_TIMEOUT)
49
+ .build()
50
+ .context("failed to build HTTP client")?;
51
+ let url = format!("{}/chat/completions", config.base_url.trim_end_matches('/'));
52
+
53
+ let mut tools_enabled = true;
54
+ let mut usage_requested = true;
55
+ let mut tool_rounds = 0usize;
56
+ let max_tool_rounds = max_tool_rounds();
57
+ let mut approval_state = ToolApprovalState::default();
58
+ let mut full_text = String::new();
59
+ let mut last_usage: Option<Usage> = None;
60
+
61
+ loop {
62
+ let mut body = json!({
63
+ "model": model,
64
+ "messages": messages,
65
+ "stream": true,
66
+ });
67
+ if usage_requested {
68
+ body["stream_options"] = json!({ "include_usage": true });
69
+ }
70
+ if tools_enabled {
71
+ body["tools"] = json!(tools::definitions(policy.allows_write_tools()));
72
+ body["tool_choice"] = json!("auto");
73
+ }
74
+
75
+ let response = send_with_retry(&client, &url, &headers, &body)
76
+ .await
77
+ .with_context(|| format!("request to provider '{provider_name}' failed"))?;
78
+
79
+ let status = response.status();
80
+ if !status.is_success() {
81
+ let response_body = response.text().await.unwrap_or_default();
82
+ if tools_enabled && is_tool_parameter_error(&response_body) {
83
+ tools_enabled = false;
84
+ continue;
85
+ }
86
+ if usage_requested && is_stream_options_error(&response_body) {
87
+ usage_requested = false;
88
+ continue;
89
+ }
90
+ bail!("provider '{provider_name}' returned HTTP {status}: {response_body}");
91
+ }
92
+
93
+ let mut state = StreamState::default();
94
+ stream_response(response, &mut state, events).await?;
95
+
96
+ full_text.push_str(&state.content);
97
+ if let Some(usage) = state.usage {
98
+ last_usage = Some(usage);
99
+ }
100
+
101
+ if state.tool_calls.is_empty() {
102
+ break;
103
+ }
104
+
105
+ if !tools_enabled {
106
+ break;
107
+ }
108
+ tool_rounds += 1;
109
+
110
+ messages.push(assistant_tool_message(&state));
111
+ for call in &state.tool_calls {
112
+ let content = dispatch_tool(call, policy, &mut approval_state, events).await;
113
+ messages.push(json!({
114
+ "role": "tool",
115
+ "tool_call_id": call.id,
116
+ "name": call.name,
117
+ "content": content,
118
+ }));
119
+ }
120
+
121
+ if tool_rounds >= max_tool_rounds {
122
+ tools_enabled = false;
123
+ messages.push(tool_limit_message(max_tool_rounds));
124
+ }
125
+ }
126
+
127
+ if let Some(usage) = last_usage {
128
+ let _ = events.send(StreamEvent::Usage(usage));
129
+ }
130
+
131
+ Ok(TurnResult {
132
+ text: full_text,
133
+ usage: last_usage,
134
+ })
135
+ }
136
+
137
+ #[derive(Debug, Default)]
138
+ struct ToolApprovalState {
139
+ allow_for_turn: bool,
140
+ /// Tracks how many times each identical (name, arguments) pair has been called this turn.
141
+ call_counts: std::collections::HashMap<(String, String), usize>,
142
+ }
143
+
144
+ async fn dispatch_tool(
145
+ call: &PartialToolCall,
146
+ policy: ApprovalPolicy,
147
+ approval_state: &mut ToolApprovalState,
148
+ events: &UnboundedSender<StreamEvent>,
149
+ ) -> String {
150
+ // Plan tools — display only, no approval or filesystem access needed.
151
+ if call.name == "set_plan" {
152
+ if let Ok(args) = serde_json::from_str::<serde_json::Value>(&call.arguments) {
153
+ let tasks = args["steps"]
154
+ .as_array()
155
+ .map(|arr| {
156
+ arr.iter()
157
+ .filter_map(|v| v.as_str().map(str::to_string))
158
+ .collect()
159
+ })
160
+ .unwrap_or_default();
161
+ let _ = events.send(StreamEvent::PlanSet { tasks });
162
+ }
163
+ return json!({"ok": true}).to_string();
164
+ }
165
+ if call.name == "complete_task" {
166
+ if let Ok(args) = serde_json::from_str::<serde_json::Value>(&call.arguments) {
167
+ if let Some(index) = args["index"].as_u64() {
168
+ let _ = events.send(StreamEvent::PlanTaskDone {
169
+ index: index as usize,
170
+ });
171
+ }
172
+ }
173
+ return json!({"ok": true}).to_string();
174
+ }
175
+
176
+ // Anti-loop guard: refuse if the model is calling the exact same (tool, args) repeatedly.
177
+ {
178
+ let key = (call.name.clone(), call.arguments.clone());
179
+ let count = approval_state.call_counts.entry(key).or_insert(0);
180
+ *count += 1;
181
+ if *count > MAX_IDENTICAL_CALLS {
182
+ return json!({
183
+ "ok": false,
184
+ "error": format!(
185
+ "Refusing to run '{}' again: this identical call has already been made {} time(s) \
186
+ this turn. Do NOT retry — stop and report the failure to the user.",
187
+ call.name, *count - 1
188
+ )
189
+ })
190
+ .to_string();
191
+ }
192
+ }
193
+
194
+ if tools::is_write_tool(&call.name) {
195
+ if !policy.allows_write_tools() {
196
+ return denied_message("write tools are disabled (pass --yes or run interactively)");
197
+ }
198
+ }
199
+
200
+ // Snapshot BEFORE the tool runs — needed both for preview and for post-run diff.
201
+ let file_op_snapshot = capture_file_op_snapshot(&call.name, &call.arguments);
202
+
203
+ let mut preview_was_shown = false;
204
+
205
+ if tools::is_write_tool(&call.name)
206
+ && policy == ApprovalPolicy::Prompt
207
+ && !approval_state.allow_for_turn
208
+ {
209
+ let preview = build_confirm_preview(&call.name, &call.arguments, &file_op_snapshot);
210
+ preview_was_shown = true;
211
+ match request_approval_with_preview(preview, events).await {
212
+ ApprovalDecision::AllowOnce => {}
213
+ ApprovalDecision::AllowForTurn => approval_state.allow_for_turn = true,
214
+ ApprovalDecision::Deny => return denied_message("user declined this action"),
215
+ }
216
+ }
217
+
218
+ let result = tools::run(&call.name, &call.arguments).await;
219
+
220
+ // When the user already reviewed the diff in the approval preview, skip the
221
+ // post-run FileOp so the same diff isn't printed twice.
222
+ if !preview_was_shown {
223
+ if let Some(snapshot) = file_op_snapshot {
224
+ if let Ok(result_json) = serde_json::from_str::<serde_json::Value>(&result) {
225
+ if result_json["ok"].as_bool().unwrap_or(false) {
226
+ emit_file_op_event(snapshot, &result_json, events);
227
+ }
228
+ }
229
+ }
230
+ }
231
+
232
+ result
233
+ }
234
+
235
+ // ── File-op diff helpers ──────────────────────────────────────────────────────
236
+
237
+ enum FileOpSnapshot {
238
+ Write {
239
+ path: String,
240
+ lines: Vec<String>,
241
+ total: usize,
242
+ },
243
+ Edit {
244
+ path: String,
245
+ start_line: usize,
246
+ old_lines: Vec<String>,
247
+ new_lines: Vec<String>,
248
+ },
249
+ CreateDir {
250
+ path: String,
251
+ },
252
+ }
253
+
254
+ fn capture_file_op_snapshot(tool_name: &str, arguments: &str) -> Option<FileOpSnapshot> {
255
+ let args: serde_json::Value = serde_json::from_str(arguments).ok()?;
256
+ match tool_name {
257
+ "write_file" => {
258
+ let path = args["path"].as_str()?.to_string();
259
+ let content = args["content"].as_str().unwrap_or("");
260
+ let all: Vec<String> = content.lines().map(str::to_string).collect();
261
+ let total = all.len();
262
+ Some(FileOpSnapshot::Write {
263
+ path,
264
+ lines: all.into_iter().take(20).collect(),
265
+ total,
266
+ })
267
+ }
268
+ "edit_file" => {
269
+ let path = args["path"].as_str()?.to_string();
270
+ let old = args["old_string"].as_str().unwrap_or("");
271
+ let new = args["new_string"].as_str().unwrap_or("");
272
+ let start_line = std::fs::read_to_string(&path)
273
+ .ok()
274
+ .and_then(|content| {
275
+ let pos = content.find(old)?;
276
+ Some(content[..pos].lines().count() + 1)
277
+ })
278
+ .unwrap_or(1);
279
+ Some(FileOpSnapshot::Edit {
280
+ path,
281
+ start_line,
282
+ old_lines: old.lines().map(str::to_string).collect(),
283
+ new_lines: new.lines().map(str::to_string).collect(),
284
+ })
285
+ }
286
+ "create_dir" => {
287
+ let path = args["path"].as_str()?.to_string();
288
+ Some(FileOpSnapshot::CreateDir { path })
289
+ }
290
+ _ => None,
291
+ }
292
+ }
293
+
294
+ fn emit_file_op_event(
295
+ snapshot: FileOpSnapshot,
296
+ result: &serde_json::Value,
297
+ events: &tokio::sync::mpsc::UnboundedSender<StreamEvent>,
298
+ ) {
299
+ const MAX_PREVIEW: usize = 20;
300
+
301
+ let event = match snapshot {
302
+ FileOpSnapshot::Write { path, lines, total } => {
303
+ let verb = if result["created"].as_bool().unwrap_or(true) {
304
+ "Create"
305
+ } else {
306
+ "Update"
307
+ }
308
+ .to_string();
309
+ let truncated = total > MAX_PREVIEW;
310
+ let preview = lines
311
+ .into_iter()
312
+ .enumerate()
313
+ .map(|(i, text)| DiffLine {
314
+ kind: DiffKind::Add,
315
+ line_no: i + 1,
316
+ text,
317
+ })
318
+ .collect();
319
+ StreamEvent::FileOp {
320
+ verb,
321
+ path,
322
+ added: total,
323
+ removed: 0,
324
+ preview,
325
+ truncated,
326
+ }
327
+ }
328
+ FileOpSnapshot::Edit {
329
+ path,
330
+ start_line,
331
+ old_lines,
332
+ new_lines,
333
+ } => {
334
+ let added = new_lines.len();
335
+ let removed = old_lines.len();
336
+ // Cap: show at most MAX_PREVIEW removed + MAX_PREVIEW added
337
+ let cap = MAX_PREVIEW;
338
+ let truncated = old_lines.len() > cap || new_lines.len() > cap;
339
+ let mut preview: Vec<DiffLine> = old_lines
340
+ .into_iter()
341
+ .take(cap)
342
+ .enumerate()
343
+ .map(|(i, text)| DiffLine {
344
+ kind: DiffKind::Remove,
345
+ line_no: start_line + i,
346
+ text,
347
+ })
348
+ .collect();
349
+ for (i, text) in new_lines.into_iter().take(cap).enumerate() {
350
+ preview.push(DiffLine {
351
+ kind: DiffKind::Add,
352
+ line_no: start_line + i,
353
+ text,
354
+ });
355
+ }
356
+ StreamEvent::FileOp {
357
+ verb: "Update".to_string(),
358
+ path,
359
+ added,
360
+ removed,
361
+ preview,
362
+ truncated,
363
+ }
364
+ }
365
+ FileOpSnapshot::CreateDir { path } => StreamEvent::FileOp {
366
+ verb: "Create dir".to_string(),
367
+ path,
368
+ added: 0,
369
+ removed: 0,
370
+ preview: vec![],
371
+ truncated: false,
372
+ },
373
+ };
374
+
375
+ let _ = events.send(event);
376
+ }
377
+
378
+ fn max_tool_rounds() -> usize {
379
+ parse_tool_round_limit(std::env::var(MAX_TOOL_ROUNDS_ENV).ok().as_deref())
380
+ }
381
+
382
+ fn parse_tool_round_limit(value: Option<&str>) -> usize {
383
+ value
384
+ .and_then(|value| value.trim().parse::<usize>().ok())
385
+ .filter(|value| *value > 0)
386
+ .unwrap_or(DEFAULT_MAX_TOOL_ROUNDS)
387
+ .min(HARD_MAX_TOOL_ROUNDS)
388
+ }
389
+
390
+ fn tool_limit_message(max_tool_rounds: usize) -> Value {
391
+ json!({
392
+ "role": "system",
393
+ "content": format!(
394
+ "Anveesa has already run {max_tool_rounds} tool rounds for this answer. Do not call tools again. Use the tool results already provided to produce the best final answer. If the requested work is not complete, say exactly what remains."
395
+ )
396
+ })
397
+ }
398
+
399
+ fn denied_message(reason: &str) -> String {
400
+ json!({ "ok": false, "error": reason }).to_string()
401
+ }
402
+
403
+ fn build_confirm_preview(
404
+ tool_name: &str,
405
+ arguments: &str,
406
+ snapshot: &Option<FileOpSnapshot>,
407
+ ) -> ToolConfirmPreview {
408
+ const CAP: usize = 20;
409
+ match snapshot {
410
+ Some(FileOpSnapshot::Write { path, lines, total }) => {
411
+ let verb = if std::path::Path::new(path).exists() {
412
+ "Update"
413
+ } else {
414
+ "Create"
415
+ };
416
+ let truncated = *total > CAP;
417
+ let diff = lines
418
+ .iter()
419
+ .enumerate()
420
+ .map(|(i, text)| DiffLine {
421
+ kind: DiffKind::Add,
422
+ line_no: i + 1,
423
+ text: text.clone(),
424
+ })
425
+ .collect();
426
+ ToolConfirmPreview::FileOp {
427
+ verb: verb.to_string(),
428
+ path: path.clone(),
429
+ added: *total,
430
+ removed: 0,
431
+ diff,
432
+ truncated,
433
+ }
434
+ }
435
+ Some(FileOpSnapshot::Edit {
436
+ path,
437
+ start_line,
438
+ old_lines,
439
+ new_lines,
440
+ }) => {
441
+ let truncated = old_lines.len() > CAP || new_lines.len() > CAP;
442
+ let mut diff: Vec<DiffLine> = old_lines
443
+ .iter()
444
+ .take(CAP)
445
+ .enumerate()
446
+ .map(|(i, text)| DiffLine {
447
+ kind: DiffKind::Remove,
448
+ line_no: start_line + i,
449
+ text: text.clone(),
450
+ })
451
+ .collect();
452
+ for (i, text) in new_lines.iter().take(CAP).enumerate() {
453
+ diff.push(DiffLine {
454
+ kind: DiffKind::Add,
455
+ line_no: start_line + i,
456
+ text: text.clone(),
457
+ });
458
+ }
459
+ ToolConfirmPreview::FileOp {
460
+ verb: "Update".to_string(),
461
+ path: path.clone(),
462
+ added: new_lines.len(),
463
+ removed: old_lines.len(),
464
+ diff,
465
+ truncated,
466
+ }
467
+ }
468
+ Some(FileOpSnapshot::CreateDir { path }) => {
469
+ ToolConfirmPreview::CreateDir { path: path.clone() }
470
+ }
471
+ None => ToolConfirmPreview::Generic {
472
+ summary: tools::describe_call(tool_name, arguments),
473
+ },
474
+ }
475
+ }
476
+
477
+ async fn request_approval_with_preview(
478
+ preview: ToolConfirmPreview,
479
+ events: &UnboundedSender<StreamEvent>,
480
+ ) -> ApprovalDecision {
481
+ let (reply, answer) = oneshot::channel();
482
+ if events
483
+ .send(StreamEvent::Confirm { preview, reply })
484
+ .is_err()
485
+ {
486
+ return ApprovalDecision::Deny;
487
+ }
488
+ answer.await.unwrap_or(ApprovalDecision::Deny)
489
+ }
490
+
491
+ fn is_anthropic_url(base_url: &str) -> bool {
492
+ base_url.contains("anthropic.com")
493
+ }
494
+
495
+ fn build_headers(config: &OpenAiCompatibleProviderConfig, prompt_cache: bool) -> Result<HeaderMap> {
496
+ let mut headers = HeaderMap::new();
497
+ headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
498
+
499
+ if let Some(api_key_env) = &config.api_key_env {
500
+ let api_key = std::env::var(api_key_env)
501
+ .with_context(|| format!("environment variable {api_key_env} is required"))?;
502
+ headers.insert(
503
+ AUTHORIZATION,
504
+ HeaderValue::from_str(&format!("Bearer {api_key}"))
505
+ .context("failed to build authorization header")?,
506
+ );
507
+ }
508
+
509
+ for (name, value) in &config.headers {
510
+ headers.insert(
511
+ HeaderName::from_bytes(name.as_bytes())
512
+ .with_context(|| format!("invalid header name '{name}'"))?,
513
+ HeaderValue::from_str(value)
514
+ .with_context(|| format!("invalid header value for '{name}'"))?,
515
+ );
516
+ }
517
+
518
+ if prompt_cache {
519
+ headers.insert(
520
+ HeaderName::from_static("anthropic-beta"),
521
+ HeaderValue::from_static("prompt-caching-2024-07-31"),
522
+ );
523
+ }
524
+
525
+ Ok(headers)
526
+ }
527
+
528
+ fn build_messages(
529
+ request: &PromptRequest,
530
+ policy: ApprovalPolicy,
531
+ prompt_cache: bool,
532
+ ) -> Vec<Value> {
533
+ let mut messages = Vec::new();
534
+ if let Some(system) = &request.system {
535
+ messages.push(json!({ "role": "system", "content": system }));
536
+ }
537
+ if let Some(workspace_context) = &request.workspace_context {
538
+ messages.push(json!({ "role": "system", "content": workspace_context }));
539
+ }
540
+ messages
541
+ .push(json!({ "role": "system", "content": tools::guidance(policy.allows_write_tools()) }));
542
+ for message in &request.history {
543
+ let role = match message.role {
544
+ ChatRole::User => "user",
545
+ ChatRole::Assistant => "assistant",
546
+ };
547
+ messages.push(json!({ "role": role, "content": message.content }));
548
+ }
549
+
550
+ // Current user turn — multimodal when a clipboard image is attached.
551
+ let user_content = match &request.image {
552
+ Some(img) => json!([
553
+ { "type": "text", "text": &request.prompt },
554
+ { "type": "image_url", "image_url": { "url": format!("data:{};base64,{}", img.mime, img.data) } }
555
+ ]),
556
+ None => json!(&request.prompt),
557
+ };
558
+ messages.push(json!({ "role": "user", "content": user_content }));
559
+
560
+ if prompt_cache {
561
+ apply_cache_breakpoints(&mut messages);
562
+ }
563
+
564
+ messages
565
+ }
566
+
567
+ /// Add `cache_control: {type: "ephemeral"}` to two breakpoints in the message list:
568
+ /// 1. The last system message (tools guidance — large, fully static).
569
+ /// 2. The last history message before the current user turn (grows each turn).
570
+ ///
571
+ /// Everything up to each breakpoint is served from cache on subsequent turns.
572
+ fn apply_cache_breakpoints(messages: &mut Vec<Value>) {
573
+ let current_turn_idx = messages.len() - 1;
574
+
575
+ // Breakpoint 1: last system message
576
+ if let Some(idx) = messages[..current_turn_idx]
577
+ .iter()
578
+ .rposition(|m| m["role"] == "system")
579
+ {
580
+ add_cache_control(&mut messages[idx]);
581
+ }
582
+
583
+ // Breakpoint 2: last history message (user or assistant before the current turn)
584
+ if let Some(idx) = messages[..current_turn_idx]
585
+ .iter()
586
+ .rposition(|m| m["role"] != "system")
587
+ {
588
+ add_cache_control(&mut messages[idx]);
589
+ }
590
+ }
591
+
592
+ /// Convert a message's `content` to the array form required by Anthropic caching,
593
+ /// then inject `cache_control: {type: "ephemeral"}` on the last content block.
594
+ fn add_cache_control(message: &mut Value) {
595
+ let content = match message.get("content").cloned() {
596
+ Some(c) => c,
597
+ None => return,
598
+ };
599
+
600
+ let cached = match content {
601
+ Value::String(s) => json!([{
602
+ "type": "text",
603
+ "text": s,
604
+ "cache_control": { "type": "ephemeral" }
605
+ }]),
606
+ Value::Array(mut arr) => {
607
+ if let Some(last) = arr.last_mut().and_then(Value::as_object_mut) {
608
+ last.insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
609
+ }
610
+ Value::Array(arr)
611
+ }
612
+ other => other,
613
+ };
614
+
615
+ message["content"] = cached;
616
+ }
617
+
618
+ async fn send_with_retry(
619
+ client: &reqwest::Client,
620
+ url: &str,
621
+ headers: &HeaderMap,
622
+ body: &Value,
623
+ ) -> Result<reqwest::Response> {
624
+ let mut attempt = 0usize;
625
+ loop {
626
+ match client
627
+ .post(url)
628
+ .headers(headers.clone())
629
+ .json(body)
630
+ .send()
631
+ .await
632
+ {
633
+ Ok(response) => {
634
+ if response.status().is_server_error() && attempt < MAX_RETRIES {
635
+ attempt += 1;
636
+ backoff(attempt).await;
637
+ continue;
638
+ }
639
+ return Ok(response);
640
+ }
641
+ Err(error) => {
642
+ let retryable = error.is_connect() || error.is_timeout();
643
+ if retryable && attempt < MAX_RETRIES {
644
+ attempt += 1;
645
+ backoff(attempt).await;
646
+ continue;
647
+ }
648
+ return Err(error.into());
649
+ }
650
+ }
651
+ }
652
+ }
653
+
654
+ async fn backoff(attempt: usize) {
655
+ let millis = 250u64 * (1u64 << (attempt - 1));
656
+ tokio::time::sleep(Duration::from_millis(millis)).await;
657
+ }
658
+
659
+ async fn stream_response(
660
+ mut response: reqwest::Response,
661
+ state: &mut StreamState,
662
+ events: &UnboundedSender<StreamEvent>,
663
+ ) -> Result<()> {
664
+ let mut buffer = String::new();
665
+
666
+ while let Some(chunk) = response
667
+ .chunk()
668
+ .await
669
+ .context("failed to read streamed response chunk")?
670
+ {
671
+ buffer.push_str(&String::from_utf8_lossy(&chunk));
672
+
673
+ while let Some(newline) = buffer.find('\n') {
674
+ let line: String = buffer.drain(..=newline).collect();
675
+ if let Some(token) = state.ingest_line(line.trim_end_matches(['\r', '\n'])) {
676
+ let _ = events.send(StreamEvent::Token(token));
677
+ }
678
+ }
679
+ }
680
+
681
+ if !buffer.is_empty()
682
+ && let Some(token) = state.ingest_line(buffer.trim())
683
+ {
684
+ let _ = events.send(StreamEvent::Token(token));
685
+ }
686
+
687
+ Ok(())
688
+ }
689
+
690
+ #[derive(Debug, Default)]
691
+ struct StreamState {
692
+ content: String,
693
+ tool_calls: Vec<PartialToolCall>,
694
+ usage: Option<Usage>,
695
+ done: bool,
696
+ }
697
+
698
+ #[derive(Debug, Default, Clone)]
699
+ struct PartialToolCall {
700
+ id: String,
701
+ name: String,
702
+ arguments: String,
703
+ }
704
+
705
+ impl StreamState {
706
+ /// Process a single SSE line, returning any new assistant text to display.
707
+ fn ingest_line(&mut self, line: &str) -> Option<String> {
708
+ if self.done {
709
+ return None;
710
+ }
711
+ let data = line.strip_prefix("data:")?.trim();
712
+ if data.is_empty() {
713
+ return None;
714
+ }
715
+ if data == "[DONE]" {
716
+ self.done = true;
717
+ return None;
718
+ }
719
+ let chunk: Value = serde_json::from_str(data).ok()?;
720
+ self.apply_chunk(&chunk)
721
+ }
722
+
723
+ fn apply_chunk(&mut self, chunk: &Value) -> Option<String> {
724
+ if let Some(usage) = chunk.get("usage").filter(|value| value.is_object()) {
725
+ self.usage = parse_usage(usage);
726
+ }
727
+
728
+ let delta = chunk.get("choices")?.get(0)?.get("delta")?;
729
+
730
+ if let Some(tool_calls) = delta.get("tool_calls").and_then(Value::as_array) {
731
+ for call in tool_calls {
732
+ self.apply_tool_call_delta(call);
733
+ }
734
+ }
735
+
736
+ let text = delta.get("content").and_then(Value::as_str)?;
737
+ if text.is_empty() {
738
+ return None;
739
+ }
740
+ self.content.push_str(text);
741
+ Some(text.to_string())
742
+ }
743
+
744
+ fn apply_tool_call_delta(&mut self, call: &Value) {
745
+ let index = call.get("index").and_then(Value::as_u64).unwrap_or(0) as usize;
746
+ while self.tool_calls.len() <= index {
747
+ self.tool_calls.push(PartialToolCall::default());
748
+ }
749
+ let slot = &mut self.tool_calls[index];
750
+
751
+ if let Some(id) = call.get("id").and_then(Value::as_str)
752
+ && !id.is_empty()
753
+ {
754
+ slot.id = id.to_string();
755
+ }
756
+ if let Some(function) = call.get("function") {
757
+ if let Some(name) = function.get("name").and_then(Value::as_str) {
758
+ slot.name.push_str(name);
759
+ }
760
+ if let Some(arguments) = function.get("arguments").and_then(Value::as_str) {
761
+ slot.arguments.push_str(arguments);
762
+ }
763
+ }
764
+ }
765
+ }
766
+
767
+ fn parse_usage(value: &Value) -> Option<Usage> {
768
+ let object = value.as_object()?;
769
+
770
+ // Anthropic uses top-level cache fields; OpenAI nests them under prompt_tokens_details.
771
+ let details = object.get("prompt_tokens_details");
772
+ let cache_read_tokens = object
773
+ .get("cache_read_input_tokens")
774
+ .and_then(Value::as_u64)
775
+ .or_else(|| details?.get("cached_tokens").and_then(Value::as_u64))
776
+ .unwrap_or(0);
777
+ let cache_write_tokens = object
778
+ .get("cache_creation_input_tokens")
779
+ .and_then(Value::as_u64)
780
+ .unwrap_or(0);
781
+
782
+ Some(Usage {
783
+ prompt_tokens: object
784
+ .get("prompt_tokens")
785
+ .and_then(Value::as_u64)
786
+ .unwrap_or(0),
787
+ completion_tokens: object
788
+ .get("completion_tokens")
789
+ .and_then(Value::as_u64)
790
+ .unwrap_or(0),
791
+ total_tokens: object
792
+ .get("total_tokens")
793
+ .and_then(Value::as_u64)
794
+ .unwrap_or(0),
795
+ cache_read_tokens,
796
+ cache_write_tokens,
797
+ })
798
+ }
799
+
800
+ fn assistant_tool_message(state: &StreamState) -> Value {
801
+ let tool_calls = state
802
+ .tool_calls
803
+ .iter()
804
+ .filter(|call| !call.name.is_empty())
805
+ .map(|call| {
806
+ json!({
807
+ "id": call.id,
808
+ "type": "function",
809
+ "function": {
810
+ "name": call.name,
811
+ "arguments": call.arguments,
812
+ }
813
+ })
814
+ })
815
+ .collect::<Vec<_>>();
816
+
817
+ json!({
818
+ "role": "assistant",
819
+ "content": state.content,
820
+ "tool_calls": tool_calls,
821
+ })
822
+ }
823
+
824
+ fn is_tool_parameter_error(body: &str) -> bool {
825
+ let lower = body.to_lowercase();
826
+ lower.contains("tool") || lower.contains("function call")
827
+ }
828
+
829
+ fn is_stream_options_error(body: &str) -> bool {
830
+ let lower = body.to_lowercase();
831
+ lower.contains("stream_options") || lower.contains("include_usage")
832
+ }
833
+
834
+ #[cfg(test)]
835
+ mod tests {
836
+ use super::*;
837
+
838
+ fn chunk(content: &str) -> Value {
839
+ json!({ "choices": [{ "delta": { "content": content } }] })
840
+ }
841
+
842
+ #[test]
843
+ fn accumulates_content_tokens() {
844
+ let mut state = StreamState::default();
845
+ assert_eq!(state.apply_chunk(&chunk("Hel")), Some("Hel".to_string()));
846
+ assert_eq!(state.apply_chunk(&chunk("lo")), Some("lo".to_string()));
847
+ assert_eq!(state.content, "Hello");
848
+ assert!(state.tool_calls.is_empty());
849
+ }
850
+
851
+ #[test]
852
+ fn ignores_empty_content_delta() {
853
+ let mut state = StreamState::default();
854
+ assert_eq!(state.apply_chunk(&chunk("")), None);
855
+ assert_eq!(
856
+ state.apply_chunk(&json!({ "choices": [{ "delta": {} }] })),
857
+ None
858
+ );
859
+ }
860
+
861
+ #[test]
862
+ fn merges_fragmented_tool_call_deltas() {
863
+ let mut state = StreamState::default();
864
+ state.apply_chunk(&json!({
865
+ "choices": [{ "delta": { "tool_calls": [{
866
+ "index": 0, "id": "call_1", "function": { "name": "read_file", "arguments": "{\"pa" }
867
+ }] } }]
868
+ }));
869
+ state.apply_chunk(&json!({
870
+ "choices": [{ "delta": { "tool_calls": [{
871
+ "index": 0, "function": { "arguments": "th\":\"x\"}" }
872
+ }] } }]
873
+ }));
874
+ assert_eq!(state.tool_calls.len(), 1);
875
+ assert_eq!(state.tool_calls[0].id, "call_1");
876
+ assert_eq!(state.tool_calls[0].name, "read_file");
877
+ assert_eq!(state.tool_calls[0].arguments, "{\"path\":\"x\"}");
878
+ }
879
+
880
+ #[test]
881
+ fn parses_usage_chunk() {
882
+ let mut state = StreamState::default();
883
+ state.apply_chunk(&json!({
884
+ "choices": [],
885
+ "usage": { "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15 }
886
+ }));
887
+ let usage = state.usage.expect("usage parsed");
888
+ assert_eq!(usage.prompt_tokens, 10);
889
+ assert_eq!(usage.completion_tokens, 5);
890
+ assert_eq!(usage.total_tokens, 15);
891
+ }
892
+
893
+ #[test]
894
+ fn ingest_line_handles_sse_framing() {
895
+ let mut state = StreamState::default();
896
+ assert_eq!(
897
+ state.ingest_line("data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}"),
898
+ Some("hi".to_string())
899
+ );
900
+ assert_eq!(state.ingest_line(""), None);
901
+ assert_eq!(state.ingest_line("data: [DONE]"), None);
902
+ assert!(state.done);
903
+ assert_eq!(
904
+ state.ingest_line("data: {\"choices\":[{\"delta\":{\"content\":\"x\"}}]}"),
905
+ None
906
+ );
907
+ }
908
+
909
+ #[test]
910
+ fn detects_parameter_errors() {
911
+ assert!(is_tool_parameter_error("This model does not support tools"));
912
+ assert!(is_stream_options_error("Unknown field stream_options"));
913
+ assert!(!is_stream_options_error("rate limit exceeded"));
914
+ }
915
+
916
+ #[test]
917
+ fn parses_tool_round_limit() {
918
+ assert_eq!(parse_tool_round_limit(None), DEFAULT_MAX_TOOL_ROUNDS);
919
+ assert_eq!(parse_tool_round_limit(Some("12")), 12);
920
+ assert_eq!(parse_tool_round_limit(Some("0")), DEFAULT_MAX_TOOL_ROUNDS);
921
+ assert_eq!(
922
+ parse_tool_round_limit(Some("nope")),
923
+ DEFAULT_MAX_TOOL_ROUNDS
924
+ );
925
+ assert_eq!(parse_tool_round_limit(Some("999")), HARD_MAX_TOOL_ROUNDS);
926
+ }
927
+
928
+ #[test]
929
+ fn tool_limit_message_forces_final_answer() {
930
+ let message = tool_limit_message(3);
931
+ assert_eq!(message["role"], json!("system"));
932
+ assert!(
933
+ message["content"]
934
+ .as_str()
935
+ .unwrap()
936
+ .contains("Do not call tools again")
937
+ );
938
+ }
939
+ }