@pwshub/aisdk 0.0.4 → 0.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/index.js CHANGED
@@ -6,7 +6,7 @@
6
6
  *
7
7
  * const ai = createAi()
8
8
  * const result = await ai.ask({
9
- * model: 'claude-sonnet-4-20250514',
9
+ * model: 'anthropic/claude-sonnet-4-20250514',
10
10
  * apikey: 'your-api-key',
11
11
  * prompt: 'What is the capital of Vietnam?',
12
12
  * temperature: 0.5,
@@ -16,18 +16,18 @@
16
16
  *
17
17
  * @example With fallbacks
18
18
  * const result = await ai.ask({
19
- * model: 'gpt-4o',
19
+ * model: 'openai/gpt-4o',
20
20
  * apikey: 'your-openai-key',
21
21
  * prompt: '...',
22
- * fallbacks: ['gpt-4o-mini', 'claude-haiku-4-5-20251001'],
22
+ * fallbacks: ['openai/gpt-4o-mini', 'anthropic/claude-haiku-4-5-20251001'],
23
23
  * })
24
- * if (result.model !== 'gpt-4o') {
24
+ * if (result.model !== 'openai/gpt-4o') {
25
25
  * console.warn('Fell back to', result.model)
26
26
  * }
27
27
  *
28
28
  * @example Google provider-specific options
29
29
  * const result = await ai.ask({
30
- * model: 'gemini-2.0-flash',
30
+ * model: 'google/gemini-2.0-flash',
31
31
  * apikey: 'your-google-key',
32
32
  * prompt: '...',
33
33
  * providerOptions: {
@@ -40,7 +40,7 @@
40
40
  *
41
41
  * @example Using messages array for multi-turn conversations
42
42
  * const result = await ai.ask({
43
- * model: 'claude-sonnet-4-20250514',
43
+ * model: 'anthropic/claude-sonnet-4-20250514',
44
44
  * apikey: 'your-api-key',
45
45
  * messages: [
46
46
  * { role: 'user', content: 'What is the capital of Vietnam?' },
@@ -52,7 +52,7 @@
52
52
  */
53
53
 
54
54
  import {
55
- getModel, listModels, setModels, addModels,
55
+ getModel, createRegistry,
56
56
  } from './registry.js'
57
57
  import { normalizeConfig } from './config.js'
58
58
  import { coerceConfig } from './coerce.js'
@@ -61,24 +61,53 @@ import {
61
61
  ProviderError, InputError, throwHttpError,
62
62
  } from './errors.js'
63
63
  import { validateAskOptions } from './validation.js'
64
+ import { getLogger, setLogger, noopLogger } from './logger.js'
65
+ import { validateApiKey } from './security.js'
64
66
 
65
67
  export {
66
68
  ProviderError, InputError,
69
+ setLogger, noopLogger, getLogger,
67
70
  }
68
71
 
72
+ export { addModels, setModels, listModels, createRegistry } from './registry.js'
73
+ /**
74
+ * @typedef {Object} HookContext
75
+ * @property {string} model - Model identifier
76
+ * @property {string} provider - Provider ID
77
+ * @property {string} url - Request URL
78
+ * @property {Record<string, string>} headers - Request headers
79
+ * @property {Record<string, unknown>} body - Request body
80
+ */
81
+
82
+ /**
83
+ * @typedef {Object} ResponseHookContext
84
+ * @property {string} model - Model identifier
85
+ * @property {string} provider - Provider ID
86
+ * @property {string} url - Request URL
87
+ * @property {Record<string, string>} headers - Request headers
88
+ * @property {Record<string, unknown>} body - Request body
89
+ * @property {number} status - Response status code
90
+ * @property {unknown} data - Response data
91
+ * @property {number} duration - Request duration in milliseconds
92
+ */
93
+
69
94
  /**
70
95
  * @typedef {Object} AiOptions
71
96
  * @property {string} [gatewayUrl] - Optional AI gateway URL override
97
+ * @property {number} [timeout] - Request timeout in milliseconds (default: 30000)
98
+ * @property {import('./models.js').ModelRecord[]} [models] - Custom model registry
99
+ * @property {(context: HookContext) => void | Promise<void>} [onRequest] - Hook called before each request
100
+ * @property {(context: ResponseHookContext) => void | Promise<void>} [onResponse] - Hook called after each response
72
101
  */
73
102
 
74
103
  /**
75
104
  * @typedef {Object} AskParams
76
- * @property {string} model - Model ID (must be registered via setModels())
105
+ * @property {string} model - Model name or 'provider/name' format (e.g., 'gpt-4o', 'ollama/llama3.2')
77
106
  * @property {string} apikey - API key for the provider
78
107
  * @property {string} [prompt] - The user message (alternative to messages)
79
108
  * @property {string} [system] - Optional system prompt (used with prompt)
80
109
  * @property {import('./providers.js').Message[]} [messages] - Array of messages with role and content (alternative to prompt)
81
- * @property {string[]} [fallbacks] - Ordered list of fallback model IDs
110
+ * @property {string[]} [fallbacks] - Ordered list of fallback models (same format as model)
82
111
  * @property {Record<string, unknown>} [providerOptions] - Provider-specific options merged into body
83
112
  * @property {number} [temperature]
84
113
  * @property {number} [maxTokens]
@@ -110,7 +139,7 @@ export {
110
139
  * @returns {import('./config.js').GenerationConfig}
111
140
  */
112
141
  const extractGenConfig = (params) => {
113
- const keys = ['temperature', 'maxTokens', 'topP', 'topK', 'frequencyPenalty', 'presencePenalty']
142
+ const keys = ['temperature', 'maxTokens', 'topP', 'topK', 'frequencyPenalty', 'presencePenalty', 'stop', 'seed']
114
143
  return Object.fromEntries(
115
144
  keys.filter((k) => params[k] !== undefined).map((k) => [k, params[k]])
116
145
  )
@@ -126,7 +155,9 @@ const extractGenConfig = (params) => {
126
155
  const calcCost = (usage, record) => {
127
156
  const M = 1_000_000
128
157
  const inputCost = (usage.inputTokens / M) * record.input_price
129
- const outputCost = ((usage.outputTokens + usage.reasoningTokens) / M) * record.output_price
158
+ // Don't add reasoningTokens - they're already included in outputTokens
159
+ // reasoningTokens is for informational/tracking purposes only
160
+ const outputCost = (usage.outputTokens / M) * record.output_price
130
161
  const cacheCost = (usage.cacheTokens / M) * record.cache_price
131
162
 
132
163
  // Round to 8 decimal places to avoid floating point noise
@@ -144,21 +175,34 @@ const calcCost = (usage, record) => {
144
175
  * @throws {ProviderError} On 429 / 5xx — safe to retry or fallback
145
176
  * @throws {InputError} On 4xx — do not retry, fix the input
146
177
  */
147
- const callModel = async (modelId, params, gatewayUrl) => {
178
+ const callModel = async (modelId, params, gatewayUrl, registry = null, timeout = 30000, hooks = {}) => {
179
+ const logger = getLogger()
180
+ const { onRequest, onResponse } = hooks
181
+
182
+ // Use provided registry instance or fall back to global getModel
183
+ const modelLookup = registry ? registry.getModel : getModel
148
184
  const {
149
- record, supportedParams,
150
- } = getModel(modelId)
185
+ record, supportedParams, paramOverrides,
186
+ } = modelLookup(modelId)
151
187
  const {
152
188
  provider: providerId, name: modelName,
153
189
  } = record
154
190
 
155
191
  const { apikey } = params
192
+
193
+ // Validate API key before making request
194
+ validateApiKey(apikey, providerId, logger)
195
+
156
196
  const adapter = getAdapter(providerId)
157
197
 
158
198
  const genConfig = extractGenConfig(params)
159
199
 
160
200
  // Coerce values to provider's acceptable ranges (clamp, don't throw)
161
- const coerced = coerceConfig(genConfig, providerId)
201
+ // Pass model-specific param overrides
202
+ const { coerced } = coerceConfig(genConfig, providerId, {
203
+ modelId,
204
+ overrides: paramOverrides,
205
+ })
162
206
 
163
207
  // Normalize to wire format
164
208
  const normalizedConfig = normalizeConfig(coerced, providerId, supportedParams, modelId)
@@ -177,30 +221,79 @@ const callModel = async (modelId, params, gatewayUrl) => {
177
221
  },
178
222
  ]
179
223
 
180
- const url = gatewayUrl ?? adapter.url(modelName, apikey)
224
+ const url = adapter.url(modelName, apikey, gatewayUrl)
225
+ const requestHeaders = adapter.headers(apikey)
181
226
  const body = adapter.buildBody(modelName, messageList, normalizedConfig, providerOptions)
182
227
 
228
+ // Invoke onRequest hook
229
+ if (onRequest) {
230
+ await onRequest({
231
+ model: modelId,
232
+ provider: providerId,
233
+ url,
234
+ headers: requestHeaders,
235
+ body,
236
+ })
237
+ }
238
+
183
239
  let res
240
+ const controller = new AbortController()
241
+ const timeoutId = setTimeout(() => controller.abort(), timeout)
242
+ const startTime = Date.now()
243
+
184
244
  try {
185
245
  res = await fetch(url, {
186
246
  method: 'POST',
187
- headers: adapter.headers(apikey),
247
+ headers: requestHeaders,
188
248
  body: JSON.stringify(body),
249
+ signal: controller.signal,
189
250
  })
190
251
  } catch (networkErr) {
252
+ clearTimeout(timeoutId)
253
+
191
254
  // Network-level failure (DNS, connection refused) — treat as provider error
192
- throw new ProviderError(`Network error calling ${providerId}/${modelId}: ${networkErr.message}`, {
255
+ logger.warn(
256
+ `[ai-client] Network error calling ${providerId}/${modelId}: ${networkErr.message}`
257
+ )
258
+
259
+ if (networkErr.name === 'AbortError') {
260
+ throw new ProviderError(`Request timeout after ${timeout}ms`, {
261
+ status: 408,
262
+ provider: providerId,
263
+ model: modelId,
264
+ })
265
+ }
266
+
267
+ throw new ProviderError(`Network error calling ${providerId}/${modelId}`, {
193
268
  status: 0,
194
269
  provider: providerId,
195
270
  model: modelId,
196
271
  })
197
272
  }
198
273
 
274
+ clearTimeout(timeoutId)
275
+
199
276
  if (!res.ok) {
200
- await throwHttpError(res, providerId, modelId)
277
+ await throwHttpError(res, providerId, modelId, logger)
201
278
  }
202
279
 
203
280
  const data = await res.json()
281
+ const duration = Date.now() - startTime
282
+
283
+ // Invoke onResponse hook
284
+ if (onResponse) {
285
+ await onResponse({
286
+ model: modelId,
287
+ provider: providerId,
288
+ url,
289
+ headers: requestHeaders,
290
+ body,
291
+ status: res.status,
292
+ data,
293
+ duration,
294
+ })
295
+ }
296
+
204
297
  const rawUsage = adapter.extractUsage(data)
205
298
 
206
299
  /** @type {Usage} */
@@ -228,7 +321,11 @@ const callModel = async (modelId, params, gatewayUrl) => {
228
321
  * @returns {{ ask: (params: AskParams) => Promise<AskResult>, listModels: () => import('./registry.js').ModelRecord[] }}
229
322
  */
230
323
  export const createAi = (opts = {}) => {
231
- const { gatewayUrl } = opts
324
+ const { gatewayUrl, models, timeout, onRequest, onResponse } = opts
325
+ // Create isolated registry instance for this AI client
326
+ const registry = models
327
+ ? createRegistry(models)
328
+ : createRegistry()
232
329
 
233
330
  /**
234
331
  * Sends a text generation request, with optional fallback chain.
@@ -240,6 +337,8 @@ export const createAi = (opts = {}) => {
240
337
  * @throws {InputError} Immediately, without trying fallbacks
241
338
  */
242
339
  const ask = async (params) => {
340
+ const logger = getLogger()
341
+
243
342
  // Validate input structure and types
244
343
  try {
245
344
  validateAskOptions(params)
@@ -254,17 +353,18 @@ export const createAi = (opts = {}) => {
254
353
 
255
354
  const chain = [params.model, ...(params.fallbacks ?? [])]
256
355
  let lastProviderError
356
+ const hooks = { onRequest, onResponse }
257
357
 
258
358
  for (const modelId of chain) {
259
359
  try {
260
- return await callModel(modelId, params, gatewayUrl)
360
+ return await callModel(modelId, params, gatewayUrl, registry, timeout, hooks)
261
361
  } catch (err) {
262
362
  if (err instanceof InputError) {
263
363
  // Input errors are not fallback-able — rethrow immediately
264
364
  throw err
265
365
  }
266
366
  // ProviderError — log and try next model in chain
267
- console.warn(
367
+ logger.warn(
268
368
  `[ai-client] ${err.message}. ${modelId === chain.at(-1) ? 'No more fallbacks.' : 'Trying next fallback...'}`
269
369
  )
270
370
  lastProviderError = err
@@ -275,8 +375,8 @@ export const createAi = (opts = {}) => {
275
375
  }
276
376
 
277
377
  return {
278
- ask, listModels,
378
+ ask,
379
+ listModels: () => registry.listModels(),
380
+ addModels: (m) => registry.addModels(m),
279
381
  }
280
382
  }
281
-
282
- export { addModels, setModels, listModels }