@agi-cli/server 0.1.118 → 0.1.120

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@agi-cli/server",
3
- "version": "0.1.118",
3
+ "version": "0.1.120",
4
4
  "description": "HTTP API server for AGI CLI",
5
5
  "type": "module",
6
6
  "main": "./src/index.ts",
@@ -29,8 +29,8 @@
29
29
  "typecheck": "tsc --noEmit"
30
30
  },
31
31
  "dependencies": {
32
- "@agi-cli/sdk": "0.1.118",
33
- "@agi-cli/database": "0.1.118",
32
+ "@agi-cli/sdk": "0.1.120",
33
+ "@agi-cli/database": "0.1.120",
34
34
  "drizzle-orm": "^0.44.5",
35
35
  "hono": "^4.9.9",
36
36
  "zod": "^4.1.8"
@@ -11,6 +11,7 @@ export type AGIEventType =
11
11
  | 'plan.updated'
12
12
  | 'finish-step'
13
13
  | 'usage'
14
+ | 'queue.updated'
14
15
  | 'error'
15
16
  | 'heartbeat';
16
17
 
package/src/index.ts CHANGED
@@ -14,6 +14,7 @@ import { registerFilesRoutes } from './routes/files.ts';
14
14
  import { registerGitRoutes } from './routes/git/index.ts';
15
15
  import { registerTerminalsRoutes } from './routes/terminals.ts';
16
16
  import { registerSessionFilesRoutes } from './routes/session-files.ts';
17
+ import { registerBranchRoutes } from './routes/branch.ts';
17
18
  import type { AgentConfigEntry } from './runtime/agent-registry.ts';
18
19
 
19
20
  const globalTerminalManager = new TerminalManager();
@@ -64,6 +65,7 @@ function initApp() {
64
65
  registerGitRoutes(app);
65
66
  registerTerminalsRoutes(app, globalTerminalManager);
66
67
  registerSessionFilesRoutes(app);
68
+ registerBranchRoutes(app);
67
69
 
68
70
  return app;
69
71
  }
@@ -130,6 +132,7 @@ export function createStandaloneApp(_config?: StandaloneAppConfig) {
130
132
  registerGitRoutes(honoApp);
131
133
  registerTerminalsRoutes(honoApp, globalTerminalManager);
132
134
  registerSessionFilesRoutes(honoApp);
135
+ registerBranchRoutes(honoApp);
133
136
 
134
137
  return honoApp;
135
138
  }
@@ -224,6 +227,7 @@ export function createEmbeddedApp(config: EmbeddedAppConfig = {}) {
224
227
  registerGitRoutes(honoApp);
225
228
  registerTerminalsRoutes(honoApp, globalTerminalManager);
226
229
  registerSessionFilesRoutes(honoApp);
230
+ registerBranchRoutes(honoApp);
227
231
 
228
232
  return honoApp;
229
233
  }
@@ -0,0 +1,106 @@
1
+ import type { Hono } from 'hono';
2
+ import { loadConfig } from '@agi-cli/sdk';
3
+ import { getDb } from '@agi-cli/database';
4
+ import { isProviderId, logger } from '@agi-cli/sdk';
5
+ import {
6
+ createBranch,
7
+ listBranches,
8
+ getParentSession,
9
+ } from '../runtime/branch.ts';
10
+ import { serializeError } from '../runtime/api-error.ts';
11
+
12
+ export function registerBranchRoutes(app: Hono) {
13
+ app.post('/v1/sessions/:sessionId/branch', async (c) => {
14
+ try {
15
+ const sessionId = c.req.param('sessionId');
16
+ const projectRoot = c.req.query('project') || process.cwd();
17
+ const cfg = await loadConfig(projectRoot);
18
+ const db = await getDb(cfg.projectRoot);
19
+
20
+ const body = (await c.req.json().catch(() => ({}))) as Record<
21
+ string,
22
+ unknown
23
+ >;
24
+
25
+ const fromMessageId = body.fromMessageId;
26
+ if (typeof fromMessageId !== 'string' || !fromMessageId.trim()) {
27
+ return c.json({ error: 'fromMessageId is required' }, 400);
28
+ }
29
+
30
+ const provider =
31
+ typeof body.provider === 'string' && isProviderId(body.provider)
32
+ ? body.provider
33
+ : undefined;
34
+
35
+ const model =
36
+ typeof body.model === 'string' && body.model.trim()
37
+ ? body.model.trim()
38
+ : undefined;
39
+
40
+ const agent =
41
+ typeof body.agent === 'string' && body.agent.trim()
42
+ ? body.agent.trim()
43
+ : undefined;
44
+
45
+ const title =
46
+ typeof body.title === 'string' && body.title.trim()
47
+ ? body.title.trim()
48
+ : undefined;
49
+
50
+ const result = await createBranch({
51
+ db,
52
+ parentSessionId: sessionId,
53
+ fromMessageId: fromMessageId.trim(),
54
+ provider,
55
+ model,
56
+ agent,
57
+ title,
58
+ projectPath: cfg.projectRoot,
59
+ });
60
+
61
+ return c.json(result, 201);
62
+ } catch (err) {
63
+ logger.error('Failed to create branch', err);
64
+ const errorResponse = serializeError(err);
65
+ return c.json(errorResponse, errorResponse.error.status || 400);
66
+ }
67
+ });
68
+
69
+ app.get('/v1/sessions/:sessionId/branches', async (c) => {
70
+ try {
71
+ const sessionId = c.req.param('sessionId');
72
+ const projectRoot = c.req.query('project') || process.cwd();
73
+ const cfg = await loadConfig(projectRoot);
74
+ const db = await getDb(cfg.projectRoot);
75
+
76
+ const branches = await listBranches(db, sessionId, cfg.projectRoot);
77
+
78
+ return c.json({ branches });
79
+ } catch (err) {
80
+ logger.error('Failed to list branches', err);
81
+ const errorResponse = serializeError(err);
82
+ return c.json(errorResponse, errorResponse.error.status || 500);
83
+ }
84
+ });
85
+
86
+ app.get('/v1/sessions/:sessionId/parent', async (c) => {
87
+ try {
88
+ const sessionId = c.req.param('sessionId');
89
+ const projectRoot = c.req.query('project') || process.cwd();
90
+ const cfg = await loadConfig(projectRoot);
91
+ const db = await getDb(cfg.projectRoot);
92
+
93
+ const parent = await getParentSession(db, sessionId, cfg.projectRoot);
94
+
95
+ if (!parent) {
96
+ return c.json({ parent: null });
97
+ }
98
+
99
+ return c.json({ parent });
100
+ } catch (err) {
101
+ logger.error('Failed to get parent session', err);
102
+ const errorResponse = serializeError(err);
103
+ return c.json(errorResponse, errorResponse.error.status || 500);
104
+ }
105
+ });
106
+ }
@@ -1,8 +1,8 @@
1
1
  import type { Hono } from 'hono';
2
2
  import { loadConfig } from '@agi-cli/sdk';
3
3
  import { getDb } from '@agi-cli/database';
4
- import { sessions } from '@agi-cli/database/schema';
5
- import { desc, eq } from 'drizzle-orm';
4
+ import { sessions, messages, messageParts } from '@agi-cli/database/schema';
5
+ import { desc, eq, and, inArray } from 'drizzle-orm';
6
6
  import type { ProviderId } from '@agi-cli/sdk';
7
7
  import { isProviderId, catalog } from '@agi-cli/sdk';
8
8
  import { resolveAgentConfig } from '../runtime/agent-registry.ts';
@@ -194,8 +194,108 @@ export function registerSessionsRoutes(app: Hono) {
194
194
  // Abort session stream
195
195
  app.delete('/v1/sessions/:sessionId/abort', async (c) => {
196
196
  const sessionId = c.req.param('sessionId');
197
- const { abortSession } = await import('../runtime/runner.ts');
198
- abortSession(sessionId);
197
+ const body = (await c.req.json().catch(() => ({}))) as Record<
198
+ string,
199
+ unknown
200
+ >;
201
+ const messageId =
202
+ typeof body.messageId === 'string' ? body.messageId : undefined;
203
+ const clearQueue = body.clearQueue === true;
204
+
205
+ const { abortSession, abortMessage } = await import('../runtime/runner.ts');
206
+
207
+ if (messageId) {
208
+ const result = abortMessage(sessionId, messageId);
209
+ return c.json({
210
+ success: result.removed,
211
+ wasRunning: result.wasRunning,
212
+ messageId,
213
+ });
214
+ }
215
+
216
+ abortSession(sessionId, clearQueue);
199
217
  return c.json({ success: true });
200
218
  });
219
+
220
+ // Get queue state for a session
221
+ app.get('/v1/sessions/:sessionId/queue', async (c) => {
222
+ const sessionId = c.req.param('sessionId');
223
+ const { getQueueState } = await import('../runtime/session-queue.ts');
224
+ const state = getQueueState(sessionId);
225
+ return c.json(
226
+ state ?? {
227
+ currentMessageId: null,
228
+ queuedMessages: [],
229
+ isRunning: false,
230
+ },
231
+ );
232
+ });
233
+
234
+ // Remove a message from the queue
235
+ app.delete('/v1/sessions/:sessionId/queue/:messageId', async (c) => {
236
+ const sessionId = c.req.param('sessionId');
237
+ const messageId = c.req.param('messageId');
238
+ const projectRoot = c.req.query('project') || process.cwd();
239
+ const cfg = await loadConfig(projectRoot);
240
+ const db = await getDb(cfg.projectRoot);
241
+ const { removeFromQueue, abortMessage } = await import(
242
+ '../runtime/session-queue.ts'
243
+ );
244
+
245
+ // First try to remove from queue (queued messages)
246
+ const removed = removeFromQueue(sessionId, messageId);
247
+ if (removed) {
248
+ // Delete messages from database
249
+ try {
250
+ // Find the assistant message to get its creation time
251
+ const assistantMsg = await db
252
+ .select()
253
+ .from(messages)
254
+ .where(eq(messages.id, messageId))
255
+ .limit(1);
256
+
257
+ if (assistantMsg.length > 0) {
258
+ // Find the user message that came right before (same session, created just before)
259
+ const userMsg = await db
260
+ .select()
261
+ .from(messages)
262
+ .where(
263
+ and(eq(messages.sessionId, sessionId), eq(messages.role, 'user')),
264
+ )
265
+ .orderBy(desc(messages.createdAt))
266
+ .limit(1);
267
+
268
+ const messageIdsToDelete = [messageId];
269
+ if (userMsg.length > 0) {
270
+ messageIdsToDelete.push(userMsg[0].id);
271
+ }
272
+
273
+ // Delete message parts first (foreign key constraint)
274
+ await db
275
+ .delete(messageParts)
276
+ .where(inArray(messageParts.messageId, messageIdsToDelete));
277
+ // Delete messages
278
+ await db
279
+ .delete(messages)
280
+ .where(inArray(messages.id, messageIdsToDelete));
281
+ }
282
+ } catch (err) {
283
+ logger.error('Failed to delete queued messages from DB', err);
284
+ }
285
+ return c.json({ success: true, removed: true, wasQueued: true });
286
+ }
287
+
288
+ // If not in queue, try to abort (might be running)
289
+ const result = abortMessage(sessionId, messageId);
290
+ if (result.removed) {
291
+ return c.json({
292
+ success: true,
293
+ removed: true,
294
+ wasQueued: false,
295
+ wasRunning: result.wasRunning,
296
+ });
297
+ }
298
+
299
+ return c.json({ success: false, removed: false }, 404);
300
+ });
201
301
  }
@@ -0,0 +1,277 @@
1
+ import { eq, asc } from 'drizzle-orm';
2
+ import type { DB } from '@agi-cli/database';
3
+ import { sessions, messages, messageParts } from '@agi-cli/database/schema';
4
+ import { publish } from '../events/bus.ts';
5
+ import type { ProviderId } from '@agi-cli/sdk';
6
+
7
+ type SessionRow = typeof sessions.$inferSelect;
8
+
9
+ export type CreateBranchInput = {
10
+ db: DB;
11
+ parentSessionId: string;
12
+ fromMessageId: string;
13
+ provider?: ProviderId;
14
+ model?: string;
15
+ agent?: string;
16
+ title?: string;
17
+ projectPath: string;
18
+ };
19
+
20
+ export type BranchResult = {
21
+ session: SessionRow;
22
+ parentSessionId: string;
23
+ branchPointMessageId: string;
24
+ copiedMessages: number;
25
+ copiedParts: number;
26
+ };
27
+
28
+ export async function createBranch({
29
+ db,
30
+ parentSessionId,
31
+ fromMessageId,
32
+ provider,
33
+ model,
34
+ agent,
35
+ title,
36
+ projectPath,
37
+ }: CreateBranchInput): Promise<BranchResult> {
38
+ const parentRows = await db
39
+ .select()
40
+ .from(sessions)
41
+ .where(eq(sessions.id, parentSessionId));
42
+
43
+ if (!parentRows.length) {
44
+ throw new Error('Parent session not found');
45
+ }
46
+
47
+ const parent = parentRows[0];
48
+
49
+ if (parent.projectPath !== projectPath) {
50
+ throw new Error('Parent session not found in this project');
51
+ }
52
+
53
+ const branchPointRows = await db
54
+ .select()
55
+ .from(messages)
56
+ .where(eq(messages.id, fromMessageId));
57
+
58
+ if (!branchPointRows.length) {
59
+ throw new Error('Branch point message not found');
60
+ }
61
+
62
+ const branchPoint = branchPointRows[0];
63
+
64
+ if (branchPoint.sessionId !== parentSessionId) {
65
+ throw new Error('Branch point message does not belong to parent session');
66
+ }
67
+
68
+ const allMessages = await db
69
+ .select()
70
+ .from(messages)
71
+ .where(eq(messages.sessionId, parentSessionId))
72
+ .orderBy(asc(messages.createdAt));
73
+
74
+ const branchPointIndex = allMessages.findIndex((m) => m.id === fromMessageId);
75
+ if (branchPointIndex === -1) {
76
+ throw new Error('Branch point message not found in session');
77
+ }
78
+
79
+ const messagesToCopy = allMessages.slice(0, branchPointIndex + 1);
80
+
81
+ const newSessionId = crypto.randomUUID();
82
+ const now = Date.now();
83
+
84
+ const newSession: typeof sessions.$inferInsert = {
85
+ id: newSessionId,
86
+ title: title || `Branch of ${parent.title || 'Untitled'}`,
87
+ agent: agent || parent.agent,
88
+ provider: provider || parent.provider,
89
+ model: model || parent.model,
90
+ projectPath: parent.projectPath,
91
+ createdAt: now,
92
+ lastActiveAt: now,
93
+ parentSessionId,
94
+ branchPointMessageId: fromMessageId,
95
+ sessionType: 'branch',
96
+ };
97
+
98
+ await db.insert(sessions).values(newSession);
99
+
100
+ const messageIdMap = new Map<string, string>();
101
+ let copiedParts = 0;
102
+
103
+ for (const msg of messagesToCopy) {
104
+ const newMessageId = crypto.randomUUID();
105
+ messageIdMap.set(msg.id, newMessageId);
106
+
107
+ const newMessage: typeof messages.$inferInsert = {
108
+ id: newMessageId,
109
+ sessionId: newSessionId,
110
+ role: msg.role,
111
+ status: msg.status,
112
+ agent: msg.agent,
113
+ provider: msg.provider,
114
+ model: msg.model,
115
+ createdAt: msg.createdAt,
116
+ completedAt: msg.completedAt,
117
+ latencyMs: msg.latencyMs,
118
+ promptTokens: msg.promptTokens,
119
+ completionTokens: msg.completionTokens,
120
+ totalTokens: msg.totalTokens,
121
+ cachedInputTokens: msg.cachedInputTokens,
122
+ reasoningTokens: msg.reasoningTokens,
123
+ error: msg.error,
124
+ errorType: msg.errorType,
125
+ errorDetails: msg.errorDetails,
126
+ isAborted: msg.isAborted,
127
+ };
128
+
129
+ await db.insert(messages).values(newMessage);
130
+
131
+ const parts = await db
132
+ .select()
133
+ .from(messageParts)
134
+ .where(eq(messageParts.messageId, msg.id))
135
+ .orderBy(asc(messageParts.index));
136
+
137
+ for (const part of parts) {
138
+ const newPart: typeof messageParts.$inferInsert = {
139
+ id: crypto.randomUUID(),
140
+ messageId: newMessageId,
141
+ index: part.index,
142
+ stepIndex: part.stepIndex,
143
+ type: part.type,
144
+ content: part.content,
145
+ agent: part.agent,
146
+ provider: part.provider,
147
+ model: part.model,
148
+ startedAt: part.startedAt,
149
+ completedAt: part.completedAt,
150
+ compactedAt: part.compactedAt,
151
+ toolName: part.toolName,
152
+ toolCallId: part.toolCallId,
153
+ toolDurationMs: part.toolDurationMs,
154
+ };
155
+
156
+ await db.insert(messageParts).values(newPart);
157
+ copiedParts++;
158
+ }
159
+ }
160
+
161
+ const result: SessionRow = {
162
+ ...newSession,
163
+ totalInputTokens: null,
164
+ totalOutputTokens: null,
165
+ totalCachedTokens: null,
166
+ totalReasoningTokens: null,
167
+ totalToolTimeMs: null,
168
+ toolCountsJson: null,
169
+ contextSummary: null,
170
+ lastCompactedAt: null,
171
+ };
172
+
173
+ publish({
174
+ type: 'session.created',
175
+ sessionId: newSessionId,
176
+ payload: result,
177
+ });
178
+
179
+ return {
180
+ session: result,
181
+ parentSessionId,
182
+ branchPointMessageId: fromMessageId,
183
+ copiedMessages: messagesToCopy.length,
184
+ copiedParts,
185
+ };
186
+ }
187
+
188
+ export type ListBranchesResult = Array<{
189
+ session: SessionRow;
190
+ branchPointMessageId: string | null;
191
+ branchPointPreview: string | null;
192
+ createdAt: number;
193
+ }>;
194
+
195
+ export async function listBranches(
196
+ db: DB,
197
+ sessionId: string,
198
+ projectPath: string,
199
+ ): Promise<ListBranchesResult> {
200
+ const branches = await db
201
+ .select()
202
+ .from(sessions)
203
+ .where(eq(sessions.parentSessionId, sessionId))
204
+ .orderBy(asc(sessions.createdAt));
205
+
206
+ const results: ListBranchesResult = [];
207
+
208
+ for (const branch of branches) {
209
+ if (branch.projectPath !== projectPath) continue;
210
+
211
+ let preview: string | null = null;
212
+
213
+ if (branch.branchPointMessageId) {
214
+ const msgRows = await db
215
+ .select()
216
+ .from(messages)
217
+ .where(eq(messages.id, branch.branchPointMessageId));
218
+
219
+ if (msgRows.length > 0) {
220
+ const parts = await db
221
+ .select()
222
+ .from(messageParts)
223
+ .where(eq(messageParts.messageId, branch.branchPointMessageId))
224
+ .orderBy(asc(messageParts.index));
225
+
226
+ for (const part of parts) {
227
+ if (part.type === 'text') {
228
+ try {
229
+ const content = JSON.parse(part.content || '{}');
230
+ if (content.text) {
231
+ preview = content.text.slice(0, 100);
232
+ break;
233
+ }
234
+ } catch {}
235
+ }
236
+ }
237
+ }
238
+ }
239
+
240
+ results.push({
241
+ session: branch,
242
+ branchPointMessageId: branch.branchPointMessageId,
243
+ branchPointPreview: preview,
244
+ createdAt: branch.createdAt,
245
+ });
246
+ }
247
+
248
+ return results;
249
+ }
250
+
251
+ export async function getParentSession(
252
+ db: DB,
253
+ sessionId: string,
254
+ projectPath: string,
255
+ ): Promise<SessionRow | null> {
256
+ const sessionRows = await db
257
+ .select()
258
+ .from(sessions)
259
+ .where(eq(sessions.id, sessionId));
260
+
261
+ if (!sessionRows.length) return null;
262
+
263
+ const session = sessionRows[0];
264
+ if (!session.parentSessionId) return null;
265
+
266
+ const parentRows = await db
267
+ .select()
268
+ .from(sessions)
269
+ .where(eq(sessions.id, session.parentSessionId));
270
+
271
+ if (!parentRows.length) return null;
272
+
273
+ const parent = parentRows[0];
274
+ if (parent.projectPath !== projectPath) return null;
275
+
276
+ return parent;
277
+ }
@@ -437,6 +437,7 @@ export async function resolveModel(
437
437
  provider: ProviderName,
438
438
  model: string,
439
439
  cfg: AGIConfig,
440
+ options?: { systemPrompt?: string },
440
441
  ) {
441
442
  if (provider === 'openai') {
442
443
  const auth = await getAuth('openai', cfg.projectRoot);
@@ -447,6 +448,7 @@ export async function resolveModel(
447
448
  projectRoot: cfg.projectRoot,
448
449
  reasoningEffort: isCodexModel ? 'high' : 'medium',
449
450
  reasoningSummary: 'auto',
451
+ instructions: options?.systemPrompt,
450
452
  });
451
453
  }
452
454
  if (auth?.type === 'api' && auth.key) {
@@ -34,8 +34,14 @@ import {
34
34
  } from './stream-handlers.ts';
35
35
  import { getCompactionSystemPrompt, pruneSession } from './compaction.ts';
36
36
 
37
- export { enqueueAssistantRun, abortSession } from './session-queue.ts';
38
- export { getRunnerState } from './session-queue.ts';
37
+ export {
38
+ enqueueAssistantRun,
39
+ abortSession,
40
+ abortMessage,
41
+ removeFromQueue,
42
+ getQueueState,
43
+ getRunnerState,
44
+ } from './session-queue.ts';
39
45
 
40
46
  /**
41
47
  * Main loop that processes the queue for a given session.
@@ -253,7 +259,14 @@ async function runAssistant(opts: RunOpts) {
253
259
  );
254
260
  }
255
261
 
256
- const model = await resolveModel(opts.provider, opts.model, cfg);
262
+ // For OpenAI OAuth, pass the full system prompt as instructions
263
+ const oauthSystemPrompt =
264
+ needsSpoof && opts.provider === 'openai' && additionalSystemMessages[0]
265
+ ? additionalSystemMessages[0].content
266
+ : undefined;
267
+ const model = await resolveModel(opts.provider, opts.model, cfg, {
268
+ systemPrompt: oauthSystemPrompt,
269
+ });
257
270
  debugLog(
258
271
  `[RUNNER] Model created: ${JSON.stringify({ id: model.modelId, provider: model.provider })}`,
259
272
  );
@@ -1,4 +1,5 @@
1
1
  import type { ProviderName } from './provider.ts';
2
+ import { publish } from '../events/bus.ts';
2
3
 
3
4
  export type RunOpts = {
4
5
  sessionId: string;
@@ -15,45 +16,190 @@ export type RunOpts = {
15
16
  compactionContext?: string;
16
17
  };
17
18
 
18
- type RunnerState = { queue: RunOpts[]; running: boolean };
19
+ export type QueuedMessage = {
20
+ messageId: string;
21
+ position: number;
22
+ };
23
+
24
+ type RunnerState = {
25
+ queue: RunOpts[];
26
+ running: boolean;
27
+ currentMessageId: string | null;
28
+ };
19
29
 
20
30
  // Global state for session queues
21
31
  const runners = new Map<string, RunnerState>();
22
32
 
23
- // Track active abort controllers per session
24
- const sessionAbortControllers = new Map<string, AbortController>();
33
+ // Track active abort controllers per MESSAGE (not session)
34
+ const messageAbortControllers = new Map<string, AbortController>();
35
+
36
+ function publishQueueState(sessionId: string) {
37
+ const state = runners.get(sessionId);
38
+ if (!state) return;
39
+
40
+ const queuedMessages: QueuedMessage[] = state.queue.map((opts, index) => ({
41
+ messageId: opts.assistantMessageId,
42
+ position: index,
43
+ }));
44
+
45
+ publish({
46
+ type: 'queue.updated',
47
+ sessionId,
48
+ payload: {
49
+ currentMessageId: state.currentMessageId,
50
+ queuedMessages,
51
+ queueLength: state.queue.length,
52
+ },
53
+ });
54
+ }
25
55
 
26
56
  /**
27
57
  * Enqueues an assistant run for a given session.
28
- * Creates an abort controller for the session if one doesn't exist.
58
+ * Creates an abort controller per message.
29
59
  */
30
60
  export function enqueueAssistantRun(
31
61
  opts: Omit<RunOpts, 'abortSignal'>,
32
62
  processQueueFn: (sessionId: string) => Promise<void>,
33
63
  ) {
34
- // Create abort controller for this session
35
64
  const abortController = new AbortController();
36
- sessionAbortControllers.set(opts.sessionId, abortController);
65
+ messageAbortControllers.set(opts.assistantMessageId, abortController);
37
66
 
38
- const state = runners.get(opts.sessionId) ?? { queue: [], running: false };
67
+ const state = runners.get(opts.sessionId) ?? {
68
+ queue: [],
69
+ running: false,
70
+ currentMessageId: null,
71
+ };
39
72
  state.queue.push({ ...opts, abortSignal: abortController.signal });
40
73
  runners.set(opts.sessionId, state);
74
+
75
+ publishQueueState(opts.sessionId);
76
+
41
77
  if (!state.running) void processQueueFn(opts.sessionId);
42
78
  }
43
79
 
44
80
  /**
45
- * Signals the abort controller for a session.
46
- * This will trigger the abortSignal in the streamText call.
81
+ * Aborts the currently running message for a session.
82
+ * Optionally clears the queue.
83
+ */
84
+ export function abortSession(sessionId: string, clearQueue = false) {
85
+ const state = runners.get(sessionId);
86
+ if (!state) return;
87
+
88
+ // Abort the currently running message
89
+ if (state.currentMessageId) {
90
+ const controller = messageAbortControllers.get(state.currentMessageId);
91
+ if (controller) {
92
+ controller.abort();
93
+ messageAbortControllers.delete(state.currentMessageId);
94
+ }
95
+ }
96
+
97
+ // Optionally clear the queue and abort all queued messages
98
+ if (clearQueue && state.queue.length > 0) {
99
+ for (const opts of state.queue) {
100
+ const controller = messageAbortControllers.get(opts.assistantMessageId);
101
+ if (controller) {
102
+ controller.abort();
103
+ messageAbortControllers.delete(opts.assistantMessageId);
104
+ }
105
+ }
106
+ state.queue = [];
107
+ publishQueueState(sessionId);
108
+ }
109
+ }
110
+
111
+ /**
112
+ * Aborts a specific message by its ID.
113
+ * If it's currently running, aborts the stream.
114
+ * If it's queued, removes it from the queue.
115
+ */
116
+ export function abortMessage(
117
+ sessionId: string,
118
+ messageId: string,
119
+ ): { removed: boolean; wasRunning: boolean } {
120
+ const state = runners.get(sessionId);
121
+ if (!state) return { removed: false, wasRunning: false };
122
+
123
+ // Check if this is the currently running message
124
+ if (state.currentMessageId === messageId) {
125
+ const controller = messageAbortControllers.get(messageId);
126
+ if (controller) {
127
+ controller.abort();
128
+ messageAbortControllers.delete(messageId);
129
+ }
130
+ return { removed: true, wasRunning: true };
131
+ }
132
+
133
+ // Check if it's in the queue
134
+ const index = state.queue.findIndex(
135
+ (opts) => opts.assistantMessageId === messageId,
136
+ );
137
+ if (index !== -1) {
138
+ state.queue.splice(index, 1);
139
+ const controller = messageAbortControllers.get(messageId);
140
+ if (controller) {
141
+ controller.abort();
142
+ messageAbortControllers.delete(messageId);
143
+ }
144
+ publishQueueState(sessionId);
145
+ return { removed: true, wasRunning: false };
146
+ }
147
+
148
+ return { removed: false, wasRunning: false };
149
+ }
150
+
151
+ /**
152
+ * Removes a queued message (not the currently running one).
47
153
  */
48
- export function abortSession(sessionId: string) {
49
- const controller = sessionAbortControllers.get(sessionId);
154
+ export function removeFromQueue(sessionId: string, messageId: string): boolean {
155
+ const state = runners.get(sessionId);
156
+ if (!state) return false;
157
+
158
+ // Don't allow removing the currently running message via this function
159
+ if (state.currentMessageId === messageId) {
160
+ return false;
161
+ }
162
+
163
+ const index = state.queue.findIndex(
164
+ (opts) => opts.assistantMessageId === messageId,
165
+ );
166
+ if (index === -1) return false;
167
+
168
+ state.queue.splice(index, 1);
169
+ const controller = messageAbortControllers.get(messageId);
50
170
  if (controller) {
51
171
  controller.abort();
52
- sessionAbortControllers.delete(sessionId);
172
+ messageAbortControllers.delete(messageId);
53
173
  }
174
+
175
+ publishQueueState(sessionId);
176
+ return true;
54
177
  }
55
178
 
56
- export function getRunnerState(sessionId: string): RunnerState | undefined {
179
+ /**
180
+ * Gets the current queue state for a session.
181
+ */
182
+ export function getQueueState(sessionId: string): {
183
+ currentMessageId: string | null;
184
+ queuedMessages: QueuedMessage[];
185
+ isRunning: boolean;
186
+ } | null {
187
+ const state = runners.get(sessionId);
188
+ if (!state) return null;
189
+
190
+ return {
191
+ currentMessageId: state.currentMessageId,
192
+ queuedMessages: state.queue.map((opts, index) => ({
193
+ messageId: opts.assistantMessageId,
194
+ position: index,
195
+ })),
196
+ isRunning: state.running,
197
+ };
198
+ }
199
+
200
+ export function getRunnerState(
201
+ sessionId: string,
202
+ ): { queue: RunOpts[]; running: boolean } | undefined {
57
203
  return runners.get(sessionId);
58
204
  }
59
205
 
@@ -62,16 +208,33 @@ export function setRunning(sessionId: string, running: boolean) {
62
208
  if (state) state.running = running;
63
209
  }
64
210
 
211
+ export function setCurrentMessage(sessionId: string, messageId: string | null) {
212
+ const state = runners.get(sessionId);
213
+ if (state) {
214
+ state.currentMessageId = messageId;
215
+ publishQueueState(sessionId);
216
+ }
217
+ }
218
+
65
219
  export function dequeueJob(sessionId: string): RunOpts | undefined {
66
220
  const state = runners.get(sessionId);
67
- return state?.queue.shift();
221
+ const job = state?.queue.shift();
222
+ if (job && state) {
223
+ state.currentMessageId = job.assistantMessageId;
224
+ publishQueueState(sessionId);
225
+ }
226
+ return job;
68
227
  }
69
228
 
70
229
  export function cleanupSession(sessionId: string) {
71
230
  const state = runners.get(sessionId);
72
231
  if (state && state.queue.length === 0 && !state.running) {
232
+ // Clean up any lingering abort controller for current message
233
+ if (state.currentMessageId) {
234
+ messageAbortControllers.delete(state.currentMessageId);
235
+ }
236
+ state.currentMessageId = null;
73
237
  runners.delete(sessionId);
74
- // Clean up any lingering abort controller
75
- sessionAbortControllers.delete(sessionId);
238
+ publishQueueState(sessionId);
76
239
  }
77
240
  }
@@ -191,6 +191,7 @@ export function adaptTools(
191
191
  delta,
192
192
  stepIndex: meta?.stepIndex ?? ctx.stepIndex,
193
193
  callId: meta?.callId,
194
+ messageId: ctx.messageId,
194
195
  },
195
196
  });
196
197
  if (typeof base.onInputDelta === 'function')
@@ -235,6 +236,7 @@ export function adaptTools(
235
236
  args,
236
237
  callId,
237
238
  stepIndex: ctx.stepIndex,
239
+ messageId: ctx.messageId,
238
240
  },
239
241
  });
240
242
  // Persist synchronously to maintain correct ordering
@@ -266,7 +268,13 @@ export function adaptTools(
266
268
  publish({
267
269
  type: 'tool.call',
268
270
  sessionId: ctx.sessionId,
269
- payload: { name, args, callId, stepIndex: ctx.stepIndex },
271
+ payload: {
272
+ name,
273
+ args,
274
+ callId,
275
+ stepIndex: ctx.stepIndex,
276
+ messageId: ctx.messageId,
277
+ },
270
278
  });
271
279
  // Persist synchronously to maintain correct ordering
272
280
  try {
@@ -373,6 +381,7 @@ export function adaptTools(
373
381
  delta: chunk,
374
382
  stepIndex: stepIndexForEvent,
375
383
  callId: callIdFromQueue,
384
+ messageId: ctx.messageId,
376
385
  },
377
386
  });
378
387
  }