@zenstackhq/plugin-policy 3.1.0 → 3.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -10,7 +10,7 @@ import { ExpressionWrapper as ExpressionWrapper2, ValueNode as ValueNode4 } from
10
10
  import { invariant as invariant3 } from "@zenstackhq/common-helpers";
11
11
  import { getCrudDialect as getCrudDialect2, QueryUtils as QueryUtils2, RejectedByPolicyReason, SchemaUtils as SchemaUtils2 } from "@zenstackhq/orm";
12
12
  import { ExpressionUtils as ExpressionUtils4 } from "@zenstackhq/orm/schema";
13
- import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode2, DeleteQueryNode, expressionBuilder as expressionBuilder2, ExpressionWrapper, FromNode as FromNode2, FunctionNode as FunctionNode3, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, RawNode, ReferenceNode as ReferenceNode3, ReturningNode, SelectAllNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql, TableNode as TableNode3, UpdateQueryNode, ValueListNode as ValueListNode2, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
13
+ import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode3, DeleteQueryNode, expressionBuilder as expressionBuilder2, ExpressionWrapper, FromNode as FromNode2, FunctionNode as FunctionNode3, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, RawNode, ReferenceNode as ReferenceNode3, ReturningNode, SelectAllNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql, TableNode as TableNode3, UpdateQueryNode, ValueListNode as ValueListNode2, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
14
14
  import { match as match3 } from "ts-pattern";
15
15
 
16
16
  // src/column-collector.ts
