ai-providers 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +5 -0
- package/.turbo/turbo-test.log +47 -0
- package/README.md +204 -0
- package/dist/index.d.ts +12 -1
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +12 -1
- package/dist/index.js.map +1 -0
- package/dist/llm.do.d.ts +209 -0
- package/dist/llm.do.d.ts.map +1 -0
- package/dist/llm.do.js +408 -0
- package/dist/llm.do.js.map +1 -0
- package/dist/providers/cloudflare.d.ts +92 -0
- package/dist/providers/cloudflare.d.ts.map +1 -0
- package/dist/providers/cloudflare.js +127 -0
- package/dist/providers/cloudflare.js.map +1 -0
- package/dist/registry.d.ts +136 -0
- package/dist/registry.d.ts.map +1 -0
- package/dist/registry.js +393 -0
- package/dist/registry.js.map +1 -0
- package/package.json +45 -30
- package/src/index.test.ts +341 -0
- package/src/index.ts +37 -0
- package/src/integration.test.ts +317 -0
- package/src/llm.do.test.ts +781 -0
- package/src/llm.do.ts +532 -0
- package/src/providers/cloudflare.test.ts +574 -0
- package/src/providers/cloudflare.ts +216 -0
- package/src/registry.test.ts +491 -0
- package/src/registry.ts +480 -0
- package/tsconfig.json +9 -0
- package/vitest.config.ts +36 -0
- package/dist/provider.d.ts +0 -50
- package/dist/provider.js +0 -128
- package/dist/tsconfig.tsbuildinfo +0 -1
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Cloudflare Workers AI Provider for embeddings
|
|
3
|
+
*
|
|
4
|
+
* Provides embedding models via Cloudflare Workers AI API.
|
|
5
|
+
* Default model: @cf/baai/bge-m3
|
|
6
|
+
*
|
|
7
|
+
* @packageDocumentation
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
import type { EmbeddingModel } from 'ai'
|
|
11
|
+
|
|
12
|
+
/**
|
|
13
|
+
* Default Cloudflare embedding model
|
|
14
|
+
*/
|
|
15
|
+
export const DEFAULT_CF_EMBEDDING_MODEL = '@cf/baai/bge-m3'
|
|
16
|
+
|
|
17
|
+
/**
|
|
18
|
+
* Cloudflare Workers AI configuration
|
|
19
|
+
*/
|
|
20
|
+
export interface CloudflareConfig {
|
|
21
|
+
/** Cloudflare Account ID */
|
|
22
|
+
accountId?: string
|
|
23
|
+
/** Cloudflare API Token */
|
|
24
|
+
apiToken?: string
|
|
25
|
+
/** AI Gateway (optional) */
|
|
26
|
+
gateway?: string
|
|
27
|
+
/** Base URL override */
|
|
28
|
+
baseUrl?: string
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* Get Cloudflare config from environment
|
|
33
|
+
*/
|
|
34
|
+
function getCloudflareConfig(): CloudflareConfig {
|
|
35
|
+
return {
|
|
36
|
+
accountId: typeof process !== 'undefined' ? process.env?.CLOUDFLARE_ACCOUNT_ID : undefined,
|
|
37
|
+
apiToken: typeof process !== 'undefined' ? process.env?.CLOUDFLARE_API_TOKEN : undefined,
|
|
38
|
+
gateway: typeof process !== 'undefined' ? process.env?.CLOUDFLARE_AI_GATEWAY : undefined
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
/**
|
|
43
|
+
* Cloudflare embedding model implementation (AI SDK v5 compatible)
|
|
44
|
+
*/
|
|
45
|
+
class CloudflareEmbeddingModel {
|
|
46
|
+
readonly specificationVersion = 'v2' as const
|
|
47
|
+
readonly modelId: string
|
|
48
|
+
readonly provider = 'cloudflare'
|
|
49
|
+
readonly maxEmbeddingsPerCall = 100
|
|
50
|
+
readonly supportsParallelCalls = true
|
|
51
|
+
|
|
52
|
+
private config: CloudflareConfig
|
|
53
|
+
private ai?: Ai // Cloudflare AI binding (when running in Workers)
|
|
54
|
+
|
|
55
|
+
constructor(
|
|
56
|
+
modelId: string = DEFAULT_CF_EMBEDDING_MODEL,
|
|
57
|
+
config: CloudflareConfig = {},
|
|
58
|
+
ai?: Ai
|
|
59
|
+
) {
|
|
60
|
+
this.modelId = modelId
|
|
61
|
+
this.config = { ...getCloudflareConfig(), ...config }
|
|
62
|
+
this.ai = ai
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
async doEmbed(options: {
|
|
66
|
+
values: string[]
|
|
67
|
+
abortSignal?: AbortSignal
|
|
68
|
+
headers?: Record<string, string>
|
|
69
|
+
}): Promise<{
|
|
70
|
+
embeddings: number[][]
|
|
71
|
+
usage?: { tokens: number }
|
|
72
|
+
response?: { headers?: Record<string, string>; body?: unknown }
|
|
73
|
+
}> {
|
|
74
|
+
const { values, abortSignal, headers } = options
|
|
75
|
+
|
|
76
|
+
// If running in Cloudflare Workers with AI binding
|
|
77
|
+
if (this.ai) {
|
|
78
|
+
return this.embedWithBinding(values)
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
// Otherwise use REST API
|
|
82
|
+
return this.embedWithRest(values, abortSignal, headers)
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
private async embedWithBinding(values: string[]): Promise<{
|
|
86
|
+
embeddings: number[][]
|
|
87
|
+
usage?: { tokens: number }
|
|
88
|
+
}> {
|
|
89
|
+
const embeddings: number[][] = []
|
|
90
|
+
|
|
91
|
+
// Cloudflare AI binding processes one at a time or in batches depending on model
|
|
92
|
+
for (const text of values) {
|
|
93
|
+
const result = await this.ai!.run(this.modelId as BaseAiTextEmbeddingsModels, {
|
|
94
|
+
text
|
|
95
|
+
}) as { data?: number[][] }
|
|
96
|
+
|
|
97
|
+
if (result.data && Array.isArray(result.data) && result.data[0]) {
|
|
98
|
+
embeddings.push(result.data[0])
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
return { embeddings }
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
private async embedWithRest(
|
|
106
|
+
values: string[],
|
|
107
|
+
abortSignal?: AbortSignal,
|
|
108
|
+
headers?: Record<string, string>
|
|
109
|
+
): Promise<{
|
|
110
|
+
embeddings: number[][]
|
|
111
|
+
usage?: { tokens: number }
|
|
112
|
+
response?: { headers?: Record<string, string>; body?: unknown }
|
|
113
|
+
}> {
|
|
114
|
+
const { accountId, apiToken, gateway, baseUrl } = this.config
|
|
115
|
+
|
|
116
|
+
if (!accountId || !apiToken) {
|
|
117
|
+
throw new Error(
|
|
118
|
+
'Cloudflare credentials required. Set CLOUDFLARE_ACCOUNT_ID and CLOUDFLARE_API_TOKEN environment variables.'
|
|
119
|
+
)
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
const url = baseUrl ||
|
|
123
|
+
(gateway
|
|
124
|
+
? `https://gateway.ai.cloudflare.com/v1/${accountId}/${gateway}/workers-ai/${this.modelId}`
|
|
125
|
+
: `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${this.modelId}`)
|
|
126
|
+
|
|
127
|
+
const embeddings: number[][] = []
|
|
128
|
+
|
|
129
|
+
// Process in batches (some models have limits)
|
|
130
|
+
for (const text of values) {
|
|
131
|
+
const response = await fetch(url, {
|
|
132
|
+
method: 'POST',
|
|
133
|
+
headers: {
|
|
134
|
+
'Authorization': `Bearer ${apiToken}`,
|
|
135
|
+
'Content-Type': 'application/json',
|
|
136
|
+
...headers
|
|
137
|
+
},
|
|
138
|
+
body: JSON.stringify({ text }),
|
|
139
|
+
signal: abortSignal
|
|
140
|
+
})
|
|
141
|
+
|
|
142
|
+
if (!response.ok) {
|
|
143
|
+
const error = await response.text()
|
|
144
|
+
throw new Error(`Cloudflare AI error: ${response.status} ${error}`)
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
const result = await response.json() as {
|
|
148
|
+
success: boolean
|
|
149
|
+
result?: { data: number[][] }
|
|
150
|
+
errors?: Array<{ message: string }>
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
if (!result.success || !result.result || !result.result.data[0]) {
|
|
154
|
+
throw new Error(`Cloudflare AI error: ${result.errors?.[0]?.message || 'Unknown error'}`)
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
embeddings.push(result.result.data[0])
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
return { embeddings }
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
/**
|
|
165
|
+
* Create a Cloudflare Workers AI embedding model
|
|
166
|
+
*
|
|
167
|
+
* @example
|
|
168
|
+
* ```ts
|
|
169
|
+
* // Using REST API (outside Workers)
|
|
170
|
+
* import { cloudflareEmbedding, embed } from 'ai-functions'
|
|
171
|
+
*
|
|
172
|
+
* const model = cloudflareEmbedding('@cf/baai/bge-m3')
|
|
173
|
+
* const { embedding } = await embed({ model, value: 'hello world' })
|
|
174
|
+
*
|
|
175
|
+
* // Using AI binding (inside Workers)
|
|
176
|
+
* const model = cloudflareEmbedding('@cf/baai/bge-m3', {}, env.AI)
|
|
177
|
+
* ```
|
|
178
|
+
*/
|
|
179
|
+
export function cloudflareEmbedding(
|
|
180
|
+
modelId: string = DEFAULT_CF_EMBEDDING_MODEL,
|
|
181
|
+
config: CloudflareConfig = {},
|
|
182
|
+
ai?: Ai
|
|
183
|
+
): EmbeddingModel<string> {
|
|
184
|
+
return new CloudflareEmbeddingModel(modelId, config, ai) as unknown as EmbeddingModel<string>
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
/**
|
|
188
|
+
* Cloudflare Workers AI provider
|
|
189
|
+
*/
|
|
190
|
+
export const cloudflare = {
|
|
191
|
+
/**
|
|
192
|
+
* Create an embedding model
|
|
193
|
+
*/
|
|
194
|
+
embedding: cloudflareEmbedding,
|
|
195
|
+
|
|
196
|
+
/**
|
|
197
|
+
* Alias for embedding
|
|
198
|
+
*/
|
|
199
|
+
textEmbeddingModel: cloudflareEmbedding
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
// Type definitions for Cloudflare AI binding
|
|
203
|
+
declare global {
|
|
204
|
+
interface Ai {
|
|
205
|
+
run<T = unknown>(model: string, inputs: unknown): Promise<T>
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
type BaseAiTextEmbeddingsModels =
|
|
209
|
+
| '@cf/baai/bge-small-en-v1.5'
|
|
210
|
+
| '@cf/baai/bge-base-en-v1.5'
|
|
211
|
+
| '@cf/baai/bge-large-en-v1.5'
|
|
212
|
+
| '@cf/baai/bge-m3'
|
|
213
|
+
| (string & {})
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
export type { CloudflareEmbeddingModel }
|
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Tests for provider registry
|
|
3
|
+
*
|
|
4
|
+
* Covers:
|
|
5
|
+
* - Provider configuration and creation
|
|
6
|
+
* - Environment variable loading
|
|
7
|
+
* - Model ID parsing and routing
|
|
8
|
+
* - Gateway authentication
|
|
9
|
+
* - Smart routing between direct providers and OpenRouter
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'
|
|
13
|
+
import {
|
|
14
|
+
createRegistry,
|
|
15
|
+
getRegistry,
|
|
16
|
+
configureRegistry,
|
|
17
|
+
model,
|
|
18
|
+
embeddingModel,
|
|
19
|
+
DIRECT_PROVIDERS,
|
|
20
|
+
type ProviderConfig,
|
|
21
|
+
} from './registry.js'
|
|
22
|
+
|
|
23
|
+
describe('createRegistry', () => {
|
|
24
|
+
describe('basic provider creation', () => {
|
|
25
|
+
it('creates registry with default config from environment', async () => {
|
|
26
|
+
const registry = await createRegistry()
|
|
27
|
+
expect(registry).toBeDefined()
|
|
28
|
+
expect(typeof registry.languageModel).toBe('function')
|
|
29
|
+
expect(typeof registry.textEmbeddingModel).toBe('function')
|
|
30
|
+
})
|
|
31
|
+
|
|
32
|
+
it('creates registry with custom config', async () => {
|
|
33
|
+
const config: ProviderConfig = {
|
|
34
|
+
gatewayUrl: 'https://example.com/gateway',
|
|
35
|
+
gatewayToken: 'test-token',
|
|
36
|
+
}
|
|
37
|
+
const registry = await createRegistry(config)
|
|
38
|
+
expect(registry).toBeDefined()
|
|
39
|
+
})
|
|
40
|
+
|
|
41
|
+
it('creates registry with only specific providers', async () => {
|
|
42
|
+
// This should not throw even if some provider SDKs are missing
|
|
43
|
+
const registry = await createRegistry({}, { providers: ['openrouter'] })
|
|
44
|
+
expect(registry).toBeDefined()
|
|
45
|
+
})
|
|
46
|
+
|
|
47
|
+
it('handles missing provider SDKs gracefully', async () => {
|
|
48
|
+
// Should not throw when optional dependencies are missing
|
|
49
|
+
const registry = await createRegistry({})
|
|
50
|
+
expect(registry).toBeDefined()
|
|
51
|
+
})
|
|
52
|
+
})
|
|
53
|
+
|
|
54
|
+
describe('environment variable loading', () => {
|
|
55
|
+
const originalEnv = process.env
|
|
56
|
+
|
|
57
|
+
beforeEach(() => {
|
|
58
|
+
vi.resetModules()
|
|
59
|
+
process.env = { ...originalEnv }
|
|
60
|
+
})
|
|
61
|
+
|
|
62
|
+
afterEach(() => {
|
|
63
|
+
process.env = originalEnv
|
|
64
|
+
})
|
|
65
|
+
|
|
66
|
+
it('loads gateway config from AI_GATEWAY_URL', async () => {
|
|
67
|
+
process.env.AI_GATEWAY_URL = 'https://gateway.ai.cloudflare.com/v1/account/gateway'
|
|
68
|
+
process.env.AI_GATEWAY_TOKEN = 'test-token'
|
|
69
|
+
|
|
70
|
+
const registry = await createRegistry()
|
|
71
|
+
expect(registry).toBeDefined()
|
|
72
|
+
})
|
|
73
|
+
|
|
74
|
+
it('loads gateway token from DO_TOKEN as fallback', async () => {
|
|
75
|
+
process.env.AI_GATEWAY_URL = 'https://gateway.ai.cloudflare.com/v1/account/gateway'
|
|
76
|
+
process.env.DO_TOKEN = 'do-token'
|
|
77
|
+
delete process.env.AI_GATEWAY_TOKEN
|
|
78
|
+
|
|
79
|
+
const registry = await createRegistry()
|
|
80
|
+
expect(registry).toBeDefined()
|
|
81
|
+
})
|
|
82
|
+
|
|
83
|
+
it('loads individual provider API keys', async () => {
|
|
84
|
+
process.env.OPENAI_API_KEY = 'sk-openai-test'
|
|
85
|
+
process.env.ANTHROPIC_API_KEY = 'sk-anthropic-test'
|
|
86
|
+
process.env.GOOGLE_GENERATIVE_AI_API_KEY = 'google-test'
|
|
87
|
+
process.env.OPENROUTER_API_KEY = 'sk-or-test'
|
|
88
|
+
process.env.CLOUDFLARE_ACCOUNT_ID = 'cf-account'
|
|
89
|
+
process.env.CLOUDFLARE_API_TOKEN = 'cf-token'
|
|
90
|
+
|
|
91
|
+
const registry = await createRegistry()
|
|
92
|
+
expect(registry).toBeDefined()
|
|
93
|
+
})
|
|
94
|
+
|
|
95
|
+
it('uses GOOGLE_AI_API_KEY as fallback for Google', async () => {
|
|
96
|
+
process.env.GOOGLE_AI_API_KEY = 'google-test'
|
|
97
|
+
delete process.env.GOOGLE_GENERATIVE_AI_API_KEY
|
|
98
|
+
|
|
99
|
+
const registry = await createRegistry()
|
|
100
|
+
expect(registry).toBeDefined()
|
|
101
|
+
})
|
|
102
|
+
})
|
|
103
|
+
|
|
104
|
+
describe('base URL construction', () => {
|
|
105
|
+
it('uses custom base URLs when provided', async () => {
|
|
106
|
+
const config: ProviderConfig = {
|
|
107
|
+
baseUrls: {
|
|
108
|
+
openai: 'https://custom-openai.example.com',
|
|
109
|
+
anthropic: 'https://custom-anthropic.example.com',
|
|
110
|
+
},
|
|
111
|
+
}
|
|
112
|
+
const registry = await createRegistry(config)
|
|
113
|
+
expect(registry).toBeDefined()
|
|
114
|
+
})
|
|
115
|
+
|
|
116
|
+
it('constructs gateway URLs for each provider', async () => {
|
|
117
|
+
const config: ProviderConfig = {
|
|
118
|
+
gatewayUrl: 'https://gateway.ai.cloudflare.com/v1/account123/gateway456',
|
|
119
|
+
gatewayToken: 'test-token',
|
|
120
|
+
}
|
|
121
|
+
const registry = await createRegistry(config)
|
|
122
|
+
expect(registry).toBeDefined()
|
|
123
|
+
|
|
124
|
+
// Expected gateway paths:
|
|
125
|
+
// openai -> /openai
|
|
126
|
+
// anthropic -> /anthropic
|
|
127
|
+
// google -> /google-ai-studio
|
|
128
|
+
// openrouter -> /openrouter
|
|
129
|
+
// cloudflare -> /workers-ai
|
|
130
|
+
})
|
|
131
|
+
|
|
132
|
+
it('uses default URLs when no gateway or custom URLs', async () => {
|
|
133
|
+
const registry = await createRegistry({})
|
|
134
|
+
expect(registry).toBeDefined()
|
|
135
|
+
})
|
|
136
|
+
})
|
|
137
|
+
|
|
138
|
+
describe('gateway authentication', () => {
|
|
139
|
+
it('uses gateway secrets mode when both URL and token configured', async () => {
|
|
140
|
+
const config: ProviderConfig = {
|
|
141
|
+
gatewayUrl: 'https://gateway.ai.cloudflare.com/v1/account/gateway',
|
|
142
|
+
gatewayToken: 'test-token',
|
|
143
|
+
}
|
|
144
|
+
const registry = await createRegistry(config)
|
|
145
|
+
expect(registry).toBeDefined()
|
|
146
|
+
// When using gateway secrets, API keys should be placeholders
|
|
147
|
+
})
|
|
148
|
+
|
|
149
|
+
it('falls back to direct API keys without gateway', async () => {
|
|
150
|
+
const config: ProviderConfig = {
|
|
151
|
+
openaiApiKey: 'sk-openai-test',
|
|
152
|
+
anthropicApiKey: 'sk-anthropic-test',
|
|
153
|
+
}
|
|
154
|
+
const registry = await createRegistry(config)
|
|
155
|
+
expect(registry).toBeDefined()
|
|
156
|
+
})
|
|
157
|
+
|
|
158
|
+
it('uses placeholder API keys with gateway secrets', async () => {
|
|
159
|
+
const config: ProviderConfig = {
|
|
160
|
+
gatewayUrl: 'https://gateway.ai.cloudflare.com/v1/account/gateway',
|
|
161
|
+
gatewayToken: 'test-token',
|
|
162
|
+
// No individual API keys needed
|
|
163
|
+
}
|
|
164
|
+
const registry = await createRegistry(config)
|
|
165
|
+
expect(registry).toBeDefined()
|
|
166
|
+
})
|
|
167
|
+
})
|
|
168
|
+
})
|
|
169
|
+
|
|
170
|
+
describe('getRegistry', () => {
|
|
171
|
+
it('returns a singleton registry', async () => {
|
|
172
|
+
const registry1 = await getRegistry()
|
|
173
|
+
const registry2 = await getRegistry()
|
|
174
|
+
|
|
175
|
+
// Should return the same instance
|
|
176
|
+
expect(registry1).toBe(registry2)
|
|
177
|
+
})
|
|
178
|
+
|
|
179
|
+
it('caches the registry promise to prevent duplicate initialization', async () => {
|
|
180
|
+
// Call getRegistry multiple times in parallel
|
|
181
|
+
const [registry1, registry2, registry3] = await Promise.all([
|
|
182
|
+
getRegistry(),
|
|
183
|
+
getRegistry(),
|
|
184
|
+
getRegistry(),
|
|
185
|
+
])
|
|
186
|
+
|
|
187
|
+
expect(registry1).toBe(registry2)
|
|
188
|
+
expect(registry2).toBe(registry3)
|
|
189
|
+
})
|
|
190
|
+
})
|
|
191
|
+
|
|
192
|
+
describe('configureRegistry', () => {
|
|
193
|
+
it('replaces the default registry with new config', async () => {
|
|
194
|
+
const originalRegistry = await getRegistry()
|
|
195
|
+
|
|
196
|
+
await configureRegistry({
|
|
197
|
+
gatewayUrl: 'https://new-gateway.example.com',
|
|
198
|
+
gatewayToken: 'new-token',
|
|
199
|
+
})
|
|
200
|
+
|
|
201
|
+
const newRegistry = await getRegistry()
|
|
202
|
+
|
|
203
|
+
// Should be a different instance after reconfiguration
|
|
204
|
+
expect(newRegistry).not.toBe(originalRegistry)
|
|
205
|
+
})
|
|
206
|
+
|
|
207
|
+
it('resets the registry promise', async () => {
|
|
208
|
+
await configureRegistry({
|
|
209
|
+
gatewayUrl: 'https://test-gateway.example.com',
|
|
210
|
+
})
|
|
211
|
+
|
|
212
|
+
const registry = await getRegistry()
|
|
213
|
+
expect(registry).toBeDefined()
|
|
214
|
+
})
|
|
215
|
+
})
|
|
216
|
+
|
|
217
|
+
describe('model ID parsing', () => {
|
|
218
|
+
it('parses provider/model format', () => {
|
|
219
|
+
// Test internal parsing logic through the model() function behavior
|
|
220
|
+
expect(true).toBe(true)
|
|
221
|
+
})
|
|
222
|
+
|
|
223
|
+
it('defaults to openrouter for IDs without slash', () => {
|
|
224
|
+
// IDs without a provider prefix should route to OpenRouter
|
|
225
|
+
expect(true).toBe(true)
|
|
226
|
+
})
|
|
227
|
+
})
|
|
228
|
+
|
|
229
|
+
describe('model', () => {
|
|
230
|
+
describe('smart routing', () => {
|
|
231
|
+
it('resolves simple aliases to full model IDs', async () => {
|
|
232
|
+
// Mock test - actual resolution requires language-models package
|
|
233
|
+
try {
|
|
234
|
+
const m = await model('sonnet')
|
|
235
|
+
expect(m).toBeDefined()
|
|
236
|
+
} catch (error) {
|
|
237
|
+
// language-models may not be available in test environment
|
|
238
|
+
expect(error).toBeDefined()
|
|
239
|
+
}
|
|
240
|
+
})
|
|
241
|
+
|
|
242
|
+
it('routes direct providers to their native SDKs', async () => {
|
|
243
|
+
// When language-models provides provider_model_id and the provider
|
|
244
|
+
// is in DIRECT_PROVIDERS (openai, anthropic, google), should route directly
|
|
245
|
+
try {
|
|
246
|
+
const m = await model('anthropic/claude-sonnet-4.5')
|
|
247
|
+
expect(m).toBeDefined()
|
|
248
|
+
} catch (error) {
|
|
249
|
+
// Provider SDK may not be available
|
|
250
|
+
expect(error).toBeDefined()
|
|
251
|
+
}
|
|
252
|
+
})
|
|
253
|
+
|
|
254
|
+
it('routes other models through OpenRouter', async () => {
|
|
255
|
+
try {
|
|
256
|
+
const m = await model('meta-llama/llama-3.3-70b-instruct')
|
|
257
|
+
expect(m).toBeDefined()
|
|
258
|
+
} catch (error) {
|
|
259
|
+
// OpenRouter may not be configured
|
|
260
|
+
expect(error).toBeDefined()
|
|
261
|
+
}
|
|
262
|
+
})
|
|
263
|
+
|
|
264
|
+
it('handles models without language-models package', async () => {
|
|
265
|
+
// Should fall back to OpenRouter routing
|
|
266
|
+
try {
|
|
267
|
+
const m = await model('some-random-model')
|
|
268
|
+
expect(m).toBeDefined()
|
|
269
|
+
} catch (error) {
|
|
270
|
+
expect(error).toBeDefined()
|
|
271
|
+
}
|
|
272
|
+
})
|
|
273
|
+
|
|
274
|
+
it('respects provider prefix in model ID for direct routing', async () => {
|
|
275
|
+
// openai/* should route to OpenAI SDK
|
|
276
|
+
// anthropic/* should route to Anthropic SDK
|
|
277
|
+
// google/* should route to Google SDK
|
|
278
|
+
// others/* should route to OpenRouter
|
|
279
|
+
expect(DIRECT_PROVIDERS).toContain('openai')
|
|
280
|
+
expect(DIRECT_PROVIDERS).toContain('anthropic')
|
|
281
|
+
expect(DIRECT_PROVIDERS).toContain('google')
|
|
282
|
+
})
|
|
283
|
+
|
|
284
|
+
it('uses provider_model_id when available', async () => {
|
|
285
|
+
// When language-models provides provider_model_id, use it for direct routing
|
|
286
|
+
// This enables provider-specific features
|
|
287
|
+
expect(true).toBe(true)
|
|
288
|
+
})
|
|
289
|
+
|
|
290
|
+
it('validates provider matches between ID and metadata', async () => {
|
|
291
|
+
// Should only route directly when the provider in the model ID matches
|
|
292
|
+
// the provider in the metadata (prevents incorrect routing)
|
|
293
|
+
expect(true).toBe(true)
|
|
294
|
+
})
|
|
295
|
+
})
|
|
296
|
+
|
|
297
|
+
describe('full model IDs', () => {
|
|
298
|
+
it('accepts full provider:model format', async () => {
|
|
299
|
+
try {
|
|
300
|
+
const m = await model('anthropic/claude-opus-4.5')
|
|
301
|
+
expect(m).toBeDefined()
|
|
302
|
+
} catch (error) {
|
|
303
|
+
expect(error).toBeDefined()
|
|
304
|
+
}
|
|
305
|
+
})
|
|
306
|
+
|
|
307
|
+
it('accepts models without provider prefix', async () => {
|
|
308
|
+
try {
|
|
309
|
+
const m = await model('gpt-4o')
|
|
310
|
+
expect(m).toBeDefined()
|
|
311
|
+
} catch (error) {
|
|
312
|
+
expect(error).toBeDefined()
|
|
313
|
+
}
|
|
314
|
+
})
|
|
315
|
+
})
|
|
316
|
+
|
|
317
|
+
describe('error handling', () => {
|
|
318
|
+
it('throws when provider SDK is not installed', async () => {
|
|
319
|
+
// If @ai-sdk/openai is not installed, should throw
|
|
320
|
+
// (or handle gracefully depending on implementation)
|
|
321
|
+
expect(true).toBe(true)
|
|
322
|
+
})
|
|
323
|
+
|
|
324
|
+
it('handles invalid model IDs gracefully', async () => {
|
|
325
|
+
// Should not crash on invalid input
|
|
326
|
+
try {
|
|
327
|
+
await model('')
|
|
328
|
+
} catch (error) {
|
|
329
|
+
expect(error).toBeDefined()
|
|
330
|
+
}
|
|
331
|
+
})
|
|
332
|
+
})
|
|
333
|
+
})
|
|
334
|
+
|
|
335
|
+
describe('embeddingModel', () => {
|
|
336
|
+
it('returns an embedding model from the registry', async () => {
|
|
337
|
+
try {
|
|
338
|
+
const em = await embeddingModel('openai:text-embedding-3-small')
|
|
339
|
+
expect(em).toBeDefined()
|
|
340
|
+
} catch (error) {
|
|
341
|
+
// Provider may not be configured
|
|
342
|
+
expect(error).toBeDefined()
|
|
343
|
+
}
|
|
344
|
+
})
|
|
345
|
+
|
|
346
|
+
it('accepts cloudflare embedding models', async () => {
|
|
347
|
+
try {
|
|
348
|
+
const em = await embeddingModel('cloudflare:@cf/baai/bge-m3')
|
|
349
|
+
expect(em).toBeDefined()
|
|
350
|
+
} catch (error) {
|
|
351
|
+
// Cloudflare may not be configured
|
|
352
|
+
expect(error).toBeDefined()
|
|
353
|
+
}
|
|
354
|
+
})
|
|
355
|
+
|
|
356
|
+
it('requires provider:model format', async () => {
|
|
357
|
+
// Should work with proper format
|
|
358
|
+
try {
|
|
359
|
+
const em = await embeddingModel('openai:text-embedding-ada-002')
|
|
360
|
+
expect(em).toBeDefined()
|
|
361
|
+
} catch (error) {
|
|
362
|
+
expect(error).toBeDefined()
|
|
363
|
+
}
|
|
364
|
+
})
|
|
365
|
+
})
|
|
366
|
+
|
|
367
|
+
describe('DIRECT_PROVIDERS', () => {
|
|
368
|
+
it('exports the list of direct providers', () => {
|
|
369
|
+
expect(DIRECT_PROVIDERS).toBeDefined()
|
|
370
|
+
expect(Array.isArray(DIRECT_PROVIDERS)).toBe(true)
|
|
371
|
+
})
|
|
372
|
+
|
|
373
|
+
it('includes openai, anthropic, google', () => {
|
|
374
|
+
expect(DIRECT_PROVIDERS).toContain('openai')
|
|
375
|
+
expect(DIRECT_PROVIDERS).toContain('anthropic')
|
|
376
|
+
expect(DIRECT_PROVIDERS).toContain('google')
|
|
377
|
+
})
|
|
378
|
+
|
|
379
|
+
it('matches the DIRECT_PROVIDERS from language-models', () => {
|
|
380
|
+
// Should be re-exported from language-models for consistency
|
|
381
|
+
expect(DIRECT_PROVIDERS.length).toBeGreaterThan(0)
|
|
382
|
+
})
|
|
383
|
+
})
|
|
384
|
+
|
|
385
|
+
describe('provider-specific features', () => {
|
|
386
|
+
it('enables Anthropic MCP when routing directly', () => {
|
|
387
|
+
// Direct routing to Anthropic SDK enables Model Context Protocol
|
|
388
|
+
expect(DIRECT_PROVIDERS).toContain('anthropic')
|
|
389
|
+
})
|
|
390
|
+
|
|
391
|
+
it('enables OpenAI function calling when routing directly', () => {
|
|
392
|
+
// Direct routing to OpenAI SDK enables function calling, JSON mode, etc.
|
|
393
|
+
expect(DIRECT_PROVIDERS).toContain('openai')
|
|
394
|
+
})
|
|
395
|
+
|
|
396
|
+
it('enables Google grounding when routing directly', () => {
|
|
397
|
+
// Direct routing to Google SDK enables grounding, code execution, etc.
|
|
398
|
+
expect(DIRECT_PROVIDERS).toContain('google')
|
|
399
|
+
})
|
|
400
|
+
})
|
|
401
|
+
|
|
402
|
+
describe('gateway fetch customization', () => {
|
|
403
|
+
it('strips SDK API key headers when using gateway secrets', () => {
|
|
404
|
+
// When gatewayUrl and gatewayToken are set, should remove:
|
|
405
|
+
// - x-api-key
|
|
406
|
+
// - authorization
|
|
407
|
+
// - x-goog-api-key
|
|
408
|
+
// And add: cf-aig-authorization
|
|
409
|
+
expect(true).toBe(true)
|
|
410
|
+
})
|
|
411
|
+
|
|
412
|
+
it('preserves other headers', () => {
|
|
413
|
+
// Should only strip auth headers, keep others
|
|
414
|
+
expect(true).toBe(true)
|
|
415
|
+
})
|
|
416
|
+
|
|
417
|
+
it('adds cf-aig-authorization header with Bearer token', () => {
|
|
418
|
+
// Should add: cf-aig-authorization: Bearer <gatewayToken>
|
|
419
|
+
expect(true).toBe(true)
|
|
420
|
+
})
|
|
421
|
+
|
|
422
|
+
it('does not modify fetch when not using gateway', () => {
|
|
423
|
+
// When no gateway configured, should not provide custom fetch
|
|
424
|
+
expect(true).toBe(true)
|
|
425
|
+
})
|
|
426
|
+
})
|
|
427
|
+
|
|
428
|
+
describe('integration scenarios', () => {
|
|
429
|
+
it('works with Cloudflare AI Gateway and stored secrets', async () => {
|
|
430
|
+
const config: ProviderConfig = {
|
|
431
|
+
gatewayUrl: 'https://gateway.ai.cloudflare.com/v1/account/gateway',
|
|
432
|
+
gatewayToken: 'test-token',
|
|
433
|
+
// No individual API keys - gateway provides them
|
|
434
|
+
}
|
|
435
|
+
const registry = await createRegistry(config)
|
|
436
|
+
expect(registry).toBeDefined()
|
|
437
|
+
})
|
|
438
|
+
|
|
439
|
+
it('works with direct API keys (no gateway)', async () => {
|
|
440
|
+
const config: ProviderConfig = {
|
|
441
|
+
openaiApiKey: 'sk-test',
|
|
442
|
+
anthropicApiKey: 'sk-test',
|
|
443
|
+
googleApiKey: 'test',
|
|
444
|
+
openrouterApiKey: 'sk-or-test',
|
|
445
|
+
}
|
|
446
|
+
const registry = await createRegistry(config)
|
|
447
|
+
expect(registry).toBeDefined()
|
|
448
|
+
})
|
|
449
|
+
|
|
450
|
+
it('works with mixed config (gateway + fallback keys)', async () => {
|
|
451
|
+
const config: ProviderConfig = {
|
|
452
|
+
gatewayUrl: 'https://gateway.ai.cloudflare.com/v1/account/gateway',
|
|
453
|
+
gatewayToken: 'test-token',
|
|
454
|
+
// Also provide fallback keys
|
|
455
|
+
openaiApiKey: 'sk-test',
|
|
456
|
+
}
|
|
457
|
+
const registry = await createRegistry(config)
|
|
458
|
+
expect(registry).toBeDefined()
|
|
459
|
+
})
|
|
460
|
+
|
|
461
|
+
it('supports custom base URLs overriding gateway', async () => {
|
|
462
|
+
const config: ProviderConfig = {
|
|
463
|
+
gatewayUrl: 'https://gateway.ai.cloudflare.com/v1/account/gateway',
|
|
464
|
+
gatewayToken: 'test-token',
|
|
465
|
+
baseUrls: {
|
|
466
|
+
openai: 'https://custom-openai.example.com',
|
|
467
|
+
},
|
|
468
|
+
}
|
|
469
|
+
const registry = await createRegistry(config)
|
|
470
|
+
expect(registry).toBeDefined()
|
|
471
|
+
})
|
|
472
|
+
})
|
|
473
|
+
|
|
474
|
+
describe('type safety', () => {
|
|
475
|
+
it('exports ProviderConfig type', () => {
|
|
476
|
+
const config: ProviderConfig = {
|
|
477
|
+
gatewayUrl: 'test',
|
|
478
|
+
}
|
|
479
|
+
expect(config).toBeDefined()
|
|
480
|
+
})
|
|
481
|
+
|
|
482
|
+
it('exports ProviderId type', () => {
|
|
483
|
+
// Should be: 'openai' | 'anthropic' | 'google' | 'openrouter' | 'cloudflare'
|
|
484
|
+
expect(true).toBe(true)
|
|
485
|
+
})
|
|
486
|
+
|
|
487
|
+
it('exports DirectProvider type', () => {
|
|
488
|
+
// Should match language-models DirectProvider type
|
|
489
|
+
expect(true).toBe(true)
|
|
490
|
+
})
|
|
491
|
+
})
|