agent-sql 0.2.3 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +13 -15
- package/dist/index.d.mts +4 -0
- package/dist/index.mjs +180 -42
- package/package.json +2 -2
package/README.md
CHANGED
|
@@ -4,20 +4,20 @@ Sanitise agent-written SQL for multi-tenant DBs.
|
|
|
4
4
|
|
|
5
5
|
You provide a tenant ID, and the agent supplies the query.
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
CTEs and other complex things that we aren't confident of securing: error.
|
|
10
|
-
|
|
11
|
-
It ensures that that the needed tenant table is somewhere in the query,
|
|
12
|
-
and adds a `WHERE` clause ensuring that only values from the supplied ID are returned.
|
|
13
|
-
Then it checks that the tables and `JOIN`s follow the schema, preventing sneaky joins.
|
|
7
|
+
Apparently this is how [Trigger.dev does it](https://x.com/mattaitken/status/2033928542975639785).
|
|
8
|
+
And [Cloudflare](https://x.com/thomas_ankcorn/status/2033931057133748330).
|
|
14
9
|
|
|
15
|
-
|
|
10
|
+
## How it works
|
|
16
11
|
|
|
17
|
-
|
|
12
|
+
agent-sql works by fully parsing the supplied SQL query into an AST and transforming it:
|
|
18
13
|
|
|
19
|
-
|
|
20
|
-
|
|
14
|
+
- **Only `SELECT`:** it's impossible to insert, drop or anything else.
|
|
15
|
+
- **Reduced subset:** CTEs, subqueries and other tricky things are rejected.
|
|
16
|
+
- **Limited functions:** passed through a (configurable) whitelist.
|
|
17
|
+
- **No DoS:** a default `LIMIT` is applied, but can be adjusted.
|
|
18
|
+
- **`WHERE` guards:** insert multiple tenant/ownership conditions to be inserted.
|
|
19
|
+
- **`JOIN`s added:** if needed to reach the guard tenant tables (save on tokens).
|
|
20
|
+
- **No sneaky joins:** no `join secrets on true`. We have your back.
|
|
21
21
|
|
|
22
22
|
## Quickstart
|
|
23
23
|
|
|
@@ -28,17 +28,15 @@ npm install agent-sql
|
|
|
28
28
|
```ts
|
|
29
29
|
import { agentSql } from "agent-sql";
|
|
30
30
|
|
|
31
|
-
const sql = agentSql(`SELECT * FROM msg`, "msg.
|
|
31
|
+
const sql = agentSql(`SELECT * FROM msg`, "msg.tenant_id", 123);
|
|
32
32
|
|
|
33
33
|
console.log(sql);
|
|
34
34
|
// SELECT *
|
|
35
35
|
// FROM msg
|
|
36
|
-
// WHERE msg.
|
|
36
|
+
// WHERE msg.tenant_id = 123
|
|
37
37
|
// LIMIT 10000
|
|
38
38
|
```
|
|
39
39
|
|
|
40
|
-
`agent-sql` parses the SQL, enforces a mandatory equality filter on the given column as the outermost `AND` condition (so it cannot be short-circuited by agent-supplied `OR` clauses), and returns the sanitised SQL string.
|
|
41
|
-
|
|
42
40
|
## Usage
|
|
43
41
|
|
|
44
42
|
### Define once, use many times
|
package/dist/index.d.mts
CHANGED
|
@@ -26,18 +26,21 @@ declare function parseSql(expr: string): Result<SelectStatement>;
|
|
|
26
26
|
//#region src/index.d.ts
|
|
27
27
|
declare function agentSql<S extends string>(sql: string, column: S & OneOrTwoDots<S>, value: GuardVal, {
|
|
28
28
|
schema,
|
|
29
|
+
autoJoin,
|
|
29
30
|
limit,
|
|
30
31
|
pretty,
|
|
31
32
|
db,
|
|
32
33
|
allowExtraFunctions
|
|
33
34
|
}?: {
|
|
34
35
|
schema?: Schema;
|
|
36
|
+
autoJoin?: boolean;
|
|
35
37
|
limit?: number;
|
|
36
38
|
pretty?: boolean;
|
|
37
39
|
db?: DbType;
|
|
38
40
|
allowExtraFunctions?: string[];
|
|
39
41
|
}): string;
|
|
40
42
|
declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(schema: T, guards: Record<S, GuardVal>, opts: {
|
|
43
|
+
autoJoin?: boolean;
|
|
41
44
|
limit?: number;
|
|
42
45
|
pretty?: boolean;
|
|
43
46
|
throws: false;
|
|
@@ -45,6 +48,7 @@ declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(
|
|
|
45
48
|
allowExtraFunctions?: string[];
|
|
46
49
|
}): (expr: string) => Result<string>;
|
|
47
50
|
declare function createAgentSql<T extends Schema, S extends SchemaGuardKeys<T>>(schema: T, guards: Record<S, GuardVal>, opts?: {
|
|
51
|
+
autoJoin?: boolean;
|
|
48
52
|
limit?: number;
|
|
49
53
|
pretty?: boolean;
|
|
50
54
|
throws?: true;
|
package/dist/index.mjs
CHANGED
|
@@ -774,6 +774,177 @@ function checkWhereExpr(expr, allowed) {
|
|
|
774
774
|
}
|
|
775
775
|
}
|
|
776
776
|
//#endregion
|
|
777
|
+
//#region src/joins.ts
|
|
778
|
+
function defineSchema(schema) {
|
|
779
|
+
return schema;
|
|
780
|
+
}
|
|
781
|
+
function checkJoins(ast, schema) {
|
|
782
|
+
if (schema === void 0) {
|
|
783
|
+
if (ast.joins.length > 0) return Err(new SanitiseError("No joins allowed when using simple API without schema."));
|
|
784
|
+
return Ok(ast);
|
|
785
|
+
}
|
|
786
|
+
if (!(ast.from.table.name in schema)) return Err(new SanitiseError(`Table ${ast.from.table.name} is not allowed`));
|
|
787
|
+
for (const join of ast.joins) {
|
|
788
|
+
const joinSettings = schema[join.table.name];
|
|
789
|
+
if (joinSettings === void 0) return Err(new SanitiseError(`Table ${join.table.name} is not allowed`));
|
|
790
|
+
if (join.condition === null || join.condition.type === "join_using" || join.condition.expr.type !== "where_comparison" || join.condition.expr.operator !== "=" || join.condition.expr.left.type !== "where_value" || join.condition.expr.left.kind !== "column_ref" || join.condition.expr.right.type !== "where_value" || join.condition.expr.right.kind !== "column_ref") return Err(new SanitiseError("Only JOIN ON column_ref = column_ref supported"));
|
|
791
|
+
const { joining, foreign } = getJoinTableRef(join.table.name, join.condition.expr.left.ref, join.condition.expr.right.ref);
|
|
792
|
+
const joinTableCol = joinSettings[joining.name];
|
|
793
|
+
if (joinTableCol === void 0) return Err(new SanitiseError(`Tried to join using ${join.table.name}.${joining.name}`));
|
|
794
|
+
if (joinTableCol === null) {
|
|
795
|
+
const foreignTableSettings = schema[foreign.table];
|
|
796
|
+
if (foreignTableSettings === void 0) return Err(new SanitiseError(`Table ${foreign.name} is not allowed`));
|
|
797
|
+
const foreignCol = foreignTableSettings[foreign.name];
|
|
798
|
+
if (foreignCol === void 0 || foreignCol === null) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
|
|
799
|
+
if (joining.table !== foreignCol.ft || joining.name !== foreignCol.fc) return Err(new SanitiseError(`Tried to join using ${joining.table}.${joining.name}`));
|
|
800
|
+
} else if (foreign.table !== joinTableCol.ft || foreign.name !== joinTableCol.fc) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
|
|
801
|
+
}
|
|
802
|
+
return Ok(ast);
|
|
803
|
+
}
|
|
804
|
+
function getJoinTableRef(joinTableName, left, right) {
|
|
805
|
+
if (left.table === joinTableName) return {
|
|
806
|
+
joining: left,
|
|
807
|
+
foreign: right
|
|
808
|
+
};
|
|
809
|
+
return {
|
|
810
|
+
joining: right,
|
|
811
|
+
foreign: left
|
|
812
|
+
};
|
|
813
|
+
}
|
|
814
|
+
//#endregion
|
|
815
|
+
//#region src/graph.ts
|
|
816
|
+
function insertNeededGuardJoins(ast, schema, guards, autoJoin) {
|
|
817
|
+
if (schema === void 0) return Ok(ast);
|
|
818
|
+
if (!autoJoin) return Ok(ast);
|
|
819
|
+
return resolveGraphForJoins(ast, schema, guards);
|
|
820
|
+
}
|
|
821
|
+
function resolveGraphForJoins(ast, schema, guards) {
|
|
822
|
+
const haveTables = /* @__PURE__ */ new Set();
|
|
823
|
+
haveTables.add(ast.from.table.name);
|
|
824
|
+
for (const join of ast.joins) haveTables.add(join.table.name);
|
|
825
|
+
const originalTables = new Set(haveTables);
|
|
826
|
+
const adj = buildAdjacency(schema);
|
|
827
|
+
const newJoins = [];
|
|
828
|
+
for (const guard of guards) {
|
|
829
|
+
if (haveTables.has(guard.table)) continue;
|
|
830
|
+
const path = bfsPath(adj, haveTables, guard.table);
|
|
831
|
+
if (path === null) return Err(new SanitiseError(`No join path from query tables to guard table '${guard.table}'`));
|
|
832
|
+
let current = [...haveTables].find((t) => path[0] && (path[0].tableA === t || path[0].tableB === t));
|
|
833
|
+
for (const edge of path) {
|
|
834
|
+
const neighbor = edge.tableA === current ? edge.tableB : edge.tableA;
|
|
835
|
+
if (!haveTables.has(neighbor)) {
|
|
836
|
+
newJoins.push(edgeToJoin(edge, neighbor));
|
|
837
|
+
haveTables.add(neighbor);
|
|
838
|
+
}
|
|
839
|
+
current = neighbor;
|
|
840
|
+
}
|
|
841
|
+
}
|
|
842
|
+
if (newJoins.length === 0) return Ok(ast);
|
|
843
|
+
const columns = qualifyWildcards(ast.columns, originalTables);
|
|
844
|
+
return Ok({
|
|
845
|
+
...ast,
|
|
846
|
+
columns,
|
|
847
|
+
joins: [...ast.joins, ...newJoins]
|
|
848
|
+
});
|
|
849
|
+
}
|
|
850
|
+
function buildAdjacency(schema) {
|
|
851
|
+
const adj = /* @__PURE__ */ new Map();
|
|
852
|
+
const tables = schema;
|
|
853
|
+
for (const [tableName, cols] of Object.entries(tables)) {
|
|
854
|
+
if (!adj.has(tableName)) adj.set(tableName, []);
|
|
855
|
+
for (const [colName, def] of Object.entries(cols)) if (def && typeof def === "object" && "ft" in def && "fc" in def) {
|
|
856
|
+
const edge = {
|
|
857
|
+
tableA: tableName,
|
|
858
|
+
colA: colName,
|
|
859
|
+
tableB: def.ft,
|
|
860
|
+
colB: def.fc
|
|
861
|
+
};
|
|
862
|
+
adj.get(tableName).push(edge);
|
|
863
|
+
if (!adj.has(edge.tableB)) adj.set(edge.tableB, []);
|
|
864
|
+
adj.get(edge.tableB).push(edge);
|
|
865
|
+
}
|
|
866
|
+
}
|
|
867
|
+
return adj;
|
|
868
|
+
}
|
|
869
|
+
function bfsPath(adj, startTables, target) {
|
|
870
|
+
if (startTables.has(target)) return [];
|
|
871
|
+
const visited = new Set(startTables);
|
|
872
|
+
const queue = [];
|
|
873
|
+
for (const t of startTables) queue.push([t, []]);
|
|
874
|
+
while (queue.length > 0) {
|
|
875
|
+
const [current, path] = queue.shift();
|
|
876
|
+
for (const edge of adj.get(current) ?? []) {
|
|
877
|
+
const neighbor = edge.tableA === current ? edge.tableB : edge.tableA;
|
|
878
|
+
if (visited.has(neighbor)) continue;
|
|
879
|
+
visited.add(neighbor);
|
|
880
|
+
const newPath = [...path, edge];
|
|
881
|
+
if (neighbor === target) return newPath;
|
|
882
|
+
queue.push([neighbor, newPath]);
|
|
883
|
+
}
|
|
884
|
+
}
|
|
885
|
+
return null;
|
|
886
|
+
}
|
|
887
|
+
function edgeToJoin(edge, fromTable) {
|
|
888
|
+
const [localTable, localCol, foreignTable, foreignCol] = edge.tableA === fromTable ? [
|
|
889
|
+
edge.tableA,
|
|
890
|
+
edge.colA,
|
|
891
|
+
edge.tableB,
|
|
892
|
+
edge.colB
|
|
893
|
+
] : [
|
|
894
|
+
edge.tableB,
|
|
895
|
+
edge.colB,
|
|
896
|
+
edge.tableA,
|
|
897
|
+
edge.colA
|
|
898
|
+
];
|
|
899
|
+
return {
|
|
900
|
+
type: "join",
|
|
901
|
+
joinType: "inner",
|
|
902
|
+
table: {
|
|
903
|
+
type: "table_ref",
|
|
904
|
+
name: localTable
|
|
905
|
+
},
|
|
906
|
+
condition: {
|
|
907
|
+
type: "join_on",
|
|
908
|
+
expr: {
|
|
909
|
+
type: "where_comparison",
|
|
910
|
+
operator: "=",
|
|
911
|
+
left: {
|
|
912
|
+
type: "where_value",
|
|
913
|
+
kind: "column_ref",
|
|
914
|
+
ref: {
|
|
915
|
+
type: "column_ref",
|
|
916
|
+
table: localTable,
|
|
917
|
+
name: localCol
|
|
918
|
+
}
|
|
919
|
+
},
|
|
920
|
+
right: {
|
|
921
|
+
type: "where_value",
|
|
922
|
+
kind: "column_ref",
|
|
923
|
+
ref: {
|
|
924
|
+
type: "column_ref",
|
|
925
|
+
table: foreignTable,
|
|
926
|
+
name: foreignCol
|
|
927
|
+
}
|
|
928
|
+
}
|
|
929
|
+
}
|
|
930
|
+
}
|
|
931
|
+
};
|
|
932
|
+
}
|
|
933
|
+
function qualifyWildcards(columns, tables) {
|
|
934
|
+
if (!columns.some((c) => c.expr.kind === "wildcard")) return columns;
|
|
935
|
+
const qualified = [];
|
|
936
|
+
for (const col of columns) if (col.expr.kind === "wildcard") for (const table of tables) qualified.push({
|
|
937
|
+
type: "column",
|
|
938
|
+
expr: {
|
|
939
|
+
type: "column_expr",
|
|
940
|
+
kind: "qualified_wildcard",
|
|
941
|
+
table
|
|
942
|
+
}
|
|
943
|
+
});
|
|
944
|
+
else qualified.push(col);
|
|
945
|
+
return qualified;
|
|
946
|
+
}
|
|
947
|
+
//#endregion
|
|
777
948
|
//#region src/utils.ts
|
|
778
949
|
function unreachable(x) {
|
|
779
950
|
throw new Error(`Unhandled variant: ${JSON.stringify(x)}`);
|
|
@@ -1108,44 +1279,6 @@ function tableEquals(a, b) {
|
|
|
1108
1279
|
return a.schema == b.schema && a.name == b.name;
|
|
1109
1280
|
}
|
|
1110
1281
|
//#endregion
|
|
1111
|
-
//#region src/joins.ts
|
|
1112
|
-
function defineSchema(schema) {
|
|
1113
|
-
return schema;
|
|
1114
|
-
}
|
|
1115
|
-
function checkJoins(ast, schema) {
|
|
1116
|
-
if (schema === void 0) {
|
|
1117
|
-
if (ast.joins.length > 0) return Err(new SanitiseError("No joins allowed when using simple API without schema."));
|
|
1118
|
-
return Ok(ast);
|
|
1119
|
-
}
|
|
1120
|
-
if (!(ast.from.table.name in schema)) return Err(new SanitiseError(`Table ${ast.from.table.name} is not allowed`));
|
|
1121
|
-
for (const join of ast.joins) {
|
|
1122
|
-
const joinSettings = schema[join.table.name];
|
|
1123
|
-
if (joinSettings === void 0) return Err(new SanitiseError(`Table ${join.table.name} is not allowed`));
|
|
1124
|
-
if (join.condition === null || join.condition.type === "join_using" || join.condition.expr.type !== "where_comparison" || join.condition.expr.operator !== "=" || join.condition.expr.left.type !== "where_value" || join.condition.expr.left.kind !== "column_ref" || join.condition.expr.right.type !== "where_value" || join.condition.expr.right.kind !== "column_ref") return Err(new SanitiseError("Only JOIN ON column_ref = column_ref supported"));
|
|
1125
|
-
const { joining, foreign } = getJoinTableRef(join.table.name, join.condition.expr.left.ref, join.condition.expr.right.ref);
|
|
1126
|
-
const joinTableCol = joinSettings[joining.name];
|
|
1127
|
-
if (joinTableCol === void 0) return Err(new SanitiseError(`Tried to join using ${join.table.name}.${joining.name}`));
|
|
1128
|
-
if (joinTableCol === null) {
|
|
1129
|
-
const foreignTableSettings = schema[foreign.table];
|
|
1130
|
-
if (foreignTableSettings === void 0) return Err(new SanitiseError(`Table ${foreign.name} is not allowed`));
|
|
1131
|
-
const foreignCol = foreignTableSettings[foreign.name];
|
|
1132
|
-
if (foreignCol === void 0 || foreignCol === null) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
|
|
1133
|
-
if (joining.table !== foreignCol.ft || joining.name !== foreignCol.fc) return Err(new SanitiseError(`Tried to join using ${joining.table}.${joining.name}`));
|
|
1134
|
-
} else if (foreign.table !== joinTableCol.ft || foreign.name !== joinTableCol.fc) return Err(new SanitiseError(`Tried to join using ${foreign.table}.${foreign.name}`));
|
|
1135
|
-
}
|
|
1136
|
-
return Ok(ast);
|
|
1137
|
-
}
|
|
1138
|
-
function getJoinTableRef(joinTableName, left, right) {
|
|
1139
|
-
if (left.table === joinTableName) return {
|
|
1140
|
-
joining: left,
|
|
1141
|
-
foreign: right
|
|
1142
|
-
};
|
|
1143
|
-
return {
|
|
1144
|
-
joining: right,
|
|
1145
|
-
foreign: left
|
|
1146
|
-
};
|
|
1147
|
-
}
|
|
1148
|
-
//#endregion
|
|
1149
1282
|
//#region src/sql.ohm-bundle.js
|
|
1150
1283
|
const result = makeRecipe([
|
|
1151
1284
|
"grammar",
|
|
@@ -7204,10 +7337,11 @@ function parseSql(expr) {
|
|
|
7204
7337
|
}
|
|
7205
7338
|
//#endregion
|
|
7206
7339
|
//#region src/index.ts
|
|
7207
|
-
function agentSql(sql, column, value, { schema, limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [] } = {}) {
|
|
7340
|
+
function agentSql(sql, column, value, { schema, autoJoin = true, limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [] } = {}) {
|
|
7208
7341
|
return privateAgentSql(sql, {
|
|
7209
7342
|
guards: { [column]: value },
|
|
7210
7343
|
schema,
|
|
7344
|
+
autoJoin,
|
|
7211
7345
|
limit,
|
|
7212
7346
|
pretty,
|
|
7213
7347
|
db,
|
|
@@ -7215,10 +7349,11 @@ function agentSql(sql, column, value, { schema, limit = DEFAULT_LIMIT, pretty =
|
|
|
7215
7349
|
throws: true
|
|
7216
7350
|
});
|
|
7217
7351
|
}
|
|
7218
|
-
function createAgentSql(schema, guards, { limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [], throws = true } = {}) {
|
|
7352
|
+
function createAgentSql(schema, guards, { autoJoin = true, limit = DEFAULT_LIMIT, pretty = false, db = DEFAULT_DB, allowExtraFunctions = [], throws = true } = {}) {
|
|
7219
7353
|
return (expr) => throws ? privateAgentSql(expr, {
|
|
7220
7354
|
guards,
|
|
7221
7355
|
schema,
|
|
7356
|
+
autoJoin,
|
|
7222
7357
|
limit,
|
|
7223
7358
|
pretty,
|
|
7224
7359
|
db,
|
|
@@ -7227,6 +7362,7 @@ function createAgentSql(schema, guards, { limit = DEFAULT_LIMIT, pretty = false,
|
|
|
7227
7362
|
}) : privateAgentSql(expr, {
|
|
7228
7363
|
guards,
|
|
7229
7364
|
schema,
|
|
7365
|
+
autoJoin,
|
|
7230
7366
|
limit,
|
|
7231
7367
|
pretty,
|
|
7232
7368
|
db,
|
|
@@ -7234,7 +7370,7 @@ function createAgentSql(schema, guards, { limit = DEFAULT_LIMIT, pretty = false,
|
|
|
7234
7370
|
throws
|
|
7235
7371
|
});
|
|
7236
7372
|
}
|
|
7237
|
-
function privateAgentSql(sql, { guards: guardsRaw, schema, limit, pretty, db, allowExtraFunctions, throws }) {
|
|
7373
|
+
function privateAgentSql(sql, { guards: guardsRaw, schema, autoJoin, limit, pretty, db, allowExtraFunctions, throws }) {
|
|
7238
7374
|
const guards = resolveGuards(guardsRaw);
|
|
7239
7375
|
if (!guards.ok) throw guards.error;
|
|
7240
7376
|
const ast = parseSql(sql);
|
|
@@ -7243,7 +7379,9 @@ function privateAgentSql(sql, { guards: guardsRaw, schema, limit, pretty, db, al
|
|
|
7243
7379
|
if (!ast2.ok) return returnOrThrow(ast2, throws);
|
|
7244
7380
|
const ast3 = checkFunctions(ast2.data, db, allowExtraFunctions);
|
|
7245
7381
|
if (!ast3.ok) return returnOrThrow(ast3, throws);
|
|
7246
|
-
const
|
|
7382
|
+
const ast4 = insertNeededGuardJoins(ast3.data, schema, guards.data, autoJoin);
|
|
7383
|
+
if (!ast4.ok) return returnOrThrow(ast4, throws);
|
|
7384
|
+
const san = applyGuards(ast4.data, guards.data, limit);
|
|
7247
7385
|
if (!san.ok) return returnOrThrow(san, throws);
|
|
7248
7386
|
const res = outputSql(san.data, pretty);
|
|
7249
7387
|
if (throws) return res;
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "agent-sql",
|
|
3
|
-
"version": "0.
|
|
3
|
+
"version": "0.3.1",
|
|
4
4
|
"description": "A starter for creating a TypeScript package.",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"agent",
|
|
@@ -36,7 +36,7 @@
|
|
|
36
36
|
"test": "vp test",
|
|
37
37
|
"check": "vp check",
|
|
38
38
|
"prepublishOnly": "vp run build",
|
|
39
|
-
"
|
|
39
|
+
"ohm": "ohm generateBundles --withTypes --esm 'src/sql.ohm'",
|
|
40
40
|
"script": "vp exec tsx"
|
|
41
41
|
},
|
|
42
42
|
"dependencies": {
|