@auto-engineer/ai-gateway 0.7.0 → 0.8.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +1 -1
- package/.turbo/turbo-test.log +8 -8
- package/.turbo/turbo-type-check.log +1 -1
- package/CHANGELOG.md +2 -0
- package/README.md +365 -0
- package/dist/config.d.ts +2 -0
- package/dist/config.d.ts.map +1 -1
- package/dist/config.js +31 -2
- package/dist/config.js.map +1 -1
- package/dist/config.specs.d.ts +2 -0
- package/dist/config.specs.d.ts.map +1 -0
- package/dist/config.specs.js +123 -0
- package/dist/config.specs.js.map +1 -0
- package/dist/constants.d.ts +20 -0
- package/dist/constants.d.ts.map +1 -0
- package/dist/constants.js +15 -0
- package/dist/constants.js.map +1 -0
- package/dist/index-custom.specs.d.ts +2 -0
- package/dist/index-custom.specs.d.ts.map +1 -0
- package/dist/index-custom.specs.js +161 -0
- package/dist/index-custom.specs.js.map +1 -0
- package/dist/index.d.ts +12 -13
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +93 -59
- package/dist/index.js.map +1 -1
- package/dist/index.specs.js +152 -11
- package/dist/index.specs.js.map +1 -1
- package/dist/providers/custom.d.ts +6 -0
- package/dist/providers/custom.d.ts.map +1 -0
- package/dist/providers/custom.js +16 -0
- package/dist/providers/custom.js.map +1 -0
- package/dist/providers/custom.specs.d.ts +2 -0
- package/dist/providers/custom.specs.d.ts.map +1 -0
- package/dist/providers/custom.specs.js +129 -0
- package/dist/providers/custom.specs.js.map +1 -0
- package/package.json +5 -5
- package/src/config.specs.ts +147 -0
- package/src/config.ts +46 -2
- package/src/constants.ts +21 -0
- package/src/index-custom.specs.ts +192 -0
- package/src/index.specs.ts +199 -10
- package/src/index.ts +99 -78
- package/src/providers/custom.specs.ts +161 -0
- package/src/providers/custom.ts +24 -0
- package/tsconfig.tsbuildinfo +1 -1
package/src/index.ts
CHANGED
|
@@ -4,6 +4,10 @@ import { anthropic } from '@ai-sdk/anthropic';
|
|
|
4
4
|
import { google } from '@ai-sdk/google';
|
|
5
5
|
import { xai } from '@ai-sdk/xai';
|
|
6
6
|
import { configureAIProvider } from './config';
|
|
7
|
+
import { DEFAULT_MODELS, AIProvider } from './constants';
|
|
8
|
+
import { createCustomProvider } from './providers/custom';
|
|
9
|
+
|
|
10
|
+
export { AIProvider } from './constants';
|
|
7
11
|
import { z } from 'zod';
|
|
8
12
|
import { getRegisteredToolsForAI } from './mcp-server';
|
|
9
13
|
import { startServer } from './mcp-server';
|
|
@@ -137,14 +141,8 @@ interface AIToolValidationError extends Error {
|
|
|
137
141
|
[key: string]: unknown;
|
|
138
142
|
}
|
|
139
143
|
|
|
140
|
-
export enum AIProvider {
|
|
141
|
-
OpenAI = 'openai',
|
|
142
|
-
Anthropic = 'anthropic',
|
|
143
|
-
Google = 'google',
|
|
144
|
-
XAI = 'xai',
|
|
145
|
-
}
|
|
146
|
-
|
|
147
144
|
export interface AIOptions {
|
|
145
|
+
provider?: AIProvider;
|
|
148
146
|
model?: string;
|
|
149
147
|
temperature?: number;
|
|
150
148
|
maxTokens?: number;
|
|
@@ -167,22 +165,47 @@ const defaultOptions: AIOptions = {
|
|
|
167
165
|
maxTokens: 1000,
|
|
168
166
|
};
|
|
169
167
|
|
|
170
|
-
function
|
|
171
|
-
const
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
168
|
+
export function getDefaultAIProvider(): AIProvider {
|
|
169
|
+
const envProvider = process.env.DEFAULT_AI_PROVIDER?.toLowerCase();
|
|
170
|
+
switch (envProvider) {
|
|
171
|
+
case 'openai':
|
|
172
|
+
return AIProvider.OpenAI;
|
|
173
|
+
case 'anthropic':
|
|
174
|
+
return AIProvider.Anthropic;
|
|
175
|
+
case 'google':
|
|
176
|
+
return AIProvider.Google;
|
|
177
|
+
case 'xai':
|
|
178
|
+
return AIProvider.XAI;
|
|
179
|
+
default: {
|
|
180
|
+
// Fallback to the first available provider
|
|
181
|
+
const available = getAvailableProviders();
|
|
182
|
+
return available.length > 0 ? available[0] : AIProvider.Anthropic;
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
export function getDefaultModel(provider: AIProvider): string {
|
|
188
|
+
// Check if DEFAULT_AI_MODEL is set in environment
|
|
189
|
+
const envModel = process.env.DEFAULT_AI_MODEL;
|
|
190
|
+
if (envModel !== undefined && envModel !== null && envModel.trim().length > 0) {
|
|
191
|
+
debugConfig('Using DEFAULT_AI_MODEL from environment: %s for provider %s', envModel, provider);
|
|
192
|
+
return envModel.trim();
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
if (provider === AIProvider.Custom) {
|
|
196
|
+
const config = configureAIProvider();
|
|
197
|
+
if (config.custom != null) {
|
|
198
|
+
debugConfig('Selected custom provider default model %s for provider %s', config.custom.defaultModel, provider);
|
|
199
|
+
return config.custom.defaultModel;
|
|
184
200
|
}
|
|
185
|
-
|
|
201
|
+
throw new Error('Custom provider not configured');
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
const model =
|
|
205
|
+
DEFAULT_MODELS[provider] ??
|
|
206
|
+
(() => {
|
|
207
|
+
throw new Error(`Unknown provider: ${provider}`);
|
|
208
|
+
})();
|
|
186
209
|
debugConfig('Selected default model %s for provider %s', model, provider);
|
|
187
210
|
return model;
|
|
188
211
|
}
|
|
@@ -200,30 +223,35 @@ function getModel(provider: AIProvider, model?: string) {
|
|
|
200
223
|
return google(modelName);
|
|
201
224
|
case AIProvider.XAI:
|
|
202
225
|
return xai(modelName);
|
|
226
|
+
case AIProvider.Custom: {
|
|
227
|
+
const config = configureAIProvider();
|
|
228
|
+
if (config.custom == null) {
|
|
229
|
+
throw new Error('Custom provider not configured');
|
|
230
|
+
}
|
|
231
|
+
const customProvider = createCustomProvider(config.custom);
|
|
232
|
+
return customProvider.languageModel(modelName);
|
|
233
|
+
}
|
|
203
234
|
default:
|
|
204
235
|
throw new Error(`Unknown provider: ${provider as string}`);
|
|
205
236
|
}
|
|
206
237
|
}
|
|
207
238
|
|
|
208
|
-
export async function generateTextWithAI(
|
|
209
|
-
|
|
210
|
-
provider:
|
|
211
|
-
options: AIOptions = {},
|
|
212
|
-
): Promise<string> {
|
|
213
|
-
debugAPI('generateTextWithAI called - provider: %s, promptLength: %d', provider, prompt.length);
|
|
239
|
+
export async function generateTextWithAI(prompt: string, options: AIOptions = {}): Promise<string> {
|
|
240
|
+
const resolvedProvider = options.provider ?? getDefaultAIProvider();
|
|
241
|
+
debugAPI('generateTextWithAI called - provider: %s, promptLength: %d', resolvedProvider, prompt.length);
|
|
214
242
|
const finalOptions = { ...defaultOptions, ...options };
|
|
215
|
-
const model = finalOptions.model ?? getDefaultModel(
|
|
216
|
-
const modelInstance = getModel(
|
|
243
|
+
const model = finalOptions.model ?? getDefaultModel(resolvedProvider);
|
|
244
|
+
const modelInstance = getModel(resolvedProvider, model);
|
|
217
245
|
|
|
218
246
|
if (finalOptions.includeTools === true) {
|
|
219
247
|
debugTools('Tools requested, starting MCP server');
|
|
220
248
|
await startServer();
|
|
221
|
-
const result = await generateTextWithToolsAI(prompt,
|
|
249
|
+
const result = await generateTextWithToolsAI(prompt, options);
|
|
222
250
|
return result.text;
|
|
223
251
|
}
|
|
224
252
|
|
|
225
253
|
try {
|
|
226
|
-
debugAPI('Making API call to %s with model %s',
|
|
254
|
+
debugAPI('Making API call to %s with model %s', resolvedProvider, model);
|
|
227
255
|
debugAPI('Request params - temperature: %d, maxTokens: %d', finalOptions.temperature, finalOptions.maxTokens);
|
|
228
256
|
|
|
229
257
|
const result = await generateText({
|
|
@@ -236,22 +264,19 @@ export async function generateTextWithAI(
|
|
|
236
264
|
debugAPI('API call successful - response length: %d, usage: %o', result.text.length, result.usage);
|
|
237
265
|
return result.text;
|
|
238
266
|
} catch (error) {
|
|
239
|
-
extractAndLogError(error,
|
|
267
|
+
extractAndLogError(error, resolvedProvider, 'generateTextWithAI');
|
|
240
268
|
throw error;
|
|
241
269
|
}
|
|
242
270
|
}
|
|
243
271
|
|
|
244
|
-
export async function* streamTextWithAI(
|
|
245
|
-
|
|
246
|
-
provider:
|
|
247
|
-
options: AIOptions = {},
|
|
248
|
-
): AsyncGenerator<string> {
|
|
249
|
-
debugStream('streamTextWithAI called - provider: %s, promptLength: %d', provider, prompt.length);
|
|
272
|
+
export async function* streamTextWithAI(prompt: string, options: AIOptions = {}): AsyncGenerator<string> {
|
|
273
|
+
const resolvedProvider = options.provider ?? getDefaultAIProvider();
|
|
274
|
+
debugStream('streamTextWithAI called - provider: %s, promptLength: %d', resolvedProvider, prompt.length);
|
|
250
275
|
const finalOptions = { ...defaultOptions, ...options };
|
|
251
|
-
const model = getModel(
|
|
276
|
+
const model = getModel(resolvedProvider, finalOptions.model);
|
|
252
277
|
|
|
253
278
|
try {
|
|
254
|
-
debugStream('Starting stream from %s',
|
|
279
|
+
debugStream('Starting stream from %s', resolvedProvider);
|
|
255
280
|
const stream = await streamText({
|
|
256
281
|
model,
|
|
257
282
|
prompt,
|
|
@@ -269,7 +294,7 @@ export async function* streamTextWithAI(
|
|
|
269
294
|
}
|
|
270
295
|
debugStream('Stream completed - total chunks: %d, total length: %d', totalChunks, totalLength);
|
|
271
296
|
} catch (error) {
|
|
272
|
-
extractAndLogError(error,
|
|
297
|
+
extractAndLogError(error, resolvedProvider, 'streamTextWithAI');
|
|
273
298
|
throw error;
|
|
274
299
|
}
|
|
275
300
|
}
|
|
@@ -279,16 +304,13 @@ export async function* streamTextWithAI(
|
|
|
279
304
|
* Optionally calls a stream callback for each token if provided.
|
|
280
305
|
* Always returns the complete collected response.
|
|
281
306
|
*/
|
|
282
|
-
export async function generateTextStreamingWithAI(
|
|
283
|
-
|
|
284
|
-
provider:
|
|
285
|
-
options: AIOptions = {},
|
|
286
|
-
): Promise<string> {
|
|
287
|
-
debugStream('generateTextStreamingWithAI called - provider: %s', provider);
|
|
307
|
+
export async function generateTextStreamingWithAI(prompt: string, options: AIOptions = {}): Promise<string> {
|
|
308
|
+
const resolvedProvider = options.provider ?? getDefaultAIProvider();
|
|
309
|
+
debugStream('generateTextStreamingWithAI called - provider: %s', resolvedProvider);
|
|
288
310
|
const finalOptions = { ...defaultOptions, ...options };
|
|
289
311
|
let collectedResult = '';
|
|
290
312
|
|
|
291
|
-
const stream = streamTextWithAI(prompt,
|
|
313
|
+
const stream = streamTextWithAI(prompt, finalOptions);
|
|
292
314
|
|
|
293
315
|
let tokenCount = 0;
|
|
294
316
|
for await (const token of stream) {
|
|
@@ -386,13 +408,13 @@ async function executeToolConversation(
|
|
|
386
408
|
|
|
387
409
|
export async function generateTextWithToolsAI(
|
|
388
410
|
prompt: string,
|
|
389
|
-
provider: AIProvider,
|
|
390
411
|
options: AIOptions = {},
|
|
391
412
|
): Promise<{ text: string; toolCalls?: unknown[] }> {
|
|
392
|
-
|
|
413
|
+
const resolvedProvider = options.provider ?? getDefaultAIProvider();
|
|
414
|
+
debugTools('generateTextWithToolsAI called - provider: %s', resolvedProvider);
|
|
393
415
|
const finalOptions = { ...defaultOptions, ...options };
|
|
394
|
-
const model = finalOptions.model ?? getDefaultModel(
|
|
395
|
-
const modelInstance = getModel(
|
|
416
|
+
const model = finalOptions.model ?? getDefaultModel(resolvedProvider);
|
|
417
|
+
const modelInstance = getModel(resolvedProvider, model);
|
|
396
418
|
|
|
397
419
|
const registeredTools = finalOptions.includeTools === true ? getRegisteredToolsForAI() : {};
|
|
398
420
|
debugTools('Registered tools: %o', Object.keys(registeredTools));
|
|
@@ -408,7 +430,7 @@ export async function generateTextWithToolsAI(
|
|
|
408
430
|
registeredTools,
|
|
409
431
|
hasTools,
|
|
410
432
|
finalOptions,
|
|
411
|
-
|
|
433
|
+
resolvedProvider,
|
|
412
434
|
);
|
|
413
435
|
|
|
414
436
|
return {
|
|
@@ -450,26 +472,26 @@ async function executeToolCalls(
|
|
|
450
472
|
export async function generateTextWithImageAI(
|
|
451
473
|
text: string,
|
|
452
474
|
imageBase64: string,
|
|
453
|
-
provider: AIProvider,
|
|
454
475
|
options: AIOptions = {},
|
|
455
476
|
): Promise<string> {
|
|
477
|
+
const resolvedProvider = options.provider ?? getDefaultAIProvider();
|
|
456
478
|
debugAPI(
|
|
457
479
|
'generateTextWithImageAI called - provider: %s, textLength: %d, imageSize: %d',
|
|
458
|
-
|
|
480
|
+
resolvedProvider,
|
|
459
481
|
text.length,
|
|
460
482
|
imageBase64.length,
|
|
461
483
|
);
|
|
462
484
|
const finalOptions = { ...defaultOptions, ...options };
|
|
463
|
-
const model = finalOptions.model ?? getDefaultModel(
|
|
464
|
-
const modelInstance = getModel(
|
|
485
|
+
const model = finalOptions.model ?? getDefaultModel(resolvedProvider);
|
|
486
|
+
const modelInstance = getModel(resolvedProvider, model);
|
|
465
487
|
|
|
466
|
-
if (
|
|
467
|
-
debugError('Provider %s does not support image inputs',
|
|
468
|
-
throw new Error(`Provider ${
|
|
488
|
+
if (resolvedProvider !== AIProvider.OpenAI && resolvedProvider !== AIProvider.XAI) {
|
|
489
|
+
debugError('Provider %s does not support image inputs', resolvedProvider);
|
|
490
|
+
throw new Error(`Provider ${resolvedProvider} does not support image inputs`);
|
|
469
491
|
}
|
|
470
492
|
|
|
471
493
|
try {
|
|
472
|
-
debugAPI('Sending image+text to %s',
|
|
494
|
+
debugAPI('Sending image+text to %s', resolvedProvider);
|
|
473
495
|
const result = await generateText({
|
|
474
496
|
model: modelInstance,
|
|
475
497
|
messages: [
|
|
@@ -488,7 +510,7 @@ export async function generateTextWithImageAI(
|
|
|
488
510
|
debugAPI('Image API call successful - response length: %d', result.text.length);
|
|
489
511
|
return result.text;
|
|
490
512
|
} catch (error) {
|
|
491
|
-
extractAndLogError(error,
|
|
513
|
+
extractAndLogError(error, resolvedProvider, 'generateTextWithImageAI');
|
|
492
514
|
throw error;
|
|
493
515
|
}
|
|
494
516
|
}
|
|
@@ -496,10 +518,11 @@ export async function generateTextWithImageAI(
|
|
|
496
518
|
export function getAvailableProviders(): AIProvider[] {
|
|
497
519
|
const config = configureAIProvider();
|
|
498
520
|
const providers: AIProvider[] = [];
|
|
499
|
-
if (config.openai != null) providers.push(AIProvider.OpenAI);
|
|
500
521
|
if (config.anthropic != null) providers.push(AIProvider.Anthropic);
|
|
522
|
+
if (config.openai != null) providers.push(AIProvider.OpenAI);
|
|
501
523
|
if (config.google != null) providers.push(AIProvider.Google);
|
|
502
524
|
if (config.xai != null) providers.push(AIProvider.XAI);
|
|
525
|
+
if (config.custom != null) providers.push(AIProvider.Custom);
|
|
503
526
|
debugConfig('Available providers: %o', providers);
|
|
504
527
|
return providers;
|
|
505
528
|
}
|
|
@@ -626,12 +649,13 @@ async function attemptStructuredGeneration<T>(
|
|
|
626
649
|
throw lastError;
|
|
627
650
|
}
|
|
628
651
|
|
|
629
|
-
export async function generateStructuredDataWithAI<T>(
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
652
|
+
export async function generateStructuredDataWithAI<T>(prompt: string, options: StructuredAIOptions<T>): Promise<T> {
|
|
653
|
+
const resolvedProvider = options.provider ?? getDefaultAIProvider();
|
|
654
|
+
debugAPI(
|
|
655
|
+
'generateStructuredDataWithAI called - provider: %s, schema: %s',
|
|
656
|
+
resolvedProvider,
|
|
657
|
+
options.schemaName ?? 'unnamed',
|
|
658
|
+
);
|
|
635
659
|
|
|
636
660
|
if (options.includeTools === true) {
|
|
637
661
|
debugTools('Tools requested, starting MCP server');
|
|
@@ -641,17 +665,14 @@ export async function generateStructuredDataWithAI<T>(
|
|
|
641
665
|
debugTools('Registered tools for structured data: %o', Object.keys(registeredTools));
|
|
642
666
|
const hasTools = Object.keys(registeredTools).length > 0;
|
|
643
667
|
|
|
644
|
-
return attemptStructuredGeneration(prompt,
|
|
668
|
+
return attemptStructuredGeneration(prompt, resolvedProvider, options, registeredTools, hasTools);
|
|
645
669
|
}
|
|
646
670
|
|
|
647
|
-
export async function streamStructuredDataWithAI<T>(
|
|
648
|
-
|
|
649
|
-
provider: AIProvider,
|
|
650
|
-
options: StreamStructuredAIOptions<T>,
|
|
651
|
-
): Promise<T> {
|
|
671
|
+
export async function streamStructuredDataWithAI<T>(prompt: string, options: StreamStructuredAIOptions<T>): Promise<T> {
|
|
672
|
+
const resolvedProvider = options.provider ?? getDefaultAIProvider();
|
|
652
673
|
debugStream(
|
|
653
674
|
'streamStructuredDataWithAI called - provider: %s, schema: %s',
|
|
654
|
-
|
|
675
|
+
resolvedProvider,
|
|
655
676
|
options.schemaName ?? 'unnamed',
|
|
656
677
|
);
|
|
657
678
|
const maxSchemaRetries = 3;
|
|
@@ -660,7 +681,7 @@ export async function streamStructuredDataWithAI<T>(
|
|
|
660
681
|
for (let attempt = 0; attempt < maxSchemaRetries; attempt++) {
|
|
661
682
|
try {
|
|
662
683
|
debugValidation('Stream structured data attempt %d/%d', attempt + 1, maxSchemaRetries);
|
|
663
|
-
const model = getModel(
|
|
684
|
+
const model = getModel(resolvedProvider, options.model);
|
|
664
685
|
|
|
665
686
|
const enhancedPrompt = attempt > 0 && lastError ? getEnhancedPrompt(prompt, lastError) : prompt;
|
|
666
687
|
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
2
|
+
import { createCustomProvider } from './custom';
|
|
3
|
+
import { CustomProviderConfig } from '../constants';
|
|
4
|
+
|
|
5
|
+
interface MockConfig {
|
|
6
|
+
name: string;
|
|
7
|
+
baseURL: string;
|
|
8
|
+
apiKey: string;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
// Mock the createOpenAI function
|
|
12
|
+
vi.mock('@ai-sdk/openai', () => ({
|
|
13
|
+
createOpenAI: vi.fn((config: MockConfig) => ({
|
|
14
|
+
languageModel: vi.fn((modelId: string) => ({
|
|
15
|
+
modelId,
|
|
16
|
+
provider: config.name,
|
|
17
|
+
specificationVersion: 'v1' as const,
|
|
18
|
+
defaultObjectGenerationMode: 'json' as const,
|
|
19
|
+
})),
|
|
20
|
+
})),
|
|
21
|
+
}));
|
|
22
|
+
|
|
23
|
+
describe('Custom Provider', () => {
|
|
24
|
+
const mockConfig: CustomProviderConfig = {
|
|
25
|
+
name: 'test-provider',
|
|
26
|
+
baseUrl: 'https://api.example.com/v1',
|
|
27
|
+
apiKey: 'test-api-key',
|
|
28
|
+
defaultModel: 'test-model',
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
beforeEach(() => {
|
|
32
|
+
vi.clearAllMocks();
|
|
33
|
+
});
|
|
34
|
+
|
|
35
|
+
describe('createCustomProvider', () => {
|
|
36
|
+
it('should create a custom provider with the correct configuration', () => {
|
|
37
|
+
const provider = createCustomProvider(mockConfig);
|
|
38
|
+
|
|
39
|
+
expect(provider).toBeDefined();
|
|
40
|
+
// Provider creation should succeed with valid config
|
|
41
|
+
});
|
|
42
|
+
|
|
43
|
+
it('should call createOpenAI with the correct parameters', async () => {
|
|
44
|
+
const { createOpenAI } = await import('@ai-sdk/openai');
|
|
45
|
+
|
|
46
|
+
createCustomProvider(mockConfig);
|
|
47
|
+
|
|
48
|
+
expect(createOpenAI).toHaveBeenCalledWith({
|
|
49
|
+
name: mockConfig.name,
|
|
50
|
+
baseURL: mockConfig.baseUrl,
|
|
51
|
+
apiKey: mockConfig.apiKey,
|
|
52
|
+
});
|
|
53
|
+
});
|
|
54
|
+
|
|
55
|
+
it('should create a provider that can create language models', () => {
|
|
56
|
+
const provider = createCustomProvider(mockConfig);
|
|
57
|
+
const model = provider.languageModel('test-model');
|
|
58
|
+
|
|
59
|
+
expect(model).toBeDefined();
|
|
60
|
+
expect(model.modelId).toBe('test-model');
|
|
61
|
+
expect(model.provider).toBe(mockConfig.name);
|
|
62
|
+
});
|
|
63
|
+
|
|
64
|
+
it('should handle different base URLs correctly', () => {
|
|
65
|
+
const configs = [
|
|
66
|
+
{ ...mockConfig, baseUrl: 'https://api.litellm.ai' },
|
|
67
|
+
{ ...mockConfig, baseUrl: 'http://localhost:8000' },
|
|
68
|
+
{ ...mockConfig, baseUrl: 'https://custom-llm.company.com/api' },
|
|
69
|
+
];
|
|
70
|
+
|
|
71
|
+
configs.forEach((config) => {
|
|
72
|
+
const provider = createCustomProvider(config);
|
|
73
|
+
expect(provider).toBeDefined();
|
|
74
|
+
});
|
|
75
|
+
});
|
|
76
|
+
|
|
77
|
+
it('should handle different provider names', () => {
|
|
78
|
+
const configs = [
|
|
79
|
+
{ ...mockConfig, name: 'litellm' },
|
|
80
|
+
{ ...mockConfig, name: 'local-llm' },
|
|
81
|
+
{ ...mockConfig, name: 'company-custom-llm' },
|
|
82
|
+
];
|
|
83
|
+
|
|
84
|
+
configs.forEach((config) => {
|
|
85
|
+
const provider = createCustomProvider(config);
|
|
86
|
+
expect(provider).toBeDefined();
|
|
87
|
+
});
|
|
88
|
+
});
|
|
89
|
+
});
|
|
90
|
+
|
|
91
|
+
describe('provider compatibility', () => {
|
|
92
|
+
it('should be compatible with OpenAI-style endpoints', () => {
|
|
93
|
+
const litellmConfig: CustomProviderConfig = {
|
|
94
|
+
name: 'litellm',
|
|
95
|
+
baseUrl: 'https://api.litellm.ai/chat/completions',
|
|
96
|
+
apiKey: 'sk-litellm-key',
|
|
97
|
+
defaultModel: 'claude-3-sonnet',
|
|
98
|
+
};
|
|
99
|
+
|
|
100
|
+
const provider = createCustomProvider(litellmConfig);
|
|
101
|
+
expect(provider).toBeDefined();
|
|
102
|
+
});
|
|
103
|
+
|
|
104
|
+
it('should work with localhost endpoints for development', () => {
|
|
105
|
+
const localConfig: CustomProviderConfig = {
|
|
106
|
+
name: 'local-dev',
|
|
107
|
+
baseUrl: 'http://localhost:8000',
|
|
108
|
+
apiKey: 'local-key',
|
|
109
|
+
defaultModel: 'local-model',
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
const provider = createCustomProvider(localConfig);
|
|
113
|
+
expect(provider).toBeDefined();
|
|
114
|
+
});
|
|
115
|
+
});
|
|
116
|
+
|
|
117
|
+
describe('error handling', () => {
|
|
118
|
+
it('should pass through any errors from createOpenAI', async () => {
|
|
119
|
+
const { createOpenAI } = await import('@ai-sdk/openai');
|
|
120
|
+
const mockError = new Error('Invalid API key');
|
|
121
|
+
|
|
122
|
+
vi.mocked(createOpenAI).mockImplementationOnce(() => {
|
|
123
|
+
throw mockError;
|
|
124
|
+
});
|
|
125
|
+
|
|
126
|
+
expect(() => createCustomProvider(mockConfig)).toThrow('Invalid API key');
|
|
127
|
+
});
|
|
128
|
+
});
|
|
129
|
+
|
|
130
|
+
describe('configuration validation', () => {
|
|
131
|
+
it('should handle minimal required configuration', () => {
|
|
132
|
+
const minimalConfig: CustomProviderConfig = {
|
|
133
|
+
name: 'minimal',
|
|
134
|
+
baseUrl: 'https://api.example.com',
|
|
135
|
+
apiKey: 'key',
|
|
136
|
+
defaultModel: 'model',
|
|
137
|
+
};
|
|
138
|
+
|
|
139
|
+
const provider = createCustomProvider(minimalConfig);
|
|
140
|
+
expect(provider).toBeDefined();
|
|
141
|
+
});
|
|
142
|
+
|
|
143
|
+
it('should preserve all configuration properties', async () => {
|
|
144
|
+
const fullConfig: CustomProviderConfig = {
|
|
145
|
+
name: 'full-config',
|
|
146
|
+
baseUrl: 'https://api.example.com/v1',
|
|
147
|
+
apiKey: 'sk-test-key-123',
|
|
148
|
+
defaultModel: 'gpt-4o',
|
|
149
|
+
};
|
|
150
|
+
|
|
151
|
+
createCustomProvider(fullConfig);
|
|
152
|
+
|
|
153
|
+
const { createOpenAI } = await import('@ai-sdk/openai');
|
|
154
|
+
expect(createOpenAI).toHaveBeenCalledWith({
|
|
155
|
+
name: fullConfig.name,
|
|
156
|
+
baseURL: fullConfig.baseUrl,
|
|
157
|
+
apiKey: fullConfig.apiKey,
|
|
158
|
+
});
|
|
159
|
+
});
|
|
160
|
+
});
|
|
161
|
+
});
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import { createOpenAI } from '@ai-sdk/openai';
|
|
2
|
+
import { CustomProviderConfig } from '../constants';
|
|
3
|
+
import createDebug from 'debug';
|
|
4
|
+
|
|
5
|
+
const debug = createDebug('ai-gateway:custom');
|
|
6
|
+
|
|
7
|
+
export interface CustomProviderOptions {
|
|
8
|
+
config: CustomProviderConfig;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
export function createCustomProvider(config: CustomProviderConfig) {
|
|
12
|
+
debug('Creating custom provider: %s with baseUrl: %s', config.name, config.baseUrl);
|
|
13
|
+
|
|
14
|
+
// Use OpenAI's provider implementation but with custom baseUrl
|
|
15
|
+
// This leverages the existing, battle-tested OpenAI provider while allowing custom endpoints
|
|
16
|
+
const customProvider = createOpenAI({
|
|
17
|
+
name: config.name,
|
|
18
|
+
baseURL: config.baseUrl,
|
|
19
|
+
apiKey: config.apiKey,
|
|
20
|
+
});
|
|
21
|
+
|
|
22
|
+
debug('Custom provider created successfully: %s', config.name);
|
|
23
|
+
return customProvider;
|
|
24
|
+
}
|