agent-sql 0.1.2 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -8,10 +8,14 @@ agent-sql works by fully parsing the supplied SQL query into an AST.
8
8
  The grammar ONLY accepts `SELECT` statements. Anything else is an error.
9
9
  CTEs and other complex things that we aren't confident of securing: error.
10
10
 
11
- Then we ensure that that the needed tenant table is somewhere in the query,
12
- and add a `WHERE` clause ensuring that only values from the supplied ID are returned.
11
+ It ensures that that the needed tenant table is somewhere in the query,
12
+ and adds a `WHERE` clause ensuring that only values from the supplied ID are returned.
13
+ Then it checks that the tables and `JOIN`s follow the schema, preventing sneaky joins.
13
14
 
14
- Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785). And [Cloudflare](https://x.com/thomas_ankcorn/status/2033931057133748330).
15
+ Finally, we throw in a `LIMIT` clause (configurable) to prevent accidental LLM denial-of-service.
16
+
17
+ Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785).
18
+ And [Cloudflare](https://x.com/thomas_ankcorn/status/2033931057133748330).
15
19
 
16
20
  ## Quickstart
17
21
 
@@ -20,33 +24,47 @@ npm install agent-sql
20
24
  ```
21
25
 
22
26
  ```ts
23
- import { sanitise } from "agent-sql";
27
+ import { agentSql } from "agent-sql";
24
28
 
25
- const sql = sanitise(`SELECT id, name FROM users WHERE status = 'active' LIMIT 10`, {
26
- tables: { users: {} },
27
- where: { table: "users", col: "tenant_id", value: "acme" },
28
- });
29
+ const sql = agentSql(`SELECT * FROM msg`, "msg.user_id", 123);
29
30
 
30
31
  console.log(sql);
31
- // SELECT id, name
32
- // FROM users
33
- // WHERE (users.tenant_id = 'acme' AND status = 'active')
34
- // LIMIT 10
32
+ // SELECT *
33
+ // FROM msg
34
+ // WHERE msg.user_id = 123
35
+ // LIMIT 10000
35
36
  ```
36
37
 
37
- Or, more usefully:
38
+ `agent-sql` parses the SQL, enforces a mandatory equality filter on the given column as the outermost `AND` condition (so it cannot be short-circuited by agent-supplied `OR` clauses), and returns the sanitised SQL string.
39
+
40
+ ## Usage
41
+
42
+ ### Define once, use many times
43
+
44
+ The simple approach above is enough to get started.
45
+ But since no schema is provided, `JOIN`s will be blocked.
46
+ A schema can be passed to `agentSql`, but typically you'll want to set it up once and re-use.
38
47
 
39
48
  ```ts
40
- import { sanitiserFactory } from "agent-sql";
49
+ import { createAgentSql, defineSchema } from "agent-sql";
41
50
  import { tool } from "ai";
42
51
  import { sql } from "drizzle-orm";
43
52
  import { db } from "@/db";
44
53
 
45
- function makeSqlTool(orgId: string) {
54
+ // Define your schema.
55
+ // Only the tables listed will be permitted
56
+ // Joins can only use the FKs defined here
57
+ const schema = defineSchema({
58
+ user: { id },
59
+ msg: { userId: { user: "id" } },
60
+ });
61
+
62
+ function makeSqlTool(userId: string) {
46
63
  // Create a sanitiser function for this tenant
47
- const sanitise = sanitiserFactory({
48
- tables: {},
49
- where: { table: "org", col: "id", value: orgId },
64
+ const agentSql = createAgentSql({
65
+ column: "user.id",
66
+ value: userId,
67
+ schema,
50
68
  });
51
69
 
52
70
  return tool({
@@ -55,7 +73,7 @@ function makeSqlTool(orgId: string) {
55
73
  execute: async ({ query }) => {
56
74
  // The LLM can pass any query it likes, we'll sanitise it if possible
57
75
  // and return helpful error messages if not
58
- const sanitised = sanitise(query);
76
+ const sanitised = agentSql(query);
59
77
  // Now we can throw that straight at the db and be confident it'll only
60
78
  // return data from the specified tenant
61
79
  return db.execute(sql.raw(sanitised));
@@ -64,7 +82,33 @@ function makeSqlTool(orgId: string) {
64
82
  }
65
83
  ```
66
84
 
67
- `agent-sql` parses the SQL, enforces a mandatory equality filter on the given column as the outermost `AND` condition (so it cannot be short-circuited by agent-supplied `OR` clauses), and returns the sanitised SQL string.
85
+ ### It works with Drizzle
86
+
87
+ If you're using Drizzle, you can skip the schema step and use the one you already have!
88
+
89
+ Just pass it through, and `agentSql` will respect your schema.
90
+
91
+ ```ts
92
+ import { defineSchemaFromDrizzle } from "agent-sql/drizzle";
93
+ import * as drizzleSchema from "@/db/schema";
94
+
95
+ const schema = defineSchemaFromDrizzle(drizzleSchema);
96
+
97
+ // The rest as before...
98
+ const agentSql = createAgentSql({
99
+ column: "user.id",
100
+ value: userId,
101
+ schema,
102
+ });
103
+ ```
104
+
105
+ You can also exclude tables if you don't want agents to see them:
106
+
107
+ ```ts
108
+ const schema = defineSchemaFromDrizzle(drizzleSchema, {
109
+ exclude: ["api_keys"],
110
+ });
111
+ ```
68
112
 
69
113
  ## Development
70
114
 
@@ -0,0 +1,22 @@
1
+ import { t as Schema } from "./joins-Cu_0yAgN.mjs";
2
+ import { Table, TableConfig } from "drizzle-orm";
3
+
4
+ //#region src/drizzle.d.ts
5
+ /** Extract all table name strings from a drizzle schema module. */
6
+ type TableNames<T> = { [K in keyof T]: T[K] extends Table<infer C extends TableConfig> ? C["name"] : never }[keyof T];
7
+ /**
8
+ * Build an agent-sql schema from a drizzle-orm schema module.
9
+ *
10
+ * Usage:
11
+ * ```ts
12
+ * import * as drizzleSchema from "./schema";
13
+ * const schema = defineSchemaFromDrizzle(drizzleSchema);
14
+ * ```
15
+ */
16
+ declare function defineSchemaFromDrizzle<T extends Record<string, unknown>>(drizzleSchema: T, {
17
+ exclude
18
+ }?: {
19
+ exclude?: TableNames<T>[];
20
+ }): Schema;
21
+ //#endregion
22
+ export { defineSchemaFromDrizzle };
@@ -0,0 +1,66 @@
1
+ import { getTableName, isTable } from "drizzle-orm";
2
+ //#region src/drizzle.ts
3
+ /**
4
+ * Convert a camelCase string to snake_case.
5
+ * This matches drizzle-orm's internal conversion for columns without explicit DB names.
6
+ */
7
+ function camelToSnake(input) {
8
+ return (input.replace(/['\u2019]/g, "").match(/[\da-z]+|[A-Z]+(?![a-z])|[A-Z][\da-z]+/g) ?? []).map((w) => w.toLowerCase()).join("_");
9
+ }
10
+ /** Get the DB column name, applying camelCase → snake_case when the key was used as name. */
11
+ function getDbColumnName(col) {
12
+ return col.keyAsName ? camelToSnake(col.name) : col.name;
13
+ }
14
+ const ColumnsSymbol = Symbol.for("drizzle:Columns");
15
+ const InlineForeignKeysSymbol = Symbol.for("drizzle:PgInlineForeignKeys");
16
+ /**
17
+ * Build an agent-sql schema from a drizzle-orm schema module.
18
+ *
19
+ * Usage:
20
+ * ```ts
21
+ * import * as drizzleSchema from "./schema";
22
+ * const schema = defineSchemaFromDrizzle(drizzleSchema);
23
+ * ```
24
+ */
25
+ function defineSchemaFromDrizzle(drizzleSchema, { exclude } = {}) {
26
+ const schema = {};
27
+ const excluded = new Set(exclude);
28
+ const tables = [];
29
+ for (const value of Object.values(drizzleSchema)) if (isTable(value) && !excluded.has(getTableName(value))) tables.push(value);
30
+ const fkMap = /* @__PURE__ */ new Map();
31
+ for (const table of tables) {
32
+ const tableName = getTableName(table);
33
+ const fks = table[InlineForeignKeysSymbol] ?? [];
34
+ const colFks = /* @__PURE__ */ new Map();
35
+ for (const fk of fks) {
36
+ const ref = fk.reference();
37
+ const fromCol = ref.columns[0];
38
+ const foreignTableName = getTableName(ref.foreignTable);
39
+ if (excluded.has(foreignTableName)) colFks.set(getDbColumnName(fromCol), "excluded");
40
+ else {
41
+ const toCol = ref.foreignColumns[0];
42
+ colFks.set(getDbColumnName(fromCol), {
43
+ ft: foreignTableName,
44
+ fc: getDbColumnName(toCol)
45
+ });
46
+ }
47
+ }
48
+ fkMap.set(tableName, colFks);
49
+ }
50
+ for (const table of tables) {
51
+ const tableName = getTableName(table);
52
+ const columns = Object.values(table[ColumnsSymbol]);
53
+ const colFks = fkMap.get(tableName);
54
+ const tableSchema = {};
55
+ for (const col of columns) {
56
+ const dbName = getDbColumnName(col);
57
+ const fk = colFks.get(dbName);
58
+ if (fk === "excluded") continue;
59
+ tableSchema[dbName] = fk ?? null;
60
+ }
61
+ schema[tableName] = tableSchema;
62
+ }
63
+ return schema;
64
+ }
65
+ //#endregion
66
+ export { defineSchemaFromDrizzle };
package/dist/index.d.mts CHANGED
@@ -1,267 +1,19 @@
1
- //#region src/ast.d.ts
2
- interface SelectStatement {
3
- readonly type: "select";
4
- distinct: Distinct | DistinctOn | null;
5
- columns: Column[];
6
- from: SelectFrom;
7
- joins: JoinClause[];
8
- where: WhereRoot | null;
9
- groupBy: GroupByClause | null;
10
- having: HavingClause | null;
11
- orderBy: OrderByClause | null;
12
- limit: LimitClause | null;
13
- offset: OffsetClause | null;
14
- }
15
- interface Distinct {
16
- readonly type: "distinct";
17
- }
18
- /** PostgreSQL: DISTINCT ON (col1, col2, ...) */
19
- interface DistinctOn {
20
- readonly type: "distinct_on";
21
- columns: WhereValue[];
22
- }
23
- interface GroupByClause {
24
- readonly type: "group_by";
25
- items: WhereValue[];
26
- }
27
- interface HavingClause {
28
- readonly type: "having";
29
- expr: WhereExpr;
30
- }
31
- interface OrderByClause {
32
- readonly type: "order_by";
33
- items: OrderByItem[];
34
- }
35
- type SortDirection = "asc" | "desc";
36
- type NullsOrder = "nulls_first" | "nulls_last";
37
- interface OrderByItem {
38
- readonly type: "order_by_item";
39
- expr: WhereValue;
40
- direction?: SortDirection;
41
- nulls?: NullsOrder;
42
- }
43
- interface OffsetClause {
44
- readonly type: "offset";
45
- value: number;
46
- }
47
- type JoinType = "inner" | "inner_outer" | "left" | "left_outer" | "right" | "right_outer" | "full" | "full_outer" | "cross" | "natural";
48
- type JoinCondition = {
49
- readonly type: "join_on";
50
- expr: WhereExpr;
51
- } | {
52
- readonly type: "join_using";
53
- columns: string[];
54
- };
55
- interface JoinClause {
56
- readonly type: "join";
57
- joinType: JoinType;
58
- table: TableRef;
59
- condition: JoinCondition | null;
60
- }
61
- type SelectFrom = {
62
- readonly type: "select_from";
63
- table: TableRef;
64
- };
65
- type LimitClause = {
66
- readonly type: "limit";
67
- value: number;
68
- };
69
- type WhereRoot = {
70
- readonly type: "where_root";
71
- inner: WhereExpr;
72
- };
73
- type WhereExpr = WhereAnd | WhereOr | WhereNot | WhereComparison | WhereIsNull | WhereIsBool | WhereBetween | WhereIn | WhereLike | WhereTsMatch;
74
- interface WhereAnd {
75
- readonly type: "where_and";
76
- left: WhereExpr;
77
- right: WhereExpr;
78
- }
79
- interface WhereOr {
80
- readonly type: "where_or";
81
- left: WhereExpr;
82
- right: WhereExpr;
83
- }
84
- type ComparisonOperator = "=" | "<>" | "!=" | "<" | ">" | "<=" | ">=";
85
- interface WhereNot {
86
- readonly type: "where_not";
87
- expr: WhereExpr;
88
- }
89
- interface WhereIsNull {
90
- readonly type: "where_is_null";
91
- not: boolean;
92
- expr: WhereValue;
93
- }
94
- type IsBoolTarget = boolean | "unknown";
95
- interface WhereIsBool {
96
- readonly type: "where_is_bool";
97
- not: boolean;
98
- expr: WhereValue;
99
- target: IsBoolTarget;
100
- }
101
- interface WhereBetween {
102
- readonly type: "where_between";
103
- not: boolean;
104
- expr: WhereValue;
105
- low: WhereValue;
106
- high: WhereValue;
107
- }
108
- interface WhereIn {
109
- readonly type: "where_in";
110
- not: boolean;
111
- expr: WhereValue;
112
- list: WhereValue[];
113
- }
114
- /** "ilike" is PostgreSQL-specific (case-insensitive LIKE) */
115
- type LikeOp = "like" | "ilike";
116
- interface WhereLike {
117
- readonly type: "where_like";
118
- not: boolean;
119
- op: LikeOp;
120
- expr: WhereValue;
121
- pattern: WhereValue;
122
- }
123
- /** PostgreSQL: JSONB operators */
124
- type JsonbOp = "->" | "->>" | "#>" | "#>>" | "?" | "?|" | "?&" | "@>";
125
- /** PostgreSQL: JSONB binary operator expression */
126
- interface WhereJsonbOp {
127
- readonly type: "where_jsonb_op";
128
- op: JsonbOp;
129
- left: WhereValue;
130
- right: WhereValue;
131
- }
132
- /** PostgreSQL: text search match (@@) */
133
- interface WhereTsMatch {
134
- readonly type: "where_ts_match";
135
- left: WhereValue;
136
- right: WhereValue;
137
- }
138
- /** pgvector: distance operators */
139
- type PgvectorOp = "<->" | "<#>" | "<=>" | "<+>" | "<~>" | "<%>";
140
- /** pgvector: distance operator expression */
141
- interface WherePgvectorOp {
142
- readonly type: "where_pgvector_op";
143
- op: PgvectorOp;
144
- left: WhereValue;
145
- right: WhereValue;
146
- }
147
- type ArithOp = "+" | "-" | "*" | "/" | "%" | "||";
148
- interface WhereArith {
149
- readonly type: "where_arith";
150
- op: ArithOp;
151
- left: WhereValue;
152
- right: WhereValue;
153
- }
154
- interface WhereUnaryMinus {
155
- readonly type: "where_unary_minus";
156
- expr: WhereValue;
157
- }
158
- interface CaseWhen {
159
- condition: WhereValue;
160
- result: WhereValue;
161
- }
162
- interface CaseExpr {
163
- readonly type: "case_expr";
164
- subject: WhereValue | null;
165
- whens: CaseWhen[];
166
- else: WhereValue | null;
167
- }
168
- interface CastExpr {
169
- readonly type: "cast_expr";
170
- expr: WhereValue;
171
- typeName: string;
172
- }
173
- type WhereValue = {
174
- readonly type: "where_value";
175
- kind: "string";
176
- value: string;
177
- } | {
178
- readonly type: "where_value";
179
- kind: "integer";
180
- value: number;
181
- } | {
182
- readonly type: "where_value";
183
- kind: "float";
184
- value: number;
185
- } | {
186
- readonly type: "where_value";
187
- kind: "bool";
188
- value: boolean;
189
- } | {
190
- readonly type: "where_value";
191
- kind: "null";
192
- } | {
193
- readonly type: "where_value";
194
- kind: "column_ref";
195
- ref: ColumnRef;
196
- } | {
197
- readonly type: "where_value";
198
- kind: "func_call";
199
- func: FuncCall;
200
- } | WhereArith | WhereUnaryMinus | WhereJsonbOp | WherePgvectorOp | CaseExpr | CastExpr;
201
- interface WhereComparison {
202
- readonly type: "where_comparison";
203
- operator: ComparisonOperator;
204
- left: WhereValue;
205
- right: WhereValue;
206
- }
207
- interface ColumnRef {
208
- readonly type: "column_ref";
209
- schema?: string;
210
- table?: string;
211
- name: string;
212
- }
213
- type FuncCallArg = {
214
- kind: "wildcard";
215
- } | {
216
- kind: "args";
217
- distinct: boolean;
218
- args: WhereValue[];
219
- };
220
- interface FuncCall {
221
- readonly type: "func_call";
222
- name: string;
223
- args: FuncCallArg;
224
- }
225
- interface ColumnExpr {
226
- readonly type: "column_expr";
227
- kind: "wildcard" | "qualified_wildcard" | "expr";
228
- table?: string;
229
- expr?: WhereValue;
230
- }
231
- interface Column {
232
- readonly type: "column";
233
- expr: ColumnExpr;
234
- alias?: Alias;
235
- }
236
- interface Alias {
237
- readonly type: "alias";
238
- name: string;
239
- }
240
- interface TableRef {
241
- readonly type: "table_ref";
1
+ import { i as SelectStatement, n as defineSchema, r as Result, t as Schema } from "./joins-Cu_0yAgN.mjs";
2
+
3
+ //#region src/guard.d.ts
4
+ type GuardVal = string | number;
5
+ interface GuardCol {
242
6
  schema?: string;
243
- name: string;
7
+ table: string;
8
+ column: string;
244
9
  }
245
- //#endregion
246
- //#region src/result.d.ts
247
- type Failure<E = unknown> = {
248
- ok: false;
249
- error: E;
250
- unwrap(): never;
10
+ type WhereGuard = GuardCol & {
11
+ value: GuardVal;
251
12
  };
252
- type Success<T = void> = {
253
- ok: true;
254
- data: T;
255
- unwrap(): T;
256
- };
257
- type Result<T = void, E = Error> = Failure<E> | Success<T>;
13
+ declare function addGuards(ast: SelectStatement, guard: WhereGuard, limit?: number): Result<SelectStatement>;
258
14
  //#endregion
259
- //#region src/graph.d.ts
260
- type TableDefs = ReturnType<typeof defineTables>;
261
- declare function defineTables<T extends { [Table in keyof T]: Record<string, null | { [FK in keyof T & string]: {
262
- ft: FK;
263
- fc: keyof T[FK] & string;
264
- } }[keyof T & string]> }>(tables: T): T;
15
+ //#region src/namespec.d.ts
16
+ type OneOrTwoDots<S extends string> = S extends `${infer A}.${infer B}.${infer C}` ? A extends `${string}.${string}` ? never : B extends `${string}.${string}` ? never : C extends `${string}.${string}` ? never : S : S extends `${infer A}.${infer B}` ? A extends `${string}.${string}` ? never : B extends `${string}.${string}` ? never : S : never;
265
17
  //#endregion
266
18
  //#region src/output.d.ts
267
19
  declare function outputSql(ast: SelectStatement): string;
@@ -269,54 +21,41 @@ declare function outputSql(ast: SelectStatement): string;
269
21
  //#region src/parse.d.ts
270
22
  declare function parseSql(expr: string): Result<SelectStatement>;
271
23
  //#endregion
272
- //#region src/sanitise.d.ts
273
- interface GuardCol {
274
- schema?: string;
275
- table: string;
276
- col: string;
277
- }
278
- interface WhereGuard {
279
- schema?: string;
280
- table: string;
281
- col: string;
282
- value: string | number;
283
- }
284
- declare function sanitiseSql(ast: SelectStatement, guard: WhereGuard): Result<SelectStatement>;
285
- //#endregion
286
24
  //#region src/index.d.ts
287
- declare function sanitise(sql: string, {
288
- tables,
289
- where
290
- }: {
291
- tables: TableDefs;
292
- where: WhereGuard;
25
+ declare function agentSql<S extends string>(sql: string, column: S & OneOrTwoDots<S>, value: GuardVal, {
26
+ schema,
27
+ limit
28
+ }?: {
29
+ schema?: Schema;
30
+ limit?: number;
293
31
  }): string;
294
- declare function safeSanitise(sql: string, {
295
- tables,
296
- where
297
- }: {
298
- tables: TableDefs;
299
- where: WhereGuard;
300
- }): Result<string>;
301
- declare function sanitiserFactory(_: {
302
- tables: TableDefs;
303
- where: WhereGuard;
32
+ declare function createAgentSql<S extends string>(_: {
33
+ column: S & OneOrTwoDots<S>;
34
+ value: GuardVal;
35
+ schema?: Schema;
36
+ limit?: number;
304
37
  throws: false;
305
38
  }): (expr: string) => Result<string>;
306
- declare function sanitiserFactory(_: {
307
- tables: TableDefs;
308
- where: WhereGuard;
39
+ declare function createAgentSql<S extends string>(_: {
40
+ column: S & OneOrTwoDots<S>;
41
+ value: GuardVal;
42
+ schema?: Schema;
43
+ limit?: number;
309
44
  throws?: true;
310
45
  }): (expr: string) => string;
311
- declare function makeSanitiserFactory(_: {
312
- tables: TableDefs;
313
- guardCol: GuardCol;
46
+ declare function createAgentSql<S extends string>(_: {
47
+ column: S & OneOrTwoDots<S>;
48
+ value?: undefined;
49
+ schema?: Schema;
50
+ limit?: number;
314
51
  throws: false;
315
- }): (guardVal: string | number) => (expr: string) => Result<string>;
316
- declare function makeSanitiserFactory(_: {
317
- tables: TableDefs;
318
- guardCol: GuardCol;
52
+ }): (guardVal: GuardVal) => (expr: string) => Result<string>;
53
+ declare function createAgentSql<S extends string>(_: {
54
+ column: S & OneOrTwoDots<S>;
55
+ value?: undefined;
56
+ schema?: Schema;
57
+ limit?: number;
319
58
  throws?: true;
320
- }): (guardVal: string | number) => (expr: string) => string;
59
+ }): (guardVal: GuardVal) => (expr: string) => string;
321
60
  //#endregion
322
- export { defineTables, makeSanitiserFactory, outputSql, parseSql, safeSanitise, sanitise, sanitiseSql, sanitiserFactory };
61
+ export { agentSql, createAgentSql, defineSchema, outputSql, parseSql, addGuards as sanitiseSql };
package/dist/index.mjs CHANGED
@@ -6,65 +6,9 @@ var ParseError = class extends Error {
6
6
  var SanitiseError = class extends Error {
7
7
  type = "sanitise_error";
8
8
  };
9
- //#endregion
10
- //#region src/result.ts
11
- function Err(error) {
12
- return {
13
- ok: false,
14
- error,
15
- unwrap() {
16
- throw new Error(String(error));
17
- }
18
- };
19
- }
20
- function Ok(data) {
21
- return {
22
- ok: true,
23
- data,
24
- unwrap() {
25
- return data;
26
- }
27
- };
28
- }
29
- function returnOrThrow(result, throws) {
30
- if (!throws) return result;
31
- if (result.ok) return result.data;
32
- throw result.error;
33
- }
34
- //#endregion
35
- //#region src/graph.ts
36
- function defineTables(tables) {
37
- return tables;
38
- }
39
- function checkJoins(ast, tables) {
40
- if (!(ast.from.table.name in tables)) return Err(new SanitiseError(`Table ${ast.from.table.name} is not allowed`));
41
- for (const join of ast.joins) {
42
- const joinSettings = tables[join.table.name];
43
- if (joinSettings === void 0) return Err(new SanitiseError(`Table ${join.table.name} is not allowed`));
44
- if (join.condition === null || join.condition.type === "join_using" || join.condition.expr.type !== "where_comparison" || join.condition.expr.operator !== "=" || join.condition.expr.left.type !== "where_value" || join.condition.expr.left.kind !== "column_ref" || join.condition.expr.right.type !== "where_value" || join.condition.expr.right.kind !== "column_ref") return Err(new SanitiseError("Only JOIN ON column_ref = column_ref supported"));
45
- const { joining, foreign } = getJoinTableRef(join.table.name, join.condition.expr.left.ref, join.condition.expr.right.ref);
46
- const joinTableCol = joinSettings[joining.name];
47
- if (joinTableCol === void 0) return Err(new SanitiseError(`Tried to join using ${join.table.name}.${joining.name}`));
48
- if (joinTableCol === null) {
49
- const foreignTableSettings = tables[foreign.table];
50
- if (foreignTableSettings === void 0) return Err(new SanitiseError(`Table ${foreign.name} is not allowed`));
51
- const foreignCol = foreignTableSettings[foreign.name];
52
- if (foreignCol === void 0 || foreignCol === null) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
53
- if (joining.table !== foreignCol.ft || joining.name !== foreignCol.fc) return Err(new SanitiseError(`Tried to join using ${joining.table}.${joining.name}`));
54
- } else if (foreign.table !== joinTableCol.ft || foreign.name !== joinTableCol.fc) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
55
- }
56
- return Ok(ast);
57
- }
58
- function getJoinTableRef(joinTableName, left, right) {
59
- if (left.table === joinTableName) return {
60
- joining: left,
61
- foreign: right
62
- };
63
- return {
64
- joining: right,
65
- foreign: left
66
- };
67
- }
9
+ var AgentSqlError = class extends Error {
10
+ type = "agent_sql_error";
11
+ };
68
12
  //#endregion
69
13
  //#region src/utils.ts
70
14
  function unreachable(x) {
@@ -276,6 +220,159 @@ function handleColumnExpr(node) {
276
220
  }
277
221
  }
278
222
  //#endregion
223
+ //#region src/result.ts
224
+ function Err(error) {
225
+ return {
226
+ ok: false,
227
+ error,
228
+ unwrap() {
229
+ throw new Error(String(error));
230
+ }
231
+ };
232
+ }
233
+ function Ok(data) {
234
+ return {
235
+ ok: true,
236
+ data,
237
+ unwrap() {
238
+ return data;
239
+ }
240
+ };
241
+ }
242
+ function returnOrThrow(result, throws) {
243
+ if (!throws) return result;
244
+ if (result.ok) return result.data;
245
+ throw result.error;
246
+ }
247
+ //#endregion
248
+ //#region src/guard.ts
249
+ const DEFAULT_LIMIT = 1e4;
250
+ function addGuards(ast, guard, limit = DEFAULT_LIMIT) {
251
+ const ast2 = addWhereGuard(ast, guard);
252
+ if (!ast2.ok) return ast2;
253
+ return addLimitGuard(ast2.data, limit);
254
+ }
255
+ function addWhereGuard(ast, guard) {
256
+ const { schema, table, column, value } = guard;
257
+ const tableRef = {
258
+ type: "table_ref",
259
+ schema,
260
+ name: table
261
+ };
262
+ if (!checkIfTableRefExists(ast, tableRef)) return Err(new SanitiseError(`The table '${handleTableRef(tableRef)}' must appear in the FROM or JOIN clauses.`));
263
+ const newClause = {
264
+ type: "where_comparison",
265
+ operator: "=",
266
+ left: {
267
+ type: "where_value",
268
+ kind: "column_ref",
269
+ ref: {
270
+ type: "column_ref",
271
+ schema,
272
+ table,
273
+ name: column
274
+ }
275
+ },
276
+ right: typeof value === "string" ? {
277
+ type: "where_value",
278
+ kind: "string",
279
+ value
280
+ } : {
281
+ type: "where_value",
282
+ kind: "integer",
283
+ value
284
+ }
285
+ };
286
+ return Ok({
287
+ ...ast,
288
+ where: ast.where ? {
289
+ type: "where_root",
290
+ inner: {
291
+ type: "where_and",
292
+ left: newClause,
293
+ right: ast.where.inner
294
+ }
295
+ } : {
296
+ type: "where_root",
297
+ inner: newClause
298
+ }
299
+ });
300
+ }
301
+ function addLimitGuard(ast, limit) {
302
+ const limitClause = {
303
+ type: "limit",
304
+ value: limit
305
+ };
306
+ if (ast.limit === null) return Ok({
307
+ ...ast,
308
+ limit: limitClause
309
+ });
310
+ if (ast.limit.value > limit) return Ok({
311
+ ...ast,
312
+ limit: limitClause
313
+ });
314
+ return Ok(ast);
315
+ }
316
+ function checkIfTableRefExists(ast, tableRef) {
317
+ return tableEquals(ast.from.table, tableRef) || ast.joins.some((join) => tableEquals(join.table, tableRef));
318
+ }
319
+ function tableEquals(a, b) {
320
+ return a.schema == b.schema && a.name == b.name;
321
+ }
322
+ //#endregion
323
+ //#region src/joins.ts
324
+ function defineSchema(schema) {
325
+ return schema;
326
+ }
327
+ function checkJoins(ast, schema) {
328
+ if (schema === void 0) {
329
+ if (ast.joins.length > 0) return Err(new SanitiseError("No joins allowed when using simple API without schema."));
330
+ return Ok(ast);
331
+ }
332
+ if (!(ast.from.table.name in schema)) return Err(new SanitiseError(`Table ${ast.from.table.name} is not allowed`));
333
+ for (const join of ast.joins) {
334
+ const joinSettings = schema[join.table.name];
335
+ if (joinSettings === void 0) return Err(new SanitiseError(`Table ${join.table.name} is not allowed`));
336
+ if (join.condition === null || join.condition.type === "join_using" || join.condition.expr.type !== "where_comparison" || join.condition.expr.operator !== "=" || join.condition.expr.left.type !== "where_value" || join.condition.expr.left.kind !== "column_ref" || join.condition.expr.right.type !== "where_value" || join.condition.expr.right.kind !== "column_ref") return Err(new SanitiseError("Only JOIN ON column_ref = column_ref supported"));
337
+ const { joining, foreign } = getJoinTableRef(join.table.name, join.condition.expr.left.ref, join.condition.expr.right.ref);
338
+ const joinTableCol = joinSettings[joining.name];
339
+ if (joinTableCol === void 0) return Err(new SanitiseError(`Tried to join using ${join.table.name}.${joining.name}`));
340
+ if (joinTableCol === null) {
341
+ const foreignTableSettings = schema[foreign.table];
342
+ if (foreignTableSettings === void 0) return Err(new SanitiseError(`Table ${foreign.name} is not allowed`));
343
+ const foreignCol = foreignTableSettings[foreign.name];
344
+ if (foreignCol === void 0 || foreignCol === null) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
345
+ if (joining.table !== foreignCol.ft || joining.name !== foreignCol.fc) return Err(new SanitiseError(`Tried to join using ${joining.table}.${joining.name}`));
346
+ } else if (foreign.table !== joinTableCol.ft || foreign.name !== joinTableCol.fc) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
347
+ }
348
+ return Ok(ast);
349
+ }
350
+ function getJoinTableRef(joinTableName, left, right) {
351
+ if (left.table === joinTableName) return {
352
+ joining: left,
353
+ foreign: right
354
+ };
355
+ return {
356
+ joining: right,
357
+ foreign: left
358
+ };
359
+ }
360
+ //#endregion
361
+ //#region src/namespec.ts
362
+ function getQualifiedColumnFromString(column) {
363
+ const [a, b, c] = column.split(".");
364
+ if (a === void 0 || b === void 0) throw new AgentSqlError(`Malformed column string: '${column}'. Pass 'table.column'.`);
365
+ if (c === void 0) return Ok({
366
+ table: a,
367
+ column: b
368
+ });
369
+ return Ok({
370
+ schema: a,
371
+ table: b,
372
+ column: c
373
+ });
374
+ }
375
+ //#endregion
279
376
  //#region src/sql.ohm-bundle.js
280
377
  const result = makeRecipe([
281
378
  "grammar",
@@ -6333,114 +6430,63 @@ function parseSql(expr) {
6333
6430
  }
6334
6431
  }
6335
6432
  //#endregion
6336
- //#region src/sanitise.ts
6337
- function sanitiseSql(ast, guard) {
6338
- const { schema, table, col, value } = guard;
6339
- const tableRef = {
6340
- type: "table_ref",
6341
- schema,
6342
- name: table
6343
- };
6344
- if (!checkIfTableRefExists(ast, tableRef)) return Err(new SanitiseError(`The table '${handleTableRef(tableRef)}' must appear in the FROM or JOIN clauses.`));
6345
- const newClause = {
6346
- type: "where_comparison",
6347
- operator: "=",
6348
- left: {
6349
- type: "where_value",
6350
- kind: "column_ref",
6351
- ref: {
6352
- type: "column_ref",
6353
- schema,
6354
- table,
6355
- name: col
6356
- }
6357
- },
6358
- right: typeof value === "string" ? {
6359
- type: "where_value",
6360
- kind: "string",
6361
- value
6362
- } : {
6363
- type: "where_value",
6364
- kind: "integer",
6365
- value
6366
- }
6367
- };
6368
- return Ok({
6369
- ...ast,
6370
- where: ast.where ? {
6371
- type: "where_root",
6372
- inner: {
6373
- type: "where_and",
6374
- left: newClause,
6375
- right: ast.where.inner
6376
- }
6377
- } : {
6378
- type: "where_root",
6379
- inner: newClause
6380
- }
6381
- });
6382
- }
6383
- function checkIfTableRefExists(ast, tableRef) {
6384
- return tableEquals(ast.from.table, tableRef) || ast.joins.some((join) => tableEquals(join.table, tableRef));
6385
- }
6386
- function tableEquals(a, b) {
6387
- return a.schema == b.schema && a.name == b.name;
6388
- }
6389
- //#endregion
6390
6433
  //#region src/index.ts
6391
- function privateSanitise(sql, { tables, where, throws }) {
6392
- const ast = parseSql(sql);
6393
- if (!ast.ok) return returnOrThrow(ast, throws);
6394
- const ast2 = checkJoins(ast.data, tables);
6395
- if (!ast2.ok) return returnOrThrow(ast2, throws);
6396
- const san = sanitiseSql(ast2.data, where);
6397
- if (!san.ok) return returnOrThrow(san, throws);
6398
- const res = outputSql(san.data);
6399
- if (throws) return res;
6400
- return Ok(res);
6401
- }
6402
- function sanitise(sql, { tables, where }) {
6403
- return privateSanitise(sql, {
6404
- tables,
6405
- where,
6434
+ function agentSql(sql, column, value, { schema, limit } = {}) {
6435
+ return privateAgentSql(sql, {
6436
+ column,
6437
+ value,
6438
+ schema,
6439
+ limit,
6406
6440
  throws: true
6407
6441
  });
6408
6442
  }
6409
- function safeSanitise(sql, { tables, where }) {
6410
- return privateSanitise(sql, {
6411
- tables,
6412
- where,
6413
- throws: false
6414
- });
6415
- }
6416
- function sanitiserFactory({ tables, where, throws = true }) {
6417
- return (expr) => throws ? privateSanitise(expr, {
6418
- tables,
6419
- where,
6443
+ function createAgentSql({ column, schema, value, limit, throws = true }) {
6444
+ if (value !== void 0) return (expr) => throws ? privateAgentSql(expr, {
6445
+ column,
6446
+ value,
6447
+ schema,
6448
+ limit,
6420
6449
  throws
6421
- }) : privateSanitise(expr, {
6422
- tables,
6423
- where,
6450
+ }) : privateAgentSql(expr, {
6451
+ column,
6452
+ value,
6453
+ schema,
6454
+ limit,
6424
6455
  throws
6425
6456
  });
6426
- }
6427
- function makeSanitiserFactory({ tables, guardCol, throws = true }) {
6428
6457
  function factory(guardVal) {
6429
- const where = {
6430
- ...guardCol,
6431
- value: guardVal
6432
- };
6433
- return throws ? sanitiserFactory({
6434
- tables,
6435
- where,
6458
+ return throws ? createAgentSql({
6459
+ column,
6460
+ schema,
6461
+ value: guardVal,
6462
+ limit,
6436
6463
  throws
6437
- }) : sanitiserFactory({
6438
- tables,
6439
- where,
6464
+ }) : createAgentSql({
6465
+ column,
6466
+ schema,
6467
+ value: guardVal,
6468
+ limit,
6440
6469
  throws
6441
6470
  });
6442
6471
  }
6443
6472
  return factory;
6444
6473
  }
6474
+ function privateAgentSql(sql, { column, value, schema, limit, throws }) {
6475
+ const guardCol = getQualifiedColumnFromString(column);
6476
+ if (!guardCol.ok) throw guardCol.error;
6477
+ const ast = parseSql(sql);
6478
+ if (!ast.ok) return returnOrThrow(ast, throws);
6479
+ const ast2 = checkJoins(ast.data, schema);
6480
+ if (!ast2.ok) return returnOrThrow(ast2, throws);
6481
+ const where = {
6482
+ ...guardCol.data,
6483
+ value
6484
+ };
6485
+ const san = addGuards(ast2.data, where, limit);
6486
+ if (!san.ok) return returnOrThrow(san, throws);
6487
+ const res = outputSql(san.data);
6488
+ if (throws) return res;
6489
+ return Ok(res);
6490
+ }
6445
6491
  //#endregion
6446
- export { defineTables, makeSanitiserFactory, outputSql, parseSql, safeSanitise, sanitise, sanitiseSql, sanitiserFactory };
6492
+ export { agentSql, createAgentSql, defineSchema, outputSql, parseSql, addGuards as sanitiseSql };
@@ -0,0 +1,266 @@
1
+ //#region src/ast.d.ts
2
+ interface SelectStatement {
3
+ readonly type: "select";
4
+ distinct: Distinct | DistinctOn | null;
5
+ columns: Column[];
6
+ from: SelectFrom;
7
+ joins: JoinClause[];
8
+ where: WhereRoot | null;
9
+ groupBy: GroupByClause | null;
10
+ having: HavingClause | null;
11
+ orderBy: OrderByClause | null;
12
+ limit: LimitClause | null;
13
+ offset: OffsetClause | null;
14
+ }
15
+ interface Distinct {
16
+ readonly type: "distinct";
17
+ }
18
+ /** PostgreSQL: DISTINCT ON (col1, col2, ...) */
19
+ interface DistinctOn {
20
+ readonly type: "distinct_on";
21
+ columns: WhereValue[];
22
+ }
23
+ interface GroupByClause {
24
+ readonly type: "group_by";
25
+ items: WhereValue[];
26
+ }
27
+ interface HavingClause {
28
+ readonly type: "having";
29
+ expr: WhereExpr;
30
+ }
31
+ interface OrderByClause {
32
+ readonly type: "order_by";
33
+ items: OrderByItem[];
34
+ }
35
+ type SortDirection = "asc" | "desc";
36
+ type NullsOrder = "nulls_first" | "nulls_last";
37
+ interface OrderByItem {
38
+ readonly type: "order_by_item";
39
+ expr: WhereValue;
40
+ direction?: SortDirection;
41
+ nulls?: NullsOrder;
42
+ }
43
+ interface OffsetClause {
44
+ readonly type: "offset";
45
+ value: number;
46
+ }
47
+ type JoinType = "inner" | "inner_outer" | "left" | "left_outer" | "right" | "right_outer" | "full" | "full_outer" | "cross" | "natural";
48
+ type JoinCondition = {
49
+ readonly type: "join_on";
50
+ expr: WhereExpr;
51
+ } | {
52
+ readonly type: "join_using";
53
+ columns: string[];
54
+ };
55
+ interface JoinClause {
56
+ readonly type: "join";
57
+ joinType: JoinType;
58
+ table: TableRef;
59
+ condition: JoinCondition | null;
60
+ }
61
+ type SelectFrom = {
62
+ readonly type: "select_from";
63
+ table: TableRef;
64
+ };
65
+ type LimitClause = {
66
+ readonly type: "limit";
67
+ value: number;
68
+ };
69
+ type WhereRoot = {
70
+ readonly type: "where_root";
71
+ inner: WhereExpr;
72
+ };
73
+ type WhereExpr = WhereAnd | WhereOr | WhereNot | WhereComparison | WhereIsNull | WhereIsBool | WhereBetween | WhereIn | WhereLike | WhereTsMatch;
74
+ interface WhereAnd {
75
+ readonly type: "where_and";
76
+ left: WhereExpr;
77
+ right: WhereExpr;
78
+ }
79
+ interface WhereOr {
80
+ readonly type: "where_or";
81
+ left: WhereExpr;
82
+ right: WhereExpr;
83
+ }
84
+ type ComparisonOperator = "=" | "<>" | "!=" | "<" | ">" | "<=" | ">=";
85
+ interface WhereNot {
86
+ readonly type: "where_not";
87
+ expr: WhereExpr;
88
+ }
89
+ interface WhereIsNull {
90
+ readonly type: "where_is_null";
91
+ not: boolean;
92
+ expr: WhereValue;
93
+ }
94
+ type IsBoolTarget = boolean | "unknown";
95
+ interface WhereIsBool {
96
+ readonly type: "where_is_bool";
97
+ not: boolean;
98
+ expr: WhereValue;
99
+ target: IsBoolTarget;
100
+ }
101
+ interface WhereBetween {
102
+ readonly type: "where_between";
103
+ not: boolean;
104
+ expr: WhereValue;
105
+ low: WhereValue;
106
+ high: WhereValue;
107
+ }
108
+ interface WhereIn {
109
+ readonly type: "where_in";
110
+ not: boolean;
111
+ expr: WhereValue;
112
+ list: WhereValue[];
113
+ }
114
+ /** "ilike" is PostgreSQL-specific (case-insensitive LIKE) */
115
+ type LikeOp = "like" | "ilike";
116
+ interface WhereLike {
117
+ readonly type: "where_like";
118
+ not: boolean;
119
+ op: LikeOp;
120
+ expr: WhereValue;
121
+ pattern: WhereValue;
122
+ }
123
+ /** PostgreSQL: JSONB operators */
124
+ type JsonbOp = "->" | "->>" | "#>" | "#>>" | "?" | "?|" | "?&" | "@>";
125
+ /** PostgreSQL: JSONB binary operator expression */
126
+ interface WhereJsonbOp {
127
+ readonly type: "where_jsonb_op";
128
+ op: JsonbOp;
129
+ left: WhereValue;
130
+ right: WhereValue;
131
+ }
132
+ /** PostgreSQL: text search match (@@) */
133
+ interface WhereTsMatch {
134
+ readonly type: "where_ts_match";
135
+ left: WhereValue;
136
+ right: WhereValue;
137
+ }
138
+ /** pgvector: distance operators */
139
+ type PgvectorOp = "<->" | "<#>" | "<=>" | "<+>" | "<~>" | "<%>";
140
+ /** pgvector: distance operator expression */
141
+ interface WherePgvectorOp {
142
+ readonly type: "where_pgvector_op";
143
+ op: PgvectorOp;
144
+ left: WhereValue;
145
+ right: WhereValue;
146
+ }
147
+ type ArithOp = "+" | "-" | "*" | "/" | "%" | "||";
148
+ interface WhereArith {
149
+ readonly type: "where_arith";
150
+ op: ArithOp;
151
+ left: WhereValue;
152
+ right: WhereValue;
153
+ }
154
+ interface WhereUnaryMinus {
155
+ readonly type: "where_unary_minus";
156
+ expr: WhereValue;
157
+ }
158
+ interface CaseWhen {
159
+ condition: WhereValue;
160
+ result: WhereValue;
161
+ }
162
+ interface CaseExpr {
163
+ readonly type: "case_expr";
164
+ subject: WhereValue | null;
165
+ whens: CaseWhen[];
166
+ else: WhereValue | null;
167
+ }
168
+ interface CastExpr {
169
+ readonly type: "cast_expr";
170
+ expr: WhereValue;
171
+ typeName: string;
172
+ }
173
+ type WhereValue = {
174
+ readonly type: "where_value";
175
+ kind: "string";
176
+ value: string;
177
+ } | {
178
+ readonly type: "where_value";
179
+ kind: "integer";
180
+ value: number;
181
+ } | {
182
+ readonly type: "where_value";
183
+ kind: "float";
184
+ value: number;
185
+ } | {
186
+ readonly type: "where_value";
187
+ kind: "bool";
188
+ value: boolean;
189
+ } | {
190
+ readonly type: "where_value";
191
+ kind: "null";
192
+ } | {
193
+ readonly type: "where_value";
194
+ kind: "column_ref";
195
+ ref: ColumnRef;
196
+ } | {
197
+ readonly type: "where_value";
198
+ kind: "func_call";
199
+ func: FuncCall;
200
+ } | WhereArith | WhereUnaryMinus | WhereJsonbOp | WherePgvectorOp | CaseExpr | CastExpr;
201
+ interface WhereComparison {
202
+ readonly type: "where_comparison";
203
+ operator: ComparisonOperator;
204
+ left: WhereValue;
205
+ right: WhereValue;
206
+ }
207
+ interface ColumnRef {
208
+ readonly type: "column_ref";
209
+ schema?: string;
210
+ table?: string;
211
+ name: string;
212
+ }
213
+ type FuncCallArg = {
214
+ kind: "wildcard";
215
+ } | {
216
+ kind: "args";
217
+ distinct: boolean;
218
+ args: WhereValue[];
219
+ };
220
+ interface FuncCall {
221
+ readonly type: "func_call";
222
+ name: string;
223
+ args: FuncCallArg;
224
+ }
225
+ interface ColumnExpr {
226
+ readonly type: "column_expr";
227
+ kind: "wildcard" | "qualified_wildcard" | "expr";
228
+ table?: string;
229
+ expr?: WhereValue;
230
+ }
231
+ interface Column {
232
+ readonly type: "column";
233
+ expr: ColumnExpr;
234
+ alias?: Alias;
235
+ }
236
+ interface Alias {
237
+ readonly type: "alias";
238
+ name: string;
239
+ }
240
+ interface TableRef {
241
+ readonly type: "table_ref";
242
+ schema?: string;
243
+ name: string;
244
+ }
245
+ //#endregion
246
+ //#region src/result.d.ts
247
+ type Failure<E = unknown> = {
248
+ ok: false;
249
+ error: E;
250
+ unwrap(): never;
251
+ };
252
+ type Success<T = void> = {
253
+ ok: true;
254
+ data: T;
255
+ unwrap(): T;
256
+ };
257
+ type Result<T = void, E = Error> = Failure<E> | Success<T>;
258
+ //#endregion
259
+ //#region src/joins.d.ts
260
+ type Schema = ReturnType<typeof defineSchema>;
261
+ declare function defineSchema<T extends { [Table in keyof T]: Record<string, null | { [FK in keyof T & string]: {
262
+ ft: FK;
263
+ fc: keyof T[FK] & string;
264
+ } }[keyof T & string]> }>(schema: T): T;
265
+ //#endregion
266
+ export { SelectStatement as i, defineSchema as n, Result as r, Schema as t };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "agent-sql",
3
- "version": "0.1.2",
3
+ "version": "0.2.0",
4
4
  "description": "A starter for creating a TypeScript package.",
5
5
  "homepage": "https://github.com/carderne/agent-sql#readme",
6
6
  "bugs": {
@@ -18,6 +18,7 @@
18
18
  "type": "module",
19
19
  "exports": {
20
20
  ".": "./dist/index.mjs",
21
+ "./drizzle": "./dist/drizzle.mjs",
21
22
  "./package.json": "./package.json"
22
23
  },
23
24
  "publishConfig": {
@@ -42,11 +43,20 @@
42
43
  "@types/pg": "^8.18.0",
43
44
  "@typescript/native-preview": "7.0.0-dev.20260316.1",
44
45
  "bumpp": "^11.0.1",
46
+ "drizzle-orm": "^0.45.1",
45
47
  "pg": "^8.20.0",
46
48
  "tsx": "^4.21.0",
47
49
  "typescript": "^5.9.3",
48
50
  "vite-plus": "^0.1.11"
49
51
  },
52
+ "peerDependencies": {
53
+ "drizzle-orm": ">=0.45"
54
+ },
55
+ "peerDependenciesMeta": {
56
+ "drizzle-orm": {
57
+ "optional": true
58
+ }
59
+ },
50
60
  "packageManager": "pnpm@10.32.1",
51
61
  "pnpm": {
52
62
  "overrides": {