@zenstackhq/plugin-policy 3.3.0-beta.2 → 3.3.0-beta.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -65,14 +65,14 @@ var import_ts_pattern2 = require("ts-pattern");
65
65
 
66
66
  // src/expression-evaluator.ts
67
67
  var import_common_helpers = require("@zenstackhq/common-helpers");
68
- var import_ts_pattern = require("ts-pattern");
69
68
  var import_schema = require("@zenstackhq/orm/schema");
69
+ var import_ts_pattern = require("ts-pattern");
70
70
  var ExpressionEvaluator = class {
71
71
  static {
72
72
  __name(this, "ExpressionEvaluator");
73
73
  }
74
74
  evaluate(expression, context) {
75
- const result = (0, import_ts_pattern.match)(expression).when(import_schema.ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(import_schema.ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(import_schema.ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(import_schema.ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(import_schema.ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(import_schema.ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(import_schema.ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(import_schema.ExpressionUtils.isThis, () => context.thisValue).when(import_schema.ExpressionUtils.isNull, () => null).exhaustive();
75
+ const result = (0, import_ts_pattern.match)(expression).when(import_schema.ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(import_schema.ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(import_schema.ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(import_schema.ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(import_schema.ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(import_schema.ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(import_schema.ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(import_schema.ExpressionUtils.isBinding, (expr2) => this.evaluateBinding(expr2, context)).when(import_schema.ExpressionUtils.isThis, () => context.thisValue).when(import_schema.ExpressionUtils.isNull, () => null).exhaustive();
76
76
  return result ?? null;
77
77
  }
78
78
  evaluateCall(expr2, context) {
@@ -96,6 +96,9 @@ var ExpressionEvaluator = class {
96
96
  return expr2.value;
97
97
  }
98
98
  evaluateField(expr2, context) {
99
+ if (context.bindingScope && expr2.field in context.bindingScope) {
100
+ return context.bindingScope[expr2.field];
101
+ }
99
102
  return context.thisValue?.[expr2.field];
100
103
  }
101
104
  evaluateArray(expr2, context) {
@@ -129,15 +132,33 @@ var ExpressionEvaluator = class {
129
132
  (0, import_common_helpers.invariant)(Array.isArray(left), "expected array");
130
133
  return (0, import_ts_pattern.match)(op).with("?", () => left.some((item) => this.evaluate(expr2.right, {
131
134
  ...context,
132
- thisValue: item
135
+ thisValue: item,
136
+ bindingScope: expr2.binding ? {
137
+ ...context.bindingScope ?? {},
138
+ [expr2.binding]: item
139
+ } : context.bindingScope
133
140
  }))).with("!", () => left.every((item) => this.evaluate(expr2.right, {
134
141
  ...context,
135
- thisValue: item
142
+ thisValue: item,
143
+ bindingScope: expr2.binding ? {
144
+ ...context.bindingScope ?? {},
145
+ [expr2.binding]: item
146
+ } : context.bindingScope
136
147
  }))).with("^", () => !left.some((item) => this.evaluate(expr2.right, {
137
148
  ...context,
138
- thisValue: item
149
+ thisValue: item,
150
+ bindingScope: expr2.binding ? {
151
+ ...context.bindingScope ?? {},
152
+ [expr2.binding]: item
153
+ } : context.bindingScope
139
154
  }))).exhaustive();
140
155
  }
156
+ evaluateBinding(expr2, context) {
157
+ if (!context.bindingScope || !(expr2.name in context.bindingScope)) {
158
+ throw new Error(`Unresolved binding: ${expr2.name}`);
159
+ }
160
+ return context.bindingScope[expr2.name];
161
+ }
141
162
  };
142
163
 
143
164
  // src/types.ts
@@ -152,11 +173,11 @@ var import_orm2 = require("@zenstackhq/orm");
152
173
  var import_schema2 = require("@zenstackhq/orm/schema");
153
174
  var import_kysely = require("kysely");
154
175
  function trueNode(dialect) {
155
- return import_kysely.ValueNode.createImmediate(dialect.transformPrimitive(true, "Boolean", false));
176
+ return import_kysely.ValueNode.createImmediate(dialect.transformInput(true, "Boolean", false));
156
177
  }
157
178
  __name(trueNode, "trueNode");
158
179
  function falseNode(dialect) {
159
- return import_kysely.ValueNode.createImmediate(dialect.transformPrimitive(false, "Boolean", false));
180
+ return import_kysely.ValueNode.createImmediate(dialect.transformInput(false, "Boolean", false));
160
181
  }
161
182
  __name(falseNode, "falseNode");
162
183
  function isTrueNode(node) {
@@ -290,6 +311,7 @@ var ExpressionTransformer = class {
290
311
  }
291
312
  client;
292
313
  dialect;
314
+ eb = (0, import_kysely2.expressionBuilder)();
293
315
  constructor(client) {
294
316
  this.client = client;
295
317
  this.dialect = (0, import_orm3.getCrudDialect)(this.schema, this.clientOptions);
@@ -314,13 +336,15 @@ var ExpressionTransformer = class {
314
336
  if (!handler) {
315
337
  throw new Error(`Unsupported expression kind: ${expression.kind}`);
316
338
  }
317
- return handler.value.call(this, expression, context);
339
+ const result = handler.value.call(this, expression, context);
340
+ (0, import_common_helpers2.invariant)("kind" in result, `expression handler must return an OperationNode: transforming ${expression.kind}`);
341
+ return result;
318
342
  }
319
343
  _literal(expr2) {
320
344
  return this.transformValue(expr2.value, typeof expr2.value === "string" ? "String" : typeof expr2.value === "boolean" ? "Boolean" : "Int");
321
345
  }
322
346
  _array(expr2, context) {
323
- return import_kysely2.ValueListNode.create(expr2.items.map((item) => this.transform(item, context)));
347
+ return this.dialect.buildArrayValue(expr2.items.map((item) => new import_kysely2.ExpressionWrapper(this.transform(item, context))), expr2.type).toOperationNode();
324
348
  }
325
349
  _field(expr2, context) {
326
350
  if (context.contextValue) {
@@ -450,7 +474,8 @@ var ExpressionTransformer = class {
450
474
  const evaluator = new ExpressionEvaluator();
451
475
  const receiver = evaluator.evaluate(expr2.left, {
452
476
  thisValue: context.contextValue,
453
- auth: this.auth
477
+ auth: this.auth,
478
+ bindingScope: this.getEvaluationBindingScope(context.bindingScope)
454
479
  });
455
480
  const baseType = this.isAuthMember(expr2.left) ? this.authType : context.modelOrType;
456
481
  const memberType = this.getMemberType(baseType, expr2.left);
@@ -466,18 +491,31 @@ var ExpressionTransformer = class {
466
491
  (0, import_common_helpers2.invariant)(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr2.left)}`);
467
492
  newContextModel = fieldDef.type;
468
493
  } else {
469
- (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isMember(expr2.left) && import_schema3.ExpressionUtils.isField(expr2.left.receiver), "left operand must be member access with field receiver");
470
- const fieldDef2 = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
471
- newContextModel = fieldDef2.type;
494
+ (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isMember(expr2.left) && (import_schema3.ExpressionUtils.isField(expr2.left.receiver) || import_schema3.ExpressionUtils.isBinding(expr2.left.receiver)), "left operand must be member access with field receiver");
495
+ if (import_schema3.ExpressionUtils.isField(expr2.left.receiver)) {
496
+ const fieldDef2 = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
497
+ newContextModel = fieldDef2.type;
498
+ } else {
499
+ const binding = this.requireBindingScope(expr2.left.receiver, context);
500
+ newContextModel = binding.type;
501
+ }
472
502
  for (const member of expr2.left.members) {
473
503
  const memberDef = import_orm3.QueryUtils.requireField(this.schema, newContextModel, member);
474
504
  newContextModel = memberDef.type;
475
505
  }
476
506
  }
507
+ const bindingScope = expr2.binding ? {
508
+ ...context.bindingScope ?? {},
509
+ [expr2.binding]: {
510
+ type: newContextModel,
511
+ alias: newContextModel
512
+ }
513
+ } : context.bindingScope;
477
514
  let predicateFilter = this.transform(expr2.right, {
478
515
  ...context,
479
516
  modelOrType: newContextModel,
480
- alias: void 0
517
+ alias: void 0,
518
+ bindingScope
481
519
  });
482
520
  if (expr2.op === "!") {
483
521
  predicateFilter = logicalNot(this.dialect, predicateFilter);
@@ -504,18 +542,30 @@ var ExpressionTransformer = class {
504
542
  if (!visitor.find(expr2.right)) {
505
543
  const value = new ExpressionEvaluator().evaluate(expr2, {
506
544
  auth: this.auth,
507
- thisValue: context.contextValue
545
+ thisValue: context.contextValue,
546
+ bindingScope: this.getEvaluationBindingScope(context.bindingScope)
508
547
  });
509
548
  return this.transformValue(value, "Boolean");
510
549
  } else {
511
550
  (0, import_common_helpers2.invariant)(Array.isArray(receiver), "array value is expected");
512
- const components = receiver.map((item) => this.transform(expr2.right, {
513
- operation: context.operation,
514
- thisType: context.thisType,
515
- thisAlias: context.thisAlias,
516
- modelOrType: context.modelOrType,
517
- contextValue: item
518
- }));
551
+ const components = receiver.map((item) => {
552
+ const bindingScope = expr2.binding ? {
553
+ ...context.bindingScope ?? {},
554
+ [expr2.binding]: {
555
+ type: context.modelOrType,
556
+ alias: context.thisAlias ?? context.modelOrType,
557
+ value: item
558
+ }
559
+ } : context.bindingScope;
560
+ return this.transform(expr2.right, {
561
+ operation: context.operation,
562
+ thisType: context.thisType,
563
+ thisAlias: context.thisAlias,
564
+ modelOrType: context.modelOrType,
565
+ contextValue: item,
566
+ bindingScope
567
+ });
568
+ });
519
569
  return (0, import_ts_pattern2.match)(expr2.op).with("?", () => disjunction(this.dialect, components)).with("!", () => conjunction(this.dialect, components)).with("^", () => logicalNot(this.dialect, disjunction(this.dialect, components))).exhaustive();
520
570
  }
521
571
  }
@@ -581,9 +631,11 @@ var ExpressionTransformer = class {
581
631
  return trueNode(this.dialect);
582
632
  } else if (value === false) {
583
633
  return falseNode(this.dialect);
634
+ } else if (Array.isArray(value)) {
635
+ return this.dialect.buildArrayValue(value.map((v) => new import_kysely2.ExpressionWrapper(this.transformValue(v, type))), type).toOperationNode();
584
636
  } else {
585
- const transformed = this.dialect.transformPrimitive(value, type, false) ?? null;
586
- if (!Array.isArray(transformed)) {
637
+ const transformed = this.dialect.transformInput(value, type, false) ?? null;
638
+ if (typeof transformed !== "string") {
587
639
  return import_kysely2.ValueNode.createImmediate(transformed);
588
640
  } else {
589
641
  return import_kysely2.ValueNode.create(transformed);
@@ -607,8 +659,7 @@ var ExpressionTransformer = class {
607
659
  if (!func) {
608
660
  throw createUnsupportedError(`Function not implemented: ${expr2.function}`);
609
661
  }
610
- const eb = (0, import_kysely2.expressionBuilder)();
611
- return func(eb, (expr2.args ?? []).map((arg) => this.transformCallArg(eb, arg, context)), {
662
+ return func(this.eb, (expr2.args ?? []).map((arg) => this.transformCallArg(arg, context)), {
612
663
  client: this.client,
613
664
  dialect: this.dialect,
614
665
  model: context.modelOrType,
@@ -628,23 +679,20 @@ var ExpressionTransformer = class {
628
679
  }
629
680
  return func;
630
681
  }
631
- transformCallArg(eb, arg, context) {
632
- if (import_schema3.ExpressionUtils.isLiteral(arg)) {
633
- return eb.val(arg.value);
634
- }
682
+ transformCallArg(arg, context) {
635
683
  if (import_schema3.ExpressionUtils.isField(arg)) {
636
- return eb.ref(arg.field);
637
- }
638
- if (import_schema3.ExpressionUtils.isCall(arg)) {
639
- return this.transformCall(arg, context);
640
- }
641
- if (this.isAuthMember(arg)) {
642
- const valNode = this.valueMemberAccess(this.auth, arg, this.authType);
643
- return valNode ? eb.val(valNode.value) : eb.val(null);
684
+ return this.eb.ref(arg.field);
685
+ } else {
686
+ return new import_kysely2.ExpressionWrapper(this.transform(arg, context));
644
687
  }
645
- throw createUnsupportedError(`Unsupported argument expression: ${arg.kind}`);
646
688
  }
647
689
  _member(expr2, context) {
690
+ if (import_schema3.ExpressionUtils.isBinding(expr2.receiver)) {
691
+ const scope = this.requireBindingScope(expr2.receiver, context);
692
+ if (scope.value !== void 0) {
693
+ return this.valueMemberAccess(scope.value, expr2, scope.type);
694
+ }
695
+ }
648
696
  if (this.isAuthCall(expr2.receiver)) {
649
697
  return this.valueMemberAccess(this.auth, expr2, this.authType);
650
698
  }
@@ -653,9 +701,10 @@ var ExpressionTransformer = class {
653
701
  (0, import_common_helpers2.invariant)(expr2.members.length === 1, "before() can only be followed by a scalar field access");
654
702
  return import_kysely2.ReferenceNode.create(import_kysely2.ColumnNode.create(expr2.members[0]), import_kysely2.TableNode.create("$before"));
655
703
  }
656
- (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isField(expr2.receiver) || import_schema3.ExpressionUtils.isThis(expr2.receiver), 'expect receiver to be field expression or "this"');
704
+ (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isField(expr2.receiver) || import_schema3.ExpressionUtils.isThis(expr2.receiver) || import_schema3.ExpressionUtils.isBinding(expr2.receiver), 'expect receiver to be field expression, collection predicate binding, or "this"');
657
705
  let members = expr2.members;
658
706
  let receiver;
707
+ let startType;
659
708
  const { memberFilter, memberSelect, ...restContext } = context;
660
709
  if (import_schema3.ExpressionUtils.isThis(expr2.receiver)) {
661
710
  if (expr2.members.length === 1) {
@@ -670,17 +719,40 @@ var ExpressionTransformer = class {
670
719
  const firstMemberFieldDef = import_orm3.QueryUtils.requireField(this.schema, context.thisType, expr2.members[0]);
671
720
  receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, restContext);
672
721
  members = expr2.members.slice(1);
722
+ startType = firstMemberFieldDef.type;
723
+ }
724
+ } else if (import_schema3.ExpressionUtils.isBinding(expr2.receiver)) {
725
+ if (expr2.members.length === 1) {
726
+ const bindingScope = this.requireBindingScope(expr2.receiver, context);
727
+ return this._field(import_schema3.ExpressionUtils.field(expr2.members[0]), {
728
+ ...context,
729
+ modelOrType: bindingScope.type,
730
+ alias: bindingScope.alias,
731
+ thisType: context.thisType,
732
+ contextValue: void 0
733
+ });
734
+ } else {
735
+ const bindingScope = this.requireBindingScope(expr2.receiver, context);
736
+ const firstMemberFieldDef = import_orm3.QueryUtils.requireField(this.schema, bindingScope.type, expr2.members[0]);
737
+ receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, {
738
+ ...restContext,
739
+ modelOrType: bindingScope.type,
740
+ alias: bindingScope.alias
741
+ });
742
+ members = expr2.members.slice(1);
743
+ startType = firstMemberFieldDef.type;
673
744
  }
674
745
  } else {
675
746
  receiver = this.transform(expr2.receiver, restContext);
676
747
  }
677
748
  (0, import_common_helpers2.invariant)(import_kysely2.SelectQueryNode.is(receiver), "expected receiver to be select query");
678
- let startType;
679
- if (import_schema3.ExpressionUtils.isField(expr2.receiver)) {
680
- const receiverField = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
681
- startType = receiverField.type;
682
- } else {
683
- startType = context.thisType;
749
+ if (startType === void 0) {
750
+ if (import_schema3.ExpressionUtils.isField(expr2.receiver)) {
751
+ const receiverField = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
752
+ startType = receiverField.type;
753
+ } else {
754
+ startType = context.thisType;
755
+ }
684
756
  }
685
757
  const memberFields = [];
686
758
  let currType = startType;
@@ -731,6 +803,11 @@ var ExpressionTransformer = class {
731
803
  ]
732
804
  };
733
805
  }
806
+ requireBindingScope(expr2, context) {
807
+ const binding = context.bindingScope?.[expr2.name];
808
+ (0, import_common_helpers2.invariant)(binding, `binding not found: ${expr2.name}`);
809
+ return binding;
810
+ }
734
811
  valueMemberAccess(receiver, expr2, receiverType) {
735
812
  if (!receiver) {
736
813
  return import_kysely2.ValueNode.createImmediate(null);
@@ -796,6 +873,19 @@ var ExpressionTransformer = class {
796
873
  }
797
874
  return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel);
798
875
  }
876
+ // convert transformer's binding scope to equivalent expression evaluator binding scope
877
+ getEvaluationBindingScope(scope) {
878
+ if (!scope) {
879
+ return void 0;
880
+ }
881
+ const result = {};
882
+ for (const [key, value] of Object.entries(scope)) {
883
+ if (value.value !== void 0) {
884
+ result[key] = value.value;
885
+ }
886
+ }
887
+ return Object.keys(result).length > 0 ? result : void 0;
888
+ }
799
889
  buildDelegateBaseFieldSelect(model, modelAlias, field, baseModel) {
800
890
  const idFields = import_orm3.QueryUtils.requireIdFields(this.client.$schema, model);
801
891
  return {
@@ -920,6 +1010,7 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
920
1010
  }
921
1011
  client;
922
1012
  dialect;
1013
+ eb = (0, import_kysely3.expressionBuilder)();
923
1014
  constructor(client) {
924
1015
  super(), this.client = client;
925
1016
  this.dialect = (0, import_orm4.getCrudDialect)(this.client.$schema, this.client.$options);
@@ -943,52 +1034,21 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
943
1034
  if (import_kysely3.UpdateQueryNode.is(node)) {
944
1035
  await this.preUpdateCheck(mutationModel, node, proceed);
945
1036
  }
946
- const hasPostUpdatePolicies = import_kysely3.UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
1037
+ const needsPostUpdateCheck = import_kysely3.UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
947
1038
  let beforeUpdateInfo;
948
- if (hasPostUpdatePolicies) {
949
- beforeUpdateInfo = await this.loadBeforeUpdateEntities(mutationModel, node.where, proceed);
1039
+ if (needsPostUpdateCheck) {
1040
+ beforeUpdateInfo = await this.loadBeforeUpdateEntities(
1041
+ mutationModel,
1042
+ node.where,
1043
+ proceed,
1044
+ // force load pre-update entities if dialect doesn't support returning,
1045
+ // so we can rely on pre-update ids to read back updated entities
1046
+ !this.dialect.supportsReturning
1047
+ );
950
1048
  }
951
1049
  const result = await proceed(this.transformNode(node));
952
- if (hasPostUpdatePolicies && result.rows.length > 0) {
953
- if (beforeUpdateInfo) {
954
- (0, import_common_helpers3.invariant)(beforeUpdateInfo.rows.length === result.rows.length);
955
- const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, mutationModel);
956
- for (const postRow of result.rows) {
957
- const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
958
- if (!beforeRow) {
959
- throw createRejectedByPolicyError(mutationModel, import_orm4.RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.");
960
- }
961
- }
962
- }
963
- const idConditions = this.buildIdConditions(mutationModel, result.rows);
964
- const postUpdateFilter = this.buildPolicyFilter(mutationModel, void 0, "post-update");
965
- const eb = (0, import_kysely3.expressionBuilder)();
966
- const beforeUpdateTable = beforeUpdateInfo ? {
967
- kind: "SelectQueryNode",
968
- from: import_kysely3.FromNode.create([
969
- import_kysely3.ParensNode.create(import_kysely3.ValuesNode.create(beforeUpdateInfo.rows.map((r) => import_kysely3.PrimitiveValueListNode.create(beforeUpdateInfo.fields.map((f) => r[f])))))
970
- ]),
971
- selections: beforeUpdateInfo.fields.map((name, index) => {
972
- const def = import_orm4.QueryUtils.requireField(this.client.$schema, mutationModel, name);
973
- const castedColumnRef = import_kysely3.sql`CAST(${eb.ref(`column${index + 1}`)} as ${import_kysely3.sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
974
- return import_kysely3.SelectionNode.create(castedColumnRef.toOperationNode());
975
- })
976
- } : void 0;
977
- const postUpdateQuery = eb.selectFrom(mutationModel).select(() => [
978
- eb(eb.fn("COUNT", [
979
- eb.lit(1)
980
- ]), "=", result.rows.length).as("$condition")
981
- ]).where(() => new import_kysely3.ExpressionWrapper(conjunction(this.dialect, [
982
- idConditions,
983
- postUpdateFilter
984
- ]))).$if(!!beforeUpdateInfo, (qb) => qb.leftJoin(() => new import_kysely3.ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
985
- const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, mutationModel);
986
- return idFields.reduce((acc, f) => acc.onRef(`${mutationModel}.${f}`, "=", `$before.${f}`), join);
987
- }));
988
- const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
989
- if (!postUpdateResult.rows[0]?.$condition) {
990
- throw createRejectedByPolicyError(mutationModel, import_orm4.RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
991
- }
1050
+ if ((result.numAffectedRows ?? 0) > 0 && needsPostUpdateCheck) {
1051
+ await this.postUpdateCheck(mutationModel, beforeUpdateInfo, result, proceed);
992
1052
  }
993
1053
  if (!node.returning || this.onlyReturningId(node)) {
994
1054
  return this.postProcessMutationResult(result, node);
@@ -1027,12 +1087,69 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1027
1087
  modelLevelFilter,
1028
1088
  node.where?.where ?? trueNode(this.dialect)
1029
1089
  ]);
1030
- const preUpdateCheckQuery = (0, import_kysely3.expressionBuilder)().selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(eb.cast(new import_kysely3.ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)), "integer")), eb.lit(0)).as("$filteredCount")).where(() => new import_kysely3.ExpressionWrapper(updateFilter));
1090
+ const preUpdateCheckQuery = this.eb.selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(this.dialect.castInt(new import_kysely3.ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)))), eb.lit(0)).as("$filteredCount")).where(() => new import_kysely3.ExpressionWrapper(updateFilter));
1031
1091
  const preUpdateResult = await proceed(preUpdateCheckQuery.toOperationNode());
1032
1092
  if (preUpdateResult.rows[0].$filteredCount > 0) {
1033
1093
  throw createRejectedByPolicyError(mutationModel, import_orm4.RejectedByPolicyReason.NO_ACCESS, "some rows cannot be updated due to field policies");
1034
1094
  }
1035
1095
  }
1096
+ async postUpdateCheck(model, beforeUpdateInfo, updateResult, proceed) {
1097
+ let postUpdateRows;
1098
+ if (this.dialect.supportsReturning) {
1099
+ postUpdateRows = updateResult.rows;
1100
+ } else {
1101
+ (0, import_common_helpers3.invariant)(beforeUpdateInfo, "beforeUpdateInfo must be defined for dialects not supporting returning");
1102
+ const idConditions2 = this.buildIdConditions(model, beforeUpdateInfo.rows);
1103
+ const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, model);
1104
+ const postUpdateQuery2 = {
1105
+ kind: "SelectQueryNode",
1106
+ from: import_kysely3.FromNode.create([
1107
+ import_kysely3.TableNode.create(model)
1108
+ ]),
1109
+ where: import_kysely3.WhereNode.create(idConditions2),
1110
+ selections: idFields.map((field) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(field)))
1111
+ };
1112
+ const postUpdateQueryResult = await proceed(postUpdateQuery2);
1113
+ postUpdateRows = postUpdateQueryResult.rows;
1114
+ }
1115
+ if (beforeUpdateInfo) {
1116
+ if (beforeUpdateInfo.rows.length !== postUpdateRows.length) {
1117
+ throw createRejectedByPolicyError(model, import_orm4.RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
1118
+ }
1119
+ const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, model);
1120
+ for (const postRow of postUpdateRows) {
1121
+ const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
1122
+ if (!beforeRow) {
1123
+ throw createRejectedByPolicyError(model, import_orm4.RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
1124
+ }
1125
+ }
1126
+ }
1127
+ const idConditions = this.buildIdConditions(model, postUpdateRows);
1128
+ const postUpdateFilter = this.buildPolicyFilter(model, void 0, "post-update");
1129
+ const eb = (0, import_kysely3.expressionBuilder)();
1130
+ const needsBeforeUpdateJoin = !!beforeUpdateInfo?.fields;
1131
+ let beforeUpdateTable = void 0;
1132
+ if (needsBeforeUpdateJoin) {
1133
+ const fieldDefs = beforeUpdateInfo.fields.map((name) => import_orm4.QueryUtils.requireField(this.client.$schema, model, name));
1134
+ const rows = beforeUpdateInfo.rows.map((r) => beforeUpdateInfo.fields.map((f) => r[f]));
1135
+ beforeUpdateTable = this.dialect.buildValuesTableSelect(fieldDefs, rows).toOperationNode();
1136
+ }
1137
+ const postUpdateQuery = eb.selectFrom(model).select(() => [
1138
+ eb(eb.fn("COUNT", [
1139
+ eb.lit(1)
1140
+ ]), "=", Number(updateResult.numAffectedRows ?? 0)).as("$condition")
1141
+ ]).where(() => new import_kysely3.ExpressionWrapper(conjunction(this.dialect, [
1142
+ idConditions,
1143
+ postUpdateFilter
1144
+ ]))).$if(needsBeforeUpdateJoin, (qb) => qb.leftJoin(() => new import_kysely3.ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
1145
+ const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, model);
1146
+ return idFields.reduce((acc, f) => acc.onRef(`${model}.${f}`, "=", `$before.${f}`), join);
1147
+ }));
1148
+ const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
1149
+ if (!postUpdateResult.rows[0]?.$condition) {
1150
+ throw createRejectedByPolicyError(model, import_orm4.RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
1151
+ }
1152
+ }
1036
1153
  // #endregion
1037
1154
  // #region Transformations
1038
1155
  transformSelectQuery(node) {
@@ -1100,6 +1217,7 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1100
1217
  };
1101
1218
  }
1102
1219
  transformInsertQuery(node) {
1220
+ let processedNode = node;
1103
1221
  let onConflict = node.onConflict;
1104
1222
  if (onConflict?.updates) {
1105
1223
  const { mutationModel, alias } = this.getMutationModel(node);
@@ -1118,11 +1236,36 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1118
1236
  updateWhere: import_kysely3.WhereNode.create(filter)
1119
1237
  };
1120
1238
  }
1239
+ processedNode = {
1240
+ ...node,
1241
+ onConflict
1242
+ };
1243
+ }
1244
+ let onDuplicateKey = node.onDuplicateKey;
1245
+ if (onDuplicateKey?.updates) {
1246
+ const { mutationModel } = this.getMutationModel(node);
1247
+ const filterWithTableRef = this.buildPolicyFilter(mutationModel, void 0, "update");
1248
+ const filter = this.stripTableReferences(filterWithTableRef, mutationModel);
1249
+ const wrappedUpdates = onDuplicateKey.updates.map((update) => {
1250
+ const columnName = import_kysely3.ColumnNode.is(update.column) ? update.column.column.name : void 0;
1251
+ if (!columnName) {
1252
+ return update;
1253
+ }
1254
+ const wrappedValue = import_kysely3.sql`IF(${new import_kysely3.ExpressionWrapper(filter)}, ${new import_kysely3.ExpressionWrapper(update.value)}, ${import_kysely3.sql.ref(columnName)})`.toOperationNode();
1255
+ return {
1256
+ ...update,
1257
+ value: wrappedValue
1258
+ };
1259
+ });
1260
+ onDuplicateKey = {
1261
+ ...onDuplicateKey,
1262
+ updates: wrappedUpdates
1263
+ };
1264
+ processedNode = {
1265
+ ...processedNode,
1266
+ onDuplicateKey
1267
+ };
1121
1268
  }
1122
- const processedNode = onConflict ? {
1123
- ...node,
1124
- onConflict
1125
- } : node;
1126
1269
  const result = super.transformInsertQuery(processedNode);
1127
1270
  let returning = result.returning;
1128
1271
  if (returning) {
@@ -1150,7 +1293,7 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1150
1293
  }
1151
1294
  }
1152
1295
  let returning = result.returning;
1153
- if (returning || this.hasPostUpdatePolicies(mutationModel)) {
1296
+ if (this.dialect.supportsReturning && (returning || this.hasPostUpdatePolicies(mutationModel))) {
1154
1297
  const idFields = import_orm4.QueryUtils.requireIdFields(this.client.$schema, mutationModel);
1155
1298
  returning = import_kysely3.ReturningNode.create(idFields.map((f) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(f))));
1156
1299
  }
@@ -1187,9 +1330,9 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1187
1330
  }
1188
1331
  // #endregion
1189
1332
  // #region post-update
1190
- async loadBeforeUpdateEntities(model, where, proceed) {
1333
+ async loadBeforeUpdateEntities(model, where, proceed, forceLoad = false) {
1191
1334
  const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
1192
- if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
1335
+ if (!forceLoad && (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0)) {
1193
1336
  return void 0;
1194
1337
  }
1195
1338
  const policyFilter = this.buildPolicyFilter(model, model, "update");
@@ -1197,15 +1340,14 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1197
1340
  where.where,
1198
1341
  policyFilter
1199
1342
  ]) : policyFilter;
1343
+ const selections = beforeUpdateAccessFields ?? import_orm4.QueryUtils.requireIdFields(this.client.$schema, model);
1200
1344
  const query = {
1201
1345
  kind: "SelectQueryNode",
1202
1346
  from: import_kysely3.FromNode.create([
1203
1347
  import_kysely3.TableNode.create(model)
1204
1348
  ]),
1205
1349
  where: import_kysely3.WhereNode.create(combinedFilter),
1206
- selections: [
1207
- ...beforeUpdateAccessFields.map((f) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(f)))
1208
- ]
1350
+ selections: selections.map((f) => import_kysely3.SelectionNode.create(import_kysely3.ColumnNode.create(f)))
1209
1351
  };
1210
1352
  const result = await proceed(query);
1211
1353
  return {
@@ -1385,43 +1527,24 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1385
1527
  }
1386
1528
  }
1387
1529
  async enforcePreCreatePolicyForOne(model, fields, values, proceed) {
1388
- const allFields = Object.entries(import_orm4.QueryUtils.requireModel(this.client.$schema, model).fields).filter(([, def]) => !def.relation);
1530
+ const allFields = import_orm4.QueryUtils.getModelFields(this.client.$schema, model, {
1531
+ inherited: true
1532
+ });
1389
1533
  const allValues = [];
1390
- for (const [name, _def] of allFields) {
1391
- const index = fields.indexOf(name);
1534
+ for (const def of allFields) {
1535
+ const index = fields.indexOf(def.name);
1392
1536
  if (index >= 0) {
1393
- allValues.push(values[index]);
1537
+ allValues.push(new import_kysely3.ExpressionWrapper(values[index]));
1394
1538
  } else {
1395
- allValues.push(import_kysely3.ValueNode.createImmediate(null));
1539
+ allValues.push(this.eb.lit(null));
1396
1540
  }
1397
1541
  }
1398
- const eb = (0, import_kysely3.expressionBuilder)();
1399
- const constTable = {
1400
- kind: "SelectQueryNode",
1401
- from: import_kysely3.FromNode.create([
1402
- import_kysely3.AliasNode.create(import_kysely3.ParensNode.create(import_kysely3.ValuesNode.create([
1403
- import_kysely3.ValueListNode.create(allValues)
1404
- ])), import_kysely3.IdentifierNode.create("$t"))
1405
- ]),
1406
- selections: allFields.map(([name, def], index) => {
1407
- const castedColumnRef = import_kysely3.sql`CAST(${eb.ref(`column${index + 1}`)} as ${import_kysely3.sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
1408
- return import_kysely3.SelectionNode.create(castedColumnRef.toOperationNode());
1409
- })
1410
- };
1542
+ const valuesTable = this.dialect.buildValuesTableSelect(allFields, [
1543
+ allValues
1544
+ ]);
1411
1545
  const filter = this.buildPolicyFilter(model, void 0, "create");
1412
- const preCreateCheck = {
1413
- kind: "SelectQueryNode",
1414
- from: import_kysely3.FromNode.create([
1415
- import_kysely3.AliasNode.create(constTable, import_kysely3.IdentifierNode.create(model))
1416
- ]),
1417
- selections: [
1418
- import_kysely3.SelectionNode.create(import_kysely3.AliasNode.create(import_kysely3.BinaryOperationNode.create(import_kysely3.FunctionNode.create("COUNT", [
1419
- import_kysely3.ValueNode.createImmediate(1)
1420
- ]), import_kysely3.OperatorNode.create(">"), import_kysely3.ValueNode.createImmediate(0)), import_kysely3.IdentifierNode.create("$condition")))
1421
- ],
1422
- where: import_kysely3.WhereNode.create(filter)
1423
- };
1424
- const result = await proceed(preCreateCheck);
1546
+ const preCreateCheck = this.eb.selectFrom(valuesTable.as(model)).select(this.eb(this.eb.fn.count(this.eb.lit(1)), ">", 0).as("$condition")).where(() => new import_kysely3.ExpressionWrapper(filter));
1547
+ const result = await proceed(preCreateCheck.toOperationNode());
1425
1548
  if (!result.rows[0]?.$condition) {
1426
1549
  throw createRejectedByPolicyError(model, import_orm4.RejectedByPolicyReason.NO_ACCESS);
1427
1550
  }
@@ -1446,18 +1569,19 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1446
1569
  const fieldDef = import_orm4.QueryUtils.requireField(this.client.$schema, model, fields[i]);
1447
1570
  (0, import_common_helpers3.invariant)(item.kind === "ValueNode", "expecting a ValueNode");
1448
1571
  result.push({
1449
- node: import_kysely3.ValueNode.create(this.dialect.transformPrimitive(item.value, fieldDef.type, !!fieldDef.array)),
1572
+ node: import_kysely3.ValueNode.create(this.dialect.transformInput(item.value, fieldDef.type, !!fieldDef.array)),
1450
1573
  raw: item.value
1451
1574
  });
1452
1575
  } else {
1453
1576
  let value = item;
1454
1577
  if (!isImplicitManyToManyJoinTable) {
1455
1578
  const fieldDef = import_orm4.QueryUtils.requireField(this.client.$schema, model, fields[i]);
1456
- value = this.dialect.transformPrimitive(item, fieldDef.type, !!fieldDef.array);
1579
+ value = this.dialect.transformInput(item, fieldDef.type, !!fieldDef.array);
1457
1580
  }
1458
1581
  if (Array.isArray(value)) {
1582
+ const fieldDef = import_orm4.QueryUtils.requireField(this.client.$schema, model, fields[i]);
1459
1583
  result.push({
1460
- node: import_kysely3.RawNode.createWithSql(this.dialect.buildArrayLiteralSQL(value)),
1584
+ node: this.dialect.buildArrayValue(value, fieldDef.type).toOperationNode(),
1461
1585
  raw: value
1462
1586
  });
1463
1587
  } else {
@@ -1705,11 +1829,10 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1705
1829
  return void 0;
1706
1830
  }
1707
1831
  const checkForOperation = operation === "read" ? "read" : "update";
1708
- const eb = (0, import_kysely3.expressionBuilder)();
1709
1832
  const joinTable = alias ?? tableName;
1710
- const aQuery = eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new import_kysely3.ExpressionWrapper(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
1711
- const bQuery = eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new import_kysely3.ExpressionWrapper(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
1712
- return eb.and([
1833
+ const aQuery = this.eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new import_kysely3.ExpressionWrapper(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
1834
+ const bQuery = this.eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new import_kysely3.ExpressionWrapper(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
1835
+ return this.eb.and([
1713
1836
  aQuery,
1714
1837
  bQuery
1715
1838
  ]).toOperationNode();
@@ -1740,6 +1863,26 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1740
1863
  };
1741
1864
  }
1742
1865
  }
1866
+ // strips table references from an OperationNode
1867
+ stripTableReferences(node, modelName) {
1868
+ return new TableReferenceStripper().strip(node, modelName);
1869
+ }
1870
+ };
1871
+ var TableReferenceStripper = class TableReferenceStripper2 extends import_kysely3.OperationNodeTransformer {
1872
+ static {
1873
+ __name(this, "TableReferenceStripper");
1874
+ }
1875
+ tableName = "";
1876
+ strip(node, tableName) {
1877
+ this.tableName = tableName;
1878
+ return this.transformNode(node);
1879
+ }
1880
+ transformReference(node) {
1881
+ if (import_kysely3.ColumnNode.is(node.column) && node.table?.table.identifier.name === this.tableName) {
1882
+ return import_kysely3.ReferenceNode.create(this.transformNode(node.column));
1883
+ }
1884
+ return super.transformReference(node);
1885
+ }
1743
1886
  };
1744
1887
 
1745
1888
  // src/functions.ts
@@ -1788,7 +1931,7 @@ var check = /* @__PURE__ */ __name((eb, args, { client, model, modelAlias, opera
1788
1931
  const policyHandler = new PolicyHandler(client);
1789
1932
  const op = arg2Node ? arg2Node.value : operation;
1790
1933
  const policyCondition = policyHandler.buildPolicyFilter(relationModel, void 0, op);
1791
- const result = eb.selectFrom(relationModel).where(joinCondition).select(new import_kysely4.ExpressionWrapper(policyCondition).as("$condition"));
1934
+ const result = eb.selectFrom(eb.selectFrom(relationModel).where(joinCondition).select(new import_kysely4.ExpressionWrapper(policyCondition).as("$condition")).as("$sub")).selectAll();
1792
1935
  return result;
1793
1936
  }, "check");
1794
1937