@zenstackhq/plugin-policy 3.3.0-beta.2 → 3.3.0-beta.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -10,7 +10,7 @@ import { ExpressionWrapper as ExpressionWrapper2, ValueNode as ValueNode4 } from
10
10
  import { invariant as invariant3 } from "@zenstackhq/common-helpers";
11
11
  import { getCrudDialect as getCrudDialect2, QueryUtils as QueryUtils2, RejectedByPolicyReason, SchemaUtils as SchemaUtils2 } from "@zenstackhq/orm";
12
12
  import { ExpressionUtils as ExpressionUtils4 } from "@zenstackhq/orm/schema";
13
- import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode3, DeleteQueryNode, expressionBuilder as expressionBuilder2, ExpressionWrapper, FromNode as FromNode2, FunctionNode as FunctionNode3, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, RawNode, ReferenceNode as ReferenceNode3, ReturningNode, SelectAllNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql, TableNode as TableNode3, UpdateQueryNode, ValueListNode as ValueListNode2, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
13
+ import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode3, DeleteQueryNode, expressionBuilder as expressionBuilder2, ExpressionWrapper, FromNode as FromNode2, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, ReferenceNode as ReferenceNode3, ReturningNode, SelectAllNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql, TableNode as TableNode3, UpdateQueryNode, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
14
14
  import { match as match3 } from "ts-pattern";
15
15
 
16
16
  // src/column-collector.ts
@@ -41,14 +41,14 @@ import { match as match2 } from "ts-pattern";
41
41
 
42
42
  // src/expression-evaluator.ts
43
43
  import { invariant } from "@zenstackhq/common-helpers";
44
- import { match } from "ts-pattern";
45
44
  import { ExpressionUtils } from "@zenstackhq/orm/schema";
