ai-shield-classifier-onnx 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,85 @@
1
+ import type { Scanner, ScannerResult, ScanContext } from "ai-shield-core";
2
+ /**
3
+ * Minimal subset of `onnxruntime-node`'s `InferenceSession` we use.
4
+ * Declaring it locally means this package type-checks even when
5
+ * `onnxruntime-node` is not installed (which is the entire point —
6
+ * it's an optional peer dep).
7
+ */
8
+ export interface OnnxInferenceRuntime {
9
+ run(feeds: Record<string, OnnxTensorLike>): Promise<Record<string, OnnxTensorLike>>;
10
+ readonly inputNames?: readonly string[];
11
+ readonly outputNames?: readonly string[];
12
+ }
13
+ /** Tensor descriptor compatible with `onnxruntime-common`'s Tensor. */
14
+ export interface OnnxTensorLike {
15
+ readonly data: ArrayLike<number> | BigInt64Array;
16
+ readonly dims: readonly number[];
17
+ readonly type?: string;
18
+ }
19
+ /**
20
+ * Tokenizer abstraction. Same trick as the runtime — we don't bind to
21
+ * any specific HF tokenizer package so the user can wire up
22
+ * `@huggingface/transformers`, `tokenizers`, or a hand-written one.
23
+ */
24
+ export interface Tokenizer {
25
+ encode(text: string): {
26
+ input_ids: number[];
27
+ attention_mask: number[];
28
+ token_type_ids?: number[];
29
+ };
30
+ /** Optional max sequence length the tokenizer was trained for. */
31
+ modelMaxLength?: number;
32
+ }
33
+ export interface OnnxClassifierConfig {
34
+ /** A pre-constructed inference session. */
35
+ session: OnnxInferenceRuntime;
36
+ /** Tokenizer matching the model. */
37
+ tokenizer: Tokenizer;
38
+ /**
39
+ * Probability threshold above which the input is flagged as
40
+ * injection. Default 0.85 (calibrated for protectai/deberta-v3-base-
41
+ * prompt-injection — adjust per model).
42
+ */
43
+ threshold?: number;
44
+ /**
45
+ * Name of the output node that carries logits/probabilities.
46
+ * Default: first key in the runtime's result map.
47
+ */
48
+ outputName?: string;
49
+ /**
50
+ * Index of the "injection" class in the model output.
51
+ * Default: 1 (binary classifier convention: 0 = SAFE, 1 = INJECTION).
52
+ */
53
+ injectionClassIndex?: number;
54
+ /**
55
+ * Maximum sequence length to feed. Default: 512.
56
+ * Larger inputs are truncated head-only (start kept) which matches
57
+ * the standard DeBERTa fine-tune recipe.
58
+ */
59
+ maxLength?: number;
60
+ }
61
+ export declare class OnnxInjectionScanner implements Scanner {
62
+ readonly name = "onnx-classifier";
63
+ private readonly cfg;
64
+ constructor(config: OnnxClassifierConfig);
65
+ scan(input: string, _context: ScanContext): Promise<ScannerResult>;
66
+ /** Direct probability access — useful for tests + custom flows. */
67
+ predict(input: string): Promise<number>;
68
+ }
69
+ /**
70
+ * Convenience loader. Imports `onnxruntime-node` *at runtime* so
71
+ * consumers who never call this function don't pay the install cost.
72
+ *
73
+ * Tokenizer loading is left to the caller because tokenizer
74
+ * implementations vary widely between models — we don't want to
75
+ * pin a specific HF tokenizer package.
76
+ */
77
+ export declare function loadOnnxClassifier(opts: {
78
+ modelPath: string;
79
+ tokenizer: Tokenizer;
80
+ threshold?: number;
81
+ outputName?: string;
82
+ injectionClassIndex?: number;
83
+ maxLength?: number;
84
+ }): Promise<OnnxInjectionScanner>;
85
+ //# sourceMappingURL=classifier.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"classifier.d.ts","sourceRoot":"","sources":["../src/classifier.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EACV,OAAO,EACP,aAAa,EACb,WAAW,EAEZ,MAAM,gBAAgB,CAAC;AAqBxB;;;;;GAKG;AACH,MAAM,WAAW,oBAAoB;IACnC,GAAG,CACD,KAAK,EAAE,MAAM,CAAC,MAAM,EAAE,cAAc,CAAC,GACpC,OAAO,CAAC,MAAM,CAAC,MAAM,EAAE,cAAc,CAAC,CAAC,CAAC;IAC3C,QAAQ,CAAC,UAAU,CAAC,EAAE,SAAS,MAAM,EAAE,CAAC;IACxC,QAAQ,CAAC,WAAW,CAAC,EAAE,SAAS,MAAM,EAAE,CAAC;CAC1C;AAED,uEAAuE;AACvE,MAAM,WAAW,cAAc;IAC7B,QAAQ,CAAC,IAAI,EAAE,SAAS,CAAC,MAAM,CAAC,GAAG,aAAa,CAAC;IACjD,QAAQ,CAAC,IAAI,EAAE,SAAS,MAAM,EAAE,CAAC;IACjC,QAAQ,CAAC,IAAI,CAAC,EAAE,MAAM,CAAC;CACxB;AAED;;;;GAIG;AACH,MAAM,WAAW,SAAS;IACxB,MAAM,CAAC,IAAI,EAAE,MAAM,GAAG;QACpB,SAAS,EAAE,MAAM,EAAE,CAAC;QACpB,cAAc,EAAE,MAAM,EAAE,CAAC;QACzB,cAAc,CAAC,EAAE,MAAM,EAAE,CAAC;KAC3B,CAAC;IACF,kEAAkE;IAClE,cAAc,CAAC,EAAE,MAAM,CAAC;CACzB;AAED,MAAM,WAAW,oBAAoB;IACnC,2CAA2C;IAC3C,OAAO,EAAE,oBAAoB,CAAC;IAC9B,oCAAoC;IACpC,SAAS,EAAE,SAAS,CAAC;IACrB;;;;OAIG;IACH,SAAS,CAAC,EAAE,MAAM,CAAC;IACnB;;;OAGG;IACH,UAAU,CAAC,EAAE,MAAM,CAAC;IACpB;;;OAGG;IACH,mBAAmB,CAAC,EAAE,MAAM,CAAC;IAC7B;;;;OAIG;IACH,SAAS,CAAC,EAAE,MAAM,CAAC;CACpB;AAED,qBAAa,oBAAqB,YAAW,OAAO;IAClD,QAAQ,CAAC,IAAI,qBAAqB;IAClC,OAAO,CAAC,QAAQ,CAAC,GAAG,CAEQ;gBAEhB,MAAM,EAAE,oBAAoB;IAkBlC,IAAI,CAAC,KAAK,EAAE,MAAM,EAAE,QAAQ,EAAE,WAAW,GAAG,OAAO,CAAC,aAAa,CAAC;IAmExE,mEAAmE;IAC7D,OAAO,CAAC,KAAK,EAAE,MAAM,GAAG,OAAO,CAAC,MAAM,CAAC;CAkD9C;AAED;;;;;;;GAOG;AACH,wBAAsB,kBAAkB,CAAC,IAAI,EAAE;IAC7C,SAAS,EAAE,MAAM,CAAC;IAClB,SAAS,EAAE,SAAS,CAAC;IACrB,SAAS,CAAC,EAAE,MAAM,CAAC;IACnB,UAAU,CAAC,EAAE,MAAM,CAAC;IACpB,mBAAmB,CAAC,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,EAAE,MAAM,CAAC;CACpB,GAAG,OAAO,CAAC,oBAAoB,CAAC,CA6BhC"}
@@ -0,0 +1,201 @@
1
+ export class OnnxInjectionScanner {
2
+ name = "onnx-classifier";
3
+ cfg;
4
+ constructor(config) {
5
+ if (!config.session) {
6
+ throw new TypeError("OnnxInjectionScanner: 'session' is required");
7
+ }
8
+ if (!config.tokenizer) {
9
+ throw new TypeError("OnnxInjectionScanner: 'tokenizer' is required");
10
+ }
11
+ this.cfg = {
12
+ session: config.session,
13
+ tokenizer: config.tokenizer,
14
+ threshold: config.threshold ?? 0.85,
15
+ outputName: config.outputName,
16
+ injectionClassIndex: config.injectionClassIndex ?? 1,
17
+ maxLength: config.maxLength ?? config.tokenizer.modelMaxLength ?? 512,
18
+ };
19
+ }
20
+ async scan(input, _context) {
21
+ const start = performance.now();
22
+ try {
23
+ const probability = await this.predict(input);
24
+ const violations = [];
25
+ let decision = "allow";
26
+ if (probability >= this.cfg.threshold) {
27
+ decision = "block";
28
+ violations.push({
29
+ type: "prompt_injection",
30
+ scanner: this.name,
31
+ score: probability,
32
+ threshold: this.cfg.threshold,
33
+ message: "ML classifier flagged prompt-injection",
34
+ detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
35
+ });
36
+ }
37
+ else if (probability >= this.cfg.threshold * 0.6) {
38
+ decision = "warn";
39
+ violations.push({
40
+ type: "prompt_injection",
41
+ scanner: this.name,
42
+ score: probability,
43
+ threshold: this.cfg.threshold,
44
+ message: "ML classifier flagged borderline content",
45
+ detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
46
+ });
47
+ }
48
+ return {
49
+ decision,
50
+ violations,
51
+ durationMs: performance.now() - start,
52
+ };
53
+ }
54
+ catch (err) {
55
+ // ML errors must not take down the entire chain — degrade gracefully
56
+ // to "allow" with a synthetic violation so the audit log shows
57
+ // something went wrong without blocking traffic.
58
+ //
59
+ // Critic H4 round 1 — the raw error message can contain file
60
+ // paths (model location), native library symbols, or deployment-
61
+ // internal strings. Strip absolute paths before they hit the audit
62
+ // log. In dev mode we keep more detail to aid debugging.
63
+ const rawMessage = err?.message ?? "unknown error";
64
+ const isDev = process.env.NODE_ENV === "development" ||
65
+ process.env.AI_SHIELD_DEBUG === "1";
66
+ const safeDetail = isDev
67
+ ? rawMessage
68
+ : sanitizeOnnxErrorMessage(rawMessage);
69
+ return {
70
+ decision: "allow",
71
+ violations: [
72
+ {
73
+ type: "content_policy",
74
+ scanner: this.name,
75
+ score: 0,
76
+ threshold: this.cfg.threshold,
77
+ message: "ML classifier failed — degraded to allow",
78
+ detail: safeDetail,
79
+ },
80
+ ],
81
+ durationMs: performance.now() - start,
82
+ };
83
+ }
84
+ }
85
+ /** Direct probability access — useful for tests + custom flows. */
86
+ async predict(input) {
87
+ const tokens = this.cfg.tokenizer.encode(input);
88
+ const trunc = truncate(tokens, this.cfg.maxLength);
89
+ const inputIds = BigInt64Array.from(trunc.input_ids.map((n) => BigInt(n)));
90
+ const attentionMask = BigInt64Array.from(trunc.attention_mask.map((n) => BigInt(n)));
91
+ const dims = [1, trunc.input_ids.length];
92
+ const feeds = {
93
+ input_ids: { data: inputIds, dims, type: "int64" },
94
+ attention_mask: { data: attentionMask, dims, type: "int64" },
95
+ };
96
+ if (trunc.token_type_ids) {
97
+ feeds.token_type_ids = {
98
+ data: BigInt64Array.from(trunc.token_type_ids.map((n) => BigInt(n))),
99
+ dims,
100
+ type: "int64",
101
+ };
102
+ }
103
+ const result = await this.cfg.session.run(feeds);
104
+ const outputName = this.cfg.outputName ??
105
+ this.cfg.session.outputNames?.[0] ??
106
+ Object.keys(result)[0];
107
+ if (!outputName) {
108
+ throw new Error("OnnxInjectionScanner: no output node available");
109
+ }
110
+ const tensor = result[outputName];
111
+ if (!tensor) {
112
+ throw new Error(`OnnxInjectionScanner: output '${outputName}' not in result`);
113
+ }
114
+ // Model emits logits of shape [1, num_classes]. Softmax + pick class.
115
+ const logits = Array.from(tensor.data).map((n) => Number(n));
116
+ const probs = softmax(logits);
117
+ const idx = this.cfg.injectionClassIndex;
118
+ if (idx < 0 || idx >= probs.length) {
119
+ throw new Error(`OnnxInjectionScanner: injectionClassIndex ${idx} out of range (len=${probs.length})`);
120
+ }
121
+ return probs[idx] ?? 0;
122
+ }
123
+ }
124
+ /**
125
+ * Convenience loader. Imports `onnxruntime-node` *at runtime* so
126
+ * consumers who never call this function don't pay the install cost.
127
+ *
128
+ * Tokenizer loading is left to the caller because tokenizer
129
+ * implementations vary widely between models — we don't want to
130
+ * pin a specific HF tokenizer package.
131
+ */
132
+ export async function loadOnnxClassifier(opts) {
133
+ // Dynamic import keeps the dep optional — TypeScript can't see it
134
+ // at compile time, which is the whole point.
135
+ let ort;
136
+ try {
137
+ ort = await import("onnxruntime-node");
138
+ }
139
+ catch (err) {
140
+ throw new Error("ai-shield-classifier-onnx: 'onnxruntime-node' is required " +
141
+ "to call loadOnnxClassifier(). Install it as a peer dependency.\n" +
142
+ `Underlying error: ${err.message}`);
143
+ }
144
+ const ortModule = ort;
145
+ const create = ortModule.InferenceSession?.create;
146
+ if (typeof create !== "function") {
147
+ throw new Error("ai-shield-classifier-onnx: 'onnxruntime-node' did not expose InferenceSession.create");
148
+ }
149
+ const session = await create(opts.modelPath);
150
+ return new OnnxInjectionScanner({
151
+ session,
152
+ tokenizer: opts.tokenizer,
153
+ threshold: opts.threshold,
154
+ outputName: opts.outputName,
155
+ injectionClassIndex: opts.injectionClassIndex,
156
+ maxLength: opts.maxLength,
157
+ });
158
+ }
159
+ // --- helpers ---
160
+ /**
161
+ * Strip absolute paths and other deployment-internal strings from an
162
+ * ONNX-runtime error message before it lands in the audit log. Keeps
163
+ * the short error class / cause hint that helps diagnose the failure.
164
+ */
165
+ function sanitizeOnnxErrorMessage(message) {
166
+ if (typeof message !== "string" || message.length === 0) {
167
+ return "classifier_runtime_error";
168
+ }
169
+ // Truncate before sanitizing — bounded work on adversarial input.
170
+ const truncated = message.length > 500 ? message.slice(0, 500) : message;
171
+ return truncated
172
+ // POSIX absolute paths
173
+ .replace(/(?:^|[\s(])(\/[\w./@-]+)/g, " [path]")
174
+ // Windows drive paths
175
+ .replace(/[A-Za-z]:\\[\\\w./@-]+/g, "[path]")
176
+ // file:// URLs
177
+ .replace(/file:\/\/\S+/g, "[file-url]")
178
+ // Memory addresses (0x...)
179
+ .replace(/0x[0-9a-fA-F]{6,}/g, "[addr]")
180
+ .trim();
181
+ }
182
+ function softmax(logits) {
183
+ if (logits.length === 0)
184
+ return [];
185
+ const max = Math.max(...logits);
186
+ const exps = logits.map((l) => Math.exp(l - max));
187
+ const sum = exps.reduce((a, b) => a + b, 0);
188
+ if (sum === 0)
189
+ return logits.map(() => 0);
190
+ return exps.map((e) => e / sum);
191
+ }
192
+ function truncate(tokens, maxLength) {
193
+ if (tokens.input_ids.length <= maxLength)
194
+ return tokens;
195
+ return {
196
+ input_ids: tokens.input_ids.slice(0, maxLength),
197
+ attention_mask: tokens.attention_mask.slice(0, maxLength),
198
+ token_type_ids: tokens.token_type_ids?.slice(0, maxLength),
199
+ };
200
+ }
201
+ //# sourceMappingURL=classifier.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"classifier.js","sourceRoot":"","sources":["../src/classifier.ts"],"names":[],"mappings":"AA2FA,MAAM,OAAO,oBAAoB;IACtB,IAAI,GAAG,iBAAiB,CAAC;IACjB,GAAG,CAEQ;IAE5B,YAAY,MAA4B;QACtC,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC;YACpB,MAAM,IAAI,SAAS,CAAC,6CAA6C,CAAC,CAAC;QACrE,CAAC;QACD,IAAI,CAAC,MAAM,CAAC,SAAS,EAAE,CAAC;YACtB,MAAM,IAAI,SAAS,CAAC,+CAA+C,CAAC,CAAC;QACvE,CAAC;QACD,IAAI,CAAC,GAAG,GAAG;YACT,OAAO,EAAE,MAAM,CAAC,OAAO;YACvB,SAAS,EAAE,MAAM,CAAC,SAAS;YAC3B,SAAS,EAAE,MAAM,CAAC,SAAS,IAAI,IAAI;YACnC,UAAU,EAAE,MAAM,CAAC,UAAU;YAC7B,mBAAmB,EAAE,MAAM,CAAC,mBAAmB,IAAI,CAAC;YACpD,SAAS,EACP,MAAM,CAAC,SAAS,IAAI,MAAM,CAAC,SAAS,CAAC,cAAc,IAAI,GAAG;SAC7D,CAAC;IACJ,CAAC;IAED,KAAK,CAAC,IAAI,CAAC,KAAa,EAAE,QAAqB;QAC7C,MAAM,KAAK,GAAG,WAAW,CAAC,GAAG,EAAE,CAAC;QAChC,IAAI,CAAC;YACH,MAAM,WAAW,GAAG,MAAM,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;YAC9C,MAAM,UAAU,GAAgB,EAAE,CAAC;YACnC,IAAI,QAAQ,GAA8B,OAAO,CAAC;YAElD,IAAI,WAAW,IAAI,IAAI,CAAC,GAAG,CAAC,SAAS,EAAE,CAAC;gBACtC,QAAQ,GAAG,OAAO,CAAC;gBACnB,UAAU,CAAC,IAAI,CAAC;oBACd,IAAI,EAAE,kBAAkB;oBACxB,OAAO,EAAE,IAAI,CAAC,IAAI;oBAClB,KAAK,EAAE,WAAW;oBAClB,SAAS,EAAE,IAAI,CAAC,GAAG,CAAC,SAAS;oBAC7B,OAAO,EAAE,wCAAwC;oBACjD,MAAM,EAAE,gBAAgB,WAAW,CAAC,OAAO,CAAC,CAAC,CAAC,cAAc,IAAI,CAAC,GAAG,CAAC,SAAS,EAAE;iBACjF,CAAC,CAAC;YACL,CAAC;iBAAM,IAAI,WAAW,IAAI,IAAI,CAAC,GAAG,CAAC,SAAS,GAAG,GAAG,EAAE,CAAC;gBACnD,QAAQ,GAAG,MAAM,CAAC;gBAClB,UAAU,CAAC,IAAI,CAAC;oBACd,IAAI,EAAE,kBAAkB;oBACxB,OAAO,EAAE,IAAI,CAAC,IAAI;oBAClB,KAAK,EAAE,WAAW;oBAClB,SAAS,EAAE,IAAI,CAAC,GAAG,CAAC,SAAS;oBAC7B,OAAO,EAAE,0CAA0C;oBACnD,MAAM,EAAE,gBAAgB,WAAW,CAAC,OAAO,CAAC,CAAC,CAAC,cAAc,IAAI,CAAC,GAAG,CAAC,SAAS,EAAE;iBACjF,CAAC,CAAC;YACL,CAAC;YAED,OAAO;gBACL,QAAQ;gBACR,UAAU;gBACV,UAAU,EAAE,WAAW,CAAC,GAAG,EAAE,GAAG,KAAK;aACtC,CAAC;QACJ,CAAC;QAAC,OAAO,GAAG,EAAE,CAAC;YACb,qEAAqE;YACrE,+DAA+D;YAC/D,iDAAiD;YACjD,EAAE;YACF,6DAA6D;YAC7D,iEAAiE;YACjE,mEAAmE;YACnE,yDAAyD;YACzD,MAAM,UAAU,GAAI,GAAa,EAAE,OAAO,IAAI,eAAe,CAAC;YAC9D,MAAM,KAAK,GACT,OAAO,CAAC,GAAG,CAAC,QAAQ,KAAK,aAAa;gBACtC,OAAO,CAAC,GAAG,CAAC,eAAe,KAAK,GAAG,CAAC;YACtC,MAAM,UAAU,GAAG,KAAK;gBACtB,CAAC,CAAC,UAAU;gBACZ,CAAC,CAAC,wBAAwB,CAAC,UAAU,CAAC,CAAC;YACzC,OAAO;gBACL,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE;oBACV;wBACE,IAAI,EAAE,gBAAgB;wBACtB,OAAO,EAAE,IAAI,CAAC,IAAI;wBAClB,KAAK,EAAE,CAAC;wBACR,SAAS,EAAE,IAAI,CAAC,GAAG,CAAC,SAAS;wBAC7B,OAAO,EAAE,0CAA0C;wBACnD,MAAM,EAAE,UAAU;qBACnB;iBACF;gBACD,UAAU,EAAE,WAAW,CAAC,GAAG,EAAE,GAAG,KAAK;aACtC,CAAC;QACJ,CAAC;IACH,CAAC;IAED,mEAAmE;IACnE,KAAK,CAAC,OAAO,CAAC,KAAa;QACzB,MAAM,MAAM,GAAG,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;QAChD,MAAM,KAAK,GAAG,QAAQ,CAAC,MAAM,EAAE,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC;QAEnD,MAAM,QAAQ,GAAG,aAAa,CAAC,IAAI,CAAC,KAAK,CAAC,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,MAAM,aAAa,GAAG,aAAa,CAAC,IAAI,CACtC,KAAK,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAC3C,CAAC;QACF,MAAM,IAAI,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC;QAEzC,MAAM,KAAK,GAAmC;YAC5C,SAAS,EAAE,EAAE,IAAI,EAAE,QAAQ,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE;YAClD,cAAc,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE;SAC7D,CAAC;QACF,IAAI,KAAK,CAAC,cAAc,EAAE,CAAC;YACzB,KAAK,CAAC,cAAc,GAAG;gBACrB,IAAI,EAAE,aAAa,CAAC,IAAI,CAAC,KAAK,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;gBACpE,IAAI;gBACJ,IAAI,EAAE,OAAO;aACd,CAAC;QACJ,CAAC;QAED,MAAM,MAAM,GAAG,MAAM,IAAI,CAAC,GAAG,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC;QACjD,MAAM,UAAU,GACd,IAAI,CAAC,GAAG,CAAC,UAAU;YACnB,IAAI,CAAC,GAAG,CAAC,OAAO,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC;YACjC,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;QACzB,IAAI,CAAC,UAAU,EAAE,CAAC;YAChB,MAAM,IAAI,KAAK,CAAC,gDAAgD,CAAC,CAAC;QACpE,CAAC;QACD,MAAM,MAAM,GAAG,MAAM,CAAC,UAAU,CAAC,CAAC;QAClC,IAAI,CAAC,MAAM,EAAE,CAAC;YACZ,MAAM,IAAI,KAAK,CACb,iCAAiC,UAAU,iBAAiB,CAC7D,CAAC;QACJ,CAAC;QAED,sEAAsE;QACtE,MAAM,MAAM,GAAG,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC,IAAyB,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CACpE,MAAM,CAAC,CAAC,CAAC,CACV,CAAC;QACF,MAAM,KAAK,GAAG,OAAO,CAAC,MAAM,CAAC,CAAC;QAC9B,MAAM,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,mBAAmB,CAAC;QACzC,IAAI,GAAG,GAAG,CAAC,IAAI,GAAG,IAAI,KAAK,CAAC,MAAM,EAAE,CAAC;YACnC,MAAM,IAAI,KAAK,CACb,6CAA6C,GAAG,sBAAsB,KAAK,CAAC,MAAM,GAAG,CACtF,CAAC;QACJ,CAAC;QACD,OAAO,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;IACzB,CAAC;CACF;AAED;;;;;;;GAOG;AACH,MAAM,CAAC,KAAK,UAAU,kBAAkB,CAAC,IAOxC;IACC,kEAAkE;IAClE,6CAA6C;IAC7C,IAAI,GAAY,CAAC;IACjB,IAAI,CAAC;QACH,GAAG,GAAG,MAAM,MAAM,CAAC,kBAA4B,CAAC,CAAC;IACnD,CAAC;IAAC,OAAO,GAAG,EAAE,CAAC;QACb,MAAM,IAAI,KAAK,CACb,4DAA4D;YAC1D,kEAAkE;YAClE,qBAAsB,GAAa,CAAC,OAAO,EAAE,CAChD,CAAC;IACJ,CAAC;IACD,MAAM,SAAS,GAAI,GAA2F,CAAC;IAC/G,MAAM,MAAM,GAAG,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAC;IAClD,IAAI,OAAO,MAAM,KAAK,UAAU,EAAE,CAAC;QACjC,MAAM,IAAI,KAAK,CACb,sFAAsF,CACvF,CAAC;IACJ,CAAC;IACD,MAAM,OAAO,GAAG,MAAM,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IAC7C,OAAO,IAAI,oBAAoB,CAAC;QAC9B,OAAO;QACP,SAAS,EAAE,IAAI,CAAC,SAAS;QACzB,SAAS,EAAE,IAAI,CAAC,SAAS;QACzB,UAAU,EAAE,IAAI,CAAC,UAAU;QAC3B,mBAAmB,EAAE,IAAI,CAAC,mBAAmB;QAC7C,SAAS,EAAE,IAAI,CAAC,SAAS;KAC1B,CAAC,CAAC;AACL,CAAC;AAED,kBAAkB;AAElB;;;;GAIG;AACH,SAAS,wBAAwB,CAAC,OAAe;IAC/C,IAAI,OAAO,OAAO,KAAK,QAAQ,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE,CAAC;QACxD,OAAO,0BAA0B,CAAC;IACpC,CAAC;IACD,kEAAkE;IAClE,MAAM,SAAS,GAAG,OAAO,CAAC,MAAM,GAAG,GAAG,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC;IACzE,OAAO,SAAS;QACd,uBAAuB;SACtB,OAAO,CAAC,2BAA2B,EAAE,SAAS,CAAC;QAChD,sBAAsB;SACrB,OAAO,CAAC,yBAAyB,EAAE,QAAQ,CAAC;QAC7C,eAAe;SACd,OAAO,CAAC,eAAe,EAAE,YAAY,CAAC;QACvC,2BAA2B;SAC1B,OAAO,CAAC,oBAAoB,EAAE,QAAQ,CAAC;SACvC,IAAI,EAAE,CAAC;AACZ,CAAC;AAED,SAAS,OAAO,CAAC,MAAgB;IAC/B,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC;QAAE,OAAO,EAAE,CAAC;IACnC,MAAM,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,MAAM,CAAC,CAAC;IAChC,MAAM,IAAI,GAAG,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC;IAClD,MAAM,GAAG,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC;IAC5C,IAAI,GAAG,KAAK,CAAC;QAAE,OAAO,MAAM,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC;IAC1C,OAAO,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC;AAClC,CAAC;AAED,SAAS,QAAQ,CACf,MAAuC,EACvC,SAAiB;IAEjB,IAAI,MAAM,CAAC,SAAS,CAAC,MAAM,IAAI,SAAS;QAAE,OAAO,MAAM,CAAC;IACxD,OAAO;QACL,SAAS,EAAE,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC;QAC/C,cAAc,EAAE,MAAM,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC;QACzD,cAAc,EAAE,MAAM,CAAC,cAAc,EAAE,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC;KAC3D,CAAC;AACJ,CAAC"}
@@ -0,0 +1,2 @@
1
+ export { OnnxInjectionScanner, loadOnnxClassifier, type OnnxClassifierConfig, type OnnxInferenceRuntime, type Tokenizer, } from "./classifier.js";
2
+ //# sourceMappingURL=index.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAuBA,OAAO,EACL,oBAAoB,EACpB,kBAAkB,EAClB,KAAK,oBAAoB,EACzB,KAAK,oBAAoB,EACzB,KAAK,SAAS,GACf,MAAM,iBAAiB,CAAC"}
@@ -20,11 +20,5 @@
20
20
  // + tokenizer JSON, the wrapper handles inference + thresholding +
21
21
  // integration with the Scanner interface.
22
22
  // ============================================================
23
-
24
- export {
25
- OnnxInjectionScanner,
26
- loadOnnxClassifier,
27
- type OnnxClassifierConfig,
28
- type OnnxInferenceRuntime,
29
- type Tokenizer,
30
- } from "./classifier.js";
23
+ export { OnnxInjectionScanner, loadOnnxClassifier, } from "./classifier.js";
24
+ //# sourceMappingURL=index.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+DAA+D;AAC/D,6DAA6D;AAC7D,EAAE;AACF,yDAAyD;AACzD,kEAAkE;AAClE,gEAAgE;AAChE,EAAE;AACF,0BAA0B;AAC1B,yEAAyE;AACzE,mEAAmE;AACnE,uEAAuE;AACvE,gEAAgE;AAChE,EAAE;AACF,sBAAsB;AACtB,8DAA8D;AAC9D,oDAAoD;AACpD,mDAAmD;AACnD,EAAE;AACF,uEAAuE;AACvE,mEAAmE;AACnE,0CAA0C;AAC1C,+DAA+D;AAE/D,OAAO,EACL,oBAAoB,EACpB,kBAAkB,GAInB,MAAM,iBAAiB,CAAC"}
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "ai-shield-classifier-onnx",
3
- "version": "0.5.0",
3
+ "version": "0.5.1",
4
4
  "license": "MIT",
5
5
  "description": "Optional ONNX ML classifier for ai-shield — DeBERTa-style prompt-injection detection alongside heuristic patterns",
6
6
  "author": "StudioMeyer <hello@studiomeyer.io>",
@@ -24,6 +24,9 @@
24
24
  "ai-shield"
25
25
  ],
