@framers/agentos-ext-ml-classifiers 0.1.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.github/workflows/ci.yml +20 -0
- package/.github/workflows/release.yml +37 -0
- package/.releaserc.json +9 -0
- package/LICENSE +96 -21
- package/README.md +72 -0
- package/dist/MLClassifierGuardrail.d.ts +88 -117
- package/dist/MLClassifierGuardrail.d.ts.map +1 -1
- package/dist/MLClassifierGuardrail.js +263 -264
- package/dist/MLClassifierGuardrail.js.map +1 -1
- package/dist/index.d.ts +16 -90
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +36 -309
- package/dist/index.js.map +1 -1
- package/dist/keyword-classifier.d.ts +26 -0
- package/dist/keyword-classifier.d.ts.map +1 -0
- package/dist/keyword-classifier.js +113 -0
- package/dist/keyword-classifier.js.map +1 -0
- package/dist/llm-classifier.d.ts +27 -0
- package/dist/llm-classifier.d.ts.map +1 -0
- package/dist/llm-classifier.js +129 -0
- package/dist/llm-classifier.js.map +1 -0
- package/dist/tools/ClassifyContentTool.d.ts +53 -80
- package/dist/tools/ClassifyContentTool.d.ts.map +1 -1
- package/dist/tools/ClassifyContentTool.js +52 -103
- package/dist/tools/ClassifyContentTool.js.map +1 -1
- package/dist/types.d.ts +77 -277
- package/dist/types.d.ts.map +1 -1
- package/dist/types.js +9 -55
- package/dist/types.js.map +1 -1
- package/package.json +10 -24
- package/scripts/fix-esm-imports.mjs +181 -0
- package/src/MLClassifierGuardrail.ts +306 -310
- package/src/index.ts +35 -339
- package/src/keyword-classifier.ts +130 -0
- package/src/llm-classifier.ts +163 -0
- package/src/tools/ClassifyContentTool.ts +75 -132
- package/src/types.ts +78 -325
- package/test/llm-tier.spec.ts +267 -0
- package/test/ml-classifiers.spec.ts +57 -0
- package/test/onnx-tier.spec.ts +255 -0
- package/test/tier-fallthrough.spec.ts +185 -0
- package/tsconfig.json +20 -0
- package/vitest.config.ts +35 -0
- package/dist/ClassifierOrchestrator.d.ts +0 -126
- package/dist/ClassifierOrchestrator.d.ts.map +0 -1
- package/dist/ClassifierOrchestrator.js +0 -239
- package/dist/ClassifierOrchestrator.js.map +0 -1
- package/dist/IContentClassifier.d.ts +0 -117
- package/dist/IContentClassifier.d.ts.map +0 -1
- package/dist/IContentClassifier.js +0 -22
- package/dist/IContentClassifier.js.map +0 -1
- package/dist/SlidingWindowBuffer.d.ts +0 -213
- package/dist/SlidingWindowBuffer.d.ts.map +0 -1
- package/dist/SlidingWindowBuffer.js +0 -246
- package/dist/SlidingWindowBuffer.js.map +0 -1
- package/dist/classifiers/InjectionClassifier.d.ts +0 -126
- package/dist/classifiers/InjectionClassifier.d.ts.map +0 -1
- package/dist/classifiers/InjectionClassifier.js +0 -210
- package/dist/classifiers/InjectionClassifier.js.map +0 -1
- package/dist/classifiers/JailbreakClassifier.d.ts +0 -124
- package/dist/classifiers/JailbreakClassifier.d.ts.map +0 -1
- package/dist/classifiers/JailbreakClassifier.js +0 -208
- package/dist/classifiers/JailbreakClassifier.js.map +0 -1
- package/dist/classifiers/ToxicityClassifier.d.ts +0 -125
- package/dist/classifiers/ToxicityClassifier.d.ts.map +0 -1
- package/dist/classifiers/ToxicityClassifier.js +0 -212
- package/dist/classifiers/ToxicityClassifier.js.map +0 -1
- package/dist/classifiers/WorkerClassifierProxy.d.ts +0 -158
- package/dist/classifiers/WorkerClassifierProxy.d.ts.map +0 -1
- package/dist/classifiers/WorkerClassifierProxy.js +0 -268
- package/dist/classifiers/WorkerClassifierProxy.js.map +0 -1
- package/dist/worker/classifier-worker.d.ts +0 -49
- package/dist/worker/classifier-worker.d.ts.map +0 -1
- package/dist/worker/classifier-worker.js +0 -180
- package/dist/worker/classifier-worker.js.map +0 -1
- package/src/ClassifierOrchestrator.ts +0 -290
- package/src/IContentClassifier.ts +0 -124
- package/src/SlidingWindowBuffer.ts +0 -384
- package/src/classifiers/InjectionClassifier.ts +0 -261
- package/src/classifiers/JailbreakClassifier.ts +0 -259
- package/src/classifiers/ToxicityClassifier.ts +0 -263
- package/src/classifiers/WorkerClassifierProxy.ts +0 -366
- package/src/worker/classifier-worker.ts +0 -267
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file llm-tier.spec.ts
|
|
3
|
+
* @description Tests for the LLM-as-judge (Tier 2) classification path.
|
|
4
|
+
*
|
|
5
|
+
* Exercises `classifyByLlm()` directly — verifying that the structured
|
|
6
|
+
* classification prompt is sent to the invoker, JSON and markdown-wrapped JSON
|
|
7
|
+
* are parsed correctly, and failures produce zero-confidence scores.
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
11
|
+
|
|
12
|
+
// Force ONNX unavailable so classify() falls through to LLM tier
|
|
13
|
+
vi.mock('@huggingface/transformers', () => {
|
|
14
|
+
throw new Error('ONNX not available');
|
|
15
|
+
});
|
|
16
|
+
|
|
17
|
+
import { classifyByLlm } from '../src/llm-classifier';
|
|
18
|
+
import type { LlmInvoker, ClassifierCategory } from '../src/types';
|
|
19
|
+
import { MLClassifierGuardrail } from '../src/MLClassifierGuardrail';
|
|
20
|
+
|
|
21
|
+
// ---------------------------------------------------------------------------
|
|
22
|
+
// Tests — classifyByLlm directly
|
|
23
|
+
// ---------------------------------------------------------------------------
|
|
24
|
+
|
|
25
|
+
describe('LLM tier — classifyByLlm()', () => {
|
|
26
|
+
// -----------------------------------------------------------------------
|
|
27
|
+
// Prompt verification
|
|
28
|
+
// -----------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
describe('prompt sent to invoker', () => {
|
|
31
|
+
it('sends a system prompt and the user text to the invoker', async () => {
|
|
32
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
33
|
+
JSON.stringify({
|
|
34
|
+
toxic: false,
|
|
35
|
+
injection: false,
|
|
36
|
+
nsfw: false,
|
|
37
|
+
threat: false,
|
|
38
|
+
confidence: 0.9,
|
|
39
|
+
})
|
|
40
|
+
);
|
|
41
|
+
|
|
42
|
+
await classifyByLlm('Hello world', invoker);
|
|
43
|
+
|
|
44
|
+
expect(invoker).toHaveBeenCalledTimes(1);
|
|
45
|
+
|
|
46
|
+
const [systemPrompt, userMessage] = (invoker as any).mock.calls[0];
|
|
47
|
+
expect(typeof systemPrompt).toBe('string');
|
|
48
|
+
expect(systemPrompt).toContain('content safety classifier');
|
|
49
|
+
expect(userMessage).toBe('Hello world');
|
|
50
|
+
});
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
// -----------------------------------------------------------------------
|
|
54
|
+
// JSON parsing
|
|
55
|
+
// -----------------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
describe('JSON response parsing', () => {
|
|
58
|
+
it('parses a clean JSON response into category scores', async () => {
|
|
59
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
60
|
+
JSON.stringify({
|
|
61
|
+
toxic: true,
|
|
62
|
+
injection: false,
|
|
63
|
+
nsfw: false,
|
|
64
|
+
threat: true,
|
|
65
|
+
confidence: 0.85,
|
|
66
|
+
})
|
|
67
|
+
);
|
|
68
|
+
|
|
69
|
+
const scores = await classifyByLlm('some bad text', invoker);
|
|
70
|
+
|
|
71
|
+
expect(scores).toHaveLength(4);
|
|
72
|
+
|
|
73
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
74
|
+
expect(toxic?.confidence).toBe(0.85);
|
|
75
|
+
|
|
76
|
+
const injection = scores.find((s) => s.name === 'injection');
|
|
77
|
+
expect(injection?.confidence).toBe(0);
|
|
78
|
+
|
|
79
|
+
const nsfw = scores.find((s) => s.name === 'nsfw');
|
|
80
|
+
expect(nsfw?.confidence).toBe(0);
|
|
81
|
+
|
|
82
|
+
const threat = scores.find((s) => s.name === 'threat');
|
|
83
|
+
expect(threat?.confidence).toBe(0.85);
|
|
84
|
+
});
|
|
85
|
+
|
|
86
|
+
it('uses default confidence (0.7) when confidence is omitted', async () => {
|
|
87
|
+
const invoker: LlmInvoker = vi
|
|
88
|
+
.fn()
|
|
89
|
+
.mockResolvedValue(
|
|
90
|
+
JSON.stringify({ toxic: true, injection: false, nsfw: false, threat: false })
|
|
91
|
+
);
|
|
92
|
+
|
|
93
|
+
const scores = await classifyByLlm('abusive text', invoker);
|
|
94
|
+
|
|
95
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
96
|
+
expect(toxic?.confidence).toBe(0.7);
|
|
97
|
+
});
|
|
98
|
+
|
|
99
|
+
it('clamps confidence to [0, 1]', async () => {
|
|
100
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
101
|
+
JSON.stringify({
|
|
102
|
+
toxic: true,
|
|
103
|
+
injection: false,
|
|
104
|
+
nsfw: false,
|
|
105
|
+
threat: false,
|
|
106
|
+
confidence: 5.0,
|
|
107
|
+
})
|
|
108
|
+
);
|
|
109
|
+
|
|
110
|
+
const scores = await classifyByLlm('test', invoker);
|
|
111
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
112
|
+
expect(toxic?.confidence).toBeLessThanOrEqual(1.0);
|
|
113
|
+
});
|
|
114
|
+
});
|
|
115
|
+
|
|
116
|
+
// -----------------------------------------------------------------------
|
|
117
|
+
// Markdown-wrapped JSON
|
|
118
|
+
// -----------------------------------------------------------------------
|
|
119
|
+
|
|
120
|
+
describe('markdown-wrapped JSON handling', () => {
|
|
121
|
+
it('strips ```json fences before parsing', async () => {
|
|
122
|
+
const invoker: LlmInvoker = vi
|
|
123
|
+
.fn()
|
|
124
|
+
.mockResolvedValue(
|
|
125
|
+
'```json\n{"toxic": true, "injection": false, "nsfw": false, "threat": false, "confidence": 0.9}\n```'
|
|
126
|
+
);
|
|
127
|
+
|
|
128
|
+
const scores = await classifyByLlm('wrapped response', invoker);
|
|
129
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
130
|
+
expect(toxic?.confidence).toBe(0.9);
|
|
131
|
+
});
|
|
132
|
+
|
|
133
|
+
it('strips bare ``` fences (no language tag)', async () => {
|
|
134
|
+
const invoker: LlmInvoker = vi
|
|
135
|
+
.fn()
|
|
136
|
+
.mockResolvedValue(
|
|
137
|
+
'```\n{"toxic": false, "injection": true, "nsfw": false, "threat": false, "confidence": 0.75}\n```'
|
|
138
|
+
);
|
|
139
|
+
|
|
140
|
+
const scores = await classifyByLlm('injection attempt', invoker);
|
|
141
|
+
const injection = scores.find((s) => s.name === 'injection');
|
|
142
|
+
expect(injection?.confidence).toBe(0.75);
|
|
143
|
+
});
|
|
144
|
+
|
|
145
|
+
it('handles trailing commas in LLM output', async () => {
|
|
146
|
+
const invoker: LlmInvoker = vi
|
|
147
|
+
.fn()
|
|
148
|
+
.mockResolvedValue(
|
|
149
|
+
'{"toxic": true, "injection": false, "nsfw": false, "threat": false, "confidence": 0.8,}'
|
|
150
|
+
);
|
|
151
|
+
|
|
152
|
+
const scores = await classifyByLlm('trailing comma', invoker);
|
|
153
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
154
|
+
expect(toxic?.confidence).toBe(0.8);
|
|
155
|
+
});
|
|
156
|
+
});
|
|
157
|
+
|
|
158
|
+
// -----------------------------------------------------------------------
|
|
159
|
+
// Failure modes
|
|
160
|
+
// -----------------------------------------------------------------------
|
|
161
|
+
|
|
162
|
+
describe('failure handling', () => {
|
|
163
|
+
it('returns zero scores when invoker throws', async () => {
|
|
164
|
+
const invoker: LlmInvoker = vi.fn().mockRejectedValue(new Error('LLM unavailable'));
|
|
165
|
+
|
|
166
|
+
const scores = await classifyByLlm('test', invoker);
|
|
167
|
+
|
|
168
|
+
expect(scores).toHaveLength(4);
|
|
169
|
+
for (const score of scores) {
|
|
170
|
+
expect(score.confidence).toBe(0);
|
|
171
|
+
}
|
|
172
|
+
});
|
|
173
|
+
|
|
174
|
+
it('returns zero scores when invoker returns unparseable text', async () => {
|
|
175
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue('I cannot classify this content.');
|
|
176
|
+
|
|
177
|
+
const scores = await classifyByLlm('test', invoker);
|
|
178
|
+
|
|
179
|
+
for (const score of scores) {
|
|
180
|
+
expect(score.confidence).toBe(0);
|
|
181
|
+
}
|
|
182
|
+
});
|
|
183
|
+
|
|
184
|
+
it('returns zero scores when invoker returns an array instead of object', async () => {
|
|
185
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue('[1, 2, 3]');
|
|
186
|
+
|
|
187
|
+
const scores = await classifyByLlm('test', invoker);
|
|
188
|
+
|
|
189
|
+
for (const score of scores) {
|
|
190
|
+
expect(score.confidence).toBe(0);
|
|
191
|
+
}
|
|
192
|
+
});
|
|
193
|
+
});
|
|
194
|
+
|
|
195
|
+
// -----------------------------------------------------------------------
|
|
196
|
+
// Category filtering
|
|
197
|
+
// -----------------------------------------------------------------------
|
|
198
|
+
|
|
199
|
+
describe('category filtering', () => {
|
|
200
|
+
it('returns scores only for requested categories', async () => {
|
|
201
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
202
|
+
JSON.stringify({
|
|
203
|
+
toxic: true,
|
|
204
|
+
injection: true,
|
|
205
|
+
nsfw: false,
|
|
206
|
+
threat: false,
|
|
207
|
+
confidence: 0.9,
|
|
208
|
+
})
|
|
209
|
+
);
|
|
210
|
+
|
|
211
|
+
const subset: ClassifierCategory[] = ['toxic', 'injection'];
|
|
212
|
+
const scores = await classifyByLlm('targeted', invoker, subset);
|
|
213
|
+
|
|
214
|
+
expect(scores).toHaveLength(2);
|
|
215
|
+
expect(scores.map((s) => s.name)).toEqual(['toxic', 'injection']);
|
|
216
|
+
});
|
|
217
|
+
});
|
|
218
|
+
});
|
|
219
|
+
|
|
220
|
+
// ---------------------------------------------------------------------------
|
|
221
|
+
// Tests — LLM tier via MLClassifierGuardrail.classify()
|
|
222
|
+
// ---------------------------------------------------------------------------
|
|
223
|
+
|
|
224
|
+
describe('LLM tier — via guardrail classify()', () => {
|
|
225
|
+
beforeEach(() => {
|
|
226
|
+
vi.clearAllMocks();
|
|
227
|
+
});
|
|
228
|
+
|
|
229
|
+
it('falls through to LLM when ONNX is unavailable', async () => {
|
|
230
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
231
|
+
JSON.stringify({
|
|
232
|
+
toxic: true,
|
|
233
|
+
injection: false,
|
|
234
|
+
nsfw: false,
|
|
235
|
+
threat: false,
|
|
236
|
+
confidence: 0.9,
|
|
237
|
+
})
|
|
238
|
+
);
|
|
239
|
+
|
|
240
|
+
const guardrail = new MLClassifierGuardrail({ llmInvoker: invoker });
|
|
241
|
+
const result = await guardrail.classify('test');
|
|
242
|
+
|
|
243
|
+
expect(result.source).toBe('llm');
|
|
244
|
+
expect(invoker).toHaveBeenCalledTimes(1);
|
|
245
|
+
});
|
|
246
|
+
|
|
247
|
+
it('result.flagged is true when LLM detects a category above threshold', async () => {
|
|
248
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
249
|
+
JSON.stringify({
|
|
250
|
+
toxic: true,
|
|
251
|
+
injection: false,
|
|
252
|
+
nsfw: false,
|
|
253
|
+
threat: false,
|
|
254
|
+
confidence: 0.85,
|
|
255
|
+
})
|
|
256
|
+
);
|
|
257
|
+
|
|
258
|
+
const guardrail = new MLClassifierGuardrail({ llmInvoker: invoker });
|
|
259
|
+
const result = await guardrail.classify('abusive text');
|
|
260
|
+
|
|
261
|
+
expect(result.flagged).toBe(true);
|
|
262
|
+
expect(result.source).toBe('llm');
|
|
263
|
+
|
|
264
|
+
const toxic = result.categories.find((c) => c.name === 'toxic');
|
|
265
|
+
expect(toxic?.confidence).toBe(0.85);
|
|
266
|
+
});
|
|
267
|
+
});
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
// @ts-nocheck
|
|
2
|
+
import { describe, it, expect } from 'vitest';
|
|
3
|
+
import { createExtensionPack } from '../src/index';
|
|
4
|
+
|
|
5
|
+
describe('ML Classifiers Extension Pack', () => {
|
|
6
|
+
// -------------------------------------------------------------------------
|
|
7
|
+
// Pack structure
|
|
8
|
+
// -------------------------------------------------------------------------
|
|
9
|
+
|
|
10
|
+
it('createExtensionPack returns correct structure', () => {
|
|
11
|
+
const pack = createExtensionPack({ options: {} } as any);
|
|
12
|
+
|
|
13
|
+
expect(pack.name).toBe('ml-classifiers');
|
|
14
|
+
expect(pack.version).toBe('1.0.0');
|
|
15
|
+
expect(pack.descriptors).toHaveLength(2);
|
|
16
|
+
|
|
17
|
+
const kinds = pack.descriptors.map((d) => d.kind);
|
|
18
|
+
expect(kinds).toContain('guardrail');
|
|
19
|
+
expect(kinds).toContain('tool');
|
|
20
|
+
|
|
21
|
+
const ids = pack.descriptors.map((d) => d.id);
|
|
22
|
+
expect(ids).toContain('ml-classifier-guardrail');
|
|
23
|
+
expect(ids).toContain('classify_content');
|
|
24
|
+
});
|
|
25
|
+
|
|
26
|
+
// -------------------------------------------------------------------------
|
|
27
|
+
// Guardrail — keyword fallback detection
|
|
28
|
+
// -------------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
describe('guardrail evaluateInput', () => {
|
|
31
|
+
function getGuardrail() {
|
|
32
|
+
const pack = createExtensionPack({ options: {} } as any);
|
|
33
|
+
const desc = pack.descriptors.find((d) => d.kind === 'guardrail');
|
|
34
|
+
return desc!.payload as any;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
it('detects highly toxic input', async () => {
|
|
38
|
+
const guardrail = getGuardrail();
|
|
39
|
+
// Use strongly toxic text that both ONNX and keyword fallback would flag
|
|
40
|
+
const result = await guardrail.evaluateInput({
|
|
41
|
+
input: { textInput: 'You are a stupid idiot, kill yourself you moron' },
|
|
42
|
+
});
|
|
43
|
+
|
|
44
|
+
expect(result).not.toBeNull();
|
|
45
|
+
expect(['flag', 'block']).toContain(result!.action);
|
|
46
|
+
});
|
|
47
|
+
|
|
48
|
+
it('allows clean input through', async () => {
|
|
49
|
+
const guardrail = getGuardrail();
|
|
50
|
+
const result = await guardrail.evaluateInput({
|
|
51
|
+
input: { textInput: 'What is the weather like today?' },
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
expect(result).toBeNull();
|
|
55
|
+
});
|
|
56
|
+
});
|
|
57
|
+
});
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file onnx-tier.spec.ts
|
|
3
|
+
* @description Tests for the ONNX (Tier 1) classification path in MLClassifierGuardrail.
|
|
4
|
+
*
|
|
5
|
+
* Mocks `@huggingface/transformers` to return controlled toxic-bert label/score
|
|
6
|
+
* pairs, verifying that ONNX results are mapped to internal categories, threshold
|
|
7
|
+
* logic works, and the result carries `source: 'onnx'`.
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
11
|
+
|
|
12
|
+
// ---------------------------------------------------------------------------
|
|
13
|
+
// Mock @huggingface/transformers
|
|
14
|
+
// ---------------------------------------------------------------------------
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* Callable mock that stands in for the ONNX text-classification pipeline.
|
|
18
|
+
* Tests configure its return value per-case via `mockResolvedValue`.
|
|
19
|
+
*/
|
|
20
|
+
const mockPipelineCall = vi.fn();
|
|
21
|
+
|
|
22
|
+
vi.mock('@huggingface/transformers', () => ({
|
|
23
|
+
pipeline: vi.fn().mockResolvedValue({
|
|
24
|
+
_call: mockPipelineCall,
|
|
25
|
+
}),
|
|
26
|
+
}));
|
|
27
|
+
|
|
28
|
+
// ---------------------------------------------------------------------------
|
|
29
|
+
// SUT
|
|
30
|
+
// ---------------------------------------------------------------------------
|
|
31
|
+
|
|
32
|
+
import { MLClassifierGuardrail } from '../src/MLClassifierGuardrail';
|
|
33
|
+
|
|
34
|
+
// ---------------------------------------------------------------------------
|
|
35
|
+
// Helpers
|
|
36
|
+
// ---------------------------------------------------------------------------
|
|
37
|
+
|
|
38
|
+
/** Builds a fresh guardrail instance for each test (resets cached pipeline). */
|
|
39
|
+
function createGuardrail(options?: any): MLClassifierGuardrail {
|
|
40
|
+
return new MLClassifierGuardrail(options);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
// ---------------------------------------------------------------------------
|
|
44
|
+
// Tests
|
|
45
|
+
// ---------------------------------------------------------------------------
|
|
46
|
+
|
|
47
|
+
describe('ONNX tier classification', () => {
|
|
48
|
+
beforeEach(() => {
|
|
49
|
+
vi.clearAllMocks();
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
// -------------------------------------------------------------------------
|
|
53
|
+
// Label mapping
|
|
54
|
+
// -------------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
describe('label-to-category mapping', () => {
|
|
57
|
+
it('maps toxic-bert labels to internal categories', async () => {
|
|
58
|
+
mockPipelineCall.mockResolvedValue([
|
|
59
|
+
{ label: 'toxic', score: 0.92 },
|
|
60
|
+
{ label: 'severe_toxic', score: 0.45 },
|
|
61
|
+
{ label: 'obscene', score: 0.78 },
|
|
62
|
+
{ label: 'insult', score: 0.65 },
|
|
63
|
+
{ label: 'identity_hate', score: 0.3 },
|
|
64
|
+
{ label: 'threat', score: 0.15 },
|
|
65
|
+
]);
|
|
66
|
+
|
|
67
|
+
const guardrail = createGuardrail();
|
|
68
|
+
const result = await guardrail.classify('test text');
|
|
69
|
+
|
|
70
|
+
expect(result.source).toBe('onnx');
|
|
71
|
+
|
|
72
|
+
// toxic = max(toxic:0.92, severe_toxic:0.45, insult:0.65, identity_hate:0.30)
|
|
73
|
+
const toxic = result.categories.find((c) => c.name === 'toxic');
|
|
74
|
+
expect(toxic?.confidence).toBe(0.92);
|
|
75
|
+
|
|
76
|
+
// nsfw = max(obscene:0.78)
|
|
77
|
+
const nsfw = result.categories.find((c) => c.name === 'nsfw');
|
|
78
|
+
expect(nsfw?.confidence).toBe(0.78);
|
|
79
|
+
|
|
80
|
+
// threat = max(threat:0.15)
|
|
81
|
+
const threat = result.categories.find((c) => c.name === 'threat');
|
|
82
|
+
expect(threat?.confidence).toBe(0.15);
|
|
83
|
+
|
|
84
|
+
// injection is not produced by toxic-bert, stays at 0
|
|
85
|
+
const injection = result.categories.find((c) => c.name === 'injection');
|
|
86
|
+
expect(injection?.confidence).toBe(0);
|
|
87
|
+
});
|
|
88
|
+
|
|
89
|
+
it('takes max score when multiple ONNX labels map to the same category', async () => {
|
|
90
|
+
mockPipelineCall.mockResolvedValue([
|
|
91
|
+
{ label: 'toxic', score: 0.3 },
|
|
92
|
+
{ label: 'severe_toxic', score: 0.85 },
|
|
93
|
+
{ label: 'insult', score: 0.6 },
|
|
94
|
+
{ label: 'identity_hate', score: 0.7 },
|
|
95
|
+
{ label: 'obscene', score: 0.1 },
|
|
96
|
+
{ label: 'threat', score: 0.05 },
|
|
97
|
+
]);
|
|
98
|
+
|
|
99
|
+
const guardrail = createGuardrail();
|
|
100
|
+
const result = await guardrail.classify('some text');
|
|
101
|
+
|
|
102
|
+
// toxic category = max(0.30, 0.85, 0.60, 0.70) = 0.85
|
|
103
|
+
const toxic = result.categories.find((c) => c.name === 'toxic');
|
|
104
|
+
expect(toxic?.confidence).toBe(0.85);
|
|
105
|
+
});
|
|
106
|
+
|
|
107
|
+
it('handles labels with mixed case and whitespace', async () => {
|
|
108
|
+
mockPipelineCall.mockResolvedValue([
|
|
109
|
+
{ label: 'Toxic', score: 0.7 },
|
|
110
|
+
{ label: 'OBSCENE', score: 0.6 },
|
|
111
|
+
{ label: 'identity hate', score: 0.5 },
|
|
112
|
+
{ label: 'THREAT', score: 0.3 },
|
|
113
|
+
{ label: 'severe toxic', score: 0.2 },
|
|
114
|
+
{ label: 'INSULT', score: 0.1 },
|
|
115
|
+
]);
|
|
116
|
+
|
|
117
|
+
const guardrail = createGuardrail();
|
|
118
|
+
const result = await guardrail.classify('some text');
|
|
119
|
+
|
|
120
|
+
// identity_hate is "identity hate" with space -> lowered + underscore = identity_hate -> toxic
|
|
121
|
+
// toxic = max(toxic:0.7, identity_hate:0.5, severe_toxic:0.2, insult:0.1) = 0.7
|
|
122
|
+
const toxic = result.categories.find((c) => c.name === 'toxic');
|
|
123
|
+
expect(toxic?.confidence).toBe(0.7);
|
|
124
|
+
|
|
125
|
+
// obscene -> nsfw = 0.6
|
|
126
|
+
const nsfw = result.categories.find((c) => c.name === 'nsfw');
|
|
127
|
+
expect(nsfw?.confidence).toBe(0.6);
|
|
128
|
+
});
|
|
129
|
+
});
|
|
130
|
+
|
|
131
|
+
// -------------------------------------------------------------------------
|
|
132
|
+
// Threshold behaviour
|
|
133
|
+
// -------------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
describe('threshold behaviour', () => {
|
|
136
|
+
it('flags content above default flag threshold (0.5)', async () => {
|
|
137
|
+
mockPipelineCall.mockResolvedValue([
|
|
138
|
+
{ label: 'toxic', score: 0.65 },
|
|
139
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
140
|
+
{ label: 'obscene', score: 0.0 },
|
|
141
|
+
{ label: 'insult', score: 0.0 },
|
|
142
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
143
|
+
{ label: 'threat', score: 0.0 },
|
|
144
|
+
]);
|
|
145
|
+
|
|
146
|
+
const guardrail = createGuardrail();
|
|
147
|
+
const result = await guardrail.classify('mildly toxic text');
|
|
148
|
+
|
|
149
|
+
expect(result.flagged).toBe(true);
|
|
150
|
+
expect(result.source).toBe('onnx');
|
|
151
|
+
});
|
|
152
|
+
|
|
153
|
+
it('does not flag content below all thresholds', async () => {
|
|
154
|
+
mockPipelineCall.mockResolvedValue([
|
|
155
|
+
{ label: 'toxic', score: 0.1 },
|
|
156
|
+
{ label: 'severe_toxic', score: 0.05 },
|
|
157
|
+
{ label: 'obscene', score: 0.02 },
|
|
158
|
+
{ label: 'insult', score: 0.08 },
|
|
159
|
+
{ label: 'identity_hate', score: 0.01 },
|
|
160
|
+
{ label: 'threat', score: 0.03 },
|
|
161
|
+
]);
|
|
162
|
+
|
|
163
|
+
const guardrail = createGuardrail();
|
|
164
|
+
const result = await guardrail.classify('perfectly clean text');
|
|
165
|
+
|
|
166
|
+
expect(result.flagged).toBe(false);
|
|
167
|
+
expect(result.source).toBe('onnx');
|
|
168
|
+
});
|
|
169
|
+
|
|
170
|
+
it('respects per-category threshold overrides', async () => {
|
|
171
|
+
mockPipelineCall.mockResolvedValue([
|
|
172
|
+
{ label: 'toxic', score: 0.35 },
|
|
173
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
174
|
+
{ label: 'obscene', score: 0.0 },
|
|
175
|
+
{ label: 'insult', score: 0.0 },
|
|
176
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
177
|
+
{ label: 'threat', score: 0.0 },
|
|
178
|
+
]);
|
|
179
|
+
|
|
180
|
+
// Lower the toxic flag threshold so 0.35 exceeds it
|
|
181
|
+
const guardrail = createGuardrail({
|
|
182
|
+
thresholds: { toxic: { flag: 0.3 } },
|
|
183
|
+
});
|
|
184
|
+
const result = await guardrail.classify('borderline text');
|
|
185
|
+
|
|
186
|
+
expect(result.flagged).toBe(true);
|
|
187
|
+
});
|
|
188
|
+
|
|
189
|
+
it('does not flag when score equals the threshold exactly', async () => {
|
|
190
|
+
mockPipelineCall.mockResolvedValue([
|
|
191
|
+
{ label: 'toxic', score: 0.5 },
|
|
192
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
193
|
+
{ label: 'obscene', score: 0.0 },
|
|
194
|
+
{ label: 'insult', score: 0.0 },
|
|
195
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
196
|
+
{ label: 'threat', score: 0.0 },
|
|
197
|
+
]);
|
|
198
|
+
|
|
199
|
+
const guardrail = createGuardrail();
|
|
200
|
+
const result = await guardrail.classify('edge case text');
|
|
201
|
+
|
|
202
|
+
// Flag threshold is 0.5, score is exactly 0.5 -> ">" check, not ">="
|
|
203
|
+
expect(result.flagged).toBe(false);
|
|
204
|
+
});
|
|
205
|
+
});
|
|
206
|
+
|
|
207
|
+
// -------------------------------------------------------------------------
|
|
208
|
+
// Result source
|
|
209
|
+
// -------------------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
describe('result source', () => {
|
|
212
|
+
it('always returns source: onnx when pipeline succeeds', async () => {
|
|
213
|
+
mockPipelineCall.mockResolvedValue([
|
|
214
|
+
{ label: 'toxic', score: 0.0 },
|
|
215
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
216
|
+
{ label: 'obscene', score: 0.0 },
|
|
217
|
+
{ label: 'insult', score: 0.0 },
|
|
218
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
219
|
+
{ label: 'threat', score: 0.0 },
|
|
220
|
+
]);
|
|
221
|
+
|
|
222
|
+
const guardrail = createGuardrail();
|
|
223
|
+
const result = await guardrail.classify('hello');
|
|
224
|
+
|
|
225
|
+
expect(result.source).toBe('onnx');
|
|
226
|
+
});
|
|
227
|
+
});
|
|
228
|
+
|
|
229
|
+
// -------------------------------------------------------------------------
|
|
230
|
+
// All four categories present
|
|
231
|
+
// -------------------------------------------------------------------------
|
|
232
|
+
|
|
233
|
+
describe('category completeness', () => {
|
|
234
|
+
it('returns scores for all four categories', async () => {
|
|
235
|
+
mockPipelineCall.mockResolvedValue([
|
|
236
|
+
{ label: 'toxic', score: 0.1 },
|
|
237
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
238
|
+
{ label: 'obscene', score: 0.2 },
|
|
239
|
+
{ label: 'insult', score: 0.0 },
|
|
240
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
241
|
+
{ label: 'threat', score: 0.3 },
|
|
242
|
+
]);
|
|
243
|
+
|
|
244
|
+
const guardrail = createGuardrail();
|
|
245
|
+
const result = await guardrail.classify('test');
|
|
246
|
+
|
|
247
|
+
const names = result.categories.map((c) => c.name);
|
|
248
|
+
expect(names).toContain('toxic');
|
|
249
|
+
expect(names).toContain('injection');
|
|
250
|
+
expect(names).toContain('nsfw');
|
|
251
|
+
expect(names).toContain('threat');
|
|
252
|
+
expect(result.categories).toHaveLength(4);
|
|
253
|
+
});
|
|
254
|
+
});
|
|
255
|
+
});
|