@framers/agentos-ext-ml-classifiers 0.2.1 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. package/.github/workflows/ci.yml +20 -0
  2. package/.github/workflows/release.yml +37 -0
  3. package/.releaserc.json +9 -0
  4. package/LICENSE +96 -21
  5. package/README.md +72 -0
  6. package/dist/MLClassifierGuardrail.d.ts.map +1 -1
  7. package/dist/MLClassifierGuardrail.js +14 -6
  8. package/dist/MLClassifierGuardrail.js.map +1 -1
  9. package/dist/index.js +3 -3
  10. package/dist/keyword-classifier.js +1 -1
  11. package/dist/llm-classifier.js +1 -1
  12. package/package.json +5 -13
  13. package/scripts/fix-esm-imports.mjs +181 -0
  14. package/src/MLClassifierGuardrail.ts +38 -5
  15. package/test/llm-tier.spec.ts +267 -0
  16. package/test/ml-classifiers.spec.ts +57 -0
  17. package/test/onnx-tier.spec.ts +255 -0
  18. package/test/tier-fallthrough.spec.ts +185 -0
  19. package/vitest.config.ts +18 -7
  20. package/CHANGELOG.md +0 -18
  21. package/dist/ClassifierOrchestrator.d.ts +0 -126
  22. package/dist/ClassifierOrchestrator.d.ts.map +0 -1
  23. package/dist/ClassifierOrchestrator.js +0 -239
  24. package/dist/ClassifierOrchestrator.js.map +0 -1
  25. package/dist/IContentClassifier.d.ts +0 -117
  26. package/dist/IContentClassifier.d.ts.map +0 -1
  27. package/dist/IContentClassifier.js +0 -22
  28. package/dist/IContentClassifier.js.map +0 -1
  29. package/dist/SlidingWindowBuffer.d.ts +0 -213
  30. package/dist/SlidingWindowBuffer.d.ts.map +0 -1
  31. package/dist/SlidingWindowBuffer.js +0 -246
  32. package/dist/SlidingWindowBuffer.js.map +0 -1
  33. package/dist/classifiers/InjectionClassifier.d.ts +0 -126
  34. package/dist/classifiers/InjectionClassifier.d.ts.map +0 -1
  35. package/dist/classifiers/InjectionClassifier.js +0 -210
  36. package/dist/classifiers/InjectionClassifier.js.map +0 -1
  37. package/dist/classifiers/JailbreakClassifier.d.ts +0 -124
  38. package/dist/classifiers/JailbreakClassifier.d.ts.map +0 -1
  39. package/dist/classifiers/JailbreakClassifier.js +0 -208
  40. package/dist/classifiers/JailbreakClassifier.js.map +0 -1
  41. package/dist/classifiers/ToxicityClassifier.d.ts +0 -125
  42. package/dist/classifiers/ToxicityClassifier.d.ts.map +0 -1
  43. package/dist/classifiers/ToxicityClassifier.js +0 -212
  44. package/dist/classifiers/ToxicityClassifier.js.map +0 -1
  45. package/dist/classifiers/WorkerClassifierProxy.d.ts +0 -158
  46. package/dist/classifiers/WorkerClassifierProxy.d.ts.map +0 -1
  47. package/dist/classifiers/WorkerClassifierProxy.js +0 -268
  48. package/dist/classifiers/WorkerClassifierProxy.js.map +0 -1
  49. package/dist/worker/classifier-worker.d.ts +0 -49
  50. package/dist/worker/classifier-worker.d.ts.map +0 -1
  51. package/dist/worker/classifier-worker.js +0 -180
  52. package/dist/worker/classifier-worker.js.map +0 -1
  53. package/src/ClassifierOrchestrator.ts +0 -290
  54. package/src/IContentClassifier.ts +0 -124
  55. package/src/SlidingWindowBuffer.ts +0 -384
  56. package/src/classifiers/InjectionClassifier.ts +0 -261
  57. package/src/classifiers/JailbreakClassifier.ts +0 -259
  58. package/src/classifiers/ToxicityClassifier.ts +0 -263
  59. package/src/classifiers/WorkerClassifierProxy.ts +0 -366
  60. package/src/worker/classifier-worker.ts +0 -267
  61. package/test/ClassifierOrchestrator.spec.ts +0 -365
  62. package/test/ClassifyContentTool.spec.ts +0 -226
  63. package/test/InjectionClassifier.spec.ts +0 -263
  64. package/test/JailbreakClassifier.spec.ts +0 -295
  65. package/test/MLClassifierGuardrail.spec.ts +0 -486
  66. package/test/SlidingWindowBuffer.spec.ts +0 -391
  67. package/test/ToxicityClassifier.spec.ts +0 -268
  68. package/test/WorkerClassifierProxy.spec.ts +0 -303
  69. package/test/index.spec.ts +0 -431
