promptfoo 0.18.1 → 0.18.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/package.json +1 -1
- package/dist/src/assertions.d.ts +2 -2
- package/dist/src/assertions.d.ts.map +1 -1
- package/dist/src/assertions.js +42 -11
- package/dist/src/assertions.js.map +1 -1
- package/dist/src/cache.d.ts +1 -1
- package/dist/src/cache.d.ts.map +1 -1
- package/dist/src/cache.js +4 -4
- package/dist/src/cache.js.map +1 -1
- package/dist/src/evaluator.d.ts.map +1 -1
- package/dist/src/evaluator.js +5 -2
- package/dist/src/evaluator.js.map +1 -1
- package/dist/src/main.js +4 -4
- package/dist/src/main.js.map +1 -1
- package/dist/src/providers/azureopenai.d.ts +2 -2
- package/dist/src/providers/azureopenai.d.ts.map +1 -1
- package/dist/src/providers/azureopenai.js +7 -5
- package/dist/src/providers/azureopenai.js.map +1 -1
- package/dist/src/providers/llama.js +1 -1
- package/dist/src/providers/llama.js.map +1 -1
- package/dist/src/providers/localai.js +2 -2
- package/dist/src/providers/localai.js.map +1 -1
- package/dist/src/providers/ollama.d.ts +9 -0
- package/dist/src/providers/ollama.d.ts.map +1 -0
- package/dist/src/providers/ollama.js +66 -0
- package/dist/src/providers/ollama.js.map +1 -0
- package/dist/src/providers/openai.d.ts +2 -2
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/openai.js +7 -5
- package/dist/src/providers/openai.js.map +1 -1
- package/dist/src/providers.d.ts.map +1 -1
- package/dist/src/providers.js +11 -5
- package/dist/src/providers.js.map +1 -1
- package/dist/src/types.d.ts +6 -2
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/util.d.ts +2 -0
- package/dist/src/util.d.ts.map +1 -1
- package/dist/src/util.js +24 -12
- package/dist/src/util.js.map +1 -1
- package/dist/src/web/client/assets/index-6d2a3573.js +200 -0
- package/dist/src/web/client/index.html +1 -1
- package/package.json +1 -1
- package/src/assertions.ts +45 -11
- package/src/cache.ts +3 -2
- package/src/evaluator.ts +5 -1
- package/src/main.ts +4 -4
- package/src/providers/azureopenai.ts +18 -6
- package/src/providers/llama.ts +2 -2
- package/src/providers/localai.ts +3 -3
- package/src/providers/ollama.ts +88 -0
- package/src/providers/openai.ts +8 -6
- package/src/providers.ts +20 -5
- package/src/types.ts +6 -2
- package/src/util.ts +25 -17
- package/src/web/client/package-lock.json +5726 -0
- package/src/web/client/src/EvalOutputPromptDialog.tsx +78 -16
- package/src/web/client/src/ResultsTable.tsx +32 -9
- package/src/web/client/src/ResultsView.tsx +1 -1
- package/src/web/client/src/types.ts +3 -1
- package/dist/src/web/client/assets/index-8388d689.js +0 -199
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
<link rel="icon" type="image/svg+xml" href="favicon.ico" />
|
|
6
6
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
7
7
|
<title>promptfoo web viewer</title>
|
|
8
|
-
<script type="module" crossorigin src="/assets/index-
|
|
8
|
+
<script type="module" crossorigin src="/assets/index-6d2a3573.js"></script>
|
|
9
9
|
<link rel="stylesheet" href="/assets/index-d2b6a160.css">
|
|
10
10
|
</head>
|
|
11
11
|
<body>
|
package/package.json
CHANGED
package/src/assertions.ts
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
import rouge from 'rouge';
|
|
2
2
|
import invariant from 'tiny-invariant';
|
|
3
|
-
import nunjucks from 'nunjucks';
|
|
4
3
|
|
|
5
4
|
import telemetry from './telemetry';
|
|
6
5
|
import { DefaultEmbeddingProvider, DefaultGradingProvider } from './providers/openai';
|
|
7
|
-
import { cosineSimilarity, fetchWithRetries } from './util';
|
|
6
|
+
import { cosineSimilarity, fetchWithRetries, getNunjucksEngine } from './util';
|
|
8
7
|
import { loadApiProvider } from './providers';
|
|
9
8
|
import { DEFAULT_GRADING_PROMPT } from './prompts';
|
|
10
9
|
|
|
@@ -18,6 +17,8 @@ import type {
|
|
|
18
17
|
|
|
19
18
|
const DEFAULT_SEMANTIC_SIMILARITY_THRESHOLD = 0.8;
|
|
20
19
|
|
|
20
|
+
const nunjucks = getNunjucksEngine();
|
|
21
|
+
|
|
21
22
|
function handleRougeScore(
|
|
22
23
|
baseType: 'rouge-n',
|
|
23
24
|
assertion: Assertion,
|
|
@@ -40,6 +41,7 @@ function handleRougeScore(
|
|
|
40
41
|
: `${baseType.toUpperCase()} score ${score} is less than threshold ${
|
|
41
42
|
assertion.threshold || 0.75
|
|
42
43
|
}`,
|
|
44
|
+
assertion,
|
|
43
45
|
};
|
|
44
46
|
}
|
|
45
47
|
|
|
@@ -51,17 +53,23 @@ export async function runAssertions(test: AtomicTestCase, output: string): Promi
|
|
|
51
53
|
};
|
|
52
54
|
|
|
53
55
|
if (!test.assert || test.assert.length < 1) {
|
|
54
|
-
return { pass: true, score: 1, reason: 'No assertions', tokensUsed };
|
|
56
|
+
return { pass: true, score: 1, reason: 'No assertions', tokensUsed, assertion: null };
|
|
55
57
|
}
|
|
56
58
|
|
|
57
59
|
let totalScore = 0;
|
|
58
60
|
let totalWeight = 0;
|
|
61
|
+
let allPass = true;
|
|
62
|
+
let failedReason = '';
|
|
63
|
+
const componentResults: GradingResult[] = [];
|
|
64
|
+
|
|
59
65
|
for (const assertion of test.assert) {
|
|
60
66
|
const weight = assertion.weight || 1;
|
|
61
67
|
totalWeight += weight;
|
|
62
68
|
|
|
63
69
|
const result = await runAssertion(assertion, test, output);
|
|
64
70
|
totalScore += result.score * weight;
|
|
71
|
+
componentResults.push(result);
|
|
72
|
+
|
|
65
73
|
if (result.tokensUsed) {
|
|
66
74
|
tokensUsed.total += result.tokensUsed.total;
|
|
67
75
|
tokensUsed.prompt += result.tokensUsed.prompt;
|
|
@@ -69,16 +77,21 @@ export async function runAssertions(test: AtomicTestCase, output: string): Promi
|
|
|
69
77
|
}
|
|
70
78
|
|
|
71
79
|
if (!result.pass) {
|
|
72
|
-
|
|
73
|
-
|
|
80
|
+
allPass = false;
|
|
81
|
+
failedReason = result.reason;
|
|
82
|
+
if (process.env.PROMPTFOO_SHORT_CIRCUIT_TEST_FAILURES) {
|
|
83
|
+
return result;
|
|
84
|
+
}
|
|
74
85
|
}
|
|
75
86
|
}
|
|
76
87
|
|
|
77
88
|
return {
|
|
78
|
-
pass:
|
|
89
|
+
pass: allPass,
|
|
79
90
|
score: totalScore / totalWeight,
|
|
80
|
-
reason: 'All assertions passed',
|
|
91
|
+
reason: allPass ? 'All assertions passed' : failedReason,
|
|
81
92
|
tokensUsed,
|
|
93
|
+
componentResults,
|
|
94
|
+
assertion: null,
|
|
82
95
|
};
|
|
83
96
|
}
|
|
84
97
|
|
|
@@ -114,6 +127,7 @@ export async function runAssertion(
|
|
|
114
127
|
pass,
|
|
115
128
|
score: pass ? 1 : 0,
|
|
116
129
|
reason: pass ? 'Assertion passed' : `Expected output "${renderedValue}"`,
|
|
130
|
+
assertion,
|
|
117
131
|
};
|
|
118
132
|
}
|
|
119
133
|
|
|
@@ -128,6 +142,7 @@ export async function runAssertion(
|
|
|
128
142
|
pass,
|
|
129
143
|
score: pass ? 1 : 0,
|
|
130
144
|
reason: pass ? 'Assertion passed' : 'Expected output to be valid JSON',
|
|
145
|
+
assertion,
|
|
131
146
|
};
|
|
132
147
|
}
|
|
133
148
|
|
|
@@ -144,6 +159,7 @@ export async function runAssertion(
|
|
|
144
159
|
reason: pass
|
|
145
160
|
? 'Assertion passed'
|
|
146
161
|
: `Expected output to ${inverse ? 'not ' : ''}contain "${renderedValue}"`,
|
|
162
|
+
assertion,
|
|
147
163
|
};
|
|
148
164
|
}
|
|
149
165
|
|
|
@@ -160,6 +176,7 @@ export async function runAssertion(
|
|
|
160
176
|
reason: pass
|
|
161
177
|
? 'Assertion passed'
|
|
162
178
|
: `Expected output to ${inverse ? 'not ' : ''}contain one of "${renderedValue.join(', ')}"`,
|
|
179
|
+
assertion,
|
|
163
180
|
};
|
|
164
181
|
}
|
|
165
182
|
|
|
@@ -176,6 +193,7 @@ export async function runAssertion(
|
|
|
176
193
|
reason: pass
|
|
177
194
|
? 'Assertion passed'
|
|
178
195
|
: `Expected output to ${inverse ? 'not ' : ''}contain all of "${renderedValue.join(', ')}"`,
|
|
196
|
+
assertion,
|
|
179
197
|
};
|
|
180
198
|
}
|
|
181
199
|
|
|
@@ -193,6 +211,7 @@ export async function runAssertion(
|
|
|
193
211
|
reason: pass
|
|
194
212
|
? 'Assertion passed'
|
|
195
213
|
: `Expected output to ${inverse ? 'not ' : ''}match regex "${renderedValue}"`,
|
|
214
|
+
assertion,
|
|
196
215
|
};
|
|
197
216
|
}
|
|
198
217
|
|
|
@@ -209,6 +228,7 @@ export async function runAssertion(
|
|
|
209
228
|
reason: pass
|
|
210
229
|
? 'Assertion passed'
|
|
211
230
|
: `Expected output to ${inverse ? 'not ' : ''}contain "${renderedValue}"`,
|
|
231
|
+
assertion,
|
|
212
232
|
};
|
|
213
233
|
}
|
|
214
234
|
|
|
@@ -225,6 +245,7 @@ export async function runAssertion(
|
|
|
225
245
|
reason: pass
|
|
226
246
|
? 'Assertion passed'
|
|
227
247
|
: `Expected output to ${inverse ? 'not ' : ''}start with "${renderedValue}"`,
|
|
248
|
+
assertion,
|
|
228
249
|
};
|
|
229
250
|
}
|
|
230
251
|
|
|
@@ -236,6 +257,7 @@ export async function runAssertion(
|
|
|
236
257
|
reason: pass
|
|
237
258
|
? 'Assertion passed'
|
|
238
259
|
: `Expected output to ${inverse ? 'not ' : ''}contain valid JSON`,
|
|
260
|
+
assertion,
|
|
239
261
|
};
|
|
240
262
|
}
|
|
241
263
|
|
|
@@ -265,6 +287,7 @@ export async function runAssertion(
|
|
|
265
287
|
score: 0,
|
|
266
288
|
reason: `Custom function threw error: ${(err as Error).message}
|
|
267
289
|
${renderedValue}`,
|
|
290
|
+
assertion,
|
|
268
291
|
};
|
|
269
292
|
}
|
|
270
293
|
return {
|
|
@@ -274,6 +297,7 @@ ${renderedValue}`,
|
|
|
274
297
|
? 'Assertion passed'
|
|
275
298
|
: `Custom function returned ${inverse ? 'true' : 'false'}
|
|
276
299
|
${renderedValue}`,
|
|
300
|
+
assertion,
|
|
277
301
|
};
|
|
278
302
|
}
|
|
279
303
|
|
|
@@ -309,6 +333,7 @@ ${renderedValue}`,
|
|
|
309
333
|
pass: false,
|
|
310
334
|
score: 0,
|
|
311
335
|
reason: `Python code execution failed: ${(err as Error).message}`,
|
|
336
|
+
assertion,
|
|
312
337
|
};
|
|
313
338
|
}
|
|
314
339
|
return {
|
|
@@ -318,6 +343,7 @@ ${renderedValue}`,
|
|
|
318
343
|
? 'Assertion passed'
|
|
319
344
|
: `Python code returned ${pass ? 'true' : 'false'}
|
|
320
345
|
${assertion.value}`,
|
|
346
|
+
assertion,
|
|
321
347
|
};
|
|
322
348
|
}
|
|
323
349
|
|
|
@@ -327,7 +353,10 @@ ${assertion.value}`,
|
|
|
327
353
|
typeof renderedValue === 'string',
|
|
328
354
|
'"contains" assertion type must have a string value',
|
|
329
355
|
);
|
|
330
|
-
return
|
|
356
|
+
return {
|
|
357
|
+
assertion,
|
|
358
|
+
...(await matchesSimilarity(renderedValue, output, assertion.threshold || 0.75, inverse)),
|
|
359
|
+
};
|
|
331
360
|
}
|
|
332
361
|
|
|
333
362
|
if (baseType === 'llm-rubric') {
|
|
@@ -336,7 +365,10 @@ ${assertion.value}`,
|
|
|
336
365
|
typeof renderedValue === 'string',
|
|
337
366
|
'"contains" assertion type must have a string value',
|
|
338
367
|
);
|
|
339
|
-
return
|
|
368
|
+
return {
|
|
369
|
+
assertion,
|
|
370
|
+
...(await matchesLlmRubric(renderedValue, output, test.options)),
|
|
371
|
+
};
|
|
340
372
|
}
|
|
341
373
|
|
|
342
374
|
if (baseType === 'webhook') {
|
|
@@ -378,6 +410,7 @@ ${assertion.value}`,
|
|
|
378
410
|
pass: false,
|
|
379
411
|
score: 0,
|
|
380
412
|
reason: `Webhook error: ${(err as Error).message}`,
|
|
413
|
+
assertion,
|
|
381
414
|
};
|
|
382
415
|
}
|
|
383
416
|
|
|
@@ -385,6 +418,7 @@ ${assertion.value}`,
|
|
|
385
418
|
pass,
|
|
386
419
|
score,
|
|
387
420
|
reason: pass ? 'Assertion passed' : `Webhook returned ${inverse ? 'true' : 'false'}`,
|
|
421
|
+
assertion,
|
|
388
422
|
};
|
|
389
423
|
}
|
|
390
424
|
|
|
@@ -422,7 +456,7 @@ export async function matchesSimilarity(
|
|
|
422
456
|
output: string,
|
|
423
457
|
threshold: number,
|
|
424
458
|
inverse: boolean = false,
|
|
425
|
-
): Promise<GradingResult
|
|
459
|
+
): Promise<Omit<GradingResult, 'assertion'>> {
|
|
426
460
|
const expectedEmbedding = await DefaultEmbeddingProvider.callEmbeddingApi(expected);
|
|
427
461
|
const outputEmbedding = await DefaultEmbeddingProvider.callEmbeddingApi(output);
|
|
428
462
|
|
|
@@ -477,7 +511,7 @@ export async function matchesLlmRubric(
|
|
|
477
511
|
expected: string,
|
|
478
512
|
output: string,
|
|
479
513
|
options?: GradingConfig,
|
|
480
|
-
): Promise<GradingResult
|
|
514
|
+
): Promise<Omit<GradingResult, 'assertion'>> {
|
|
481
515
|
if (!options) {
|
|
482
516
|
throw new Error(
|
|
483
517
|
'Cannot grade output without grading config. Specify --grader option or grading config.',
|
package/src/cache.ts
CHANGED
|
@@ -42,10 +42,11 @@ export function getCache() {
|
|
|
42
42
|
return cacheInstance;
|
|
43
43
|
}
|
|
44
44
|
|
|
45
|
-
export async function
|
|
45
|
+
export async function fetchWithCache(
|
|
46
46
|
url: RequestInfo,
|
|
47
47
|
options: RequestInit = {},
|
|
48
48
|
timeout: number,
|
|
49
|
+
format: 'json' | 'text' = 'json',
|
|
49
50
|
): Promise<{ data: any; cached: boolean }> {
|
|
50
51
|
if (!enabled) {
|
|
51
52
|
const resp = await fetchWithRetries(url, options, timeout);
|
|
@@ -75,7 +76,7 @@ export async function fetchJsonWithCache(
|
|
|
75
76
|
// Fetch the actual data and store it in the cache
|
|
76
77
|
const response = await fetchWithRetries(url, options, timeout);
|
|
77
78
|
try {
|
|
78
|
-
const data = await response.json();
|
|
79
|
+
const data = format === 'json' ? await response.json() : await response.text();
|
|
79
80
|
if (response.ok) {
|
|
80
81
|
logger.debug(`Storing ${url} response in cache: ${JSON.stringify(data)}`);
|
|
81
82
|
await cache.set(cacheKey, JSON.stringify(data));
|
package/src/evaluator.ts
CHANGED
|
@@ -2,13 +2,13 @@ import readline from 'readline';
|
|
|
2
2
|
|
|
3
3
|
import async from 'async';
|
|
4
4
|
import chalk from 'chalk';
|
|
5
|
-
import nunjucks from 'nunjucks';
|
|
6
5
|
import invariant from 'tiny-invariant';
|
|
7
6
|
|
|
8
7
|
import logger from './logger';
|
|
9
8
|
import telemetry from './telemetry';
|
|
10
9
|
import { runAssertions } from './assertions';
|
|
11
10
|
import { generatePrompts } from './suggestions';
|
|
11
|
+
import { getNunjucksEngine } from './util';
|
|
12
12
|
|
|
13
13
|
import type { SingleBar } from 'cli-progress';
|
|
14
14
|
import type {
|
|
@@ -39,6 +39,8 @@ interface RunEvalOptions {
|
|
|
39
39
|
|
|
40
40
|
const DEFAULT_MAX_CONCURRENCY = 4;
|
|
41
41
|
|
|
42
|
+
const nunjucks = getNunjucksEngine();
|
|
43
|
+
|
|
42
44
|
function generateVarCombinations(
|
|
43
45
|
vars: Record<string, string | string[] | any>,
|
|
44
46
|
): Record<string, string | any[]>[] {
|
|
@@ -156,6 +158,7 @@ class Evaluator {
|
|
|
156
158
|
this.stats.tokenUsage.completion += checkResult.tokensUsed.completion;
|
|
157
159
|
}
|
|
158
160
|
ret.response = processedResponse;
|
|
161
|
+
ret.gradingResult = checkResult;
|
|
159
162
|
} else {
|
|
160
163
|
ret.success = false;
|
|
161
164
|
ret.score = 0;
|
|
@@ -464,6 +467,7 @@ class Evaluator {
|
|
|
464
467
|
prompt: row.prompt.raw,
|
|
465
468
|
latencyMs: row.latencyMs,
|
|
466
469
|
tokenUsage: row.response?.tokenUsage,
|
|
470
|
+
gradingResult: row.gradingResult,
|
|
467
471
|
};
|
|
468
472
|
},
|
|
469
473
|
);
|
package/src/main.ts
CHANGED
|
@@ -286,7 +286,7 @@ async function main() {
|
|
|
286
286
|
process.env.PROMPTFOO_DISABLE_SHARING === '1'
|
|
287
287
|
? false
|
|
288
288
|
: fileConfig.sharing ?? defaultConfig.sharing ?? true,
|
|
289
|
-
defaultTest: fileConfig.defaultTest,
|
|
289
|
+
defaultTest: fileConfig.defaultTest || defaultConfig.defaultTest,
|
|
290
290
|
};
|
|
291
291
|
|
|
292
292
|
// Validation
|
|
@@ -312,7 +312,7 @@ async function main() {
|
|
|
312
312
|
cmdObj.tests ? undefined : basePath,
|
|
313
313
|
);
|
|
314
314
|
|
|
315
|
-
//
|
|
315
|
+
// Parse testCases for each scenario
|
|
316
316
|
if (fileConfig.scenarios) {
|
|
317
317
|
for (const scenario of fileConfig.scenarios) {
|
|
318
318
|
const parsedScenarioTests: TestCase[] = await readTests(
|
|
@@ -335,8 +335,8 @@ async function main() {
|
|
|
335
335
|
prefix: cmdObj.promptPrefix,
|
|
336
336
|
suffix: cmdObj.promptSuffix,
|
|
337
337
|
provider: cmdObj.grader,
|
|
338
|
-
// rubricPrompt
|
|
339
|
-
|
|
338
|
+
// rubricPrompt
|
|
339
|
+
...(config.defaultTest?.options || {}),
|
|
340
340
|
},
|
|
341
341
|
...config.defaultTest,
|
|
342
342
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logger from '../logger';
|
|
2
|
-
import {
|
|
2
|
+
import { fetchWithCache } from '../cache';
|
|
3
3
|
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
|
|
4
4
|
|
|
5
5
|
import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
|
|
@@ -61,7 +61,7 @@ export class AzureOpenAiEmbeddingProvider extends AzureOpenAiGenericProvider {
|
|
|
61
61
|
let data,
|
|
62
62
|
cached = false;
|
|
63
63
|
try {
|
|
64
|
-
({ data, cached } = (await
|
|
64
|
+
({ data, cached } = (await fetchWithCache(
|
|
65
65
|
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/embeddings?api-version=2023-07-01-preview`,
|
|
66
66
|
{
|
|
67
67
|
method: 'POST',
|
|
@@ -117,9 +117,15 @@ export class AzureOpenAiEmbeddingProvider extends AzureOpenAiGenericProvider {
|
|
|
117
117
|
export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
|
|
118
118
|
options: AzureOpenAiCompletionOptions;
|
|
119
119
|
|
|
120
|
-
constructor(
|
|
120
|
+
constructor(
|
|
121
|
+
deploymentName: string,
|
|
122
|
+
apiKey?: string,
|
|
123
|
+
context?: AzureOpenAiCompletionOptions,
|
|
124
|
+
id?: string,
|
|
125
|
+
) {
|
|
121
126
|
super(deploymentName, apiKey);
|
|
122
127
|
this.options = context || {};
|
|
128
|
+
this.id = id ? () => id : this.id;
|
|
123
129
|
}
|
|
124
130
|
|
|
125
131
|
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
|
|
@@ -165,7 +171,7 @@ export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
|
|
|
165
171
|
let data,
|
|
166
172
|
cached = false;
|
|
167
173
|
try {
|
|
168
|
-
({ data, cached } = (await
|
|
174
|
+
({ data, cached } = (await fetchWithCache(
|
|
169
175
|
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/completions?api-version=2023-07-01-preview`,
|
|
170
176
|
{
|
|
171
177
|
method: 'POST',
|
|
@@ -205,9 +211,15 @@ export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
|
|
|
205
211
|
export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvider {
|
|
206
212
|
options: AzureOpenAiCompletionOptions;
|
|
207
213
|
|
|
208
|
-
constructor(
|
|
214
|
+
constructor(
|
|
215
|
+
deploymentName: string,
|
|
216
|
+
apiKey?: string,
|
|
217
|
+
context?: AzureOpenAiCompletionOptions,
|
|
218
|
+
id?: string,
|
|
219
|
+
) {
|
|
209
220
|
super(deploymentName, apiKey);
|
|
210
221
|
this.options = context || {};
|
|
222
|
+
this.id = id ? () => id : this.id;
|
|
211
223
|
}
|
|
212
224
|
|
|
213
225
|
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
|
|
@@ -246,7 +258,7 @@ export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvide
|
|
|
246
258
|
let data,
|
|
247
259
|
cached = false;
|
|
248
260
|
try {
|
|
249
|
-
({ data, cached } = (await
|
|
261
|
+
({ data, cached } = (await fetchWithCache(
|
|
250
262
|
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/chat/completions?api-version=2023-07-01-preview`,
|
|
251
263
|
{
|
|
252
264
|
method: 'POST',
|
package/src/providers/llama.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { fetchWithCache } from '../cache';
|
|
2
2
|
import { REQUEST_TIMEOUT_MS } from './shared';
|
|
3
3
|
|
|
4
4
|
import type { ApiProvider, ProviderResponse } from '../types.js';
|
|
@@ -65,7 +65,7 @@ export class LlamaProvider implements ApiProvider {
|
|
|
65
65
|
|
|
66
66
|
let response;
|
|
67
67
|
try {
|
|
68
|
-
response = await
|
|
68
|
+
response = await fetchWithCache(
|
|
69
69
|
`${process.env.LLAMA_BASE_URL || 'http://localhost:8080'}/completion`,
|
|
70
70
|
{
|
|
71
71
|
method: 'POST',
|
package/src/providers/localai.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logger from '../logger';
|
|
2
|
-
import {
|
|
2
|
+
import { fetchWithCache } from '../cache';
|
|
3
3
|
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
|
|
4
4
|
|
|
5
5
|
import type { ApiProvider, ProviderResponse } from '../types.js';
|
|
@@ -40,7 +40,7 @@ export class LocalAiChatProvider extends LocalAiGenericProvider {
|
|
|
40
40
|
let data,
|
|
41
41
|
cached = false;
|
|
42
42
|
try {
|
|
43
|
-
({ data, cached } = (await
|
|
43
|
+
({ data, cached } = (await fetchWithCache(
|
|
44
44
|
`${this.apiBaseUrl}/chat/completions`,
|
|
45
45
|
{
|
|
46
46
|
method: 'POST',
|
|
@@ -81,7 +81,7 @@ export class LocalAiCompletionProvider extends LocalAiGenericProvider {
|
|
|
81
81
|
let data,
|
|
82
82
|
cached = false;
|
|
83
83
|
try {
|
|
84
|
-
({ data, cached } = (await
|
|
84
|
+
({ data, cached } = (await fetchWithCache(
|
|
85
85
|
`${this.apiBaseUrl}/completions`,
|
|
86
86
|
{
|
|
87
87
|
method: 'POST',
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import logger from '../logger';
|
|
2
|
+
import { fetchWithCache } from '../cache';
|
|
3
|
+
|
|
4
|
+
import type { ApiProvider, ProviderResponse } from '../types.js';
|
|
5
|
+
import { REQUEST_TIMEOUT_MS } from './shared';
|
|
6
|
+
|
|
7
|
+
interface OllamaJsonL {
|
|
8
|
+
model: string;
|
|
9
|
+
created_at: string;
|
|
10
|
+
response?: string;
|
|
11
|
+
done: boolean;
|
|
12
|
+
context?: number[];
|
|
13
|
+
total_duration?: number;
|
|
14
|
+
load_duration?: number;
|
|
15
|
+
sample_count?: number;
|
|
16
|
+
sample_duration?: number;
|
|
17
|
+
prompt_eval_count?: number;
|
|
18
|
+
prompt_eval_duration?: number;
|
|
19
|
+
eval_count?: number;
|
|
20
|
+
eval_duration?: number;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export class OllamaProvider implements ApiProvider {
|
|
24
|
+
modelName: string;
|
|
25
|
+
|
|
26
|
+
constructor(modelName: string) {
|
|
27
|
+
this.modelName = modelName;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
id(): string {
|
|
31
|
+
return `ollama:${this.modelName}`;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
toString(): string {
|
|
35
|
+
return `[Ollama Provider ${this.modelName}]`;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
async callApi(prompt: string): Promise<ProviderResponse> {
|
|
39
|
+
const params = {
|
|
40
|
+
model: this.modelName,
|
|
41
|
+
prompt,
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
logger.debug(`Calling Ollama API: ${JSON.stringify(params)}`);
|
|
45
|
+
let response;
|
|
46
|
+
try {
|
|
47
|
+
response = await fetchWithCache(
|
|
48
|
+
`${process.env.OLLAMA_BASE_URL || 'http://localhost:11434'}/api/generate`,
|
|
49
|
+
{
|
|
50
|
+
method: 'POST',
|
|
51
|
+
headers: {
|
|
52
|
+
'Content-Type': 'application/json',
|
|
53
|
+
},
|
|
54
|
+
body: JSON.stringify(params),
|
|
55
|
+
},
|
|
56
|
+
REQUEST_TIMEOUT_MS,
|
|
57
|
+
'text',
|
|
58
|
+
);
|
|
59
|
+
} catch (err) {
|
|
60
|
+
return {
|
|
61
|
+
error: `API call error: ${String(err)}`,
|
|
62
|
+
};
|
|
63
|
+
}
|
|
64
|
+
logger.debug(`\tOllama API response: ${response.data}`);
|
|
65
|
+
|
|
66
|
+
try {
|
|
67
|
+
const output = response.data
|
|
68
|
+
.split('\n')
|
|
69
|
+
.map((line: string) => {
|
|
70
|
+
const parsed = JSON.parse(line) as OllamaJsonL;
|
|
71
|
+
if (parsed.response) {
|
|
72
|
+
return parsed.response;
|
|
73
|
+
}
|
|
74
|
+
return null;
|
|
75
|
+
})
|
|
76
|
+
.filter((s: string | null) => s !== null)
|
|
77
|
+
.join('');
|
|
78
|
+
|
|
79
|
+
return {
|
|
80
|
+
output,
|
|
81
|
+
};
|
|
82
|
+
} catch (err) {
|
|
83
|
+
return {
|
|
84
|
+
error: `API response error: ${String(err)}: ${JSON.stringify(response.data)}`,
|
|
85
|
+
};
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
}
|
package/src/providers/openai.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logger from '../logger';
|
|
2
|
-
import {
|
|
2
|
+
import { fetchWithCache } from '../cache';
|
|
3
3
|
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
|
|
4
4
|
|
|
5
5
|
import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
|
|
@@ -61,7 +61,7 @@ export class OpenAiEmbeddingProvider extends OpenAiGenericProvider {
|
|
|
61
61
|
let data,
|
|
62
62
|
cached = false;
|
|
63
63
|
try {
|
|
64
|
-
({ data, cached } = (await
|
|
64
|
+
({ data, cached } = (await fetchWithCache(
|
|
65
65
|
`https://${this.apiHost}/v1/embeddings`,
|
|
66
66
|
{
|
|
67
67
|
method: 'POST',
|
|
@@ -125,12 +125,13 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider {
|
|
|
125
125
|
|
|
126
126
|
options: OpenAiCompletionOptions;
|
|
127
127
|
|
|
128
|
-
constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions) {
|
|
128
|
+
constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions, id?: string) {
|
|
129
129
|
if (!OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelName)) {
|
|
130
130
|
logger.warn(`Using unknown OpenAI completion model: ${modelName}`);
|
|
131
131
|
}
|
|
132
132
|
super(modelName, apiKey);
|
|
133
133
|
this.options = context || {};
|
|
134
|
+
this.id = id ? () => id : this.id;
|
|
134
135
|
}
|
|
135
136
|
|
|
136
137
|
async callApi(prompt: string, options?: OpenAiCompletionOptions): Promise<ProviderResponse> {
|
|
@@ -176,7 +177,7 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider {
|
|
|
176
177
|
let data,
|
|
177
178
|
cached = false;
|
|
178
179
|
try {
|
|
179
|
-
({ data, cached } = (await
|
|
180
|
+
({ data, cached } = (await fetchWithCache(
|
|
180
181
|
`https://${this.apiHost}/v1/completions`,
|
|
181
182
|
{
|
|
182
183
|
method: 'POST',
|
|
@@ -229,12 +230,13 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
|
|
|
229
230
|
|
|
230
231
|
options: OpenAiCompletionOptions;
|
|
231
232
|
|
|
232
|
-
constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions) {
|
|
233
|
+
constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions, id?: string) {
|
|
233
234
|
if (!OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelName)) {
|
|
234
235
|
logger.warn(`Using unknown OpenAI chat model: ${modelName}`);
|
|
235
236
|
}
|
|
236
237
|
super(modelName, apiKey);
|
|
237
238
|
this.options = context || {};
|
|
239
|
+
this.id = id ? () => id : this.id;
|
|
238
240
|
}
|
|
239
241
|
|
|
240
242
|
async callApi(prompt: string, options?: OpenAiCompletionOptions): Promise<ProviderResponse> {
|
|
@@ -273,7 +275,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
|
|
|
273
275
|
let data,
|
|
274
276
|
cached = false;
|
|
275
277
|
try {
|
|
276
|
-
({ data, cached } = (await
|
|
278
|
+
({ data, cached } = (await fetchWithCache(
|
|
277
279
|
`https://${this.apiHost}/v1/chat/completions`,
|
|
278
280
|
{
|
|
279
281
|
method: 'POST',
|
package/src/providers.ts
CHANGED
|
@@ -5,6 +5,7 @@ import { AnthropicCompletionProvider } from './providers/anthropic';
|
|
|
5
5
|
import { ReplicateProvider } from './providers/replicate';
|
|
6
6
|
import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
|
|
7
7
|
import { LlamaProvider } from './providers/llama';
|
|
8
|
+
import { OllamaProvider } from './providers/ollama';
|
|
8
9
|
import { ScriptCompletionProvider } from './providers/scriptCompletion';
|
|
9
10
|
import {
|
|
10
11
|
AzureOpenAiChatCompletionProvider,
|
|
@@ -44,7 +45,8 @@ export async function loadApiProviders(
|
|
|
44
45
|
};
|
|
45
46
|
} else {
|
|
46
47
|
const id = Object.keys(provider)[0];
|
|
47
|
-
const
|
|
48
|
+
const providerObject = provider[id];
|
|
49
|
+
const context = { ...providerObject, id: providerObject.id || id };
|
|
48
50
|
return loadApiProvider(id, context, basePath);
|
|
49
51
|
}
|
|
50
52
|
}),
|
|
@@ -84,9 +86,9 @@ export async function loadApiProvider(
|
|
|
84
86
|
context?.config,
|
|
85
87
|
);
|
|
86
88
|
} else if (OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelType)) {
|
|
87
|
-
return new OpenAiChatCompletionProvider(modelType, undefined, context?.config);
|
|
89
|
+
return new OpenAiChatCompletionProvider(modelType, undefined, context?.config, context?.id);
|
|
88
90
|
} else if (OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelType)) {
|
|
89
|
-
return new OpenAiCompletionProvider(modelType, undefined, context?.config);
|
|
91
|
+
return new OpenAiCompletionProvider(modelType, undefined, context?.config, context?.id);
|
|
90
92
|
} else {
|
|
91
93
|
throw new Error(
|
|
92
94
|
`Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
|
|
@@ -99,9 +101,19 @@ export async function loadApiProvider(
|
|
|
99
101
|
const deploymentName = options[2];
|
|
100
102
|
|
|
101
103
|
if (modelType === 'chat') {
|
|
102
|
-
return new AzureOpenAiChatCompletionProvider(
|
|
104
|
+
return new AzureOpenAiChatCompletionProvider(
|
|
105
|
+
deploymentName,
|
|
106
|
+
undefined,
|
|
107
|
+
context?.config,
|
|
108
|
+
context?.id,
|
|
109
|
+
);
|
|
103
110
|
} else if (modelType === 'completion') {
|
|
104
|
-
return new AzureOpenAiCompletionProvider(
|
|
111
|
+
return new AzureOpenAiCompletionProvider(
|
|
112
|
+
deploymentName,
|
|
113
|
+
undefined,
|
|
114
|
+
context?.config,
|
|
115
|
+
context?.id,
|
|
116
|
+
);
|
|
105
117
|
} else {
|
|
106
118
|
throw new Error(
|
|
107
119
|
`Unknown Azure OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
|
|
@@ -137,6 +149,9 @@ export async function loadApiProvider(
|
|
|
137
149
|
if (providerPath === 'llama' || providerPath.startsWith('llama:')) {
|
|
138
150
|
const modelName = providerPath.split(':')[1];
|
|
139
151
|
return new LlamaProvider(modelName, context?.config);
|
|
152
|
+
} else if (providerPath.startsWith('ollama:')) {
|
|
153
|
+
const modelName = providerPath.split(':')[1];
|
|
154
|
+
return new OllamaProvider(modelName);
|
|
140
155
|
} else if (providerPath?.startsWith('localai:')) {
|
|
141
156
|
const options = providerPath.split(':');
|
|
142
157
|
const modelType = options[1];
|