agent-sql 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +99 -0
- package/dist/index.d.mts +322 -0
- package/dist/index.mjs +6446 -0
- package/package.json +58 -0
package/README.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# agent-sql
|
|
2
|
+
|
|
3
|
+
Sanitise agent-written SQL for multi-tenant DBs.
|
|
4
|
+
|
|
5
|
+
You provide a tenant ID, and the agent supplies the query.
|
|
6
|
+
|
|
7
|
+
agent-sql works by fully parsing the supplied SQL query into an AST.
|
|
8
|
+
The grammar ONLY accepts `SELECT` statements. Anything else is an error.
|
|
9
|
+
CTEs and other complex things that we aren't confident of securing: error.
|
|
10
|
+
|
|
11
|
+
Then we ensure that that the needed tenant table is somewhere in the query,
|
|
12
|
+
and add a `WHERE` clause ensuring that only values from the supplied ID are returned.
|
|
13
|
+
|
|
14
|
+
Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785). And [Cloudflare](https://x.com/thomas_ankcorn/status/2033931057133748330).
|
|
15
|
+
|
|
16
|
+
## Quickstart
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
npm install agent-sql
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
```ts
|
|
23
|
+
import { sanitise } from "agent-sql";
|
|
24
|
+
|
|
25
|
+
const sql = sanitise(`SELECT id, name FROM users WHERE status = 'active' LIMIT 10`, {
|
|
26
|
+
tables: { users: {} },
|
|
27
|
+
where: { table: "users", col: "tenant_id", value: "acme" },
|
|
28
|
+
});
|
|
29
|
+
|
|
30
|
+
console.log(sql);
|
|
31
|
+
// SELECT id, name
|
|
32
|
+
// FROM users
|
|
33
|
+
// WHERE (users.tenant_id = 'acme' AND status = 'active')
|
|
34
|
+
// LIMIT 10
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
Or, more usefully:
|
|
38
|
+
|
|
39
|
+
```ts
|
|
40
|
+
import { sanitiserFactory } from "agent-sql";
|
|
41
|
+
import { tool } from "ai";
|
|
42
|
+
import { sql } from "drizzle-orm";
|
|
43
|
+
import { db } from "@/db";
|
|
44
|
+
|
|
45
|
+
function makeSqlTool(orgId: string) {
|
|
46
|
+
// Create a sanitiser function for this tenant
|
|
47
|
+
const sanitise = sanitiserFactory({
|
|
48
|
+
tables: {},
|
|
49
|
+
where: { table: "org", col: "id", value: orgId },
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
return tool({
|
|
53
|
+
description: "Run raw SQL against the DB",
|
|
54
|
+
inputSchema: z.object({ query: z.string() }),
|
|
55
|
+
execute: async ({ query }) => {
|
|
56
|
+
// The LLM can pass any query it likes, we'll sanitise it if possible
|
|
57
|
+
// and return helpful error messages if not
|
|
58
|
+
const sanitised = sanitise(query);
|
|
59
|
+
// Now we can throw that straight at the db and be confident it'll only
|
|
60
|
+
// return data from the specified tenant
|
|
61
|
+
return db.execute(sql.raw(sanitised));
|
|
62
|
+
},
|
|
63
|
+
});
|
|
64
|
+
}
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
`agent-sql` parses the SQL, enforces a mandatory equality filter on the given column as the outermost `AND` condition (so it cannot be short-circuited by agent-supplied `OR` clauses), and returns the sanitised SQL string.
|
|
68
|
+
|
|
69
|
+
## Development
|
|
70
|
+
|
|
71
|
+
First install [Vite+](https://viteplus.dev/guide/):
|
|
72
|
+
|
|
73
|
+
```bash
|
|
74
|
+
curl -fsSL https://vite.plus | bash
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
Install dependencies:
|
|
78
|
+
|
|
79
|
+
```bash
|
|
80
|
+
vp install
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
Format, lint, typecheck:
|
|
84
|
+
|
|
85
|
+
```bash
|
|
86
|
+
vp check --fix
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
Run the unit tests:
|
|
90
|
+
|
|
91
|
+
```bash
|
|
92
|
+
vp test
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
Build the library:
|
|
96
|
+
|
|
97
|
+
```bash
|
|
98
|
+
vp pack
|
|
99
|
+
```
|
package/dist/index.d.mts
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
//#region src/ast.d.ts
|
|
2
|
+
interface SelectStatement {
|
|
3
|
+
readonly type: "select";
|
|
4
|
+
distinct: Distinct | DistinctOn | null;
|
|
5
|
+
columns: Column[];
|
|
6
|
+
from: SelectFrom;
|
|
7
|
+
joins: JoinClause[];
|
|
8
|
+
where: WhereRoot | null;
|
|
9
|
+
groupBy: GroupByClause | null;
|
|
10
|
+
having: HavingClause | null;
|
|
11
|
+
orderBy: OrderByClause | null;
|
|
12
|
+
limit: LimitClause | null;
|
|
13
|
+
offset: OffsetClause | null;
|
|
14
|
+
}
|
|
15
|
+
interface Distinct {
|
|
16
|
+
readonly type: "distinct";
|
|
17
|
+
}
|
|
18
|
+
/** PostgreSQL: DISTINCT ON (col1, col2, ...) */
|
|
19
|
+
interface DistinctOn {
|
|
20
|
+
readonly type: "distinct_on";
|
|
21
|
+
columns: WhereValue[];
|
|
22
|
+
}
|
|
23
|
+
interface GroupByClause {
|
|
24
|
+
readonly type: "group_by";
|
|
25
|
+
items: WhereValue[];
|
|
26
|
+
}
|
|
27
|
+
interface HavingClause {
|
|
28
|
+
readonly type: "having";
|
|
29
|
+
expr: WhereExpr;
|
|
30
|
+
}
|
|
31
|
+
interface OrderByClause {
|
|
32
|
+
readonly type: "order_by";
|
|
33
|
+
items: OrderByItem[];
|
|
34
|
+
}
|
|
35
|
+
type SortDirection = "asc" | "desc";
|
|
36
|
+
type NullsOrder = "nulls_first" | "nulls_last";
|
|
37
|
+
interface OrderByItem {
|
|
38
|
+
readonly type: "order_by_item";
|
|
39
|
+
expr: WhereValue;
|
|
40
|
+
direction?: SortDirection;
|
|
41
|
+
nulls?: NullsOrder;
|
|
42
|
+
}
|
|
43
|
+
interface OffsetClause {
|
|
44
|
+
readonly type: "offset";
|
|
45
|
+
value: number;
|
|
46
|
+
}
|
|
47
|
+
type JoinType = "inner" | "inner_outer" | "left" | "left_outer" | "right" | "right_outer" | "full" | "full_outer" | "cross" | "natural";
|
|
48
|
+
type JoinCondition = {
|
|
49
|
+
readonly type: "join_on";
|
|
50
|
+
expr: WhereExpr;
|
|
51
|
+
} | {
|
|
52
|
+
readonly type: "join_using";
|
|
53
|
+
columns: string[];
|
|
54
|
+
};
|
|
55
|
+
interface JoinClause {
|
|
56
|
+
readonly type: "join";
|
|
57
|
+
joinType: JoinType;
|
|
58
|
+
table: TableRef;
|
|
59
|
+
condition: JoinCondition | null;
|
|
60
|
+
}
|
|
61
|
+
type SelectFrom = {
|
|
62
|
+
readonly type: "select_from";
|
|
63
|
+
table: TableRef;
|
|
64
|
+
};
|
|
65
|
+
type LimitClause = {
|
|
66
|
+
readonly type: "limit";
|
|
67
|
+
value: number;
|
|
68
|
+
};
|
|
69
|
+
type WhereRoot = {
|
|
70
|
+
readonly type: "where_root";
|
|
71
|
+
inner: WhereExpr;
|
|
72
|
+
};
|
|
73
|
+
type WhereExpr = WhereAnd | WhereOr | WhereNot | WhereComparison | WhereIsNull | WhereIsBool | WhereBetween | WhereIn | WhereLike | WhereTsMatch;
|
|
74
|
+
interface WhereAnd {
|
|
75
|
+
readonly type: "where_and";
|
|
76
|
+
left: WhereExpr;
|
|
77
|
+
right: WhereExpr;
|
|
78
|
+
}
|
|
79
|
+
interface WhereOr {
|
|
80
|
+
readonly type: "where_or";
|
|
81
|
+
left: WhereExpr;
|
|
82
|
+
right: WhereExpr;
|
|
83
|
+
}
|
|
84
|
+
type ComparisonOperator = "=" | "<>" | "!=" | "<" | ">" | "<=" | ">=";
|
|
85
|
+
interface WhereNot {
|
|
86
|
+
readonly type: "where_not";
|
|
87
|
+
expr: WhereExpr;
|
|
88
|
+
}
|
|
89
|
+
interface WhereIsNull {
|
|
90
|
+
readonly type: "where_is_null";
|
|
91
|
+
not: boolean;
|
|
92
|
+
expr: WhereValue;
|
|
93
|
+
}
|
|
94
|
+
type IsBoolTarget = boolean | "unknown";
|
|
95
|
+
interface WhereIsBool {
|
|
96
|
+
readonly type: "where_is_bool";
|
|
97
|
+
not: boolean;
|
|
98
|
+
expr: WhereValue;
|
|
99
|
+
target: IsBoolTarget;
|
|
100
|
+
}
|
|
101
|
+
interface WhereBetween {
|
|
102
|
+
readonly type: "where_between";
|
|
103
|
+
not: boolean;
|
|
104
|
+
expr: WhereValue;
|
|
105
|
+
low: WhereValue;
|
|
106
|
+
high: WhereValue;
|
|
107
|
+
}
|
|
108
|
+
interface WhereIn {
|
|
109
|
+
readonly type: "where_in";
|
|
110
|
+
not: boolean;
|
|
111
|
+
expr: WhereValue;
|
|
112
|
+
list: WhereValue[];
|
|
113
|
+
}
|
|
114
|
+
/** "ilike" is PostgreSQL-specific (case-insensitive LIKE) */
|
|
115
|
+
type LikeOp = "like" | "ilike";
|
|
116
|
+
interface WhereLike {
|
|
117
|
+
readonly type: "where_like";
|
|
118
|
+
not: boolean;
|
|
119
|
+
op: LikeOp;
|
|
120
|
+
expr: WhereValue;
|
|
121
|
+
pattern: WhereValue;
|
|
122
|
+
}
|
|
123
|
+
/** PostgreSQL: JSONB operators */
|
|
124
|
+
type JsonbOp = "->" | "->>" | "#>" | "#>>" | "?" | "?|" | "?&" | "@>";
|
|
125
|
+
/** PostgreSQL: JSONB binary operator expression */
|
|
126
|
+
interface WhereJsonbOp {
|
|
127
|
+
readonly type: "where_jsonb_op";
|
|
128
|
+
op: JsonbOp;
|
|
129
|
+
left: WhereValue;
|
|
130
|
+
right: WhereValue;
|
|
131
|
+
}
|
|
132
|
+
/** PostgreSQL: text search match (@@) */
|
|
133
|
+
interface WhereTsMatch {
|
|
134
|
+
readonly type: "where_ts_match";
|
|
135
|
+
left: WhereValue;
|
|
136
|
+
right: WhereValue;
|
|
137
|
+
}
|
|
138
|
+
/** pgvector: distance operators */
|
|
139
|
+
type PgvectorOp = "<->" | "<#>" | "<=>" | "<+>" | "<~>" | "<%>";
|
|
140
|
+
/** pgvector: distance operator expression */
|
|
141
|
+
interface WherePgvectorOp {
|
|
142
|
+
readonly type: "where_pgvector_op";
|
|
143
|
+
op: PgvectorOp;
|
|
144
|
+
left: WhereValue;
|
|
145
|
+
right: WhereValue;
|
|
146
|
+
}
|
|
147
|
+
type ArithOp = "+" | "-" | "*" | "/" | "%" | "||";
|
|
148
|
+
interface WhereArith {
|
|
149
|
+
readonly type: "where_arith";
|
|
150
|
+
op: ArithOp;
|
|
151
|
+
left: WhereValue;
|
|
152
|
+
right: WhereValue;
|
|
153
|
+
}
|
|
154
|
+
interface WhereUnaryMinus {
|
|
155
|
+
readonly type: "where_unary_minus";
|
|
156
|
+
expr: WhereValue;
|
|
157
|
+
}
|
|
158
|
+
interface CaseWhen {
|
|
159
|
+
condition: WhereValue;
|
|
160
|
+
result: WhereValue;
|
|
161
|
+
}
|
|
162
|
+
interface CaseExpr {
|
|
163
|
+
readonly type: "case_expr";
|
|
164
|
+
subject: WhereValue | null;
|
|
165
|
+
whens: CaseWhen[];
|
|
166
|
+
else: WhereValue | null;
|
|
167
|
+
}
|
|
168
|
+
interface CastExpr {
|
|
169
|
+
readonly type: "cast_expr";
|
|
170
|
+
expr: WhereValue;
|
|
171
|
+
typeName: string;
|
|
172
|
+
}
|
|
173
|
+
type WhereValue = {
|
|
174
|
+
readonly type: "where_value";
|
|
175
|
+
kind: "string";
|
|
176
|
+
value: string;
|
|
177
|
+
} | {
|
|
178
|
+
readonly type: "where_value";
|
|
179
|
+
kind: "integer";
|
|
180
|
+
value: number;
|
|
181
|
+
} | {
|
|
182
|
+
readonly type: "where_value";
|
|
183
|
+
kind: "float";
|
|
184
|
+
value: number;
|
|
185
|
+
} | {
|
|
186
|
+
readonly type: "where_value";
|
|
187
|
+
kind: "bool";
|
|
188
|
+
value: boolean;
|
|
189
|
+
} | {
|
|
190
|
+
readonly type: "where_value";
|
|
191
|
+
kind: "null";
|
|
192
|
+
} | {
|
|
193
|
+
readonly type: "where_value";
|
|
194
|
+
kind: "column_ref";
|
|
195
|
+
ref: ColumnRef;
|
|
196
|
+
} | {
|
|
197
|
+
readonly type: "where_value";
|
|
198
|
+
kind: "func_call";
|
|
199
|
+
func: FuncCall;
|
|
200
|
+
} | WhereArith | WhereUnaryMinus | WhereJsonbOp | WherePgvectorOp | CaseExpr | CastExpr;
|
|
201
|
+
interface WhereComparison {
|
|
202
|
+
readonly type: "where_comparison";
|
|
203
|
+
operator: ComparisonOperator;
|
|
204
|
+
left: WhereValue;
|
|
205
|
+
right: WhereValue;
|
|
206
|
+
}
|
|
207
|
+
interface ColumnRef {
|
|
208
|
+
readonly type: "column_ref";
|
|
209
|
+
schema?: string;
|
|
210
|
+
table?: string;
|
|
211
|
+
name: string;
|
|
212
|
+
}
|
|
213
|
+
type FuncCallArg = {
|
|
214
|
+
kind: "wildcard";
|
|
215
|
+
} | {
|
|
216
|
+
kind: "args";
|
|
217
|
+
distinct: boolean;
|
|
218
|
+
args: WhereValue[];
|
|
219
|
+
};
|
|
220
|
+
interface FuncCall {
|
|
221
|
+
readonly type: "func_call";
|
|
222
|
+
name: string;
|
|
223
|
+
args: FuncCallArg;
|
|
224
|
+
}
|
|
225
|
+
interface ColumnExpr {
|
|
226
|
+
readonly type: "column_expr";
|
|
227
|
+
kind: "wildcard" | "qualified_wildcard" | "expr";
|
|
228
|
+
table?: string;
|
|
229
|
+
expr?: WhereValue;
|
|
230
|
+
}
|
|
231
|
+
interface Column {
|
|
232
|
+
readonly type: "column";
|
|
233
|
+
expr: ColumnExpr;
|
|
234
|
+
alias?: Alias;
|
|
235
|
+
}
|
|
236
|
+
interface Alias {
|
|
237
|
+
readonly type: "alias";
|
|
238
|
+
name: string;
|
|
239
|
+
}
|
|
240
|
+
interface TableRef {
|
|
241
|
+
readonly type: "table_ref";
|
|
242
|
+
schema?: string;
|
|
243
|
+
name: string;
|
|
244
|
+
}
|
|
245
|
+
//#endregion
|
|
246
|
+
//#region src/result.d.ts
|
|
247
|
+
type Failure<E = unknown> = {
|
|
248
|
+
ok: false;
|
|
249
|
+
error: E;
|
|
250
|
+
unwrap(): never;
|
|
251
|
+
};
|
|
252
|
+
type Success<T = void> = {
|
|
253
|
+
ok: true;
|
|
254
|
+
data: T;
|
|
255
|
+
unwrap(): T;
|
|
256
|
+
};
|
|
257
|
+
type Result<T = void, E = Error> = Failure<E> | Success<T>;
|
|
258
|
+
//#endregion
|
|
259
|
+
//#region src/graph.d.ts
|
|
260
|
+
type TableDefs = ReturnType<typeof defineTables>;
|
|
261
|
+
declare function defineTables<T extends { [Table in keyof T]: Record<string, null | { [FK in keyof T & string]: {
|
|
262
|
+
ft: FK;
|
|
263
|
+
fc: keyof T[FK] & string;
|
|
264
|
+
} }[keyof T & string]> }>(tables: T): T;
|
|
265
|
+
//#endregion
|
|
266
|
+
//#region src/output.d.ts
|
|
267
|
+
declare function outputSql(ast: SelectStatement): string;
|
|
268
|
+
//#endregion
|
|
269
|
+
//#region src/parse.d.ts
|
|
270
|
+
declare function parseSql(expr: string): Result<SelectStatement>;
|
|
271
|
+
//#endregion
|
|
272
|
+
//#region src/sanitise.d.ts
|
|
273
|
+
interface GuardCol {
|
|
274
|
+
schema?: string;
|
|
275
|
+
table: string;
|
|
276
|
+
col: string;
|
|
277
|
+
}
|
|
278
|
+
interface WhereGuard {
|
|
279
|
+
schema?: string;
|
|
280
|
+
table: string;
|
|
281
|
+
col: string;
|
|
282
|
+
value: string | number;
|
|
283
|
+
}
|
|
284
|
+
declare function sanitiseSql(ast: SelectStatement, guard: WhereGuard): Result<SelectStatement>;
|
|
285
|
+
//#endregion
|
|
286
|
+
//#region src/index.d.ts
|
|
287
|
+
declare function sanitise(sql: string, {
|
|
288
|
+
tables,
|
|
289
|
+
where
|
|
290
|
+
}: {
|
|
291
|
+
tables: TableDefs;
|
|
292
|
+
where: WhereGuard;
|
|
293
|
+
}): string;
|
|
294
|
+
declare function safeSanitise(sql: string, {
|
|
295
|
+
tables,
|
|
296
|
+
where
|
|
297
|
+
}: {
|
|
298
|
+
tables: TableDefs;
|
|
299
|
+
where: WhereGuard;
|
|
300
|
+
}): Result<string>;
|
|
301
|
+
declare function sanitiserFactory(_: {
|
|
302
|
+
tables: TableDefs;
|
|
303
|
+
where: WhereGuard;
|
|
304
|
+
throws: false;
|
|
305
|
+
}): (expr: string) => Result<string>;
|
|
306
|
+
declare function sanitiserFactory(_: {
|
|
307
|
+
tables: TableDefs;
|
|
308
|
+
where: WhereGuard;
|
|
309
|
+
throws?: true;
|
|
310
|
+
}): (expr: string) => string;
|
|
311
|
+
declare function makeSanitiserFactory(_: {
|
|
312
|
+
tables: TableDefs;
|
|
313
|
+
guardCol: GuardCol;
|
|
314
|
+
throws: false;
|
|
315
|
+
}): (guardVal: string | number) => (expr: string) => Result<string>;
|
|
316
|
+
declare function makeSanitiserFactory(_: {
|
|
317
|
+
tables: TableDefs;
|
|
318
|
+
guardCol: GuardCol;
|
|
319
|
+
throws?: true;
|
|
320
|
+
}): (guardVal: string | number) => (expr: string) => string;
|
|
321
|
+
//#endregion
|
|
322
|
+
export { defineTables, makeSanitiserFactory, outputSql, parseSql, safeSanitise, sanitise, sanitiseSql, sanitiserFactory };
|