@@ -27,6 +27,7 @@ import type {
27
27
  GuardrailInputPayload,
28
28
  GuardrailOutputPayload,
29
29
  GuardrailEvaluationResult,
30
+ AgentOSFinalResponseChunk,
30
31
  } from '@framers/agentos';
31
32
  import { GuardrailAction } from '@framers/agentos';
32
33
  import { AgentOSResponseChunkType } from '@framers/agentos';
@@ -40,6 +41,27 @@ import { ALL_CATEGORIES } from './types';
40
41
  import { classifyByKeywords } from './keyword-classifier';
41
42
  import { classifyByLlm } from './llm-classifier';
42
43
 
44
+ // ---------------------------------------------------------------------------
45
+ // HuggingFace / ONNX pipeline types
46
+ // ---------------------------------------------------------------------------
47
+
48
+ /**
49
+ * A single label + score pair returned by a HuggingFace text-classification
50
+ * pipeline. The `label` is the model's class name (e.g. `"toxic"`,
51
+ * `"obscene"`) and `score` is the softmax probability in [0, 1].
52
+ */
53
+ interface OnnxClassificationLabel {
54
+ label: string;
55
+ score: number;
56
+ }
57
+
58
+ /**
59
+ * Callable returned by `pipeline('text-classification', ...)` from the
60
+ * `@huggingface/transformers` package. Returns all label scores for the
61
+ * given text.
62
+ */
63
+ type OnnxTextClassificationPipeline = (text: string) => Promise<OnnxClassificationLabel[]>;
64
+
43
65
  // ---------------------------------------------------------------------------
44
66
  // MLClassifierGuardrail
45
67
  // ---------------------------------------------------------------------------
