agent-sql 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md ADDED
@@ -0,0 +1,99 @@
1
+ # agent-sql
2
+
3
+ Sanitise agent-written SQL for multi-tenant DBs.
4
+
5
+ You provide a tenant ID, and the agent supplies the query.
6
+
7
+ agent-sql works by fully parsing the supplied SQL query into an AST.
8
+ The grammar ONLY accepts `SELECT` statements. Anything else is an error.
9
+ CTEs and other complex things that we aren't confident of securing: error.
10
+
11
+ Then we ensure that that the needed tenant table is somewhere in the query,
12
+ and add a `WHERE` clause ensuring that only values from the supplied ID are returned.
13
+
14
+ Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785). And [Cloudflare](https://x.com/thomas_ankcorn/status/2033931057133748330).
15
+
16
+ ## Quickstart
17
+
18
+ ```bash
19
+ npm install agent-sql
20
+ ```
21
+
22
+ ```ts
23
+ import { sanitise } from "agent-sql";
24
+
25
+ const sql = sanitise(`SELECT id, name FROM users WHERE status = 'active' LIMIT 10`, {
26
+ tables: { users: {} },
27
+ where: { table: "users", col: "tenant_id", value: "acme" },
28
+ });
29
+
30
+ console.log(sql);
31
+ // SELECT id, name
32
+ // FROM users
33
+ // WHERE (users.tenant_id = 'acme' AND status = 'active')
34
+ // LIMIT 10
35
+ ```
36
+
37
+ Or, more usefully:
38
+
39
+ ```ts
40
+ import { sanitiserFactory } from "agent-sql";
41
+ import { tool } from "ai";
42
+ import { sql } from "drizzle-orm";
43
+ import { db } from "@/db";
44
+
45
+ function makeSqlTool(orgId: string) {
46
+ // Create a sanitiser function for this tenant
47
+ const sanitise = sanitiserFactory({
48
+ tables: {},
49
+ where: { table: "org", col: "id", value: orgId },
50
+ });
51
+
52
+ return tool({
53
+ description: "Run raw SQL against the DB",
54
+ inputSchema: z.object({ query: z.string() }),
55
+ execute: async ({ query }) => {
56
+ // The LLM can pass any query it likes, we'll sanitise it if possible
57
+ // and return helpful error messages if not
58
+ const sanitised = sanitise(query);
59
+ // Now we can throw that straight at the db and be confident it'll only
60
+ // return data from the specified tenant
61
+ return db.execute(sql.raw(sanitised));
62
+ },
63
+ });
64
+ }
65
+ ```
66
+
67
+ `agent-sql` parses the SQL, enforces a mandatory equality filter on the given column as the outermost `AND` condition (so it cannot be short-circuited by agent-supplied `OR` clauses), and returns the sanitised SQL string.
68
+
69
+ ## Development
70
+
71
+ First install [Vite+](https://viteplus.dev/guide/):
72
+
73
+ ```bash
74
+ curl -fsSL https://vite.plus | bash
75
+ ```
76
+
77
+ Install dependencies:
78
+
79
+ ```bash
80
+ vp install
81
+ ```
82
+
83
+ Format, lint, typecheck:
84
+
85
+ ```bash
86
+ vp check --fix
87
+ ```
88
+
89
+ Run the unit tests:
90
+
91
+ ```bash
92
+ vp test
93
+ ```
94
+
95
+ Build the library:
96
+
97
+ ```bash
98
+ vp pack
99
+ ```
@@ -0,0 +1,322 @@
1
+ //#region src/ast.d.ts
2
+ interface SelectStatement {
3
+ readonly type: "select";
4
+ distinct: Distinct | DistinctOn | null;
5
+ columns: Column[];
6
+ from: SelectFrom;
7
+ joins: JoinClause[];
8
+ where: WhereRoot | null;
9
+ groupBy: GroupByClause | null;
10
+ having: HavingClause | null;
11
+ orderBy: OrderByClause | null;
12
+ limit: LimitClause | null;
13
+ offset: OffsetClause | null;
14
+ }
15
+ interface Distinct {
16
+ readonly type: "distinct";
17
+ }
18
+ /** PostgreSQL: DISTINCT ON (col1, col2, ...) */
19
+ interface DistinctOn {
20
+ readonly type: "distinct_on";
21
+ columns: WhereValue[];
22
+ }
23
+ interface GroupByClause {
24
+ readonly type: "group_by";
25
+ items: WhereValue[];
26
+ }
27
+ interface HavingClause {
28
+ readonly type: "having";
29
+ expr: WhereExpr;
30
+ }
31
+ interface OrderByClause {
32
+ readonly type: "order_by";
33
+ items: OrderByItem[];
34
+ }
35
+ type SortDirection = "asc" | "desc";
36
+ type NullsOrder = "nulls_first" | "nulls_last";
37
+ interface OrderByItem {
38
+ readonly type: "order_by_item";
39
+ expr: WhereValue;
40
+ direction?: SortDirection;
41
+ nulls?: NullsOrder;
42
+ }
43
+ interface OffsetClause {
44
+ readonly type: "offset";
45
+ value: number;
46
+ }
47
+ type JoinType = "inner" | "inner_outer" | "left" | "left_outer" | "right" | "right_outer" | "full" | "full_outer" | "cross" | "natural";
48
+ type JoinCondition = {
49
+ readonly type: "join_on";
50
+ expr: WhereExpr;
51
+ } | {
52
+ readonly type: "join_using";
53
+ columns: string[];
54
+ };
55
+ interface JoinClause {
56
+ readonly type: "join";
57
+ joinType: JoinType;
58
+ table: TableRef;
59
+ condition: JoinCondition | null;
60
+ }
61
+ type SelectFrom = {
62
+ readonly type: "select_from";
63
+ table: TableRef;
64
+ };
65
+ type LimitClause = {
66
+ readonly type: "limit";
67
+ value: number;
68
+ };
69
+ type WhereRoot = {
70
+ readonly type: "where_root";
71
+ inner: WhereExpr;
72
+ };
73
+ type WhereExpr = WhereAnd | WhereOr | WhereNot | WhereComparison | WhereIsNull | WhereIsBool | WhereBetween | WhereIn | WhereLike | WhereTsMatch;
74
+ interface WhereAnd {
75
+ readonly type: "where_and";
76
+ left: WhereExpr;
77
+ right: WhereExpr;
78
+ }
79
+ interface WhereOr {
80
+ readonly type: "where_or";
81
+ left: WhereExpr;
82
+ right: WhereExpr;
83
+ }
84
+ type ComparisonOperator = "=" | "<>" | "!=" | "<" | ">" | "<=" | ">=";
85
+ interface WhereNot {
86
+ readonly type: "where_not";
87
+ expr: WhereExpr;
88
+ }
89
+ interface WhereIsNull {
90
+ readonly type: "where_is_null";
91
+ not: boolean;
92
+ expr: WhereValue;
93
+ }
94
+ type IsBoolTarget = boolean | "unknown";
95
+ interface WhereIsBool {
96
+ readonly type: "where_is_bool";
97
+ not: boolean;
98
+ expr: WhereValue;
99
+ target: IsBoolTarget;
100
+ }
101
+ interface WhereBetween {
102
+ readonly type: "where_between";
103
+ not: boolean;
104
+ expr: WhereValue;
105
+ low: WhereValue;
106
+ high: WhereValue;
107
+ }
108
+ interface WhereIn {
109
+ readonly type: "where_in";
110
+ not: boolean;
111
+ expr: WhereValue;
112
+ list: WhereValue[];
113
+ }
114
+ /** "ilike" is PostgreSQL-specific (case-insensitive LIKE) */
115
+ type LikeOp = "like" | "ilike";
116
+ interface WhereLike {
117
+ readonly type: "where_like";
118
+ not: boolean;
119
+ op: LikeOp;
120
+ expr: WhereValue;
121
+ pattern: WhereValue;
122
+ }
123
+ /** PostgreSQL: JSONB operators */
124
+ type JsonbOp = "->" | "->>" | "#>" | "#>>" | "?" | "?|" | "?&" | "@>";
125
+ /** PostgreSQL: JSONB binary operator expression */
126
+ interface WhereJsonbOp {
127
+ readonly type: "where_jsonb_op";
128
+ op: JsonbOp;
129
+ left: WhereValue;
130
+ right: WhereValue;
131
+ }
132
+ /** PostgreSQL: text search match (@@) */
133
+ interface WhereTsMatch {
134
+ readonly type: "where_ts_match";
135
+ left: WhereValue;
136
+ right: WhereValue;
137
+ }
138
+ /** pgvector: distance operators */
139
+ type PgvectorOp = "<->" | "<#>" | "<=>" | "<+>" | "<~>" | "<%>";
140
+ /** pgvector: distance operator expression */
141
+ interface WherePgvectorOp {
142
+ readonly type: "where_pgvector_op";
143
+ op: PgvectorOp;
144
+ left: WhereValue;
145
+ right: WhereValue;
146
+ }
147
+ type ArithOp = "+" | "-" | "*" | "/" | "%" | "||";
148
+ interface WhereArith {
149
+ readonly type: "where_arith";
150
+ op: ArithOp;
151
+ left: WhereValue;
152
+ right: WhereValue;
153
+ }
154
+ interface WhereUnaryMinus {
155
+ readonly type: "where_unary_minus";
156
+ expr: WhereValue;
157
+ }
158
+ interface CaseWhen {
159
+ condition: WhereValue;
160
+ result: WhereValue;
161
+ }
162
+ interface CaseExpr {
163
+ readonly type: "case_expr";
164
+ subject: WhereValue | null;
165
+ whens: CaseWhen[];
166
+ else: WhereValue | null;
167
+ }
168
+ interface CastExpr {
169
+ readonly type: "cast_expr";
170
+ expr: WhereValue;
171
+ typeName: string;
172
+ }
173
+ type WhereValue = {
174
+ readonly type: "where_value";
175
+ kind: "string";
176
+ value: string;
177
+ } | {
178
+ readonly type: "where_value";
179
+ kind: "integer";
180
+ value: number;
181
+ } | {
182
+ readonly type: "where_value";
183
+ kind: "float";
184
+ value: number;
185
+ } | {
186
+ readonly type: "where_value";
187
+ kind: "bool";
188
+ value: boolean;
189
+ } | {
190
+ readonly type: "where_value";
191
+ kind: "null";
192
+ } | {
193
+ readonly type: "where_value";
194
+ kind: "column_ref";
195
+ ref: ColumnRef;
196
+ } | {
197
+ readonly type: "where_value";
198
+ kind: "func_call";
199
+ func: FuncCall;
200
+ } | WhereArith | WhereUnaryMinus | WhereJsonbOp | WherePgvectorOp | CaseExpr | CastExpr;
201
+ interface WhereComparison {
202
+ readonly type: "where_comparison";
203
+ operator: ComparisonOperator;
204
+ left: WhereValue;
205
+ right: WhereValue;
206
+ }
207
+ interface ColumnRef {
208
+ readonly type: "column_ref";
209
+ schema?: string;
210
+ table?: string;
211
+ name: string;
212
+ }
213
+ type FuncCallArg = {
214
+ kind: "wildcard";
215
+ } | {
216
+ kind: "args";
217
+ distinct: boolean;
218
+ args: WhereValue[];
219
+ };
220
+ interface FuncCall {
221
+ readonly type: "func_call";
222
+ name: string;
223
+ args: FuncCallArg;
224
+ }
225
+ interface ColumnExpr {
226
+ readonly type: "column_expr";
227
+ kind: "wildcard" | "qualified_wildcard" | "expr";
228
+ table?: string;
229
+ expr?: WhereValue;
230
+ }
231
+ interface Column {
232
+ readonly type: "column";
233
+ expr: ColumnExpr;
234
+ alias?: Alias;
235
+ }
236
+ interface Alias {
237
+ readonly type: "alias";
238
+ name: string;
239
+ }
240
+ interface TableRef {
241
+ readonly type: "table_ref";
242
+ schema?: string;
243
+ name: string;
244
+ }
245
+ //#endregion
246
+ //#region src/result.d.ts
247
+ type Failure<E = unknown> = {
248
+ ok: false;
249
+ error: E;
250
+ unwrap(): never;
251
+ };
252
+ type Success<T = void> = {
253
+ ok: true;
254
+ data: T;
255
+ unwrap(): T;
256
+ };
257
+ type Result<T = void, E = Error> = Failure<E> | Success<T>;
258
+ //#endregion
259
+ //#region src/graph.d.ts
260
+ type TableDefs = ReturnType<typeof defineTables>;
261
+ declare function defineTables<T extends { [Table in keyof T]: Record<string, null | { [FK in keyof T & string]: {
262
+ ft: FK;
263
+ fc: keyof T[FK] & string;
264
+ } }[keyof T & string]> }>(tables: T): T;
265
+ //#endregion
266
+ //#region src/output.d.ts
267
+ declare function outputSql(ast: SelectStatement): string;
268
+ //#endregion
269
+ //#region src/parse.d.ts
270
+ declare function parseSql(expr: string): Result<SelectStatement>;
271
+ //#endregion
272
+ //#region src/sanitise.d.ts
273
+ interface GuardCol {
274
+ schema?: string;
275
+ table: string;
276
+ col: string;
277
+ }
278
+ interface WhereGuard {
279
+ schema?: string;
280
+ table: string;
281
+ col: string;
282
+ value: string | number;
283
+ }
284
+ declare function sanitiseSql(ast: SelectStatement, guard: WhereGuard): Result<SelectStatement>;
285
+ //#endregion
286
+ //#region src/index.d.ts
287
+ declare function sanitise(sql: string, {
288
+ tables,
289
+ where
290
+ }: {
291
+ tables: TableDefs;
292
+ where: WhereGuard;
293
+ }): string;
294
+ declare function safeSanitise(sql: string, {
295
+ tables,
296
+ where
297
+ }: {
298
+ tables: TableDefs;
299
+ where: WhereGuard;
300
+ }): Result<string>;
301
+ declare function sanitiserFactory(_: {
302
+ tables: TableDefs;
303
+ where: WhereGuard;
304
+ throws: false;
305
+ }): (expr: string) => Result<string>;
306
+ declare function sanitiserFactory(_: {
307
+ tables: TableDefs;
308
+ where: WhereGuard;
309
+ throws?: true;
310
+ }): (expr: string) => string;
311
+ declare function makeSanitiserFactory(_: {
312
+ tables: TableDefs;
313
+ guardCol: GuardCol;
314
+ throws: false;
315
+ }): (guardVal: string | number) => (expr: string) => Result<string>;
316
+ declare function makeSanitiserFactory(_: {
317
+ tables: TableDefs;
318
+ guardCol: GuardCol;
319
+ throws?: true;
320
+ }): (guardVal: string | number) => (expr: string) => string;
321
+ //#endregion
322
+ export { defineTables, makeSanitiserFactory, outputSql, parseSql, safeSanitise, sanitise, sanitiseSql, sanitiserFactory };