agent-sql 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.mjs +127 -3
- package/package.json +1 -1
package/dist/index.mjs
CHANGED
|
@@ -778,7 +778,12 @@ function checkWhereExpr(expr, allowed) {
|
|
|
778
778
|
function defineSchema(schema) {
|
|
779
779
|
return schema;
|
|
780
780
|
}
|
|
781
|
-
function
|
|
781
|
+
function validateJoins(ast, schema) {
|
|
782
|
+
const ast2 = checkJoinColumns(ast, schema);
|
|
783
|
+
if (!ast2.ok) return ast2;
|
|
784
|
+
return checkJoinContinuity(ast);
|
|
785
|
+
}
|
|
786
|
+
function checkJoinColumns(ast, schema) {
|
|
782
787
|
if (schema === void 0) {
|
|
783
788
|
if (ast.joins.length > 0) return Err(new SanitiseError("No joins allowed when using simple API without schema."));
|
|
784
789
|
return Ok(ast);
|
|
@@ -788,7 +793,10 @@ function checkJoins(ast, schema) {
|
|
|
788
793
|
const joinSettings = schema[join.table.name];
|
|
789
794
|
if (joinSettings === void 0) return Err(new SanitiseError(`Table ${join.table.name} is not allowed`));
|
|
790
795
|
if (join.condition === null || join.condition.type === "join_using" || join.condition.expr.type !== "where_comparison" || join.condition.expr.operator !== "=" || join.condition.expr.left.type !== "where_value" || join.condition.expr.left.kind !== "column_ref" || join.condition.expr.right.type !== "where_value" || join.condition.expr.right.kind !== "column_ref") return Err(new SanitiseError("Only JOIN ON column_ref = column_ref supported"));
|
|
791
|
-
const
|
|
796
|
+
const leftRef = join.condition.expr.left.ref;
|
|
797
|
+
const rightRef = join.condition.expr.right.ref;
|
|
798
|
+
if (leftRef.table !== join.table.name && rightRef.table !== join.table.name) return Err(new SanitiseError(`JOIN ${join.table.name} ON clause does not reference ${join.table.name}`));
|
|
799
|
+
const { joining, foreign } = getJoinTableRef(join.table.name, leftRef, rightRef);
|
|
792
800
|
const joinTableCol = joinSettings[joining.name];
|
|
793
801
|
if (joinTableCol === void 0) return Err(new SanitiseError(`Tried to join using ${join.table.name}.${joining.name}`));
|
|
794
802
|
if (joinTableCol === null) {
|
|
@@ -811,6 +819,122 @@ function getJoinTableRef(joinTableName, left, right) {
|
|
|
811
819
|
foreign: left
|
|
812
820
|
};
|
|
813
821
|
}
|
|
822
|
+
/**
|
|
823
|
+
* Validates that all tables in the query's FROM/JOIN clauses form a single
|
|
824
|
+
* connected component via the ON predicates. A disconnected table would produce
|
|
825
|
+
* an implicit cross product which can leak data across tenants.
|
|
826
|
+
*
|
|
827
|
+
* Returns the AST unchanged if valid, or a SanitiseError listing the
|
|
828
|
+
* disconnected tables.
|
|
829
|
+
*/
|
|
830
|
+
function checkJoinContinuity(ast) {
|
|
831
|
+
if (ast.joins.length === 0) return Ok(ast);
|
|
832
|
+
const tables = /* @__PURE__ */ new Set();
|
|
833
|
+
tables.add(ast.from.table.name);
|
|
834
|
+
for (const join of ast.joins) tables.add(join.table.name);
|
|
835
|
+
if (tables.size <= 1) return Ok(ast);
|
|
836
|
+
const adjacency = /* @__PURE__ */ new Map();
|
|
837
|
+
for (const table of tables) adjacency.set(table, /* @__PURE__ */ new Set());
|
|
838
|
+
for (const join of ast.joins) {
|
|
839
|
+
if (join.condition === null) continue;
|
|
840
|
+
if (join.condition.type === "join_using") continue;
|
|
841
|
+
const referencedTables = /* @__PURE__ */ new Set();
|
|
842
|
+
collectTableRefsFromExpr(join.condition.expr, referencedTables);
|
|
843
|
+
const relevantTables = [];
|
|
844
|
+
for (const t of referencedTables) if (tables.has(t)) relevantTables.push(t);
|
|
845
|
+
for (let i = 0; i < relevantTables.length; i++) for (let j = i + 1; j < relevantTables.length; j++) {
|
|
846
|
+
adjacency.get(relevantTables[i]).add(relevantTables[j]);
|
|
847
|
+
adjacency.get(relevantTables[j]).add(relevantTables[i]);
|
|
848
|
+
}
|
|
849
|
+
}
|
|
850
|
+
const start = ast.from.table.name;
|
|
851
|
+
const visited = /* @__PURE__ */ new Set();
|
|
852
|
+
const queue = [start];
|
|
853
|
+
visited.add(start);
|
|
854
|
+
while (queue.length > 0) {
|
|
855
|
+
const current = queue.shift();
|
|
856
|
+
for (const neighbor of adjacency.get(current) ?? []) if (!visited.has(neighbor)) {
|
|
857
|
+
visited.add(neighbor);
|
|
858
|
+
queue.push(neighbor);
|
|
859
|
+
}
|
|
860
|
+
}
|
|
861
|
+
if (visited.size === tables.size) return Ok(ast);
|
|
862
|
+
const disconnected = [];
|
|
863
|
+
for (const table of tables) if (!visited.has(table)) disconnected.push(table);
|
|
864
|
+
disconnected.sort();
|
|
865
|
+
return Err(new SanitiseError(`Disconnected table(s) in query: ${disconnected.join(", ")}. All tables must be connected via JOIN ON predicates.`));
|
|
866
|
+
}
|
|
867
|
+
/** Recursively collect table names from column references in a WhereExpr. */
|
|
868
|
+
function collectTableRefsFromExpr(expr, out) {
|
|
869
|
+
switch (expr.type) {
|
|
870
|
+
case "where_and":
|
|
871
|
+
case "where_or":
|
|
872
|
+
collectTableRefsFromExpr(expr.left, out);
|
|
873
|
+
collectTableRefsFromExpr(expr.right, out);
|
|
874
|
+
break;
|
|
875
|
+
case "where_not":
|
|
876
|
+
collectTableRefsFromExpr(expr.expr, out);
|
|
877
|
+
break;
|
|
878
|
+
case "where_comparison":
|
|
879
|
+
collectTableRefsFromValue(expr.left, out);
|
|
880
|
+
collectTableRefsFromValue(expr.right, out);
|
|
881
|
+
break;
|
|
882
|
+
case "where_is_null":
|
|
883
|
+
collectTableRefsFromValue(expr.expr, out);
|
|
884
|
+
break;
|
|
885
|
+
case "where_is_bool":
|
|
886
|
+
collectTableRefsFromValue(expr.expr, out);
|
|
887
|
+
break;
|
|
888
|
+
case "where_between":
|
|
889
|
+
collectTableRefsFromValue(expr.expr, out);
|
|
890
|
+
collectTableRefsFromValue(expr.low, out);
|
|
891
|
+
collectTableRefsFromValue(expr.high, out);
|
|
892
|
+
break;
|
|
893
|
+
case "where_in":
|
|
894
|
+
collectTableRefsFromValue(expr.expr, out);
|
|
895
|
+
for (const v of expr.list) collectTableRefsFromValue(v, out);
|
|
896
|
+
break;
|
|
897
|
+
case "where_like":
|
|
898
|
+
collectTableRefsFromValue(expr.expr, out);
|
|
899
|
+
collectTableRefsFromValue(expr.pattern, out);
|
|
900
|
+
break;
|
|
901
|
+
case "where_ts_match":
|
|
902
|
+
collectTableRefsFromValue(expr.left, out);
|
|
903
|
+
collectTableRefsFromValue(expr.right, out);
|
|
904
|
+
break;
|
|
905
|
+
}
|
|
906
|
+
}
|
|
907
|
+
/** Recursively collect table names from column references in a WhereValue. */
|
|
908
|
+
function collectTableRefsFromValue(val, out) {
|
|
909
|
+
switch (val.type) {
|
|
910
|
+
case "where_value":
|
|
911
|
+
if (val.kind === "column_ref" && val.ref.table) out.add(val.ref.table);
|
|
912
|
+
else if (val.kind === "func_call") {
|
|
913
|
+
if (val.func.args.kind === "args") for (const arg of val.func.args.args) collectTableRefsFromValue(arg, out);
|
|
914
|
+
}
|
|
915
|
+
break;
|
|
916
|
+
case "where_arith":
|
|
917
|
+
case "where_jsonb_op":
|
|
918
|
+
case "where_pgvector_op":
|
|
919
|
+
collectTableRefsFromValue(val.left, out);
|
|
920
|
+
collectTableRefsFromValue(val.right, out);
|
|
921
|
+
break;
|
|
922
|
+
case "where_unary_minus":
|
|
923
|
+
collectTableRefsFromValue(val.expr, out);
|
|
924
|
+
break;
|
|
925
|
+
case "case_expr":
|
|
926
|
+
if (val.subject) collectTableRefsFromValue(val.subject, out);
|
|
927
|
+
for (const w of val.whens) {
|
|
928
|
+
collectTableRefsFromValue(w.condition, out);
|
|
929
|
+
collectTableRefsFromValue(w.result, out);
|
|
930
|
+
}
|
|
931
|
+
if (val.else) collectTableRefsFromValue(val.else, out);
|
|
932
|
+
break;
|
|
933
|
+
case "cast_expr":
|
|
934
|
+
collectTableRefsFromValue(val.expr, out);
|
|
935
|
+
break;
|
|
936
|
+
}
|
|
937
|
+
}
|
|
814
938
|
//#endregion
|
|
815
939
|
//#region src/graph.ts
|
|
816
940
|
function insertNeededGuardJoins(ast, schema, guards, autoJoin) {
|
|
@@ -7375,7 +7499,7 @@ function privateAgentSql(sql, { guards: guardsRaw, schema, autoJoin, limit, pret
|
|
|
7375
7499
|
if (!guards.ok) throw guards.error;
|
|
7376
7500
|
const ast = parseSql(sql);
|
|
7377
7501
|
if (!ast.ok) return returnOrThrow(ast, throws);
|
|
7378
|
-
const ast2 =
|
|
7502
|
+
const ast2 = validateJoins(ast.data, schema);
|
|
7379
7503
|
if (!ast2.ok) return returnOrThrow(ast2, throws);
|
|
7380
7504
|
const ast3 = checkFunctions(ast2.data, db, allowExtraFunctions);
|
|
7381
7505
|
if (!ast3.ok) return returnOrThrow(ast3, throws);
|