@framers/agentos-ext-ml-classifiers 0.1.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.github/workflows/ci.yml +20 -0
- package/.github/workflows/release.yml +37 -0
- package/.releaserc.json +9 -0
- package/LICENSE +96 -21
- package/README.md +72 -0
- package/dist/MLClassifierGuardrail.d.ts +88 -117
- package/dist/MLClassifierGuardrail.d.ts.map +1 -1
- package/dist/MLClassifierGuardrail.js +263 -264
- package/dist/MLClassifierGuardrail.js.map +1 -1
- package/dist/index.d.ts +16 -90
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +36 -309
- package/dist/index.js.map +1 -1
- package/dist/keyword-classifier.d.ts +26 -0
- package/dist/keyword-classifier.d.ts.map +1 -0
- package/dist/keyword-classifier.js +113 -0
- package/dist/keyword-classifier.js.map +1 -0
- package/dist/llm-classifier.d.ts +27 -0
- package/dist/llm-classifier.d.ts.map +1 -0
- package/dist/llm-classifier.js +129 -0
- package/dist/llm-classifier.js.map +1 -0
- package/dist/tools/ClassifyContentTool.d.ts +53 -80
- package/dist/tools/ClassifyContentTool.d.ts.map +1 -1
- package/dist/tools/ClassifyContentTool.js +52 -103
- package/dist/tools/ClassifyContentTool.js.map +1 -1
- package/dist/types.d.ts +77 -277
- package/dist/types.d.ts.map +1 -1
- package/dist/types.js +9 -55
- package/dist/types.js.map +1 -1
- package/package.json +10 -24
- package/scripts/fix-esm-imports.mjs +181 -0
- package/src/MLClassifierGuardrail.ts +306 -310
- package/src/index.ts +35 -339
- package/src/keyword-classifier.ts +130 -0
- package/src/llm-classifier.ts +163 -0
- package/src/tools/ClassifyContentTool.ts +75 -132
- package/src/types.ts +78 -325
- package/test/llm-tier.spec.ts +267 -0
- package/test/ml-classifiers.spec.ts +57 -0
- package/test/onnx-tier.spec.ts +255 -0
- package/test/tier-fallthrough.spec.ts +185 -0
- package/tsconfig.json +20 -0
- package/vitest.config.ts +35 -0
- package/dist/ClassifierOrchestrator.d.ts +0 -126
- package/dist/ClassifierOrchestrator.d.ts.map +0 -1
- package/dist/ClassifierOrchestrator.js +0 -239
- package/dist/ClassifierOrchestrator.js.map +0 -1
- package/dist/IContentClassifier.d.ts +0 -117
- package/dist/IContentClassifier.d.ts.map +0 -1
- package/dist/IContentClassifier.js +0 -22
- package/dist/IContentClassifier.js.map +0 -1
- package/dist/SlidingWindowBuffer.d.ts +0 -213
- package/dist/SlidingWindowBuffer.d.ts.map +0 -1
- package/dist/SlidingWindowBuffer.js +0 -246
- package/dist/SlidingWindowBuffer.js.map +0 -1
- package/dist/classifiers/InjectionClassifier.d.ts +0 -126
- package/dist/classifiers/InjectionClassifier.d.ts.map +0 -1
- package/dist/classifiers/InjectionClassifier.js +0 -210
- package/dist/classifiers/InjectionClassifier.js.map +0 -1
- package/dist/classifiers/JailbreakClassifier.d.ts +0 -124
- package/dist/classifiers/JailbreakClassifier.d.ts.map +0 -1
- package/dist/classifiers/JailbreakClassifier.js +0 -208
- package/dist/classifiers/JailbreakClassifier.js.map +0 -1
- package/dist/classifiers/ToxicityClassifier.d.ts +0 -125
- package/dist/classifiers/ToxicityClassifier.d.ts.map +0 -1
- package/dist/classifiers/ToxicityClassifier.js +0 -212
- package/dist/classifiers/ToxicityClassifier.js.map +0 -1
- package/dist/classifiers/WorkerClassifierProxy.d.ts +0 -158
- package/dist/classifiers/WorkerClassifierProxy.d.ts.map +0 -1
- package/dist/classifiers/WorkerClassifierProxy.js +0 -268
- package/dist/classifiers/WorkerClassifierProxy.js.map +0 -1
- package/dist/worker/classifier-worker.d.ts +0 -49
- package/dist/worker/classifier-worker.d.ts.map +0 -1
- package/dist/worker/classifier-worker.js +0 -180
- package/dist/worker/classifier-worker.js.map +0 -1
- package/src/ClassifierOrchestrator.ts +0 -290
- package/src/IContentClassifier.ts +0 -124
- package/src/SlidingWindowBuffer.ts +0 -384
- package/src/classifiers/InjectionClassifier.ts +0 -261
- package/src/classifiers/JailbreakClassifier.ts +0 -259
- package/src/classifiers/ToxicityClassifier.ts +0 -263
- package/src/classifiers/WorkerClassifierProxy.ts +0 -366
- package/src/worker/classifier-worker.ts +0 -267
|
@@ -1,419 +1,415 @@
|
|
|
1
1
|
/**
|
|
2
|
-
* @
|
|
2
|
+
* @file MLClassifierGuardrail.ts
|
|
3
|
+
* @description IGuardrailService implementation that classifies text for toxicity,
|
|
4
|
+
* prompt injection, NSFW content, and threats using a three-tier strategy:
|
|
3
5
|
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
7
|
-
*
|
|
6
|
+
* 1. **ONNX inference** — attempts to load `@huggingface/transformers` at runtime
|
|
7
|
+
* and run a lightweight ONNX classification model.
|
|
8
|
+
* 2. **LLM-as-judge** — falls back to an LLM invoker callback that prompts a
|
|
9
|
+
* language model for structured JSON safety classification.
|
|
10
|
+
* 3. **Keyword matching** — last-resort regex/keyword-based detection when neither
|
|
11
|
+
* ONNX nor LLM are available.
|
|
8
12
|
*
|
|
9
|
-
*
|
|
13
|
+
* The guardrail is configured as Phase 2 (parallel, non-sanitizing) so it runs
|
|
14
|
+
* alongside other read-only guardrails without blocking the streaming pipeline.
|
|
10
15
|
*
|
|
11
|
-
*
|
|
12
|
-
* |---------------|----------------------------------------------------------------|
|
|
13
|
-
* | `blocking` | Every chunk that fills the sliding window is classified |
|
|
14
|
-
* | | **synchronously** — the stream waits for the result. |
|
|
15
|
-
* | `non-blocking`| Classification fires in the background; violations are surfaced |
|
|
16
|
-
* | | on the **next** `evaluateOutput` call for the same stream. |
|
|
17
|
-
* | `hybrid` | The first chunk for each stream is blocking; subsequent chunks |
|
|
18
|
-
* | | switch to non-blocking for lower latency. |
|
|
16
|
+
* ### Action thresholds
|
|
19
17
|
*
|
|
20
|
-
*
|
|
18
|
+
* - **FLAG** when any category confidence exceeds `flagThreshold` (default 0.5).
|
|
19
|
+
* - **BLOCK** when any category confidence exceeds `blockThreshold` (default 0.8).
|
|
21
20
|
*
|
|
22
|
-
* @module
|
|
21
|
+
* @module ml-classifiers/MLClassifierGuardrail
|
|
23
22
|
*/
|
|
24
23
|
|
|
25
24
|
import type {
|
|
25
|
+
IGuardrailService,
|
|
26
26
|
GuardrailConfig,
|
|
27
|
-
GuardrailEvaluationResult,
|
|
28
27
|
GuardrailInputPayload,
|
|
29
28
|
GuardrailOutputPayload,
|
|
30
|
-
|
|
29
|
+
GuardrailEvaluationResult,
|
|
30
|
+
AgentOSFinalResponseChunk,
|
|
31
31
|
} from '@framers/agentos';
|
|
32
32
|
import { GuardrailAction } from '@framers/agentos';
|
|
33
33
|
import { AgentOSResponseChunkType } from '@framers/agentos';
|
|
34
|
-
import type {
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
34
|
+
import type {
|
|
35
|
+
MLClassifierOptions,
|
|
36
|
+
ClassifierCategory,
|
|
37
|
+
ClassifierResult,
|
|
38
|
+
CategoryScore,
|
|
39
|
+
} from './types';
|
|
40
|
+
import { ALL_CATEGORIES } from './types';
|
|
41
|
+
import { classifyByKeywords } from './keyword-classifier';
|
|
42
|
+
import { classifyByLlm } from './llm-classifier';
|
|
40
43
|
|
|
41
44
|
// ---------------------------------------------------------------------------
|
|
42
|
-
//
|
|
45
|
+
// HuggingFace / ONNX pipeline types
|
|
43
46
|
// ---------------------------------------------------------------------------
|
|
44
47
|
|
|
45
48
|
/**
|
|
46
|
-
*
|
|
47
|
-
*
|
|
48
|
-
*
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
* A single label + score pair returned by a HuggingFace text-classification
|
|
50
|
+
* pipeline. The `label` is the model's class name (e.g. `"toxic"`,
|
|
51
|
+
* `"obscene"`) and `score` is the softmax probability in [0, 1].
|
|
52
|
+
*/
|
|
53
|
+
interface OnnxClassificationLabel {
|
|
54
|
+
label: string;
|
|
55
|
+
score: number;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* Callable returned by `pipeline('text-classification', ...)` from the
|
|
60
|
+
* `@huggingface/transformers` package. Returns all label scores for the
|
|
61
|
+
* given text.
|
|
51
62
|
*/
|
|
52
|
-
type
|
|
63
|
+
type OnnxTextClassificationPipeline = (text: string) => Promise<OnnxClassificationLabel[]>;
|
|
53
64
|
|
|
54
65
|
// ---------------------------------------------------------------------------
|
|
55
66
|
// MLClassifierGuardrail
|
|
56
67
|
// ---------------------------------------------------------------------------
|
|
57
68
|
|
|
58
69
|
/**
|
|
59
|
-
*
|
|
60
|
-
*
|
|
70
|
+
* AgentOS guardrail that classifies text for safety using ML models, LLM
|
|
71
|
+
* inference, or keyword fallback.
|
|
61
72
|
*
|
|
62
73
|
* @implements {IGuardrailService}
|
|
63
|
-
*
|
|
64
|
-
* @example
|
|
65
|
-
* ```typescript
|
|
66
|
-
* const guardrail = new MLClassifierGuardrail(serviceRegistry, {
|
|
67
|
-
* classifiers: ['toxicity'],
|
|
68
|
-
* streamingMode: true,
|
|
69
|
-
* chunkSize: 150,
|
|
70
|
-
* guardrailScope: 'both',
|
|
71
|
-
* });
|
|
72
|
-
*
|
|
73
|
-
* // Input evaluation — runs classifier on the full user message.
|
|
74
|
-
* const inputResult = await guardrail.evaluateInput({ context, input });
|
|
75
|
-
*
|
|
76
|
-
* // Output evaluation — accumulates tokens, classifies at window boundary.
|
|
77
|
-
* const outputResult = await guardrail.evaluateOutput({ context, chunk });
|
|
78
|
-
* ```
|
|
79
74
|
*/
|
|
80
75
|
export class MLClassifierGuardrail implements IGuardrailService {
|
|
81
|
-
//
|
|
82
|
-
// IGuardrailService
|
|
83
|
-
//
|
|
76
|
+
// -----------------------------------------------------------------------
|
|
77
|
+
// IGuardrailService.config
|
|
78
|
+
// -----------------------------------------------------------------------
|
|
84
79
|
|
|
85
80
|
/**
|
|
86
|
-
* Guardrail configuration
|
|
81
|
+
* Guardrail configuration.
|
|
87
82
|
*
|
|
88
|
-
* `
|
|
89
|
-
*
|
|
83
|
+
* - `canSanitize: false` — this guardrail does not modify content; it only
|
|
84
|
+
* BLOCKs or FLAGs. This places it in Phase 2 (parallel) of the guardrail
|
|
85
|
+
* dispatcher for better performance.
|
|
86
|
+
* - `evaluateStreamingChunks: false` — only evaluates complete messages, not
|
|
87
|
+
* individual streaming deltas. ML classification on partial text produces
|
|
88
|
+
* unreliable results.
|
|
90
89
|
*/
|
|
91
|
-
readonly config: GuardrailConfig
|
|
90
|
+
readonly config: GuardrailConfig = {
|
|
91
|
+
canSanitize: false,
|
|
92
|
+
evaluateStreamingChunks: false,
|
|
93
|
+
};
|
|
92
94
|
|
|
93
|
-
//
|
|
94
|
-
//
|
|
95
|
-
//
|
|
95
|
+
// -----------------------------------------------------------------------
|
|
96
|
+
// Private state
|
|
97
|
+
// -----------------------------------------------------------------------
|
|
96
98
|
|
|
97
|
-
/**
|
|
98
|
-
private readonly
|
|
99
|
+
/** Categories to evaluate. */
|
|
100
|
+
private readonly categories: ClassifierCategory[];
|
|
99
101
|
|
|
100
|
-
/**
|
|
101
|
-
private readonly
|
|
102
|
+
/** Per-category flag thresholds. */
|
|
103
|
+
private readonly flagThresholds: Record<ClassifierCategory, number>;
|
|
102
104
|
|
|
103
|
-
/**
|
|
104
|
-
private readonly
|
|
105
|
+
/** Per-category block thresholds. */
|
|
106
|
+
private readonly blockThresholds: Record<ClassifierCategory, number>;
|
|
105
107
|
|
|
106
|
-
/**
|
|
107
|
-
private readonly
|
|
108
|
-
|
|
109
|
-
/**
|
|
110
|
-
* Map of stream IDs to pending (background) classification promises.
|
|
111
|
-
* Used in `non-blocking` and `hybrid` modes to defer result checking
|
|
112
|
-
* to the next `evaluateOutput` call.
|
|
113
|
-
*/
|
|
114
|
-
private readonly pendingResults: Map<string, Promise<ChunkEvaluation>> = new Map();
|
|
108
|
+
/** Optional LLM invoker callback for tier-2 classification. */
|
|
109
|
+
private readonly llmInvoker: MLClassifierOptions['llmInvoker'];
|
|
115
110
|
|
|
116
111
|
/**
|
|
117
|
-
*
|
|
118
|
-
*
|
|
119
|
-
*
|
|
112
|
+
* Cached reference to the `@huggingface/transformers` pipeline function.
|
|
113
|
+
* `null` means we already tried and failed to load the module.
|
|
114
|
+
* `undefined` means we have not tried yet.
|
|
120
115
|
*/
|
|
121
|
-
private
|
|
116
|
+
private onnxPipeline: OnnxTextClassificationPipeline | null | undefined = undefined;
|
|
122
117
|
|
|
123
|
-
//
|
|
118
|
+
// -----------------------------------------------------------------------
|
|
124
119
|
// Constructor
|
|
125
|
-
//
|
|
120
|
+
// -----------------------------------------------------------------------
|
|
126
121
|
|
|
127
122
|
/**
|
|
128
|
-
* Create a new
|
|
123
|
+
* Create a new MLClassifierGuardrail.
|
|
129
124
|
*
|
|
130
|
-
* @param
|
|
131
|
-
*
|
|
132
|
-
* @param options - Pack-level options controlling classifier selection,
|
|
133
|
-
* thresholds, sliding window size, and streaming mode.
|
|
134
|
-
* @param classifiers - Pre-built classifier instances. When provided,
|
|
135
|
-
* these are used directly instead of constructing
|
|
136
|
-
* classifiers from `options.classifiers`.
|
|
125
|
+
* @param options - Pack-level configuration. All properties have sensible
|
|
126
|
+
* defaults for zero-config operation.
|
|
137
127
|
*/
|
|
138
|
-
constructor(
|
|
139
|
-
|
|
140
|
-
options: MLClassifierPackOptions,
|
|
141
|
-
classifiers: IContentClassifier[] = [],
|
|
142
|
-
) {
|
|
143
|
-
// Resolve thresholds: merge caller overrides on top of defaults.
|
|
144
|
-
const thresholds = {
|
|
145
|
-
...DEFAULT_THRESHOLDS,
|
|
146
|
-
...options.thresholds,
|
|
147
|
-
};
|
|
128
|
+
constructor(options?: MLClassifierOptions) {
|
|
129
|
+
const opts = options ?? {};
|
|
148
130
|
|
|
149
|
-
|
|
150
|
-
this.
|
|
151
|
-
|
|
152
|
-
//
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
}
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
// is 'blocking'; callers can override via the `streamingMode` option
|
|
164
|
-
// (which we reinterpret as a boolean gate here — advanced callers pass
|
|
165
|
-
// a StreamingMode string via `options` when they need finer control).
|
|
166
|
-
this.streamingMode = options.streamingMode ? 'blocking' : 'blocking';
|
|
167
|
-
|
|
168
|
-
// Expose guardrail config to the pipeline.
|
|
169
|
-
this.config = {
|
|
170
|
-
evaluateStreamingChunks: true,
|
|
171
|
-
maxStreamingEvaluations: options.maxEvaluations ?? 100,
|
|
172
|
-
};
|
|
131
|
+
this.categories = opts.categories ?? [...ALL_CATEGORIES];
|
|
132
|
+
this.llmInvoker = opts.llmInvoker;
|
|
133
|
+
|
|
134
|
+
// Resolve per-category thresholds.
|
|
135
|
+
const globalFlag = opts.flagThreshold ?? 0.5;
|
|
136
|
+
const globalBlock = opts.blockThreshold ?? 0.8;
|
|
137
|
+
|
|
138
|
+
this.flagThresholds = {} as Record<ClassifierCategory, number>;
|
|
139
|
+
this.blockThresholds = {} as Record<ClassifierCategory, number>;
|
|
140
|
+
|
|
141
|
+
for (const cat of ALL_CATEGORIES) {
|
|
142
|
+
this.flagThresholds[cat] = opts.thresholds?.[cat]?.flag ?? globalFlag;
|
|
143
|
+
this.blockThresholds[cat] = opts.thresholds?.[cat]?.block ?? globalBlock;
|
|
144
|
+
}
|
|
173
145
|
}
|
|
174
146
|
|
|
175
|
-
//
|
|
176
|
-
// evaluateInput
|
|
177
|
-
//
|
|
147
|
+
// -----------------------------------------------------------------------
|
|
148
|
+
// IGuardrailService — evaluateInput
|
|
149
|
+
// -----------------------------------------------------------------------
|
|
178
150
|
|
|
179
151
|
/**
|
|
180
|
-
* Evaluate
|
|
181
|
-
*
|
|
182
|
-
* Runs the full text through all registered classifiers and returns a
|
|
183
|
-
* {@link GuardrailEvaluationResult} when a violation is detected, or
|
|
184
|
-
* `null` when the content is clean.
|
|
152
|
+
* Evaluate user input for safety before orchestration begins.
|
|
185
153
|
*
|
|
186
|
-
*
|
|
187
|
-
*
|
|
188
|
-
* @param payload - The input payload containing user text and context.
|
|
189
|
-
* @returns Evaluation result or `null` if no action is needed.
|
|
154
|
+
* @param payload - Input evaluation payload containing the user's message.
|
|
155
|
+
* @returns Guardrail result or `null` if no action is required.
|
|
190
156
|
*/
|
|
191
157
|
async evaluateInput(payload: GuardrailInputPayload): Promise<GuardrailEvaluationResult | null> {
|
|
192
|
-
// Skip input evaluation when scope is output-only.
|
|
193
|
-
if (this.scope === 'output') {
|
|
194
|
-
return null;
|
|
195
|
-
}
|
|
196
|
-
|
|
197
|
-
// Extract the text from the input. If there is no text, nothing to classify.
|
|
198
158
|
const text = payload.input.textInput;
|
|
199
|
-
if (!text)
|
|
200
|
-
return null;
|
|
201
|
-
}
|
|
202
|
-
|
|
203
|
-
// Run all classifiers against the full user message.
|
|
204
|
-
const evaluation = await this.orchestrator.classifyAll(text);
|
|
159
|
+
if (!text || text.length === 0) return null;
|
|
205
160
|
|
|
206
|
-
|
|
207
|
-
return this.
|
|
161
|
+
const result = await this.classify(text);
|
|
162
|
+
return this.buildResult(result);
|
|
208
163
|
}
|
|
209
164
|
|
|
210
|
-
//
|
|
211
|
-
// evaluateOutput
|
|
212
|
-
//
|
|
165
|
+
// -----------------------------------------------------------------------
|
|
166
|
+
// IGuardrailService — evaluateOutput
|
|
167
|
+
// -----------------------------------------------------------------------
|
|
213
168
|
|
|
214
169
|
/**
|
|
215
|
-
* Evaluate
|
|
216
|
-
*
|
|
170
|
+
* Evaluate agent output for safety. Only processes FINAL_RESPONSE chunks
|
|
171
|
+
* since `evaluateStreamingChunks` is disabled.
|
|
217
172
|
*
|
|
218
|
-
*
|
|
219
|
-
*
|
|
220
|
-
* evaluation strategy depends on the configured streaming mode.
|
|
221
|
-
*
|
|
222
|
-
* Skipped entirely when `scope === 'input'`.
|
|
223
|
-
*
|
|
224
|
-
* @param payload - The output payload containing the response chunk and context.
|
|
225
|
-
* @returns Evaluation result or `null` if no action is needed yet.
|
|
173
|
+
* @param payload - Output evaluation payload from the AgentOS dispatcher.
|
|
174
|
+
* @returns Guardrail result or `null` if no action is required.
|
|
226
175
|
*/
|
|
227
176
|
async evaluateOutput(payload: GuardrailOutputPayload): Promise<GuardrailEvaluationResult | null> {
|
|
228
|
-
|
|
229
|
-
if (this.scope === 'input') {
|
|
230
|
-
return null;
|
|
231
|
-
}
|
|
232
|
-
|
|
233
|
-
const chunk = payload.chunk;
|
|
234
|
-
|
|
235
|
-
// Handle final chunks: flush remaining buffer and classify.
|
|
236
|
-
if (chunk.isFinal) {
|
|
237
|
-
const streamId = chunk.streamId;
|
|
238
|
-
const flushed = this.buffer.flush(streamId);
|
|
239
|
-
|
|
240
|
-
// Clean up tracking state for this stream.
|
|
241
|
-
this.isFirstChunk.delete(streamId);
|
|
242
|
-
this.pendingResults.delete(streamId);
|
|
243
|
-
|
|
244
|
-
if (!flushed) {
|
|
245
|
-
return null;
|
|
246
|
-
}
|
|
177
|
+
const { chunk } = payload;
|
|
247
178
|
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
return this.evaluationToResult(evaluation);
|
|
251
|
-
}
|
|
252
|
-
|
|
253
|
-
// Only process TEXT_DELTA chunks — ignore tool calls, progress, etc.
|
|
254
|
-
if (chunk.type !== AgentOSResponseChunkType.TEXT_DELTA) {
|
|
255
|
-
return null;
|
|
256
|
-
}
|
|
257
|
-
|
|
258
|
-
// Extract the text delta from the chunk.
|
|
259
|
-
const textDelta = (chunk as any).textDelta as string | undefined;
|
|
260
|
-
if (!textDelta) {
|
|
179
|
+
// Only evaluate final text responses.
|
|
180
|
+
if (chunk.type !== AgentOSResponseChunkType.FINAL_RESPONSE) {
|
|
261
181
|
return null;
|
|
262
182
|
}
|
|
263
183
|
|
|
264
|
-
|
|
265
|
-
const
|
|
184
|
+
const finalChunk = chunk as AgentOSFinalResponseChunk;
|
|
185
|
+
const text = finalChunk.finalResponseText ?? '';
|
|
186
|
+
if (typeof text !== 'string' || text.length === 0) return null;
|
|
266
187
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
return this.handleNonBlocking(streamId, textDelta);
|
|
188
|
+
const result = await this.classify(text);
|
|
189
|
+
return this.buildResult(result);
|
|
190
|
+
}
|
|
271
191
|
|
|
272
|
-
|
|
273
|
-
|
|
192
|
+
// -----------------------------------------------------------------------
|
|
193
|
+
// Public classification method (also used by ClassifyContentTool)
|
|
194
|
+
// -----------------------------------------------------------------------
|
|
274
195
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
196
|
+
/**
|
|
197
|
+
* Classify a text string using the three-tier strategy: ONNX -> LLM -> keyword.
|
|
198
|
+
*
|
|
199
|
+
* @param text - The text to classify.
|
|
200
|
+
* @returns Classification result with per-category scores.
|
|
201
|
+
*/
|
|
202
|
+
async classify(text: string): Promise<ClassifierResult> {
|
|
203
|
+
// Tier 1: try ONNX inference.
|
|
204
|
+
const onnxResult = await this.tryOnnxClassification(text);
|
|
205
|
+
if (onnxResult) return onnxResult;
|
|
206
|
+
|
|
207
|
+
// Tier 2: try LLM-as-judge.
|
|
208
|
+
if (this.llmInvoker) {
|
|
209
|
+
const llmResult = await this.tryLlmClassification(text);
|
|
210
|
+
if (llmResult) return llmResult;
|
|
278
211
|
}
|
|
212
|
+
|
|
213
|
+
// Tier 3: keyword fallback.
|
|
214
|
+
const scores = classifyByKeywords(text, this.categories);
|
|
215
|
+
return {
|
|
216
|
+
categories: scores,
|
|
217
|
+
flagged: scores.some((s) => s.confidence > this.flagThresholds[s.name]),
|
|
218
|
+
source: 'keyword',
|
|
219
|
+
};
|
|
279
220
|
}
|
|
280
221
|
|
|
281
|
-
//
|
|
282
|
-
//
|
|
283
|
-
//
|
|
222
|
+
// -----------------------------------------------------------------------
|
|
223
|
+
// Private — ONNX classification (Tier 1)
|
|
224
|
+
// -----------------------------------------------------------------------
|
|
284
225
|
|
|
285
226
|
/**
|
|
286
|
-
*
|
|
287
|
-
*
|
|
227
|
+
* Attempt to load `@huggingface/transformers` and run ONNX-based text
|
|
228
|
+
* classification. Returns `null` if the module is unavailable or inference
|
|
229
|
+
* fails.
|
|
230
|
+
*
|
|
231
|
+
* The module load is attempted only once; subsequent calls use the cached
|
|
232
|
+
* result (either a working pipeline or `null`).
|
|
288
233
|
*
|
|
289
|
-
* @param
|
|
290
|
-
* @
|
|
291
|
-
*
|
|
234
|
+
* @param text - Text to classify.
|
|
235
|
+
* @returns Classification result or `null`.
|
|
236
|
+
*
|
|
237
|
+
* @internal
|
|
292
238
|
*/
|
|
293
|
-
private async
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
if (
|
|
299
|
-
|
|
239
|
+
private async tryOnnxClassification(text: string): Promise<ClassifierResult | null> {
|
|
240
|
+
// If we already know ONNX is unavailable, skip.
|
|
241
|
+
if (this.onnxPipeline === null) return null;
|
|
242
|
+
|
|
243
|
+
// First-time load attempt.
|
|
244
|
+
if (this.onnxPipeline === undefined) {
|
|
245
|
+
try {
|
|
246
|
+
// Dynamic import so the optional dependency does not fail at boot.
|
|
247
|
+
const transformers = await import('@huggingface/transformers');
|
|
248
|
+
const pipelineInstance = await transformers.pipeline(
|
|
249
|
+
'text-classification',
|
|
250
|
+
'Xenova/toxic-bert',
|
|
251
|
+
{ device: 'cpu' }
|
|
252
|
+
);
|
|
253
|
+
// The HuggingFace pipeline is callable as a function. We always
|
|
254
|
+
// request all labels (top_k higher than any model's label count
|
|
255
|
+
// causes HF to return every label, matching `top_k: null`).
|
|
256
|
+
// Cast through unknown because the Pipeline union type is too
|
|
257
|
+
// wide for the inferred call signature here.
|
|
258
|
+
const callable = pipelineInstance as unknown as (
|
|
259
|
+
text: string,
|
|
260
|
+
opts: { top_k: number },
|
|
261
|
+
) => Promise<OnnxClassificationLabel[]>;
|
|
262
|
+
this.onnxPipeline = (text: string) => callable(text, { top_k: 9999 });
|
|
263
|
+
} catch {
|
|
264
|
+
// Module not installed or model load failed — mark as unavailable.
|
|
265
|
+
this.onnxPipeline = null;
|
|
266
|
+
return null;
|
|
267
|
+
}
|
|
300
268
|
}
|
|
301
269
|
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
270
|
+
try {
|
|
271
|
+
const raw = await this.onnxPipeline(text);
|
|
272
|
+
|
|
273
|
+
// Map ONNX labels to our categories.
|
|
274
|
+
const scores = this.mapOnnxScores(raw);
|
|
275
|
+
return {
|
|
276
|
+
categories: scores,
|
|
277
|
+
flagged: scores.some((s) => s.confidence > this.flagThresholds[s.name]),
|
|
278
|
+
source: 'onnx',
|
|
279
|
+
};
|
|
280
|
+
} catch {
|
|
281
|
+
// Inference failed — fall through to next tier.
|
|
282
|
+
return null;
|
|
283
|
+
}
|
|
305
284
|
}
|
|
306
285
|
|
|
307
286
|
/**
|
|
308
|
-
*
|
|
309
|
-
*
|
|
310
|
-
*
|
|
311
|
-
*
|
|
287
|
+
* Map raw ONNX text-classification output labels to our standard categories.
|
|
288
|
+
*
|
|
289
|
+
* ONNX models (e.g. toxic-bert) produce labels like `"toxic"`, `"obscene"`,
|
|
290
|
+
* `"threat"`, `"insult"`, `"identity_hate"`, etc. We map these to our four
|
|
291
|
+
* categories, taking the max score when multiple ONNX labels map to the same
|
|
292
|
+
* category.
|
|
312
293
|
*
|
|
313
|
-
* @param
|
|
314
|
-
* @
|
|
315
|
-
*
|
|
294
|
+
* @param raw - Raw ONNX pipeline output.
|
|
295
|
+
* @returns Per-category scores.
|
|
296
|
+
*
|
|
297
|
+
* @internal
|
|
316
298
|
*/
|
|
317
|
-
private
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
const result = this.evaluationToResult(resolved.val);
|
|
335
|
-
if (result) {
|
|
336
|
-
return result;
|
|
337
|
-
}
|
|
338
|
-
}
|
|
339
|
-
}
|
|
299
|
+
private mapOnnxScores(raw: OnnxClassificationLabel[]): CategoryScore[] {
|
|
300
|
+
/** Map of ONNX label -> our category. */
|
|
301
|
+
const labelMap: Record<string, ClassifierCategory> = {
|
|
302
|
+
toxic: 'toxic',
|
|
303
|
+
severe_toxic: 'toxic',
|
|
304
|
+
obscene: 'nsfw',
|
|
305
|
+
insult: 'toxic',
|
|
306
|
+
identity_hate: 'toxic',
|
|
307
|
+
threat: 'threat',
|
|
308
|
+
};
|
|
309
|
+
|
|
310
|
+
const maxScores: Record<ClassifierCategory, number> = {
|
|
311
|
+
toxic: 0,
|
|
312
|
+
injection: 0,
|
|
313
|
+
nsfw: 0,
|
|
314
|
+
threat: 0,
|
|
315
|
+
};
|
|
340
316
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
317
|
+
for (const item of raw) {
|
|
318
|
+
const label = (item.label ?? '').toLowerCase().replace(/\s+/g, '_');
|
|
319
|
+
const score = typeof item.score === 'number' ? item.score : 0;
|
|
320
|
+
const cat = labelMap[label];
|
|
321
|
+
if (cat && score > maxScores[cat]) {
|
|
322
|
+
maxScores[cat] = score;
|
|
323
|
+
}
|
|
347
324
|
}
|
|
348
325
|
|
|
349
|
-
//
|
|
350
|
-
return
|
|
326
|
+
// ONNX models typically do not detect prompt injection; leave at 0.
|
|
327
|
+
return this.categories.map((name) => ({
|
|
328
|
+
name,
|
|
329
|
+
confidence: maxScores[name] ?? 0,
|
|
330
|
+
}));
|
|
351
331
|
}
|
|
352
332
|
|
|
333
|
+
// -----------------------------------------------------------------------
|
|
334
|
+
// Private — LLM classification (Tier 2)
|
|
335
|
+
// -----------------------------------------------------------------------
|
|
336
|
+
|
|
353
337
|
/**
|
|
354
|
-
*
|
|
355
|
-
* blocking mode; subsequent chunks use non-blocking.
|
|
338
|
+
* Classify text using the LLM-as-judge fallback.
|
|
356
339
|
*
|
|
357
|
-
*
|
|
358
|
-
*
|
|
359
|
-
* remainder of the stream.
|
|
340
|
+
* @param text - Text to classify.
|
|
341
|
+
* @returns Classification result or `null` if the LLM call fails.
|
|
360
342
|
*
|
|
361
|
-
* @
|
|
362
|
-
* @param textDelta - New text fragment from the current chunk.
|
|
363
|
-
* @returns Evaluation result or `null`.
|
|
343
|
+
* @internal
|
|
364
344
|
*/
|
|
365
|
-
private async
|
|
366
|
-
|
|
367
|
-
textDelta: string,
|
|
368
|
-
): Promise<GuardrailEvaluationResult | null> {
|
|
369
|
-
// Determine whether this is the first chunk for this stream.
|
|
370
|
-
const isFirst = !this.isFirstChunk.has(streamId);
|
|
371
|
-
if (isFirst) {
|
|
372
|
-
this.isFirstChunk.set(streamId, true);
|
|
373
|
-
}
|
|
345
|
+
private async tryLlmClassification(text: string): Promise<ClassifierResult | null> {
|
|
346
|
+
if (!this.llmInvoker) return null;
|
|
374
347
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
348
|
+
try {
|
|
349
|
+
const scores = await classifyByLlm(text, this.llmInvoker, this.categories);
|
|
350
|
+
|
|
351
|
+
// If all scores are zero the LLM likely failed to parse — treat as null.
|
|
352
|
+
if (scores.every((s) => s.confidence === 0)) return null;
|
|
353
|
+
|
|
354
|
+
return {
|
|
355
|
+
categories: scores,
|
|
356
|
+
flagged: scores.some((s) => s.confidence > this.flagThresholds[s.name]),
|
|
357
|
+
source: 'llm',
|
|
358
|
+
};
|
|
359
|
+
} catch {
|
|
360
|
+
return null;
|
|
378
361
|
}
|
|
379
|
-
return this.handleNonBlocking(streamId, textDelta);
|
|
380
362
|
}
|
|
381
363
|
|
|
382
|
-
//
|
|
383
|
-
// Private
|
|
384
|
-
//
|
|
364
|
+
// -----------------------------------------------------------------------
|
|
365
|
+
// Private — result builder
|
|
366
|
+
// -----------------------------------------------------------------------
|
|
385
367
|
|
|
386
368
|
/**
|
|
387
|
-
* Convert a {@link
|
|
388
|
-
*
|
|
369
|
+
* Convert a {@link ClassifierResult} into a {@link GuardrailEvaluationResult},
|
|
370
|
+
* or return `null` when no thresholds are exceeded.
|
|
389
371
|
*
|
|
390
|
-
*
|
|
391
|
-
*
|
|
392
|
-
* metadata for audit/logging.
|
|
372
|
+
* @param result - Classification result from any tier.
|
|
373
|
+
* @returns Guardrail evaluation result or `null`.
|
|
393
374
|
*
|
|
394
|
-
* @
|
|
395
|
-
* @returns A guardrail result or `null` for clean content.
|
|
375
|
+
* @internal
|
|
396
376
|
*/
|
|
397
|
-
private
|
|
398
|
-
//
|
|
399
|
-
|
|
400
|
-
|
|
377
|
+
private buildResult(result: ClassifierResult): GuardrailEvaluationResult | null {
|
|
378
|
+
// Check for BLOCK-level violations first.
|
|
379
|
+
const blockers = result.categories.filter((s) => s.confidence > this.blockThresholds[s.name]);
|
|
380
|
+
|
|
381
|
+
if (blockers.length > 0) {
|
|
382
|
+
const worst = blockers.reduce((a, b) => (b.confidence > a.confidence ? b : a));
|
|
383
|
+
|
|
384
|
+
return {
|
|
385
|
+
action: GuardrailAction.BLOCK,
|
|
386
|
+
reason: `ML classifier detected unsafe content: ${blockers.map((b) => `${b.name}(${b.confidence.toFixed(2)})`).join(', ')}`,
|
|
387
|
+
reasonCode: `ML_CLASSIFIER_${worst.name.toUpperCase()}`,
|
|
388
|
+
metadata: {
|
|
389
|
+
source: result.source,
|
|
390
|
+
categories: result.categories,
|
|
391
|
+
},
|
|
392
|
+
};
|
|
401
393
|
}
|
|
402
394
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
395
|
+
// Check for FLAG-level violations.
|
|
396
|
+
const flaggers = result.categories.filter((s) => s.confidence > this.flagThresholds[s.name]);
|
|
397
|
+
|
|
398
|
+
if (flaggers.length > 0) {
|
|
399
|
+
const worst = flaggers.reduce((a, b) => (b.confidence > a.confidence ? b : a));
|
|
400
|
+
|
|
401
|
+
return {
|
|
402
|
+
action: GuardrailAction.FLAG,
|
|
403
|
+
reason: `ML classifier flagged content: ${flaggers.map((f) => `${f.name}(${f.confidence.toFixed(2)})`).join(', ')}`,
|
|
404
|
+
reasonCode: `ML_CLASSIFIER_${worst.name.toUpperCase()}`,
|
|
405
|
+
metadata: {
|
|
406
|
+
source: result.source,
|
|
407
|
+
categories: result.categories,
|
|
408
|
+
},
|
|
409
|
+
};
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
// No thresholds exceeded — allow.
|
|
413
|
+
return null;
|
|
418
414
|
}
|
|
419
415
|
}
|