ai-shield-middleware 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json ADDED
@@ -0,0 +1,43 @@
1
+ {
2
+ "name": "ai-shield-middleware",
3
+ "version": "0.1.0",
4
+ "description": "AI Shield middleware for Express and Hono — route-level LLM protection",
5
+ "type": "module",
6
+ "main": "./dist/index.js",
7
+ "types": "./dist/index.d.ts",
8
+ "exports": {
9
+ ".": {
10
+ "types": "./dist/index.d.ts",
11
+ "import": "./dist/index.js"
12
+ },
13
+ "./express": {
14
+ "types": "./dist/express.d.ts",
15
+ "import": "./dist/express.js"
16
+ },
17
+ "./hono": {
18
+ "types": "./dist/hono.d.ts",
19
+ "import": "./dist/hono.js"
20
+ }
21
+ },
22
+ "scripts": {
23
+ "build": "tsc",
24
+ "typecheck": "tsc --noEmit"
25
+ },
26
+ "peerDependencies": {
27
+ "express": ">=4.0.0",
28
+ "hono": ">=4.0.0"
29
+ },
30
+ "peerDependenciesMeta": {
31
+ "express": { "optional": true },
32
+ "hono": { "optional": true }
33
+ },
34
+ "dependencies": {
35
+ "ai-shield-core": "0.1.0"
36
+ },
37
+ "devDependencies": {
38
+ "@types/express": "^5.0.0",
39
+ "express": "^4.21.0",
40
+ "hono": "^4.7.0",
41
+ "typescript": "^5.7.0"
42
+ }
43
+ }
package/src/express.ts ADDED
@@ -0,0 +1,94 @@
1
+ import type { ScanResult } from "ai-shield-core";
2
+ import type { ShieldMiddlewareConfig } from "./shared.js";
3
+ import { defaultGetInput, scanRequest } from "./shared.js";
4
+
5
+ // ============================================================
6
+ // Express Middleware — AI Shield route guard
7
+ // ============================================================
8
+
9
+ // Minimal Express types (avoid hard dependency on @types/express)
10
+ interface ExpressRequest {
11
+ body?: unknown;
12
+ path: string;
13
+ url: string;
14
+ method: string;
15
+ headers: Record<string, string | string[] | undefined>;
16
+ }
17
+
18
+ interface ExpressResponse {
19
+ status(code: number): ExpressResponse;
20
+ json(body: unknown): void;
21
+ locals: Record<string, unknown>;
22
+ }
23
+
24
+ type NextFunction = (err?: unknown) => void;
25
+ type ExpressMiddleware = (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => void;
26
+
27
+ /**
28
+ * Express middleware that scans request body for prompt injection and PII.
29
+ *
30
+ * @example
31
+ * ```ts
32
+ * import express from "express";
33
+ * import { shieldMiddleware } from "@ai-shield/middleware/express";
34
+ *
35
+ * const app = express();
36
+ * app.use(express.json());
37
+ *
38
+ * // Protect all /api/chat routes
39
+ * app.use("/api/chat", shieldMiddleware({
40
+ * shield: { injection: { strictness: "high" } },
41
+ * skipPaths: ["/api/chat/health"],
42
+ * }));
43
+ *
44
+ * // Access scan result in route handler
45
+ * app.post("/api/chat", (req, res) => {
46
+ * const shieldResult = res.locals.shieldResult as ScanResult;
47
+ * // shieldResult.sanitized has PII masked
48
+ * });
49
+ * ```
50
+ */
51
+ export function shieldMiddleware(config: ShieldMiddlewareConfig = {}): ExpressMiddleware {
52
+ return (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
53
+ // Skip non-mutating methods
54
+ if (req.method === "GET" || req.method === "HEAD" || req.method === "OPTIONS") {
55
+ return next();
56
+ }
57
+
58
+ // Skip configured paths
59
+ if (config.skipPaths?.some((p) => req.path.startsWith(p))) {
60
+ return next();
61
+ }
62
+
63
+ const getInput = config.getInput ?? defaultGetInput;
64
+ const input = getInput(req.body);
65
+
66
+ // No scannable content — pass through
67
+ if (!input) {
68
+ return next();
69
+ }
70
+
71
+ const context = config.getContext?.(req) ?? {};
72
+ if (config.getAgentId) {
73
+ context.agentId = config.getAgentId(req);
74
+ }
75
+
76
+ // Async scan
77
+ void scanRequest(config, input, context)
78
+ .then(({ blocked, result, response }) => {
79
+ if (blocked && response) {
80
+ res.status(response.status).json(response.body);
81
+ return;
82
+ }
83
+
84
+ // Attach result to res.locals for downstream handlers
85
+ res.locals.shieldResult = result;
86
+ next();
87
+ })
88
+ .catch((err: unknown) => {
89
+ next(err);
90
+ });
91
+ };
92
+ }
93
+
94
+ export type { ShieldMiddlewareConfig, ScanResult };
package/src/hono.ts ADDED
@@ -0,0 +1,99 @@
1
+ import type { ScanResult } from "ai-shield-core";
2
+ import type { ShieldMiddlewareConfig } from "./shared.js";
3
+ import { defaultGetInput, scanRequest } from "./shared.js";
4
+
5
+ // ============================================================
6
+ // Hono Middleware — AI Shield route guard
7
+ // ============================================================
8
+
9
+ // Minimal Hono types (avoid hard dependency)
10
+ interface HonoContext {
11
+ req: {
12
+ method: string;
13
+ path: string;
14
+ url: string;
15
+ header(name: string): string | undefined;
16
+ json(): Promise<unknown>;
17
+ raw: { headers: Headers };
18
+ };
19
+ json(data: unknown, status?: number): Response;
20
+ set(key: string, value: unknown): void;
21
+ get(key: string): unknown;
22
+ }
23
+
24
+ type HonoNext = () => Promise<void>;
25
+ type HonoMiddleware = (c: HonoContext, next: HonoNext) => Promise<Response | void>;
26
+
27
+ /**
28
+ * Hono middleware that scans request body for prompt injection and PII.
29
+ *
30
+ * @example
31
+ * ```ts
32
+ * import { Hono } from "hono";
33
+ * import { shieldMiddleware } from "@ai-shield/middleware/hono";
34
+ *
35
+ * const app = new Hono();
36
+ *
37
+ * // Protect chat routes
38
+ * app.use("/api/chat/*", shieldMiddleware({
39
+ * shield: { injection: { strictness: "high" } },
40
+ * }));
41
+ *
42
+ * app.post("/api/chat", async (c) => {
43
+ * const shieldResult = c.get("shieldResult") as ScanResult;
44
+ * // Use shieldResult.sanitized for PII-masked input
45
+ * });
46
+ * ```
47
+ */
48
+ export function shieldMiddleware(config: ShieldMiddlewareConfig = {}): HonoMiddleware {
49
+ return async (c: HonoContext, next: HonoNext): Promise<Response | void> => {
50
+ // Skip non-mutating methods
51
+ if (c.req.method === "GET" || c.req.method === "HEAD" || c.req.method === "OPTIONS") {
52
+ return next();
53
+ }
54
+
55
+ // Skip configured paths
56
+ if (config.skipPaths?.some((p) => c.req.path.startsWith(p))) {
57
+ return next();
58
+ }
59
+
60
+ // Parse body
61
+ let body: unknown;
62
+ try {
63
+ body = await c.req.json();
64
+ } catch {
65
+ // No JSON body — pass through
66
+ return next();
67
+ }
68
+
69
+ const getInput = config.getInput ?? defaultGetInput;
70
+ const input = getInput(body);
71
+
72
+ if (!input) {
73
+ return next();
74
+ }
75
+
76
+ // Build context from headers
77
+ const headers: Record<string, string | string[] | undefined> = {};
78
+ c.req.raw.headers.forEach((value, key) => {
79
+ headers[key] = value;
80
+ });
81
+
82
+ const context = config.getContext?.({ headers, body }) ?? {};
83
+ if (config.getAgentId) {
84
+ context.agentId = config.getAgentId({ headers, path: c.req.path, url: c.req.url });
85
+ }
86
+
87
+ const { blocked, result, response } = await scanRequest(config, input, context);
88
+
89
+ if (blocked && response) {
90
+ return c.json(response.body, response.status);
91
+ }
92
+
93
+ // Attach result for downstream handlers
94
+ c.set("shieldResult", result);
95
+ return next();
96
+ };
97
+ }
98
+
99
+ export type { ShieldMiddlewareConfig, ScanResult };
package/src/index.ts ADDED
@@ -0,0 +1,13 @@
1
+ // ============================================================
2
+ // @ai-shield/middleware — Public API
3
+ // ============================================================
4
+
5
+ // Shared types and utilities
6
+ export { defaultGetInput, defaultBlockedResponse } from "./shared.js";
7
+ export type { ShieldMiddlewareConfig } from "./shared.js";
8
+
9
+ // Framework-specific exports
10
+ export { shieldMiddleware as expressShield } from "./express.js";
11
+ export { shieldMiddleware as honoShield } from "./hono.js";
12
+
13
+ export type { ScanResult } from "ai-shield-core";
package/src/shared.ts ADDED
@@ -0,0 +1,102 @@
1
+ import type { AIShield, ShieldConfig, ScanContext, ScanResult } from "ai-shield-core";
2
+
3
+ // ============================================================
4
+ // Shared middleware logic — used by Express and Hono adapters
5
+ // ============================================================
6
+
7
+ export interface ShieldMiddlewareConfig {
8
+ /** AI Shield config */
9
+ shield?: ShieldConfig;
10
+ /** Pre-created AIShield instance (shared across routes) */
11
+ shieldInstance?: AIShield;
12
+ /** Extract agent ID from request */
13
+ getAgentId?: (req: { headers: Record<string, string | string[] | undefined>; path?: string; url?: string }) => string | undefined;
14
+ /** Extract scan context from request */
15
+ getContext?: (req: { headers: Record<string, string | string[] | undefined>; body?: unknown }) => ScanContext;
16
+ /** Extract text to scan from request body */
17
+ getInput?: (body: unknown) => string | null;
18
+ /** Custom blocked response */
19
+ onBlocked?: (result: ScanResult) => { status: number; body: unknown };
20
+ /** Called on warnings (non-blocking) */
21
+ onWarning?: (result: ScanResult) => void;
22
+ /** Skip scanning for certain paths */
23
+ skipPaths?: string[];
24
+ }
25
+
26
+ /** Default: extract text from common chat API body shapes */
27
+ export function defaultGetInput(body: unknown): string | null {
28
+ if (!body || typeof body !== "object") return null;
29
+
30
+ const obj = body as Record<string, unknown>;
31
+
32
+ // Direct message field
33
+ if (typeof obj.message === "string") return obj.message;
34
+ if (typeof obj.input === "string") return obj.input;
35
+ if (typeof obj.prompt === "string") return obj.prompt;
36
+ if (typeof obj.text === "string") return obj.text;
37
+ if (typeof obj.content === "string") return obj.content;
38
+ if (typeof obj.query === "string") return obj.query;
39
+
40
+ // OpenAI-style messages array
41
+ if (Array.isArray(obj.messages)) {
42
+ const userMessages = (obj.messages as Array<{ role?: string; content?: string }>)
43
+ .filter((m) => m.role === "user" && typeof m.content === "string")
44
+ .map((m) => m.content as string);
45
+ if (userMessages.length > 0) return userMessages.join("\n");
46
+ }
47
+
48
+ return null;
49
+ }
50
+
51
+ /** Default blocked response */
52
+ export function defaultBlockedResponse(result: ScanResult): { status: number; body: unknown } {
53
+ return {
54
+ status: 403,
55
+ body: {
56
+ error: "Request blocked by AI Shield",
57
+ decision: result.decision,
58
+ violations: result.violations.map((v) => ({
59
+ type: v.type,
60
+ message: v.message,
61
+ })),
62
+ },
63
+ };
64
+ }
65
+
66
+ /** Lazy-load AIShield instance */
67
+ let _sharedShield: AIShield | null = null;
68
+ let _shieldReady: Promise<AIShield> | null = null;
69
+
70
+ export async function getOrCreateShield(config: ShieldMiddlewareConfig): Promise<AIShield> {
71
+ if (config.shieldInstance) return config.shieldInstance;
72
+ if (_sharedShield) return _sharedShield;
73
+ if (_shieldReady) return _shieldReady;
74
+
75
+ _shieldReady = import("ai-shield-core").then((mod) => {
76
+ _sharedShield = new mod.AIShield(config.shield ?? {});
77
+ return _sharedShield;
78
+ });
79
+
80
+ return _shieldReady;
81
+ }
82
+
83
+ /** Core scan logic shared between Express and Hono */
84
+ export async function scanRequest(
85
+ config: ShieldMiddlewareConfig,
86
+ input: string,
87
+ context: ScanContext,
88
+ ): Promise<{ blocked: boolean; result: ScanResult; response?: { status: number; body: unknown } }> {
89
+ const shield = await getOrCreateShield(config);
90
+ const result = await shield.scan(input, context);
91
+
92
+ if (result.decision === "block") {
93
+ const onBlocked = config.onBlocked ?? defaultBlockedResponse;
94
+ return { blocked: true, result, response: onBlocked(result) };
95
+ }
96
+
97
+ if (result.decision === "warn") {
98
+ config.onWarning?.(result);
99
+ }
100
+
101
+ return { blocked: false, result };
102
+ }
package/tsconfig.json ADDED
@@ -0,0 +1,8 @@
1
+ {
2
+ "extends": "../../tsconfig.json",
3
+ "compilerOptions": {
4
+ "outDir": "./dist",
5
+ "rootDir": "./src"
6
+ },
7
+ "include": ["src/**/*"]
8
+ }