@framers/agentos-ext-ml-classifiers 0.1.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. package/.github/workflows/ci.yml +20 -0
  2. package/.github/workflows/release.yml +37 -0
  3. package/.releaserc.json +9 -0
  4. package/LICENSE +96 -21
  5. package/README.md +72 -0
  6. package/dist/MLClassifierGuardrail.d.ts +88 -117
  7. package/dist/MLClassifierGuardrail.d.ts.map +1 -1
  8. package/dist/MLClassifierGuardrail.js +263 -264
  9. package/dist/MLClassifierGuardrail.js.map +1 -1
  10. package/dist/index.d.ts +16 -90
  11. package/dist/index.d.ts.map +1 -1
  12. package/dist/index.js +36 -309
  13. package/dist/index.js.map +1 -1
  14. package/dist/keyword-classifier.d.ts +26 -0
  15. package/dist/keyword-classifier.d.ts.map +1 -0
  16. package/dist/keyword-classifier.js +113 -0
  17. package/dist/keyword-classifier.js.map +1 -0
  18. package/dist/llm-classifier.d.ts +27 -0
  19. package/dist/llm-classifier.d.ts.map +1 -0
  20. package/dist/llm-classifier.js +129 -0
  21. package/dist/llm-classifier.js.map +1 -0
  22. package/dist/tools/ClassifyContentTool.d.ts +53 -80
  23. package/dist/tools/ClassifyContentTool.d.ts.map +1 -1
  24. package/dist/tools/ClassifyContentTool.js +52 -103
  25. package/dist/tools/ClassifyContentTool.js.map +1 -1
  26. package/dist/types.d.ts +77 -277
  27. package/dist/types.d.ts.map +1 -1
  28. package/dist/types.js +9 -55
  29. package/dist/types.js.map +1 -1
  30. package/package.json +10 -24
  31. package/scripts/fix-esm-imports.mjs +181 -0
  32. package/src/MLClassifierGuardrail.ts +306 -310
  33. package/src/index.ts +35 -339
  34. package/src/keyword-classifier.ts +130 -0
  35. package/src/llm-classifier.ts +163 -0
  36. package/src/tools/ClassifyContentTool.ts +75 -132
  37. package/src/types.ts +78 -325
  38. package/test/llm-tier.spec.ts +267 -0
  39. package/test/ml-classifiers.spec.ts +57 -0
  40. package/test/onnx-tier.spec.ts +255 -0
  41. package/test/tier-fallthrough.spec.ts +185 -0
  42. package/tsconfig.json +20 -0
  43. package/vitest.config.ts +35 -0
  44. package/dist/ClassifierOrchestrator.d.ts +0 -126
  45. package/dist/ClassifierOrchestrator.d.ts.map +0 -1
  46. package/dist/ClassifierOrchestrator.js +0 -239
  47. package/dist/ClassifierOrchestrator.js.map +0 -1
  48. package/dist/IContentClassifier.d.ts +0 -117
  49. package/dist/IContentClassifier.d.ts.map +0 -1
  50. package/dist/IContentClassifier.js +0 -22
  51. package/dist/IContentClassifier.js.map +0 -1
  52. package/dist/SlidingWindowBuffer.d.ts +0 -213
  53. package/dist/SlidingWindowBuffer.d.ts.map +0 -1
  54. package/dist/SlidingWindowBuffer.js +0 -246
  55. package/dist/SlidingWindowBuffer.js.map +0 -1
  56. package/dist/classifiers/InjectionClassifier.d.ts +0 -126
  57. package/dist/classifiers/InjectionClassifier.d.ts.map +0 -1
  58. package/dist/classifiers/InjectionClassifier.js +0 -210
  59. package/dist/classifiers/InjectionClassifier.js.map +0 -1
  60. package/dist/classifiers/JailbreakClassifier.d.ts +0 -124
  61. package/dist/classifiers/JailbreakClassifier.d.ts.map +0 -1
  62. package/dist/classifiers/JailbreakClassifier.js +0 -208
  63. package/dist/classifiers/JailbreakClassifier.js.map +0 -1
  64. package/dist/classifiers/ToxicityClassifier.d.ts +0 -125
  65. package/dist/classifiers/ToxicityClassifier.d.ts.map +0 -1
  66. package/dist/classifiers/ToxicityClassifier.js +0 -212
  67. package/dist/classifiers/ToxicityClassifier.js.map +0 -1
  68. package/dist/classifiers/WorkerClassifierProxy.d.ts +0 -158
  69. package/dist/classifiers/WorkerClassifierProxy.d.ts.map +0 -1
  70. package/dist/classifiers/WorkerClassifierProxy.js +0 -268
  71. package/dist/classifiers/WorkerClassifierProxy.js.map +0 -1
  72. package/dist/worker/classifier-worker.d.ts +0 -49
  73. package/dist/worker/classifier-worker.d.ts.map +0 -1
  74. package/dist/worker/classifier-worker.js +0 -180
  75. package/dist/worker/classifier-worker.js.map +0 -1
  76. package/src/ClassifierOrchestrator.ts +0 -290
  77. package/src/IContentClassifier.ts +0 -124
  78. package/src/SlidingWindowBuffer.ts +0 -384
  79. package/src/classifiers/InjectionClassifier.ts +0 -261
  80. package/src/classifiers/JailbreakClassifier.ts +0 -259
  81. package/src/classifiers/ToxicityClassifier.ts +0 -263
  82. package/src/classifiers/WorkerClassifierProxy.ts +0 -366
  83. package/src/worker/classifier-worker.ts +0 -267
