ai-shield-classifier-onnx 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md ADDED
@@ -0,0 +1,72 @@
1
+ # ai-shield-classifier-onnx
2
+
3
+ Optional ONNX-runtime ML classifier for [ai-shield](https://github.com/studiomeyer-io/ai-shield).
4
+ Pairs with `ai-shield-core` to add a DeBERTa-style prompt-injection classifier
5
+ alongside the heuristic patterns.
6
+
7
+ ## Why a separate package?
8
+
9
+ `ai-shield-core` is zero-dependency by design. ONNX inference requires
10
+ `onnxruntime-node`, which ships native binaries. Install this package only
11
+ when you actively want ML-augmented detection on top of the regex layer.
12
+
13
+ ## Install
14
+
15
+ ```bash
16
+ npm install ai-shield-core ai-shield-classifier-onnx onnxruntime-node
17
+ ```
18
+
19
+ ## Usage
20
+
21
+ ```ts
22
+ import { ScannerChain, HeuristicScanner } from "ai-shield-core";
23
+ import { loadOnnxClassifier } from "ai-shield-classifier-onnx";
24
+
25
+ // Bring your own tokenizer. Example: protectai/deberta-v3-base-prompt-injection
26
+ const tokenizer = await yourTokenizerFor("protectai/deberta-v3-base-prompt-injection");
27
+
28
+ const ml = await loadOnnxClassifier({
29
+ modelPath: "./models/deberta-injection.onnx",
30
+ tokenizer,
31
+ threshold: 0.85, // tune per model
32
+ });
33
+
34
+ const chain = new ScannerChain({ earlyExit: true });
35
+ chain.add(new HeuristicScanner({ strictness: "high" })); // cheap regex first
36
+ chain.add(ml); // ML fallback
37
+
38
+ const result = await chain.run("Ignore previous instructions...");
39
+ console.log(result.decision); // "block"
40
+ ```
41
+
42
+ Or manually with an already-constructed `InferenceSession`:
43
+
44
+ ```ts
45
+ import * as ort from "onnxruntime-node";
46
+ import { OnnxInjectionScanner } from "ai-shield-classifier-onnx";
47
+
48
+ const session = await ort.InferenceSession.create("./models/deberta.onnx");
49
+ const scanner = new OnnxInjectionScanner({ session, tokenizer, threshold: 0.85 });
50
+ ```
51
+
52
+ ## Recommended models
53
+
54
+ - [`protectai/deberta-v3-base-prompt-injection`](https://huggingface.co/protectai/deberta-v3-base-prompt-injection) (Apache-2.0)
55
+ - [`protectai/deberta-v3-base-prompt-injection-v2`](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2)
56
+ - Any HF model exported to ONNX via `optimum-cli export onnx`
57
+
58
+ ## Notes
59
+
60
+ - The scanner degrades **gracefully** on inference errors — failure is logged
61
+ as a `content_policy` violation but does not block traffic. This avoids
62
+ taking down the whole chain when the model file is missing or the runtime
63
+ hits a hardware-specific edge case.
64
+ - Use after the heuristic scanner. Most known attacks short-circuit on
65
+ cheap regex; the ML pass catches paraphrases and obfuscations that slip
66
+ through.
67
+ - The probability threshold is **calibrated per model**. Start at 0.85 and
68
+ tune against your false-positive budget.
69
+
70
+ ## License
71
+
72
+ MIT — see [LICENSE](../../LICENSE).
package/package.json ADDED
@@ -0,0 +1,53 @@
1
+ {
2
+ "name": "ai-shield-classifier-onnx",
3
+ "version": "0.2.0",
4
+ "license": "MIT",
5
+ "description": "Optional ONNX ML classifier for ai-shield — DeBERTa-style prompt-injection detection alongside heuristic patterns",
6
+ "author": "StudioMeyer <hello@studiomeyer.io>",
7
+ "repository": {
8
+ "type": "git",
9
+ "url": "https://github.com/studiomeyer-io/ai-shield",
10
+ "directory": "packages/classifier-onnx"
11
+ },
12
+ "homepage": "https://github.com/studiomeyer-io/ai-shield/tree/main/packages/classifier-onnx",
13
+ "bugs": {
14
+ "url": "https://github.com/studiomeyer-io/ai-shield/issues"
15
+ },
16
+ "keywords": [
17
+ "llm",
18
+ "security",
19
+ "prompt-injection",
20
+ "onnx",
21
+ "deberta",
22
+ "ml",
23
+ "classifier",
24
+ "ai-shield"
25
+ ],
26
+ "type": "module",
27
+ "main": "./dist/index.js",
28
+ "types": "./dist/index.d.ts",
29
+ "exports": {
30
+ ".": {
31
+ "types": "./dist/index.d.ts",
32
+ "import": "./dist/index.js"
33
+ }
34
+ },
35
+ "scripts": {
36
+ "build": "tsc -b",
37
+ "typecheck": "tsc -b"
38
+ },
39
+ "peerDependencies": {
40
+ "ai-shield-core": "0.2.0",
41
+ "onnxruntime-node": ">=1.17.0"
42
+ },
43
+ "peerDependenciesMeta": {
44
+ "onnxruntime-node": { "optional": true }
45
+ },
46
+ "devDependencies": {
47
+ "ai-shield-core": "0.2.0",
48
+ "typescript": "^5.7.0"
49
+ },
50
+ "engines": {
51
+ "node": ">=20.0.0"
52
+ }
53
+ }
@@ -0,0 +1,326 @@
1
+ import type {
2
+ Scanner,
3
+ ScannerResult,
4
+ ScanContext,
5
+ Violation,
6
+ } from "ai-shield-core";
7
+
8
+ // ============================================================
9
+ // OnnxInjectionScanner — ML-backed prompt-injection detection
10
+ //
11
+ // Implements the `Scanner` interface from ai-shield-core. Designed to
12
+ // be added to a `ScannerChain` after the heuristic scanner so that
13
+ // known patterns short-circuit on cheap regex AND novel paraphrases
14
+ // still get a second-pass ML check.
15
+ //
16
+ // The runtime is abstracted via `OnnxInferenceRuntime` so this file
17
+ // has zero hard dependency on `onnxruntime-node` at type-check time.
18
+ // At runtime you pass in either:
19
+ // 1. An already-constructed `ort.InferenceSession`
20
+ // (`new OnnxInjectionScanner({ session, tokenizer })`)
21
+ // 2. A path to a `.onnx` model file + a path to a `tokenizer.json`
22
+ // (`await loadOnnxClassifier({ modelPath, tokenizerPath })`)
23
+ //
24
+ // Both paths keep the dep injection clean and unit-testable.
25
+ // ============================================================
26
+
27
+ /**
28
+ * Minimal subset of `onnxruntime-node`'s `InferenceSession` we use.
29
+ * Declaring it locally means this package type-checks even when
30
+ * `onnxruntime-node` is not installed (which is the entire point —
31
+ * it's an optional peer dep).
32
+ */
33
+ export interface OnnxInferenceRuntime {
34
+ run(
35
+ feeds: Record<string, OnnxTensorLike>,
36
+ ): Promise<Record<string, OnnxTensorLike>>;
37
+ readonly inputNames?: readonly string[];
38
+ readonly outputNames?: readonly string[];
39
+ }
40
+
41
+ /** Tensor descriptor compatible with `onnxruntime-common`'s Tensor. */
42
+ export interface OnnxTensorLike {
43
+ readonly data: ArrayLike<number> | BigInt64Array;
44
+ readonly dims: readonly number[];
45
+ readonly type?: string;
46
+ }
47
+
48
+ /**
49
+ * Tokenizer abstraction. Same trick as the runtime — we don't bind to
50
+ * any specific HF tokenizer package so the user can wire up
51
+ * `@huggingface/transformers`, `tokenizers`, or a hand-written one.
52
+ */
53
+ export interface Tokenizer {
54
+ encode(text: string): {
55
+ input_ids: number[];
56
+ attention_mask: number[];
57
+ token_type_ids?: number[];
58
+ };
59
+ /** Optional max sequence length the tokenizer was trained for. */
60
+ modelMaxLength?: number;
61
+ }
62
+
63
+ export interface OnnxClassifierConfig {
64
+ /** A pre-constructed inference session. */
65
+ session: OnnxInferenceRuntime;
66
+ /** Tokenizer matching the model. */
67
+ tokenizer: Tokenizer;
68
+ /**
69
+ * Probability threshold above which the input is flagged as
70
+ * injection. Default 0.85 (calibrated for protectai/deberta-v3-base-
71
+ * prompt-injection — adjust per model).
72
+ */
73
+ threshold?: number;
74
+ /**
75
+ * Name of the output node that carries logits/probabilities.
76
+ * Default: first key in the runtime's result map.
77
+ */
78
+ outputName?: string;
79
+ /**
80
+ * Index of the "injection" class in the model output.
81
+ * Default: 1 (binary classifier convention: 0 = SAFE, 1 = INJECTION).
82
+ */
83
+ injectionClassIndex?: number;
84
+ /**
85
+ * Maximum sequence length to feed. Default: 512.
86
+ * Larger inputs are truncated head-only (start kept) which matches
87
+ * the standard DeBERTa fine-tune recipe.
88
+ */
89
+ maxLength?: number;
90
+ }
91
+
92
+ export class OnnxInjectionScanner implements Scanner {
93
+ readonly name = "onnx-classifier";
94
+ private readonly cfg: Required<
95
+ Omit<OnnxClassifierConfig, "outputName">
96
+ > & { outputName?: string };
97
+
98
+ constructor(config: OnnxClassifierConfig) {
99
+ if (!config.session) {
100
+ throw new TypeError("OnnxInjectionScanner: 'session' is required");
101
+ }
102
+ if (!config.tokenizer) {
103
+ throw new TypeError("OnnxInjectionScanner: 'tokenizer' is required");
104
+ }
105
+ this.cfg = {
106
+ session: config.session,
107
+ tokenizer: config.tokenizer,
108
+ threshold: config.threshold ?? 0.85,
109
+ outputName: config.outputName,
110
+ injectionClassIndex: config.injectionClassIndex ?? 1,
111
+ maxLength:
112
+ config.maxLength ?? config.tokenizer.modelMaxLength ?? 512,
113
+ };
114
+ }
115
+
116
+ async scan(input: string, _context: ScanContext): Promise<ScannerResult> {
117
+ const start = performance.now();
118
+ try {
119
+ const probability = await this.predict(input);
120
+ const violations: Violation[] = [];
121
+ let decision: ScannerResult["decision"] = "allow";
122
+
123
+ if (probability >= this.cfg.threshold) {
124
+ decision = "block";
125
+ violations.push({
126
+ type: "prompt_injection",
127
+ scanner: this.name,
128
+ score: probability,
129
+ threshold: this.cfg.threshold,
130
+ message: "ML classifier flagged prompt-injection",
131
+ detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
132
+ });
133
+ } else if (probability >= this.cfg.threshold * 0.6) {
134
+ decision = "warn";
135
+ violations.push({
136
+ type: "prompt_injection",
137
+ scanner: this.name,
138
+ score: probability,
139
+ threshold: this.cfg.threshold,
140
+ message: "ML classifier flagged borderline content",
141
+ detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
142
+ });
143
+ }
144
+
145
+ return {
146
+ decision,
147
+ violations,
148
+ durationMs: performance.now() - start,
149
+ };
150
+ } catch (err) {
151
+ // ML errors must not take down the entire chain — degrade gracefully
152
+ // to "allow" with a synthetic violation so the audit log shows
153
+ // something went wrong without blocking traffic.
154
+ //
155
+ // Critic H4 round 1 — the raw error message can contain file
156
+ // paths (model location), native library symbols, or deployment-
157
+ // internal strings. Strip absolute paths before they hit the audit
158
+ // log. In dev mode we keep more detail to aid debugging.
159
+ const rawMessage = (err as Error)?.message ?? "unknown error";
160
+ const isDev =
161
+ process.env.NODE_ENV === "development" ||
162
+ process.env.AI_SHIELD_DEBUG === "1";
163
+ const safeDetail = isDev
164
+ ? rawMessage
165
+ : sanitizeOnnxErrorMessage(rawMessage);
166
+ return {
167
+ decision: "allow",
168
+ violations: [
169
+ {
170
+ type: "content_policy",
171
+ scanner: this.name,
172
+ score: 0,
173
+ threshold: this.cfg.threshold,
174
+ message: "ML classifier failed — degraded to allow",
175
+ detail: safeDetail,
176
+ },
177
+ ],
178
+ durationMs: performance.now() - start,
179
+ };
180
+ }
181
+ }
182
+
183
+ /** Direct probability access — useful for tests + custom flows. */
184
+ async predict(input: string): Promise<number> {
185
+ const tokens = this.cfg.tokenizer.encode(input);
186
+ const trunc = truncate(tokens, this.cfg.maxLength);
187
+
188
+ const inputIds = BigInt64Array.from(trunc.input_ids.map((n) => BigInt(n)));
189
+ const attentionMask = BigInt64Array.from(
190
+ trunc.attention_mask.map((n) => BigInt(n)),
191
+ );
192
+ const dims = [1, trunc.input_ids.length];
193
+
194
+ const feeds: Record<string, OnnxTensorLike> = {
195
+ input_ids: { data: inputIds, dims, type: "int64" },
196
+ attention_mask: { data: attentionMask, dims, type: "int64" },
197
+ };
198
+ if (trunc.token_type_ids) {
199
+ feeds.token_type_ids = {
200
+ data: BigInt64Array.from(trunc.token_type_ids.map((n) => BigInt(n))),
201
+ dims,
202
+ type: "int64",
203
+ };
204
+ }
205
+
206
+ const result = await this.cfg.session.run(feeds);
207
+ const outputName =
208
+ this.cfg.outputName ??
209
+ this.cfg.session.outputNames?.[0] ??
210
+ Object.keys(result)[0];
211
+ if (!outputName) {
212
+ throw new Error("OnnxInjectionScanner: no output node available");
213
+ }
214
+ const tensor = result[outputName];
215
+ if (!tensor) {
216
+ throw new Error(
217
+ `OnnxInjectionScanner: output '${outputName}' not in result`,
218
+ );
219
+ }
220
+
221
+ // Model emits logits of shape [1, num_classes]. Softmax + pick class.
222
+ const logits = Array.from(tensor.data as ArrayLike<number>).map((n) =>
223
+ Number(n),
224
+ );
225
+ const probs = softmax(logits);
226
+ const idx = this.cfg.injectionClassIndex;
227
+ if (idx < 0 || idx >= probs.length) {
228
+ throw new Error(
229
+ `OnnxInjectionScanner: injectionClassIndex ${idx} out of range (len=${probs.length})`,
230
+ );
231
+ }
232
+ return probs[idx] ?? 0;
233
+ }
234
+ }
235
+
236
+ /**
237
+ * Convenience loader. Imports `onnxruntime-node` *at runtime* so
238
+ * consumers who never call this function don't pay the install cost.
239
+ *
240
+ * Tokenizer loading is left to the caller because tokenizer
241
+ * implementations vary widely between models — we don't want to
242
+ * pin a specific HF tokenizer package.
243
+ */
244
+ export async function loadOnnxClassifier(opts: {
245
+ modelPath: string;
246
+ tokenizer: Tokenizer;
247
+ threshold?: number;
248
+ outputName?: string;
249
+ injectionClassIndex?: number;
250
+ maxLength?: number;
251
+ }): Promise<OnnxInjectionScanner> {
252
+ // Dynamic import keeps the dep optional — TypeScript can't see it
253
+ // at compile time, which is the whole point.
254
+ let ort: unknown;
255
+ try {
256
+ ort = await import("onnxruntime-node" as string);
257
+ } catch (err) {
258
+ throw new Error(
259
+ "ai-shield-classifier-onnx: 'onnxruntime-node' is required " +
260
+ "to call loadOnnxClassifier(). Install it as a peer dependency.\n" +
261
+ `Underlying error: ${(err as Error).message}`,
262
+ );
263
+ }
264
+ const ortModule = (ort as { InferenceSession?: { create?: (path: string) => Promise<OnnxInferenceRuntime> } });
265
+ const create = ortModule.InferenceSession?.create;
266
+ if (typeof create !== "function") {
267
+ throw new Error(
268
+ "ai-shield-classifier-onnx: 'onnxruntime-node' did not expose InferenceSession.create",
269
+ );
270
+ }
271
+ const session = await create(opts.modelPath);
272
+ return new OnnxInjectionScanner({
273
+ session,
274
+ tokenizer: opts.tokenizer,
275
+ threshold: opts.threshold,
276
+ outputName: opts.outputName,
277
+ injectionClassIndex: opts.injectionClassIndex,
278
+ maxLength: opts.maxLength,
279
+ });
280
+ }
281
+
282
+ // --- helpers ---
283
+
284
+ /**
285
+ * Strip absolute paths and other deployment-internal strings from an
286
+ * ONNX-runtime error message before it lands in the audit log. Keeps
287
+ * the short error class / cause hint that helps diagnose the failure.
288
+ */
289
+ function sanitizeOnnxErrorMessage(message: string): string {
290
+ if (typeof message !== "string" || message.length === 0) {
291
+ return "classifier_runtime_error";
292
+ }
293
+ // Truncate before sanitizing — bounded work on adversarial input.
294
+ const truncated = message.length > 500 ? message.slice(0, 500) : message;
295
+ return truncated
296
+ // POSIX absolute paths
297
+ .replace(/(?:^|[\s(])(\/[\w./@-]+)/g, " [path]")
298
+ // Windows drive paths
299
+ .replace(/[A-Za-z]:\\[\\\w./@-]+/g, "[path]")
300
+ // file:// URLs
301
+ .replace(/file:\/\/\S+/g, "[file-url]")
302
+ // Memory addresses (0x...)
303
+ .replace(/0x[0-9a-fA-F]{6,}/g, "[addr]")
304
+ .trim();
305
+ }
306
+
307
+ function softmax(logits: number[]): number[] {
308
+ if (logits.length === 0) return [];
309
+ const max = Math.max(...logits);
310
+ const exps = logits.map((l) => Math.exp(l - max));
311
+ const sum = exps.reduce((a, b) => a + b, 0);
312
+ if (sum === 0) return logits.map(() => 0);
313
+ return exps.map((e) => e / sum);
314
+ }
315
+
316
+ function truncate(
317
+ tokens: ReturnType<Tokenizer["encode"]>,
318
+ maxLength: number,
319
+ ): ReturnType<Tokenizer["encode"]> {
320
+ if (tokens.input_ids.length <= maxLength) return tokens;
321
+ return {
322
+ input_ids: tokens.input_ids.slice(0, maxLength),
323
+ attention_mask: tokens.attention_mask.slice(0, maxLength),
324
+ token_type_ids: tokens.token_type_ids?.slice(0, maxLength),
325
+ };
326
+ }
package/src/index.ts ADDED
@@ -0,0 +1,30 @@
1
+ // ============================================================
2
+ // ai-shield-classifier-onnx — Optional ML classifier package
3
+ //
4
+ // Pairs with ai-shield-core. Adds an ONNX-runtime-backed
5
+ // prompt-injection classifier (DeBERTa-style by default) that can
6
+ // be added to a ScannerChain alongside the built-in heuristics.
7
+ //
8
+ // Why a separate package?
9
+ // The core package keeps a zero-dependency promise (Node stdlib only).
10
+ // ONNX inference requires `onnxruntime-node`, which ships native
11
+ // binaries and is too heavy to force on every consumer. Install this
12
+ // package only when you actively want ML-augmented detection.
13
+ //
14
+ // Recommended models:
15
+ // - protectai/deberta-v3-base-prompt-injection (Apache-2.0)
16
+ // - protectai/deberta-v3-base-prompt-injection-v2
17
+ // - hf models exported to ONNX via `optimum-cli`
18
+ //
19
+ // The classifier is intentionally pluggable — you bring the model file
20
+ // + tokenizer JSON, the wrapper handles inference + thresholding +
21
+ // integration with the Scanner interface.
22
+ // ============================================================
23
+
24
+ export {
25
+ OnnxInjectionScanner,
26
+ loadOnnxClassifier,
27
+ type OnnxClassifierConfig,
28
+ type OnnxInferenceRuntime,
29
+ type Tokenizer,
30
+ } from "./classifier.js";
package/tsconfig.json ADDED
@@ -0,0 +1,12 @@
1
+ {
2
+ "extends": "../../tsconfig.json",
3
+ "compilerOptions": {
4
+ "outDir": "./dist",
5
+ "rootDir": "./src",
6
+ "composite": true
7
+ },
8
+ "include": ["src/**/*"],
9
+ "references": [
10
+ { "path": "../core" }
11
+ ]
12
+ }