@agi-cli/server 0.1.61 → 0.1.62

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@agi-cli/server",
3
- "version": "0.1.61",
3
+ "version": "0.1.62",
4
4
  "description": "HTTP API server for AGI CLI",
5
5
  "type": "module",
6
6
  "main": "./src/index.ts",
@@ -29,8 +29,8 @@
29
29
  "typecheck": "tsc --noEmit"
30
30
  },
31
31
  "dependencies": {
32
- "@agi-cli/sdk": "0.1.61",
33
- "@agi-cli/database": "0.1.61",
32
+ "@agi-cli/sdk": "0.1.62",
33
+ "@agi-cli/database": "0.1.62",
34
34
  "drizzle-orm": "^0.44.5",
35
35
  "hono": "^4.9.9"
36
36
  },
@@ -1,23 +1,5 @@
1
1
  import type { ModelMessage } from 'ai';
2
2
 
3
- type SystemMessage =
4
- | string
5
- | Array<{
6
- type: 'text';
7
- text: string;
8
- cache_control?: { type: 'ephemeral' };
9
- }>;
10
-
11
- interface ContentPart {
12
- type: string;
13
- [key: string]: unknown;
14
- providerOptions?: {
15
- anthropic?: {
16
- cacheControl?: { type: 'ephemeral' };
17
- };
18
- };
19
- }
20
-
21
3
  /**
22
4
  * Adds cache control to messages for prompt caching optimization.
23
5
  * Anthropic supports caching for system messages, tools, and long context.
@@ -27,7 +9,13 @@ export function addCacheControl(
27
9
  system: string | undefined,
28
10
  messages: ModelMessage[],
29
11
  ): {
30
- system?: SystemMessage;
12
+ system?:
13
+ | string
14
+ | Array<{
15
+ type: 'text';
16
+ text: string;
17
+ cache_control?: { type: 'ephemeral' };
18
+ }>;
31
19
  messages: ModelMessage[];
32
20
  } {
33
21
  // Only Anthropic supports prompt caching currently
@@ -36,7 +24,7 @@ export function addCacheControl(
36
24
  }
37
25
 
38
26
  // Convert system to cacheable format if it's long enough
39
- let cachedSystem: SystemMessage | undefined = system;
27
+ let cachedSystem: any = system;
40
28
  if (system && system.length > 1024) {
41
29
  // Anthropic requires 1024+ tokens for Claude Sonnet/Opus
42
30
  cachedSystem = [
@@ -73,21 +61,55 @@ export function addCacheControl(
73
61
  // Add cache control to the last content part of that message
74
62
  const lastPart = targetMsg.content[targetMsg.content.length - 1];
75
63
  if (lastPart && typeof lastPart === 'object' && 'type' in lastPart) {
76
- (lastPart as ContentPart).providerOptions = {
64
+ (lastPart as any).providerOptions = {
77
65
  anthropic: { cacheControl: { type: 'ephemeral' } },
78
66
  };
79
67
  }
80
68
  }
81
69
  }
82
70
 
83
- return {
84
- system: cachedSystem,
85
- messages: cachedMessages,
86
- };
71
+ return { system: cachedSystem, messages: cachedMessages };
72
+ }
73
+
74
+ return { system: cachedSystem, messages };
75
+ }
76
+
77
+ /**
78
+ * Truncates old messages to reduce context size while keeping recent context.
79
+ * Strategy: Keep system message + last N messages
80
+ */
81
+ export function truncateHistory(
82
+ messages: ModelMessage[],
83
+ maxMessages = 20,
84
+ ): ModelMessage[] {
85
+ if (messages.length <= maxMessages) {
86
+ return messages;
87
+ }
88
+
89
+ // Keep the most recent messages
90
+ return messages.slice(-maxMessages);
91
+ }
92
+
93
+ /**
94
+ * Estimates token count (rough approximation: ~4 chars per token)
95
+ */
96
+ export function estimateTokens(text: string): number {
97
+ return Math.ceil(text.length / 4);
98
+ }
99
+
100
+ /**
101
+ * Summarizes tool results if they're too long
102
+ */
103
+ export function summarizeToolResult(result: unknown, maxLength = 5000): string {
104
+ const str = typeof result === 'string' ? result : JSON.stringify(result);
105
+
106
+ if (str.length <= maxLength) {
107
+ return str;
87
108
  }
88
109
 
89
- return {
90
- system: cachedSystem,
91
- messages,
92
- };
110
+ // Truncate and add indicator
111
+ return (
112
+ str.slice(0, maxLength) +
113
+ `\n\n[... truncated ${str.length - maxLength} characters]`
114
+ );
93
115
  }
@@ -11,20 +11,13 @@ type UsageData = {
11
11
  reasoningTokens?: number;
12
12
  };
13
13
 
14
- interface ProviderMetadata {
15
- openai?: {
16
- cachedPromptTokens?: number;
17
- };
18
- [key: string]: unknown;
19
- }
20
-
21
14
  /**
22
15
  * Updates session token counts incrementally after each step.
23
16
  * Note: onStepFinish.usage is CUMULATIVE per message, so we compute DELTA and add to session.
24
17
  */