@@ -1,419 +1,415 @@
1
1
  /**
2
- * @fileoverview IGuardrailService implementation backed by ML classifiers.
2
+ * @file MLClassifierGuardrail.ts
3
+ * @description IGuardrailService implementation that classifies text for toxicity,
4
+ * prompt injection, NSFW content, and threats using a three-tier strategy:
3
5
  *
4
- * `MLClassifierGuardrail` bridges the AgentOS guardrail pipeline to the ML
5
- * classifier subsystem. It implements both `evaluateInput` (full-text
6
- * classification of user messages) and `evaluateOutput` (sliding-window
7
- * classification of streamed agent responses).
6
+ * 1. **ONNX inference** attempts to load `@huggingface/transformers` at runtime
7
+ * and run a lightweight ONNX classification model.
8
+ * 2. **LLM-as-judge** falls back to an LLM invoker callback that prompts a
9
+ * language model for structured JSON safety classification.
10
+ * 3. **Keyword matching** — last-resort regex/keyword-based detection when neither
11
+ * ONNX nor LLM are available.
8
12
  *
9
- * Three streaming evaluation modes are supported:
13
+ * The guardrail is configured as Phase 2 (parallel, non-sanitizing) so it runs
14
+ * alongside other read-only guardrails without blocking the streaming pipeline.
10
15
  *
11
- * | Mode | Behaviour |
12
- * |---------------|----------------------------------------------------------------|
13
- * | `blocking` | Every chunk that fills the sliding window is classified |
14
- * | | **synchronously** — the stream waits for the result. |
15
- * | `non-blocking`| Classification fires in the background; violations are surfaced |
16
- * | | on the **next** `evaluateOutput` call for the same stream. |
17
- * | `hybrid` | The first chunk for each stream is blocking; subsequent chunks |
18
- * | | switch to non-blocking for lower latency. |
16
+ * ### Action thresholds
19
17
  *
20
- * The default mode is `blocking` when `streamingMode` is enabled.
18
+ * - **FLAG** when any category confidence exceeds `flagThreshold` (default 0.5).
19
+ * - **BLOCK** when any category confidence exceeds `blockThreshold` (default 0.8).
21
20
  *
22
- * @module agentos/extensions/packs/ml-classifiers/MLClassifierGuardrail
21
+ * @module ml-classifiers/MLClassifierGuardrail
23
22
  */
24
23
 
25
24
  import type {
25
+ IGuardrailService,
26
26
  GuardrailConfig,
27
- GuardrailEvaluationResult,
28
27
  GuardrailInputPayload,
29
28
  GuardrailOutputPayload,
30
- IGuardrailService,
29
+ GuardrailEvaluationResult,
30
+ AgentOSFinalResponseChunk,
31
31
  } from '@framers/agentos';
32
32
  import { GuardrailAction } from '@framers/agentos';
33
33
  import { AgentOSResponseChunkType } from '@framers/agentos';
34
- import type { ISharedServiceRegistry } from '@framers/agentos';
35
- import type { MLClassifierPackOptions, ChunkEvaluation } from './types';
36
- import { DEFAULT_THRESHOLDS } from './types';
37
- import { SlidingWindowBuffer } from './SlidingWindowBuffer';
38
- import { ClassifierOrchestrator } from './ClassifierOrchestrator';
39
- import type { IContentClassifier } from './IContentClassifier';
34
+ import type {
35
+ MLClassifierOptions,
36
+ ClassifierCategory,
37
+ ClassifierResult,
38
+ CategoryScore,
39
+ } from './types';
40
+ import { ALL_CATEGORIES } from './types';
41
+ import { classifyByKeywords } from './keyword-classifier';
42
+ import { classifyByLlm } from './llm-classifier';
40
43
 
41
44
  // ---------------------------------------------------------------------------
42
- // Streaming mode union
45
+ // HuggingFace / ONNX pipeline types
43
46
  // ---------------------------------------------------------------------------
