@strav/brain 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,509 @@
1
+ import { parseSSE } from '../utils/sse_parser.ts'
2
+ import { retryableFetch, type RetryOptions } from '../utils/retry.ts'
3
+ import { ExternalServiceError } from '@stravigor/kernel'
4
+ import type {
5
+ AIProvider,
6
+ CompletionRequest,
7
+ CompletionResponse,
8
+ StreamChunk,
9
+ EmbeddingResponse,
10
+ ProviderConfig,
11
+ Message,
12
+ ToolCall,
13
+ Usage,
14
+ } from '../types.ts'
15
+
16
+ /**
17
+ * OpenAI Chat Completions API provider.
18
+ *
19
+ * Also serves DeepSeek and any OpenAI-compatible API by setting `baseUrl`
20
+ * in the provider config. Uses raw `fetch()`.
21
+ */
22
+ export class OpenAIProvider implements AIProvider {
23
+ readonly name: string
24
+ private apiKey: string
25
+ private baseUrl: string
26
+ private defaultModel: string
27
+ private defaultMaxTokens?: number
28
+ private retryOptions: RetryOptions
29
+
30
+ constructor(config: ProviderConfig, name?: string) {
31
+ this.name = name ?? 'openai'
32
+ this.apiKey = config.apiKey
33
+ this.baseUrl = (config.baseUrl ?? 'https://api.openai.com').replace(/\/$/, '')
34
+ this.defaultModel = config.model
35
+ this.defaultMaxTokens = config.maxTokens
36
+ this.retryOptions = {
37
+ maxRetries: config.maxRetries ?? 3,
38
+ baseDelay: config.retryBaseDelay ?? 1000,
39
+ }
40
+ }
41
+
42
+ /** Whether this provider supports OpenAI's native json_schema response format. */
43
+ private get supportsJsonSchema(): boolean {
44
+ return this.baseUrl === 'https://api.openai.com'
45
+ }
46
+
47
+ async complete(request: CompletionRequest): Promise<CompletionResponse> {
48
+ const body = this.buildRequestBody(request, false)
49
+
50
+ const response = await retryableFetch(
51
+ 'OpenAI',
52
+ `${this.baseUrl}/v1/chat/completions`,
53
+ { method: 'POST', headers: this.buildHeaders(), body: JSON.stringify(body) },
54
+ this.retryOptions
55
+ )
56
+
57
+ const data: any = await response.json()
58
+ return this.parseResponse(data)
59
+ }
60
+
61
+ async *stream(request: CompletionRequest): AsyncIterable<StreamChunk> {
62
+ const body = this.buildRequestBody(request, true)
63
+
64
+ const response = await retryableFetch(
65
+ 'OpenAI',
66
+ `${this.baseUrl}/v1/chat/completions`,
67
+ { method: 'POST', headers: this.buildHeaders(), body: JSON.stringify(body) },
68
+ this.retryOptions
69
+ )
70
+
71
+ if (!response.body) {
72
+ throw new ExternalServiceError('OpenAI', undefined, 'No stream body returned')
73
+ }
74
+
75
+ // Track in-progress tool calls for tool_start vs tool_delta distinction
76
+ const seenTools = new Set<number>()
77
+
78
+ for await (const sse of parseSSE(response.body)) {
79
+ if (sse.data === '[DONE]') {
80
+ yield { type: 'done' }
81
+ break
82
+ }
83
+
84
+ let parsed: any
85
+ try {
86
+ parsed = JSON.parse(sse.data)
87
+ } catch {
88
+ continue
89
+ }
90
+
91
+ const choice = parsed.choices?.[0]
92
+ if (!choice) continue
93
+
94
+ const delta = choice.delta
95
+ if (!delta) continue
96
+
97
+ // Text content
98
+ if (delta.content) {
99
+ yield { type: 'text', text: delta.content }
100
+ }
101
+
102
+ // Tool calls
103
+ if (delta.tool_calls) {
104
+ for (const tc of delta.tool_calls) {
105
+ const index: number = tc.index ?? 0
106
+
107
+ if (!seenTools.has(index)) {
108
+ // First chunk for this tool — emit tool_start
109
+ seenTools.add(index)
110
+ yield {
111
+ type: 'tool_start',
112
+ toolCall: { id: tc.id, name: tc.function?.name },
113
+ toolIndex: index,
114
+ }
115
+ }
116
+
117
+ // Argument fragments
118
+ if (tc.function?.arguments) {
119
+ yield {
120
+ type: 'tool_delta',
121
+ text: tc.function.arguments,
122
+ toolIndex: index,
123
+ }
124
+ }
125
+ }
126
+ }
127
+
128
+ // Finish reason
129
+ if (choice.finish_reason) {
130
+ if (choice.finish_reason === 'tool_calls') {
131
+ // Emit tool_end for all tracked tools
132
+ for (const idx of seenTools) {
133
+ yield { type: 'tool_end', toolIndex: idx }
134
+ }
135
+ }
136
+
137
+ // Usage in final chunk (if stream_options.include_usage is set)
138
+ if (parsed.usage) {
139
+ yield {
140
+ type: 'usage',
141
+ usage: {
142
+ inputTokens: parsed.usage.prompt_tokens ?? 0,
143
+ outputTokens: parsed.usage.completion_tokens ?? 0,
144
+ totalTokens: parsed.usage.total_tokens ?? 0,
145
+ },
146
+ }
147
+ }
148
+ }
149
+ }
150
+ }
151
+
152
+ async embed(input: string | string[], model?: string): Promise<EmbeddingResponse> {
153
+ const body = {
154
+ input: Array.isArray(input) ? input : [input],
155
+ model: model ?? 'text-embedding-3-small',
156
+ }
157
+
158
+ const response = await retryableFetch(
159
+ 'OpenAI',
160
+ `${this.baseUrl}/v1/embeddings`,
161
+ { method: 'POST', headers: this.buildHeaders(), body: JSON.stringify(body) },
162
+ this.retryOptions
163
+ )
164
+
165
+ const data: any = await response.json()
166
+
167
+ return {
168
+ embeddings: data.data.map((d: any) => d.embedding),
169
+ model: data.model,
170
+ usage: { totalTokens: data.usage?.total_tokens ?? 0 },
171
+ }
172
+ }
173
+
174
+ // ── Private helpers ──────────────────────────────────────────────────────
175
+
176
+ private isReasoningModel(model: string): boolean {
177
+ return /^(o[1-9]|gpt-5)/.test(model)
178
+ }
179
+
180
+ private usesMaxCompletionTokens(model: string): boolean {
181
+ return this.isReasoningModel(model) || /^gpt-4\.1|gpt-4o-mini-2024/.test(model)
182
+ }
183
+
184
+ private buildHeaders(): Record<string, string> {
185
+ return {
186
+ 'content-type': 'application/json',
187
+ authorization: `Bearer ${this.apiKey}`,
188
+ }
189
+ }
190
+
191
+ private buildRequestBody(request: CompletionRequest, stream: boolean): Record<string, unknown> {
192
+ const body: Record<string, unknown> = {
193
+ model: request.model ?? this.defaultModel,
194
+ messages: this.mapMessages(request.messages, request.system),
195
+ }
196
+
197
+ if (stream) body.stream = true
198
+ if (request.maxTokens ?? this.defaultMaxTokens) {
199
+ const tokens = request.maxTokens ?? this.defaultMaxTokens
200
+ const model = (body.model as string) ?? ''
201
+
202
+ if (this.usesMaxCompletionTokens(model)) {
203
+ body.max_completion_tokens = tokens
204
+ } else {
205
+ body.max_tokens = tokens
206
+ }
207
+ }
208
+ if (request.temperature !== undefined && !this.isReasoningModel((body.model as string) ?? '')) {
209
+ body.temperature = request.temperature
210
+ }
211
+ if (request.stopSequences?.length) body.stop = request.stopSequences
212
+
213
+ // Tools
214
+ if (request.tools?.length) {
215
+ body.tools = request.tools.map(t => ({
216
+ type: 'function',
217
+ function: {
218
+ name: t.name,
219
+ description: t.description,
220
+ parameters: t.parameters,
221
+ },
222
+ }))
223
+ }
224
+
225
+ // Tool choice
226
+ if (request.toolChoice) {
227
+ if (typeof request.toolChoice === 'string') {
228
+ body.tool_choice = request.toolChoice
229
+ } else {
230
+ body.tool_choice = {
231
+ type: 'function',
232
+ function: { name: request.toolChoice.name },
233
+ }
234
+ }
235
+ }
236
+
237
+ // Structured output
238
+ if (request.schema) {
239
+ const useStrict = this.supportsJsonSchema && this.isStrictCompatible(request.schema)
240
+
241
+ if (useStrict) {
242
+ body.response_format = {
243
+ type: 'json_schema',
244
+ json_schema: {
245
+ name: 'response',
246
+ schema: this.normalizeSchemaForOpenAI(request.schema),
247
+ strict: true,
248
+ },
249
+ }
250
+ } else {
251
+ // Fallback: json_object mode with schema injected into system prompt
252
+ body.response_format = { type: 'json_object' }
253
+ const schemaHint = `\n\nYou MUST respond with valid JSON matching this schema:\n${JSON.stringify(request.schema, null, 2)}`
254
+ const messages = body.messages as any[]
255
+ if (messages[0]?.role === 'system') {
256
+ messages[0].content += schemaHint
257
+ } else {
258
+ messages.unshift({ role: 'system', content: `Respond with valid JSON.${schemaHint}` })
259
+ }
260
+ }
261
+ }
262
+
263
+ return body
264
+ }
265
+
266
+ private mapMessages(messages: Message[], system?: string): any[] {
267
+ const result: any[] = []
268
+
269
+ // System prompt as first message
270
+ if (system) {
271
+ result.push({ role: 'system', content: system })
272
+ }
273
+
274
+ for (const msg of messages) {
275
+ if (msg.role === 'tool') {
276
+ result.push({
277
+ role: 'tool',
278
+ tool_call_id: msg.toolCallId,
279
+ content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content),
280
+ })
281
+ } else if (msg.role === 'assistant') {
282
+ const mapped: any = {
283
+ role: 'assistant',
284
+ content: typeof msg.content === 'string' ? msg.content : null,
285
+ }
286
+
287
+ if (msg.toolCalls?.length) {
288
+ mapped.tool_calls = msg.toolCalls.map(tc => ({
289
+ id: tc.id,
290
+ type: 'function',
291
+ function: {
292
+ name: tc.name,
293
+ arguments: JSON.stringify(tc.arguments),
294
+ },
295
+ }))
296
+ }
297
+
298
+ result.push(mapped)
299
+ } else {
300
+ result.push({
301
+ role: 'user',
302
+ content: typeof msg.content === 'string' ? msg.content : msg.content,
303
+ })
304
+ }
305
+ }
306
+
307
+ return result
308
+ }
309
+
310
+ private parseResponse(data: any): CompletionResponse {
311
+ const choice = data.choices?.[0]
312
+ const message = choice?.message
313
+
314
+ const content: string = message?.content ?? ''
315
+ const toolCalls: ToolCall[] = []
316
+
317
+ if (message?.tool_calls) {
318
+ for (const tc of message.tool_calls) {
319
+ let args: Record<string, unknown> = {}
320
+ try {
321
+ args = JSON.parse(tc.function.arguments)
322
+ } catch {
323
+ // Invalid JSON from the model — pass as-is in a wrapper
324
+ args = { _raw: tc.function.arguments }
325
+ }
326
+
327
+ toolCalls.push({
328
+ id: tc.id,
329
+ name: tc.function.name,
330
+ arguments: args,
331
+ })
332
+ }
333
+ }
334
+
335
+ const usage: Usage = {
336
+ inputTokens: data.usage?.prompt_tokens ?? 0,
337
+ outputTokens: data.usage?.completion_tokens ?? 0,
338
+ totalTokens: data.usage?.total_tokens ?? 0,
339
+ }
340
+
341
+ let stopReason: CompletionResponse['stopReason'] = 'end'
342
+ switch (choice?.finish_reason) {
343
+ case 'tool_calls':
344
+ stopReason = 'tool_use'
345
+ break
346
+ case 'length':
347
+ stopReason = 'max_tokens'
348
+ break
349
+ case 'stop':
350
+ stopReason = 'end'
351
+ break
352
+ }
353
+
354
+ return {
355
+ id: data.id ?? '',
356
+ content,
357
+ toolCalls,
358
+ stopReason,
359
+ usage,
360
+ raw: data,
361
+ }
362
+ }
363
+
364
+ /**
365
+ * OpenAI's strict structured output requires:
366
+ * - All properties listed in `required`
367
+ * - Optional properties use nullable types instead
368
+ * - `additionalProperties: false` on every object
369
+ */
370
+ /**
371
+ * Check if a schema is compatible with OpenAI's strict structured output.
372
+ * Record types (object with additionalProperties != false) are not supported.
373
+ */
374
+ private isStrictCompatible(schema: Record<string, unknown>): boolean {
375
+ if (schema == null || typeof schema !== 'object') return true
376
+
377
+ // Record type: object with additionalProperties that isn't false
378
+ if (
379
+ schema.type === 'object' &&
380
+ schema.additionalProperties !== undefined &&
381
+ schema.additionalProperties !== false
382
+ ) {
383
+ return false
384
+ }
385
+
386
+ // Check nested properties
387
+ if (schema.properties) {
388
+ for (const prop of Object.values(schema.properties as Record<string, any>)) {
389
+ if (!this.isStrictCompatible(prop)) return false
390
+ }
391
+ }
392
+
393
+ // Check array items
394
+ if (schema.items && !this.isStrictCompatible(schema.items as Record<string, unknown>))
395
+ return false
396
+
397
+ // Check anyOf / oneOf
398
+ for (const key of ['anyOf', 'oneOf'] as const) {
399
+ if (Array.isArray(schema[key])) {
400
+ for (const s of schema[key] as any[]) {
401
+ if (!this.isStrictCompatible(s)) return false
402
+ }
403
+ }
404
+ }
405
+
406
+ return true
407
+ }
408
+
409
+ /** Keywords OpenAI strict mode does NOT support. */
410
+ private static UNSUPPORTED_KEYWORDS = new Set([
411
+ 'propertyNames',
412
+ 'patternProperties',
413
+ 'if',
414
+ 'then',
415
+ 'else',
416
+ 'not',
417
+ 'contains',
418
+ 'minItems',
419
+ 'maxItems',
420
+ 'minProperties',
421
+ 'maxProperties',
422
+ 'minLength',
423
+ 'maxLength',
424
+ 'minimum',
425
+ 'maximum',
426
+ 'exclusiveMinimum',
427
+ 'exclusiveMaximum',
428
+ 'multipleOf',
429
+ 'pattern',
430
+ 'format',
431
+ 'contentEncoding',
432
+ 'contentMediaType',
433
+ 'unevaluatedProperties',
434
+ '$schema',
435
+ ])
436
+
437
+ private normalizeSchemaForOpenAI(schema: Record<string, unknown>): Record<string, unknown> {
438
+ if (schema == null || typeof schema !== 'object') return schema
439
+
440
+ // Strip unsupported keywords
441
+ const result: Record<string, unknown> = {}
442
+ for (const [k, v] of Object.entries(schema)) {
443
+ if (!OpenAIProvider.UNSUPPORTED_KEYWORDS.has(k)) {
444
+ result[k] = v
445
+ }
446
+ }
447
+
448
+ // Handle object types with explicit properties
449
+ if (result.type === 'object' && result.properties) {
450
+ const props = result.properties as Record<string, any>
451
+ const currentRequired = new Set(
452
+ Array.isArray(result.required) ? (result.required as string[]) : []
453
+ )
454
+
455
+ const normalizedProps: Record<string, any> = {}
456
+
457
+ for (const [key, prop] of Object.entries(props)) {
458
+ let normalizedProp = this.normalizeSchemaForOpenAI(prop)
459
+
460
+ // If property is not required, make it nullable and add to required
461
+ if (!currentRequired.has(key)) {
462
+ normalizedProp = this.makeNullable(normalizedProp)
463
+ }
464
+
465
+ normalizedProps[key] = normalizedProp
466
+ }
467
+
468
+ result.properties = normalizedProps
469
+ result.required = Object.keys(normalizedProps)
470
+ result.additionalProperties = false
471
+ }
472
+
473
+ // Handle arrays
474
+ if (result.type === 'array' && result.items) {
475
+ result.items = this.normalizeSchemaForOpenAI(result.items as Record<string, unknown>)
476
+ }
477
+
478
+ // Handle anyOf / oneOf
479
+ for (const key of ['anyOf', 'oneOf'] as const) {
480
+ if (Array.isArray(result[key])) {
481
+ result[key] = (result[key] as any[]).map((s: any) => this.normalizeSchemaForOpenAI(s))
482
+ }
483
+ }
484
+
485
+ return result
486
+ }
487
+
488
+ private makeNullable(schema: Record<string, unknown>): Record<string, unknown> {
489
+ // Already nullable
490
+ if (Array.isArray(schema.type) && schema.type.includes('null')) return schema
491
+
492
+ // Has anyOf — add null variant
493
+ if (Array.isArray(schema.anyOf)) {
494
+ const hasNull = schema.anyOf.some((s: any) => s.type === 'null')
495
+ if (!hasNull) {
496
+ return { ...schema, anyOf: [...schema.anyOf, { type: 'null' }] }
497
+ }
498
+ return schema
499
+ }
500
+
501
+ // Simple type — wrap in anyOf with null
502
+ if (schema.type) {
503
+ const { type, ...rest } = schema
504
+ return { anyOf: [{ type, ...rest }, { type: 'null' }] }
505
+ }
506
+
507
+ return schema
508
+ }
509
+ }