@agi-cli/server 0.1.55

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,275 @@
1
+ import type { getDb } from '@agi-cli/database';
2
+ import { messages, messageParts } from '@agi-cli/database/schema';
3
+ import { eq } from 'drizzle-orm';
4
+ import { APICallError } from 'ai';
5
+ import { publish } from '../events/bus.ts';
6
+ import { estimateModelCostUsd } from '@agi-cli/sdk';
7
+ import { toErrorPayload } from './error-handling.ts';
8
+ import type { RunOpts } from './session-queue.ts';
9
+ import type { ToolAdapterContext } from '../tools/adapter.ts';
10
+
11
+ type StepFinishEvent = {
12
+ usage?: { inputTokens?: number; outputTokens?: number };
13
+ finishReason?: string;
14
+ response?: unknown;
15
+ };
16
+
17
+ type FinishEvent = {
18
+ usage?: {
19
+ inputTokens?: number;
20
+ outputTokens?: number;
21
+ totalTokens?: number;
22
+ };
23
+ finishReason?: string;
24
+ };
25
+
26
+ type AbortEvent = {
27
+ steps: unknown[];
28
+ };
29
+
30
+ /**
31
+ * Creates the onStepFinish handler for the stream
32
+ */
33
+ export function createStepFinishHandler(
34
+ opts: RunOpts,
35
+ db: Awaited<ReturnType<typeof getDb>>,
36
+ getCurrentPartId: () => string,
37
+ getStepIndex: () => number,
38
+ sharedCtx: ToolAdapterContext,
39
+ updateCurrentPartId: (id: string) => void,
40
+ updateAccumulated: (text: string) => void,
41
+ incrementStepIndex: () => number,
42
+ ) {
43
+ return async (step: StepFinishEvent) => {
44
+ const finishedAt = Date.now();
45
+ const currentPartId = getCurrentPartId();
46
+ const stepIndex = getStepIndex();
47
+
48
+ try {
49
+ await db
50
+ .update(messageParts)
51
+ .set({ completedAt: finishedAt })
52
+ .where(eq(messageParts.id, currentPartId));
53
+ } catch {}
54
+
55
+ try {
56
+ publish({
57
+ type: 'finish-step',
58
+ sessionId: opts.sessionId,
59
+ payload: {
60
+ stepIndex,
61
+ usage: step.usage,
62
+ finishReason: step.finishReason,
63
+ response: step.response,
64
+ },
65
+ });
66
+ if (step.usage) {
67
+ publish({
68
+ type: 'usage',
69
+ sessionId: opts.sessionId,
70
+ payload: { stepIndex, ...step.usage },
71
+ });
72
+ }
73
+ } catch {}
74
+
75
+ try {
76
+ const newStepIndex = incrementStepIndex();
77
+ const newPartId = crypto.randomUUID();
78
+ const index = await sharedCtx.nextIndex();
79
+ const nowTs = Date.now();
80
+ await db.insert(messageParts).values({
81
+ id: newPartId,
82
+ messageId: opts.assistantMessageId,
83
+ index,
84
+ stepIndex: newStepIndex,
85
+ type: 'text',
86
+ content: JSON.stringify({ text: '' }),
87
+ agent: opts.agent,
88
+ provider: opts.provider,
89
+ model: opts.model,
90
+ startedAt: nowTs,
91
+ });
92
+ updateCurrentPartId(newPartId);
93
+ sharedCtx.assistantPartId = newPartId;
94
+ sharedCtx.stepIndex = newStepIndex;
95
+ updateAccumulated('');
96
+ } catch {}
97
+ };
98
+ }
99
+
100
+ /**
101
+ * Creates the onError handler for the stream
102
+ */
103
+ export function createErrorHandler(
104
+ opts: RunOpts,
105
+ db: Awaited<ReturnType<typeof getDb>>,
106
+ getStepIndex: () => number,
107
+ sharedCtx: ToolAdapterContext,
108
+ ) {
109
+ return async (err: unknown) => {
110
+ const errorPayload = toErrorPayload(err);
111
+ const isApiError = APICallError.isInstance(err);
112
+ const stepIndex = getStepIndex();
113
+
114
+ // Create error part for UI display
115
+ const errorPartId = crypto.randomUUID();
116
+ await db.insert(messageParts).values({
117
+ id: errorPartId,
118
+ messageId: opts.assistantMessageId,
119
+ index: await sharedCtx.nextIndex(),
120
+ stepIndex,
121
+ type: 'error',
122
+ content: JSON.stringify({
123
+ message: errorPayload.message,
124
+ type: errorPayload.type,
125
+ details: errorPayload.details,
126
+ isAborted: false,
127
+ }),
128
+ agent: opts.agent,
129
+ provider: opts.provider,
130
+ model: opts.model,
131
+ startedAt: Date.now(),
132
+ completedAt: Date.now(),
133
+ });
134
+
135
+ // Update message status
136
+ await db
137
+ .update(messages)
138
+ .set({
139
+ status: 'error',
140
+ error: errorPayload.message,
141
+ errorType: errorPayload.type,
142
+ errorDetails: JSON.stringify({
143
+ ...errorPayload.details,
144
+ isApiError,
145
+ }),
146
+ isAborted: false,
147
+ })
148
+ .where(eq(messages.id, opts.assistantMessageId));
149
+
150
+ // Publish enhanced error event
151
+ publish({
152
+ type: 'error',
153
+ sessionId: opts.sessionId,
154
+ payload: {
155
+ messageId: opts.assistantMessageId,
156
+ partId: errorPartId,
157
+ error: errorPayload.message,
158
+ errorType: errorPayload.type,
159
+ details: errorPayload.details,
160
+ isAborted: false,
161
+ },
162
+ });
163
+ };
164
+ }
165
+
166
+ /**
167
+ * Creates the onAbort handler for the stream
168
+ */
169
+ export function createAbortHandler(
170
+ opts: RunOpts,
171
+ db: Awaited<ReturnType<typeof getDb>>,
172
+ getStepIndex: () => number,
173
+ sharedCtx: ToolAdapterContext,
174
+ ) {
175
+ return async ({ steps }: AbortEvent) => {
176
+ const stepIndex = getStepIndex();
177
+
178
+ // Create abort part for UI
179
+ const abortPartId = crypto.randomUUID();
180
+ await db.insert(messageParts).values({
181
+ id: abortPartId,
182
+ messageId: opts.assistantMessageId,
183
+ index: await sharedCtx.nextIndex(),
184
+ stepIndex,
185
+ type: 'error',
186
+ content: JSON.stringify({
187
+ message: 'Generation stopped by user',
188
+ type: 'abort',
189
+ isAborted: true,
190
+ stepsCompleted: steps.length,
191
+ }),
192
+ agent: opts.agent,
193
+ provider: opts.provider,
194
+ model: opts.model,
195
+ startedAt: Date.now(),
196
+ completedAt: Date.now(),
197
+ });
198
+
199
+ // Store abort info
200
+ await db
201
+ .update(messages)
202
+ .set({
203
+ status: 'error',
204
+ error: 'Generation stopped by user',
205
+ errorType: 'abort',
206
+ errorDetails: JSON.stringify({
207
+ stepsCompleted: steps.length,
208
+ abortedAt: Date.now(),
209
+ }),
210
+ isAborted: true,
211
+ })
212
+ .where(eq(messages.id, opts.assistantMessageId));
213
+
214
+ // Publish abort event
215
+ publish({
216
+ type: 'error',
217
+ sessionId: opts.sessionId,
218
+ payload: {
219
+ messageId: opts.assistantMessageId,
220
+ partId: abortPartId,
221
+ error: 'Generation stopped by user',
222
+ errorType: 'abort',
223
+ isAborted: true,
224
+ stepsCompleted: steps.length,
225
+ },
226
+ });
227
+ };
228
+ }
229
+
230
+ /**
231
+ * Creates the onFinish handler for the stream
232
+ */
233
+ export function createFinishHandler(
234
+ opts: RunOpts,
235
+ db: Awaited<ReturnType<typeof getDb>>,
236
+ ensureFinishToolCalled: () => Promise<void>,
237
+ updateSessionTokensFn: (
238
+ fin: FinishEvent,
239
+ opts: RunOpts,
240
+ db: Awaited<ReturnType<typeof getDb>>,
241
+ ) => Promise<void>,
242
+ completeAssistantMessageFn: (
243
+ fin: FinishEvent,
244
+ opts: RunOpts,
245
+ db: Awaited<ReturnType<typeof getDb>>,
246
+ ) => Promise<void>,
247
+ ) {
248
+ return async (fin: FinishEvent) => {
249
+ try {
250
+ await ensureFinishToolCalled();
251
+ } catch {}
252
+
253
+ try {
254
+ await updateSessionTokensFn(fin, opts, db);
255
+ } catch {}
256
+
257
+ try {
258
+ await completeAssistantMessageFn(fin, opts, db);
259
+ } catch {}
260
+
261
+ const costUsd = fin.usage
262
+ ? estimateModelCostUsd(opts.provider, opts.model, fin.usage)
263
+ : undefined;
264
+ publish({
265
+ type: 'message.completed',
266
+ sessionId: opts.sessionId,
267
+ payload: {
268
+ id: opts.assistantMessageId,
269
+ usage: fin.usage,
270
+ costUsd,
271
+ finishReason: fin.finishReason,
272
+ },
273
+ });
274
+ };
275
+ }
@@ -0,0 +1,35 @@
1
+ import { catalog } from '@agi-cli/sdk';
2
+ import { debugLog } from './debug.ts';
3
+ import type { ProviderName } from './provider.ts';
4
+
5
+ /**
6
+ * Gets the maximum output tokens allowed for a given provider/model combination.
7
+ * Returns undefined if the information is not available in the catalog.
8
+ */
9
+ export function getMaxOutputTokens(
10
+ provider: ProviderName,
11
+ modelId: string,
12
+ ): number | undefined {
13
+ try {
14
+ const providerCatalog = catalog[provider];
15
+ if (!providerCatalog) {
16
+ debugLog(`[maxOutputTokens] No catalog found for provider: ${provider}`);
17
+ return undefined;
18
+ }
19
+ const modelInfo = providerCatalog.models.find((m) => m.id === modelId);
20
+ if (!modelInfo) {
21
+ debugLog(
22
+ `[maxOutputTokens] No model info found for: ${modelId} in provider: ${provider}`,
23
+ );
24
+ return undefined;
25
+ }
26
+ const outputLimit = modelInfo.limit?.output;
27
+ debugLog(
28
+ `[maxOutputTokens] Provider: ${provider}, Model: ${modelId}, Limit: ${outputLimit}`,
29
+ );
30
+ return outputLimit;
31
+ } catch (err) {
32
+ debugLog(`[maxOutputTokens] Error looking up limit: ${err}`);
33
+ return undefined;
34
+ }
35
+ }
@@ -0,0 +1,58 @@
1
+ import type { getDb } from '@agi-cli/database';
2
+ import { messageParts } from '@agi-cli/database/schema';
3
+ import { eq } from 'drizzle-orm';
4
+ import { time } from './debug.ts';
5
+ import type { ToolAdapterContext } from '../tools/adapter.ts';
6
+ import type { RunOpts } from './session-queue.ts';
7
+
8
+ export type RunnerToolContext = ToolAdapterContext & { stepIndex: number };
9
+
10
+ /**
11
+ * Sets up the shared tool context for a run, including the index counter
12
+ * and first tool call tracking.
13
+ */
14
+ export async function setupToolContext(
15
+ opts: RunOpts,
16
+ db: Awaited<ReturnType<typeof getDb>>,
17
+ ) {
18
+ const firstToolTimer = time('runner:first-tool-call');
19
+ let firstToolSeen = false;
20
+
21
+ const sharedCtx: RunnerToolContext = {
22
+ nextIndex: async () => 0,
23
+ stepIndex: 0,
24
+ sessionId: opts.sessionId,
25
+ messageId: opts.assistantMessageId,
26
+ assistantPartId: opts.assistantPartId,
27
+ db,
28
+ agent: opts.agent,
29
+ provider: opts.provider,
30
+ model: opts.model,
31
+ projectRoot: opts.projectRoot,
32
+ onFirstToolCall: () => {
33
+ if (firstToolSeen) return;
34
+ firstToolSeen = true;
35
+ firstToolTimer.end();
36
+ },
37
+ };
38
+
39
+ let counter = 0;
40
+ try {
41
+ const existing = await db
42
+ .select()
43
+ .from(messageParts)
44
+ .where(eq(messageParts.messageId, opts.assistantMessageId));
45
+ if (existing.length) {
46
+ const indexes = existing.map((p) => Number(p.index ?? 0));
47
+ const maxIndex = Math.max(...indexes);
48
+ if (Number.isFinite(maxIndex)) counter = maxIndex;
49
+ }
50
+ } catch {}
51
+
52
+ sharedCtx.nextIndex = () => {
53
+ counter += 1;
54
+ return counter;
55
+ };
56
+
57
+ return { sharedCtx, firstToolTimer, firstToolSeen: () => firstToolSeen };
58
+ }
@@ -0,0 +1,72 @@
1
+ import { eq } from 'drizzle-orm';
2
+ import type { DB } from '@agi-cli/database';
3
+ import { messageParts } from '@agi-cli/database/schema';
4
+ import { publish } from '../events/bus.ts';
5
+
6
+ export type ToolAdapterContext = {
7
+ sessionId: string;
8
+ messageId: string;
9
+ assistantPartId: string;
10
+ db: DB;
11
+ agent: string;
12
+ provider: string;
13
+ model: string;
14
+ projectRoot: string;
15
+ nextIndex: () => number | Promise<number>;
16
+ stepIndex?: number;
17
+ onFirstToolCall?: () => void;
18
+ };
19
+
20
+ export function extractFinishText(input: unknown): string | undefined {
21
+ if (typeof input === 'string') return input;
22
+ if (!input || typeof input !== 'object') return undefined;
23
+ const obj = input as Record<string, unknown>;
24
+ if (typeof obj.text === 'string') return obj.text;
25
+ if (
26
+ obj.input &&
27
+ typeof (obj.input as Record<string, unknown>).text === 'string'
28
+ )
29
+ return String((obj.input as Record<string, unknown>).text);
30
+ return undefined;
31
+ }
32
+
33
+ export async function appendAssistantText(
34
+ ctx: ToolAdapterContext,
35
+ text: string,
36
+ ): Promise<void> {
37
+ try {
38
+ const rows = await ctx.db
39
+ .select()
40
+ .from(messageParts)
41
+ .where(eq(messageParts.id, ctx.assistantPartId));
42
+ let previous = '';
43
+ if (rows.length) {
44
+ try {
45
+ const parsed = JSON.parse(rows[0]?.content ?? '{}');
46
+ if (parsed && typeof parsed.text === 'string') previous = parsed.text;
47
+ } catch {}
48
+ }
49
+ const addition = text.startsWith(previous)
50
+ ? text.slice(previous.length)
51
+ : text;
52
+ if (addition.length) {
53
+ const payload: Record<string, unknown> = {
54
+ messageId: ctx.messageId,
55
+ partId: ctx.assistantPartId,
56
+ delta: addition,
57
+ };
58
+ if (ctx.stepIndex !== undefined) payload.stepIndex = ctx.stepIndex;
59
+ publish({
60
+ type: 'message.part.delta',
61
+ sessionId: ctx.sessionId,
62
+ payload,
63
+ });
64
+ }
65
+ await ctx.db
66
+ .update(messageParts)
67
+ .set({ content: JSON.stringify({ text }) })
68
+ .where(eq(messageParts.id, ctx.assistantPartId));
69
+ } catch {
70
+ // ignore to keep run alive if we can't persist the text
71
+ }
72
+ }