@agi-cli/server 0.1.57 → 0.1.59
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +3 -3
- package/src/index.ts +30 -21
- package/src/runtime/agent-registry.ts +3 -2
- package/src/runtime/cache-optimizer.ts +115 -0
- package/src/runtime/context-optimizer.ts +192 -0
- package/src/runtime/db-operations.ts +154 -15
- package/src/runtime/runner.ts +29 -4
- package/src/runtime/session-manager.ts +2 -0
- package/src/runtime/stream-handlers.ts +61 -12
- package/src/tools/adapter.ts +261 -173
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@agi-cli/server",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.59",
|
|
4
4
|
"description": "HTTP API server for AGI CLI",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "./src/index.ts",
|
|
@@ -29,8 +29,8 @@
|
|
|
29
29
|
"typecheck": "tsc --noEmit"
|
|
30
30
|
},
|
|
31
31
|
"dependencies": {
|
|
32
|
-
"@agi-cli/sdk": "0.1.
|
|
33
|
-
"@agi-cli/database": "0.1.
|
|
32
|
+
"@agi-cli/sdk": "0.1.59",
|
|
33
|
+
"@agi-cli/database": "0.1.59",
|
|
34
34
|
"drizzle-orm": "^0.44.5",
|
|
35
35
|
"hono": "^4.9.9"
|
|
36
36
|
},
|
package/src/index.ts
CHANGED
|
@@ -14,7 +14,7 @@ import type { AgentConfigEntry } from './runtime/agent-registry.ts';
|
|
|
14
14
|
function initApp() {
|
|
15
15
|
const app = new Hono();
|
|
16
16
|
|
|
17
|
-
// Enable CORS for
|
|
17
|
+
// Enable CORS for localhost and local network access
|
|
18
18
|
app.use(
|
|
19
19
|
'*',
|
|
20
20
|
cors({
|
|
@@ -22,15 +22,16 @@ function initApp() {
|
|
|
22
22
|
// Allow all localhost and 127.0.0.1 on any port
|
|
23
23
|
if (
|
|
24
24
|
origin.startsWith('http://localhost:') ||
|
|
25
|
-
origin.startsWith('http://127.0.0.1:')
|
|
25
|
+
origin.startsWith('http://127.0.0.1:') ||
|
|
26
|
+
origin.startsWith('https://localhost:') ||
|
|
27
|
+
origin.startsWith('https://127.0.0.1:')
|
|
26
28
|
) {
|
|
27
29
|
return origin;
|
|
28
30
|
}
|
|
29
|
-
// Allow
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
) {
|
|
31
|
+
// Allow local network IPs (192.168.x.x, 10.x.x.x, 172.16-31.x.x)
|
|
32
|
+
const localNetworkPattern =
|
|
33
|
+
/^https?:\/\/(192\.168\.\d{1,3}\.\d{1,3}|10\.\d{1,3}\.\d{1,3}\.\d{1,3}|172\.(1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}):\d+$/;
|
|
34
|
+
if (localNetworkPattern.test(origin)) {
|
|
34
35
|
return origin;
|
|
35
36
|
}
|
|
36
37
|
// Default to allowing the origin (can be restricted in production)
|
|
@@ -76,7 +77,7 @@ export type StandaloneAppConfig = {
|
|
|
76
77
|
export function createStandaloneApp(_config?: StandaloneAppConfig) {
|
|
77
78
|
const honoApp = new Hono();
|
|
78
79
|
|
|
79
|
-
// Enable CORS for
|
|
80
|
+
// Enable CORS for localhost and local network access
|
|
80
81
|
honoApp.use(
|
|
81
82
|
'*',
|
|
82
83
|
cors({
|
|
@@ -84,15 +85,16 @@ export function createStandaloneApp(_config?: StandaloneAppConfig) {
|
|
|
84
85
|
// Allow all localhost and 127.0.0.1 on any port
|
|
85
86
|
if (
|
|
86
87
|
origin.startsWith('http://localhost:') ||
|
|
87
|
-
origin.startsWith('http://127.0.0.1:')
|
|
88
|
+
origin.startsWith('http://127.0.0.1:') ||
|
|
89
|
+
origin.startsWith('https://localhost:') ||
|
|
90
|
+
origin.startsWith('https://127.0.0.1:')
|
|
88
91
|
) {
|
|
89
92
|
return origin;
|
|
90
93
|
}
|
|
91
|
-
// Allow
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
) {
|
|
94
|
+
// Allow local network IPs (192.168.x.x, 10.x.x.x, 172.16-31.x.x)
|
|
95
|
+
const localNetworkPattern =
|
|
96
|
+
/^https?:\/\/(192\.168\.\d{1,3}\.\d{1,3}|10\.\d{1,3}\.\d{1,3}\.\d{1,3}|172\.(1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}):\d+$/;
|
|
97
|
+
if (localNetworkPattern.test(origin)) {
|
|
96
98
|
return origin;
|
|
97
99
|
}
|
|
98
100
|
// Default to allowing the origin
|
|
@@ -148,6 +150,8 @@ export type EmbeddedAppConfig = {
|
|
|
148
150
|
model?: string;
|
|
149
151
|
agent?: string;
|
|
150
152
|
};
|
|
153
|
+
/** Additional CORS origins for proxies/Tailscale (e.g., ['https://myapp.ts.net', 'https://example.com']) */
|
|
154
|
+
corsOrigins?: string[];
|
|
151
155
|
};
|
|
152
156
|
|
|
153
157
|
export function createEmbeddedApp(config: EmbeddedAppConfig = {}) {
|
|
@@ -160,7 +164,7 @@ export function createEmbeddedApp(config: EmbeddedAppConfig = {}) {
|
|
|
160
164
|
await next();
|
|
161
165
|
});
|
|
162
166
|
|
|
163
|
-
// Enable CORS for
|
|
167
|
+
// Enable CORS for localhost and local network access
|
|
164
168
|
honoApp.use(
|
|
165
169
|
'*',
|
|
166
170
|
cors({
|
|
@@ -168,15 +172,20 @@ export function createEmbeddedApp(config: EmbeddedAppConfig = {}) {
|
|
|
168
172
|
// Allow all localhost and 127.0.0.1 on any port
|
|
169
173
|
if (
|
|
170
174
|
origin.startsWith('http://localhost:') ||
|
|
171
|
-
origin.startsWith('http://127.0.0.1:')
|
|
175
|
+
origin.startsWith('http://127.0.0.1:') ||
|
|
176
|
+
origin.startsWith('https://localhost:') ||
|
|
177
|
+
origin.startsWith('https://127.0.0.1:')
|
|
172
178
|
) {
|
|
173
179
|
return origin;
|
|
174
180
|
}
|
|
175
|
-
// Allow
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
181
|
+
// Allow local network IPs (192.168.x.x, 10.x.x.x, 172.16-31.x.x)
|
|
182
|
+
const localNetworkPattern =
|
|
183
|
+
/^https?:\/\/(192\.168\.\d{1,3}\.\d{1,3}|10\.\d{1,3}\.\d{1,3}\.\d{1,3}|172\.(1[6-9]|2\d|3[01])\.\d{1,3}\.\d{1,3}):\d+$/;
|
|
184
|
+
if (localNetworkPattern.test(origin)) {
|
|
185
|
+
return origin;
|
|
186
|
+
}
|
|
187
|
+
// Allow custom CORS origins (for Tailscale, proxies, etc.)
|
|
188
|
+
if (config.corsOrigins?.includes(origin)) {
|
|
180
189
|
return origin;
|
|
181
190
|
}
|
|
182
191
|
// Default to allowing the origin
|
|
@@ -119,10 +119,10 @@ const defaultToolExtras: Record<string, string[]> = {
|
|
|
119
119
|
'tree',
|
|
120
120
|
'bash',
|
|
121
121
|
'update_plan',
|
|
122
|
-
'
|
|
122
|
+
'glob',
|
|
123
|
+
'ripgrep',
|
|
123
124
|
'git_status',
|
|
124
125
|
'git_diff',
|
|
125
|
-
'ripgrep',
|
|
126
126
|
'apply_patch',
|
|
127
127
|
'websearch',
|
|
128
128
|
],
|
|
@@ -134,6 +134,7 @@ const defaultToolExtras: Record<string, string[]> = {
|
|
|
134
134
|
'tree',
|
|
135
135
|
'bash',
|
|
136
136
|
'ripgrep',
|
|
137
|
+
'glob',
|
|
137
138
|
'websearch',
|
|
138
139
|
'update_plan',
|
|
139
140
|
],
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import type { ModelMessage } from 'ai';
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* Adds cache control to messages for prompt caching optimization.
|
|
5
|
+
* Anthropic supports caching for system messages, tools, and long context.
|
|
6
|
+
*/
|
|
7
|
+
export function addCacheControl(
|
|
8
|
+
provider: string,
|
|
9
|
+
system: string | undefined,
|
|
10
|
+
messages: ModelMessage[],
|
|
11
|
+
): {
|
|
12
|
+
system?:
|
|
13
|
+
| string
|
|
14
|
+
| Array<{
|
|
15
|
+
type: 'text';
|
|
16
|
+
text: string;
|
|
17
|
+
cache_control?: { type: 'ephemeral' };
|
|
18
|
+
}>;
|
|
19
|
+
messages: ModelMessage[];
|
|
20
|
+
} {
|
|
21
|
+
// Only Anthropic supports prompt caching currently
|
|
22
|
+
if (provider !== 'anthropic') {
|
|
23
|
+
return { system, messages };
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
// Convert system to cacheable format if it's long enough
|
|
27
|
+
let cachedSystem: any = system;
|
|
28
|
+
if (system && system.length > 1024) {
|
|
29
|
+
// Anthropic requires 1024+ tokens for Claude Sonnet/Opus
|
|
30
|
+
cachedSystem = [
|
|
31
|
+
{
|
|
32
|
+
type: 'text',
|
|
33
|
+
text: system,
|
|
34
|
+
cache_control: { type: 'ephemeral' as const },
|
|
35
|
+
},
|
|
36
|
+
];
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
// Anthropic cache_control limits:
|
|
40
|
+
// - Max 4 cache blocks total
|
|
41
|
+
// - System message: 1 block
|
|
42
|
+
// - Tools: 2 blocks (read, write)
|
|
43
|
+
// - Last user message: 1 block
|
|
44
|
+
// Total: 4 blocks
|
|
45
|
+
|
|
46
|
+
// Add cache control to the last user message if conversation is long
|
|
47
|
+
// This caches the conversation history up to that point
|
|
48
|
+
if (messages.length >= 3) {
|
|
49
|
+
const cachedMessages = [...messages];
|
|
50
|
+
|
|
51
|
+
// Find second-to-last user message (not the current one)
|
|
52
|
+
const userIndices = cachedMessages
|
|
53
|
+
.map((m, i) => (m.role === 'user' ? i : -1))
|
|
54
|
+
.filter((i) => i >= 0);
|
|
55
|
+
|
|
56
|
+
if (userIndices.length >= 2) {
|
|
57
|
+
const targetIndex = userIndices[userIndices.length - 2];
|
|
58
|
+
const targetMsg = cachedMessages[targetIndex];
|
|
59
|
+
|
|
60
|
+
if (Array.isArray(targetMsg.content)) {
|
|
61
|
+
// Add cache control to the last content part of that message
|
|
62
|
+
const lastPart = targetMsg.content[targetMsg.content.length - 1];
|
|
63
|
+
if (lastPart && typeof lastPart === 'object' && 'type' in lastPart) {
|
|
64
|
+
(lastPart as any).providerOptions = {
|
|
65
|
+
anthropic: { cacheControl: { type: 'ephemeral' } },
|
|
66
|
+
};
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
return { system: cachedSystem, messages: cachedMessages };
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
return { system: cachedSystem, messages };
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
/**
|
|
78
|
+
* Truncates old messages to reduce context size while keeping recent context.
|
|
79
|
+
* Strategy: Keep system message + last N messages
|
|
80
|
+
*/
|
|
81
|
+
export function truncateHistory(
|
|
82
|
+
messages: ModelMessage[],
|
|
83
|
+
maxMessages = 20,
|
|
84
|
+
): ModelMessage[] {
|
|
85
|
+
if (messages.length <= maxMessages) {
|
|
86
|
+
return messages;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
// Keep the most recent messages
|
|
90
|
+
return messages.slice(-maxMessages);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/**
|
|
94
|
+
* Estimates token count (rough approximation: ~4 chars per token)
|
|
95
|
+
*/
|
|
96
|
+
export function estimateTokens(text: string): number {
|
|
97
|
+
return Math.ceil(text.length / 4);
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
/**
|
|
101
|
+
* Summarizes tool results if they're too long
|
|
102
|
+
*/
|
|
103
|
+
export function summarizeToolResult(result: unknown, maxLength = 5000): string {
|
|
104
|
+
const str = typeof result === 'string' ? result : JSON.stringify(result);
|
|
105
|
+
|
|
106
|
+
if (str.length <= maxLength) {
|
|
107
|
+
return str;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// Truncate and add indicator
|
|
111
|
+
return (
|
|
112
|
+
str.slice(0, maxLength) +
|
|
113
|
+
`\n\n[... truncated ${str.length - maxLength} characters]`
|
|
114
|
+
);
|
|
115
|
+
}
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import type { ModelMessage } from 'ai';
|
|
2
|
+
|
|
3
|
+
/**
|
|
4
|
+
* Optimizes message context by deduplicating file reads and pruning old tool results.
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
interface FileRead {
|
|
8
|
+
messageIndex: number;
|
|
9
|
+
partIndex: number;
|
|
10
|
+
path: string;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* Deduplicates file read results, keeping only the latest version of each file.
|
|
15
|
+
*
|
|
16
|
+
* Strategy:
|
|
17
|
+
* - Track all file reads (read, grep, glob tools)
|
|
18
|
+
* - For files read multiple times, remove older results
|
|
19
|
+
* - Keep only the most recent read of each file
|
|
20
|
+
*/
|
|
21
|
+
export function deduplicateFileReads(messages: ModelMessage[]): ModelMessage[] {
|
|
22
|
+
const fileReads = new Map<string, FileRead[]>();
|
|
23
|
+
|
|
24
|
+
// First pass: identify all file reads and their locations
|
|
25
|
+
messages.forEach((msg, msgIdx) => {
|
|
26
|
+
if (msg.role !== 'assistant' || !Array.isArray(msg.content)) return;
|
|
27
|
+
|
|
28
|
+
msg.content.forEach((part, partIdx) => {
|
|
29
|
+
if (!part || typeof part !== 'object') return;
|
|
30
|
+
if (!('type' in part)) return;
|
|
31
|
+
|
|
32
|
+
const toolType = part.type as string;
|
|
33
|
+
|
|
34
|
+
// Check if this is a file read tool (read, grep, glob)
|
|
35
|
+
if (!toolType.startsWith('tool-')) return;
|
|
36
|
+
|
|
37
|
+
const toolName = toolType.replace('tool-', '');
|
|
38
|
+
if (!['read', 'grep', 'glob'].includes(toolName)) return;
|
|
39
|
+
|
|
40
|
+
// Extract file path from input
|
|
41
|
+
const input = (part as any).input;
|
|
42
|
+
if (!input) return;
|
|
43
|
+
|
|
44
|
+
const path = input.path || input.filePattern || input.pattern;
|
|
45
|
+
if (!path) return;
|
|
46
|
+
|
|
47
|
+
// Track this file read
|
|
48
|
+
if (!fileReads.has(path)) {
|
|
49
|
+
fileReads.set(path, []);
|
|
50
|
+
}
|
|
51
|
+
fileReads
|
|
52
|
+
.get(path)!
|
|
53
|
+
.push({ messageIndex: msgIdx, partIndex: partIdx, path });
|
|
54
|
+
});
|
|
55
|
+
});
|
|
56
|
+
|
|
57
|
+
// Second pass: identify reads to remove (all but the latest for each file)
|
|
58
|
+
const readsToRemove = new Set<string>();
|
|
59
|
+
|
|
60
|
+
for (const [_path, reads] of fileReads) {
|
|
61
|
+
if (reads.length <= 1) continue;
|
|
62
|
+
|
|
63
|
+
// Sort by message index descending (latest first)
|
|
64
|
+
reads.sort((a, b) => b.messageIndex - a.messageIndex);
|
|
65
|
+
|
|
66
|
+
// Remove all but the first (latest)
|
|
67
|
+
for (let i = 1; i < reads.length; i++) {
|
|
68
|
+
const read = reads[i];
|
|
69
|
+
readsToRemove.add(`${read.messageIndex}-${read.partIndex}`);
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
// Third pass: rebuild messages without removed reads
|
|
74
|
+
return messages.map((msg, msgIdx) => {
|
|
75
|
+
if (msg.role !== 'assistant' || !Array.isArray(msg.content)) return msg;
|
|
76
|
+
|
|
77
|
+
const filteredContent = msg.content.filter((_part, partIdx) => {
|
|
78
|
+
const key = `${msgIdx}-${partIdx}`;
|
|
79
|
+
return !readsToRemove.has(key);
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
return {
|
|
83
|
+
...msg,
|
|
84
|
+
content: filteredContent,
|
|
85
|
+
};
|
|
86
|
+
});
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
/**
|
|
90
|
+
* Prunes old tool results to reduce context size.
|
|
91
|
+
*
|
|
92
|
+
* Strategy:
|
|
93
|
+
* - Keep only the last N tool results
|
|
94
|
+
* - Preserve tool calls but remove their output
|
|
95
|
+
* - Keep text parts intact
|
|
96
|
+
*/
|
|
97
|
+
export function pruneToolResults(
|
|
98
|
+
messages: ModelMessage[],
|
|
99
|
+
maxToolResults = 30,
|
|
100
|
+
): ModelMessage[] {
|
|
101
|
+
// Collect all tool result locations
|
|
102
|
+
const toolResults: Array<{ messageIndex: number; partIndex: number }> = [];
|
|
103
|
+
|
|
104
|
+
messages.forEach((msg, msgIdx) => {
|
|
105
|
+
if (msg.role !== 'assistant' || !Array.isArray(msg.content)) return;
|
|
106
|
+
|
|
107
|
+
msg.content.forEach((part, partIdx) => {
|
|
108
|
+
if (!part || typeof part !== 'object') return;
|
|
109
|
+
if (!('type' in part)) return;
|
|
110
|
+
|
|
111
|
+
const toolType = part.type as string;
|
|
112
|
+
if (!toolType.startsWith('tool-')) return;
|
|
113
|
+
|
|
114
|
+
// Check if this has output
|
|
115
|
+
const hasOutput = (part as any).output !== undefined;
|
|
116
|
+
if (!hasOutput) return;
|
|
117
|
+
|
|
118
|
+
toolResults.push({ messageIndex: msgIdx, partIndex: partIdx });
|
|
119
|
+
});
|
|
120
|
+
});
|
|
121
|
+
|
|
122
|
+
// If under limit, no pruning needed
|
|
123
|
+
if (toolResults.length <= maxToolResults) {
|
|
124
|
+
return messages;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// Keep only the last N tool results
|
|
128
|
+
const toKeep = new Set<string>();
|
|
129
|
+
const keepCount = Math.min(maxToolResults, toolResults.length);
|
|
130
|
+
const keepStart = toolResults.length - keepCount;
|
|
131
|
+
|
|
132
|
+
for (let i = keepStart; i < toolResults.length; i++) {
|
|
133
|
+
const result = toolResults[i];
|
|
134
|
+
toKeep.add(`${result.messageIndex}-${result.partIndex}`);
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
// Rebuild messages, removing old tool outputs
|
|
138
|
+
return messages.map((msg, msgIdx) => {
|
|
139
|
+
if (msg.role !== 'assistant' || !Array.isArray(msg.content)) return msg;
|
|
140
|
+
|
|
141
|
+
const processedContent = msg.content.map((part, partIdx) => {
|
|
142
|
+
if (!part || typeof part !== 'object') return part;
|
|
143
|
+
if (!('type' in part)) return part;
|
|
144
|
+
|
|
145
|
+
const toolType = (part as any).type as string;
|
|
146
|
+
if (!toolType.startsWith('tool-')) return part;
|
|
147
|
+
|
|
148
|
+
const key = `${msgIdx}-${partIdx}`;
|
|
149
|
+
const hasOutput = (part as any).output !== undefined;
|
|
150
|
+
|
|
151
|
+
// If this tool result should be pruned, remove its output
|
|
152
|
+
if (hasOutput && !toKeep.has(key)) {
|
|
153
|
+
return {
|
|
154
|
+
...part,
|
|
155
|
+
output: '[pruned to save context]',
|
|
156
|
+
};
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
return part;
|
|
160
|
+
});
|
|
161
|
+
|
|
162
|
+
return {
|
|
163
|
+
...msg,
|
|
164
|
+
content: processedContent,
|
|
165
|
+
};
|
|
166
|
+
});
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
/**
|
|
170
|
+
* Applies all context optimizations:
|
|
171
|
+
* 1. Deduplicate file reads
|
|
172
|
+
* 2. Prune old tool results
|
|
173
|
+
*/
|
|
174
|
+
export function optimizeContext(
|
|
175
|
+
messages: ModelMessage[],
|
|
176
|
+
options: {
|
|
177
|
+
deduplicateFiles?: boolean;
|
|
178
|
+
maxToolResults?: number;
|
|
179
|
+
} = {},
|
|
180
|
+
): ModelMessage[] {
|
|
181
|
+
let optimized = messages;
|
|
182
|
+
|
|
183
|
+
if (options.deduplicateFiles !== false) {
|
|
184
|
+
optimized = deduplicateFileReads(optimized);
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
if (options.maxToolResults !== undefined) {
|
|
188
|
+
optimized = pruneToolResults(optimized, options.maxToolResults);
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
return optimized;
|
|
192
|
+
}
|
|
@@ -3,8 +3,96 @@ import { messages, messageParts, sessions } from '@agi-cli/database/schema';
|
|
|
3
3
|
import { eq } from 'drizzle-orm';
|
|
4
4
|
import type { RunOpts } from './session-queue.ts';
|
|
5
5
|
|
|
6
|
+
type UsageData = {
|
|
7
|
+
inputTokens?: number;
|
|
8
|
+
outputTokens?: number;
|
|
9
|
+
totalTokens?: number;
|
|
10
|
+
cachedInputTokens?: number;
|
|
11
|
+
reasoningTokens?: number;
|
|
12
|
+
};
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Updates session token counts incrementally after each step.
|
|
16
|
+
* Note: onStepFinish.usage is CUMULATIVE per message, so we compute DELTA and add to session.
|
|
17
|
+
*/
|
|
18
|
+
export async function updateSessionTokensIncremental(
|
|
19
|
+
usage: UsageData,
|
|
20
|
+
providerMetadata: Record<string, any> | undefined,
|
|
21
|
+
opts: RunOpts,
|
|
22
|
+
db: Awaited<ReturnType<typeof getDb>>,
|
|
23
|
+
) {
|
|
24
|
+
if (!usage) return;
|
|
25
|
+
|
|
26
|
+
// Read session totals
|
|
27
|
+
const sessRows = await db
|
|
28
|
+
.select()
|
|
29
|
+
.from(sessions)
|
|
30
|
+
.where(eq(sessions.id, opts.sessionId));
|
|
31
|
+
|
|
32
|
+
if (sessRows.length === 0 || !sessRows[0]) return;
|
|
33
|
+
|
|
34
|
+
const sess = sessRows[0];
|
|
35
|
+
const priorInputSess = Number(sess.totalInputTokens ?? 0);
|
|
36
|
+
const priorOutputSess = Number(sess.totalOutputTokens ?? 0);
|
|
37
|
+
const priorCachedSess = Number(sess.totalCachedTokens ?? 0);
|
|
38
|
+
const priorReasoningSess = Number(sess.totalReasoningTokens ?? 0);
|
|
39
|
+
|
|
40
|
+
// Read current message totals to compute delta
|
|
41
|
+
const msgRows = await db
|
|
42
|
+
.select()
|
|
43
|
+
.from(messages)
|
|
44
|
+
.where(eq(messages.id, opts.assistantMessageId));
|
|
45
|
+
|
|
46
|
+
const msg = msgRows[0];
|
|
47
|
+
const priorPromptMsg = Number(msg?.promptTokens ?? 0);
|
|
48
|
+
const priorCompletionMsg = Number(msg?.completionTokens ?? 0);
|
|
49
|
+
const priorCachedMsg = Number(msg?.cachedInputTokens ?? 0);
|
|
50
|
+
const priorReasoningMsg = Number(msg?.reasoningTokens ?? 0);
|
|
51
|
+
|
|
52
|
+
// Treat usage as cumulative per-message for this step
|
|
53
|
+
const cumPrompt =
|
|
54
|
+
usage.inputTokens != null ? Number(usage.inputTokens) : priorPromptMsg;
|
|
55
|
+
const cumCompletion =
|
|
56
|
+
usage.outputTokens != null
|
|
57
|
+
? Number(usage.outputTokens)
|
|
58
|
+
: priorCompletionMsg;
|
|
59
|
+
const cumReasoning =
|
|
60
|
+
usage.reasoningTokens != null
|
|
61
|
+
? Number(usage.reasoningTokens)
|
|
62
|
+
: priorReasoningMsg;
|
|
63
|
+
|
|
64
|
+
const cumCached =
|
|
65
|
+
usage.cachedInputTokens != null
|
|
66
|
+
? Number(usage.cachedInputTokens)
|
|
67
|
+
: providerMetadata?.openai?.cachedPromptTokens != null
|
|
68
|
+
? Number(providerMetadata.openai.cachedPromptTokens)
|
|
69
|
+
: priorCachedMsg;
|
|
70
|
+
|
|
71
|
+
// Compute deltas for this step; clamp to 0 in case provider reports smaller values
|
|
72
|
+
const deltaInput = Math.max(0, cumPrompt - priorPromptMsg);
|
|
73
|
+
const deltaOutput = Math.max(0, cumCompletion - priorCompletionMsg);
|
|
74
|
+
const deltaCached = Math.max(0, cumCached - priorCachedMsg);
|
|
75
|
+
const deltaReasoning = Math.max(0, cumReasoning - priorReasoningMsg);
|
|
76
|
+
|
|
77
|
+
const nextInputSess = priorInputSess + deltaInput;
|
|
78
|
+
const nextOutputSess = priorOutputSess + deltaOutput;
|
|
79
|
+
const nextCachedSess = priorCachedSess + deltaCached;
|
|
80
|
+
const nextReasoningSess = priorReasoningSess + deltaReasoning;
|
|
81
|
+
|
|
82
|
+
await db
|
|
83
|
+
.update(sessions)
|
|
84
|
+
.set({
|
|
85
|
+
totalInputTokens: nextInputSess,
|
|
86
|
+
totalOutputTokens: nextOutputSess,
|
|
87
|
+
totalCachedTokens: nextCachedSess,
|
|
88
|
+
totalReasoningTokens: nextReasoningSess,
|
|
89
|
+
})
|
|
90
|
+
.where(eq(sessions.id, opts.sessionId));
|
|
91
|
+
}
|
|
92
|
+
|
|
6
93
|
/**
|
|
7
94
|
* Updates session token counts after a run completes.
|
|
95
|
+
* @deprecated Use updateSessionTokensIncremental for per-step tracking
|
|
8
96
|
*/
|
|
9
97
|
export async function updateSessionTokens(
|
|
10
98
|
fin: { usage?: { inputTokens?: number; outputTokens?: number } },
|
|
@@ -36,7 +124,67 @@ export async function updateSessionTokens(
|
|
|
36
124
|
}
|
|
37
125
|
|
|
38
126
|
/**
|
|
39
|
-
*
|
|
127
|
+
* Updates message token counts incrementally after each step.
|
|
128
|
+
* Note: onStepFinish.usage is CUMULATIVE per message, so we REPLACE values, not add.
|
|
129
|
+
*/
|
|
130
|
+
export async function updateMessageTokensIncremental(
|
|
131
|
+
usage: UsageData,
|
|
132
|
+
providerMetadata: Record<string, any> | undefined,
|
|
133
|
+
opts: RunOpts,
|
|
134
|
+
db: Awaited<ReturnType<typeof getDb>>,
|
|
135
|
+
) {
|
|
136
|
+
if (!usage) return;
|
|
137
|
+
|
|
138
|
+
const msgRows = await db
|
|
139
|
+
.select()
|
|
140
|
+
.from(messages)
|
|
141
|
+
.where(eq(messages.id, opts.assistantMessageId));
|
|
142
|
+
|
|
143
|
+
if (msgRows.length > 0 && msgRows[0]) {
|
|
144
|
+
const msg = msgRows[0];
|
|
145
|
+
const priorPrompt = Number(msg.promptTokens ?? 0);
|
|
146
|
+
const priorCompletion = Number(msg.completionTokens ?? 0);
|
|
147
|
+
const priorCached = Number(msg.cachedInputTokens ?? 0);
|
|
148
|
+
const priorReasoning = Number(msg.reasoningTokens ?? 0);
|
|
149
|
+
|
|
150
|
+
// Treat usage as cumulative per-message - REPLACE not ADD
|
|
151
|
+
const cumPrompt =
|
|
152
|
+
usage.inputTokens != null ? Number(usage.inputTokens) : priorPrompt;
|
|
153
|
+
const cumCompletion =
|
|
154
|
+
usage.outputTokens != null ? Number(usage.outputTokens) : priorCompletion;
|
|
155
|
+
const cumReasoning =
|
|
156
|
+
usage.reasoningTokens != null
|
|
157
|
+
? Number(usage.reasoningTokens)
|
|
158
|
+
: priorReasoning;
|
|
159
|
+
|
|
160
|
+
const cumCached =
|
|
161
|
+
usage.cachedInputTokens != null
|
|
162
|
+
? Number(usage.cachedInputTokens)
|
|
163
|
+
: providerMetadata?.openai?.cachedPromptTokens != null
|
|
164
|
+
? Number(providerMetadata.openai.cachedPromptTokens)
|
|
165
|
+
: priorCached;
|
|
166
|
+
|
|
167
|
+
const cumTotal =
|
|
168
|
+
usage.totalTokens != null
|
|
169
|
+
? Number(usage.totalTokens)
|
|
170
|
+
: cumPrompt + cumCompletion + cumReasoning;
|
|
171
|
+
|
|
172
|
+
await db
|
|
173
|
+
.update(messages)
|
|
174
|
+
.set({
|
|
175
|
+
promptTokens: cumPrompt,
|
|
176
|
+
completionTokens: cumCompletion,
|
|
177
|
+
totalTokens: cumTotal,
|
|
178
|
+
cachedInputTokens: cumCached,
|
|
179
|
+
reasoningTokens: cumReasoning,
|
|
180
|
+
})
|
|
181
|
+
.where(eq(messages.id, opts.assistantMessageId));
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
/**
|
|
186
|
+
* Marks an assistant message as complete.
|
|
187
|
+
* Token usage is tracked incrementally via updateMessageTokensIncremental().
|
|
40
188
|
*/
|
|
41
189
|
export async function completeAssistantMessage(
|
|
42
190
|
fin: {
|
|
@@ -49,22 +197,13 @@ export async function completeAssistantMessage(
|
|
|
49
197
|
opts: RunOpts,
|
|
50
198
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
51
199
|
) {
|
|
52
|
-
|
|
53
|
-
status: 'complete',
|
|
54
|
-
completedAt: Date.now(),
|
|
55
|
-
};
|
|
56
|
-
|
|
57
|
-
if (fin.usage) {
|
|
58
|
-
vals.promptTokens = fin.usage.inputTokens;
|
|
59
|
-
vals.completionTokens = fin.usage.outputTokens;
|
|
60
|
-
vals.totalTokens =
|
|
61
|
-
fin.usage.totalTokens ??
|
|
62
|
-
(vals.promptTokens as number) + (vals.completionTokens as number);
|
|
63
|
-
}
|
|
64
|
-
|
|
200
|
+
// Only mark as complete - tokens are already tracked incrementally
|
|
65
201
|
await db
|
|
66
202
|
.update(messages)
|
|
67
|
-
.set(
|
|
203
|
+
.set({
|
|
204
|
+
status: 'complete',
|
|
205
|
+
completedAt: Date.now(),
|
|
206
|
+
})
|
|
68
207
|
.where(eq(messages.id, opts.assistantMessageId));
|
|
69
208
|
}
|
|
70
209
|
|
package/src/runtime/runner.ts
CHANGED
|
@@ -28,6 +28,8 @@ import {
|
|
|
28
28
|
} from './tool-context-setup.ts';
|
|
29
29
|
import {
|
|
30
30
|
updateSessionTokens,
|
|
31
|
+
updateSessionTokensIncremental,
|
|
32
|
+
updateMessageTokensIncremental,
|
|
31
33
|
completeAssistantMessage,
|
|
32
34
|
cleanupEmptyTextParts,
|
|
33
35
|
} from './db-operations.ts';
|
|
@@ -180,7 +182,7 @@ async function runAssistant(opts: RunOpts) {
|
|
|
180
182
|
opts,
|
|
181
183
|
db,
|
|
182
184
|
);
|
|
183
|
-
const toolset = adaptTools(gated, sharedCtx);
|
|
185
|
+
const toolset = adaptTools(gated, sharedCtx, opts.provider);
|
|
184
186
|
|
|
185
187
|
const modelTimer = time('runner:resolveModel');
|
|
186
188
|
const model = await resolveModel(opts.provider, opts.model, cfg);
|
|
@@ -229,6 +231,8 @@ async function runAssistant(opts: RunOpts) {
|
|
|
229
231
|
updateCurrentPartId,
|
|
230
232
|
updateAccumulated,
|
|
231
233
|
incrementStepIndex,
|
|
234
|
+
updateSessionTokensIncremental,
|
|
235
|
+
updateMessageTokensIncremental,
|
|
232
236
|
);
|
|
233
237
|
|
|
234
238
|
const onError = createErrorHandler(opts, db, getStepIndex, sharedCtx);
|
|
@@ -239,16 +243,37 @@ async function runAssistant(opts: RunOpts) {
|
|
|
239
243
|
opts,
|
|
240
244
|
db,
|
|
241
245
|
() => ensureFinishToolCalled(finishObserved, toolset, sharedCtx, stepIndex),
|
|
242
|
-
updateSessionTokens,
|
|
243
246
|
completeAssistantMessage,
|
|
244
247
|
);
|
|
245
248
|
|
|
249
|
+
// Apply optimizations: deduplication, pruning, cache control, and truncation
|
|
250
|
+
const { addCacheControl, truncateHistory } = await import(
|
|
251
|
+
'./cache-optimizer.ts'
|
|
252
|
+
);
|
|
253
|
+
const { optimizeContext } = await import('./context-optimizer.ts');
|
|
254
|
+
|
|
255
|
+
// 1. Optimize context (deduplicate file reads, prune old tool results)
|
|
256
|
+
const contextOptimized = optimizeContext(messagesWithSystemInstructions, {
|
|
257
|
+
deduplicateFiles: true,
|
|
258
|
+
maxToolResults: 30,
|
|
259
|
+
});
|
|
260
|
+
|
|
261
|
+
// 2. Truncate history
|
|
262
|
+
const truncatedMessages = truncateHistory(contextOptimized, 20);
|
|
263
|
+
|
|
264
|
+
// 3. Add cache control
|
|
265
|
+
const { system: cachedSystem, messages: optimizedMessages } = addCacheControl(
|
|
266
|
+
opts.provider as any,
|
|
267
|
+
system,
|
|
268
|
+
truncatedMessages,
|
|
269
|
+
);
|
|
270
|
+
|
|
246
271
|
try {
|
|
247
272
|
const result = streamText({
|
|
248
273
|
model,
|
|
249
274
|
tools: toolset,
|
|
250
|
-
...(
|
|
251
|
-
messages:
|
|
275
|
+
...(cachedSystem ? { system: cachedSystem } : {}),
|
|
276
|
+
messages: optimizedMessages,
|
|
252
277
|
...(maxOutputTokens ? { maxOutputTokens } : {}),
|
|
253
278
|
abortSignal: opts.abortSignal,
|
|
254
279
|
stopWhen: hasToolCall('finish'),
|
|
@@ -9,9 +9,16 @@ import type { RunOpts } from './session-queue.ts';
|
|
|
9
9
|
import type { ToolAdapterContext } from '../tools/adapter.ts';
|
|
10
10
|
|
|
11
11
|
type StepFinishEvent = {
|
|
12
|
-
usage?: {
|
|
12
|
+
usage?: {
|
|
13
|
+
inputTokens?: number;
|
|
14
|
+
outputTokens?: number;
|
|
15
|
+
totalTokens?: number;
|
|
16
|
+
cachedInputTokens?: number;
|
|
17
|
+
reasoningTokens?: number;
|
|
18
|
+
};
|
|
13
19
|
finishReason?: string;
|
|
14
20
|
response?: unknown;
|
|
21
|
+
experimental_providerMetadata?: Record<string, any>;
|
|
15
22
|
};
|
|
16
23
|
|
|
17
24
|
type FinishEvent = {
|
|
@@ -39,6 +46,18 @@ export function createStepFinishHandler(
|
|
|
39
46
|
updateCurrentPartId: (id: string) => void,
|
|
40
47
|
updateAccumulated: (text: string) => void,
|
|
41
48
|
incrementStepIndex: () => number,
|
|
49
|
+
updateSessionTokensIncrementalFn: (
|
|
50
|
+
usage: any,
|
|
51
|
+
providerMetadata: Record<string, any> | undefined,
|
|
52
|
+
opts: RunOpts,
|
|
53
|
+
db: Awaited<ReturnType<typeof getDb>>,
|
|
54
|
+
) => Promise<void>,
|
|
55
|
+
updateMessageTokensIncrementalFn: (
|
|
56
|
+
usage: any,
|
|
57
|
+
providerMetadata: Record<string, any> | undefined,
|
|
58
|
+
opts: RunOpts,
|
|
59
|
+
db: Awaited<ReturnType<typeof getDb>>,
|
|
60
|
+
) => Promise<void>,
|
|
42
61
|
) {
|
|
43
62
|
return async (step: StepFinishEvent) => {
|
|
44
63
|
const finishedAt = Date.now();
|
|
@@ -52,6 +71,27 @@ export function createStepFinishHandler(
|
|
|
52
71
|
.where(eq(messageParts.id, currentPartId));
|
|
53
72
|
} catch {}
|
|
54
73
|
|
|
74
|
+
// Update token counts incrementally after each step
|
|
75
|
+
if (step.usage) {
|
|
76
|
+
try {
|
|
77
|
+
await updateSessionTokensIncrementalFn(
|
|
78
|
+
step.usage,
|
|
79
|
+
step.experimental_providerMetadata,
|
|
80
|
+
opts,
|
|
81
|
+
db,
|
|
82
|
+
);
|
|
83
|
+
} catch {}
|
|
84
|
+
|
|
85
|
+
try {
|
|
86
|
+
await updateMessageTokensIncrementalFn(
|
|
87
|
+
step.usage,
|
|
88
|
+
step.experimental_providerMetadata,
|
|
89
|
+
opts,
|
|
90
|
+
db,
|
|
91
|
+
);
|
|
92
|
+
} catch {}
|
|
93
|
+
}
|
|
94
|
+
|
|
55
95
|
try {
|
|
56
96
|
publish({
|
|
57
97
|
type: 'finish-step',
|
|
@@ -234,11 +274,6 @@ export function createFinishHandler(
|
|
|
234
274
|
opts: RunOpts,
|
|
235
275
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
236
276
|
ensureFinishToolCalled: () => Promise<void>,
|
|
237
|
-
updateSessionTokensFn: (
|
|
238
|
-
fin: FinishEvent,
|
|
239
|
-
opts: RunOpts,
|
|
240
|
-
db: Awaited<ReturnType<typeof getDb>>,
|
|
241
|
-
) => Promise<void>,
|
|
242
277
|
completeAssistantMessageFn: (
|
|
243
278
|
fin: FinishEvent,
|
|
244
279
|
opts: RunOpts,
|
|
@@ -250,23 +285,37 @@ export function createFinishHandler(
|
|
|
250
285
|
await ensureFinishToolCalled();
|
|
251
286
|
} catch {}
|
|
252
287
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
} catch {}
|
|
288
|
+
// Note: Token updates are handled incrementally in onStepFinish
|
|
289
|
+
// Do NOT add fin.usage here as it would cause double-counting
|
|
256
290
|
|
|
257
291
|
try {
|
|
258
292
|
await completeAssistantMessageFn(fin, opts, db);
|
|
259
293
|
} catch {}
|
|
260
294
|
|
|
261
|
-
|
|
262
|
-
|
|
295
|
+
// Use session totals from DB for accurate cost calculation
|
|
296
|
+
const sessRows = await db
|
|
297
|
+
.select()
|
|
298
|
+
.from(messages)
|
|
299
|
+
.where(eq(messages.id, opts.assistantMessageId));
|
|
300
|
+
|
|
301
|
+
const usage = sessRows[0]
|
|
302
|
+
? {
|
|
303
|
+
inputTokens: Number(sessRows[0].promptTokens ?? 0),
|
|
304
|
+
outputTokens: Number(sessRows[0].completionTokens ?? 0),
|
|
305
|
+
totalTokens: Number(sessRows[0].totalTokens ?? 0),
|
|
306
|
+
}
|
|
307
|
+
: fin.usage;
|
|
308
|
+
|
|
309
|
+
const costUsd = usage
|
|
310
|
+
? estimateModelCostUsd(opts.provider, opts.model, usage)
|
|
263
311
|
: undefined;
|
|
312
|
+
|
|
264
313
|
publish({
|
|
265
314
|
type: 'message.completed',
|
|
266
315
|
sessionId: opts.sessionId,
|
|
267
316
|
payload: {
|
|
268
317
|
id: opts.assistantMessageId,
|
|
269
|
-
usage
|
|
318
|
+
usage,
|
|
270
319
|
costUsd,
|
|
271
320
|
finishReason: fin.finishReason,
|
|
272
321
|
},
|
package/src/tools/adapter.ts
CHANGED
|
@@ -39,15 +39,40 @@ function getPendingQueue(
|
|
|
39
39
|
return queue;
|
|
40
40
|
}
|
|
41
41
|
|
|
42
|
-
export function adaptTools(
|
|
42
|
+
export function adaptTools(
|
|
43
|
+
tools: DiscoveredTool[],
|
|
44
|
+
ctx: ToolAdapterContext,
|
|
45
|
+
provider?: string,
|
|
46
|
+
) {
|
|
43
47
|
const out: Record<string, Tool> = {};
|
|
44
48
|
const pendingCalls = new Map<string, PendingCallMeta[]>();
|
|
45
49
|
let firstToolCallReported = false;
|
|
46
50
|
|
|
51
|
+
// Anthropic allows max 4 cache_control blocks
|
|
52
|
+
// Cache only the most frequently used tools: read, write, bash
|
|
53
|
+
const cacheableTools = new Set(['read', 'write', 'bash', 'edit']);
|
|
54
|
+
let cachedToolCount = 0;
|
|
55
|
+
|
|
47
56
|
for (const { name, tool } of tools) {
|
|
48
57
|
const base = tool;
|
|
58
|
+
|
|
59
|
+
// Add cache control for Anthropic to cache tool definitions (max 2 tools)
|
|
60
|
+
const shouldCache =
|
|
61
|
+
provider === 'anthropic' &&
|
|
62
|
+
cacheableTools.has(name) &&
|
|
63
|
+
cachedToolCount < 2;
|
|
64
|
+
|
|
65
|
+
if (shouldCache) {
|
|
66
|
+
cachedToolCount++;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
const providerOptions = shouldCache
|
|
70
|
+
? { anthropic: { cacheControl: { type: 'ephemeral' as const } } }
|
|
71
|
+
: undefined;
|
|
72
|
+
|
|
49
73
|
out[name] = {
|
|
50
74
|
...base,
|
|
75
|
+
...(providerOptions ? { providerOptions } : {}),
|
|
51
76
|
async onInputStart(options: unknown) {
|
|
52
77
|
const queue = getPendingQueue(pendingCalls, name);
|
|
53
78
|
queue.push({
|
|
@@ -185,194 +210,257 @@ export function adaptTools(tools: DiscoveredTool[], ctx: ToolAdapterContext) {
|
|
|
185
210
|
const callIdFromQueue = meta?.callId;
|
|
186
211
|
const startTsFromQueue = meta?.startTs;
|
|
187
212
|
const stepIndexForEvent = meta?.stepIndex ?? ctx.stepIndex;
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
chunks
|
|
213
|
+
|
|
214
|
+
try {
|
|
215
|
+
// Handle session-relative paths and cwd tools
|
|
216
|
+
let res: ToolExecuteReturn | { cwd: string } | null | undefined;
|
|
217
|
+
const cwd = getCwd(ctx.sessionId);
|
|
218
|
+
if (name === 'pwd') {
|
|
219
|
+
res = { cwd };
|
|
220
|
+
} else if (name === 'cd') {
|
|
221
|
+
const next = joinRelative(
|
|
222
|
+
cwd,
|
|
223
|
+
String((input as Record<string, unknown>)?.path ?? '.'),
|
|
224
|
+
);
|
|
225
|
+
setCwd(ctx.sessionId, next);
|
|
226
|
+
res = { cwd: next };
|
|
227
|
+
} else if (
|
|
228
|
+
['read', 'write', 'ls', 'tree'].includes(name) &&
|
|
229
|
+
typeof (input as Record<string, unknown>)?.path === 'string'
|
|
230
|
+
) {
|
|
231
|
+
const rel = joinRelative(
|
|
232
|
+
cwd,
|
|
233
|
+
String((input as Record<string, unknown>).path),
|
|
234
|
+
);
|
|
235
|
+
const nextInput = {
|
|
236
|
+
...(input as Record<string, unknown>),
|
|
237
|
+
path: rel,
|
|
238
|
+
} as ToolExecuteInput;
|
|
239
|
+
// biome-ignore lint/suspicious/noExplicitAny: AI SDK types are complex
|
|
240
|
+
res = base.execute?.(nextInput, options as any);
|
|
241
|
+
} else if (name === 'bash') {
|
|
242
|
+
const needsCwd =
|
|
243
|
+
!input ||
|
|
244
|
+
typeof (input as Record<string, unknown>).cwd !== 'string';
|
|
245
|
+
const nextInput = needsCwd
|
|
246
|
+
? ({
|
|
247
|
+
...(input as Record<string, unknown>),
|
|
248
|
+
cwd,
|
|
249
|
+
} as ToolExecuteInput)
|
|
250
|
+
: input;
|
|
251
|
+
// biome-ignore lint/suspicious/noExplicitAny: AI SDK types are complex
|
|
252
|
+
res = base.execute?.(nextInput, options as any);
|
|
253
|
+
} else {
|
|
254
|
+
// biome-ignore lint/suspicious/noExplicitAny: AI SDK types are complex
|
|
255
|
+
res = base.execute?.(input, options as any);
|
|
256
|
+
}
|
|
257
|
+
let result: unknown = res;
|
|
258
|
+
// If tool returns an async iterable, stream deltas while accumulating
|
|
259
|
+
if (res && typeof res === 'object' && Symbol.asyncIterator in res) {
|
|
260
|
+
const chunks: unknown[] = [];
|
|
261
|
+
for await (const chunk of res as AsyncIterable<unknown>) {
|
|
262
|
+
chunks.push(chunk);
|
|
263
|
+
publish({
|
|
264
|
+
type: 'tool.delta',
|
|
265
|
+
sessionId: ctx.sessionId,
|
|
266
|
+
payload: {
|
|
267
|
+
name,
|
|
268
|
+
channel: 'output',
|
|
269
|
+
delta: chunk,
|
|
270
|
+
stepIndex: stepIndexForEvent,
|
|
271
|
+
callId: callIdFromQueue,
|
|
272
|
+
},
|
|
273
|
+
});
|
|
274
|
+
}
|
|
275
|
+
// Prefer the last chunk as the result if present, otherwise the entire array
|
|
276
|
+
result = chunks.length > 0 ? chunks[chunks.length - 1] : null;
|
|
277
|
+
} else {
|
|
278
|
+
// Await promise or passthrough value
|
|
279
|
+
result = await Promise.resolve(res as ToolExecuteReturn);
|
|
280
|
+
}
|
|
281
|
+
const resultPartId = crypto.randomUUID();
|
|
282
|
+
const callId = callIdFromQueue;
|
|
283
|
+
const startTs = startTsFromQueue;
|
|
284
|
+
const contentObj: {
|
|
285
|
+
name: string;
|
|
286
|
+
result: unknown;
|
|
287
|
+
callId?: string;
|
|
288
|
+
artifact?: unknown;
|
|
289
|
+
args?: unknown;
|
|
290
|
+
} = {
|
|
291
|
+
name,
|
|
292
|
+
result,
|
|
293
|
+
callId,
|
|
294
|
+
};
|
|
295
|
+
if (meta?.args !== undefined) {
|
|
296
|
+
contentObj.args = meta.args;
|
|
297
|
+
}
|
|
298
|
+
if (result && typeof result === 'object' && 'artifact' in result) {
|
|
299
|
+
try {
|
|
300
|
+
const maybeArtifact = (result as { artifact?: unknown }).artifact;
|
|
301
|
+
if (maybeArtifact !== undefined)
|
|
302
|
+
contentObj.artifact = maybeArtifact;
|
|
303
|
+
} catch {}
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
const index = await ctx.nextIndex();
|
|
307
|
+
const endTs = Date.now();
|
|
308
|
+
const dur =
|
|
309
|
+
typeof startTs === 'number' ? Math.max(0, endTs - startTs) : null;
|
|
310
|
+
|
|
311
|
+
// Special-case: keep progress_update result lightweight; publish first, persist best-effort
|
|
312
|
+
if (name === 'progress_update') {
|
|
236
313
|
publish({
|
|
237
|
-
type: 'tool.
|
|
314
|
+
type: 'tool.result',
|
|
238
315
|
sessionId: ctx.sessionId,
|
|
239
|
-
payload: {
|
|
240
|
-
name,
|
|
241
|
-
channel: 'output',
|
|
242
|
-
delta: chunk,
|
|
243
|
-
stepIndex: stepIndexForEvent,
|
|
244
|
-
callId: callIdFromQueue,
|
|
245
|
-
},
|
|
316
|
+
payload: { ...contentObj, stepIndex: stepIndexForEvent },
|
|
246
317
|
});
|
|
318
|
+
// Persist without blocking the event loop
|
|
319
|
+
(async () => {
|
|
320
|
+
try {
|
|
321
|
+
await ctx.db.insert(messageParts).values({
|
|
322
|
+
id: resultPartId,
|
|
323
|
+
messageId: ctx.messageId,
|
|
324
|
+
index,
|
|
325
|
+
stepIndex: stepIndexForEvent,
|
|
326
|
+
type: 'tool_result',
|
|
327
|
+
content: JSON.stringify(contentObj),
|
|
328
|
+
agent: ctx.agent,
|
|
329
|
+
provider: ctx.provider,
|
|
330
|
+
model: ctx.model,
|
|
331
|
+
startedAt: startTs,
|
|
332
|
+
completedAt: endTs,
|
|
333
|
+
toolName: name,
|
|
334
|
+
toolCallId: callId,
|
|
335
|
+
toolDurationMs: dur ?? undefined,
|
|
336
|
+
});
|
|
337
|
+
} catch {}
|
|
338
|
+
})();
|
|
339
|
+
return result as ToolExecuteReturn;
|
|
247
340
|
}
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
callId,
|
|
267
|
-
};
|
|
268
|
-
if (meta?.args !== undefined) {
|
|
269
|
-
contentObj.args = meta.args;
|
|
270
|
-
}
|
|
271
|
-
if (result && typeof result === 'object' && 'artifact' in result) {
|
|
341
|
+
|
|
342
|
+
await ctx.db.insert(messageParts).values({
|
|
343
|
+
id: resultPartId,
|
|
344
|
+
messageId: ctx.messageId,
|
|
345
|
+
index,
|
|
346
|
+
stepIndex: stepIndexForEvent,
|
|
347
|
+
type: 'tool_result',
|
|
348
|
+
content: JSON.stringify(contentObj),
|
|
349
|
+
agent: ctx.agent,
|
|
350
|
+
provider: ctx.provider,
|
|
351
|
+
model: ctx.model,
|
|
352
|
+
startedAt: startTs,
|
|
353
|
+
completedAt: endTs,
|
|
354
|
+
toolName: name,
|
|
355
|
+
toolCallId: callId,
|
|
356
|
+
toolDurationMs: dur ?? undefined,
|
|
357
|
+
});
|
|
358
|
+
// Update session aggregates: total tool time and counts per tool
|
|
272
359
|
try {
|
|
273
|
-
const
|
|
274
|
-
|
|
275
|
-
|
|
360
|
+
const sessRows = await ctx.db
|
|
361
|
+
.select()
|
|
362
|
+
.from(sessions)
|
|
363
|
+
.where(eq(sessions.id, ctx.sessionId));
|
|
364
|
+
if (sessRows.length) {
|
|
365
|
+
const row = sessRows[0] as typeof sessions.$inferSelect;
|
|
366
|
+
const totalToolTimeMs =
|
|
367
|
+
Number(row.totalToolTimeMs || 0) + (dur ?? 0);
|
|
368
|
+
let counts: Record<string, number> = {};
|
|
369
|
+
try {
|
|
370
|
+
counts = row.toolCountsJson
|
|
371
|
+
? JSON.parse(row.toolCountsJson)
|
|
372
|
+
: {};
|
|
373
|
+
} catch {}
|
|
374
|
+
counts[name] = (counts[name] || 0) + 1;
|
|
375
|
+
await ctx.db
|
|
376
|
+
.update(sessions)
|
|
377
|
+
.set({
|
|
378
|
+
totalToolTimeMs,
|
|
379
|
+
toolCountsJson: JSON.stringify(counts),
|
|
380
|
+
lastActiveAt: endTs,
|
|
381
|
+
})
|
|
382
|
+
.where(eq(sessions.id, ctx.sessionId));
|
|
383
|
+
}
|
|
276
384
|
} catch {}
|
|
277
|
-
}
|
|
278
|
-
|
|
279
|
-
const index = await ctx.nextIndex();
|
|
280
|
-
const endTs = Date.now();
|
|
281
|
-
const dur =
|
|
282
|
-
typeof startTs === 'number' ? Math.max(0, endTs - startTs) : null;
|
|
283
|
-
|
|
284
|
-
// Special-case: keep progress_update result lightweight; publish first, persist best-effort
|
|
285
|
-
if (name === 'progress_update') {
|
|
286
385
|
publish({
|
|
287
386
|
type: 'tool.result',
|
|
288
387
|
sessionId: ctx.sessionId,
|
|
289
388
|
payload: { ...contentObj, stepIndex: stepIndexForEvent },
|
|
290
389
|
});
|
|
291
|
-
|
|
292
|
-
(async () => {
|
|
390
|
+
if (name === 'update_plan') {
|
|
293
391
|
try {
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
startedAt: startTs,
|
|
305
|
-
completedAt: endTs,
|
|
306
|
-
toolName: name,
|
|
307
|
-
toolCallId: callId,
|
|
308
|
-
toolDurationMs: dur ?? undefined,
|
|
309
|
-
});
|
|
392
|
+
const result = (contentObj as { result?: unknown }).result as
|
|
393
|
+
| { items?: unknown; note?: unknown }
|
|
394
|
+
| undefined;
|
|
395
|
+
if (result && Array.isArray(result.items)) {
|
|
396
|
+
publish({
|
|
397
|
+
type: 'plan.updated',
|
|
398
|
+
sessionId: ctx.sessionId,
|
|
399
|
+
payload: { items: result.items, note: result.note },
|
|
400
|
+
});
|
|
401
|
+
}
|
|
310
402
|
} catch {}
|
|
311
|
-
}
|
|
312
|
-
return result
|
|
313
|
-
}
|
|
403
|
+
}
|
|
404
|
+
return result;
|
|
405
|
+
} catch (error) {
|
|
406
|
+
// Tool execution failed - save error to database as tool_result
|
|
407
|
+
const resultPartId = crypto.randomUUID();
|
|
408
|
+
const callId = callIdFromQueue;
|
|
409
|
+
const startTs = startTsFromQueue;
|
|
410
|
+
const endTs = Date.now();
|
|
411
|
+
const dur =
|
|
412
|
+
typeof startTs === 'number' ? Math.max(0, endTs - startTs) : null;
|
|
314
413
|
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
const sessRows = await ctx.db
|
|
334
|
-
.select()
|
|
335
|
-
.from(sessions)
|
|
336
|
-
.where(eq(sessions.id, ctx.sessionId));
|
|
337
|
-
if (sessRows.length) {
|
|
338
|
-
const row = sessRows[0] as typeof sessions.$inferSelect;
|
|
339
|
-
const totalToolTimeMs =
|
|
340
|
-
Number(row.totalToolTimeMs || 0) + (dur ?? 0);
|
|
341
|
-
let counts: Record<string, number> = {};
|
|
342
|
-
try {
|
|
343
|
-
counts = row.toolCountsJson ? JSON.parse(row.toolCountsJson) : {};
|
|
344
|
-
} catch {}
|
|
345
|
-
counts[name] = (counts[name] || 0) + 1;
|
|
346
|
-
await ctx.db
|
|
347
|
-
.update(sessions)
|
|
348
|
-
.set({
|
|
349
|
-
totalToolTimeMs,
|
|
350
|
-
toolCountsJson: JSON.stringify(counts),
|
|
351
|
-
lastActiveAt: endTs,
|
|
352
|
-
})
|
|
353
|
-
.where(eq(sessions.id, ctx.sessionId));
|
|
414
|
+
const errorMessage =
|
|
415
|
+
error instanceof Error ? error.message : String(error);
|
|
416
|
+
const errorStack = error instanceof Error ? error.stack : undefined;
|
|
417
|
+
|
|
418
|
+
const errorResult = {
|
|
419
|
+
ok: false,
|
|
420
|
+
error: errorMessage,
|
|
421
|
+
stack: errorStack,
|
|
422
|
+
};
|
|
423
|
+
|
|
424
|
+
const contentObj = {
|
|
425
|
+
name,
|
|
426
|
+
result: errorResult,
|
|
427
|
+
callId,
|
|
428
|
+
};
|
|
429
|
+
|
|
430
|
+
if (meta?.args !== undefined) {
|
|
431
|
+
contentObj.args = meta.args;
|
|
354
432
|
}
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
433
|
+
|
|
434
|
+
const index = await ctx.nextIndex();
|
|
435
|
+
|
|
436
|
+
// Save error result to database
|
|
437
|
+
await ctx.db.insert(messageParts).values({
|
|
438
|
+
id: resultPartId,
|
|
439
|
+
messageId: ctx.messageId,
|
|
440
|
+
index,
|
|
441
|
+
stepIndex: stepIndexForEvent,
|
|
442
|
+
type: 'tool_result',
|
|
443
|
+
content: JSON.stringify(contentObj),
|
|
444
|
+
agent: ctx.agent,
|
|
445
|
+
provider: ctx.provider,
|
|
446
|
+
model: ctx.model,
|
|
447
|
+
startedAt: startTs,
|
|
448
|
+
completedAt: endTs,
|
|
449
|
+
toolName: name,
|
|
450
|
+
toolCallId: callId,
|
|
451
|
+
toolDurationMs: dur ?? undefined,
|
|
452
|
+
});
|
|
453
|
+
|
|
454
|
+
// Publish error result
|
|
455
|
+
publish({
|
|
456
|
+
type: 'tool.result',
|
|
457
|
+
sessionId: ctx.sessionId,
|
|
458
|
+
payload: { ...contentObj, stepIndex: stepIndexForEvent },
|
|
459
|
+
});
|
|
460
|
+
|
|
461
|
+
// Re-throw so AI SDK can handle it
|
|
462
|
+
throw error;
|
|
374
463
|
}
|
|
375
|
-
return result;
|
|
376
464
|
},
|
|
377
465
|
} as Tool;
|
|
378
466
|
}
|