@@ -91,7 +113,7 @@ export class MLClassifierGuardrail implements IGuardrailService {
91
113
  * `null` means we already tried and failed to load the module.
92
114
  * `undefined` means we have not tried yet.
93
115
  */
94
- private onnxPipeline: any | null | undefined = undefined;
116
+ private onnxPipeline: OnnxTextClassificationPipeline | null | undefined = undefined;
95
117
 
96
118
  // -----------------------------------------------------------------------
97
119
  // Constructor
@@ -159,7 +181,8 @@ export class MLClassifierGuardrail implements IGuardrailService {
159
181
  return null;
160
182
  }
161
183
 
162
- const text = (chunk as any).text ?? (chunk as any).content ?? '';
184
+ const finalChunk = chunk as AgentOSFinalResponseChunk;
185
+ const text = finalChunk.finalResponseText ?? '';
163
186
  if (typeof text !== 'string' || text.length === 0) return null;
164
187
 
165
188
  const result = await this.classify(text);
@@ -222,11 +245,21 @@ export class MLClassifierGuardrail implements IGuardrailService {
222
245
  try {
223
246
  // Dynamic import so the optional dependency does not fail at boot.
224
247
  const transformers = await import('@huggingface/transformers');
225
- this.onnxPipeline = await transformers.pipeline(
248
+ const pipelineInstance = await transformers.pipeline(
226
249
  'text-classification',
227
250
  'Xenova/toxic-bert',
228
251
  { device: 'cpu' }
229
252
  );
253
+ // The HuggingFace pipeline is callable as a function. We always
254
+ // request all labels (top_k higher than any model's label count
255
+ // causes HF to return every label, matching `top_k: null`).
256
+ // Cast through unknown because the Pipeline union type is too
257
+ // wide for the inferred call signature here.
258
+ const callable = pipelineInstance as unknown as (
259
+ text: string,
260
+ opts: { top_k: number },
261
+ ) => Promise<OnnxClassificationLabel[]>;
262
+ this.onnxPipeline = (text: string) => callable(text, { top_k: 9999 });
230
263
  } catch {
231
264
  // Module not installed or model load failed — mark as unavailable.
232
265
  this.onnxPipeline = null;
@@ -235,7 +268,7 @@ export class MLClassifierGuardrail implements IGuardrailService {
235
268
  }
236
269
 
237
270
  try {
238
- const raw = await this.onnxPipeline(text, { topk: null });
271
+ const raw = await this.onnxPipeline(text);
239
272
 
240
273
  // Map ONNX labels to our categories.
241
274
  const scores = this.mapOnnxScores(raw);
@@ -263,7 +296,7 @@ export class MLClassifierGuardrail implements IGuardrailService {
263
296
  *
264
297
  * @internal
265
298
  */
266
- private mapOnnxScores(raw: any[]): CategoryScore[] {
299
+ private mapOnnxScores(raw: OnnxClassificationLabel[]): CategoryScore[] {
267
300
  /** Map of ONNX label -> our category. */
268
301
  const labelMap: Record<string, ClassifierCategory> = {
269
302
  toxic: 'toxic',
@@ -0,0 +1,267 @@
1
+ /**
2
+ * @file llm-tier.spec.ts
3
+ * @description Tests for the LLM-as-judge (Tier 2) classification path.
4
+ *
5
+ * Exercises `classifyByLlm()` directly — verifying that the structured
6
+ * classification prompt is sent to the invoker, JSON and markdown-wrapped JSON
7
+ * are parsed correctly, and failures produce zero-confidence scores.
8
+ */
9
+
10
+ import { describe, it, expect, vi, beforeEach } from 'vitest';
11
+
12
+ // Force ONNX unavailable so classify() falls through to LLM tier
13
+ vi.mock('@huggingface/transformers', () => {
14
+ throw new Error('ONNX not available');
15
+ });
16
+
17
+ import { classifyByLlm } from '../src/llm-classifier';
18
+ import type { LlmInvoker, ClassifierCategory } from '../src/types';
19
+ import { MLClassifierGuardrail } from '../src/MLClassifierGuardrail';
20
+
21
+ // ---------------------------------------------------------------------------
22
+ // Tests — classifyByLlm directly
23
+ // ---------------------------------------------------------------------------
24
+
25
+ describe('LLM tier — classifyByLlm()', () => {
26
+ // -----------------------------------------------------------------------
27
+ // Prompt verification
28
+ // -----------------------------------------------------------------------
29
+
30
+ describe('prompt sent to invoker', () => {
31
+ it('sends a system prompt and the user text to the invoker', async () => {
32
+ const invoker: LlmInvoker = vi.fn().mockResolvedValue(
33
+ JSON.stringify({
34
+ toxic: false,
35
+ injection: false,
36
+ nsfw: false,
37
+ threat: false,
38
+ confidence: 0.9,
39
+ })
40
+ );
41
+
42
+ await classifyByLlm('Hello world', invoker);
43
+
44
+ expect(invoker).toHaveBeenCalledTimes(1);
45
+
46
+ const [systemPrompt, userMessage] = (invoker as any).mock.calls[0];
47
+ expect(typeof systemPrompt).toBe('string');
48
+ expect(systemPrompt).toContain('content safety classifier');
49
+ expect(userMessage).toBe('Hello world');
50
+ });
51
+ });
52
+
53
+ // -----------------------------------------------------------------------
54
+ // JSON parsing
55
+ // -----------------------------------------------------------------------
56
+
57
+ describe('JSON response parsing', () => {
58
+ it('parses a clean JSON response into category scores', async () => {
59
+ const invoker: LlmInvoker = vi.fn().mockResolvedValue(
60
+ JSON.stringify({
61
+ toxic: true,
62
+ injection: false,
63
+ nsfw: false,
64
+ threat: true,
65
+ confidence: 0.85,
66
+ })
67
+ );
68
+
69
+ const scores = await classifyByLlm('some bad text', invoker);
70
+
71
+ expect(scores).toHaveLength(4);
72
+
73
+ const toxic = scores.find((s) => s.name === 'toxic');
74
+ expect(toxic?.confidence).toBe(0.85);
75
+
76
+ const injection = scores.find((s) => s.name === 'injection');
77
+ expect(injection?.confidence).toBe(0);
78
+
79
+ const nsfw = scores.find((s) => s.name === 'nsfw');
80
+ expect(nsfw?.confidence).toBe(0);
81
+
82
+ const threat = scores.find((s) => s.name === 'threat');
83
+ expect(threat?.confidence).toBe(0.85);
84
+ });
85
+
86
+ it('uses default confidence (0.7) when confidence is omitted', async () => {
87
+ const invoker: LlmInvoker = vi
88
+ .fn()
89
+ .mockResolvedValue(
90
+ JSON.stringify({ toxic: true, injection: false, nsfw: false, threat: false })
91
+ );
92
+
93
+ const scores = await classifyByLlm('abusive text', invoker);
94
+
95
+ const toxic = scores.find((s) => s.name === 'toxic');
96
+ expect(toxic?.confidence).toBe(0.7);
97
+ });
98
+
99
+ it('clamps confidence to [0, 1]', async () => {
100
+ const invoker: LlmInvoker = vi.fn().mockResolvedValue(
101
+ JSON.stringify({
102
+ toxic: true,
103
+ injection: false,
104
+ nsfw: false,
105
+ threat: false,
106
+ confidence: 5.0,
107
+ })
108
+ );
109
+
110
+ const scores = await classifyByLlm('test', invoker);
111
+ const toxic = scores.find((s) => s.name === 'toxic');
112
+ expect(toxic?.confidence).toBeLessThanOrEqual(1.0);
113
+ });
114
+ });
115
+
116
+ // -----------------------------------------------------------------------
117
+ // Markdown-wrapped JSON
118
+ // -----------------------------------------------------------------------
119
+
120
+ describe('markdown-wrapped JSON handling', () => {
121
+ it('strips ```json fences before parsing', async () => {
122
+ const invoker: LlmInvoker = vi
123
+ .fn()
124
+ .mockResolvedValue(
125
+ '```json\n{"toxic": true, "injection": false, "nsfw": false, "threat": false, "confidence": 0.9}\n```'
126
+ );
127
+
128
+ const scores = await classifyByLlm('wrapped response', invoker);
129
+ const toxic = scores.find((s) => s.name === 'toxic');
130
+ expect(toxic?.confidence).toBe(0.9);
131
+ });
132
+
133
+ it('strips bare ``` fences (no language tag)', async () => {
134
+ const invoker: LlmInvoker = vi
135
+ .fn()
136
+ .mockResolvedValue(
137
+ '```\n{"toxic": false, "injection": true, "nsfw": false, "threat": false, "confidence": 0.75}\n```'
138
+ );
139
+
140
+ const scores = await classifyByLlm('injection attempt', invoker);
141
+ const injection = scores.find((s) => s.name === 'injection');
142
+ expect(injection?.confidence).toBe(0.75);
143
+ });
144
+
145
+ it('handles trailing commas in LLM output', async () => {
146
+ const invoker: LlmInvoker = vi
147
+ .fn()
148
+ .mockResolvedValue(
149
+ '{"toxic": true, "injection": false, "nsfw": false, "threat": false, "confidence": 0.8,}'
150
+ );
151
+
152
+ const scores = await classifyByLlm('trailing comma', invoker);
153
+ const toxic = scores.find((s) => s.name === 'toxic');
154
+ expect(toxic?.confidence).toBe(0.8);
155
+ });
156
+ });
157
+
158
+ // -----------------------------------------------------------------------
159
+ // Failure modes
160
+ // -----------------------------------------------------------------------
161
+
162
+ describe('failure handling', () => {
163
+ it('returns zero scores when invoker throws', async () => {
164
+ const invoker: LlmInvoker = vi.fn().mockRejectedValue(new Error('LLM unavailable'));
165
+
166
+ const scores = await classifyByLlm('test', invoker);
167
+
168
+ expect(scores).toHaveLength(4);
169
+ for (const score of scores) {
170
+ expect(score.confidence).toBe(0);
171
+ }
172
+ });
173
+
174
+ it('returns zero scores when invoker returns unparseable text', async () => {
175
+ const invoker: LlmInvoker = vi.fn().mockResolvedValue('I cannot classify this content.');
176
+
177
+ const scores = await classifyByLlm('test', invoker);
178
+
179
+ for (const score of scores) {
180
+ expect(score.confidence).toBe(0);
181
+ }
182
+ });
183
+
184
+ it('returns zero scores when invoker returns an array instead of object', async () => {
185
+ const invoker: LlmInvoker = vi.fn().mockResolvedValue('[1, 2, 3]');
186
+
187
+ const scores = await classifyByLlm('test', invoker);
188
+
189
+ for (const score of scores) {
190
+ expect(score.confidence).toBe(0);
191
+ }
192
+ });
193
+ });
194
+
195
+ // -----------------------------------------------------------------------
196
+ // Category filtering
197
+ // -----------------------------------------------------------------------
198
+
199
+ describe('category filtering', () => {
200
+ it('returns scores only for requested categories', async () => {
201
+ const invoker: LlmInvoker = vi.fn().mockResolvedValue(
202
+ JSON.stringify({
203
+ toxic: true,
204
+ injection: true,
205
+ nsfw: false,
206
+ threat: false,
207
+ confidence: 0.9,
208
+ })
209
+ );
210
+
211
+ const subset: ClassifierCategory[] = ['toxic', 'injection'];
212
+ const scores = await classifyByLlm('targeted', invoker, subset);
213
+
214
+ expect(scores).toHaveLength(2);
215
+ expect(scores.map((s) => s.name)).toEqual(['toxic', 'injection']);
216
+ });
217
+ });
218
+ });
219
+
220
+ // ---------------------------------------------------------------------------
221
+ // Tests — LLM tier via MLClassifierGuardrail.classify()
222
+ // ---------------------------------------------------------------------------
223
+
224
+ describe('LLM tier — via guardrail classify()', () => {
225
+ beforeEach(() => {
226
+ vi.clearAllMocks();
227
+ });
228
+
229
+ it('falls through to LLM when ONNX is unavailable', async () => {
230
+ const invoker: LlmInvoker = vi.fn().mockResolvedValue(
231
+ JSON.stringify({
232
+ toxic: true,
233
+ injection: false,
234
+ nsfw: false,
235
+ threat: false,
236
+ confidence: 0.9,
237
+ })
238
+ );
239
+
240
+ const guardrail = new MLClassifierGuardrail({ llmInvoker: invoker });
241
+ const result = await guardrail.classify('test');
242
+
243
+ expect(result.source).toBe('llm');
244
+ expect(invoker).toHaveBeenCalledTimes(1);
245
+ });
246
+
247
+ it('result.flagged is true when LLM detects a category above threshold', async () => {
248
+ const invoker: LlmInvoker = vi.fn().mockResolvedValue(
249
+ JSON.stringify({
250
+ toxic: true,
251
+ injection: false,
252
+ nsfw: false,
253
+ threat: false,
254
+ confidence: 0.85,
255
+ })
256
+ );
257
+
258
+ const guardrail = new MLClassifierGuardrail({ llmInvoker: invoker });
259
+ const result = await guardrail.classify('abusive text');
260
+
261
+ expect(result.flagged).toBe(true);
262
+ expect(result.source).toBe('llm');
263
+
264
+ const toxic = result.categories.find((c) => c.name === 'toxic');
265
+ expect(toxic?.confidence).toBe(0.85);
266
+ });
267
+ });
@@ -0,0 +1,57 @@
1
+ // @ts-nocheck
2
+ import { describe, it, expect } from 'vitest';
3
+ import { createExtensionPack } from '../src/index';
4
+
5
+ describe('ML Classifiers Extension Pack', () => {
6
+ // -------------------------------------------------------------------------
7
+ // Pack structure
8
+ // -------------------------------------------------------------------------
9
+
10
+ it('createExtensionPack returns correct structure', () => {
11
+ const pack = createExtensionPack({ options: {} } as any);
12
+
13
+ expect(pack.name).toBe('ml-classifiers');
14
+ expect(pack.version).toBe('1.0.0');
15
+ expect(pack.descriptors).toHaveLength(2);
16
+
17
+ const kinds = pack.descriptors.map((d) => d.kind);
18
+ expect(kinds).toContain('guardrail');
19
+ expect(kinds).toContain('tool');
20
+
21
+ const ids = pack.descriptors.map((d) => d.id);
22
+ expect(ids).toContain('ml-classifier-guardrail');
23
+ expect(ids).toContain('classify_content');
24
+ });
25
+
26
+ // -------------------------------------------------------------------------
27
+ // Guardrail — keyword fallback detection
28
+ // -------------------------------------------------------------------------
29
+
30
+ describe('guardrail evaluateInput', () => {
31
+ function getGuardrail() {
32
+ const pack = createExtensionPack({ options: {} } as any);
33
+ const desc = pack.descriptors.find((d) => d.kind === 'guardrail');
34
+ return desc!.payload as any;
35
+ }
36
+
37
+ it('detects highly toxic input', async () => {
38
+ const guardrail = getGuardrail();
39
+ // Use strongly toxic text that both ONNX and keyword fallback would flag
40
+ const result = await guardrail.evaluateInput({
41
+ input: { textInput: 'You are a stupid idiot, kill yourself you moron' },
42
+ });
43
+
44
+ expect(result).not.toBeNull();
45
+ expect(['flag', 'block']).toContain(result!.action);
46
+ });
47
+
48
+ it('allows clean input through', async () => {
49
+ const guardrail = getGuardrail();
50
+ const result = await guardrail.evaluateInput({
51
+ input: { textInput: 'What is the weather like today?' },
52
+ });
53
+
54
+ expect(result).toBeNull();
55
+ });
56
+ });
57
+ });
@@ -0,0 +1,255 @@
1
+ /**
2
+ * @file onnx-tier.spec.ts
3
+ * @description Tests for the ONNX (Tier 1) classification path in MLClassifierGuardrail.
4
+ *
5
+ * Mocks `@huggingface/transformers` to return controlled toxic-bert label/score
6
+ * pairs, verifying that ONNX results are mapped to internal categories, threshold
7
+ * logic works, and the result carries `source: 'onnx'`.
8
+ */
9
+
10
+ import { describe, it, expect, vi, beforeEach } from 'vitest';
11
+
12
+ // ---------------------------------------------------------------------------
13
+ // Mock @huggingface/transformers
14
+ // ---------------------------------------------------------------------------
15
+
16
+ /**
17
+ * Callable mock that stands in for the ONNX text-classification pipeline.
18
+ * Tests configure its return value per-case via `mockResolvedValue`.
19
+ */
20
+ const mockPipelineCall = vi.fn();
21
+
22
+ vi.mock('@huggingface/transformers', () => ({
23
+ pipeline: vi.fn().mockResolvedValue({
24
+ _call: mockPipelineCall,
25
+ }),
26
+ }));
27
+
28
+ // ---------------------------------------------------------------------------
29
+ // SUT
30
+ // ---------------------------------------------------------------------------
31
+
32
+ import { MLClassifierGuardrail } from '../src/MLClassifierGuardrail';
33
+
34
+ // ---------------------------------------------------------------------------
35
+ // Helpers
36
+ // ---------------------------------------------------------------------------
37
+
38
+ /** Builds a fresh guardrail instance for each test (resets cached pipeline). */
39
+ function createGuardrail(options?: any): MLClassifierGuardrail {
40
+ return new MLClassifierGuardrail(options);
41
+ }
42
+
43
+ // ---------------------------------------------------------------------------
44
+ // Tests
45
+ // ---------------------------------------------------------------------------
46
+
47
+ describe('ONNX tier classification', () => {
48
+ beforeEach(() => {
49
+ vi.clearAllMocks();
50
+ });
51
+
52
+ // -------------------------------------------------------------------------
53
+ // Label mapping
54
+ // -------------------------------------------------------------------------
55
+
56
+ describe('label-to-category mapping', () => {
57
+ it('maps toxic-bert labels to internal categories', async () => {
58
+ mockPipelineCall.mockResolvedValue([
59
+ { label: 'toxic', score: 0.92 },
60
+ { label: 'severe_toxic', score: 0.45 },
61
+ { label: 'obscene', score: 0.78 },
62
+ { label: 'insult', score: 0.65 },
63
+ { label: 'identity_hate', score: 0.3 },
64
+ { label: 'threat', score: 0.15 },
65
+ ]);
66
+
67
+ const guardrail = createGuardrail();
68
+ const result = await guardrail.classify('test text');
69
+
70
+ expect(result.source).toBe('onnx');
71
+
72
+ // toxic = max(toxic:0.92, severe_toxic:0.45, insult:0.65, identity_hate:0.30)
73
+ const toxic = result.categories.find((c) => c.name === 'toxic');
74
+ expect(toxic?.confidence).toBe(0.92);
75
+
76
+ // nsfw = max(obscene:0.78)
77
+ const nsfw = result.categories.find((c) => c.name === 'nsfw');
78
+ expect(nsfw?.confidence).toBe(0.78);
79
+
80
+ // threat = max(threat:0.15)
81
+ const threat = result.categories.find((c) => c.name === 'threat');
82
+ expect(threat?.confidence).toBe(0.15);
83
+
84
+ // injection is not produced by toxic-bert, stays at 0
85
+ const injection = result.categories.find((c) => c.name === 'injection');
86
+ expect(injection?.confidence).toBe(0);
87
+ });
88
+
89
+ it('takes max score when multiple ONNX labels map to the same category', async () => {
90
+ mockPipelineCall.mockResolvedValue([
91
+ { label: 'toxic', score: 0.3 },
92
+ { label: 'severe_toxic', score: 0.85 },
93
+ { label: 'insult', score: 0.6 },
94
+ { label: 'identity_hate', score: 0.7 },
95
+ { label: 'obscene', score: 0.1 },
96
+ { label: 'threat', score: 0.05 },
97
+ ]);
98
+
99
+ const guardrail = createGuardrail();
100
+ const result = await guardrail.classify('some text');
101
+
102
+ // toxic category = max(0.30, 0.85, 0.60, 0.70) = 0.85
103
+ const toxic = result.categories.find((c) => c.name === 'toxic');
104
+ expect(toxic?.confidence).toBe(0.85);
105
+ });
106
+
107
+ it('handles labels with mixed case and whitespace', async () => {
108
+ mockPipelineCall.mockResolvedValue([
109
+ { label: 'Toxic', score: 0.7 },
110
+ { label: 'OBSCENE', score: 0.6 },
111
+ { label: 'identity hate', score: 0.5 },
112
+ { label: 'THREAT', score: 0.3 },
113
+ { label: 'severe toxic', score: 0.2 },
114
+ { label: 'INSULT', score: 0.1 },
115
+ ]);
116
+
117
+ const guardrail = createGuardrail();
118
+ const result = await guardrail.classify('some text');
119
+
120
+ // identity_hate is "identity hate" with space -> lowered + underscore = identity_hate -> toxic
121
+ // toxic = max(toxic:0.7, identity_hate:0.5, severe_toxic:0.2, insult:0.1) = 0.7
122
+ const toxic = result.categories.find((c) => c.name === 'toxic');
123
+ expect(toxic?.confidence).toBe(0.7);
124
+
125
+ // obscene -> nsfw = 0.6
126
+ const nsfw = result.categories.find((c) => c.name === 'nsfw');
127
+ expect(nsfw?.confidence).toBe(0.6);
128
+ });
129
+ });
130
+
131
+ // -------------------------------------------------------------------------
132
+ // Threshold behaviour
133
+ // -------------------------------------------------------------------------
134
+
135
+ describe('threshold behaviour', () => {
136
+ it('flags content above default flag threshold (0.5)', async () => {
137
+ mockPipelineCall.mockResolvedValue([
138
+ { label: 'toxic', score: 0.65 },
139
+ { label: 'severe_toxic', score: 0.0 },
140
+ { label: 'obscene', score: 0.0 },
141
+ { label: 'insult', score: 0.0 },
142
+ { label: 'identity_hate', score: 0.0 },
143
+ { label: 'threat', score: 0.0 },
144
+ ]);
145
+
146
+ const guardrail = createGuardrail();
147
+ const result = await guardrail.classify('mildly toxic text');
148
+
149
+ expect(result.flagged).toBe(true);
150
+ expect(result.source).toBe('onnx');
151
+ });
152
+
153
+ it('does not flag content below all thresholds', async () => {
154
+ mockPipelineCall.mockResolvedValue([
155
+ { label: 'toxic', score: 0.1 },
156
+ { label: 'severe_toxic', score: 0.05 },
157
+ { label: 'obscene', score: 0.02 },
158
+ { label: 'insult', score: 0.08 },
159
+ { label: 'identity_hate', score: 0.01 },
160
+ { label: 'threat', score: 0.03 },
161
+ ]);
162
+
163
+ const guardrail = createGuardrail();
164
+ const result = await guardrail.classify('perfectly clean text');
165
+
166
+ expect(result.flagged).toBe(false);
167
+ expect(result.source).toBe('onnx');
168
+ });
169
+
170
+ it('respects per-category threshold overrides', async () => {
171
+ mockPipelineCall.mockResolvedValue([
172
+ { label: 'toxic', score: 0.35 },
173
+ { label: 'severe_toxic', score: 0.0 },
174
+ { label: 'obscene', score: 0.0 },
175
+ { label: 'insult', score: 0.0 },
176
+ { label: 'identity_hate', score: 0.0 },
177
+ { label: 'threat', score: 0.0 },
178
+ ]);
179
+
180
+ // Lower the toxic flag threshold so 0.35 exceeds it
181
+ const guardrail = createGuardrail({
182
+ thresholds: { toxic: { flag: 0.3 } },
183
+ });
184
+ const result = await guardrail.classify('borderline text');
185
+
186
+ expect(result.flagged).toBe(true);
187
+ });
188
+
189
+ it('does not flag when score equals the threshold exactly', async () => {
190
+ mockPipelineCall.mockResolvedValue([
191
+ { label: 'toxic', score: 0.5 },
192
+ { label: 'severe_toxic', score: 0.0 },
193
+ { label: 'obscene', score: 0.0 },
194
+ { label: 'insult', score: 0.0 },
195
+ { label: 'identity_hate', score: 0.0 },
196
+ { label: 'threat', score: 0.0 },
197
+ ]);
198
+
199
+ const guardrail = createGuardrail();
200
+ const result = await guardrail.classify('edge case text');
201
+
202
+ // Flag threshold is 0.5, score is exactly 0.5 -> ">" check, not ">="
203
+ expect(result.flagged).toBe(false);
204
+ });
205
+ });
206
+
207
+ // -------------------------------------------------------------------------
208
+ // Result source
209
+ // -------------------------------------------------------------------------
210
+
211
+ describe('result source', () => {
212
+ it('always returns source: onnx when pipeline succeeds', async () => {
213
+ mockPipelineCall.mockResolvedValue([
214
+ { label: 'toxic', score: 0.0 },
215
+ { label: 'severe_toxic', score: 0.0 },
216
+ { label: 'obscene', score: 0.0 },
217
+ { label: 'insult', score: 0.0 },
218
+ { label: 'identity_hate', score: 0.0 },
219
+ { label: 'threat', score: 0.0 },
220
+ ]);
221
+
222
+ const guardrail = createGuardrail();
223
+ const result = await guardrail.classify('hello');
224
+
225
+ expect(result.source).toBe('onnx');
226
+ });
227
+ });
228
+
229
+ // -------------------------------------------------------------------------
230
+ // All four categories present
231
+ // -------------------------------------------------------------------------
232
+
233
+ describe('category completeness', () => {
234
+ it('returns scores for all four categories', async () => {
235
+ mockPipelineCall.mockResolvedValue([
236
+ { label: 'toxic', score: 0.1 },
237
+ { label: 'severe_toxic', score: 0.0 },
238
+ { label: 'obscene', score: 0.2 },
239
+ { label: 'insult', score: 0.0 },
240
+ { label: 'identity_hate', score: 0.0 },
241
+ { label: 'threat', score: 0.3 },
242
+ ]);
243
+
244
+ const guardrail = createGuardrail();
245
+ const result = await guardrail.classify('test');
246
+
247
+ const names = result.categories.map((c) => c.name);
248
+ expect(names).toContain('toxic');
249
+ expect(names).toContain('injection');
250
+ expect(names).toContain('nsfw');
251
+ expect(names).toContain('threat');
252
+ expect(result.categories).toHaveLength(4);
253
+ });
254
+ });
255
+ });