ai-shield-middleware 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +43 -0
- package/src/express.ts +94 -0
- package/src/hono.ts +99 -0
- package/src/index.ts +13 -0
- package/src/shared.ts +102 -0
- package/tsconfig.json +8 -0
package/package.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "ai-shield-middleware",
|
|
3
|
+
"version": "0.1.0",
|
|
4
|
+
"description": "AI Shield middleware for Express and Hono — route-level LLM protection",
|
|
5
|
+
"type": "module",
|
|
6
|
+
"main": "./dist/index.js",
|
|
7
|
+
"types": "./dist/index.d.ts",
|
|
8
|
+
"exports": {
|
|
9
|
+
".": {
|
|
10
|
+
"types": "./dist/index.d.ts",
|
|
11
|
+
"import": "./dist/index.js"
|
|
12
|
+
},
|
|
13
|
+
"./express": {
|
|
14
|
+
"types": "./dist/express.d.ts",
|
|
15
|
+
"import": "./dist/express.js"
|
|
16
|
+
},
|
|
17
|
+
"./hono": {
|
|
18
|
+
"types": "./dist/hono.d.ts",
|
|
19
|
+
"import": "./dist/hono.js"
|
|
20
|
+
}
|
|
21
|
+
},
|
|
22
|
+
"scripts": {
|
|
23
|
+
"build": "tsc",
|
|
24
|
+
"typecheck": "tsc --noEmit"
|
|
25
|
+
},
|
|
26
|
+
"peerDependencies": {
|
|
27
|
+
"express": ">=4.0.0",
|
|
28
|
+
"hono": ">=4.0.0"
|
|
29
|
+
},
|
|
30
|
+
"peerDependenciesMeta": {
|
|
31
|
+
"express": { "optional": true },
|
|
32
|
+
"hono": { "optional": true }
|
|
33
|
+
},
|
|
34
|
+
"dependencies": {
|
|
35
|
+
"ai-shield-core": "0.1.0"
|
|
36
|
+
},
|
|
37
|
+
"devDependencies": {
|
|
38
|
+
"@types/express": "^5.0.0",
|
|
39
|
+
"express": "^4.21.0",
|
|
40
|
+
"hono": "^4.7.0",
|
|
41
|
+
"typescript": "^5.7.0"
|
|
42
|
+
}
|
|
43
|
+
}
|
package/src/express.ts
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import type { ScanResult } from "ai-shield-core";
|
|
2
|
+
import type { ShieldMiddlewareConfig } from "./shared.js";
|
|
3
|
+
import { defaultGetInput, scanRequest } from "./shared.js";
|
|
4
|
+
|
|
5
|
+
// ============================================================
|
|
6
|
+
// Express Middleware — AI Shield route guard
|
|
7
|
+
// ============================================================
|
|
8
|
+
|
|
9
|
+
// Minimal Express types (avoid hard dependency on @types/express)
|
|
10
|
+
interface ExpressRequest {
|
|
11
|
+
body?: unknown;
|
|
12
|
+
path: string;
|
|
13
|
+
url: string;
|
|
14
|
+
method: string;
|
|
15
|
+
headers: Record<string, string | string[] | undefined>;
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
interface ExpressResponse {
|
|
19
|
+
status(code: number): ExpressResponse;
|
|
20
|
+
json(body: unknown): void;
|
|
21
|
+
locals: Record<string, unknown>;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
type NextFunction = (err?: unknown) => void;
|
|
25
|
+
type ExpressMiddleware = (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => void;
|
|
26
|
+
|
|
27
|
+
/**
|
|
28
|
+
* Express middleware that scans request body for prompt injection and PII.
|
|
29
|
+
*
|
|
30
|
+
* @example
|
|
31
|
+
* ```ts
|
|
32
|
+
* import express from "express";
|
|
33
|
+
* import { shieldMiddleware } from "@ai-shield/middleware/express";
|
|
34
|
+
*
|
|
35
|
+
* const app = express();
|
|
36
|
+
* app.use(express.json());
|
|
37
|
+
*
|
|
38
|
+
* // Protect all /api/chat routes
|
|
39
|
+
* app.use("/api/chat", shieldMiddleware({
|
|
40
|
+
* shield: { injection: { strictness: "high" } },
|
|
41
|
+
* skipPaths: ["/api/chat/health"],
|
|
42
|
+
* }));
|
|
43
|
+
*
|
|
44
|
+
* // Access scan result in route handler
|
|
45
|
+
* app.post("/api/chat", (req, res) => {
|
|
46
|
+
* const shieldResult = res.locals.shieldResult as ScanResult;
|
|
47
|
+
* // shieldResult.sanitized has PII masked
|
|
48
|
+
* });
|
|
49
|
+
* ```
|
|
50
|
+
*/
|
|
51
|
+
export function shieldMiddleware(config: ShieldMiddlewareConfig = {}): ExpressMiddleware {
|
|
52
|
+
return (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
|
|
53
|
+
// Skip non-mutating methods
|
|
54
|
+
if (req.method === "GET" || req.method === "HEAD" || req.method === "OPTIONS") {
|
|
55
|
+
return next();
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
// Skip configured paths
|
|
59
|
+
if (config.skipPaths?.some((p) => req.path.startsWith(p))) {
|
|
60
|
+
return next();
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
const getInput = config.getInput ?? defaultGetInput;
|
|
64
|
+
const input = getInput(req.body);
|
|
65
|
+
|
|
66
|
+
// No scannable content — pass through
|
|
67
|
+
if (!input) {
|
|
68
|
+
return next();
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
const context = config.getContext?.(req) ?? {};
|
|
72
|
+
if (config.getAgentId) {
|
|
73
|
+
context.agentId = config.getAgentId(req);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// Async scan
|
|
77
|
+
void scanRequest(config, input, context)
|
|
78
|
+
.then(({ blocked, result, response }) => {
|
|
79
|
+
if (blocked && response) {
|
|
80
|
+
res.status(response.status).json(response.body);
|
|
81
|
+
return;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// Attach result to res.locals for downstream handlers
|
|
85
|
+
res.locals.shieldResult = result;
|
|
86
|
+
next();
|
|
87
|
+
})
|
|
88
|
+
.catch((err: unknown) => {
|
|
89
|
+
next(err);
|
|
90
|
+
});
|
|
91
|
+
};
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
export type { ShieldMiddlewareConfig, ScanResult };
|
package/src/hono.ts
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import type { ScanResult } from "ai-shield-core";
|
|
2
|
+
import type { ShieldMiddlewareConfig } from "./shared.js";
|
|
3
|
+
import { defaultGetInput, scanRequest } from "./shared.js";
|
|
4
|
+
|
|
5
|
+
// ============================================================
|
|
6
|
+
// Hono Middleware — AI Shield route guard
|
|
7
|
+
// ============================================================
|
|
8
|
+
|
|
9
|
+
// Minimal Hono types (avoid hard dependency)
|
|
10
|
+
interface HonoContext {
|
|
11
|
+
req: {
|
|
12
|
+
method: string;
|
|
13
|
+
path: string;
|
|
14
|
+
url: string;
|
|
15
|
+
header(name: string): string | undefined;
|
|
16
|
+
json(): Promise<unknown>;
|
|
17
|
+
raw: { headers: Headers };
|
|
18
|
+
};
|
|
19
|
+
json(data: unknown, status?: number): Response;
|
|
20
|
+
set(key: string, value: unknown): void;
|
|
21
|
+
get(key: string): unknown;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
type HonoNext = () => Promise<void>;
|
|
25
|
+
type HonoMiddleware = (c: HonoContext, next: HonoNext) => Promise<Response | void>;
|
|
26
|
+
|
|
27
|
+
/**
|
|
28
|
+
* Hono middleware that scans request body for prompt injection and PII.
|
|
29
|
+
*
|
|
30
|
+
* @example
|
|
31
|
+
* ```ts
|
|
32
|
+
* import { Hono } from "hono";
|
|
33
|
+
* import { shieldMiddleware } from "@ai-shield/middleware/hono";
|
|
34
|
+
*
|
|
35
|
+
* const app = new Hono();
|
|
36
|
+
*
|
|
37
|
+
* // Protect chat routes
|
|
38
|
+
* app.use("/api/chat/*", shieldMiddleware({
|
|
39
|
+
* shield: { injection: { strictness: "high" } },
|
|
40
|
+
* }));
|
|
41
|
+
*
|
|
42
|
+
* app.post("/api/chat", async (c) => {
|
|
43
|
+
* const shieldResult = c.get("shieldResult") as ScanResult;
|
|
44
|
+
* // Use shieldResult.sanitized for PII-masked input
|
|
45
|
+
* });
|
|
46
|
+
* ```
|
|
47
|
+
*/
|
|
48
|
+
export function shieldMiddleware(config: ShieldMiddlewareConfig = {}): HonoMiddleware {
|
|
49
|
+
return async (c: HonoContext, next: HonoNext): Promise<Response | void> => {
|
|
50
|
+
// Skip non-mutating methods
|
|
51
|
+
if (c.req.method === "GET" || c.req.method === "HEAD" || c.req.method === "OPTIONS") {
|
|
52
|
+
return next();
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
// Skip configured paths
|
|
56
|
+
if (config.skipPaths?.some((p) => c.req.path.startsWith(p))) {
|
|
57
|
+
return next();
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
// Parse body
|
|
61
|
+
let body: unknown;
|
|
62
|
+
try {
|
|
63
|
+
body = await c.req.json();
|
|
64
|
+
} catch {
|
|
65
|
+
// No JSON body — pass through
|
|
66
|
+
return next();
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
const getInput = config.getInput ?? defaultGetInput;
|
|
70
|
+
const input = getInput(body);
|
|
71
|
+
|
|
72
|
+
if (!input) {
|
|
73
|
+
return next();
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// Build context from headers
|
|
77
|
+
const headers: Record<string, string | string[] | undefined> = {};
|
|
78
|
+
c.req.raw.headers.forEach((value, key) => {
|
|
79
|
+
headers[key] = value;
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
const context = config.getContext?.({ headers, body }) ?? {};
|
|
83
|
+
if (config.getAgentId) {
|
|
84
|
+
context.agentId = config.getAgentId({ headers, path: c.req.path, url: c.req.url });
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
const { blocked, result, response } = await scanRequest(config, input, context);
|
|
88
|
+
|
|
89
|
+
if (blocked && response) {
|
|
90
|
+
return c.json(response.body, response.status);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
// Attach result for downstream handlers
|
|
94
|
+
c.set("shieldResult", result);
|
|
95
|
+
return next();
|
|
96
|
+
};
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
export type { ShieldMiddlewareConfig, ScanResult };
|
package/src/index.ts
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
// ============================================================
|
|
2
|
+
// @ai-shield/middleware — Public API
|
|
3
|
+
// ============================================================
|
|
4
|
+
|
|
5
|
+
// Shared types and utilities
|
|
6
|
+
export { defaultGetInput, defaultBlockedResponse } from "./shared.js";
|
|
7
|
+
export type { ShieldMiddlewareConfig } from "./shared.js";
|
|
8
|
+
|
|
9
|
+
// Framework-specific exports
|
|
10
|
+
export { shieldMiddleware as expressShield } from "./express.js";
|
|
11
|
+
export { shieldMiddleware as honoShield } from "./hono.js";
|
|
12
|
+
|
|
13
|
+
export type { ScanResult } from "ai-shield-core";
|
package/src/shared.ts
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import type { AIShield, ShieldConfig, ScanContext, ScanResult } from "ai-shield-core";
|
|
2
|
+
|
|
3
|
+
// ============================================================
|
|
4
|
+
// Shared middleware logic — used by Express and Hono adapters
|
|
5
|
+
// ============================================================
|
|
6
|
+
|
|
7
|
+
export interface ShieldMiddlewareConfig {
|
|
8
|
+
/** AI Shield config */
|
|
9
|
+
shield?: ShieldConfig;
|
|
10
|
+
/** Pre-created AIShield instance (shared across routes) */
|
|
11
|
+
shieldInstance?: AIShield;
|
|
12
|
+
/** Extract agent ID from request */
|
|
13
|
+
getAgentId?: (req: { headers: Record<string, string | string[] | undefined>; path?: string; url?: string }) => string | undefined;
|
|
14
|
+
/** Extract scan context from request */
|
|
15
|
+
getContext?: (req: { headers: Record<string, string | string[] | undefined>; body?: unknown }) => ScanContext;
|
|
16
|
+
/** Extract text to scan from request body */
|
|
17
|
+
getInput?: (body: unknown) => string | null;
|
|
18
|
+
/** Custom blocked response */
|
|
19
|
+
onBlocked?: (result: ScanResult) => { status: number; body: unknown };
|
|
20
|
+
/** Called on warnings (non-blocking) */
|
|
21
|
+
onWarning?: (result: ScanResult) => void;
|
|
22
|
+
/** Skip scanning for certain paths */
|
|
23
|
+
skipPaths?: string[];
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
/** Default: extract text from common chat API body shapes */
|
|
27
|
+
export function defaultGetInput(body: unknown): string | null {
|
|
28
|
+
if (!body || typeof body !== "object") return null;
|
|
29
|
+
|
|
30
|
+
const obj = body as Record<string, unknown>;
|
|
31
|
+
|
|
32
|
+
// Direct message field
|
|
33
|
+
if (typeof obj.message === "string") return obj.message;
|
|
34
|
+
if (typeof obj.input === "string") return obj.input;
|
|
35
|
+
if (typeof obj.prompt === "string") return obj.prompt;
|
|
36
|
+
if (typeof obj.text === "string") return obj.text;
|
|
37
|
+
if (typeof obj.content === "string") return obj.content;
|
|
38
|
+
if (typeof obj.query === "string") return obj.query;
|
|
39
|
+
|
|
40
|
+
// OpenAI-style messages array
|
|
41
|
+
if (Array.isArray(obj.messages)) {
|
|
42
|
+
const userMessages = (obj.messages as Array<{ role?: string; content?: string }>)
|
|
43
|
+
.filter((m) => m.role === "user" && typeof m.content === "string")
|
|
44
|
+
.map((m) => m.content as string);
|
|
45
|
+
if (userMessages.length > 0) return userMessages.join("\n");
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return null;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
/** Default blocked response */
|
|
52
|
+
export function defaultBlockedResponse(result: ScanResult): { status: number; body: unknown } {
|
|
53
|
+
return {
|
|
54
|
+
status: 403,
|
|
55
|
+
body: {
|
|
56
|
+
error: "Request blocked by AI Shield",
|
|
57
|
+
decision: result.decision,
|
|
58
|
+
violations: result.violations.map((v) => ({
|
|
59
|
+
type: v.type,
|
|
60
|
+
message: v.message,
|
|
61
|
+
})),
|
|
62
|
+
},
|
|
63
|
+
};
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
/** Lazy-load AIShield instance */
|
|
67
|
+
let _sharedShield: AIShield | null = null;
|
|
68
|
+
let _shieldReady: Promise<AIShield> | null = null;
|
|
69
|
+
|
|
70
|
+
export async function getOrCreateShield(config: ShieldMiddlewareConfig): Promise<AIShield> {
|
|
71
|
+
if (config.shieldInstance) return config.shieldInstance;
|
|
72
|
+
if (_sharedShield) return _sharedShield;
|
|
73
|
+
if (_shieldReady) return _shieldReady;
|
|
74
|
+
|
|
75
|
+
_shieldReady = import("ai-shield-core").then((mod) => {
|
|
76
|
+
_sharedShield = new mod.AIShield(config.shield ?? {});
|
|
77
|
+
return _sharedShield;
|
|
78
|
+
});
|
|
79
|
+
|
|
80
|
+
return _shieldReady;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/** Core scan logic shared between Express and Hono */
|
|
84
|
+
export async function scanRequest(
|
|
85
|
+
config: ShieldMiddlewareConfig,
|
|
86
|
+
input: string,
|
|
87
|
+
context: ScanContext,
|
|
88
|
+
): Promise<{ blocked: boolean; result: ScanResult; response?: { status: number; body: unknown } }> {
|
|
89
|
+
const shield = await getOrCreateShield(config);
|
|
90
|
+
const result = await shield.scan(input, context);
|
|
91
|
+
|
|
92
|
+
if (result.decision === "block") {
|
|
93
|
+
const onBlocked = config.onBlocked ?? defaultBlockedResponse;
|
|
94
|
+
return { blocked: true, result, response: onBlocked(result) };
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
if (result.decision === "warn") {
|
|
98
|
+
config.onWarning?.(result);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
return { blocked: false, result };
|
|
102
|
+
}
|