@zenstackhq/plugin-policy 3.3.0-beta.1 → 3.3.0-beta.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -65,14 +65,14 @@ var import_ts_pattern2 = require("ts-pattern");
65
65
 
66
66
  // src/expression-evaluator.ts
67
67
  var import_common_helpers = require("@zenstackhq/common-helpers");
68
- var import_ts_pattern = require("ts-pattern");
69
68
  var import_schema = require("@zenstackhq/orm/schema");
69
+ var import_ts_pattern = require("ts-pattern");
70
70
  var ExpressionEvaluator = class {
71
71
  static {
72
72
  __name(this, "ExpressionEvaluator");
73
73
  }
74
74
  evaluate(expression, context) {
75
- const result = (0, import_ts_pattern.match)(expression).when(import_schema.ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(import_schema.ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(import_schema.ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(import_schema.ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(import_schema.ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(import_schema.ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(import_schema.ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(import_schema.ExpressionUtils.isThis, () => context.thisValue).when(import_schema.ExpressionUtils.isNull, () => null).exhaustive();
75
+ const result = (0, import_ts_pattern.match)(expression).when(import_schema.ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(import_schema.ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(import_schema.ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(import_schema.ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(import_schema.ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(import_schema.ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(import_schema.ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(import_schema.ExpressionUtils.isBinding, (expr2) => this.evaluateBinding(expr2, context)).when(import_schema.ExpressionUtils.isThis, () => context.thisValue).when(import_schema.ExpressionUtils.isNull, () => null).exhaustive();
76
76
  return result ?? null;
77
77
  }
78
78
  evaluateCall(expr2, context) {
@@ -96,6 +96,9 @@ var ExpressionEvaluator = class {
96
96
  return expr2.value;
97
97
  }
98
98
  evaluateField(expr2, context) {
99
+ if (context.bindingScope && expr2.field in context.bindingScope) {
100
+ return context.bindingScope[expr2.field];
101
+ }
99
102
  return context.thisValue?.[expr2.field];
100
103
  }
101
104
  evaluateArray(expr2, context) {
@@ -129,15 +132,33 @@ var ExpressionEvaluator = class {
129
132
  (0, import_common_helpers.invariant)(Array.isArray(left), "expected array");
130
133
  return (0, import_ts_pattern.match)(op).with("?", () => left.some((item) => this.evaluate(expr2.right, {
131
134
  ...context,
132
- thisValue: item
135
+ thisValue: item,
136
+ bindingScope: expr2.binding ? {
137
+ ...context.bindingScope ?? {},
138
+ [expr2.binding]: item
139
+ } : context.bindingScope
133
140
  }))).with("!", () => left.every((item) => this.evaluate(expr2.right, {
134
141
  ...context,
135
- thisValue: item
142
+ thisValue: item,
143
+ bindingScope: expr2.binding ? {
144
+ ...context.bindingScope ?? {},
145
+ [expr2.binding]: item
146
+ } : context.bindingScope
136
147
  }))).with("^", () => !left.some((item) => this.evaluate(expr2.right, {
137
148
  ...context,
138
- thisValue: item
149
+ thisValue: item,
150
+ bindingScope: expr2.binding ? {
151
+ ...context.bindingScope ?? {},
152
+ [expr2.binding]: item
153
+ } : context.bindingScope
139
154
  }))).exhaustive();
140
155
  }
156
+ evaluateBinding(expr2, context) {
157
+ if (!context.bindingScope || !(expr2.name in context.bindingScope)) {
158
+ throw new Error(`Unresolved binding: ${expr2.name}`);
159
+ }
160
+ return context.bindingScope[expr2.name];
161
+ }
141
162
  };
142
163
 
143
164
  // src/types.ts
@@ -152,11 +173,11 @@ var import_orm2 = require("@zenstackhq/orm");
152
173
  var import_schema2 = require("@zenstackhq/orm/schema");
153
174
  var import_kysely = require("kysely");
154
175
  function trueNode(dialect) {
155
- return import_kysely.ValueNode.createImmediate(dialect.transformPrimitive(true, "Boolean", false));
176
+ return import_kysely.ValueNode.createImmediate(dialect.transformInput(true, "Boolean", false));
156
177
  }
157
178
  __name(trueNode, "trueNode");
158
179
  function falseNode(dialect) {
159
- return import_kysely.ValueNode.createImmediate(dialect.transformPrimitive(false, "Boolean", false));
180
+ return import_kysely.ValueNode.createImmediate(dialect.transformInput(false, "Boolean", false));
160
181
  }
161
182
  __name(falseNode, "falseNode");
162
183
  function isTrueNode(node) {
@@ -450,7 +471,8 @@ var ExpressionTransformer = class {
450
471
  const evaluator = new ExpressionEvaluator();
451
472
  const receiver = evaluator.evaluate(expr2.left, {
452
473
  thisValue: context.contextValue,
453
- auth: this.auth
474
+ auth: this.auth,
475
+ bindingScope: this.getEvaluationBindingScope(context.bindingScope)
454
476
  });
455
477
  const baseType = this.isAuthMember(expr2.left) ? this.authType : context.modelOrType;
456
478
  const memberType = this.getMemberType(baseType, expr2.left);
@@ -466,18 +488,31 @@ var ExpressionTransformer = class {
466
488
  (0, import_common_helpers2.invariant)(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr2.left)}`);
467
489
  newContextModel = fieldDef.type;
468
490
  } else {
469
- (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isMember(expr2.left) && import_schema3.ExpressionUtils.isField(expr2.left.receiver), "left operand must be member access with field receiver");
470
- const fieldDef2 = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
471
- newContextModel = fieldDef2.type;
491
+ (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isMember(expr2.left) && (import_schema3.ExpressionUtils.isField(expr2.left.receiver) || import_schema3.ExpressionUtils.isBinding(expr2.left.receiver)), "left operand must be member access with field receiver");
492
+ if (import_schema3.ExpressionUtils.isField(expr2.left.receiver)) {
493
+ const fieldDef2 = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
494
+ newContextModel = fieldDef2.type;
495
+ } else {
496
+ const binding = this.requireBindingScope(expr2.left.receiver, context);
497
+ newContextModel = binding.type;
498
+ }
472
499
  for (const member of expr2.left.members) {
473
500
  const memberDef = import_orm3.QueryUtils.requireField(this.schema, newContextModel, member);
474
501
  newContextModel = memberDef.type;
475
502
  }
476
503
  }
504
+ const bindingScope = expr2.binding ? {
505
+ ...context.bindingScope ?? {},
506
+ [expr2.binding]: {
507
+ type: newContextModel,
508
+ alias: newContextModel
509
+ }
510
+ } : context.bindingScope;
477
511
  let predicateFilter = this.transform(expr2.right, {
478
512
  ...context,
479
513
  modelOrType: newContextModel,
480
- alias: void 0
514
+ alias: void 0,
515
+ bindingScope
481
516
  });
482
517
  if (expr2.op === "!") {
483
518
  predicateFilter = logicalNot(this.dialect, predicateFilter);
@@ -504,18 +539,30 @@ var ExpressionTransformer = class {
504
539
  if (!visitor.find(expr2.right)) {
505
540
  const value = new ExpressionEvaluator().evaluate(expr2, {
506
541
  auth: this.auth,
507
- thisValue: context.contextValue
542
+ thisValue: context.contextValue,
543
+ bindingScope: this.getEvaluationBindingScope(context.bindingScope)
508
544
  });
509
545
  return this.transformValue(value, "Boolean");
510
546
  } else {
511
547
  (0, import_common_helpers2.invariant)(Array.isArray(receiver), "array value is expected");
512
- const components = receiver.map((item) => this.transform(expr2.right, {
513
- operation: context.operation,
514
- thisType: context.thisType,
515
- thisAlias: context.thisAlias,
516
- modelOrType: context.modelOrType,
517
- contextValue: item
518
- }));
548
+ const components = receiver.map((item) => {
549
+ const bindingScope = expr2.binding ? {
550
+ ...context.bindingScope ?? {},
551
+ [expr2.binding]: {
552
+ type: context.modelOrType,
553
+ alias: context.thisAlias ?? context.modelOrType,
554
+ value: item
555
+ }
556
+ } : context.bindingScope;
557
+ return this.transform(expr2.right, {
558
+ operation: context.operation,
559
+ thisType: context.thisType,
560
+ thisAlias: context.thisAlias,
561
+ modelOrType: context.modelOrType,
562
+ contextValue: item,
563
+ bindingScope
564
+ });
565
+ });
519
566
  return (0, import_ts_pattern2.match)(expr2.op).with("?", () => disjunction(this.dialect, components)).with("!", () => conjunction(this.dialect, components)).with("^", () => logicalNot(this.dialect, disjunction(this.dialect, components))).exhaustive();
520
567
  }
521
568
  }
@@ -582,7 +629,7 @@ var ExpressionTransformer = class {
582
629
  } else if (value === false) {
583
630
  return falseNode(this.dialect);
584
631
  } else {
585
- const transformed = this.dialect.transformPrimitive(value, type, false) ?? null;
632
+ const transformed = this.dialect.transformInput(value, type, false) ?? null;
586
633
  if (!Array.isArray(transformed)) {
587
634
  return import_kysely2.ValueNode.createImmediate(transformed);
588
635
  } else {
@@ -645,6 +692,12 @@ var ExpressionTransformer = class {
645
692
  throw createUnsupportedError(`Unsupported argument expression: ${arg.kind}`);
646
693
  }
647
694
  _member(expr2, context) {
695
+ if (import_schema3.ExpressionUtils.isBinding(expr2.receiver)) {
696
+ const scope = this.requireBindingScope(expr2.receiver, context);
697
+ if (scope.value !== void 0) {
698
+ return this.valueMemberAccess(scope.value, expr2, scope.type);
699
+ }
700
+ }
648
701
  if (this.isAuthCall(expr2.receiver)) {
649
702
  return this.valueMemberAccess(this.auth, expr2, this.authType);
650
703
  }
@@ -653,9 +706,10 @@ var ExpressionTransformer = class {
653
706
  (0, import_common_helpers2.invariant)(expr2.members.length === 1, "before() can only be followed by a scalar field access");
654
707
  return import_kysely2.ReferenceNode.create(import_kysely2.ColumnNode.create(expr2.members[0]), import_kysely2.TableNode.create("$before"));
655
708
  }
656
- (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isField(expr2.receiver) || import_schema3.ExpressionUtils.isThis(expr2.receiver), 'expect receiver to be field expression or "this"');
709
+ (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isField(expr2.receiver) || import_schema3.ExpressionUtils.isThis(expr2.receiver) || import_schema3.ExpressionUtils.isBinding(expr2.receiver), 'expect receiver to be field expression, collection predicate binding, or "this"');
657
710
  let members = expr2.members;
658
711
  let receiver;
712
+ let startType;
659
713
  const { memberFilter, memberSelect, ...restContext } = context;
660
714
  if (import_schema3.ExpressionUtils.isThis(expr2.receiver)) {
661
715
  if (expr2.members.length === 1) {
@@ -670,17 +724,40 @@ var ExpressionTransformer = class {
670
724
  const firstMemberFieldDef = import_orm3.QueryUtils.requireField(this.schema, context.thisType, expr2.members[0]);
671
725
  receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, restContext);
672
726
  members = expr2.members.slice(1);
727
+ startType = firstMemberFieldDef.type;
728
+ }
729
+ } else if (import_schema3.ExpressionUtils.isBinding(expr2.receiver)) {
730
+ if (expr2.members.length === 1) {
731
+ const bindingScope = this.requireBindingScope(expr2.receiver, context);
732
+ return this._field(import_schema3.ExpressionUtils.field(expr2.members[0]), {
733
+ ...context,
734
+ modelOrType: bindingScope.type,
735
+ alias: bindingScope.alias,
736
+ thisType: context.thisType,
737
+ contextValue: void 0
738
+ });
739
+ } else {
740
+ const bindingScope = this.requireBindingScope(expr2.receiver, context);
741
+ const firstMemberFieldDef = import_orm3.QueryUtils.requireField(this.schema, bindingScope.type, expr2.members[0]);
742
+ receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, {
743
+ ...restContext,
744
+ modelOrType: bindingScope.type,
745
+ alias: bindingScope.alias
746
+ });
747
+ members = expr2.members.slice(1);
748
+ startType = firstMemberFieldDef.type;
673
749
  }
674
750
  } else {
675
751
  receiver = this.transform(expr2.receiver, restContext);
676
752
  }
677
753
  (0, import_common_helpers2.invariant)(import_kysely2.SelectQueryNode.is(receiver), "expected receiver to be select query");
678
- let startType;
679
- if (import_schema3.ExpressionUtils.isField(expr2.receiver)) {
680
- const receiverField = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
681
- startType = receiverField.type;
682
- } else {
683
- startType = context.thisType;
754
+ if (startType === void 0) {
755
+ if (import_schema3.ExpressionUtils.isField(expr2.receiver)) {
756
+ const receiverField = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
757
+ startType = receiverField.type;
758
+ } else {
759
+ startType = context.thisType;
760
+ }
684
761
  }
685
762
  const memberFields = [];
686
763
  let currType = startType;
@@ -731,6 +808,11 @@ var ExpressionTransformer = class {
731
808
  ]
732
809
  };
733
810
  }
811
+ requireBindingScope(expr2, context) {
812
+ const binding = context.bindingScope?.[expr2.name];
813
+ (0, import_common_helpers2.invariant)(binding, `binding not found: ${expr2.name}`);
814
+ return binding;
815
+ }
734
816
  valueMemberAccess(receiver, expr2, receiverType) {
735
817
  if (!receiver) {
736
818
  return import_kysely2.ValueNode.createImmediate(null);
@@ -796,6 +878,19 @@ var ExpressionTransformer = class {
796
878
  }
797
879
  return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel);
798
880
  }
881
+ // convert transformer's binding scope to equivalent expression evaluator binding scope
882
+ getEvaluationBindingScope(scope) {
883
+ if (!scope) {
884
+ return void 0;
885
+ }
886
+ const result = {};
887
+ for (const [key, value] of Object.entries(scope)) {
888
+ if (value.value !== void 0) {
889
+ result[key] = value.value;
890
+ }
891
+ }
892
+ return Object.keys(result).length > 0 ? result : void 0;
893
+ }
799
894
  buildDelegateBaseFieldSelect(model, modelAlias, field, baseModel) {
800
895
  const idFields = import_orm3.QueryUtils.requireIdFields(this.client.$schema, model);
801
896
  return {
@@ -920,6 +1015,7 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
920
1015
  }
921
1016
  client;
922
1017
  dialect;
1018
+ eb = (0, import_kysely3.expressionBuilder)();
923
1019
  constructor(client) {
924
1020
  super(), this.client = client;
925
1021
  this.dialect = (0, import_orm4.getCrudDialect)(this.client.$schema, this.client.$options);
@@ -943,52 +1039,21 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
943
1039
  if (import_kysely3.UpdateQueryNode.is(node)) {
944
1040
  await this.preUpdateCheck(mutationModel, node, proceed);
945
1041
  }
946
- const hasPostUpdatePolicies = import_kysely3.UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
1042
+ const needsPostUpdateCheck = import_kysely3.UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
947
1043
  let beforeUpdateInfo;
948
- if (hasPostUpdatePolicies) {
949
- beforeUpdateInfo = await this.loadBeforeUpdateEntities(mutationModel, node.where, proceed);
1044
+ if (needsPostUpdateCheck) {
1045
+ beforeUpdateInfo = await this.loadBeforeUpdateEntities(
1046
+ mutationModel,
1047
+ node.where,
1048
+ proceed,
1049
+ // force load pre-update entities if dialect doesn't support returning,
1050
+ // so we can rely on pre-update ids to read back updated entities
1051
+ !this.dialect.supportsReturning
1052
+ );
950
1053
  }
951
1054
  const result = await proceed(this.transformNode(node));
952
- if (hasPostUpdatePolicies && result.rows.length > 0) {
953
- if (beforeUpdateInfo) {
954
- (0, import_common_helpers3.invariant)(beforeUpdateInfo.rows.length === result.rows.length);
955
- const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, mutationModel);
956
- for (const postRow of result.rows) {
957
- const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
958
- if (!beforeRow) {
959
- throw createRejectedByPolicyError(mutationModel, import_orm4.RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.");
960
- }
961
- }
962
- }
963
- const idConditions = this.buildIdConditions(mutationModel, result.rows);
964
- const postUpdateFilter = this.buildPolicyFilter(mutationModel, void 0, "post-update");
965
- const eb = (0, import_kysely3.expressionBuilder)();
966
- const beforeUpdateTable = beforeUpdateInfo ? {
967
- kind: "SelectQueryNode",
968
- from: import_kysely3.FromNode.create([
969
- import_kysely3.ParensNode.create(import_kysely3.ValuesNode.create(beforeUpdateInfo.rows.map((r) => import_kysely3.PrimitiveValueListNode.create(beforeUpdateInfo.fields.map((f) => r[f])))))
970
- ]),
971
- selections: beforeUpdateInfo.fields.map((name, index) => {
972
- const def = import_orm4.QueryUtils.requireField(this.client.$schema, mutationModel, name);
973
- const castedColumnRef = import_kysely3.sql`CAST(${eb.ref(`column${index + 1}`)} as ${import_kysely3.sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
974
- return import_kysely3.SelectionNode.create(castedColumnRef.toOperationNode());
975
- })
976
- } : void 0;
977
- const postUpdateQuery = eb.selectFrom(mutationModel).select(() => [
978
- eb(eb.fn("COUNT", [
979
- eb.lit(1)
980
- ]), "=", result.rows.length).as("$condition")
981
- ]).where(() => new import_kysely3.ExpressionWrapper(conjunction(this.dialect, [
982
- idConditions,
983
- postUpdateFilter
984
- ]))).$if(!!beforeUpdateInfo, (qb) => qb.leftJoin(() => new import_kysely3.ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
985
- const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, mutationModel);
986
- return idFields.reduce((acc, f) => acc.onRef(`${mutationModel}.${f}`, "=", `$before.${f}`), join);
987
- }));
988
- const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
989
- if (!postUpdateResult.rows[0]?.$condition) {
990
- throw createRejectedByPolicyError(mutationModel, import_orm4.RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
991
- }
1055
+ if ((result.numAffectedRows ?? 0) > 0 && needsPostUpdateCheck) {
1056
+ await this.postUpdateCheck(mutationModel, beforeUpdateInfo, result, proceed);
992
1057
  }
993
1058
  if (!node.returning || this.onlyReturningId(node)) {
994
1059
  return this.postProcessMutationResult(result, node);
@@ -1027,12 +1092,69 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1027
1092
  modelLevelFilter,
1028
1093
  node.where?.where ?? trueNode(this.dialect)
1029
1094
  ]);
1030
- const preUpdateCheckQuery = (0, import_kysely3.expressionBuilder)().selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(eb.cast(new import_kysely3.ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)), "integer")), eb.lit(0)).as("$filteredCount")).where(() => new import_kysely3.ExpressionWrapper(updateFilter));
1095
+ const preUpdateCheckQuery = this.eb.selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(this.dialect.castInt(new import_kysely3.ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)))), eb.lit(0)).as("$filteredCount")).where(() => new import_kysely3.ExpressionWrapper(updateFilter));
1031
1096
  const preUpdateResult = await proceed(preUpdateCheckQuery.toOperationNode());
1032
1097
  if (preUpdateResult.rows[0].$filteredCount > 0) {
1033
1098
  throw createRejectedByPolicyError(mutationModel, import_orm4.RejectedByPolicyReason.NO_ACCESS, "some rows cannot be updated due to field policies");
1034
1099
  }
1035
1100
  }
1101
+ async postUpdateCheck(model, beforeUpdateInfo, updateResult, proceed) {
1102
+ let postUpdateRows;
1103
+ if (this.dialect.supportsReturning) {
1104
+ postUpdateRows = updateResult.rows;
1105
+ } else {
1106
+ (0, import_common_helpers3.invariant)(beforeUpdateInfo, "beforeUpdateInfo must be defined for dialects not supporting returning");
1107
+ const idConditions2 = this.buildIdConditions(model, beforeUpdateInfo.rows);
1108
+ const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, model);
1109
+ const postUpdateQuery2 = {
1110
+ kind: "SelectQueryNode",
1111
+ from: import_kysely3.FromNode.create([
1112
+ import_kysely3.TableNode.create(model)
1113
+ ]),
1114
+ where: import_kysely3.WhereNode.create(idConditions2),
1115
+ selections: idFields.map((field) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(field)))
1116
+ };
1117
+ const postUpdateQueryResult = await proceed(postUpdateQuery2);
1118
+ postUpdateRows = postUpdateQueryResult.rows;
1119
+ }
1120
+ if (beforeUpdateInfo) {
1121
+ if (beforeUpdateInfo.rows.length !== postUpdateRows.length) {
1122
+ throw createRejectedByPolicyError(model, import_orm4.RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
1123
+ }
1124
+ const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, model);
1125
+ for (const postRow of postUpdateRows) {
1126
+ const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
1127
+ if (!beforeRow) {
1128
+ throw createRejectedByPolicyError(model, import_orm4.RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
1129
+ }
1130
+ }
1131
+ }
1132
+ const idConditions = this.buildIdConditions(model, postUpdateRows);
1133
+ const postUpdateFilter = this.buildPolicyFilter(model, void 0, "post-update");
1134
+ const eb = (0, import_kysely3.expressionBuilder)();
1135
+ const needsBeforeUpdateJoin = !!beforeUpdateInfo?.fields;
1136
+ let beforeUpdateTable = void 0;
1137
+ if (needsBeforeUpdateJoin) {
1138
+ const fieldDefs = beforeUpdateInfo.fields.map((name) => import_orm4.QueryUtils.requireField(this.client.$schema, model, name));
1139
+ const rows = beforeUpdateInfo.rows.map((r) => beforeUpdateInfo.fields.map((f) => r[f]));
1140
+ beforeUpdateTable = this.dialect.buildValuesTableSelect(fieldDefs, rows).toOperationNode();
1141
+ }
1142
+ const postUpdateQuery = eb.selectFrom(model).select(() => [
1143
+ eb(eb.fn("COUNT", [
1144
+ eb.lit(1)
1145
+ ]), "=", Number(updateResult.numAffectedRows ?? 0)).as("$condition")
1146
+ ]).where(() => new import_kysely3.ExpressionWrapper(conjunction(this.dialect, [
1147
+ idConditions,
1148
+ postUpdateFilter
1149
+ ]))).$if(needsBeforeUpdateJoin, (qb) => qb.leftJoin(() => new import_kysely3.ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
1150
+ const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, model);
1151
+ return idFields.reduce((acc, f) => acc.onRef(`${model}.${f}`, "=", `$before.${f}`), join);
1152
+ }));
1153
+ const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
1154
+ if (!postUpdateResult.rows[0]?.$condition) {
1155
+ throw createRejectedByPolicyError(model, import_orm4.RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
1156
+ }
1157
+ }
1036
1158
  // #endregion
1037
1159
  // #region Transformations
1038
1160
  transformSelectQuery(node) {
@@ -1100,6 +1222,7 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1100
1222
  };
1101
1223
  }
1102
1224
  transformInsertQuery(node) {
1225
+ let processedNode = node;
1103
1226
  let onConflict = node.onConflict;
1104
1227
  if (onConflict?.updates) {
1105
1228
  const { mutationModel, alias } = this.getMutationModel(node);
@@ -1118,11 +1241,36 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1118
1241
  updateWhere: import_kysely3.WhereNode.create(filter)
1119
1242
  };
1120
1243
  }
1244
+ processedNode = {
1245
+ ...node,
1246
+ onConflict
1247
+ };
1248
+ }
1249
+ let onDuplicateKey = node.onDuplicateKey;
1250
+ if (onDuplicateKey?.updates) {
1251
+ const { mutationModel } = this.getMutationModel(node);
1252
+ const filterWithTableRef = this.buildPolicyFilter(mutationModel, void 0, "update");
1253
+ const filter = this.stripTableReferences(filterWithTableRef, mutationModel);
1254
+ const wrappedUpdates = onDuplicateKey.updates.map((update) => {
1255
+ const columnName = import_kysely3.ColumnNode.is(update.column) ? update.column.column.name : void 0;
1256
+ if (!columnName) {
1257
+ return update;
1258
+ }
1259
+ const wrappedValue = import_kysely3.sql`IF(${new import_kysely3.ExpressionWrapper(filter)}, ${new import_kysely3.ExpressionWrapper(update.value)}, ${import_kysely3.sql.ref(columnName)})`.toOperationNode();
1260
+ return {
1261
+ ...update,
1262
+ value: wrappedValue
1263
+ };
1264
+ });
1265
+ onDuplicateKey = {
1266
+ ...onDuplicateKey,
1267
+ updates: wrappedUpdates
1268
+ };
1269
+ processedNode = {
1270
+ ...processedNode,
1271
+ onDuplicateKey
1272
+ };
1121
1273
  }
1122
- const processedNode = onConflict ? {
1123
- ...node,
1124
- onConflict
1125
- } : node;
1126
1274
  const result = super.transformInsertQuery(processedNode);
1127
1275
  let returning = result.returning;
1128
1276
  if (returning) {
@@ -1150,7 +1298,7 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1150
1298
  }
1151
1299
  }
1152
1300
  let returning = result.returning;
1153
- if (returning || this.hasPostUpdatePolicies(mutationModel)) {
1301
+ if (this.dialect.supportsReturning && (returning || this.hasPostUpdatePolicies(mutationModel))) {
1154
1302
  const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, mutationModel);
1155
1303
  returning = import_kysely3.ReturningNode.create(idFields.map((f) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(f))));
1156
1304
  }
