promptfoo 0.17.9 → 0.18.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/assertions.ts CHANGED
@@ -99,12 +99,21 @@ export async function runAssertion(
99
99
  type: baseType,
100
100
  });
101
101
 
102
+ //render assertion values
103
+ let renderedValue = assertion.value;
104
+ // renderString for assertion values
105
+ if (renderedValue && typeof renderedValue === 'string') {
106
+ renderedValue = nunjucks.renderString(renderedValue, test.vars || {});
107
+ } else if (renderedValue && Array.isArray(renderedValue)) {
108
+ renderedValue = renderedValue.map((v) => nunjucks.renderString(v, test.vars || {}));
109
+ }
110
+
102
111
  if (baseType === 'equals') {
103
- pass = assertion.value === output;
112
+ pass = renderedValue === output;
104
113
  return {
105
114
  pass,
106
115
  score: pass ? 1 : 0,
107
- reason: pass ? 'Assertion passed' : `Expected output "${assertion.value}"`,
116
+ reason: pass ? 'Assertion passed' : `Expected output "${renderedValue}"`,
108
117
  };
109
118
  }
110
119
 
@@ -123,103 +132,99 @@ export async function runAssertion(
123
132
  }
124
133
 
125
134
  if (baseType === 'contains') {
126
- invariant(assertion.value, '"contains" assertion type must have a string or number value');
135
+ invariant(renderedValue, '"contains" assertion type must have a string or number value');
127
136
  invariant(
128
- typeof assertion.value === 'string' || typeof assertion.value === 'number',
137
+ typeof renderedValue === 'string' || typeof renderedValue === 'number',
129
138
  '"contains" assertion type must have a string or number value',
130
139
  );
131
- pass = output.includes(String(assertion.value)) !== inverse;
140
+ pass = output.includes(String(renderedValue)) !== inverse;
132
141
  return {
133
142
  pass,
134
143
  score: pass ? 1 : 0,
135
144
  reason: pass
136
145
  ? 'Assertion passed'
137
- : `Expected output to ${inverse ? 'not ' : ''}contain "${assertion.value}"`,
146
+ : `Expected output to ${inverse ? 'not ' : ''}contain "${renderedValue}"`,
138
147
  };
139
148
  }
140
149
 
141
150
  if (baseType === 'contains-any') {
142
- invariant(assertion.value, '"contains-any" assertion type must have a value');
151
+ invariant(renderedValue, '"contains-any" assertion type must have a value');
143
152
  invariant(
144
- Array.isArray(assertion.value),
153
+ Array.isArray(renderedValue),
145
154
  '"contains-any" assertion type must have an array value',
146
155
  );
147
- pass = assertion.value.some((value) => output.includes(value)) !== inverse;
156
+ pass = renderedValue.some((value) => output.includes(value)) !== inverse;
148
157
  return {
149
158
  pass,
150
159
  score: pass ? 1 : 0,
151
160
  reason: pass
152
161
  ? 'Assertion passed'
153
- : `Expected output to ${inverse ? 'not ' : ''}contain one of "${assertion.value.join(
154
- ', ',
155
- )}"`,
162
+ : `Expected output to ${inverse ? 'not ' : ''}contain one of "${renderedValue.join(', ')}"`,
156
163
  };
157
164
  }
158
165
 
159
166
  if (baseType === 'contains-all') {
160
- invariant(assertion.value, '"contains-all" assertion type must have a value');
167
+ invariant(renderedValue, '"contains-all" assertion type must have a value');
161
168
  invariant(
162
- Array.isArray(assertion.value),
169
+ Array.isArray(renderedValue),
163
170
  '"contains-all" assertion type must have an array value',
164
171
  );
165
- pass = assertion.value.every((value) => output.includes(value)) !== inverse;
172
+ pass = renderedValue.every((value) => output.includes(value)) !== inverse;
166
173
  return {
167
174
  pass,
168
175
  score: pass ? 1 : 0,
169
176
  reason: pass
170
177
  ? 'Assertion passed'
171
- : `Expected output to ${inverse ? 'not ' : ''}contain all of "${assertion.value.join(
172
- ', ',
173
- )}"`,
178
+ : `Expected output to ${inverse ? 'not ' : ''}contain all of "${renderedValue.join(', ')}"`,
174
179
  };
175
180
  }
