promptfoo 0.17.9 → 0.18.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +5 -5
- package/dist/package.json +1 -1
- package/dist/src/assertions.d.ts.map +1 -1
- package/dist/src/assertions.js +97 -42
- package/dist/src/assertions.js.map +1 -1
- package/dist/src/evaluator.d.ts.map +1 -1
- package/dist/src/evaluator.js +35 -7
- package/dist/src/evaluator.js.map +1 -1
- package/dist/src/index.d.ts.map +1 -1
- package/dist/src/index.js +3 -0
- package/dist/src/index.js.map +1 -1
- package/dist/src/main.js +9 -0
- package/dist/src/main.js.map +1 -1
- package/dist/src/providers/llama.d.ts +30 -0
- package/dist/src/providers/llama.d.ts.map +1 -0
- package/dist/src/providers/llama.js +67 -0
- package/dist/src/providers/llama.js.map +1 -0
- package/dist/src/providers.d.ts +2 -2
- package/dist/src/providers.d.ts.map +1 -1
- package/dist/src/providers.js +21 -2
- package/dist/src/providers.js.map +1 -1
- package/dist/src/table.js +2 -2
- package/dist/src/table.js.map +1 -1
- package/dist/src/types.d.ts +11 -4
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/util.d.ts.map +1 -1
- package/dist/src/util.js +5 -2
- package/dist/src/util.js.map +1 -1
- package/package.json +1 -1
- package/src/assertions.ts +102 -49
- package/src/evaluator.ts +33 -4
- package/src/index.ts +6 -1
- package/src/main.ts +14 -0
- package/src/providers/llama.ts +95 -0
- package/src/providers.ts +27 -5
- package/src/table.ts +2 -2
- package/src/types.ts +25 -5
- package/src/util.ts +12 -2
- package/src/web/client/package-lock.json +0 -5726
package/src/assertions.ts
CHANGED
|
@@ -99,12 +99,21 @@ export async function runAssertion(
|
|
|
99
99
|
type: baseType,
|
|
100
100
|
});
|
|
101
101
|
|
|
102
|
+
//render assertion values
|
|
103
|
+
let renderedValue = assertion.value;
|
|
104
|
+
// renderString for assertion values
|
|
105
|
+
if (renderedValue && typeof renderedValue === 'string') {
|
|
106
|
+
renderedValue = nunjucks.renderString(renderedValue, test.vars || {});
|
|
107
|
+
} else if (renderedValue && Array.isArray(renderedValue)) {
|
|
108
|
+
renderedValue = renderedValue.map((v) => nunjucks.renderString(v, test.vars || {}));
|
|
109
|
+
}
|
|
110
|
+
|
|
102
111
|
if (baseType === 'equals') {
|
|
103
|
-
pass =
|
|
112
|
+
pass = renderedValue === output;
|
|
104
113
|
return {
|
|
105
114
|
pass,
|
|
106
115
|
score: pass ? 1 : 0,
|
|
107
|
-
reason: pass ? 'Assertion passed' : `Expected output "${
|
|
116
|
+
reason: pass ? 'Assertion passed' : `Expected output "${renderedValue}"`,
|
|
108
117
|
};
|
|
109
118
|
}
|
|
110
119
|
|
|
@@ -123,103 +132,99 @@ export async function runAssertion(
|
|
|
123
132
|
}
|
|
124
133
|
|
|
125
134
|
if (baseType === 'contains') {
|
|
126
|
-
invariant(
|
|
135
|
+
invariant(renderedValue, '"contains" assertion type must have a string or number value');
|
|
127
136
|
invariant(
|
|
128
|
-
typeof
|
|
137
|
+
typeof renderedValue === 'string' || typeof renderedValue === 'number',
|
|
129
138
|
'"contains" assertion type must have a string or number value',
|
|
130
139
|
);
|
|
131
|
-
pass = output.includes(String(
|
|
140
|
+
pass = output.includes(String(renderedValue)) !== inverse;
|
|
132
141
|
return {
|
|
133
142
|
pass,
|
|
134
143
|
score: pass ? 1 : 0,
|
|
135
144
|
reason: pass
|
|
136
145
|
? 'Assertion passed'
|
|
137
|
-
: `Expected output to ${inverse ? 'not ' : ''}contain "${
|
|
146
|
+
: `Expected output to ${inverse ? 'not ' : ''}contain "${renderedValue}"`,
|
|
138
147
|
};
|
|
139
148
|
}
|
|
140
149
|
|
|
141
150
|
if (baseType === 'contains-any') {
|
|
142
|
-
invariant(
|
|
151
|
+
invariant(renderedValue, '"contains-any" assertion type must have a value');
|
|
143
152
|
invariant(
|
|
144
|
-
Array.isArray(
|
|
153
|
+
Array.isArray(renderedValue),
|
|
145
154
|
'"contains-any" assertion type must have an array value',
|
|
146
155
|
);
|
|
147
|
-
pass =
|
|
156
|
+
pass = renderedValue.some((value) => output.includes(value)) !== inverse;
|
|
148
157
|
return {
|
|
149
158
|
pass,
|
|
150
159
|
score: pass ? 1 : 0,
|
|
151
160
|
reason: pass
|
|
152
161
|
? 'Assertion passed'
|
|
153
|
-
: `Expected output to ${inverse ? 'not ' : ''}contain one of "${
|
|
154
|
-
', ',
|
|
155
|
-
)}"`,
|
|
162
|
+
: `Expected output to ${inverse ? 'not ' : ''}contain one of "${renderedValue.join(', ')}"`,
|
|
156
163
|
};
|
|
157
164
|
}
|
|
158
165
|
|
|
159
166
|
if (baseType === 'contains-all') {
|
|
160
|
-
invariant(
|
|
167
|
+
invariant(renderedValue, '"contains-all" assertion type must have a value');
|
|
161
168
|
invariant(
|
|
162
|
-
Array.isArray(
|
|
169
|
+
Array.isArray(renderedValue),
|
|
163
170
|
'"contains-all" assertion type must have an array value',
|
|
164
171
|
);
|
|
165
|
-
pass =
|
|
172
|
+
pass = renderedValue.every((value) => output.includes(value)) !== inverse;
|
|
166
173
|
return {
|
|
167
174
|
pass,
|
|
168
175
|
score: pass ? 1 : 0,
|
|
169
176
|
reason: pass
|
|
170
177
|
? 'Assertion passed'
|
|
171
|
-
: `Expected output to ${inverse ? 'not ' : ''}contain all of "${
|
|
172
|
-
', ',
|
|
173
|
-
)}"`,
|
|
178
|
+
: `Expected output to ${inverse ? 'not ' : ''}contain all of "${renderedValue.join(', ')}"`,
|
|
174
179
|
};
|
|
175
180
|
}
|
|
176
181
|
|
|
177
182
|
if (baseType === 'regex') {
|
|
178
|
-
invariant(
|
|
183
|
+
invariant(renderedValue, '"regex" assertion type must have a string value');
|
|
179
184
|
invariant(
|
|
180
|
-
typeof
|
|
185
|
+
typeof renderedValue === 'string',
|
|
181
186
|
'"contains" assertion type must have a string value',
|
|
182
187
|
);
|
|
183
|
-
const regex = new RegExp(
|
|
188
|
+
const regex = new RegExp(renderedValue);
|
|
184
189
|
pass = regex.test(output) !== inverse;
|
|
185
190
|
return {
|
|
186
191
|
pass,
|
|
187
192
|
score: pass ? 1 : 0,
|
|
188
193
|
reason: pass
|
|
189
194
|
? 'Assertion passed'
|
|
190
|
-
: `Expected output to ${inverse ? 'not ' : ''}match regex "${
|
|
195
|
+
: `Expected output to ${inverse ? 'not ' : ''}match regex "${renderedValue}"`,
|
|
191
196
|
};
|
|
192
197
|
}
|
|
193
198
|
|
|
194
199
|
if (baseType === 'icontains') {
|
|
195
|
-
invariant(
|
|
200
|
+
invariant(renderedValue, '"icontains" assertion type must have a string or number value');
|
|
196
201
|
invariant(
|
|
197
|
-
typeof
|
|
202
|
+
typeof renderedValue === 'string' || typeof renderedValue === 'number',
|
|
198
203
|
'"icontains" assertion type must have a string or number value',
|
|
199
204
|
);
|
|
200
|
-
pass = output.toLowerCase().includes(String(
|
|
205
|
+
pass = output.toLowerCase().includes(String(renderedValue).toLowerCase()) !== inverse;
|
|
201
206
|
return {
|
|
202
207
|
pass,
|
|
203
208
|
score: pass ? 1 : 0,
|
|
204
209
|
reason: pass
|
|
205
210
|
? 'Assertion passed'
|
|
206
|
-
: `Expected output to ${inverse ? 'not ' : ''}contain "${
|
|
211
|
+
: `Expected output to ${inverse ? 'not ' : ''}contain "${renderedValue}"`,
|
|
207
212
|
};
|
|
208
213
|
}
|
|
209
214
|
|
|
210
215
|
if (baseType === 'starts-with') {
|
|
211
|
-
invariant(
|
|
216
|
+
invariant(renderedValue, '"starts-with" assertion type must have a string value');
|
|
212
217
|
invariant(
|
|
213
|
-
typeof
|
|
218
|
+
typeof renderedValue === 'string',
|
|
214
219
|
'"starts-with" assertion type must have a string value',
|
|
215
220
|
);
|
|
216
|
-
pass = output.startsWith(String(
|
|
221
|
+
pass = output.startsWith(String(renderedValue)) !== inverse;
|
|
217
222
|
return {
|
|
218
223
|
pass,
|
|
219
224
|
score: pass ? 1 : 0,
|
|
220
225
|
reason: pass
|
|
221
226
|
? 'Assertion passed'
|
|
222
|
-
: `Expected output to ${inverse ? 'not ' : ''}start with "${
|
|
227
|
+
: `Expected output to ${inverse ? 'not ' : ''}start with "${renderedValue}"`,
|
|
223
228
|
};
|
|
224
229
|
}
|
|
225
230
|
|
|
@@ -234,12 +239,16 @@ export async function runAssertion(
|
|
|
234
239
|
};
|
|
235
240
|
}
|
|
236
241
|
|
|
242
|
+
const context = {
|
|
243
|
+
vars: test.vars || {},
|
|
244
|
+
};
|
|
245
|
+
|
|
237
246
|
if (baseType === 'javascript') {
|
|
238
247
|
try {
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
};
|
|
248
|
+
if (typeof assertion.value === 'function') {
|
|
249
|
+
return assertion.value(output, test, assertion);
|
|
250
|
+
}
|
|
251
|
+
const customFunction = new Function('output', 'context', `return ${renderedValue}`);
|
|
243
252
|
const result = customFunction(output, context) as any;
|
|
244
253
|
if (typeof result === 'boolean') {
|
|
245
254
|
pass = result !== inverse;
|
|
@@ -255,7 +264,7 @@ export async function runAssertion(
|
|
|
255
264
|
pass: false,
|
|
256
265
|
score: 0,
|
|
257
266
|
reason: `Custom function threw error: ${(err as Error).message}
|
|
258
|
-
${
|
|
267
|
+
${renderedValue}`,
|
|
259
268
|
};
|
|
260
269
|
}
|
|
261
270
|
return {
|
|
@@ -264,41 +273,82 @@ ${assertion.value}`,
|
|
|
264
273
|
reason: pass
|
|
265
274
|
? 'Assertion passed'
|
|
266
275
|
: `Custom function returned ${inverse ? 'true' : 'false'}
|
|
276
|
+
${renderedValue}`,
|
|
277
|
+
};
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
if (baseType === 'python') {
|
|
281
|
+
try {
|
|
282
|
+
const { execSync } = require('child_process');
|
|
283
|
+
const escapedOutput = output.replace(/'/g, "\\'").replace(/"/g, '\\"');
|
|
284
|
+
const escapedContext = JSON.stringify(context).replace(/'/g, "\\'").replace(/"/g, '\\"');
|
|
285
|
+
const result = execSync(
|
|
286
|
+
`python -c "import json; import math; import os; import sys; import re; import datetime; import random; import collections; output='${escapedOutput}'; context='${escapedContext}'; print(json.dumps(${assertion.value}))"`,
|
|
287
|
+
)
|
|
288
|
+
.toString()
|
|
289
|
+
.trim();
|
|
290
|
+
if (result === 'true') {
|
|
291
|
+
pass = true;
|
|
292
|
+
score = 1.0;
|
|
293
|
+
} else if (result === 'false') {
|
|
294
|
+
pass = false;
|
|
295
|
+
score = 0.0;
|
|
296
|
+
} else if (result.startsWith('{')) {
|
|
297
|
+
return JSON.parse(result);
|
|
298
|
+
} else {
|
|
299
|
+
pass = true;
|
|
300
|
+
score = parseFloat(result);
|
|
301
|
+
if (isNaN(score)) {
|
|
302
|
+
throw new Error(
|
|
303
|
+
'Python code must return a boolean, number, or {pass, score, reason} object',
|
|
304
|
+
);
|
|
305
|
+
}
|
|
306
|
+
}
|
|
307
|
+
} catch (err) {
|
|
308
|
+
return {
|
|
309
|
+
pass: false,
|
|
310
|
+
score: 0,
|
|
311
|
+
reason: `Python code execution failed: ${(err as Error).message}`,
|
|
312
|
+
};
|
|
313
|
+
}
|
|
314
|
+
return {
|
|
315
|
+
pass,
|
|
316
|
+
score,
|
|
317
|
+
reason: pass
|
|
318
|
+
? 'Assertion passed'
|
|
319
|
+
: `Python code returned ${pass ? 'true' : 'false'}
|
|
267
320
|
${assertion.value}`,
|
|
268
321
|
};
|
|
269
322
|
}
|
|
270
323
|
|
|
271
324
|
if (baseType === 'similar') {
|
|
272
|
-
invariant(
|
|
325
|
+
invariant(renderedValue, 'Similarity assertion must have a string value');
|
|
273
326
|
invariant(
|
|
274
|
-
typeof
|
|
327
|
+
typeof renderedValue === 'string',
|
|
275
328
|
'"contains" assertion type must have a string value',
|
|
276
329
|
);
|
|
277
|
-
return matchesSimilarity(
|
|
330
|
+
return matchesSimilarity(renderedValue, output, assertion.threshold || 0.75, inverse);
|
|
278
331
|
}
|
|
279
332
|
|
|
280
333
|
if (baseType === 'llm-rubric') {
|
|
281
|
-
invariant(
|
|
334
|
+
invariant(renderedValue, 'Similarity assertion must have a string value');
|
|
282
335
|
invariant(
|
|
283
|
-
typeof
|
|
336
|
+
typeof renderedValue === 'string',
|
|
284
337
|
'"contains" assertion type must have a string value',
|
|
285
338
|
);
|
|
286
|
-
return matchesLlmRubric(
|
|
339
|
+
return matchesLlmRubric(renderedValue, output, test.options);
|
|
287
340
|
}
|
|
288
341
|
|
|
289
342
|
if (baseType === 'webhook') {
|
|
290
|
-
invariant(
|
|
291
|
-
invariant(
|
|
292
|
-
typeof assertion.value === 'string',
|
|
293
|
-
'"webhook" assertion type must have a URL value',
|
|
294
|
-
);
|
|
343
|
+
invariant(renderedValue, '"webhook" assertion type must have a URL value');
|
|
344
|
+
invariant(typeof renderedValue === 'string', '"webhook" assertion type must have a URL value');
|
|
295
345
|
|
|
296
346
|
try {
|
|
297
347
|
const context = {
|
|
298
348
|
vars: test.vars || {},
|
|
299
349
|
};
|
|
300
350
|
const response = await fetchWithRetries(
|
|
301
|
-
|
|
351
|
+
renderedValue,
|
|
302
352
|
{
|
|
303
353
|
method: 'POST',
|
|
304
354
|
headers: {
|
|
@@ -339,8 +389,11 @@ ${assertion.value}`,
|
|
|
339
389
|
}
|
|
340
390
|
|
|
341
391
|
if (baseType === 'rouge-n') {
|
|
342
|
-
invariant(
|
|
343
|
-
|
|
392
|
+
invariant(
|
|
393
|
+
typeof renderedValue === 'string' || Array.isArray(renderedValue),
|
|
394
|
+
'"rouge" assertion type must be a value (string or string array)',
|
|
395
|
+
);
|
|
396
|
+
return handleRougeScore(baseType, assertion, renderedValue, output, inverse);
|
|
344
397
|
}
|
|
345
398
|
|
|
346
399
|
throw new Error('Unknown assertion type: ' + assertion.type);
|
package/src/evaluator.ts
CHANGED
|
@@ -255,10 +255,11 @@ class Evaluator {
|
|
|
255
255
|
}
|
|
256
256
|
|
|
257
257
|
// Aggregate all vars across test cases
|
|
258
|
-
|
|
259
|
-
const tests = (
|
|
258
|
+
let tests = (
|
|
260
259
|
testSuite.tests && testSuite.tests.length > 0
|
|
261
260
|
? testSuite.tests
|
|
261
|
+
: testSuite.scenarios
|
|
262
|
+
? []
|
|
262
263
|
: [
|
|
263
264
|
{
|
|
264
265
|
// Dummy test for cases when we're only comparing raw prompts.
|
|
@@ -269,6 +270,35 @@ class Evaluator {
|
|
|
269
270
|
return Object.assign(finalTestCase, test);
|
|
270
271
|
});
|
|
271
272
|
|
|
273
|
+
// Build scenarios and add to tests
|
|
274
|
+
if (testSuite.scenarios && testSuite.scenarios.length > 0) {
|
|
275
|
+
for (const scenario of testSuite.scenarios) {
|
|
276
|
+
for (const data of scenario.config) {
|
|
277
|
+
// Merge defaultTest with scenario config
|
|
278
|
+
const scenarioTests = (
|
|
279
|
+
scenario.tests || [
|
|
280
|
+
{
|
|
281
|
+
// Dummy test for cases when we're only comparing raw prompts.
|
|
282
|
+
},
|
|
283
|
+
]
|
|
284
|
+
).map((test) => {
|
|
285
|
+
return {
|
|
286
|
+
...testSuite.defaultTest,
|
|
287
|
+
...data,
|
|
288
|
+
...test,
|
|
289
|
+
vars: {
|
|
290
|
+
...testSuite.defaultTest?.vars,
|
|
291
|
+
...data.vars,
|
|
292
|
+
...test.vars,
|
|
293
|
+
},
|
|
294
|
+
};
|
|
295
|
+
});
|
|
296
|
+
// Add scenario tests to tests
|
|
297
|
+
tests = tests.concat(scenarioTests);
|
|
298
|
+
}
|
|
299
|
+
}
|
|
300
|
+
}
|
|
301
|
+
|
|
272
302
|
const varNames: Set<string> = new Set();
|
|
273
303
|
const varsWithSpecialColsRemoved: Record<string, string | string[] | object>[] = [];
|
|
274
304
|
for (const testCase of tests) {
|
|
@@ -352,8 +382,7 @@ class Evaluator {
|
|
|
352
382
|
// Set up progress bar...
|
|
353
383
|
let progressbar: SingleBar | undefined;
|
|
354
384
|
if (options.showProgressBar) {
|
|
355
|
-
const totalNumRuns =
|
|
356
|
-
testSuite.prompts.length * testSuite.providers.length * (totalVarCombinations || 1);
|
|
385
|
+
const totalNumRuns = runEvalOptions.length;
|
|
357
386
|
const cliProgress = await import('cli-progress');
|
|
358
387
|
progressbar = new cliProgress.SingleBar(
|
|
359
388
|
{
|
package/src/index.ts
CHANGED
|
@@ -3,7 +3,7 @@ import providers from './providers';
|
|
|
3
3
|
import telemetry from './telemetry';
|
|
4
4
|
import { evaluate as doEvaluate } from './evaluator';
|
|
5
5
|
import { loadApiProviders } from './providers';
|
|
6
|
-
import { readTests } from './util';
|
|
6
|
+
import { readTests, writeOutput } from './util';
|
|
7
7
|
import type { EvaluateOptions, TestSuite, TestSuiteConfig } from './types';
|
|
8
8
|
|
|
9
9
|
export * from './types';
|
|
@@ -28,6 +28,11 @@ async function evaluate(testSuite: EvaluateTestSuite, options: EvaluateOptions =
|
|
|
28
28
|
};
|
|
29
29
|
telemetry.maybeShowNotice();
|
|
30
30
|
const ret = await doEvaluate(constructedTestSuite, options);
|
|
31
|
+
|
|
32
|
+
if (testSuite.outputPath) {
|
|
33
|
+
writeOutput(testSuite.outputPath, ret, testSuite, null);
|
|
34
|
+
}
|
|
35
|
+
|
|
31
36
|
await telemetry.send();
|
|
32
37
|
return ret;
|
|
33
38
|
}
|
package/src/main.ts
CHANGED
|
@@ -281,6 +281,7 @@ async function main() {
|
|
|
281
281
|
prompts: cmdObj.prompts || fileConfig.prompts || defaultConfig.prompts,
|
|
282
282
|
providers: cmdObj.providers || fileConfig.providers || defaultConfig.providers,
|
|
283
283
|
tests: cmdObj.tests || cmdObj.vars || fileConfig.tests || defaultConfig.tests,
|
|
284
|
+
scenarios: fileConfig.scenarios || defaultConfig.scenarios,
|
|
284
285
|
sharing:
|
|
285
286
|
process.env.PROMPTFOO_DISABLE_SHARING === '1'
|
|
286
287
|
? false
|
|
@@ -310,6 +311,18 @@ async function main() {
|
|
|
310
311
|
config.tests,
|
|
311
312
|
cmdObj.tests ? undefined : basePath,
|
|
312
313
|
);
|
|
314
|
+
|
|
315
|
+
//parse testCases for each scenario
|
|
316
|
+
if (fileConfig.scenarios) {
|
|
317
|
+
for (const scenario of fileConfig.scenarios) {
|
|
318
|
+
const parsedScenarioTests: TestCase[] = await readTests(
|
|
319
|
+
scenario.tests,
|
|
320
|
+
cmdObj.tests ? undefined : basePath,
|
|
321
|
+
);
|
|
322
|
+
scenario.tests = parsedScenarioTests;
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
|
|
313
326
|
const parsedProviderPromptMap = readProviderPromptMap(config, parsedPrompts);
|
|
314
327
|
|
|
315
328
|
if (parsedPrompts.length === 0) {
|
|
@@ -334,6 +347,7 @@ async function main() {
|
|
|
334
347
|
providers: parsedProviders,
|
|
335
348
|
providerPromptMap: parsedProviderPromptMap,
|
|
336
349
|
tests: parsedTests,
|
|
350
|
+
scenarios: config.scenarios,
|
|
337
351
|
defaultTest,
|
|
338
352
|
};
|
|
339
353
|
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import { fetchJsonWithCache } from '../cache';
|
|
2
|
+
import { REQUEST_TIMEOUT_MS } from './shared';
|
|
3
|
+
|
|
4
|
+
import type { ApiProvider, ProviderResponse } from '../types.js';
|
|
5
|
+
|
|
6
|
+
interface LlamaCompletionOptions {
|
|
7
|
+
n_predict?: number;
|
|
8
|
+
temperature?: number;
|
|
9
|
+
top_k?: number;
|
|
10
|
+
top_p?: number;
|
|
11
|
+
n_keep?: number;
|
|
12
|
+
stop?: string[];
|
|
13
|
+
repeat_penalty?: number;
|
|
14
|
+
repeat_last_n?: number;
|
|
15
|
+
penalize_nl?: boolean;
|
|
16
|
+
presence_penalty?: number;
|
|
17
|
+
frequency_penalty?: number;
|
|
18
|
+
mirostat?: boolean;
|
|
19
|
+
mirostat_tau?: number;
|
|
20
|
+
mirostat_eta?: number;
|
|
21
|
+
seed?: number;
|
|
22
|
+
ignore_eos?: boolean;
|
|
23
|
+
logit_bias?: Record<string, number>;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
export class LlamaProvider implements ApiProvider {
|
|
27
|
+
modelName: string;
|
|
28
|
+
options?: LlamaCompletionOptions;
|
|
29
|
+
|
|
30
|
+
constructor(modelName: string, options?: LlamaCompletionOptions) {
|
|
31
|
+
this.modelName = modelName;
|
|
32
|
+
this.options = options;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
id(): string {
|
|
36
|
+
return `llama:${this.modelName}`;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
toString(): string {
|
|
40
|
+
return `[Llama Provider ${this.modelName}]`;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
async callApi(prompt: string, options?: LlamaCompletionOptions): Promise<ProviderResponse> {
|
|
44
|
+
options = Object.assign({}, this.options, options);
|
|
45
|
+
const body = {
|
|
46
|
+
prompt,
|
|
47
|
+
n_predict: options?.n_predict || 512,
|
|
48
|
+
temperature: options?.temperature,
|
|
49
|
+
top_k: options?.top_k,
|
|
50
|
+
top_p: options?.top_p,
|
|
51
|
+
n_keep: options?.n_keep,
|
|
52
|
+
stop: options?.stop,
|
|
53
|
+
repeat_penalty: options?.repeat_penalty,
|
|
54
|
+
repeat_last_n: options?.repeat_last_n,
|
|
55
|
+
penalize_nl: options?.penalize_nl,
|
|
56
|
+
presence_penalty: options?.presence_penalty,
|
|
57
|
+
frequency_penalty: options?.frequency_penalty,
|
|
58
|
+
mirostat: options?.mirostat,
|
|
59
|
+
mirostat_tau: options?.mirostat_tau,
|
|
60
|
+
mirostat_eta: options?.mirostat_eta,
|
|
61
|
+
seed: options?.seed,
|
|
62
|
+
ignore_eos: options?.ignore_eos,
|
|
63
|
+
logit_bias: options?.logit_bias,
|
|
64
|
+
};
|
|
65
|
+
|
|
66
|
+
let response;
|
|
67
|
+
try {
|
|
68
|
+
response = await fetchJsonWithCache(
|
|
69
|
+
`${process.env.LLAMA_BASE_URL || 'http://localhost:8080'}/completion`,
|
|
70
|
+
{
|
|
71
|
+
method: 'POST',
|
|
72
|
+
headers: {
|
|
73
|
+
'Content-Type': 'application/json',
|
|
74
|
+
},
|
|
75
|
+
body: JSON.stringify(body),
|
|
76
|
+
},
|
|
77
|
+
REQUEST_TIMEOUT_MS,
|
|
78
|
+
);
|
|
79
|
+
} catch (err) {
|
|
80
|
+
return {
|
|
81
|
+
error: `API call error: ${String(err)}`,
|
|
82
|
+
};
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
try {
|
|
86
|
+
return {
|
|
87
|
+
output: response.data.content,
|
|
88
|
+
};
|
|
89
|
+
} catch (err) {
|
|
90
|
+
return {
|
|
91
|
+
error: `API response error: ${String(err)}: ${JSON.stringify(response.data)}`,
|
|
92
|
+
};
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
}
|
package/src/providers.ts
CHANGED
|
@@ -1,28 +1,47 @@
|
|
|
1
1
|
import path from 'path';
|
|
2
2
|
|
|
3
|
-
import { ApiProvider, ProviderConfig, ProviderId, RawProviderConfig } from './types';
|
|
4
|
-
|
|
5
3
|
import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from './providers/openai';
|
|
6
4
|
import { AnthropicCompletionProvider } from './providers/anthropic';
|
|
7
5
|
import { ReplicateProvider } from './providers/replicate';
|
|
8
6
|
import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
|
|
7
|
+
import { LlamaProvider } from './providers/llama';
|
|
9
8
|
import { ScriptCompletionProvider } from './providers/scriptCompletion';
|
|
10
9
|
import {
|
|
11
10
|
AzureOpenAiChatCompletionProvider,
|
|
12
11
|
AzureOpenAiCompletionProvider,
|
|
13
12
|
} from './providers/azureopenai';
|
|
14
13
|
|
|
14
|
+
import type {
|
|
15
|
+
ApiProvider,
|
|
16
|
+
ProviderConfig,
|
|
17
|
+
ProviderFunction,
|
|
18
|
+
ProviderId,
|
|
19
|
+
RawProviderConfig,
|
|
20
|
+
} from './types';
|
|
21
|
+
|
|
15
22
|
export async function loadApiProviders(
|
|
16
|
-
providerPaths: ProviderId | ProviderId[] | RawProviderConfig[],
|
|
23
|
+
providerPaths: ProviderId | ProviderId[] | RawProviderConfig[] | ProviderFunction,
|
|
17
24
|
basePath?: string,
|
|
18
25
|
): Promise<ApiProvider[]> {
|
|
19
26
|
if (typeof providerPaths === 'string') {
|
|
20
27
|
return [await loadApiProvider(providerPaths, undefined, basePath)];
|
|
28
|
+
} else if (typeof providerPaths === 'function') {
|
|
29
|
+
return [
|
|
30
|
+
{
|
|
31
|
+
id: () => 'custom-function',
|
|
32
|
+
callApi: providerPaths,
|
|
33
|
+
},
|
|
34
|
+
];
|
|
21
35
|
} else if (Array.isArray(providerPaths)) {
|
|
22
36
|
return Promise.all(
|
|
23
|
-
providerPaths.map((provider) => {
|
|
37
|
+
providerPaths.map((provider, idx) => {
|
|
24
38
|
if (typeof provider === 'string') {
|
|
25
39
|
return loadApiProvider(provider, undefined, basePath);
|
|
40
|
+
} else if (typeof provider === 'function') {
|
|
41
|
+
return {
|
|
42
|
+
id: () => `custom-function-${idx}`,
|
|
43
|
+
callApi: provider,
|
|
44
|
+
};
|
|
26
45
|
} else {
|
|
27
46
|
const id = Object.keys(provider)[0];
|
|
28
47
|
const context = { ...provider[id], id };
|
|
@@ -115,7 +134,10 @@ export async function loadApiProvider(
|
|
|
115
134
|
return new ReplicateProvider(modelName, undefined, context?.config);
|
|
116
135
|
}
|
|
117
136
|
|
|
118
|
-
if (providerPath
|
|
137
|
+
if (providerPath === 'llama' || providerPath.startsWith('llama:')) {
|
|
138
|
+
const modelName = providerPath.split(':')[1];
|
|
139
|
+
return new LlamaProvider(modelName, context?.config);
|
|
140
|
+
} else if (providerPath?.startsWith('localai:')) {
|
|
119
141
|
const options = providerPath.split(':');
|
|
120
142
|
const modelType = options[1];
|
|
121
143
|
const modelName = options[2];
|
package/src/table.ts
CHANGED
|
@@ -24,11 +24,11 @@ export function generateTable(summary: EvaluateSummary, tableCellMaxLength = 250
|
|
|
24
24
|
text = text.slice(0, tableCellMaxLength) + '...';
|
|
25
25
|
}
|
|
26
26
|
if (pass) {
|
|
27
|
-
return chalk.green
|
|
27
|
+
return chalk.green('[PASS] ') + text;
|
|
28
28
|
} else if (!pass) {
|
|
29
29
|
// color everything red up until '---'
|
|
30
30
|
return (
|
|
31
|
-
chalk.red
|
|
31
|
+
chalk.red('[FAIL] ') +
|
|
32
32
|
text
|
|
33
33
|
.split('---')
|
|
34
34
|
.map((c, idx) => (idx === 0 ? chalk.red.bold(c) : c))
|
package/src/types.ts
CHANGED
|
@@ -151,6 +151,7 @@ type BaseAssertionTypes =
|
|
|
151
151
|
| 'is-json'
|
|
152
152
|
| 'contains-json'
|
|
153
153
|
| 'javascript'
|
|
154
|
+
| 'python'
|
|
154
155
|
| 'similar'
|
|
155
156
|
| 'llm-rubric'
|
|
156
157
|
| 'webhook'
|
|
@@ -168,7 +169,10 @@ export interface Assertion {
|
|
|
168
169
|
type: AssertionType;
|
|
169
170
|
|
|
170
171
|
// The expected value, if applicable
|
|
171
|
-
value?:
|
|
172
|
+
value?:
|
|
173
|
+
| string
|
|
174
|
+
| string[]
|
|
175
|
+
| ((output: string, testCase: AtomicTestCase, assertion: Assertion) => Promise<GradingResult>);
|
|
172
176
|
|
|
173
177
|
// The threshold value, only applicable for similarity (cosine distance)
|
|
174
178
|
threshold?: number;
|
|
@@ -188,9 +192,6 @@ export interface TestCase {
|
|
|
188
192
|
// Key-value pairs to substitute in the prompt
|
|
189
193
|
vars?: Record<string, string | string[] | object>;
|
|
190
194
|
|
|
191
|
-
// Optional filepath or glob pattern to load vars from
|
|
192
|
-
loadVars?: string | string[];
|
|
193
|
-
|
|
194
195
|
// Optional list of automatic checks to run on the LLM output
|
|
195
196
|
assert?: Assertion[];
|
|
196
197
|
|
|
@@ -198,6 +199,17 @@ export interface TestCase {
|
|
|
198
199
|
options?: PromptConfig & OutputConfig & GradingConfig;
|
|
199
200
|
}
|
|
200
201
|
|
|
202
|
+
export interface Scenario {
|
|
203
|
+
// Optional description of what you're testing
|
|
204
|
+
description?: string;
|
|
205
|
+
|
|
206
|
+
// Default test case config
|
|
207
|
+
config: Partial<TestCase>[];
|
|
208
|
+
|
|
209
|
+
// Optional list of automatic checks to run on the LLM output
|
|
210
|
+
tests: TestCase[];
|
|
211
|
+
}
|
|
212
|
+
|
|
201
213
|
// Same as a TestCase, except the `vars` object has been flattened into its final form.
|
|
202
214
|
export interface AtomicTestCase extends TestCase {
|
|
203
215
|
vars?: Record<string, string | object>;
|
|
@@ -221,12 +233,17 @@ export interface TestSuite {
|
|
|
221
233
|
// Test cases
|
|
222
234
|
tests?: TestCase[];
|
|
223
235
|
|
|
236
|
+
// scenarios
|
|
237
|
+
scenarios?: Scenario[];
|
|
238
|
+
|
|
224
239
|
// Default test case config
|
|
225
240
|
defaultTest?: Partial<TestCase>;
|
|
226
241
|
}
|
|
227
242
|
|
|
228
243
|
export type ProviderId = string;
|
|
229
244
|
|
|
245
|
+
export type ProviderFunction = (prompt: string) => Promise<ProviderResponse>;
|
|
246
|
+
|
|
230
247
|
export type RawProviderConfig = Record<ProviderId, Omit<ProviderConfig, 'id'>>;
|
|
231
248
|
|
|
232
249
|
// TestSuiteConfig = Test Suite, but before everything is parsed and resolved. Providers are just strings, prompts are filepaths, tests can be filepath or inline.
|
|
@@ -235,7 +252,7 @@ export interface TestSuiteConfig {
|
|
|
235
252
|
description?: string;
|
|
236
253
|
|
|
237
254
|
// One or more LLM APIs to use, for example: openai:gpt-3.5-turbo, openai:gpt-4, localai:chat:vicuna
|
|
238
|
-
providers: ProviderId | ProviderId[] | RawProviderConfig[];
|
|
255
|
+
providers: ProviderId | ProviderId[] | RawProviderConfig[] | ProviderFunction;
|
|
239
256
|
|
|
240
257
|
// One or more prompt files to load
|
|
241
258
|
prompts: string | string[];
|
|
@@ -243,6 +260,9 @@ export interface TestSuiteConfig {
|
|
|
243
260
|
// Path to a test file, OR list of LLM prompt variations (aka "test case")
|
|
244
261
|
tests: string | string[] | TestCase[];
|
|
245
262
|
|
|
263
|
+
// Scenarios, groupings of data and tests to be evaluated
|
|
264
|
+
scenarios?: Scenario[];
|
|
265
|
+
|
|
246
266
|
// Sets the default properties for each test case. Useful for setting an assertion, on all test cases, for example.
|
|
247
267
|
defaultTest?: Omit<TestCase, 'description'>;
|
|
248
268
|
|