promptfoo 0.19.1 → 0.19.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +2 -2
- package/dist/package.json +2 -1
- package/dist/src/assertions.d.ts +2 -3
- package/dist/src/assertions.d.ts.map +1 -1
- package/dist/src/assertions.js +44 -108
- package/dist/src/assertions.js.map +1 -1
- package/dist/src/evaluator.d.ts.map +1 -1
- package/dist/src/evaluator.js +5 -1
- package/dist/src/evaluator.js.map +1 -1
- package/dist/src/index.d.ts +2 -2
- package/dist/src/main.js +2 -2
- package/dist/src/main.js.map +1 -1
- package/dist/src/matchers.d.ts +4 -0
- package/dist/src/matchers.d.ts.map +1 -0
- package/dist/src/matchers.js +102 -0
- package/dist/src/matchers.js.map +1 -0
- package/dist/src/providers/azureopenai.d.ts.map +1 -1
- package/dist/src/providers/azureopenai.js +1 -1
- package/dist/src/providers/azureopenai.js.map +1 -1
- package/dist/src/providers/openai.d.ts.map +1 -1
- package/dist/src/providers/openai.js +2 -4
- package/dist/src/providers/openai.js.map +1 -1
- package/dist/src/providers/scriptCompletion.d.ts +2 -2
- package/dist/src/providers/scriptCompletion.d.ts.map +1 -1
- package/dist/src/providers/scriptCompletion.js.map +1 -1
- package/dist/src/providers.d.ts +3 -3
- package/dist/src/providers.d.ts.map +1 -1
- package/dist/src/providers.js +11 -10
- package/dist/src/providers.js.map +1 -1
- package/dist/src/types.d.ts +5 -5
- package/dist/src/types.d.ts.map +1 -1
- package/dist/src/util.js.map +1 -1
- package/dist/src/web/nextui/404/index.html +1 -1
- package/dist/src/web/nextui/404.html +1 -1
- package/dist/src/web/nextui/api +1 -1
- package/dist/src/web/nextui/eval/index.html +1 -1
- package/dist/src/web/nextui/eval/index.txt +1 -1
- package/dist/src/web/nextui/index.html +1 -1
- package/dist/src/web/nextui/index.txt +1 -1
- package/dist/src/web/nextui/setup/index.html +1 -1
- package/dist/src/web/nextui/setup/index.txt +1 -1
- package/package.json +2 -1
- package/src/assertions.ts +55 -131
- package/src/evaluator.ts +5 -1
- package/src/main.ts +6 -2
- package/src/matchers.ts +120 -0
- package/src/providers/azureopenai.ts +1 -2
- package/src/providers/openai.ts +2 -5
- package/src/providers/scriptCompletion.ts +2 -2
- package/src/providers.ts +20 -19
- package/src/types.ts +10 -4
- package/src/util.ts +2 -2
- package/src/web/nextui/src/app/setup/ProviderConfigDialog.tsx +3 -3
- package/src/web/nextui/src/app/setup/ProviderSelector.tsx +12 -12
- package/src/web/nextui/src/util/store.ts +3 -3
- /package/dist/src/web/nextui/_next/static/{i1iOxHlErUK1hGZ9mGI2E → eCTjsASjQCuaN3ajMqfGS}/_buildManifest.js +0 -0
- /package/dist/src/web/nextui/_next/static/{i1iOxHlErUK1hGZ9mGI2E → eCTjsASjQCuaN3ajMqfGS}/_ssgManifest.js +0 -0
package/src/matchers.ts
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import { DefaultEmbeddingProvider, DefaultGradingProvider } from './providers/openai';
|
|
2
|
+
import { cosineSimilarity, getNunjucksEngine } from './util';
|
|
3
|
+
import { loadApiProvider } from './providers';
|
|
4
|
+
import { DEFAULT_GRADING_PROMPT } from './prompts';
|
|
5
|
+
|
|
6
|
+
import type { GradingConfig, GradingResult } from './types';
|
|
7
|
+
|
|
8
|
+
const nunjucks = getNunjucksEngine();
|
|
9
|
+
|
|
10
|
+
export async function matchesSimilarity(
|
|
11
|
+
expected: string,
|
|
12
|
+
output: string,
|
|
13
|
+
threshold: number,
|
|
14
|
+
inverse: boolean = false,
|
|
15
|
+
): Promise<Omit<GradingResult, 'assertion'>> {
|
|
16
|
+
const expectedEmbedding = await DefaultEmbeddingProvider.callEmbeddingApi(expected);
|
|
17
|
+
const outputEmbedding = await DefaultEmbeddingProvider.callEmbeddingApi(output);
|
|
18
|
+
|
|
19
|
+
const tokensUsed = {
|
|
20
|
+
total: (expectedEmbedding.tokenUsage?.total || 0) + (outputEmbedding.tokenUsage?.total || 0),
|
|
21
|
+
prompt: (expectedEmbedding.tokenUsage?.prompt || 0) + (outputEmbedding.tokenUsage?.prompt || 0),
|
|
22
|
+
completion:
|
|
23
|
+
(expectedEmbedding.tokenUsage?.completion || 0) +
|
|
24
|
+
(outputEmbedding.tokenUsage?.completion || 0),
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
if (expectedEmbedding.error || outputEmbedding.error) {
|
|
28
|
+
return {
|
|
29
|
+
pass: false,
|
|
30
|
+
score: 0,
|
|
31
|
+
reason:
|
|
32
|
+
expectedEmbedding.error || outputEmbedding.error || 'Unknown error fetching embeddings',
|
|
33
|
+
tokensUsed,
|
|
34
|
+
};
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
if (!expectedEmbedding.embedding || !outputEmbedding.embedding) {
|
|
38
|
+
return {
|
|
39
|
+
pass: false,
|
|
40
|
+
score: 0,
|
|
41
|
+
reason: 'Embedding not found',
|
|
42
|
+
tokensUsed,
|
|
43
|
+
};
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
const similarity = cosineSimilarity(expectedEmbedding.embedding, outputEmbedding.embedding);
|
|
47
|
+
const pass = inverse ? similarity <= threshold : similarity >= threshold;
|
|
48
|
+
const greaterThanReason = `Similarity ${similarity} is greater than threshold ${threshold}`;
|
|
49
|
+
const lessThanReason = `Similarity ${similarity} is less than threshold ${threshold}`;
|
|
50
|
+
if (pass) {
|
|
51
|
+
return {
|
|
52
|
+
pass: true,
|
|
53
|
+
score: inverse ? 1 - similarity : similarity,
|
|
54
|
+
reason: inverse ? lessThanReason : greaterThanReason,
|
|
55
|
+
tokensUsed,
|
|
56
|
+
};
|
|
57
|
+
}
|
|
58
|
+
return {
|
|
59
|
+
pass: false,
|
|
60
|
+
score: inverse ? 1 - similarity : similarity,
|
|
61
|
+
reason: inverse ? greaterThanReason : lessThanReason,
|
|
62
|
+
tokensUsed,
|
|
63
|
+
};
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
export async function matchesLlmRubric(
|
|
67
|
+
expected: string,
|
|
68
|
+
output: string,
|
|
69
|
+
grading?: GradingConfig,
|
|
70
|
+
): Promise<Omit<GradingResult, 'assertion'>> {
|
|
71
|
+
if (!grading) {
|
|
72
|
+
throw new Error(
|
|
73
|
+
'Cannot grade output without grading config. Specify --grader option or grading config.',
|
|
74
|
+
);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
const prompt = nunjucks.renderString(grading.rubricPrompt || DEFAULT_GRADING_PROMPT, {
|
|
78
|
+
output: output.replace(/\n/g, '\\n').replace(/"/g, '\\"'),
|
|
79
|
+
rubric: expected.replace(/\n/g, '\\n').replace(/"/g, '\\"'),
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
let provider = grading.provider || DefaultGradingProvider;
|
|
83
|
+
if (typeof provider === 'string') {
|
|
84
|
+
provider = await loadApiProvider(provider);
|
|
85
|
+
}
|
|
86
|
+
const resp = await provider.callApi(prompt);
|
|
87
|
+
if (resp.error || !resp.output) {
|
|
88
|
+
return {
|
|
89
|
+
pass: false,
|
|
90
|
+
score: 0,
|
|
91
|
+
reason: resp.error || 'No output',
|
|
92
|
+
tokensUsed: {
|
|
93
|
+
total: resp.tokenUsage?.total || 0,
|
|
94
|
+
prompt: resp.tokenUsage?.prompt || 0,
|
|
95
|
+
completion: resp.tokenUsage?.completion || 0,
|
|
96
|
+
},
|
|
97
|
+
};
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
try {
|
|
101
|
+
const parsed = JSON.parse(resp.output) as Omit<GradingResult, 'score'>;
|
|
102
|
+
parsed.tokensUsed = {
|
|
103
|
+
total: resp.tokenUsage?.total || 0,
|
|
104
|
+
prompt: resp.tokenUsage?.prompt || 0,
|
|
105
|
+
completion: resp.tokenUsage?.completion || 0,
|
|
106
|
+
};
|
|
107
|
+
return { ...parsed, score: parsed.pass ? 1 : 0 };
|
|
108
|
+
} catch (err) {
|
|
109
|
+
return {
|
|
110
|
+
pass: false,
|
|
111
|
+
score: 0,
|
|
112
|
+
reason: `Output is not valid JSON: ${resp.output}`,
|
|
113
|
+
tokensUsed: {
|
|
114
|
+
total: resp.tokenUsage?.total || 0,
|
|
115
|
+
prompt: resp.tokenUsage?.prompt || 0,
|
|
116
|
+
completion: resp.tokenUsage?.completion || 0,
|
|
117
|
+
},
|
|
118
|
+
};
|
|
119
|
+
}
|
|
120
|
+
}
|
|
@@ -279,8 +279,7 @@ export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvide
|
|
|
279
279
|
logger.debug(`\tAzure OpenAI API response: ${JSON.stringify(data)}`);
|
|
280
280
|
try {
|
|
281
281
|
const message = data.choices[0].message;
|
|
282
|
-
const output =
|
|
283
|
-
message.content === null ? JSON.stringify(message.function_call) : message.content;
|
|
282
|
+
const output = message.content == null ? message.function_call : message.content;
|
|
284
283
|
return {
|
|
285
284
|
output,
|
|
286
285
|
tokenUsage: cached
|
package/src/providers/openai.ts
CHANGED
|
@@ -78,9 +78,7 @@ export class OpenAiEmbeddingProvider extends OpenAiGenericProvider {
|
|
|
78
78
|
headers: {
|
|
79
79
|
'Content-Type': 'application/json',
|
|
80
80
|
Authorization: `Bearer ${this.getApiKey()}`,
|
|
81
|
-
...(this.getOrganization()
|
|
82
|
-
? { 'OpenAI-Organization': this.getOrganization() }
|
|
83
|
-
: {}),
|
|
81
|
+
...(this.getOrganization() ? { 'OpenAI-Organization': this.getOrganization() } : {}),
|
|
84
82
|
},
|
|
85
83
|
body: JSON.stringify(body),
|
|
86
84
|
},
|
|
@@ -315,8 +313,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
|
|
|
315
313
|
logger.debug(`\tOpenAI API response: ${JSON.stringify(data)}`);
|
|
316
314
|
try {
|
|
317
315
|
const message = data.choices[0].message;
|
|
318
|
-
const output =
|
|
319
|
-
message.content === null ? JSON.stringify(message.function_call) : message.content;
|
|
316
|
+
const output = message.content === null ? message.function_call : message.content;
|
|
320
317
|
return {
|
|
321
318
|
output,
|
|
322
319
|
tokenUsage: cached
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { exec } from 'child_process';
|
|
2
2
|
|
|
3
|
-
import { ApiProvider,
|
|
3
|
+
import { ApiProvider, ProviderOptions, ProviderResponse } from '../types';
|
|
4
4
|
|
|
5
5
|
const ANSI_ESCAPE = /\x1b(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g;
|
|
6
6
|
|
|
@@ -9,7 +9,7 @@ function stripText(text: string) {
|
|
|
9
9
|
}
|
|
10
10
|
|
|
11
11
|
export class ScriptCompletionProvider implements ApiProvider {
|
|
12
|
-
constructor(private scriptPath: string, private config?:
|
|
12
|
+
constructor(private scriptPath: string, private config?: ProviderOptions) {}
|
|
13
13
|
|
|
14
14
|
id() {
|
|
15
15
|
return `exec:${this.scriptPath}`;
|
package/src/providers.ts
CHANGED
|
@@ -14,18 +14,18 @@ import {
|
|
|
14
14
|
|
|
15
15
|
import type {
|
|
16
16
|
ApiProvider,
|
|
17
|
-
|
|
17
|
+
ProviderOptions,
|
|
18
18
|
ProviderFunction,
|
|
19
19
|
ProviderId,
|
|
20
|
-
|
|
20
|
+
ProviderOptionsMap,
|
|
21
21
|
} from './types';
|
|
22
22
|
|
|
23
23
|
export async function loadApiProviders(
|
|
24
24
|
providerPaths:
|
|
25
25
|
| ProviderId
|
|
26
26
|
| ProviderId[]
|
|
27
|
-
|
|
|
28
|
-
|
|
|
27
|
+
| ProviderOptionsMap[]
|
|
28
|
+
| ProviderOptions[]
|
|
29
29
|
| ProviderFunction,
|
|
30
30
|
basePath?: string,
|
|
31
31
|
): Promise<ApiProvider[]> {
|
|
@@ -50,11 +50,11 @@ export async function loadApiProviders(
|
|
|
50
50
|
};
|
|
51
51
|
} else if (provider.id) {
|
|
52
52
|
// List of ProviderConfig objects
|
|
53
|
-
return loadApiProvider((provider as
|
|
53
|
+
return loadApiProvider((provider as ProviderOptions).id!, provider, basePath);
|
|
54
54
|
} else {
|
|
55
55
|
// List of { id: string, config: ProviderConfig } objects
|
|
56
56
|
const id = Object.keys(provider)[0];
|
|
57
|
-
const providerObject = (provider as
|
|
57
|
+
const providerObject = (provider as ProviderOptionsMap)[id];
|
|
58
58
|
const context = { ...providerObject, id: providerObject.id || id };
|
|
59
59
|
return loadApiProvider(id, context, basePath);
|
|
60
60
|
}
|
|
@@ -66,9 +66,10 @@ export async function loadApiProviders(
|
|
|
66
66
|
|
|
67
67
|
export async function loadApiProvider(
|
|
68
68
|
providerPath: string,
|
|
69
|
-
context?:
|
|
69
|
+
context?: ProviderOptions,
|
|
70
70
|
basePath?: string,
|
|
71
71
|
): Promise<ApiProvider> {
|
|
72
|
+
context = context || {};
|
|
72
73
|
if (providerPath?.startsWith('exec:')) {
|
|
73
74
|
// Load script module
|
|
74
75
|
const scriptPath = providerPath.split(':')[1];
|
|
@@ -86,18 +87,18 @@ export async function loadApiProvider(
|
|
|
86
87
|
return new OpenAiChatCompletionProvider(
|
|
87
88
|
modelName || 'gpt-3.5-turbo',
|
|
88
89
|
undefined,
|
|
89
|
-
context
|
|
90
|
+
context.config,
|
|
90
91
|
);
|
|
91
92
|
} else if (modelType === 'completion') {
|
|
92
93
|
return new OpenAiCompletionProvider(
|
|
93
94
|
modelName || 'text-davinci-003',
|
|
94
95
|
undefined,
|
|
95
|
-
context
|
|
96
|
+
context.config,
|
|
96
97
|
);
|
|
97
98
|
} else if (OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelType)) {
|
|
98
|
-
return new OpenAiChatCompletionProvider(modelType, undefined, context
|
|
99
|
+
return new OpenAiChatCompletionProvider(modelType, undefined, context.config, context.id);
|
|
99
100
|
} else if (OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelType)) {
|
|
100
|
-
return new OpenAiCompletionProvider(modelType, undefined, context
|
|
101
|
+
return new OpenAiCompletionProvider(modelType, undefined, context.config, context.id);
|
|
101
102
|
} else {
|
|
102
103
|
throw new Error(
|
|
103
104
|
`Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
|
|
@@ -113,15 +114,15 @@ export async function loadApiProvider(
|
|
|
113
114
|
return new AzureOpenAiChatCompletionProvider(
|
|
114
115
|
deploymentName,
|
|
115
116
|
undefined,
|
|
116
|
-
context
|
|
117
|
-
context
|
|
117
|
+
context.config,
|
|
118
|
+
context.id,
|
|
118
119
|
);
|
|
119
120
|
} else if (modelType === 'completion') {
|
|
120
121
|
return new AzureOpenAiCompletionProvider(
|
|
121
122
|
deploymentName,
|
|
122
123
|
undefined,
|
|
123
|
-
context
|
|
124
|
-
context
|
|
124
|
+
context.config,
|
|
125
|
+
context.id,
|
|
125
126
|
);
|
|
126
127
|
} else {
|
|
127
128
|
throw new Error(
|
|
@@ -138,10 +139,10 @@ export async function loadApiProvider(
|
|
|
138
139
|
return new AnthropicCompletionProvider(
|
|
139
140
|
modelName || 'claude-instant-1',
|
|
140
141
|
undefined,
|
|
141
|
-
context
|
|
142
|
+
context.config,
|
|
142
143
|
);
|
|
143
144
|
} else if (AnthropicCompletionProvider.ANTHROPIC_COMPLETION_MODELS.includes(modelType)) {
|
|
144
|
-
return new AnthropicCompletionProvider(modelType, undefined, context
|
|
145
|
+
return new AnthropicCompletionProvider(modelType, undefined, context.config);
|
|
145
146
|
} else {
|
|
146
147
|
throw new Error(
|
|
147
148
|
`Unknown Anthropic model type: ${modelType}. Use one of the following providers: anthropic:completion:<model name>`,
|
|
@@ -152,12 +153,12 @@ export async function loadApiProvider(
|
|
|
152
153
|
const options = providerPath.split(':');
|
|
153
154
|
const modelName = options.slice(1).join(':');
|
|
154
155
|
|
|
155
|
-
return new ReplicateProvider(modelName, undefined, context
|
|
156
|
+
return new ReplicateProvider(modelName, undefined, context.config);
|
|
156
157
|
}
|
|
157
158
|
|
|
158
159
|
if (providerPath === 'llama' || providerPath.startsWith('llama:')) {
|
|
159
160
|
const modelName = providerPath.split(':')[1];
|
|
160
|
-
return new LlamaProvider(modelName, context
|
|
161
|
+
return new LlamaProvider(modelName, context.config);
|
|
161
162
|
} else if (providerPath.startsWith('ollama:')) {
|
|
162
163
|
const modelName = providerPath.split(':')[1];
|
|
163
164
|
return new OllamaProvider(modelName);
|
package/src/types.ts
CHANGED
|
@@ -27,7 +27,7 @@ export interface CommandLineOptions {
|
|
|
27
27
|
promptSuffix?: string;
|
|
28
28
|
}
|
|
29
29
|
|
|
30
|
-
export interface
|
|
30
|
+
export interface ProviderOptions {
|
|
31
31
|
id?: ProviderId;
|
|
32
32
|
config?: any;
|
|
33
33
|
prompts?: string[]; // List of prompt display strings
|
|
@@ -177,6 +177,7 @@ export interface Assertion {
|
|
|
177
177
|
value?:
|
|
178
178
|
| string
|
|
179
179
|
| string[]
|
|
180
|
+
| object
|
|
180
181
|
| ((output: string, testCase: AtomicTestCase, assertion: Assertion) => Promise<GradingResult>);
|
|
181
182
|
|
|
182
183
|
// The threshold value, only applicable for similarity (cosine distance)
|
|
@@ -186,7 +187,7 @@ export interface Assertion {
|
|
|
186
187
|
weight?: number;
|
|
187
188
|
|
|
188
189
|
// Some assertions (similarity, llm-rubric) require an LLM provider
|
|
189
|
-
provider?:
|
|
190
|
+
provider?: GradingConfig['provider'];
|
|
190
191
|
}
|
|
191
192
|
|
|
192
193
|
// Each test case is graded pass/fail. A test case represents a unique input to the LLM after substituting `vars` in the prompt.
|
|
@@ -249,7 +250,7 @@ export type ProviderId = string;
|
|
|
249
250
|
|
|
250
251
|
export type ProviderFunction = (prompt: string) => Promise<ProviderResponse>;
|
|
251
252
|
|
|
252
|
-
export type
|
|
253
|
+
export type ProviderOptionsMap = Record<ProviderId, ProviderOptions>;
|
|
253
254
|
|
|
254
255
|
// TestSuiteConfig = Test Suite, but before everything is parsed and resolved. Providers are just strings, prompts are filepaths, tests can be filepath or inline.
|
|
255
256
|
export interface TestSuiteConfig {
|
|
@@ -257,7 +258,12 @@ export interface TestSuiteConfig {
|
|
|
257
258
|
description?: string;
|
|
258
259
|
|
|
259
260
|
// One or more LLM APIs to use, for example: openai:gpt-3.5-turbo, openai:gpt-4, localai:chat:vicuna
|
|
260
|
-
providers:
|
|
261
|
+
providers:
|
|
262
|
+
| ProviderId
|
|
263
|
+
| ProviderId[]
|
|
264
|
+
| ProviderOptionsMap[]
|
|
265
|
+
| ProviderOptions[]
|
|
266
|
+
| ProviderFunction;
|
|
261
267
|
|
|
262
268
|
// One or more prompt files to load
|
|
263
269
|
prompts: string | string[];
|
package/src/util.ts
CHANGED
|
@@ -24,7 +24,7 @@ import type {
|
|
|
24
24
|
UnifiedConfig,
|
|
25
25
|
TestCase,
|
|
26
26
|
Prompt,
|
|
27
|
-
|
|
27
|
+
ProviderOptionsMap,
|
|
28
28
|
TestSuite,
|
|
29
29
|
} from './types';
|
|
30
30
|
|
|
@@ -53,7 +53,7 @@ export function readProviderPromptMap(
|
|
|
53
53
|
|
|
54
54
|
for (const provider of config.providers) {
|
|
55
55
|
if (typeof provider === 'object') {
|
|
56
|
-
const rawProvider = provider as
|
|
56
|
+
const rawProvider = provider as ProviderOptionsMap;
|
|
57
57
|
const originalId = Object.keys(rawProvider)[0];
|
|
58
58
|
const providerObject = rawProvider[originalId];
|
|
59
59
|
const id = providerObject.id || originalId;
|
|
@@ -8,14 +8,14 @@ import {
|
|
|
8
8
|
DialogActions,
|
|
9
9
|
Button,
|
|
10
10
|
} from '@mui/material';
|
|
11
|
-
import {
|
|
11
|
+
import { ProviderOptions } from '../../../../../types';
|
|
12
12
|
|
|
13
13
|
interface ProviderConfigDialogProps {
|
|
14
14
|
open: boolean;
|
|
15
15
|
providerId: string;
|
|
16
|
-
config:
|
|
16
|
+
config: ProviderOptions['config'];
|
|
17
17
|
onClose: () => void;
|
|
18
|
-
onSave: (config:
|
|
18
|
+
onSave: (config: ProviderOptions['config']) => void;
|
|
19
19
|
}
|
|
20
20
|
|
|
21
21
|
const ProviderConfigDialog: React.FC<ProviderConfigDialogProps> = ({
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import React from 'react';
|
|
2
2
|
import { Autocomplete, Box, Chip, TextField } from '@mui/material';
|
|
3
|
-
import {
|
|
3
|
+
import { ProviderOptions } from '../../../../../types';
|
|
4
4
|
import ProviderConfigDialog from './ProviderConfigDialog';
|
|
5
5
|
|
|
6
|
-
const defaultProviders:
|
|
6
|
+
const defaultProviders: ProviderOptions[] = [
|
|
7
7
|
{
|
|
8
8
|
id: 'replicate:replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48',
|
|
9
9
|
config: { temperature: 0.5 },
|
|
@@ -48,38 +48,38 @@ const defaultProviders: ProviderConfig[] = [
|
|
|
48
48
|
.sort((a, b) => a.id.localeCompare(b.id));
|
|
49
49
|
|
|
50
50
|
interface ProviderSelectorProps {
|
|
51
|
-
providers:
|
|
52
|
-
onChange: (providers:
|
|
51
|
+
providers: ProviderOptions[];
|
|
52
|
+
onChange: (providers: ProviderOptions[]) => void;
|
|
53
53
|
}
|
|
54
54
|
|
|
55
55
|
const ProviderSelector: React.FC<ProviderSelectorProps> = ({ providers, onChange }) => {
|
|
56
|
-
const [selectedProvider, setSelectedProvider] = React.useState<
|
|
56
|
+
const [selectedProvider, setSelectedProvider] = React.useState<ProviderOptions | null>(null);
|
|
57
57
|
|
|
58
|
-
const getProviderLabel = (provider: string |
|
|
58
|
+
const getProviderLabel = (provider: string | ProviderOptions) => {
|
|
59
59
|
if (typeof provider === 'string') {
|
|
60
60
|
return provider;
|
|
61
61
|
}
|
|
62
62
|
return provider.id || 'Unknown provider';
|
|
63
63
|
};
|
|
64
64
|
|
|
65
|
-
const getProviderKey = (provider: string |
|
|
65
|
+
const getProviderKey = (provider: string | ProviderOptions, index: number) => {
|
|
66
66
|
if (typeof provider === 'string') {
|
|
67
67
|
return provider;
|
|
68
68
|
}
|
|
69
69
|
return provider.id || index;
|
|
70
70
|
};
|
|
71
71
|
|
|
72
|
-
const handleProviderClick = (provider: string |
|
|
72
|
+
const handleProviderClick = (provider: string | ProviderOptions) => {
|
|
73
73
|
if (typeof provider === 'string') {
|
|
74
74
|
alert('Cannot edit custom providers');
|
|
75
75
|
} else if (!provider.config) {
|
|
76
76
|
alert('There is no config for this provider');
|
|
77
77
|
} else {
|
|
78
|
-
setSelectedProvider(provider as
|
|
78
|
+
setSelectedProvider(provider as ProviderOptions);
|
|
79
79
|
}
|
|
80
80
|
};
|
|
81
81
|
|
|
82
|
-
const handleSave = (config:
|
|
82
|
+
const handleSave = (config: ProviderOptions['config']) => {
|
|
83
83
|
if (selectedProvider) {
|
|
84
84
|
const updatedProviders = providers.map((provider) =>
|
|
85
85
|
provider.id === selectedProvider.id ? { ...provider, config } : provider,
|
|
@@ -96,7 +96,7 @@ const ProviderSelector: React.FC<ProviderSelectorProps> = ({ providers, onChange
|
|
|
96
96
|
freeSolo
|
|
97
97
|
options={defaultProviders}
|
|
98
98
|
value={providers}
|
|
99
|
-
onChange={(event, newValue: (string |
|
|
99
|
+
onChange={(event, newValue: (string | ProviderOptions)[]) => {
|
|
100
100
|
onChange(newValue.map((value) => (typeof value === 'string' ? { id: value } : value)));
|
|
101
101
|
}}
|
|
102
102
|
getOptionLabel={(option) => {
|
|
@@ -106,7 +106,7 @@ const ProviderSelector: React.FC<ProviderSelectorProps> = ({ providers, onChange
|
|
|
106
106
|
if (typeof option === 'string') {
|
|
107
107
|
return option;
|
|
108
108
|
}
|
|
109
|
-
return (option as
|
|
109
|
+
return (option as ProviderOptions).id || 'Unknown provider';
|
|
110
110
|
}}
|
|
111
111
|
renderTags={(value, getTagProps) =>
|
|
112
112
|
value.map((provider, index: number) => {
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
import { create } from 'zustand';
|
|
2
2
|
import { persist } from 'zustand/middleware';
|
|
3
3
|
|
|
4
|
-
import type { Assertion,
|
|
4
|
+
import type { Assertion, ProviderOptions, TestCase } from '../../../../types';
|
|
5
5
|
|
|
6
6
|
export interface State {
|
|
7
7
|
asserts: Assertion[];
|
|
8
8
|
testCases: TestCase[];
|
|
9
9
|
description: string;
|
|
10
|
-
providers:
|
|
10
|
+
providers: ProviderOptions[];
|
|
11
11
|
prompts: string[];
|
|
12
12
|
setAsserts: (asserts: Assertion[]) => void;
|
|
13
13
|
setTestCases: (testCases: TestCase[]) => void;
|
|
14
14
|
setDescription: (description: string) => void;
|
|
15
|
-
setProviders: (providers:
|
|
15
|
+
setProviders: (providers: ProviderOptions[]) => void;
|
|
16
16
|
setPrompts: (prompts: string[]) => void;
|
|
17
17
|
}
|
|
18
18
|
|
|
File without changes
|
|
File without changes
|