45
+ import { match } from "ts-pattern";
46
46
  var ExpressionEvaluator = class {
47
47
  static {
48
48
  __name(this, "ExpressionEvaluator");
49
49
  }
50
50
  evaluate(expression, context) {
51
- const result = match(expression).when(ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(ExpressionUtils.isThis, () => context.thisValue).when(ExpressionUtils.isNull, () => null).exhaustive();
51
+ const result = match(expression).when(ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(ExpressionUtils.isBinding, (expr2) => this.evaluateBinding(expr2, context)).when(ExpressionUtils.isThis, () => context.thisValue).when(ExpressionUtils.isNull, () => null).exhaustive();
52
52
  return result ?? null;
53
53
  }
54
54
  evaluateCall(expr2, context) {
@@ -72,6 +72,9 @@ var ExpressionEvaluator = class {
72
72
  return expr2.value;
73
73
  }
74
74
  evaluateField(expr2, context) {
75
+ if (context.bindingScope && expr2.field in context.bindingScope) {
76
+ return context.bindingScope[expr2.field];
77
+ }
75
78
  return context.thisValue?.[expr2.field];
76
79
  }
77
80
  evaluateArray(expr2, context) {
@@ -105,15 +108,33 @@ var ExpressionEvaluator = class {
105
108
  invariant(Array.isArray(left), "expected array");
106
109
  return match(op).with("?", () => left.some((item) => this.evaluate(expr2.right, {
107
110
  ...context,
108
- thisValue: item
111
+ thisValue: item,
112
+ bindingScope: expr2.binding ? {
113
+ ...context.bindingScope ?? {},
114
+ [expr2.binding]: item
115
+ } : context.bindingScope
109
116
  }))).with("!", () => left.every((item) => this.evaluate(expr2.right, {
110
117
  ...context,
111
- thisValue: item
118
+ thisValue: item,
119
+ bindingScope: expr2.binding ? {
120
+ ...context.bindingScope ?? {},
121
+ [expr2.binding]: item
122
+ } : context.bindingScope
112
123
  }))).with("^", () => !left.some((item) => this.evaluate(expr2.right, {
113
124
  ...context,
114
- thisValue: item
125
+ thisValue: item,
126
+ bindingScope: expr2.binding ? {
127
+ ...context.bindingScope ?? {},
128
+ [expr2.binding]: item
129
+ } : context.bindingScope
115
130
  }))).exhaustive();
116
131
  }
132
+ evaluateBinding(expr2, context) {
133
+ if (!context.bindingScope || !(expr2.name in context.bindingScope)) {
134
+ throw new Error(`Unresolved binding: ${expr2.name}`);
135
+ }
136
+ return context.bindingScope[expr2.name];
137
+ }
117
138
  };
118
139
 
119
140
  // src/types.ts
@@ -128,11 +149,11 @@ import { ORMError, ORMErrorReason } from "@zenstackhq/orm";
128
149
  import { ExpressionUtils as ExpressionUtils2 } from "@zenstackhq/orm/schema";
129
150
  import { AliasNode, AndNode, BinaryOperationNode, ColumnNode, FunctionNode, OperatorNode, OrNode, ParensNode, ReferenceNode, TableNode, UnaryOperationNode, ValueNode } from "kysely";
130
151
  function trueNode(dialect) {
131
- return ValueNode.createImmediate(dialect.transformPrimitive(true, "Boolean", false));
152
+ return ValueNode.createImmediate(dialect.transformInput(true, "Boolean", false));
132
153
  }
133
154
  __name(trueNode, "trueNode");
134
155
  function falseNode(dialect) {
135
- return ValueNode.createImmediate(dialect.transformPrimitive(false, "Boolean", false));
156
+ return ValueNode.createImmediate(dialect.transformInput(false, "Boolean", false));
136
157
  }
137
158
  __name(falseNode, "falseNode");
138
159
  function isTrueNode(node) {
@@ -426,7 +447,8 @@ var ExpressionTransformer = class {
426
447
  const evaluator = new ExpressionEvaluator();
427
448
  const receiver = evaluator.evaluate(expr2.left, {
428
449
  thisValue: context.contextValue,
429
- auth: this.auth
450
+ auth: this.auth,
451
+ bindingScope: this.getEvaluationBindingScope(context.bindingScope)
430
452
  });
431
453
  const baseType = this.isAuthMember(expr2.left) ? this.authType : context.modelOrType;
432
454
  const memberType = this.getMemberType(baseType, expr2.left);
@@ -442,18 +464,31 @@ var ExpressionTransformer = class {
442
464
  invariant2(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr2.left)}`);
443
465
  newContextModel = fieldDef.type;
444
466
  } else {
445
- invariant2(ExpressionUtils3.isMember(expr2.left) && ExpressionUtils3.isField(expr2.left.receiver), "left operand must be member access with field receiver");
446
- const fieldDef2 = QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
447
- newContextModel = fieldDef2.type;
467
+ invariant2(ExpressionUtils3.isMember(expr2.left) && (ExpressionUtils3.isField(expr2.left.receiver) || ExpressionUtils3.isBinding(expr2.left.receiver)), "left operand must be member access with field receiver");
468
+ if (ExpressionUtils3.isField(expr2.left.receiver)) {
469
+ const fieldDef2 = QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
470
+ newContextModel = fieldDef2.type;
471
+ } else {
472
+ const binding = this.requireBindingScope(expr2.left.receiver, context);
473
+ newContextModel = binding.type;
474
+ }
448
475
  for (const member of expr2.left.members) {
449
476
  const memberDef = QueryUtils.requireField(this.schema, newContextModel, member);
450
477
  newContextModel = memberDef.type;
451
478
  }
452
479
  }
480
+ const bindingScope = expr2.binding ? {
481
+ ...context.bindingScope ?? {},
482
+ [expr2.binding]: {
483
+ type: newContextModel,
484
+ alias: newContextModel
485
+ }
486
+ } : context.bindingScope;
453
487
  let predicateFilter = this.transform(expr2.right, {
454
488
  ...context,
455
489
  modelOrType: newContextModel,
456
- alias: void 0
490
+ alias: void 0,
491
+ bindingScope
457
492
  });
458
493
  if (expr2.op === "!") {
459
494
  predicateFilter = logicalNot(this.dialect, predicateFilter);
@@ -480,18 +515,30 @@ var ExpressionTransformer = class {
480
515
  if (!visitor.find(expr2.right)) {
481
516
  const value = new ExpressionEvaluator().evaluate(expr2, {
482
517
  auth: this.auth,
483
- thisValue: context.contextValue
518
+ thisValue: context.contextValue,
519
+ bindingScope: this.getEvaluationBindingScope(context.bindingScope)
484
520
  });
485
521
  return this.transformValue(value, "Boolean");
486
522
  } else {
487
523
  invariant2(Array.isArray(receiver), "array value is expected");
488
- const components = receiver.map((item) => this.transform(expr2.right, {
489
- operation: context.operation,
490
- thisType: context.thisType,
491
- thisAlias: context.thisAlias,
492
- modelOrType: context.modelOrType,
493
- contextValue: item
494
- }));
524
+ const components = receiver.map((item) => {
525
+ const bindingScope = expr2.binding ? {
526
+ ...context.bindingScope ?? {},
527
+ [expr2.binding]: {
528
+ type: context.modelOrType,
529
+ alias: context.thisAlias ?? context.modelOrType,
530
+ value: item
531
+ }
532
+ } : context.bindingScope;
533
+ return this.transform(expr2.right, {
534
+ operation: context.operation,
535
+ thisType: context.thisType,
536
+ thisAlias: context.thisAlias,
537
+ modelOrType: context.modelOrType,
538
+ contextValue: item,
539
+ bindingScope
540
+ });
541
+ });
495
542
  return match2(expr2.op).with("?", () => disjunction(this.dialect, components)).with("!", () => conjunction(this.dialect, components)).with("^", () => logicalNot(this.dialect, disjunction(this.dialect, components))).exhaustive();
496
543
  }
497
544
  }
@@ -558,7 +605,7 @@ var ExpressionTransformer = class {
558
605
  } else if (value === false) {
559
606
  return falseNode(this.dialect);
560
607
  } else {
561
- const transformed = this.dialect.transformPrimitive(value, type, false) ?? null;
608
+ const transformed = this.dialect.transformInput(value, type, false) ?? null;
562
609
  if (!Array.isArray(transformed)) {
563
610
  return ValueNode2.createImmediate(transformed);
564
611
  } else {
@@ -621,6 +668,12 @@ var ExpressionTransformer = class {
621
668
  throw createUnsupportedError(`Unsupported argument expression: ${arg.kind}`);
622
669
  }
623
670
  _member(expr2, context) {
671
+ if (ExpressionUtils3.isBinding(expr2.receiver)) {
672
+ const scope = this.requireBindingScope(expr2.receiver, context);
673
+ if (scope.value !== void 0) {
674
+ return this.valueMemberAccess(scope.value, expr2, scope.type);
675
+ }
676
+ }
624
677
  if (this.isAuthCall(expr2.receiver)) {
625
678
  return this.valueMemberAccess(this.auth, expr2, this.authType);
626
679
  }
@@ -629,9 +682,10 @@ var ExpressionTransformer = class {
629
682
  invariant2(expr2.members.length === 1, "before() can only be followed by a scalar field access");
630
683
  return ReferenceNode2.create(ColumnNode2.create(expr2.members[0]), TableNode2.create("$before"));
631
684
  }
632
- invariant2(ExpressionUtils3.isField(expr2.receiver) || ExpressionUtils3.isThis(expr2.receiver), 'expect receiver to be field expression or "this"');
685
+ invariant2(ExpressionUtils3.isField(expr2.receiver) || ExpressionUtils3.isThis(expr2.receiver) || ExpressionUtils3.isBinding(expr2.receiver), 'expect receiver to be field expression, collection predicate binding, or "this"');
633
686
  let members = expr2.members;
634
687
  let receiver;
688
+ let startType;
635
689
  const { memberFilter, memberSelect, ...restContext } = context;
636
690
  if (ExpressionUtils3.isThis(expr2.receiver)) {
637
691
  if (expr2.members.length === 1) {
@@ -646,17 +700,40 @@ var ExpressionTransformer = class {
646
700
  const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.thisType, expr2.members[0]);
647
701
  receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, restContext);
648
702
  members = expr2.members.slice(1);
703
+ startType = firstMemberFieldDef.type;
704
+ }
705
+ } else if (ExpressionUtils3.isBinding(expr2.receiver)) {
706
+ if (expr2.members.length === 1) {
707
+ const bindingScope = this.requireBindingScope(expr2.receiver, context);
708
+ return this._field(ExpressionUtils3.field(expr2.members[0]), {
709
+ ...context,
710
+ modelOrType: bindingScope.type,
711
+ alias: bindingScope.alias,
712
+ thisType: context.thisType,
713
+ contextValue: void 0
714
+ });
715
+ } else {
716
+ const bindingScope = this.requireBindingScope(expr2.receiver, context);
717
+ const firstMemberFieldDef = QueryUtils.requireField(this.schema, bindingScope.type, expr2.members[0]);
718
+ receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, {
719
+ ...restContext,
720
+ modelOrType: bindingScope.type,
721
+ alias: bindingScope.alias
722
+ });
723
+ members = expr2.members.slice(1);
724
+ startType = firstMemberFieldDef.type;
649
725
  }
650
726
  } else {
651
727
  receiver = this.transform(expr2.receiver, restContext);
652
728
  }
653
729
  invariant2(SelectQueryNode.is(receiver), "expected receiver to be select query");
654
- let startType;
655
- if (ExpressionUtils3.isField(expr2.receiver)) {
656
- const receiverField = QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
657
- startType = receiverField.type;
658
- } else {
659
- startType = context.thisType;
730
+ if (startType === void 0) {
731
+ if (ExpressionUtils3.isField(expr2.receiver)) {
732
+ const receiverField = QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
733
+ startType = receiverField.type;
734
+ } else {
735
+ startType = context.thisType;
736
+ }
660
737
  }
661
738
  const memberFields = [];
662
739
  let currType = startType;
@@ -707,6 +784,11 @@ var ExpressionTransformer = class {
707
784
  ]
708
785
  };
709
786
  }
787
+ requireBindingScope(expr2, context) {
788
+ const binding = context.bindingScope?.[expr2.name];
789
+ invariant2(binding, `binding not found: ${expr2.name}`);
790
+ return binding;
791
+ }
710
792
  valueMemberAccess(receiver, expr2, receiverType) {
711
793
  if (!receiver) {
712
794
  return ValueNode2.createImmediate(null);
@@ -772,6 +854,19 @@ var ExpressionTransformer = class {
772
854
  }
773
855
  return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel);
774
856
  }
857
+ // convert transformer's binding scope to equivalent expression evaluator binding scope
858
+ getEvaluationBindingScope(scope) {
859
+ if (!scope) {
860
+ return void 0;
861
+ }
862
+ const result = {};
863
+ for (const [key, value] of Object.entries(scope)) {
864
+ if (value.value !== void 0) {
865
+ result[key] = value.value;
866
+ }
867
+ }
868
+ return Object.keys(result).length > 0 ? result : void 0;
869
+ }
775
870
  buildDelegateBaseFieldSelect(model, modelAlias, field, baseModel) {
776
871
  const idFields = QueryUtils.requireIdFields(this.client.$schema, model);
777
872
  return {
@@ -896,6 +991,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
896
991
  }
897
992
  client;
898
993
  dialect;
994
+ eb = expressionBuilder2();
899
995
  constructor(client) {
900
996
  super(), this.client = client;
901
997
  this.dialect = getCrudDialect2(this.client.$schema, this.client.$options);
@@ -919,52 +1015,21 @@ var PolicyHandler = class extends OperationNodeTransformer {
919
1015
  if (UpdateQueryNode.is(node)) {
920
1016
  await this.preUpdateCheck(mutationModel, node, proceed);
921
1017
  }
922
- const hasPostUpdatePolicies = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
1018
+ const needsPostUpdateCheck = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
923
1019
  let beforeUpdateInfo;
924
- if (hasPostUpdatePolicies) {
925
- beforeUpdateInfo = await this.loadBeforeUpdateEntities(mutationModel, node.where, proceed);
1020
+ if (needsPostUpdateCheck) {
1021
+ beforeUpdateInfo = await this.loadBeforeUpdateEntities(
1022
+ mutationModel,
1023
+ node.where,
1024
+ proceed,
1025
+ // force load pre-update entities if dialect doesn't support returning,
1026
+ // so we can rely on pre-update ids to read back updated entities
1027
+ !this.dialect.supportsReturning
1028
+ );
926
1029
  }
927
1030
  const result = await proceed(this.transformNode(node));
928
- if (hasPostUpdatePolicies && result.rows.length > 0) {
929
- if (beforeUpdateInfo) {
930
- invariant3(beforeUpdateInfo.rows.length === result.rows.length);
931
- const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
932
- for (const postRow of result.rows) {
933
- const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
934
- if (!beforeRow) {
935
- throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.");
936
- }
937
- }
938
- }
939
- const idConditions = this.buildIdConditions(mutationModel, result.rows);
940
- const postUpdateFilter = this.buildPolicyFilter(mutationModel, void 0, "post-update");
941
- const eb = expressionBuilder2();
942
- const beforeUpdateTable = beforeUpdateInfo ? {
943
- kind: "SelectQueryNode",
944
- from: FromNode2.create([
945
- ParensNode2.create(ValuesNode.create(beforeUpdateInfo.rows.map((r) => PrimitiveValueListNode.create(beforeUpdateInfo.fields.map((f) => r[f])))))
946
- ]),
947
- selections: beforeUpdateInfo.fields.map((name, index) => {
948
- const def = QueryUtils2.requireField(this.client.$schema, mutationModel, name);
949
- const castedColumnRef = sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
950
- return SelectionNode2.create(castedColumnRef.toOperationNode());
951
- })
952
- } : void 0;
953
- const postUpdateQuery = eb.selectFrom(mutationModel).select(() => [
954
- eb(eb.fn("COUNT", [
955
- eb.lit(1)
956
- ]), "=", result.rows.length).as("$condition")
957
- ]).where(() => new ExpressionWrapper(conjunction(this.dialect, [
958
- idConditions,
959
- postUpdateFilter
960
- ]))).$if(!!beforeUpdateInfo, (qb) => qb.leftJoin(() => new ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
961
- const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
962
- return idFields.reduce((acc, f) => acc.onRef(`${mutationModel}.${f}`, "=", `$before.${f}`), join);
963
- }));
964
- const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
965
- if (!postUpdateResult.rows[0]?.$condition) {
966
- throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
967
- }
1031
+ if ((result.numAffectedRows ?? 0) > 0 && needsPostUpdateCheck) {
1032
+ await this.postUpdateCheck(mutationModel, beforeUpdateInfo, result, proceed);
968
1033
  }
969
1034
  if (!node.returning || this.onlyReturningId(node)) {
970
1035
  return this.postProcessMutationResult(result, node);
@@ -1003,12 +1068,69 @@ var PolicyHandler = class extends OperationNodeTransformer {
1003
1068
  modelLevelFilter,
1004
1069
  node.where?.where ?? trueNode(this.dialect)
1005
1070
  ]);
1006
- const preUpdateCheckQuery = expressionBuilder2().selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(eb.cast(new ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)), "integer")), eb.lit(0)).as("$filteredCount")).where(() => new ExpressionWrapper(updateFilter));
1071
+ const preUpdateCheckQuery = this.eb.selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(this.dialect.castInt(new ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)))), eb.lit(0)).as("$filteredCount")).where(() => new ExpressionWrapper(updateFilter));
1007
1072
  const preUpdateResult = await proceed(preUpdateCheckQuery.toOperationNode());
1008
1073
  if (preUpdateResult.rows[0].$filteredCount > 0) {
1009
1074
  throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS, "some rows cannot be updated due to field policies");
1010
1075
  }
1011
1076
  }
1077
+ async postUpdateCheck(model, beforeUpdateInfo, updateResult, proceed) {
1078
+ let postUpdateRows;
1079
+ if (this.dialect.supportsReturning) {
1080
+ postUpdateRows = updateResult.rows;
1081
+ } else {
1082
+ invariant3(beforeUpdateInfo, "beforeUpdateInfo must be defined for dialects not supporting returning");
1083
+ const idConditions2 = this.buildIdConditions(model, beforeUpdateInfo.rows);
1084
+ const idFields = QueryUtils2.requireIdFields(this.client.$schema, model);
1085
+ const postUpdateQuery2 = {
1086
+ kind: "SelectQueryNode",
1087
+ from: FromNode2.create([
1088
+ TableNode3.create(model)
1089
+ ]),
1090
+ where: WhereNode2.create(idConditions2),
1091
+ selections: idFields.map((field) => SelectionNode2.create(ColumnNode3.create(field)))
1092
+ };
1093
+ const postUpdateQueryResult = await proceed(postUpdateQuery2);
1094
+ postUpdateRows = postUpdateQueryResult.rows;
1095
+ }
1096
+ if (beforeUpdateInfo) {
1097
+ if (beforeUpdateInfo.rows.length !== postUpdateRows.length) {
1098
+ throw createRejectedByPolicyError(model, RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
1099
+ }
1100
+ const idFields = QueryUtils2.requireIdFields(this.client.$schema, model);
1101
+ for (const postRow of postUpdateRows) {
1102
+ const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
1103
+ if (!beforeRow) {
1104
+ throw createRejectedByPolicyError(model, RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
1105
+ }
1106
+ }
1107
+ }
1108
+ const idConditions = this.buildIdConditions(model, postUpdateRows);
1109
+ const postUpdateFilter = this.buildPolicyFilter(model, void 0, "post-update");
1110
+ const eb = expressionBuilder2();
1111
+ const needsBeforeUpdateJoin = !!beforeUpdateInfo?.fields;
1112
+ let beforeUpdateTable = void 0;
1113
+ if (needsBeforeUpdateJoin) {
1114
+ const fieldDefs = beforeUpdateInfo.fields.map((name) => QueryUtils2.requireField(this.client.$schema, model, name));
1115
+ const rows = beforeUpdateInfo.rows.map((r) => beforeUpdateInfo.fields.map((f) => r[f]));
1116
+ beforeUpdateTable = this.dialect.buildValuesTableSelect(fieldDefs, rows).toOperationNode();
1117
+ }
1118
+ const postUpdateQuery = eb.selectFrom(model).select(() => [
1119
+ eb(eb.fn("COUNT", [
1120
+ eb.lit(1)
1121
+ ]), "=", Number(updateResult.numAffectedRows ?? 0)).as("$condition")
1122
+ ]).where(() => new ExpressionWrapper(conjunction(this.dialect, [
1123
+ idConditions,
1124
+ postUpdateFilter
1125
+ ]))).$if(needsBeforeUpdateJoin, (qb) => qb.leftJoin(() => new ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
1126
+ const idFields = QueryUtils2.requireIdFields(this.client.$schema, model);
1127
+ return idFields.reduce((acc, f) => acc.onRef(`${model}.${f}`, "=", `$before.${f}`), join);
1128
+ }));
1129
+ const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
1130
+ if (!postUpdateResult.rows[0]?.$condition) {
1131
+ throw createRejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
1132
+ }
1133
+ }
1012
1134
  // #endregion
1013
1135
  // #region Transformations
1014
1136
  transformSelectQuery(node) {
@@ -1076,6 +1198,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1076
1198
  };
1077
1199
  }
1078
1200
  transformInsertQuery(node) {
1201
+ let processedNode = node;
1079
1202
  let onConflict = node.onConflict;
1080
1203
  if (onConflict?.updates) {
1081
1204
  const { mutationModel, alias } = this.getMutationModel(node);
@@ -1094,11 +1217,36 @@ var PolicyHandler = class extends OperationNodeTransformer {
1094
1217
  updateWhere: WhereNode2.create(filter)
1095
1218
  };
1096
1219
  }
1220
+ processedNode = {
1221
+ ...node,
1222
+ onConflict
1223
+ };
1224
+ }
1225
+ let onDuplicateKey = node.onDuplicateKey;
1226
+ if (onDuplicateKey?.updates) {
1227
+ const { mutationModel } = this.getMutationModel(node);
1228
+ const filterWithTableRef = this.buildPolicyFilter(mutationModel, void 0, "update");
1229
+ const filter = this.stripTableReferences(filterWithTableRef, mutationModel);
1230
+ const wrappedUpdates = onDuplicateKey.updates.map((update) => {
1231
+ const columnName = ColumnNode3.is(update.column) ? update.column.column.name : void 0;
1232
+ if (!columnName) {
1233
+ return update;
1234
+ }
1235
+ const wrappedValue = sql`IF(${new ExpressionWrapper(filter)}, ${new ExpressionWrapper(update.value)}, ${sql.ref(columnName)})`.toOperationNode();
1236
+ return {
1237
+ ...update,
1238
+ value: wrappedValue
1239
+ };
1240
+ });
1241
+ onDuplicateKey = {
1242
+ ...onDuplicateKey,
1243
+ updates: wrappedUpdates
1244
+ };
1245
+ processedNode = {
1246
+ ...processedNode,
1247
+ onDuplicateKey
1248
+ };
1097
1249
  }
1098
- const processedNode = onConflict ? {
1099
- ...node,
1100
- onConflict
1101
- } : node;
1102
1250
  const result = super.transformInsertQuery(processedNode);
1103
1251
  let returning = result.returning;
1104
1252
  if (returning) {
@@ -1126,7 +1274,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
1126
1274
  }
1127
1275
  }
1128
1276
  let returning = result.returning;
1129
- if (returning || this.hasPostUpdatePolicies(mutationModel)) {
1277
+ if (this.dialect.supportsReturning && (returning || this.hasPostUpdatePolicies(mutationModel))) {
1130
1278
  const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
1131
1279
  returning = ReturningNode.create(idFields.map((f) => SelectionNode2.create(ColumnNode3.create(f))));
1132
1280
  }
@@ -1163,9 +1311,9 @@ var PolicyHandler = class extends OperationNodeTransformer {
1163
1311
  }
1164
1312
  // #endregion
1165
1313
  // #region post-update
1166
- async loadBeforeUpdateEntities(model, where, proceed) {
1314
+ async loadBeforeUpdateEntities(model, where, proceed, forceLoad = false) {
1167
1315
  const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
1168
- if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
1316
+ if (!forceLoad && (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0)) {
1169
1317
  return void 0;
1170
1318
  }
1171
1319
  const policyFilter = this.buildPolicyFilter(model, model, "update");
@@ -1173,15 +1321,14 @@ var PolicyHandler = class extends OperationNodeTransformer {
1173
1321
  where.where,
1174
1322
  policyFilter
1175
1323
  ]) : policyFilter;
1324
+ const selections = beforeUpdateAccessFields ?? QueryUtils2.requireIdFields(this.client.$schema, model);
1176
1325
  const query = {
1177
1326
  kind: "SelectQueryNode",
1178
1327
  from: FromNode2.create([
1179
1328
  TableNode3.create(model)
1180
1329
  ]),
1181
1330
  where: WhereNode2.create(combinedFilter),
1182
- selections: [
1183
- ...beforeUpdateAccessFields.map((f) => SelectionNode2.create(ColumnNode3.create(f)))
1184
- ]
1331
+ selections: selections.map((f) => SelectionNode2.create(ColumnNode3.create(f)))
1185
1332
  };
1186
1333
  const result = await proceed(query);
1187
1334
  return {
@@ -1361,43 +1508,24 @@ var PolicyHandler = class extends OperationNodeTransformer {
1361
1508
  }
1362
1509
  }
1363
1510
  async enforcePreCreatePolicyForOne(model, fields, values, proceed) {
1364
- const allFields = Object.entries(QueryUtils2.requireModel(this.client.$schema, model).fields).filter(([, def]) => !def.relation);
1511
+ const allFields = QueryUtils2.getModelFields(this.client.$schema, model, {
1512
+ inherited: true
1513
+ });
1365
1514
  const allValues = [];
1366
- for (const [name, _def] of allFields) {
1367
- const index = fields.indexOf(name);
1515
+ for (const def of allFields) {
1516
+ const index = fields.indexOf(def.name);
1368
1517
  if (index >= 0) {
1369
- allValues.push(values[index]);
1518
+ allValues.push(new ExpressionWrapper(values[index]));
1370
1519
  } else {
1371
- allValues.push(ValueNode3.createImmediate(null));
1520
+ allValues.push(this.eb.lit(null));
1372
1521
  }
1373
1522
  }
1374
- const eb = expressionBuilder2();
1375
- const constTable = {
1376
- kind: "SelectQueryNode",
1377
- from: FromNode2.create([
1378
- AliasNode3.create(ParensNode2.create(ValuesNode.create([
1379
- ValueListNode2.create(allValues)
1380
- ])), IdentifierNode2.create("$t"))
1381
- ]),
1382
- selections: allFields.map(([name, def], index) => {
1383
- const castedColumnRef = sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
1384
- return SelectionNode2.create(castedColumnRef.toOperationNode());
1385
- })
1386
- };
1523
+ const valuesTable = this.dialect.buildValuesTableSelect(allFields, [
1524
+ allValues
1525
+ ]);
1387
1526
  const filter = this.buildPolicyFilter(model, void 0, "create");
1388
- const preCreateCheck = {
1389
- kind: "SelectQueryNode",
1390
- from: FromNode2.create([
1391
- AliasNode3.create(constTable, IdentifierNode2.create(model))
1392
- ]),
1393
- selections: [
1394
- SelectionNode2.create(AliasNode3.create(BinaryOperationNode3.create(FunctionNode3.create("COUNT", [
1395
- ValueNode3.createImmediate(1)
1396
- ]), OperatorNode3.create(">"), ValueNode3.createImmediate(0)), IdentifierNode2.create("$condition")))
1397
- ],
1398
- where: WhereNode2.create(filter)
1399
- };
1400
- const result = await proceed(preCreateCheck);
1527
+ const preCreateCheck = this.eb.selectFrom(valuesTable.as(model)).select(this.eb(this.eb.fn.count(this.eb.lit(1)), ">", 0).as("$condition")).where(() => new ExpressionWrapper(filter));
1528
+ const result = await proceed(preCreateCheck.toOperationNode());
1401
1529
  if (!result.rows[0]?.$condition) {
1402
1530
  throw createRejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS);
1403
1531
  }
@@ -1422,18 +1550,18 @@ var PolicyHandler = class extends OperationNodeTransformer {
1422
1550
  const fieldDef = QueryUtils2.requireField(this.client.$schema, model, fields[i]);
1423
1551
  invariant3(item.kind === "ValueNode", "expecting a ValueNode");
1424
1552
  result.push({
1425
- node: ValueNode3.create(this.dialect.transformPrimitive(item.value, fieldDef.type, !!fieldDef.array)),
1553
+ node: ValueNode3.create(this.dialect.transformInput(item.value, fieldDef.type, !!fieldDef.array)),
1426
1554
  raw: item.value
1427
1555
  });
1428
1556
  } else {
1429
1557
  let value = item;
1430
1558
  if (!isImplicitManyToManyJoinTable) {
1431
1559
  const fieldDef = QueryUtils2.requireField(this.client.$schema, model, fields[i]);
1432
- value = this.dialect.transformPrimitive(item, fieldDef.type, !!fieldDef.array);
1560
+ value = this.dialect.transformInput(item, fieldDef.type, !!fieldDef.array);
1433
1561
  }
1434
1562
  if (Array.isArray(value)) {
1435
1563
  result.push({
1436
- node: RawNode.createWithSql(this.dialect.buildArrayLiteralSQL(value)),
1564
+ node: this.dialect.buildArrayLiteralSQL(value).toOperationNode(),
1437
1565
  raw: value
1438
1566
  });
1439
1567
  } else {
@@ -1681,11 +1809,10 @@ var PolicyHandler = class extends OperationNodeTransformer {
1681
1809
  return void 0;
1682
1810
  }
1683
1811
  const checkForOperation = operation === "read" ? "read" : "update";
1684
- const eb = expressionBuilder2();
1685
1812
  const joinTable = alias ?? tableName;
1686
- const aQuery = eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
1687
- const bQuery = eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
1688
- return eb.and([
1813
+ const aQuery = this.eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
1814
+ const bQuery = this.eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
1815
+ return this.eb.and([
1689
1816
  aQuery,
1690
1817
  bQuery
1691
1818
  ]).toOperationNode();
@@ -1716,6 +1843,26 @@ var PolicyHandler = class extends OperationNodeTransformer {
1716
1843
  };
1717
1844
  }
1718
1845
  }
1846
+ // strips table references from an OperationNode
1847
+ stripTableReferences(node, modelName) {
1848
+ return new TableReferenceStripper().strip(node, modelName);
1849
+ }
1850
+ };
1851
+ var TableReferenceStripper = class TableReferenceStripper2 extends OperationNodeTransformer {
1852
+ static {
1853
+ __name(this, "TableReferenceStripper");
1854
+ }
1855
+ tableName = "";
1856
+ strip(node, tableName) {
1857
+ this.tableName = tableName;
1858
+ return this.transformNode(node);
1859
+ }
1860
+ transformReference(node) {
1861
+ if (ColumnNode3.is(node.column) && node.table?.table.identifier.name === this.tableName) {
1862
+ return ReferenceNode3.create(this.transformNode(node.column));
1863
+ }
1864
+ return super.transformReference(node);
1865
+ }
1719
1866
  };
1720
1867
 
1721
1868
  // src/functions.ts
@@ -1764,7 +1911,7 @@ var check = /* @__PURE__ */ __name((eb, args, { client, model, modelAlias, opera
1764
1911
  const policyHandler = new PolicyHandler(client);
1765
1912
  const op = arg2Node ? arg2Node.value : operation;
1766
1913
  const policyCondition = policyHandler.buildPolicyFilter(relationModel, void 0, op);
1767
- const result = eb.selectFrom(relationModel).where(joinCondition).select(new ExpressionWrapper2(policyCondition).as("$condition"));
1914
+ const result = eb.selectFrom(eb.selectFrom(relationModel).where(joinCondition).select(new ExpressionWrapper2(policyCondition).as("$condition")).as("$sub")).selectAll();
1768
1915
  return result;
1769
1916
  }, "check");
1770
1917