@zenstackhq/plugin-policy 3.1.0 → 3.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -927,6 +927,7 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
927
927
  get kysely() {
928
928
  return this.client.$qb;
929
929
  }
930
+ // #region main entry point
930
931
  async handle(node, proceed) {
931
932
  if (!this.isCrudQueryNode(node)) {
932
933
  throw createRejectedByPolicyError(void 0, import_orm4.RejectedByPolicyReason.OTHER, "non-CRUD queries are not allowed");
@@ -937,19 +938,10 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
937
938
  const { mutationModel } = this.getMutationModel(node);
938
939
  this.tryRejectNonexistentModel(mutationModel);
939
940
  if (import_kysely3.InsertQueryNode.is(node)) {
940
- const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel);
941
- let needCheckPreCreate = true;
942
- if (!isManyToManyJoinTable) {
943
- const constCondition = this.tryGetConstantPolicy(mutationModel, "create");
944
- if (constCondition === true) {
945
- needCheckPreCreate = false;
946
- } else if (constCondition === false) {
947
- throw createRejectedByPolicyError(mutationModel, import_orm4.RejectedByPolicyReason.NO_ACCESS);
948
- }
949
- }
950
- if (needCheckPreCreate) {
951
- await this.enforcePreCreatePolicy(node, mutationModel, isManyToManyJoinTable, proceed);
952
- }
941
+ await this.preCreateCheck(mutationModel, node, proceed);
942
+ }
943
+ if (import_kysely3.UpdateQueryNode.is(node)) {
944
+ await this.preUpdateCheck(mutationModel, node, proceed);
953
945
  }
954
946
  const hasPostUpdatePolicies = import_kysely3.UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
955
947
  let beforeUpdateInfo;
@@ -1008,94 +1000,85 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1008
1000
  return readBackResult;
1009
1001
  }
1010
1002
  }
1011
- // correction to kysely mutation result may be needed because we might have added
1012
- // returning clause to the query and caused changes to the result shape
1013
- postProcessMutationResult(result, node) {
1014
- if (node.returning) {
1015
- return result;
1016
- } else {
1017
- return {
1018
- ...result,
1019
- rows: [],
1020
- numAffectedRows: result.numAffectedRows ?? BigInt(result.rows.length)
1021
- };
1003
+ async preCreateCheck(mutationModel, node, proceed) {
1004
+ const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel);
1005
+ let needCheckPreCreate = true;
1006
+ if (!isManyToManyJoinTable) {
1007
+ const constCondition = this.tryGetConstantPolicy(mutationModel, "create");
1008
+ if (constCondition === true) {
1009
+ needCheckPreCreate = false;
1010
+ } else if (constCondition === false) {
1011
+ throw createRejectedByPolicyError(mutationModel, import_orm4.RejectedByPolicyReason.NO_ACCESS);
1012
+ }
1022
1013
  }
1023
- }
1024
- hasPostUpdatePolicies(model) {
1025
- const policies = this.getModelPolicies(model, "post-update");
1026
- return policies.length > 0;
1027
- }
1028
- async loadBeforeUpdateEntities(model, where, proceed) {
1029
- const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
1030
- if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
1031
- return void 0;
1014
+ if (needCheckPreCreate) {
1015
+ await this.enforcePreCreatePolicy(node, mutationModel, isManyToManyJoinTable, proceed);
1032
1016
  }
1033
- const policyFilter = this.buildPolicyFilter(model, model, "update");
1034
- const combinedFilter = where ? conjunction(this.dialect, [
1035
- where.where,
1036
- policyFilter
1037
- ]) : policyFilter;
1038
- const query = {
1039
- kind: "SelectQueryNode",
1040
- from: import_kysely3.FromNode.create([
1041
- import_kysely3.TableNode.create(model)
1042
- ]),
1043
- where: import_kysely3.WhereNode.create(combinedFilter),
1044
- selections: [
1045
- ...beforeUpdateAccessFields.map((f) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(f)))
1046
- ]
1047
- };
1048
- const result = await proceed(query);
1049
- return {
1050
- fields: beforeUpdateAccessFields,
1051
- rows: result.rows
1052
- };
1053
1017
  }