44
47
 
45
48
  /**
46
- * The evaluation strategy used for output (streaming) chunks.
47
- *
48
- * - `blocking` await classification on every filled window.
49
- * - `non-blocking` — fire classification in the background; surface result later.
50
- * - `hybrid` — first chunk per stream is blocking, rest non-blocking.
49
+ * A single label + score pair returned by a HuggingFace text-classification
50
+ * pipeline. The `label` is the model's class name (e.g. `"toxic"`,
51
+ * `"obscene"`) and `score` is the softmax probability in [0, 1].
52
+ */
53
+ interface OnnxClassificationLabel {
54
+ label: string;
55
+ score: number;
56
+ }
57
+
58
+ /**
59
+ * Callable returned by `pipeline('text-classification', ...)` from the
60
+ * `@huggingface/transformers` package. Returns all label scores for the
61
+ * given text.
51
62
  */
52
- type StreamingMode = 'blocking' | 'non-blocking' | 'hybrid';
63
+ type OnnxTextClassificationPipeline = (text: string) => Promise<OnnxClassificationLabel[]>;
53
64
 
54
65
  // ---------------------------------------------------------------------------
55
66
  // MLClassifierGuardrail
56
67
  // ---------------------------------------------------------------------------
57
68
 
58
69
  /**
59
- * Guardrail implementation that runs ML classifiers against both user input
60
- * and streamed agent output.
70
+ * AgentOS guardrail that classifies text for safety using ML models, LLM
71
+ * inference, or keyword fallback.
61
72
  *
62
73
  * @implements {IGuardrailService}
63
- *
64
- * @example
65
- * ```typescript
66
- * const guardrail = new MLClassifierGuardrail(serviceRegistry, {
67
- * classifiers: ['toxicity'],
68
- * streamingMode: true,
69
- * chunkSize: 150,
70
- * guardrailScope: 'both',
71
- * });
72
- *
73
- * // Input evaluation — runs classifier on the full user message.
74
- * const inputResult = await guardrail.evaluateInput({ context, input });
75
- *
76
- * // Output evaluation — accumulates tokens, classifies at window boundary.
77
- * const outputResult = await guardrail.evaluateOutput({ context, chunk });
78
- * ```
79
74
  */