176
181
 
177
182
  if (baseType === 'regex') {
178
- invariant(assertion.value, '"regex" assertion type must have a string value');
183
+ invariant(renderedValue, '"regex" assertion type must have a string value');
179
184
  invariant(
180
- typeof assertion.value === 'string',
185
+ typeof renderedValue === 'string',
181
186
  '"contains" assertion type must have a string value',
182
187
  );
183
- const regex = new RegExp(assertion.value);
188
+ const regex = new RegExp(renderedValue);
184
189
  pass = regex.test(output) !== inverse;
185
190
  return {
186
191
  pass,
187
192
  score: pass ? 1 : 0,
188
193
  reason: pass
189
194
  ? 'Assertion passed'
190
- : `Expected output to ${inverse ? 'not ' : ''}match regex "${assertion.value}"`,
195
+ : `Expected output to ${inverse ? 'not ' : ''}match regex "${renderedValue}"`,
191
196
  };
192
197
  }
193
198
 
194
199
  if (baseType === 'icontains') {
195
- invariant(assertion.value, '"icontains" assertion type must have a string or number value');
200
+ invariant(renderedValue, '"icontains" assertion type must have a string or number value');
196
201
  invariant(
197
- typeof assertion.value === 'string' || typeof assertion.value === 'number',
202
+ typeof renderedValue === 'string' || typeof renderedValue === 'number',
198
203
  '"icontains" assertion type must have a string or number value',
199
204
  );
200
- pass = output.toLowerCase().includes(String(assertion.value).toLowerCase()) !== inverse;
205
+ pass = output.toLowerCase().includes(String(renderedValue).toLowerCase()) !== inverse;
201
206
  return {
202
207
  pass,
203
208
  score: pass ? 1 : 0,
204
209
  reason: pass
205
210
  ? 'Assertion passed'
206
- : `Expected output to ${inverse ? 'not ' : ''}contain "${assertion.value}"`,
211
+ : `Expected output to ${inverse ? 'not ' : ''}contain "${renderedValue}"`,
207
212
  };
208
213
  }
209
214
 
210
215
  if (baseType === 'starts-with') {
211
- invariant(assertion.value, '"starts-with" assertion type must have a string value');
216
+ invariant(renderedValue, '"starts-with" assertion type must have a string value');
212
217
  invariant(
213
- typeof assertion.value === 'string',
218
+ typeof renderedValue === 'string',
214
219
  '"starts-with" assertion type must have a string value',
215
220
  );
216
- pass = output.startsWith(String(assertion.value)) !== inverse;
221
+ pass = output.startsWith(String(renderedValue)) !== inverse;
217
222
  return {
218
223
  pass,
219
224
  score: pass ? 1 : 0,
220
225
  reason: pass
221
226
  ? 'Assertion passed'
222
- : `Expected output to ${inverse ? 'not ' : ''}start with "${assertion.value}"`,
227
+ : `Expected output to ${inverse ? 'not ' : ''}start with "${renderedValue}"`,
223
228
  };
224
229
  }
225
230
 
@@ -234,12 +239,16 @@ export async function runAssertion(
234
239
  };
235
240
  }
236
241
 