1054
- getFieldsAccessForBeforeUpdatePolicies(model) {
1055
- const policies = this.getModelPolicies(model, "post-update");
1056
- if (policies.length === 0) {
1057
- return void 0;
1018
+ async preUpdateCheck(mutationModel, node, proceed) {
1019
+ const fieldsToUpdate = node.updates?.map((u) => import_kysely3.ColumnNode.is(u.column) ? u.column.column.name : void 0).filter((f) => !!f) ?? [];
1020
+ const fieldUpdatePolicies = fieldsToUpdate.map((f) => this.buildFieldPolicyFilter(mutationModel, f, "update"));
1021
+ const fieldLevelFilter = conjunction(this.dialect, fieldUpdatePolicies);
1022
+ if (isTrueNode(fieldLevelFilter)) {
1023
+ return;
1058
1024
  }
1059
- const fields = /* @__PURE__ */ new Set();
1060
- const fieldCollector = new class extends import_orm4.SchemaUtils.ExpressionVisitor {
1061
- visitMember(e) {
1062
- if (isBeforeInvocation(e.receiver)) {
1063
- (0, import_common_helpers3.invariant)(e.members.length === 1, "before() can only be followed by a scalar field access");
1064
- fields.add(e.members[0]);
1065
- }
1066
- super.visitMember(e);
1067
- }
1068
- }();
1069
- for (const policy of policies) {
1070
- fieldCollector.visit(policy.condition);
1071
- }
1072
- if (fields.size === 0) {
1073
- return void 0;
1025
+ const modelLevelFilter = this.buildPolicyFilter(mutationModel, void 0, "update");
1026
+ const updateFilter = conjunction(this.dialect, [
1027
+ modelLevelFilter,
1028
+ node.where?.where ?? trueNode(this.dialect)
1029
+ ]);
1030
+ const preUpdateCheckQuery = (0, import_kysely3.expressionBuilder)().selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(eb.cast(new import_kysely3.ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)), "integer")), eb.lit(0)).as("$filteredCount")).where(() => new import_kysely3.ExpressionWrapper(updateFilter));
1031
+ const preUpdateResult = await proceed(preUpdateCheckQuery.toOperationNode());
1032
+ if (preUpdateResult.rows[0].$filteredCount > 0) {
1033
+ throw createRejectedByPolicyError(mutationModel, import_orm4.RejectedByPolicyReason.NO_ACCESS, "some rows cannot be updated due to field policies");
1074
1034
  }
1075
- import_orm4.QueryUtils.requireIdFields(this.client.$schema, model).forEach((f) => fields.add(f));
1076
- return Array.from(fields).sort();
1077
1035
  }
