ai-shield-classifier-onnx 0.4.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/classifier.d.ts +85 -0
- package/dist/classifier.d.ts.map +1 -0
- package/dist/classifier.js +201 -0
- package/dist/classifier.js.map +1 -0
- package/dist/index.d.ts +2 -0
- package/dist/index.d.ts.map +1 -0
- package/{src/index.ts → dist/index.js} +2 -8
- package/dist/index.js.map +1 -0
- package/package.json +7 -4
- package/src/classifier.ts +0 -326
- package/tsconfig.json +0 -12
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import type { Scanner, ScannerResult, ScanContext } from "ai-shield-core";
|
|
2
|
+
/**
|
|
3
|
+
* Minimal subset of `onnxruntime-node`'s `InferenceSession` we use.
|
|
4
|
+
* Declaring it locally means this package type-checks even when
|
|
5
|
+
* `onnxruntime-node` is not installed (which is the entire point —
|
|
6
|
+
* it's an optional peer dep).
|
|
7
|
+
*/
|
|
8
|
+
export interface OnnxInferenceRuntime {
|
|
9
|
+
run(feeds: Record<string, OnnxTensorLike>): Promise<Record<string, OnnxTensorLike>>;
|
|
10
|
+
readonly inputNames?: readonly string[];
|
|
11
|
+
readonly outputNames?: readonly string[];
|
|
12
|
+
}
|
|
13
|
+
/** Tensor descriptor compatible with `onnxruntime-common`'s Tensor. */
|
|
14
|
+
export interface OnnxTensorLike {
|
|
15
|
+
readonly data: ArrayLike<number> | BigInt64Array;
|
|
16
|
+
readonly dims: readonly number[];
|
|
17
|
+
readonly type?: string;
|
|
18
|
+
}
|
|
19
|
+
/**
|
|
20
|
+
* Tokenizer abstraction. Same trick as the runtime — we don't bind to
|
|
21
|
+
* any specific HF tokenizer package so the user can wire up
|
|
22
|
+
* `@huggingface/transformers`, `tokenizers`, or a hand-written one.
|
|
23
|
+
*/
|
|
24
|
+
export interface Tokenizer {
|
|
25
|
+
encode(text: string): {
|
|
26
|
+
input_ids: number[];
|
|
27
|
+
attention_mask: number[];
|
|
28
|
+
token_type_ids?: number[];
|
|
29
|
+
};
|
|
30
|
+
/** Optional max sequence length the tokenizer was trained for. */
|
|
31
|
+
modelMaxLength?: number;
|
|
32
|
+
}
|
|
33
|
+
export interface OnnxClassifierConfig {
|
|
34
|
+
/** A pre-constructed inference session. */
|
|
35
|
+
session: OnnxInferenceRuntime;
|
|
36
|
+
/** Tokenizer matching the model. */
|
|
37
|
+
tokenizer: Tokenizer;
|
|
38
|
+
/**
|
|
39
|
+
* Probability threshold above which the input is flagged as
|
|
40
|
+
* injection. Default 0.85 (calibrated for protectai/deberta-v3-base-
|
|
41
|
+
* prompt-injection — adjust per model).
|
|
42
|
+
*/
|
|
43
|
+
threshold?: number;
|
|
44
|
+
/**
|
|
45
|
+
* Name of the output node that carries logits/probabilities.
|
|
46
|
+
* Default: first key in the runtime's result map.
|
|
47
|
+
*/
|
|
48
|
+
outputName?: string;
|
|
49
|
+
/**
|
|
50
|
+
* Index of the "injection" class in the model output.
|
|
51
|
+
* Default: 1 (binary classifier convention: 0 = SAFE, 1 = INJECTION).
|
|
52
|
+
*/
|
|
53
|
+
injectionClassIndex?: number;
|
|
54
|
+
/**
|
|
55
|
+
* Maximum sequence length to feed. Default: 512.
|
|
56
|
+
* Larger inputs are truncated head-only (start kept) which matches
|
|
57
|
+
* the standard DeBERTa fine-tune recipe.
|
|
58
|
+
*/
|
|
59
|
+
maxLength?: number;
|
|
60
|
+
}
|
|
61
|
+
export declare class OnnxInjectionScanner implements Scanner {
|
|
62
|
+
readonly name = "onnx-classifier";
|
|
63
|
+
private readonly cfg;
|
|
64
|
+
constructor(config: OnnxClassifierConfig);
|
|
65
|
+
scan(input: string, _context: ScanContext): Promise<ScannerResult>;
|
|
66
|
+
/** Direct probability access — useful for tests + custom flows. */
|
|
67
|
+
predict(input: string): Promise<number>;
|
|
68
|
+
}
|
|
69
|
+
/**
|
|
70
|
+
* Convenience loader. Imports `onnxruntime-node` *at runtime* so
|
|
71
|
+
* consumers who never call this function don't pay the install cost.
|
|
72
|
+
*
|
|
73
|
+
* Tokenizer loading is left to the caller because tokenizer
|
|
74
|
+
* implementations vary widely between models — we don't want to
|
|
75
|
+
* pin a specific HF tokenizer package.
|
|
76
|
+
*/
|
|
77
|
+
export declare function loadOnnxClassifier(opts: {
|
|
78
|
+
modelPath: string;
|
|
79
|
+
tokenizer: Tokenizer;
|
|
80
|
+
threshold?: number;
|
|
81
|
+
outputName?: string;
|
|
82
|
+
injectionClassIndex?: number;
|
|
83
|
+
maxLength?: number;
|
|
84
|
+
}): Promise<OnnxInjectionScanner>;
|
|
85
|
+
//# sourceMappingURL=classifier.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"classifier.d.ts","sourceRoot":"","sources":["../src/classifier.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EACV,OAAO,EACP,aAAa,EACb,WAAW,EAEZ,MAAM,gBAAgB,CAAC;AAqBxB;;;;;GAKG;AACH,MAAM,WAAW,oBAAoB;IACnC,GAAG,CACD,KAAK,EAAE,MAAM,CAAC,MAAM,EAAE,cAAc,CAAC,GACpC,OAAO,CAAC,MAAM,CAAC,MAAM,EAAE,cAAc,CAAC,CAAC,CAAC;IAC3C,QAAQ,CAAC,UAAU,CAAC,EAAE,SAAS,MAAM,EAAE,CAAC;IACxC,QAAQ,CAAC,WAAW,CAAC,EAAE,SAAS,MAAM,EAAE,CAAC;CAC1C;AAED,uEAAuE;AACvE,MAAM,WAAW,cAAc;IAC7B,QAAQ,CAAC,IAAI,EAAE,SAAS,CAAC,MAAM,CAAC,GAAG,aAAa,CAAC;IACjD,QAAQ,CAAC,IAAI,EAAE,SAAS,MAAM,EAAE,CAAC;IACjC,QAAQ,CAAC,IAAI,CAAC,EAAE,MAAM,CAAC;CACxB;AAED;;;;GAIG;AACH,MAAM,WAAW,SAAS;IACxB,MAAM,CAAC,IAAI,EAAE,MAAM,GAAG;QACpB,SAAS,EAAE,MAAM,EAAE,CAAC;QACpB,cAAc,EAAE,MAAM,EAAE,CAAC;QACzB,cAAc,CAAC,EAAE,MAAM,EAAE,CAAC;KAC3B,CAAC;IACF,kEAAkE;IAClE,cAAc,CAAC,EAAE,MAAM,CAAC;CACzB;AAED,MAAM,WAAW,oBAAoB;IACnC,2CAA2C;IAC3C,OAAO,EAAE,oBAAoB,CAAC;IAC9B,oCAAoC;IACpC,SAAS,EAAE,SAAS,CAAC;IACrB;;;;OAIG;IACH,SAAS,CAAC,EAAE,MAAM,CAAC;IACnB;;;OAGG;IACH,UAAU,CAAC,EAAE,MAAM,CAAC;IACpB;;;OAGG;IACH,mBAAmB,CAAC,EAAE,MAAM,CAAC;IAC7B;;;;OAIG;IACH,SAAS,CAAC,EAAE,MAAM,CAAC;CACpB;AAED,qBAAa,oBAAqB,YAAW,OAAO;IAClD,QAAQ,CAAC,IAAI,qBAAqB;IAClC,OAAO,CAAC,QAAQ,CAAC,GAAG,CAEQ;gBAEhB,MAAM,EAAE,oBAAoB;IAkBlC,IAAI,CAAC,KAAK,EAAE,MAAM,EAAE,QAAQ,EAAE,WAAW,GAAG,OAAO,CAAC,aAAa,CAAC;IAmExE,mEAAmE;IAC7D,OAAO,CAAC,KAAK,EAAE,MAAM,GAAG,OAAO,CAAC,MAAM,CAAC;CAkD9C;AAED;;;;;;;GAOG;AACH,wBAAsB,kBAAkB,CAAC,IAAI,EAAE;IAC7C,SAAS,EAAE,MAAM,CAAC;IAClB,SAAS,EAAE,SAAS,CAAC;IACrB,SAAS,CAAC,EAAE,MAAM,CAAC;IACnB,UAAU,CAAC,EAAE,MAAM,CAAC;IACpB,mBAAmB,CAAC,EAAE,MAAM,CAAC;IAC7B,SAAS,CAAC,EAAE,MAAM,CAAC;CACpB,GAAG,OAAO,CAAC,oBAAoB,CAAC,CA6BhC"}
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
export class OnnxInjectionScanner {
|
|
2
|
+
name = "onnx-classifier";
|
|
3
|
+
cfg;
|
|
4
|
+
constructor(config) {
|
|
5
|
+
if (!config.session) {
|
|
6
|
+
throw new TypeError("OnnxInjectionScanner: 'session' is required");
|
|
7
|
+
}
|
|
8
|
+
if (!config.tokenizer) {
|
|
9
|
+
throw new TypeError("OnnxInjectionScanner: 'tokenizer' is required");
|
|
10
|
+
}
|
|
11
|
+
this.cfg = {
|
|
12
|
+
session: config.session,
|
|
13
|
+
tokenizer: config.tokenizer,
|
|
14
|
+
threshold: config.threshold ?? 0.85,
|
|
15
|
+
outputName: config.outputName,
|
|
16
|
+
injectionClassIndex: config.injectionClassIndex ?? 1,
|
|
17
|
+
maxLength: config.maxLength ?? config.tokenizer.modelMaxLength ?? 512,
|
|
18
|
+
};
|
|
19
|
+
}
|
|
20
|
+
async scan(input, _context) {
|
|
21
|
+
const start = performance.now();
|
|
22
|
+
try {
|
|
23
|
+
const probability = await this.predict(input);
|
|
24
|
+
const violations = [];
|
|
25
|
+
let decision = "allow";
|
|
26
|
+
if (probability >= this.cfg.threshold) {
|
|
27
|
+
decision = "block";
|
|
28
|
+
violations.push({
|
|
29
|
+
type: "prompt_injection",
|
|
30
|
+
scanner: this.name,
|
|
31
|
+
score: probability,
|
|
32
|
+
threshold: this.cfg.threshold,
|
|
33
|
+
message: "ML classifier flagged prompt-injection",
|
|
34
|
+
detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
|
|
35
|
+
});
|
|
36
|
+
}
|
|
37
|
+
else if (probability >= this.cfg.threshold * 0.6) {
|
|
38
|
+
decision = "warn";
|
|
39
|
+
violations.push({
|
|
40
|
+
type: "prompt_injection",
|
|
41
|
+
scanner: this.name,
|
|
42
|
+
score: probability,
|
|
43
|
+
threshold: this.cfg.threshold,
|
|
44
|
+
message: "ML classifier flagged borderline content",
|
|
45
|
+
detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
|
|
46
|
+
});
|
|
47
|
+
}
|
|
48
|
+
return {
|
|
49
|
+
decision,
|
|
50
|
+
violations,
|
|
51
|
+
durationMs: performance.now() - start,
|
|
52
|
+
};
|
|
53
|
+
}
|
|
54
|
+
catch (err) {
|
|
55
|
+
// ML errors must not take down the entire chain — degrade gracefully
|
|
56
|
+
// to "allow" with a synthetic violation so the audit log shows
|
|
57
|
+
// something went wrong without blocking traffic.
|
|
58
|
+
//
|
|
59
|
+
// Critic H4 round 1 — the raw error message can contain file
|
|
60
|
+
// paths (model location), native library symbols, or deployment-
|
|
61
|
+
// internal strings. Strip absolute paths before they hit the audit
|
|
62
|
+
// log. In dev mode we keep more detail to aid debugging.
|
|
63
|
+
const rawMessage = err?.message ?? "unknown error";
|
|
64
|
+
const isDev = process.env.NODE_ENV === "development" ||
|
|
65
|
+
process.env.AI_SHIELD_DEBUG === "1";
|
|
66
|
+
const safeDetail = isDev
|
|
67
|
+
? rawMessage
|
|
68
|
+
: sanitizeOnnxErrorMessage(rawMessage);
|
|
69
|
+
return {
|
|
70
|
+
decision: "allow",
|
|
71
|
+
violations: [
|
|
72
|
+
{
|
|
73
|
+
type: "content_policy",
|
|
74
|
+
scanner: this.name,
|
|
75
|
+
score: 0,
|
|
76
|
+
threshold: this.cfg.threshold,
|
|
77
|
+
message: "ML classifier failed — degraded to allow",
|
|
78
|
+
detail: safeDetail,
|
|
79
|
+
},
|
|
80
|
+
],
|
|
81
|
+
durationMs: performance.now() - start,
|
|
82
|
+
};
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
/** Direct probability access — useful for tests + custom flows. */
|
|
86
|
+
async predict(input) {
|
|
87
|
+
const tokens = this.cfg.tokenizer.encode(input);
|
|
88
|
+
const trunc = truncate(tokens, this.cfg.maxLength);
|
|
89
|
+
const inputIds = BigInt64Array.from(trunc.input_ids.map((n) => BigInt(n)));
|
|
90
|
+
const attentionMask = BigInt64Array.from(trunc.attention_mask.map((n) => BigInt(n)));
|
|
91
|
+
const dims = [1, trunc.input_ids.length];
|
|
92
|
+
const feeds = {
|
|
93
|
+
input_ids: { data: inputIds, dims, type: "int64" },
|
|
94
|
+
attention_mask: { data: attentionMask, dims, type: "int64" },
|
|
95
|
+
};
|
|
96
|
+
if (trunc.token_type_ids) {
|
|
97
|
+
feeds.token_type_ids = {
|
|
98
|
+
data: BigInt64Array.from(trunc.token_type_ids.map((n) => BigInt(n))),
|
|
99
|
+
dims,
|
|
100
|
+
type: "int64",
|
|
101
|
+
};
|
|
102
|
+
}
|
|
103
|
+
const result = await this.cfg.session.run(feeds);
|
|
104
|
+
const outputName = this.cfg.outputName ??
|
|
105
|
+
this.cfg.session.outputNames?.[0] ??
|
|
106
|
+
Object.keys(result)[0];
|
|
107
|
+
if (!outputName) {
|
|
108
|
+
throw new Error("OnnxInjectionScanner: no output node available");
|
|
109
|
+
}
|
|
110
|
+
const tensor = result[outputName];
|
|
111
|
+
if (!tensor) {
|
|
112
|
+
throw new Error(`OnnxInjectionScanner: output '${outputName}' not in result`);
|
|
113
|
+
}
|
|
114
|
+
// Model emits logits of shape [1, num_classes]. Softmax + pick class.
|
|
115
|
+
const logits = Array.from(tensor.data).map((n) => Number(n));
|
|
116
|
+
const probs = softmax(logits);
|
|
117
|
+
const idx = this.cfg.injectionClassIndex;
|
|
118
|
+
if (idx < 0 || idx >= probs.length) {
|
|
119
|
+
throw new Error(`OnnxInjectionScanner: injectionClassIndex ${idx} out of range (len=${probs.length})`);
|
|
120
|
+
}
|
|
121
|
+
return probs[idx] ?? 0;
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
/**
|
|
125
|
+
* Convenience loader. Imports `onnxruntime-node` *at runtime* so
|
|
126
|
+
* consumers who never call this function don't pay the install cost.
|
|
127
|
+
*
|
|
128
|
+
* Tokenizer loading is left to the caller because tokenizer
|
|
129
|
+
* implementations vary widely between models — we don't want to
|
|
130
|
+
* pin a specific HF tokenizer package.
|
|
131
|
+
*/
|
|
132
|
+
export async function loadOnnxClassifier(opts) {
|
|
133
|
+
// Dynamic import keeps the dep optional — TypeScript can't see it
|
|
134
|
+
// at compile time, which is the whole point.
|
|
135
|
+
let ort;
|
|
136
|
+
try {
|
|
137
|
+
ort = await import("onnxruntime-node");
|
|
138
|
+
}
|
|
139
|
+
catch (err) {
|
|
140
|
+
throw new Error("ai-shield-classifier-onnx: 'onnxruntime-node' is required " +
|
|
141
|
+
"to call loadOnnxClassifier(). Install it as a peer dependency.\n" +
|
|
142
|
+
`Underlying error: ${err.message}`);
|
|
143
|
+
}
|
|
144
|
+
const ortModule = ort;
|
|
145
|
+
const create = ortModule.InferenceSession?.create;
|
|
146
|
+
if (typeof create !== "function") {
|
|
147
|
+
throw new Error("ai-shield-classifier-onnx: 'onnxruntime-node' did not expose InferenceSession.create");
|
|
148
|
+
}
|
|
149
|
+
const session = await create(opts.modelPath);
|
|
150
|
+
return new OnnxInjectionScanner({
|
|
151
|
+
session,
|
|
152
|
+
tokenizer: opts.tokenizer,
|
|
153
|
+
threshold: opts.threshold,
|
|
154
|
+
outputName: opts.outputName,
|
|
155
|
+
injectionClassIndex: opts.injectionClassIndex,
|
|
156
|
+
maxLength: opts.maxLength,
|
|
157
|
+
});
|
|
158
|
+
}
|
|
159
|
+
// --- helpers ---
|
|
160
|
+
/**
|
|
161
|
+
* Strip absolute paths and other deployment-internal strings from an
|
|
162
|
+
* ONNX-runtime error message before it lands in the audit log. Keeps
|
|
163
|
+
* the short error class / cause hint that helps diagnose the failure.
|
|
164
|
+
*/
|
|
165
|
+
function sanitizeOnnxErrorMessage(message) {
|
|
166
|
+
if (typeof message !== "string" || message.length === 0) {
|
|
167
|
+
return "classifier_runtime_error";
|
|
168
|
+
}
|
|
169
|
+
// Truncate before sanitizing — bounded work on adversarial input.
|
|
170
|
+
const truncated = message.length > 500 ? message.slice(0, 500) : message;
|
|
171
|
+
return truncated
|
|
172
|
+
// POSIX absolute paths
|
|
173
|
+
.replace(/(?:^|[\s(])(\/[\w./@-]+)/g, " [path]")
|
|
174
|
+
// Windows drive paths
|
|
175
|
+
.replace(/[A-Za-z]:\\[\\\w./@-]+/g, "[path]")
|
|
176
|
+
// file:// URLs
|
|
177
|
+
.replace(/file:\/\/\S+/g, "[file-url]")
|
|
178
|
+
// Memory addresses (0x...)
|
|
179
|
+
.replace(/0x[0-9a-fA-F]{6,}/g, "[addr]")
|
|
180
|
+
.trim();
|
|
181
|
+
}
|
|
182
|
+
function softmax(logits) {
|
|
183
|
+
if (logits.length === 0)
|
|
184
|
+
return [];
|
|
185
|
+
const max = Math.max(...logits);
|
|
186
|
+
const exps = logits.map((l) => Math.exp(l - max));
|
|
187
|
+
const sum = exps.reduce((a, b) => a + b, 0);
|
|
188
|
+
if (sum === 0)
|
|
189
|
+
return logits.map(() => 0);
|
|
190
|
+
return exps.map((e) => e / sum);
|
|
191
|
+
}
|
|
192
|
+
function truncate(tokens, maxLength) {
|
|
193
|
+
if (tokens.input_ids.length <= maxLength)
|
|
194
|
+
return tokens;
|
|
195
|
+
return {
|
|
196
|
+
input_ids: tokens.input_ids.slice(0, maxLength),
|
|
197
|
+
attention_mask: tokens.attention_mask.slice(0, maxLength),
|
|
198
|
+
token_type_ids: tokens.token_type_ids?.slice(0, maxLength),
|
|
199
|
+
};
|
|
200
|
+
}
|
|
201
|
+
//# sourceMappingURL=classifier.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"classifier.js","sourceRoot":"","sources":["../src/classifier.ts"],"names":[],"mappings":"AA2FA,MAAM,OAAO,oBAAoB;IACtB,IAAI,GAAG,iBAAiB,CAAC;IACjB,GAAG,CAEQ;IAE5B,YAAY,MAA4B;QACtC,IAAI,CAAC,MAAM,CAAC,OAAO,EAAE,CAAC;YACpB,MAAM,IAAI,SAAS,CAAC,6CAA6C,CAAC,CAAC;QACrE,CAAC;QACD,IAAI,CAAC,MAAM,CAAC,SAAS,EAAE,CAAC;YACtB,MAAM,IAAI,SAAS,CAAC,+CAA+C,CAAC,CAAC;QACvE,CAAC;QACD,IAAI,CAAC,GAAG,GAAG;YACT,OAAO,EAAE,MAAM,CAAC,OAAO;YACvB,SAAS,EAAE,MAAM,CAAC,SAAS;YAC3B,SAAS,EAAE,MAAM,CAAC,SAAS,IAAI,IAAI;YACnC,UAAU,EAAE,MAAM,CAAC,UAAU;YAC7B,mBAAmB,EAAE,MAAM,CAAC,mBAAmB,IAAI,CAAC;YACpD,SAAS,EACP,MAAM,CAAC,SAAS,IAAI,MAAM,CAAC,SAAS,CAAC,cAAc,IAAI,GAAG;SAC7D,CAAC;IACJ,CAAC;IAED,KAAK,CAAC,IAAI,CAAC,KAAa,EAAE,QAAqB;QAC7C,MAAM,KAAK,GAAG,WAAW,CAAC,GAAG,EAAE,CAAC;QAChC,IAAI,CAAC;YACH,MAAM,WAAW,GAAG,MAAM,IAAI,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC;YAC9C,MAAM,UAAU,GAAgB,EAAE,CAAC;YACnC,IAAI,QAAQ,GAA8B,OAAO,CAAC;YAElD,IAAI,WAAW,IAAI,IAAI,CAAC,GAAG,CAAC,SAAS,EAAE,CAAC;gBACtC,QAAQ,GAAG,OAAO,CAAC;gBACnB,UAAU,CAAC,IAAI,CAAC;oBACd,IAAI,EAAE,kBAAkB;oBACxB,OAAO,EAAE,IAAI,CAAC,IAAI;oBAClB,KAAK,EAAE,WAAW;oBAClB,SAAS,EAAE,IAAI,CAAC,GAAG,CAAC,SAAS;oBAC7B,OAAO,EAAE,wCAAwC;oBACjD,MAAM,EAAE,gBAAgB,WAAW,CAAC,OAAO,CAAC,CAAC,CAAC,cAAc,IAAI,CAAC,GAAG,CAAC,SAAS,EAAE;iBACjF,CAAC,CAAC;YACL,CAAC;iBAAM,IAAI,WAAW,IAAI,IAAI,CAAC,GAAG,CAAC,SAAS,GAAG,GAAG,EAAE,CAAC;gBACnD,QAAQ,GAAG,MAAM,CAAC;gBAClB,UAAU,CAAC,IAAI,CAAC;oBACd,IAAI,EAAE,kBAAkB;oBACxB,OAAO,EAAE,IAAI,CAAC,IAAI;oBAClB,KAAK,EAAE,WAAW;oBAClB,SAAS,EAAE,IAAI,CAAC,GAAG,CAAC,SAAS;oBAC7B,OAAO,EAAE,0CAA0C;oBACnD,MAAM,EAAE,gBAAgB,WAAW,CAAC,OAAO,CAAC,CAAC,CAAC,cAAc,IAAI,CAAC,GAAG,CAAC,SAAS,EAAE;iBACjF,CAAC,CAAC;YACL,CAAC;YAED,OAAO;gBACL,QAAQ;gBACR,UAAU;gBACV,UAAU,EAAE,WAAW,CAAC,GAAG,EAAE,GAAG,KAAK;aACtC,CAAC;QACJ,CAAC;QAAC,OAAO,GAAG,EAAE,CAAC;YACb,qEAAqE;YACrE,+DAA+D;YAC/D,iDAAiD;YACjD,EAAE;YACF,6DAA6D;YAC7D,iEAAiE;YACjE,mEAAmE;YACnE,yDAAyD;YACzD,MAAM,UAAU,GAAI,GAAa,EAAE,OAAO,IAAI,eAAe,CAAC;YAC9D,MAAM,KAAK,GACT,OAAO,CAAC,GAAG,CAAC,QAAQ,KAAK,aAAa;gBACtC,OAAO,CAAC,GAAG,CAAC,eAAe,KAAK,GAAG,CAAC;YACtC,MAAM,UAAU,GAAG,KAAK;gBACtB,CAAC,CAAC,UAAU;gBACZ,CAAC,CAAC,wBAAwB,CAAC,UAAU,CAAC,CAAC;YACzC,OAAO;gBACL,QAAQ,EAAE,OAAO;gBACjB,UAAU,EAAE;oBACV;wBACE,IAAI,EAAE,gBAAgB;wBACtB,OAAO,EAAE,IAAI,CAAC,IAAI;wBAClB,KAAK,EAAE,CAAC;wBACR,SAAS,EAAE,IAAI,CAAC,GAAG,CAAC,SAAS;wBAC7B,OAAO,EAAE,0CAA0C;wBACnD,MAAM,EAAE,UAAU;qBACnB;iBACF;gBACD,UAAU,EAAE,WAAW,CAAC,GAAG,EAAE,GAAG,KAAK;aACtC,CAAC;QACJ,CAAC;IACH,CAAC;IAED,mEAAmE;IACnE,KAAK,CAAC,OAAO,CAAC,KAAa;QACzB,MAAM,MAAM,GAAG,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC;QAChD,MAAM,KAAK,GAAG,QAAQ,CAAC,MAAM,EAAE,IAAI,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC;QAEnD,MAAM,QAAQ,GAAG,aAAa,CAAC,IAAI,CAAC,KAAK,CAAC,SAAS,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,MAAM,aAAa,GAAG,aAAa,CAAC,IAAI,CACtC,KAAK,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAC3C,CAAC;QACF,MAAM,IAAI,GAAG,CAAC,CAAC,EAAE,KAAK,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC;QAEzC,MAAM,KAAK,GAAmC;YAC5C,SAAS,EAAE,EAAE,IAAI,EAAE,QAAQ,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE;YAClD,cAAc,EAAE,EAAE,IAAI,EAAE,aAAa,EAAE,IAAI,EAAE,IAAI,EAAE,OAAO,EAAE;SAC7D,CAAC;QACF,IAAI,KAAK,CAAC,cAAc,EAAE,CAAC;YACzB,KAAK,CAAC,cAAc,GAAG;gBACrB,IAAI,EAAE,aAAa,CAAC,IAAI,CAAC,KAAK,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;gBACpE,IAAI;gBACJ,IAAI,EAAE,OAAO;aACd,CAAC;QACJ,CAAC;QAED,MAAM,MAAM,GAAG,MAAM,IAAI,CAAC,GAAG,CAAC,OAAO,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC;QACjD,MAAM,UAAU,GACd,IAAI,CAAC,GAAG,CAAC,UAAU;YACnB,IAAI,CAAC,GAAG,CAAC,OAAO,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC;YACjC,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;QACzB,IAAI,CAAC,UAAU,EAAE,CAAC;YAChB,MAAM,IAAI,KAAK,CAAC,gDAAgD,CAAC,CAAC;QACpE,CAAC;QACD,MAAM,MAAM,GAAG,MAAM,CAAC,UAAU,CAAC,CAAC;QAClC,IAAI,CAAC,MAAM,EAAE,CAAC;YACZ,MAAM,IAAI,KAAK,CACb,iCAAiC,UAAU,iBAAiB,CAC7D,CAAC;QACJ,CAAC;QAED,sEAAsE;QACtE,MAAM,MAAM,GAAG,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC,IAAyB,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CACpE,MAAM,CAAC,CAAC,CAAC,CACV,CAAC;QACF,MAAM,KAAK,GAAG,OAAO,CAAC,MAAM,CAAC,CAAC;QAC9B,MAAM,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,mBAAmB,CAAC;QACzC,IAAI,GAAG,GAAG,CAAC,IAAI,GAAG,IAAI,KAAK,CAAC,MAAM,EAAE,CAAC;YACnC,MAAM,IAAI,KAAK,CACb,6CAA6C,GAAG,sBAAsB,KAAK,CAAC,MAAM,GAAG,CACtF,CAAC;QACJ,CAAC;QACD,OAAO,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;IACzB,CAAC;CACF;AAED;;;;;;;GAOG;AACH,MAAM,CAAC,KAAK,UAAU,kBAAkB,CAAC,IAOxC;IACC,kEAAkE;IAClE,6CAA6C;IAC7C,IAAI,GAAY,CAAC;IACjB,IAAI,CAAC;QACH,GAAG,GAAG,MAAM,MAAM,CAAC,kBAA4B,CAAC,CAAC;IACnD,CAAC;IAAC,OAAO,GAAG,EAAE,CAAC;QACb,MAAM,IAAI,KAAK,CACb,4DAA4D;YAC1D,kEAAkE;YAClE,qBAAsB,GAAa,CAAC,OAAO,EAAE,CAChD,CAAC;IACJ,CAAC;IACD,MAAM,SAAS,GAAI,GAA2F,CAAC;IAC/G,MAAM,MAAM,GAAG,SAAS,CAAC,gBAAgB,EAAE,MAAM,CAAC;IAClD,IAAI,OAAO,MAAM,KAAK,UAAU,EAAE,CAAC;QACjC,MAAM,IAAI,KAAK,CACb,sFAAsF,CACvF,CAAC;IACJ,CAAC;IACD,MAAM,OAAO,GAAG,MAAM,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IAC7C,OAAO,IAAI,oBAAoB,CAAC;QAC9B,OAAO;QACP,SAAS,EAAE,IAAI,CAAC,SAAS;QACzB,SAAS,EAAE,IAAI,CAAC,SAAS;QACzB,UAAU,EAAE,IAAI,CAAC,UAAU;QAC3B,mBAAmB,EAAE,IAAI,CAAC,mBAAmB;QAC7C,SAAS,EAAE,IAAI,CAAC,SAAS;KAC1B,CAAC,CAAC;AACL,CAAC;AAED,kBAAkB;AAElB;;;;GAIG;AACH,SAAS,wBAAwB,CAAC,OAAe;IAC/C,IAAI,OAAO,OAAO,KAAK,QAAQ,IAAI,OAAO,CAAC,MAAM,KAAK,CAAC,EAAE,CAAC;QACxD,OAAO,0BAA0B,CAAC;IACpC,CAAC;IACD,kEAAkE;IAClE,MAAM,SAAS,GAAG,OAAO,CAAC,MAAM,GAAG,GAAG,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC;IACzE,OAAO,SAAS;QACd,uBAAuB;SACtB,OAAO,CAAC,2BAA2B,EAAE,SAAS,CAAC;QAChD,sBAAsB;SACrB,OAAO,CAAC,yBAAyB,EAAE,QAAQ,CAAC;QAC7C,eAAe;SACd,OAAO,CAAC,eAAe,EAAE,YAAY,CAAC;QACvC,2BAA2B;SAC1B,OAAO,CAAC,oBAAoB,EAAE,QAAQ,CAAC;SACvC,IAAI,EAAE,CAAC;AACZ,CAAC;AAED,SAAS,OAAO,CAAC,MAAgB;IAC/B,IAAI,MAAM,CAAC,MAAM,KAAK,CAAC;QAAE,OAAO,EAAE,CAAC;IACnC,MAAM,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,MAAM,CAAC,CAAC;IAChC,MAAM,IAAI,GAAG,MAAM,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC;IAClD,MAAM,GAAG,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC;IAC5C,IAAI,GAAG,KAAK,CAAC;QAAE,OAAO,MAAM,CAAC,GAAG,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,CAAC;IAC1C,OAAO,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC;AAClC,CAAC;AAED,SAAS,QAAQ,CACf,MAAuC,EACvC,SAAiB;IAEjB,IAAI,MAAM,CAAC,SAAS,CAAC,MAAM,IAAI,SAAS;QAAE,OAAO,MAAM,CAAC;IACxD,OAAO;QACL,SAAS,EAAE,MAAM,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC;QAC/C,cAAc,EAAE,MAAM,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC;QACzD,cAAc,EAAE,MAAM,CAAC,cAAc,EAAE,KAAK,CAAC,CAAC,EAAE,SAAS,CAAC;KAC3D,CAAC;AACJ,CAAC"}
|
package/dist/index.d.ts
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAuBA,OAAO,EACL,oBAAoB,EACpB,kBAAkB,EAClB,KAAK,oBAAoB,EACzB,KAAK,oBAAoB,EACzB,KAAK,SAAS,GACf,MAAM,iBAAiB,CAAC"}
|
|
@@ -20,11 +20,5 @@
|
|
|
20
20
|
// + tokenizer JSON, the wrapper handles inference + thresholding +
|
|
21
21
|
// integration with the Scanner interface.
|
|
22
22
|
// ============================================================
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
OnnxInjectionScanner,
|
|
26
|
-
loadOnnxClassifier,
|
|
27
|
-
type OnnxClassifierConfig,
|
|
28
|
-
type OnnxInferenceRuntime,
|
|
29
|
-
type Tokenizer,
|
|
30
|
-
} from "./classifier.js";
|
|
23
|
+
export { OnnxInjectionScanner, loadOnnxClassifier, } from "./classifier.js";
|
|
24
|
+
//# sourceMappingURL=index.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"index.js","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA,+DAA+D;AAC/D,6DAA6D;AAC7D,EAAE;AACF,yDAAyD;AACzD,kEAAkE;AAClE,gEAAgE;AAChE,EAAE;AACF,0BAA0B;AAC1B,yEAAyE;AACzE,mEAAmE;AACnE,uEAAuE;AACvE,gEAAgE;AAChE,EAAE;AACF,sBAAsB;AACtB,8DAA8D;AAC9D,oDAAoD;AACpD,mDAAmD;AACnD,EAAE;AACF,uEAAuE;AACvE,mEAAmE;AACnE,0CAA0C;AAC1C,+DAA+D;AAE/D,OAAO,EACL,oBAAoB,EACpB,kBAAkB,GAInB,MAAM,iBAAiB,CAAC"}
|
package/package.json
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "ai-shield-classifier-onnx",
|
|
3
|
-
"version": "0.
|
|
3
|
+
"version": "0.5.1",
|
|
4
4
|
"license": "MIT",
|
|
5
5
|
"description": "Optional ONNX ML classifier for ai-shield — DeBERTa-style prompt-injection detection alongside heuristic patterns",
|
|
6
6
|
"author": "StudioMeyer <hello@studiomeyer.io>",
|
|
7
7
|
"repository": {
|
|
8
8
|
"type": "git",
|
|
9
|
-
"url": "https://github.com/studiomeyer-io/ai-shield",
|
|
9
|
+
"url": "git+https://github.com/studiomeyer-io/ai-shield.git",
|
|
10
10
|
"directory": "packages/classifier-onnx"
|
|
11
11
|
},
|
|
12
12
|
"homepage": "https://github.com/studiomeyer-io/ai-shield/tree/main/packages/classifier-onnx",
|
|
@@ -24,6 +24,9 @@
|
|
|
24
24
|
"ai-shield"
|
|
25
25
|
],
|
|
26
26
|
"type": "module",
|
|
27
|
+
"files": [
|
|
28
|
+
"dist"
|
|
29
|
+
],
|
|
27
30
|
"main": "./dist/index.js",
|
|
28
31
|
"types": "./dist/index.d.ts",
|
|
29
32
|
"exports": {
|
|
@@ -37,7 +40,7 @@
|
|
|
37
40
|
"typecheck": "tsc -b"
|
|
38
41
|
},
|
|
39
42
|
"peerDependencies": {
|
|
40
|
-
"ai-shield-core": "0.
|
|
43
|
+
"ai-shield-core": "0.5.0",
|
|
41
44
|
"onnxruntime-node": ">=1.17.0"
|
|
42
45
|
},
|
|
43
46
|
"peerDependenciesMeta": {
|
|
@@ -46,7 +49,7 @@
|
|
|
46
49
|
}
|
|
47
50
|
},
|
|
48
51
|
"devDependencies": {
|
|
49
|
-
"ai-shield-core": "0.
|
|
52
|
+
"ai-shield-core": "0.5.0",
|
|
50
53
|
"typescript": "^5.7.0"
|
|
51
54
|
},
|
|
52
55
|
"engines": {
|
package/src/classifier.ts
DELETED
|
@@ -1,326 +0,0 @@
|
|
|
1
|
-
import type {
|
|
2
|
-
Scanner,
|
|
3
|
-
ScannerResult,
|
|
4
|
-
ScanContext,
|
|
5
|
-
Violation,
|
|
6
|
-
} from "ai-shield-core";
|
|
7
|
-
|
|
8
|
-
// ============================================================
|
|
9
|
-
// OnnxInjectionScanner — ML-backed prompt-injection detection
|
|
10
|
-
//
|
|
11
|
-
// Implements the `Scanner` interface from ai-shield-core. Designed to
|
|
12
|
-
// be added to a `ScannerChain` after the heuristic scanner so that
|
|
13
|
-
// known patterns short-circuit on cheap regex AND novel paraphrases
|
|
14
|
-
// still get a second-pass ML check.
|
|
15
|
-
//
|
|
16
|
-
// The runtime is abstracted via `OnnxInferenceRuntime` so this file
|
|
17
|
-
// has zero hard dependency on `onnxruntime-node` at type-check time.
|
|
18
|
-
// At runtime you pass in either:
|
|
19
|
-
// 1. An already-constructed `ort.InferenceSession`
|
|
20
|
-
// (`new OnnxInjectionScanner({ session, tokenizer })`)
|
|
21
|
-
// 2. A path to a `.onnx` model file + a path to a `tokenizer.json`
|
|
22
|
-
// (`await loadOnnxClassifier({ modelPath, tokenizerPath })`)
|
|
23
|
-
//
|
|
24
|
-
// Both paths keep the dep injection clean and unit-testable.
|
|
25
|
-
// ============================================================
|
|
26
|
-
|
|
27
|
-
/**
|
|
28
|
-
* Minimal subset of `onnxruntime-node`'s `InferenceSession` we use.
|
|
29
|
-
* Declaring it locally means this package type-checks even when
|
|
30
|
-
* `onnxruntime-node` is not installed (which is the entire point —
|
|
31
|
-
* it's an optional peer dep).
|
|
32
|
-
*/
|
|
33
|
-
export interface OnnxInferenceRuntime {
|
|
34
|
-
run(
|
|
35
|
-
feeds: Record<string, OnnxTensorLike>,
|
|
36
|
-
): Promise<Record<string, OnnxTensorLike>>;
|
|
37
|
-
readonly inputNames?: readonly string[];
|
|
38
|
-
readonly outputNames?: readonly string[];
|
|
39
|
-
}
|
|
40
|
-
|
|
41
|
-
/** Tensor descriptor compatible with `onnxruntime-common`'s Tensor. */
|
|
42
|
-
export interface OnnxTensorLike {
|
|
43
|
-
readonly data: ArrayLike<number> | BigInt64Array;
|
|
44
|
-
readonly dims: readonly number[];
|
|
45
|
-
readonly type?: string;
|
|
46
|
-
}
|
|
47
|
-
|
|
48
|
-
/**
|
|
49
|
-
* Tokenizer abstraction. Same trick as the runtime — we don't bind to
|
|
50
|
-
* any specific HF tokenizer package so the user can wire up
|
|
51
|
-
* `@huggingface/transformers`, `tokenizers`, or a hand-written one.
|
|
52
|
-
*/
|
|
53
|
-
export interface Tokenizer {
|
|
54
|
-
encode(text: string): {
|
|
55
|
-
input_ids: number[];
|
|
56
|
-
attention_mask: number[];
|
|
57
|
-
token_type_ids?: number[];
|
|
58
|
-
};
|
|
59
|
-
/** Optional max sequence length the tokenizer was trained for. */
|
|
60
|
-
modelMaxLength?: number;
|
|
61
|
-
}
|
|
62
|
-
|
|
63
|
-
export interface OnnxClassifierConfig {
|
|
64
|
-
/** A pre-constructed inference session. */
|
|
65
|
-
session: OnnxInferenceRuntime;
|
|
66
|
-
/** Tokenizer matching the model. */
|
|
67
|
-
tokenizer: Tokenizer;
|
|
68
|
-
/**
|
|
69
|
-
* Probability threshold above which the input is flagged as
|
|
70
|
-
* injection. Default 0.85 (calibrated for protectai/deberta-v3-base-
|
|
71
|
-
* prompt-injection — adjust per model).
|
|
72
|
-
*/
|
|
73
|
-
threshold?: number;
|
|
74
|
-
/**
|
|
75
|
-
* Name of the output node that carries logits/probabilities.
|
|
76
|
-
* Default: first key in the runtime's result map.
|
|
77
|
-
*/
|
|
78
|
-
outputName?: string;
|
|
79
|
-
/**
|
|
80
|
-
* Index of the "injection" class in the model output.
|
|
81
|
-
* Default: 1 (binary classifier convention: 0 = SAFE, 1 = INJECTION).
|
|
82
|
-
*/
|
|
83
|
-
injectionClassIndex?: number;
|
|
84
|
-
/**
|
|
85
|
-
* Maximum sequence length to feed. Default: 512.
|
|
86
|
-
* Larger inputs are truncated head-only (start kept) which matches
|
|
87
|
-
* the standard DeBERTa fine-tune recipe.
|
|
88
|
-
*/
|
|
89
|
-
maxLength?: number;
|
|
90
|
-
}
|
|
91
|
-
|
|
92
|
-
export class OnnxInjectionScanner implements Scanner {
|
|
93
|
-
readonly name = "onnx-classifier";
|
|
94
|
-
private readonly cfg: Required<
|
|
95
|
-
Omit<OnnxClassifierConfig, "outputName">
|
|
96
|
-
> & { outputName?: string };
|
|
97
|
-
|
|
98
|
-
constructor(config: OnnxClassifierConfig) {
|
|
99
|
-
if (!config.session) {
|
|
100
|
-
throw new TypeError("OnnxInjectionScanner: 'session' is required");
|
|
101
|
-
}
|
|
102
|
-
if (!config.tokenizer) {
|
|
103
|
-
throw new TypeError("OnnxInjectionScanner: 'tokenizer' is required");
|
|
104
|
-
}
|
|
105
|
-
this.cfg = {
|
|
106
|
-
session: config.session,
|
|
107
|
-
tokenizer: config.tokenizer,
|
|
108
|
-
threshold: config.threshold ?? 0.85,
|
|
109
|
-
outputName: config.outputName,
|
|
110
|
-
injectionClassIndex: config.injectionClassIndex ?? 1,
|
|
111
|
-
maxLength:
|
|
112
|
-
config.maxLength ?? config.tokenizer.modelMaxLength ?? 512,
|
|
113
|
-
};
|
|
114
|
-
}
|
|
115
|
-
|
|
116
|
-
async scan(input: string, _context: ScanContext): Promise<ScannerResult> {
|
|
117
|
-
const start = performance.now();
|
|
118
|
-
try {
|
|
119
|
-
const probability = await this.predict(input);
|
|
120
|
-
const violations: Violation[] = [];
|
|
121
|
-
let decision: ScannerResult["decision"] = "allow";
|
|
122
|
-
|
|
123
|
-
if (probability >= this.cfg.threshold) {
|
|
124
|
-
decision = "block";
|
|
125
|
-
violations.push({
|
|
126
|
-
type: "prompt_injection",
|
|
127
|
-
scanner: this.name,
|
|
128
|
-
score: probability,
|
|
129
|
-
threshold: this.cfg.threshold,
|
|
130
|
-
message: "ML classifier flagged prompt-injection",
|
|
131
|
-
detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
|
|
132
|
-
});
|
|
133
|
-
} else if (probability >= this.cfg.threshold * 0.6) {
|
|
134
|
-
decision = "warn";
|
|
135
|
-
violations.push({
|
|
136
|
-
type: "prompt_injection",
|
|
137
|
-
scanner: this.name,
|
|
138
|
-
score: probability,
|
|
139
|
-
threshold: this.cfg.threshold,
|
|
140
|
-
message: "ML classifier flagged borderline content",
|
|
141
|
-
detail: `p(injection)=${probability.toFixed(4)} threshold=${this.cfg.threshold}`,
|
|
142
|
-
});
|
|
143
|
-
}
|
|
144
|
-
|
|
145
|
-
return {
|
|
146
|
-
decision,
|
|
147
|
-
violations,
|
|
148
|
-
durationMs: performance.now() - start,
|
|
149
|
-
};
|
|
150
|
-
} catch (err) {
|
|
151
|
-
// ML errors must not take down the entire chain — degrade gracefully
|
|
152
|
-
// to "allow" with a synthetic violation so the audit log shows
|
|
153
|
-
// something went wrong without blocking traffic.
|
|
154
|
-
//
|
|
155
|
-
// Critic H4 round 1 — the raw error message can contain file
|
|
156
|
-
// paths (model location), native library symbols, or deployment-
|
|
157
|
-
// internal strings. Strip absolute paths before they hit the audit
|
|
158
|
-
// log. In dev mode we keep more detail to aid debugging.
|
|
159
|
-
const rawMessage = (err as Error)?.message ?? "unknown error";
|
|
160
|
-
const isDev =
|
|
161
|
-
process.env.NODE_ENV === "development" ||
|
|
162
|
-
process.env.AI_SHIELD_DEBUG === "1";
|
|
163
|
-
const safeDetail = isDev
|
|
164
|
-
? rawMessage
|
|
165
|
-
: sanitizeOnnxErrorMessage(rawMessage);
|
|
166
|
-
return {
|
|
167
|
-
decision: "allow",
|
|
168
|
-
violations: [
|
|
169
|
-
{
|
|
170
|
-
type: "content_policy",
|
|
171
|
-
scanner: this.name,
|
|
172
|
-
score: 0,
|
|
173
|
-
threshold: this.cfg.threshold,
|
|
174
|
-
message: "ML classifier failed — degraded to allow",
|
|
175
|
-
detail: safeDetail,
|
|
176
|
-
},
|
|
177
|
-
],
|
|
178
|
-
durationMs: performance.now() - start,
|
|
179
|
-
};
|
|
180
|
-
}
|
|
181
|
-
}
|
|
182
|
-
|
|
183
|
-
/** Direct probability access — useful for tests + custom flows. */
|
|
184
|
-
async predict(input: string): Promise<number> {
|
|
185
|
-
const tokens = this.cfg.tokenizer.encode(input);
|
|
186
|
-
const trunc = truncate(tokens, this.cfg.maxLength);
|
|
187
|
-
|
|
188
|
-
const inputIds = BigInt64Array.from(trunc.input_ids.map((n) => BigInt(n)));
|
|
189
|
-
const attentionMask = BigInt64Array.from(
|
|
190
|
-
trunc.attention_mask.map((n) => BigInt(n)),
|
|
191
|
-
);
|
|
192
|
-
const dims = [1, trunc.input_ids.length];
|
|
193
|
-
|
|
194
|
-
const feeds: Record<string, OnnxTensorLike> = {
|
|
195
|
-
input_ids: { data: inputIds, dims, type: "int64" },
|
|
196
|
-
attention_mask: { data: attentionMask, dims, type: "int64" },
|
|
197
|
-
};
|
|
198
|
-
if (trunc.token_type_ids) {
|
|
199
|
-
feeds.token_type_ids = {
|
|
200
|
-
data: BigInt64Array.from(trunc.token_type_ids.map((n) => BigInt(n))),
|
|
201
|
-
dims,
|
|
202
|
-
type: "int64",
|
|
203
|
-
};
|
|
204
|
-
}
|
|
205
|
-
|
|
206
|
-
const result = await this.cfg.session.run(feeds);
|
|
207
|
-
const outputName =
|
|
208
|
-
this.cfg.outputName ??
|
|
209
|
-
this.cfg.session.outputNames?.[0] ??
|
|
210
|
-
Object.keys(result)[0];
|
|
211
|
-
if (!outputName) {
|
|
212
|
-
throw new Error("OnnxInjectionScanner: no output node available");
|
|
213
|
-
}
|
|
214
|
-
const tensor = result[outputName];
|
|
215
|
-
if (!tensor) {
|
|
216
|
-
throw new Error(
|
|
217
|
-
`OnnxInjectionScanner: output '${outputName}' not in result`,
|
|
218
|
-
);
|
|
219
|
-
}
|
|
220
|
-
|
|
221
|
-
// Model emits logits of shape [1, num_classes]. Softmax + pick class.
|
|
222
|
-
const logits = Array.from(tensor.data as ArrayLike<number>).map((n) =>
|
|
223
|
-
Number(n),
|
|
224
|
-
);
|
|
225
|
-
const probs = softmax(logits);
|
|
226
|
-
const idx = this.cfg.injectionClassIndex;
|
|
227
|
-
if (idx < 0 || idx >= probs.length) {
|
|
228
|
-
throw new Error(
|
|
229
|
-
`OnnxInjectionScanner: injectionClassIndex ${idx} out of range (len=${probs.length})`,
|
|
230
|
-
);
|
|
231
|
-
}
|
|
232
|
-
return probs[idx] ?? 0;
|
|
233
|
-
}
|
|
234
|
-
}
|
|
235
|
-
|
|
236
|
-
/**
|
|
237
|
-
* Convenience loader. Imports `onnxruntime-node` *at runtime* so
|
|
238
|
-
* consumers who never call this function don't pay the install cost.
|
|
239
|
-
*
|
|
240
|
-
* Tokenizer loading is left to the caller because tokenizer
|
|
241
|
-
* implementations vary widely between models — we don't want to
|
|
242
|
-
* pin a specific HF tokenizer package.
|
|
243
|
-
*/
|
|
244
|
-
export async function loadOnnxClassifier(opts: {
|
|
245
|
-
modelPath: string;
|
|
246
|
-
tokenizer: Tokenizer;
|
|
247
|
-
threshold?: number;
|
|
248
|
-
outputName?: string;
|
|
249
|
-
injectionClassIndex?: number;
|
|
250
|
-
maxLength?: number;
|
|
251
|
-
}): Promise<OnnxInjectionScanner> {
|
|
252
|
-
// Dynamic import keeps the dep optional — TypeScript can't see it
|
|
253
|
-
// at compile time, which is the whole point.
|
|
254
|
-
let ort: unknown;
|
|
255
|
-
try {
|
|
256
|
-
ort = await import("onnxruntime-node" as string);
|
|
257
|
-
} catch (err) {
|
|
258
|
-
throw new Error(
|
|
259
|
-
"ai-shield-classifier-onnx: 'onnxruntime-node' is required " +
|
|
260
|
-
"to call loadOnnxClassifier(). Install it as a peer dependency.\n" +
|
|
261
|
-
`Underlying error: ${(err as Error).message}`,
|
|
262
|
-
);
|
|
263
|
-
}
|
|
264
|
-
const ortModule = (ort as { InferenceSession?: { create?: (path: string) => Promise<OnnxInferenceRuntime> } });
|
|
265
|
-
const create = ortModule.InferenceSession?.create;
|
|
266
|
-
if (typeof create !== "function") {
|
|
267
|
-
throw new Error(
|
|
268
|
-
"ai-shield-classifier-onnx: 'onnxruntime-node' did not expose InferenceSession.create",
|
|
269
|
-
);
|
|
270
|
-
}
|
|
271
|
-
const session = await create(opts.modelPath);
|
|
272
|
-
return new OnnxInjectionScanner({
|
|
273
|
-
session,
|
|
274
|
-
tokenizer: opts.tokenizer,
|
|
275
|
-
threshold: opts.threshold,
|
|
276
|
-
outputName: opts.outputName,
|
|
277
|
-
injectionClassIndex: opts.injectionClassIndex,
|
|
278
|
-
maxLength: opts.maxLength,
|
|
279
|
-
});
|
|
280
|
-
}
|
|
281
|
-
|
|
282
|
-
// --- helpers ---
|
|
283
|
-
|
|
284
|
-
/**
|
|
285
|
-
* Strip absolute paths and other deployment-internal strings from an
|
|
286
|
-
* ONNX-runtime error message before it lands in the audit log. Keeps
|
|
287
|
-
* the short error class / cause hint that helps diagnose the failure.
|
|
288
|
-
*/
|
|
289
|
-
function sanitizeOnnxErrorMessage(message: string): string {
|
|
290
|
-
if (typeof message !== "string" || message.length === 0) {
|
|
291
|
-
return "classifier_runtime_error";
|
|
292
|
-
}
|
|
293
|
-
// Truncate before sanitizing — bounded work on adversarial input.
|
|
294
|
-
const truncated = message.length > 500 ? message.slice(0, 500) : message;
|
|
295
|
-
return truncated
|
|
296
|
-
// POSIX absolute paths
|
|
297
|
-
.replace(/(?:^|[\s(])(\/[\w./@-]+)/g, " [path]")
|
|
298
|
-
// Windows drive paths
|
|
299
|
-
.replace(/[A-Za-z]:\\[\\\w./@-]+/g, "[path]")
|
|
300
|
-
// file:// URLs
|
|
301
|
-
.replace(/file:\/\/\S+/g, "[file-url]")
|
|
302
|
-
// Memory addresses (0x...)
|
|
303
|
-
.replace(/0x[0-9a-fA-F]{6,}/g, "[addr]")
|
|
304
|
-
.trim();
|
|
305
|
-
}
|
|
306
|
-
|
|
307
|
-
function softmax(logits: number[]): number[] {
|
|
308
|
-
if (logits.length === 0) return [];
|
|
309
|
-
const max = Math.max(...logits);
|
|
310
|
-
const exps = logits.map((l) => Math.exp(l - max));
|
|
311
|
-
const sum = exps.reduce((a, b) => a + b, 0);
|
|
312
|
-
if (sum === 0) return logits.map(() => 0);
|
|
313
|
-
return exps.map((e) => e / sum);
|
|
314
|
-
}
|
|
315
|
-
|
|
316
|
-
function truncate(
|
|
317
|
-
tokens: ReturnType<Tokenizer["encode"]>,
|
|
318
|
-
maxLength: number,
|
|
319
|
-
): ReturnType<Tokenizer["encode"]> {
|
|
320
|
-
if (tokens.input_ids.length <= maxLength) return tokens;
|
|
321
|
-
return {
|
|
322
|
-
input_ids: tokens.input_ids.slice(0, maxLength),
|
|
323
|
-
attention_mask: tokens.attention_mask.slice(0, maxLength),
|
|
324
|
-
token_type_ids: tokens.token_type_ids?.slice(0, maxLength),
|
|
325
|
-
};
|
|
326
|
-
}
|