242
+ const context = {
243
+ vars: test.vars || {},
244
+ };
245
+
237
246
  if (baseType === 'javascript') {
238
247
  try {
239
- const customFunction = new Function('output', 'context', `return ${assertion.value}`);
240
- const context = {
241
- vars: test.vars || {},
242
- };
248
+ if (typeof assertion.value === 'function') {
249
+ return assertion.value(output, test, assertion);
250
+ }
251
+ const customFunction = new Function('output', 'context', `return ${renderedValue}`);
243
252
  const result = customFunction(output, context) as any;
244
253
  if (typeof result === 'boolean') {
245
254
  pass = result !== inverse;
@@ -255,7 +264,7 @@ export async function runAssertion(
255
264
  pass: false,
256
265
  score: 0,
257
266
  reason: `Custom function threw error: ${(err as Error).message}
258
- ${assertion.value}`,
267
+ ${renderedValue}`,
259
268
  };
260
269
  }
261
270
  return {
@@ -264,41 +273,82 @@ ${assertion.value}`,
264
273
  reason: pass
265
274
  ? 'Assertion passed'
266
275
  : `Custom function returned ${inverse ? 'true' : 'false'}
276
+ ${renderedValue}`,
277
+ };
278
+ }
279
+
280
+ if (baseType === 'python') {
281
+ try {
282
+ const { execSync } = require('child_process');
283
+ const escapedOutput = output.replace(/'/g, "\\'").replace(/"/g, '\\"');
284
+ const escapedContext = JSON.stringify(context).replace(/'/g, "\\'").replace(/"/g, '\\"');
285
+ const result = execSync(
286
+ `python -c "import json; import math; import os; import sys; import re; import datetime; import random; import collections; output='${escapedOutput}'; context='${escapedContext}'; print(json.dumps(${assertion.value}))"`,
287
+ )
288
+ .toString()
289
+ .trim();
290
+ if (result === 'true') {
291
+ pass = true;
292
+ score = 1.0;
293
+ } else if (result === 'false') {
294
+ pass = false;
295
+ score = 0.0;
296
+ } else if (result.startsWith('{')) {
297
+ return JSON.parse(result);
298
+ } else {
299
+ pass = true;
300
+ score = parseFloat(result);
301
+ if (isNaN(score)) {
302
+ throw new Error(
303
+ 'Python code must return a boolean, number, or {pass, score, reason} object',
304
+ );
305
+ }
306
+ }
307
+ } catch (err) {
308
+ return {
309
+ pass: false,
310
+ score: 0,
311
+ reason: `Python code execution failed: ${(err as Error).message}`,
312
+ };
313
+ }
314
+ return {
315
+ pass,
316
+ score,
317
+ reason: pass
318
+ ? 'Assertion passed'
319
+ : `Python code returned ${pass ? 'true' : 'false'}
267
320
  ${assertion.value}`,
268
321
  };
269
322
  }
270
323
 
271
324
  if (baseType === 'similar') {
272
- invariant(assertion.value, 'Similarity assertion must have a string value');
325
+ invariant(renderedValue, 'Similarity assertion must have a string value');
273
326
  invariant(
274
- typeof assertion.value === 'string',
327
+ typeof renderedValue === 'string',
275
328
  '"contains" assertion type must have a string value',
276
329
  );
277
- return matchesSimilarity(assertion.value, output, assertion.threshold || 0.75, inverse);
330
+ return matchesSimilarity(renderedValue, output, assertion.threshold || 0.75, inverse);
278
331
  }
279
332
 
280
333
  if (baseType === 'llm-rubric') {
281
- invariant(assertion.value, 'Similarity assertion must have a string value');
334
+ invariant(renderedValue, 'Similarity assertion must have a string value');
282
335
  invariant(
283
- typeof assertion.value === 'string',
336
+ typeof renderedValue === 'string',
284
337
  '"contains" assertion type must have a string value',
285
338
  );
286
- return matchesLlmRubric(assertion.value, output, test.options);
339
+ return matchesLlmRubric(renderedValue, output, test.options);
287
340
  }
288
341
 
289
342
  if (baseType === 'webhook') {
290
- invariant(assertion.value, '"webhook" assertion type must have a URL value');
291
- invariant(
292
- typeof assertion.value === 'string',
293
- '"webhook" assertion type must have a URL value',
294
- );
343
+ invariant(renderedValue, '"webhook" assertion type must have a URL value');
344
+ invariant(typeof renderedValue === 'string', '"webhook" assertion type must have a URL value');
295
345
 
296
346
  try {
297
347
  const context = {
298
348
  vars: test.vars || {},
299
349
  };
300
350
  const response = await fetchWithRetries(
301
- assertion.value,
351
+ renderedValue,
302
352
  {
303
353
  method: 'POST',
304
354
  headers: {
@@ -339,8 +389,11 @@ ${assertion.value}`,
339
389
  }
340
390
 
341
391
  if (baseType === 'rouge-n') {
342
- invariant(assertion.value, '"rouge" assertion type must a value (string or string array)');
343
- return handleRougeScore(baseType, assertion, assertion.value, output, inverse);
392
+ invariant(
393
+ typeof renderedValue === 'string' || Array.isArray(renderedValue),
394
+ '"rouge" assertion type must be a value (string or string array)',
395
+ );
396
+ return handleRougeScore(baseType, assertion, renderedValue, output, inverse);
344
397
  }
345
398
 
346
399
  throw new Error('Unknown assertion type: ' + assertion.type);
package/src/evaluator.ts CHANGED
@@ -255,10 +255,11 @@ class Evaluator {
255
255
  }
256
256
 
257
257
  // Aggregate all vars across test cases
258
-
259
- const tests = (
258
+ let tests = (
260
259
  testSuite.tests && testSuite.tests.length > 0
261
260
  ? testSuite.tests
261
+ : testSuite.scenarios
262
+ ? []
262
263
  : [
263
264
  {
264
265
  // Dummy test for cases when we're only comparing raw prompts.
@@ -269,6 +270,35 @@ class Evaluator {
269
270
  return Object.assign(finalTestCase, test);
270
271
  });
271
272
 
273
+ // Build scenarios and add to tests
274
+ if (testSuite.scenarios && testSuite.scenarios.length > 0) {
275
+ for (const scenario of testSuite.scenarios) {
276
+ for (const data of scenario.config) {
277
+ // Merge defaultTest with scenario config
278
+ const scenarioTests = (
279
+ scenario.tests || [
280
+ {
281
+ // Dummy test for cases when we're only comparing raw prompts.
282
+ },
283
+ ]
284
+ ).map((test) => {
285
+ return {
286
+ ...testSuite.defaultTest,
287
+ ...data,
288
+ ...test,
289
+ vars: {
290
+ ...testSuite.defaultTest?.vars,
291
+ ...data.vars,
292
+ ...test.vars,
293
+ },
294
+ };
295
+ });
296
+ // Add scenario tests to tests
297
+ tests = tests.concat(scenarioTests);
298
+ }
299
+ }
300
+ }
301
+
272
302
  const varNames: Set<string> = new Set();
273
303
  const varsWithSpecialColsRemoved: Record<string, string | string[] | object>[] = [];
274
304
  for (const testCase of tests) {
@@ -352,8 +382,7 @@ class Evaluator {
352
382
  // Set up progress bar...
353
383
  let progressbar: SingleBar | undefined;
354
384
  if (options.showProgressBar) {
355
- const totalNumRuns =
356
- testSuite.prompts.length * testSuite.providers.length * (totalVarCombinations || 1);
385
+ const totalNumRuns = runEvalOptions.length;
357
386
  const cliProgress = await import('cli-progress');
358
387
  progressbar = new cliProgress.SingleBar(
359
388
  {
package/src/index.ts CHANGED
@@ -3,7 +3,7 @@ import providers from './providers';
3
3
  import telemetry from './telemetry';
4
4
  import { evaluate as doEvaluate } from './evaluator';
5
5
  import { loadApiProviders } from './providers';
6
- import { readTests } from './util';
6
+ import { readTests, writeOutput } from './util';
7
7
  import type { EvaluateOptions, TestSuite, TestSuiteConfig } from './types';
8
8
 
9
9
  export * from './types';
@@ -28,6 +28,11 @@ async function evaluate(testSuite: EvaluateTestSuite, options: EvaluateOptions =
28
28
  };
29
29
  telemetry.maybeShowNotice();
30
30
  const ret = await doEvaluate(constructedTestSuite, options);
31
+
32
+ if (testSuite.outputPath) {
33
+ writeOutput(testSuite.outputPath, ret, testSuite, null);
34
+ }
35
+
31
36
  await telemetry.send();
32
37
  return ret;
33
38
  }
package/src/main.ts CHANGED
@@ -281,6 +281,7 @@ async function main() {
281
281
  prompts: cmdObj.prompts || fileConfig.prompts || defaultConfig.prompts,
282
282
  providers: cmdObj.providers || fileConfig.providers || defaultConfig.providers,
283
283
  tests: cmdObj.tests || cmdObj.vars || fileConfig.tests || defaultConfig.tests,
284
+ scenarios: fileConfig.scenarios || defaultConfig.scenarios,
284
285
  sharing:
285
286
  process.env.PROMPTFOO_DISABLE_SHARING === '1'
286
287
  ? false
@@ -310,6 +311,18 @@ async function main() {
310
311
  config.tests,
311
312
  cmdObj.tests ? undefined : basePath,
312
313
  );
314
+
315
+ //parse testCases for each scenario
316
+ if (fileConfig.scenarios) {
317
+ for (const scenario of fileConfig.scenarios) {
318
+ const parsedScenarioTests: TestCase[] = await readTests(
319
+ scenario.tests,
320
+ cmdObj.tests ? undefined : basePath,
321
+ );
322
+ scenario.tests = parsedScenarioTests;
323
+ }
324
+ }
325
+
313
326
  const parsedProviderPromptMap = readProviderPromptMap(config, parsedPrompts);
314
327
 
315
328
  if (parsedPrompts.length === 0) {
@@ -334,6 +347,7 @@ async function main() {
334
347
  providers: parsedProviders,
335
348
  providerPromptMap: parsedProviderPromptMap,
336
349
  tests: parsedTests,
350
+ scenarios: config.scenarios,
337
351
  defaultTest,
338
352
  };
339
353
 
@@ -0,0 +1,95 @@
1
+ import { fetchJsonWithCache } from '../cache';
2
+ import { REQUEST_TIMEOUT_MS } from './shared';
3
+
4
+ import type { ApiProvider, ProviderResponse } from '../types.js';
5
+
6
+ interface LlamaCompletionOptions {
7
+ n_predict?: number;
8
+ temperature?: number;
9
+ top_k?: number;
10
+ top_p?: number;
11
+ n_keep?: number;
12
+ stop?: string[];
13
+ repeat_penalty?: number;
14
+ repeat_last_n?: number;
15
+ penalize_nl?: boolean;
16
+ presence_penalty?: number;
17
+ frequency_penalty?: number;
18
+ mirostat?: boolean;
19
+ mirostat_tau?: number;
20
+ mirostat_eta?: number;
21
+ seed?: number;
22
+ ignore_eos?: boolean;
23
+ logit_bias?: Record<string, number>;
24
+ }
25
+
26
+ export class LlamaProvider implements ApiProvider {
27
+ modelName: string;
28
+ options?: LlamaCompletionOptions;
29
+
30
+ constructor(modelName: string, options?: LlamaCompletionOptions) {
31
+ this.modelName = modelName;
32
+ this.options = options;
33
+ }
34
+
35
+ id(): string {
36
+ return `llama:${this.modelName}`;
37
+ }
38
+
39
+ toString(): string {
40
+ return `[Llama Provider ${this.modelName}]`;
41
+ }
42
+
43
+ async callApi(prompt: string, options?: LlamaCompletionOptions): Promise<ProviderResponse> {
44
+ options = Object.assign({}, this.options, options);
45
+ const body = {
46
+ prompt,
47
+ n_predict: options?.n_predict || 512,
48
+ temperature: options?.temperature,
49
+ top_k: options?.top_k,
50
+ top_p: options?.top_p,
51
+ n_keep: options?.n_keep,
52
+ stop: options?.stop,
53
+ repeat_penalty: options?.repeat_penalty,
54
+ repeat_last_n: options?.repeat_last_n,
55
+ penalize_nl: options?.penalize_nl,
56
+ presence_penalty: options?.presence_penalty,
57
+ frequency_penalty: options?.frequency_penalty,
58
+ mirostat: options?.mirostat,
59
+ mirostat_tau: options?.mirostat_tau,
60
+ mirostat_eta: options?.mirostat_eta,
61
+ seed: options?.seed,
62
+ ignore_eos: options?.ignore_eos,
63
+ logit_bias: options?.logit_bias,
64
+ };
65
+
66
+ let response;
67
+ try {
68
+ response = await fetchJsonWithCache(
69
+ `${process.env.LLAMA_BASE_URL || 'http://localhost:8080'}/completion`,
70
+ {
71
+ method: 'POST',
72
+ headers: {
73
+ 'Content-Type': 'application/json',
74
+ },
75
+ body: JSON.stringify(body),
76
+ },
77
+ REQUEST_TIMEOUT_MS,
78
+ );
79
+ } catch (err) {
80
+ return {
81
+ error: `API call error: ${String(err)}`,
82
+ };
83
+ }
84
+
85
+ try {
86
+ return {
87
+ output: response.data.content,
88
+ };
89
+ } catch (err) {
90
+ return {
91
+ error: `API response error: ${String(err)}: ${JSON.stringify(response.data)}`,
92
+ };
93
+ }
94
+ }
95
+ }
package/src/providers.ts CHANGED
@@ -1,28 +1,47 @@
1
1
  import path from 'path';
2
2
 
3
- import { ApiProvider, ProviderConfig, ProviderId, RawProviderConfig } from './types';
4
-
5
3
  import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from './providers/openai';
6
4
  import { AnthropicCompletionProvider } from './providers/anthropic';
7
5
  import { ReplicateProvider } from './providers/replicate';
8
6
  import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
7
+ import { LlamaProvider } from './providers/llama';
9
8
  import { ScriptCompletionProvider } from './providers/scriptCompletion';
10
9
  import {
11
10
  AzureOpenAiChatCompletionProvider,
12
11
  AzureOpenAiCompletionProvider,
13
12
  } from './providers/azureopenai';
14
13
 
14
+ import type {
15
+ ApiProvider,
16
+ ProviderConfig,
17
+ ProviderFunction,
18
+ ProviderId,
19
+ RawProviderConfig,
20
+ } from './types';
21
+
15
22
  export async function loadApiProviders(
16
- providerPaths: ProviderId | ProviderId[] | RawProviderConfig[],
23
+ providerPaths: ProviderId | ProviderId[] | RawProviderConfig[] | ProviderFunction,
17
24
  basePath?: string,
18
25
  ): Promise<ApiProvider[]> {
19
26
  if (typeof providerPaths === 'string') {
20
27
  return [await loadApiProvider(providerPaths, undefined, basePath)];
28
+ } else if (typeof providerPaths === 'function') {
29
+ return [
30
+ {
31
+ id: () => 'custom-function',
32
+ callApi: providerPaths,
33
+ },
34
+ ];
21
35
  } else if (Array.isArray(providerPaths)) {
22
36
  return Promise.all(
23
- providerPaths.map((provider) => {
37
+ providerPaths.map((provider, idx) => {
24
38
  if (typeof provider === 'string') {
25
39
  return loadApiProvider(provider, undefined, basePath);
40
+ } else if (typeof provider === 'function') {
41
+ return {
42
+ id: () => `custom-function-${idx}`,
43
+ callApi: provider,
44
+ };
26
45
  } else {
27
46
  const id = Object.keys(provider)[0];
28
47
  const context = { ...provider[id], id };
@@ -115,7 +134,10 @@ export async function loadApiProvider(
115
134
  return new ReplicateProvider(modelName, undefined, context?.config);
116
135
  }
117
136
 
118
- if (providerPath?.startsWith('localai:')) {
137
+ if (providerPath === 'llama' || providerPath.startsWith('llama:')) {
138
+ const modelName = providerPath.split(':')[1];
139
+ return new LlamaProvider(modelName, context?.config);
140
+ } else if (providerPath?.startsWith('localai:')) {
119
141
  const options = providerPath.split(':');
120
142
  const modelType = options[1];
121
143
  const modelName = options[2];
package/src/table.ts CHANGED
@@ -24,11 +24,11 @@ export function generateTable(summary: EvaluateSummary, tableCellMaxLength = 250
24
24
  text = text.slice(0, tableCellMaxLength) + '...';
25
25
  }
26
26
  if (pass) {
27
- return chalk.green.bold('[PASS] ') + text;
27
+ return chalk.green('[PASS] ') + text;
28
28
  } else if (!pass) {
29
29
  // color everything red up until '---'
30
30
  return (
31
- chalk.red.bold('[FAIL] ') +
31
+ chalk.red('[FAIL] ') +
32
32
  text
33
33
  .split('---')
34
34
  .map((c, idx) => (idx === 0 ? chalk.red.bold(c) : c))
package/src/types.ts CHANGED
@@ -151,6 +151,7 @@ type BaseAssertionTypes =
151
151
  | 'is-json'
152
152
  | 'contains-json'
153
153
  | 'javascript'
154
+ | 'python'
154
155
  | 'similar'
155
156
  | 'llm-rubric'
156
157
  | 'webhook'
@@ -168,7 +169,10 @@ export interface Assertion {
168
169
  type: AssertionType;
169
170
 
170
171
  // The expected value, if applicable
171
- value?: string | string[];
172
+ value?:
173
+ | string
174
+ | string[]
175
+ | ((output: string, testCase: AtomicTestCase, assertion: Assertion) => Promise<GradingResult>);
172
176
 
173
177
  // The threshold value, only applicable for similarity (cosine distance)
174
178
  threshold?: number;
@@ -188,9 +192,6 @@ export interface TestCase {
188
192
  // Key-value pairs to substitute in the prompt
189
193
  vars?: Record<string, string | string[] | object>;
190
194
 
191
- // Optional filepath or glob pattern to load vars from
192
- loadVars?: string | string[];
193
-
194
195
  // Optional list of automatic checks to run on the LLM output
195
196
  assert?: Assertion[];
196
197
 
@@ -198,6 +199,17 @@ export interface TestCase {
198
199
  options?: PromptConfig & OutputConfig & GradingConfig;
199
200
  }
200
201
 
202
+ export interface Scenario {
203
+ // Optional description of what you're testing
204
+ description?: string;
205
+
206
+ // Default test case config
207
+ config: Partial<TestCase>[];
208
+
209
+ // Optional list of automatic checks to run on the LLM output
210
+ tests: TestCase[];
211
+ }
212
+
201
213
  // Same as a TestCase, except the `vars` object has been flattened into its final form.
202
214
  export interface AtomicTestCase extends TestCase {
203
215
  vars?: Record<string, string | object>;
@@ -221,12 +233,17 @@ export interface TestSuite {
221
233
  // Test cases
222
234
  tests?: TestCase[];
223
235
 
236
+ // scenarios
237
+ scenarios?: Scenario[];
238
+
224
239
  // Default test case config
225
240
  defaultTest?: Partial<TestCase>;
226
241
  }
227
242
 
228
243
  export type ProviderId = string;
229
244
 
245
+ export type ProviderFunction = (prompt: string) => Promise<ProviderResponse>;
246
+
230
247
  export type RawProviderConfig = Record<ProviderId, Omit<ProviderConfig, 'id'>>;
231
248
 
232
249
  // TestSuiteConfig = Test Suite, but before everything is parsed and resolved. Providers are just strings, prompts are filepaths, tests can be filepath or inline.
@@ -235,7 +252,7 @@ export interface TestSuiteConfig {
235
252
  description?: string;
236
253
 
237
254
  // One or more LLM APIs to use, for example: openai:gpt-3.5-turbo, openai:gpt-4, localai:chat:vicuna
238
- providers: ProviderId | ProviderId[] | RawProviderConfig[];
255
+ providers: ProviderId | ProviderId[] | RawProviderConfig[] | ProviderFunction;
239
256
 
240
257
  // One or more prompt files to load
241
258
  prompts: string | string[];
@@ -243,6 +260,9 @@ export interface TestSuiteConfig {
243
260
  // Path to a test file, OR list of LLM prompt variations (aka "test case")
244
261
  tests: string | string[] | TestCase[];
245
262
 
263
+ // Scenarios, groupings of data and tests to be evaluated
264
+ scenarios?: Scenario[];
265
+
246
266
  // Sets the default properties for each test case. Useful for setting an assertion, on all test cases, for example.
247
267
  defaultTest?: Omit<TestCase, 'description'>;
248
268