@pwshub/aisdk 0.0.4 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,142 @@
1
+ /**
2
+ * @fileoverview Tests for config module.
3
+ */
4
+
5
+ import {
6
+ describe, it,
7
+ } from 'node:test'
8
+ import assert from 'node:assert'
9
+ import {
10
+ normalizeConfig, getWireMap,
11
+ } from '../src/config.js'
12
+
13
+ describe('config', () => {
14
+ describe('getWireMap', () => {
15
+ it('should return wire map for openai', () => {
16
+ const wireMap = getWireMap('openai')
17
+ assert.ok(wireMap)
18
+ assert.ok(wireMap.temperature)
19
+ assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
20
+ })
21
+
22
+ it('should return wire map for anthropic', () => {
23
+ const wireMap = getWireMap('anthropic')
24
+ assert.ok(wireMap)
25
+ assert.ok(wireMap.temperature)
26
+ assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
27
+ })
28
+
29
+ it('should return wire map for google', () => {
30
+ const wireMap = getWireMap('google')
31
+ assert.ok(wireMap)
32
+ assert.ok(wireMap.temperature)
33
+ assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
34
+ assert.strictEqual(wireMap.temperature.scope, 'generationConfig')
35
+ })
36
+
37
+ it('should return wire map for dashscope', () => {
38
+ const wireMap = getWireMap('dashscope')
39
+ assert.ok(wireMap)
40
+ assert.ok(wireMap.temperature)
41
+ assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
42
+ })
43
+
44
+ it('should return wire map for deepseek', () => {
45
+ const wireMap = getWireMap('deepseek')
46
+ assert.ok(wireMap)
47
+ assert.ok(wireMap.temperature)
48
+ assert.strictEqual(wireMap.temperature.wireKey, 'temperature')
49
+ })
50
+
51
+ it('should return undefined for unknown provider', () => {
52
+ const wireMap = getWireMap('unknown')
53
+ assert.strictEqual(wireMap, undefined)
54
+ })
55
+ })
56
+
57
+ describe('normalizeConfig', () => {
58
+ it('should normalize openai config', () => {
59
+ const config = {
60
+ temperature: 0.5,
61
+ maxTokens: 100,
62
+ topP: 0.9,
63
+ }
64
+ const supportedParams = ['temperature', 'maxTokens', 'topP']
65
+ const result = normalizeConfig(config, 'openai', supportedParams, 'gpt-4o')
66
+
67
+ assert.strictEqual(result.temperature, 0.5)
68
+ assert.strictEqual(result.max_completion_tokens, 100)
69
+ assert.strictEqual(result.top_p, 0.9)
70
+ })
71
+
72
+ it('should normalize anthropic config', () => {
73
+ const config = {
74
+ temperature: 0.5,
75
+ maxTokens: 100,
76
+ topK: 50,
77
+ }
78
+ const supportedParams = ['temperature', 'maxTokens', 'topK']
79
+ const result = normalizeConfig(config, 'anthropic', supportedParams, 'claude-sonnet')
80
+
81
+ assert.strictEqual(result.temperature, 0.5)
82
+ assert.strictEqual(result.max_tokens, 100)
83
+ assert.strictEqual(result.top_k, 50)
84
+ })
85
+
86
+ it('should normalize google config with generationConfig nesting', () => {
87
+ const config = {
88
+ temperature: 0.5,
89
+ maxTokens: 100,
90
+ topP: 0.9,
91
+ }
92
+ const supportedParams = ['temperature', 'maxTokens', 'topP']
93
+ const result = normalizeConfig(config, 'google', supportedParams, 'gemini-2.0-flash')
94
+
95
+ assert.strictEqual(result.temperature, undefined)
96
+ assert.strictEqual(result.maxTokens, undefined)
97
+ assert.strictEqual(result.topP, undefined)
98
+ assert.ok(result.generationConfig)
99
+ assert.strictEqual(result.generationConfig.temperature, 0.5)
100
+ assert.strictEqual(result.generationConfig.maxOutputTokens, 100)
101
+ assert.strictEqual(result.generationConfig.topP, 0.9)
102
+ })
103
+
104
+ it('should skip unsupported params', () => {
105
+ const config = {
106
+ temperature: 0.5,
107
+ topK: 50, // openai doesn't support topK
108
+ }
109
+ const supportedParams = ['temperature']
110
+ const result = normalizeConfig(config, 'openai', supportedParams, 'gpt-4o')
111
+
112
+ assert.strictEqual(result.temperature, 0.5)
113
+ assert.strictEqual(result.top_k, undefined)
114
+ })
115
+
116
+ it('should skip null/undefined values', () => {
117
+ const config = {
118
+ temperature: 0.5,
119
+ maxTokens: null,
120
+ topP: undefined,
121
+ }
122
+ const supportedParams = ['temperature', 'maxTokens', 'topP']
123
+ const result = normalizeConfig(config, 'openai', supportedParams, 'gpt-4o')
124
+
125
+ assert.strictEqual(result.temperature, 0.5)
126
+ assert.strictEqual(result.max_completion_tokens, undefined)
127
+ assert.strictEqual(result.top_p, undefined)
128
+ })
129
+
130
+ it('should return empty object when no config provided', () => {
131
+ const result = normalizeConfig({}, 'openai', ['temperature'], 'gpt-4o')
132
+ assert.deepStrictEqual(result, {})
133
+ })
134
+
135
+ it('should return empty object when no supported params match', () => {
136
+ const config = { unknownParam: 123 }
137
+ const supportedParams = ['temperature']
138
+ const result = normalizeConfig(config, 'openai', supportedParams, 'gpt-4o')
139
+ assert.deepStrictEqual(result, {})
140
+ })
141
+ })
142
+ })
package/src/errors.js CHANGED
@@ -35,7 +35,7 @@ export class ProviderError extends Error {
35
35
  * @param {object} meta
36
36
  * @param {number} meta.status - HTTP status code
37
37
  * @param {string} meta.provider - Provider ID
38
- * @param {string} meta.model - Model ID that was called
38
+ * @param {string} meta.model - Model name that was called
39
39
  * @param {string} [meta.raw] - Raw response body from provider
40
40
  */
41
41
  constructor(message, {
package/src/index.js CHANGED
@@ -6,7 +6,7 @@
6
6
  *
7
7
  * const ai = createAi()
8
8
  * const result = await ai.ask({
9
- * model: 'claude-sonnet-4-20250514',
9
+ * model: 'anthropic/claude-sonnet-4-20250514',
10
10
  * apikey: 'your-api-key',
11
11
  * prompt: 'What is the capital of Vietnam?',
12
12
  * temperature: 0.5,
@@ -16,18 +16,18 @@
16
16
  *
17
17
  * @example With fallbacks
18
18
  * const result = await ai.ask({
19
- * model: 'gpt-4o',
19
+ * model: 'openai/gpt-4o',
20
20
  * apikey: 'your-openai-key',
21
21
  * prompt: '...',
22
- * fallbacks: ['gpt-4o-mini', 'claude-haiku-4-5-20251001'],
22
+ * fallbacks: ['openai/gpt-4o-mini', 'anthropic/claude-haiku-4-5-20251001'],
23
23
  * })
24
- * if (result.model !== 'gpt-4o') {
24
+ * if (result.model !== 'openai/gpt-4o') {
25
25
  * console.warn('Fell back to', result.model)
26
26
  * }
27
27
  *
28
28
  * @example Google provider-specific options
29
29
  * const result = await ai.ask({
30
- * model: 'gemini-2.0-flash',
30
+ * model: 'google/gemini-2.0-flash',
31
31
  * apikey: 'your-google-key',
32
32
  * prompt: '...',
33
33
  * providerOptions: {
@@ -40,7 +40,7 @@
40
40
  *
41
41
  * @example Using messages array for multi-turn conversations
42
42
  * const result = await ai.ask({
43
- * model: 'claude-sonnet-4-20250514',
43
+ * model: 'anthropic/claude-sonnet-4-20250514',
44
44
  * apikey: 'your-api-key',
45
45
  * messages: [
46
46
  * { role: 'user', content: 'What is the capital of Vietnam?' },
@@ -73,12 +73,12 @@ export {
73
73
 
74
74
  /**
75
75
  * @typedef {Object} AskParams
76
- * @property {string} model - Model ID (must be registered via setModels())
76
+ * @property {string} model - Model name or 'provider/name' format (e.g., 'gpt-4o', 'ollama/llama3.2')
77
77
  * @property {string} apikey - API key for the provider
78
78
  * @property {string} [prompt] - The user message (alternative to messages)
79
79
  * @property {string} [system] - Optional system prompt (used with prompt)
80
80
  * @property {import('./providers.js').Message[]} [messages] - Array of messages with role and content (alternative to prompt)
81
- * @property {string[]} [fallbacks] - Ordered list of fallback model IDs
81
+ * @property {string[]} [fallbacks] - Ordered list of fallback models (same format as model)
82
82
  * @property {Record<string, unknown>} [providerOptions] - Provider-specific options merged into body
83
83
  * @property {number} [temperature]
84
84
  * @property {number} [maxTokens]
package/src/models.js CHANGED
@@ -342,4 +342,60 @@ export const DEFAULT_MODELS = [
342
342
  max_out: 65536,
343
343
  enable: true,
344
344
  },
345
+ // Mistral models
346
+ {
347
+ id: 'mistral-large-latest',
348
+ name: 'mistral-large-latest',
349
+ provider: 'mistral',
350
+ input_price: 0.5,
351
+ output_price: 1.5,
352
+ cache_price: 0,
353
+ max_in: 128000,
354
+ max_out: 128000,
355
+ enable: true,
356
+ },
357
+ {
358
+ id: 'mistral-medium-latest',
359
+ name: 'mistral-medium-latest',
360
+ provider: 'mistral',
361
+ input_price: 0.4,
362
+ output_price: 2,
363
+ cache_price: 0,
364
+ max_in: 64000,
365
+ max_out: 64000,
366
+ enable: true,
367
+ },
368
+ {
369
+ id: 'mistral-small-latest',
370
+ name: 'mistral-small-latest',
371
+ provider: 'mistral',
372
+ input_price: 0.15,
373
+ output_price: 0.6,
374
+ cache_price: 0,
375
+ max_in: 128000,
376
+ max_out: 128000,
377
+ enable: true,
378
+ },
379
+ {
380
+ id: 'magistral-medium-latest',
381
+ name: 'magistral-medium-latest',
382
+ provider: 'mistral',
383
+ input_price: 2,
384
+ output_price: 5,
385
+ cache_price: 0,
386
+ max_in: 64000,
387
+ max_out: 64000,
388
+ enable: true,
389
+ },
390
+ {
391
+ id: 'magistral-small-latest',
392
+ name: 'magistral-small-latest',
393
+ provider: 'mistral',
394
+ input_price: 0.5,
395
+ output_price: 1.5,
396
+ cache_price: 0,
397
+ max_in: 64000,
398
+ max_out: 64000,
399
+ enable: true,
400
+ },
345
401
  ]
package/src/providers.js CHANGED
@@ -10,7 +10,7 @@
10
10
  */
11
11
 
12
12
  /**
13
- * @typedef {'openai'|'anthropic'|'google'|'dashscope'|'deepseek'} ProviderId
13
+ * @typedef {'openai'|'anthropic'|'google'|'dashscope'|'deepseek'|'mistral'|'ollama'} ProviderId
14
14
  */
15
15
 
16
16
  /**
@@ -298,9 +298,79 @@ const deepseek = {
298
298
  }),
299
299
  }
300
300
 
301
+ /** @type {ProviderAdapter} */
302
+ const mistral = {
303
+ headers: (apikey) => ({
304
+ Authorization: `Bearer ${apikey}`,
305
+ 'Content-Type': 'application/json',
306
+ Accept: 'application/json',
307
+ }),
308
+ url: () => 'https://api.mistral.ai/v1/chat/completions',
309
+ buildBody: (modelName, messages, config, providerOptions) => ({
310
+ model: modelName,
311
+ messages,
312
+ ...config,
313
+ ...providerOptions,
314
+ }),
315
+ extractText: (data) => {
316
+ const content = data.choices?.[0]?.message?.content
317
+ if (!content) {
318
+ throw new Error('Mistral response missing content')
319
+ }
320
+ return content
321
+ },
322
+ extractUsage: (data) => ({
323
+ inputTokens: data.usage?.prompt_tokens ?? 0,
324
+ outputTokens: data.usage?.completion_tokens ?? 0,
325
+ cacheTokens: 0,
326
+ reasoningTokens: 0,
327
+ }),
328
+ }
329
+
330
+ /** @type {ProviderAdapter} */
331
+ const ollama = {
332
+ headers: (apikey) => ({
333
+ 'Content-Type': 'application/json',
334
+ ...(apikey && { Authorization: `Bearer ${apikey}` }),
335
+ }),
336
+ // Default to localhost, but can be overridden via gatewayUrl
337
+ url: () => 'http://localhost:11434/api/chat',
338
+ buildBody: (modelName, messages, config, providerOptions) => {
339
+ // Ollama uses snake_case options
340
+ const options = {}
341
+ if (config.temperature !== undefined) options.temperature = config.temperature
342
+ if (config.top_p !== undefined) options.top_p = config.top_p
343
+ if (config.top_k !== undefined) options.top_k = config.top_k
344
+ if (config.num_predict !== undefined) options.num_predict = config.num_predict
345
+ if (config.seed !== undefined) options.seed = config.seed
346
+ if (config.stop !== undefined) options.stop = config.stop
347
+
348
+ return {
349
+ model: modelName,
350
+ messages,
351
+ stream: false,
352
+ ...providerOptions,
353
+ ...(Object.keys(options).length > 0 && { options }),
354
+ }
355
+ },
356
+ extractText: (data) => {
357
+ const content = data.message?.content
358
+ if (!content) {
359
+ throw new Error('Ollama response missing content')
360
+ }
361
+ return content
362
+ },
363
+ extractUsage: (data) => ({
364
+ inputTokens: data.prompt_eval_count ?? 0,
365
+ outputTokens: data.eval_count ?? 0,
366
+ cacheTokens: 0,
367
+ reasoningTokens: 0,
368
+ }),
369
+ }
370
+
301
371
  /** @type {Record<string, ProviderAdapter>} */
302
372
  const ADAPTERS = {
303
- openai, anthropic, google, dashscope, deepseek,
373
+ openai, anthropic, google, dashscope, deepseek, mistral, ollama,
304
374
  }
305
375
 
306
376
  /**
package/src/registry.js CHANGED
@@ -4,7 +4,8 @@
4
4
  * Default models are loaded automatically from ./models.js at import time.
5
5
  * Users can modify the registry via addModels() and setModels().
6
6
  *
7
- * This module provides O(1) lookups at runtime via a Map indexed by model ID.
7
+ * This module provides O(1) lookups at runtime via a Map.
8
+ * Models can be looked up by name, or by provider/name format.
8
9
  *
9
10
  * `supportedParams` is optional per record. When absent, the provider's
10
11
  * default param set is used.
@@ -43,10 +44,12 @@ export const PROVIDER_DEFAULT_PARAMS = {
43
44
  google: ['temperature', 'maxTokens', 'topP', 'topK', 'seed', 'stop'],
44
45
  dashscope: ['temperature', 'maxTokens', 'topP', 'topK', 'stop'],
45
46
  deepseek: ['temperature', 'maxTokens', 'topP', 'frequencyPenalty', 'presencePenalty', 'stop'],
47
+ mistral: ['temperature', 'maxTokens', 'topP', 'randomSeed', 'stop'],
48
+ ollama: ['temperature', 'maxTokens', 'topP', 'topK', 'seed', 'stop'],
46
49
  }
47
50
 
48
51
  /** @type {ProviderId[]} */
49
- const VALID_PROVIDERS = ['openai', 'anthropic', 'google', 'dashscope', 'deepseek']
52
+ const VALID_PROVIDERS = ['openai', 'anthropic', 'google', 'dashscope', 'deepseek', 'mistral', 'ollama']
50
53
 
51
54
  /** @type {Map<string, ModelRecord>} */
52
55
  let REGISTRY = new Map()
@@ -72,34 +75,29 @@ initRegistry()
72
75
  const validateModelRecord = (model, index) => {
73
76
  const errors = []
74
77
 
75
- // Check required string fields
76
- if (!model.id || typeof model.id !== 'string') {
77
- errors.push('"id" must be a non-empty string')
78
- }
79
-
78
+ // Only name and provider are required
80
79
  if (!model.name || typeof model.name !== 'string') {
81
80
  errors.push('"name" must be a non-empty string')
82
81
  }
83
82
 
84
- // Check provider is valid
85
83
  if (!model.provider || typeof model.provider !== 'string') {
86
84
  errors.push('"provider" must be a string')
87
85
  } else if (!VALID_PROVIDERS.includes(model.provider)) {
88
86
  errors.push(`"provider" must be one of: ${VALID_PROVIDERS.join(', ')}. Got: "${model.provider}"`)
89
87
  }
90
88
 
91
- // Check required number fields (must be non-negative)
89
+ // Check optional number fields if present (must be non-negative)
92
90
  const numberFields = ['input_price', 'output_price', 'cache_price', 'max_in', 'max_out']
93
91
  for (const field of numberFields) {
94
- if (typeof model[field] !== 'number') {
92
+ if (model[field] !== undefined && typeof model[field] !== 'number') {
95
93
  errors.push(`"${field}" must be a number`)
96
- } else if (model[field] < 0) {
94
+ } else if (model[field] !== undefined && model[field] < 0) {
97
95
  errors.push(`"${field}" must be non-negative, got: ${model[field]}`)
98
96
  }
99
97
  }
100
98
 
101
- // Check enable is boolean
102
- if (typeof model.enable !== 'boolean') {
99
+ // Check optional enable if present
100
+ if (model.enable !== undefined && typeof model.enable !== 'boolean') {
103
101
  errors.push('"enable" must be a boolean')
104
102
  }
105
103
 
@@ -122,30 +120,69 @@ const validateModelRecord = (model, index) => {
122
120
  }
123
121
 
124
122
  /**
125
- * Looks up a model by ID, validates it is enabled, and resolves its
126
- * effective supported params (record-level override or provider default).
123
+ * Normalizes a model record by setting defaults for missing fields.
124
+ * Generates id from provider and name if not provided.
127
125
  *
128
- * @param {string} modelId
126
+ * @param {Object} model - The model record to normalize
127
+ * @returns {ModelRecord} Normalized model record
128
+ */
129
+ const normalizeModelRecord = (model) => {
130
+ return {
131
+ id: model.id || `${model.provider}_${model.name}`,
132
+ name: model.name,
133
+ provider: model.provider,
134
+ input_price: model.input_price ?? 0,
135
+ output_price: model.output_price ?? 0,
136
+ cache_price: model.cache_price ?? 0,
137
+ max_in: model.max_in ?? 32000,
138
+ max_out: model.max_out ?? 8000,
139
+ enable: model.enable ?? true,
140
+ ...(model.supportedParams !== undefined && { supportedParams: model.supportedParams }),
141
+ }
142
+ }
143
+
144
+ /**
145
+ * Looks up a model by provider/name format.
146
+ * Validates it is enabled, and resolves its effective supported params.
147
+ *
148
+ * @param {string} modelId - Model in 'provider/name' format (e.g., 'openai/gpt-4o', 'ollama/llama3.2')
129
149
  * @returns {{ record: ModelRecord, supportedParams: string[] }}
130
150
  * @throws {Error} When the model is not found or is disabled
131
151
  */
132
152
  export const getModel = (modelId) => {
133
- const record = REGISTRY.get(modelId)
134
-
135
- if (!record) {
136
- const available = [...REGISTRY.keys()].join(', ')
137
- throw new Error(`Unknown model "${modelId}". Available: ${available}`)
153
+ // Require provider/name format
154
+ if (!modelId.includes('/')) {
155
+ const available = [...REGISTRY.values()].map(m => `${m.provider}/${m.name}`).join(', ')
156
+ throw new Error(`Model must be in 'provider/name' format. Got: "${modelId}". Available: ${available}`)
138
157
  }
139
-
140
- if (!record.enable) {
141
- throw new Error(`Model "${modelId}" is currently disabled.`)
158
+
159
+ const parts = modelId.split('/')
160
+ if (parts.length !== 2) {
161
+ const available = [...REGISTRY.values()].map(m => `${m.provider}/${m.name}`).join(', ')
162
+ throw new Error(`Model must be in 'provider/name' format. Got: "${modelId}". Available: ${available}`)
142
163
  }
143
-
144
- const supportedParams = record.supportedParams ?? PROVIDER_DEFAULT_PARAMS[record.provider]
145
-
146
- return {
147
- record, supportedParams,
164
+
165
+ const [provider, name] = parts
166
+
167
+ // Search for model by name and provider
168
+ for (const m of REGISTRY.values()) {
169
+ if (m.name === name && m.provider === provider) {
170
+ const record = m
171
+
172
+ if (!record.enable) {
173
+ throw new Error(`Model "${record.provider}/${record.name}" is currently disabled.`)
174
+ }
175
+
176
+ const supportedParams = record.supportedParams ?? PROVIDER_DEFAULT_PARAMS[record.provider]
177
+
178
+ return {
179
+ record, supportedParams,
180
+ }
181
+ }
148
182
  }
183
+
184
+ const available = [...REGISTRY.values()].map(m => `${m.provider}/${m.name}`).join(', ')
185
+ throw new Error(`Unknown model "${modelId}". Available: ${available}`)
149
186
  }
150
187
 
151
188
  /**
@@ -168,14 +205,15 @@ export const addModels = (models) => {
168
205
  throw new Error(`addModels expects an array. Got: ${typeof models}`)
169
206
  }
170
207
 
171
- // Validate each model record
208
+ // Validate and normalize each model record
172
209
  models.forEach((model, index) => {
173
210
  validateModelRecord(model, index)
174
211
  })
175
212
 
176
- // Add models to the registry
213
+ // Add normalized models to the registry
177
214
  models.forEach((model) => {
178
- REGISTRY.set(model.id, model)
215
+ const normalized = normalizeModelRecord(model)
216
+ REGISTRY.set(normalized.id, normalized)
179
217
  })
180
218
  }
181
219
 
@@ -191,10 +229,13 @@ export const setModels = (models) => {
191
229
  throw new Error(`setModels expects an array. Got: ${typeof models}`)
192
230
  }
193
231
 
194
- // Validate each model record strictly
232
+ // Validate and normalize each model record
195
233
  models.forEach((model, index) => {
196
234
  validateModelRecord(model, index)
197
235
  })
198
236
 
199
- REGISTRY = new Map(models.map((model) => [model.id, model]))
237
+ REGISTRY = new Map(models.map((model) => {
238
+ const normalized = normalizeModelRecord(model)
239
+ return [normalized.id, normalized]
240
+ }))
200
241
  }