@pwshub/aisdk 0.0.4 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +126 -51
- package/index.d.ts +17 -8
- package/package.json +8 -6
- package/src/coerce.test.js +142 -0
- package/src/config.js +30 -0
- package/src/config.test.js +142 -0
- package/src/errors.js +1 -1
- package/src/index.js +8 -8
- package/src/models.js +56 -0
- package/src/providers.js +72 -2
- package/src/registry.js +75 -34
- package/src/registry.test.js +314 -0
- package/src/validation.test.js +410 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @fileoverview Tests for config module.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
import {
|
|
6
|
+
describe, it,
|
|
7
|
+
} from 'node:test'
|
|
8
|
+
import assert from 'node:assert'
|
|
9
|
+
import {
|
|
10
|
+
normalizeConfig, getWireMap,
|
|
11
|
+
} from '../src/config.js'
|
|
12
|
+
|
|
13
|
+
describe('config', () => {
|
|
14
|
+
describe('getWireMap', () => {
|
|
15
|
+
it('should return wire map for openai', () => {
|
|
16
|
+
const wireMap = getWireMap('openai')
|
|
17
|
+
assert.ok(wireMap)
|
|
18
|
+
assert.ok(wireMap.temperature)
|
|
19
|
+
assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
|
|
20
|
+
})
|
|
21
|
+
|
|
22
|
+
it('should return wire map for anthropic', () => {
|
|
23
|
+
const wireMap = getWireMap('anthropic')
|
|
24
|
+
assert.ok(wireMap)
|
|
25
|
+
assert.ok(wireMap.temperature)
|
|
26
|
+
assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
|
|
27
|
+
})
|
|
28
|
+
|
|
29
|
+
it('should return wire map for google', () => {
|
|
30
|
+
const wireMap = getWireMap('google')
|
|
31
|
+
assert.ok(wireMap)
|
|
32
|
+
assert.ok(wireMap.temperature)
|
|
33
|
+
assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
|
|
34
|
+
assert.strictEqual(wireMap.temperature.scope, 'generationConfig')
|
|
35
|
+
})
|
|
36
|
+
|
|
37
|
+
it('should return wire map for dashscope', () => {
|
|
38
|
+
const wireMap = getWireMap('dashscope')
|
|
39
|
+
assert.ok(wireMap)
|
|
40
|
+
assert.ok(wireMap.temperature)
|
|
41
|
+
assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
|
|
42
|
+
})
|
|
43
|
+
|
|
44
|
+
it('should return wire map for deepseek', () => {
|
|
45
|
+
const wireMap = getWireMap('deepseek')
|
|
46
|
+
assert.ok(wireMap)
|
|
47
|
+
assert.ok(wireMap.temperature)
|
|
48
|
+
assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
|
|
49
|
+
})
|
|
50
|
+
|
|
51
|
+
it('should return undefined for unknown provider', () => {
|
|
52
|
+
const wireMap = getWireMap('unknown')
|
|
53
|
+
assert.strictEqual(wireMap, undefined)
|
|
54
|
+
})
|
|
55
|
+
})
|
|
56
|
+
|
|
57
|
+
describe('normalizeConfig', () => {
|
|
58
|
+
it('should normalize openai config', () => {
|
|
59
|
+
const config = {
|
|
60
|
+
temperature: 0.5,
|
|
61
|
+
maxTokens: 100,
|
|
62
|
+
topP: 0.9,
|
|
63
|
+
}
|
|
64
|
+
const supportedParams = ['temperature', 'maxTokens', 'topP']
|
|
65
|
+
const result = normalizeConfig(config, 'openai', supportedParams, 'gpt-4o')
|
|
66
|
+
|
|
67
|
+
assert.strictEqual(result.temperature, 0.5)
|
|
68
|
+
assert.strictEqual(result.max_completion_tokens, 100)
|
|
69
|
+
assert.strictEqual(result.top_p, 0.9)
|
|
70
|
+
})
|
|
71
|
+
|
|
72
|
+
it('should normalize anthropic config', () => {
|
|
73
|
+
const config = {
|
|
74
|
+
temperature: 0.5,
|
|
75
|
+
maxTokens: 100,
|
|
76
|
+
topK: 50,
|
|
77
|
+
}
|
|
78
|
+
const supportedParams = ['temperature', 'maxTokens', 'topK']
|
|
79
|
+
const result = normalizeConfig(config, 'anthropic', supportedParams, 'claude-sonnet')
|
|
80
|
+
|
|
81
|
+
assert.strictEqual(result.temperature, 0.5)
|
|
82
|
+
assert.strictEqual(result.max_tokens, 100)
|
|
83
|
+
assert.strictEqual(result.top_k, 50)
|
|
84
|
+
})
|
|
85
|
+
|
|
86
|
+
it('should normalize google config with generationConfig nesting', () => {
|
|
87
|
+
const config = {
|
|
88
|
+
temperature: 0.5,
|
|
89
|
+
maxTokens: 100,
|
|
90
|
+
topP: 0.9,
|
|
91
|
+
}
|
|
92
|
+
const supportedParams = ['temperature', 'maxTokens', 'topP']
|
|
93
|
+
const result = normalizeConfig(config, 'google', supportedParams, 'gemini-2.0-flash')
|
|
94
|
+
|
|
95
|
+
assert.strictEqual(result.temperature, undefined)
|
|
96
|
+
assert.strictEqual(result.maxTokens, undefined)
|
|
97
|
+
assert.strictEqual(result.topP, undefined)
|
|
98
|
+
assert.ok(result.generationConfig)
|
|
99
|
+
assert.strictEqual(result.generationConfig.temperature, 0.5)
|
|
100
|
+
assert.strictEqual(result.generationConfig.maxOutputTokens, 100)
|
|
101
|
+
assert.strictEqual(result.generationConfig.topP, 0.9)
|
|
102
|
+
})
|
|
103
|
+
|
|
104
|
+
it('should skip unsupported params', () => {
|
|
105
|
+
const config = {
|
|
106
|
+
temperature: 0.5,
|
|
107
|
+
topK: 50, // openai doesn't support topK
|
|
108
|
+
}
|
|
109
|
+
const supportedParams = ['temperature']
|
|
110
|
+
const result = normalizeConfig(config, 'openai', supportedParams, 'gpt-4o')
|
|
111
|
+
|
|
112
|
+
assert.strictEqual(result.temperature, 0.5)
|
|
113
|
+
assert.strictEqual(result.top_k, undefined)
|
|
114
|
+
})
|
|
115
|
+
|
|
116
|
+
it('should skip null/undefined values', () => {
|
|
117
|
+
const config = {
|
|
118
|
+
temperature: 0.5,
|
|
119
|
+
maxTokens: null,
|
|
120
|
+
topP: undefined,
|
|
121
|
+
}
|
|
122
|
+
const supportedParams = ['temperature', 'maxTokens', 'topP']
|
|
123
|
+
const result = normalizeConfig(config, 'openai', supportedParams, 'gpt-4o')
|
|
124
|
+
|
|
125
|
+
assert.strictEqual(result.temperature, 0.5)
|
|
126
|
+
assert.strictEqual(result.max_completion_tokens, undefined)
|
|
127
|
+
assert.strictEqual(result.top_p, undefined)
|
|
128
|
+
})
|
|
129
|
+
|
|
130
|
+
it('should return empty object when no config provided', () => {
|
|
131
|
+
const result = normalizeConfig({}, 'openai', ['temperature'], 'gpt-4o')
|
|
132
|
+
assert.deepStrictEqual(result, {})
|
|
133
|
+
})
|
|
134
|
+
|
|
135
|
+
it('should return empty object when no supported params match', () => {
|
|
136
|
+
const config = { unknownParam: 123 }
|
|
137
|
+
const supportedParams = ['temperature']
|
|
138
|
+
const result = normalizeConfig(config, 'openai', supportedParams, 'gpt-4o')
|
|
139
|
+
assert.deepStrictEqual(result, {})
|
|
140
|
+
})
|
|
141
|
+
})
|
|
142
|
+
})
|
package/src/errors.js
CHANGED
|
@@ -35,7 +35,7 @@ export class ProviderError extends Error {
|
|
|
35
35
|
* @param {object} meta
|
|
36
36
|
* @param {number} meta.status - HTTP status code
|
|
37
37
|
* @param {string} meta.provider - Provider ID
|
|
38
|
-
* @param {string} meta.model - Model
|
|
38
|
+
* @param {string} meta.model - Model name that was called
|
|
39
39
|
* @param {string} [meta.raw] - Raw response body from provider
|
|
40
40
|
*/
|
|
41
41
|
constructor(message, {
|
package/src/index.js
CHANGED
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
*
|
|
7
7
|
* const ai = createAi()
|
|
8
8
|
* const result = await ai.ask({
|
|
9
|
-
* model: 'claude-sonnet-4-20250514',
|
|
9
|
+
* model: 'anthropic/claude-sonnet-4-20250514',
|
|
10
10
|
* apikey: 'your-api-key',
|
|
11
11
|
* prompt: 'What is the capital of Vietnam?',
|
|
12
12
|
* temperature: 0.5,
|
|
@@ -16,18 +16,18 @@
|
|
|
16
16
|
*
|
|
17
17
|
* @example With fallbacks
|
|
18
18
|
* const result = await ai.ask({
|
|
19
|
-
* model: 'gpt-4o',
|
|
19
|
+
* model: 'openai/gpt-4o',
|
|
20
20
|
* apikey: 'your-openai-key',
|
|
21
21
|
* prompt: '...',
|
|
22
|
-
* fallbacks: ['gpt-4o-mini', 'claude-haiku-4-5-20251001'],
|
|
22
|
+
* fallbacks: ['openai/gpt-4o-mini', 'anthropic/claude-haiku-4-5-20251001'],
|
|
23
23
|
* })
|
|
24
|
-
* if (result.model !== 'gpt-4o') {
|
|
24
|
+
* if (result.model !== 'openai/gpt-4o') {
|
|
25
25
|
* console.warn('Fell back to', result.model)
|
|
26
26
|
* }
|
|
27
27
|
*
|
|
28
28
|
* @example Google provider-specific options
|
|
29
29
|
* const result = await ai.ask({
|
|
30
|
-
* model: 'gemini-2.0-flash',
|
|
30
|
+
* model: 'google/gemini-2.0-flash',
|
|
31
31
|
* apikey: 'your-google-key',
|
|
32
32
|
* prompt: '...',
|
|
33
33
|
* providerOptions: {
|
|
@@ -40,7 +40,7 @@
|
|
|
40
40
|
*
|
|
41
41
|
* @example Using messages array for multi-turn conversations
|
|
42
42
|
* const result = await ai.ask({
|
|
43
|
-
* model: 'claude-sonnet-4-20250514',
|
|
43
|
+
* model: 'anthropic/claude-sonnet-4-20250514',
|
|
44
44
|
* apikey: 'your-api-key',
|
|
45
45
|
* messages: [
|
|
46
46
|
* { role: 'user', content: 'What is the capital of Vietnam?' },
|
|
@@ -73,12 +73,12 @@ export {
|
|
|
73
73
|
|
|
74
74
|
/**
|
|
75
75
|
* @typedef {Object} AskParams
|
|
76
|
-
* @property {string} model - Model
|
|
76
|
+
* @property {string} model - Model name or 'provider/name' format (e.g., 'gpt-4o', 'ollama/llama3.2')
|
|
77
77
|
* @property {string} apikey - API key for the provider
|
|
78
78
|
* @property {string} [prompt] - The user message (alternative to messages)
|
|
79
79
|
* @property {string} [system] - Optional system prompt (used with prompt)
|
|
80
80
|
* @property {import('./providers.js').Message[]} [messages] - Array of messages with role and content (alternative to prompt)
|
|
81
|
-
* @property {string[]} [fallbacks] - Ordered list of fallback model
|
|
81
|
+
* @property {string[]} [fallbacks] - Ordered list of fallback models (same format as model)
|
|
82
82
|
* @property {Record<string, unknown>} [providerOptions] - Provider-specific options merged into body
|
|
83
83
|
* @property {number} [temperature]
|
|
84
84
|
* @property {number} [maxTokens]
|
package/src/models.js
CHANGED
|
@@ -342,4 +342,60 @@ export const DEFAULT_MODELS = [
|
|
|
342
342
|
max_out: 65536,
|
|
343
343
|
enable: true,
|
|
344
344
|
},
|
|
345
|
+
// Mistral models
|
|
346
|
+
{
|
|
347
|
+
id: 'mistral-large-latest',
|
|
348
|
+
name: 'mistral-large-latest',
|
|
349
|
+
provider: 'mistral',
|
|
350
|
+
input_price: 0.5,
|
|
351
|
+
output_price: 1.5,
|
|
352
|
+
cache_price: 0,
|
|
353
|
+
max_in: 128000,
|
|
354
|
+
max_out: 128000,
|
|
355
|
+
enable: true,
|
|
356
|
+
},
|
|
357
|
+
{
|
|
358
|
+
id: 'mistral-medium-latest',
|
|
359
|
+
name: 'mistral-medium-latest',
|
|
360
|
+
provider: 'mistral',
|
|
361
|
+
input_price: 0.4,
|
|
362
|
+
output_price: 2,
|
|
363
|
+
cache_price: 0,
|
|
364
|
+
max_in: 64000,
|
|
365
|
+
max_out: 64000,
|
|
366
|
+
enable: true,
|
|
367
|
+
},
|
|
368
|
+
{
|
|
369
|
+
id: 'mistral-small-latest',
|
|
370
|
+
name: 'mistral-small-latest',
|
|
371
|
+
provider: 'mistral',
|
|
372
|
+
input_price: 0.15,
|
|
373
|
+
output_price: 0.6,
|
|
374
|
+
cache_price: 0,
|
|
375
|
+
max_in: 128000,
|
|
376
|
+
max_out: 128000,
|
|
377
|
+
enable: true,
|
|
378
|
+
},
|
|
379
|
+
{
|
|
380
|
+
id: 'magistral-medium-latest',
|
|
381
|
+
name: 'magistral-medium-latest',
|
|
382
|
+
provider: 'mistral',
|
|
383
|
+
input_price: 2,
|
|
384
|
+
output_price: 5,
|
|
385
|
+
cache_price: 0,
|
|
386
|
+
max_in: 64000,
|
|
387
|
+
max_out: 64000,
|
|
388
|
+
enable: true,
|
|
389
|
+
},
|
|
390
|
+
{
|
|
391
|
+
id: 'magistral-small-latest',
|
|
392
|
+
name: 'magistral-small-latest',
|
|
393
|
+
provider: 'mistral',
|
|
394
|
+
input_price: 0.5,
|
|
395
|
+
output_price: 1.5,
|
|
396
|
+
cache_price: 0,
|
|
397
|
+
max_in: 64000,
|
|
398
|
+
max_out: 64000,
|
|
399
|
+
enable: true,
|
|
400
|
+
},
|
|
345
401
|
]
|
package/src/providers.js
CHANGED
|
@@ -10,7 +10,7 @@
|
|
|
10
10
|
*/
|
|
11
11
|
|
|
12
12
|
/**
|
|
13
|
-
* @typedef {'openai'|'anthropic'|'google'|'dashscope'|'deepseek'} ProviderId
|
|
13
|
+
* @typedef {'openai'|'anthropic'|'google'|'dashscope'|'deepseek'|'mistral'|'ollama'} ProviderId
|
|
14
14
|
*/
|
|
15
15
|
|
|
16
16
|
/**
|
|
@@ -298,9 +298,79 @@ const deepseek = {
|
|
|
298
298
|
}),
|
|
299
299
|
}
|
|
300
300
|
|
|
301
|
+
/** @type {ProviderAdapter} */
|
|
302
|
+
const mistral = {
|
|
303
|
+
headers: (apikey) => ({
|
|
304
|
+
Authorization: `Bearer ${apikey}`,
|
|
305
|
+
'Content-Type': 'application/json',
|
|
306
|
+
Accept: 'application/json',
|
|
307
|
+
}),
|
|
308
|
+
url: () => 'https://api.mistral.ai/v1/chat/completions',
|
|
309
|
+
buildBody: (modelName, messages, config, providerOptions) => ({
|
|
310
|
+
model: modelName,
|
|
311
|
+
messages,
|
|
312
|
+
...config,
|
|
313
|
+
...providerOptions,
|
|
314
|
+
}),
|
|
315
|
+
extractText: (data) => {
|
|
316
|
+
const content = data.choices?.[0]?.message?.content
|
|
317
|
+
if (!content) {
|
|
318
|
+
throw new Error('Mistral response missing content')
|
|
319
|
+
}
|
|
320
|
+
return content
|
|
321
|
+
},
|
|
322
|
+
extractUsage: (data) => ({
|
|
323
|
+
inputTokens: data.usage?.prompt_tokens ?? 0,
|
|
324
|
+
outputTokens: data.usage?.completion_tokens ?? 0,
|
|
325
|
+
cacheTokens: 0,
|
|
326
|
+
reasoningTokens: 0,
|
|
327
|
+
}),
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
/** @type {ProviderAdapter} */
|
|
331
|
+
const ollama = {
|
|
332
|
+
headers: (apikey) => ({
|
|
333
|
+
'Content-Type': 'application/json',
|
|
334
|
+
...(apikey && { Authorization: `Bearer ${apikey}` }),
|
|
335
|
+
}),
|
|
336
|
+
// Default to localhost, but can be overridden via gatewayUrl
|
|
337
|
+
url: () => 'http://localhost:11434/api/chat',
|
|
338
|
+
buildBody: (modelName, messages, config, providerOptions) => {
|
|
339
|
+
// Ollama uses snake_case options
|
|
340
|
+
const options = {}
|
|
341
|
+
if (config.temperature !== undefined) options.temperature = config.temperature
|
|
342
|
+
if (config.top_p !== undefined) options.top_p = config.top_p
|
|
343
|
+
if (config.top_k !== undefined) options.top_k = config.top_k
|
|
344
|
+
if (config.num_predict !== undefined) options.num_predict = config.num_predict
|
|
345
|
+
if (config.seed !== undefined) options.seed = config.seed
|
|
346
|
+
if (config.stop !== undefined) options.stop = config.stop
|
|
347
|
+
|
|
348
|
+
return {
|
|
349
|
+
model: modelName,
|
|
350
|
+
messages,
|
|
351
|
+
stream: false,
|
|
352
|
+
...providerOptions,
|
|
353
|
+
...(Object.keys(options).length > 0 && { options }),
|
|
354
|
+
}
|
|
355
|
+
},
|
|
356
|
+
extractText: (data) => {
|
|
357
|
+
const content = data.message?.content
|
|
358
|
+
if (!content) {
|
|
359
|
+
throw new Error('Ollama response missing content')
|
|
360
|
+
}
|
|
361
|
+
return content
|
|
362
|
+
},
|
|
363
|
+
extractUsage: (data) => ({
|
|
364
|
+
inputTokens: data.prompt_eval_count ?? 0,
|
|
365
|
+
outputTokens: data.eval_count ?? 0,
|
|
366
|
+
cacheTokens: 0,
|
|
367
|
+
reasoningTokens: 0,
|
|
368
|
+
}),
|
|
369
|
+
}
|
|
370
|
+
|
|
301
371
|
/** @type {Record<string, ProviderAdapter>} */
|
|
302
372
|
const ADAPTERS = {
|
|
303
|
-
openai, anthropic, google, dashscope, deepseek,
|
|
373
|
+
openai, anthropic, google, dashscope, deepseek, mistral, ollama,
|
|
304
374
|
}
|
|
305
375
|
|
|
306
376
|
/**
|
package/src/registry.js
CHANGED
|
@@ -4,7 +4,8 @@
|
|
|
4
4
|
* Default models are loaded automatically from ./models.js at import time.
|
|
5
5
|
* Users can modify the registry via addModels() and setModels().
|
|
6
6
|
*
|
|
7
|
-
* This module provides O(1) lookups at runtime via a Map
|
|
7
|
+
* This module provides O(1) lookups at runtime via a Map.
|
|
8
|
+
* Models can be looked up by name, or by provider/name format.
|
|
8
9
|
*
|
|
9
10
|
* `supportedParams` is optional per record. When absent, the provider's
|
|
10
11
|
* default param set is used.
|
|
@@ -43,10 +44,12 @@ export const PROVIDER_DEFAULT_PARAMS = {
|
|
|
43
44
|
google: ['temperature', 'maxTokens', 'topP', 'topK', 'seed', 'stop'],
|
|
44
45
|
dashscope: ['temperature', 'maxTokens', 'topP', 'topK', 'stop'],
|
|
45
46
|
deepseek: ['temperature', 'maxTokens', 'topP', 'frequencyPenalty', 'presencePenalty', 'stop'],
|
|
47
|
+
mistral: ['temperature', 'maxTokens', 'topP', 'randomSeed', 'stop'],
|
|
48
|
+
ollama: ['temperature', 'maxTokens', 'topP', 'topK', 'seed', 'stop'],
|
|
46
49
|
}
|
|
47
50
|
|
|
48
51
|
/** @type {ProviderId[]} */
|
|
49
|
-
const VALID_PROVIDERS = ['openai', 'anthropic', 'google', 'dashscope', 'deepseek']
|
|
52
|
+
const VALID_PROVIDERS = ['openai', 'anthropic', 'google', 'dashscope', 'deepseek', 'mistral', 'ollama']
|
|
50
53
|
|
|
51
54
|
/** @type {Map<string, ModelRecord>} */
|
|
52
55
|
let REGISTRY = new Map()
|
|
@@ -72,34 +75,29 @@ initRegistry()
|
|
|
72
75
|
const validateModelRecord = (model, index) => {
|
|
73
76
|
const errors = []
|
|
74
77
|
|
|
75
|
-
//
|
|
76
|
-
if (!model.id || typeof model.id !== 'string') {
|
|
77
|
-
errors.push('"id" must be a non-empty string')
|
|
78
|
-
}
|
|
79
|
-
|
|
78
|
+
// Only name and provider are required
|
|
80
79
|
if (!model.name || typeof model.name !== 'string') {
|
|
81
80
|
errors.push('"name" must be a non-empty string')
|
|
82
81
|
}
|
|
83
82
|
|
|
84
|
-
// Check provider is valid
|
|
85
83
|
if (!model.provider || typeof model.provider !== 'string') {
|
|
86
84
|
errors.push('"provider" must be a string')
|
|
87
85
|
} else if (!VALID_PROVIDERS.includes(model.provider)) {
|
|
88
86
|
errors.push(`"provider" must be one of: ${VALID_PROVIDERS.join(', ')}. Got: "${model.provider}"`)
|
|
89
87
|
}
|
|
90
88
|
|
|
91
|
-
// Check
|
|
89
|
+
// Check optional number fields if present (must be non-negative)
|
|
92
90
|
const numberFields = ['input_price', 'output_price', 'cache_price', 'max_in', 'max_out']
|
|
93
91
|
for (const field of numberFields) {
|
|
94
|
-
if (typeof model[field] !== 'number') {
|
|
92
|
+
if (model[field] !== undefined && typeof model[field] !== 'number') {
|
|
95
93
|
errors.push(`"${field}" must be a number`)
|
|
96
|
-
} else if (model[field] < 0) {
|
|
94
|
+
} else if (model[field] !== undefined && model[field] < 0) {
|
|
97
95
|
errors.push(`"${field}" must be non-negative, got: ${model[field]}`)
|
|
98
96
|
}
|
|
99
97
|
}
|
|
100
98
|
|
|
101
|
-
// Check enable
|
|
102
|
-
if (typeof model.enable !== 'boolean') {
|
|
99
|
+
// Check optional enable if present
|
|
100
|
+
if (model.enable !== undefined && typeof model.enable !== 'boolean') {
|
|
103
101
|
errors.push('"enable" must be a boolean')
|
|
104
102
|
}
|
|
105
103
|
|
|
@@ -122,30 +120,69 @@ const validateModelRecord = (model, index) => {
|
|
|
122
120
|
}
|
|
123
121
|
|
|
124
122
|
/**
|
|
125
|
-
*
|
|
126
|
-
*
|
|
123
|
+
* Normalizes a model record by setting defaults for missing fields.
|
|
124
|
+
* Generates id from provider and name if not provided.
|
|
127
125
|
*
|
|
128
|
-
* @param {
|
|
126
|
+
* @param {Object} model - The model record to normalize
|
|
127
|
+
* @returns {ModelRecord} Normalized model record
|
|
128
|
+
*/
|
|
129
|
+
const normalizeModelRecord = (model) => {
|
|
130
|
+
return {
|
|
131
|
+
id: model.id || `${model.provider}_${model.name}`,
|
|
132
|
+
name: model.name,
|
|
133
|
+
provider: model.provider,
|
|
134
|
+
input_price: model.input_price ?? 0,
|
|
135
|
+
output_price: model.output_price ?? 0,
|
|
136
|
+
cache_price: model.cache_price ?? 0,
|
|
137
|
+
max_in: model.max_in ?? 32000,
|
|
138
|
+
max_out: model.max_out ?? 8000,
|
|
139
|
+
enable: model.enable ?? true,
|
|
140
|
+
...(model.supportedParams !== undefined && { supportedParams: model.supportedParams }),
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
/**
|
|
145
|
+
* Looks up a model by provider/name format.
|
|
146
|
+
* Validates it is enabled, and resolves its effective supported params.
|
|
147
|
+
*
|
|
148
|
+
* @param {string} modelId - Model in 'provider/name' format (e.g., 'openai/gpt-4o', 'ollama/llama3.2')
|
|
129
149
|
* @returns {{ record: ModelRecord, supportedParams: string[] }}
|
|
130
150
|
* @throws {Error} When the model is not found or is disabled
|
|
131
151
|
*/
|
|
132
152
|
export const getModel = (modelId) => {
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
throw new Error(`Unknown model "${modelId}". Available: ${available}`)
|
|
153
|
+
// Require provider/name format
|
|
154
|
+
if (!modelId.includes('/')) {
|
|
155
|
+
const available = [...REGISTRY.values()].map(m => `${m.provider}/${m.name}`).join(', ')
|
|
156
|
+
throw new Error(`Model must be in 'provider/name' format. Got: "${modelId}". Available: ${available}`)
|
|
138
157
|
}
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
158
|
+
|
|
159
|
+
const parts = modelId.split('/')
|
|
160
|
+
if (parts.length !== 2) {
|
|
161
|
+
const available = [...REGISTRY.values()].map(m => `${m.provider}/${m.name}`).join(', ')
|
|
162
|
+
throw new Error(`Model must be in 'provider/name' format. Got: "${modelId}". Available: ${available}`)
|
|
142
163
|
}
|
|
143
|
-
|
|
144
|
-
const
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
164
|
+
|
|
165
|
+
const [provider, name] = parts
|
|
166
|
+
|
|
167
|
+
// Search for model by name and provider
|
|
168
|
+
for (const m of REGISTRY.values()) {
|
|
169
|
+
if (m.name === name && m.provider === provider) {
|
|
170
|
+
const record = m
|
|
171
|
+
|
|
172
|
+
if (!record.enable) {
|
|
173
|
+
throw new Error(`Model "${record.provider}/${record.name}" is currently disabled.`)
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
const supportedParams = record.supportedParams ?? PROVIDER_DEFAULT_PARAMS[record.provider]
|
|
177
|
+
|
|
178
|
+
return {
|
|
179
|
+
record, supportedParams,
|
|
180
|
+
}
|
|
181
|
+
}
|
|
148
182
|
}
|
|
183
|
+
|
|
184
|
+
const available = [...REGISTRY.values()].map(m => `${m.provider}/${m.name}`).join(', ')
|
|
185
|
+
throw new Error(`Unknown model "${modelId}". Available: ${available}`)
|
|
149
186
|
}
|
|
150
187
|
|
|
151
188
|
/**
|
|
@@ -168,14 +205,15 @@ export const addModels = (models) => {
|
|
|
168
205
|
throw new Error(`addModels expects an array. Got: ${typeof models}`)
|
|
169
206
|
}
|
|
170
207
|
|
|
171
|
-
// Validate each model record
|
|
208
|
+
// Validate and normalize each model record
|
|
172
209
|
models.forEach((model, index) => {
|
|
173
210
|
validateModelRecord(model, index)
|
|
174
211
|
})
|
|
175
212
|
|
|
176
|
-
// Add models to the registry
|
|
213
|
+
// Add normalized models to the registry
|
|
177
214
|
models.forEach((model) => {
|
|
178
|
-
|
|
215
|
+
const normalized = normalizeModelRecord(model)
|
|
216
|
+
REGISTRY.set(normalized.id, normalized)
|
|
179
217
|
})
|
|
180
218
|
}
|
|
181
219
|
|
|
@@ -191,10 +229,13 @@ export const setModels = (models) => {
|
|
|
191
229
|
throw new Error(`setModels expects an array. Got: ${typeof models}`)
|
|
192
230
|
}
|
|
193
231
|
|
|
194
|
-
// Validate each model record
|
|
232
|
+
// Validate and normalize each model record
|
|
195
233
|
models.forEach((model, index) => {
|
|
196
234
|
validateModelRecord(model, index)
|
|
197
235
|
})
|
|
198
236
|
|
|
199
|
-
REGISTRY = new Map(models.map((model) =>
|
|
237
|
+
REGISTRY = new Map(models.map((model) => {
|
|
238
|
+
const normalized = normalizeModelRecord(model)
|
|
239
|
+
return [normalized.id, normalized]
|
|
240
|
+
}))
|
|
200
241
|
}
|