@@ -1187,9 +1335,9 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1187
1335
  }
1188
1336
  // #endregion
1189
1337
  // #region post-update
1190
- async loadBeforeUpdateEntities(model, where, proceed) {
1338
+ async loadBeforeUpdateEntities(model, where, proceed, forceLoad = false) {
1191
1339
  const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
1192
- if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
1340
+ if (!forceLoad && (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0)) {
1193
1341
  return void 0;
1194
1342
  }
1195
1343
  const policyFilter = this.buildPolicyFilter(model, model, "update");
@@ -1197,15 +1345,14 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1197
1345
  where.where,
1198
1346
  policyFilter
1199
1347
  ]) : policyFilter;
1348
+ const selections = beforeUpdateAccessFields ?? import_orm4.QueryUtils.requireIdFields(this.client.$schema, model);
1200
1349
  const query = {
1201
1350
  kind: "SelectQueryNode",
1202
1351
  from: import_kysely3.FromNode.create([
1203
1352
  import_kysely3.TableNode.create(model)
1204
1353
  ]),
1205
1354
  where: import_kysely3.WhereNode.create(combinedFilter),
1206
- selections: [
1207
- ...beforeUpdateAccessFields.map((f) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(f)))
1208
- ]
1355
+ selections: selections.map((f) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(f)))
1209
1356
  };
