@oh-my-pi/pi-agent-core 1.337.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,409 @@
1
+ /**
2
+ * Agent loop that works with AgentMessage throughout.
3
+ * Transforms to Message[] only at the LLM call boundary.
4
+ */
5
+
6
+ import {
7
+ type AssistantMessage,
8
+ type Context,
9
+ EventStream,
10
+ streamSimple,
11
+ type ToolResultMessage,
12
+ validateToolArguments,
13
+ } from "@oh-my-pi/pi-ai";
14
+ import type {
15
+ AgentContext,
16
+ AgentEvent,
17
+ AgentLoopConfig,
18
+ AgentMessage,
19
+ AgentTool,
20
+ AgentToolResult,
21
+ StreamFn,
22
+ } from "./types.js";
23
+
24
+ /**
25
+ * Start an agent loop with a new prompt message.
26
+ * The prompt is added to the context and events are emitted for it.
27
+ */
28
+ export function agentLoop(
29
+ prompts: AgentMessage[],
30
+ context: AgentContext,
31
+ config: AgentLoopConfig,
32
+ signal?: AbortSignal,
33
+ streamFn?: StreamFn,
34
+ ): EventStream<AgentEvent, AgentMessage[]> {
35
+ const stream = createAgentStream();
36
+
37
+ (async () => {
38
+ const newMessages: AgentMessage[] = [...prompts];
39
+ const currentContext: AgentContext = {
40
+ ...context,
41
+ messages: [...context.messages, ...prompts],
42
+ };
43
+
44
+ stream.push({ type: "agent_start" });
45
+ stream.push({ type: "turn_start" });
46
+ for (const prompt of prompts) {
47
+ stream.push({ type: "message_start", message: prompt });
48
+ stream.push({ type: "message_end", message: prompt });
49
+ }
50
+
51
+ await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
52
+ })();
53
+
54
+ return stream;
55
+ }
56
+
57
+ /**
58
+ * Continue an agent loop from the current context without adding a new message.
59
+ * Used for retries - context already has user message or tool results.
60
+ *
61
+ * **Important:** The last message in context must convert to a `user` or `toolResult` message
62
+ * via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
63
+ * This cannot be validated here since `convertToLlm` is only called once per turn.
64
+ */
65
+ export function agentLoopContinue(
66
+ context: AgentContext,
67
+ config: AgentLoopConfig,
68
+ signal?: AbortSignal,
69
+ streamFn?: StreamFn,
70
+ ): EventStream<AgentEvent, AgentMessage[]> {
71
+ if (context.messages.length === 0) {
72
+ throw new Error("Cannot continue: no messages in context");
73
+ }
74
+
75
+ if (context.messages[context.messages.length - 1].role === "assistant") {
76
+ throw new Error("Cannot continue from message role: assistant");
77
+ }
78
+
79
+ const stream = createAgentStream();
80
+
81
+ (async () => {
82
+ const newMessages: AgentMessage[] = [];
83
+ const currentContext: AgentContext = { ...context };
84
+
85
+ stream.push({ type: "agent_start" });
86
+ stream.push({ type: "turn_start" });
87
+
88
+ await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
89
+ })();
90
+
91
+ return stream;
92
+ }
93
+
94
+ function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
95
+ return new EventStream<AgentEvent, AgentMessage[]>(
96
+ (event: AgentEvent) => event.type === "agent_end",
97
+ (event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
98
+ );
99
+ }
100
+
101
+ /**
102
+ * Main loop logic shared by agentLoop and agentLoopContinue.
103
+ */
104
+ async function runLoop(
105
+ currentContext: AgentContext,
106
+ newMessages: AgentMessage[],
107
+ config: AgentLoopConfig,
108
+ signal: AbortSignal | undefined,
109
+ stream: EventStream<AgentEvent, AgentMessage[]>,
110
+ streamFn?: StreamFn,
111
+ ): Promise<void> {
112
+ let hasMoreToolCalls = true;
113
+ let firstTurn = true;
114
+ let queuedMessages: AgentMessage[] = (await config.getQueuedMessages?.()) || [];
115
+ let queuedAfterTools: AgentMessage[] | null = null;
116
+
117
+ while (hasMoreToolCalls || queuedMessages.length > 0) {
118
+ if (!firstTurn) {
119
+ stream.push({ type: "turn_start" });
120
+ } else {
121
+ firstTurn = false;
122
+ }
123
+
124
+ // Process queued messages (inject before next assistant response)
125
+ if (queuedMessages.length > 0) {
126
+ for (const message of queuedMessages) {
127
+ stream.push({ type: "message_start", message });
128
+ stream.push({ type: "message_end", message });
129
+ currentContext.messages.push(message);
130
+ newMessages.push(message);
131
+ }
132
+ queuedMessages = [];
133
+ }
134
+
135
+ // Stream assistant response
136
+ const message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn);
137
+ newMessages.push(message);
138
+
139
+ if (message.stopReason === "error" || message.stopReason === "aborted") {
140
+ stream.push({ type: "turn_end", message, toolResults: [] });
141
+ stream.push({ type: "agent_end", messages: newMessages });
142
+ stream.end(newMessages);
143
+ return;
144
+ }
145
+
146
+ // Check for tool calls
147
+ const toolCalls = message.content.filter((c) => c.type === "toolCall");
148
+ hasMoreToolCalls = toolCalls.length > 0;
149
+
150
+ const toolResults: ToolResultMessage[] = [];
151
+ if (hasMoreToolCalls) {
152
+ const toolExecution = await executeToolCalls(
153
+ currentContext.tools,
154
+ message,
155
+ signal,
156
+ stream,
157
+ config.getQueuedMessages,
158
+ config.getToolContext,
159
+ );
160
+ toolResults.push(...toolExecution.toolResults);
161
+ queuedAfterTools = toolExecution.queuedMessages ?? null;
162
+
163
+ for (const result of toolResults) {
164
+ currentContext.messages.push(result);
165
+ newMessages.push(result);
166
+ }
167
+ }
168
+
169
+ stream.push({ type: "turn_end", message, toolResults });
170
+
171
+ // Get queued messages after turn completes
172
+ if (queuedAfterTools && queuedAfterTools.length > 0) {
173
+ queuedMessages = queuedAfterTools;
174
+ queuedAfterTools = null;
175
+ } else {
176
+ queuedMessages = (await config.getQueuedMessages?.()) || [];
177
+ }
178
+ }
179
+
180
+ stream.push({ type: "agent_end", messages: newMessages });
181
+ stream.end(newMessages);
182
+ }
183
+
184
+ /**
185
+ * Stream an assistant response from the LLM.
186
+ * This is where AgentMessage[] gets transformed to Message[] for the LLM.
187
+ */
188
+ async function streamAssistantResponse(
189
+ context: AgentContext,
190
+ config: AgentLoopConfig,
191
+ signal: AbortSignal | undefined,
192
+ stream: EventStream<AgentEvent, AgentMessage[]>,
193
+ streamFn?: StreamFn,
194
+ ): Promise<AssistantMessage> {
195
+ // Apply context transform if configured (AgentMessage[] → AgentMessage[])
196
+ let messages = context.messages;
197
+ if (config.transformContext) {
198
+ messages = await config.transformContext(messages, signal);
199
+ }
200
+
201
+ // Convert to LLM-compatible messages (AgentMessage[] → Message[])
202
+ const llmMessages = await config.convertToLlm(messages);
203
+
204
+ // Build LLM context
205
+ const llmContext: Context = {
206
+ systemPrompt: context.systemPrompt,
207
+ messages: llmMessages,
208
+ tools: context.tools,
209
+ };
210
+
211
+ const streamFunction = streamFn || streamSimple;
212
+
213
+ // Resolve API key (important for expiring tokens)
214
+ const resolvedApiKey =
215
+ (config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
216
+
217
+ const response = await streamFunction(config.model, llmContext, {
218
+ ...config,
219
+ apiKey: resolvedApiKey,
220
+ signal,
221
+ });
222
+
223
+ let partialMessage: AssistantMessage | null = null;
224
+ let addedPartial = false;
225
+
226
+ for await (const event of response) {
227
+ switch (event.type) {
228
+ case "start":
229
+ partialMessage = event.partial;
230
+ context.messages.push(partialMessage);
231
+ addedPartial = true;
232
+ stream.push({ type: "message_start", message: { ...partialMessage } });
233
+ break;
234
+
235
+ case "text_start":
236
+ case "text_delta":
237
+ case "text_end":
238
+ case "thinking_start":
239
+ case "thinking_delta":
240
+ case "thinking_end":
241
+ case "toolcall_start":
242
+ case "toolcall_delta":
243
+ case "toolcall_end":
244
+ if (partialMessage) {
245
+ partialMessage = event.partial;
246
+ context.messages[context.messages.length - 1] = partialMessage;
247
+ stream.push({
248
+ type: "message_update",
249
+ assistantMessageEvent: event,
250
+ message: { ...partialMessage },
251
+ });
252
+ }
253
+ break;
254
+
255
+ case "done":
256
+ case "error": {
257
+ const finalMessage = await response.result();
258
+ if (addedPartial) {
259
+ context.messages[context.messages.length - 1] = finalMessage;
260
+ } else {
261
+ context.messages.push(finalMessage);
262
+ }
263
+ if (!addedPartial) {
264
+ stream.push({ type: "message_start", message: { ...finalMessage } });
265
+ }
266
+ stream.push({ type: "message_end", message: finalMessage });
267
+ return finalMessage;
268
+ }
269
+ }
270
+ }
271
+
272
+ return await response.result();
273
+ }
274
+
275
+ /**
276
+ * Execute tool calls from an assistant message.
277
+ */
278
+ async function executeToolCalls(
279
+ tools: AgentTool<any>[] | undefined,
280
+ assistantMessage: AssistantMessage,
281
+ signal: AbortSignal | undefined,
282
+ stream: EventStream<AgentEvent, AgentMessage[]>,
283
+ getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
284
+ getToolContext?: AgentLoopConfig["getToolContext"],
285
+ ): Promise<{ toolResults: ToolResultMessage[]; queuedMessages?: AgentMessage[] }> {
286
+ const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
287
+ const results: ToolResultMessage[] = [];
288
+ let queuedMessages: AgentMessage[] | undefined;
289
+
290
+ for (let index = 0; index < toolCalls.length; index++) {
291
+ const toolCall = toolCalls[index];
292
+ const tool = tools?.find((t) => t.name === toolCall.name);
293
+
294
+ stream.push({
295
+ type: "tool_execution_start",
296
+ toolCallId: toolCall.id,
297
+ toolName: toolCall.name,
298
+ args: toolCall.arguments,
299
+ });
300
+
301
+ let result: AgentToolResult<any>;
302
+ let isError = false;
303
+
304
+ try {
305
+ if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
306
+
307
+ const validatedArgs = validateToolArguments(tool, toolCall);
308
+ const toolContext = getToolContext?.();
309
+
310
+ result = await tool.execute(
311
+ toolCall.id,
312
+ validatedArgs,
313
+ signal,
314
+ (partialResult) => {
315
+ stream.push({
316
+ type: "tool_execution_update",
317
+ toolCallId: toolCall.id,
318
+ toolName: toolCall.name,
319
+ args: toolCall.arguments,
320
+ partialResult,
321
+ });
322
+ },
323
+ toolContext,
324
+ );
325
+ } catch (e) {
326
+ result = {
327
+ content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
328
+ details: {},
329
+ };
330
+ isError = true;
331
+ }
332
+
333
+ stream.push({
334
+ type: "tool_execution_end",
335
+ toolCallId: toolCall.id,
336
+ toolName: toolCall.name,
337
+ result,
338
+ isError,
339
+ });
340
+
341
+ const toolResultMessage: ToolResultMessage = {
342
+ role: "toolResult",
343
+ toolCallId: toolCall.id,
344
+ toolName: toolCall.name,
345
+ content: result.content,
346
+ details: result.details,
347
+ isError,
348
+ timestamp: Date.now(),
349
+ };
350
+
351
+ results.push(toolResultMessage);
352
+ stream.push({ type: "message_start", message: toolResultMessage });
353
+ stream.push({ type: "message_end", message: toolResultMessage });
354
+
355
+ // Check for queued messages - skip remaining tools if user interrupted
356
+ if (getQueuedMessages) {
357
+ const queued = await getQueuedMessages();
358
+ if (queued.length > 0) {
359
+ queuedMessages = queued;
360
+ const remainingCalls = toolCalls.slice(index + 1);
361
+ for (const skipped of remainingCalls) {
362
+ results.push(skipToolCall(skipped, stream));
363
+ }
364
+ break;
365
+ }
366
+ }
367
+ }
368
+
369
+ return { toolResults: results, queuedMessages };
370
+ }
371
+
372
+ function skipToolCall(
373
+ toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
374
+ stream: EventStream<AgentEvent, AgentMessage[]>,
375
+ ): ToolResultMessage {
376
+ const result: AgentToolResult<any> = {
377
+ content: [{ type: "text", text: "Skipped due to queued user message." }],
378
+ details: {},
379
+ };
380
+
381
+ stream.push({
382
+ type: "tool_execution_start",
383
+ toolCallId: toolCall.id,
384
+ toolName: toolCall.name,
385
+ args: toolCall.arguments,
386
+ });
387
+ stream.push({
388
+ type: "tool_execution_end",
389
+ toolCallId: toolCall.id,
390
+ toolName: toolCall.name,
391
+ result,
392
+ isError: true,
393
+ });
394
+
395
+ const toolResultMessage: ToolResultMessage = {
396
+ role: "toolResult",
397
+ toolCallId: toolCall.id,
398
+ toolName: toolCall.name,
399
+ content: result.content,
400
+ details: {},
401
+ isError: true,
402
+ timestamp: Date.now(),
403
+ };
404
+
405
+ stream.push({ type: "message_start", message: toolResultMessage });
406
+ stream.push({ type: "message_end", message: toolResultMessage });
407
+
408
+ return toolResultMessage;
409
+ }