80
75
  export class MLClassifierGuardrail implements IGuardrailService {
81
- // -------------------------------------------------------------------------
82
- // IGuardrailService config
83
- // -------------------------------------------------------------------------
76
+ // -----------------------------------------------------------------------
77
+ // IGuardrailService.config
78
+ // -----------------------------------------------------------------------
84
79
 
85
80
  /**
86
- * Guardrail configuration exposed to the AgentOS pipeline.
81
+ * Guardrail configuration.
87
82
  *
88
- * `evaluateStreamingChunks` is always `true` because this guardrail uses
89
- * the sliding window to evaluate output tokens incrementally.
83
+ * - `canSanitize: false` this guardrail does not modify content; it only
84
+ * BLOCKs or FLAGs. This places it in Phase 2 (parallel) of the guardrail
85
+ * dispatcher for better performance.
86
+ * - `evaluateStreamingChunks: false` — only evaluates complete messages, not
87
+ * individual streaming deltas. ML classification on partial text produces
88
+ * unreliable results.
90
89
  */
91
- readonly config: GuardrailConfig;
90
+ readonly config: GuardrailConfig = {
91
+ canSanitize: false,
92
+ evaluateStreamingChunks: false,
93
+ };
92
94
 
93
- // -------------------------------------------------------------------------
94
- // Internal state
95
- // -------------------------------------------------------------------------
95
+ // -----------------------------------------------------------------------
96
+ // Private state
97
+ // -----------------------------------------------------------------------
96
98
 
97
- /** The classifier orchestrator that runs all classifiers in parallel. */
98
- private readonly orchestrator: ClassifierOrchestrator;
99
+ /** Categories to evaluate. */
100
+ private readonly categories: ClassifierCategory[];
99
101
 
100
- /** Sliding window buffer for accumulating streaming tokens. */
101
- private readonly buffer: SlidingWindowBuffer;
102
+ /** Per-category flag thresholds. */
103
+ private readonly flagThresholds: Record<ClassifierCategory, number>;
102
104
 
103
- /** Guardrail scope — which direction(s) this guardrail evaluates. */
104
- private readonly scope: 'input' | 'output' | 'both';
105
+ /** Per-category block thresholds. */
106
+ private readonly blockThresholds: Record<ClassifierCategory, number>;
105
107
 
106
- /** Streaming evaluation strategy for output chunks. */
107
- private readonly streamingMode: StreamingMode;
108
-
109
- /**
110
- * Map of stream IDs to pending (background) classification promises.
111
- * Used in `non-blocking` and `hybrid` modes to defer result checking
112
- * to the next `evaluateOutput` call.
113
- */
114
- private readonly pendingResults: Map<string, Promise<ChunkEvaluation>> = new Map();
108
+ /** Optional LLM invoker callback for tier-2 classification. */
109
+ private readonly llmInvoker: MLClassifierOptions['llmInvoker'];
115
110
 
116
111
  /**
117
- * Tracks whether the first chunk for a given stream has been processed.
118
- * Used by `hybrid` mode to apply blocking evaluation on the first chunk
119
- * and non-blocking for subsequent chunks.
112
+ * Cached reference to the `@huggingface/transformers` pipeline function.
113
+ * `null` means we already tried and failed to load the module.
114
+ * `undefined` means we have not tried yet.
120
115
  */
121
- private readonly isFirstChunk: Map<string, boolean> = new Map();
116
+ private onnxPipeline: OnnxTextClassificationPipeline | null | undefined = undefined;
122
117
 
123
- // -------------------------------------------------------------------------
118
+ // -----------------------------------------------------------------------
124
119
  // Constructor
125
- // -------------------------------------------------------------------------
120
+ // -----------------------------------------------------------------------
126
121
 
127
122
  /**
128
- * Create a new ML classifier guardrail.
123
+ * Create a new MLClassifierGuardrail.
129
124
  *
130
- * @param _services - Shared service registry (reserved for future use by
131
- * classifier factories that need lazy model loading).
132
- * @param options - Pack-level options controlling classifier selection,
133
- * thresholds, sliding window size, and streaming mode.
134
- * @param classifiers - Pre-built classifier instances. When provided,
135
- * these are used directly instead of constructing
136
- * classifiers from `options.classifiers`.
125
+ * @param options - Pack-level configuration. All properties have sensible
126
+ * defaults for zero-config operation.
137
127
  */
138
- constructor(
139
- _services: ISharedServiceRegistry,
140
- options: MLClassifierPackOptions,
141
- classifiers: IContentClassifier[] = [],
142
- ) {
143
- // Resolve thresholds: merge caller overrides on top of defaults.
144
- const thresholds = {
145
- ...DEFAULT_THRESHOLDS,
146
- ...options.thresholds,
147
- };
128
+ constructor(options?: MLClassifierOptions) {
129
+ const opts = options ?? {};
148
130
 
149
- // Build the orchestrator from the supplied classifiers.
150
- this.orchestrator = new ClassifierOrchestrator(classifiers, thresholds);
151
-
152
- // Initialise the sliding window buffer for streaming evaluation.
153
- this.buffer = new SlidingWindowBuffer({
154
- chunkSize: options.chunkSize,
155
- contextSize: options.contextSize,
156
- maxEvaluations: options.maxEvaluations,
157
- });
158
-
159
- // Store the guardrail scope (defaults to 'both').
160
- this.scope = options.guardrailScope ?? 'both';
161
-
162
- // Determine streaming mode. When `streamingMode` is enabled the default
163
- // is 'blocking'; callers can override via the `streamingMode` option
164
- // (which we reinterpret as a boolean gate here — advanced callers pass
165
- // a StreamingMode string via `options` when they need finer control).
166
- this.streamingMode = options.streamingMode ? 'blocking' : 'blocking';
167
-
168
- // Expose guardrail config to the pipeline.
169
- this.config = {
170
- evaluateStreamingChunks: true,
171
- maxStreamingEvaluations: options.maxEvaluations ?? 100,
172
- };
131
+ this.categories = opts.categories ?? [...ALL_CATEGORIES];
132
+ this.llmInvoker = opts.llmInvoker;
133
+
134
+ // Resolve per-category thresholds.
135
+ const globalFlag = opts.flagThreshold ?? 0.5;
136
+ const globalBlock = opts.blockThreshold ?? 0.8;
137
+
138
+ this.flagThresholds = {} as Record<ClassifierCategory, number>;
139
+ this.blockThresholds = {} as Record<ClassifierCategory, number>;
140
+
141
+ for (const cat of ALL_CATEGORIES) {
142
+ this.flagThresholds[cat] = opts.thresholds?.[cat]?.flag ?? globalFlag;
143
+ this.blockThresholds[cat] = opts.thresholds?.[cat]?.block ?? globalBlock;
144
+ }
173
145
  }
174
146
 
175
- // -------------------------------------------------------------------------
176
- // evaluateInput
177
- // -------------------------------------------------------------------------
147
+ // -----------------------------------------------------------------------
148
+ // IGuardrailService — evaluateInput
149
+ // -----------------------------------------------------------------------
178
150
 
179
151
  /**
180
- * Evaluate a user's input message before it enters the orchestration pipeline.
181
- *
182
- * Runs the full text through all registered classifiers and returns a
183
- * {@link GuardrailEvaluationResult} when a violation is detected, or
184
- * `null` when the content is clean.
152
+ * Evaluate user input for safety before orchestration begins.
185
153
  *
186
- * Skipped entirely when `scope === 'output'`.
187
- *
188
- * @param payload - The input payload containing user text and context.
189
- * @returns Evaluation result or `null` if no action is needed.
154
+ * @param payload - Input evaluation payload containing the user's message.
155
+ * @returns Guardrail result or `null` if no action is required.
190
156
  */
191
157
  async evaluateInput(payload: GuardrailInputPayload): Promise<GuardrailEvaluationResult | null> {
192
- // Skip input evaluation when scope is output-only.
193
- if (this.scope === 'output') {
194
- return null;
195
- }
196
-
197
- // Extract the text from the input. If there is no text, nothing to classify.
198
158
  const text = payload.input.textInput;
199
- if (!text) {
200
- return null;
201
- }
202
-
203
- // Run all classifiers against the full user message.
204
- const evaluation = await this.orchestrator.classifyAll(text);
159
+ if (!text || text.length === 0) return null;
205
160
 
206
- // Map the evaluation to a guardrail result (null for ALLOW).
207
- return this.evaluationToResult(evaluation);
161
+ const result = await this.classify(text);
162
+ return this.buildResult(result);
208
163
  }
209
164
 
210
- // -------------------------------------------------------------------------
211
- // evaluateOutput
212
- // -------------------------------------------------------------------------
165
+ // -----------------------------------------------------------------------
166
+ // IGuardrailService — evaluateOutput
167
+ // -----------------------------------------------------------------------
213
168
 
214
169
  /**
215
- * Evaluate a streamed output chunk from the agent before it is delivered
216
- * to the client.
170
+ * Evaluate agent output for safety. Only processes FINAL_RESPONSE chunks
171
+ * since `evaluateStreamingChunks` is disabled.
217
172
  *
218
- * The method accumulates text tokens in the sliding window buffer and
219
- * triggers classifier evaluation when a full window is available. The
220
- * evaluation strategy depends on the configured streaming mode.
221
- *
222
- * Skipped entirely when `scope === 'input'`.
223
- *
224
- * @param payload - The output payload containing the response chunk and context.
225
- * @returns Evaluation result or `null` if no action is needed yet.
173
+ * @param payload - Output evaluation payload from the AgentOS dispatcher.
174
+ * @returns Guardrail result or `null` if no action is required.
226
175
  */
227
176
  async evaluateOutput(payload: GuardrailOutputPayload): Promise<GuardrailEvaluationResult | null> {
228
- // Skip output evaluation when scope is input-only.
229
- if (this.scope === 'input') {
230
- return null;
231
- }
232
-
233
- const chunk = payload.chunk;
234
-
235
- // Handle final chunks: flush remaining buffer and classify.
236
- if (chunk.isFinal) {
237
- const streamId = chunk.streamId;
238
- const flushed = this.buffer.flush(streamId);
239
-
240
- // Clean up tracking state for this stream.
241
- this.isFirstChunk.delete(streamId);
242
- this.pendingResults.delete(streamId);
243
-
244
- if (!flushed) {
245
- return null;
246
- }
177
+ const { chunk } = payload;
247
178
 
248
- // Classify the remaining buffered text.
249
- const evaluation = await this.orchestrator.classifyAll(flushed.text);
250
- return this.evaluationToResult(evaluation);
251
- }
252
-
253
- // Only process TEXT_DELTA chunks — ignore tool calls, progress, etc.
254
- if (chunk.type !== AgentOSResponseChunkType.TEXT_DELTA) {
255
- return null;
256
- }
257
-
258
- // Extract the text delta from the chunk.
259
- const textDelta = (chunk as any).textDelta as string | undefined;
260
- if (!textDelta) {
179
+ // Only evaluate final text responses.
180
+ if (chunk.type !== AgentOSResponseChunkType.FINAL_RESPONSE) {
261
181
  return null;
262
182
  }
263
183
 
264
- // Resolve the stream identifier for the sliding window.
265
- const streamId = chunk.streamId;
184
+ const finalChunk = chunk as AgentOSFinalResponseChunk;
185
+ const text = finalChunk.finalResponseText ?? '';
186
+ if (typeof text !== 'string' || text.length === 0) return null;
266
187
 
267
- // Dispatch to the appropriate streaming mode handler.
268
- switch (this.streamingMode) {
269
- case 'non-blocking':
270
- return this.handleNonBlocking(streamId, textDelta);
188
+ const result = await this.classify(text);
189
+ return this.buildResult(result);
190
+ }
271
191
 
272
- case 'hybrid':
273
- return this.handleHybrid(streamId, textDelta);
192
+ // -----------------------------------------------------------------------
193
+ // Public classification method (also used by ClassifyContentTool)
194
+ // -----------------------------------------------------------------------
274
195
 
275
- case 'blocking':
276
- default:
277
- return this.handleBlocking(streamId, textDelta);
196
+ /**
197
+ * Classify a text string using the three-tier strategy: ONNX -> LLM -> keyword.
198
+ *
199
+ * @param text - The text to classify.
200
+ * @returns Classification result with per-category scores.
201
+ */
202
+ async classify(text: string): Promise<ClassifierResult> {
203
+ // Tier 1: try ONNX inference.
204
+ const onnxResult = await this.tryOnnxClassification(text);
205
+ if (onnxResult) return onnxResult;
206
+
207
+ // Tier 2: try LLM-as-judge.
208
+ if (this.llmInvoker) {
209
+ const llmResult = await this.tryLlmClassification(text);
210
+ if (llmResult) return llmResult;
278
211
  }
212
+
213
+ // Tier 3: keyword fallback.
214
+ const scores = classifyByKeywords(text, this.categories);
215
+ return {
216
+ categories: scores,
217
+ flagged: scores.some((s) => s.confidence > this.flagThresholds[s.name]),
218
+ source: 'keyword',
219
+ };
279
220
  }
280
221
 
281
- // -------------------------------------------------------------------------
282
- // Streaming mode handlers
283
- // -------------------------------------------------------------------------
222
+ // -----------------------------------------------------------------------
223
+ // Private ONNX classification (Tier 1)
224
+ // -----------------------------------------------------------------------
284
225
 
285
226
  /**
286
- * **Blocking mode**: push text into the buffer and, when a full window is
287
- * ready, await the classifier result before returning.
227
+ * Attempt to load `@huggingface/transformers` and run ONNX-based text
228
+ * classification. Returns `null` if the module is unavailable or inference
229
+ * fails.
230
+ *
231
+ * The module load is attempted only once; subsequent calls use the cached
232
+ * result (either a working pipeline or `null`).
288
233
  *
289
- * @param streamId - Identifier of the active stream.
290
- * @param textDelta - New text fragment from the current chunk.
291
- * @returns Evaluation result (possibly BLOCK/FLAG) or `null`.
234
+ * @param text - Text to classify.
235
+ * @returns Classification result or `null`.
236
+ *
237
+ * @internal
292
238
  */
293
- private async handleBlocking(
294
- streamId: string,
295
- textDelta: string,
296
- ): Promise<GuardrailEvaluationResult | null> {
297
- const ready = this.buffer.push(streamId, textDelta);
298
- if (!ready) {
299
- return null;
239
+ private async tryOnnxClassification(text: string): Promise<ClassifierResult | null> {
240
+ // If we already know ONNX is unavailable, skip.
241
+ if (this.onnxPipeline === null) return null;
242
+
243
+ // First-time load attempt.
244
+ if (this.onnxPipeline === undefined) {
245
+ try {
246
+ // Dynamic import so the optional dependency does not fail at boot.
247
+ const transformers = await import('@huggingface/transformers');
248
+ const pipelineInstance = await transformers.pipeline(
249
+ 'text-classification',
250
+ 'Xenova/toxic-bert',
251
+ { device: 'cpu' }
252
+ );
253
+ // The HuggingFace pipeline is callable as a function. We always
254
+ // request all labels (top_k higher than any model's label count
255
+ // causes HF to return every label, matching `top_k: null`).
256
+ // Cast through unknown because the Pipeline union type is too
257
+ // wide for the inferred call signature here.
258
+ const callable = pipelineInstance as unknown as (
259
+ text: string,
260
+ opts: { top_k: number },
261
+ ) => Promise<OnnxClassificationLabel[]>;
262
+ this.onnxPipeline = (text: string) => callable(text, { top_k: 9999 });
263
+ } catch {
264
+ // Module not installed or model load failed — mark as unavailable.
265
+ this.onnxPipeline = null;
266
+ return null;
267
+ }
300
268
  }
301
269
 
302
- // Classify the filled window synchronously.
303
- const evaluation = await this.orchestrator.classifyAll(ready.text);
304
- return this.evaluationToResult(evaluation);
270
+ try {
271
+ const raw = await this.onnxPipeline(text);
272
+
273
+ // Map ONNX labels to our categories.
274
+ const scores = this.mapOnnxScores(raw);
275
+ return {
276
+ categories: scores,
277
+ flagged: scores.some((s) => s.confidence > this.flagThresholds[s.name]),
278
+ source: 'onnx',
279
+ };
280
+ } catch {
281
+ // Inference failed — fall through to next tier.
282
+ return null;
283
+ }
305
284
  }
306
285
 
307
286
  /**
308
- * **Non-blocking mode**: push text into the buffer. When a window is
309
- * ready, fire classification in the background and store the promise.
310
- * On the **next** `evaluateOutput` call for the same stream, check the
311
- * pending promise if it resolved with a violation, return that result.
287
+ * Map raw ONNX text-classification output labels to our standard categories.
288
+ *
289
+ * ONNX models (e.g. toxic-bert) produce labels like `"toxic"`, `"obscene"`,
290
+ * `"threat"`, `"insult"`, `"identity_hate"`, etc. We map these to our four
291
+ * categories, taking the max score when multiple ONNX labels map to the same
292
+ * category.
312
293
  *
313
- * @param streamId - Identifier of the active stream.
314
- * @param textDelta - New text fragment from the current chunk.
315
- * @returns A previously resolved violation result, or `null`.
294
+ * @param raw - Raw ONNX pipeline output.
295
+ * @returns Per-category scores.
296
+ *
297
+ * @internal
316
298
  */
317
- private async handleNonBlocking(
318
- streamId: string,
319
- textDelta: string,
320
- ): Promise<GuardrailEvaluationResult | null> {
321
- // First, check if there is a pending result from a previous window.
322
- const pending = this.pendingResults.get(streamId);
323
- if (pending) {
324
- // Check if the promise has settled without blocking.
325
- const resolved = await Promise.race([
326
- pending.then((val) => ({ done: true as const, val })),
327
- Promise.resolve({ done: false as const, val: null as ChunkEvaluation | null }),
328
- ]);
329
-
330
- if (resolved.done && resolved.val) {
331
- // Consume the pending result.
332
- this.pendingResults.delete(streamId);
333
-
334
- const result = this.evaluationToResult(resolved.val);
335
- if (result) {
336
- return result;
337
- }
338
- }
339
- }
299
+ private mapOnnxScores(raw: OnnxClassificationLabel[]): CategoryScore[] {
300
+ /** Map of ONNX label -> our category. */
301
+ const labelMap: Record<string, ClassifierCategory> = {
302
+ toxic: 'toxic',
303
+ severe_toxic: 'toxic',
304
+ obscene: 'nsfw',
305
+ insult: 'toxic',
306
+ identity_hate: 'toxic',
307
+ threat: 'threat',
308
+ };
309
+
310
+ const maxScores: Record<ClassifierCategory, number> = {
311
+ toxic: 0,
312
+ injection: 0,
313
+ nsfw: 0,
314
+ threat: 0,
315
+ };
340
316
 
341
- // Push text into the buffer.
342
- const ready = this.buffer.push(streamId, textDelta);
343
- if (ready) {
344
- // Fire classification in the background — do NOT await.
345
- const classifyPromise = this.orchestrator.classifyAll(ready.text);
346
- this.pendingResults.set(streamId, classifyPromise);
317
+ for (const item of raw) {
318
+ const label = (item.label ?? '').toLowerCase().replace(/\s+/g, '_');
319
+ const score = typeof item.score === 'number' ? item.score : 0;
320
+ const cat = labelMap[label];
321
+ if (cat && score > maxScores[cat]) {
322
+ maxScores[cat] = score;
323
+ }
347
324
  }
348
325
 
349
- // Return null immediately result will be checked on next call.
350
- return null;
326
+ // ONNX models typically do not detect prompt injection; leave at 0.
327
+ return this.categories.map((name) => ({
328
+ name,
329
+ confidence: maxScores[name] ?? 0,
330
+ }));
351
331
  }
352
332
 
333
+ // -----------------------------------------------------------------------
334
+ // Private — LLM classification (Tier 2)
335
+ // -----------------------------------------------------------------------
336
+
353
337
  /**
354
- * **Hybrid mode**: the first chunk for each stream is evaluated in
355
- * blocking mode; subsequent chunks use non-blocking.
338
+ * Classify text using the LLM-as-judge fallback.
356
339
  *
357
- * This provides immediate feedback on the first window (where early
358
- * jailbreak attempts are most likely) while minimising latency for the
359
- * remainder of the stream.
340
+ * @param text - Text to classify.
341
+ * @returns Classification result or `null` if the LLM call fails.
360
342
  *
361
- * @param streamId - Identifier of the active stream.
362
- * @param textDelta - New text fragment from the current chunk.
363
- * @returns Evaluation result or `null`.
343
+ * @internal
364
344
  */
365
- private async handleHybrid(
366
- streamId: string,
367
- textDelta: string,
368
- ): Promise<GuardrailEvaluationResult | null> {
369
- // Determine whether this is the first chunk for this stream.
370
- const isFirst = !this.isFirstChunk.has(streamId);
371
- if (isFirst) {
372
- this.isFirstChunk.set(streamId, true);
373
- }
345
+ private async tryLlmClassification(text: string): Promise<ClassifierResult | null> {
346
+ if (!this.llmInvoker) return null;
374
347
 
375
- // First chunk → blocking, subsequent → non-blocking.
376
- if (isFirst) {
377
- return this.handleBlocking(streamId, textDelta);
348
+ try {
349
+ const scores = await classifyByLlm(text, this.llmInvoker, this.categories);
350
+
351
+ // If all scores are zero the LLM likely failed to parse — treat as null.
352
+ if (scores.every((s) => s.confidence === 0)) return null;
353
+
354
+ return {
355
+ categories: scores,
356
+ flagged: scores.some((s) => s.confidence > this.flagThresholds[s.name]),
357
+ source: 'llm',
358
+ };
359
+ } catch {
360
+ return null;
378
361
  }
379
- return this.handleNonBlocking(streamId, textDelta);
380
362
  }
381
363
 
382
- // -------------------------------------------------------------------------
383
- // Private helpers
384
- // -------------------------------------------------------------------------
364
+ // -----------------------------------------------------------------------
365
+ // Private — result builder
366
+ // -----------------------------------------------------------------------
385
367
 
386
368
  /**
387
- * Convert a {@link ChunkEvaluation} into a {@link GuardrailEvaluationResult}
388
- * suitable for the AgentOS guardrail pipeline.
369
+ * Convert a {@link ClassifierResult} into a {@link GuardrailEvaluationResult},
370
+ * or return `null` when no thresholds are exceeded.
389
371
  *
390
- * Returns `null` when the recommended action is ALLOW (no intervention
391
- * needed). For all other actions, the evaluation details are attached as
392
- * metadata for audit/logging.
372
+ * @param result - Classification result from any tier.
373
+ * @returns Guardrail evaluation result or `null`.
393
374
  *
394
- * @param evaluation - Aggregated classifier evaluation.
395
- * @returns A guardrail result or `null` for clean content.
375
+ * @internal
396
376
  */
397
- private evaluationToResult(evaluation: ChunkEvaluation): GuardrailEvaluationResult | null {
398
- // ALLOW means no guardrail action is needed.
399
- if (evaluation.recommendedAction === GuardrailAction.ALLOW) {
400
- return null;
377
+ private buildResult(result: ClassifierResult): GuardrailEvaluationResult | null {
378
+ // Check for BLOCK-level violations first.
379
+ const blockers = result.categories.filter((s) => s.confidence > this.blockThresholds[s.name]);
380
+
381
+ if (blockers.length > 0) {
382
+ const worst = blockers.reduce((a, b) => (b.confidence > a.confidence ? b : a));
383
+
384
+ return {
385
+ action: GuardrailAction.BLOCK,
386
+ reason: `ML classifier detected unsafe content: ${blockers.map((b) => `${b.name}(${b.confidence.toFixed(2)})`).join(', ')}`,
387
+ reasonCode: `ML_CLASSIFIER_${worst.name.toUpperCase()}`,
388
+ metadata: {
389
+ source: result.source,
390
+ categories: result.categories,
391
+ },
392
+ };
401
393
  }
402
394
 
403
- return {
404
- action: evaluation.recommendedAction,
405
- reason: `ML classifier "${evaluation.triggeredBy}" flagged content`,
406
- reasonCode: `ML_CLASSIFIER_${evaluation.recommendedAction.toUpperCase()}`,
407
- metadata: {
408
- triggeredBy: evaluation.triggeredBy,
409
- totalLatencyMs: evaluation.totalLatencyMs,
410
- classifierResults: evaluation.results.map((r) => ({
411
- classifierId: r.classifierId,
412
- bestClass: r.bestClass,
413
- confidence: r.confidence,
414
- latencyMs: r.latencyMs,
415
- })),
416
- },
417
- };
395
+ // Check for FLAG-level violations.
396
+ const flaggers = result.categories.filter((s) => s.confidence > this.flagThresholds[s.name]);
397
+
398
+ if (flaggers.length > 0) {
399
+ const worst = flaggers.reduce((a, b) => (b.confidence > a.confidence ? b : a));
400
+
401
+ return {
402
+ action: GuardrailAction.FLAG,
403
+ reason: `ML classifier flagged content: ${flaggers.map((f) => `${f.name}(${f.confidence.toFixed(2)})`).join(', ')}`,
404
+ reasonCode: `ML_CLASSIFIER_${worst.name.toUpperCase()}`,
405
+ metadata: {
406
+ source: result.source,
407
+ categories: result.categories,
408
+ },
409
+ };
410
+ }
411
+
412
+ // No thresholds exceeded — allow.
413
+ return null;
418
414
  }
419
415
  }