promptfoo 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +19 -0
- package/README.md +353 -0
- package/dist/__mocks__/esm.d.ts +2 -0
- package/dist/__mocks__/esm.d.ts.map +1 -0
- package/dist/__mocks__/esm.js +4 -0
- package/dist/__mocks__/esm.js.map +1 -0
- package/dist/esm.d.ts +2 -0
- package/dist/esm.d.ts.map +1 -0
- package/dist/esm.js +9 -0
- package/dist/esm.js.map +1 -0
- package/dist/evaluator.d.ts +3 -0
- package/dist/evaluator.d.ts.map +1 -0
- package/dist/evaluator.js +162 -0
- package/dist/evaluator.js.map +1 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +29 -0
- package/dist/index.js.map +1 -0
- package/dist/logger.d.ts +11 -0
- package/dist/logger.d.ts.map +1 -0
- package/dist/logger.js +38 -0
- package/dist/logger.js.map +1 -0
- package/dist/main.d.ts +3 -0
- package/dist/main.d.ts.map +1 -0
- package/dist/main.js +90 -0
- package/dist/main.js.map +1 -0
- package/dist/providers.d.ts +21 -0
- package/dist/providers.d.ts.map +1 -0
- package/dist/providers.js +145 -0
- package/dist/providers.js.map +1 -0
- package/dist/tableOutput.html +55 -0
- package/dist/types.d.ts +55 -0
- package/dist/types.d.ts.map +1 -0
- package/dist/types.js +2 -0
- package/dist/types.js.map +1 -0
- package/dist/util.d.ts +6 -0
- package/dist/util.d.ts.map +1 -0
- package/dist/util.js +62 -0
- package/dist/util.js.map +1 -0
- package/package.json +55 -0
- package/src/__mocks__/esm.ts +3 -0
- package/src/esm.ts +10 -0
- package/src/evaluator.ts +203 -0
- package/src/index.ts +35 -0
- package/src/logger.ts +38 -0
- package/src/main.ts +108 -0
- package/src/providers.ts +170 -0
- package/src/tableOutput.html +55 -0
- package/src/types.ts +63 -0
- package/src/util.ts +67 -0
package/src/evaluator.ts
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
import async from 'async';
|
|
2
|
+
import nunjucks from 'nunjucks';
|
|
3
|
+
|
|
4
|
+
import type { SingleBar } from 'cli-progress';
|
|
5
|
+
|
|
6
|
+
import { EvaluateOptions, EvaluateSummary, EvaluateResult, ApiProvider, Prompt } from './types.js';
|
|
7
|
+
|
|
8
|
+
interface RunEvalOptions {
|
|
9
|
+
provider: ApiProvider;
|
|
10
|
+
prompt: string;
|
|
11
|
+
vars?: Record<string, string>;
|
|
12
|
+
includeProviderId?: boolean;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
const DEFAULT_MAX_CONCURRENCY = 3;
|
|
16
|
+
|
|
17
|
+
function checkExpectedValue(expected: string, output: string): boolean {
|
|
18
|
+
if (expected.startsWith('eval:')) {
|
|
19
|
+
const evalBody = expected.slice(5);
|
|
20
|
+
const evalFunction = new Function('output', `return ${evalBody}`);
|
|
21
|
+
return evalFunction(output);
|
|
22
|
+
} else if (expected.startsWith('grade:')) {
|
|
23
|
+
// NYI
|
|
24
|
+
return false;
|
|
25
|
+
} else {
|
|
26
|
+
return expected === output;
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
async function runEval({
|
|
31
|
+
provider,
|
|
32
|
+
prompt,
|
|
33
|
+
vars,
|
|
34
|
+
includeProviderId,
|
|
35
|
+
}: RunEvalOptions): Promise<EvaluateResult> {
|
|
36
|
+
vars = vars || {};
|
|
37
|
+
const renderedPrompt = nunjucks.renderString(prompt, vars);
|
|
38
|
+
|
|
39
|
+
// Note that we're using original prompt, not renderedPrompt
|
|
40
|
+
const promptDisplay = includeProviderId ? `[${provider.id()}] ${prompt}` : prompt;
|
|
41
|
+
|
|
42
|
+
const setup = {
|
|
43
|
+
prompt: {
|
|
44
|
+
raw: renderedPrompt,
|
|
45
|
+
display: promptDisplay,
|
|
46
|
+
},
|
|
47
|
+
vars,
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
try {
|
|
51
|
+
const response = await provider.callApi(renderedPrompt);
|
|
52
|
+
const success = vars.__expected ? checkExpectedValue(vars.__expected, response.output) : true;
|
|
53
|
+
const ret: EvaluateResult = {
|
|
54
|
+
...setup,
|
|
55
|
+
response,
|
|
56
|
+
success,
|
|
57
|
+
};
|
|
58
|
+
if (!success) {
|
|
59
|
+
ret.error = `Expected ${vars.__expected}, got "${response.output}"`;
|
|
60
|
+
}
|
|
61
|
+
return ret;
|
|
62
|
+
} catch (err) {
|
|
63
|
+
return {
|
|
64
|
+
...setup,
|
|
65
|
+
error: String(err),
|
|
66
|
+
success: false,
|
|
67
|
+
};
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
export async function evaluate(options: EvaluateOptions): Promise<EvaluateSummary> {
|
|
72
|
+
const prompts: Prompt[] = [];
|
|
73
|
+
const results: EvaluateResult[] = [];
|
|
74
|
+
|
|
75
|
+
for (const promptContent of options.prompts) {
|
|
76
|
+
for (const provider of options.providers) {
|
|
77
|
+
prompts.push({
|
|
78
|
+
raw: promptContent,
|
|
79
|
+
display:
|
|
80
|
+
options.providers.length > 1 ? `[${provider.id()}] ${promptContent}` : promptContent,
|
|
81
|
+
});
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
const vars = options.vars && options.vars.length > 0 ? options.vars : [{}];
|
|
86
|
+
const varsWithExpectedKeyRemoved = vars.map((v) => {
|
|
87
|
+
const ret = { ...v };
|
|
88
|
+
delete ret.__expected;
|
|
89
|
+
return ret;
|
|
90
|
+
});
|
|
91
|
+
const isTest = vars[0].__expected;
|
|
92
|
+
const table: string[][] = [
|
|
93
|
+
isTest
|
|
94
|
+
? [
|
|
95
|
+
'RESULT',
|
|
96
|
+
[...prompts.map((p) => p.display), ...Object.keys(varsWithExpectedKeyRemoved[0])],
|
|
97
|
+
].flat()
|
|
98
|
+
: [...prompts.map((p) => p.display), ...Object.keys(varsWithExpectedKeyRemoved[0])],
|
|
99
|
+
];
|
|
100
|
+
|
|
101
|
+
const stats = {
|
|
102
|
+
successes: 0,
|
|
103
|
+
failures: 0,
|
|
104
|
+
tokenUsage: {
|
|
105
|
+
total: 0,
|
|
106
|
+
prompt: 0,
|
|
107
|
+
completion: 0,
|
|
108
|
+
},
|
|
109
|
+
};
|
|
110
|
+
|
|
111
|
+
let progressbar: SingleBar | undefined;
|
|
112
|
+
if (options.showProgressBar) {
|
|
113
|
+
const totalNumRuns =
|
|
114
|
+
options.prompts.length * options.providers.length * (options.vars?.length || 1);
|
|
115
|
+
const cliProgress = await import('cli-progress');
|
|
116
|
+
progressbar = new cliProgress.SingleBar(
|
|
117
|
+
{
|
|
118
|
+
format:
|
|
119
|
+
'Eval: [{bar}] {percentage}% | ETA: {eta}s | {value}/{total} | {provider} "{prompt}" {vars}',
|
|
120
|
+
},
|
|
121
|
+
cliProgress.Presets.shades_classic,
|
|
122
|
+
);
|
|
123
|
+
progressbar.start(totalNumRuns, 0, {
|
|
124
|
+
provider: '',
|
|
125
|
+
prompt: '',
|
|
126
|
+
vars: '',
|
|
127
|
+
});
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
const runEvalOptions: RunEvalOptions[] = [];
|
|
131
|
+
for (const row of vars) {
|
|
132
|
+
for (const promptContent of options.prompts) {
|
|
133
|
+
for (const provider of options.providers) {
|
|
134
|
+
runEvalOptions.push({
|
|
135
|
+
provider,
|
|
136
|
+
prompt: promptContent,
|
|
137
|
+
vars: row,
|
|
138
|
+
includeProviderId: options.providers.length > 1,
|
|
139
|
+
});
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
const combinedOutputs: string[][] = new Array(vars.length).fill(null).map(() => []);
|
|
145
|
+
await async.forEachOfLimit(
|
|
146
|
+
runEvalOptions,
|
|
147
|
+
options.maxConcurrency || DEFAULT_MAX_CONCURRENCY,
|
|
148
|
+
async (options: RunEvalOptions, index: number | string) => {
|
|
149
|
+
const row = await runEval(options);
|
|
150
|
+
results.push(row);
|
|
151
|
+
if (row.error) {
|
|
152
|
+
stats.failures++;
|
|
153
|
+
} else {
|
|
154
|
+
if (row.success) {
|
|
155
|
+
stats.successes++;
|
|
156
|
+
} else {
|
|
157
|
+
stats.failures++;
|
|
158
|
+
}
|
|
159
|
+
stats.tokenUsage.total += row.response?.tokenUsage?.total || 0;
|
|
160
|
+
stats.tokenUsage.prompt += row.response?.tokenUsage?.prompt || 0;
|
|
161
|
+
stats.tokenUsage.completion += row.response?.tokenUsage?.completion || 0;
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
if (progressbar) {
|
|
165
|
+
progressbar.increment({
|
|
166
|
+
provider: options.provider.id(),
|
|
167
|
+
prompt: options.prompt.slice(0, 10),
|
|
168
|
+
vars: Object.entries(options.vars || {})
|
|
169
|
+
.map(([k, v]) => `${k}=${v}`)
|
|
170
|
+
.join(' ')
|
|
171
|
+
.slice(0, 10),
|
|
172
|
+
});
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// Bookkeeping for table
|
|
176
|
+
if (typeof index !== 'number') {
|
|
177
|
+
throw new Error('Expected index to be a number');
|
|
178
|
+
}
|
|
179
|
+
const combinedOutputIndex = Math.floor(index / prompts.length);
|
|
180
|
+
combinedOutputs[combinedOutputIndex].push(row.response?.output || '');
|
|
181
|
+
},
|
|
182
|
+
);
|
|
183
|
+
|
|
184
|
+
if (progressbar) {
|
|
185
|
+
progressbar.stop();
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
if (isTest) {
|
|
189
|
+
table.push(
|
|
190
|
+
...combinedOutputs.map((output, index) => [
|
|
191
|
+
results[index].success ? 'PASS' : `FAIL: ${results[index].error}`,
|
|
192
|
+
...output,
|
|
193
|
+
...Object.values(varsWithExpectedKeyRemoved[index]),
|
|
194
|
+
]),
|
|
195
|
+
);
|
|
196
|
+
} else {
|
|
197
|
+
table.push(
|
|
198
|
+
...combinedOutputs.map((output, index) => [...output, ...Object.values(vars[index])]),
|
|
199
|
+
);
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
return { results, stats, table };
|
|
203
|
+
}
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import { evaluate as doEvaluate } from './evaluator.js';
|
|
2
|
+
import { loadApiProvider } from './providers.js';
|
|
3
|
+
|
|
4
|
+
import type { ApiProvider, EvaluateOptions, EvaluateSummary } from './types.js';
|
|
5
|
+
|
|
6
|
+
async function evaluate(
|
|
7
|
+
providers: (string | ApiProvider)[] | (string | ApiProvider),
|
|
8
|
+
options: Omit<EvaluateOptions, 'providers'>,
|
|
9
|
+
): Promise<EvaluateSummary> {
|
|
10
|
+
let apiProviders: ApiProvider[] = [];
|
|
11
|
+
const addProvider = async (provider: ApiProvider | string) => {
|
|
12
|
+
if (typeof provider === 'string') {
|
|
13
|
+
apiProviders.push(await loadApiProvider(provider));
|
|
14
|
+
} else {
|
|
15
|
+
apiProviders.push(provider);
|
|
16
|
+
}
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
if (Array.isArray(providers)) {
|
|
20
|
+
for (const provider of providers) {
|
|
21
|
+
await addProvider(provider);
|
|
22
|
+
}
|
|
23
|
+
} else {
|
|
24
|
+
await addProvider(providers);
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
return doEvaluate({
|
|
28
|
+
...options,
|
|
29
|
+
providers: apiProviders,
|
|
30
|
+
});
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
export default {
|
|
34
|
+
evaluate,
|
|
35
|
+
};
|
package/src/logger.ts
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import chalk from 'chalk';
|
|
2
|
+
import winston from 'winston';
|
|
3
|
+
|
|
4
|
+
const logLevels = {
|
|
5
|
+
error: 0,
|
|
6
|
+
warn: 1,
|
|
7
|
+
info: 2,
|
|
8
|
+
debug: 3,
|
|
9
|
+
};
|
|
10
|
+
|
|
11
|
+
const customFormatter = winston.format.printf(({ level, message, ...args }) => {
|
|
12
|
+
if (level === 'error') {
|
|
13
|
+
return chalk.red(message);
|
|
14
|
+
} else if (level === 'warn') {
|
|
15
|
+
return chalk.yellow(message);
|
|
16
|
+
} else if (level === 'info') {
|
|
17
|
+
return message;
|
|
18
|
+
} else if (level === 'debug') {
|
|
19
|
+
return chalk.cyan(message);
|
|
20
|
+
}
|
|
21
|
+
throw new Error(`Invalid log level: ${level}`);
|
|
22
|
+
});
|
|
23
|
+
|
|
24
|
+
const logger = winston.createLogger({
|
|
25
|
+
levels: logLevels,
|
|
26
|
+
format: winston.format.combine(winston.format.simple(), customFormatter),
|
|
27
|
+
transports: [new winston.transports.Console()],
|
|
28
|
+
});
|
|
29
|
+
|
|
30
|
+
export function setLogLevel(level: keyof typeof logLevels) {
|
|
31
|
+
if (logLevels.hasOwnProperty(level)) {
|
|
32
|
+
logger.transports[0].level = level;
|
|
33
|
+
} else {
|
|
34
|
+
throw new Error(`Invalid log level: ${level}`);
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
export default logger;
|
package/src/main.ts
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
import { readFileSync } from 'fs';
|
|
3
|
+
import { parse } from 'path';
|
|
4
|
+
|
|
5
|
+
import Table from 'cli-table3';
|
|
6
|
+
import chalk from 'chalk';
|
|
7
|
+
import { Command } from 'commander';
|
|
8
|
+
|
|
9
|
+
import logger, { setLogLevel } from './logger.js';
|
|
10
|
+
import { loadApiProvider } from './providers.js';
|
|
11
|
+
import { evaluate } from './evaluator.js';
|
|
12
|
+
import { readPrompts, readVars, writeOutput } from './util.js';
|
|
13
|
+
|
|
14
|
+
import type { CommandLineOptions, EvaluateOptions, VarMapping } from './types.js';
|
|
15
|
+
|
|
16
|
+
const program = new Command();
|
|
17
|
+
|
|
18
|
+
program
|
|
19
|
+
.command('eval')
|
|
20
|
+
.description('Evaluate prompts')
|
|
21
|
+
.requiredOption('-p, --prompt <paths...>', 'Paths to prompt files (.txt)')
|
|
22
|
+
.requiredOption(
|
|
23
|
+
'-r, --provider <name or path...>',
|
|
24
|
+
'One of: openai:chat, openai:completion, openai:<model name>, or path to custom API caller module',
|
|
25
|
+
)
|
|
26
|
+
.option('-o, --output <path>', 'Path to output file (csv, json, yaml, html)')
|
|
27
|
+
.option('-v, --vars <path>', 'Path to file with prompt variables (csv, json, yaml)')
|
|
28
|
+
.option('-c, --config <path>', 'Path to configuration file')
|
|
29
|
+
.option('-j, --max-concurrency <number>', 'Maximum number of concurrent API calls')
|
|
30
|
+
.option('--verbose', 'Show debug logs')
|
|
31
|
+
.action(async (cmdObj: CommandLineOptions & Command) => {
|
|
32
|
+
if (cmdObj.verbose) {
|
|
33
|
+
setLogLevel('debug');
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
const configPath = cmdObj.config;
|
|
37
|
+
let config = {};
|
|
38
|
+
if (configPath) {
|
|
39
|
+
const ext = parse(configPath).ext;
|
|
40
|
+
switch (ext) {
|
|
41
|
+
case '.json':
|
|
42
|
+
const content = readFileSync(configPath, 'utf-8');
|
|
43
|
+
config = JSON.parse(content);
|
|
44
|
+
break;
|
|
45
|
+
case '.js':
|
|
46
|
+
config = require(configPath);
|
|
47
|
+
break;
|
|
48
|
+
default:
|
|
49
|
+
throw new Error(`Unsupported configuration file format: ${ext}`);
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
let vars: VarMapping[] = [];
|
|
54
|
+
if (cmdObj.vars) {
|
|
55
|
+
vars = readVars(cmdObj.vars);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
const providers = cmdObj.provider.map((p) => loadApiProvider(p));
|
|
59
|
+
const options: EvaluateOptions = {
|
|
60
|
+
prompts: readPrompts(cmdObj.prompt),
|
|
61
|
+
vars,
|
|
62
|
+
providers,
|
|
63
|
+
showProgressBar: true,
|
|
64
|
+
maxConcurrency:
|
|
65
|
+
cmdObj.maxConcurrency && cmdObj.maxConcurrency > 0 ? cmdObj.maxConcurrency : undefined,
|
|
66
|
+
...config,
|
|
67
|
+
};
|
|
68
|
+
|
|
69
|
+
const summary = await evaluate(options);
|
|
70
|
+
|
|
71
|
+
if (cmdObj.output) {
|
|
72
|
+
logger.info(chalk.yellow(`Writing output to ${cmdObj.output}`));
|
|
73
|
+
writeOutput(cmdObj.output, summary);
|
|
74
|
+
} else {
|
|
75
|
+
// Output table by default
|
|
76
|
+
const maxWidth = process.stdout.columns ? process.stdout.columns - 10 : 120;
|
|
77
|
+
const head = summary.table[0];
|
|
78
|
+
const table = new Table({
|
|
79
|
+
head,
|
|
80
|
+
colWidths: Array(head.length).fill(Math.floor(maxWidth / head.length)),
|
|
81
|
+
wordWrap: true,
|
|
82
|
+
wrapOnWordBoundary: true,
|
|
83
|
+
style: {
|
|
84
|
+
head: ['blue', 'bold'],
|
|
85
|
+
},
|
|
86
|
+
});
|
|
87
|
+
// Skip first row (header) and add the rest. Color the first column green if it's a success, red if it's a failure.
|
|
88
|
+
for (const row of summary.table.slice(1)) {
|
|
89
|
+
const color = row[0] === 'PASS' ? 'green' : row[0].startsWith('FAIL') ? 'red' : undefined;
|
|
90
|
+
table.push(row.map((col, i) => (i === 0 && color ? chalk[color](col) : col)));
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
logger.info('\n' + table.toString());
|
|
94
|
+
}
|
|
95
|
+
logger.info('Evaluation complete');
|
|
96
|
+
logger.info(chalk.green.bold(`Successes: ${summary.stats.successes}`));
|
|
97
|
+
logger.info(chalk.red.bold(`Failures: ${summary.stats.failures}`));
|
|
98
|
+
logger.info(
|
|
99
|
+
`Token usage: Total ${summary.stats.tokenUsage.total} Prompt ${summary.stats.tokenUsage.prompt} Completion ${summary.stats.tokenUsage.completion}`,
|
|
100
|
+
);
|
|
101
|
+
logger.info('Done.');
|
|
102
|
+
});
|
|
103
|
+
|
|
104
|
+
program.parse(process.argv);
|
|
105
|
+
|
|
106
|
+
if (!process.argv.slice(2).length) {
|
|
107
|
+
program.outputHelp();
|
|
108
|
+
}
|
package/src/providers.ts
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import fetch from 'node-fetch';
|
|
2
|
+
|
|
3
|
+
import { ApiProvider, ProviderResponse } from './types.js';
|
|
4
|
+
import logger from './logger.js';
|
|
5
|
+
|
|
6
|
+
export class OpenAiGenericProvider implements ApiProvider {
|
|
7
|
+
modelName: string;
|
|
8
|
+
apiKey: string;
|
|
9
|
+
|
|
10
|
+
constructor(modelName: string, apiKey?: string) {
|
|
11
|
+
this.modelName = modelName;
|
|
12
|
+
|
|
13
|
+
const key = apiKey || process.env.OPENAI_API_KEY;
|
|
14
|
+
if (!key) {
|
|
15
|
+
throw new Error(
|
|
16
|
+
'OpenAI API key is not set. Set OPENAI_API_KEY environment variable or pass it as an argument to the constructor.',
|
|
17
|
+
);
|
|
18
|
+
}
|
|
19
|
+
this.apiKey = key;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
id(): string {
|
|
23
|
+
return `openai:${this.modelName}`;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
toString(): string {
|
|
27
|
+
return `[OpenAI Provider ${this.modelName}]`;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
// @ts-ignore: Prompt is not used in this implementation
|
|
31
|
+
async callApi(prompt: string): Promise<ProviderResponse> {
|
|
32
|
+
throw new Error('Not implemented');
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
export class OpenAiCompletionProvider extends OpenAiGenericProvider {
|
|
37
|
+
static OPENAI_COMPLETION_MODELS = [
|
|
38
|
+
'text-davinci-003',
|
|
39
|
+
'text-davinci-002',
|
|
40
|
+
'text-curie-001',
|
|
41
|
+
'text-babbage-001',
|
|
42
|
+
'text-ada-001',
|
|
43
|
+
];
|
|
44
|
+
|
|
45
|
+
constructor(modelName: string, apiKey?: string) {
|
|
46
|
+
if (!OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelName)) {
|
|
47
|
+
throw new Error(
|
|
48
|
+
`Unknown OpenAI completion model name: ${modelName}. Use one of the following: ${OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.join(
|
|
49
|
+
', ',
|
|
50
|
+
)}`,
|
|
51
|
+
);
|
|
52
|
+
}
|
|
53
|
+
super(modelName, apiKey);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
async callApi(prompt: string): Promise<ProviderResponse> {
|
|
57
|
+
const body = {
|
|
58
|
+
model: this.modelName,
|
|
59
|
+
prompt,
|
|
60
|
+
max_tokens: process.env.OPENAI_MAX_TOKENS || 1024,
|
|
61
|
+
temperature: process.env.OPENAI_TEMPERATURE || 0,
|
|
62
|
+
};
|
|
63
|
+
logger.debug(`Calling OpenAI API: ${JSON.stringify(body)}`);
|
|
64
|
+
const response = await fetch('https://api.openai.com/v1/completions', {
|
|
65
|
+
method: 'POST',
|
|
66
|
+
headers: {
|
|
67
|
+
'Content-Type': 'application/json',
|
|
68
|
+
Authorization: `Bearer ${this.apiKey}`,
|
|
69
|
+
},
|
|
70
|
+
body: JSON.stringify(body),
|
|
71
|
+
});
|
|
72
|
+
|
|
73
|
+
const data = (await response.json()) as unknown as any;
|
|
74
|
+
logger.debug(`\tOpenAI API response: ${JSON.stringify(data)}`);
|
|
75
|
+
return {
|
|
76
|
+
output: data.choices[0].text,
|
|
77
|
+
tokenUsage: {
|
|
78
|
+
total: data.usage.total_tokens,
|
|
79
|
+
prompt: data.usage.prompt_tokens,
|
|
80
|
+
completion: data.usage.completion_tokens,
|
|
81
|
+
},
|
|
82
|
+
};
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
|
|
87
|
+
static OPENAI_CHAT_MODELS = [
|
|
88
|
+
'gpt-4',
|
|
89
|
+
'gpt-4-0314',
|
|
90
|
+
'gpt-4-32k',
|
|
91
|
+
'gpt-4-32k-0314',
|
|
92
|
+
'gpt-3.5-turbo',
|
|
93
|
+
'gpt-3.5-turbo-0301',
|
|
94
|
+
];
|
|
95
|
+
|
|
96
|
+
constructor(modelName: string, apiKey?: string) {
|
|
97
|
+
if (!OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelName)) {
|
|
98
|
+
throw new Error(
|
|
99
|
+
`Unknown OpenAI completion model name: ${modelName}. Use one of the following: ${OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.join(
|
|
100
|
+
', ',
|
|
101
|
+
)}`,
|
|
102
|
+
);
|
|
103
|
+
}
|
|
104
|
+
super(modelName, apiKey);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
async callApi(prompt: string): Promise<ProviderResponse> {
|
|
108
|
+
let messages: { role: string; content: string }[];
|
|
109
|
+
try {
|
|
110
|
+
// User can specify `messages` payload as JSON, or we'll just put the
|
|
111
|
+
// string prompt into a `messages` array.
|
|
112
|
+
messages = JSON.parse(prompt);
|
|
113
|
+
} catch (e) {
|
|
114
|
+
messages = [{ role: 'user', content: prompt }];
|
|
115
|
+
}
|
|
116
|
+
const body = {
|
|
117
|
+
model: this.modelName,
|
|
118
|
+
messages: messages,
|
|
119
|
+
max_tokens: process.env.OPENAI_MAX_TOKENS || 1024,
|
|
120
|
+
temperature: process.env.OPENAI_MAX_TEMPERATURE || 0,
|
|
121
|
+
};
|
|
122
|
+
logger.debug(`Calling OpenAI API: ${JSON.stringify(body)}`);
|
|
123
|
+
const response = await fetch('https://api.openai.com/v1/chat/completions', {
|
|
124
|
+
method: 'POST',
|
|
125
|
+
headers: {
|
|
126
|
+
'Content-Type': 'application/json',
|
|
127
|
+
Authorization: `Bearer ${this.apiKey}`,
|
|
128
|
+
},
|
|
129
|
+
body: JSON.stringify(body),
|
|
130
|
+
});
|
|
131
|
+
|
|
132
|
+
const data = (await response.json()) as unknown as any;
|
|
133
|
+
logger.debug(`\tOpenAI API response: ${JSON.stringify(data)}`);
|
|
134
|
+
return {
|
|
135
|
+
output: data.choices[0].message.content,
|
|
136
|
+
tokenUsage: {
|
|
137
|
+
total: data.usage.total_tokens,
|
|
138
|
+
prompt: data.usage.prompt_tokens,
|
|
139
|
+
completion: data.usage.completion_tokens,
|
|
140
|
+
},
|
|
141
|
+
};
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
export function loadApiProvider(providerPath: string): ApiProvider {
|
|
146
|
+
if (providerPath?.startsWith('openai:')) {
|
|
147
|
+
// Load OpenAI module
|
|
148
|
+
const options = providerPath.split(':');
|
|
149
|
+
const modelType = options[1];
|
|
150
|
+
const modelName = options[2];
|
|
151
|
+
|
|
152
|
+
if (modelType === 'chat') {
|
|
153
|
+
return new OpenAiChatCompletionProvider(modelName || 'gpt-3.5-turbo');
|
|
154
|
+
} else if (modelType === 'completion') {
|
|
155
|
+
return new OpenAiCompletionProvider(modelName || 'text-davinci-003');
|
|
156
|
+
} else if (OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelType)) {
|
|
157
|
+
return new OpenAiChatCompletionProvider(modelType);
|
|
158
|
+
} else if (OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelType)) {
|
|
159
|
+
return new OpenAiCompletionProvider(modelType);
|
|
160
|
+
} else {
|
|
161
|
+
throw new Error(
|
|
162
|
+
`Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
|
|
163
|
+
);
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
// Load custom module
|
|
168
|
+
const CustomApiProvider = require(providerPath).default;
|
|
169
|
+
return new CustomApiProvider();
|
|
170
|
+
}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
<!DOCTYPE html>
|
|
2
|
+
<html>
|
|
3
|
+
<head>
|
|
4
|
+
<meta charset="utf-8" />
|
|
5
|
+
<meta name="viewport" content="width=device-width" />
|
|
6
|
+
<title>Table Output</title>
|
|
7
|
+
<style>
|
|
8
|
+
body {
|
|
9
|
+
font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica, Arial,
|
|
10
|
+
sans-serif;
|
|
11
|
+
}
|
|
12
|
+
table,
|
|
13
|
+
th,
|
|
14
|
+
td {
|
|
15
|
+
border: 1px solid black;
|
|
16
|
+
border-collapse: collapse;
|
|
17
|
+
text-align: left;
|
|
18
|
+
word-break: break-all;
|
|
19
|
+
}
|
|
20
|
+
th,
|
|
21
|
+
td {
|
|
22
|
+
padding: 5px;
|
|
23
|
+
}
|
|
24
|
+
/* If data-content is exactly "PASS", set font color to green */
|
|
25
|
+
tr > td[data-content='PASS']:first-child {
|
|
26
|
+
color: green;
|
|
27
|
+
font-weight: bold;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
/* If data-content starts with "FAIL", set font color to red */
|
|
31
|
+
tr > td[data-content^='FAIL']:first-child {
|
|
32
|
+
color: red;
|
|
33
|
+
font-weight: bold;
|
|
34
|
+
}
|
|
35
|
+
</style>
|
|
36
|
+
</head>
|
|
37
|
+
<body>
|
|
38
|
+
<table>
|
|
39
|
+
<thead>
|
|
40
|
+
{% for header in table[0] %}
|
|
41
|
+
<th>{{ header }}</th>
|
|
42
|
+
{% endfor %}
|
|
43
|
+
</thead>
|
|
44
|
+
<tbody>
|
|
45
|
+
{% for row in table.slice(1) %}
|
|
46
|
+
<tr>
|
|
47
|
+
{% for cell in row %}
|
|
48
|
+
<td data-content="{{cell}}">{{ cell }}</td>
|
|
49
|
+
{% endfor %}
|
|
50
|
+
</tr>
|
|
51
|
+
{% endfor %}
|
|
52
|
+
</tbody>
|
|
53
|
+
</table>
|
|
54
|
+
</body>
|
|
55
|
+
</html>
|
package/src/types.ts
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
export interface CommandLineOptions {
|
|
2
|
+
prompt: string[];
|
|
3
|
+
provider: string[];
|
|
4
|
+
output?: string;
|
|
5
|
+
vars?: string;
|
|
6
|
+
config?: string;
|
|
7
|
+
verbose?: boolean;
|
|
8
|
+
maxConcurrency?: number;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
export interface ApiProvider {
|
|
12
|
+
id: () => string;
|
|
13
|
+
callApi: (prompt: string) => Promise<ProviderResponse>;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
interface TokenUsage {
|
|
17
|
+
total: number;
|
|
18
|
+
prompt: number;
|
|
19
|
+
completion: number;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
export interface ProviderResponse {
|
|
23
|
+
output: string;
|
|
24
|
+
tokenUsage?: TokenUsage;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
export interface CsvRow {
|
|
28
|
+
[key: string]: string;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
export type VarMapping = Record<string, string>;
|
|
32
|
+
|
|
33
|
+
export interface EvaluateOptions {
|
|
34
|
+
providers: ApiProvider[];
|
|
35
|
+
prompts: string[];
|
|
36
|
+
vars?: VarMapping[];
|
|
37
|
+
|
|
38
|
+
maxConcurrency?: number;
|
|
39
|
+
showProgressBar?: boolean;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
export interface Prompt {
|
|
43
|
+
raw: string;
|
|
44
|
+
display: string;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
export interface EvaluateResult {
|
|
48
|
+
prompt: Prompt;
|
|
49
|
+
vars: Record<string, string>;
|
|
50
|
+
response?: ProviderResponse;
|
|
51
|
+
error?: string;
|
|
52
|
+
success: boolean;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
export interface EvaluateSummary {
|
|
56
|
+
results: EvaluateResult[];
|
|
57
|
+
table: string[][];
|
|
58
|
+
stats: {
|
|
59
|
+
successes: number;
|
|
60
|
+
failures: number;
|
|
61
|
+
tokenUsage: TokenUsage;
|
|
62
|
+
};
|
|
63
|
+
}
|