agent-sql 0.3.1 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2) hide show
  1. package/dist/index.mjs +123 -2
  2. package/package.json +1 -1
package/dist/index.mjs CHANGED
@@ -778,7 +778,12 @@ function checkWhereExpr(expr, allowed) {
778
778
  function defineSchema(schema) {
779
779
  return schema;
780
780
  }
781
- function checkJoins(ast, schema) {
781
+ function validateJoins(ast, schema) {
782
+ const ast2 = checkJoinColumns(ast, schema);
783
+ if (!ast2.ok) return ast2;
784
+ return checkJoinContinuity(ast);
785
+ }
786
+ function checkJoinColumns(ast, schema) {
782
787
  if (schema === void 0) {
783
788
  if (ast.joins.length > 0) return Err(new SanitiseError("No joins allowed when using simple API without schema."));
784
789
  return Ok(ast);
@@ -811,6 +816,122 @@ function getJoinTableRef(joinTableName, left, right) {
811
816
  foreign: left
812
817
  };
813
818
  }
819
+ /**
820
+ * Validates that all tables in the query's FROM/JOIN clauses form a single
821
+ * connected component via the ON predicates. A disconnected table would produce
822
+ * an implicit cross product which can leak data across tenants.
823
+ *
824
+ * Returns the AST unchanged if valid, or a SanitiseError listing the
825
+ * disconnected tables.
826
+ */
827
+ function checkJoinContinuity(ast) {
828
+ if (ast.joins.length === 0) return Ok(ast);
829
+ const tables = /* @__PURE__ */ new Set();
830
+ tables.add(ast.from.table.name);
831
+ for (const join of ast.joins) tables.add(join.table.name);
832
+ if (tables.size <= 1) return Ok(ast);
833
+ const adjacency = /* @__PURE__ */ new Map();
834
+ for (const table of tables) adjacency.set(table, /* @__PURE__ */ new Set());
835
+ for (const join of ast.joins) {
836
+ if (join.condition === null) continue;
837
+ if (join.condition.type === "join_using") continue;
838
+ const referencedTables = /* @__PURE__ */ new Set();
839
+ collectTableRefsFromExpr(join.condition.expr, referencedTables);
840
+ const relevantTables = [];
841
+ for (const t of referencedTables) if (tables.has(t)) relevantTables.push(t);
842
+ for (let i = 0; i < relevantTables.length; i++) for (let j = i + 1; j < relevantTables.length; j++) {
843
+ adjacency.get(relevantTables[i]).add(relevantTables[j]);
844
+ adjacency.get(relevantTables[j]).add(relevantTables[i]);
845
+ }
846
+ }
847
+ const start = ast.from.table.name;
848
+ const visited = /* @__PURE__ */ new Set();
849
+ const queue = [start];
850
+ visited.add(start);
851
+ while (queue.length > 0) {
852
+ const current = queue.shift();
853
+ for (const neighbor of adjacency.get(current) ?? []) if (!visited.has(neighbor)) {
854
+ visited.add(neighbor);
855
+ queue.push(neighbor);
856
+ }
857
+ }
858
+ if (visited.size === tables.size) return Ok(ast);
859
+ const disconnected = [];
860
+ for (const table of tables) if (!visited.has(table)) disconnected.push(table);
861
+ disconnected.sort();
862
+ return Err(new SanitiseError(`Disconnected table(s) in query: ${disconnected.join(", ")}. All tables must be connected via JOIN ON predicates.`));
863
+ }
864
+ /** Recursively collect table names from column references in a WhereExpr. */
865
+ function collectTableRefsFromExpr(expr, out) {
866
+ switch (expr.type) {
867
+ case "where_and":
868
+ case "where_or":
869
+ collectTableRefsFromExpr(expr.left, out);
870
+ collectTableRefsFromExpr(expr.right, out);
871
+ break;
872
+ case "where_not":
873
+ collectTableRefsFromExpr(expr.expr, out);
874
+ break;
875
+ case "where_comparison":
876
+ collectTableRefsFromValue(expr.left, out);
877
+ collectTableRefsFromValue(expr.right, out);
878
+ break;
879
+ case "where_is_null":
880
+ collectTableRefsFromValue(expr.expr, out);
881
+ break;
882
+ case "where_is_bool":
883
+ collectTableRefsFromValue(expr.expr, out);
884
+ break;
885
+ case "where_between":
886
+ collectTableRefsFromValue(expr.expr, out);
887
+ collectTableRefsFromValue(expr.low, out);
888
+ collectTableRefsFromValue(expr.high, out);
889
+ break;
890
+ case "where_in":
891
+ collectTableRefsFromValue(expr.expr, out);
892
+ for (const v of expr.list) collectTableRefsFromValue(v, out);
893
+ break;
894
+ case "where_like":
895
+ collectTableRefsFromValue(expr.expr, out);
896
+ collectTableRefsFromValue(expr.pattern, out);
897
+ break;
898
+ case "where_ts_match":
899
+ collectTableRefsFromValue(expr.left, out);
900
+ collectTableRefsFromValue(expr.right, out);
901
+ break;
902
+ }
903
+ }
904
+ /** Recursively collect table names from column references in a WhereValue. */
905
+ function collectTableRefsFromValue(val, out) {
906
+ switch (val.type) {
907
+ case "where_value":
908
+ if (val.kind === "column_ref" && val.ref.table) out.add(val.ref.table);
909
+ else if (val.kind === "func_call") {
910
+ if (val.func.args.kind === "args") for (const arg of val.func.args.args) collectTableRefsFromValue(arg, out);
911
+ }
912
+ break;
913
+ case "where_arith":
914
+ case "where_jsonb_op":
915
+ case "where_pgvector_op":
916
+ collectTableRefsFromValue(val.left, out);
917
+ collectTableRefsFromValue(val.right, out);
918
+ break;
919
+ case "where_unary_minus":
920
+ collectTableRefsFromValue(val.expr, out);
921
+ break;
922
+ case "case_expr":
923
+ if (val.subject) collectTableRefsFromValue(val.subject, out);
924
+ for (const w of val.whens) {
925
+ collectTableRefsFromValue(w.condition, out);
926
+ collectTableRefsFromValue(w.result, out);
927
+ }
928
+ if (val.else) collectTableRefsFromValue(val.else, out);
929
+ break;
930
+ case "cast_expr":
931
+ collectTableRefsFromValue(val.expr, out);
932
+ break;
933
+ }
934
+ }
814
935
  //#endregion
815
936
  //#region src/graph.ts
816
937
  function insertNeededGuardJoins(ast, schema, guards, autoJoin) {
@@ -7375,7 +7496,7 @@ function privateAgentSql(sql, { guards: guardsRaw, schema, autoJoin, limit, pret
7375
7496
  if (!guards.ok) throw guards.error;
7376
7497
  const ast = parseSql(sql);
7377
7498
  if (!ast.ok) return returnOrThrow(ast, throws);
7378
- const ast2 = checkJoins(ast.data, schema);
7499
+ const ast2 = validateJoins(ast.data, schema);
7379
7500
  if (!ast2.ok) return returnOrThrow(ast2, throws);
7380
7501
  const ast3 = checkFunctions(ast2.data, db, allowExtraFunctions);
7381
7502
  if (!ast3.ok) return returnOrThrow(ast3, throws);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "agent-sql",
3
- "version": "0.3.1",
3
+ "version": "0.3.2",
4
4
  "description": "A starter for creating a TypeScript package.",
5
5
  "keywords": [
6
6
  "agent",