1210
1357
  const result = await proceed(query);
1211
1358
  return {
@@ -1385,43 +1532,24 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1385
1532
  }
1386
1533
  }
1387
1534
  async enforcePreCreatePolicyForOne(model, fields, values, proceed) {
1388
- const allFields = Object.entries(import_orm4.QueryUtils.requireModel(this.client.$schema, model).fields).filter(([, def]) => !def.relation);
1535
+ const allFields = import_orm4.QueryUtils.getModelFields(this.client.$schema, model, {
1536
+ inherited: true
1537
+ });
1389
1538
  const allValues = [];
1390
- for (const [name, _def] of allFields) {
1391
- const index = fields.indexOf(name);
1539
+ for (const def of allFields) {
1540
+ const index = fields.indexOf(def.name);
1392
1541
  if (index >= 0) {
1393
- allValues.push(values[index]);
1542
+ allValues.push(new import_kysely3.ExpressionWrapper(values[index]));
1394
1543
  } else {
1395
- allValues.push(import_kysely3.ValueNode.createImmediate(null));
1544
+ allValues.push(this.eb.lit(null));
1396
1545
  }
1397
1546
  }
1398
- const eb = (0, import_kysely3.expressionBuilder)();
1399
- const constTable = {
1400
- kind: "SelectQueryNode",
1401
- from: import_kysely3.FromNode.create([
1402
- import_kysely3.AliasNode.create(import_kysely3.ParensNode.create(import_kysely3.ValuesNode.create([
1403
- import_kysely3.ValueListNode.create(allValues)
1404
- ])), import_kysely3.IdentifierNode.create("$t"))
1405
- ]),
1406
- selections: allFields.map(([name, def], index) => {
1407
- const castedColumnRef = import_kysely3.sql`CAST(${eb.ref(`column${index + 1}`)} as ${import_kysely3.sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
1408
- return import_kysely3.SelectionNode.create(castedColumnRef.toOperationNode());
1409
- })
1410
- };
1547
+ const valuesTable = this.dialect.buildValuesTableSelect(allFields, [
1548
+ allValues
1549
+ ]);
1411
1550
  const filter = this.buildPolicyFilter(model, void 0, "create");
1412
- const preCreateCheck = {
1413
- kind: "SelectQueryNode",
1414
- from: import_kysely3.FromNode.create([
1415
- import_kysely3.AliasNode.create(constTable, import_kysely3.IdentifierNode.create(model))
1416
- ]),
1417
- selections: [
1418
- import_kysely3.SelectionNode.create(import_kysely3.AliasNode.create(import_kysely3.BinaryOperationNode.create(import_kysely3.FunctionNode.create("COUNT", [
1419
- import_kysely3.ValueNode.createImmediate(1)
1420
- ]), import_kysely3.OperatorNode.create(">"), import_kysely3.ValueNode.createImmediate(0)), import_kysely3.IdentifierNode.create("$condition")))
1421
- ],
1422
- where: import_kysely3.WhereNode.create(filter)
1423
- };
1424
- const result = await proceed(preCreateCheck);
1551
+ const preCreateCheck = this.eb.selectFrom(valuesTable.as(model)).select(this.eb(this.eb.fn.count(this.eb.lit(1)), ">", 0).as("$condition")).where(() => new import_kysely3.ExpressionWrapper(filter));
1552
+ const result = await proceed(preCreateCheck.toOperationNode());
1425
1553
  if (!result.rows[0]?.$condition) {
1426
1554
  throw createRejectedByPolicyError(model, import_orm4.RejectedByPolicyReason.NO_ACCESS);
1427
1555
  }
@@ -1446,18 +1574,18 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1446
1574
  const fieldDef = import_orm4.QueryUtils.requireField(this.client.$schema, model, fields[i]);
1447
1575
  (0, import_common_helpers3.invariant)(item.kind === "ValueNode", "expecting a ValueNode");
1448
1576
  result.push({
1449
- node: import_kysely3.ValueNode.create(this.dialect.transformPrimitive(item.value, fieldDef.type, !!fieldDef.array)),
1577
+ node: import_kysely3.ValueNode.create(this.dialect.transformInput(item.value, fieldDef.type, !!fieldDef.array)),
1450
1578
  raw: item.value
1451
1579
  });
1452
1580
  } else {
1453
1581
  let value = item;
1454
1582
  if (!isImplicitManyToManyJoinTable) {
1455
1583
  const fieldDef = import_orm4.QueryUtils.requireField(this.client.$schema, model, fields[i]);
1456
- value = this.dialect.transformPrimitive(item, fieldDef.type, !!fieldDef.array);
1584
+ value = this.dialect.transformInput(item, fieldDef.type, !!fieldDef.array);
1457
1585
  }
1458
1586
  if (Array.isArray(value)) {
1459
1587
  result.push({
1460
- node: import_kysely3.RawNode.createWithSql(this.dialect.buildArrayLiteralSQL(value)),
1588
+ node: this.dialect.buildArrayLiteralSQL(value).toOperationNode(),
1461
1589
  raw: value
1462
1590
  });
1463
1591
  } else {
@@ -1705,11 +1833,10 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1705
1833
  return void 0;
1706
1834
  }
1707
1835
  const checkForOperation = operation === "read" ? "read" : "update";
1708
- const eb = (0, import_kysely3.expressionBuilder)();
1709
1836
  const joinTable = alias ?? tableName;
1710
- const aQuery = eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new import_kysely3.ExpressionWrapper(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
1711
- const bQuery = eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new import_kysely3.ExpressionWrapper(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
1712
- return eb.and([
1837
+ const aQuery = this.eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new import_kysely3.ExpressionWrapper(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
1838
+ const bQuery = this.eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new import_kysely3.ExpressionWrapper(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
1839
+ return this.eb.and([
1713
1840
  aQuery,
1714
1841
  bQuery
1715
1842
  ]).toOperationNode();
@@ -1740,6 +1867,26 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1740
1867
  };
1741
1868
  }
1742
1869
  }
1870
+ // strips table references from an OperationNode
1871
+ stripTableReferences(node, modelName) {
1872
+ return new TableReferenceStripper().strip(node, modelName);
1873
+ }
1874
+ };
1875
+ var TableReferenceStripper = class TableReferenceStripper2 extends import_kysely3.OperationNodeTransformer {
1876
+ static {
1877
+ __name(this, "TableReferenceStripper");
1878
+ }
1879
+ tableName = "";
1880
+ strip(node, tableName) {
1881
+ this.tableName = tableName;
1882
+ return this.transformNode(node);
1883
+ }
1884
+ transformReference(node) {
1885
+ if (import_kysely3.ColumnNode.is(node.column) && node.table?.table.identifier.name === this.tableName) {
1886
+ return import_kysely3.ReferenceNode.create(this.transformNode(node.column));
1887
+ }
1888
+ return super.transformReference(node);
1889
+ }
1743
1890
  };
1744
1891
 
1745
1892
  // src/functions.ts
@@ -1788,7 +1935,7 @@ var check = /* @__PURE__ */ __name((eb, args, { client, model, modelAlias, opera
1788
1935
  const policyHandler = new PolicyHandler(client);
1789
1936
  const op = arg2Node ? arg2Node.value : operation;
1790
1937
  const policyCondition = policyHandler.buildPolicyFilter(relationModel, void 0, op);
1791
- const result = eb.selectFrom(relationModel).where(joinCondition).select(new import_kysely4.ExpressionWrapper(policyCondition).as("$condition"));
1938
+ const result = eb.selectFrom(eb.selectFrom(relationModel).where(joinCondition).select(new import_kysely4.ExpressionWrapper(policyCondition).as("$condition")).as("$sub")).selectAll();
1792
1939
  return result;
1793
1940
  }, "check");
1794
1941