agent-sql 0.2.3 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -4,20 +4,20 @@ Sanitise agent-written SQL for multi-tenant DBs.
4
4
 
5
5
  You provide a tenant ID, and the agent supplies the query.
6
6
 
7
- agent-sql works by fully parsing the supplied SQL query into an AST.
8
- The grammar ONLY accepts `SELECT` statements. Anything else is an error.
9
- CTEs and other complex things that we aren't confident of securing: error.
10
-
11
- It ensures that that the needed tenant table is somewhere in the query,
12
- and adds a `WHERE` clause ensuring that only values from the supplied ID are returned.
13
- Then it checks that the tables and `JOIN`s follow the schema, preventing sneaky joins.
7
+ Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785).
8
+ And [Cloudflare](https://x.com/thomas_ankcorn/status/2033931057133748330).
14
9
 
15
- Function calls also go through a whitelist (configurable).
10
+ ## How it works
16
11
 
17
- Finally, we throw in a `LIMIT` clause (configurable) to prevent accidental LLM denial-of-service.
12
+ agent-sql works by fully parsing the supplied SQL query into an AST and transforming it:
18
13
 
19
- Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785).
20
- And [Cloudflare](https://x.com/thomas_ankcorn/status/2033931057133748330).
14
+ - **Only `SELECT`:** it's impossible to insert, drop or anything else.
15
+ - **Reduced subset:** CTEs, subqueries and other tricky things are rejected.
16
+ - **Limited functions:** passed through a (configurable) whitelist.
17
+ - **No DoS:** a default `LIMIT` is applied, but can be adjusted.
18
+ - **`WHERE` guards:** insert multiple tenant/ownership conditions to be inserted.
19
+ - **`JOIN`s added:** if needed to reach the guard tenant tables (save on tokens).
20
+ - **No sneaky joins:** no `join secrets on true`. We have your back.
21
21
 
22
22
  ## Quickstart
23
23
 
@@ -28,17 +28,15 @@ npm install agent-sql
28
28
  ```ts
29
29
  import { agentSql } from "agent-sql";
30
30
 
31
- const sql = agentSql(`SELECT * FROM msg`, "msg.user_id", 123);
31
+ const sql = agentSql(`SELECT * FROM msg`, "msg.tenant_id", 123);
32
32
 
33
33
  console.log(sql);
34
34
  // SELECT *
35
35
  // FROM msg
36
- // WHERE msg.user_id = 123
36
+ // WHERE msg.tenant_id = 123
37
37
  // LIMIT 10000
38
38
  ```
39
39
 
40
- `agent-sql` parses the SQL, enforces a mandatory equality filter on the given column as the outermost `AND` condition (so it cannot be short-circuited by agent-supplied `OR` clauses), and returns the sanitised SQL string.
41
-
42
40
  ## Usage
43
41
 
44
42
  ### Define once, use many times
package/dist/index.d.mts CHANGED
@@ -26,18 +26,21 @@ declare function parseSql(expr: string): Result<SelectStatement>;
26
26
  //#region src/index.d.ts
27
27
  declare function agentSql<S extends string>(sql: string, column: S & OneOrTwoDots<S>, value: GuardVal, {
28
28
  schema,
29
+ autoJoin,
29
30
  limit,
30
31
  pretty,
31
32
  db,
32
33
  allowExtraFunctions
33
34
  }?: {
34
35
  schema?: Schema;
36
+ autoJoin?: boolean;
35
37
  limit?: number;
36
38
  pretty?: boolean;
37
39
  db?: DbType;
38
40
  allowExtraFunctions?: string[];
39
41
  }): string;
40
42
  declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(schema: T, guards: Record<S, GuardVal>, opts: {
43
+ autoJoin?: boolean;
41
44
  limit?: number;
42
45
  pretty?: boolean;
43
46
  throws: false;
@@ -45,6 +48,7 @@ declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(
45
48
  allowExtraFunctions?: string[];
46
49
  }): (expr: string) => Result<string>;
47
50
  declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(schema: T, guards: Record<S, GuardVal>, opts?: {
51
+ autoJoin?: boolean;
48
52
  limit?: number;
49
53
  pretty?: boolean;
50
54
  throws?: true;
package/dist/index.mjs CHANGED
@@ -774,6 +774,177 @@ function checkWhereExpr(expr, allowed) {
774
774
  }
775
775
  }
776
776
  //#endregion
777
+ //#region src/joins.ts
778
+ function defineSchema(schema) {
779
+ return schema;
780
+ }
781
+ function checkJoins(ast, schema) {
782
+ if (schema === void 0) {
783
+ if (ast.joins.length > 0) return Err(new SanitiseError("No joins allowed when using simple API without schema."));
784
+ return Ok(ast);
785
+ }
786
+ if (!(ast.from.table.name in schema)) return Err(new SanitiseError(`Table ${ast.from.table.name} is not allowed`));
787
+ for (const join of ast.joins) {
788
+ const joinSettings = schema[join.table.name];
789
+ if (joinSettings === void 0) return Err(new SanitiseError(`Table ${join.table.name} is not allowed`));
790
+ if (join.condition === null || join.condition.type === "join_using" || join.condition.expr.type !== "where_comparison" || join.condition.expr.operator !== "=" || join.condition.expr.left.type !== "where_value" || join.condition.expr.left.kind !== "column_ref" || join.condition.expr.right.type !== "where_value" || join.condition.expr.right.kind !== "column_ref") return Err(new SanitiseError("Only JOIN ON column_ref = column_ref supported"));
791
+ const { joining, foreign } = getJoinTableRef(join.table.name, join.condition.expr.left.ref, join.condition.expr.right.ref);
792
+ const joinTableCol = joinSettings[joining.name];
793
+ if (joinTableCol === void 0) return Err(new SanitiseError(`Tried to join using ${join.table.name}.${joining.name}`));
794
+ if (joinTableCol === null) {
795
+ const foreignTableSettings = schema[foreign.table];
796
+ if (foreignTableSettings === void 0) return Err(new SanitiseError(`Table ${foreign.name} is not allowed`));
797
+ const foreignCol = foreignTableSettings[foreign.name];
798
+ if (foreignCol === void 0 || foreignCol === null) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
799
+ if (joining.table !== foreignCol.ft || joining.name !== foreignCol.fc) return Err(new SanitiseError(`Tried to join using ${joining.table}.${joining.name}`));
800
+ } else if (foreign.table !== joinTableCol.ft || foreign.name !== joinTableCol.fc) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
801
+ }
802
+ return Ok(ast);
803
+ }
804
+ function getJoinTableRef(joinTableName, left, right) {
805
+ if (left.table === joinTableName) return {
806
+ joining: left,
807
+ foreign: right
808
+ };
809
+ return {
810
+ joining: right,
811
+ foreign: left
812
+ };
813
+ }
814
+ //#endregion
815
+ //#region src/graph.ts
816
+ function insertNeededGuardJoins(ast, schema, guards, autoJoin) {
817
+ if (schema === void 0) return Ok(ast);
818
+ if (!autoJoin) return Ok(ast);
819
+ return resolveGraphForJoins(ast, schema, guards);
820
+ }
821
+ function resolveGraphForJoins(ast, schema, guards) {
822
+ const haveTables = /* @__PURE__ */ new Set();
823
+ haveTables.add(ast.from.table.name);
824
+ for (const join of ast.joins) haveTables.add(join.table.name);
825
+ const originalTables = new Set(haveTables);
826
+ const adj = buildAdjacency(schema);
827
+ const newJoins = [];
828
+ for (const guard of guards) {
829
+ if (haveTables.has(guard.table)) continue;
830
+ const path = bfsPath(adj, haveTables, guard.table);
831
+ if (path === null) return Err(new SanitiseError(`No join path from query tables to guard table '${guard.table}'`));
832
+ let current = [...haveTables].find((t) => path[0] && (path[0].tableA === t || path[0].tableB === t));
833
+ for (const edge of path) {
834
+ const neighbor = edge.tableA === current ? edge.tableB : edge.tableA;
835
+ if (!haveTables.has(neighbor)) {
836
+ newJoins.push(edgeToJoin(edge, neighbor));
837
+ haveTables.add(neighbor);
838
+ }
839
+ current = neighbor;
840
+ }
841
+ }
842
+ if (newJoins.length === 0) return Ok(ast);
843
+ const columns = qualifyWildcards(ast.columns, originalTables);
844
+ return Ok({
845
+ ...ast,
846
+ columns,
847
+ joins: [...ast.joins, ...newJoins]
848
+ });
849
+ }
850
+ function buildAdjacency(schema) {
851
+ const adj = /* @__PURE__ */ new Map();
852
+ const tables = schema;
853
+ for (const [tableName, cols] of Object.entries(tables)) {
854
+ if (!adj.has(tableName)) adj.set(tableName, []);
855
+ for (const [colName, def] of Object.entries(cols)) if (def && typeof def === "object" && "ft" in def && "fc" in def) {
856
+ const edge = {
857
+ tableA: tableName,
858
+ colA: colName,
859
+ tableB: def.ft,
860
+ colB: def.fc
861
+ };
862
+ adj.get(tableName).push(edge);
863
+ if (!adj.has(edge.tableB)) adj.set(edge.tableB, []);
864
+ adj.get(edge.tableB).push(edge);
865
+ }
866
+ }
867
+ return adj;
868
+ }
869
+ function bfsPath(adj, startTables, target) {
870
+ if (startTables.has(target)) return [];
871
+ const visited = new Set(startTables);
872
+ const queue = [];
873
+ for (const t of startTables) queue.push([t, []]);
874
+ while (queue.length > 0) {
875
+ const [current, path] = queue.shift();
876
+ for (const edge of adj.get(current) ?? []) {
877
+ const neighbor = edge.tableA === current ? edge.tableB : edge.tableA;
878
+ if (visited.has(neighbor)) continue;
879
+ visited.add(neighbor);
880
+ const newPath = [...path, edge];
881
+ if (neighbor === target) return newPath;
882
+ queue.push([neighbor, newPath]);
883
+ }
884
+ }
885
+ return null;
886
+ }
887
+ function edgeToJoin(edge, fromTable) {
888
+ const [localTable, localCol, foreignTable, foreignCol] = edge.tableA === fromTable ? [
889
+ edge.tableA,
890
+ edge.colA,
891
+ edge.tableB,
892
+ edge.colB
893
+ ] : [
894
+ edge.tableB,
895
+ edge.colB,
896
+ edge.tableA,
897
+ edge.colA
898
+ ];
899
+ return {
900
+ type: "join",
901
+ joinType: "inner",
902
+ table: {
903
+ type: "table_ref",
904
+ name: localTable
905
+ },
906
+ condition: {
907
+ type: "join_on",
908
+ expr: {
909
+ type: "where_comparison",
910
+ operator: "=",
911
+ left: {
912
+ type: "where_value",
913
+ kind: "column_ref",
914
+ ref: {
915
+ type: "column_ref",
916
+ table: localTable,
917
+ name: localCol
918
+ }
919
+ },
920
+ right: {
921
+ type: "where_value",
922
+ kind: "column_ref",
923
+ ref: {
924
+ type: "column_ref",
925
+ table: foreignTable,
926
+ name: foreignCol
927
+ }
928
+ }
929
+ }
930
+ }
931
+ };
932
+ }
933
+ function qualifyWildcards(columns, tables) {
934
+ if (!columns.some((c) => c.expr.kind === "wildcard")) return columns;
935
+ const qualified = [];
936
+ for (const col of columns) if (col.expr.kind === "wildcard") for (const table of tables) qualified.push({
937
+ type: "column",
938
+ expr: {
939
+ type: "column_expr",
940
+ kind: "qualified_wildcard",
941
+ table
942
+ }
943
+ });
944
+ else qualified.push(col);
945
+ return qualified;
946
+ }
947
+ //#endregion
777
948
  //#region src/utils.ts
778
949
  function unreachable(x) {
779
950
  throw new Error(`Unhandled variant: ${JSON.stringify(x)}`);
@@ -1108,44 +1279,6 @@ function tableEquals(a, b) {
1108
1279
  return a.schema == b.schema && a.name == b.name;
1109
1280
  }
1110
1281
  //#endregion
1111
- //#region src/joins.ts
1112
- function defineSchema(schema) {
1113
- return schema;
1114
- }
1115
- function checkJoins(ast, schema) {
1116
- if (schema === void 0) {
1117
- if (ast.joins.length > 0) return Err(new SanitiseError("No joins allowed when using simple API without schema."));
1118
- return Ok(ast);
1119
- }
1120
- if (!(ast.from.table.name in schema)) return Err(new SanitiseError(`Table ${ast.from.table.name} is not allowed`));
1121
- for (const join of ast.joins) {
1122
- const joinSettings = schema[join.table.name];
1123
- if (joinSettings === void 0) return Err(new SanitiseError(`Table ${join.table.name} is not allowed`));
1124
- if (join.condition === null || join.condition.type === "join_using" || join.condition.expr.type !== "where_comparison" || join.condition.expr.operator !== "=" || join.condition.expr.left.type !== "where_value" || join.condition.expr.left.kind !== "column_ref" || join.condition.expr.right.type !== "where_value" || join.condition.expr.right.kind !== "column_ref") return Err(new SanitiseError("Only JOIN ON column_ref = column_ref supported"));
1125
- const { joining, foreign } = getJoinTableRef(join.table.name, join.condition.expr.left.ref, join.condition.expr.right.ref);
1126
- const joinTableCol = joinSettings[joining.name];
1127
- if (joinTableCol === void 0) return Err(new SanitiseError(`Tried to join using ${join.table.name}.${joining.name}`));
1128
- if (joinTableCol === null) {
1129
- const foreignTableSettings = schema[foreign.table];
1130
- if (foreignTableSettings === void 0) return Err(new SanitiseError(`Table ${foreign.name} is not allowed`));
1131
- const foreignCol = foreignTableSettings[foreign.name];
1132
- if (foreignCol === void 0 || foreignCol === null) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
1133
- if (joining.table !== foreignCol.ft || joining.name !== foreignCol.fc) return Err(new SanitiseError(`Tried to join using ${joining.table}.${joining.name}`));
1134
- } else if (foreign.table !== joinTableCol.ft || foreign.name !== joinTableCol.fc) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
1135
- }
1136
- return Ok(ast);
1137
- }
1138
- function getJoinTableRef(joinTableName, left, right) {
1139
- if (left.table === joinTableName) return {
1140
- joining: left,
1141
- foreign: right
1142
- };
1143
- return {
1144
- joining: right,
1145
- foreign: left
1146
- };
1147
- }
1148
- //#endregion
1149
1282
  //#region src/sql.ohm-bundle.js
1150
1283
  const result = makeRecipe([
1151
1284
  "grammar",
@@ -7204,10 +7337,11 @@ function parseSql(expr) {
7204
7337
  }
7205
7338
  //#endregion
7206
7339
  //#region src/index.ts
7207
- function agentSql(sql, column, value, { schema, limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [] } = {}) {
7340
+ function agentSql(sql, column, value, { schema, autoJoin = true, limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [] } = {}) {
7208
7341
  return privateAgentSql(sql, {
7209
7342
  guards: { [column]: value },
7210
7343
  schema,
7344
+ autoJoin,
7211
7345
  limit,
7212
7346
  pretty,
7213
7347
  db,
@@ -7215,10 +7349,11 @@ function agentSql(sql, column, value, { schema, limit = DEFAULT_LIMIT, pretty =
7215
7349
  throws: true
7216
7350
  });
7217
7351
  }
7218
- function createAgentSql(schema, guards, { limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [], throws = true } = {}) {
7352
+ function createAgentSql(schema, guards, { autoJoin = true, limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [], throws = true } = {}) {
7219
7353
  return (expr) => throws ? privateAgentSql(expr, {
7220
7354
  guards,
7221
7355
  schema,
7356
+ autoJoin,
7222
7357
  limit,
7223
7358
  pretty,
7224
7359
  db,
@@ -7227,6 +7362,7 @@ function createAgentSql(schema, guards, { limit = DEFAULT_LIMIT, pretty = false,
7227
7362
  }) : privateAgentSql(expr, {
7228
7363
  guards,
7229
7364
  schema,
7365
+ autoJoin,
7230
7366
  limit,
7231
7367
  pretty,
7232
7368
  db,
@@ -7234,7 +7370,7 @@ function createAgentSql(schema, guards, { limit = DEFAULT_LIMIT, pretty = false,
7234
7370
  throws
7235
7371
  });
7236
7372
  }
7237
- function privateAgentSql(sql, { guards: guardsRaw, schema, limit, pretty, db, allowExtraFunctions, throws }) {
7373
+ function privateAgentSql(sql, { guards: guardsRaw, schema, autoJoin, limit, pretty, db, allowExtraFunctions, throws }) {
7238
7374
  const guards = resolveGuards(guardsRaw);
7239
7375
  if (!guards.ok) throw guards.error;
7240
7376
  const ast = parseSql(sql);
@@ -7243,7 +7379,9 @@ function privateAgentSql(sql, { guards: guardsRaw, schema, limit, pretty, db, al
7243
7379
  if (!ast2.ok) return returnOrThrow(ast2, throws);
7244
7380
  const ast3 = checkFunctions(ast2.data, db, allowExtraFunctions);
7245
7381
  if (!ast3.ok) return returnOrThrow(ast3, throws);
7246
- const san = applyGuards(ast3.data, guards.data, limit);
7382
+ const ast4 = insertNeededGuardJoins(ast3.data, schema, guards.data, autoJoin);
7383
+ if (!ast4.ok) return returnOrThrow(ast4, throws);
7384
+ const san = applyGuards(ast4.data, guards.data, limit);
7247
7385
  if (!san.ok) return returnOrThrow(san, throws);
7248
7386
  const res = outputSql(san.data, pretty);
7249
7387
  if (throws) return res;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "agent-sql",
3
- "version": "0.2.3",
3
+ "version": "0.3.1",
4
4
  "description": "A starter for creating a TypeScript package.",
5
5
  "keywords": [
6
6
  "agent",
@@ -36,7 +36,7 @@
36
36
  "test": "vp test",
37
37
  "check": "vp check",
38
38
  "prepublishOnly": "vp run build",
39
- "gen:types": "ohm generateBundles --withTypes --esm 'src/sql.ohm'",
39
+ "ohm": "ohm generateBundles --withTypes --esm 'src/sql.ohm'",
40
40
  "script": "vp exec tsx"
41
41
  },
42
42
  "dependencies": {