@strav/brain 0.2.11 → 0.2.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +12 -0
- package/README.md +13 -0
- package/package.json +3 -3
- package/src/index.ts +1 -0
- package/src/providers/google_provider.ts +398 -0
package/CHANGELOG.md
CHANGED
|
@@ -1,5 +1,17 @@
|
|
|
1
1
|
# Changelog
|
|
2
2
|
|
|
3
|
+
## 0.2.12
|
|
4
|
+
|
|
5
|
+
### Added
|
|
6
|
+
|
|
7
|
+
- **GoogleProvider** — Support for Google's Gemini models
|
|
8
|
+
- Native Gemini API integration using `generativelanguage.googleapis.com`
|
|
9
|
+
- Support for completion, streaming, function calling, and embeddings
|
|
10
|
+
- Models: `gemini-2.0-flash`, `gemini-2.5-flash`, `gemini-3-pro-preview`
|
|
11
|
+
- Authentication via `x-goog-api-key` header
|
|
12
|
+
- Zero new dependencies — uses raw `fetch()` following existing patterns
|
|
13
|
+
- Comprehensive test suite with 29 tests covering all functionality
|
|
14
|
+
|
|
3
15
|
## 0.6.0
|
|
4
16
|
|
|
5
17
|
### Added
|
package/README.md
CHANGED
|
@@ -14,6 +14,7 @@ Requires `@strav/core` as a peer dependency.
|
|
|
14
14
|
|
|
15
15
|
- **Anthropic** (Claude)
|
|
16
16
|
- **OpenAI** (GPT, also works with DeepSeek via custom `baseUrl`)
|
|
17
|
+
- **Google** (Gemini)
|
|
17
18
|
|
|
18
19
|
## Usage
|
|
19
20
|
|
|
@@ -73,6 +74,14 @@ class ResearchAgent extends Agent {
|
|
|
73
74
|
tools = [searchTool, summarizeTool]
|
|
74
75
|
}
|
|
75
76
|
|
|
77
|
+
// Google Gemini agent
|
|
78
|
+
class GeminiResearchAgent extends Agent {
|
|
79
|
+
provider = 'google'
|
|
80
|
+
model = 'gemini-2.0-flash'
|
|
81
|
+
instructions = 'You are a research assistant powered by Gemini.'
|
|
82
|
+
tools = [searchTool, summarizeTool]
|
|
83
|
+
}
|
|
84
|
+
|
|
76
85
|
// Run agent with context
|
|
77
86
|
const runner = brain.agent(ResearchAgent)
|
|
78
87
|
runner.context({ userId: '123' }) // Pass context to tools
|
|
@@ -88,6 +97,10 @@ const thread = brain.thread({ provider: 'anthropic', model: 'claude-sonnet-4-202
|
|
|
88
97
|
await thread.send('Hello')
|
|
89
98
|
await thread.send('Tell me more')
|
|
90
99
|
const saved = thread.serialize() // persist and restore later
|
|
100
|
+
|
|
101
|
+
// Google Gemini example
|
|
102
|
+
const geminiThread = brain.thread({ provider: 'google', model: 'gemini-2.0-flash' })
|
|
103
|
+
await geminiThread.send('Explain quantum computing')
|
|
91
104
|
```
|
|
92
105
|
|
|
93
106
|
## Workflows
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@strav/brain",
|
|
3
|
-
"version": "0.2.
|
|
3
|
+
"version": "0.2.12",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"description": "AI module for the Strav framework",
|
|
6
6
|
"license": "MIT",
|
|
@@ -15,10 +15,10 @@
|
|
|
15
15
|
"CHANGELOG.md"
|
|
16
16
|
],
|
|
17
17
|
"peerDependencies": {
|
|
18
|
-
"@strav/kernel": "0.2.
|
|
18
|
+
"@strav/kernel": "0.2.12"
|
|
19
19
|
},
|
|
20
20
|
"dependencies": {
|
|
21
|
-
"@strav/workflow": "0.2.
|
|
21
|
+
"@strav/workflow": "0.2.12",
|
|
22
22
|
"zod": "^3.25 || ^4.0"
|
|
23
23
|
},
|
|
24
24
|
"scripts": {
|
package/src/index.ts
CHANGED
|
@@ -7,6 +7,7 @@ export { Agent } from './agent.ts'
|
|
|
7
7
|
export { defineTool, defineToolbox } from './tool.ts'
|
|
8
8
|
export { Workflow } from './workflow.ts'
|
|
9
9
|
export { AnthropicProvider } from './providers/anthropic_provider.ts'
|
|
10
|
+
export { GoogleProvider } from './providers/google_provider.ts'
|
|
10
11
|
export { OpenAIProvider } from './providers/openai_provider.ts'
|
|
11
12
|
export { OpenAIResponsesProvider } from './providers/openai_responses_provider.ts'
|
|
12
13
|
export { parseSSE } from './utils/sse_parser.ts'
|
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
import { parseSSE } from '../utils/sse_parser.ts'
|
|
2
|
+
import { retryableFetch, type RetryOptions } from '../utils/retry.ts'
|
|
3
|
+
import { ExternalServiceError } from '@strav/kernel'
|
|
4
|
+
import type {
|
|
5
|
+
AIProvider,
|
|
6
|
+
CompletionRequest,
|
|
7
|
+
CompletionResponse,
|
|
8
|
+
StreamChunk,
|
|
9
|
+
EmbeddingResponse,
|
|
10
|
+
ProviderConfig,
|
|
11
|
+
Message,
|
|
12
|
+
ToolCall,
|
|
13
|
+
Usage,
|
|
14
|
+
} from '../types.ts'
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* Google Gemini API provider.
|
|
18
|
+
*
|
|
19
|
+
* Translates the framework's normalized CompletionRequest/Response
|
|
20
|
+
* to/from the Google Generative Language API wire format. Uses raw `fetch()`.
|
|
21
|
+
*/
|
|
22
|
+
export class GoogleProvider implements AIProvider {
|
|
23
|
+
readonly name: string
|
|
24
|
+
private apiKey: string
|
|
25
|
+
private baseUrl: string
|
|
26
|
+
private defaultModel: string
|
|
27
|
+
private defaultMaxTokens?: number
|
|
28
|
+
private retryOptions: RetryOptions
|
|
29
|
+
private toolCallIdToNameMap: Map<string, string> = new Map()
|
|
30
|
+
|
|
31
|
+
constructor(config: ProviderConfig) {
|
|
32
|
+
this.name = 'google'
|
|
33
|
+
this.apiKey = config.apiKey
|
|
34
|
+
this.baseUrl = (config.baseUrl ?? 'https://generativelanguage.googleapis.com/v1beta').replace(/\/$/, '')
|
|
35
|
+
this.defaultModel = config.model || 'gemini-2.0-flash'
|
|
36
|
+
this.defaultMaxTokens = config.maxTokens
|
|
37
|
+
this.retryOptions = {
|
|
38
|
+
maxRetries: config.maxRetries ?? 3,
|
|
39
|
+
baseDelay: config.retryBaseDelay ?? 1000,
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
async complete(request: CompletionRequest): Promise<CompletionResponse> {
|
|
44
|
+
const model = request.model ?? this.defaultModel
|
|
45
|
+
const body = this.buildRequestBody(request, false)
|
|
46
|
+
|
|
47
|
+
const response = await retryableFetch(
|
|
48
|
+
'Google',
|
|
49
|
+
`${this.baseUrl}/models/${model}:generateContent`,
|
|
50
|
+
{ method: 'POST', headers: this.buildHeaders(), body: JSON.stringify(body) },
|
|
51
|
+
this.retryOptions
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
const data: any = await response.json()
|
|
55
|
+
return this.parseResponse(data)
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
async *stream(request: CompletionRequest): AsyncIterable<StreamChunk> {
|
|
59
|
+
const model = request.model ?? this.defaultModel
|
|
60
|
+
const body = this.buildRequestBody(request, true)
|
|
61
|
+
|
|
62
|
+
const response = await retryableFetch(
|
|
63
|
+
'Google',
|
|
64
|
+
`${this.baseUrl}/models/${model}:streamGenerateContent`,
|
|
65
|
+
{ method: 'POST', headers: this.buildHeaders(), body: JSON.stringify(body) },
|
|
66
|
+
this.retryOptions
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if (!response.body) {
|
|
70
|
+
throw new ExternalServiceError('Google', undefined, 'No stream body returned')
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
let currentToolIndex = -1
|
|
74
|
+
let currentToolCall: Partial<ToolCall> | null = null
|
|
75
|
+
|
|
76
|
+
for await (const sse of parseSSE(response.body)) {
|
|
77
|
+
if (sse.data === '[DONE]') {
|
|
78
|
+
yield { type: 'done' }
|
|
79
|
+
break
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
let parsed: any
|
|
83
|
+
try {
|
|
84
|
+
parsed = JSON.parse(sse.data)
|
|
85
|
+
} catch {
|
|
86
|
+
continue
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
const candidate = parsed.candidates?.[0]
|
|
90
|
+
if (!candidate) continue
|
|
91
|
+
|
|
92
|
+
// Process content parts if they exist
|
|
93
|
+
if (candidate.content?.parts) {
|
|
94
|
+
for (const part of candidate.content.parts) {
|
|
95
|
+
if (part.text) {
|
|
96
|
+
// Text content
|
|
97
|
+
yield { type: 'text', text: part.text }
|
|
98
|
+
} else if (part.functionCall) {
|
|
99
|
+
// Function call
|
|
100
|
+
if (currentToolCall === null) {
|
|
101
|
+
// Start of new tool call
|
|
102
|
+
currentToolIndex++
|
|
103
|
+
currentToolCall = {
|
|
104
|
+
id: part.functionCall.id || this.generateToolCallId(),
|
|
105
|
+
name: part.functionCall.name,
|
|
106
|
+
arguments: part.functionCall.args || {}
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
yield {
|
|
110
|
+
type: 'tool_start',
|
|
111
|
+
toolCall: {
|
|
112
|
+
id: currentToolCall.id,
|
|
113
|
+
name: currentToolCall.name
|
|
114
|
+
} as ToolCall,
|
|
115
|
+
toolIndex: currentToolIndex,
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
// If this is a complete function call, end it
|
|
120
|
+
if (part.functionCall.name && part.functionCall.args) {
|
|
121
|
+
yield { type: 'tool_end', toolIndex: currentToolIndex }
|
|
122
|
+
currentToolCall = null
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
// Check if this is the final chunk
|
|
129
|
+
if (candidate.finishReason) {
|
|
130
|
+
// Handle usage information in the final chunk
|
|
131
|
+
if (parsed.usageMetadata) {
|
|
132
|
+
const usage: Usage = {
|
|
133
|
+
inputTokens: parsed.usageMetadata.promptTokenCount ?? 0,
|
|
134
|
+
outputTokens: parsed.usageMetadata.candidatesTokenCount ?? 0,
|
|
135
|
+
totalTokens: parsed.usageMetadata.totalTokenCount ?? 0,
|
|
136
|
+
}
|
|
137
|
+
yield { type: 'usage', usage }
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
yield { type: 'done' }
|
|
141
|
+
break
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
async embed(input: string | string[], model?: string): Promise<EmbeddingResponse> {
|
|
147
|
+
const embeddingModel = model ?? 'text-embedding-004'
|
|
148
|
+
const inputs = Array.isArray(input) ? input : [input]
|
|
149
|
+
|
|
150
|
+
const requests = inputs.map(text => ({
|
|
151
|
+
model: `models/${embeddingModel}`,
|
|
152
|
+
content: {
|
|
153
|
+
parts: [{ text }]
|
|
154
|
+
}
|
|
155
|
+
}))
|
|
156
|
+
|
|
157
|
+
const embeddings: number[][] = []
|
|
158
|
+
|
|
159
|
+
// Process each input separately as Google's batch API might not be available
|
|
160
|
+
for (const request of requests) {
|
|
161
|
+
const response = await retryableFetch(
|
|
162
|
+
'Google',
|
|
163
|
+
`${this.baseUrl}/models/${embeddingModel}:embedContent`,
|
|
164
|
+
{ method: 'POST', headers: this.buildHeaders(), body: JSON.stringify(request) },
|
|
165
|
+
this.retryOptions
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
const data: any = await response.json()
|
|
169
|
+
if (data.embedding?.values) {
|
|
170
|
+
embeddings.push(data.embedding.values)
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
return {
|
|
175
|
+
embeddings,
|
|
176
|
+
model: embeddingModel,
|
|
177
|
+
usage: { totalTokens: inputs.length * 10 } // Rough estimate, Google doesn't provide token count for embeddings
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
// ── Private helpers ──────────────────────────────────────────────────────
|
|
182
|
+
|
|
183
|
+
private buildHeaders(): Record<string, string> {
|
|
184
|
+
return {
|
|
185
|
+
'content-type': 'application/json',
|
|
186
|
+
'x-goog-api-key': this.apiKey,
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
private buildRequestBody(request: CompletionRequest, stream: boolean): Record<string, unknown> {
|
|
191
|
+
const model = request.model ?? this.defaultModel
|
|
192
|
+
|
|
193
|
+
const body: Record<string, unknown> = {
|
|
194
|
+
contents: this.mapMessages(request.messages),
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
// Add system instruction if present
|
|
198
|
+
if (request.system) {
|
|
199
|
+
body.systemInstruction = {
|
|
200
|
+
parts: [{ text: request.system }]
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
// Generation config
|
|
205
|
+
const generationConfig: Record<string, unknown> = {}
|
|
206
|
+
|
|
207
|
+
if (request.maxTokens !== undefined) {
|
|
208
|
+
generationConfig.maxOutputTokens = request.maxTokens
|
|
209
|
+
} else if (this.defaultMaxTokens !== undefined) {
|
|
210
|
+
generationConfig.maxOutputTokens = this.defaultMaxTokens
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
if (request.temperature !== undefined) {
|
|
214
|
+
generationConfig.temperature = request.temperature
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
if (request.stopSequences?.length) {
|
|
218
|
+
generationConfig.stopSequences = request.stopSequences
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
// Structured output
|
|
222
|
+
if (request.schema) {
|
|
223
|
+
generationConfig.responseMimeType = 'application/json'
|
|
224
|
+
generationConfig.responseSchema = request.schema
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
if (Object.keys(generationConfig).length > 0) {
|
|
228
|
+
body.generationConfig = generationConfig
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
// Tools (function declarations)
|
|
232
|
+
if (request.tools?.length) {
|
|
233
|
+
body.tools = [{
|
|
234
|
+
functionDeclarations: request.tools.map(t => ({
|
|
235
|
+
name: t.name,
|
|
236
|
+
description: t.description,
|
|
237
|
+
parameters: t.parameters,
|
|
238
|
+
}))
|
|
239
|
+
}]
|
|
240
|
+
|
|
241
|
+
// Tool choice configuration
|
|
242
|
+
if (request.toolChoice) {
|
|
243
|
+
const toolConfig: Record<string, unknown> = {}
|
|
244
|
+
|
|
245
|
+
if (request.toolChoice === 'auto') {
|
|
246
|
+
toolConfig.functionCallingConfig = { mode: 'AUTO' }
|
|
247
|
+
} else if (request.toolChoice === 'required') {
|
|
248
|
+
toolConfig.functionCallingConfig = { mode: 'ANY' }
|
|
249
|
+
} else if (typeof request.toolChoice === 'object' && request.toolChoice.name) {
|
|
250
|
+
toolConfig.functionCallingConfig = {
|
|
251
|
+
mode: 'ANY',
|
|
252
|
+
allowedFunctionNames: [request.toolChoice.name]
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
if (Object.keys(toolConfig).length > 0) {
|
|
257
|
+
body.toolConfig = toolConfig
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
return body
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
private mapMessages(messages: Message[]): any[] {
|
|
266
|
+
const result: any[] = []
|
|
267
|
+
|
|
268
|
+
for (const msg of messages) {
|
|
269
|
+
if (msg.role === 'tool') {
|
|
270
|
+
// Tool results go as user messages with function response parts
|
|
271
|
+
// Get the function name from our mapping
|
|
272
|
+
const functionName = msg.toolCallId ? this.toolCallIdToNameMap.get(msg.toolCallId) : undefined
|
|
273
|
+
|
|
274
|
+
if (!functionName) {
|
|
275
|
+
throw new ExternalServiceError('Google', undefined, `No function name found for tool call ID: ${msg.toolCallId}`)
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
result.push({
|
|
279
|
+
role: 'user',
|
|
280
|
+
parts: [
|
|
281
|
+
{
|
|
282
|
+
functionResponse: {
|
|
283
|
+
name: functionName,
|
|
284
|
+
response: {
|
|
285
|
+
content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content),
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
]
|
|
290
|
+
})
|
|
291
|
+
} else if (msg.role === 'assistant') {
|
|
292
|
+
const parts: any[] = []
|
|
293
|
+
|
|
294
|
+
// Add text content if present
|
|
295
|
+
const text = typeof msg.content === 'string' ? msg.content : ''
|
|
296
|
+
if (text) {
|
|
297
|
+
parts.push({ text })
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
// Add function call parts and track their IDs
|
|
301
|
+
if (msg.toolCalls?.length) {
|
|
302
|
+
for (const tc of msg.toolCalls) {
|
|
303
|
+
// Store the mapping for later use
|
|
304
|
+
this.toolCallIdToNameMap.set(tc.id, tc.name)
|
|
305
|
+
|
|
306
|
+
parts.push({
|
|
307
|
+
functionCall: {
|
|
308
|
+
name: tc.name,
|
|
309
|
+
args: tc.arguments,
|
|
310
|
+
}
|
|
311
|
+
})
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
result.push({
|
|
316
|
+
role: 'model', // Gemini uses 'model' instead of 'assistant'
|
|
317
|
+
parts
|
|
318
|
+
})
|
|
319
|
+
} else {
|
|
320
|
+
// User messages
|
|
321
|
+
result.push({
|
|
322
|
+
role: 'user',
|
|
323
|
+
parts: [{ text: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) }]
|
|
324
|
+
})
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
return result
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
private parseResponse(data: any): CompletionResponse {
|
|
332
|
+
const candidate = data.candidates?.[0]
|
|
333
|
+
if (!candidate) {
|
|
334
|
+
throw new ExternalServiceError('Google', undefined, 'No candidates in response')
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
let content = ''
|
|
338
|
+
const toolCalls: ToolCall[] = []
|
|
339
|
+
|
|
340
|
+
// Extract content from parts
|
|
341
|
+
if (Array.isArray(candidate.content?.parts)) {
|
|
342
|
+
for (const part of candidate.content.parts) {
|
|
343
|
+
if (part.text) {
|
|
344
|
+
content += part.text
|
|
345
|
+
} else if (part.functionCall) {
|
|
346
|
+
toolCalls.push({
|
|
347
|
+
id: part.functionCall.id || this.generateToolCallId(),
|
|
348
|
+
name: part.functionCall.name,
|
|
349
|
+
arguments: part.functionCall.args || {},
|
|
350
|
+
})
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
const usage: Usage = {
|
|
356
|
+
inputTokens: data.usageMetadata?.promptTokenCount ?? 0,
|
|
357
|
+
outputTokens: data.usageMetadata?.candidatesTokenCount ?? 0,
|
|
358
|
+
totalTokens: data.usageMetadata?.totalTokenCount ?? 0,
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
let stopReason: CompletionResponse['stopReason'] = 'end'
|
|
362
|
+
|
|
363
|
+
// Check tool calls first, as Google may return STOP even with tool calls
|
|
364
|
+
if (toolCalls.length > 0) {
|
|
365
|
+
stopReason = 'tool_use'
|
|
366
|
+
} else {
|
|
367
|
+
switch (candidate.finishReason) {
|
|
368
|
+
case 'STOP':
|
|
369
|
+
stopReason = 'end'
|
|
370
|
+
break
|
|
371
|
+
case 'MAX_TOKENS':
|
|
372
|
+
stopReason = 'max_tokens'
|
|
373
|
+
break
|
|
374
|
+
case 'SAFETY':
|
|
375
|
+
case 'RECITATION':
|
|
376
|
+
stopReason = 'stop_sequence'
|
|
377
|
+
break
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
return {
|
|
382
|
+
id: data.candidates?.[0]?.id || this.generateResponseId(),
|
|
383
|
+
content,
|
|
384
|
+
toolCalls,
|
|
385
|
+
stopReason,
|
|
386
|
+
usage,
|
|
387
|
+
raw: data,
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
private generateToolCallId(): string {
|
|
392
|
+
return `tool_${Math.random().toString(36).substring(2, 15)}`
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
private generateResponseId(): string {
|
|
396
|
+
return `resp_${Math.random().toString(36).substring(2, 15)}`
|
|
397
|
+
}
|
|
398
|
+
}
|