@strav/brain 0.2.11 → 0.2.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/CHANGELOG.md CHANGED
@@ -1,5 +1,17 @@
1
1
  # Changelog
2
2
 
3
+ ## 0.2.12
4
+
5
+ ### Added
6
+
7
+ - **GoogleProvider** — Support for Google's Gemini models
8
+ - Native Gemini API integration using `generativelanguage.googleapis.com`
9
+ - Support for completion, streaming, function calling, and embeddings
10
+ - Models: `gemini-2.0-flash`, `gemini-2.5-flash`, `gemini-3-pro-preview`
11
+ - Authentication via `x-goog-api-key` header
12
+ - Zero new dependencies — uses raw `fetch()` following existing patterns
13
+ - Comprehensive test suite with 29 tests covering all functionality
14
+
3
15
  ## 0.6.0
4
16
 
5
17
  ### Added
package/README.md CHANGED
@@ -14,6 +14,7 @@ Requires `@strav/core` as a peer dependency.
14
14
 
15
15
  - **Anthropic** (Claude)
16
16
  - **OpenAI** (GPT, also works with DeepSeek via custom `baseUrl`)
17
+ - **Google** (Gemini)
17
18
 
18
19
  ## Usage
19
20
 
@@ -73,6 +74,14 @@ class ResearchAgent extends Agent {
73
74
  tools = [searchTool, summarizeTool]
74
75
  }
75
76
 
77
+ // Google Gemini agent
78
+ class GeminiResearchAgent extends Agent {
79
+ provider = 'google'
80
+ model = 'gemini-2.0-flash'
81
+ instructions = 'You are a research assistant powered by Gemini.'
82
+ tools = [searchTool, summarizeTool]
83
+ }
84
+
76
85
  // Run agent with context
77
86
  const runner = brain.agent(ResearchAgent)
78
87
  runner.context({ userId: '123' }) // Pass context to tools