@@ -36,7 +36,7 @@ var ColumnCollector = class extends KyselyUtils.DefaultOperationNodeVisitor {
36
36
  import { invariant as invariant2 } from "@zenstackhq/common-helpers";
37
37
  import { getCrudDialect, QueryUtils, SchemaUtils } from "@zenstackhq/orm";
38
38
  import { ExpressionUtils as ExpressionUtils3 } from "@zenstackhq/orm/schema";
39
- import { AliasNode as AliasNode2, BinaryOperationNode as BinaryOperationNode2, ColumnNode, expressionBuilder, FromNode, FunctionNode as FunctionNode2, IdentifierNode, OperatorNode as OperatorNode2, ReferenceNode as ReferenceNode2, SelectionNode, SelectQueryNode, TableNode as TableNode2, ValueListNode, ValueNode as ValueNode2, WhereNode } from "kysely";
39
+ import { AliasNode as AliasNode2, BinaryOperationNode as BinaryOperationNode2, ColumnNode as ColumnNode2, expressionBuilder, FromNode, FunctionNode as FunctionNode2, IdentifierNode, OperatorNode as OperatorNode2, ReferenceNode as ReferenceNode2, SelectionNode, SelectQueryNode, TableNode as TableNode2, ValueListNode, ValueNode as ValueNode2, WhereNode } from "kysely";
40
40
  import { match as match2 } from "ts-pattern";
41
41
 
42
42
  // src/expression-evaluator.ts
@@ -126,7 +126,7 @@ var CollectionPredicateOperator = [
126
126
  // src/utils.ts
127
127
  import { ORMError, ORMErrorReason } from "@zenstackhq/orm";
128
128
  import { ExpressionUtils as ExpressionUtils2 } from "@zenstackhq/orm/schema";
129
- import { AliasNode, AndNode, BinaryOperationNode, FunctionNode, OperatorNode, OrNode, ParensNode, ReferenceNode, TableNode, UnaryOperationNode, ValueNode } from "kysely";
129
+ import { AliasNode, AndNode, BinaryOperationNode, ColumnNode, FunctionNode, OperatorNode, OrNode, ParensNode, ReferenceNode, TableNode, UnaryOperationNode, ValueNode } from "kysely";
130
130
  function trueNode(dialect) {
131
131
  return ValueNode.createImmediate(dialect.transformPrimitive(true, "Boolean", false));
132
132
  }
@@ -627,7 +627,7 @@ var ExpressionTransformer = class {
627
627
  if (isBeforeInvocation(expr2.receiver)) {
628
628
  invariant2(context.operation === "post-update", "before() can only be used in post-update policy");
629
629
  invariant2(expr2.members.length === 1, "before() can only be followed by a scalar field access");
630
- return ReferenceNode2.create(ColumnNode.create(expr2.members[0]), TableNode2.create("$before"));
630
+ return ReferenceNode2.create(ColumnNode2.create(expr2.members[0]), TableNode2.create("$before"));
631
631
  }
632
632
  invariant2(ExpressionUtils3.isField(expr2.receiver) || ExpressionUtils3.isThis(expr2.receiver), 'expect receiver to be field expression or "this"');
633
633
  let members = expr2.members;
@@ -697,7 +697,7 @@ var ExpressionTransformer = class {
697
697
  } else {
698
698
  invariant2(i === members.length - 1, "plain field access must be the last segment");
699
699
  invariant2(!currNode, "plain field access must be the last segment");
700
- currNode = ColumnNode.create(member);
700
+ currNode = ColumnNode2.create(member);
701
701
  }
702
702
  }
703
703
  return {
@@ -739,14 +739,14 @@ var ExpressionTransformer = class {
739
739
  let condition;
740
740
  if (ownedByModel) {
741
741
  condition = conjunction(this.dialect, keyPairs.map(({ fk, pk }) => {
742
- let fkRef = ReferenceNode2.create(ColumnNode.create(fk), TableNode2.create(context.alias ?? fromModel));
742
+ let fkRef = ReferenceNode2.create(ColumnNode2.create(fk), TableNode2.create(context.alias ?? fromModel));
743
743
  if (relationFieldDef.originModel && relationFieldDef.originModel !== fromModel) {
744
744
  fkRef = this.buildDelegateBaseFieldSelect(fromModel, context.alias ?? fromModel, fk, relationFieldDef.originModel);
745
745
  }
746
- return BinaryOperationNode2.create(fkRef, OperatorNode2.create("="), ReferenceNode2.create(ColumnNode.create(pk), TableNode2.create(relationModel)));
746
+ return BinaryOperationNode2.create(fkRef, OperatorNode2.create("="), ReferenceNode2.create(ColumnNode2.create(pk), TableNode2.create(relationModel)));
747
747
  }));
748
748
  } else {
749
- condition = conjunction(this.dialect, keyPairs.map(({ fk, pk }) => BinaryOperationNode2.create(ReferenceNode2.create(ColumnNode.create(pk), TableNode2.create(context.alias ?? fromModel)), OperatorNode2.create("="), ReferenceNode2.create(ColumnNode.create(fk), TableNode2.create(relationModel)))));
749
+ condition = conjunction(this.dialect, keyPairs.map(({ fk, pk }) => BinaryOperationNode2.create(ReferenceNode2.create(ColumnNode2.create(pk), TableNode2.create(context.alias ?? fromModel)), OperatorNode2.create("="), ReferenceNode2.create(ColumnNode2.create(fk), TableNode2.create(relationModel)))));
750
750
  }
751
751
  return {
752
752
  kind: "SelectQueryNode",
@@ -764,11 +764,11 @@ var ExpressionTransformer = class {
764
764
  createColumnRef(column, context) {
765
765
  const tableName = context.alias ?? context.modelOrType;
766
766
  if (context.operation === "create") {
767
- return ReferenceNode2.create(ColumnNode.create(column), TableNode2.create(tableName));
767
+ return ReferenceNode2.create(ColumnNode2.create(column), TableNode2.create(tableName));
768
768
  }
769
769
  const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, column);
770
770
  if (!fieldDef.originModel || fieldDef.originModel === context.modelOrType) {
771
- return ReferenceNode2.create(ColumnNode.create(column), TableNode2.create(tableName));
771
+ return ReferenceNode2.create(ColumnNode2.create(column), TableNode2.create(tableName));
772
772
  }
773
773
  return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel);
774
774
  }
@@ -780,9 +780,9 @@ var ExpressionTransformer = class {
780
780
  TableNode2.create(baseModel)
781
781
  ]),
782
782
  selections: [
783
- SelectionNode.create(ReferenceNode2.create(ColumnNode.create(field), TableNode2.create(baseModel)))
783
+ SelectionNode.create(ReferenceNode2.create(ColumnNode2.create(field), TableNode2.create(baseModel)))
784
784
  ],
785
- where: WhereNode.create(conjunction(this.dialect, idFields.map((idField) => BinaryOperationNode2.create(ReferenceNode2.create(ColumnNode.create(idField), TableNode2.create(baseModel)), OperatorNode2.create("="), ReferenceNode2.create(ColumnNode.create(idField), TableNode2.create(modelAlias))))))
785
+ where: WhereNode.create(conjunction(this.dialect, idFields.map((idField) => BinaryOperationNode2.create(ReferenceNode2.create(ColumnNode2.create(idField), TableNode2.create(baseModel)), OperatorNode2.create("="), ReferenceNode2.create(ColumnNode2.create(idField), TableNode2.create(modelAlias))))))
786
786
  };
787
787
  }
788
788
  isAuthCall(value) {
@@ -903,6 +903,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
903
903
  get kysely() {
904
904
  return this.client.$qb;
905
905
  }
906
+ // #region main entry point
906
907
  async handle(node, proceed) {
907
908
  if (!this.isCrudQueryNode(node)) {
908
909
  throw createRejectedByPolicyError(void 0, RejectedByPolicyReason.OTHER, "non-CRUD queries are not allowed");
@@ -913,19 +914,10 @@ var PolicyHandler = class extends OperationNodeTransformer {
913
914
  const { mutationModel } = this.getMutationModel(node);
914
915
  this.tryRejectNonexistentModel(mutationModel);
915
916
  if (InsertQueryNode.is(node)) {
916
- const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel);
917
- let needCheckPreCreate = true;
918
- if (!isManyToManyJoinTable) {
919
- const constCondition = this.tryGetConstantPolicy(mutationModel, "create");
920
- if (constCondition === true) {
921
- needCheckPreCreate = false;
922
- } else if (constCondition === false) {
923
- throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS);
924
- }
925
- }
926
- if (needCheckPreCreate) {
927
- await this.enforcePreCreatePolicy(node, mutationModel, isManyToManyJoinTable, proceed);
928
- }
917
+ await this.preCreateCheck(mutationModel, node, proceed);
918
+ }
919
+ if (UpdateQueryNode.is(node)) {
920
+ await this.preUpdateCheck(mutationModel, node, proceed);
929
921
  }
930
922
  const hasPostUpdatePolicies = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
931
923
  let beforeUpdateInfo;
@@ -984,94 +976,85 @@ var PolicyHandler = class extends OperationNodeTransformer {
984
976
  return readBackResult;
985
977
  }
986
978
  }
987
- // correction to kysely mutation result may be needed because we might have added
988
- // returning clause to the query and caused changes to the result shape
989
- postProcessMutationResult(result, node) {
990
- if (node.returning) {
991
- return result;
992
- } else {
993
- return {
994
- ...result,
995
- rows: [],
996
- numAffectedRows: result.numAffectedRows ?? BigInt(result.rows.length)
997
- };
979
+ async preCreateCheck(mutationModel, node, proceed) {
980
+ const isManyToManyJoinTable = this.isManyToManyJoinTable(mutationModel);
981
+ let needCheckPreCreate = true;
982
+ if (!isManyToManyJoinTable) {
983
+ const constCondition = this.tryGetConstantPolicy(mutationModel, "create");
984
+ if (constCondition === true) {
985
+ needCheckPreCreate = false;
986
+ } else if (constCondition === false) {
987
+ throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS);
988
+ }
998
989
  }
999
- }
1000
- hasPostUpdatePolicies(model) {
1001
- const policies = this.getModelPolicies(model, "post-update");
1002
- return policies.length > 0;
1003
- }
1004
- async loadBeforeUpdateEntities(model, where, proceed) {
1005
- const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
1006
- if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
1007
- return void 0;
990
+ if (needCheckPreCreate) {
991
+ await this.enforcePreCreatePolicy(node, mutationModel, isManyToManyJoinTable, proceed);
1008
992
  }
1009
- const policyFilter = this.buildPolicyFilter(model, model, "update");
1010
- const combinedFilter = where ? conjunction(this.dialect, [
1011
- where.where,
1012
- policyFilter
1013
- ]) : policyFilter;
1014
- const query = {
1015
- kind: "SelectQueryNode",
1016
- from: FromNode2.create([
1017
- TableNode3.create(model)
1018
- ]),
1019
- where: WhereNode2.create(combinedFilter),
1020
- selections: [
1021
- ...beforeUpdateAccessFields.map((f) => SelectionNode2.create(ColumnNode2.create(f)))
1022
- ]
1023
- };
1024
- const result = await proceed(query);
1025
- return {
1026
- fields: beforeUpdateAccessFields,
1027
- rows: result.rows
1028
- };
1029
993
  }
1030
- getFieldsAccessForBeforeUpdatePolicies(model) {
1031
- const policies = this.getModelPolicies(model, "post-update");
1032
- if (policies.length === 0) {
1033
- return void 0;
1034
- }
1035
- const fields = /* @__PURE__ */ new Set();
1036
- const fieldCollector = new class extends SchemaUtils2.ExpressionVisitor {
1037
- visitMember(e) {
1038
- if (isBeforeInvocation(e.receiver)) {
1039
- invariant3(e.members.length === 1, "before() can only be followed by a scalar field access");
1040
- fields.add(e.members[0]);
1041
- }
1042
- super.visitMember(e);
1043
- }
1044
- }();
1045
- for (const policy of policies) {
1046
- fieldCollector.visit(policy.condition);
994
+ async preUpdateCheck(mutationModel, node, proceed) {
995
+ const fieldsToUpdate = node.updates?.map((u) => ColumnNode3.is(u.column) ? u.column.column.name : void 0).filter((f) => !!f) ?? [];
996
+ const fieldUpdatePolicies = fieldsToUpdate.map((f) => this.buildFieldPolicyFilter(mutationModel, f, "update"));
997
+ const fieldLevelFilter = conjunction(this.dialect, fieldUpdatePolicies);
998
+ if (isTrueNode(fieldLevelFilter)) {
999
+ return;
1047
1000
  }
1048
- if (fields.size === 0) {
1049
- return void 0;
1001
+ const modelLevelFilter = this.buildPolicyFilter(mutationModel, void 0, "update");
1002
+ const updateFilter = conjunction(this.dialect, [
1003
+ modelLevelFilter,
1004
+ node.where?.where ?? trueNode(this.dialect)
1005
+ ]);
1006
+ const preUpdateCheckQuery = expressionBuilder2().selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(eb.cast(new ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)), "integer")), eb.lit(0)).as("$filteredCount")).where(() => new ExpressionWrapper(updateFilter));
1007
+ const preUpdateResult = await proceed(preUpdateCheckQuery.toOperationNode());
1008
+ if (preUpdateResult.rows[0].$filteredCount > 0) {
1009
+ throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS, "some rows cannot be updated due to field policies");
1050
1010
  }
1051
- QueryUtils2.requireIdFields(this.client.$schema, model).forEach((f) => fields.add(f));
1052
- return Array.from(fields).sort();
1053
1011
  }
1054
- // #region overrides
1012
+ // #endregion
1013
+ // #region Transformations
1055
1014
  transformSelectQuery(node) {
1056
1015
  if (!node.from) {
1057
1016
  return super.transformSelectQuery(node);
1058
1017
  }
1059
- let whereNode = this.transformNode(node.where);
1060
- const policyFilter = this.createPolicyFilterForFrom(node.from);
1061
- if (policyFilter) {
1062
- whereNode = WhereNode2.create(whereNode?.where ? conjunction(this.dialect, [
1063
- whereNode.where,
1064
- policyFilter
1065
- ]) : policyFilter);
1066
- }
1067
- const baseResult = super.transformSelectQuery({
1068
- ...node,
1069
- where: void 0
1018
+ this.tryRejectNonexistingTables(node.from.froms);
1019
+ let result = super.transformSelectQuery(node);
1020
+ const hasFieldLevelPolicies = node.from.froms.some((table) => {
1021
+ const extractedTable = this.extractTableName(table);
1022
+ if (extractedTable) {
1023
+ return this.hasFieldLevelPolicies(extractedTable.model, "read");
1024
+ } else {
1025
+ return false;
1026
+ }
1070
1027
  });
1071
- return {
1072
- ...baseResult,
1073
- where: whereNode
1074
- };
1028
+ if (hasFieldLevelPolicies) {
1029
+ const updatedFroms = [];
1030
+ for (const table of result.from.froms) {
1031
+ const extractedTable = this.extractTableName(table);
1032
+ if (extractedTable?.model && QueryUtils2.getModel(this.client.$schema, extractedTable.model)) {
1033
+ const { query } = this.createSelectAllFieldsWithPolicies(extractedTable.model, extractedTable.alias, "read");
1034
+ updatedFroms.push(query);
1035
+ } else {
1036
+ updatedFroms.push(table);
1037
+ }
1038
+ }
1039
+ result = {
1040
+ ...result,
1041
+ from: FromNode2.create(updatedFroms)
1042
+ };
1043
+ } else {
1044
+ let whereNode = result.where;
1045
+ const policyFilter = this.createPolicyFilterForFrom(result.from);
1046
+ if (policyFilter && !isTrueNode(policyFilter)) {
1047
+ whereNode = WhereNode2.create(whereNode?.where ? conjunction(this.dialect, [
1048
+ whereNode.where,
1049
+ policyFilter
1050
+ ]) : policyFilter);
1051
+ }
1052
+ result = {
1053
+ ...result,
1054
+ where: whereNode
1055
+ };
1056
+ }
1057
+ return result;
1075
1058
  }
1076
1059
  transformJoin(node) {
1077
1060
  const table = this.extractTableName(node.table);
@@ -1079,20 +1062,17 @@ var PolicyHandler = class extends OperationNodeTransformer {
1079
1062
  return super.transformJoin(node);
1080
1063
  }
1081
1064
  this.tryRejectNonexistentModel(table.model);
1082
- const filter = this.buildPolicyFilter(table.model, table.alias, "read");
1083
- const nestedSelect = {
1084
- kind: "SelectQueryNode",
1085
- from: FromNode2.create([
1086
- node.table
1087
- ]),
1088
- selections: [
1089
- SelectionNode2.createSelectAll()
1090
- ],
1091
- where: WhereNode2.create(filter)
1092
- };
1065
+ if (!QueryUtils2.getModel(this.client.$schema, table.model)) {
1066
+ return super.transformJoin(node);
1067
+ }
1068
+ const result = super.transformJoin(node);
1069
+ const { hasPolicies, query: nestedQuery } = this.createSelectAllFieldsWithPolicies(table.model, table.alias, "read");
1070
+ if (!hasPolicies) {
1071
+ return result;
1072
+ }
1093
1073
  return {
1094
- ...node,
1095
- table: AliasNode3.create(ParensNode2.create(nestedSelect), IdentifierNode2.create(table.alias ?? table.model))
1074
+ ...result,
1075
+ table: nestedQuery
1096
1076
  };
1097
1077
  }
1098
1078
  transformInsertQuery(node) {
@@ -1124,7 +1104,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1124
1104
  if (returning) {
1125
1105
  const { mutationModel } = this.getMutationModel(node);
1126
1106
  const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
1127
- returning = ReturningNode.create(idFields.map((f) => SelectionNode2.create(ColumnNode2.create(f))));
1107
+ returning = ReturningNode.create(idFields.map((f) => SelectionNode2.create(ColumnNode3.create(f))));
1128
1108
  }
1129
1109
  return {
1130
1110
  ...result,
@@ -1136,6 +1116,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1136
1116
  const { mutationModel, alias } = this.getMutationModel(node);
1137
1117
  let filter = this.buildPolicyFilter(mutationModel, alias, "update");
1138
1118
  if (node.from) {
1119
+ this.tryRejectNonexistingTables(node.from.froms);
1139
1120
  const joinFilter = this.createPolicyFilterForFrom(node.from);
1140
1121
  if (joinFilter) {
1141
1122
  filter = conjunction(this.dialect, [
@@ -1147,7 +1128,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1147
1128
  let returning = result.returning;
1148
1129
  if (returning || this.hasPostUpdatePolicies(mutationModel)) {
1149
1130
  const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
1150
- returning = ReturningNode.create(idFields.map((f) => SelectionNode2.create(ColumnNode2.create(f))));
1131
+ returning = ReturningNode.create(idFields.map((f) => SelectionNode2.create(ColumnNode3.create(f))));
1151
1132
  }
1152
1133
  return {
1153
1134
  ...result,
@@ -1163,6 +1144,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1163
1144
  const { mutationModel, alias } = this.getMutationModel(node);
1164
1145
  let filter = this.buildPolicyFilter(mutationModel, alias, "delete");
1165
1146
  if (node.using) {
1147
+ this.tryRejectNonexistingTables(node.using.tables);
1166
1148
  const joinFilter = this.createPolicyFilterForTables(node.using.tables);
1167
1149
  if (joinFilter) {
1168
1150
  filter = conjunction(this.dialect, [
@@ -1180,6 +1162,139 @@ var PolicyHandler = class extends OperationNodeTransformer {
1180
1162
  };
1181
1163
  }
1182
1164
  // #endregion
1165
+ // #region post-update
1166
+ async loadBeforeUpdateEntities(model, where, proceed) {
1167
+ const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
1168
+ if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
1169
+ return void 0;
1170
+ }
1171
+ const policyFilter = this.buildPolicyFilter(model, model, "update");
1172
+ const combinedFilter = where ? conjunction(this.dialect, [
1173
+ where.where,
1174
+ policyFilter
1175
+ ]) : policyFilter;
1176
+ const query = {
1177
+ kind: "SelectQueryNode",
1178
+ from: FromNode2.create([
1179
+ TableNode3.create(model)
1180
+ ]),
1181
+ where: WhereNode2.create(combinedFilter),
1182
+ selections: [
1183
+ ...beforeUpdateAccessFields.map((f) => SelectionNode2.create(ColumnNode3.create(f)))
1184
+ ]
1185
+ };
1186
+ const result = await proceed(query);
1187
+ return {
1188
+ fields: beforeUpdateAccessFields,
1189
+ rows: result.rows
1190
+ };
1191
+ }
1192
+ getFieldsAccessForBeforeUpdatePolicies(model) {
1193
+ const policies = this.getModelPolicies(model, "post-update");
1194
+ if (policies.length === 0) {
1195
+ return void 0;
1196
+ }
1197
+ const fields = /* @__PURE__ */ new Set();
1198
+ const fieldCollector = new class extends SchemaUtils2.ExpressionVisitor {
1199
+ visitMember(e) {
1200
+ if (isBeforeInvocation(e.receiver)) {
1201
+ invariant3(e.members.length === 1, "before() can only be followed by a scalar field access");
1202
+ fields.add(e.members[0]);
1203
+ }
1204
+ super.visitMember(e);
1205
+ }
1206
+ }();
1207
+ for (const policy of policies) {
1208
+ fieldCollector.visit(policy.condition);
1209
+ }
1210
+ if (fields.size === 0) {
1211
+ return void 0;
1212
+ }
1213
+ QueryUtils2.requireIdFields(this.client.$schema, model).forEach((f) => fields.add(f));
1214
+ return Array.from(fields).sort();
1215
+ }
1216
+ hasPostUpdatePolicies(model) {
1217
+ const policies = this.getModelPolicies(model, "post-update");
1218
+ return policies.length > 0;
1219
+ }
1220
+ // #endregion
1221
+ // #region field-level policies
1222
+ createSelectAllFieldsWithPolicies(model, alias, operation) {
1223
+ let hasPolicies = false;
1224
+ const modelDef = QueryUtils2.requireModel(this.client.$schema, model);
1225
+ let selections = [];
1226
+ for (const fieldDef of Object.values(modelDef.fields).filter(
1227
+ // exclude relation/computed/inherited fields
1228
+ (f) => !f.relation && !f.computed && !f.originModel
1229
+ )) {
1230
+ const { hasPolicies: fieldHasPolicies, selection } = this.createFieldSelectionWithPolicy(model, fieldDef.name, operation);
1231
+ hasPolicies = hasPolicies || fieldHasPolicies;
1232
+ selections.push(selection);
1233
+ }
1234
+ if (!hasPolicies) {
1235
+ selections = [
1236
+ SelectionNode2.create(SelectAllNode.create())
1237
+ ];
1238
+ }
1239
+ const modelPolicyFilter = this.buildPolicyFilter(model, alias, operation);
1240
+ if (!isTrueNode(modelPolicyFilter)) {
1241
+ hasPolicies = true;
1242
+ }
1243
+ const nestedQuery = {
1244
+ kind: "SelectQueryNode",
1245
+ from: FromNode2.create([
1246
+ TableNode3.create(model)
1247
+ ]),
1248
+ where: isTrueNode(modelPolicyFilter) ? void 0 : WhereNode2.create(modelPolicyFilter),
1249
+ selections
1250
+ };
1251
+ return {
1252
+ hasPolicies,
1253
+ query: AliasNode3.create(ParensNode2.create(nestedQuery), IdentifierNode2.create(alias ?? model))
1254
+ };
1255
+ }
1256
+ createFieldSelectionWithPolicy(model, field, operation) {
1257
+ const filter = this.buildFieldPolicyFilter(model, field, operation);
1258
+ if (isTrueNode(filter)) {
1259
+ return {
1260
+ hasPolicies: false,
1261
+ selection: SelectionNode2.create(ColumnNode3.create(field))
1262
+ };
1263
+ }
1264
+ const eb = expressionBuilder2();
1265
+ const selection = eb.case().when(new ExpressionWrapper(filter)).then(eb.ref(field)).else(null).end().as(field).toOperationNode();
1266
+ return {
1267
+ hasPolicies: true,
1268
+ selection: SelectionNode2.create(selection)
1269
+ };
1270
+ }
1271
+ hasFieldLevelPolicies(model, operation) {
1272
+ const modelDef = QueryUtils2.getModel(this.client.$schema, model);
1273
+ if (!modelDef) {
1274
+ return false;
1275
+ }
1276
+ return Object.keys(modelDef.fields).some((field) => this.getFieldPolicies(model, field, operation).length > 0);
1277
+ }
1278
+ buildFieldPolicyFilter(model, field, operation) {
1279
+ const policies = this.getFieldPolicies(model, field, operation);
1280
+ const allows = policies.filter((policy) => policy.kind === "allow").map((policy) => this.compilePolicyCondition(model, model, operation, policy));
1281
+ const denies = policies.filter((policy) => policy.kind === "deny").map((policy) => this.compilePolicyCondition(model, model, operation, policy));
1282
+ let combinedPolicy;
1283
+ if (allows.length === 0) {
1284
+ combinedPolicy = trueNode(this.dialect);
1285
+ } else {
1286
+ combinedPolicy = disjunction(this.dialect, allows);
1287
+ }
1288
+ if (denies.length !== 0) {
1289
+ const combinedDenies = conjunction(this.dialect, denies.map((d) => buildIsFalse(d, this.dialect)));
1290
+ combinedPolicy = conjunction(this.dialect, [
1291
+ combinedPolicy,
1292
+ combinedDenies
1293
+ ]);
1294
+ }
1295
+ return combinedPolicy;
1296
+ }
1297
+ // #endregion
1183
1298
  // #region helpers
1184
1299
  onlyReturningId(node) {
1185
1300
  if (!node.returning) {
@@ -1378,7 +1493,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1378
1493
  }
1379
1494
  buildIdConditions(table, rows) {
1380
1495
  const idFields = QueryUtils2.requireIdFields(this.client.$schema, table);
1381
- return disjunction(this.dialect, rows.map((row) => conjunction(this.dialect, idFields.map((field) => BinaryOperationNode3.create(ReferenceNode3.create(ColumnNode2.create(field), TableNode3.create(table)), OperatorNode3.create("="), ValueNode3.create(row[field]))))));
1496
+ return disjunction(this.dialect, rows.map((row) => conjunction(this.dialect, idFields.map((field) => BinaryOperationNode3.create(ReferenceNode3.create(ColumnNode3.create(field), TableNode3.create(table)), OperatorNode3.create("="), ValueNode3.create(row[field]))))));
1382
1497
  }
1383
1498
  getMutationModel(node) {
1384
1499
  const r = match3(node).when(InsertQueryNode.is, (node2) => ({
@@ -1471,7 +1586,6 @@ var PolicyHandler = class extends OperationNodeTransformer {
1471
1586
  const extractResult = this.extractTableName(table);
1472
1587
  if (extractResult) {
1473
1588
  const { model, alias } = extractResult;
1474
- this.tryRejectNonexistentModel(model);
1475
1589
  const filter = this.buildPolicyFilter(model, alias, "read");
1476
1590
  return acc ? conjunction(this.dialect, [
1477
1591
  acc,
@@ -1507,6 +1621,23 @@ var PolicyHandler = class extends OperationNodeTransformer {
1507
1621
  }
1508
1622
  return result;
1509
1623
  }
1624
+ getFieldPolicies(model, field, operation) {
1625
+ const fieldDef = QueryUtils2.requireField(this.client.$schema, model, field);
1626
+ const result = [];
1627
+ const extractOperations = /* @__PURE__ */ __name((expr2) => {
1628
+ invariant3(ExpressionUtils4.isLiteral(expr2), "expecting a literal");
1629
+ invariant3(typeof expr2.value === "string", "expecting a string literal");
1630
+ return expr2.value.split(",").filter((v) => !!v).map((v) => v.trim());
1631
+ }, "extractOperations");
1632
+ if (fieldDef.attributes) {
1633
+ result.push(...fieldDef.attributes.filter((attr) => attr.name === "@allow" || attr.name === "@deny").map((attr) => ({
1634
+ kind: attr.name === "@allow" ? "allow" : "deny",
1635
+ operations: extractOperations(attr.args[0].value),
1636
+ condition: attr.args[1].value
1637
+ })).filter((policy) => policy.operations.includes("all") || policy.operations.includes(operation)));
1638
+ }
1639
+ return result;
1640
+ }
1510
1641
  resolveManyToManyJoinTable(tableName) {
1511
1642
  for (const model of Object.values(this.client.$schema.models)) {
1512
1643
  for (const field of Object.values(model.fields)) {
@@ -1564,6 +1695,27 @@ var PolicyHandler = class extends OperationNodeTransformer {
1564
1695
  throw createRejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS);
1565
1696
  }
1566
1697
  }
1698
+ tryRejectNonexistingTables(tables) {
1699
+ for (const table of tables) {
1700
+ const extractResult = this.extractTableName(table);
1701
+ if (extractResult) {
1702
+ this.tryRejectNonexistentModel(extractResult.model);
1703
+ }
1704
+ }
1705
+ }
1706
+ // correction to kysely mutation result may be needed because we might have added
1707
+ // returning clause to the query and caused changes to the result shape
1708
+ postProcessMutationResult(result, node) {
1709
+ if (node.returning) {
1710
+ return result;
1711
+ } else {
1712
+ return {
1713
+ ...result,
1714
+ rows: [],
1715
+ numAffectedRows: result.numAffectedRows ?? BigInt(result.rows.length)
1716
+ };
1717
+ }
1718
+ }
1567
1719
  };
1568
1720
 
1569
1721
  // src/functions.ts