@framers/agentos-ext-ml-classifiers 0.2.1 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.github/workflows/ci.yml +20 -0
- package/.github/workflows/release.yml +37 -0
- package/.releaserc.json +9 -0
- package/LICENSE +96 -21
- package/README.md +72 -0
- package/dist/MLClassifierGuardrail.d.ts.map +1 -1
- package/dist/MLClassifierGuardrail.js +14 -6
- package/dist/MLClassifierGuardrail.js.map +1 -1
- package/dist/index.js +3 -3
- package/dist/keyword-classifier.js +1 -1
- package/dist/llm-classifier.js +1 -1
- package/package.json +5 -13
- package/scripts/fix-esm-imports.mjs +181 -0
- package/src/MLClassifierGuardrail.ts +38 -5
- package/test/llm-tier.spec.ts +267 -0
- package/test/ml-classifiers.spec.ts +57 -0
- package/test/onnx-tier.spec.ts +255 -0
- package/test/tier-fallthrough.spec.ts +185 -0
- package/vitest.config.ts +18 -7
- package/CHANGELOG.md +0 -18
- package/dist/ClassifierOrchestrator.d.ts +0 -126
- package/dist/ClassifierOrchestrator.d.ts.map +0 -1
- package/dist/ClassifierOrchestrator.js +0 -239
- package/dist/ClassifierOrchestrator.js.map +0 -1
- package/dist/IContentClassifier.d.ts +0 -117
- package/dist/IContentClassifier.d.ts.map +0 -1
- package/dist/IContentClassifier.js +0 -22
- package/dist/IContentClassifier.js.map +0 -1
- package/dist/SlidingWindowBuffer.d.ts +0 -213
- package/dist/SlidingWindowBuffer.d.ts.map +0 -1
- package/dist/SlidingWindowBuffer.js +0 -246
- package/dist/SlidingWindowBuffer.js.map +0 -1
- package/dist/classifiers/InjectionClassifier.d.ts +0 -126
- package/dist/classifiers/InjectionClassifier.d.ts.map +0 -1
- package/dist/classifiers/InjectionClassifier.js +0 -210
- package/dist/classifiers/InjectionClassifier.js.map +0 -1
- package/dist/classifiers/JailbreakClassifier.d.ts +0 -124
- package/dist/classifiers/JailbreakClassifier.d.ts.map +0 -1
- package/dist/classifiers/JailbreakClassifier.js +0 -208
- package/dist/classifiers/JailbreakClassifier.js.map +0 -1
- package/dist/classifiers/ToxicityClassifier.d.ts +0 -125
- package/dist/classifiers/ToxicityClassifier.d.ts.map +0 -1
- package/dist/classifiers/ToxicityClassifier.js +0 -212
- package/dist/classifiers/ToxicityClassifier.js.map +0 -1
- package/dist/classifiers/WorkerClassifierProxy.d.ts +0 -158
- package/dist/classifiers/WorkerClassifierProxy.d.ts.map +0 -1
- package/dist/classifiers/WorkerClassifierProxy.js +0 -268
- package/dist/classifiers/WorkerClassifierProxy.js.map +0 -1
- package/dist/worker/classifier-worker.d.ts +0 -49
- package/dist/worker/classifier-worker.d.ts.map +0 -1
- package/dist/worker/classifier-worker.js +0 -180
- package/dist/worker/classifier-worker.js.map +0 -1
- package/src/ClassifierOrchestrator.ts +0 -290
- package/src/IContentClassifier.ts +0 -124
- package/src/SlidingWindowBuffer.ts +0 -384
- package/src/classifiers/InjectionClassifier.ts +0 -261
- package/src/classifiers/JailbreakClassifier.ts +0 -259
- package/src/classifiers/ToxicityClassifier.ts +0 -263
- package/src/classifiers/WorkerClassifierProxy.ts +0 -366
- package/src/worker/classifier-worker.ts +0 -267
- package/test/ClassifierOrchestrator.spec.ts +0 -365
- package/test/ClassifyContentTool.spec.ts +0 -226
- package/test/InjectionClassifier.spec.ts +0 -263
- package/test/JailbreakClassifier.spec.ts +0 -295
- package/test/MLClassifierGuardrail.spec.ts +0 -486
- package/test/SlidingWindowBuffer.spec.ts +0 -391
- package/test/ToxicityClassifier.spec.ts +0 -268
- package/test/WorkerClassifierProxy.spec.ts +0 -303
- package/test/index.spec.ts +0 -431
|
@@ -27,6 +27,7 @@ import type {
|
|
|
27
27
|
GuardrailInputPayload,
|
|
28
28
|
GuardrailOutputPayload,
|
|
29
29
|
GuardrailEvaluationResult,
|
|
30
|
+
AgentOSFinalResponseChunk,
|
|
30
31
|
} from '@framers/agentos';
|
|
31
32
|
import { GuardrailAction } from '@framers/agentos';
|
|
32
33
|
import { AgentOSResponseChunkType } from '@framers/agentos';
|
|
@@ -40,6 +41,27 @@ import { ALL_CATEGORIES } from './types';
|
|
|
40
41
|
import { classifyByKeywords } from './keyword-classifier';
|
|
41
42
|
import { classifyByLlm } from './llm-classifier';
|
|
42
43
|
|
|
44
|
+
// ---------------------------------------------------------------------------
|
|
45
|
+
// HuggingFace / ONNX pipeline types
|
|
46
|
+
// ---------------------------------------------------------------------------
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* A single label + score pair returned by a HuggingFace text-classification
|
|
50
|
+
* pipeline. The `label` is the model's class name (e.g. `"toxic"`,
|
|
51
|
+
* `"obscene"`) and `score` is the softmax probability in [0, 1].
|
|
52
|
+
*/
|
|
53
|
+
interface OnnxClassificationLabel {
|
|
54
|
+
label: string;
|
|
55
|
+
score: number;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* Callable returned by `pipeline('text-classification', ...)` from the
|
|
60
|
+
* `@huggingface/transformers` package. Returns all label scores for the
|
|
61
|
+
* given text.
|
|
62
|
+
*/
|
|
63
|
+
type OnnxTextClassificationPipeline = (text: string) => Promise<OnnxClassificationLabel[]>;
|
|
64
|
+
|
|
43
65
|
// ---------------------------------------------------------------------------
|
|
44
66
|
// MLClassifierGuardrail
|
|
45
67
|
// ---------------------------------------------------------------------------
|
|
@@ -91,7 +113,7 @@ export class MLClassifierGuardrail implements IGuardrailService {
|
|
|
91
113
|
* `null` means we already tried and failed to load the module.
|
|
92
114
|
* `undefined` means we have not tried yet.
|
|
93
115
|
*/
|
|
94
|
-
private onnxPipeline:
|
|
116
|
+
private onnxPipeline: OnnxTextClassificationPipeline | null | undefined = undefined;
|
|
95
117
|
|
|
96
118
|
// -----------------------------------------------------------------------
|
|
97
119
|
// Constructor
|
|
@@ -159,7 +181,8 @@ export class MLClassifierGuardrail implements IGuardrailService {
|
|
|
159
181
|
return null;
|
|
160
182
|
}
|
|
161
183
|
|
|
162
|
-
const
|
|
184
|
+
const finalChunk = chunk as AgentOSFinalResponseChunk;
|
|
185
|
+
const text = finalChunk.finalResponseText ?? '';
|
|
163
186
|
if (typeof text !== 'string' || text.length === 0) return null;
|
|
164
187
|
|
|
165
188
|
const result = await this.classify(text);
|
|
@@ -222,11 +245,21 @@ export class MLClassifierGuardrail implements IGuardrailService {
|
|
|
222
245
|
try {
|
|
223
246
|
// Dynamic import so the optional dependency does not fail at boot.
|
|
224
247
|
const transformers = await import('@huggingface/transformers');
|
|
225
|
-
|
|
248
|
+
const pipelineInstance = await transformers.pipeline(
|
|
226
249
|
'text-classification',
|
|
227
250
|
'Xenova/toxic-bert',
|
|
228
251
|
{ device: 'cpu' }
|
|
229
252
|
);
|
|
253
|
+
// The HuggingFace pipeline is callable as a function. We always
|
|
254
|
+
// request all labels (top_k higher than any model's label count
|
|
255
|
+
// causes HF to return every label, matching `top_k: null`).
|
|
256
|
+
// Cast through unknown because the Pipeline union type is too
|
|
257
|
+
// wide for the inferred call signature here.
|
|
258
|
+
const callable = pipelineInstance as unknown as (
|
|
259
|
+
text: string,
|
|
260
|
+
opts: { top_k: number },
|
|
261
|
+
) => Promise<OnnxClassificationLabel[]>;
|
|
262
|
+
this.onnxPipeline = (text: string) => callable(text, { top_k: 9999 });
|
|
230
263
|
} catch {
|
|
231
264
|
// Module not installed or model load failed — mark as unavailable.
|
|
232
265
|
this.onnxPipeline = null;
|
|
@@ -235,7 +268,7 @@ export class MLClassifierGuardrail implements IGuardrailService {
|
|
|
235
268
|
}
|
|
236
269
|
|
|
237
270
|
try {
|
|
238
|
-
const raw = await this.onnxPipeline(text
|
|
271
|
+
const raw = await this.onnxPipeline(text);
|
|
239
272
|
|
|
240
273
|
// Map ONNX labels to our categories.
|
|
241
274
|
const scores = this.mapOnnxScores(raw);
|
|
@@ -263,7 +296,7 @@ export class MLClassifierGuardrail implements IGuardrailService {
|
|
|
263
296
|
*
|
|
264
297
|
* @internal
|
|
265
298
|
*/
|
|
266
|
-
private mapOnnxScores(raw:
|
|
299
|
+
private mapOnnxScores(raw: OnnxClassificationLabel[]): CategoryScore[] {
|
|
267
300
|
/** Map of ONNX label -> our category. */
|
|
268
301
|
const labelMap: Record<string, ClassifierCategory> = {
|
|
269
302
|
toxic: 'toxic',
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file llm-tier.spec.ts
|
|
3
|
+
* @description Tests for the LLM-as-judge (Tier 2) classification path.
|
|
4
|
+
*
|
|
5
|
+
* Exercises `classifyByLlm()` directly — verifying that the structured
|
|
6
|
+
* classification prompt is sent to the invoker, JSON and markdown-wrapped JSON
|
|
7
|
+
* are parsed correctly, and failures produce zero-confidence scores.
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
11
|
+
|
|
12
|
+
// Force ONNX unavailable so classify() falls through to LLM tier
|
|
13
|
+
vi.mock('@huggingface/transformers', () => {
|
|
14
|
+
throw new Error('ONNX not available');
|
|
15
|
+
});
|
|
16
|
+
|
|
17
|
+
import { classifyByLlm } from '../src/llm-classifier';
|
|
18
|
+
import type { LlmInvoker, ClassifierCategory } from '../src/types';
|
|
19
|
+
import { MLClassifierGuardrail } from '../src/MLClassifierGuardrail';
|
|
20
|
+
|
|
21
|
+
// ---------------------------------------------------------------------------
|
|
22
|
+
// Tests — classifyByLlm directly
|
|
23
|
+
// ---------------------------------------------------------------------------
|
|
24
|
+
|
|
25
|
+
describe('LLM tier — classifyByLlm()', () => {
|
|
26
|
+
// -----------------------------------------------------------------------
|
|
27
|
+
// Prompt verification
|
|
28
|
+
// -----------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
describe('prompt sent to invoker', () => {
|
|
31
|
+
it('sends a system prompt and the user text to the invoker', async () => {
|
|
32
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
33
|
+
JSON.stringify({
|
|
34
|
+
toxic: false,
|
|
35
|
+
injection: false,
|
|
36
|
+
nsfw: false,
|
|
37
|
+
threat: false,
|
|
38
|
+
confidence: 0.9,
|
|
39
|
+
})
|
|
40
|
+
);
|
|
41
|
+
|
|
42
|
+
await classifyByLlm('Hello world', invoker);
|
|
43
|
+
|
|
44
|
+
expect(invoker).toHaveBeenCalledTimes(1);
|
|
45
|
+
|
|
46
|
+
const [systemPrompt, userMessage] = (invoker as any).mock.calls[0];
|
|
47
|
+
expect(typeof systemPrompt).toBe('string');
|
|
48
|
+
expect(systemPrompt).toContain('content safety classifier');
|
|
49
|
+
expect(userMessage).toBe('Hello world');
|
|
50
|
+
});
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
// -----------------------------------------------------------------------
|
|
54
|
+
// JSON parsing
|
|
55
|
+
// -----------------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
describe('JSON response parsing', () => {
|
|
58
|
+
it('parses a clean JSON response into category scores', async () => {
|
|
59
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
60
|
+
JSON.stringify({
|
|
61
|
+
toxic: true,
|
|
62
|
+
injection: false,
|
|
63
|
+
nsfw: false,
|
|
64
|
+
threat: true,
|
|
65
|
+
confidence: 0.85,
|
|
66
|
+
})
|
|
67
|
+
);
|
|
68
|
+
|
|
69
|
+
const scores = await classifyByLlm('some bad text', invoker);
|
|
70
|
+
|
|
71
|
+
expect(scores).toHaveLength(4);
|
|
72
|
+
|
|
73
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
74
|
+
expect(toxic?.confidence).toBe(0.85);
|
|
75
|
+
|
|
76
|
+
const injection = scores.find((s) => s.name === 'injection');
|
|
77
|
+
expect(injection?.confidence).toBe(0);
|
|
78
|
+
|
|
79
|
+
const nsfw = scores.find((s) => s.name === 'nsfw');
|
|
80
|
+
expect(nsfw?.confidence).toBe(0);
|
|
81
|
+
|
|
82
|
+
const threat = scores.find((s) => s.name === 'threat');
|
|
83
|
+
expect(threat?.confidence).toBe(0.85);
|
|
84
|
+
});
|
|
85
|
+
|
|
86
|
+
it('uses default confidence (0.7) when confidence is omitted', async () => {
|
|
87
|
+
const invoker: LlmInvoker = vi
|
|
88
|
+
.fn()
|
|
89
|
+
.mockResolvedValue(
|
|
90
|
+
JSON.stringify({ toxic: true, injection: false, nsfw: false, threat: false })
|
|
91
|
+
);
|
|
92
|
+
|
|
93
|
+
const scores = await classifyByLlm('abusive text', invoker);
|
|
94
|
+
|
|
95
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
96
|
+
expect(toxic?.confidence).toBe(0.7);
|
|
97
|
+
});
|
|
98
|
+
|
|
99
|
+
it('clamps confidence to [0, 1]', async () => {
|
|
100
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
101
|
+
JSON.stringify({
|
|
102
|
+
toxic: true,
|
|
103
|
+
injection: false,
|
|
104
|
+
nsfw: false,
|
|
105
|
+
threat: false,
|
|
106
|
+
confidence: 5.0,
|
|
107
|
+
})
|
|
108
|
+
);
|
|
109
|
+
|
|
110
|
+
const scores = await classifyByLlm('test', invoker);
|
|
111
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
112
|
+
expect(toxic?.confidence).toBeLessThanOrEqual(1.0);
|
|
113
|
+
});
|
|
114
|
+
});
|
|
115
|
+
|
|
116
|
+
// -----------------------------------------------------------------------
|
|
117
|
+
// Markdown-wrapped JSON
|
|
118
|
+
// -----------------------------------------------------------------------
|
|
119
|
+
|
|
120
|
+
describe('markdown-wrapped JSON handling', () => {
|
|
121
|
+
it('strips ```json fences before parsing', async () => {
|
|
122
|
+
const invoker: LlmInvoker = vi
|
|
123
|
+
.fn()
|
|
124
|
+
.mockResolvedValue(
|
|
125
|
+
'```json\n{"toxic": true, "injection": false, "nsfw": false, "threat": false, "confidence": 0.9}\n```'
|
|
126
|
+
);
|
|
127
|
+
|
|
128
|
+
const scores = await classifyByLlm('wrapped response', invoker);
|
|
129
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
130
|
+
expect(toxic?.confidence).toBe(0.9);
|
|
131
|
+
});
|
|
132
|
+
|
|
133
|
+
it('strips bare ``` fences (no language tag)', async () => {
|
|
134
|
+
const invoker: LlmInvoker = vi
|
|
135
|
+
.fn()
|
|
136
|
+
.mockResolvedValue(
|
|
137
|
+
'```\n{"toxic": false, "injection": true, "nsfw": false, "threat": false, "confidence": 0.75}\n```'
|
|
138
|
+
);
|
|
139
|
+
|
|
140
|
+
const scores = await classifyByLlm('injection attempt', invoker);
|
|
141
|
+
const injection = scores.find((s) => s.name === 'injection');
|
|
142
|
+
expect(injection?.confidence).toBe(0.75);
|
|
143
|
+
});
|
|
144
|
+
|
|
145
|
+
it('handles trailing commas in LLM output', async () => {
|
|
146
|
+
const invoker: LlmInvoker = vi
|
|
147
|
+
.fn()
|
|
148
|
+
.mockResolvedValue(
|
|
149
|
+
'{"toxic": true, "injection": false, "nsfw": false, "threat": false, "confidence": 0.8,}'
|
|
150
|
+
);
|
|
151
|
+
|
|
152
|
+
const scores = await classifyByLlm('trailing comma', invoker);
|
|
153
|
+
const toxic = scores.find((s) => s.name === 'toxic');
|
|
154
|
+
expect(toxic?.confidence).toBe(0.8);
|
|
155
|
+
});
|
|
156
|
+
});
|
|
157
|
+
|
|
158
|
+
// -----------------------------------------------------------------------
|
|
159
|
+
// Failure modes
|
|
160
|
+
// -----------------------------------------------------------------------
|
|
161
|
+
|
|
162
|
+
describe('failure handling', () => {
|
|
163
|
+
it('returns zero scores when invoker throws', async () => {
|
|
164
|
+
const invoker: LlmInvoker = vi.fn().mockRejectedValue(new Error('LLM unavailable'));
|
|
165
|
+
|
|
166
|
+
const scores = await classifyByLlm('test', invoker);
|
|
167
|
+
|
|
168
|
+
expect(scores).toHaveLength(4);
|
|
169
|
+
for (const score of scores) {
|
|
170
|
+
expect(score.confidence).toBe(0);
|
|
171
|
+
}
|
|
172
|
+
});
|
|
173
|
+
|
|
174
|
+
it('returns zero scores when invoker returns unparseable text', async () => {
|
|
175
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue('I cannot classify this content.');
|
|
176
|
+
|
|
177
|
+
const scores = await classifyByLlm('test', invoker);
|
|
178
|
+
|
|
179
|
+
for (const score of scores) {
|
|
180
|
+
expect(score.confidence).toBe(0);
|
|
181
|
+
}
|
|
182
|
+
});
|
|
183
|
+
|
|
184
|
+
it('returns zero scores when invoker returns an array instead of object', async () => {
|
|
185
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue('[1, 2, 3]');
|
|
186
|
+
|
|
187
|
+
const scores = await classifyByLlm('test', invoker);
|
|
188
|
+
|
|
189
|
+
for (const score of scores) {
|
|
190
|
+
expect(score.confidence).toBe(0);
|
|
191
|
+
}
|
|
192
|
+
});
|
|
193
|
+
});
|
|
194
|
+
|
|
195
|
+
// -----------------------------------------------------------------------
|
|
196
|
+
// Category filtering
|
|
197
|
+
// -----------------------------------------------------------------------
|
|
198
|
+
|
|
199
|
+
describe('category filtering', () => {
|
|
200
|
+
it('returns scores only for requested categories', async () => {
|
|
201
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
202
|
+
JSON.stringify({
|
|
203
|
+
toxic: true,
|
|
204
|
+
injection: true,
|
|
205
|
+
nsfw: false,
|
|
206
|
+
threat: false,
|
|
207
|
+
confidence: 0.9,
|
|
208
|
+
})
|
|
209
|
+
);
|
|
210
|
+
|
|
211
|
+
const subset: ClassifierCategory[] = ['toxic', 'injection'];
|
|
212
|
+
const scores = await classifyByLlm('targeted', invoker, subset);
|
|
213
|
+
|
|
214
|
+
expect(scores).toHaveLength(2);
|
|
215
|
+
expect(scores.map((s) => s.name)).toEqual(['toxic', 'injection']);
|
|
216
|
+
});
|
|
217
|
+
});
|
|
218
|
+
});
|
|
219
|
+
|
|
220
|
+
// ---------------------------------------------------------------------------
|
|
221
|
+
// Tests — LLM tier via MLClassifierGuardrail.classify()
|
|
222
|
+
// ---------------------------------------------------------------------------
|
|
223
|
+
|
|
224
|
+
describe('LLM tier — via guardrail classify()', () => {
|
|
225
|
+
beforeEach(() => {
|
|
226
|
+
vi.clearAllMocks();
|
|
227
|
+
});
|
|
228
|
+
|
|
229
|
+
it('falls through to LLM when ONNX is unavailable', async () => {
|
|
230
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
231
|
+
JSON.stringify({
|
|
232
|
+
toxic: true,
|
|
233
|
+
injection: false,
|
|
234
|
+
nsfw: false,
|
|
235
|
+
threat: false,
|
|
236
|
+
confidence: 0.9,
|
|
237
|
+
})
|
|
238
|
+
);
|
|
239
|
+
|
|
240
|
+
const guardrail = new MLClassifierGuardrail({ llmInvoker: invoker });
|
|
241
|
+
const result = await guardrail.classify('test');
|
|
242
|
+
|
|
243
|
+
expect(result.source).toBe('llm');
|
|
244
|
+
expect(invoker).toHaveBeenCalledTimes(1);
|
|
245
|
+
});
|
|
246
|
+
|
|
247
|
+
it('result.flagged is true when LLM detects a category above threshold', async () => {
|
|
248
|
+
const invoker: LlmInvoker = vi.fn().mockResolvedValue(
|
|
249
|
+
JSON.stringify({
|
|
250
|
+
toxic: true,
|
|
251
|
+
injection: false,
|
|
252
|
+
nsfw: false,
|
|
253
|
+
threat: false,
|
|
254
|
+
confidence: 0.85,
|
|
255
|
+
})
|
|
256
|
+
);
|
|
257
|
+
|
|
258
|
+
const guardrail = new MLClassifierGuardrail({ llmInvoker: invoker });
|
|
259
|
+
const result = await guardrail.classify('abusive text');
|
|
260
|
+
|
|
261
|
+
expect(result.flagged).toBe(true);
|
|
262
|
+
expect(result.source).toBe('llm');
|
|
263
|
+
|
|
264
|
+
const toxic = result.categories.find((c) => c.name === 'toxic');
|
|
265
|
+
expect(toxic?.confidence).toBe(0.85);
|
|
266
|
+
});
|
|
267
|
+
});
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
// @ts-nocheck
|
|
2
|
+
import { describe, it, expect } from 'vitest';
|
|
3
|
+
import { createExtensionPack } from '../src/index';
|
|
4
|
+
|
|
5
|
+
describe('ML Classifiers Extension Pack', () => {
|
|
6
|
+
// -------------------------------------------------------------------------
|
|
7
|
+
// Pack structure
|
|
8
|
+
// -------------------------------------------------------------------------
|
|
9
|
+
|
|
10
|
+
it('createExtensionPack returns correct structure', () => {
|
|
11
|
+
const pack = createExtensionPack({ options: {} } as any);
|
|
12
|
+
|
|
13
|
+
expect(pack.name).toBe('ml-classifiers');
|
|
14
|
+
expect(pack.version).toBe('1.0.0');
|
|
15
|
+
expect(pack.descriptors).toHaveLength(2);
|
|
16
|
+
|
|
17
|
+
const kinds = pack.descriptors.map((d) => d.kind);
|
|
18
|
+
expect(kinds).toContain('guardrail');
|
|
19
|
+
expect(kinds).toContain('tool');
|
|
20
|
+
|
|
21
|
+
const ids = pack.descriptors.map((d) => d.id);
|
|
22
|
+
expect(ids).toContain('ml-classifier-guardrail');
|
|
23
|
+
expect(ids).toContain('classify_content');
|
|
24
|
+
});
|
|
25
|
+
|
|
26
|
+
// -------------------------------------------------------------------------
|
|
27
|
+
// Guardrail — keyword fallback detection
|
|
28
|
+
// -------------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
describe('guardrail evaluateInput', () => {
|
|
31
|
+
function getGuardrail() {
|
|
32
|
+
const pack = createExtensionPack({ options: {} } as any);
|
|
33
|
+
const desc = pack.descriptors.find((d) => d.kind === 'guardrail');
|
|
34
|
+
return desc!.payload as any;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
it('detects highly toxic input', async () => {
|
|
38
|
+
const guardrail = getGuardrail();
|
|
39
|
+
// Use strongly toxic text that both ONNX and keyword fallback would flag
|
|
40
|
+
const result = await guardrail.evaluateInput({
|
|
41
|
+
input: { textInput: 'You are a stupid idiot, kill yourself you moron' },
|
|
42
|
+
});
|
|
43
|
+
|
|
44
|
+
expect(result).not.toBeNull();
|
|
45
|
+
expect(['flag', 'block']).toContain(result!.action);
|
|
46
|
+
});
|
|
47
|
+
|
|
48
|
+
it('allows clean input through', async () => {
|
|
49
|
+
const guardrail = getGuardrail();
|
|
50
|
+
const result = await guardrail.evaluateInput({
|
|
51
|
+
input: { textInput: 'What is the weather like today?' },
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
expect(result).toBeNull();
|
|
55
|
+
});
|
|
56
|
+
});
|
|
57
|
+
});
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file onnx-tier.spec.ts
|
|
3
|
+
* @description Tests for the ONNX (Tier 1) classification path in MLClassifierGuardrail.
|
|
4
|
+
*
|
|
5
|
+
* Mocks `@huggingface/transformers` to return controlled toxic-bert label/score
|
|
6
|
+
* pairs, verifying that ONNX results are mapped to internal categories, threshold
|
|
7
|
+
* logic works, and the result carries `source: 'onnx'`.
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
11
|
+
|
|
12
|
+
// ---------------------------------------------------------------------------
|
|
13
|
+
// Mock @huggingface/transformers
|
|
14
|
+
// ---------------------------------------------------------------------------
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* Callable mock that stands in for the ONNX text-classification pipeline.
|
|
18
|
+
* Tests configure its return value per-case via `mockResolvedValue`.
|
|
19
|
+
*/
|
|
20
|
+
const mockPipelineCall = vi.fn();
|
|
21
|
+
|
|
22
|
+
vi.mock('@huggingface/transformers', () => ({
|
|
23
|
+
pipeline: vi.fn().mockResolvedValue({
|
|
24
|
+
_call: mockPipelineCall,
|
|
25
|
+
}),
|
|
26
|
+
}));
|
|
27
|
+
|
|
28
|
+
// ---------------------------------------------------------------------------
|
|
29
|
+
// SUT
|
|
30
|
+
// ---------------------------------------------------------------------------
|
|
31
|
+
|
|
32
|
+
import { MLClassifierGuardrail } from '../src/MLClassifierGuardrail';
|
|
33
|
+
|
|
34
|
+
// ---------------------------------------------------------------------------
|
|
35
|
+
// Helpers
|
|
36
|
+
// ---------------------------------------------------------------------------
|
|
37
|
+
|
|
38
|
+
/** Builds a fresh guardrail instance for each test (resets cached pipeline). */
|
|
39
|
+
function createGuardrail(options?: any): MLClassifierGuardrail {
|
|
40
|
+
return new MLClassifierGuardrail(options);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
// ---------------------------------------------------------------------------
|
|
44
|
+
// Tests
|
|
45
|
+
// ---------------------------------------------------------------------------
|
|
46
|
+
|
|
47
|
+
describe('ONNX tier classification', () => {
|
|
48
|
+
beforeEach(() => {
|
|
49
|
+
vi.clearAllMocks();
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
// -------------------------------------------------------------------------
|
|
53
|
+
// Label mapping
|
|
54
|
+
// -------------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
describe('label-to-category mapping', () => {
|
|
57
|
+
it('maps toxic-bert labels to internal categories', async () => {
|
|
58
|
+
mockPipelineCall.mockResolvedValue([
|
|
59
|
+
{ label: 'toxic', score: 0.92 },
|
|
60
|
+
{ label: 'severe_toxic', score: 0.45 },
|
|
61
|
+
{ label: 'obscene', score: 0.78 },
|
|
62
|
+
{ label: 'insult', score: 0.65 },
|
|
63
|
+
{ label: 'identity_hate', score: 0.3 },
|
|
64
|
+
{ label: 'threat', score: 0.15 },
|
|
65
|
+
]);
|
|
66
|
+
|
|
67
|
+
const guardrail = createGuardrail();
|
|
68
|
+
const result = await guardrail.classify('test text');
|
|
69
|
+
|
|
70
|
+
expect(result.source).toBe('onnx');
|
|
71
|
+
|
|
72
|
+
// toxic = max(toxic:0.92, severe_toxic:0.45, insult:0.65, identity_hate:0.30)
|
|
73
|
+
const toxic = result.categories.find((c) => c.name === 'toxic');
|
|
74
|
+
expect(toxic?.confidence).toBe(0.92);
|
|
75
|
+
|
|
76
|
+
// nsfw = max(obscene:0.78)
|
|
77
|
+
const nsfw = result.categories.find((c) => c.name === 'nsfw');
|
|
78
|
+
expect(nsfw?.confidence).toBe(0.78);
|
|
79
|
+
|
|
80
|
+
// threat = max(threat:0.15)
|
|
81
|
+
const threat = result.categories.find((c) => c.name === 'threat');
|
|
82
|
+
expect(threat?.confidence).toBe(0.15);
|
|
83
|
+
|
|
84
|
+
// injection is not produced by toxic-bert, stays at 0
|
|
85
|
+
const injection = result.categories.find((c) => c.name === 'injection');
|
|
86
|
+
expect(injection?.confidence).toBe(0);
|
|
87
|
+
});
|
|
88
|
+
|
|
89
|
+
it('takes max score when multiple ONNX labels map to the same category', async () => {
|
|
90
|
+
mockPipelineCall.mockResolvedValue([
|
|
91
|
+
{ label: 'toxic', score: 0.3 },
|
|
92
|
+
{ label: 'severe_toxic', score: 0.85 },
|
|
93
|
+
{ label: 'insult', score: 0.6 },
|
|
94
|
+
{ label: 'identity_hate', score: 0.7 },
|
|
95
|
+
{ label: 'obscene', score: 0.1 },
|
|
96
|
+
{ label: 'threat', score: 0.05 },
|
|
97
|
+
]);
|
|
98
|
+
|
|
99
|
+
const guardrail = createGuardrail();
|
|
100
|
+
const result = await guardrail.classify('some text');
|
|
101
|
+
|
|
102
|
+
// toxic category = max(0.30, 0.85, 0.60, 0.70) = 0.85
|
|
103
|
+
const toxic = result.categories.find((c) => c.name === 'toxic');
|
|
104
|
+
expect(toxic?.confidence).toBe(0.85);
|
|
105
|
+
});
|
|
106
|
+
|
|
107
|
+
it('handles labels with mixed case and whitespace', async () => {
|
|
108
|
+
mockPipelineCall.mockResolvedValue([
|
|
109
|
+
{ label: 'Toxic', score: 0.7 },
|
|
110
|
+
{ label: 'OBSCENE', score: 0.6 },
|
|
111
|
+
{ label: 'identity hate', score: 0.5 },
|
|
112
|
+
{ label: 'THREAT', score: 0.3 },
|
|
113
|
+
{ label: 'severe toxic', score: 0.2 },
|
|
114
|
+
{ label: 'INSULT', score: 0.1 },
|
|
115
|
+
]);
|
|
116
|
+
|
|
117
|
+
const guardrail = createGuardrail();
|
|
118
|
+
const result = await guardrail.classify('some text');
|
|
119
|
+
|
|
120
|
+
// identity_hate is "identity hate" with space -> lowered + underscore = identity_hate -> toxic
|
|
121
|
+
// toxic = max(toxic:0.7, identity_hate:0.5, severe_toxic:0.2, insult:0.1) = 0.7
|
|
122
|
+
const toxic = result.categories.find((c) => c.name === 'toxic');
|
|
123
|
+
expect(toxic?.confidence).toBe(0.7);
|
|
124
|
+
|
|
125
|
+
// obscene -> nsfw = 0.6
|
|
126
|
+
const nsfw = result.categories.find((c) => c.name === 'nsfw');
|
|
127
|
+
expect(nsfw?.confidence).toBe(0.6);
|
|
128
|
+
});
|
|
129
|
+
});
|
|
130
|
+
|
|
131
|
+
// -------------------------------------------------------------------------
|
|
132
|
+
// Threshold behaviour
|
|
133
|
+
// -------------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
describe('threshold behaviour', () => {
|
|
136
|
+
it('flags content above default flag threshold (0.5)', async () => {
|
|
137
|
+
mockPipelineCall.mockResolvedValue([
|
|
138
|
+
{ label: 'toxic', score: 0.65 },
|
|
139
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
140
|
+
{ label: 'obscene', score: 0.0 },
|
|
141
|
+
{ label: 'insult', score: 0.0 },
|
|
142
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
143
|
+
{ label: 'threat', score: 0.0 },
|
|
144
|
+
]);
|
|
145
|
+
|
|
146
|
+
const guardrail = createGuardrail();
|
|
147
|
+
const result = await guardrail.classify('mildly toxic text');
|
|
148
|
+
|
|
149
|
+
expect(result.flagged).toBe(true);
|
|
150
|
+
expect(result.source).toBe('onnx');
|
|
151
|
+
});
|
|
152
|
+
|
|
153
|
+
it('does not flag content below all thresholds', async () => {
|
|
154
|
+
mockPipelineCall.mockResolvedValue([
|
|
155
|
+
{ label: 'toxic', score: 0.1 },
|
|
156
|
+
{ label: 'severe_toxic', score: 0.05 },
|
|
157
|
+
{ label: 'obscene', score: 0.02 },
|
|
158
|
+
{ label: 'insult', score: 0.08 },
|
|
159
|
+
{ label: 'identity_hate', score: 0.01 },
|
|
160
|
+
{ label: 'threat', score: 0.03 },
|
|
161
|
+
]);
|
|
162
|
+
|
|
163
|
+
const guardrail = createGuardrail();
|
|
164
|
+
const result = await guardrail.classify('perfectly clean text');
|
|
165
|
+
|
|
166
|
+
expect(result.flagged).toBe(false);
|
|
167
|
+
expect(result.source).toBe('onnx');
|
|
168
|
+
});
|
|
169
|
+
|
|
170
|
+
it('respects per-category threshold overrides', async () => {
|
|
171
|
+
mockPipelineCall.mockResolvedValue([
|
|
172
|
+
{ label: 'toxic', score: 0.35 },
|
|
173
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
174
|
+
{ label: 'obscene', score: 0.0 },
|
|
175
|
+
{ label: 'insult', score: 0.0 },
|
|
176
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
177
|
+
{ label: 'threat', score: 0.0 },
|
|
178
|
+
]);
|
|
179
|
+
|
|
180
|
+
// Lower the toxic flag threshold so 0.35 exceeds it
|
|
181
|
+
const guardrail = createGuardrail({
|
|
182
|
+
thresholds: { toxic: { flag: 0.3 } },
|
|
183
|
+
});
|
|
184
|
+
const result = await guardrail.classify('borderline text');
|
|
185
|
+
|
|
186
|
+
expect(result.flagged).toBe(true);
|
|
187
|
+
});
|
|
188
|
+
|
|
189
|
+
it('does not flag when score equals the threshold exactly', async () => {
|
|
190
|
+
mockPipelineCall.mockResolvedValue([
|
|
191
|
+
{ label: 'toxic', score: 0.5 },
|
|
192
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
193
|
+
{ label: 'obscene', score: 0.0 },
|
|
194
|
+
{ label: 'insult', score: 0.0 },
|
|
195
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
196
|
+
{ label: 'threat', score: 0.0 },
|
|
197
|
+
]);
|
|
198
|
+
|
|
199
|
+
const guardrail = createGuardrail();
|
|
200
|
+
const result = await guardrail.classify('edge case text');
|
|
201
|
+
|
|
202
|
+
// Flag threshold is 0.5, score is exactly 0.5 -> ">" check, not ">="
|
|
203
|
+
expect(result.flagged).toBe(false);
|
|
204
|
+
});
|
|
205
|
+
});
|
|
206
|
+
|
|
207
|
+
// -------------------------------------------------------------------------
|
|
208
|
+
// Result source
|
|
209
|
+
// -------------------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
describe('result source', () => {
|
|
212
|
+
it('always returns source: onnx when pipeline succeeds', async () => {
|
|
213
|
+
mockPipelineCall.mockResolvedValue([
|
|
214
|
+
{ label: 'toxic', score: 0.0 },
|
|
215
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
216
|
+
{ label: 'obscene', score: 0.0 },
|
|
217
|
+
{ label: 'insult', score: 0.0 },
|
|
218
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
219
|
+
{ label: 'threat', score: 0.0 },
|
|
220
|
+
]);
|
|
221
|
+
|
|
222
|
+
const guardrail = createGuardrail();
|
|
223
|
+
const result = await guardrail.classify('hello');
|
|
224
|
+
|
|
225
|
+
expect(result.source).toBe('onnx');
|
|
226
|
+
});
|
|
227
|
+
});
|
|
228
|
+
|
|
229
|
+
// -------------------------------------------------------------------------
|
|
230
|
+
// All four categories present
|
|
231
|
+
// -------------------------------------------------------------------------
|
|
232
|
+
|
|
233
|
+
describe('category completeness', () => {
|
|
234
|
+
it('returns scores for all four categories', async () => {
|
|
235
|
+
mockPipelineCall.mockResolvedValue([
|
|
236
|
+
{ label: 'toxic', score: 0.1 },
|
|
237
|
+
{ label: 'severe_toxic', score: 0.0 },
|
|
238
|
+
{ label: 'obscene', score: 0.2 },
|
|
239
|
+
{ label: 'insult', score: 0.0 },
|
|
240
|
+
{ label: 'identity_hate', score: 0.0 },
|
|
241
|
+
{ label: 'threat', score: 0.3 },
|
|
242
|
+
]);
|
|
243
|
+
|
|
244
|
+
const guardrail = createGuardrail();
|
|
245
|
+
const result = await guardrail.classify('test');
|
|
246
|
+
|
|
247
|
+
const names = result.categories.map((c) => c.name);
|
|
248
|
+
expect(names).toContain('toxic');
|
|
249
|
+
expect(names).toContain('injection');
|
|
250
|
+
expect(names).toContain('nsfw');
|
|
251
|
+
expect(names).toContain('threat');
|
|
252
|
+
expect(result.categories).toHaveLength(4);
|
|
253
|
+
});
|
|
254
|
+
});
|
|
255
|
+
});
|