kimi-vercel-ai-sdk-provider 0.4.0 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +426 -31
- package/dist/index.d.mts +1608 -2
- package/dist/index.d.ts +1608 -2
- package/dist/index.js +1949 -6
- package/dist/index.js.map +1 -1
- package/dist/index.mjs +1924 -5
- package/dist/index.mjs.map +1 -1
- package/package.json +1 -1
- package/src/__tests__/auto-detect.test.ts +140 -0
- package/src/__tests__/code-validation.test.ts +267 -0
- package/src/__tests__/ensemble.test.ts +242 -0
- package/src/__tests__/multi-agent.test.ts +201 -0
- package/src/__tests__/project-tools.test.ts +181 -0
- package/src/__tests__/tools.test.ts +1 -1
- package/src/chat/kimi-chat-settings.ts +15 -1
- package/src/code-validation/detector.ts +319 -0
- package/src/code-validation/index.ts +31 -0
- package/src/code-validation/types.ts +291 -0
- package/src/code-validation/validator.ts +547 -0
- package/src/core/errors.ts +91 -0
- package/src/core/index.ts +5 -0
- package/src/ensemble/index.ts +17 -0
- package/src/ensemble/multi-sampler.ts +433 -0
- package/src/ensemble/types.ts +279 -0
- package/src/index.ts +102 -3
- package/src/kimi-provider.ts +354 -1
- package/src/multi-agent/index.ts +21 -0
- package/src/multi-agent/types.ts +312 -0
- package/src/multi-agent/workflows.ts +539 -0
- package/src/project-tools/index.ts +16 -0
- package/src/project-tools/scaffolder.ts +494 -0
- package/src/project-tools/types.ts +244 -0
- package/src/tools/auto-detect.ts +276 -0
- package/src/tools/index.ts +6 -2
- package/src/tools/prepare-tools.ts +91 -2
package/package.json
CHANGED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import { describe, expect, it } from 'vitest';
|
|
2
|
+
import { detectToolsFromPrompt, hasToolOptOut, shouldAutoEnableTools } from '../tools/auto-detect';
|
|
3
|
+
|
|
4
|
+
describe('Auto-detect Tools', () => {
|
|
5
|
+
describe('detectToolsFromPrompt', () => {
|
|
6
|
+
describe('web search detection', () => {
|
|
7
|
+
it('detects direct search requests', () => {
|
|
8
|
+
const result = detectToolsFromPrompt('Search for the latest news about AI');
|
|
9
|
+
expect(result.webSearch).toBe(true);
|
|
10
|
+
expect(result.webSearchConfidence).toBeGreaterThan(0.5);
|
|
11
|
+
expect(result.webSearchMatches).toContain('direct_search');
|
|
12
|
+
});
|
|
13
|
+
|
|
14
|
+
it('detects current/latest information requests', () => {
|
|
15
|
+
const result = detectToolsFromPrompt('What is the current price of Bitcoin?');
|
|
16
|
+
expect(result.webSearch).toBe(true);
|
|
17
|
+
expect(result.webSearchConfidence).toBeGreaterThan(0.7);
|
|
18
|
+
});
|
|
19
|
+
|
|
20
|
+
it('detects weather requests', () => {
|
|
21
|
+
const result = detectToolsFromPrompt('What is the weather in New York today?');
|
|
22
|
+
expect(result.webSearch).toBe(true);
|
|
23
|
+
expect(result.webSearchMatches).toContain('weather');
|
|
24
|
+
});
|
|
25
|
+
|
|
26
|
+
it('detects news requests', () => {
|
|
27
|
+
const result = detectToolsFromPrompt('What are the latest headlines about technology?');
|
|
28
|
+
expect(result.webSearch).toBe(true);
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
it('does not detect web search for general questions', () => {
|
|
32
|
+
const result = detectToolsFromPrompt('What is the capital of France?');
|
|
33
|
+
expect(result.webSearch).toBe(false);
|
|
34
|
+
expect(result.webSearchConfidence).toBeLessThan(0.3);
|
|
35
|
+
});
|
|
36
|
+
});
|
|
37
|
+
|
|
38
|
+
describe('code interpreter detection', () => {
|
|
39
|
+
it('detects write code requests', () => {
|
|
40
|
+
const result = detectToolsFromPrompt('Write a function to sort an array');
|
|
41
|
+
expect(result.codeInterpreter).toBe(true);
|
|
42
|
+
expect(result.codeInterpreterConfidence).toBeGreaterThan(0.5);
|
|
43
|
+
expect(result.codeInterpreterMatches).toContain('write_code');
|
|
44
|
+
});
|
|
45
|
+
|
|
46
|
+
it('detects calculation requests', () => {
|
|
47
|
+
const result = detectToolsFromPrompt('Calculate 5 factorial');
|
|
48
|
+
expect(result.codeInterpreter).toBe(true);
|
|
49
|
+
expect(result.codeInterpreterMatches).toContain('calculate');
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
it('detects debug requests', () => {
|
|
53
|
+
const result = detectToolsFromPrompt('Debug this code and find the error');
|
|
54
|
+
expect(result.codeInterpreter).toBe(true);
|
|
55
|
+
expect(result.codeInterpreterMatches).toContain('debug');
|
|
56
|
+
});
|
|
57
|
+
|
|
58
|
+
it('detects code execution requests', () => {
|
|
59
|
+
const result = detectToolsFromPrompt('Run this script and show the output');
|
|
60
|
+
expect(result.codeInterpreter).toBe(true);
|
|
61
|
+
expect(result.codeInterpreterMatches).toContain('run_code');
|
|
62
|
+
});
|
|
63
|
+
|
|
64
|
+
it('detects code blocks in prompt', () => {
|
|
65
|
+
const result = detectToolsFromPrompt(`
|
|
66
|
+
Please help me with this code:
|
|
67
|
+
\`\`\`javascript
|
|
68
|
+
const x = 5;
|
|
69
|
+
console.log(x);
|
|
70
|
+
\`\`\`
|
|
71
|
+
`);
|
|
72
|
+
expect(result.codeInterpreter).toBe(true);
|
|
73
|
+
});
|
|
74
|
+
|
|
75
|
+
it('does not detect code interpreter for general questions', () => {
|
|
76
|
+
const result = detectToolsFromPrompt('Tell me a joke');
|
|
77
|
+
expect(result.codeInterpreter).toBe(false);
|
|
78
|
+
});
|
|
79
|
+
});
|
|
80
|
+
|
|
81
|
+
describe('combined detection', () => {
|
|
82
|
+
it('can detect both tools in one prompt', () => {
|
|
83
|
+
const result = detectToolsFromPrompt(
|
|
84
|
+
'Search for the latest Python tutorials and write a function to parse JSON'
|
|
85
|
+
);
|
|
86
|
+
expect(result.webSearch).toBe(true);
|
|
87
|
+
expect(result.codeInterpreter).toBe(true);
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
it('respects custom confidence threshold', () => {
|
|
91
|
+
const result = detectToolsFromPrompt('What is the price of something?', {
|
|
92
|
+
confidenceThreshold: 0.9
|
|
93
|
+
});
|
|
94
|
+
// Price alone may not meet 0.9 threshold
|
|
95
|
+
expect(result.webSearchConfidence).toBeLessThan(0.9);
|
|
96
|
+
});
|
|
97
|
+
});
|
|
98
|
+
});
|
|
99
|
+
|
|
100
|
+
describe('shouldAutoEnableTools', () => {
|
|
101
|
+
it('returns simple boolean flags', () => {
|
|
102
|
+
const result = shouldAutoEnableTools('Search for news about AI');
|
|
103
|
+
expect(typeof result.webSearch).toBe('boolean');
|
|
104
|
+
expect(typeof result.codeInterpreter).toBe('boolean');
|
|
105
|
+
expect(result.webSearch).toBe(true);
|
|
106
|
+
});
|
|
107
|
+
|
|
108
|
+
it('handles empty prompts', () => {
|
|
109
|
+
const result = shouldAutoEnableTools('');
|
|
110
|
+
expect(result.webSearch).toBe(false);
|
|
111
|
+
expect(result.codeInterpreter).toBe(false);
|
|
112
|
+
});
|
|
113
|
+
});
|
|
114
|
+
|
|
115
|
+
describe('hasToolOptOut', () => {
|
|
116
|
+
it('detects web search opt-out', () => {
|
|
117
|
+
const result = hasToolOptOut("Don't search the web, just answer from memory");
|
|
118
|
+
expect(result.webSearch).toBe(true);
|
|
119
|
+
expect(result.codeInterpreter).toBe(false);
|
|
120
|
+
});
|
|
121
|
+
|
|
122
|
+
it('detects code execution opt-out', () => {
|
|
123
|
+
const result = hasToolOptOut("Don't run or execute any code");
|
|
124
|
+
expect(result.webSearch).toBe(false);
|
|
125
|
+
expect(result.codeInterpreter).toBe(true);
|
|
126
|
+
});
|
|
127
|
+
|
|
128
|
+
it('detects both opt-outs', () => {
|
|
129
|
+
const result = hasToolOptOut('Without searching online or running code, explain this concept');
|
|
130
|
+
expect(result.webSearch).toBe(true);
|
|
131
|
+
expect(result.codeInterpreter).toBe(true);
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
it('returns false when no opt-out', () => {
|
|
135
|
+
const result = hasToolOptOut('Please help me with this task');
|
|
136
|
+
expect(result.webSearch).toBe(false);
|
|
137
|
+
expect(result.codeInterpreter).toBe(false);
|
|
138
|
+
});
|
|
139
|
+
});
|
|
140
|
+
});
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
import { describe, expect, it, vi } from 'vitest';
|
|
2
|
+
import {
|
|
3
|
+
containsCode,
|
|
4
|
+
detectLanguage,
|
|
5
|
+
extractCodeBlocks,
|
|
6
|
+
extractPrimaryCode,
|
|
7
|
+
getFileExtension
|
|
8
|
+
} from '../code-validation/detector';
|
|
9
|
+
import {
|
|
10
|
+
CodeValidator,
|
|
11
|
+
createFailedValidationResult,
|
|
12
|
+
createPassedValidationResult
|
|
13
|
+
} from '../code-validation/validator';
|
|
14
|
+
|
|
15
|
+
describe('Code Validation', () => {
|
|
16
|
+
describe('detectLanguage', () => {
|
|
17
|
+
it('detects TypeScript', () => {
|
|
18
|
+
const code = `
|
|
19
|
+
interface User {
|
|
20
|
+
name: string;
|
|
21
|
+
age: number;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
const user: User = { name: 'John', age: 30 };
|
|
25
|
+
`;
|
|
26
|
+
const result = detectLanguage(code);
|
|
27
|
+
expect(result.language).toBe('typescript');
|
|
28
|
+
expect(result.confidence).toBeGreaterThan(0.5);
|
|
29
|
+
expect(result.indicators).toContain('interface');
|
|
30
|
+
});
|
|
31
|
+
|
|
32
|
+
it('detects JavaScript', () => {
|
|
33
|
+
const code = `
|
|
34
|
+
const express = require('express');
|
|
35
|
+
const app = express();
|
|
36
|
+
|
|
37
|
+
app.get('/', (req, res) => {
|
|
38
|
+
res.send('Hello World!');
|
|
39
|
+
});
|
|
40
|
+
|
|
41
|
+
module.exports = app;
|
|
42
|
+
`;
|
|
43
|
+
const result = detectLanguage(code);
|
|
44
|
+
expect(result.language).toBe('javascript');
|
|
45
|
+
expect(result.indicators).toContain('require');
|
|
46
|
+
});
|
|
47
|
+
|
|
48
|
+
it('detects Python', () => {
|
|
49
|
+
const code = `
|
|
50
|
+
def hello_world():
|
|
51
|
+
print("Hello, World!")
|
|
52
|
+
|
|
53
|
+
if __name__ == "__main__":
|
|
54
|
+
hello_world()
|
|
55
|
+
`;
|
|
56
|
+
const result = detectLanguage(code);
|
|
57
|
+
expect(result.language).toBe('python');
|
|
58
|
+
expect(result.indicators).toContain('def');
|
|
59
|
+
expect(result.indicators).toContain('main_guard');
|
|
60
|
+
});
|
|
61
|
+
|
|
62
|
+
it('detects Java', () => {
|
|
63
|
+
const code = `
|
|
64
|
+
public class HelloWorld {
|
|
65
|
+
public static void main(String[] args) {
|
|
66
|
+
System.out.println("Hello, World!");
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
`;
|
|
70
|
+
const result = detectLanguage(code);
|
|
71
|
+
expect(result.language).toBe('java');
|
|
72
|
+
expect(result.indicators).toContain('main_method');
|
|
73
|
+
});
|
|
74
|
+
|
|
75
|
+
it('detects Go', () => {
|
|
76
|
+
const code = `
|
|
77
|
+
package main
|
|
78
|
+
|
|
79
|
+
import "fmt"
|
|
80
|
+
|
|
81
|
+
func main() {
|
|
82
|
+
fmt.Println("Hello, World!")
|
|
83
|
+
}
|
|
84
|
+
`;
|
|
85
|
+
const result = detectLanguage(code);
|
|
86
|
+
expect(result.language).toBe('go');
|
|
87
|
+
expect(result.indicators).toContain('package_main');
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
it('detects Rust', () => {
|
|
91
|
+
const code = `
|
|
92
|
+
fn main() {
|
|
93
|
+
let mut x = 5;
|
|
94
|
+
println!("The value is: {}", x);
|
|
95
|
+
}
|
|
96
|
+
`;
|
|
97
|
+
const result = detectLanguage(code);
|
|
98
|
+
expect(result.language).toBe('rust');
|
|
99
|
+
expect(result.indicators).toContain('fn');
|
|
100
|
+
expect(result.indicators).toContain('let_mut');
|
|
101
|
+
});
|
|
102
|
+
});
|
|
103
|
+
|
|
104
|
+
describe('extractCodeBlocks', () => {
|
|
105
|
+
it('extracts fenced code blocks', () => {
|
|
106
|
+
const text = `
|
|
107
|
+
Here is some code:
|
|
108
|
+
|
|
109
|
+
\`\`\`javascript
|
|
110
|
+
const x = 5;
|
|
111
|
+
console.log(x);
|
|
112
|
+
\`\`\`
|
|
113
|
+
|
|
114
|
+
And another:
|
|
115
|
+
|
|
116
|
+
\`\`\`python
|
|
117
|
+
print("hello")
|
|
118
|
+
\`\`\`
|
|
119
|
+
`;
|
|
120
|
+
|
|
121
|
+
const result = extractCodeBlocks(text);
|
|
122
|
+
expect(result.hasCode).toBe(true);
|
|
123
|
+
expect(result.blocks).toHaveLength(2);
|
|
124
|
+
expect(result.blocks[0].language).toBe('javascript');
|
|
125
|
+
expect(result.blocks[0].code).toContain('const x = 5');
|
|
126
|
+
expect(result.blocks[1].language).toBe('python');
|
|
127
|
+
});
|
|
128
|
+
|
|
129
|
+
it('handles code blocks without language annotation', () => {
|
|
130
|
+
const text = `
|
|
131
|
+
\`\`\`
|
|
132
|
+
some code here
|
|
133
|
+
\`\`\`
|
|
134
|
+
`;
|
|
135
|
+
|
|
136
|
+
const result = extractCodeBlocks(text);
|
|
137
|
+
expect(result.hasCode).toBe(true);
|
|
138
|
+
expect(result.blocks[0].language).toBeUndefined();
|
|
139
|
+
});
|
|
140
|
+
|
|
141
|
+
it('returns empty when no code found', () => {
|
|
142
|
+
const text = 'Just plain text without any code.';
|
|
143
|
+
const result = extractCodeBlocks(text);
|
|
144
|
+
expect(result.hasCode).toBe(false);
|
|
145
|
+
expect(result.blocks).toHaveLength(0);
|
|
146
|
+
});
|
|
147
|
+
});
|
|
148
|
+
|
|
149
|
+
describe('extractPrimaryCode', () => {
|
|
150
|
+
it('returns the largest code block', () => {
|
|
151
|
+
const text = `
|
|
152
|
+
\`\`\`
|
|
153
|
+
small
|
|
154
|
+
\`\`\`
|
|
155
|
+
|
|
156
|
+
\`\`\`
|
|
157
|
+
this is a much larger code block
|
|
158
|
+
with multiple lines
|
|
159
|
+
of code
|
|
160
|
+
\`\`\`
|
|
161
|
+
`;
|
|
162
|
+
|
|
163
|
+
const code = extractPrimaryCode(text);
|
|
164
|
+
expect(code).toContain('much larger');
|
|
165
|
+
});
|
|
166
|
+
|
|
167
|
+
it('returns undefined for no code', () => {
|
|
168
|
+
const code = extractPrimaryCode('no code here');
|
|
169
|
+
expect(code).toBeUndefined();
|
|
170
|
+
});
|
|
171
|
+
});
|
|
172
|
+
|
|
173
|
+
describe('containsCode', () => {
|
|
174
|
+
it('returns true for code fences', () => {
|
|
175
|
+
expect(containsCode('```\ncode\n```')).toBe(true);
|
|
176
|
+
});
|
|
177
|
+
|
|
178
|
+
it('returns true for code-like patterns', () => {
|
|
179
|
+
expect(containsCode('function test() { return 42; }')).toBe(true);
|
|
180
|
+
});
|
|
181
|
+
|
|
182
|
+
it('returns false for plain text', () => {
|
|
183
|
+
expect(containsCode('Hello, this is just text.')).toBe(false);
|
|
184
|
+
});
|
|
185
|
+
});
|
|
186
|
+
|
|
187
|
+
describe('getFileExtension', () => {
|
|
188
|
+
it('returns correct extensions', () => {
|
|
189
|
+
expect(getFileExtension('javascript')).toBe('js');
|
|
190
|
+
expect(getFileExtension('typescript')).toBe('ts');
|
|
191
|
+
expect(getFileExtension('python')).toBe('py');
|
|
192
|
+
expect(getFileExtension('java')).toBe('java');
|
|
193
|
+
expect(getFileExtension('go')).toBe('go');
|
|
194
|
+
expect(getFileExtension('rust')).toBe('rs');
|
|
195
|
+
expect(getFileExtension('auto')).toBe('txt');
|
|
196
|
+
});
|
|
197
|
+
});
|
|
198
|
+
|
|
199
|
+
describe('CodeValidator', () => {
|
|
200
|
+
const createMockGenerator = (response: string) => vi.fn().mockResolvedValue({ text: response });
|
|
201
|
+
|
|
202
|
+
it('validates correct code', async () => {
|
|
203
|
+
const validator = new CodeValidator({
|
|
204
|
+
generateText: createMockGenerator('{"errors": [], "warnings": []}')
|
|
205
|
+
});
|
|
206
|
+
|
|
207
|
+
const result = await validator.validate(
|
|
208
|
+
'function add(a, b) { return a + b; }',
|
|
209
|
+
{ enabled: true, language: 'javascript' },
|
|
210
|
+
''
|
|
211
|
+
);
|
|
212
|
+
|
|
213
|
+
// Static analysis should pass for this simple function
|
|
214
|
+
expect(result.attempts).toBe(1);
|
|
215
|
+
expect(result.language).toBe('javascript');
|
|
216
|
+
});
|
|
217
|
+
|
|
218
|
+
it('detects syntax errors', async () => {
|
|
219
|
+
const validator = new CodeValidator({
|
|
220
|
+
generateText: createMockGenerator(
|
|
221
|
+
'{"errors": [{"message": "Missing bracket", "type": "syntax"}], "warnings": []}'
|
|
222
|
+
)
|
|
223
|
+
});
|
|
224
|
+
|
|
225
|
+
const result = await validator.validate(
|
|
226
|
+
'function test( { return 42; }', // Missing closing paren
|
|
227
|
+
{ enabled: true, language: 'javascript', maxAttempts: 1 },
|
|
228
|
+
''
|
|
229
|
+
);
|
|
230
|
+
|
|
231
|
+
expect(result.valid).toBe(false);
|
|
232
|
+
expect(result.errors.length).toBeGreaterThan(0);
|
|
233
|
+
});
|
|
234
|
+
|
|
235
|
+
it('respects maxAttempts', async () => {
|
|
236
|
+
let attempts = 0;
|
|
237
|
+
const validator = new CodeValidator({
|
|
238
|
+
generateText: vi.fn().mockImplementation(async () => {
|
|
239
|
+
attempts++;
|
|
240
|
+
return { text: '{"errors": [{"message": "error"}], "warnings": []}' };
|
|
241
|
+
})
|
|
242
|
+
});
|
|
243
|
+
|
|
244
|
+
const result = await validator.validate('broken code {{{', { enabled: true, maxAttempts: 2 }, '');
|
|
245
|
+
|
|
246
|
+
expect(result.attempts).toBeLessThanOrEqual(2);
|
|
247
|
+
expect(attempts).toBe(1);
|
|
248
|
+
});
|
|
249
|
+
});
|
|
250
|
+
|
|
251
|
+
describe('utility functions', () => {
|
|
252
|
+
it('createPassedValidationResult creates valid result', () => {
|
|
253
|
+
const result = createPassedValidationResult('code', 'javascript', 'output');
|
|
254
|
+
expect(result.valid).toBe(true);
|
|
255
|
+
expect(result.errors).toHaveLength(0);
|
|
256
|
+
expect(result.output).toBe('output');
|
|
257
|
+
});
|
|
258
|
+
|
|
259
|
+
it('createFailedValidationResult creates invalid result', () => {
|
|
260
|
+
const result = createFailedValidationResult('code', 'python', [
|
|
261
|
+
{ type: 'syntax', message: 'Error', severity: 'error' }
|
|
262
|
+
]);
|
|
263
|
+
expect(result.valid).toBe(false);
|
|
264
|
+
expect(result.errors).toHaveLength(1);
|
|
265
|
+
});
|
|
266
|
+
});
|
|
267
|
+
});
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import type { EnsembleResponse } from '../ensemble/types';
|
|
2
|
+
import { describe, expect, it, vi } from 'vitest';
|
|
3
|
+
import { MultiSampler, createSingletonEnsembleResult } from '../ensemble/multi-sampler';
|
|
4
|
+
|
|
5
|
+
describe('Ensemble / Multi-Sampling', () => {
|
|
6
|
+
describe('MultiSampler', () => {
|
|
7
|
+
const createMockGenerator = (responses: Array<{ text: string; error?: string }>) => {
|
|
8
|
+
let callIndex = 0;
|
|
9
|
+
return vi.fn().mockImplementation(async ({ temperature, sampleIndex }) => {
|
|
10
|
+
const response = responses[callIndex % responses.length];
|
|
11
|
+
callIndex++;
|
|
12
|
+
|
|
13
|
+
if (response.error) {
|
|
14
|
+
throw new Error(response.error);
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
return {
|
|
18
|
+
text: response.text,
|
|
19
|
+
usage: { promptTokens: 10, completionTokens: 20, totalTokens: 30 },
|
|
20
|
+
finishReason: 'stop',
|
|
21
|
+
temperature,
|
|
22
|
+
sampleIndex
|
|
23
|
+
};
|
|
24
|
+
});
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
it('generates multiple samples', async () => {
|
|
28
|
+
const sampler = new MultiSampler({ modelId: 'test-model' });
|
|
29
|
+
const generator = createMockGenerator([{ text: 'Response 1' }, { text: 'Response 2' }, { text: 'Response 3' }]);
|
|
30
|
+
|
|
31
|
+
const result = await sampler.generate(generator, {
|
|
32
|
+
n: 3,
|
|
33
|
+
selectionStrategy: 'first'
|
|
34
|
+
});
|
|
35
|
+
|
|
36
|
+
expect(generator).toHaveBeenCalledTimes(3);
|
|
37
|
+
expect(result.text).toBe('Response 1');
|
|
38
|
+
expect(result.metadata.nRequested).toBe(3);
|
|
39
|
+
expect(result.metadata.nCompleted).toBe(3);
|
|
40
|
+
});
|
|
41
|
+
|
|
42
|
+
it('applies temperature variance', async () => {
|
|
43
|
+
const sampler = new MultiSampler({ modelId: 'test-model', baseTemperature: 0.5 });
|
|
44
|
+
const temperatures: number[] = [];
|
|
45
|
+
|
|
46
|
+
const generator = vi.fn().mockImplementation(async ({ temperature }) => {
|
|
47
|
+
temperatures.push(temperature);
|
|
48
|
+
return {
|
|
49
|
+
text: 'Response',
|
|
50
|
+
usage: { promptTokens: 10, completionTokens: 20, totalTokens: 30 },
|
|
51
|
+
finishReason: 'stop'
|
|
52
|
+
};
|
|
53
|
+
});
|
|
54
|
+
|
|
55
|
+
await sampler.generate(generator, {
|
|
56
|
+
n: 3,
|
|
57
|
+
temperatureVariance: 0.2,
|
|
58
|
+
selectionStrategy: 'first'
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
expect(temperatures[0]).toBe(0.5);
|
|
62
|
+
expect(temperatures[1]).toBe(0.7);
|
|
63
|
+
expect(temperatures[2]).toBe(0.9);
|
|
64
|
+
});
|
|
65
|
+
|
|
66
|
+
describe('selection strategies', () => {
|
|
67
|
+
it('first strategy returns first response', async () => {
|
|
68
|
+
const sampler = new MultiSampler({ modelId: 'test-model' });
|
|
69
|
+
const generator = createMockGenerator([{ text: 'First' }, { text: 'Second' }, { text: 'Third' }]);
|
|
70
|
+
|
|
71
|
+
const result = await sampler.generate(generator, {
|
|
72
|
+
n: 3,
|
|
73
|
+
selectionStrategy: 'first'
|
|
74
|
+
});
|
|
75
|
+
|
|
76
|
+
expect(result.text).toBe('First');
|
|
77
|
+
expect(result.metadata.winningIndex).toBe(0);
|
|
78
|
+
});
|
|
79
|
+
|
|
80
|
+
it('best strategy with confidence heuristic prefers longer completions', async () => {
|
|
81
|
+
const sampler = new MultiSampler({ modelId: 'test-model' });
|
|
82
|
+
let callIndex = 0;
|
|
83
|
+
|
|
84
|
+
const generator = vi.fn().mockImplementation(async () => {
|
|
85
|
+
callIndex++;
|
|
86
|
+
const completionTokens = callIndex === 2 ? 100 : 20; // Second response has most tokens
|
|
87
|
+
return {
|
|
88
|
+
text: `Response ${callIndex}`,
|
|
89
|
+
usage: { promptTokens: 10, completionTokens, totalTokens: 10 + completionTokens },
|
|
90
|
+
finishReason: 'stop'
|
|
91
|
+
};
|
|
92
|
+
});
|
|
93
|
+
|
|
94
|
+
const result = await sampler.generate(generator, {
|
|
95
|
+
n: 3,
|
|
96
|
+
selectionStrategy: 'best',
|
|
97
|
+
scoringHeuristic: 'confidence'
|
|
98
|
+
});
|
|
99
|
+
|
|
100
|
+
expect(result.metadata.winningIndex).toBe(1); // Second response wins
|
|
101
|
+
});
|
|
102
|
+
|
|
103
|
+
it('best strategy with code heuristic penalizes errors', async () => {
|
|
104
|
+
const sampler = new MultiSampler({ modelId: 'test-model' });
|
|
105
|
+
|
|
106
|
+
const generator = vi.fn().mockImplementation(async ({ sampleIndex }) => {
|
|
107
|
+
const texts = [
|
|
108
|
+
'function test() { return SyntaxError; }', // Has error pattern
|
|
109
|
+
'function test() { return 42; }', // Clean code
|
|
110
|
+
'TypeError: undefined is not a function' // Has error
|
|
111
|
+
];
|
|
112
|
+
return {
|
|
113
|
+
text: texts[sampleIndex],
|
|
114
|
+
usage: { promptTokens: 10, completionTokens: 20, totalTokens: 30 },
|
|
115
|
+
finishReason: 'stop'
|
|
116
|
+
};
|
|
117
|
+
});
|
|
118
|
+
|
|
119
|
+
const result = await sampler.generate(generator, {
|
|
120
|
+
n: 3,
|
|
121
|
+
selectionStrategy: 'best',
|
|
122
|
+
scoringHeuristic: 'code'
|
|
123
|
+
});
|
|
124
|
+
|
|
125
|
+
expect(result.metadata.winningIndex).toBe(1); // Clean code wins
|
|
126
|
+
});
|
|
127
|
+
|
|
128
|
+
it('all strategy returns alternatives', async () => {
|
|
129
|
+
const sampler = new MultiSampler({ modelId: 'test-model' });
|
|
130
|
+
const generator = createMockGenerator([{ text: 'First' }, { text: 'Second' }, { text: 'Third' }]);
|
|
131
|
+
|
|
132
|
+
const result = await sampler.generate(generator, {
|
|
133
|
+
n: 3,
|
|
134
|
+
selectionStrategy: 'all'
|
|
135
|
+
});
|
|
136
|
+
|
|
137
|
+
expect(result.alternatives).toHaveLength(3);
|
|
138
|
+
expect(result.alternatives![0].text).toBe('First');
|
|
139
|
+
expect(result.alternatives![1].text).toBe('Second');
|
|
140
|
+
expect(result.alternatives![2].text).toBe('Third');
|
|
141
|
+
});
|
|
142
|
+
});
|
|
143
|
+
|
|
144
|
+
describe('error handling', () => {
|
|
145
|
+
it('handles partial failures with allowPartialFailure', async () => {
|
|
146
|
+
const sampler = new MultiSampler({ modelId: 'test-model' });
|
|
147
|
+
const generator = createMockGenerator([
|
|
148
|
+
{ text: 'Success' },
|
|
149
|
+
{ text: '', error: 'Failed' },
|
|
150
|
+
{ text: 'Another success' }
|
|
151
|
+
]);
|
|
152
|
+
|
|
153
|
+
const result = await sampler.generate(generator, {
|
|
154
|
+
n: 3,
|
|
155
|
+
selectionStrategy: 'first',
|
|
156
|
+
allowPartialFailure: true
|
|
157
|
+
});
|
|
158
|
+
|
|
159
|
+
expect(result.metadata.nCompleted).toBe(2);
|
|
160
|
+
expect(result.metadata.nFailed).toBe(1);
|
|
161
|
+
});
|
|
162
|
+
|
|
163
|
+
it('throws when all samples fail', async () => {
|
|
164
|
+
const sampler = new MultiSampler({ modelId: 'test-model' });
|
|
165
|
+
const generator = createMockGenerator([
|
|
166
|
+
{ text: '', error: 'Failed 1' },
|
|
167
|
+
{ text: '', error: 'Failed 2' },
|
|
168
|
+
{ text: '', error: 'Failed 3' }
|
|
169
|
+
]);
|
|
170
|
+
|
|
171
|
+
await expect(
|
|
172
|
+
sampler.generate(generator, {
|
|
173
|
+
n: 3,
|
|
174
|
+
selectionStrategy: 'first'
|
|
175
|
+
})
|
|
176
|
+
).rejects.toThrow('All ensemble samples failed');
|
|
177
|
+
});
|
|
178
|
+
|
|
179
|
+
it('validates n parameter', async () => {
|
|
180
|
+
const sampler = new MultiSampler({ modelId: 'test-model' });
|
|
181
|
+
const generator = createMockGenerator([{ text: 'test' }]);
|
|
182
|
+
|
|
183
|
+
await expect(sampler.generate(generator, { n: 0, selectionStrategy: 'first' })).rejects.toThrow(
|
|
184
|
+
'Ensemble n must be between 1 and 10'
|
|
185
|
+
);
|
|
186
|
+
|
|
187
|
+
await expect(sampler.generate(generator, { n: 11, selectionStrategy: 'first' })).rejects.toThrow(
|
|
188
|
+
'Ensemble n must be between 1 and 10'
|
|
189
|
+
);
|
|
190
|
+
});
|
|
191
|
+
});
|
|
192
|
+
|
|
193
|
+
describe('custom scorer', () => {
|
|
194
|
+
it('uses custom scorer when heuristic is custom', async () => {
|
|
195
|
+
const sampler = new MultiSampler({ modelId: 'test-model' });
|
|
196
|
+
|
|
197
|
+
const generator = vi.fn().mockImplementation(async ({ sampleIndex }) => {
|
|
198
|
+
return {
|
|
199
|
+
text: `Response ${sampleIndex}`,
|
|
200
|
+
usage: { promptTokens: 10, completionTokens: 20, totalTokens: 30 },
|
|
201
|
+
finishReason: 'stop'
|
|
202
|
+
};
|
|
203
|
+
});
|
|
204
|
+
|
|
205
|
+
const customScorer = vi.fn().mockImplementation((response: EnsembleResponse) => {
|
|
206
|
+
// Prefer response 2
|
|
207
|
+
return response.sampleIndex === 2 ? 100 : 10;
|
|
208
|
+
});
|
|
209
|
+
|
|
210
|
+
const result = await sampler.generate(generator, {
|
|
211
|
+
n: 3,
|
|
212
|
+
selectionStrategy: 'best',
|
|
213
|
+
scoringHeuristic: 'custom',
|
|
214
|
+
customScorer
|
|
215
|
+
});
|
|
216
|
+
|
|
217
|
+
expect(customScorer).toHaveBeenCalledTimes(3);
|
|
218
|
+
expect(result.metadata.winningIndex).toBe(2);
|
|
219
|
+
});
|
|
220
|
+
});
|
|
221
|
+
});
|
|
222
|
+
|
|
223
|
+
describe('createSingletonEnsembleResult', () => {
|
|
224
|
+
it('creates a result for single response', () => {
|
|
225
|
+
const response = {
|
|
226
|
+
text: 'Hello',
|
|
227
|
+
reasoning: 'Thinking...',
|
|
228
|
+
usage: { promptTokens: 10, completionTokens: 20, totalTokens: 30 },
|
|
229
|
+
finishReason: 'stop'
|
|
230
|
+
};
|
|
231
|
+
|
|
232
|
+
const result = createSingletonEnsembleResult(response, 'test-model', 100);
|
|
233
|
+
|
|
234
|
+
expect(result.text).toBe('Hello');
|
|
235
|
+
expect(result.reasoning).toBe('Thinking...');
|
|
236
|
+
expect(result.metadata.nRequested).toBe(1);
|
|
237
|
+
expect(result.metadata.nCompleted).toBe(1);
|
|
238
|
+
expect(result.metadata.modelId).toBe('test-model');
|
|
239
|
+
expect(result.metadata.durationMs).toBe(100);
|
|
240
|
+
});
|
|
241
|
+
});
|
|
242
|
+
});
|