26
26
  "type": "module",
27
+ "files": [
28
+ "dist"
29
+ ],
27
30
  "main": "./dist/index.js",
28
31
  "types": "./dist/index.d.ts",
29
32
  "exports": {
package/src/classifier.ts DELETED
@@ -1,326 +0,0 @@
1
- import type {
2
- Scanner,
3
- ScannerResult,
4
- ScanContext,
5
- Violation,
6
- } from "ai-shield-core";
7
-
8
- // ============================================================
9
- // OnnxInjectionScanner — ML-backed prompt-injection detection
10
- //
11
- // Implements the `Scanner` interface from ai-shield-core. Designed to
12
- // be added to a `ScannerChain` after the heuristic scanner so that
13
- // known patterns short-circuit on cheap regex AND novel paraphrases
14
- // still get a second-pass ML check.
15
- //
16
- // The runtime is abstracted via `OnnxInferenceRuntime` so this file
17
- // has zero hard dependency on `onnxruntime-node` at type-check time.
18
- // At runtime you pass in either:
19
- // 1. An already-constructed `ort.InferenceSession`
20
- // (`new OnnxInjectionScanner({ session, tokenizer })`)
21
- // 2. A path to a `.onnx` model file + a path to a `tokenizer.json`
22
- // (`await loadOnnxClassifier({ modelPath, tokenizerPath })`)
23
- //
24
- // Both paths keep the dep injection clean and unit-testable.
25
- // ============================================================
26
-
27
- /**
28
- * Minimal subset of `onnxruntime-node`'s `InferenceSession` we use.
29
- * Declaring it locally means this package type-checks even when
30
- * `onnxruntime-node` is not installed (which is the entire point —
31
- * it's an optional peer dep).
32
- */
33
- export interface OnnxInferenceRuntime {
34
- run(
35
- feeds: Record<string, OnnxTensorLike>,
36
- ): Promise<Record<string, OnnxTensorLike>>;
37
- readonly inputNames?: readonly string[];
38
- readonly outputNames?: readonly string[];
39
- }
40
-
41
- /** Tensor descriptor compatible with `onnxruntime-common`'s Tensor. */
42
- export interface OnnxTensorLike {
43
- readonly data: ArrayLike<number> | BigInt64Array;
44
- readonly dims: readonly number[];
45
- readonly type?: string;
46
- }
47
-
48
- /**
49
- * Tokenizer abstraction. Same trick as the runtime — we don't bind to
50
- * any specific HF tokenizer package so the user can wire up
51
- * `@huggingface/transformers`, `tokenizers`, or a hand-written one.
52
- */
53
- export interface Tokenizer {
54
- encode(text: string): {
55
- input_ids: number[];
56
- attention_mask: number[];
57
- token_type_ids?: number[];
58
- };
59
- /** Optional max sequence length the tokenizer was trained for. */
60
- modelMaxLength?: number;
61
- }
62
-
63
- export interface OnnxClassifierConfig {
64
- /** A pre-constructed inference session. */
65
- session: OnnxInferenceRuntime;
66
- /** Tokenizer matching the model. */
67
- tokenizer: Tokenizer;
68
- /**
69
- * Probability threshold above which the input is flagged as
70
- * injection. Default 0.85 (calibrated for protectai/deberta-v3-base-
71
- * prompt-injection — adjust per model).
72
- */
73
- threshold?: number;
74
- /**
75
- * Name of the output node that carries logits/probabilities.
76
- * Default: first key in the runtime's result map.
77
- */
78
- outputName?: string;
79
- /**
80
- * Index of the "injection" class in the model output.
81
- * Default: 1 (binary classifier convention: 0 = SAFE, 1 = INJECTION).
82
- */
83
- injectionClassIndex?: number;
84
- /**
85
- * Maximum sequence length to feed. Default: 512.
86
- * Larger inputs are truncated head-only (start kept) which matches
87
- * the standard DeBERTa fine-tune recipe.
88
- */
89
- maxLength?: number;
90
- }
91
-
92
- export class OnnxInjectionScanner implements Scanner {
93
- readonly name = "onnx-classifier";
94
- private readonly cfg: Required<
95
- Omit<OnnxClassifierConfig, "outputName">
96
- > & { outputName?: string };
97
-
98
- constructor(config: OnnxClassifierConfig) {
99
- if (!config.session) {
100
- throw new TypeError("OnnxInjectionScanner: 'session' is required");
101
- }
102
- if (!config.tokenizer) {
103
- throw new TypeError("OnnxInjectionScanner: 'tokenizer' is required");
104
- }
105
- this.cfg = {
106
- session: config.session,
107
- tokenizer: config.tokenizer,
108
- threshold: config.threshold ?? 0.85,
109
- outputName: config.outputName,
110
- injectionClassIndex: config.injectionClassIndex ?? 1,
111
- maxLength:
112
- config.maxLength ?? config.tokenizer.modelMaxLength ?? 512,
113
- };
114
- }
115
-
116
- async scan(input: string, _context: ScanContext): Promise<ScannerResult> {
117
- const start = performance.now();
118
- try {
119
- const probability = await this.predict(input);
120
- const violations: Violation[] = [];
121
- let decision: ScannerResult["decision"] = "allow";
122
-
123
- if (probability >= this.cfg.threshold) {
124
- decision = "block";
125
- violations.push({
126
- type: "prompt_injection",
127
- scanner: this.name,
128
- score: probability,
129
- threshold: this.cfg.threshold,
130
- message: "ML classifier flagged prompt-injection",
131
- detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
132
- });
133
- } else if (probability >= this.cfg.threshold * 0.6) {
134
- decision = "warn";
135
- violations.push({
136
- type: "prompt_injection",
137
- scanner: this.name,
138
- score: probability,
139
- threshold: this.cfg.threshold,
140
- message: "ML classifier flagged borderline content",
141
- detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
142
- });
143
- }
144
-
145
- return {
146
- decision,
147
- violations,
148
- durationMs: performance.now() - start,
149
- };
150
- } catch (err) {
151
- // ML errors must not take down the entire chain — degrade gracefully
152
- // to "allow" with a synthetic violation so the audit log shows
153
- // something went wrong without blocking traffic.
154
- //
155
- // Critic H4 round 1 — the raw error message can contain file
156
- // paths (model location), native library symbols, or deployment-
157
- // internal strings. Strip absolute paths before they hit the audit
158
- // log. In dev mode we keep more detail to aid debugging.
159
- const rawMessage = (err as Error)?.message ?? "unknown error";
160
- const isDev =
161
- process.env.NODE_ENV === "development" ||
162
- process.env.AI_SHIELD_DEBUG === "1";
163
- const safeDetail = isDev
164
- ? rawMessage
165
- : sanitizeOnnxErrorMessage(rawMessage);
166
- return {
167
- decision: "allow",
168
- violations: [
169
- {
170
- type: "content_policy",
171
- scanner: this.name,
172
- score: 0,
173
- threshold: this.cfg.threshold,
174
- message: "ML classifier failed — degraded to allow",
175
- detail: safeDetail,
176
- },
177
- ],
178
- durationMs: performance.now() - start,
179
- };
180
- }
181
- }
182
-
183
- /** Direct probability access — useful for tests + custom flows. */
184
- async predict(input: string): Promise<number> {
185
- const tokens = this.cfg.tokenizer.encode(input);
186
- const trunc = truncate(tokens, this.cfg.maxLength);
187
-
188
- const inputIds = BigInt64Array.from(trunc.input_ids.map((n) => BigInt(n)));
189
- const attentionMask = BigInt64Array.from(
190
- trunc.attention_mask.map((n) => BigInt(n)),
191
- );
192
- const dims = [1, trunc.input_ids.length];
193
-
194
- const feeds: Record<string, OnnxTensorLike> = {
195
- input_ids: { data: inputIds, dims, type: "int64" },
196
- attention_mask: { data: attentionMask, dims, type: "int64" },
197
- };
198
- if (trunc.token_type_ids) {
199
- feeds.token_type_ids = {
200
- data: BigInt64Array.from(trunc.token_type_ids.map((n) => BigInt(n))),
201
- dims,
202
- type: "int64",
203
- };
204
- }
205
-
206
- const result = await this.cfg.session.run(feeds);
207
- const outputName =
208
- this.cfg.outputName ??
209
- this.cfg.session.outputNames?.[0] ??
210
- Object.keys(result)[0];
211
- if (!outputName) {
212
- throw new Error("OnnxInjectionScanner: no output node available");
213
- }
214
- const tensor = result[outputName];
215
- if (!tensor) {
216
- throw new Error(
217
- `OnnxInjectionScanner: output '${outputName}' not in result`,
218
- );
219
- }
220
-
221
- // Model emits logits of shape [1, num_classes]. Softmax + pick class.
222
- const logits = Array.from(tensor.data as ArrayLike<number>).map((n) =>
223
- Number(n),
224
- );
225
- const probs = softmax(logits);
226
- const idx = this.cfg.injectionClassIndex;
227
- if (idx < 0 || idx >= probs.length) {
228
- throw new Error(
229
- `OnnxInjectionScanner: injectionClassIndex ${idx} out of range (len=${probs.length})`,
230
- );
231
- }
232
- return probs[idx] ?? 0;
233
- }
234
- }
235
-
236
- /**
237
- * Convenience loader. Imports `onnxruntime-node` *at runtime* so
238
- * consumers who never call this function don't pay the install cost.
239
- *
240
- * Tokenizer loading is left to the caller because tokenizer
241
- * implementations vary widely between models — we don't want to
242
- * pin a specific HF tokenizer package.
243
- */
244
- export async function loadOnnxClassifier(opts: {
245
- modelPath: string;
246
- tokenizer: Tokenizer;
247
- threshold?: number;
248
- outputName?: string;
249
- injectionClassIndex?: number;
250
- maxLength?: number;
251
- }): Promise<OnnxInjectionScanner> {
252
- // Dynamic import keeps the dep optional — TypeScript can't see it
253
- // at compile time, which is the whole point.
254
- let ort: unknown;
255
- try {
256
- ort = await import("onnxruntime-node" as string);
257
- } catch (err) {
258
- throw new Error(
259
- "ai-shield-classifier-onnx: 'onnxruntime-node' is required " +
260
- "to call loadOnnxClassifier(). Install it as a peer dependency.\n" +
261
- `Underlying error: ${(err as Error).message}`,
262
- );
263
- }
264
- const ortModule = (ort as { InferenceSession?: { create?: (path: string) => Promise<OnnxInferenceRuntime> } });
265
- const create = ortModule.InferenceSession?.create;
266
- if (typeof create !== "function") {
267
- throw new Error(
268
- "ai-shield-classifier-onnx: 'onnxruntime-node' did not expose InferenceSession.create",
269
- );
270
- }
271
- const session = await create(opts.modelPath);
272
- return new OnnxInjectionScanner({
273
- session,
274
- tokenizer: opts.tokenizer,
275
- threshold: opts.threshold,
276
- outputName: opts.outputName,
277
- injectionClassIndex: opts.injectionClassIndex,
278
- maxLength: opts.maxLength,
279
- });
280
- }
281
-
282
- // --- helpers ---
283
-
284
- /**
285
- * Strip absolute paths and other deployment-internal strings from an
286
- * ONNX-runtime error message before it lands in the audit log. Keeps
287
- * the short error class / cause hint that helps diagnose the failure.
288
- */
289
- function sanitizeOnnxErrorMessage(message: string): string {
290
- if (typeof message !== "string" || message.length === 0) {
291
- return "classifier_runtime_error";
292
- }
293
- // Truncate before sanitizing — bounded work on adversarial input.
294
- const truncated = message.length > 500 ? message.slice(0, 500) : message;
295
- return truncated
296
- // POSIX absolute paths
297
- .replace(/(?:^|[\s(])(\/[\w./@-]+)/g, " [path]")
298
- // Windows drive paths
299
- .replace(/[A-Za-z]:\\[\\\w./@-]+/g, "[path]")
300
- // file:// URLs
301
- .replace(/file:\/\/\S+/g, "[file-url]")
302
- // Memory addresses (0x...)
303
- .replace(/0x[0-9a-fA-F]{6,}/g, "[addr]")
304
- .trim();
305
- }
306
-
307
- function softmax(logits: number[]): number[] {
308
- if (logits.length === 0) return [];
309
- const max = Math.max(...logits);
310
- const exps = logits.map((l) => Math.exp(l - max));
311
- const sum = exps.reduce((a, b) => a + b, 0);
312
- if (sum === 0) return logits.map(() => 0);
313
- return exps.map((e) => e / sum);
314
- }
315
-
316
- function truncate(
317
- tokens: ReturnType<Tokenizer["encode"]>,
318
- maxLength: number,
319
- ): ReturnType<Tokenizer["encode"]> {
320
- if (tokens.input_ids.length <= maxLength) return tokens;
321
- return {
322
- input_ids: tokens.input_ids.slice(0, maxLength),
323
- attention_mask: tokens.attention_mask.slice(0, maxLength),
324
- token_type_ids: tokens.token_type_ids?.slice(0, maxLength),
325
- };
326
- }
package/tsconfig.json DELETED
@@ -1,12 +0,0 @@
1
- {
2
- "extends": "../../tsconfig.json",
3
- "compilerOptions": {
4
- "outDir": "./dist",
5
- "rootDir": "./src",
6
- "composite": true
7
- },
8
- "include": ["src/**/*"],
9
- "references": [
10
- { "path": "../core" }
11
- ]
12
- }