@strav/brain 1.0.0-alpha.17 → 1.0.0-alpha.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +4 -2
- package/src/agent_generate_result.ts +2 -0
- package/src/agent_result.ts +7 -0
- package/src/agent_runner.ts +80 -4
- package/src/brain_manager.ts +119 -2
- package/src/index.ts +20 -2
- package/src/mcp/client.ts +17 -0
- package/src/mcp/index.ts +1 -0
- package/src/mcp/pool.ts +106 -0
- package/src/mcp/resolve_mcp_tools.ts +25 -7
- package/src/persistence/brain_message.ts +34 -0
- package/src/persistence/brain_message_repository.ts +106 -0
- package/src/persistence/brain_store.ts +166 -0
- package/src/persistence/brain_suspended_run.ts +30 -0
- package/src/persistence/brain_suspended_run_repository.ts +68 -0
- package/src/persistence/brain_thread.ts +30 -0
- package/src/persistence/brain_thread_repository.ts +65 -0
- package/src/persistence/database_brain_store.ts +190 -0
- package/src/persistence/index.ts +48 -0
- package/src/persistence/schema/brain_message_schema.ts +61 -0
- package/src/persistence/schema/brain_suspended_run_schema.ts +58 -0
- package/src/persistence/schema/brain_thread_schema.ts +50 -0
- package/src/persistence/schema/index.ts +3 -0
- package/src/provider.ts +36 -1
- package/src/providers/anthropic_provider.ts +140 -23
- package/src/providers/gemini_provider.ts +55 -32
- package/src/providers/openai_compat_provider.ts +452 -23
- package/src/providers/openai_provider.ts +87 -32
- package/src/providers/openai_responses_provider.ts +365 -50
- package/src/suspended_run.ts +153 -0
- package/src/thread.ts +40 -1
- package/src/types.ts +110 -0
|
@@ -73,7 +73,12 @@ import type {
|
|
|
73
73
|
import { resolveMcpTools, type ResolveMcpToolsOptions } from '../mcp/resolve_mcp_tools.ts'
|
|
74
74
|
import { parseGenerated, type OutputSchema } from '../output_schema.ts'
|
|
75
75
|
import { runToolWithRecovery } from '../tool_runner.ts'
|
|
76
|
-
import type {
|
|
76
|
+
import type {
|
|
77
|
+
Provider,
|
|
78
|
+
RunWithToolsOptions,
|
|
79
|
+
RunWithToolsOptionsWithSuspend,
|
|
80
|
+
} from '../provider.ts'
|
|
81
|
+
import type { SuspendedRun } from '../suspended_run.ts'
|
|
77
82
|
import type { Tool } from '../tool.ts'
|
|
78
83
|
import type {
|
|
79
84
|
ChatOptions,
|
|
@@ -119,6 +124,8 @@ export interface GeminiProviderOptions {
|
|
|
119
124
|
client?: { models: GeminiModelsClient }
|
|
120
125
|
/** Internal seam — tests inject a stub MCP client factory. */
|
|
121
126
|
mcpClientFactory?: ResolveMcpToolsOptions['clientFactory']
|
|
127
|
+
/** See `OpenAIProviderOptions.mcpPool` — same semantics. */
|
|
128
|
+
mcpPool?: ResolveMcpToolsOptions['pool']
|
|
122
129
|
}
|
|
123
130
|
|
|
124
131
|
export class GeminiProvider implements Provider {
|
|
@@ -128,6 +135,7 @@ export class GeminiProvider implements Provider {
|
|
|
128
135
|
private readonly defaultMaxTokens: number
|
|
129
136
|
private readonly defaultEmbedModel: string
|
|
130
137
|
private readonly mcpClientFactory?: ResolveMcpToolsOptions['clientFactory']
|
|
138
|
+
private readonly mcpPool?: ResolveMcpToolsOptions['pool']
|
|
131
139
|
|
|
132
140
|
constructor(name: string, config: GeminiProviderConfig, options: GeminiProviderOptions = {}) {
|
|
133
141
|
this.name = name
|
|
@@ -135,6 +143,7 @@ export class GeminiProvider implements Provider {
|
|
|
135
143
|
this.defaultMaxTokens = config.defaultMaxTokens ?? 4096
|
|
136
144
|
this.defaultEmbedModel = config.defaultEmbedModel ?? DEFAULT_GEMINI_EMBED_MODEL
|
|
137
145
|
this.mcpClientFactory = options.mcpClientFactory
|
|
146
|
+
this.mcpPool = options.mcpPool
|
|
138
147
|
if (options.client) {
|
|
139
148
|
this.models = options.client.models
|
|
140
149
|
} else {
|
|
@@ -273,18 +282,42 @@ export class GeminiProvider implements Provider {
|
|
|
273
282
|
}
|
|
274
283
|
}
|
|
275
284
|
|
|
285
|
+
/**
|
|
286
|
+
* Resolve MCP tool descriptors for `servers`, threading the
|
|
287
|
+
* provider's optional `clientFactory` (test seam) and `mcpPool`
|
|
288
|
+
* (long-lived connections) through. Caller invokes
|
|
289
|
+
* `resolved.close()` in `finally` — a no-op when the pool owns
|
|
290
|
+
* lifetimes.
|
|
291
|
+
*/
|
|
292
|
+
private resolveMcp(servers: readonly MCPServer[]): Promise<{
|
|
293
|
+
tools: Tool[]
|
|
294
|
+
close: () => Promise<void>
|
|
295
|
+
}> {
|
|
296
|
+
if (servers.length === 0) {
|
|
297
|
+
return Promise.resolve({ tools: [], close: async () => {} })
|
|
298
|
+
}
|
|
299
|
+
return resolveMcpTools(servers, {
|
|
300
|
+
...(this.mcpClientFactory ? { clientFactory: this.mcpClientFactory } : {}),
|
|
301
|
+
...(this.mcpPool ? { pool: this.mcpPool } : {}),
|
|
302
|
+
})
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
runWithTools(
|
|
306
|
+
messages: readonly Message[],
|
|
307
|
+
tools: readonly Tool[],
|
|
308
|
+
options: RunWithToolsOptionsWithSuspend,
|
|
309
|
+
): Promise<AgentResult | SuspendedRun>
|
|
310
|
+
runWithTools(
|
|
311
|
+
messages: readonly Message[],
|
|
312
|
+
tools: readonly Tool[],
|
|
313
|
+
options?: RunWithToolsOptions,
|
|
314
|
+
): Promise<AgentResult>
|
|
276
315
|
async runWithTools(
|
|
277
316
|
messages: readonly Message[],
|
|
278
317
|
tools: readonly Tool[],
|
|
279
318
|
options: RunWithToolsOptions = {},
|
|
280
|
-
): Promise<AgentResult> {
|
|
281
|
-
const
|
|
282
|
-
const resolved =
|
|
283
|
-
mcpServers.length > 0
|
|
284
|
-
? await resolveMcpTools(mcpServers, {
|
|
285
|
-
...(this.mcpClientFactory ? { clientFactory: this.mcpClientFactory } : {}),
|
|
286
|
-
})
|
|
287
|
-
: { tools: [] as Tool[], close: async () => {} }
|
|
319
|
+
): Promise<AgentResult | SuspendedRun> {
|
|
320
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
288
321
|
try {
|
|
289
322
|
return await this._runLoop(messages, [...tools, ...resolved.tools], options)
|
|
290
323
|
} finally {
|
|
@@ -296,7 +329,7 @@ export class GeminiProvider implements Provider {
|
|
|
296
329
|
messages: readonly Message[],
|
|
297
330
|
tools: readonly Tool[],
|
|
298
331
|
options: RunWithToolsOptions,
|
|
299
|
-
): Promise<AgentResult> {
|
|
332
|
+
): Promise<AgentResult | SuspendedRun> {
|
|
300
333
|
const maxIterations = options.maxIterations ?? 10
|
|
301
334
|
const toolMap = new Map<string, Tool>(tools.map((t) => [t.name, t]))
|
|
302
335
|
const workingMessages: Message[] = [...messages]
|
|
@@ -339,7 +372,15 @@ export class GeminiProvider implements Provider {
|
|
|
339
372
|
}
|
|
340
373
|
|
|
341
374
|
const resultBlocks: ContentBlock[] = []
|
|
342
|
-
for (
|
|
375
|
+
for (let i = 0; i < toolUses.length; i++) {
|
|
376
|
+
const call = toolUses[i]!
|
|
377
|
+
if (options.shouldSuspend && await options.shouldSuspend(call, options.context)) {
|
|
378
|
+
return {
|
|
379
|
+
status: 'suspended',
|
|
380
|
+
pendingToolCalls: toolUses.slice(i),
|
|
381
|
+
state: { messages: workingMessages, iterations, usage: aggregated },
|
|
382
|
+
}
|
|
383
|
+
}
|
|
343
384
|
const { content, isError } = await runToolWithRecovery(
|
|
344
385
|
toolMap.get(call.name),
|
|
345
386
|
call.name,
|
|
@@ -375,13 +416,7 @@ export class GeminiProvider implements Provider {
|
|
|
375
416
|
schema: OutputSchema<T>,
|
|
376
417
|
options: RunWithToolsOptions = {},
|
|
377
418
|
): Promise<AgentGenerateResult<T>> {
|
|
378
|
-
const
|
|
379
|
-
const resolved =
|
|
380
|
-
mcpServers.length > 0
|
|
381
|
-
? await resolveMcpTools(mcpServers, {
|
|
382
|
-
...(this.mcpClientFactory ? { clientFactory: this.mcpClientFactory } : {}),
|
|
383
|
-
})
|
|
384
|
-
: { tools: [] as Tool[], close: async () => {} }
|
|
419
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
385
420
|
try {
|
|
386
421
|
return await this._runLoopWithSchema([...tools, ...resolved.tools], messages, schema, options)
|
|
387
422
|
} finally {
|
|
@@ -480,13 +515,7 @@ export class GeminiProvider implements Provider {
|
|
|
480
515
|
tools: readonly Tool[],
|
|
481
516
|
options: RunWithToolsOptions = {},
|
|
482
517
|
): AsyncIterable<AgentStreamEvent> {
|
|
483
|
-
const
|
|
484
|
-
const resolved =
|
|
485
|
-
mcpServers.length > 0
|
|
486
|
-
? await resolveMcpTools(mcpServers, {
|
|
487
|
-
...(this.mcpClientFactory ? { clientFactory: this.mcpClientFactory } : {}),
|
|
488
|
-
})
|
|
489
|
-
: { tools: [] as Tool[], close: async () => {} }
|
|
518
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
490
519
|
try {
|
|
491
520
|
yield* this._streamLoop(messages, [...tools, ...resolved.tools], options)
|
|
492
521
|
} finally {
|
|
@@ -605,13 +634,7 @@ export class GeminiProvider implements Provider {
|
|
|
605
634
|
schema: OutputSchema<T>,
|
|
606
635
|
options: RunWithToolsOptions = {},
|
|
607
636
|
): AsyncIterable<AgentStreamEvent<T>> {
|
|
608
|
-
const
|
|
609
|
-
const resolved =
|
|
610
|
-
mcpServers.length > 0
|
|
611
|
-
? await resolveMcpTools(mcpServers, {
|
|
612
|
-
...(this.mcpClientFactory ? { clientFactory: this.mcpClientFactory } : {}),
|
|
613
|
-
})
|
|
614
|
-
: { tools: [] as Tool[], close: async () => {} }
|
|
637
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
615
638
|
try {
|
|
616
639
|
yield* this._streamLoopWithSchema(
|
|
617
640
|
[...tools, ...resolved.tools],
|
|
@@ -45,12 +45,17 @@ import { BrainError } from '../brain_error.ts'
|
|
|
45
45
|
import { parseGenerated, type OutputSchema } from '../output_schema.ts'
|
|
46
46
|
import type { RunWithToolsOptions } from '../provider.ts'
|
|
47
47
|
import type { Tool } from '../tool.ts'
|
|
48
|
+
import { recoverOrThrow, runToolWithRecovery } from '../tool_runner.ts'
|
|
49
|
+
import { ToolExecutionError } from '../tool_execution_error.ts'
|
|
48
50
|
import type {
|
|
49
51
|
ChatOptions,
|
|
50
52
|
ChatUsage,
|
|
53
|
+
ContentBlock,
|
|
51
54
|
GenerateResult,
|
|
52
55
|
Message,
|
|
53
56
|
SystemPrompt,
|
|
57
|
+
ToolResultBlock,
|
|
58
|
+
ToolUseBlock,
|
|
54
59
|
} from '../types.ts'
|
|
55
60
|
import { OpenAIProvider } from './openai_provider.ts'
|
|
56
61
|
|
|
@@ -116,36 +121,351 @@ export abstract class OpenAICompatProvider extends OpenAIProvider {
|
|
|
116
121
|
}
|
|
117
122
|
|
|
118
123
|
/**
|
|
119
|
-
* Combined tool-loop + structured output
|
|
120
|
-
* OpenAI-compat
|
|
121
|
-
*
|
|
122
|
-
*
|
|
123
|
-
* tool
|
|
124
|
-
* `
|
|
125
|
-
*
|
|
124
|
+
* Combined tool-loop + structured output via the **tool-forcing**
|
|
125
|
+
* pattern. OpenAI-compat endpoints don't support per-turn
|
|
126
|
+
* `json_schema` enforcement, but they do support OpenAI-style
|
|
127
|
+
* function calling — so the framework injects a synthetic
|
|
128
|
+
* `respond_with_<schemaName>` tool whose JSON-Schema
|
|
129
|
+
* `parameters` IS the desired output schema. The model uses it
|
|
130
|
+
* (and only it) for its final answer; the args become the
|
|
131
|
+
* parsed structured value. Regular tools work normally
|
|
132
|
+
* alongside.
|
|
133
|
+
*
|
|
134
|
+
* The model is prompted to call regular tools first, then
|
|
135
|
+
* `respond_with` exactly once when ready to answer. If it
|
|
136
|
+
* doesn't (returns plain text instead, or hits `maxIterations`),
|
|
137
|
+
* the framework throws `BrainError` — apps should reinforce the
|
|
138
|
+
* pattern via a clearer system prompt, or simplify the task.
|
|
139
|
+
*
|
|
140
|
+
* Caveats vs OpenAI's `strict: true`:
|
|
141
|
+
* - Smaller models may emit invalid JSON in the tool args.
|
|
142
|
+
* `parseGenerated` + the optional `schema.parse` hook catch
|
|
143
|
+
* it at the boundary.
|
|
144
|
+
* - Schema features beyond OpenAI function-calling's subset
|
|
145
|
+
* (recursive refs, advanced keywords) may not be honored.
|
|
146
|
+
* Stick to flat object schemas for best results.
|
|
126
147
|
*/
|
|
127
148
|
override async runWithToolsAndSchema<T>(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
149
|
+
messages: readonly Message[],
|
|
150
|
+
tools: readonly Tool[],
|
|
151
|
+
schema: OutputSchema<T>,
|
|
152
|
+
options: RunWithToolsOptions = {},
|
|
132
153
|
): Promise<AgentGenerateResult<T>> {
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
154
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
155
|
+
try {
|
|
156
|
+
return await this._toolForcingLoop(
|
|
157
|
+
messages,
|
|
158
|
+
[...tools, ...resolved.tools],
|
|
159
|
+
schema,
|
|
160
|
+
options,
|
|
161
|
+
)
|
|
162
|
+
} finally {
|
|
163
|
+
await resolved.close()
|
|
164
|
+
}
|
|
137
165
|
}
|
|
138
166
|
|
|
139
167
|
override async *streamWithToolsAndSchema<T>(
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
168
|
+
messages: readonly Message[],
|
|
169
|
+
tools: readonly Tool[],
|
|
170
|
+
schema: OutputSchema<T>,
|
|
171
|
+
options: RunWithToolsOptions = {},
|
|
144
172
|
): AsyncIterable<AgentStreamEvent<T>> {
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
173
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
174
|
+
try {
|
|
175
|
+
yield* this._toolForcingStream(
|
|
176
|
+
messages,
|
|
177
|
+
[...tools, ...resolved.tools],
|
|
178
|
+
schema,
|
|
179
|
+
options,
|
|
180
|
+
)
|
|
181
|
+
} finally {
|
|
182
|
+
await resolved.close()
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
private async _toolForcingLoop<T>(
|
|
187
|
+
messages: readonly Message[],
|
|
188
|
+
tools: readonly Tool[],
|
|
189
|
+
schema: OutputSchema<T>,
|
|
190
|
+
options: RunWithToolsOptions,
|
|
191
|
+
): Promise<AgentGenerateResult<T>> {
|
|
192
|
+
const { respondTool, respondName, augmented } = prepareToolForcing(schema, options, tools)
|
|
193
|
+
const maxIterations = options.maxIterations ?? 10
|
|
194
|
+
const toolMap = new Map<string, Tool>(tools.map((t) => [t.name, t]))
|
|
195
|
+
const workingMessages: Message[] = [...messages]
|
|
196
|
+
const aggregated: ChatUsage = {
|
|
197
|
+
inputTokens: 0,
|
|
198
|
+
outputTokens: 0,
|
|
199
|
+
cacheReadTokens: 0,
|
|
200
|
+
cacheCreationTokens: 0,
|
|
201
|
+
}
|
|
202
|
+
let iterations = 0
|
|
203
|
+
|
|
204
|
+
while (true) {
|
|
205
|
+
checkAborted(options.signal)
|
|
206
|
+
const params = this.buildParams(workingMessages, augmented, tools)
|
|
207
|
+
params.tools = [...(params.tools ?? []), respondTool]
|
|
208
|
+
const response = await this.client.chat.completions.create(
|
|
209
|
+
params,
|
|
210
|
+
reqOpts(options),
|
|
211
|
+
)
|
|
212
|
+
addUsageHere(aggregated, response.usage, this)
|
|
213
|
+
|
|
214
|
+
const choice = response.choices[0]
|
|
215
|
+
if (!choice) {
|
|
216
|
+
throw new BrainError(
|
|
217
|
+
`${this.name}.runWithToolsAndSchema: response had no choices.`,
|
|
218
|
+
)
|
|
219
|
+
}
|
|
220
|
+
const assistantMessage = choice.message
|
|
221
|
+
workingMessages.push({
|
|
222
|
+
role: 'assistant',
|
|
223
|
+
content: fromOpenAIAssistant(assistantMessage),
|
|
224
|
+
})
|
|
225
|
+
|
|
226
|
+
const toolCalls = assistantMessage.tool_calls ?? []
|
|
227
|
+
const respond = toolCalls.find(
|
|
228
|
+
(c) => c.type === 'function' && c.function.name === respondName,
|
|
229
|
+
)
|
|
230
|
+
if (respond && respond.type === 'function') {
|
|
231
|
+
const text = respond.function.arguments ?? ''
|
|
232
|
+
const value = parseGenerated(text, schema)
|
|
233
|
+
return {
|
|
234
|
+
value,
|
|
235
|
+
text,
|
|
236
|
+
messages: workingMessages,
|
|
237
|
+
iterations,
|
|
238
|
+
stopReason: choice.finish_reason ?? 'stop',
|
|
239
|
+
usage: aggregated,
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
if (toolCalls.length === 0 || choice.finish_reason !== 'tool_calls') {
|
|
244
|
+
throw new BrainError(
|
|
245
|
+
`${this.name}.runWithToolsAndSchema: model returned without calling \`${respondName}\`. Add a stronger instruction in the system prompt — apps must steer the model to use the synthetic respond tool for its final answer.`,
|
|
246
|
+
{ context: { provider: this.name, text: assistantMessage.content ?? '' } },
|
|
247
|
+
)
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
const resultBlocks: ContentBlock[] = []
|
|
251
|
+
for (const call of toolCalls) {
|
|
252
|
+
if (call.type !== 'function') continue
|
|
253
|
+
let parsedInput: unknown
|
|
254
|
+
let parseFailed: { content: string; isError: boolean } | undefined
|
|
255
|
+
try {
|
|
256
|
+
parsedInput = call.function.arguments ? JSON.parse(call.function.arguments) : {}
|
|
257
|
+
} catch (err) {
|
|
258
|
+
parseFailed = recoverOrThrow(
|
|
259
|
+
new ToolExecutionError(
|
|
260
|
+
call.function.name,
|
|
261
|
+
call.id,
|
|
262
|
+
new Error(`Failed to parse tool input JSON: ${(err as Error).message}`),
|
|
263
|
+
),
|
|
264
|
+
options,
|
|
265
|
+
)
|
|
266
|
+
}
|
|
267
|
+
const { content, isError } = parseFailed
|
|
268
|
+
?? (await runToolWithRecovery(
|
|
269
|
+
toolMap.get(call.function.name),
|
|
270
|
+
call.function.name,
|
|
271
|
+
call.id,
|
|
272
|
+
parsedInput,
|
|
273
|
+
options,
|
|
274
|
+
))
|
|
275
|
+
resultBlocks.push({
|
|
276
|
+
type: 'tool_result',
|
|
277
|
+
toolUseId: call.id,
|
|
278
|
+
content,
|
|
279
|
+
...(isError ? { isError: true } : {}),
|
|
280
|
+
} satisfies ToolResultBlock)
|
|
281
|
+
}
|
|
282
|
+
workingMessages.push({ role: 'user', content: resultBlocks })
|
|
283
|
+
|
|
284
|
+
iterations++
|
|
285
|
+
if (iterations >= maxIterations) {
|
|
286
|
+
throw new BrainError(
|
|
287
|
+
`${this.name}.runWithToolsAndSchema: hit maxIterations (${maxIterations}) without the model calling \`${respondName}\`. Bump maxIterations, simplify the task, or strengthen the system-prompt nudge.`,
|
|
288
|
+
{ context: { provider: this.name } },
|
|
289
|
+
)
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
private async *_toolForcingStream<T>(
|
|
295
|
+
messages: readonly Message[],
|
|
296
|
+
tools: readonly Tool[],
|
|
297
|
+
schema: OutputSchema<T>,
|
|
298
|
+
options: RunWithToolsOptions,
|
|
299
|
+
): AsyncIterable<AgentStreamEvent<T>> {
|
|
300
|
+
const { respondTool, respondName, augmented } = prepareToolForcing(schema, options, tools)
|
|
301
|
+
const maxIterations = options.maxIterations ?? 10
|
|
302
|
+
const toolMap = new Map<string, Tool>(tools.map((t) => [t.name, t]))
|
|
303
|
+
const workingMessages: Message[] = [...messages]
|
|
304
|
+
const aggregated: ChatUsage = {
|
|
305
|
+
inputTokens: 0,
|
|
306
|
+
outputTokens: 0,
|
|
307
|
+
cacheReadTokens: 0,
|
|
308
|
+
cacheCreationTokens: 0,
|
|
309
|
+
}
|
|
310
|
+
let iterations = 0
|
|
311
|
+
|
|
312
|
+
while (true) {
|
|
313
|
+
checkAborted(options.signal)
|
|
314
|
+
yield { type: 'iteration_start', iteration: iterations }
|
|
315
|
+
|
|
316
|
+
const baseParams = this.buildParams(workingMessages, augmented, tools)
|
|
317
|
+
baseParams.tools = [...(baseParams.tools ?? []), respondTool]
|
|
318
|
+
const params: OpenAI.Chat.ChatCompletionCreateParamsStreaming = {
|
|
319
|
+
...baseParams,
|
|
320
|
+
stream: true,
|
|
321
|
+
stream_options: { include_usage: true },
|
|
322
|
+
}
|
|
323
|
+
const stream = await this.client.chat.completions.create(params, reqOpts(options))
|
|
324
|
+
|
|
325
|
+
let textBuf = ''
|
|
326
|
+
const toolCallsByIndex = new Map<
|
|
327
|
+
number,
|
|
328
|
+
{ id?: string; name?: string; args: string; started: boolean }
|
|
329
|
+
>()
|
|
330
|
+
let finishReason: string | null = null
|
|
331
|
+
let lastUsage: OpenAI.CompletionUsage | undefined
|
|
332
|
+
|
|
333
|
+
for await (const chunk of stream) {
|
|
334
|
+
const choice = chunk.choices[0]
|
|
335
|
+
const delta = choice?.delta
|
|
336
|
+
if (delta?.content && typeof delta.content === 'string' && delta.content.length > 0) {
|
|
337
|
+
textBuf += delta.content
|
|
338
|
+
yield { type: 'text', delta: delta.content }
|
|
339
|
+
}
|
|
340
|
+
if (delta?.tool_calls) {
|
|
341
|
+
for (const tc of delta.tool_calls) {
|
|
342
|
+
const entry = toolCallsByIndex.get(tc.index) ?? { args: '', started: false }
|
|
343
|
+
if (tc.id) entry.id = tc.id
|
|
344
|
+
if (tc.function?.name) entry.name = tc.function.name
|
|
345
|
+
toolCallsByIndex.set(tc.index, entry)
|
|
346
|
+
if (!entry.started && entry.id !== undefined && entry.name !== undefined) {
|
|
347
|
+
entry.started = true
|
|
348
|
+
if (entry.name !== respondName) {
|
|
349
|
+
yield { type: 'tool_use_start', id: entry.id, name: entry.name }
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
if (tc.function?.arguments) {
|
|
353
|
+
entry.args += tc.function.arguments
|
|
354
|
+
if (
|
|
355
|
+
entry.started &&
|
|
356
|
+
entry.id !== undefined &&
|
|
357
|
+
entry.name !== respondName
|
|
358
|
+
) {
|
|
359
|
+
yield {
|
|
360
|
+
type: 'tool_use_delta',
|
|
361
|
+
id: entry.id,
|
|
362
|
+
argsDelta: tc.function.arguments,
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
if (choice?.finish_reason) finishReason = choice.finish_reason
|
|
369
|
+
if (chunk.usage) lastUsage = chunk.usage
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
addUsageHere(aggregated, lastUsage, this)
|
|
373
|
+
yield { type: 'iteration_end', iteration: iterations, stopReason: finishReason }
|
|
374
|
+
|
|
375
|
+
const assistantBlocks: ContentBlock[] = []
|
|
376
|
+
if (textBuf.length > 0) assistantBlocks.push({ type: 'text', text: textBuf })
|
|
377
|
+
const orderedCalls = [...toolCallsByIndex.entries()]
|
|
378
|
+
.sort(([a], [b]) => a - b)
|
|
379
|
+
.map(([, v]) => v)
|
|
380
|
+
for (const call of orderedCalls) {
|
|
381
|
+
if (!call.id || !call.name) continue
|
|
382
|
+
let parsedInput: unknown = {}
|
|
383
|
+
try {
|
|
384
|
+
parsedInput = call.args ? JSON.parse(call.args) : {}
|
|
385
|
+
} catch {
|
|
386
|
+
parsedInput = call.args
|
|
387
|
+
}
|
|
388
|
+
assistantBlocks.push({
|
|
389
|
+
type: 'tool_use',
|
|
390
|
+
id: call.id,
|
|
391
|
+
name: call.name,
|
|
392
|
+
input: parsedInput,
|
|
393
|
+
} satisfies ToolUseBlock)
|
|
394
|
+
}
|
|
395
|
+
const assistantContent: string | ContentBlock[] =
|
|
396
|
+
assistantBlocks.length === 1 && assistantBlocks[0]?.type === 'text'
|
|
397
|
+
? assistantBlocks[0].text
|
|
398
|
+
: assistantBlocks
|
|
399
|
+
workingMessages.push({ role: 'assistant', content: assistantContent })
|
|
400
|
+
|
|
401
|
+
const respond = orderedCalls.find((c) => c.name === respondName)
|
|
402
|
+
if (respond && respond.id) {
|
|
403
|
+
const text = respond.args
|
|
404
|
+
const value = parseGenerated(text, schema)
|
|
405
|
+
yield {
|
|
406
|
+
type: 'stop',
|
|
407
|
+
stopReason: finishReason ?? 'stop',
|
|
408
|
+
iterations,
|
|
409
|
+
usage: aggregated,
|
|
410
|
+
messages: workingMessages,
|
|
411
|
+
value,
|
|
412
|
+
text,
|
|
413
|
+
} as AgentStreamEvent<T>
|
|
414
|
+
return
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
if (finishReason !== 'tool_calls' || orderedCalls.length === 0) {
|
|
418
|
+
throw new BrainError(
|
|
419
|
+
`${this.name}.streamWithToolsAndSchema: model returned without calling \`${respondName}\`. Strengthen the system-prompt nudge.`,
|
|
420
|
+
{ context: { provider: this.name, text: textBuf } },
|
|
421
|
+
)
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
const resultBlocks: ContentBlock[] = []
|
|
425
|
+
for (const call of orderedCalls) {
|
|
426
|
+
if (!call.id || !call.name) continue
|
|
427
|
+
let parsedInput: unknown
|
|
428
|
+
let parseFailed: { content: string; isError: boolean } | undefined
|
|
429
|
+
try {
|
|
430
|
+
parsedInput = call.args ? JSON.parse(call.args) : {}
|
|
431
|
+
} catch (err) {
|
|
432
|
+
parseFailed = recoverOrThrow(
|
|
433
|
+
new ToolExecutionError(
|
|
434
|
+
call.name,
|
|
435
|
+
call.id,
|
|
436
|
+
new Error(`Failed to parse tool input JSON: ${(err as Error).message}`),
|
|
437
|
+
),
|
|
438
|
+
options,
|
|
439
|
+
)
|
|
440
|
+
parsedInput = call.args
|
|
441
|
+
}
|
|
442
|
+
yield { type: 'tool_use', id: call.id, name: call.name, input: parsedInput }
|
|
443
|
+
const { content, isError } = parseFailed
|
|
444
|
+
?? (await runToolWithRecovery(
|
|
445
|
+
toolMap.get(call.name),
|
|
446
|
+
call.name,
|
|
447
|
+
call.id,
|
|
448
|
+
parsedInput,
|
|
449
|
+
options,
|
|
450
|
+
))
|
|
451
|
+
resultBlocks.push({
|
|
452
|
+
type: 'tool_result',
|
|
453
|
+
toolUseId: call.id,
|
|
454
|
+
content,
|
|
455
|
+
...(isError ? { isError: true } : {}),
|
|
456
|
+
} satisfies ToolResultBlock)
|
|
457
|
+
yield { type: 'tool_result', id: call.id, name: call.name, content, isError }
|
|
458
|
+
}
|
|
459
|
+
workingMessages.push({ role: 'user', content: resultBlocks })
|
|
460
|
+
|
|
461
|
+
iterations++
|
|
462
|
+
if (iterations >= maxIterations) {
|
|
463
|
+
throw new BrainError(
|
|
464
|
+
`${this.name}.streamWithToolsAndSchema: hit maxIterations (${maxIterations}) without the model calling \`${respondName}\`.`,
|
|
465
|
+
{ context: { provider: this.name } },
|
|
466
|
+
)
|
|
467
|
+
}
|
|
468
|
+
}
|
|
149
469
|
}
|
|
150
470
|
|
|
151
471
|
/**
|
|
@@ -185,3 +505,112 @@ function schemaInstruction(schema: OutputSchema<unknown>): string {
|
|
|
185
505
|
].filter((s): s is string => s !== undefined)
|
|
186
506
|
return lines.join('\n')
|
|
187
507
|
}
|
|
508
|
+
|
|
509
|
+
// ─── Tool-forcing helpers ────────────────────────────────────────────────
|
|
510
|
+
|
|
511
|
+
const RESPOND_TOOL_PREFIX = 'respond_with_'
|
|
512
|
+
|
|
513
|
+
/**
|
|
514
|
+
* Build the synthetic respond-tool entry + the system-prompt nudge
|
|
515
|
+
* apps inject alongside their own system message. Validates that
|
|
516
|
+
* the chosen tool name doesn't collide with any user tool — that
|
|
517
|
+
* would make the loop's terminal detection ambiguous.
|
|
518
|
+
*/
|
|
519
|
+
function prepareToolForcing(
|
|
520
|
+
schema: OutputSchema<unknown>,
|
|
521
|
+
options: ChatOptions,
|
|
522
|
+
userTools: readonly Tool[],
|
|
523
|
+
): {
|
|
524
|
+
respondTool: OpenAI.Chat.ChatCompletionTool
|
|
525
|
+
respondName: string
|
|
526
|
+
augmented: ChatOptions
|
|
527
|
+
} {
|
|
528
|
+
const respondName = `${RESPOND_TOOL_PREFIX}${schema.name}`
|
|
529
|
+
if (userTools.some((t) => t.name === respondName)) {
|
|
530
|
+
throw new BrainError(
|
|
531
|
+
`OpenAICompatProvider.runWithToolsAndSchema: synthetic tool name "${respondName}" collides with a user-supplied tool. Rename your tool or the OutputSchema.name to avoid the clash.`,
|
|
532
|
+
{ context: { conflictingName: respondName } },
|
|
533
|
+
)
|
|
534
|
+
}
|
|
535
|
+
const respondTool: OpenAI.Chat.ChatCompletionTool = {
|
|
536
|
+
type: 'function',
|
|
537
|
+
function: {
|
|
538
|
+
name: respondName,
|
|
539
|
+
description:
|
|
540
|
+
`Submit your final answer. Call this exactly once, after using any other tools you need. ` +
|
|
541
|
+
`The arguments MUST conform to the schema below. Do not return prose alongside or in place of this call.` +
|
|
542
|
+
(schema.description ? ` (${schema.description})` : ''),
|
|
543
|
+
parameters: schema.jsonSchema as Record<string, unknown>,
|
|
544
|
+
},
|
|
545
|
+
}
|
|
546
|
+
const augmented: ChatOptions = {
|
|
547
|
+
...options,
|
|
548
|
+
system: combineSystem(options.system, toolForcingInstruction(respondName)),
|
|
549
|
+
}
|
|
550
|
+
return { respondTool, respondName, augmented }
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
function toolForcingInstruction(respondName: string): string {
|
|
554
|
+
return [
|
|
555
|
+
`When you are ready to give the final answer, call the \`${respondName}\` function with the structured arguments.`,
|
|
556
|
+
`Use any other available tools first to gather what you need. Once you have enough information, call \`${respondName}\` exactly once and do NOT also return prose text.`,
|
|
557
|
+
].join(' ')
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
function reqOpts(options: { signal?: AbortSignal }): { signal?: AbortSignal } | undefined {
|
|
561
|
+
return options.signal !== undefined ? { signal: options.signal } : undefined
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
function checkAborted(signal: AbortSignal | undefined): void {
|
|
565
|
+
if (signal?.aborted) {
|
|
566
|
+
throw signal.reason ?? new DOMException('Aborted', 'AbortError')
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
function fromOpenAIAssistant(
|
|
571
|
+
msg: OpenAI.Chat.ChatCompletionMessage,
|
|
572
|
+
): string | ContentBlock[] {
|
|
573
|
+
const blocks: ContentBlock[] = []
|
|
574
|
+
if (msg.content) blocks.push({ type: 'text', text: msg.content })
|
|
575
|
+
if (msg.tool_calls) {
|
|
576
|
+
for (const call of msg.tool_calls) {
|
|
577
|
+
if (call.type !== 'function') continue
|
|
578
|
+
let parsedInput: unknown = {}
|
|
579
|
+
try {
|
|
580
|
+
parsedInput = call.function.arguments ? JSON.parse(call.function.arguments) : {}
|
|
581
|
+
} catch {
|
|
582
|
+
parsedInput = call.function.arguments ?? {}
|
|
583
|
+
}
|
|
584
|
+
blocks.push({
|
|
585
|
+
type: 'tool_use',
|
|
586
|
+
id: call.id,
|
|
587
|
+
name: call.function.name,
|
|
588
|
+
input: parsedInput,
|
|
589
|
+
} satisfies ToolUseBlock)
|
|
590
|
+
}
|
|
591
|
+
}
|
|
592
|
+
if (blocks.length === 1 && blocks[0]?.type === 'text') return blocks[0].text
|
|
593
|
+
return blocks
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
/**
|
|
597
|
+
* Add provider-mapped usage onto an accumulator. Calls `mapUsage`
|
|
598
|
+
* on the provider instance so subclasses (e.g., DeepSeek) honor
|
|
599
|
+
* their vendor-specific cache fields.
|
|
600
|
+
*/
|
|
601
|
+
function addUsageHere(
|
|
602
|
+
acc: ChatUsage,
|
|
603
|
+
u: OpenAI.CompletionUsage | undefined,
|
|
604
|
+
provider: OpenAICompatProvider,
|
|
605
|
+
): void {
|
|
606
|
+
if (!u) return
|
|
607
|
+
// Cast: `mapUsage` is protected on the abstract class; we're
|
|
608
|
+
// inside the module so the access is valid at runtime.
|
|
609
|
+
const mapped = (provider as unknown as {
|
|
610
|
+
mapUsage(u: OpenAI.CompletionUsage | undefined): ChatUsage
|
|
611
|
+
}).mapUsage(u)
|
|
612
|
+
acc.inputTokens += mapped.inputTokens
|
|
613
|
+
acc.outputTokens += mapped.outputTokens
|
|
614
|
+
acc.cacheReadTokens += mapped.cacheReadTokens
|
|
615
|
+
acc.cacheCreationTokens += mapped.cacheCreationTokens
|
|
616
|
+
}
|