@strav/brain 1.0.0-alpha.16 → 1.0.0-alpha.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +4 -2
- package/src/agent.ts +34 -5
- package/src/agent_generate_result.ts +2 -0
- package/src/agent_result.ts +7 -0
- package/src/agent_runner.ts +134 -15
- package/src/agent_stream_event.ts +100 -0
- package/src/brain_config.ts +91 -1
- package/src/brain_manager.ts +287 -6
- package/src/brain_provider.ts +25 -1
- package/src/index.ts +37 -2
- package/src/mcp/client.ts +99 -13
- package/src/mcp/index.ts +7 -0
- package/src/mcp/oauth.ts +227 -0
- package/src/mcp/pool.ts +106 -0
- package/src/mcp/resolve_mcp_tools.ts +31 -9
- package/src/mcp_server.ts +16 -0
- package/src/persistence/brain_message.ts +34 -0
- package/src/persistence/brain_message_repository.ts +106 -0
- package/src/persistence/brain_store.ts +166 -0
- package/src/persistence/brain_suspended_run.ts +30 -0
- package/src/persistence/brain_suspended_run_repository.ts +68 -0
- package/src/persistence/brain_thread.ts +30 -0
- package/src/persistence/brain_thread_repository.ts +65 -0
- package/src/persistence/database_brain_store.ts +190 -0
- package/src/persistence/index.ts +48 -0
- package/src/persistence/schema/brain_message_schema.ts +61 -0
- package/src/persistence/schema/brain_suspended_run_schema.ts +58 -0
- package/src/persistence/schema/brain_thread_schema.ts +50 -0
- package/src/persistence/schema/index.ts +3 -0
- package/src/provider.ts +145 -1
- package/src/providers/anthropic_provider.ts +723 -38
- package/src/providers/deepseek_provider.ts +117 -0
- package/src/providers/gemini_provider.ts +625 -33
- package/src/providers/ollama_provider.ts +86 -0
- package/src/providers/openai_compat_provider.ts +616 -0
- package/src/providers/openai_provider.ts +801 -43
- package/src/providers/openai_responses_provider.ts +1015 -0
- package/src/suspended_run.ts +153 -0
- package/src/thread.ts +40 -1
- package/src/tool.ts +7 -0
- package/src/tool_runner.ts +81 -0
- package/src/types.ts +343 -0
|
@@ -60,11 +60,26 @@ import type { AgentResult } from '../agent_result.ts'
|
|
|
60
60
|
import { BrainError } from '../brain_error.ts'
|
|
61
61
|
import type { GeminiProviderConfig } from '../brain_config.ts'
|
|
62
62
|
import type { MCPServer } from '../mcp_server.ts'
|
|
63
|
+
import type { AgentGenerateResult } from '../agent_generate_result.ts'
|
|
64
|
+
import type { AgentStreamEvent } from '../agent_stream_event.ts'
|
|
65
|
+
import type {
|
|
66
|
+
AudioSource,
|
|
67
|
+
EmbedOptions,
|
|
68
|
+
EmbedResult,
|
|
69
|
+
ServerTool,
|
|
70
|
+
TranscribeOptions,
|
|
71
|
+
TranscribeResult,
|
|
72
|
+
} from '../types.ts'
|
|
63
73
|
import { resolveMcpTools, type ResolveMcpToolsOptions } from '../mcp/resolve_mcp_tools.ts'
|
|
64
74
|
import { parseGenerated, type OutputSchema } from '../output_schema.ts'
|
|
65
|
-
import
|
|
75
|
+
import { runToolWithRecovery } from '../tool_runner.ts'
|
|
76
|
+
import type {
|
|
77
|
+
Provider,
|
|
78
|
+
RunWithToolsOptions,
|
|
79
|
+
RunWithToolsOptionsWithSuspend,
|
|
80
|
+
} from '../provider.ts'
|
|
81
|
+
import type { SuspendedRun } from '../suspended_run.ts'
|
|
66
82
|
import type { Tool } from '../tool.ts'
|
|
67
|
-
import { ToolExecutionError } from '../tool_execution_error.ts'
|
|
68
83
|
import type {
|
|
69
84
|
ChatOptions,
|
|
70
85
|
ChatResult,
|
|
@@ -80,6 +95,7 @@ import type {
|
|
|
80
95
|
} from '../types.ts'
|
|
81
96
|
|
|
82
97
|
const DEFAULT_GEMINI_MODEL = 'gemini-2.5-flash'
|
|
98
|
+
const DEFAULT_GEMINI_EMBED_MODEL = 'text-embedding-004'
|
|
83
99
|
|
|
84
100
|
/**
|
|
85
101
|
* The slice of `GoogleGenAI` the provider exercises. Narrowed so
|
|
@@ -91,12 +107,25 @@ export interface GeminiModelsClient {
|
|
|
91
107
|
params: GenerateContentParameters,
|
|
92
108
|
): Promise<AsyncIterable<GenerateContentResponse>>
|
|
93
109
|
countTokens(params: { model: string; contents: Content[] }): Promise<{ totalTokens?: number }>
|
|
110
|
+
/**
|
|
111
|
+
* Optional on the test seam — the real SDK always provides it,
|
|
112
|
+
* but tests that don't exercise embed don't need to stub it.
|
|
113
|
+
* `embed()` calls this directly; missing it throws a clear
|
|
114
|
+
* TypeError if invoked.
|
|
115
|
+
*/
|
|
116
|
+
embedContent?(params: {
|
|
117
|
+
model: string
|
|
118
|
+
contents: string[]
|
|
119
|
+
config?: { outputDimensionality?: number; abortSignal?: AbortSignal }
|
|
120
|
+
}): Promise<{ embeddings?: Array<{ values?: number[] }> }>
|
|
94
121
|
}
|
|
95
122
|
|
|
96
123
|
export interface GeminiProviderOptions {
|
|
97
124
|
client?: { models: GeminiModelsClient }
|
|
98
125
|
/** Internal seam — tests inject a stub MCP client factory. */
|
|
99
126
|
mcpClientFactory?: ResolveMcpToolsOptions['clientFactory']
|
|
127
|
+
/** See `OpenAIProviderOptions.mcpPool` — same semantics. */
|
|
128
|
+
mcpPool?: ResolveMcpToolsOptions['pool']
|
|
100
129
|
}
|
|
101
130
|
|
|
102
131
|
export class GeminiProvider implements Provider {
|
|
@@ -104,13 +133,17 @@ export class GeminiProvider implements Provider {
|
|
|
104
133
|
private readonly models: GeminiModelsClient
|
|
105
134
|
private readonly defaultModel: string
|
|
106
135
|
private readonly defaultMaxTokens: number
|
|
136
|
+
private readonly defaultEmbedModel: string
|
|
107
137
|
private readonly mcpClientFactory?: ResolveMcpToolsOptions['clientFactory']
|
|
138
|
+
private readonly mcpPool?: ResolveMcpToolsOptions['pool']
|
|
108
139
|
|
|
109
140
|
constructor(name: string, config: GeminiProviderConfig, options: GeminiProviderOptions = {}) {
|
|
110
141
|
this.name = name
|
|
111
142
|
this.defaultModel = config.defaultModel ?? DEFAULT_GEMINI_MODEL
|
|
112
143
|
this.defaultMaxTokens = config.defaultMaxTokens ?? 4096
|
|
144
|
+
this.defaultEmbedModel = config.defaultEmbedModel ?? DEFAULT_GEMINI_EMBED_MODEL
|
|
113
145
|
this.mcpClientFactory = options.mcpClientFactory
|
|
146
|
+
this.mcpPool = options.mcpPool
|
|
114
147
|
if (options.client) {
|
|
115
148
|
this.models = options.client.models
|
|
116
149
|
} else {
|
|
@@ -169,18 +202,122 @@ export class GeminiProvider implements Provider {
|
|
|
169
202
|
return response.totalTokens ?? 0
|
|
170
203
|
}
|
|
171
204
|
|
|
205
|
+
/**
|
|
206
|
+
* Gemini embeddings via `ai.models.embedContent`. Returns one
|
|
207
|
+
* vector per input text. `usage.inputTokens` is `0` — Gemini's
|
|
208
|
+
* embed endpoint doesn't surface token counts in the response
|
|
209
|
+
* for the Gemini Developer API tier (Vertex's request-level
|
|
210
|
+
* metadata exposes billable characters, but that's a different
|
|
211
|
+
* accounting unit and not the framework's contract). Apps that
|
|
212
|
+
* need exact embed-token usage call `countTokens` separately
|
|
213
|
+
* before the call.
|
|
214
|
+
*/
|
|
215
|
+
/**
|
|
216
|
+
* Gemini has no dedicated transcription endpoint, so we wrap a
|
|
217
|
+
* chat call: an AudioBlock + a system message that tells the
|
|
218
|
+
* model to transcribe verbatim. Apps that want OpenAI-style
|
|
219
|
+
* Whisper transcription with `language` / `duration` metadata
|
|
220
|
+
* route to OpenAI (or local Whisper via Ollama).
|
|
221
|
+
*
|
|
222
|
+
* `options.prompt` threads into the system instruction —
|
|
223
|
+
* useful for style/vocabulary hints. `options.language` is
|
|
224
|
+
* surfaced to the model in the system prompt (Gemini doesn't
|
|
225
|
+
* have a dedicated language field).
|
|
226
|
+
*/
|
|
227
|
+
async transcribe(
|
|
228
|
+
audio: AudioSource,
|
|
229
|
+
options: TranscribeOptions = {},
|
|
230
|
+
): Promise<TranscribeResult> {
|
|
231
|
+
const lines = [
|
|
232
|
+
'Transcribe the attached audio verbatim. Output ONLY the transcribed text — no preamble, no quotes, no commentary.',
|
|
233
|
+
options.language ? `Audio language: ${options.language}.` : undefined,
|
|
234
|
+
options.prompt ? `Style / vocabulary hints: ${options.prompt}` : undefined,
|
|
235
|
+
].filter((s): s is string => s !== undefined)
|
|
236
|
+
const system = lines.join(' ')
|
|
237
|
+
const chatResult = await this.chat(
|
|
238
|
+
[
|
|
239
|
+
{
|
|
240
|
+
role: 'user',
|
|
241
|
+
content: [{ type: 'audio', source: audio }],
|
|
242
|
+
},
|
|
243
|
+
],
|
|
244
|
+
{
|
|
245
|
+
system,
|
|
246
|
+
...(options.model !== undefined ? { model: options.model } : {}),
|
|
247
|
+
...(options.signal !== undefined ? { signal: options.signal } : {}),
|
|
248
|
+
},
|
|
249
|
+
)
|
|
250
|
+
return {
|
|
251
|
+
text: chatResult.text,
|
|
252
|
+
model: chatResult.model,
|
|
253
|
+
raw: chatResult.raw,
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
async embed(
|
|
258
|
+
texts: readonly string[],
|
|
259
|
+
options: EmbedOptions = {},
|
|
260
|
+
): Promise<EmbedResult<{ embeddings?: Array<{ values?: number[] }> }>> {
|
|
261
|
+
const model = options.model ?? this.defaultEmbedModel
|
|
262
|
+
const config: { outputDimensionality?: number; abortSignal?: AbortSignal } = {}
|
|
263
|
+
if (options.dimensions !== undefined) config.outputDimensionality = options.dimensions
|
|
264
|
+
if (options.signal !== undefined) config.abortSignal = options.signal
|
|
265
|
+
if (!this.models.embedContent) {
|
|
266
|
+
throw new BrainError(
|
|
267
|
+
`GeminiProvider.embed: underlying SDK does not implement embedContent. This usually means a test stub omitted it.`,
|
|
268
|
+
{ context: { provider: this.name } },
|
|
269
|
+
)
|
|
270
|
+
}
|
|
271
|
+
const response = await this.models.embedContent({
|
|
272
|
+
model,
|
|
273
|
+
contents: texts as string[],
|
|
274
|
+
...(Object.keys(config).length > 0 ? { config } : {}),
|
|
275
|
+
})
|
|
276
|
+
const embeddings = (response.embeddings ?? []).map((e) => e.values ?? [])
|
|
277
|
+
return {
|
|
278
|
+
embeddings,
|
|
279
|
+
model,
|
|
280
|
+
usage: { inputTokens: 0 },
|
|
281
|
+
raw: response,
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
/**
|
|
286
|
+
* Resolve MCP tool descriptors for `servers`, threading the
|
|
287
|
+
* provider's optional `clientFactory` (test seam) and `mcpPool`
|
|
288
|
+
* (long-lived connections) through. Caller invokes
|
|
289
|
+
* `resolved.close()` in `finally` — a no-op when the pool owns
|
|
290
|
+
* lifetimes.
|
|
291
|
+
*/
|
|
292
|
+
private resolveMcp(servers: readonly MCPServer[]): Promise<{
|
|
293
|
+
tools: Tool[]
|
|
294
|
+
close: () => Promise<void>
|
|
295
|
+
}> {
|
|
296
|
+
if (servers.length === 0) {
|
|
297
|
+
return Promise.resolve({ tools: [], close: async () => {} })
|
|
298
|
+
}
|
|
299
|
+
return resolveMcpTools(servers, {
|
|
300
|
+
...(this.mcpClientFactory ? { clientFactory: this.mcpClientFactory } : {}),
|
|
301
|
+
...(this.mcpPool ? { pool: this.mcpPool } : {}),
|
|
302
|
+
})
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
runWithTools(
|
|
306
|
+
messages: readonly Message[],
|
|
307
|
+
tools: readonly Tool[],
|
|
308
|
+
options: RunWithToolsOptionsWithSuspend,
|
|
309
|
+
): Promise<AgentResult | SuspendedRun>
|
|
310
|
+
runWithTools(
|
|
311
|
+
messages: readonly Message[],
|
|
312
|
+
tools: readonly Tool[],
|
|
313
|
+
options?: RunWithToolsOptions,
|
|
314
|
+
): Promise<AgentResult>
|
|
172
315
|
async runWithTools(
|
|
173
316
|
messages: readonly Message[],
|
|
174
317
|
tools: readonly Tool[],
|
|
175
318
|
options: RunWithToolsOptions = {},
|
|
176
|
-
): Promise<AgentResult> {
|
|
177
|
-
const
|
|
178
|
-
const resolved =
|
|
179
|
-
mcpServers.length > 0
|
|
180
|
-
? await resolveMcpTools(mcpServers, {
|
|
181
|
-
...(this.mcpClientFactory ? { clientFactory: this.mcpClientFactory } : {}),
|
|
182
|
-
})
|
|
183
|
-
: { tools: [] as Tool[], close: async () => {} }
|
|
319
|
+
): Promise<AgentResult | SuspendedRun> {
|
|
320
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
184
321
|
try {
|
|
185
322
|
return await this._runLoop(messages, [...tools, ...resolved.tools], options)
|
|
186
323
|
} finally {
|
|
@@ -192,7 +329,7 @@ export class GeminiProvider implements Provider {
|
|
|
192
329
|
messages: readonly Message[],
|
|
193
330
|
tools: readonly Tool[],
|
|
194
331
|
options: RunWithToolsOptions,
|
|
195
|
-
): Promise<AgentResult> {
|
|
332
|
+
): Promise<AgentResult | SuspendedRun> {
|
|
196
333
|
const maxIterations = options.maxIterations ?? 10
|
|
197
334
|
const toolMap = new Map<string, Tool>(tools.map((t) => [t.name, t]))
|
|
198
335
|
const workingMessages: Message[] = [...messages]
|
|
@@ -205,6 +342,7 @@ export class GeminiProvider implements Provider {
|
|
|
205
342
|
let iterations = 0
|
|
206
343
|
|
|
207
344
|
while (true) {
|
|
345
|
+
checkAborted(options.signal)
|
|
208
346
|
const params = this.buildParams(workingMessages, options, tools)
|
|
209
347
|
const response = await this.models.generateContent(params)
|
|
210
348
|
addUsage(aggregated, response.usageMetadata)
|
|
@@ -234,30 +372,28 @@ export class GeminiProvider implements Provider {
|
|
|
234
372
|
}
|
|
235
373
|
|
|
236
374
|
const resultBlocks: ContentBlock[] = []
|
|
237
|
-
for (
|
|
238
|
-
const
|
|
239
|
-
if (
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
375
|
+
for (let i = 0; i < toolUses.length; i++) {
|
|
376
|
+
const call = toolUses[i]!
|
|
377
|
+
if (options.shouldSuspend && await options.shouldSuspend(call, options.context)) {
|
|
378
|
+
return {
|
|
379
|
+
status: 'suspended',
|
|
380
|
+
pendingToolCalls: toolUses.slice(i),
|
|
381
|
+
state: { messages: workingMessages, iterations, usage: aggregated },
|
|
382
|
+
}
|
|
245
383
|
}
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
}
|
|
255
|
-
const resultBlock: ToolResultBlock = {
|
|
384
|
+
const { content, isError } = await runToolWithRecovery(
|
|
385
|
+
toolMap.get(call.name),
|
|
386
|
+
call.name,
|
|
387
|
+
call.id,
|
|
388
|
+
call.input,
|
|
389
|
+
options,
|
|
390
|
+
)
|
|
391
|
+
resultBlocks.push({
|
|
256
392
|
type: 'tool_result',
|
|
257
393
|
toolUseId: call.id,
|
|
258
|
-
content
|
|
259
|
-
|
|
260
|
-
|
|
394
|
+
content,
|
|
395
|
+
...(isError ? { isError: true } : {}),
|
|
396
|
+
} satisfies ToolResultBlock)
|
|
261
397
|
}
|
|
262
398
|
workingMessages.push({ role: 'user', content: resultBlocks })
|
|
263
399
|
|
|
@@ -274,6 +410,364 @@ export class GeminiProvider implements Provider {
|
|
|
274
410
|
}
|
|
275
411
|
}
|
|
276
412
|
|
|
413
|
+
async runWithToolsAndSchema<T>(
|
|
414
|
+
messages: readonly Message[],
|
|
415
|
+
tools: readonly Tool[],
|
|
416
|
+
schema: OutputSchema<T>,
|
|
417
|
+
options: RunWithToolsOptions = {},
|
|
418
|
+
): Promise<AgentGenerateResult<T>> {
|
|
419
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
420
|
+
try {
|
|
421
|
+
return await this._runLoopWithSchema([...tools, ...resolved.tools], messages, schema, options)
|
|
422
|
+
} finally {
|
|
423
|
+
await resolved.close()
|
|
424
|
+
}
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
private async _runLoopWithSchema<T>(
|
|
428
|
+
tools: readonly Tool[],
|
|
429
|
+
messages: readonly Message[],
|
|
430
|
+
schema: OutputSchema<T>,
|
|
431
|
+
options: RunWithToolsOptions,
|
|
432
|
+
): Promise<AgentGenerateResult<T>> {
|
|
433
|
+
const maxIterations = options.maxIterations ?? 10
|
|
434
|
+
const toolMap = new Map<string, Tool>(tools.map((t) => [t.name, t]))
|
|
435
|
+
const workingMessages: Message[] = [...messages]
|
|
436
|
+
const aggregated: ChatUsage = {
|
|
437
|
+
inputTokens: 0,
|
|
438
|
+
outputTokens: 0,
|
|
439
|
+
cacheReadTokens: 0,
|
|
440
|
+
cacheCreationTokens: 0,
|
|
441
|
+
}
|
|
442
|
+
let iterations = 0
|
|
443
|
+
|
|
444
|
+
while (true) {
|
|
445
|
+
const params = this.buildParams(workingMessages, options, tools)
|
|
446
|
+
params.config = {
|
|
447
|
+
...(params.config ?? {}),
|
|
448
|
+
responseMimeType: 'application/json',
|
|
449
|
+
responseJsonSchema: schema.jsonSchema,
|
|
450
|
+
}
|
|
451
|
+
const response = await this.models.generateContent(params)
|
|
452
|
+
addUsage(aggregated, response.usageMetadata)
|
|
453
|
+
|
|
454
|
+
const candidate = response.candidates?.[0]
|
|
455
|
+
if (!candidate) {
|
|
456
|
+
throw new BrainError('GeminiProvider: response had no candidates.')
|
|
457
|
+
}
|
|
458
|
+
const parts = candidate.content?.parts ?? []
|
|
459
|
+
const assistantContent = fromGeminiParts(parts)
|
|
460
|
+
workingMessages.push({ role: 'assistant', content: assistantContent })
|
|
461
|
+
|
|
462
|
+
const toolUses = (Array.isArray(assistantContent) ? assistantContent : []).filter(
|
|
463
|
+
(b): b is ToolUseBlock => b.type === 'tool_use',
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
if (toolUses.length === 0) {
|
|
467
|
+
const text = typeof assistantContent === 'string'
|
|
468
|
+
? assistantContent
|
|
469
|
+
: candidateText(candidate)
|
|
470
|
+
return {
|
|
471
|
+
value: parseGenerated(text, schema),
|
|
472
|
+
text,
|
|
473
|
+
messages: workingMessages,
|
|
474
|
+
iterations,
|
|
475
|
+
stopReason: candidate.finishReason ? String(candidate.finishReason) : 'stop',
|
|
476
|
+
usage: aggregated,
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
const resultBlocks: ContentBlock[] = []
|
|
481
|
+
for (const call of toolUses) {
|
|
482
|
+
const { content, isError } = await runToolWithRecovery(
|
|
483
|
+
toolMap.get(call.name),
|
|
484
|
+
call.name,
|
|
485
|
+
call.id,
|
|
486
|
+
call.input,
|
|
487
|
+
options,
|
|
488
|
+
)
|
|
489
|
+
resultBlocks.push({
|
|
490
|
+
type: 'tool_result',
|
|
491
|
+
toolUseId: call.id,
|
|
492
|
+
content,
|
|
493
|
+
...(isError ? { isError: true } : {}),
|
|
494
|
+
} satisfies ToolResultBlock)
|
|
495
|
+
}
|
|
496
|
+
workingMessages.push({ role: 'user', content: resultBlocks })
|
|
497
|
+
|
|
498
|
+
iterations++
|
|
499
|
+
if (iterations >= maxIterations) {
|
|
500
|
+
const text = candidateText(candidate)
|
|
501
|
+
return {
|
|
502
|
+
value: parseGenerated(text, schema),
|
|
503
|
+
text,
|
|
504
|
+
messages: workingMessages,
|
|
505
|
+
iterations,
|
|
506
|
+
stopReason: 'max_iterations',
|
|
507
|
+
usage: aggregated,
|
|
508
|
+
}
|
|
509
|
+
}
|
|
510
|
+
}
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
async *streamWithTools(
|
|
514
|
+
messages: readonly Message[],
|
|
515
|
+
tools: readonly Tool[],
|
|
516
|
+
options: RunWithToolsOptions = {},
|
|
517
|
+
): AsyncIterable<AgentStreamEvent> {
|
|
518
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
519
|
+
try {
|
|
520
|
+
yield* this._streamLoop(messages, [...tools, ...resolved.tools], options)
|
|
521
|
+
} finally {
|
|
522
|
+
await resolved.close()
|
|
523
|
+
}
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
private async *_streamLoop(
|
|
527
|
+
messages: readonly Message[],
|
|
528
|
+
tools: readonly Tool[],
|
|
529
|
+
options: RunWithToolsOptions,
|
|
530
|
+
): AsyncIterable<AgentStreamEvent> {
|
|
531
|
+
const maxIterations = options.maxIterations ?? 10
|
|
532
|
+
const toolMap = new Map<string, Tool>(tools.map((t) => [t.name, t]))
|
|
533
|
+
const workingMessages: Message[] = [...messages]
|
|
534
|
+
const aggregated: ChatUsage = {
|
|
535
|
+
inputTokens: 0,
|
|
536
|
+
outputTokens: 0,
|
|
537
|
+
cacheReadTokens: 0,
|
|
538
|
+
cacheCreationTokens: 0,
|
|
539
|
+
}
|
|
540
|
+
let iterations = 0
|
|
541
|
+
|
|
542
|
+
while (true) {
|
|
543
|
+
checkAborted(options.signal)
|
|
544
|
+
yield { type: 'iteration_start', iteration: iterations }
|
|
545
|
+
|
|
546
|
+
const params = this.buildParams(workingMessages, options, tools)
|
|
547
|
+
const stream = await this.models.generateContentStream(params)
|
|
548
|
+
|
|
549
|
+
const accumulatedParts: Part[] = []
|
|
550
|
+
let finishReason: string | null = null
|
|
551
|
+
let lastUsage: ChatUsage | undefined
|
|
552
|
+
|
|
553
|
+
for await (const chunk of stream) {
|
|
554
|
+
const candidate = chunk.candidates?.[0]
|
|
555
|
+
const chunkParts = candidate?.content?.parts ?? []
|
|
556
|
+
for (const part of chunkParts) {
|
|
557
|
+
if (typeof part.text === 'string' && part.text.length > 0) {
|
|
558
|
+
yield { type: 'text', delta: part.text }
|
|
559
|
+
}
|
|
560
|
+
}
|
|
561
|
+
accumulatedParts.push(...chunkParts)
|
|
562
|
+
if (candidate?.finishReason) finishReason = String(candidate.finishReason)
|
|
563
|
+
if (chunk.usageMetadata) lastUsage = toUsage(chunk.usageMetadata)
|
|
564
|
+
}
|
|
565
|
+
if (lastUsage) {
|
|
566
|
+
aggregated.inputTokens += lastUsage.inputTokens
|
|
567
|
+
aggregated.outputTokens += lastUsage.outputTokens
|
|
568
|
+
aggregated.cacheReadTokens += lastUsage.cacheReadTokens
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
yield { type: 'iteration_end', iteration: iterations, stopReason: finishReason }
|
|
572
|
+
|
|
573
|
+
const assistantContent = fromGeminiParts(accumulatedParts)
|
|
574
|
+
workingMessages.push({ role: 'assistant', content: assistantContent })
|
|
575
|
+
|
|
576
|
+
const toolUses = (Array.isArray(assistantContent) ? assistantContent : []).filter(
|
|
577
|
+
(b): b is ToolUseBlock => b.type === 'tool_use',
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
if (toolUses.length === 0) {
|
|
581
|
+
yield {
|
|
582
|
+
type: 'stop',
|
|
583
|
+
stopReason: finishReason ?? 'stop',
|
|
584
|
+
iterations,
|
|
585
|
+
usage: aggregated,
|
|
586
|
+
messages: workingMessages,
|
|
587
|
+
}
|
|
588
|
+
return
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
const resultBlocks: ContentBlock[] = []
|
|
592
|
+
for (const call of toolUses) {
|
|
593
|
+
yield { type: 'tool_use', id: call.id, name: call.name, input: call.input }
|
|
594
|
+
const { content, isError } = await runToolWithRecovery(
|
|
595
|
+
toolMap.get(call.name),
|
|
596
|
+
call.name,
|
|
597
|
+
call.id,
|
|
598
|
+
call.input,
|
|
599
|
+
options,
|
|
600
|
+
)
|
|
601
|
+
resultBlocks.push({
|
|
602
|
+
type: 'tool_result',
|
|
603
|
+
toolUseId: call.id,
|
|
604
|
+
content,
|
|
605
|
+
...(isError ? { isError: true } : {}),
|
|
606
|
+
} satisfies ToolResultBlock)
|
|
607
|
+
yield {
|
|
608
|
+
type: 'tool_result',
|
|
609
|
+
id: call.id,
|
|
610
|
+
name: call.name,
|
|
611
|
+
content,
|
|
612
|
+
isError,
|
|
613
|
+
}
|
|
614
|
+
}
|
|
615
|
+
workingMessages.push({ role: 'user', content: resultBlocks })
|
|
616
|
+
|
|
617
|
+
iterations++
|
|
618
|
+
if (iterations >= maxIterations) {
|
|
619
|
+
yield {
|
|
620
|
+
type: 'stop',
|
|
621
|
+
stopReason: 'max_iterations',
|
|
622
|
+
iterations,
|
|
623
|
+
usage: aggregated,
|
|
624
|
+
messages: workingMessages,
|
|
625
|
+
}
|
|
626
|
+
return
|
|
627
|
+
}
|
|
628
|
+
}
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
async *streamWithToolsAndSchema<T>(
|
|
632
|
+
messages: readonly Message[],
|
|
633
|
+
tools: readonly Tool[],
|
|
634
|
+
schema: OutputSchema<T>,
|
|
635
|
+
options: RunWithToolsOptions = {},
|
|
636
|
+
): AsyncIterable<AgentStreamEvent<T>> {
|
|
637
|
+
const resolved = await this.resolveMcp(options.mcpServers ?? [])
|
|
638
|
+
try {
|
|
639
|
+
yield* this._streamLoopWithSchema(
|
|
640
|
+
[...tools, ...resolved.tools],
|
|
641
|
+
messages,
|
|
642
|
+
schema,
|
|
643
|
+
options,
|
|
644
|
+
)
|
|
645
|
+
} finally {
|
|
646
|
+
await resolved.close()
|
|
647
|
+
}
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
private async *_streamLoopWithSchema<T>(
|
|
651
|
+
tools: readonly Tool[],
|
|
652
|
+
messages: readonly Message[],
|
|
653
|
+
schema: OutputSchema<T>,
|
|
654
|
+
options: RunWithToolsOptions,
|
|
655
|
+
): AsyncIterable<AgentStreamEvent<T>> {
|
|
656
|
+
const maxIterations = options.maxIterations ?? 10
|
|
657
|
+
const toolMap = new Map<string, Tool>(tools.map((t) => [t.name, t]))
|
|
658
|
+
const workingMessages: Message[] = [...messages]
|
|
659
|
+
const aggregated: ChatUsage = {
|
|
660
|
+
inputTokens: 0,
|
|
661
|
+
outputTokens: 0,
|
|
662
|
+
cacheReadTokens: 0,
|
|
663
|
+
cacheCreationTokens: 0,
|
|
664
|
+
}
|
|
665
|
+
let iterations = 0
|
|
666
|
+
|
|
667
|
+
while (true) {
|
|
668
|
+
checkAborted(options.signal)
|
|
669
|
+
yield { type: 'iteration_start', iteration: iterations }
|
|
670
|
+
|
|
671
|
+
const params = this.buildParams(workingMessages, options, tools)
|
|
672
|
+
params.config = {
|
|
673
|
+
...(params.config ?? {}),
|
|
674
|
+
responseMimeType: 'application/json',
|
|
675
|
+
responseJsonSchema: schema.jsonSchema,
|
|
676
|
+
}
|
|
677
|
+
const stream = await this.models.generateContentStream(params)
|
|
678
|
+
|
|
679
|
+
const accumulatedParts: Part[] = []
|
|
680
|
+
let textBuf = ''
|
|
681
|
+
let finishReason: string | null = null
|
|
682
|
+
let lastUsage: ChatUsage | undefined
|
|
683
|
+
|
|
684
|
+
for await (const chunk of stream) {
|
|
685
|
+
const candidate = chunk.candidates?.[0]
|
|
686
|
+
const chunkParts = candidate?.content?.parts ?? []
|
|
687
|
+
for (const part of chunkParts) {
|
|
688
|
+
if (typeof part.text === 'string' && part.text.length > 0) {
|
|
689
|
+
textBuf += part.text
|
|
690
|
+
yield { type: 'text', delta: part.text }
|
|
691
|
+
}
|
|
692
|
+
}
|
|
693
|
+
accumulatedParts.push(...chunkParts)
|
|
694
|
+
if (candidate?.finishReason) finishReason = String(candidate.finishReason)
|
|
695
|
+
if (chunk.usageMetadata) lastUsage = toUsage(chunk.usageMetadata)
|
|
696
|
+
}
|
|
697
|
+
if (lastUsage) {
|
|
698
|
+
aggregated.inputTokens += lastUsage.inputTokens
|
|
699
|
+
aggregated.outputTokens += lastUsage.outputTokens
|
|
700
|
+
aggregated.cacheReadTokens += lastUsage.cacheReadTokens
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
yield { type: 'iteration_end', iteration: iterations, stopReason: finishReason }
|
|
704
|
+
|
|
705
|
+
const assistantContent = fromGeminiParts(accumulatedParts)
|
|
706
|
+
workingMessages.push({ role: 'assistant', content: assistantContent })
|
|
707
|
+
|
|
708
|
+
const toolUses = (Array.isArray(assistantContent) ? assistantContent : []).filter(
|
|
709
|
+
(b): b is ToolUseBlock => b.type === 'tool_use',
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
if (toolUses.length === 0) {
|
|
713
|
+
const text = textBuf
|
|
714
|
+
const value = parseGenerated(text, schema)
|
|
715
|
+
yield {
|
|
716
|
+
type: 'stop',
|
|
717
|
+
stopReason: finishReason ?? 'stop',
|
|
718
|
+
iterations,
|
|
719
|
+
usage: aggregated,
|
|
720
|
+
messages: workingMessages,
|
|
721
|
+
value,
|
|
722
|
+
text,
|
|
723
|
+
} as AgentStreamEvent<T>
|
|
724
|
+
return
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
const resultBlocks: ContentBlock[] = []
|
|
728
|
+
for (const call of toolUses) {
|
|
729
|
+
yield { type: 'tool_use', id: call.id, name: call.name, input: call.input }
|
|
730
|
+
const { content, isError } = await runToolWithRecovery(
|
|
731
|
+
toolMap.get(call.name),
|
|
732
|
+
call.name,
|
|
733
|
+
call.id,
|
|
734
|
+
call.input,
|
|
735
|
+
options,
|
|
736
|
+
)
|
|
737
|
+
resultBlocks.push({
|
|
738
|
+
type: 'tool_result',
|
|
739
|
+
toolUseId: call.id,
|
|
740
|
+
content,
|
|
741
|
+
...(isError ? { isError: true } : {}),
|
|
742
|
+
} satisfies ToolResultBlock)
|
|
743
|
+
yield {
|
|
744
|
+
type: 'tool_result',
|
|
745
|
+
id: call.id,
|
|
746
|
+
name: call.name,
|
|
747
|
+
content,
|
|
748
|
+
isError,
|
|
749
|
+
}
|
|
750
|
+
}
|
|
751
|
+
workingMessages.push({ role: 'user', content: resultBlocks })
|
|
752
|
+
|
|
753
|
+
iterations++
|
|
754
|
+
if (iterations >= maxIterations) {
|
|
755
|
+
const text = textBuf
|
|
756
|
+
const value = parseGenerated(text, schema)
|
|
757
|
+
yield {
|
|
758
|
+
type: 'stop',
|
|
759
|
+
stopReason: 'max_iterations',
|
|
760
|
+
iterations,
|
|
761
|
+
usage: aggregated,
|
|
762
|
+
messages: workingMessages,
|
|
763
|
+
value,
|
|
764
|
+
text,
|
|
765
|
+
} as AgentStreamEvent<T>
|
|
766
|
+
return
|
|
767
|
+
}
|
|
768
|
+
}
|
|
769
|
+
}
|
|
770
|
+
|
|
277
771
|
async generate<T>(
|
|
278
772
|
messages: readonly Message[],
|
|
279
773
|
schema: OutputSchema<T>,
|
|
@@ -317,18 +811,27 @@ export class GeminiProvider implements Provider {
|
|
|
317
811
|
config.systemInstruction = systemText
|
|
318
812
|
}
|
|
319
813
|
|
|
814
|
+
const configTools: NonNullable<GenerateContentConfig['tools']> = []
|
|
320
815
|
if (tools.length > 0) {
|
|
321
816
|
const functionDeclarations: FunctionDeclaration[] = tools.map((t) => ({
|
|
322
817
|
name: t.name,
|
|
323
818
|
description: t.description,
|
|
324
819
|
parametersJsonSchema: t.inputSchema,
|
|
325
820
|
}))
|
|
326
|
-
|
|
821
|
+
configTools.push({ functionDeclarations })
|
|
822
|
+
}
|
|
823
|
+
if (options.serverTools && options.serverTools.length > 0) {
|
|
824
|
+
configTools.push(...geminiServerTools(options.serverTools))
|
|
825
|
+
}
|
|
826
|
+
if (configTools.length > 0) {
|
|
827
|
+
config.tools = configTools
|
|
327
828
|
}
|
|
328
829
|
|
|
329
830
|
const thinking = buildThinkingConfig(options)
|
|
330
831
|
if (thinking !== undefined) config.thinkingConfig = thinking
|
|
331
832
|
|
|
833
|
+
if (options.signal !== undefined) config.abortSignal = options.signal
|
|
834
|
+
|
|
332
835
|
return { model, contents, config }
|
|
333
836
|
}
|
|
334
837
|
|
|
@@ -356,6 +859,13 @@ export class GeminiProvider implements Provider {
|
|
|
356
859
|
|
|
357
860
|
// ─── Shape converters ─────────────────────────────────────────────────────
|
|
358
861
|
|
|
862
|
+
/** Throw a DOMException-shaped abort error if the signal has fired. */
|
|
863
|
+
function checkAborted(signal: AbortSignal | undefined): void {
|
|
864
|
+
if (signal?.aborted) {
|
|
865
|
+
throw signal.reason ?? new DOMException('Aborted', 'AbortError')
|
|
866
|
+
}
|
|
867
|
+
}
|
|
868
|
+
|
|
359
869
|
function systemPromptText(system: SystemPrompt | undefined): string {
|
|
360
870
|
if (system === undefined) return ''
|
|
361
871
|
if (typeof system === 'string') return system
|
|
@@ -369,6 +879,25 @@ function toGeminiParts(content: string | ContentBlock[]): Part[] {
|
|
|
369
879
|
for (const block of content) {
|
|
370
880
|
if (block.type === 'text') {
|
|
371
881
|
parts.push({ text: block.text })
|
|
882
|
+
} else if (block.type === 'image' || block.type === 'document' || block.type === 'audio') {
|
|
883
|
+
// All three media block types share Gemini's inlineData /
|
|
884
|
+
// fileData wire shape; only the MIME differs. Base64 →
|
|
885
|
+
// inlineData. URL → fileData with fileUri. Gemini's
|
|
886
|
+
// fileData accepts public HTTPS and gs:// URIs; arbitrary
|
|
887
|
+
// private URLs need to be fetched and converted to base64
|
|
888
|
+
// by the app.
|
|
889
|
+
if (block.source.type === 'base64') {
|
|
890
|
+
parts.push({
|
|
891
|
+
inlineData: { mimeType: block.source.mediaType, data: block.source.data },
|
|
892
|
+
})
|
|
893
|
+
} else {
|
|
894
|
+
parts.push({
|
|
895
|
+
fileData: {
|
|
896
|
+
fileUri: block.source.url,
|
|
897
|
+
mimeType: guessMimeFromUrl(block.source.url, block.type),
|
|
898
|
+
},
|
|
899
|
+
})
|
|
900
|
+
}
|
|
372
901
|
} else if (block.type === 'tool_use') {
|
|
373
902
|
parts.push({
|
|
374
903
|
functionCall: {
|
|
@@ -394,6 +923,69 @@ function toGeminiParts(content: string | ContentBlock[]): Part[] {
|
|
|
394
923
|
return parts
|
|
395
924
|
}
|
|
396
925
|
|
|
926
|
+
/**
|
|
927
|
+
* Gemini's `fileData.mimeType` is required, but our media-block
|
|
928
|
+
* URL-source variants don't carry it (the app may not know).
|
|
929
|
+
* Best-effort from the file extension. Default falls back to the
|
|
930
|
+
* block type's most-common MIME (jpeg for images, pdf for
|
|
931
|
+
* documents, mp3 for audio).
|
|
932
|
+
*/
|
|
933
|
+
/**
|
|
934
|
+
* Translate framework `ServerTool[]` into Gemini's typed entries
|
|
935
|
+
* (`googleSearch` / `codeExecution` / `urlContext`). Anthropic-
|
|
936
|
+
* specific tools (`web_fetch`) throw with clear guidance.
|
|
937
|
+
*
|
|
938
|
+
* Gemini's server tools have no per-tool config — they're enabled
|
|
939
|
+
* with empty `{}` objects. Domain allowlists / max_uses /
|
|
940
|
+
* blocked_domains on `web_search` are silently dropped (Gemini
|
|
941
|
+
* doesn't accept them).
|
|
942
|
+
*/
|
|
943
|
+
function geminiServerTools(
|
|
944
|
+
serverTools: readonly ServerTool[],
|
|
945
|
+
): NonNullable<GenerateContentConfig['tools']> {
|
|
946
|
+
const out: NonNullable<GenerateContentConfig['tools']> = []
|
|
947
|
+
for (const t of serverTools) {
|
|
948
|
+
if (t.type === 'web_search') {
|
|
949
|
+
out.push({ googleSearch: {} })
|
|
950
|
+
} else if (t.type === 'code_execution') {
|
|
951
|
+
out.push({ codeExecution: {} })
|
|
952
|
+
} else if (t.type === 'url_context') {
|
|
953
|
+
out.push({ urlContext: {} })
|
|
954
|
+
} else if (t.type === 'web_fetch') {
|
|
955
|
+
throw new BrainError(
|
|
956
|
+
'GeminiProvider: server tool `web_fetch` is Anthropic-only. Use `url_context` for Gemini or route the call to Anthropic.',
|
|
957
|
+
{ context: { provider: 'google' } },
|
|
958
|
+
)
|
|
959
|
+
}
|
|
960
|
+
}
|
|
961
|
+
return out
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
function guessMimeFromUrl(
|
|
965
|
+
url: string,
|
|
966
|
+
kind: 'image' | 'document' | 'audio',
|
|
967
|
+
): string {
|
|
968
|
+
const lower = url.toLowerCase().split('?')[0] ?? ''
|
|
969
|
+
if (kind === 'image') {
|
|
970
|
+
if (lower.endsWith('.png')) return 'image/png'
|
|
971
|
+
if (lower.endsWith('.webp')) return 'image/webp'
|
|
972
|
+
if (lower.endsWith('.gif')) return 'image/gif'
|
|
973
|
+
if (lower.endsWith('.jpg') || lower.endsWith('.jpeg')) return 'image/jpeg'
|
|
974
|
+
return 'image/jpeg'
|
|
975
|
+
}
|
|
976
|
+
if (kind === 'document') {
|
|
977
|
+
return 'application/pdf'
|
|
978
|
+
}
|
|
979
|
+
// audio
|
|
980
|
+
if (lower.endsWith('.mp3')) return 'audio/mp3'
|
|
981
|
+
if (lower.endsWith('.wav')) return 'audio/wav'
|
|
982
|
+
if (lower.endsWith('.ogg')) return 'audio/ogg'
|
|
983
|
+
if (lower.endsWith('.flac')) return 'audio/flac'
|
|
984
|
+
if (lower.endsWith('.webm')) return 'audio/webm'
|
|
985
|
+
if (lower.endsWith('.aac') || lower.endsWith('.m4a')) return 'audio/aac'
|
|
986
|
+
return 'audio/mp3'
|
|
987
|
+
}
|
|
988
|
+
|
|
397
989
|
function fromGeminiParts(parts: readonly Part[]): string | ContentBlock[] {
|
|
398
990
|
const blocks: ContentBlock[] = []
|
|
399
991
|
for (const part of parts) {
|