1078
- // #region overrides
1036
+ // #endregion
1037
+ // #region Transformations
1079
1038
  transformSelectQuery(node) {
1080
1039
  if (!node.from) {
1081
1040
  return super.transformSelectQuery(node);
1082
1041
  }
1083
- let whereNode = this.transformNode(node.where);
1084
- const policyFilter = this.createPolicyFilterForFrom(node.from);
1085
- if (policyFilter) {
1086
- whereNode = import_kysely3.WhereNode.create(whereNode?.where ? conjunction(this.dialect, [
1087
- whereNode.where,
1088
- policyFilter
1089
- ]) : policyFilter);
1090
- }
1091
- const baseResult = super.transformSelectQuery({
1092
- ...node,
1093
- where: void 0
1042
+ this.tryRejectNonexistingTables(node.from.froms);
1043
+ let result = super.transformSelectQuery(node);
1044
+ const hasFieldLevelPolicies = node.from.froms.some((table) => {
1045
+ const extractedTable = this.extractTableName(table);
1046
+ if (extractedTable) {
1047
+ return this.hasFieldLevelPolicies(extractedTable.model, "read");
1048
+ } else {
1049
+ return false;
1050
+ }
1094
1051
  });
1095
- return {
1096
- ...baseResult,
1097
- where: whereNode
1098
- };
1052
+ if (hasFieldLevelPolicies) {
1053
+ const updatedFroms = [];
1054
+ for (const table of result.from.froms) {
1055
+ const extractedTable = this.extractTableName(table);
1056
+ if (extractedTable?.model && import_orm4.QueryUtils.getModel(this.client.$schema, extractedTable.model)) {
1057
+ const { query } = this.createSelectAllFieldsWithPolicies(extractedTable.model, extractedTable.alias, "read");
1058
+ updatedFroms.push(query);
1059
+ } else {
1060
+ updatedFroms.push(table);
1061
+ }
1062
+ }
1063
+ result = {
1064
+ ...result,
1065
+ from: import_kysely3.FromNode.create(updatedFroms)
1066
+ };
1067
+ } else {
1068
+ let whereNode = result.where;
1069
+ const policyFilter = this.createPolicyFilterForFrom(result.from);
1070
+ if (policyFilter && !isTrueNode(policyFilter)) {
1071
+ whereNode = import_kysely3.WhereNode.create(whereNode?.where ? conjunction(this.dialect, [
1072
+ whereNode.where,
1073
+ policyFilter
1074
+ ]) : policyFilter);
1075
+ }
1076
+ result = {
1077
+ ...result,
1078
+ where: whereNode
1079
+ };
1080
+ }
1081
+ return result;
1099
1082
  }
1100
1083
  transformJoin(node) {
1101
1084
  const table = this.extractTableName(node.table);
@@ -1103,20 +1086,17 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1103
1086
  return super.transformJoin(node);
1104
1087
  }
1105
1088
  this.tryRejectNonexistentModel(table.model);
1106
- const filter = this.buildPolicyFilter(table.model, table.alias, "read");
1107
- const nestedSelect = {
1108
- kind: "SelectQueryNode",
1109
- from: import_kysely3.FromNode.create([
1110
- node.table
1111
- ]),
1112
- selections: [
1113
- import_kysely3.SelectionNode.createSelectAll()
1114
- ],
1115
- where: import_kysely3.WhereNode.create(filter)
1116
- };
1089
+ if (!import_orm4.QueryUtils.getModel(this.client.$schema, table.model)) {
1090
+ return super.transformJoin(node);
1091
+ }
1092
+ const result = super.transformJoin(node);
1093
+ const { hasPolicies, query: nestedQuery } = this.createSelectAllFieldsWithPolicies(table.model, table.alias, "read");
1094
+ if (!hasPolicies) {
1095
+ return result;
1096
+ }
1117
1097
  return {
1118
- ...node,
1119
- table: import_kysely3.AliasNode.create(import_kysely3.ParensNode.create(nestedSelect), import_kysely3.IdentifierNode.create(table.alias ?? table.model))
1098
+ ...result,
1099
+ table: nestedQuery
1120
1100
  };
1121
1101
  }
1122
1102
  transformInsertQuery(node) {
@@ -1160,6 +1140,7 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1160
1140
  const { mutationModel, alias } = this.getMutationModel(node);
1161
1141
  let filter = this.buildPolicyFilter(mutationModel, alias, "update");
1162
1142
  if (node.from) {
1143
+ this.tryRejectNonexistingTables(node.from.froms);
1163
1144
  const joinFilter = this.createPolicyFilterForFrom(node.from);
1164
1145
  if (joinFilter) {
1165
1146
  filter = conjunction(this.dialect, [
@@ -1187,6 +1168,7 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1187
1168
  const { mutationModel, alias } = this.getMutationModel(node);
1188
1169
  let filter = this.buildPolicyFilter(mutationModel, alias, "delete");
1189
1170
  if (node.using) {
1171
+ this.tryRejectNonexistingTables(node.using.tables);
1190
1172
  const joinFilter = this.createPolicyFilterForTables(node.using.tables);
1191
1173
  if (joinFilter) {
1192
1174
  filter = conjunction(this.dialect, [
@@ -1204,6 +1186,139 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1204
1186
  };
1205
1187
  }
1206
1188
  // #endregion
1189
+ // #region post-update
1190
+ async loadBeforeUpdateEntities(model, where, proceed) {
1191
+ const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
1192
+ if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
1193
+ return void 0;
1194
+ }
1195
+ const policyFilter = this.buildPolicyFilter(model, model, "update");
1196
+ const combinedFilter = where ? conjunction(this.dialect, [
1197
+ where.where,
1198
+ policyFilter
1199
+ ]) : policyFilter;
1200
+ const query = {
1201
+ kind: "SelectQueryNode",
1202
+ from: import_kysely3.FromNode.create([
1203
+ import_kysely3.TableNode.create(model)
1204
+ ]),
1205
+ where: import_kysely3.WhereNode.create(combinedFilter),
1206
+ selections: [
1207
+ ...beforeUpdateAccessFields.map((f) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(f)))
1208
+ ]
1209
+ };
1210
+ const result = await proceed(query);
1211
+ return {
1212
+ fields: beforeUpdateAccessFields,
1213
+ rows: result.rows
1214
+ };
1215
+ }
1216
+ getFieldsAccessForBeforeUpdatePolicies(model) {
1217
+ const policies = this.getModelPolicies(model, "post-update");
1218
+ if (policies.length === 0) {
1219
+ return void 0;
1220
+ }
1221
+ const fields = /* @__PURE__ */ new Set();
1222
+ const fieldCollector = new class extends import_orm4.SchemaUtils.ExpressionVisitor {
1223
+ visitMember(e) {
1224
+ if (isBeforeInvocation(e.receiver)) {
1225
+ (0, import_common_helpers3.invariant)(e.members.length === 1, "before() can only be followed by a scalar field access");
1226
+ fields.add(e.members[0]);
1227
+ }
1228
+ super.visitMember(e);
1229
+ }
1230
+ }();
1231
+ for (const policy of policies) {
1232
+ fieldCollector.visit(policy.condition);
1233
+ }
1234
+ if (fields.size === 0) {
1235
+ return void 0;
1236
+ }
1237
+ import_orm4.QueryUtils.requireIdFields(this.client.$schema, model).forEach((f) => fields.add(f));
1238
+ return Array.from(fields).sort();
1239
+ }
1240
+ hasPostUpdatePolicies(model) {
1241
+ const policies = this.getModelPolicies(model, "post-update");
1242
+ return policies.length > 0;
1243
+ }
1244
+ // #endregion
1245
+ // #region field-level policies
1246
+ createSelectAllFieldsWithPolicies(model, alias, operation) {
1247
+ let hasPolicies = false;
1248
+ const modelDef = import_orm4.QueryUtils.requireModel(this.client.$schema, model);
1249
+ let selections = [];
1250
+ for (const fieldDef of Object.values(modelDef.fields).filter(
1251
+ // exclude relation/computed/inherited fields
1252
+ (f) => !f.relation && !f.computed && !f.originModel
1253
+ )) {
1254
+ const { hasPolicies: fieldHasPolicies, selection } = this.createFieldSelectionWithPolicy(model, fieldDef.name, operation);
1255
+ hasPolicies = hasPolicies || fieldHasPolicies;
1256
+ selections.push(selection);
1257
+ }
1258
+ if (!hasPolicies) {
1259
+ selections = [
1260
+ import_kysely3.SelectionNode.create(import_kysely3.SelectAllNode.create())
1261
+ ];
1262
+ }
1263
+ const modelPolicyFilter = this.buildPolicyFilter(model, alias, operation);
1264
+ if (!isTrueNode(modelPolicyFilter)) {
1265
+ hasPolicies = true;
1266
+ }
1267
+ const nestedQuery = {
1268
+ kind: "SelectQueryNode",
1269
+ from: import_kysely3.FromNode.create([
1270
+ import_kysely3.TableNode.create(model)
1271
+ ]),
1272
+ where: isTrueNode(modelPolicyFilter) ? void 0 : import_kysely3.WhereNode.create(modelPolicyFilter),
1273
+ selections
1274
+ };
1275
+ return {
1276
+ hasPolicies,
1277
+ query: import_kysely3.AliasNode.create(import_kysely3.ParensNode.create(nestedQuery), import_kysely3.IdentifierNode.create(alias ?? model))
1278
+ };
1279
+ }
1280
+ createFieldSelectionWithPolicy(model, field, operation) {
1281
+ const filter = this.buildFieldPolicyFilter(model, field, operation);
1282
+ if (isTrueNode(filter)) {
1283
+ return {
1284
+ hasPolicies: false,
1285
+ selection: import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(field))
1286
+ };
1287
+ }
1288
+ const eb = (0, import_kysely3.expressionBuilder)();
1289
+ const selection = eb.case().when(new import_kysely3.ExpressionWrapper(filter)).then(eb.ref(field)).else(null).end().as(field).toOperationNode();
1290
+ return {
1291
+ hasPolicies: true,
1292
+ selection: import_kysely3.SelectionNode.create(selection)
1293
+ };
1294
+ }
1295
+ hasFieldLevelPolicies(model, operation) {
1296
+ const modelDef = import_orm4.QueryUtils.getModel(this.client.$schema, model);
1297
+ if (!modelDef) {
1298
+ return false;
1299
+ }
1300
+ return Object.keys(modelDef.fields).some((field) => this.getFieldPolicies(model, field, operation).length > 0);
1301
+ }
1302
+ buildFieldPolicyFilter(model, field, operation) {
1303
+ const policies = this.getFieldPolicies(model, field, operation);
1304
+ const allows = policies.filter((policy) => policy.kind === "allow").map((policy) => this.compilePolicyCondition(model, model, operation, policy));
1305
+ const denies = policies.filter((policy) => policy.kind === "deny").map((policy) => this.compilePolicyCondition(model, model, operation, policy));
1306
+ let combinedPolicy;
1307
+ if (allows.length === 0) {
1308
+ combinedPolicy = trueNode(this.dialect);
1309
+ } else {
1310
+ combinedPolicy = disjunction(this.dialect, allows);
1311
+ }
1312
+ if (denies.length !== 0) {
1313
+ const combinedDenies = conjunction(this.dialect, denies.map((d) => buildIsFalse(d, this.dialect)));
1314
+ combinedPolicy = conjunction(this.dialect, [
1315
+ combinedPolicy,
1316
+ combinedDenies
1317
+ ]);
1318
+ }
1319
+ return combinedPolicy;
1320
+ }
1321
+ // #endregion
1207
1322
  // #region helpers
1208
1323
  onlyReturningId(node) {
1209
1324
  if (!node.returning) {
@@ -1495,7 +1610,6 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1495
1610
  const extractResult = this.extractTableName(table);
1496
1611
  if (extractResult) {
1497
1612
  const { model, alias } = extractResult;
1498
- this.tryRejectNonexistentModel(model);
1499
1613
  const filter = this.buildPolicyFilter(model, alias, "read");
1500
1614
  return acc ? conjunction(this.dialect, [
1501
1615
  acc,
@@ -1531,6 +1645,23 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1531
1645
  }
1532
1646
  return result;
1533
1647
  }
1648
+ getFieldPolicies(model, field, operation) {
1649
+ const fieldDef = import_orm4.QueryUtils.requireField(this.client.$schema, model, field);
1650
+ const result = [];
1651
+ const extractOperations = /* @__PURE__ */ __name((expr2) => {
1652
+ (0, import_common_helpers3.invariant)(import_schema4.ExpressionUtils.isLiteral(expr2), "expecting a literal");
1653
+ (0, import_common_helpers3.invariant)(typeof expr2.value === "string", "expecting a string literal");
1654
+ return expr2.value.split(",").filter((v) => !!v).map((v) => v.trim());
1655
+ }, "extractOperations");
1656
+ if (fieldDef.attributes) {
1657
+ result.push(...fieldDef.attributes.filter((attr) => attr.name === "@allow" || attr.name === "@deny").map((attr) => ({
1658
+ kind: attr.name === "@allow" ? "allow" : "deny",
1659
+ operations: extractOperations(attr.args[0].value),
1660
+ condition: attr.args[1].value
1661
+ })).filter((policy) => policy.operations.includes("all") || policy.operations.includes(operation)));
1662
+ }
1663
+ return result;
1664
+ }
1534
1665
  resolveManyToManyJoinTable(tableName) {
1535
1666
  for (const model of Object.values(this.client.$schema.models)) {
1536
1667
  for (const field of Object.values(model.fields)) {
@@ -1588,6 +1719,27 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1588
1719
  throw createRejectedByPolicyError(model, import_orm4.RejectedByPolicyReason.NO_ACCESS);
1589
1720
  }
1590
1721
  }
1722
+ tryRejectNonexistingTables(tables) {
1723
+ for (const table of tables) {
1724
+ const extractResult = this.extractTableName(table);
1725
+ if (extractResult) {
1726
+ this.tryRejectNonexistentModel(extractResult.model);
1727
+ }
1728
+ }
1729
+ }
1730
+ // correction to kysely mutation result may be needed because we might have added
1731
+ // returning clause to the query and caused changes to the result shape
1732
+ postProcessMutationResult(result, node) {
1733
+ if (node.returning) {
1734
+ return result;
1735
+ } else {
1736
+ return {
1737
+ ...result,
1738
+ rows: [],
1739
+ numAffectedRows: result.numAffectedRows ?? BigInt(result.rows.length)
1740
+ };
1741
+ }
1742
+ }
1591
1743
  };
1592
1744
 
1593
1745
  // src/functions.ts