ai-shield-classifier-onnx 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +72 -0
- package/package.json +53 -0
- package/src/classifier.ts +326 -0
- package/src/index.ts +30 -0
- package/tsconfig.json +12 -0
package/README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# ai-shield-classifier-onnx
|
|
2
|
+
|
|
3
|
+
Optional ONNX-runtime ML classifier for [ai-shield](https://github.com/studiomeyer-io/ai-shield).
|
|
4
|
+
Pairs with `ai-shield-core` to add a DeBERTa-style prompt-injection classifier
|
|
5
|
+
alongside the heuristic patterns.
|
|
6
|
+
|
|
7
|
+
## Why a separate package?
|
|
8
|
+
|
|
9
|
+
`ai-shield-core` is zero-dependency by design. ONNX inference requires
|
|
10
|
+
`onnxruntime-node`, which ships native binaries. Install this package only
|
|
11
|
+
when you actively want ML-augmented detection on top of the regex layer.
|
|
12
|
+
|
|
13
|
+
## Install
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
npm install ai-shield-core ai-shield-classifier-onnx onnxruntime-node
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Usage
|
|
20
|
+
|
|
21
|
+
```ts
|
|
22
|
+
import { ScannerChain, HeuristicScanner } from "ai-shield-core";
|
|
23
|
+
import { loadOnnxClassifier } from "ai-shield-classifier-onnx";
|
|
24
|
+
|
|
25
|
+
// Bring your own tokenizer. Example: protectai/deberta-v3-base-prompt-injection
|
|
26
|
+
const tokenizer = await yourTokenizerFor("protectai/deberta-v3-base-prompt-injection");
|
|
27
|
+
|
|
28
|
+
const ml = await loadOnnxClassifier({
|
|
29
|
+
modelPath: "./models/deberta-injection.onnx",
|
|
30
|
+
tokenizer,
|
|
31
|
+
threshold: 0.85, // tune per model
|
|
32
|
+
});
|
|
33
|
+
|
|
34
|
+
const chain = new ScannerChain({ earlyExit: true });
|
|
35
|
+
chain.add(new HeuristicScanner({ strictness: "high" })); // cheap regex first
|
|
36
|
+
chain.add(ml); // ML fallback
|
|
37
|
+
|
|
38
|
+
const result = await chain.run("Ignore previous instructions...");
|
|
39
|
+
console.log(result.decision); // "block"
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
Or manually with an already-constructed `InferenceSession`:
|
|
43
|
+
|
|
44
|
+
```ts
|
|
45
|
+
import * as ort from "onnxruntime-node";
|
|
46
|
+
import { OnnxInjectionScanner } from "ai-shield-classifier-onnx";
|
|
47
|
+
|
|
48
|
+
const session = await ort.InferenceSession.create("./models/deberta.onnx");
|
|
49
|
+
const scanner = new OnnxInjectionScanner({ session, tokenizer, threshold: 0.85 });
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## Recommended models
|
|
53
|
+
|
|
54
|
+
- [`protectai/deberta-v3-base-prompt-injection`](https://huggingface.co/protectai/deberta-v3-base-prompt-injection) (Apache-2.0)
|
|
55
|
+
- [`protectai/deberta-v3-base-prompt-injection-v2`](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2)
|
|
56
|
+
- Any HF model exported to ONNX via `optimum-cli export onnx`
|
|
57
|
+
|
|
58
|
+
## Notes
|
|
59
|
+
|
|
60
|
+
- The scanner degrades **gracefully** on inference errors — failure is logged
|
|
61
|
+
as a `content_policy` violation but does not block traffic. This avoids
|
|
62
|
+
taking down the whole chain when the model file is missing or the runtime
|
|
63
|
+
hits a hardware-specific edge case.
|
|
64
|
+
- Use after the heuristic scanner. Most known attacks short-circuit on
|
|
65
|
+
cheap regex; the ML pass catches paraphrases and obfuscations that slip
|
|
66
|
+
through.
|
|
67
|
+
- The probability threshold is **calibrated per model**. Start at 0.85 and
|
|
68
|
+
tune against your false-positive budget.
|
|
69
|
+
|
|
70
|
+
## License
|
|
71
|
+
|
|
72
|
+
MIT — see [LICENSE](../../LICENSE).
|
package/package.json
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "ai-shield-classifier-onnx",
|
|
3
|
+
"version": "0.2.0",
|
|
4
|
+
"license": "MIT",
|
|
5
|
+
"description": "Optional ONNX ML classifier for ai-shield — DeBERTa-style prompt-injection detection alongside heuristic patterns",
|
|
6
|
+
"author": "StudioMeyer <hello@studiomeyer.io>",
|
|
7
|
+
"repository": {
|
|
8
|
+
"type": "git",
|
|
9
|
+
"url": "https://github.com/studiomeyer-io/ai-shield",
|
|
10
|
+
"directory": "packages/classifier-onnx"
|
|
11
|
+
},
|
|
12
|
+
"homepage": "https://github.com/studiomeyer-io/ai-shield/tree/main/packages/classifier-onnx",
|
|
13
|
+
"bugs": {
|
|
14
|
+
"url": "https://github.com/studiomeyer-io/ai-shield/issues"
|
|
15
|
+
},
|
|
16
|
+
"keywords": [
|
|
17
|
+
"llm",
|
|
18
|
+
"security",
|
|
19
|
+
"prompt-injection",
|
|
20
|
+
"onnx",
|
|
21
|
+
"deberta",
|
|
22
|
+
"ml",
|
|
23
|
+
"classifier",
|
|
24
|
+
"ai-shield"
|
|
25
|
+
],
|
|
26
|
+
"type": "module",
|
|
27
|
+
"main": "./dist/index.js",
|
|
28
|
+
"types": "./dist/index.d.ts",
|
|
29
|
+
"exports": {
|
|
30
|
+
".": {
|
|
31
|
+
"types": "./dist/index.d.ts",
|
|
32
|
+
"import": "./dist/index.js"
|
|
33
|
+
}
|
|
34
|
+
},
|
|
35
|
+
"scripts": {
|
|
36
|
+
"build": "tsc -b",
|
|
37
|
+
"typecheck": "tsc -b"
|
|
38
|
+
},
|
|
39
|
+
"peerDependencies": {
|
|
40
|
+
"ai-shield-core": "0.2.0",
|
|
41
|
+
"onnxruntime-node": ">=1.17.0"
|
|
42
|
+
},
|
|
43
|
+
"peerDependenciesMeta": {
|
|
44
|
+
"onnxruntime-node": { "optional": true }
|
|
45
|
+
},
|
|
46
|
+
"devDependencies": {
|
|
47
|
+
"ai-shield-core": "0.2.0",
|
|
48
|
+
"typescript": "^5.7.0"
|
|
49
|
+
},
|
|
50
|
+
"engines": {
|
|
51
|
+
"node": ">=20.0.0"
|
|
52
|
+
}
|
|
53
|
+
}
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
import type {
|
|
2
|
+
Scanner,
|
|
3
|
+
ScannerResult,
|
|
4
|
+
ScanContext,
|
|
5
|
+
Violation,
|
|
6
|
+
} from "ai-shield-core";
|
|
7
|
+
|
|
8
|
+
// ============================================================
|
|
9
|
+
// OnnxInjectionScanner — ML-backed prompt-injection detection
|
|
10
|
+
//
|
|
11
|
+
// Implements the `Scanner` interface from ai-shield-core. Designed to
|
|
12
|
+
// be added to a `ScannerChain` after the heuristic scanner so that
|
|
13
|
+
// known patterns short-circuit on cheap regex AND novel paraphrases
|
|
14
|
+
// still get a second-pass ML check.
|
|
15
|
+
//
|
|
16
|
+
// The runtime is abstracted via `OnnxInferenceRuntime` so this file
|
|
17
|
+
// has zero hard dependency on `onnxruntime-node` at type-check time.
|
|
18
|
+
// At runtime you pass in either:
|
|
19
|
+
// 1. An already-constructed `ort.InferenceSession`
|
|
20
|
+
// (`new OnnxInjectionScanner({ session, tokenizer })`)
|
|
21
|
+
// 2. A path to a `.onnx` model file + a path to a `tokenizer.json`
|
|
22
|
+
// (`await loadOnnxClassifier({ modelPath, tokenizerPath })`)
|
|
23
|
+
//
|
|
24
|
+
// Both paths keep the dep injection clean and unit-testable.
|
|
25
|
+
// ============================================================
|
|
26
|
+
|
|
27
|
+
/**
|
|
28
|
+
* Minimal subset of `onnxruntime-node`'s `InferenceSession` we use.
|
|
29
|
+
* Declaring it locally means this package type-checks even when
|
|
30
|
+
* `onnxruntime-node` is not installed (which is the entire point —
|
|
31
|
+
* it's an optional peer dep).
|
|
32
|
+
*/
|
|
33
|
+
export interface OnnxInferenceRuntime {
|
|
34
|
+
run(
|
|
35
|
+
feeds: Record<string, OnnxTensorLike>,
|
|
36
|
+
): Promise<Record<string, OnnxTensorLike>>;
|
|
37
|
+
readonly inputNames?: readonly string[];
|
|
38
|
+
readonly outputNames?: readonly string[];
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
/** Tensor descriptor compatible with `onnxruntime-common`'s Tensor. */
|
|
42
|
+
export interface OnnxTensorLike {
|
|
43
|
+
readonly data: ArrayLike<number> | BigInt64Array;
|
|
44
|
+
readonly dims: readonly number[];
|
|
45
|
+
readonly type?: string;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/**
|
|
49
|
+
* Tokenizer abstraction. Same trick as the runtime — we don't bind to
|
|
50
|
+
* any specific HF tokenizer package so the user can wire up
|
|
51
|
+
* `@huggingface/transformers`, `tokenizers`, or a hand-written one.
|
|
52
|
+
*/
|
|
53
|
+
export interface Tokenizer {
|
|
54
|
+
encode(text: string): {
|
|
55
|
+
input_ids: number[];
|
|
56
|
+
attention_mask: number[];
|
|
57
|
+
token_type_ids?: number[];
|
|
58
|
+
};
|
|
59
|
+
/** Optional max sequence length the tokenizer was trained for. */
|
|
60
|
+
modelMaxLength?: number;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
export interface OnnxClassifierConfig {
|
|
64
|
+
/** A pre-constructed inference session. */
|
|
65
|
+
session: OnnxInferenceRuntime;
|
|
66
|
+
/** Tokenizer matching the model. */
|
|
67
|
+
tokenizer: Tokenizer;
|
|
68
|
+
/**
|
|
69
|
+
* Probability threshold above which the input is flagged as
|
|
70
|
+
* injection. Default 0.85 (calibrated for protectai/deberta-v3-base-
|
|
71
|
+
* prompt-injection — adjust per model).
|
|
72
|
+
*/
|
|
73
|
+
threshold?: number;
|
|
74
|
+
/**
|
|
75
|
+
* Name of the output node that carries logits/probabilities.
|
|
76
|
+
* Default: first key in the runtime's result map.
|
|
77
|
+
*/
|
|
78
|
+
outputName?: string;
|
|
79
|
+
/**
|
|
80
|
+
* Index of the "injection" class in the model output.
|
|
81
|
+
* Default: 1 (binary classifier convention: 0 = SAFE, 1 = INJECTION).
|
|
82
|
+
*/
|
|
83
|
+
injectionClassIndex?: number;
|
|
84
|
+
/**
|
|
85
|
+
* Maximum sequence length to feed. Default: 512.
|
|
86
|
+
* Larger inputs are truncated head-only (start kept) which matches
|
|
87
|
+
* the standard DeBERTa fine-tune recipe.
|
|
88
|
+
*/
|
|
89
|
+
maxLength?: number;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
export class OnnxInjectionScanner implements Scanner {
|
|
93
|
+
readonly name = "onnx-classifier";
|
|
94
|
+
private readonly cfg: Required<
|
|
95
|
+
Omit<OnnxClassifierConfig, "outputName">
|
|
96
|
+
> & { outputName?: string };
|
|
97
|
+
|
|
98
|
+
constructor(config: OnnxClassifierConfig) {
|
|
99
|
+
if (!config.session) {
|
|
100
|
+
throw new TypeError("OnnxInjectionScanner: 'session' is required");
|
|
101
|
+
}
|
|
102
|
+
if (!config.tokenizer) {
|
|
103
|
+
throw new TypeError("OnnxInjectionScanner: 'tokenizer' is required");
|
|
104
|
+
}
|
|
105
|
+
this.cfg = {
|
|
106
|
+
session: config.session,
|
|
107
|
+
tokenizer: config.tokenizer,
|
|
108
|
+
threshold: config.threshold ?? 0.85,
|
|
109
|
+
outputName: config.outputName,
|
|
110
|
+
injectionClassIndex: config.injectionClassIndex ?? 1,
|
|
111
|
+
maxLength:
|
|
112
|
+
config.maxLength ?? config.tokenizer.modelMaxLength ?? 512,
|
|
113
|
+
};
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
async scan(input: string, _context: ScanContext): Promise<ScannerResult> {
|
|
117
|
+
const start = performance.now();
|
|
118
|
+
try {
|
|
119
|
+
const probability = await this.predict(input);
|
|
120
|
+
const violations: Violation[] = [];
|
|
121
|
+
let decision: ScannerResult["decision"] = "allow";
|
|
122
|
+
|
|
123
|
+
if (probability >= this.cfg.threshold) {
|
|
124
|
+
decision = "block";
|
|
125
|
+
violations.push({
|
|
126
|
+
type: "prompt_injection",
|
|
127
|
+
scanner: this.name,
|
|
128
|
+
score: probability,
|
|
129
|
+
threshold: this.cfg.threshold,
|
|
130
|
+
message: "ML classifier flagged prompt-injection",
|
|
131
|
+
detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
|
|
132
|
+
});
|
|
133
|
+
} else if (probability >= this.cfg.threshold * 0.6) {
|
|
134
|
+
decision = "warn";
|
|
135
|
+
violations.push({
|
|
136
|
+
type: "prompt_injection",
|
|
137
|
+
scanner: this.name,
|
|
138
|
+
score: probability,
|
|
139
|
+
threshold: this.cfg.threshold,
|
|
140
|
+
message: "ML classifier flagged borderline content",
|
|
141
|
+
detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
|
|
142
|
+
});
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
return {
|
|
146
|
+
decision,
|
|
147
|
+
violations,
|
|
148
|
+
durationMs: performance.now() - start,
|
|
149
|
+
};
|
|
150
|
+
} catch (err) {
|
|
151
|
+
// ML errors must not take down the entire chain — degrade gracefully
|
|
152
|
+
// to "allow" with a synthetic violation so the audit log shows
|
|
153
|
+
// something went wrong without blocking traffic.
|
|
154
|
+
//
|
|
155
|
+
// Critic H4 round 1 — the raw error message can contain file
|
|
156
|
+
// paths (model location), native library symbols, or deployment-
|
|
157
|
+
// internal strings. Strip absolute paths before they hit the audit
|
|
158
|
+
// log. In dev mode we keep more detail to aid debugging.
|
|
159
|
+
const rawMessage = (err as Error)?.message ?? "unknown error";
|
|
160
|
+
const isDev =
|
|
161
|
+
process.env.NODE_ENV === "development" ||
|
|
162
|
+
process.env.AI_SHIELD_DEBUG === "1";
|
|
163
|
+
const safeDetail = isDev
|
|
164
|
+
? rawMessage
|
|
165
|
+
: sanitizeOnnxErrorMessage(rawMessage);
|
|
166
|
+
return {
|
|
167
|
+
decision: "allow",
|
|
168
|
+
violations: [
|
|
169
|
+
{
|
|
170
|
+
type: "content_policy",
|
|
171
|
+
scanner: this.name,
|
|
172
|
+
score: 0,
|
|
173
|
+
threshold: this.cfg.threshold,
|
|
174
|
+
message: "ML classifier failed — degraded to allow",
|
|
175
|
+
detail: safeDetail,
|
|
176
|
+
},
|
|
177
|
+
],
|
|
178
|
+
durationMs: performance.now() - start,
|
|
179
|
+
};
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
/** Direct probability access — useful for tests + custom flows. */
|
|
184
|
+
async predict(input: string): Promise<number> {
|
|
185
|
+
const tokens = this.cfg.tokenizer.encode(input);
|
|
186
|
+
const trunc = truncate(tokens, this.cfg.maxLength);
|
|
187
|
+
|
|
188
|
+
const inputIds = BigInt64Array.from(trunc.input_ids.map((n) => BigInt(n)));
|
|
189
|
+
const attentionMask = BigInt64Array.from(
|
|
190
|
+
trunc.attention_mask.map((n) => BigInt(n)),
|
|
191
|
+
);
|
|
192
|
+
const dims = [1, trunc.input_ids.length];
|
|
193
|
+
|
|
194
|
+
const feeds: Record<string, OnnxTensorLike> = {
|
|
195
|
+
input_ids: { data: inputIds, dims, type: "int64" },
|
|
196
|
+
attention_mask: { data: attentionMask, dims, type: "int64" },
|
|
197
|
+
};
|
|
198
|
+
if (trunc.token_type_ids) {
|
|
199
|
+
feeds.token_type_ids = {
|
|
200
|
+
data: BigInt64Array.from(trunc.token_type_ids.map((n) => BigInt(n))),
|
|
201
|
+
dims,
|
|
202
|
+
type: "int64",
|
|
203
|
+
};
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
const result = await this.cfg.session.run(feeds);
|
|
207
|
+
const outputName =
|
|
208
|
+
this.cfg.outputName ??
|
|
209
|
+
this.cfg.session.outputNames?.[0] ??
|
|
210
|
+
Object.keys(result)[0];
|
|
211
|
+
if (!outputName) {
|
|
212
|
+
throw new Error("OnnxInjectionScanner: no output node available");
|
|
213
|
+
}
|
|
214
|
+
const tensor = result[outputName];
|
|
215
|
+
if (!tensor) {
|
|
216
|
+
throw new Error(
|
|
217
|
+
`OnnxInjectionScanner: output '${outputName}' not in result`,
|
|
218
|
+
);
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
// Model emits logits of shape [1, num_classes]. Softmax + pick class.
|
|
222
|
+
const logits = Array.from(tensor.data as ArrayLike<number>).map((n) =>
|
|
223
|
+
Number(n),
|
|
224
|
+
);
|
|
225
|
+
const probs = softmax(logits);
|
|
226
|
+
const idx = this.cfg.injectionClassIndex;
|
|
227
|
+
if (idx < 0 || idx >= probs.length) {
|
|
228
|
+
throw new Error(
|
|
229
|
+
`OnnxInjectionScanner: injectionClassIndex ${idx} out of range (len=${probs.length})`,
|
|
230
|
+
);
|
|
231
|
+
}
|
|
232
|
+
return probs[idx] ?? 0;
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
/**
|
|
237
|
+
* Convenience loader. Imports `onnxruntime-node` *at runtime* so
|
|
238
|
+
* consumers who never call this function don't pay the install cost.
|
|
239
|
+
*
|
|
240
|
+
* Tokenizer loading is left to the caller because tokenizer
|
|
241
|
+
* implementations vary widely between models — we don't want to
|
|
242
|
+
* pin a specific HF tokenizer package.
|
|
243
|
+
*/
|
|
244
|
+
export async function loadOnnxClassifier(opts: {
|
|
245
|
+
modelPath: string;
|
|
246
|
+
tokenizer: Tokenizer;
|
|
247
|
+
threshold?: number;
|
|
248
|
+
outputName?: string;
|
|
249
|
+
injectionClassIndex?: number;
|
|
250
|
+
maxLength?: number;
|
|
251
|
+
}): Promise<OnnxInjectionScanner> {
|
|
252
|
+
// Dynamic import keeps the dep optional — TypeScript can't see it
|
|
253
|
+
// at compile time, which is the whole point.
|
|
254
|
+
let ort: unknown;
|
|
255
|
+
try {
|
|
256
|
+
ort = await import("onnxruntime-node" as string);
|
|
257
|
+
} catch (err) {
|
|
258
|
+
throw new Error(
|
|
259
|
+
"ai-shield-classifier-onnx: 'onnxruntime-node' is required " +
|
|
260
|
+
"to call loadOnnxClassifier(). Install it as a peer dependency.\n" +
|
|
261
|
+
`Underlying error: ${(err as Error).message}`,
|
|
262
|
+
);
|
|
263
|
+
}
|
|
264
|
+
const ortModule = (ort as { InferenceSession?: { create?: (path: string) => Promise<OnnxInferenceRuntime> } });
|
|
265
|
+
const create = ortModule.InferenceSession?.create;
|
|
266
|
+
if (typeof create !== "function") {
|
|
267
|
+
throw new Error(
|
|
268
|
+
"ai-shield-classifier-onnx: 'onnxruntime-node' did not expose InferenceSession.create",
|
|
269
|
+
);
|
|
270
|
+
}
|
|
271
|
+
const session = await create(opts.modelPath);
|
|
272
|
+
return new OnnxInjectionScanner({
|
|
273
|
+
session,
|
|
274
|
+
tokenizer: opts.tokenizer,
|
|
275
|
+
threshold: opts.threshold,
|
|
276
|
+
outputName: opts.outputName,
|
|
277
|
+
injectionClassIndex: opts.injectionClassIndex,
|
|
278
|
+
maxLength: opts.maxLength,
|
|
279
|
+
});
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
// --- helpers ---
|
|
283
|
+
|
|
284
|
+
/**
|
|
285
|
+
* Strip absolute paths and other deployment-internal strings from an
|
|
286
|
+
* ONNX-runtime error message before it lands in the audit log. Keeps
|
|
287
|
+
* the short error class / cause hint that helps diagnose the failure.
|
|
288
|
+
*/
|
|
289
|
+
function sanitizeOnnxErrorMessage(message: string): string {
|
|
290
|
+
if (typeof message !== "string" || message.length === 0) {
|
|
291
|
+
return "classifier_runtime_error";
|
|
292
|
+
}
|
|
293
|
+
// Truncate before sanitizing — bounded work on adversarial input.
|
|
294
|
+
const truncated = message.length > 500 ? message.slice(0, 500) : message;
|
|
295
|
+
return truncated
|
|
296
|
+
// POSIX absolute paths
|
|
297
|
+
.replace(/(?:^|[\s(])(\/[\w./@-]+)/g, " [path]")
|
|
298
|
+
// Windows drive paths
|
|
299
|
+
.replace(/[A-Za-z]:\\[\\\w./@-]+/g, "[path]")
|
|
300
|
+
// file:// URLs
|
|
301
|
+
.replace(/file:\/\/\S+/g, "[file-url]")
|
|
302
|
+
// Memory addresses (0x...)
|
|
303
|
+
.replace(/0x[0-9a-fA-F]{6,}/g, "[addr]")
|
|
304
|
+
.trim();
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
function softmax(logits: number[]): number[] {
|
|
308
|
+
if (logits.length === 0) return [];
|
|
309
|
+
const max = Math.max(...logits);
|
|
310
|
+
const exps = logits.map((l) => Math.exp(l - max));
|
|
311
|
+
const sum = exps.reduce((a, b) => a + b, 0);
|
|
312
|
+
if (sum === 0) return logits.map(() => 0);
|
|
313
|
+
return exps.map((e) => e / sum);
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
function truncate(
|
|
317
|
+
tokens: ReturnType<Tokenizer["encode"]>,
|
|
318
|
+
maxLength: number,
|
|
319
|
+
): ReturnType<Tokenizer["encode"]> {
|
|
320
|
+
if (tokens.input_ids.length <= maxLength) return tokens;
|
|
321
|
+
return {
|
|
322
|
+
input_ids: tokens.input_ids.slice(0, maxLength),
|
|
323
|
+
attention_mask: tokens.attention_mask.slice(0, maxLength),
|
|
324
|
+
token_type_ids: tokens.token_type_ids?.slice(0, maxLength),
|
|
325
|
+
};
|
|
326
|
+
}
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
// ============================================================
|
|
2
|
+
// ai-shield-classifier-onnx — Optional ML classifier package
|
|
3
|
+
//
|
|
4
|
+
// Pairs with ai-shield-core. Adds an ONNX-runtime-backed
|
|
5
|
+
// prompt-injection classifier (DeBERTa-style by default) that can
|
|
6
|
+
// be added to a ScannerChain alongside the built-in heuristics.
|
|
7
|
+
//
|
|
8
|
+
// Why a separate package?
|
|
9
|
+
// The core package keeps a zero-dependency promise (Node stdlib only).
|
|
10
|
+
// ONNX inference requires `onnxruntime-node`, which ships native
|
|
11
|
+
// binaries and is too heavy to force on every consumer. Install this
|
|
12
|
+
// package only when you actively want ML-augmented detection.
|
|
13
|
+
//
|
|
14
|
+
// Recommended models:
|
|
15
|
+
// - protectai/deberta-v3-base-prompt-injection (Apache-2.0)
|
|
16
|
+
// - protectai/deberta-v3-base-prompt-injection-v2
|
|
17
|
+
// - hf models exported to ONNX via `optimum-cli`
|
|
18
|
+
//
|
|
19
|
+
// The classifier is intentionally pluggable — you bring the model file
|
|
20
|
+
// + tokenizer JSON, the wrapper handles inference + thresholding +
|
|
21
|
+
// integration with the Scanner interface.
|
|
22
|
+
// ============================================================
|
|
23
|
+
|
|
24
|
+
export {
|
|
25
|
+
OnnxInjectionScanner,
|
|
26
|
+
loadOnnxClassifier,
|
|
27
|
+
type OnnxClassifierConfig,
|
|
28
|
+
type OnnxInferenceRuntime,
|
|
29
|
+
type Tokenizer,
|
|
30
|
+
} from "./classifier.js";
|