@zenstackhq/plugin-policy 3.3.0-beta.2 → 3.3.0-beta.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -4,13 +4,13 @@ var __name = (target, value) => __defProp(target, "name", { value, configurable:
4
4
  // src/functions.ts
5
5
  import { invariant as invariant4 } from "@zenstackhq/common-helpers";
6
6
  import { CRUD, QueryUtils as QueryUtils3 } from "@zenstackhq/orm";
7
- import { ExpressionWrapper as ExpressionWrapper2, ValueNode as ValueNode4 } from "kysely";
7
+ import { ExpressionWrapper as ExpressionWrapper3, ValueNode as ValueNode4 } from "kysely";
8
8
 
9
9
  // src/policy-handler.ts
10
10
  import { invariant as invariant3 } from "@zenstackhq/common-helpers";
11
11
  import { getCrudDialect as getCrudDialect2, QueryUtils as QueryUtils2, RejectedByPolicyReason, SchemaUtils as SchemaUtils2 } from "@zenstackhq/orm";
12
12
  import { ExpressionUtils as ExpressionUtils4 } from "@zenstackhq/orm/schema";
13
- import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode3, DeleteQueryNode, expressionBuilder as expressionBuilder2, ExpressionWrapper, FromNode as FromNode2, FunctionNode as FunctionNode3, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, RawNode, ReferenceNode as ReferenceNode3, ReturningNode, SelectAllNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql, TableNode as TableNode3, UpdateQueryNode, ValueListNode as ValueListNode2, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
13
+ import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode3, DeleteQueryNode, expressionBuilder as expressionBuilder2, ExpressionWrapper as ExpressionWrapper2, FromNode as FromNode2, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, ReferenceNode as ReferenceNode3, ReturningNode, SelectAllNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql, TableNode as TableNode3, UpdateQueryNode, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
14
14
  import { match as match3 } from "ts-pattern";
15
15
 
16
16
  // src/column-collector.ts
@@ -36,19 +36,19 @@ var ColumnCollector = class extends KyselyUtils.DefaultOperationNodeVisitor {
36
36
  import { invariant as invariant2 } from "@zenstackhq/common-helpers";
37
37
  import { getCrudDialect, QueryUtils, SchemaUtils } from "@zenstackhq/orm";
38
38
  import { ExpressionUtils as ExpressionUtils3 } from "@zenstackhq/orm/schema";
39
- import { AliasNode as AliasNode2, BinaryOperationNode as BinaryOperationNode2, ColumnNode as ColumnNode2, expressionBuilder, FromNode, FunctionNode as FunctionNode2, IdentifierNode, OperatorNode as OperatorNode2, ReferenceNode as ReferenceNode2, SelectionNode, SelectQueryNode, TableNode as TableNode2, ValueListNode, ValueNode as ValueNode2, WhereNode } from "kysely";
39
+ import { AliasNode as AliasNode2, BinaryOperationNode as BinaryOperationNode2, ColumnNode as ColumnNode2, expressionBuilder, ExpressionWrapper, FromNode, FunctionNode as FunctionNode2, IdentifierNode, OperatorNode as OperatorNode2, ReferenceNode as ReferenceNode2, SelectionNode, SelectQueryNode, TableNode as TableNode2, ValueListNode, ValueNode as ValueNode2, WhereNode } from "kysely";
40
40
  import { match as match2 } from "ts-pattern";
41
41
 
42
42
  // src/expression-evaluator.ts
43
43
  import { invariant } from "@zenstackhq/common-helpers";
44
- import { match } from "ts-pattern";
45
44
  import { ExpressionUtils } from "@zenstackhq/orm/schema";
45
+ import { match } from "ts-pattern";
46
46
  var ExpressionEvaluator = class {
47
47
  static {
48
48
  __name(this, "ExpressionEvaluator");
49
49
  }
50
50
  evaluate(expression, context) {
51
- const result = match(expression).when(ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(ExpressionUtils.isThis, () => context.thisValue).when(ExpressionUtils.isNull, () => null).exhaustive();
51
+ const result = match(expression).when(ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(ExpressionUtils.isBinding, (expr2) => this.evaluateBinding(expr2, context)).when(ExpressionUtils.isThis, () => context.thisValue).when(ExpressionUtils.isNull, () => null).exhaustive();
52
52
  return result ?? null;
53
53
  }
54
54
  evaluateCall(expr2, context) {
@@ -72,6 +72,9 @@ var ExpressionEvaluator = class {
72
72
  return expr2.value;
73
73
  }
74
74
  evaluateField(expr2, context) {
75
+ if (context.bindingScope && expr2.field in context.bindingScope) {
76
+ return context.bindingScope[expr2.field];
77
+ }
75
78
  return context.thisValue?.[expr2.field];
76
79
  }
77
80
  evaluateArray(expr2, context) {
@@ -105,15 +108,33 @@ var ExpressionEvaluator = class {
105
108
  invariant(Array.isArray(left), "expected array");
106
109
  return match(op).with("?", () => left.some((item) => this.evaluate(expr2.right, {
107
110
  ...context,
108
- thisValue: item
111
+ thisValue: item,
112
+ bindingScope: expr2.binding ? {
113
+ ...context.bindingScope ?? {},
114
+ [expr2.binding]: item
115
+ } : context.bindingScope
109
116
  }))).with("!", () => left.every((item) => this.evaluate(expr2.right, {
110
117
  ...context,
111
- thisValue: item
118
+ thisValue: item,
119
+ bindingScope: expr2.binding ? {
120
+ ...context.bindingScope ?? {},
121
+ [expr2.binding]: item
122
+ } : context.bindingScope
112
123
  }))).with("^", () => !left.some((item) => this.evaluate(expr2.right, {
113
124
  ...context,
114
- thisValue: item
125
+ thisValue: item,
126
+ bindingScope: expr2.binding ? {
127
+ ...context.bindingScope ?? {},
128
+ [expr2.binding]: item
129
+ } : context.bindingScope
115
130
  }))).exhaustive();
116
131
  }
132
+ evaluateBinding(expr2, context) {
133
+ if (!context.bindingScope || !(expr2.name in context.bindingScope)) {
134
+ throw new Error(`Unresolved binding: ${expr2.name}`);
135
+ }
136
+ return context.bindingScope[expr2.name];
137
+ }
117
138
  };
118
139
 
119
140
  // src/types.ts
@@ -128,11 +149,11 @@ import { ORMError, ORMErrorReason } from "@zenstackhq/orm";
128
149
  import { ExpressionUtils as ExpressionUtils2 } from "@zenstackhq/orm/schema";
129
150
  import { AliasNode, AndNode, BinaryOperationNode, ColumnNode, FunctionNode, OperatorNode, OrNode, ParensNode, ReferenceNode, TableNode, UnaryOperationNode, ValueNode } from "kysely";
130
151
  function trueNode(dialect) {
131
- return ValueNode.createImmediate(dialect.transformPrimitive(true, "Boolean", false));
152
+ return ValueNode.createImmediate(dialect.transformInput(true, "Boolean", false));
132
153
  }
133
154
  __name(trueNode, "trueNode");
134
155
  function falseNode(dialect) {
135
- return ValueNode.createImmediate(dialect.transformPrimitive(false, "Boolean", false));
156
+ return ValueNode.createImmediate(dialect.transformInput(false, "Boolean", false));
136
157
  }
137
158
  __name(falseNode, "falseNode");
138
159
  function isTrueNode(node) {
@@ -266,6 +287,7 @@ var ExpressionTransformer = class {
266
287
  }
267
288
  client;
268
289
  dialect;
290
+ eb = expressionBuilder();
269
291
  constructor(client) {
270
292
  this.client = client;
271
293
  this.dialect = getCrudDialect(this.schema, this.clientOptions);
@@ -290,13 +312,15 @@ var ExpressionTransformer = class {
290
312
  if (!handler) {
291
313
  throw new Error(`Unsupported expression kind: ${expression.kind}`);
292
314
  }
293
- return handler.value.call(this, expression, context);
315
+ const result = handler.value.call(this, expression, context);
316
+ invariant2("kind" in result, `expression handler must return an OperationNode: transforming ${expression.kind}`);
317
+ return result;
294
318
  }
295
319
  _literal(expr2) {
296
320
  return this.transformValue(expr2.value, typeof expr2.value === "string" ? "String" : typeof expr2.value === "boolean" ? "Boolean" : "Int");
297
321
  }
298
322
  _array(expr2, context) {
299
- return ValueListNode.create(expr2.items.map((item) => this.transform(item, context)));
323
+ return this.dialect.buildArrayValue(expr2.items.map((item) => new ExpressionWrapper(this.transform(item, context))), expr2.type).toOperationNode();
300
324
  }
301
325
  _field(expr2, context) {
302
326
  if (context.contextValue) {
@@ -426,7 +450,8 @@ var ExpressionTransformer = class {
426
450
  const evaluator = new ExpressionEvaluator();
427
451
  const receiver = evaluator.evaluate(expr2.left, {
428
452
  thisValue: context.contextValue,
429
- auth: this.auth
453
+ auth: this.auth,
454
+ bindingScope: this.getEvaluationBindingScope(context.bindingScope)
430
455
  });
431
456
  const baseType = this.isAuthMember(expr2.left) ? this.authType : context.modelOrType;
432
457
  const memberType = this.getMemberType(baseType, expr2.left);
@@ -442,18 +467,31 @@ var ExpressionTransformer = class {
442
467
  invariant2(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr2.left)}`);
443
468
  newContextModel = fieldDef.type;
444
469
  } else {
445
- invariant2(ExpressionUtils3.isMember(expr2.left) && ExpressionUtils3.isField(expr2.left.receiver), "left operand must be member access with field receiver");
446
- const fieldDef2 = QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
447
- newContextModel = fieldDef2.type;
470
+ invariant2(ExpressionUtils3.isMember(expr2.left) && (ExpressionUtils3.isField(expr2.left.receiver) || ExpressionUtils3.isBinding(expr2.left.receiver)), "left operand must be member access with field receiver");
471
+ if (ExpressionUtils3.isField(expr2.left.receiver)) {
472
+ const fieldDef2 = QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
473
+ newContextModel = fieldDef2.type;
474
+ } else {
475
+ const binding = this.requireBindingScope(expr2.left.receiver, context);
476
+ newContextModel = binding.type;
477
+ }
448
478
  for (const member of expr2.left.members) {
449
479
  const memberDef = QueryUtils.requireField(this.schema, newContextModel, member);
450
480
  newContextModel = memberDef.type;
451
481
  }
452
482
  }
483
+ const bindingScope = expr2.binding ? {
484
+ ...context.bindingScope ?? {},
485
+ [expr2.binding]: {
486
+ type: newContextModel,
487
+ alias: newContextModel
488
+ }
489
+ } : context.bindingScope;
453
490
  let predicateFilter = this.transform(expr2.right, {
454
491
  ...context,
455
492
  modelOrType: newContextModel,
456
- alias: void 0
493
+ alias: void 0,
494
+ bindingScope
457
495
  });
458
496
  if (expr2.op === "!") {
459
497
  predicateFilter = logicalNot(this.dialect, predicateFilter);
@@ -480,18 +518,30 @@ var ExpressionTransformer = class {
480
518
  if (!visitor.find(expr2.right)) {
481
519
  const value = new ExpressionEvaluator().evaluate(expr2, {
482
520
  auth: this.auth,
483
- thisValue: context.contextValue
521
+ thisValue: context.contextValue,
522
+ bindingScope: this.getEvaluationBindingScope(context.bindingScope)
484
523
  });
485
524
  return this.transformValue(value, "Boolean");
486
525
  } else {
487
526
  invariant2(Array.isArray(receiver), "array value is expected");
488
- const components = receiver.map((item) => this.transform(expr2.right, {
489
- operation: context.operation,
490
- thisType: context.thisType,
491
- thisAlias: context.thisAlias,
492
- modelOrType: context.modelOrType,
493
- contextValue: item
494
- }));
527
+ const components = receiver.map((item) => {
528
+ const bindingScope = expr2.binding ? {
529
+ ...context.bindingScope ?? {},
530
+ [expr2.binding]: {
531
+ type: context.modelOrType,
532
+ alias: context.thisAlias ?? context.modelOrType,
533
+ value: item
534
+ }
535
+ } : context.bindingScope;
536
+ return this.transform(expr2.right, {
537
+ operation: context.operation,
538
+ thisType: context.thisType,
539
+ thisAlias: context.thisAlias,
540
+ modelOrType: context.modelOrType,
541
+ contextValue: item,
542
+ bindingScope
543
+ });
544
+ });
495
545
  return match2(expr2.op).with("?", () => disjunction(this.dialect, components)).with("!", () => conjunction(this.dialect, components)).with("^", () => logicalNot(this.dialect, disjunction(this.dialect, components))).exhaustive();
496
546
  }
497
547
  }
@@ -557,9 +607,11 @@ var ExpressionTransformer = class {
557
607
  return trueNode(this.dialect);
558
608
  } else if (value === false) {
559
609
  return falseNode(this.dialect);
610
+ } else if (Array.isArray(value)) {
611
+ return this.dialect.buildArrayValue(value.map((v) => new ExpressionWrapper(this.transformValue(v, type))), type).toOperationNode();
560
612
  } else {
561
- const transformed = this.dialect.transformPrimitive(value, type, false) ?? null;
562
- if (!Array.isArray(transformed)) {
613
+ const transformed = this.dialect.transformInput(value, type, false) ?? null;
614
+ if (typeof transformed !== "string") {
563
615
  return ValueNode2.createImmediate(transformed);
564
616
  } else {
565
617
  return ValueNode2.create(transformed);
@@ -583,8 +635,7 @@ var ExpressionTransformer = class {
583
635
  if (!func) {
584
636
  throw createUnsupportedError(`Function not implemented: ${expr2.function}`);
585
637
  }
586
- const eb = expressionBuilder();
587
- return func(eb, (expr2.args ?? []).map((arg) => this.transformCallArg(eb, arg, context)), {
638
+ return func(this.eb, (expr2.args ?? []).map((arg) => this.transformCallArg(arg, context)), {
588
639
  client: this.client,
589
640
  dialect: this.dialect,
590
641
  model: context.modelOrType,
@@ -604,23 +655,20 @@ var ExpressionTransformer = class {
604
655
  }
605
656
  return func;
606
657
  }
607
- transformCallArg(eb, arg, context) {
608
- if (ExpressionUtils3.isLiteral(arg)) {
609
- return eb.val(arg.value);
610
- }
658
+ transformCallArg(arg, context) {
611
659
  if (ExpressionUtils3.isField(arg)) {
612
- return eb.ref(arg.field);
613
- }
614
- if (ExpressionUtils3.isCall(arg)) {
615
- return this.transformCall(arg, context);
616
- }
617
- if (this.isAuthMember(arg)) {
618
- const valNode = this.valueMemberAccess(this.auth, arg, this.authType);
619
- return valNode ? eb.val(valNode.value) : eb.val(null);
660
+ return this.eb.ref(arg.field);
661
+ } else {
662
+ return new ExpressionWrapper(this.transform(arg, context));
620
663
  }
621
- throw createUnsupportedError(`Unsupported argument expression: ${arg.kind}`);
622
664
  }
623
665
  _member(expr2, context) {
666
+ if (ExpressionUtils3.isBinding(expr2.receiver)) {
667
+ const scope = this.requireBindingScope(expr2.receiver, context);
668
+ if (scope.value !== void 0) {
669
+ return this.valueMemberAccess(scope.value, expr2, scope.type);
670
+ }
671
+ }
624
672
  if (this.isAuthCall(expr2.receiver)) {
625
673
  return this.valueMemberAccess(this.auth, expr2, this.authType);
626
674
  }
@@ -629,9 +677,10 @@ var ExpressionTransformer = class {
629
677
  invariant2(expr2.members.length === 1, "before() can only be followed by a scalar field access");
630
678
  return ReferenceNode2.create(ColumnNode2.create(expr2.members[0]), TableNode2.create("$before"));
631
679
  }
632
- invariant2(ExpressionUtils3.isField(expr2.receiver) || ExpressionUtils3.isThis(expr2.receiver), 'expect receiver to be field expression or "this"');
680
+ invariant2(ExpressionUtils3.isField(expr2.receiver) || ExpressionUtils3.isThis(expr2.receiver) || ExpressionUtils3.isBinding(expr2.receiver), 'expect receiver to be field expression, collection predicate binding, or "this"');
633
681
  let members = expr2.members;
634
682
  let receiver;
683
+ let startType;
635
684
  const { memberFilter, memberSelect, ...restContext } = context;
636
685
  if (ExpressionUtils3.isThis(expr2.receiver)) {
637
686
  if (expr2.members.length === 1) {
@@ -646,17 +695,40 @@ var ExpressionTransformer = class {
646
695
  const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.thisType, expr2.members[0]);
647
696
  receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, restContext);
648
697
  members = expr2.members.slice(1);
698
+ startType = firstMemberFieldDef.type;
699
+ }
700
+ } else if (ExpressionUtils3.isBinding(expr2.receiver)) {
701
+ if (expr2.members.length === 1) {
702
+ const bindingScope = this.requireBindingScope(expr2.receiver, context);
703
+ return this._field(ExpressionUtils3.field(expr2.members[0]), {
704
+ ...context,
705
+ modelOrType: bindingScope.type,
706
+ alias: bindingScope.alias,
707
+ thisType: context.thisType,
708
+ contextValue: void 0
709
+ });
710
+ } else {
711
+ const bindingScope = this.requireBindingScope(expr2.receiver, context);
712
+ const firstMemberFieldDef = QueryUtils.requireField(this.schema, bindingScope.type, expr2.members[0]);
713
+ receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, {
714
+ ...restContext,
715
+ modelOrType: bindingScope.type,
716
+ alias: bindingScope.alias
717
+ });
718
+ members = expr2.members.slice(1);
719
+ startType = firstMemberFieldDef.type;
649
720
  }
650
721
  } else {
651
722
  receiver = this.transform(expr2.receiver, restContext);
652
723
  }
653
724
  invariant2(SelectQueryNode.is(receiver), "expected receiver to be select query");
654
- let startType;
655
- if (ExpressionUtils3.isField(expr2.receiver)) {
656
- const receiverField = QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
657
- startType = receiverField.type;
658
- } else {
659
- startType = context.thisType;
725
+ if (startType === void 0) {
726
+ if (ExpressionUtils3.isField(expr2.receiver)) {
727
+ const receiverField = QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
728
+ startType = receiverField.type;
729
+ } else {
730
+ startType = context.thisType;
731
+ }
660
732
  }
661
733
  const memberFields = [];
662
734
  let currType = startType;
@@ -707,6 +779,11 @@ var ExpressionTransformer = class {
707
779
  ]
708
780
  };
709
781
  }
782
+ requireBindingScope(expr2, context) {
783
+ const binding = context.bindingScope?.[expr2.name];
784
+ invariant2(binding, `binding not found: ${expr2.name}`);
785
+ return binding;
786
+ }
710
787
  valueMemberAccess(receiver, expr2, receiverType) {
711
788
  if (!receiver) {
712
789
  return ValueNode2.createImmediate(null);
@@ -772,6 +849,19 @@ var ExpressionTransformer = class {
772
849
  }
773
850
  return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel);
774
851
  }
852
+ // convert transformer's binding scope to equivalent expression evaluator binding scope
853
+ getEvaluationBindingScope(scope) {
854
+ if (!scope) {
855
+ return void 0;
856
+ }
857
+ const result = {};
858
+ for (const [key, value] of Object.entries(scope)) {
859
+ if (value.value !== void 0) {
860
+ result[key] = value.value;
861
+ }
862
+ }
863
+ return Object.keys(result).length > 0 ? result : void 0;
864
+ }
775
865
  buildDelegateBaseFieldSelect(model, modelAlias, field, baseModel) {
776
866
  const idFields = QueryUtils.requireIdFields(this.client.$schema, model);
777
867
  return {
@@ -896,6 +986,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
896
986
  }
897
987
  client;
898
988
  dialect;
989
+ eb = expressionBuilder2();
899
990
  constructor(client) {
900
991
  super(), this.client = client;
901
992
  this.dialect = getCrudDialect2(this.client.$schema, this.client.$options);
@@ -919,52 +1010,21 @@ var PolicyHandler = class extends OperationNodeTransformer {
919
1010
  if (UpdateQueryNode.is(node)) {
920
1011
  await this.preUpdateCheck(mutationModel, node, proceed);
921
1012
  }
922
- const hasPostUpdatePolicies = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
1013
+ const needsPostUpdateCheck = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
923
1014
  let beforeUpdateInfo;
924
- if (hasPostUpdatePolicies) {
925
- beforeUpdateInfo = await this.loadBeforeUpdateEntities(mutationModel, node.where, proceed);
1015
+ if (needsPostUpdateCheck) {
1016
+ beforeUpdateInfo = await this.loadBeforeUpdateEntities(
1017
+ mutationModel,
1018
+ node.where,
1019
+ proceed,
1020
+ // force load pre-update entities if dialect doesn't support returning,
1021
+ // so we can rely on pre-update ids to read back updated entities
1022
+ !this.dialect.supportsReturning
1023
+ );
926
1024
  }
927
1025
  const result = await proceed(this.transformNode(node));
928
- if (hasPostUpdatePolicies && result.rows.length > 0) {
929
- if (beforeUpdateInfo) {
930
- invariant3(beforeUpdateInfo.rows.length === result.rows.length);
931
- const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
932
- for (const postRow of result.rows) {
933
- const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
934
- if (!beforeRow) {
935
- throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.");
936
- }
937
- }
938
- }
939
- const idConditions = this.buildIdConditions(mutationModel, result.rows);
940
- const postUpdateFilter = this.buildPolicyFilter(mutationModel, void 0, "post-update");
941
- const eb = expressionBuilder2();
942
- const beforeUpdateTable = beforeUpdateInfo ? {
943
- kind: "SelectQueryNode",
944
- from: FromNode2.create([
945
- ParensNode2.create(ValuesNode.create(beforeUpdateInfo.rows.map((r) => PrimitiveValueListNode.create(beforeUpdateInfo.fields.map((f) => r[f])))))
946
- ]),
947
- selections: beforeUpdateInfo.fields.map((name, index) => {
948
- const def = QueryUtils2.requireField(this.client.$schema, mutationModel, name);
949
- const castedColumnRef = sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
950
- return SelectionNode2.create(castedColumnRef.toOperationNode());
951
- })
952
- } : void 0;
953
- const postUpdateQuery = eb.selectFrom(mutationModel).select(() => [
954
- eb(eb.fn("COUNT", [
955
- eb.lit(1)
956
- ]), "=", result.rows.length).as("$condition")
957
- ]).where(() => new ExpressionWrapper(conjunction(this.dialect, [
958
- idConditions,
959
- postUpdateFilter
960
- ]))).$if(!!beforeUpdateInfo, (qb) => qb.leftJoin(() => new ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
961
- const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
962
- return idFields.reduce((acc, f) => acc.onRef(`${mutationModel}.${f}`, "=", `$before.${f}`), join);
963
- }));
964
- const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
965
- if (!postUpdateResult.rows[0]?.$condition) {
966
- throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
967
- }
1026
+ if ((result.numAffectedRows ?? 0) > 0 && needsPostUpdateCheck) {
1027
+ await this.postUpdateCheck(mutationModel, beforeUpdateInfo, result, proceed);
968
1028
  }
969
1029
  if (!node.returning || this.onlyReturningId(node)) {
970
1030
  return this.postProcessMutationResult(result, node);
@@ -1003,12 +1063,69 @@ var PolicyHandler = class extends OperationNodeTransformer {
1003
1063
  modelLevelFilter,
1004
1064
  node.where?.where ?? trueNode(this.dialect)
1005
1065
  ]);
1006
- const preUpdateCheckQuery = expressionBuilder2().selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(eb.cast(new ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)), "integer")), eb.lit(0)).as("$filteredCount")).where(() => new ExpressionWrapper(updateFilter));
1066
+ const preUpdateCheckQuery = this.eb.selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(this.dialect.castInt(new ExpressionWrapper2(logicalNot(this.dialect, fieldLevelFilter)))), eb.lit(0)).as("$filteredCount")).where(() => new ExpressionWrapper2(updateFilter));
1007
1067
  const preUpdateResult = await proceed(preUpdateCheckQuery.toOperationNode());
1008
1068
  if (preUpdateResult.rows[0].$filteredCount > 0) {
1009
1069
  throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS, "some rows cannot be updated due to field policies");
1010
1070
  }
1011
1071
  }
1072
+ async postUpdateCheck(model, beforeUpdateInfo, updateResult, proceed) {
1073
+ let postUpdateRows;
1074
+ if (this.dialect.supportsReturning) {
1075
+ postUpdateRows = updateResult.rows;
1076
+ } else {
1077
+ invariant3(beforeUpdateInfo, "beforeUpdateInfo must be defined for dialects not supporting returning");
1078
+ const idConditions2 = this.buildIdConditions(model, beforeUpdateInfo.rows);
1079
+ const idFields = QueryUtils2.requireIdFields(this.client.$schema, model);
1080
+ const postUpdateQuery2 = {
1081
+ kind: "SelectQueryNode",
1082
+ from: FromNode2.create([
1083
+ TableNode3.create(model)
1084
+ ]),
1085
+ where: WhereNode2.create(idConditions2),
1086
+ selections: idFields.map((field) => SelectionNode2.create(ColumnNode3.create(field)))
1087
+ };
1088
+ const postUpdateQueryResult = await proceed(postUpdateQuery2);
1089
+ postUpdateRows = postUpdateQueryResult.rows;
1090
+ }
1091
+ if (beforeUpdateInfo) {
1092
+ if (beforeUpdateInfo.rows.length !== postUpdateRows.length) {
1093
+ throw createRejectedByPolicyError(model, RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
1094
+ }
1095
+ const idFields = QueryUtils2.requireIdFields(this.client.$schema, model);
1096
+ for (const postRow of postUpdateRows) {
1097
+ const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
1098
+ if (!beforeRow) {
1099
+ throw createRejectedByPolicyError(model, RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
1100
+ }
1101
+ }
1102
+ }
1103
+ const idConditions = this.buildIdConditions(model, postUpdateRows);
1104
+ const postUpdateFilter = this.buildPolicyFilter(model, void 0, "post-update");
1105
+ const eb = expressionBuilder2();
1106
+ const needsBeforeUpdateJoin = !!beforeUpdateInfo?.fields;
1107
+ let beforeUpdateTable = void 0;
1108
+ if (needsBeforeUpdateJoin) {
1109
+ const fieldDefs = beforeUpdateInfo.fields.map((name) => QueryUtils2.requireField(this.client.$schema, model, name));
1110
+ const rows = beforeUpdateInfo.rows.map((r) => beforeUpdateInfo.fields.map((f) => r[f]));
1111
+ beforeUpdateTable = this.dialect.buildValuesTableSelect(fieldDefs, rows).toOperationNode();
1112
+ }
1113
+ const postUpdateQuery = eb.selectFrom(model).select(() => [
1114
+ eb(eb.fn("COUNT", [
1115
+ eb.lit(1)
1116
+ ]), "=", Number(updateResult.numAffectedRows ?? 0)).as("$condition")
1117
+ ]).where(() => new ExpressionWrapper2(conjunction(this.dialect, [
1118
+ idConditions,
1119
+ postUpdateFilter
1120
+ ]))).$if(needsBeforeUpdateJoin, (qb) => qb.leftJoin(() => new ExpressionWrapper2(beforeUpdateTable).as("$before"), (join) => {
1121
+ const idFields = QueryUtils2.requireIdFields(this.client.$schema, model);
1122
+ return idFields.reduce((acc, f) => acc.onRef(`${model}.${f}`, "=", `$before.${f}`), join);
1123
+ }));
1124
+ const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
1125
+ if (!postUpdateResult.rows[0]?.$condition) {
1126
+ throw createRejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
1127
+ }
1128
+ }
1012
1129
  // #endregion
1013
1130
  // #region Transformations
1014
1131
  transformSelectQuery(node) {
@@ -1076,6 +1193,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1076
1193
  };
1077
1194
  }
1078
1195
  transformInsertQuery(node) {
1196
+ let processedNode = node;
1079
1197
  let onConflict = node.onConflict;
1080
1198
  if (onConflict?.updates) {
1081
1199
  const { mutationModel, alias } = this.getMutationModel(node);
@@ -1094,11 +1212,36 @@ var PolicyHandler = class extends OperationNodeTransformer {
1094
1212
  updateWhere: WhereNode2.create(filter)
1095
1213
  };
1096
1214
  }
1215
+ processedNode = {
1216
+ ...node,
1217
+ onConflict
1218
+ };
1219
+ }
1220
+ let onDuplicateKey = node.onDuplicateKey;
1221
+ if (onDuplicateKey?.updates) {
1222
+ const { mutationModel } = this.getMutationModel(node);
1223
+ const filterWithTableRef = this.buildPolicyFilter(mutationModel, void 0, "update");
1224
+ const filter = this.stripTableReferences(filterWithTableRef, mutationModel);
1225
+ const wrappedUpdates = onDuplicateKey.updates.map((update) => {
1226
+ const columnName = ColumnNode3.is(update.column) ? update.column.column.name : void 0;
1227
+ if (!columnName) {
1228
+ return update;
1229
+ }
1230
+ const wrappedValue = sql`IF(${new ExpressionWrapper2(filter)}, ${new ExpressionWrapper2(update.value)}, ${sql.ref(columnName)})`.toOperationNode();
1231
+ return {
1232
+ ...update,
1233
+ value: wrappedValue
1234
+ };
1235
+ });
1236
+ onDuplicateKey = {
1237
+ ...onDuplicateKey,
1238
+ updates: wrappedUpdates
1239
+ };
1240
+ processedNode = {
1241
+ ...processedNode,
1242
+ onDuplicateKey
1243
+ };
1097
1244
  }
1098
- const processedNode = onConflict ? {
1099
- ...node,
1100
- onConflict
1101
- } : node;
1102
1245
  const result = super.transformInsertQuery(processedNode);
1103
1246
  let returning = result.returning;
1104
1247
  if (returning) {
@@ -1126,7 +1269,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1126
1269
  }
1127
1270
  }
1128
1271
  let returning = result.returning;
1129
- if (returning || this.hasPostUpdatePolicies(mutationModel)) {
1272
+ if (this.dialect.supportsReturning && (returning || this.hasPostUpdatePolicies(mutationModel))) {
1130
1273
  const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
1131
1274
  returning = ReturningNode.create(idFields.map((f) => SelectionNode2.create(ColumnNode3.create(f))));
1132
1275
  }
@@ -1163,9 +1306,9 @@ var PolicyHandler = class extends OperationNodeTransformer {
1163
1306
  }
1164
1307
  // #endregion
1165
1308
  // #region post-update
1166
- async loadBeforeUpdateEntities(model, where, proceed) {
1309
+ async loadBeforeUpdateEntities(model, where, proceed, forceLoad = false) {
1167
1310
  const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
1168
- if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
1311
+ if (!forceLoad && (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0)) {
1169
1312
  return void 0;
1170
1313
  }
1171
1314
  const policyFilter = this.buildPolicyFilter(model, model, "update");
@@ -1173,15 +1316,14 @@ var PolicyHandler = class extends OperationNodeTransformer {
1173
1316
  where.where,
1174
1317
  policyFilter
1175
1318
  ]) : policyFilter;
1319
+ const selections = beforeUpdateAccessFields ?? QueryUtils2.requireIdFields(this.client.$schema, model);
1176
1320
  const query = {
1177
1321
  kind: "SelectQueryNode",
1178
1322
  from: FromNode2.create([
1179
1323
  TableNode3.create(model)
1180
1324
  ]),
1181
1325
  where: WhereNode2.create(combinedFilter),
1182
- selections: [
1183
- ...beforeUpdateAccessFields.map((f) => SelectionNode2.create(ColumnNode3.create(f)))
1184
- ]
1326
+ selections: selections.map((f) => SelectionNode2.create(ColumnNode3.create(f)))
1185
1327
  };
1186
1328
  const result = await proceed(query);
1187
1329
  return {
@@ -1262,7 +1404,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1262
1404
  };
1263
1405
  }
1264
1406
  const eb = expressionBuilder2();
1265
- const selection = eb.case().when(new ExpressionWrapper(filter)).then(eb.ref(field)).else(null).end().as(field).toOperationNode();
1407
+ const selection = eb.case().when(new ExpressionWrapper2(filter)).then(eb.ref(field)).else(null).end().as(field).toOperationNode();
1266
1408
  return {
1267
1409
  hasPolicies: true,
1268
1410
  selection: SelectionNode2.create(selection)
@@ -1342,9 +1484,9 @@ var PolicyHandler = class extends OperationNodeTransformer {
1342
1484
  invariant3(bValue !== null && bValue !== void 0, "B value cannot be null or undefined");
1343
1485
  const eb = expressionBuilder2();
1344
1486
  const filterA = this.buildPolicyFilter(m2m.firstModel, void 0, "update");
1345
- const queryA = eb.selectFrom(m2m.firstModel).where(eb(eb.ref(`${m2m.firstModel}.${m2m.firstIdField}`), "=", aValue)).select(() => new ExpressionWrapper(filterA).as("$t"));
1487
+ const queryA = eb.selectFrom(m2m.firstModel).where(eb(eb.ref(`${m2m.firstModel}.${m2m.firstIdField}`), "=", aValue)).select(() => new ExpressionWrapper2(filterA).as("$t"));
1346
1488
  const filterB = this.buildPolicyFilter(m2m.secondModel, void 0, "update");
1347
- const queryB = eb.selectFrom(m2m.secondModel).where(eb(eb.ref(`${m2m.secondModel}.${m2m.secondIdField}`), "=", bValue)).select(() => new ExpressionWrapper(filterB).as("$t"));
1489
+ const queryB = eb.selectFrom(m2m.secondModel).where(eb(eb.ref(`${m2m.secondModel}.${m2m.secondIdField}`), "=", bValue)).select(() => new ExpressionWrapper2(filterB).as("$t"));
1348
1490
  const queryNode = {
1349
1491
  kind: "SelectQueryNode",
1350
1492
  selections: [
@@ -1361,43 +1503,24 @@ var PolicyHandler = class extends OperationNodeTransformer {
1361
1503
  }
1362
1504
  }
1363
1505
  async enforcePreCreatePolicyForOne(model, fields, values, proceed) {
1364
- const allFields = Object.entries(QueryUtils2.requireModel(this.client.$schema, model).fields).filter(([, def]) => !def.relation);
1506
+ const allFields = QueryUtils2.getModelFields(this.client.$schema, model, {
1507
+ inherited: true
1508
+ });
1365
1509
  const allValues = [];
1366
- for (const [name, _def] of allFields) {
1367
- const index = fields.indexOf(name);
1510
+ for (const def of allFields) {
1511
+ const index = fields.indexOf(def.name);
1368
1512
  if (index >= 0) {
1369
- allValues.push(values[index]);
1513
+ allValues.push(new ExpressionWrapper2(values[index]));
1370
1514
  } else {
1371
- allValues.push(ValueNode3.createImmediate(null));
1515
+ allValues.push(this.eb.lit(null));
1372
1516
  }
1373
1517
  }
1374
- const eb = expressionBuilder2();
1375
- const constTable = {
1376
- kind: "SelectQueryNode",
1377
- from: FromNode2.create([
1378
- AliasNode3.create(ParensNode2.create(ValuesNode.create([
1379
- ValueListNode2.create(allValues)
1380
- ])), IdentifierNode2.create("$t"))
1381
- ]),
1382
- selections: allFields.map(([name, def], index) => {
1383
- const castedColumnRef = sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
1384
- return SelectionNode2.create(castedColumnRef.toOperationNode());
1385
- })
1386
- };
1518
+ const valuesTable = this.dialect.buildValuesTableSelect(allFields, [
1519
+ allValues
1520
+ ]);
1387
1521
  const filter = this.buildPolicyFilter(model, void 0, "create");
1388
- const preCreateCheck = {
1389
- kind: "SelectQueryNode",
1390
- from: FromNode2.create([
1391
- AliasNode3.create(constTable, IdentifierNode2.create(model))
1392
- ]),
1393
- selections: [
1394
- SelectionNode2.create(AliasNode3.create(BinaryOperationNode3.create(FunctionNode3.create("COUNT", [
1395
- ValueNode3.createImmediate(1)
1396
- ]), OperatorNode3.create(">"), ValueNode3.createImmediate(0)), IdentifierNode2.create("$condition")))
1397
- ],
1398
- where: WhereNode2.create(filter)
1399
- };
1400
- const result = await proceed(preCreateCheck);
1522
+ const preCreateCheck = this.eb.selectFrom(valuesTable.as(model)).select(this.eb(this.eb.fn.count(this.eb.lit(1)), ">", 0).as("$condition")).where(() => new ExpressionWrapper2(filter));
1523
+ const result = await proceed(preCreateCheck.toOperationNode());
1401
1524
  if (!result.rows[0]?.$condition) {
1402
1525
  throw createRejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS);
1403
1526
  }
@@ -1422,18 +1545,19 @@ var PolicyHandler = class extends OperationNodeTransformer {
1422
1545
  const fieldDef = QueryUtils2.requireField(this.client.$schema, model, fields[i]);
1423
1546
  invariant3(item.kind === "ValueNode", "expecting a ValueNode");
1424
1547
  result.push({
1425
- node: ValueNode3.create(this.dialect.transformPrimitive(item.value, fieldDef.type, !!fieldDef.array)),
1548
+ node: ValueNode3.create(this.dialect.transformInput(item.value, fieldDef.type, !!fieldDef.array)),
1426
1549
  raw: item.value
1427
1550
  });
1428
1551
  } else {
1429
1552
  let value = item;
1430
1553
  if (!isImplicitManyToManyJoinTable) {
1431
1554
  const fieldDef = QueryUtils2.requireField(this.client.$schema, model, fields[i]);
1432
- value = this.dialect.transformPrimitive(item, fieldDef.type, !!fieldDef.array);
1555
+ value = this.dialect.transformInput(item, fieldDef.type, !!fieldDef.array);
1433
1556
  }
1434
1557
  if (Array.isArray(value)) {
1558
+ const fieldDef = QueryUtils2.requireField(this.client.$schema, model, fields[i]);
1435
1559
  result.push({
1436
- node: RawNode.createWithSql(this.dialect.buildArrayLiteralSQL(value)),
1560
+ node: this.dialect.buildArrayValue(value, fieldDef.type).toOperationNode(),
1437
1561
  raw: value
1438
1562
  });
1439
1563
  } else {
@@ -1681,11 +1805,10 @@ var PolicyHandler = class extends OperationNodeTransformer {
1681
1805
  return void 0;
1682
1806
  }
1683
1807
  const checkForOperation = operation === "read" ? "read" : "update";
1684
- const eb = expressionBuilder2();
1685
1808
  const joinTable = alias ?? tableName;
1686
- const aQuery = eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
1687
- const bQuery = eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
1688
- return eb.and([
1809
+ const aQuery = this.eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new ExpressionWrapper2(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
1810
+ const bQuery = this.eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new ExpressionWrapper2(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
1811
+ return this.eb.and([
1689
1812
  aQuery,
1690
1813
  bQuery
1691
1814
  ]).toOperationNode();
@@ -1716,6 +1839,26 @@ var PolicyHandler = class extends OperationNodeTransformer {
1716
1839
  };
1717
1840
  }
1718
1841
  }
1842
+ // strips table references from an OperationNode
1843
+ stripTableReferences(node, modelName) {
1844
+ return new TableReferenceStripper().strip(node, modelName);
1845
+ }
1846
+ };
1847
+ var TableReferenceStripper = class TableReferenceStripper2 extends OperationNodeTransformer {
1848
+ static {
1849
+ __name(this, "TableReferenceStripper");
1850
+ }
1851
+ tableName = "";
1852
+ strip(node, tableName) {
1853
+ this.tableName = tableName;
1854
+ return this.transformNode(node);
1855
+ }
1856
+ transformReference(node) {
1857
+ if (ColumnNode3.is(node.column) && node.table?.table.identifier.name === this.tableName) {
1858
+ return ReferenceNode3.create(this.transformNode(node.column));
1859
+ }
1860
+ return super.transformReference(node);
1861
+ }
1719
1862
  };
1720
1863
 
1721
1864
  // src/functions.ts
@@ -1764,7 +1907,7 @@ var check = /* @__PURE__ */ __name((eb, args, { client, model, modelAlias, opera
1764
1907
  const policyHandler = new PolicyHandler(client);
1765
1908
  const op = arg2Node ? arg2Node.value : operation;
1766
1909
  const policyCondition = policyHandler.buildPolicyFilter(relationModel, void 0, op);
1767
- const result = eb.selectFrom(relationModel).where(joinCondition).select(new ExpressionWrapper2(policyCondition).as("$condition"));
1910
+ const result = eb.selectFrom(eb.selectFrom(relationModel).where(joinCondition).select(new ExpressionWrapper3(policyCondition).as("$condition")).as("$sub")).selectAll();
1768
1911
  return result;
1769
1912
  }, "check");
1770
1913