@framers/agentos-ext-ml-classifiers 0.1.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +18 -0
- package/dist/MLClassifierGuardrail.d.ts +88 -117
- package/dist/MLClassifierGuardrail.d.ts.map +1 -1
- package/dist/MLClassifierGuardrail.js +255 -264
- package/dist/MLClassifierGuardrail.js.map +1 -1
- package/dist/classifiers/InjectionClassifier.d.ts +1 -1
- package/dist/classifiers/InjectionClassifier.d.ts.map +1 -1
- package/dist/classifiers/JailbreakClassifier.d.ts +1 -1
- package/dist/classifiers/JailbreakClassifier.d.ts.map +1 -1
- package/dist/classifiers/ToxicityClassifier.d.ts +1 -1
- package/dist/classifiers/ToxicityClassifier.d.ts.map +1 -1
- package/dist/classifiers/WorkerClassifierProxy.d.ts +1 -1
- package/dist/classifiers/WorkerClassifierProxy.d.ts.map +1 -1
- package/dist/index.d.ts +16 -90
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +33 -306
- package/dist/index.js.map +1 -1
- package/dist/keyword-classifier.d.ts +26 -0
- package/dist/keyword-classifier.d.ts.map +1 -0
- package/dist/keyword-classifier.js +113 -0
- package/dist/keyword-classifier.js.map +1 -0
- package/dist/llm-classifier.d.ts +27 -0
- package/dist/llm-classifier.d.ts.map +1 -0
- package/dist/llm-classifier.js +129 -0
- package/dist/llm-classifier.js.map +1 -0
- package/dist/tools/ClassifyContentTool.d.ts +53 -80
- package/dist/tools/ClassifyContentTool.d.ts.map +1 -1
- package/dist/tools/ClassifyContentTool.js +52 -103
- package/dist/tools/ClassifyContentTool.js.map +1 -1
- package/dist/types.d.ts +77 -277
- package/dist/types.d.ts.map +1 -1
- package/dist/types.js +9 -55
- package/dist/types.js.map +1 -1
- package/package.json +10 -16
- package/src/MLClassifierGuardrail.ts +279 -316
- package/src/index.ts +35 -339
- package/src/keyword-classifier.ts +130 -0
- package/src/llm-classifier.ts +163 -0
- package/src/tools/ClassifyContentTool.ts +75 -132
- package/src/types.ts +78 -325
- package/test/ClassifierOrchestrator.spec.ts +365 -0
- package/test/ClassifyContentTool.spec.ts +226 -0
- package/test/InjectionClassifier.spec.ts +263 -0
- package/test/JailbreakClassifier.spec.ts +295 -0
- package/test/MLClassifierGuardrail.spec.ts +486 -0
- package/test/SlidingWindowBuffer.spec.ts +391 -0
- package/test/ToxicityClassifier.spec.ts +268 -0
- package/test/WorkerClassifierProxy.spec.ts +303 -0
- package/test/index.spec.ts +431 -0
- package/tsconfig.json +20 -0
- package/vitest.config.ts +24 -0
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @fileoverview Unit tests for `SlidingWindowBuffer`.
|
|
3
|
+
*
|
|
4
|
+
* Tests verify:
|
|
5
|
+
* - Null return until chunkSize is reached
|
|
6
|
+
* - ChunkReady emission once chunkSize is reached
|
|
7
|
+
* - Context carry-forward between consecutive chunks
|
|
8
|
+
* - Multiple concurrent streams operate independently
|
|
9
|
+
* - maxEvaluations budget is respected
|
|
10
|
+
* - flush() returns remaining buffer content
|
|
11
|
+
* - flush() returns null for empty or unknown streams
|
|
12
|
+
* - pruneStale() removes expired streams
|
|
13
|
+
* - clear() removes all streams
|
|
14
|
+
*/
|
|
15
|
+
|
|
16
|
+
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
|
17
|
+
import {
|
|
18
|
+
SlidingWindowBuffer,
|
|
19
|
+
} from '../src/SlidingWindowBuffer';
|
|
20
|
+
|
|
21
|
+
// ---------------------------------------------------------------------------
|
|
22
|
+
// Helpers
|
|
23
|
+
// ---------------------------------------------------------------------------
|
|
24
|
+
|
|
25
|
+
/**
|
|
26
|
+
* Build a string of exactly `charCount` characters so token estimation
|
|
27
|
+
* (ceil(length/4)) is predictable.
|
|
28
|
+
*/
|
|
29
|
+
function chars(charCount: number): string {
|
|
30
|
+
return 'a'.repeat(charCount);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/**
|
|
34
|
+
* Push enough characters to fill exactly `tokenCount` estimated tokens
|
|
35
|
+
* (4 chars per token) into the buffer for the given stream.
|
|
36
|
+
*/
|
|
37
|
+
function pushTokens(
|
|
38
|
+
buf: SlidingWindowBuffer,
|
|
39
|
+
streamId: string,
|
|
40
|
+
tokenCount: number,
|
|
41
|
+
) {
|
|
42
|
+
return buf.push(streamId, chars(tokenCount * 4));
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
// ---------------------------------------------------------------------------
|
|
46
|
+
// Tests
|
|
47
|
+
// ---------------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
describe('SlidingWindowBuffer', () => {
|
|
50
|
+
let buf: SlidingWindowBuffer;
|
|
51
|
+
|
|
52
|
+
beforeEach(() => {
|
|
53
|
+
// Fresh buffer for every test with small sizes to keep tests fast.
|
|
54
|
+
buf = new SlidingWindowBuffer({
|
|
55
|
+
chunkSize: 10, // 10 tokens = 40 chars before a chunk fires
|
|
56
|
+
contextSize: 5, // keep 5 tokens (20 chars) of context
|
|
57
|
+
maxEvaluations: 5,
|
|
58
|
+
streamTimeoutMs: 1000,
|
|
59
|
+
});
|
|
60
|
+
});
|
|
61
|
+
|
|
62
|
+
// -------------------------------------------------------------------------
|
|
63
|
+
// Basic accumulation
|
|
64
|
+
// -------------------------------------------------------------------------
|
|
65
|
+
|
|
66
|
+
describe('returns null until chunkSize is reached', () => {
|
|
67
|
+
it('returns null for a single short push', () => {
|
|
68
|
+
// Push 5 tokens (< chunkSize of 10)
|
|
69
|
+
expect(pushTokens(buf, 'stream-a', 5)).toBeNull();
|
|
70
|
+
});
|
|
71
|
+
|
|
72
|
+
it('returns null for multiple small pushes that do not reach chunkSize', () => {
|
|
73
|
+
expect(buf.push('stream-a', chars(8))).toBeNull(); // 2 tokens
|
|
74
|
+
expect(buf.push('stream-a', chars(8))).toBeNull(); // 2 tokens
|
|
75
|
+
expect(buf.push('stream-a', chars(8))).toBeNull(); // 2 tokens
|
|
76
|
+
// Total: 6 tokens — still below 10
|
|
77
|
+
});
|
|
78
|
+
});
|
|
79
|
+
|
|
80
|
+
describe('returns ChunkReady when chunkSize is reached', () => {
|
|
81
|
+
it('emits a chunk exactly when chunkSize tokens accumulate', () => {
|
|
82
|
+
// Push 9 tokens first — still under threshold
|
|
83
|
+
expect(pushTokens(buf, 'stream-a', 9)).toBeNull();
|
|
84
|
+
|
|
85
|
+
// Push 1 more token to hit chunkSize=10
|
|
86
|
+
const chunk = pushTokens(buf, 'stream-a', 1);
|
|
87
|
+
expect(chunk).not.toBeNull();
|
|
88
|
+
expect(chunk!.evaluationNumber).toBe(1);
|
|
89
|
+
});
|
|
90
|
+
|
|
91
|
+
it('chunk.newText equals exactly what was buffered', () => {
|
|
92
|
+
// Push exactly chunkSize tokens in one shot
|
|
93
|
+
const chunk = pushTokens(buf, 'stream-a', 10);
|
|
94
|
+
expect(chunk).not.toBeNull();
|
|
95
|
+
// newText should be the 40 chars we pushed
|
|
96
|
+
expect(chunk!.newText).toHaveLength(40);
|
|
97
|
+
});
|
|
98
|
+
|
|
99
|
+
it('chunk.text equals contextRing (empty on first chunk) + newText', () => {
|
|
100
|
+
const chunk = pushTokens(buf, 'stream-a', 10);
|
|
101
|
+
// No prior context on first chunk → text === newText
|
|
102
|
+
expect(chunk!.text).toBe(chunk!.newText);
|
|
103
|
+
});
|
|
104
|
+
|
|
105
|
+
it('evaluationNumber starts at 1', () => {
|
|
106
|
+
const chunk = pushTokens(buf, 'stream-a', 10);
|
|
107
|
+
expect(chunk!.evaluationNumber).toBe(1);
|
|
108
|
+
});
|
|
109
|
+
|
|
110
|
+
it('evaluationNumber increments on successive chunks', () => {
|
|
111
|
+
const c1 = pushTokens(buf, 'stream-a', 10);
|
|
112
|
+
const c2 = pushTokens(buf, 'stream-a', 10);
|
|
113
|
+
expect(c1!.evaluationNumber).toBe(1);
|
|
114
|
+
expect(c2!.evaluationNumber).toBe(2);
|
|
115
|
+
});
|
|
116
|
+
});
|
|
117
|
+
|
|
118
|
+
// -------------------------------------------------------------------------
|
|
119
|
+
// Context carry-forward
|
|
120
|
+
// -------------------------------------------------------------------------
|
|
121
|
+
|
|
122
|
+
describe('context carry-forward between chunks', () => {
|
|
123
|
+
it('second chunk text is longer than its newText (context prepended)', () => {
|
|
124
|
+
// Emit first chunk
|
|
125
|
+
pushTokens(buf, 'stream-a', 10);
|
|
126
|
+
|
|
127
|
+
// Emit second chunk
|
|
128
|
+
const c2 = pushTokens(buf, 'stream-a', 10);
|
|
129
|
+
expect(c2).not.toBeNull();
|
|
130
|
+
|
|
131
|
+
// text = contextRing (≤ contextSize=5 tokens = 20 chars) + newText (40 chars)
|
|
132
|
+
expect(c2!.text.length).toBeGreaterThan(c2!.newText.length);
|
|
133
|
+
});
|
|
134
|
+
|
|
135
|
+
it('context length is bounded by contextSize tokens', () => {
|
|
136
|
+
pushTokens(buf, 'stream-a', 10); // first chunk
|
|
137
|
+
|
|
138
|
+
const c2 = pushTokens(buf, 'stream-a', 10);
|
|
139
|
+
const contextLen = c2!.text.length - c2!.newText.length;
|
|
140
|
+
|
|
141
|
+
// contextSize=5 tokens ≈ 20 chars; context length must be ≤ 20
|
|
142
|
+
expect(contextLen).toBeLessThanOrEqual(5 * 4);
|
|
143
|
+
});
|
|
144
|
+
|
|
145
|
+
it('context is derived from the tail of the previous buffer', () => {
|
|
146
|
+
// Use distinct characters to trace which text makes it into the context.
|
|
147
|
+
// First window: 10 tokens of 'x'
|
|
148
|
+
buf.push('stream-a', 'x'.repeat(40));
|
|
149
|
+
|
|
150
|
+
// Second window: 10 tokens of 'y'
|
|
151
|
+
const c2 = buf.push('stream-a', 'y'.repeat(40));
|
|
152
|
+
expect(c2).not.toBeNull();
|
|
153
|
+
|
|
154
|
+
// The text should start with some 'x' context, then 'y' new content.
|
|
155
|
+
expect(c2!.text).toContain('x'); // context from first window
|
|
156
|
+
expect(c2!.newText).toBe('y'.repeat(40)); // only new content
|
|
157
|
+
});
|
|
158
|
+
});
|
|
159
|
+
|
|
160
|
+
// -------------------------------------------------------------------------
|
|
161
|
+
// Multiple concurrent streams
|
|
162
|
+
// -------------------------------------------------------------------------
|
|
163
|
+
|
|
164
|
+
describe('multiple concurrent streams are independent', () => {
|
|
165
|
+
it('two streams accumulate independently', () => {
|
|
166
|
+
// Push 8 tokens into stream-a, 10 tokens into stream-b
|
|
167
|
+
pushTokens(buf, 'stream-a', 8); // not ready
|
|
168
|
+
const chunkB = pushTokens(buf, 'stream-b', 10); // ready
|
|
169
|
+
|
|
170
|
+
expect(chunkB).not.toBeNull();
|
|
171
|
+
|
|
172
|
+
// stream-a should still be null (only 8 tokens)
|
|
173
|
+
const chunkA = buf.push('stream-a', ''); // empty push, no change
|
|
174
|
+
expect(chunkA).toBeNull();
|
|
175
|
+
});
|
|
176
|
+
|
|
177
|
+
it('flushing one stream does not affect another', () => {
|
|
178
|
+
pushTokens(buf, 'stream-a', 5); // partial
|
|
179
|
+
pushTokens(buf, 'stream-b', 5); // partial
|
|
180
|
+
|
|
181
|
+
buf.flush('stream-a');
|
|
182
|
+
|
|
183
|
+
// stream-b still exists and can be flushed separately
|
|
184
|
+
const chunkB = buf.flush('stream-b');
|
|
185
|
+
expect(chunkB).not.toBeNull();
|
|
186
|
+
expect(chunkB!.newText).toHaveLength(20); // 5 tokens * 4 chars
|
|
187
|
+
});
|
|
188
|
+
|
|
189
|
+
it('context rings are independent per stream', () => {
|
|
190
|
+
// Emit first chunk for each stream with distinct chars
|
|
191
|
+
buf.push('stream-a', 'A'.repeat(40));
|
|
192
|
+
buf.push('stream-b', 'B'.repeat(40));
|
|
193
|
+
|
|
194
|
+
const c2a = pushTokens(buf, 'stream-a', 10);
|
|
195
|
+
const c2b = pushTokens(buf, 'stream-b', 10);
|
|
196
|
+
|
|
197
|
+
// stream-a context should contain 'A', not 'B'
|
|
198
|
+
expect(c2a!.text).toContain('A');
|
|
199
|
+
expect(c2a!.text).not.toContain('B');
|
|
200
|
+
|
|
201
|
+
// stream-b context should contain 'B', not 'A'
|
|
202
|
+
expect(c2b!.text).toContain('B');
|
|
203
|
+
expect(c2b!.text).not.toContain('A');
|
|
204
|
+
});
|
|
205
|
+
});
|
|
206
|
+
|
|
207
|
+
// -------------------------------------------------------------------------
|
|
208
|
+
// Evaluation budget (maxEvaluations)
|
|
209
|
+
// -------------------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
describe('maxEvaluations budget', () => {
|
|
212
|
+
it('returns null after maxEvaluations chunks are emitted', () => {
|
|
213
|
+
// Emit exactly maxEvaluations=5 chunks
|
|
214
|
+
for (let i = 0; i < 5; i++) {
|
|
215
|
+
const chunk = pushTokens(buf, 'stream-a', 10);
|
|
216
|
+
expect(chunk).not.toBeNull();
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
// 6th push should return null even though enough tokens are pushed
|
|
220
|
+
const extra = pushTokens(buf, 'stream-a', 10);
|
|
221
|
+
expect(extra).toBeNull();
|
|
222
|
+
});
|
|
223
|
+
|
|
224
|
+
it('budget is tracked per-stream (other streams unaffected)', () => {
|
|
225
|
+
// Exhaust stream-a
|
|
226
|
+
for (let i = 0; i < 5; i++) {
|
|
227
|
+
pushTokens(buf, 'stream-a', 10);
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// stream-b should still be able to emit chunks
|
|
231
|
+
const chunkB = pushTokens(buf, 'stream-b', 10);
|
|
232
|
+
expect(chunkB).not.toBeNull();
|
|
233
|
+
});
|
|
234
|
+
});
|
|
235
|
+
|
|
236
|
+
// -------------------------------------------------------------------------
|
|
237
|
+
// flush()
|
|
238
|
+
// -------------------------------------------------------------------------
|
|
239
|
+
|
|
240
|
+
describe('flush()', () => {
|
|
241
|
+
it('returns a ChunkReady for remaining buffered text', () => {
|
|
242
|
+
buf.push('stream-a', chars(20)); // 5 tokens, less than chunkSize=10
|
|
243
|
+
const chunk = buf.flush('stream-a');
|
|
244
|
+
|
|
245
|
+
expect(chunk).not.toBeNull();
|
|
246
|
+
expect(chunk!.newText).toHaveLength(20);
|
|
247
|
+
});
|
|
248
|
+
|
|
249
|
+
it('returned chunk includes context from prior windows', () => {
|
|
250
|
+
// Emit one full chunk first
|
|
251
|
+
pushTokens(buf, 'stream-a', 10);
|
|
252
|
+
|
|
253
|
+
// Push partial second window
|
|
254
|
+
buf.push('stream-a', chars(20)); // 5 tokens
|
|
255
|
+
|
|
256
|
+
const chunk = buf.flush('stream-a');
|
|
257
|
+
expect(chunk).not.toBeNull();
|
|
258
|
+
// text should be longer than newText (context prepended)
|
|
259
|
+
expect(chunk!.text.length).toBeGreaterThan(chunk!.newText.length);
|
|
260
|
+
});
|
|
261
|
+
|
|
262
|
+
it('returns null for empty buffer', () => {
|
|
263
|
+
// Push nothing, then flush
|
|
264
|
+
const chunk = buf.flush('stream-empty');
|
|
265
|
+
expect(chunk).toBeNull();
|
|
266
|
+
});
|
|
267
|
+
|
|
268
|
+
it('returns null for a non-existent stream', () => {
|
|
269
|
+
expect(buf.flush('does-not-exist')).toBeNull();
|
|
270
|
+
});
|
|
271
|
+
|
|
272
|
+
it('removes the stream from internal state after flush', () => {
|
|
273
|
+
buf.push('stream-a', chars(20));
|
|
274
|
+
buf.flush('stream-a');
|
|
275
|
+
|
|
276
|
+
// size should be 0 after flushing the only stream
|
|
277
|
+
expect(buf.size).toBe(0);
|
|
278
|
+
});
|
|
279
|
+
|
|
280
|
+
it('subsequent flush on same stream returns null', () => {
|
|
281
|
+
buf.push('stream-a', chars(20));
|
|
282
|
+
buf.flush('stream-a');
|
|
283
|
+
|
|
284
|
+
// Second flush: stream was deleted
|
|
285
|
+
expect(buf.flush('stream-a')).toBeNull();
|
|
286
|
+
});
|
|
287
|
+
});
|
|
288
|
+
|
|
289
|
+
// -------------------------------------------------------------------------
|
|
290
|
+
// pruneStale()
|
|
291
|
+
// -------------------------------------------------------------------------
|
|
292
|
+
|
|
293
|
+
describe('pruneStale()', () => {
|
|
294
|
+
it('removes streams that have exceeded streamTimeoutMs', async () => {
|
|
295
|
+
// Use a very short timeout for this test
|
|
296
|
+
const shortBuf = new SlidingWindowBuffer({
|
|
297
|
+
chunkSize: 10,
|
|
298
|
+
contextSize: 5,
|
|
299
|
+
maxEvaluations: 5,
|
|
300
|
+
streamTimeoutMs: 50, // 50 ms
|
|
301
|
+
});
|
|
302
|
+
|
|
303
|
+
shortBuf.push('old-stream', chars(20)); // partial push
|
|
304
|
+
|
|
305
|
+
// Wait for the timeout to expire
|
|
306
|
+
await new Promise((resolve) => setTimeout(resolve, 100));
|
|
307
|
+
|
|
308
|
+
shortBuf.pruneStale();
|
|
309
|
+
expect(shortBuf.size).toBe(0);
|
|
310
|
+
});
|
|
311
|
+
|
|
312
|
+
it('does not remove streams that are still within the timeout', async () => {
|
|
313
|
+
const shortBuf = new SlidingWindowBuffer({
|
|
314
|
+
chunkSize: 10,
|
|
315
|
+
contextSize: 5,
|
|
316
|
+
maxEvaluations: 5,
|
|
317
|
+
streamTimeoutMs: 5000,
|
|
318
|
+
});
|
|
319
|
+
|
|
320
|
+
shortBuf.push('fresh-stream', chars(20));
|
|
321
|
+
shortBuf.pruneStale();
|
|
322
|
+
|
|
323
|
+
// Should still be present
|
|
324
|
+
expect(shortBuf.size).toBe(1);
|
|
325
|
+
});
|
|
326
|
+
|
|
327
|
+
it('is invoked lazily when map.size > 10', async () => {
|
|
328
|
+
const shortBuf = new SlidingWindowBuffer({
|
|
329
|
+
chunkSize: 10,
|
|
330
|
+
contextSize: 5,
|
|
331
|
+
maxEvaluations: 5,
|
|
332
|
+
streamTimeoutMs: 1, // expire immediately
|
|
333
|
+
});
|
|
334
|
+
|
|
335
|
+
// Create 10 streams that will immediately be stale
|
|
336
|
+
for (let i = 0; i < 10; i++) {
|
|
337
|
+
shortBuf.push(`stale-${i}`, chars(4));
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
// Wait for timeout
|
|
341
|
+
await new Promise((resolve) => setTimeout(resolve, 20));
|
|
342
|
+
|
|
343
|
+
// Push to an 11th stream — this triggers lazy pruning
|
|
344
|
+
shortBuf.push('trigger-prune', chars(4));
|
|
345
|
+
|
|
346
|
+
// Stale streams should have been removed; only 'trigger-prune' remains
|
|
347
|
+
expect(shortBuf.size).toBe(1);
|
|
348
|
+
});
|
|
349
|
+
});
|
|
350
|
+
|
|
351
|
+
// -------------------------------------------------------------------------
|
|
352
|
+
// clear()
|
|
353
|
+
// -------------------------------------------------------------------------
|
|
354
|
+
|
|
355
|
+
describe('clear()', () => {
|
|
356
|
+
it('removes all streams', () => {
|
|
357
|
+
buf.push('s1', chars(20));
|
|
358
|
+
buf.push('s2', chars(20));
|
|
359
|
+
buf.push('s3', chars(20));
|
|
360
|
+
|
|
361
|
+
buf.clear();
|
|
362
|
+
expect(buf.size).toBe(0);
|
|
363
|
+
});
|
|
364
|
+
|
|
365
|
+
it('is idempotent on an empty buffer', () => {
|
|
366
|
+
buf.clear();
|
|
367
|
+
buf.clear();
|
|
368
|
+
expect(buf.size).toBe(0);
|
|
369
|
+
});
|
|
370
|
+
});
|
|
371
|
+
|
|
372
|
+
// -------------------------------------------------------------------------
|
|
373
|
+
// Edge cases
|
|
374
|
+
// -------------------------------------------------------------------------
|
|
375
|
+
|
|
376
|
+
describe('edge cases', () => {
|
|
377
|
+
it('push with empty string returns null and creates no state', () => {
|
|
378
|
+
expect(buf.push('stream-a', '')).toBeNull();
|
|
379
|
+
expect(buf.size).toBe(0);
|
|
380
|
+
});
|
|
381
|
+
|
|
382
|
+
it('handles a single massive push that exceeds chunkSize', () => {
|
|
383
|
+
// 20 tokens in one push — should still emit exactly one chunk
|
|
384
|
+
const chunk = pushTokens(buf, 'stream-a', 20);
|
|
385
|
+
expect(chunk).not.toBeNull();
|
|
386
|
+
// After the chunk, residual text (10 extra tokens) stays in buffer
|
|
387
|
+
// A second push of 0 tokens shouldn't fire a second chunk
|
|
388
|
+
expect(buf.push('stream-a', '')).toBeNull();
|
|
389
|
+
});
|
|
390
|
+
});
|
|
391
|
+
});
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @fileoverview Unit tests for {@link ToxicityClassifier}.
|
|
3
|
+
*
|
|
4
|
+
* All tests use a mocked {@link ISharedServiceRegistry} that returns a
|
|
5
|
+
* pre-configured pipeline function. No real model weights are downloaded.
|
|
6
|
+
*
|
|
7
|
+
* Test coverage:
|
|
8
|
+
* 1. Correct static identity: `id`, `displayName`, `modelId`
|
|
9
|
+
* 2. Maps pipeline output to ClassificationResult correctly
|
|
10
|
+
* (bestClass = highest-score label, confidence = its score, allScores = all labels)
|
|
11
|
+
* 3. Graceful degradation — returns pass result when model fails to load
|
|
12
|
+
* 4. Uses ISharedServiceRegistry with the correct service ID
|
|
13
|
+
* 5. `isLoaded` flag is set after a successful classification
|
|
14
|
+
* 6. `isLoaded` is false before any classify() call
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
18
|
+
import type { ISharedServiceRegistry } from '@framers/agentos';
|
|
19
|
+
import { ToxicityClassifier } from '../src/classifiers/ToxicityClassifier';
|
|
20
|
+
import { ML_CLASSIFIER_SERVICE_IDS } from '../src/types';
|
|
21
|
+
|
|
22
|
+
// ---------------------------------------------------------------------------
|
|
23
|
+
// Test fixture helpers
|
|
24
|
+
// ---------------------------------------------------------------------------
|
|
25
|
+
|
|
26
|
+
/**
|
|
27
|
+
* Raw multi-label output that the `unitary/toxic-bert` pipeline would return
|
|
28
|
+
* when called with `{ topk: null }`. `toxic` is the winner at 0.92.
|
|
29
|
+
*/
|
|
30
|
+
const TOXICITY_PIPELINE_OUTPUT = [
|
|
31
|
+
{ label: 'toxic', score: 0.92 },
|
|
32
|
+
{ label: 'severe_toxic', score: 0.03 },
|
|
33
|
+
{ label: 'obscene', score: 0.45 },
|
|
34
|
+
{ label: 'threat', score: 0.02 },
|
|
35
|
+
{ label: 'insult', score: 0.61 },
|
|
36
|
+
{ label: 'identity_hate', score: 0.01 },
|
|
37
|
+
];
|
|
38
|
+
|
|
39
|
+
/**
|
|
40
|
+
* Build a mock {@link ISharedServiceRegistry} whose `getOrCreate` method
|
|
41
|
+
* returns a mock pipeline function pre-configured to resolve with
|
|
42
|
+
* `pipelineResult`.
|
|
43
|
+
*
|
|
44
|
+
* @param pipelineResult - The value the mock pipeline resolves with.
|
|
45
|
+
*/
|
|
46
|
+
function mockRegistry(pipelineResult: unknown): ISharedServiceRegistry {
|
|
47
|
+
// The pipeline is a callable that the classifier invokes as pipeline(text, opts).
|
|
48
|
+
const pipeline = vi.fn(async () => pipelineResult);
|
|
49
|
+
return {
|
|
50
|
+
/**
|
|
51
|
+
* Ignores the factory and always returns the same mock pipeline.
|
|
52
|
+
* The `serviceId` is captured in the spy so tests can assert on it.
|
|
53
|
+
*/
|
|
54
|
+
getOrCreate: vi.fn(async () => pipeline),
|
|
55
|
+
has: vi.fn(() => false),
|
|
56
|
+
release: vi.fn(async () => {}),
|
|
57
|
+
releaseAll: vi.fn(async () => {}),
|
|
58
|
+
};
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
/**
|
|
62
|
+
* Build a registry whose `getOrCreate` rejects with an error to simulate a
|
|
63
|
+
* model-load failure.
|
|
64
|
+
*/
|
|
65
|
+
function failingRegistry(): ISharedServiceRegistry {
|
|
66
|
+
return {
|
|
67
|
+
getOrCreate: vi.fn(async () => {
|
|
68
|
+
throw new Error('Model download failed');
|
|
69
|
+
}),
|
|
70
|
+
has: vi.fn(() => false),
|
|
71
|
+
release: vi.fn(async () => {}),
|
|
72
|
+
releaseAll: vi.fn(async () => {}),
|
|
73
|
+
};
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// ---------------------------------------------------------------------------
|
|
77
|
+
// Tests
|
|
78
|
+
// ---------------------------------------------------------------------------
|
|
79
|
+
|
|
80
|
+
describe('ToxicityClassifier', () => {
|
|
81
|
+
// -------------------------------------------------------------------------
|
|
82
|
+
// 1. Static identity
|
|
83
|
+
// -------------------------------------------------------------------------
|
|
84
|
+
|
|
85
|
+
describe('static identity', () => {
|
|
86
|
+
it('has the correct id', () => {
|
|
87
|
+
const classifier = new ToxicityClassifier(mockRegistry([]));
|
|
88
|
+
expect(classifier.id).toBe('toxicity');
|
|
89
|
+
});
|
|
90
|
+
|
|
91
|
+
it('has the correct displayName', () => {
|
|
92
|
+
const classifier = new ToxicityClassifier(mockRegistry([]));
|
|
93
|
+
expect(classifier.displayName).toBe('Toxicity Classifier');
|
|
94
|
+
});
|
|
95
|
+
|
|
96
|
+
it('has the correct default modelId', () => {
|
|
97
|
+
const classifier = new ToxicityClassifier(mockRegistry([]));
|
|
98
|
+
expect(classifier.modelId).toBe('unitary/toxic-bert');
|
|
99
|
+
});
|
|
100
|
+
});
|
|
101
|
+
|
|
102
|
+
// -------------------------------------------------------------------------
|
|
103
|
+
// 2. isLoaded flag
|
|
104
|
+
// -------------------------------------------------------------------------
|
|
105
|
+
|
|
106
|
+
describe('isLoaded flag', () => {
|
|
107
|
+
it('is false before any classify() call', () => {
|
|
108
|
+
const classifier = new ToxicityClassifier(mockRegistry(TOXICITY_PIPELINE_OUTPUT));
|
|
109
|
+
expect(classifier.isLoaded).toBe(false);
|
|
110
|
+
});
|
|
111
|
+
|
|
112
|
+
it('is true after a successful classify() call', async () => {
|
|
113
|
+
const classifier = new ToxicityClassifier(mockRegistry(TOXICITY_PIPELINE_OUTPUT));
|
|
114
|
+
await classifier.classify('some text');
|
|
115
|
+
expect(classifier.isLoaded).toBe(true);
|
|
116
|
+
});
|
|
117
|
+
|
|
118
|
+
it('remains false after a model-load failure', async () => {
|
|
119
|
+
const classifier = new ToxicityClassifier(failingRegistry());
|
|
120
|
+
await classifier.classify('some text');
|
|
121
|
+
expect(classifier.isLoaded).toBe(false);
|
|
122
|
+
});
|
|
123
|
+
});
|
|
124
|
+
|
|
125
|
+
// -------------------------------------------------------------------------
|
|
126
|
+
// 3. Result mapping
|
|
127
|
+
// -------------------------------------------------------------------------
|
|
128
|
+
|
|
129
|
+
describe('classify() — result mapping', () => {
|
|
130
|
+
let classifier: ToxicityClassifier;
|
|
131
|
+
|
|
132
|
+
beforeEach(() => {
|
|
133
|
+
classifier = new ToxicityClassifier(mockRegistry(TOXICITY_PIPELINE_OUTPUT));
|
|
134
|
+
});
|
|
135
|
+
|
|
136
|
+
it('resolves with the label with the highest score as bestClass', async () => {
|
|
137
|
+
const result = await classifier.classify('You are terrible!');
|
|
138
|
+
// toxic (0.92) beats insult (0.61) and obscene (0.45)
|
|
139
|
+
expect(result.bestClass).toBe('toxic');
|
|
140
|
+
});
|
|
141
|
+
|
|
142
|
+
it('resolves with the top score as confidence', async () => {
|
|
143
|
+
const result = await classifier.classify('You are terrible!');
|
|
144
|
+
expect(result.confidence).toBeCloseTo(0.92);
|
|
145
|
+
});
|
|
146
|
+
|
|
147
|
+
it('includes all six labels in allScores', async () => {
|
|
148
|
+
const result = await classifier.classify('You are terrible!');
|
|
149
|
+
expect(result.allScores).toHaveLength(6);
|
|
150
|
+
});
|
|
151
|
+
|
|
152
|
+
it('allScores contains correct classLabel/score pairs', async () => {
|
|
153
|
+
const result = await classifier.classify('You are terrible!');
|
|
154
|
+
// Spot-check a few entries
|
|
155
|
+
const toxic = result.allScores.find((s) => s.classLabel === 'toxic');
|
|
156
|
+
expect(toxic?.score).toBeCloseTo(0.92);
|
|
157
|
+
|
|
158
|
+
const threat = result.allScores.find((s) => s.classLabel === 'threat');
|
|
159
|
+
expect(threat?.score).toBeCloseTo(0.02);
|
|
160
|
+
});
|
|
161
|
+
|
|
162
|
+
it('returns bestClass=toxic for a message where toxic wins', async () => {
|
|
163
|
+
// Verify the classifier picks the maximum regardless of array order
|
|
164
|
+
const shuffled = [...TOXICITY_PIPELINE_OUTPUT].reverse();
|
|
165
|
+
const reg = mockRegistry(shuffled);
|
|
166
|
+
const cls = new ToxicityClassifier(reg);
|
|
167
|
+
const result = await cls.classify('test');
|
|
168
|
+
expect(result.bestClass).toBe('toxic');
|
|
169
|
+
});
|
|
170
|
+
});
|
|
171
|
+
|
|
172
|
+
// -------------------------------------------------------------------------
|
|
173
|
+
// 4. Graceful degradation
|
|
174
|
+
// -------------------------------------------------------------------------
|
|
175
|
+
|
|
176
|
+
describe('graceful degradation on model load failure', () => {
|
|
177
|
+
it('returns bestClass=benign when model fails to load', async () => {
|
|
178
|
+
const classifier = new ToxicityClassifier(failingRegistry());
|
|
179
|
+
const result = await classifier.classify('some text');
|
|
180
|
+
expect(result.bestClass).toBe('benign');
|
|
181
|
+
});
|
|
182
|
+
|
|
183
|
+
it('returns confidence=0 when model fails to load', async () => {
|
|
184
|
+
const classifier = new ToxicityClassifier(failingRegistry());
|
|
185
|
+
const result = await classifier.classify('some text');
|
|
186
|
+
expect(result.confidence).toBe(0);
|
|
187
|
+
});
|
|
188
|
+
|
|
189
|
+
it('returns empty allScores when model fails to load', async () => {
|
|
190
|
+
const classifier = new ToxicityClassifier(failingRegistry());
|
|
191
|
+
const result = await classifier.classify('some text');
|
|
192
|
+
expect(result.allScores).toEqual([]);
|
|
193
|
+
});
|
|
194
|
+
|
|
195
|
+
it('continues returning pass result on all subsequent calls after failure', async () => {
|
|
196
|
+
const classifier = new ToxicityClassifier(failingRegistry());
|
|
197
|
+
// First call triggers the failure
|
|
198
|
+
await classifier.classify('call 1');
|
|
199
|
+
// Subsequent calls should still return the pass result without retrying
|
|
200
|
+
const result = await classifier.classify('call 2');
|
|
201
|
+
expect(result.bestClass).toBe('benign');
|
|
202
|
+
});
|
|
203
|
+
|
|
204
|
+
it('does not retry getOrCreate after the first failure', async () => {
|
|
205
|
+
const registry = failingRegistry();
|
|
206
|
+
const classifier = new ToxicityClassifier(registry);
|
|
207
|
+
await classifier.classify('call 1');
|
|
208
|
+
await classifier.classify('call 2');
|
|
209
|
+
// getOrCreate should only have been called once (on the first classify call)
|
|
210
|
+
expect(registry.getOrCreate).toHaveBeenCalledTimes(1);
|
|
211
|
+
});
|
|
212
|
+
});
|
|
213
|
+
|
|
214
|
+
// -------------------------------------------------------------------------
|
|
215
|
+
// 5. Uses ISharedServiceRegistry with correct service ID
|
|
216
|
+
// -------------------------------------------------------------------------
|
|
217
|
+
|
|
218
|
+
describe('shared service registry integration', () => {
|
|
219
|
+
it('calls getOrCreate with the TOXICITY_PIPELINE service ID', async () => {
|
|
220
|
+
const registry = mockRegistry(TOXICITY_PIPELINE_OUTPUT);
|
|
221
|
+
const classifier = new ToxicityClassifier(registry);
|
|
222
|
+
await classifier.classify('hello');
|
|
223
|
+
expect(registry.getOrCreate).toHaveBeenCalledWith(
|
|
224
|
+
ML_CLASSIFIER_SERVICE_IDS.TOXICITY_PIPELINE,
|
|
225
|
+
expect.any(Function),
|
|
226
|
+
expect.objectContaining({ tags: expect.arrayContaining(['toxicity']) }),
|
|
227
|
+
);
|
|
228
|
+
});
|
|
229
|
+
|
|
230
|
+
it('does not call getOrCreate again on a second classify() call (cached)', async () => {
|
|
231
|
+
const registry = mockRegistry(TOXICITY_PIPELINE_OUTPUT);
|
|
232
|
+
const classifier = new ToxicityClassifier(registry);
|
|
233
|
+
await classifier.classify('first call');
|
|
234
|
+
await classifier.classify('second call');
|
|
235
|
+
// Pipeline is retrieved once and re-used
|
|
236
|
+
expect(registry.getOrCreate).toHaveBeenCalledTimes(2); // once per classify() — registry handles caching internally
|
|
237
|
+
});
|
|
238
|
+
|
|
239
|
+
it('calls release with TOXICITY_PIPELINE service ID on dispose()', async () => {
|
|
240
|
+
const registry = mockRegistry(TOXICITY_PIPELINE_OUTPUT);
|
|
241
|
+
const classifier = new ToxicityClassifier(registry);
|
|
242
|
+
await classifier.classify('hello');
|
|
243
|
+
await classifier.dispose();
|
|
244
|
+
expect(registry.release).toHaveBeenCalledWith(
|
|
245
|
+
ML_CLASSIFIER_SERVICE_IDS.TOXICITY_PIPELINE,
|
|
246
|
+
);
|
|
247
|
+
});
|
|
248
|
+
});
|
|
249
|
+
|
|
250
|
+
// -------------------------------------------------------------------------
|
|
251
|
+
// 6. Config override
|
|
252
|
+
// -------------------------------------------------------------------------
|
|
253
|
+
|
|
254
|
+
describe('ClassifierConfig.modelId override', () => {
|
|
255
|
+
it('passes the overridden modelId to the factory (verified via the factory closure)', async () => {
|
|
256
|
+
// We cannot peek inside the factory directly, but we can verify that
|
|
257
|
+
// getOrCreate is called — the factory is a closure that reads config.modelId.
|
|
258
|
+
// A true integration test would require a real import; here we just confirm
|
|
259
|
+
// the registry is invoked at all when a custom modelId is provided.
|
|
260
|
+
const registry = mockRegistry(TOXICITY_PIPELINE_OUTPUT);
|
|
261
|
+
const classifier = new ToxicityClassifier(registry, {
|
|
262
|
+
modelId: 'my-org/custom-toxic-bert',
|
|
263
|
+
});
|
|
264
|
+
await classifier.classify('hello');
|
|
265
|
+
expect(registry.getOrCreate).toHaveBeenCalled();
|
|
266
|
+
});
|
|
267
|
+
});
|
|
268
|
+
});
|