@agi-cli/server 0.1.61 → 0.1.62
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +3 -3
- package/src/runtime/cache-optimizer.ts +51 -29
- package/src/runtime/db-operations.ts +48 -43
- package/src/runtime/runner.ts +248 -99
- package/src/runtime/stream-handlers.ts +209 -175
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@agi-cli/server",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.62",
|
|
4
4
|
"description": "HTTP API server for AGI CLI",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "./src/index.ts",
|
|
@@ -29,8 +29,8 @@
|
|
|
29
29
|
"typecheck": "tsc --noEmit"
|
|
30
30
|
},
|
|
31
31
|
"dependencies": {
|
|
32
|
-
"@agi-cli/sdk": "0.1.
|
|
33
|
-
"@agi-cli/database": "0.1.
|
|
32
|
+
"@agi-cli/sdk": "0.1.62",
|
|
33
|
+
"@agi-cli/database": "0.1.62",
|
|
34
34
|
"drizzle-orm": "^0.44.5",
|
|
35
35
|
"hono": "^4.9.9"
|
|
36
36
|
},
|
|
@@ -1,23 +1,5 @@
|
|
|
1
1
|
import type { ModelMessage } from 'ai';
|
|
2
2
|
|
|
3
|
-
type SystemMessage =
|
|
4
|
-
| string
|
|
5
|
-
| Array<{
|
|
6
|
-
type: 'text';
|
|
7
|
-
text: string;
|
|
8
|
-
cache_control?: { type: 'ephemeral' };
|
|
9
|
-
}>;
|
|
10
|
-
|
|
11
|
-
interface ContentPart {
|
|
12
|
-
type: string;
|
|
13
|
-
[key: string]: unknown;
|
|
14
|
-
providerOptions?: {
|
|
15
|
-
anthropic?: {
|
|
16
|
-
cacheControl?: { type: 'ephemeral' };
|
|
17
|
-
};
|
|
18
|
-
};
|
|
19
|
-
}
|
|
20
|
-
|
|
21
3
|
/**
|
|
22
4
|
* Adds cache control to messages for prompt caching optimization.
|
|
23
5
|
* Anthropic supports caching for system messages, tools, and long context.
|
|
@@ -27,7 +9,13 @@ export function addCacheControl(
|
|
|
27
9
|
system: string | undefined,
|
|
28
10
|
messages: ModelMessage[],
|
|
29
11
|
): {
|
|
30
|
-
system?:
|
|
12
|
+
system?:
|
|
13
|
+
| string
|
|
14
|
+
| Array<{
|
|
15
|
+
type: 'text';
|
|
16
|
+
text: string;
|
|
17
|
+
cache_control?: { type: 'ephemeral' };
|
|
18
|
+
}>;
|
|
31
19
|
messages: ModelMessage[];
|
|
32
20
|
} {
|
|
33
21
|
// Only Anthropic supports prompt caching currently
|
|
@@ -36,7 +24,7 @@ export function addCacheControl(
|
|
|
36
24
|
}
|
|
37
25
|
|
|
38
26
|
// Convert system to cacheable format if it's long enough
|
|
39
|
-
let cachedSystem:
|
|
27
|
+
let cachedSystem: any = system;
|
|
40
28
|
if (system && system.length > 1024) {
|
|
41
29
|
// Anthropic requires 1024+ tokens for Claude Sonnet/Opus
|
|
42
30
|
cachedSystem = [
|
|
@@ -73,21 +61,55 @@ export function addCacheControl(
|
|
|
73
61
|
// Add cache control to the last content part of that message
|
|
74
62
|
const lastPart = targetMsg.content[targetMsg.content.length - 1];
|
|
75
63
|
if (lastPart && typeof lastPart === 'object' && 'type' in lastPart) {
|
|
76
|
-
(lastPart as
|
|
64
|
+
(lastPart as any).providerOptions = {
|
|
77
65
|
anthropic: { cacheControl: { type: 'ephemeral' } },
|
|
78
66
|
};
|
|
79
67
|
}
|
|
80
68
|
}
|
|
81
69
|
}
|
|
82
70
|
|
|
83
|
-
return {
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
71
|
+
return { system: cachedSystem, messages: cachedMessages };
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
return { system: cachedSystem, messages };
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
/**
|
|
78
|
+
* Truncates old messages to reduce context size while keeping recent context.
|
|
79
|
+
* Strategy: Keep system message + last N messages
|
|
80
|
+
*/
|
|
81
|
+
export function truncateHistory(
|
|
82
|
+
messages: ModelMessage[],
|
|
83
|
+
maxMessages = 20,
|
|
84
|
+
): ModelMessage[] {
|
|
85
|
+
if (messages.length <= maxMessages) {
|
|
86
|
+
return messages;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
// Keep the most recent messages
|
|
90
|
+
return messages.slice(-maxMessages);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/**
|
|
94
|
+
* Estimates token count (rough approximation: ~4 chars per token)
|
|
95
|
+
*/
|
|
96
|
+
export function estimateTokens(text: string): number {
|
|
97
|
+
return Math.ceil(text.length / 4);
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
/**
|
|
101
|
+
* Summarizes tool results if they're too long
|
|
102
|
+
*/
|
|
103
|
+
export function summarizeToolResult(result: unknown, maxLength = 5000): string {
|
|
104
|
+
const str = typeof result === 'string' ? result : JSON.stringify(result);
|
|
105
|
+
|
|
106
|
+
if (str.length <= maxLength) {
|
|
107
|
+
return str;
|
|
87
108
|
}
|
|
88
109
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
110
|
+
// Truncate and add indicator
|
|
111
|
+
return (
|
|
112
|
+
str.slice(0, maxLength) +
|
|
113
|
+
`\n\n[... truncated ${str.length - maxLength} characters]`
|
|
114
|
+
);
|
|
93
115
|
}
|
|
@@ -11,20 +11,13 @@ type UsageData = {
|
|
|
11
11
|
reasoningTokens?: number;
|
|
12
12
|
};
|
|
13
13
|
|
|
14
|
-
interface ProviderMetadata {
|
|
15
|
-
openai?: {
|
|
16
|
-
cachedPromptTokens?: number;
|
|
17
|
-
};
|
|
18
|
-
[key: string]: unknown;
|
|
19
|
-
}
|
|
20
|
-
|
|
21
14
|
/**
|
|
22
15
|
* Updates session token counts incrementally after each step.
|
|
23
16
|
* Note: onStepFinish.usage is CUMULATIVE per message, so we compute DELTA and add to session.
|
|
24
17
|
*/
|
|
25
18
|
export async function updateSessionTokensIncremental(
|
|
26
19
|
usage: UsageData,
|
|
27
|
-
providerMetadata:
|
|
20
|
+
providerMetadata: Record<string, any> | undefined,
|
|
28
21
|
opts: RunOpts,
|
|
29
22
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
30
23
|
) {
|
|
@@ -136,7 +129,7 @@ export async function updateSessionTokens(
|
|
|
136
129
|
*/
|
|
137
130
|
export async function updateMessageTokensIncremental(
|
|
138
131
|
usage: UsageData,
|
|
139
|
-
providerMetadata:
|
|
132
|
+
providerMetadata: Record<string, any> | undefined,
|
|
140
133
|
opts: RunOpts,
|
|
141
134
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
142
135
|
) {
|
|
@@ -155,74 +148,86 @@ export async function updateMessageTokensIncremental(
|
|
|
155
148
|
const priorReasoning = Number(msg.reasoningTokens ?? 0);
|
|
156
149
|
|
|
157
150
|
// Treat usage as cumulative per-message - REPLACE not ADD
|
|
158
|
-
const
|
|
151
|
+
const cumPrompt =
|
|
159
152
|
usage.inputTokens != null ? Number(usage.inputTokens) : priorPrompt;
|
|
160
|
-
const
|
|
153
|
+
const cumCompletion =
|
|
161
154
|
usage.outputTokens != null ? Number(usage.outputTokens) : priorCompletion;
|
|
162
|
-
const
|
|
155
|
+
const cumReasoning =
|
|
163
156
|
usage.reasoningTokens != null
|
|
164
157
|
? Number(usage.reasoningTokens)
|
|
165
158
|
: priorReasoning;
|
|
166
159
|
|
|
167
|
-
const
|
|
160
|
+
const cumCached =
|
|
168
161
|
usage.cachedInputTokens != null
|
|
169
162
|
? Number(usage.cachedInputTokens)
|
|
170
163
|
: providerMetadata?.openai?.cachedPromptTokens != null
|
|
171
164
|
? Number(providerMetadata.openai.cachedPromptTokens)
|
|
172
165
|
: priorCached;
|
|
173
166
|
|
|
167
|
+
const cumTotal =
|
|
168
|
+
usage.totalTokens != null
|
|
169
|
+
? Number(usage.totalTokens)
|
|
170
|
+
: cumPrompt + cumCompletion + cumReasoning;
|
|
171
|
+
|
|
174
172
|
await db
|
|
175
173
|
.update(messages)
|
|
176
174
|
.set({
|
|
177
|
-
promptTokens:
|
|
178
|
-
completionTokens:
|
|
179
|
-
|
|
180
|
-
|
|
175
|
+
promptTokens: cumPrompt,
|
|
176
|
+
completionTokens: cumCompletion,
|
|
177
|
+
totalTokens: cumTotal,
|
|
178
|
+
cachedInputTokens: cumCached,
|
|
179
|
+
reasoningTokens: cumReasoning,
|
|
181
180
|
})
|
|
182
181
|
.where(eq(messages.id, opts.assistantMessageId));
|
|
183
182
|
}
|
|
184
183
|
}
|
|
185
184
|
|
|
186
185
|
/**
|
|
187
|
-
*
|
|
188
|
-
*
|
|
186
|
+
* Marks an assistant message as complete.
|
|
187
|
+
* Token usage is tracked incrementally via updateMessageTokensIncremental().
|
|
189
188
|
*/
|
|
190
189
|
export async function completeAssistantMessage(
|
|
191
|
-
|
|
190
|
+
fin: {
|
|
192
191
|
usage?: {
|
|
193
192
|
inputTokens?: number;
|
|
194
193
|
outputTokens?: number;
|
|
194
|
+
totalTokens?: number;
|
|
195
195
|
};
|
|
196
196
|
},
|
|
197
197
|
opts: RunOpts,
|
|
198
198
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
199
199
|
) {
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
.
|
|
200
|
+
// Only mark as complete - tokens are already tracked incrementally
|
|
201
|
+
await db
|
|
202
|
+
.update(messages)
|
|
203
|
+
.set({
|
|
204
|
+
status: 'complete',
|
|
205
|
+
completedAt: Date.now(),
|
|
206
|
+
})
|
|
203
207
|
.where(eq(messages.id, opts.assistantMessageId));
|
|
204
|
-
|
|
205
|
-
if (msgRow.length > 0) {
|
|
206
|
-
await db
|
|
207
|
-
.update(messages)
|
|
208
|
-
.set({
|
|
209
|
-
finishedAt: new Date(),
|
|
210
|
-
})
|
|
211
|
-
.where(eq(messages.id, opts.assistantMessageId));
|
|
212
|
-
}
|
|
213
208
|
}
|
|
214
209
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
toolArgs?: unknown;
|
|
221
|
-
toolResult?: unknown;
|
|
222
|
-
textContent?: string | null;
|
|
223
|
-
stepIndex?: number | null;
|
|
224
|
-
},
|
|
210
|
+
/**
|
|
211
|
+
* Removes empty text parts from an assistant message.
|
|
212
|
+
*/
|
|
213
|
+
export async function cleanupEmptyTextParts(
|
|
214
|
+
opts: RunOpts,
|
|
225
215
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
226
216
|
) {
|
|
227
|
-
await db
|
|
217
|
+
const parts = await db
|
|
218
|
+
.select()
|
|
219
|
+
.from(messageParts)
|
|
220
|
+
.where(eq(messageParts.messageId, opts.assistantMessageId));
|
|
221
|
+
|
|
222
|
+
for (const p of parts) {
|
|
223
|
+
if (p.type === 'text') {
|
|
224
|
+
let t = '';
|
|
225
|
+
try {
|
|
226
|
+
t = JSON.parse(p.content || '{}')?.text || '';
|
|
227
|
+
} catch {}
|
|
228
|
+
if (!t || t.length === 0) {
|
|
229
|
+
await db.delete(messageParts).where(eq(messageParts.id, p.id));
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
}
|
|
228
233
|
}
|
package/src/runtime/runner.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { streamText } from 'ai';
|
|
1
|
+
import { hasToolCall, streamText } from 'ai';
|
|
2
2
|
import { loadConfig } from '@agi-cli/sdk';
|
|
3
3
|
import { getDb } from '@agi-cli/database';
|
|
4
4
|
import { messageParts } from '@agi-cli/database/schema';
|
|
@@ -7,8 +7,11 @@ import { resolveModel } from './provider.ts';
|
|
|
7
7
|
import { resolveAgentConfig } from './agent-registry.ts';
|
|
8
8
|
import { composeSystemPrompt } from './prompt.ts';
|
|
9
9
|
import { discoverProjectTools } from '@agi-cli/sdk';
|
|
10
|
-
import {
|
|
10
|
+
import { adaptTools } from '../tools/adapter.ts';
|
|
11
|
+
import { publish, subscribe } from '../events/bus.ts';
|
|
12
|
+
import { debugLog, time } from './debug.ts';
|
|
11
13
|
import { buildHistoryMessages } from './history-builder.ts';
|
|
14
|
+
import { toErrorPayload } from './error-handling.ts';
|
|
12
15
|
import { getMaxOutputTokens } from './token-utils.ts';
|
|
13
16
|
import {
|
|
14
17
|
type RunOpts,
|
|
@@ -19,11 +22,15 @@ import {
|
|
|
19
22
|
dequeueJob,
|
|
20
23
|
cleanupSession,
|
|
21
24
|
} from './session-queue.ts';
|
|
22
|
-
import {
|
|
25
|
+
import {
|
|
26
|
+
setupToolContext,
|
|
27
|
+
type RunnerToolContext,
|
|
28
|
+
} from './tool-context-setup.ts';
|
|
23
29
|
import {
|
|
24
30
|
updateSessionTokensIncremental,
|
|
25
31
|
updateMessageTokensIncremental,
|
|
26
32
|
completeAssistantMessage,
|
|
33
|
+
cleanupEmptyTextParts,
|
|
27
34
|
} from './db-operations.ts';
|
|
28
35
|
import {
|
|
29
36
|
createStepFinishHandler,
|
|
@@ -31,38 +38,175 @@ import {
|
|
|
31
38
|
createAbortHandler,
|
|
32
39
|
createFinishHandler,
|
|
33
40
|
} from './stream-handlers.ts';
|
|
34
|
-
import { addCacheControl } from './cache-optimizer.ts';
|
|
35
|
-
import { optimizeContext } from './context-optimizer.ts';
|
|
36
|
-
import { truncateHistory } from './history-truncator.ts';
|
|
37
41
|
|
|
38
42
|
/**
|
|
39
|
-
*
|
|
43
|
+
* Enqueues an assistant run for processing.
|
|
44
|
+
*/
|
|
45
|
+
export function enqueueAssistantRun(opts: Omit<RunOpts, 'abortSignal'>) {
|
|
46
|
+
enqueueRun(opts, processQueue);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
/**
|
|
50
|
+
* Aborts an active session.
|
|
51
|
+
*/
|
|
52
|
+
export function abortSession(sessionId: string) {
|
|
53
|
+
abortSessionQueue(sessionId);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
/**
|
|
57
|
+
* Processes the queue of assistant runs for a session.
|
|
58
|
+
*/
|
|
59
|
+
async function processQueue(sessionId: string) {
|
|
60
|
+
const state = getRunnerState(sessionId);
|
|
61
|
+
if (!state) return;
|
|
62
|
+
if (state.running) return;
|
|
63
|
+
setRunning(sessionId, true);
|
|
64
|
+
|
|
65
|
+
while (state.queue.length > 0) {
|
|
66
|
+
const job = dequeueJob(sessionId);
|
|
67
|
+
if (!job) break;
|
|
68
|
+
try {
|
|
69
|
+
await runAssistant(job);
|
|
70
|
+
} catch (_err) {
|
|
71
|
+
// Swallow to keep the loop alive; event published by runner
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
setRunning(sessionId, false);
|
|
76
|
+
cleanupSession(sessionId);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
/**
|
|
80
|
+
* Ensures the finish tool is called if not already observed.
|
|
81
|
+
*/
|
|
82
|
+
async function ensureFinishToolCalled(
|
|
83
|
+
finishObserved: boolean,
|
|
84
|
+
toolset: ReturnType<typeof adaptTools>,
|
|
85
|
+
sharedCtx: RunnerToolContext,
|
|
86
|
+
stepIndex: number,
|
|
87
|
+
) {
|
|
88
|
+
if (finishObserved || !toolset?.finish?.execute) return;
|
|
89
|
+
|
|
90
|
+
const finishInput = {} as const;
|
|
91
|
+
const callOptions = { input: finishInput } as const;
|
|
92
|
+
|
|
93
|
+
sharedCtx.stepIndex = stepIndex;
|
|
94
|
+
|
|
95
|
+
try {
|
|
96
|
+
await toolset.finish.onInputStart?.(callOptions as never);
|
|
97
|
+
} catch {}
|
|
98
|
+
|
|
99
|
+
try {
|
|
100
|
+
await toolset.finish.onInputAvailable?.(callOptions as never);
|
|
101
|
+
} catch {}
|
|
102
|
+
|
|
103
|
+
await toolset.finish.execute(finishInput, {} as never);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
/**
|
|
107
|
+
* Main function to run the assistant for a given request.
|
|
40
108
|
*/
|
|
41
|
-
|
|
42
|
-
const
|
|
43
|
-
const
|
|
44
|
-
const
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
const
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
const
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
const
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
109
|
+
async function runAssistant(opts: RunOpts) {
|
|
110
|
+
const cfgTimer = time('runner:loadConfig+db');
|
|
111
|
+
const cfg = await loadConfig(opts.projectRoot);
|
|
112
|
+
const db = await getDb(cfg.projectRoot);
|
|
113
|
+
cfgTimer.end();
|
|
114
|
+
|
|
115
|
+
const agentTimer = time('runner:resolveAgentConfig');
|
|
116
|
+
const agentCfg = await resolveAgentConfig(cfg.projectRoot, opts.agent);
|
|
117
|
+
agentTimer.end({ agent: opts.agent });
|
|
118
|
+
|
|
119
|
+
const agentPrompt = agentCfg.prompt || '';
|
|
120
|
+
|
|
121
|
+
const historyTimer = time('runner:buildHistory');
|
|
122
|
+
const history = await buildHistoryMessages(db, opts.sessionId);
|
|
123
|
+
historyTimer.end({ messages: history.length });
|
|
124
|
+
|
|
125
|
+
const isFirstMessage = history.length === 0;
|
|
126
|
+
|
|
127
|
+
const systemTimer = time('runner:composeSystemPrompt');
|
|
128
|
+
const { getAuth } = await import('@agi-cli/sdk');
|
|
129
|
+
const { getProviderSpoofPrompt } = await import('./prompt.ts');
|
|
130
|
+
const auth = await getAuth(opts.provider, cfg.projectRoot);
|
|
131
|
+
const needsSpoof = auth?.type === 'oauth';
|
|
132
|
+
const spoofPrompt = needsSpoof
|
|
133
|
+
? getProviderSpoofPrompt(opts.provider)
|
|
134
|
+
: undefined;
|
|
135
|
+
|
|
136
|
+
let system: string;
|
|
137
|
+
let additionalSystemMessages: Array<{ role: 'system'; content: string }> = [];
|
|
138
|
+
|
|
139
|
+
if (spoofPrompt) {
|
|
140
|
+
system = spoofPrompt;
|
|
141
|
+
const fullPrompt = await composeSystemPrompt({
|
|
142
|
+
provider: opts.provider,
|
|
143
|
+
model: opts.model,
|
|
144
|
+
projectRoot: cfg.projectRoot,
|
|
145
|
+
agentPrompt,
|
|
146
|
+
oneShot: opts.oneShot,
|
|
147
|
+
spoofPrompt: undefined,
|
|
148
|
+
includeProjectTree: isFirstMessage,
|
|
149
|
+
});
|
|
150
|
+
additionalSystemMessages = [{ role: 'system', content: fullPrompt }];
|
|
151
|
+
} else {
|
|
152
|
+
system = await composeSystemPrompt({
|
|
153
|
+
provider: opts.provider,
|
|
154
|
+
model: opts.model,
|
|
155
|
+
projectRoot: cfg.projectRoot,
|
|
156
|
+
agentPrompt,
|
|
157
|
+
oneShot: opts.oneShot,
|
|
158
|
+
spoofPrompt: undefined,
|
|
159
|
+
includeProjectTree: isFirstMessage,
|
|
160
|
+
});
|
|
161
|
+
}
|
|
162
|
+
systemTimer.end();
|
|
163
|
+
debugLog('[system] composed prompt (provider+base+agent):');
|
|
164
|
+
debugLog(system);
|
|
165
|
+
|
|
166
|
+
const toolsTimer = time('runner:discoverTools');
|
|
167
|
+
const allTools = await discoverProjectTools(cfg.projectRoot);
|
|
168
|
+
toolsTimer.end({ count: allTools.length });
|
|
169
|
+
const allowedNames = new Set([
|
|
170
|
+
...(agentCfg.tools || []),
|
|
171
|
+
'finish',
|
|
172
|
+
'progress_update',
|
|
173
|
+
]);
|
|
174
|
+
const gated = allTools.filter((t) => allowedNames.has(t.name));
|
|
175
|
+
const messagesWithSystemInstructions = [
|
|
176
|
+
...(isFirstMessage ? additionalSystemMessages : []),
|
|
177
|
+
...history,
|
|
178
|
+
];
|
|
179
|
+
|
|
180
|
+
const { sharedCtx, firstToolTimer, firstToolSeen } = await setupToolContext(
|
|
181
|
+
opts,
|
|
182
|
+
db,
|
|
183
|
+
);
|
|
184
|
+
const toolset = adaptTools(gated, sharedCtx, opts.provider);
|
|
185
|
+
|
|
186
|
+
const modelTimer = time('runner:resolveModel');
|
|
187
|
+
const model = await resolveModel(opts.provider, opts.model, cfg);
|
|
188
|
+
modelTimer.end();
|
|
189
|
+
|
|
190
|
+
const maxOutputTokens = getMaxOutputTokens(opts.provider, opts.model);
|
|
191
|
+
|
|
192
|
+
let currentPartId = opts.assistantPartId;
|
|
62
193
|
let accumulated = '';
|
|
63
|
-
|
|
194
|
+
let stepIndex = 0;
|
|
195
|
+
|
|
196
|
+
let finishObserved = false;
|
|
197
|
+
const unsubscribeFinish = subscribe(opts.sessionId, (evt) => {
|
|
198
|
+
if (evt.type !== 'tool.result') return;
|
|
199
|
+
try {
|
|
200
|
+
const name = (evt.payload as { name?: string } | undefined)?.name;
|
|
201
|
+
if (name === 'finish') finishObserved = true;
|
|
202
|
+
} catch {}
|
|
203
|
+
});
|
|
64
204
|
|
|
65
|
-
|
|
205
|
+
const streamStartTimer = time('runner:first-delta');
|
|
206
|
+
let firstDeltaSeen = false;
|
|
207
|
+
debugLog(`[streamText] Calling with maxOutputTokens: ${maxOutputTokens}`);
|
|
208
|
+
|
|
209
|
+
// State management helpers
|
|
66
210
|
const getCurrentPartId = () => currentPartId;
|
|
67
211
|
const getStepIndex = () => stepIndex;
|
|
68
212
|
const updateCurrentPartId = (id: string) => {
|
|
@@ -71,10 +215,12 @@ export async function runAssistant(opts: RunOpts) {
|
|
|
71
215
|
const updateAccumulated = (text: string) => {
|
|
72
216
|
accumulated = text;
|
|
73
217
|
};
|
|
74
|
-
const
|
|
75
|
-
|
|
218
|
+
const incrementStepIndex = () => {
|
|
219
|
+
stepIndex += 1;
|
|
220
|
+
return stepIndex;
|
|
221
|
+
};
|
|
76
222
|
|
|
77
|
-
//
|
|
223
|
+
// Create stream handlers
|
|
78
224
|
const onStepFinish = createStepFinishHandler(
|
|
79
225
|
opts,
|
|
80
226
|
db,
|
|
@@ -88,102 +234,105 @@ export async function runAssistant(opts: RunOpts) {
|
|
|
88
234
|
updateMessageTokensIncremental,
|
|
89
235
|
);
|
|
90
236
|
|
|
237
|
+
const onError = createErrorHandler(opts, db, getStepIndex, sharedCtx);
|
|
238
|
+
|
|
239
|
+
const onAbort = createAbortHandler(opts, db, getStepIndex, sharedCtx);
|
|
240
|
+
|
|
91
241
|
const onFinish = createFinishHandler(
|
|
92
242
|
opts,
|
|
93
243
|
db,
|
|
244
|
+
() => ensureFinishToolCalled(finishObserved, toolset, sharedCtx, stepIndex),
|
|
94
245
|
completeAssistantMessage,
|
|
95
|
-
getAccumulated,
|
|
96
|
-
abortController,
|
|
97
246
|
);
|
|
98
247
|
|
|
99
|
-
|
|
100
|
-
const
|
|
248
|
+
// Apply optimizations: deduplication, pruning, cache control, and truncation
|
|
249
|
+
const { addCacheControl, truncateHistory } = await import(
|
|
250
|
+
'./cache-optimizer.ts'
|
|
251
|
+
);
|
|
252
|
+
const { optimizeContext } = await import('./context-optimizer.ts');
|
|
101
253
|
|
|
102
|
-
//
|
|
103
|
-
const contextOptimized = optimizeContext(
|
|
254
|
+
// 1. Optimize context (deduplicate file reads, prune old tool results)
|
|
255
|
+
const contextOptimized = optimizeContext(messagesWithSystemInstructions, {
|
|
104
256
|
deduplicateFiles: true,
|
|
105
257
|
maxToolResults: 30,
|
|
106
258
|
});
|
|
107
259
|
|
|
108
|
-
// Truncate history
|
|
260
|
+
// 2. Truncate history
|
|
109
261
|
const truncatedMessages = truncateHistory(contextOptimized, 20);
|
|
110
262
|
|
|
111
|
-
// Add cache control
|
|
263
|
+
// 3. Add cache control
|
|
112
264
|
const { system: cachedSystem, messages: optimizedMessages } = addCacheControl(
|
|
113
|
-
opts.provider,
|
|
265
|
+
opts.provider as any,
|
|
114
266
|
system,
|
|
115
267
|
truncatedMessages,
|
|
116
268
|
);
|
|
117
269
|
|
|
118
270
|
try {
|
|
119
|
-
|
|
120
|
-
const result =
|
|
271
|
+
// @ts-expect-error this is fine 🔥
|
|
272
|
+
const result = streamText({
|
|
121
273
|
model,
|
|
122
|
-
|
|
274
|
+
tools: toolset,
|
|
275
|
+
...(cachedSystem ? { system: cachedSystem } : {}),
|
|
123
276
|
messages: optimizedMessages,
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
temperature: agentConfig.temperature ?? 0.7,
|
|
128
|
-
abortSignal: abortController.signal,
|
|
277
|
+
...(maxOutputTokens ? { maxOutputTokens } : {}),
|
|
278
|
+
abortSignal: opts.abortSignal,
|
|
279
|
+
stopWhen: hasToolCall('finish'),
|
|
129
280
|
onStepFinish,
|
|
281
|
+
onError,
|
|
282
|
+
onAbort,
|
|
130
283
|
onFinish,
|
|
131
|
-
experimental_continueSteps: true,
|
|
132
284
|
});
|
|
133
285
|
|
|
134
|
-
// Process the stream
|
|
135
286
|
for await (const delta of result.textStream) {
|
|
136
|
-
if (
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
await db
|
|
141
|
-
.update(messageParts)
|
|
142
|
-
.set({ content: accumulated })
|
|
143
|
-
.where(eq(messageParts.id, currentPartId));
|
|
287
|
+
if (!delta) continue;
|
|
288
|
+
if (!firstDeltaSeen) {
|
|
289
|
+
firstDeltaSeen = true;
|
|
290
|
+
streamStartTimer.end();
|
|
144
291
|
}
|
|
145
|
-
|
|
146
|
-
publish(
|
|
292
|
+
accumulated += delta;
|
|
293
|
+
publish({
|
|
294
|
+
type: 'message.part.delta',
|
|
147
295
|
sessionId: opts.sessionId,
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
296
|
+
payload: {
|
|
297
|
+
messageId: opts.assistantMessageId,
|
|
298
|
+
partId: currentPartId,
|
|
299
|
+
stepIndex,
|
|
300
|
+
delta,
|
|
301
|
+
},
|
|
153
302
|
});
|
|
303
|
+
await db
|
|
304
|
+
.update(messageParts)
|
|
305
|
+
.set({ content: JSON.stringify({ text: accumulated }) })
|
|
306
|
+
.where(eq(messageParts.id, currentPartId));
|
|
154
307
|
}
|
|
155
|
-
} catch (
|
|
156
|
-
|
|
308
|
+
} catch (error) {
|
|
309
|
+
const errorPayload = toErrorPayload(error);
|
|
310
|
+
await db
|
|
311
|
+
.update(messageParts)
|
|
312
|
+
.set({
|
|
313
|
+
content: JSON.stringify({
|
|
314
|
+
text: accumulated,
|
|
315
|
+
error: errorPayload.message,
|
|
316
|
+
}),
|
|
317
|
+
})
|
|
318
|
+
.where(eq(messageParts.messageId, opts.assistantMessageId));
|
|
319
|
+
publish({
|
|
320
|
+
type: 'error',
|
|
321
|
+
sessionId: opts.sessionId,
|
|
322
|
+
payload: {
|
|
323
|
+
messageId: opts.assistantMessageId,
|
|
324
|
+
error: errorPayload.message,
|
|
325
|
+
details: errorPayload.details,
|
|
326
|
+
},
|
|
327
|
+
});
|
|
328
|
+
throw error;
|
|
157
329
|
} finally {
|
|
158
|
-
|
|
159
|
-
|
|
330
|
+
if (!firstToolSeen()) firstToolTimer.end({ skipped: true });
|
|
331
|
+
try {
|
|
332
|
+
unsubscribeFinish();
|
|
333
|
+
} catch {}
|
|
334
|
+
try {
|
|
335
|
+
await cleanupEmptyTextParts(opts, db);
|
|
336
|
+
} catch {}
|
|
160
337
|
}
|
|
161
338
|
}
|
|
162
|
-
|
|
163
|
-
/**
|
|
164
|
-
* Enqueues an assistant run
|
|
165
|
-
*/
|
|
166
|
-
export async function enqueueAssistantRun(opts: RunOpts) {
|
|
167
|
-
return enqueueRun(opts);
|
|
168
|
-
}
|
|
169
|
-
|
|
170
|
-
/**
|
|
171
|
-
* Aborts a running session
|
|
172
|
-
*/
|
|
173
|
-
export async function abortSession(sessionId: number) {
|
|
174
|
-
return abortSessionQueue(sessionId);
|
|
175
|
-
}
|
|
176
|
-
|
|
177
|
-
/**
|
|
178
|
-
* Gets the current runner state for a session
|
|
179
|
-
*/
|
|
180
|
-
export function getSessionState(sessionId: number) {
|
|
181
|
-
return getRunnerState(sessionId);
|
|
182
|
-
}
|
|
183
|
-
|
|
184
|
-
/**
|
|
185
|
-
* Cleanup session resources
|
|
186
|
-
*/
|
|
187
|
-
export function cleanupSessionResources(sessionId: number) {
|
|
188
|
-
return cleanupSession(sessionId);
|
|
189
|
-
}
|
|
@@ -8,26 +8,17 @@ import { toErrorPayload } from './error-handling.ts';
|
|
|
8
8
|
import type { RunOpts } from './session-queue.ts';
|
|
9
9
|
import type { ToolAdapterContext } from '../tools/adapter.ts';
|
|
10
10
|
|
|
11
|
-
interface ProviderMetadata {
|
|
12
|
-
openai?: {
|
|
13
|
-
cachedPromptTokens?: number;
|
|
14
|
-
};
|
|
15
|
-
[key: string]: unknown;
|
|
16
|
-
}
|
|
17
|
-
|
|
18
|
-
interface UsageData {
|
|
19
|
-
inputTokens?: number;
|
|
20
|
-
outputTokens?: number;
|
|
21
|
-
totalTokens?: number;
|
|
22
|
-
cachedInputTokens?: number;
|
|
23
|
-
reasoningTokens?: number;
|
|
24
|
-
}
|
|
25
|
-
|
|
26
11
|
type StepFinishEvent = {
|
|
27
|
-
usage?:
|
|
12
|
+
usage?: {
|
|
13
|
+
inputTokens?: number;
|
|
14
|
+
outputTokens?: number;
|
|
15
|
+
totalTokens?: number;
|
|
16
|
+
cachedInputTokens?: number;
|
|
17
|
+
reasoningTokens?: number;
|
|
18
|
+
};
|
|
28
19
|
finishReason?: string;
|
|
29
20
|
response?: unknown;
|
|
30
|
-
experimental_providerMetadata?:
|
|
21
|
+
experimental_providerMetadata?: Record<string, any>;
|
|
31
22
|
};
|
|
32
23
|
|
|
33
24
|
type FinishEvent = {
|
|
@@ -51,19 +42,19 @@ export function createStepFinishHandler(
|
|
|
51
42
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
52
43
|
getCurrentPartId: () => string,
|
|
53
44
|
getStepIndex: () => number,
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
45
|
+
sharedCtx: ToolAdapterContext,
|
|
46
|
+
updateCurrentPartId: (id: string) => void,
|
|
47
|
+
updateAccumulated: (text: string) => void,
|
|
57
48
|
incrementStepIndex: () => number,
|
|
58
49
|
updateSessionTokensIncrementalFn: (
|
|
59
|
-
usage:
|
|
60
|
-
providerMetadata:
|
|
50
|
+
usage: any,
|
|
51
|
+
providerMetadata: Record<string, any> | undefined,
|
|
61
52
|
opts: RunOpts,
|
|
62
53
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
63
54
|
) => Promise<void>,
|
|
64
55
|
updateMessageTokensIncrementalFn: (
|
|
65
|
-
usage:
|
|
66
|
-
providerMetadata:
|
|
56
|
+
usage: any,
|
|
57
|
+
providerMetadata: Record<string, any> | undefined,
|
|
67
58
|
opts: RunOpts,
|
|
68
59
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
69
60
|
) => Promise<void>,
|
|
@@ -78,11 +69,9 @@ export function createStepFinishHandler(
|
|
|
78
69
|
.update(messageParts)
|
|
79
70
|
.set({ completedAt: finishedAt })
|
|
80
71
|
.where(eq(messageParts.id, currentPartId));
|
|
81
|
-
} catch
|
|
82
|
-
console.error('[createStepFinishHandler] Failed to update part', err);
|
|
83
|
-
}
|
|
72
|
+
} catch {}
|
|
84
73
|
|
|
85
|
-
// Update
|
|
74
|
+
// Update token counts incrementally after each step
|
|
86
75
|
if (step.usage) {
|
|
87
76
|
try {
|
|
88
77
|
await updateSessionTokensIncrementalFn(
|
|
@@ -91,81 +80,126 @@ export function createStepFinishHandler(
|
|
|
91
80
|
opts,
|
|
92
81
|
db,
|
|
93
82
|
);
|
|
83
|
+
} catch {}
|
|
84
|
+
|
|
85
|
+
try {
|
|
94
86
|
await updateMessageTokensIncrementalFn(
|
|
95
87
|
step.usage,
|
|
96
88
|
step.experimental_providerMetadata,
|
|
97
89
|
opts,
|
|
98
90
|
db,
|
|
99
91
|
);
|
|
100
|
-
} catch
|
|
101
|
-
console.error('[createStepFinishHandler] Token update failed', err);
|
|
102
|
-
}
|
|
92
|
+
} catch {}
|
|
103
93
|
}
|
|
104
94
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
95
|
+
try {
|
|
96
|
+
publish({
|
|
97
|
+
type: 'finish-step',
|
|
98
|
+
sessionId: opts.sessionId,
|
|
99
|
+
payload: {
|
|
100
|
+
stepIndex,
|
|
101
|
+
usage: step.usage,
|
|
102
|
+
finishReason: step.finishReason,
|
|
103
|
+
response: step.response,
|
|
104
|
+
},
|
|
105
|
+
});
|
|
106
|
+
if (step.usage) {
|
|
107
|
+
publish({
|
|
108
|
+
type: 'usage',
|
|
109
|
+
sessionId: opts.sessionId,
|
|
110
|
+
payload: { stepIndex, ...step.usage },
|
|
111
|
+
});
|
|
112
|
+
}
|
|
113
|
+
} catch {}
|
|
114
114
|
|
|
115
|
-
|
|
115
|
+
try {
|
|
116
|
+
const newStepIndex = incrementStepIndex();
|
|
117
|
+
const newPartId = crypto.randomUUID();
|
|
118
|
+
const index = await sharedCtx.nextIndex();
|
|
119
|
+
const nowTs = Date.now();
|
|
120
|
+
await db.insert(messageParts).values({
|
|
121
|
+
id: newPartId,
|
|
122
|
+
messageId: opts.assistantMessageId,
|
|
123
|
+
index,
|
|
124
|
+
stepIndex: newStepIndex,
|
|
125
|
+
type: 'text',
|
|
126
|
+
content: JSON.stringify({ text: '' }),
|
|
127
|
+
agent: opts.agent,
|
|
128
|
+
provider: opts.provider,
|
|
129
|
+
model: opts.model,
|
|
130
|
+
startedAt: nowTs,
|
|
131
|
+
});
|
|
132
|
+
updateCurrentPartId(newPartId);
|
|
133
|
+
sharedCtx.assistantPartId = newPartId;
|
|
134
|
+
sharedCtx.stepIndex = newStepIndex;
|
|
135
|
+
updateAccumulated('');
|
|
136
|
+
} catch {}
|
|
116
137
|
};
|
|
117
138
|
}
|
|
118
139
|
|
|
119
140
|
/**
|
|
120
|
-
* Creates the
|
|
141
|
+
* Creates the onError handler for the stream
|
|
121
142
|
*/
|
|
122
|
-
export function
|
|
143
|
+
export function createErrorHandler(
|
|
123
144
|
opts: RunOpts,
|
|
124
145
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
opts: RunOpts,
|
|
128
|
-
db: Awaited<ReturnType<typeof getDb>>,
|
|
129
|
-
) => Promise<void>,
|
|
130
|
-
_getAccumulated: () => string,
|
|
131
|
-
_abortController: AbortController,
|
|
146
|
+
getStepIndex: () => number,
|
|
147
|
+
sharedCtx: ToolAdapterContext,
|
|
132
148
|
) {
|
|
133
|
-
return async (
|
|
134
|
-
|
|
135
|
-
|
|
149
|
+
return async (err: unknown) => {
|
|
150
|
+
const errorPayload = toErrorPayload(err);
|
|
151
|
+
const isApiError = APICallError.isInstance(err);
|
|
152
|
+
const stepIndex = getStepIndex();
|
|
136
153
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
154
|
+
// Create error part for UI display
|
|
155
|
+
const errorPartId = crypto.randomUUID();
|
|
156
|
+
await db.insert(messageParts).values({
|
|
157
|
+
id: errorPartId,
|
|
158
|
+
messageId: opts.assistantMessageId,
|
|
159
|
+
index: await sharedCtx.nextIndex(),
|
|
160
|
+
stepIndex,
|
|
161
|
+
type: 'error',
|
|
162
|
+
content: JSON.stringify({
|
|
163
|
+
message: errorPayload.message,
|
|
164
|
+
type: errorPayload.type,
|
|
165
|
+
details: errorPayload.details,
|
|
166
|
+
isAborted: false,
|
|
167
|
+
}),
|
|
168
|
+
agent: opts.agent,
|
|
169
|
+
provider: opts.provider,
|
|
170
|
+
model: opts.model,
|
|
171
|
+
startedAt: Date.now(),
|
|
172
|
+
completedAt: Date.now(),
|
|
173
|
+
});
|
|
141
174
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
175
|
+
// Update message status
|
|
176
|
+
await db
|
|
177
|
+
.update(messages)
|
|
178
|
+
.set({
|
|
179
|
+
status: 'error',
|
|
180
|
+
error: errorPayload.message,
|
|
181
|
+
errorType: errorPayload.type,
|
|
182
|
+
errorDetails: JSON.stringify({
|
|
183
|
+
...errorPayload.details,
|
|
184
|
+
isApiError,
|
|
185
|
+
}),
|
|
186
|
+
isAborted: false,
|
|
187
|
+
})
|
|
188
|
+
.where(eq(messages.id, opts.assistantMessageId));
|
|
152
189
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
finishReason: fin.finishReason,
|
|
159
|
-
estimatedCost,
|
|
160
|
-
});
|
|
161
|
-
} catch (err) {
|
|
162
|
-
console.error('[createFinishHandler] Error in onFinish', err);
|
|
163
|
-
publish('stream:error', {
|
|
164
|
-
sessionId: opts.sessionId,
|
|
190
|
+
// Publish enhanced error event
|
|
191
|
+
publish({
|
|
192
|
+
type: 'error',
|
|
193
|
+
sessionId: opts.sessionId,
|
|
194
|
+
payload: {
|
|
165
195
|
messageId: opts.assistantMessageId,
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
196
|
+
partId: errorPartId,
|
|
197
|
+
error: errorPayload.message,
|
|
198
|
+
errorType: errorPayload.type,
|
|
199
|
+
details: errorPayload.details,
|
|
200
|
+
isAborted: false,
|
|
201
|
+
},
|
|
202
|
+
});
|
|
169
203
|
};
|
|
170
204
|
}
|
|
171
205
|
|
|
@@ -175,116 +209,116 @@ export function createFinishHandler(
|
|
|
175
209
|
export function createAbortHandler(
|
|
176
210
|
opts: RunOpts,
|
|
177
211
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
178
|
-
|
|
212
|
+
getStepIndex: () => number,
|
|
213
|
+
sharedCtx: ToolAdapterContext,
|
|
179
214
|
) {
|
|
180
|
-
return async (
|
|
181
|
-
|
|
182
|
-
await db
|
|
183
|
-
.update(messages)
|
|
184
|
-
.set({ status: 'aborted', finishedAt: new Date() })
|
|
185
|
-
.where(eq(messages.id, opts.assistantMessageId));
|
|
215
|
+
return async ({ steps }: AbortEvent) => {
|
|
216
|
+
const stepIndex = getStepIndex();
|
|
186
217
|
|
|
187
|
-
|
|
188
|
-
|
|
218
|
+
// Create abort part for UI
|
|
219
|
+
const abortPartId = crypto.randomUUID();
|
|
220
|
+
await db.insert(messageParts).values({
|
|
221
|
+
id: abortPartId,
|
|
222
|
+
messageId: opts.assistantMessageId,
|
|
223
|
+
index: await sharedCtx.nextIndex(),
|
|
224
|
+
stepIndex,
|
|
225
|
+
type: 'error',
|
|
226
|
+
content: JSON.stringify({
|
|
227
|
+
message: 'Generation stopped by user',
|
|
228
|
+
type: 'abort',
|
|
229
|
+
isAborted: true,
|
|
230
|
+
stepsCompleted: steps.length,
|
|
231
|
+
}),
|
|
232
|
+
agent: opts.agent,
|
|
233
|
+
provider: opts.provider,
|
|
234
|
+
model: opts.model,
|
|
235
|
+
startedAt: Date.now(),
|
|
236
|
+
completedAt: Date.now(),
|
|
237
|
+
});
|
|
238
|
+
|
|
239
|
+
// Store abort info
|
|
240
|
+
await db
|
|
241
|
+
.update(messages)
|
|
242
|
+
.set({
|
|
243
|
+
status: 'error',
|
|
244
|
+
error: 'Generation stopped by user',
|
|
245
|
+
errorType: 'abort',
|
|
246
|
+
errorDetails: JSON.stringify({
|
|
247
|
+
stepsCompleted: steps.length,
|
|
248
|
+
abortedAt: Date.now(),
|
|
249
|
+
}),
|
|
250
|
+
isAborted: true,
|
|
251
|
+
})
|
|
252
|
+
.where(eq(messages.id, opts.assistantMessageId));
|
|
253
|
+
|
|
254
|
+
// Publish abort event
|
|
255
|
+
publish({
|
|
256
|
+
type: 'error',
|
|
257
|
+
sessionId: opts.sessionId,
|
|
258
|
+
payload: {
|
|
189
259
|
messageId: opts.assistantMessageId,
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
260
|
+
partId: abortPartId,
|
|
261
|
+
error: 'Generation stopped by user',
|
|
262
|
+
errorType: 'abort',
|
|
263
|
+
isAborted: true,
|
|
264
|
+
stepsCompleted: steps.length,
|
|
265
|
+
},
|
|
266
|
+
});
|
|
195
267
|
};
|
|
196
268
|
}
|
|
197
269
|
|
|
198
270
|
/**
|
|
199
|
-
* Creates the
|
|
271
|
+
* Creates the onFinish handler for the stream
|
|
200
272
|
*/
|
|
201
|
-
export function
|
|
273
|
+
export function createFinishHandler(
|
|
202
274
|
opts: RunOpts,
|
|
203
275
|
db: Awaited<ReturnType<typeof getDb>>,
|
|
276
|
+
ensureFinishToolCalled: () => Promise<void>,
|
|
277
|
+
completeAssistantMessageFn: (
|
|
278
|
+
fin: FinishEvent,
|
|
279
|
+
opts: RunOpts,
|
|
280
|
+
db: Awaited<ReturnType<typeof getDb>>,
|
|
281
|
+
) => Promise<void>,
|
|
204
282
|
) {
|
|
205
|
-
return async (
|
|
206
|
-
console.error('[createErrorHandler] Stream error:', err);
|
|
207
|
-
|
|
283
|
+
return async (fin: FinishEvent) => {
|
|
208
284
|
try {
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
let errorStack: string | undefined;
|
|
285
|
+
await ensureFinishToolCalled();
|
|
286
|
+
} catch {}
|
|
212
287
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
errorType = 'API_CALL_ERROR';
|
|
216
|
-
errorStack = err.stack;
|
|
217
|
-
} else if (err instanceof Error) {
|
|
218
|
-
errorMessage = err.message;
|
|
219
|
-
errorType = err.name || 'ERROR';
|
|
220
|
-
errorStack = err.stack;
|
|
221
|
-
} else if (typeof err === 'string') {
|
|
222
|
-
errorMessage = err;
|
|
223
|
-
}
|
|
288
|
+
// Note: Token updates are handled incrementally in onStepFinish
|
|
289
|
+
// Do NOT add fin.usage here as it would cause double-counting
|
|
224
290
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
status: 'error',
|
|
229
|
-
finishedAt: new Date(),
|
|
230
|
-
error: errorMessage,
|
|
231
|
-
})
|
|
232
|
-
.where(eq(messages.id, opts.assistantMessageId));
|
|
233
|
-
|
|
234
|
-
publish('stream:error', {
|
|
235
|
-
sessionId: opts.sessionId,
|
|
236
|
-
messageId: opts.assistantMessageId,
|
|
237
|
-
assistantMessageId: opts.assistantMessageId,
|
|
238
|
-
error: {
|
|
239
|
-
message: errorMessage,
|
|
240
|
-
type: errorType,
|
|
241
|
-
stack: errorStack,
|
|
242
|
-
},
|
|
243
|
-
});
|
|
244
|
-
} catch (dbErr) {
|
|
245
|
-
console.error('[createErrorHandler] Failed to save error to DB', dbErr);
|
|
246
|
-
}
|
|
247
|
-
};
|
|
248
|
-
}
|
|
291
|
+
try {
|
|
292
|
+
await completeAssistantMessageFn(fin, opts, db);
|
|
293
|
+
} catch {}
|
|
249
294
|
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
db: Awaited<ReturnType<typeof getDb>>,
|
|
256
|
-
getCurrentPartId: () => string,
|
|
257
|
-
getStepIndex: () => number,
|
|
258
|
-
_updateCurrentPartId: (id: string) => void,
|
|
259
|
-
updateAccumulated: (text: string) => void,
|
|
260
|
-
getAccumulated: () => string,
|
|
261
|
-
) {
|
|
262
|
-
return async (textDelta: string) => {
|
|
263
|
-
const currentPartId = getCurrentPartId();
|
|
264
|
-
const stepIndex = getStepIndex();
|
|
295
|
+
// Use session totals from DB for accurate cost calculation
|
|
296
|
+
const sessRows = await db
|
|
297
|
+
.select()
|
|
298
|
+
.from(messages)
|
|
299
|
+
.where(eq(messages.id, opts.assistantMessageId));
|
|
265
300
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
301
|
+
const usage = sessRows[0]
|
|
302
|
+
? {
|
|
303
|
+
inputTokens: Number(sessRows[0].promptTokens ?? 0),
|
|
304
|
+
outputTokens: Number(sessRows[0].completionTokens ?? 0),
|
|
305
|
+
totalTokens: Number(sessRows[0].totalTokens ?? 0),
|
|
306
|
+
}
|
|
307
|
+
: fin.usage;
|
|
269
308
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
.update(messageParts)
|
|
274
|
-
.set({ content: accumulated })
|
|
275
|
-
.where(eq(messageParts.id, currentPartId));
|
|
276
|
-
}
|
|
309
|
+
const costUsd = usage
|
|
310
|
+
? estimateModelCostUsd(opts.provider, opts.model, usage)
|
|
311
|
+
: undefined;
|
|
277
312
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
}
|
|
313
|
+
publish({
|
|
314
|
+
type: 'message.completed',
|
|
315
|
+
sessionId: opts.sessionId,
|
|
316
|
+
payload: {
|
|
317
|
+
id: opts.assistantMessageId,
|
|
318
|
+
usage,
|
|
319
|
+
costUsd,
|
|
320
|
+
finishReason: fin.finishReason,
|
|
321
|
+
},
|
|
322
|
+
});
|
|
289
323
|
};
|
|
290
324
|
}
|