@sparkleideas/agentdb-onnx 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/ARCHITECTURE.md +331 -0
- package/IMPLEMENTATION-SUMMARY.md +456 -0
- package/README.md +418 -0
- package/examples/complete-workflow.ts +281 -0
- package/package.json +41 -0
- package/src/benchmarks/benchmark-runner.ts +301 -0
- package/src/cli.ts +245 -0
- package/src/index.ts +128 -0
- package/src/services/ONNXEmbeddingService.ts +459 -0
- package/src/tests/integration.test.ts +302 -0
- package/src/tests/onnx-embedding.test.ts +317 -0
- package/tsconfig.json +19 -0
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Integration tests for AgentDB + ONNX
|
|
3
|
+
*
|
|
4
|
+
* Tests the integration between ONNXEmbeddingService and AgentDB controllers
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
import { describe, it, expect, beforeAll, afterAll } from 'vitest';
|
|
8
|
+
import { createONNXAgentDB } from '../index.js';
|
|
9
|
+
|
|
10
|
+
describe('AgentDB + ONNX Integration', () => {
|
|
11
|
+
let agentdb: Awaited<ReturnType<typeof createONNXAgentDB>>;
|
|
12
|
+
|
|
13
|
+
beforeAll(async () => {
|
|
14
|
+
agentdb = await createONNXAgentDB({
|
|
15
|
+
dbPath: ':memory:', // Use in-memory database for tests
|
|
16
|
+
modelName: 'Xenova/all-MiniLM-L6-v2',
|
|
17
|
+
useGPU: false,
|
|
18
|
+
batchSize: 4,
|
|
19
|
+
cacheSize: 100
|
|
20
|
+
});
|
|
21
|
+
});
|
|
22
|
+
|
|
23
|
+
afterAll(async () => {
|
|
24
|
+
await agentdb.close();
|
|
25
|
+
});
|
|
26
|
+
|
|
27
|
+
describe('ReasoningBank', () => {
|
|
28
|
+
it('should store and retrieve patterns', async () => {
|
|
29
|
+
const patternId = await agentdb.reasoningBank.storePattern({
|
|
30
|
+
taskType: 'debugging',
|
|
31
|
+
approach: 'Binary search through code execution',
|
|
32
|
+
successRate: 0.92,
|
|
33
|
+
tags: ['systematic', 'efficient']
|
|
34
|
+
});
|
|
35
|
+
|
|
36
|
+
expect(patternId).toBeGreaterThan(0);
|
|
37
|
+
|
|
38
|
+
const pattern = agentdb.reasoningBank.getPattern(patternId);
|
|
39
|
+
expect(pattern).not.toBeNull();
|
|
40
|
+
expect(pattern?.taskType).toBe('debugging');
|
|
41
|
+
expect(pattern?.successRate).toBe(0.92);
|
|
42
|
+
});
|
|
43
|
+
|
|
44
|
+
it('should search for similar patterns with semantic matching', async () => {
|
|
45
|
+
// Store multiple debugging patterns
|
|
46
|
+
await agentdb.reasoningBank.storePattern({
|
|
47
|
+
taskType: 'debugging',
|
|
48
|
+
approach: 'Check logs first, then reproduce the issue systematically',
|
|
49
|
+
successRate: 0.88
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
await agentdb.reasoningBank.storePattern({
|
|
53
|
+
taskType: 'debugging',
|
|
54
|
+
approach: 'Use debugger breakpoints and step through execution',
|
|
55
|
+
successRate: 0.85
|
|
56
|
+
});
|
|
57
|
+
|
|
58
|
+
await agentdb.reasoningBank.storePattern({
|
|
59
|
+
taskType: 'optimization',
|
|
60
|
+
approach: 'Profile before optimizing to identify bottlenecks',
|
|
61
|
+
successRate: 0.95
|
|
62
|
+
});
|
|
63
|
+
|
|
64
|
+
// Search for debugging patterns using semantic query
|
|
65
|
+
const results = await agentdb.reasoningBank.searchPatterns({
|
|
66
|
+
task: 'how to debug code issues',
|
|
67
|
+
k: 5,
|
|
68
|
+
threshold: 0.5
|
|
69
|
+
});
|
|
70
|
+
|
|
71
|
+
expect(results.length).toBeGreaterThan(0);
|
|
72
|
+
results.forEach(r => {
|
|
73
|
+
expect(r.similarity).toBeGreaterThanOrEqual(0.5);
|
|
74
|
+
expect(r.similarity).toBeLessThanOrEqual(1.0);
|
|
75
|
+
});
|
|
76
|
+
|
|
77
|
+
// Most results should be debugging-related (semantic matching)
|
|
78
|
+
const debuggingResults = results.filter(r => r.taskType === 'debugging');
|
|
79
|
+
expect(debuggingResults.length).toBeGreaterThan(0);
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
it('should filter by task type', async () => {
|
|
83
|
+
const results = await agentdb.reasoningBank.searchPatterns({
|
|
84
|
+
task: 'approach for solving problems',
|
|
85
|
+
k: 10,
|
|
86
|
+
filters: { taskType: 'debugging' }
|
|
87
|
+
});
|
|
88
|
+
|
|
89
|
+
results.forEach(r => {
|
|
90
|
+
expect(r.taskType).toBe('debugging');
|
|
91
|
+
});
|
|
92
|
+
});
|
|
93
|
+
|
|
94
|
+
it('should use ONNX cache for repeated queries', async () => {
|
|
95
|
+
const stats1 = agentdb.embedder.getStats();
|
|
96
|
+
|
|
97
|
+
// First search - will generate embedding
|
|
98
|
+
await agentdb.reasoningBank.searchPatterns({
|
|
99
|
+
task: 'test query for caching',
|
|
100
|
+
k: 5
|
|
101
|
+
});
|
|
102
|
+
|
|
103
|
+
const stats2 = agentdb.embedder.getStats();
|
|
104
|
+
const embeddings1 = stats2.totalEmbeddings - stats1.totalEmbeddings;
|
|
105
|
+
|
|
106
|
+
// Second search - should use cache
|
|
107
|
+
await agentdb.reasoningBank.searchPatterns({
|
|
108
|
+
task: 'test query for caching',
|
|
109
|
+
k: 5
|
|
110
|
+
});
|
|
111
|
+
|
|
112
|
+
const stats3 = agentdb.embedder.getStats();
|
|
113
|
+
const embeddings2 = stats3.totalEmbeddings - stats2.totalEmbeddings;
|
|
114
|
+
|
|
115
|
+
// Second query should use cache (0 new embeddings)
|
|
116
|
+
expect(embeddings2).toBe(0);
|
|
117
|
+
expect(embeddings1).toBeGreaterThan(0);
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
it('should delete patterns', async () => {
|
|
121
|
+
const id = await agentdb.reasoningBank.storePattern({
|
|
122
|
+
taskType: 'test',
|
|
123
|
+
approach: 'temporary pattern',
|
|
124
|
+
successRate: 0.5
|
|
125
|
+
});
|
|
126
|
+
|
|
127
|
+
// Verify it exists
|
|
128
|
+
let pattern = agentdb.reasoningBank.getPattern(id);
|
|
129
|
+
expect(pattern).not.toBeNull();
|
|
130
|
+
|
|
131
|
+
// Delete it
|
|
132
|
+
const deleted = agentdb.reasoningBank.deletePattern(id);
|
|
133
|
+
expect(deleted).toBe(true);
|
|
134
|
+
|
|
135
|
+
// Verify it's gone
|
|
136
|
+
pattern = agentdb.reasoningBank.getPattern(id);
|
|
137
|
+
expect(pattern).toBeNull();
|
|
138
|
+
});
|
|
139
|
+
|
|
140
|
+
it('should record outcomes for learning', async () => {
|
|
141
|
+
const id = await agentdb.reasoningBank.storePattern({
|
|
142
|
+
taskType: 'testing',
|
|
143
|
+
approach: 'learning pattern',
|
|
144
|
+
successRate: 0.5,
|
|
145
|
+
uses: 0,
|
|
146
|
+
avgReward: 0
|
|
147
|
+
});
|
|
148
|
+
|
|
149
|
+
// Record successful outcome
|
|
150
|
+
await agentdb.reasoningBank.recordOutcome(id, true, 0.95);
|
|
151
|
+
|
|
152
|
+
// Verify stats updated
|
|
153
|
+
const pattern = agentdb.reasoningBank.getPattern(id);
|
|
154
|
+
expect(pattern?.uses).toBe(1);
|
|
155
|
+
expect(pattern?.avgReward).toBeGreaterThan(0);
|
|
156
|
+
});
|
|
157
|
+
});
|
|
158
|
+
|
|
159
|
+
describe('ReflexionMemory', () => {
|
|
160
|
+
it('should store and retrieve episodes', async () => {
|
|
161
|
+
const episodeId = await agentdb.reflexionMemory.storeEpisode({
|
|
162
|
+
sessionId: 'test-session-1',
|
|
163
|
+
task: 'Debug memory leak in server',
|
|
164
|
+
reward: 0.95,
|
|
165
|
+
success: true,
|
|
166
|
+
critique: 'Profiling helped identify the leak quickly'
|
|
167
|
+
});
|
|
168
|
+
|
|
169
|
+
expect(episodeId).toBeGreaterThan(0);
|
|
170
|
+
});
|
|
171
|
+
|
|
172
|
+
it('should retrieve relevant episodes', async () => {
|
|
173
|
+
// Store multiple episodes
|
|
174
|
+
await agentdb.reflexionMemory.storeEpisode({
|
|
175
|
+
sessionId: 'session-2',
|
|
176
|
+
task: 'Optimize database queries',
|
|
177
|
+
reward: 0.88,
|
|
178
|
+
success: true,
|
|
179
|
+
critique: 'Adding indexes improved performance significantly'
|
|
180
|
+
});
|
|
181
|
+
|
|
182
|
+
await agentdb.reflexionMemory.storeEpisode({
|
|
183
|
+
sessionId: 'session-2',
|
|
184
|
+
task: 'Debug connection timeout',
|
|
185
|
+
reward: 0.65,
|
|
186
|
+
success: false,
|
|
187
|
+
critique: 'Should have checked network logs first'
|
|
188
|
+
});
|
|
189
|
+
|
|
190
|
+
await agentdb.reflexionMemory.storeEpisode({
|
|
191
|
+
sessionId: 'session-2',
|
|
192
|
+
task: 'Fix API response time',
|
|
193
|
+
reward: 0.92,
|
|
194
|
+
success: true,
|
|
195
|
+
critique: 'Caching strategy worked well'
|
|
196
|
+
});
|
|
197
|
+
|
|
198
|
+
// Retrieve episodes related to performance
|
|
199
|
+
const results = await agentdb.reflexionMemory.retrieveRelevant({
|
|
200
|
+
task: 'performance optimization',
|
|
201
|
+
k: 5
|
|
202
|
+
});
|
|
203
|
+
|
|
204
|
+
expect(results.length).toBeGreaterThan(0);
|
|
205
|
+
results.forEach(r => {
|
|
206
|
+
expect(r.similarity).toBeGreaterThan(0);
|
|
207
|
+
});
|
|
208
|
+
});
|
|
209
|
+
|
|
210
|
+
it('should filter by success', async () => {
|
|
211
|
+
const successes = await agentdb.reflexionMemory.retrieveRelevant({
|
|
212
|
+
task: 'debugging approach',
|
|
213
|
+
k: 10,
|
|
214
|
+
onlySuccesses: true
|
|
215
|
+
});
|
|
216
|
+
|
|
217
|
+
successes.forEach(r => {
|
|
218
|
+
expect(r.success).toBe(true);
|
|
219
|
+
});
|
|
220
|
+
|
|
221
|
+
const failures = await agentdb.reflexionMemory.retrieveRelevant({
|
|
222
|
+
task: 'debugging approach',
|
|
223
|
+
k: 10,
|
|
224
|
+
onlyFailures: true
|
|
225
|
+
});
|
|
226
|
+
|
|
227
|
+
failures.forEach(r => {
|
|
228
|
+
expect(r.success).toBe(false);
|
|
229
|
+
});
|
|
230
|
+
});
|
|
231
|
+
|
|
232
|
+
it('should get critique summary', async () => {
|
|
233
|
+
const summary = await agentdb.reflexionMemory.getCritiqueSummary({
|
|
234
|
+
task: 'debugging',
|
|
235
|
+
k: 5
|
|
236
|
+
});
|
|
237
|
+
|
|
238
|
+
expect(typeof summary).toBe('string');
|
|
239
|
+
// Summary should contain critique content if failures exist
|
|
240
|
+
});
|
|
241
|
+
|
|
242
|
+
it('should get success strategies', async () => {
|
|
243
|
+
const strategies = await agentdb.reflexionMemory.getSuccessStrategies({
|
|
244
|
+
task: 'optimization',
|
|
245
|
+
k: 5
|
|
246
|
+
});
|
|
247
|
+
|
|
248
|
+
expect(typeof strategies).toBe('string');
|
|
249
|
+
// Strategies should contain successful approach descriptions
|
|
250
|
+
});
|
|
251
|
+
});
|
|
252
|
+
|
|
253
|
+
describe('Performance', () => {
|
|
254
|
+
it('should have good cache hit rate', async () => {
|
|
255
|
+
// Clear cache first
|
|
256
|
+
agentdb.embedder.clearCache();
|
|
257
|
+
|
|
258
|
+
// Generate some queries
|
|
259
|
+
const queries = [
|
|
260
|
+
'debug memory issue',
|
|
261
|
+
'optimize performance',
|
|
262
|
+
'debug memory issue', // Repeat
|
|
263
|
+
'fix bug',
|
|
264
|
+
'optimize performance' // Repeat
|
|
265
|
+
];
|
|
266
|
+
|
|
267
|
+
for (const query of queries) {
|
|
268
|
+
await agentdb.reasoningBank.searchPatterns({
|
|
269
|
+
task: query,
|
|
270
|
+
k: 3
|
|
271
|
+
});
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
const stats = agentdb.embedder.getStats();
|
|
275
|
+
expect(stats.cache.hitRate).toBeGreaterThan(0.2); // At least 20% hit rate
|
|
276
|
+
});
|
|
277
|
+
|
|
278
|
+
it('should maintain low latency with warmup', async () => {
|
|
279
|
+
const stats = agentdb.embedder.getStats();
|
|
280
|
+
|
|
281
|
+
// After warmup, average latency should be reasonable
|
|
282
|
+
expect(stats.avgLatency).toBeLessThan(200); // < 200ms average
|
|
283
|
+
expect(stats.warmupComplete).toBe(true);
|
|
284
|
+
});
|
|
285
|
+
});
|
|
286
|
+
|
|
287
|
+
describe('Statistics', () => {
|
|
288
|
+
it('should provide comprehensive stats', async () => {
|
|
289
|
+
const stats = agentdb.getStats();
|
|
290
|
+
|
|
291
|
+
expect(stats).toHaveProperty('embedder');
|
|
292
|
+
expect(stats).toHaveProperty('database');
|
|
293
|
+
|
|
294
|
+
expect(stats.embedder).toHaveProperty('totalEmbeddings');
|
|
295
|
+
expect(stats.embedder).toHaveProperty('avgLatency');
|
|
296
|
+
expect(stats.embedder).toHaveProperty('cache');
|
|
297
|
+
|
|
298
|
+
expect(stats.embedder.cache).toHaveProperty('hitRate');
|
|
299
|
+
expect(stats.embedder.cache).toHaveProperty('size');
|
|
300
|
+
});
|
|
301
|
+
});
|
|
302
|
+
});
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Comprehensive tests for ONNX Embedding Service
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
import { describe, it, expect, beforeAll, afterAll } from 'vitest';
|
|
6
|
+
import { ONNXEmbeddingService } from '../services/ONNXEmbeddingService.js';
|
|
7
|
+
|
|
8
|
+
describe('ONNXEmbeddingService', () => {
|
|
9
|
+
let embedder: ONNXEmbeddingService;
|
|
10
|
+
|
|
11
|
+
beforeAll(async () => {
|
|
12
|
+
embedder = new ONNXEmbeddingService({
|
|
13
|
+
modelName: 'Xenova/all-MiniLM-L6-v2',
|
|
14
|
+
useGPU: false, // Use CPU for tests
|
|
15
|
+
batchSize: 4,
|
|
16
|
+
cacheSize: 100
|
|
17
|
+
});
|
|
18
|
+
await embedder.initialize();
|
|
19
|
+
});
|
|
20
|
+
|
|
21
|
+
afterAll(() => {
|
|
22
|
+
embedder.clearCache();
|
|
23
|
+
});
|
|
24
|
+
|
|
25
|
+
describe('Initialization', () => {
|
|
26
|
+
it('should initialize successfully', async () => {
|
|
27
|
+
const stats = embedder.getStats();
|
|
28
|
+
expect(stats.initialized).toBe(true);
|
|
29
|
+
expect(stats.model).toBe('Xenova/all-MiniLM-L6-v2');
|
|
30
|
+
});
|
|
31
|
+
|
|
32
|
+
it('should have correct dimension', () => {
|
|
33
|
+
const dimension = embedder.getDimension();
|
|
34
|
+
expect(dimension).toBe(384);
|
|
35
|
+
});
|
|
36
|
+
});
|
|
37
|
+
|
|
38
|
+
describe('Single Embedding', () => {
|
|
39
|
+
it('should generate embedding for single text', async () => {
|
|
40
|
+
const text = 'This is a test sentence';
|
|
41
|
+
const result = await embedder.embed(text);
|
|
42
|
+
|
|
43
|
+
expect(result.embedding).toBeInstanceOf(Float32Array);
|
|
44
|
+
expect(result.embedding.length).toBe(384);
|
|
45
|
+
expect(result.latency).toBeGreaterThan(0);
|
|
46
|
+
expect(result.cached).toBe(false);
|
|
47
|
+
expect(result.model).toBe('Xenova/all-MiniLM-L6-v2');
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
it('should return cached result for same text', async () => {
|
|
51
|
+
const text = 'Cached test sentence';
|
|
52
|
+
|
|
53
|
+
// First call
|
|
54
|
+
const result1 = await embedder.embed(text);
|
|
55
|
+
expect(result1.cached).toBe(false);
|
|
56
|
+
|
|
57
|
+
// Second call should be cached
|
|
58
|
+
const result2 = await embedder.embed(text);
|
|
59
|
+
expect(result2.cached).toBe(true);
|
|
60
|
+
expect(result2.latency).toBeLessThan(result1.latency);
|
|
61
|
+
});
|
|
62
|
+
|
|
63
|
+
it('should generate different embeddings for different texts', async () => {
|
|
64
|
+
const text1 = 'First sentence';
|
|
65
|
+
const text2 = 'Second sentence';
|
|
66
|
+
|
|
67
|
+
const result1 = await embedder.embed(text1);
|
|
68
|
+
const result2 = await embedder.embed(text2);
|
|
69
|
+
|
|
70
|
+
// Embeddings should be different
|
|
71
|
+
const areDifferent = Array.from(result1.embedding).some(
|
|
72
|
+
(val, i) => val !== result2.embedding[i]
|
|
73
|
+
);
|
|
74
|
+
expect(areDifferent).toBe(true);
|
|
75
|
+
});
|
|
76
|
+
|
|
77
|
+
it('should handle empty text', async () => {
|
|
78
|
+
const result = await embedder.embed('');
|
|
79
|
+
expect(result.embedding).toBeInstanceOf(Float32Array);
|
|
80
|
+
expect(result.embedding.length).toBe(384);
|
|
81
|
+
});
|
|
82
|
+
|
|
83
|
+
it('should handle very long text', async () => {
|
|
84
|
+
const longText = 'word '.repeat(1000);
|
|
85
|
+
const result = await embedder.embed(longText);
|
|
86
|
+
expect(result.embedding).toBeInstanceOf(Float32Array);
|
|
87
|
+
expect(result.embedding.length).toBe(384);
|
|
88
|
+
});
|
|
89
|
+
});
|
|
90
|
+
|
|
91
|
+
describe('Batch Embedding', () => {
|
|
92
|
+
it('should generate embeddings for batch', async () => {
|
|
93
|
+
const texts = [
|
|
94
|
+
'First text',
|
|
95
|
+
'Second text',
|
|
96
|
+
'Third text',
|
|
97
|
+
'Fourth text'
|
|
98
|
+
];
|
|
99
|
+
|
|
100
|
+
const result = await embedder.embedBatch(texts);
|
|
101
|
+
|
|
102
|
+
expect(result.embeddings).toHaveLength(4);
|
|
103
|
+
expect(result.total).toBe(4);
|
|
104
|
+
expect(result.cached).toBe(0);
|
|
105
|
+
expect(result.latency).toBeGreaterThan(0);
|
|
106
|
+
|
|
107
|
+
result.embeddings.forEach(emb => {
|
|
108
|
+
expect(emb).toBeInstanceOf(Float32Array);
|
|
109
|
+
expect(emb.length).toBe(384);
|
|
110
|
+
});
|
|
111
|
+
});
|
|
112
|
+
|
|
113
|
+
it('should use cache for batch processing', async () => {
|
|
114
|
+
const texts = ['Cached 1', 'Cached 2', 'Cached 3'];
|
|
115
|
+
|
|
116
|
+
// First batch
|
|
117
|
+
const result1 = await embedder.embedBatch(texts);
|
|
118
|
+
expect(result1.cached).toBe(0);
|
|
119
|
+
|
|
120
|
+
// Second batch (should all be cached)
|
|
121
|
+
const result2 = await embedder.embedBatch(texts);
|
|
122
|
+
expect(result2.cached).toBe(3);
|
|
123
|
+
expect(result2.latency).toBeLessThan(result1.latency);
|
|
124
|
+
});
|
|
125
|
+
|
|
126
|
+
it('should handle large batches', async () => {
|
|
127
|
+
const texts = Array.from({ length: 50 }, (_, i) => `Batch text ${i}`);
|
|
128
|
+
const result = await embedder.embedBatch(texts);
|
|
129
|
+
|
|
130
|
+
expect(result.embeddings).toHaveLength(50);
|
|
131
|
+
expect(result.total).toBe(50);
|
|
132
|
+
});
|
|
133
|
+
|
|
134
|
+
it('should handle empty batch', async () => {
|
|
135
|
+
const result = await embedder.embedBatch([]);
|
|
136
|
+
expect(result.embeddings).toHaveLength(0);
|
|
137
|
+
expect(result.total).toBe(0);
|
|
138
|
+
});
|
|
139
|
+
});
|
|
140
|
+
|
|
141
|
+
describe('Performance', () => {
|
|
142
|
+
it('should generate embedding quickly', async () => {
|
|
143
|
+
const startTime = Date.now();
|
|
144
|
+
await embedder.embed('Performance test');
|
|
145
|
+
const latency = Date.now() - startTime;
|
|
146
|
+
|
|
147
|
+
// Should be under 1 second for first run
|
|
148
|
+
expect(latency).toBeLessThan(1000);
|
|
149
|
+
});
|
|
150
|
+
|
|
151
|
+
it('should be fast with cache', async () => {
|
|
152
|
+
const text = 'Cache speed test';
|
|
153
|
+
|
|
154
|
+
// Warm up cache
|
|
155
|
+
await embedder.embed(text);
|
|
156
|
+
|
|
157
|
+
// Measure cached access
|
|
158
|
+
const startTime = Date.now();
|
|
159
|
+
await embedder.embed(text);
|
|
160
|
+
const latency = Date.now() - startTime;
|
|
161
|
+
|
|
162
|
+
// Cached access should be < 10ms
|
|
163
|
+
expect(latency).toBeLessThan(10);
|
|
164
|
+
});
|
|
165
|
+
|
|
166
|
+
it('should show performance improvement with warmup', async () => {
|
|
167
|
+
const newEmbedder = new ONNXEmbeddingService({
|
|
168
|
+
modelName: 'Xenova/all-MiniLM-L6-v2',
|
|
169
|
+
useGPU: false
|
|
170
|
+
});
|
|
171
|
+
await newEmbedder.initialize();
|
|
172
|
+
|
|
173
|
+
// Before warmup
|
|
174
|
+
const start1 = Date.now();
|
|
175
|
+
await newEmbedder.embed('Test before warmup');
|
|
176
|
+
const beforeWarmup = Date.now() - start1;
|
|
177
|
+
|
|
178
|
+
// Warmup
|
|
179
|
+
await newEmbedder.warmup(5);
|
|
180
|
+
|
|
181
|
+
// After warmup
|
|
182
|
+
const start2 = Date.now();
|
|
183
|
+
await newEmbedder.embed('Test after warmup');
|
|
184
|
+
const afterWarmup = Date.now() - start2;
|
|
185
|
+
|
|
186
|
+
// Warmup should improve performance
|
|
187
|
+
expect(newEmbedder.getStats().warmupComplete).toBe(true);
|
|
188
|
+
});
|
|
189
|
+
});
|
|
190
|
+
|
|
191
|
+
describe('Cache Management', () => {
|
|
192
|
+
it('should respect cache size limit', async () => {
|
|
193
|
+
const smallCache = new ONNXEmbeddingService({
|
|
194
|
+
modelName: 'Xenova/all-MiniLM-L6-v2',
|
|
195
|
+
useGPU: false,
|
|
196
|
+
cacheSize: 5
|
|
197
|
+
});
|
|
198
|
+
await smallCache.initialize();
|
|
199
|
+
|
|
200
|
+
// Add 10 items (should evict 5)
|
|
201
|
+
for (let i = 0; i < 10; i++) {
|
|
202
|
+
await smallCache.embed(`Cache test ${i}`);
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
const stats = smallCache.getStats();
|
|
206
|
+
expect(stats.cache.size).toBeLessThanOrEqual(5);
|
|
207
|
+
});
|
|
208
|
+
|
|
209
|
+
it('should clear cache', async () => {
|
|
210
|
+
await embedder.embed('Test 1');
|
|
211
|
+
await embedder.embed('Test 2');
|
|
212
|
+
|
|
213
|
+
let stats = embedder.getStats();
|
|
214
|
+
expect(stats.cache.size).toBeGreaterThan(0);
|
|
215
|
+
|
|
216
|
+
embedder.clearCache();
|
|
217
|
+
|
|
218
|
+
stats = embedder.getStats();
|
|
219
|
+
expect(stats.cache.size).toBe(0);
|
|
220
|
+
});
|
|
221
|
+
|
|
222
|
+
it('should track cache hit rate', async () => {
|
|
223
|
+
embedder.clearCache();
|
|
224
|
+
|
|
225
|
+
// Generate some embeddings
|
|
226
|
+
await embedder.embed('Hit rate test 1');
|
|
227
|
+
await embedder.embed('Hit rate test 2');
|
|
228
|
+
|
|
229
|
+
// Access cached items
|
|
230
|
+
await embedder.embed('Hit rate test 1');
|
|
231
|
+
await embedder.embed('Hit rate test 2');
|
|
232
|
+
|
|
233
|
+
const stats = embedder.getStats();
|
|
234
|
+
expect(stats.cache.hitRate).toBeGreaterThan(0);
|
|
235
|
+
});
|
|
236
|
+
});
|
|
237
|
+
|
|
238
|
+
describe('Statistics', () => {
|
|
239
|
+
it('should track total embeddings', async () => {
|
|
240
|
+
const initialStats = embedder.getStats();
|
|
241
|
+
const initialCount = initialStats.totalEmbeddings;
|
|
242
|
+
|
|
243
|
+
await embedder.embed('Stats test 1');
|
|
244
|
+
await embedder.embed('Stats test 2');
|
|
245
|
+
|
|
246
|
+
const newStats = embedder.getStats();
|
|
247
|
+
expect(newStats.totalEmbeddings).toBe(initialCount + 2);
|
|
248
|
+
});
|
|
249
|
+
|
|
250
|
+
it('should track average latency', async () => {
|
|
251
|
+
const stats = embedder.getStats();
|
|
252
|
+
expect(stats.avgLatency).toBeGreaterThan(0);
|
|
253
|
+
});
|
|
254
|
+
|
|
255
|
+
it('should track batch sizes', async () => {
|
|
256
|
+
await embedder.embedBatch(['Batch 1', 'Batch 2', 'Batch 3']);
|
|
257
|
+
const stats = embedder.getStats();
|
|
258
|
+
expect(stats.avgBatchSize).toBeGreaterThan(0);
|
|
259
|
+
});
|
|
260
|
+
});
|
|
261
|
+
|
|
262
|
+
describe('Error Handling', () => {
|
|
263
|
+
it('should throw error if not initialized', async () => {
|
|
264
|
+
const uninitialized = new ONNXEmbeddingService({
|
|
265
|
+
modelName: 'Xenova/all-MiniLM-L6-v2'
|
|
266
|
+
});
|
|
267
|
+
|
|
268
|
+
await expect(uninitialized.embed('Test')).rejects.toThrow('not initialized');
|
|
269
|
+
});
|
|
270
|
+
});
|
|
271
|
+
|
|
272
|
+
describe('Similarity', () => {
|
|
273
|
+
it('should generate similar embeddings for similar texts', async () => {
|
|
274
|
+
const text1 = 'The cat sits on the mat';
|
|
275
|
+
const text2 = 'A cat is sitting on a mat';
|
|
276
|
+
|
|
277
|
+
const result1 = await embedder.embed(text1);
|
|
278
|
+
const result2 = await embedder.embed(text2);
|
|
279
|
+
|
|
280
|
+
// Calculate cosine similarity
|
|
281
|
+
const similarity = cosineSimilarity(result1.embedding, result2.embedding);
|
|
282
|
+
|
|
283
|
+
// Similar texts should have high similarity (>0.7)
|
|
284
|
+
expect(similarity).toBeGreaterThan(0.7);
|
|
285
|
+
});
|
|
286
|
+
|
|
287
|
+
it('should generate dissimilar embeddings for different texts', async () => {
|
|
288
|
+
const text1 = 'The weather is sunny today';
|
|
289
|
+
const text2 = 'Quantum physics is fascinating';
|
|
290
|
+
|
|
291
|
+
const result1 = await embedder.embed(text1);
|
|
292
|
+
const result2 = await embedder.embed(text2);
|
|
293
|
+
|
|
294
|
+
const similarity = cosineSimilarity(result1.embedding, result2.embedding);
|
|
295
|
+
|
|
296
|
+
// Different texts should have lower similarity (<0.7)
|
|
297
|
+
expect(similarity).toBeLessThan(0.7);
|
|
298
|
+
});
|
|
299
|
+
});
|
|
300
|
+
});
|
|
301
|
+
|
|
302
|
+
/**
|
|
303
|
+
* Helper: Calculate cosine similarity
|
|
304
|
+
*/
|
|
305
|
+
function cosineSimilarity(a: Float32Array, b: Float32Array): number {
|
|
306
|
+
let dotProduct = 0;
|
|
307
|
+
let normA = 0;
|
|
308
|
+
let normB = 0;
|
|
309
|
+
|
|
310
|
+
for (let i = 0; i < a.length; i++) {
|
|
311
|
+
dotProduct += a[i] * b[i];
|
|
312
|
+
normA += a[i] * a[i];
|
|
313
|
+
normB += b[i] * b[i];
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
|
317
|
+
}
|
package/tsconfig.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
{
|
|
2
|
+
"compilerOptions": {
|
|
3
|
+
"target": "ES2022",
|
|
4
|
+
"module": "ES2022",
|
|
5
|
+
"lib": ["ES2022"],
|
|
6
|
+
"moduleResolution": "node",
|
|
7
|
+
"esModuleInterop": true,
|
|
8
|
+
"resolveJsonModule": true,
|
|
9
|
+
"declaration": true,
|
|
10
|
+
"outDir": "./dist",
|
|
11
|
+
"rootDir": "./src",
|
|
12
|
+
"strict": false,
|
|
13
|
+
"skipLibCheck": true,
|
|
14
|
+
"forceConsistentCasingInFileNames": true,
|
|
15
|
+
"allowSyntheticDefaultImports": true
|
|
16
|
+
},
|
|
17
|
+
"include": ["src/**/*"],
|
|
18
|
+
"exclude": ["node_modules", "dist", "**/*.test.ts"]
|
|
19
|
+
}
|