agent-sql 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. package/README.md +13 -15
  2. package/dist/index.mjs +140 -2
  3. package/package.json +1 -1
package/README.md CHANGED
@@ -4,20 +4,20 @@ Sanitise agent-written SQL for multi-tenant DBs.
4
4
 
5
5
  You provide a tenant ID, and the agent supplies the query.
6
6
 
7
- agent-sql works by fully parsing the supplied SQL query into an AST.
8
- The grammar ONLY accepts `SELECT` statements. Anything else is an error.
9
- CTEs and other complex things that we aren't confident of securing: error.
10
-
11
- It ensures that that the needed tenant table is somewhere in the query,
12
- and adds a `WHERE` clause ensuring that only values from the supplied ID are returned.
13
- Then it checks that the tables and `JOIN`s follow the schema, preventing sneaky joins.
7
+ Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785).
8
+ And [Cloudflare](https://x.com/thomas_ankcorn/status/2033931057133748330).
14
9
 
15
- Function calls also go through a whitelist (configurable).
10
+ ## How it works
16
11
 
17
- Finally, we throw in a `LIMIT` clause (configurable) to prevent accidental LLM denial-of-service.
12
+ agent-sql works by fully parsing the supplied SQL query into an AST and transforming it:
18
13
 
19
- Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785).
20
- And [Cloudflare](https://x.com/thomas_ankcorn/status/2033931057133748330).
14
+ - **Only `SELECT`:** it's impossible to insert, drop or anything else.
15
+ - **Reduced subset:** CTEs, subqueries and other tricky things are rejected.
16
+ - **Limited functions:** passed through a (configurable) whitelist.
17
+ - **No DoS:** a default `LIMIT` is applied, but can be adjusted.
18
+ - **`WHERE` guards:** insert multiple tenant/ownership conditions to be inserted.
19
+ - **`JOIN`s added:** if needed to reach the guard tenant tables (save on tokens).
20
+ - **No sneaky joins:** no `join secrets on true`. We have your back.
21
21
 
22
22
  ## Quickstart
23
23
 
@@ -28,17 +28,15 @@ npm install agent-sql
28
28
  ```ts
29
29
  import { agentSql } from "agent-sql";
30
30
 
31
- const sql = agentSql(`SELECT * FROM msg`, "msg.user_id", 123);
31
+ const sql = agentSql(`SELECT * FROM msg`, "msg.tenant_id", 123);
32
32
 
33
33
  console.log(sql);
34
34
  // SELECT *
35
35
  // FROM msg
36
- // WHERE msg.user_id = 123
36
+ // WHERE msg.tenant_id = 123
37
37
  // LIMIT 10000
38
38
  ```
39
39
 
40
- `agent-sql` parses the SQL, enforces a mandatory equality filter on the given column as the outermost `AND` condition (so it cannot be short-circuited by agent-supplied `OR` clauses), and returns the sanitised SQL string.
41
-
42
40
  ## Usage
43
41
 
44
42
  ### Define once, use many times
package/dist/index.mjs CHANGED
@@ -778,7 +778,12 @@ function checkWhereExpr(expr, allowed) {
778
778
  function defineSchema(schema) {
779
779
  return schema;
780
780
  }
781
- function checkJoins(ast, schema) {
781
+ function validateJoins(ast, schema) {
782
+ const ast2 = checkJoinColumns(ast, schema);
783
+ if (!ast2.ok) return ast2;
784
+ return checkJoinContinuity(ast);
785
+ }
786
+ function checkJoinColumns(ast, schema) {
782
787
  if (schema === void 0) {
783
788
  if (ast.joins.length > 0) return Err(new SanitiseError("No joins allowed when using simple API without schema."));
784
789
  return Ok(ast);
@@ -811,6 +816,122 @@ function getJoinTableRef(joinTableName, left, right) {
811
816
  foreign: left
812
817
  };
813
818
  }
819
+ /**
820
+ * Validates that all tables in the query's FROM/JOIN clauses form a single
821
+ * connected component via the ON predicates. A disconnected table would produce
822
+ * an implicit cross product which can leak data across tenants.
823
+ *
824
+ * Returns the AST unchanged if valid, or a SanitiseError listing the
825
+ * disconnected tables.
826
+ */
827
+ function checkJoinContinuity(ast) {
828
+ if (ast.joins.length === 0) return Ok(ast);
829
+ const tables = /* @__PURE__ */ new Set();
830
+ tables.add(ast.from.table.name);
831
+ for (const join of ast.joins) tables.add(join.table.name);
832
+ if (tables.size <= 1) return Ok(ast);
833
+ const adjacency = /* @__PURE__ */ new Map();
834
+ for (const table of tables) adjacency.set(table, /* @__PURE__ */ new Set());
835
+ for (const join of ast.joins) {
836
+ if (join.condition === null) continue;
837
+ if (join.condition.type === "join_using") continue;
838
+ const referencedTables = /* @__PURE__ */ new Set();
839
+ collectTableRefsFromExpr(join.condition.expr, referencedTables);
840
+ const relevantTables = [];
841
+ for (const t of referencedTables) if (tables.has(t)) relevantTables.push(t);
842
+ for (let i = 0; i < relevantTables.length; i++) for (let j = i + 1; j < relevantTables.length; j++) {
843
+ adjacency.get(relevantTables[i]).add(relevantTables[j]);
844
+ adjacency.get(relevantTables[j]).add(relevantTables[i]);
845
+ }
846
+ }
847
+ const start = ast.from.table.name;
848
+ const visited = /* @__PURE__ */ new Set();
849
+ const queue = [start];
850
+ visited.add(start);
851
+ while (queue.length > 0) {
852
+ const current = queue.shift();
853
+ for (const neighbor of adjacency.get(current) ?? []) if (!visited.has(neighbor)) {
854
+ visited.add(neighbor);
855
+ queue.push(neighbor);
856
+ }
857
+ }
858
+ if (visited.size === tables.size) return Ok(ast);
859
+ const disconnected = [];
860
+ for (const table of tables) if (!visited.has(table)) disconnected.push(table);
861
+ disconnected.sort();
862
+ return Err(new SanitiseError(`Disconnected table(s) in query: ${disconnected.join(", ")}. All tables must be connected via JOIN ON predicates.`));
863
+ }
864
+ /** Recursively collect table names from column references in a WhereExpr. */
865
+ function collectTableRefsFromExpr(expr, out) {
866
+ switch (expr.type) {
867
+ case "where_and":
868
+ case "where_or":
869
+ collectTableRefsFromExpr(expr.left, out);
870
+ collectTableRefsFromExpr(expr.right, out);
871
+ break;
872
+ case "where_not":
873
+ collectTableRefsFromExpr(expr.expr, out);
874
+ break;
875
+ case "where_comparison":
876
+ collectTableRefsFromValue(expr.left, out);
877
+ collectTableRefsFromValue(expr.right, out);
878
+ break;
879
+ case "where_is_null":
880
+ collectTableRefsFromValue(expr.expr, out);
881
+ break;
882
+ case "where_is_bool":
883
+ collectTableRefsFromValue(expr.expr, out);
884
+ break;
885
+ case "where_between":
886
+ collectTableRefsFromValue(expr.expr, out);
887
+ collectTableRefsFromValue(expr.low, out);
888
+ collectTableRefsFromValue(expr.high, out);
889
+ break;
890
+ case "where_in":
891
+ collectTableRefsFromValue(expr.expr, out);
892
+ for (const v of expr.list) collectTableRefsFromValue(v, out);
893
+ break;
894
+ case "where_like":
895
+ collectTableRefsFromValue(expr.expr, out);
896
+ collectTableRefsFromValue(expr.pattern, out);
897
+ break;
898
+ case "where_ts_match":
899
+ collectTableRefsFromValue(expr.left, out);
900
+ collectTableRefsFromValue(expr.right, out);
901
+ break;
902
+ }
903
+ }
904
+ /** Recursively collect table names from column references in a WhereValue. */
905
+ function collectTableRefsFromValue(val, out) {
906
+ switch (val.type) {
907
+ case "where_value":
908
+ if (val.kind === "column_ref" && val.ref.table) out.add(val.ref.table);
909
+ else if (val.kind === "func_call") {
910
+ if (val.func.args.kind === "args") for (const arg of val.func.args.args) collectTableRefsFromValue(arg, out);
911
+ }
912
+ break;
913
+ case "where_arith":
914
+ case "where_jsonb_op":
915
+ case "where_pgvector_op":
916
+ collectTableRefsFromValue(val.left, out);
917
+ collectTableRefsFromValue(val.right, out);
918
+ break;
919
+ case "where_unary_minus":
920
+ collectTableRefsFromValue(val.expr, out);
921
+ break;
922
+ case "case_expr":
923
+ if (val.subject) collectTableRefsFromValue(val.subject, out);
924
+ for (const w of val.whens) {
925
+ collectTableRefsFromValue(w.condition, out);
926
+ collectTableRefsFromValue(w.result, out);
927
+ }
928
+ if (val.else) collectTableRefsFromValue(val.else, out);
929
+ break;
930
+ case "cast_expr":
931
+ collectTableRefsFromValue(val.expr, out);
932
+ break;
933
+ }
934
+ }
814
935
  //#endregion
815
936
  //#region src/graph.ts
816
937
  function insertNeededGuardJoins(ast, schema, guards, autoJoin) {
@@ -822,6 +943,7 @@ function resolveGraphForJoins(ast, schema, guards) {
822
943
  const haveTables = /* @__PURE__ */ new Set();
823
944
  haveTables.add(ast.from.table.name);
824
945
  for (const join of ast.joins) haveTables.add(join.table.name);
946
+ const originalTables = new Set(haveTables);
825
947
  const adj = buildAdjacency(schema);
826
948
  const newJoins = [];
827
949
  for (const guard of guards) {
@@ -839,8 +961,10 @@ function resolveGraphForJoins(ast, schema, guards) {
839
961
  }
840
962
  }
841
963
  if (newJoins.length === 0) return Ok(ast);
964
+ const columns = qualifyWildcards(ast.columns, originalTables);
842
965
  return Ok({
843
966
  ...ast,
967
+ columns,
844
968
  joins: [...ast.joins, ...newJoins]
845
969
  });
846
970
  }
@@ -927,6 +1051,20 @@ function edgeToJoin(edge, fromTable) {
927
1051
  }
928
1052
  };
929
1053
  }
1054
+ function qualifyWildcards(columns, tables) {
1055
+ if (!columns.some((c) => c.expr.kind === "wildcard")) return columns;
1056
+ const qualified = [];
1057
+ for (const col of columns) if (col.expr.kind === "wildcard") for (const table of tables) qualified.push({
1058
+ type: "column",
1059
+ expr: {
1060
+ type: "column_expr",
1061
+ kind: "qualified_wildcard",
1062
+ table
1063
+ }
1064
+ });
1065
+ else qualified.push(col);
1066
+ return qualified;
1067
+ }
930
1068
  //#endregion
931
1069
  //#region src/utils.ts
932
1070
  function unreachable(x) {
@@ -7358,7 +7496,7 @@ function privateAgentSql(sql, { guards: guardsRaw, schema, autoJoin, limit, pret
7358
7496
  if (!guards.ok) throw guards.error;
7359
7497
  const ast = parseSql(sql);
7360
7498
  if (!ast.ok) return returnOrThrow(ast, throws);
7361
- const ast2 = checkJoins(ast.data, schema);
7499
+ const ast2 = validateJoins(ast.data, schema);
7362
7500
  if (!ast2.ok) return returnOrThrow(ast2, throws);
7363
7501
  const ast3 = checkFunctions(ast2.data, db, allowExtraFunctions);
7364
7502
  if (!ast3.ok) return returnOrThrow(ast3, throws);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "agent-sql",
3
- "version": "0.3.0",
3
+ "version": "0.3.2",
4
4
  "description": "A starter for creating a TypeScript package.",
5
5
  "keywords": [
6
6
  "agent",