25
18
  export async function updateSessionTokensIncremental(
26
19
  usage: UsageData,
27
- providerMetadata: ProviderMetadata | undefined,
20
+ providerMetadata: Record<string, any> | undefined,
28
21
  opts: RunOpts,
29
22
  db: Awaited<ReturnType<typeof getDb>>,
30
23
  ) {
@@ -136,7 +129,7 @@ export async function updateSessionTokens(
136
129
  */
137
130
  export async function updateMessageTokensIncremental(
138
131
  usage: UsageData,
139
- providerMetadata: ProviderMetadata | undefined,
132
+ providerMetadata: Record<string, any> | undefined,
140
133
  opts: RunOpts,
141
134
  db: Awaited<ReturnType<typeof getDb>>,
142
135
  ) {
@@ -155,74 +148,86 @@ export async function updateMessageTokensIncremental(
155
148
  const priorReasoning = Number(msg.reasoningTokens ?? 0);
156
149
 
157
150
  // Treat usage as cumulative per-message - REPLACE not ADD
158
- const nextPrompt =
151
+ const cumPrompt =
159
152
  usage.inputTokens != null ? Number(usage.inputTokens) : priorPrompt;
160
- const nextCompletion =
153
+ const cumCompletion =
161
154
  usage.outputTokens != null ? Number(usage.outputTokens) : priorCompletion;
162
- const nextReasoning =
155
+ const cumReasoning =
163
156
  usage.reasoningTokens != null
164
157
  ? Number(usage.reasoningTokens)
165
158
  : priorReasoning;
166
159
 
167
- const nextCached =
160
+ const cumCached =
168
161
  usage.cachedInputTokens != null
169
162
  ? Number(usage.cachedInputTokens)
170
163
  : providerMetadata?.openai?.cachedPromptTokens != null
171
164
  ? Number(providerMetadata.openai.cachedPromptTokens)
172
165
  : priorCached;
173
166
 
167
+ const cumTotal =
168
+ usage.totalTokens != null
169
+ ? Number(usage.totalTokens)
170
+ : cumPrompt + cumCompletion + cumReasoning;
171
+
174
172
  await db
175
173
  .update(messages)
176
174
  .set({
177
- promptTokens: nextPrompt,
178
- completionTokens: nextCompletion,
179
- cachedInputTokens: nextCached,
180
- reasoningTokens: nextReasoning,
175
+ promptTokens: cumPrompt,
176
+ completionTokens: cumCompletion,
177
+ totalTokens: cumTotal,
178
+ cachedInputTokens: cumCached,
179
+ reasoningTokens: cumReasoning,
181
180
  })
182
181
  .where(eq(messages.id, opts.assistantMessageId));
183
182
  }
184
183
  }
185
184
 
186
185
  /**
187
- * Completes the assistant message after the run finishes.
188
- * Used to finalize timing but NOT tokens, which are already incremental.
186
+ * Marks an assistant message as complete.
187
+ * Token usage is tracked incrementally via updateMessageTokensIncremental().
189
188
  */
190
189
  export async function completeAssistantMessage(
191
- _fin: {
190
+ fin: {
192
191
  usage?: {
193
192
  inputTokens?: number;
194
193
  outputTokens?: number;
194
+ totalTokens?: number;
195
195
  };
196
196
  },
197
197
  opts: RunOpts,
198
198
  db: Awaited<ReturnType<typeof getDb>>,
199
199
  ) {
200
- const msgRow = await db
201
- .select()
202
- .from(messages)
200
+ // Only mark as complete - tokens are already tracked incrementally
201
+ await db
202
+ .update(messages)
203
+ .set({
204
+ status: 'complete',
205
+ completedAt: Date.now(),
206
+ })
203
207
  .where(eq(messages.id, opts.assistantMessageId));
204
-
205
- if (msgRow.length > 0) {
206
- await db
207
- .update(messages)
208
- .set({
209
- finishedAt: new Date(),
210
- })
211
- .where(eq(messages.id, opts.assistantMessageId));
212
- }
213
208
  }
214
209
 
215
- export async function createMessagePart(
216
- partData: {
217
- messageId: number;
218
- contentType: 'text' | 'tool' | 'other';
219
- toolName?: string | null;
220
- toolArgs?: unknown;
221
- toolResult?: unknown;
222
- textContent?: string | null;
223
- stepIndex?: number | null;
224
- },
210
+ /**
211
+ * Removes empty text parts from an assistant message.
212
+ */
213
+ export async function cleanupEmptyTextParts(
214
+ opts: RunOpts,
225
215
  db: Awaited<ReturnType<typeof getDb>>,
226
216
  ) {
227
- await db.insert(messageParts).values(partData);
217
+ const parts = await db
218
+ .select()
219
+ .from(messageParts)
220
+ .where(eq(messageParts.messageId, opts.assistantMessageId));
221
+
222
+ for (const p of parts) {
223
+ if (p.type === 'text') {
224
+ let t = '';
225
+ try {
226
+ t = JSON.parse(p.content || '{}')?.text || '';
227
+ } catch {}
228
+ if (!t || t.length === 0) {
229
+ await db.delete(messageParts).where(eq(messageParts.id, p.id));
230
+ }
231
+ }
232
+ }
228
233
  }
@@ -1,4 +1,4 @@
1
- import { streamText } from 'ai';
1
+ import { hasToolCall, streamText } from 'ai';
2
2
  import { loadConfig } from '@agi-cli/sdk';
3
3
  import { getDb } from '@agi-cli/database';
4
4
  import { messageParts } from '@agi-cli/database/schema';
@@ -7,8 +7,11 @@ import { resolveModel } from './provider.ts';
7
7
  import { resolveAgentConfig } from './agent-registry.ts';
8
8
  import { composeSystemPrompt } from './prompt.ts';
9
9
  import { discoverProjectTools } from '@agi-cli/sdk';
10
- import { publish } from '../events/bus.ts';
10
+ import { adaptTools } from '../tools/adapter.ts';
11
+ import { publish, subscribe } from '../events/bus.ts';
12
+ import { debugLog, time } from './debug.ts';
11
13
  import { buildHistoryMessages } from './history-builder.ts';
14
+ import { toErrorPayload } from './error-handling.ts';
12
15
  import { getMaxOutputTokens } from './token-utils.ts';
13
16
  import {
14
17
  type RunOpts,
@@ -19,11 +22,15 @@ import {
19
22
  dequeueJob,
20
23
  cleanupSession,
21
24
  } from './session-queue.ts';
22
- import { setupToolContext } from './tool-context-setup.ts';
25
+ import {
26
+ setupToolContext,
27
+ type RunnerToolContext,
28
+ } from './tool-context-setup.ts';
23
29
  import {
24
30
  updateSessionTokensIncremental,
25
31
  updateMessageTokensIncremental,
26
32
  completeAssistantMessage,
33
+ cleanupEmptyTextParts,
27
34
  } from './db-operations.ts';
28
35
  import {
29
36
  createStepFinishHandler,
@@ -31,38 +38,175 @@ import {
31
38
  createAbortHandler,
32
39
  createFinishHandler,
33
40
  } from './stream-handlers.ts';
34
- import { addCacheControl } from './cache-optimizer.ts';
35
- import { optimizeContext } from './context-optimizer.ts';
36
- import { truncateHistory } from './history-truncator.ts';
37
41
 
38
42
  /**
39
- * Main runner that executes the LLM streaming loop with tools
43
+ * Enqueues an assistant run for processing.
44
+ */
45
+ export function enqueueAssistantRun(opts: Omit<RunOpts, 'abortSignal'>) {
46
+ enqueueRun(opts, processQueue);
47
+ }
48
+
49
+ /**
50
+ * Aborts an active session.
51
+ */
52
+ export function abortSession(sessionId: string) {
53
+ abortSessionQueue(sessionId);
54
+ }
55
+
56
+ /**
57
+ * Processes the queue of assistant runs for a session.
58
+ */
59
+ async function processQueue(sessionId: string) {
60
+ const state = getRunnerState(sessionId);
61
+ if (!state) return;
62
+ if (state.running) return;
63
+ setRunning(sessionId, true);
64
+
65
+ while (state.queue.length > 0) {
66
+ const job = dequeueJob(sessionId);
67
+ if (!job) break;
68
+ try {
69
+ await runAssistant(job);
70
+ } catch (_err) {
71
+ // Swallow to keep the loop alive; event published by runner
72
+ }
73
+ }
74
+
75
+ setRunning(sessionId, false);
76
+ cleanupSession(sessionId);
77
+ }
78
+
79
+ /**
80
+ * Ensures the finish tool is called if not already observed.
81
+ */
82
+ async function ensureFinishToolCalled(
83
+ finishObserved: boolean,
84
+ toolset: ReturnType<typeof adaptTools>,
85
+ sharedCtx: RunnerToolContext,
86
+ stepIndex: number,
87
+ ) {
88
+ if (finishObserved || !toolset?.finish?.execute) return;
89
+
90
+ const finishInput = {} as const;
91
+ const callOptions = { input: finishInput } as const;
92
+
93
+ sharedCtx.stepIndex = stepIndex;
94
+
95
+ try {
96
+ await toolset.finish.onInputStart?.(callOptions as never);
97
+ } catch {}
98
+
99
+ try {
100
+ await toolset.finish.onInputAvailable?.(callOptions as never);
101
+ } catch {}
102
+
103
+ await toolset.finish.execute(finishInput, {} as never);
104
+ }
105
+
106
+ /**
107
+ * Main function to run the assistant for a given request.
40
108
  */
41
- export async function runAssistant(opts: RunOpts) {
42
- const db = await getDb();
43
- const config = await loadConfig();
44
- const [provider, modelName] = opts.model.split('/', 2);
45
- const model = resolveModel(provider, modelName);
46
-
47
- // Build agent + system prompt
48
- const agentConfig = resolveAgentConfig(opts.agent);
49
- const availableTools = await discoverProjectTools(config.project.root);
50
- const system = composeSystemPrompt(agentConfig, availableTools);
51
-
52
- // Build message history
53
- const history = await buildHistoryMessages(opts, db);
54
-
55
- // Setup tool context
56
- const toolContext = await setupToolContext(opts, db);
57
- const { tools, sharedCtx } = toolContext;
58
-
59
- // State
60
- let currentPartId = sharedCtx.assistantPartId;
61
- let stepIndex = sharedCtx.stepIndex;
109
+ async function runAssistant(opts: RunOpts) {
110
+ const cfgTimer = time('runner:loadConfig+db');
111
+ const cfg = await loadConfig(opts.projectRoot);
112
+ const db = await getDb(cfg.projectRoot);
113
+ cfgTimer.end();
114
+
115
+ const agentTimer = time('runner:resolveAgentConfig');
116
+ const agentCfg = await resolveAgentConfig(cfg.projectRoot, opts.agent);
117
+ agentTimer.end({ agent: opts.agent });
118
+
119
+ const agentPrompt = agentCfg.prompt || '';
120
+
121
+ const historyTimer = time('runner:buildHistory');
122
+ const history = await buildHistoryMessages(db, opts.sessionId);
123
+ historyTimer.end({ messages: history.length });
124
+
125
+ const isFirstMessage = history.length === 0;
126
+
127
+ const systemTimer = time('runner:composeSystemPrompt');
128
+ const { getAuth } = await import('@agi-cli/sdk');
129
+ const { getProviderSpoofPrompt } = await import('./prompt.ts');
130
+ const auth = await getAuth(opts.provider, cfg.projectRoot);
131
+ const needsSpoof = auth?.type === 'oauth';
132
+ const spoofPrompt = needsSpoof
133
+ ? getProviderSpoofPrompt(opts.provider)
134
+ : undefined;
135
+
136
+ let system: string;
137
+ let additionalSystemMessages: Array<{ role: 'system'; content: string }> = [];
138
+
139
+ if (spoofPrompt) {
140
+ system = spoofPrompt;
141
+ const fullPrompt = await composeSystemPrompt({
142
+ provider: opts.provider,
143
+ model: opts.model,
144
+ projectRoot: cfg.projectRoot,
145
+ agentPrompt,
146
+ oneShot: opts.oneShot,
147
+ spoofPrompt: undefined,
148
+ includeProjectTree: isFirstMessage,
149
+ });
150
+ additionalSystemMessages = [{ role: 'system', content: fullPrompt }];
151
+ } else {
152
+ system = await composeSystemPrompt({
153
+ provider: opts.provider,
154
+ model: opts.model,
155
+ projectRoot: cfg.projectRoot,
156
+ agentPrompt,
157
+ oneShot: opts.oneShot,
158
+ spoofPrompt: undefined,
159
+ includeProjectTree: isFirstMessage,
160
+ });
161
+ }
162
+ systemTimer.end();
163
+ debugLog('[system] composed prompt (provider+base+agent):');
164
+ debugLog(system);
165
+
166
+ const toolsTimer = time('runner:discoverTools');
167
+ const allTools = await discoverProjectTools(cfg.projectRoot);
168
+ toolsTimer.end({ count: allTools.length });
169
+ const allowedNames = new Set([
170
+ ...(agentCfg.tools || []),
171
+ 'finish',
172
+ 'progress_update',
173
+ ]);
174
+ const gated = allTools.filter((t) => allowedNames.has(t.name));
175
+ const messagesWithSystemInstructions = [
176
+ ...(isFirstMessage ? additionalSystemMessages : []),
177
+ ...history,
178
+ ];
179
+
180
+ const { sharedCtx, firstToolTimer, firstToolSeen } = await setupToolContext(
181
+ opts,
182
+ db,
183
+ );
184
+ const toolset = adaptTools(gated, sharedCtx, opts.provider);
185
+
186
+ const modelTimer = time('runner:resolveModel');
187
+ const model = await resolveModel(opts.provider, opts.model, cfg);
188
+ modelTimer.end();
189
+
190
+ const maxOutputTokens = getMaxOutputTokens(opts.provider, opts.model);
191
+
192
+ let currentPartId = opts.assistantPartId;
62
193
  let accumulated = '';
63
- const abortController = new AbortController();
194
+ let stepIndex = 0;
195
+
196
+ let finishObserved = false;
197
+ const unsubscribeFinish = subscribe(opts.sessionId, (evt) => {
198
+ if (evt.type !== 'tool.result') return;
199
+ try {
200
+ const name = (evt.payload as { name?: string } | undefined)?.name;
201
+ if (name === 'finish') finishObserved = true;
202
+ } catch {}
203
+ });
64
204
 
65
- // State getters/setters
205
+ const streamStartTimer = time('runner:first-delta');
206
+ let firstDeltaSeen = false;
207
+ debugLog(`[streamText] Calling with maxOutputTokens: ${maxOutputTokens}`);
208
+
209
+ // State management helpers
66
210
  const getCurrentPartId = () => currentPartId;
67
211
  const getStepIndex = () => stepIndex;
68
212
  const updateCurrentPartId = (id: string) => {
@@ -71,10 +215,12 @@ export async function runAssistant(opts: RunOpts) {
71
215
  const updateAccumulated = (text: string) => {
72
216
  accumulated = text;
73
217
  };
74
- const getAccumulated = () => accumulated;
75
- const incrementStepIndex = () => ++stepIndex;
218
+ const incrementStepIndex = () => {
219
+ stepIndex += 1;
220
+ return stepIndex;
221
+ };
76
222
 
77
- // Handlers
223
+ // Create stream handlers
78
224
  const onStepFinish = createStepFinishHandler(
79
225
  opts,
80
226
  db,
@@ -88,102 +234,105 @@ export async function runAssistant(opts: RunOpts) {
88
234
  updateMessageTokensIncremental,
89
235
  );
90
236
 
237
+ const onError = createErrorHandler(opts, db, getStepIndex, sharedCtx);
238
+
239
+ const onAbort = createAbortHandler(opts, db, getStepIndex, sharedCtx);
240
+
91
241
  const onFinish = createFinishHandler(
92
242
  opts,
93
243
  db,
244
+ () => ensureFinishToolCalled(finishObserved, toolset, sharedCtx, stepIndex),
94
245
  completeAssistantMessage,
95
- getAccumulated,
96
- abortController,
97
246
  );
98
247
 
99
- const _onAbort = createAbortHandler(opts, db, abortController);
100
- const onError = createErrorHandler(opts, db);
248
+ // Apply optimizations: deduplication, pruning, cache control, and truncation
249
+ const { addCacheControl, truncateHistory } = await import(
250
+ './cache-optimizer.ts'
251
+ );
252
+ const { optimizeContext } = await import('./context-optimizer.ts');
101
253
 
102
- // Context optimization
103
- const contextOptimized = optimizeContext(history, {
254
+ // 1. Optimize context (deduplicate file reads, prune old tool results)
255
+ const contextOptimized = optimizeContext(messagesWithSystemInstructions, {
104
256
  deduplicateFiles: true,
105
257
  maxToolResults: 30,
106
258
  });
107
259
 
108
- // Truncate history
260
+ // 2. Truncate history
109
261
  const truncatedMessages = truncateHistory(contextOptimized, 20);
110
262
 
111
- // Add cache control
263
+ // 3. Add cache control
112
264
  const { system: cachedSystem, messages: optimizedMessages } = addCacheControl(
113
- opts.provider,
265
+ opts.provider as any,
114
266
  system,
115
267
  truncatedMessages,
116
268
  );
117
269
 
118
270
  try {
119
- const maxTokens = getMaxOutputTokens(provider, modelName);
120
- const result = await streamText({
271
+ // @ts-expect-error this is fine 🔥
272
+ const result = streamText({
121
273
  model,
122
- system: cachedSystem,
274
+ tools: toolset,
275
+ ...(cachedSystem ? { system: cachedSystem } : {}),
123
276
  messages: optimizedMessages,
124
- tools,
125
- maxSteps: 50,
126
- maxTokens,
127
- temperature: agentConfig.temperature ?? 0.7,
128
- abortSignal: abortController.signal,
277
+ ...(maxOutputTokens ? { maxOutputTokens } : {}),
278
+ abortSignal: opts.abortSignal,
279
+ stopWhen: hasToolCall('finish'),
129
280
  onStepFinish,
281
+ onError,
282
+ onAbort,
130
283
  onFinish,
131
- experimental_continueSteps: true,
132
284
  });
133
285
 
134
- // Process the stream
135
286
  for await (const delta of result.textStream) {
136
- if (abortController.signal.aborted) break;
137
-
138
- accumulated += delta;
139
- if (currentPartId) {
140
- await db
141
- .update(messageParts)
142
- .set({ content: accumulated })
143
- .where(eq(messageParts.id, currentPartId));
287
+ if (!delta) continue;
288
+ if (!firstDeltaSeen) {
289
+ firstDeltaSeen = true;
290
+ streamStartTimer.end();
144
291
  }
145
-
146
- publish('stream:text-delta', {
292
+ accumulated += delta;
293
+ publish({
294
+ type: 'message.part.delta',
147
295
  sessionId: opts.sessionId,
148
- messageId: opts.assistantMessageId,
149
- assistantMessageId: opts.assistantMessageId,
150
- stepIndex,
151
- textDelta: delta,
152
- fullText: accumulated,
296
+ payload: {
297
+ messageId: opts.assistantMessageId,
298
+ partId: currentPartId,
299
+ stepIndex,
300
+ delta,
301
+ },
153
302
  });
303
+ await db
304
+ .update(messageParts)
305
+ .set({ content: JSON.stringify({ text: accumulated }) })
306
+ .where(eq(messageParts.id, currentPartId));
154
307
  }
155
- } catch (err) {
156
- await onError(err);
308
+ } catch (error) {
309
+ const errorPayload = toErrorPayload(error);
310
+ await db
311
+ .update(messageParts)
312
+ .set({
313
+ content: JSON.stringify({
314
+ text: accumulated,
315
+ error: errorPayload.message,
316
+ }),
317
+ })
318
+ .where(eq(messageParts.messageId, opts.assistantMessageId));
319
+ publish({
320
+ type: 'error',
321
+ sessionId: opts.sessionId,
322
+ payload: {
323
+ messageId: opts.assistantMessageId,
324
+ error: errorPayload.message,
325
+ details: errorPayload.details,
326
+ },
327
+ });
328
+ throw error;
157
329
  } finally {
158
- setRunning(opts.sessionId, false);
159
- dequeueJob(opts.sessionId);
330
+ if (!firstToolSeen()) firstToolTimer.end({ skipped: true });
331
+ try {
332
+ unsubscribeFinish();
333
+ } catch {}
334
+ try {
335
+ await cleanupEmptyTextParts(opts, db);
336
+ } catch {}
160
337
  }
161
338
  }
162
-
163
- /**
164
- * Enqueues an assistant run
165
- */
166
- export async function enqueueAssistantRun(opts: RunOpts) {
167
- return enqueueRun(opts);
168
- }
169
-
170
- /**
171
- * Aborts a running session
172
- */
173
- export async function abortSession(sessionId: number) {
174
- return abortSessionQueue(sessionId);
175
- }
176
-
177
- /**
178
- * Gets the current runner state for a session
179
- */
180
- export function getSessionState(sessionId: number) {
181
- return getRunnerState(sessionId);
182
- }
183
-
184
- /**
185
- * Cleanup session resources
186
- */
187
- export function cleanupSessionResources(sessionId: number) {
188
- return cleanupSession(sessionId);
189
- }
@@ -8,26 +8,17 @@ import { toErrorPayload } from './error-handling.ts';
8
8
  import type { RunOpts } from './session-queue.ts';
9
9
  import type { ToolAdapterContext } from '../tools/adapter.ts';
10
10
 
11
- interface ProviderMetadata {
12
- openai?: {
13
- cachedPromptTokens?: number;
14
- };
15
- [key: string]: unknown;
16
- }
17
-
18
- interface UsageData {
19
- inputTokens?: number;
20
- outputTokens?: number;
21
- totalTokens?: number;
22
- cachedInputTokens?: number;
23
- reasoningTokens?: number;
24
- }
25
-
26
11
  type StepFinishEvent = {
27
- usage?: UsageData;
12
+ usage?: {
13
+ inputTokens?: number;
14
+ outputTokens?: number;
15
+ totalTokens?: number;
16
+ cachedInputTokens?: number;
17
+ reasoningTokens?: number;
18
+ };
28
19
  finishReason?: string;
29
20
  response?: unknown;
30
- experimental_providerMetadata?: ProviderMetadata;
21
+ experimental_providerMetadata?: Record<string, any>;
31
22
  };
32
23
 
33
24
  type FinishEvent = {
@@ -51,19 +42,19 @@ export function createStepFinishHandler(
51
42
  db: Awaited<ReturnType<typeof getDb>>,
52
43
  getCurrentPartId: () => string,
53
44
  getStepIndex: () => number,
54
- _sharedCtx: ToolAdapterContext,
55
- _updateCurrentPartId: (id: string) => void,
56
- _updateAccumulated: (text: string) => void,
45
+ sharedCtx: ToolAdapterContext,
46
+ updateCurrentPartId: (id: string) => void,
47
+ updateAccumulated: (text: string) => void,
57
48
  incrementStepIndex: () => number,
58
49
  updateSessionTokensIncrementalFn: (
59
- usage: UsageData,
60
- providerMetadata: ProviderMetadata | undefined,
50
+ usage: any,
51
+ providerMetadata: Record<string, any> | undefined,
61
52
  opts: RunOpts,
62
53
  db: Awaited<ReturnType<typeof getDb>>,
63
54
  ) => Promise<void>,
64
55
  updateMessageTokensIncrementalFn: (
65
- usage: UsageData,
66
- providerMetadata: ProviderMetadata | undefined,
56
+ usage: any,
57
+ providerMetadata: Record<string, any> | undefined,
67
58
  opts: RunOpts,
68
59
  db: Awaited<ReturnType<typeof getDb>>,
69
60
  ) => Promise<void>,
@@ -78,11 +69,9 @@ export function createStepFinishHandler(
78
69
  .update(messageParts)
79
70
  .set({ completedAt: finishedAt })
80
71
  .where(eq(messageParts.id, currentPartId));
81
- } catch (err) {
82
- console.error('[createStepFinishHandler] Failed to update part', err);
83
- }
72
+ } catch {}
84
73
 
85
- // Update tokens incrementally
74
+ // Update token counts incrementally after each step
86
75
  if (step.usage) {
87
76
  try {
88
77
  await updateSessionTokensIncrementalFn(
@@ -91,81 +80,126 @@ export function createStepFinishHandler(
91
80
  opts,
92
81
  db,
93
82
  );
83
+ } catch {}
84
+
85
+ try {
94
86
  await updateMessageTokensIncrementalFn(
95
87
  step.usage,
96
88
  step.experimental_providerMetadata,
97
89
  opts,
98
90
  db,
99
91
  );
100
- } catch (err) {
101
- console.error('[createStepFinishHandler] Token update failed', err);
102
- }
92
+ } catch {}
103
93
  }
104
94
 
105
- // Publish step-finished event
106
- publish('stream:step-finished', {
107
- sessionId: opts.sessionId,
108
- messageId: opts.assistantMessageId,
109
- assistantMessageId: opts.assistantMessageId,
110
- stepIndex,
111
- finishReason: step.finishReason,
112
- usage: step.usage,
113
- });
95
+ try {
96
+ publish({
97
+ type: 'finish-step',
98
+ sessionId: opts.sessionId,
99
+ payload: {
100
+ stepIndex,
101
+ usage: step.usage,
102
+ finishReason: step.finishReason,
103
+ response: step.response,
104
+ },
105
+ });
106
+ if (step.usage) {
107
+ publish({
108
+ type: 'usage',
109
+ sessionId: opts.sessionId,
110
+ payload: { stepIndex, ...step.usage },
111
+ });
112
+ }
113
+ } catch {}
114
114
 
115
- incrementStepIndex();
115
+ try {
116
+ const newStepIndex = incrementStepIndex();
117
+ const newPartId = crypto.randomUUID();
118
+ const index = await sharedCtx.nextIndex();
119
+ const nowTs = Date.now();
120
+ await db.insert(messageParts).values({
121
+ id: newPartId,
122
+ messageId: opts.assistantMessageId,
123
+ index,
124
+ stepIndex: newStepIndex,
125
+ type: 'text',
126
+ content: JSON.stringify({ text: '' }),
127
+ agent: opts.agent,
128
+ provider: opts.provider,
129
+ model: opts.model,
130
+ startedAt: nowTs,
131
+ });
132
+ updateCurrentPartId(newPartId);
133
+ sharedCtx.assistantPartId = newPartId;
134
+ sharedCtx.stepIndex = newStepIndex;
135
+ updateAccumulated('');
136
+ } catch {}
116
137
  };
117
138
  }
118
139
 
119
140
  /**
120
- * Creates the onFinish handler for the stream
141
+ * Creates the onError handler for the stream
121
142
  */
122
- export function createFinishHandler(
143
+ export function createErrorHandler(
123
144
  opts: RunOpts,
124
145
  db: Awaited<ReturnType<typeof getDb>>,
125
- completeAssistantMessageFn: (
126
- fin: FinishEvent,
127
- opts: RunOpts,
128
- db: Awaited<ReturnType<typeof getDb>>,
129
- ) => Promise<void>,
130
- _getAccumulated: () => string,
131
- _abortController: AbortController,
146
+ getStepIndex: () => number,
147
+ sharedCtx: ToolAdapterContext,
132
148
  ) {
133
- return async (fin: FinishEvent) => {
134
- try {
135
- await completeAssistantMessageFn(fin, opts, db);
149
+ return async (err: unknown) => {
150
+ const errorPayload = toErrorPayload(err);
151
+ const isApiError = APICallError.isInstance(err);
152
+ const stepIndex = getStepIndex();
136
153
 
137
- const msgRows = await db
138
- .select()
139
- .from(messages)
140
- .where(eq(messages.id, opts.assistantMessageId));
154
+ // Create error part for UI display
155
+ const errorPartId = crypto.randomUUID();
156
+ await db.insert(messageParts).values({
157
+ id: errorPartId,
158
+ messageId: opts.assistantMessageId,
159
+ index: await sharedCtx.nextIndex(),
160
+ stepIndex,
161
+ type: 'error',
162
+ content: JSON.stringify({
163
+ message: errorPayload.message,
164
+ type: errorPayload.type,
165
+ details: errorPayload.details,
166
+ isAborted: false,
167
+ }),
168
+ agent: opts.agent,
169
+ provider: opts.provider,
170
+ model: opts.model,
171
+ startedAt: Date.now(),
172
+ completedAt: Date.now(),
173
+ });
141
174
 
142
- let estimatedCost = 0;
143
- if (msgRows.length > 0 && msgRows[0]) {
144
- const msg = msgRows[0];
145
- estimatedCost = estimateModelCostUsd(
146
- opts.provider,
147
- opts.model,
148
- Number(msg.promptTokens ?? 0),
149
- Number(msg.completionTokens ?? 0),
150
- );
151
- }
175
+ // Update message status
176
+ await db
177
+ .update(messages)
178
+ .set({
179
+ status: 'error',
180
+ error: errorPayload.message,
181
+ errorType: errorPayload.type,
182
+ errorDetails: JSON.stringify({
183
+ ...errorPayload.details,
184
+ isApiError,
185
+ }),
186
+ isAborted: false,
187
+ })
188
+ .where(eq(messages.id, opts.assistantMessageId));
152
189
 
153
- publish('stream:finished', {
154
- sessionId: opts.sessionId,
155
- messageId: opts.assistantMessageId,
156
- assistantMessageId: opts.assistantMessageId,
157
- usage: fin.usage,
158
- finishReason: fin.finishReason,
159
- estimatedCost,
160
- });
161
- } catch (err) {
162
- console.error('[createFinishHandler] Error in onFinish', err);
163
- publish('stream:error', {
164
- sessionId: opts.sessionId,
190
+ // Publish enhanced error event
191
+ publish({
192
+ type: 'error',
193
+ sessionId: opts.sessionId,
194
+ payload: {
165
195
  messageId: opts.assistantMessageId,
166
- error: toErrorPayload(err),
167
- });
168
- }
196
+ partId: errorPartId,
197
+ error: errorPayload.message,
198
+ errorType: errorPayload.type,
199
+ details: errorPayload.details,
200
+ isAborted: false,
201
+ },
202
+ });
169
203
  };
170
204
  }
171
205
 
@@ -175,116 +209,116 @@ export function createFinishHandler(
175
209
  export function createAbortHandler(
176
210
  opts: RunOpts,
177
211
  db: Awaited<ReturnType<typeof getDb>>,
178
- _abortController: AbortController,
212
+ getStepIndex: () => number,
213
+ sharedCtx: ToolAdapterContext,
179
214
  ) {
180
- return async (_event: AbortEvent) => {
181
- try {
182
- await db
183
- .update(messages)
184
- .set({ status: 'aborted', finishedAt: new Date() })
185
- .where(eq(messages.id, opts.assistantMessageId));
215
+ return async ({ steps }: AbortEvent) => {
216
+ const stepIndex = getStepIndex();
186
217
 
187
- publish('stream:aborted', {
188
- sessionId: opts.sessionId,
218
+ // Create abort part for UI
219
+ const abortPartId = crypto.randomUUID();
220
+ await db.insert(messageParts).values({
221
+ id: abortPartId,
222
+ messageId: opts.assistantMessageId,
223
+ index: await sharedCtx.nextIndex(),
224
+ stepIndex,
225
+ type: 'error',
226
+ content: JSON.stringify({
227
+ message: 'Generation stopped by user',
228
+ type: 'abort',
229
+ isAborted: true,
230
+ stepsCompleted: steps.length,
231
+ }),
232
+ agent: opts.agent,
233
+ provider: opts.provider,
234
+ model: opts.model,
235
+ startedAt: Date.now(),
236
+ completedAt: Date.now(),
237
+ });
238
+
239
+ // Store abort info
240
+ await db
241
+ .update(messages)
242
+ .set({
243
+ status: 'error',
244
+ error: 'Generation stopped by user',
245
+ errorType: 'abort',
246
+ errorDetails: JSON.stringify({
247
+ stepsCompleted: steps.length,
248
+ abortedAt: Date.now(),
249
+ }),
250
+ isAborted: true,
251
+ })
252
+ .where(eq(messages.id, opts.assistantMessageId));
253
+
254
+ // Publish abort event
255
+ publish({
256
+ type: 'error',
257
+ sessionId: opts.sessionId,
258
+ payload: {
189
259
  messageId: opts.assistantMessageId,
190
- assistantMessageId: opts.assistantMessageId,
191
- });
192
- } catch (err) {
193
- console.error('[createAbortHandler] Error in onAbort', err);
194
- }
260
+ partId: abortPartId,
261
+ error: 'Generation stopped by user',
262
+ errorType: 'abort',
263
+ isAborted: true,
264
+ stepsCompleted: steps.length,
265
+ },
266
+ });
195
267
  };
196
268
  }
197
269
 
198
270
  /**
199
- * Creates the error handler for the stream
271
+ * Creates the onFinish handler for the stream
200
272
  */
201
- export function createErrorHandler(
273
+ export function createFinishHandler(
202
274
  opts: RunOpts,
203
275
  db: Awaited<ReturnType<typeof getDb>>,
276
+ ensureFinishToolCalled: () => Promise<void>,
277
+ completeAssistantMessageFn: (
278
+ fin: FinishEvent,
279
+ opts: RunOpts,
280
+ db: Awaited<ReturnType<typeof getDb>>,
281
+ ) => Promise<void>,
204
282
  ) {
205
- return async (err: unknown) => {
206
- console.error('[createErrorHandler] Stream error:', err);
207
-
283
+ return async (fin: FinishEvent) => {
208
284
  try {
209
- let errorMessage = 'Unknown error';
210
- let errorType = 'UNKNOWN_ERROR';
211
- let errorStack: string | undefined;
285
+ await ensureFinishToolCalled();
286
+ } catch {}
212
287
 
213
- if (err instanceof APICallError) {
214
- errorMessage = err.message;
215
- errorType = 'API_CALL_ERROR';
216
- errorStack = err.stack;
217
- } else if (err instanceof Error) {
218
- errorMessage = err.message;
219
- errorType = err.name || 'ERROR';
220
- errorStack = err.stack;
221
- } else if (typeof err === 'string') {
222
- errorMessage = err;
223
- }
288
+ // Note: Token updates are handled incrementally in onStepFinish
289
+ // Do NOT add fin.usage here as it would cause double-counting
224
290
 
225
- await db
226
- .update(messages)
227
- .set({
228
- status: 'error',
229
- finishedAt: new Date(),
230
- error: errorMessage,
231
- })
232
- .where(eq(messages.id, opts.assistantMessageId));
233
-
234
- publish('stream:error', {
235
- sessionId: opts.sessionId,
236
- messageId: opts.assistantMessageId,
237
- assistantMessageId: opts.assistantMessageId,
238
- error: {
239
- message: errorMessage,
240
- type: errorType,
241
- stack: errorStack,
242
- },
243
- });
244
- } catch (dbErr) {
245
- console.error('[createErrorHandler] Failed to save error to DB', dbErr);
246
- }
247
- };
248
- }
291
+ try {
292
+ await completeAssistantMessageFn(fin, opts, db);
293
+ } catch {}
249
294
 
250
- /**
251
- * Creates the text delta handler for the stream
252
- */
253
- export function createTextHandler(
254
- opts: RunOpts,
255
- db: Awaited<ReturnType<typeof getDb>>,
256
- getCurrentPartId: () => string,
257
- getStepIndex: () => number,
258
- _updateCurrentPartId: (id: string) => void,
259
- updateAccumulated: (text: string) => void,
260
- getAccumulated: () => string,
261
- ) {
262
- return async (textDelta: string) => {
263
- const currentPartId = getCurrentPartId();
264
- const stepIndex = getStepIndex();
295
+ // Use session totals from DB for accurate cost calculation
296
+ const sessRows = await db
297
+ .select()
298
+ .from(messages)
299
+ .where(eq(messages.id, opts.assistantMessageId));
265
300
 
266
- // Accumulate the text
267
- const accumulated = getAccumulated() + textDelta;
268
- updateAccumulated(accumulated);
301
+ const usage = sessRows[0]
302
+ ? {
303
+ inputTokens: Number(sessRows[0].promptTokens ?? 0),
304
+ outputTokens: Number(sessRows[0].completionTokens ?? 0),
305
+ totalTokens: Number(sessRows[0].totalTokens ?? 0),
306
+ }
307
+ : fin.usage;
269
308
 
270
- try {
271
- if (currentPartId) {
272
- await db
273
- .update(messageParts)
274
- .set({ content: accumulated })
275
- .where(eq(messageParts.id, currentPartId));
276
- }
309
+ const costUsd = usage
310
+ ? estimateModelCostUsd(opts.provider, opts.model, usage)
311
+ : undefined;
277
312
 
278
- publish('stream:text-delta', {
279
- sessionId: opts.sessionId,
280
- messageId: opts.assistantMessageId,
281
- assistantMessageId: opts.assistantMessageId,
282
- stepIndex,
283
- textDelta,
284
- fullText: accumulated,
285
- });
286
- } catch (err) {
287
- console.error('[createTextHandler] Error updating text part', err);
288
- }
313
+ publish({
314
+ type: 'message.completed',
315
+ sessionId: opts.sessionId,
316
+ payload: {
317
+ id: opts.assistantMessageId,
318
+ usage,
319
+ costUsd,
320
+ finishReason: fin.finishReason,
321
+ },
322
+ });
289
323
  };
290
324
  }