@@ -88,6 +97,10 @@ const thread = brain.thread({ provider: 'anthropic', model: 'claude-sonnet-4-202
88
97
  await thread.send('Hello')
89
98
  await thread.send('Tell me more')
90
99
  const saved = thread.serialize() // persist and restore later
100
+
101
+ // Google Gemini example
102
+ const geminiThread = brain.thread({ provider: 'google', model: 'gemini-2.0-flash' })
103
+ await geminiThread.send('Explain quantum computing')
91
104
  ```
92
105
 
93
106
  ## Workflows
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@strav/brain",
3
- "version": "0.2.11",
3
+ "version": "0.2.13",
4
4
  "type": "module",
5
5
  "description": "AI module for the Strav framework",
6
6
  "license": "MIT",
@@ -15,10 +15,10 @@
15
15
  "CHANGELOG.md"
16
16
  ],
17
17
  "peerDependencies": {
18
- "@strav/kernel": "0.2.9"
18
+ "@strav/kernel": "0.2.13"
19
19
  },
20
20
  "dependencies": {
21
- "@strav/workflow": "0.2.9",
21
+ "@strav/workflow": "0.2.13",
22
22
  "zod": "^3.25 || ^4.0"
23
23
  },
24
24
  "scripts": {
package/src/index.ts CHANGED
@@ -7,6 +7,7 @@ export { Agent } from './agent.ts'
7
7
  export { defineTool, defineToolbox } from './tool.ts'
8
8
  export { Workflow } from './workflow.ts'
9
9
  export { AnthropicProvider } from './providers/anthropic_provider.ts'
10
+ export { GoogleProvider } from './providers/google_provider.ts'
10
11
  export { OpenAIProvider } from './providers/openai_provider.ts'
11
12
  export { OpenAIResponsesProvider } from './providers/openai_responses_provider.ts'
12
13
  export { parseSSE } from './utils/sse_parser.ts'
@@ -0,0 +1,398 @@
1
+ import { parseSSE } from '../utils/sse_parser.ts'
2
+ import { retryableFetch, type RetryOptions } from '../utils/retry.ts'
3
+ import { ExternalServiceError } from '@strav/kernel'
4
+ import type {
5
+ AIProvider,
6
+ CompletionRequest,
7
+ CompletionResponse,
8
+ StreamChunk,
9
+ EmbeddingResponse,
10
+ ProviderConfig,
11
+ Message,
12
+ ToolCall,
13
+ Usage,
14
+ } from '../types.ts'
15
+
16
+ /**
17
+ * Google Gemini API provider.
18
+ *
19
+ * Translates the framework's normalized CompletionRequest/Response
20
+ * to/from the Google Generative Language API wire format. Uses raw `fetch()`.
21
+ */
22
+ export class GoogleProvider implements AIProvider {
23
+ readonly name: string
24
+ private apiKey: string
25
+ private baseUrl: string
26
+ private defaultModel: string
27
+ private defaultMaxTokens?: number
28
+ private retryOptions: RetryOptions
29
+ private toolCallIdToNameMap: Map<string, string> = new Map()
30
+
31
+ constructor(config: ProviderConfig) {
32
+ this.name = 'google'
33
+ this.apiKey = config.apiKey
34
+ this.baseUrl = (config.baseUrl ?? 'https://generativelanguage.googleapis.com/v1beta').replace(/\/$/, '')
35
+ this.defaultModel = config.model || 'gemini-2.0-flash'
36
+ this.defaultMaxTokens = config.maxTokens
37
+ this.retryOptions = {
38
+ maxRetries: config.maxRetries ?? 3,
39
+ baseDelay: config.retryBaseDelay ?? 1000,
40
+ }
41
+ }
42
+
43
+ async complete(request: CompletionRequest): Promise<CompletionResponse> {
44
+ const model = request.model ?? this.defaultModel
45
+ const body = this.buildRequestBody(request, false)
46
+
47
+ const response = await retryableFetch(
48
+ 'Google',
49
+ `${this.baseUrl}/models/${model}:generateContent`,
50
+ { method: 'POST', headers: this.buildHeaders(), body: JSON.stringify(body) },
51
+ this.retryOptions
52
+ )
53
+
54
+ const data: any = await response.json()
55
+ return this.parseResponse(data)
56
+ }
57
+
58
+ async *stream(request: CompletionRequest): AsyncIterable<StreamChunk> {
59
+ const model = request.model ?? this.defaultModel
60
+ const body = this.buildRequestBody(request, true)
61
+
62
+ const response = await retryableFetch(
63
+ 'Google',
64
+ `${this.baseUrl}/models/${model}:streamGenerateContent`,
65
+ { method: 'POST', headers: this.buildHeaders(), body: JSON.stringify(body) },
66
+ this.retryOptions
67
+ )
68
+
69
+ if (!response.body) {
70
+ throw new ExternalServiceError('Google', undefined, 'No stream body returned')
71
+ }
72
+
73
+ let currentToolIndex = -1
74
+ let currentToolCall: Partial<ToolCall> | null = null
75
+
76
+ for await (const sse of parseSSE(response.body)) {
77
+ if (sse.data === '[DONE]') {
78
+ yield { type: 'done' }
79
+ break
80
+ }
81
+
82
+ let parsed: any
83
+ try {
84
+ parsed = JSON.parse(sse.data)
85
+ } catch {
86
+ continue
87
+ }
88
+
89
+ const candidate = parsed.candidates?.[0]
90
+ if (!candidate) continue
91
+
92
+ // Process content parts if they exist
93
+ if (candidate.content?.parts) {
94
+ for (const part of candidate.content.parts) {
95
+ if (part.text) {
96
+ // Text content
97
+ yield { type: 'text', text: part.text }
98
+ } else if (part.functionCall) {
99
+ // Function call
100
+ if (currentToolCall === null) {
101
+ // Start of new tool call
102
+ currentToolIndex++
103
+ currentToolCall = {
104
+ id: part.functionCall.id || this.generateToolCallId(),
105
+ name: part.functionCall.name,
106
+ arguments: part.functionCall.args || {}
107
+ }
108
+
109
+ yield {
110
+ type: 'tool_start',
111
+ toolCall: {
112
+ id: currentToolCall.id,
113
+ name: currentToolCall.name
114
+ } as ToolCall,
115
+ toolIndex: currentToolIndex,
116
+ }
117
+ }
118
+
119
+ // If this is a complete function call, end it
120
+ if (part.functionCall.name && part.functionCall.args) {
121
+ yield { type: 'tool_end', toolIndex: currentToolIndex }
122
+ currentToolCall = null
123
+ }
124
+ }
125
+ }
126
+ }
127
+
128
+ // Check if this is the final chunk
129
+ if (candidate.finishReason) {
130
+ // Handle usage information in the final chunk
131
+ if (parsed.usageMetadata) {
132
+ const usage: Usage = {
133
+ inputTokens: parsed.usageMetadata.promptTokenCount ?? 0,
134
+ outputTokens: parsed.usageMetadata.candidatesTokenCount ?? 0,
135
+ totalTokens: parsed.usageMetadata.totalTokenCount ?? 0,
136
+ }
137
+ yield { type: 'usage', usage }
138
+ }
139
+
140
+ yield { type: 'done' }
141
+ break
142
+ }
143
+ }
144
+ }
145
+
146
+ async embed(input: string | string[], model?: string): Promise<EmbeddingResponse> {
147
+ const embeddingModel = model ?? 'text-embedding-004'
148
+ const inputs = Array.isArray(input) ? input : [input]
149
+
150
+ const requests = inputs.map(text => ({
151
+ model: `models/${embeddingModel}`,
152
+ content: {
153
+ parts: [{ text }]
154
+ }
155
+ }))
156
+
157
+ const embeddings: number[][] = []
158
+
159
+ // Process each input separately as Google's batch API might not be available
160
+ for (const request of requests) {
161
+ const response = await retryableFetch(
162
+ 'Google',
163
+ `${this.baseUrl}/models/${embeddingModel}:embedContent`,
164
+ { method: 'POST', headers: this.buildHeaders(), body: JSON.stringify(request) },
165
+ this.retryOptions
166
+ )
167
+
168
+ const data: any = await response.json()
169
+ if (data.embedding?.values) {
170
+ embeddings.push(data.embedding.values)
171
+ }
172
+ }
173
+
174
+ return {
175
+ embeddings,
176
+ model: embeddingModel,
177
+ usage: { totalTokens: inputs.length * 10 } // Rough estimate, Google doesn't provide token count for embeddings
178
+ }
179
+ }
180
+
181
+ // ── Private helpers ──────────────────────────────────────────────────────
182
+
183
+ private buildHeaders(): Record<string, string> {
184
+ return {
185
+ 'content-type': 'application/json',
186
+ 'x-goog-api-key': this.apiKey,
187
+ }
188
+ }
189
+
190
+ private buildRequestBody(request: CompletionRequest, stream: boolean): Record<string, unknown> {
191
+ const model = request.model ?? this.defaultModel
192
+
193
+ const body: Record<string, unknown> = {
194
+ contents: this.mapMessages(request.messages),
195
+ }
196
+
197
+ // Add system instruction if present
198
+ if (request.system) {
199
+ body.systemInstruction = {
200
+ parts: [{ text: request.system }]
201
+ }
202
+ }
203
+
204
+ // Generation config
205
+ const generationConfig: Record<string, unknown> = {}
206
+
207
+ if (request.maxTokens !== undefined) {
208
+ generationConfig.maxOutputTokens = request.maxTokens
209
+ } else if (this.defaultMaxTokens !== undefined) {
210
+ generationConfig.maxOutputTokens = this.defaultMaxTokens
211
+ }
212
+
213
+ if (request.temperature !== undefined) {
214
+ generationConfig.temperature = request.temperature
215
+ }
216
+
217
+ if (request.stopSequences?.length) {
218
+ generationConfig.stopSequences = request.stopSequences
219
+ }
220
+
221
+ // Structured output
222
+ if (request.schema) {
223
+ generationConfig.responseMimeType = 'application/json'
224
+ generationConfig.responseSchema = request.schema
225
+ }
226
+
227
+ if (Object.keys(generationConfig).length > 0) {
228
+ body.generationConfig = generationConfig
229
+ }
230
+
231
+ // Tools (function declarations)
232
+ if (request.tools?.length) {
233
+ body.tools = [{
234
+ functionDeclarations: request.tools.map(t => ({
235
+ name: t.name,
236
+ description: t.description,
237
+ parameters: t.parameters,
238
+ }))
239
+ }]
240
+
241
+ // Tool choice configuration
242
+ if (request.toolChoice) {
243
+ const toolConfig: Record<string, unknown> = {}
244
+
245
+ if (request.toolChoice === 'auto') {
246
+ toolConfig.functionCallingConfig = { mode: 'AUTO' }
247
+ } else if (request.toolChoice === 'required') {
248
+ toolConfig.functionCallingConfig = { mode: 'ANY' }
249
+ } else if (typeof request.toolChoice === 'object' && request.toolChoice.name) {
250
+ toolConfig.functionCallingConfig = {
251
+ mode: 'ANY',
252
+ allowedFunctionNames: [request.toolChoice.name]
253
+ }
254
+ }
255
+
256
+ if (Object.keys(toolConfig).length > 0) {
257
+ body.toolConfig = toolConfig
258
+ }
259
+ }
260
+ }
261
+
262
+ return body
263
+ }
264
+
265
+ private mapMessages(messages: Message[]): any[] {
266
+ const result: any[] = []
267
+
268
+ for (const msg of messages) {
269
+ if (msg.role === 'tool') {
270
+ // Tool results go as user messages with function response parts
271
+ // Get the function name from our mapping
272
+ const functionName = msg.toolCallId ? this.toolCallIdToNameMap.get(msg.toolCallId) : undefined
273
+
274
+ if (!functionName) {
275
+ throw new ExternalServiceError('Google', undefined, `No function name found for tool call ID: ${msg.toolCallId}`)
276
+ }
277
+
278
+ result.push({
279
+ role: 'user',
280
+ parts: [
281
+ {
282
+ functionResponse: {
283
+ name: functionName,
284
+ response: {
285
+ content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content),
286
+ }
287
+ }
288
+ }
289
+ ]
290
+ })
291
+ } else if (msg.role === 'assistant') {
292
+ const parts: any[] = []
293
+
294
+ // Add text content if present
295
+ const text = typeof msg.content === 'string' ? msg.content : ''
296
+ if (text) {
297
+ parts.push({ text })
298
+ }
299
+
300
+ // Add function call parts and track their IDs
301
+ if (msg.toolCalls?.length) {
302
+ for (const tc of msg.toolCalls) {
303
+ // Store the mapping for later use
304
+ this.toolCallIdToNameMap.set(tc.id, tc.name)
305
+
306
+ parts.push({
307
+ functionCall: {
308
+ name: tc.name,
309
+ args: tc.arguments,
310
+ }
311
+ })
312
+ }
313
+ }
314
+
315
+ result.push({
316
+ role: 'model', // Gemini uses 'model' instead of 'assistant'
317
+ parts
318
+ })
319
+ } else {
320
+ // User messages
321
+ result.push({
322
+ role: 'user',
323
+ parts: [{ text: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content) }]
324
+ })
325
+ }
326
+ }
327
+
328
+ return result
329
+ }
330
+
331
+ private parseResponse(data: any): CompletionResponse {
332
+ const candidate = data.candidates?.[0]
333
+ if (!candidate) {
334
+ throw new ExternalServiceError('Google', undefined, 'No candidates in response')
335
+ }
336
+
337
+ let content = ''
338
+ const toolCalls: ToolCall[] = []
339
+
340
+ // Extract content from parts
341
+ if (Array.isArray(candidate.content?.parts)) {
342
+ for (const part of candidate.content.parts) {
343
+ if (part.text) {
344
+ content += part.text
345
+ } else if (part.functionCall) {
346
+ toolCalls.push({
347
+ id: part.functionCall.id || this.generateToolCallId(),
348
+ name: part.functionCall.name,
349
+ arguments: part.functionCall.args || {},
350
+ })
351
+ }
352
+ }
353
+ }
354
+
355
+ const usage: Usage = {
356
+ inputTokens: data.usageMetadata?.promptTokenCount ?? 0,
357
+ outputTokens: data.usageMetadata?.candidatesTokenCount ?? 0,
358
+ totalTokens: data.usageMetadata?.totalTokenCount ?? 0,
359
+ }
360
+
361
+ let stopReason: CompletionResponse['stopReason'] = 'end'
362
+
363
+ // Check tool calls first, as Google may return STOP even with tool calls
364
+ if (toolCalls.length > 0) {
365
+ stopReason = 'tool_use'
366
+ } else {
367
+ switch (candidate.finishReason) {
368
+ case 'STOP':
369
+ stopReason = 'end'
370
+ break
371
+ case 'MAX_TOKENS':
372
+ stopReason = 'max_tokens'
373
+ break
374
+ case 'SAFETY':
375
+ case 'RECITATION':
376
+ stopReason = 'stop_sequence'
377
+ break
378
+ }
379
+ }
380
+
381
+ return {
382
+ id: data.candidates?.[0]?.id || this.generateResponseId(),
383
+ content,
384
+ toolCalls,
385
+ stopReason,
386
+ usage,
387
+ raw: data,
388
+ }
389
+ }
390
+
391
+ private generateToolCallId(): string {
392
+ return `tool_${Math.random().toString(36).substring(2, 15)}`
393
+ }
394
+
395
+ private generateResponseId(): string {
396
+ return `resp_${Math.random().toString(36).substring(2, 15)}`
397
+ }
398
+ }