@zenstackhq/runtime 3.0.0-beta.7 → 3.0.0-beta.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -11,7 +11,7 @@ import { CompiledQuery, DefaultConnectionProvider, DefaultQueryExecutor as Defau
11
11
 
12
12
  // src/client/crud/operations/aggregate.ts
13
13
  import { sql as sql6 } from "kysely";
14
- import { match as match10 } from "ts-pattern";
14
+ import { match as match11 } from "ts-pattern";
15
15
 
16
16
  // src/client/query-utils.ts
17
17
  import { invariant } from "@zenstackhq/common-helpers";
@@ -116,7 +116,12 @@ function fieldsToSelectObject(fields) {
116
116
  __name(fieldsToSelectObject, "fieldsToSelectObject");
117
117
 
118
118
  // src/client/errors.ts
119
- var InputValidationError = class extends Error {
119
+ var ZenStackError = class extends Error {
120
+ static {
121
+ __name(this, "ZenStackError");
122
+ }
123
+ };
124
+ var InputValidationError = class extends ZenStackError {
120
125
  static {
121
126
  __name(this, "InputValidationError");
122
127
  }
@@ -126,7 +131,7 @@ var InputValidationError = class extends Error {
126
131
  });
127
132
  }
128
133
  };
129
- var QueryError = class extends Error {
134
+ var QueryError = class extends ZenStackError {
130
135
  static {
131
136
  __name(this, "QueryError");
132
137
  }
@@ -136,12 +141,12 @@ var QueryError = class extends Error {
136
141
  });
137
142
  }
138
143
  };
139
- var InternalError = class extends Error {
144
+ var InternalError = class extends ZenStackError {
140
145
  static {
141
146
  __name(this, "InternalError");
142
147
  }
143
148
  };
144
- var NotFoundError = class extends Error {
149
+ var NotFoundError = class extends ZenStackError {
145
150
  static {
146
151
  __name(this, "NotFoundError");
147
152
  }
@@ -482,7 +487,7 @@ import { createId } from "@paralleldrive/cuid2";
482
487
  import { invariant as invariant9, isPlainObject as isPlainObject3 } from "@zenstackhq/common-helpers";
483
488
  import { expressionBuilder as expressionBuilder4, sql as sql5 } from "kysely";
484
489
  import { nanoid } from "nanoid";
485
- import { match as match9 } from "ts-pattern";
490
+ import { match as match10 } from "ts-pattern";
486
491
  import { ulid } from "ulid";
487
492
  import * as uuid from "uuid";
488
493
 
@@ -493,7 +498,7 @@ var RejectedByPolicyReason = /* @__PURE__ */ function(RejectedByPolicyReason2) {
493
498
  RejectedByPolicyReason2["OTHER"] = "other";
494
499
  return RejectedByPolicyReason2;
495
500
  }({});
496
- var RejectedByPolicyError = class extends Error {
501
+ var RejectedByPolicyError = class extends ZenStackError {
497
502
  static {
498
503
  __name(this, "RejectedByPolicyError");
499
504
  }
@@ -523,6 +528,10 @@ var CRUD = [
523
528
  "update",
524
529
  "delete"
525
530
  ];
531
+ var CRUD_EXT = [
532
+ ...CRUD,
533
+ "post-update"
534
+ ];
526
535
 
527
536
  // src/client/kysely-utils.ts
528
537
  import { AliasNode, ColumnNode, ReferenceNode, TableNode } from "kysely";
@@ -558,8 +567,8 @@ __name(extractFieldName, "extractFieldName");
558
567
 
559
568
  // src/plugins/policy/policy-handler.ts
560
569
  import { invariant as invariant7 } from "@zenstackhq/common-helpers";
561
- import { AliasNode as AliasNode4, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode3, DeleteQueryNode, expressionBuilder as expressionBuilder3, ExpressionWrapper, FromNode as FromNode2, FunctionNode as FunctionNode3, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, RawNode, ReturningNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql as sql4, TableNode as TableNode4, UpdateQueryNode, ValueListNode as ValueListNode2, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
562
- import { match as match8 } from "ts-pattern";
570
+ import { AliasNode as AliasNode4, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode3, DeleteQueryNode, expressionBuilder as expressionBuilder3, ExpressionWrapper, FromNode as FromNode2, FunctionNode as FunctionNode3, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, RawNode, ReferenceNode as ReferenceNode4, ReturningNode, SelectAllNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql as sql4, TableNode as TableNode4, UpdateQueryNode, ValueListNode as ValueListNode2, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
571
+ import { match as match9 } from "ts-pattern";
563
572
 
564
573
  // src/client/crud/dialects/index.ts
565
574
  import { match as match5 } from "ts-pattern";
@@ -1706,6 +1715,59 @@ function getCrudDialect(schema, options) {
1706
1715
  }
1707
1716
  __name(getCrudDialect, "getCrudDialect");
1708
1717
 
1718
+ // src/utils/expression-utils.ts
1719
+ import { match as match6 } from "ts-pattern";
1720
+ var ExpressionVisitor = class {
1721
+ static {
1722
+ __name(this, "ExpressionVisitor");
1723
+ }
1724
+ visit(expr2) {
1725
+ match6(expr2).with({
1726
+ kind: "literal"
1727
+ }, (e) => this.visitLiteral(e)).with({
1728
+ kind: "array"
1729
+ }, (e) => this.visitArray(e)).with({
1730
+ kind: "field"
1731
+ }, (e) => this.visitField(e)).with({
1732
+ kind: "member"
1733
+ }, (e) => this.visitMember(e)).with({
1734
+ kind: "binary"
1735
+ }, (e) => this.visitBinary(e)).with({
1736
+ kind: "unary"
1737
+ }, (e) => this.visitUnary(e)).with({
1738
+ kind: "call"
1739
+ }, (e) => this.visitCall(e)).with({
1740
+ kind: "this"
1741
+ }, (e) => this.visitThis(e)).with({
1742
+ kind: "null"
1743
+ }, (e) => this.visitNull(e)).exhaustive();
1744
+ }
1745
+ visitLiteral(_e) {
1746
+ }
1747
+ visitArray(e) {
1748
+ e.items.forEach((item) => this.visit(item));
1749
+ }
1750
+ visitField(_e) {
1751
+ }
1752
+ visitMember(e) {
1753
+ this.visit(e.receiver);
1754
+ }
1755
+ visitBinary(e) {
1756
+ this.visit(e.left);
1757
+ this.visit(e.right);
1758
+ }
1759
+ visitUnary(e) {
1760
+ this.visit(e.operand);
1761
+ }
1762
+ visitCall(e) {
1763
+ e.args?.forEach((arg) => this.visit(arg));
1764
+ }
1765
+ visitThis(_e) {
1766
+ }
1767
+ visitNull(_e) {
1768
+ }
1769
+ };
1770
+
1709
1771
  // src/utils/default-operation-node-visitor.ts
1710
1772
  import { OperationNodeVisitor } from "kysely";
1711
1773
  var DefaultOperationNodeVisitor = class extends OperationNodeVisitor {
@@ -2027,17 +2089,17 @@ var ColumnCollector = class extends DefaultOperationNodeVisitor {
2027
2089
  // src/plugins/policy/expression-transformer.ts
2028
2090
  import { invariant as invariant6 } from "@zenstackhq/common-helpers";
2029
2091
  import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode2, ColumnNode as ColumnNode2, expressionBuilder as expressionBuilder2, FromNode, FunctionNode as FunctionNode2, IdentifierNode, OperatorNode as OperatorNode2, ReferenceNode as ReferenceNode3, SelectionNode, SelectQueryNode, TableNode as TableNode3, ValueListNode, ValueNode as ValueNode2, WhereNode } from "kysely";
2030
- import { match as match7 } from "ts-pattern";
2092
+ import { match as match8 } from "ts-pattern";
2031
2093
 
2032
2094
  // src/plugins/policy/expression-evaluator.ts
2033
2095
  import { invariant as invariant5 } from "@zenstackhq/common-helpers";
2034
- import { match as match6 } from "ts-pattern";
2096
+ import { match as match7 } from "ts-pattern";
2035
2097
  var ExpressionEvaluator = class {
2036
2098
  static {
2037
2099
  __name(this, "ExpressionEvaluator");
2038
2100
  }
2039
2101
  evaluate(expression, context) {
2040
- const result = match6(expression).when(ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(ExpressionUtils.isThis, () => context.thisValue).when(ExpressionUtils.isNull, () => null).exhaustive();
2102
+ const result = match7(expression).when(ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(ExpressionUtils.isThis, () => context.thisValue).when(ExpressionUtils.isNull, () => null).exhaustive();
2041
2103
  return result ?? null;
2042
2104
  }
2043
2105
  evaluateCall(expr2, context) {
@@ -2048,7 +2110,7 @@ var ExpressionEvaluator = class {
2048
2110
  }
2049
2111
  }
2050
2112
  evaluateUnary(expr2, context) {
2051
- return match6(expr2.op).with("!", () => !this.evaluate(expr2.operand, context)).exhaustive();
2113
+ return match7(expr2.op).with("!", () => !this.evaluate(expr2.operand, context)).exhaustive();
2052
2114
  }
2053
2115
  evaluateMember(expr2, context) {
2054
2116
  let val = this.evaluate(expr2.receiver, context);
@@ -2072,7 +2134,7 @@ var ExpressionEvaluator = class {
2072
2134
  }
2073
2135
  const left = this.evaluate(expr2.left, context);
2074
2136
  const right = this.evaluate(expr2.right, context);
2075
- return match6(expr2.op).with("==", () => left === right).with("!=", () => left !== right).with(">", () => left > right).with(">=", () => left >= right).with("<", () => left < right).with("<=", () => left <= right).with("&&", () => left && right).with("||", () => left || right).with("in", () => {
2137
+ return match7(expr2.op).with("==", () => left === right).with("!=", () => left !== right).with(">", () => left > right).with(">=", () => left >= right).with("<", () => left < right).with("<=", () => left <= right).with("&&", () => left && right).with("||", () => left || right).with("in", () => {
2076
2138
  const _right = right ?? [];
2077
2139
  invariant5(Array.isArray(_right), 'expected array for "in" operator');
2078
2140
  return _right.includes(left);
@@ -2086,7 +2148,7 @@ var ExpressionEvaluator = class {
2086
2148
  return false;
2087
2149
  }
2088
2150
  invariant5(Array.isArray(left), "expected array");
2089
- return match6(op).with("?", () => left.some((item) => this.evaluate(expr2.right, {
2151
+ return match7(op).with("?", () => left.some((item) => this.evaluate(expr2.right, {
2090
2152
  ...context,
2091
2153
  thisValue: item
2092
2154
  }))).with("!", () => left.every((item) => this.evaluate(expr2.right, {
@@ -2196,6 +2258,10 @@ function getTableName(node) {
2196
2258
  return void 0;
2197
2259
  }
2198
2260
  __name(getTableName, "getTableName");
2261
+ function isBeforeInvocation(expr2) {
2262
+ return ExpressionUtils.isCall(expr2) && expr2.function === "before";
2263
+ }
2264
+ __name(isBeforeInvocation, "isBeforeInvocation");
2199
2265
 
2200
2266
  // src/plugins/policy/expression-transformer.ts
2201
2267
  function _ts_decorate(decorators, target, key, desc) {
@@ -2399,7 +2465,7 @@ var ExpressionTransformer = class {
2399
2465
  const count = FunctionNode2.create("count", [
2400
2466
  ValueNode2.createImmediate(1)
2401
2467
  ]);
2402
- const predicateResult = match7(expr2.op).with("?", () => BinaryOperationNode2.create(count, OperatorNode2.create(">"), ValueNode2.createImmediate(0))).with("!", () => BinaryOperationNode2.create(count, OperatorNode2.create("="), ValueNode2.createImmediate(0))).with("^", () => BinaryOperationNode2.create(count, OperatorNode2.create("="), ValueNode2.createImmediate(0))).exhaustive();
2468
+ const predicateResult = match8(expr2.op).with("?", () => BinaryOperationNode2.create(count, OperatorNode2.create(">"), ValueNode2.createImmediate(0))).with("!", () => BinaryOperationNode2.create(count, OperatorNode2.create("="), ValueNode2.createImmediate(0))).with("^", () => BinaryOperationNode2.create(count, OperatorNode2.create("="), ValueNode2.createImmediate(0))).exhaustive();
2403
2469
  return this.transform(expr2.left, {
2404
2470
  ...context,
2405
2471
  memberSelect: SelectionNode.create(AliasNode3.create(predicateResult, IdentifierNode.create("$t"))),
@@ -2464,7 +2530,7 @@ var ExpressionTransformer = class {
2464
2530
  return logicalNot(this.dialect, this.transform(expr2.operand, context));
2465
2531
  }
2466
2532
  transformOperator(op) {
2467
- const mappedOp = match7(op).with("==", () => "=").otherwise(() => op);
2533
+ const mappedOp = match8(op).with("==", () => "=").otherwise(() => op);
2468
2534
  return OperatorNode2.create(mappedOp);
2469
2535
  }
2470
2536
  _call(expr2, context) {
@@ -2508,7 +2574,7 @@ var ExpressionTransformer = class {
2508
2574
  return this.transformCall(arg, context);
2509
2575
  }
2510
2576
  if (this.isAuthMember(arg)) {
2511
- const valNode = this.valueMemberAccess(context.auth, arg, this.authType);
2577
+ const valNode = this.valueMemberAccess(this.auth, arg, this.authType);
2512
2578
  return valNode ? eb.val(valNode.value) : eb.val(null);
2513
2579
  }
2514
2580
  throw new InternalError(`Unsupported argument expression: ${arg.kind}`);
@@ -2517,6 +2583,11 @@ var ExpressionTransformer = class {
2517
2583
  if (this.isAuthCall(expr2.receiver)) {
2518
2584
  return this.valueMemberAccess(this.auth, expr2, this.authType);
2519
2585
  }
2586
+ if (isBeforeInvocation(expr2.receiver)) {
2587
+ invariant6(context.operation === "post-update", "before() can only be used in post-update policy");
2588
+ invariant6(expr2.members.length === 1, "before() can only be followed by a scalar field access");
2589
+ return ReferenceNode3.create(ColumnNode2.create(expr2.members[0]), TableNode3.create("$before"));
2590
+ }
2520
2591
  invariant6(ExpressionUtils.isField(expr2.receiver) || ExpressionUtils.isThis(expr2.receiver), 'expect receiver to be field expression or "this"');
2521
2592
  let members = expr2.members;
2522
2593
  let receiver;
@@ -2561,7 +2632,6 @@ var ExpressionTransformer = class {
2561
2632
  alias: void 0
2562
2633
  });
2563
2634
  if (currNode) {
2564
- invariant6(SelectQueryNode.is(currNode), "expected select query node");
2565
2635
  currNode = {
2566
2636
  ...relation,
2567
2637
  selections: [
@@ -2772,9 +2842,45 @@ var PolicyHandler = class extends OperationNodeTransformer {
2772
2842
  await this.enforcePreCreatePolicy(node, mutationModel, isManyToManyJoinTable, proceed);
2773
2843
  }
2774
2844
  }
2845
+ const hasPostUpdatePolicies = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
2846
+ let beforeUpdateInfo;
2847
+ if (hasPostUpdatePolicies) {
2848
+ beforeUpdateInfo = await this.loadBeforeUpdateEntities(mutationModel, node.where, proceed);
2849
+ }
2775
2850
  const result = await proceed(this.transformNode(node));
2851
+ if (hasPostUpdatePolicies && result.rows.length > 0) {
2852
+ const idConditions = this.buildIdConditions(mutationModel, result.rows);
2853
+ const postUpdateFilter = this.buildPolicyFilter(mutationModel, void 0, "post-update");
2854
+ const eb = expressionBuilder3();
2855
+ const beforeUpdateTable = beforeUpdateInfo ? {
2856
+ kind: "SelectQueryNode",
2857
+ from: FromNode2.create([
2858
+ ParensNode2.create(ValuesNode.create(beforeUpdateInfo.rows.map((r) => PrimitiveValueListNode.create(beforeUpdateInfo.fields.map((f) => r[f])))))
2859
+ ]),
2860
+ selections: beforeUpdateInfo.fields.map((name, index) => {
2861
+ const def = requireField(this.client.$schema, mutationModel, name);
2862
+ const castedColumnRef = sql4`CAST(${eb.ref(`column${index + 1}`)} as ${sql4.raw(this.dialect.getFieldSqlType(def))})`.as(name);
2863
+ return SelectionNode2.create(castedColumnRef.toOperationNode());
2864
+ })
2865
+ } : void 0;
2866
+ const postUpdateQuery = eb.selectFrom(mutationModel).select(() => [
2867
+ eb(eb.fn("COUNT", [
2868
+ eb.lit(1)
2869
+ ]), "=", result.rows.length).as("$condition")
2870
+ ]).where(() => new ExpressionWrapper(conjunction(this.dialect, [
2871
+ idConditions,
2872
+ postUpdateFilter
2873
+ ]))).$if(!!beforeUpdateInfo, (qb) => qb.leftJoin(() => new ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
2874
+ const idFields = requireIdFields(this.client.$schema, mutationModel);
2875
+ return idFields.reduce((acc, f) => acc.onRef(`${mutationModel}.${f}`, "=", `$before.${f}`), join);
2876
+ }));
2877
+ const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
2878
+ if (!postUpdateResult.rows[0]?.$condition) {
2879
+ throw new RejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
2880
+ }
2881
+ }
2776
2882
  if (!node.returning || this.onlyReturningId(node)) {
2777
- return result;
2883
+ return this.postProcessMutationResult(result, node);
2778
2884
  } else {
2779
2885
  const readBackResult = await this.processReadBack(node, result, proceed);
2780
2886
  if (readBackResult.rows.length !== result.rows.length) {
@@ -2783,6 +2889,68 @@ var PolicyHandler = class extends OperationNodeTransformer {
2783
2889
  return readBackResult;
2784
2890
  }
2785
2891
  }
2892
+ // correction to kysely mutation result may be needed because we might have added
2893
+ // returning clause to the query and caused changes to the result shape
2894
+ postProcessMutationResult(result, node) {
2895
+ if (node.returning) {
2896
+ return result;
2897
+ } else {
2898
+ return {
2899
+ ...result,
2900
+ rows: [],
2901
+ numAffectedRows: result.numAffectedRows ?? BigInt(result.rows.length)
2902
+ };
2903
+ }
2904
+ }
2905
+ hasPostUpdatePolicies(model) {
2906
+ const policies = this.getModelPolicies(model, "post-update");
2907
+ return policies.length > 0;
2908
+ }
2909
+ async loadBeforeUpdateEntities(model, where, proceed) {
2910
+ const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
2911
+ if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
2912
+ return void 0;
2913
+ }
2914
+ const query = {
2915
+ kind: "SelectQueryNode",
2916
+ from: FromNode2.create([
2917
+ TableNode4.create(model)
2918
+ ]),
2919
+ where,
2920
+ selections: [
2921
+ ...beforeUpdateAccessFields.map((f) => SelectionNode2.create(ColumnNode3.create(f)))
2922
+ ]
2923
+ };
2924
+ const result = await proceed(query);
2925
+ return {
2926
+ fields: beforeUpdateAccessFields,
2927
+ rows: result.rows
2928
+ };
2929
+ }
2930
+ getFieldsAccessForBeforeUpdatePolicies(model) {
2931
+ const policies = this.getModelPolicies(model, "post-update");
2932
+ if (policies.length === 0) {
2933
+ return void 0;
2934
+ }
2935
+ const fields = /* @__PURE__ */ new Set();
2936
+ const fieldCollector = new class extends ExpressionVisitor {
2937
+ visitMember(e) {
2938
+ if (isBeforeInvocation(e.receiver)) {
2939
+ invariant7(e.members.length === 1, "before() can only be followed by a scalar field access");
2940
+ fields.add(e.members[0]);
2941
+ }
2942
+ super.visitMember(e);
2943
+ }
2944
+ }();
2945
+ for (const policy of policies) {
2946
+ fieldCollector.visit(policy.condition);
2947
+ }
2948
+ if (fields.size === 0) {
2949
+ return void 0;
2950
+ }
2951
+ requireIdFields(this.client.$schema, model).forEach((f) => fields.add(f));
2952
+ return Array.from(fields).sort();
2953
+ }
2786
2954
  // #region overrides
2787
2955
  transformSelectQuery(node) {
2788
2956
  let whereNode = this.transformNode(node.where);
@@ -2848,19 +3016,16 @@ var PolicyHandler = class extends OperationNodeTransformer {
2848
3016
  onConflict
2849
3017
  } : node;
2850
3018
  const result = super.transformInsertQuery(processedNode);
2851
- if (!node.returning) {
2852
- return result;
2853
- }
2854
- if (this.onlyReturningId(node)) {
2855
- return result;
2856
- } else {
3019
+ let returning = result.returning;
3020
+ if (returning) {
2857
3021
  const { mutationModel } = this.getMutationModel(node);
2858
3022
  const idFields = requireIdFields(this.client.$schema, mutationModel);
2859
- return {
2860
- ...result,
2861
- returning: ReturningNode.create(idFields.map((field) => SelectionNode2.create(ColumnNode3.create(field))))
2862
- };
3023
+ returning = ReturningNode.create(idFields.map((f) => SelectionNode2.create(ColumnNode3.create(f))));
2863
3024
  }
3025
+ return {
3026
+ ...result,
3027
+ returning
3028
+ };
2864
3029
  }
2865
3030
  transformUpdateQuery(node) {
2866
3031
  const result = super.transformUpdateQuery(node);
@@ -2875,12 +3040,18 @@ var PolicyHandler = class extends OperationNodeTransformer {
2875
3040
  ]);
2876
3041
  }
2877
3042
  }
3043
+ let returning = result.returning;
3044
+ if (returning || this.hasPostUpdatePolicies(mutationModel)) {
3045
+ const idFields = requireIdFields(this.client.$schema, mutationModel);
3046
+ returning = ReturningNode.create(idFields.map((f) => SelectionNode2.create(ColumnNode3.create(f))));
3047
+ }
2878
3048
  return {
2879
3049
  ...result,
2880
3050
  where: WhereNode2.create(result.where ? conjunction(this.dialect, [
2881
3051
  result.where.where,
2882
3052
  filter
2883
- ]) : filter)
3053
+ ]) : filter),
3054
+ returning
2884
3055
  };
2885
3056
  }
2886
3057
  transformDeleteQuery(node) {
@@ -2912,6 +3083,14 @@ var PolicyHandler = class extends OperationNodeTransformer {
2912
3083
  }
2913
3084
  const { mutationModel } = this.getMutationModel(node);
2914
3085
  const idFields = requireIdFields(this.client.$schema, mutationModel);
3086
+ if (node.returning.selections.some((s) => SelectAllNode.is(s.selection))) {
3087
+ const modelDef = requireModel(this.client.$schema, mutationModel);
3088
+ if (Object.keys(modelDef.fields).some((f) => !idFields.includes(f))) {
3089
+ return false;
3090
+ } else {
3091
+ return true;
3092
+ }
3093
+ }
2915
3094
  const collector = new ColumnCollector();
2916
3095
  const selectedColumns = collector.collect(node.returning);
2917
3096
  return selectedColumns.every((c) => idFields.includes(c));
@@ -3095,10 +3274,10 @@ var PolicyHandler = class extends OperationNodeTransformer {
3095
3274
  }
3096
3275
  buildIdConditions(table, rows) {
3097
3276
  const idFields = requireIdFields(this.client.$schema, table);
3098
- return disjunction(this.dialect, rows.map((row) => conjunction(this.dialect, idFields.map((field) => BinaryOperationNode3.create(ColumnNode3.create(field), OperatorNode3.create("="), ValueNode3.create(row[field]))))));
3277
+ return disjunction(this.dialect, rows.map((row) => conjunction(this.dialect, idFields.map((field) => BinaryOperationNode3.create(ReferenceNode4.create(ColumnNode3.create(field), TableNode4.create(table)), OperatorNode3.create("="), ValueNode3.create(row[field]))))));
3099
3278
  }
3100
3279
  getMutationModel(node) {
3101
- const r = match8(node).when(InsertQueryNode.is, (node2) => ({
3280
+ const r = match9(node).when(InsertQueryNode.is, (node2) => ({
3102
3281
  mutationModel: getTableName(node2.into),
3103
3282
  alias: void 0
3104
3283
  })).when(UpdateQueryNode.is, (node2) => {
@@ -3137,23 +3316,24 @@ var PolicyHandler = class extends OperationNodeTransformer {
3137
3316
  return m2mFilter;
3138
3317
  }
3139
3318
  const policies = this.getModelPolicies(model, operation);
3140
- if (policies.length === 0) {
3141
- return falseNode(this.dialect);
3142
- }
3143
3319
  const allows = policies.filter((policy) => policy.kind === "allow").map((policy) => this.compilePolicyCondition(model, alias, operation, policy));
3144
3320
  const denies = policies.filter((policy) => policy.kind === "deny").map((policy) => this.compilePolicyCondition(model, alias, operation, policy));
3145
3321
  let combinedPolicy;
3146
3322
  if (allows.length === 0) {
3147
- combinedPolicy = falseNode(this.dialect);
3323
+ if (operation === "post-update") {
3324
+ combinedPolicy = trueNode(this.dialect);
3325
+ } else {
3326
+ combinedPolicy = falseNode(this.dialect);
3327
+ }
3148
3328
  } else {
3149
3329
  combinedPolicy = disjunction(this.dialect, allows);
3150
- if (denies.length !== 0) {
3151
- const combinedDenies = conjunction(this.dialect, denies.map((d) => buildIsFalse(d, this.dialect)));
3152
- combinedPolicy = conjunction(this.dialect, [
3153
- combinedPolicy,
3154
- combinedDenies
3155
- ]);
3156
- }
3330
+ }
3331
+ if (denies.length !== 0) {
3332
+ const combinedDenies = conjunction(this.dialect, denies.map((d) => buildIsFalse(d, this.dialect)));
3333
+ combinedPolicy = conjunction(this.dialect, [
3334
+ combinedPolicy,
3335
+ combinedDenies
3336
+ ]);
3157
3337
  }
3158
3338
  return combinedPolicy;
3159
3339
  }
@@ -3200,8 +3380,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
3200
3380
  return new ExpressionTransformer(this.client).transform(policy.condition, {
3201
3381
  model,
3202
3382
  alias,
3203
- operation,
3204
- auth: this.client.$auth
3383
+ operation
3205
3384
  });
3206
3385
  }
3207
3386
  getModelPolicies(model, operation) {
@@ -3217,7 +3396,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
3217
3396
  kind: attr.name === "@@allow" ? "allow" : "deny",
3218
3397
  operations: extractOperations(attr.args[0].value),
3219
3398
  condition: attr.args[1].value
3220
- })).filter((policy) => policy.operations.includes("all") || policy.operations.includes(operation)));
3399
+ })).filter((policy) => operation !== "post-update" && policy.operations.includes("all") || policy.operations.includes(operation)));
3221
3400
  }
3222
3401
  return result;
3223
3402
  }
@@ -3318,18 +3497,9 @@ var PolicyPlugin = class {
3318
3497
  check
3319
3498
  };
3320
3499
  }
3321
- onKyselyQuery({
3322
- query,
3323
- client,
3324
- proceed
3325
- /*, transaction*/
3326
- }) {
3500
+ onKyselyQuery({ query, client, proceed }) {
3327
3501
  const handler = new PolicyHandler(client);
3328
- return handler.handle(
3329
- query,
3330
- proceed
3331
- /*, transaction*/
3332
- );
3502
+ return handler.handle(query, proceed);
3333
3503
  }
3334
3504
  };
3335
3505
 
@@ -3870,7 +4040,7 @@ var BaseOperationHandler = class {
3870
4040
  }
3871
4041
  evalGenerator(defaultValue) {
3872
4042
  if (ExpressionUtils.isCall(defaultValue)) {
3873
- return match9(defaultValue.function).with("cuid", () => createId()).with("uuid", () => defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args?.[0]) && defaultValue.args[0].value === 7 ? uuid.v7() : uuid.v4()).with("nanoid", () => defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args[0]) && typeof defaultValue.args[0].value === "number" ? nanoid(defaultValue.args[0].value) : nanoid()).with("ulid", () => ulid()).otherwise(() => void 0);
4043
+ return match10(defaultValue.function).with("cuid", () => createId()).with("uuid", () => defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args?.[0]) && defaultValue.args[0].value === 7 ? uuid.v7() : uuid.v4()).with("nanoid", () => defaultValue.args?.[0] && ExpressionUtils.isLiteral(defaultValue.args[0]) && typeof defaultValue.args[0].value === "number" ? nanoid(defaultValue.args[0].value) : nanoid()).with("ulid", () => ulid()).otherwise(() => void 0);
3874
4044
  } else if (ExpressionUtils.isMember(defaultValue) && ExpressionUtils.isCall(defaultValue.receiver) && defaultValue.receiver.function === "auth") {
3875
4045
  let val = this.client.$auth;
3876
4046
  for (const member of defaultValue.members) {
@@ -4052,7 +4222,7 @@ var BaseOperationHandler = class {
4052
4222
  const value = this.dialect.transformPrimitive(payload[key], fieldDef.type, false);
4053
4223
  const eb = expressionBuilder4();
4054
4224
  const fieldRef = this.dialect.fieldRef(model, field, eb);
4055
- return match9(key).with("set", () => value).with("increment", () => eb(fieldRef, "+", value)).with("decrement", () => eb(fieldRef, "-", value)).with("multiply", () => eb(fieldRef, "*", value)).with("divide", () => eb(fieldRef, "/", value)).otherwise(() => {
4225
+ return match10(key).with("set", () => value).with("increment", () => eb(fieldRef, "+", value)).with("decrement", () => eb(fieldRef, "-", value)).with("multiply", () => eb(fieldRef, "*", value)).with("divide", () => eb(fieldRef, "/", value)).otherwise(() => {
4056
4226
  throw new InternalError(`Invalid incremental update operation: ${key}`);
4057
4227
  });
4058
4228
  }
@@ -4062,7 +4232,7 @@ var BaseOperationHandler = class {
4062
4232
  const value = this.dialect.transformPrimitive(payload[key], fieldDef.type, true);
4063
4233
  const eb = expressionBuilder4();
4064
4234
  const fieldRef = this.dialect.fieldRef(model, field, eb);
4065
- return match9(key).with("set", () => value).with("push", () => {
4235
+ return match10(key).with("set", () => value).with("push", () => {
4066
4236
  return eb(fieldRef, "||", eb.val(ensureArray(value)));
4067
4237
  }).otherwise(() => {
4068
4238
  throw new InternalError(`Invalid array update operation: ${key}`);
@@ -4132,8 +4302,9 @@ var BaseOperationHandler = class {
4132
4302
  };
4133
4303
  } else {
4134
4304
  const idFields = requireIdFields(this.schema, model);
4135
- const result = await query.returning(idFields).execute();
4136
- return result;
4305
+ const finalQuery = query.returning(idFields);
4306
+ const result = await this.executeQuery(kysely, finalQuery, "update");
4307
+ return result.rows;
4137
4308
  }
4138
4309
  }
4139
4310
  async processBaseModelUpdateMany(kysely, model, where, updateFields, filterModel) {
@@ -4753,7 +4924,7 @@ var AggregateOperationHandler = class extends BaseOperationHandler {
4753
4924
  Object.entries(value).forEach(([field, val]) => {
4754
4925
  if (val === true) {
4755
4926
  query = query.select((eb) => {
4756
- const fn = match10(key).with("_sum", () => eb.fn.sum).with("_avg", () => eb.fn.avg).with("_max", () => eb.fn.max).with("_min", () => eb.fn.min).exhaustive();
4927
+ const fn = match11(key).with("_sum", () => eb.fn.sum).with("_avg", () => eb.fn.avg).with("_max", () => eb.fn.max).with("_min", () => eb.fn.min).exhaustive();
4757
4928
  return fn(sql6.ref(`$sub.${field}`)).as(`${key}.${field}`);
4758
4929
  });
4759
4930
  }
@@ -4786,7 +4957,7 @@ var AggregateOperationHandler = class extends BaseOperationHandler {
4786
4957
  val = parseFloat(val);
4787
4958
  } else {
4788
4959
  if (op === "_sum" || op === "_min" || op === "_max") {
4789
- val = match10(type).with("Int", () => parseInt(value, 10)).with("BigInt", () => BigInt(value)).with("Float", () => parseFloat(value)).with("Decimal", () => parseFloat(value)).otherwise(() => value);
4960
+ val = match11(type).with("Int", () => parseInt(value, 10)).with("BigInt", () => BigInt(value)).with("Float", () => parseFloat(value)).with("Decimal", () => parseFloat(value)).otherwise(() => value);
4790
4961
  }
4791
4962
  }
4792
4963
  }
@@ -4837,14 +5008,14 @@ var CountOperationHandler = class extends BaseOperationHandler {
4837
5008
  };
4838
5009
 
4839
5010
  // src/client/crud/operations/create.ts
4840
- import { match as match11 } from "ts-pattern";
5011
+ import { match as match12 } from "ts-pattern";
4841
5012
  var CreateOperationHandler = class extends BaseOperationHandler {
4842
5013
  static {
4843
5014
  __name(this, "CreateOperationHandler");
4844
5015
  }
4845
5016
  async handle(operation, args) {
4846
5017
  const normalizedArgs = this.normalizeArgs(args);
4847
- return match11(operation).with("create", () => this.runCreate(this.inputValidator.validateCreateArgs(this.model, normalizedArgs))).with("createMany", () => {
5018
+ return match12(operation).with("create", () => this.runCreate(this.inputValidator.validateCreateArgs(this.model, normalizedArgs))).with("createMany", () => {
4848
5019
  return this.runCreateMany(this.inputValidator.validateCreateManyArgs(this.model, normalizedArgs));
4849
5020
  }).with("createManyAndReturn", () => {
4850
5021
  return this.runCreateManyAndReturn(this.inputValidator.validateCreateManyAndReturnArgs(this.model, normalizedArgs));
@@ -4891,14 +5062,14 @@ var CreateOperationHandler = class extends BaseOperationHandler {
4891
5062
  };
4892
5063
 
4893
5064
  // src/client/crud/operations/delete.ts
4894
- import { match as match12 } from "ts-pattern";
5065
+ import { match as match13 } from "ts-pattern";
4895
5066
  var DeleteOperationHandler = class extends BaseOperationHandler {
4896
5067
  static {
4897
5068
  __name(this, "DeleteOperationHandler");
4898
5069
  }
4899
5070
  async handle(operation, args) {
4900
5071
  const normalizedArgs = this.normalizeArgs(args);
4901
- return match12(operation).with("delete", () => this.runDelete(this.inputValidator.validateDeleteArgs(this.model, normalizedArgs))).with("deleteMany", () => this.runDeleteMany(this.inputValidator.validateDeleteManyArgs(this.model, normalizedArgs))).exhaustive();
5072
+ return match13(operation).with("delete", () => this.runDelete(this.inputValidator.validateDeleteArgs(this.model, normalizedArgs))).with("deleteMany", () => this.runDeleteMany(this.inputValidator.validateDeleteManyArgs(this.model, normalizedArgs))).exhaustive();
4902
5073
  }
4903
5074
  async runDelete(args) {
4904
5075
  const existing = await this.readUnique(this.kysely, this.model, {
@@ -4950,7 +5121,7 @@ var FindOperationHandler = class extends BaseOperationHandler {
4950
5121
 
4951
5122
  // src/client/crud/operations/group-by.ts
4952
5123
  import { expressionBuilder as expressionBuilder5 } from "kysely";
4953
- import { match as match13 } from "ts-pattern";
5124
+ import { match as match14 } from "ts-pattern";
4954
5125
  var GroupByOperationHandler = class extends BaseOperationHandler {
4955
5126
  static {
4956
5127
  __name(this, "GroupByOperationHandler");
@@ -5044,7 +5215,7 @@ var GroupByOperationHandler = class extends BaseOperationHandler {
5044
5215
  val = parseFloat(val);
5045
5216
  } else {
5046
5217
  if (op === "_sum" || op === "_min" || op === "_max") {
5047
- val = match13(type).with("Int", () => parseInt(value, 10)).with("BigInt", () => BigInt(value)).with("Float", () => parseFloat(value)).with("Decimal", () => parseFloat(value)).otherwise(() => value);
5218
+ val = match14(type).with("Int", () => parseInt(value, 10)).with("BigInt", () => BigInt(value)).with("Float", () => parseFloat(value)).with("Decimal", () => parseFloat(value)).otherwise(() => value);
5048
5219
  }
5049
5220
  }
5050
5221
  }
@@ -5059,14 +5230,14 @@ var GroupByOperationHandler = class extends BaseOperationHandler {
5059
5230
  };
5060
5231
 
5061
5232
  // src/client/crud/operations/update.ts
5062
- import { match as match14 } from "ts-pattern";
5233
+ import { match as match15 } from "ts-pattern";
5063
5234
  var UpdateOperationHandler = class extends BaseOperationHandler {
5064
5235
  static {
5065
5236
  __name(this, "UpdateOperationHandler");
5066
5237
  }
5067
5238
  async handle(operation, args) {
5068
5239
  const normalizedArgs = this.normalizeArgs(args);
5069
- return match14(operation).with("update", () => this.runUpdate(this.inputValidator.validateUpdateArgs(this.model, normalizedArgs))).with("updateMany", () => this.runUpdateMany(this.inputValidator.validateUpdateManyArgs(this.model, normalizedArgs))).with("updateManyAndReturn", () => this.runUpdateManyAndReturn(this.inputValidator.validateUpdateManyAndReturnArgs(this.model, normalizedArgs))).with("upsert", () => this.runUpsert(this.inputValidator.validateUpsertArgs(this.model, normalizedArgs))).exhaustive();
5240
+ return match15(operation).with("update", () => this.runUpdate(this.inputValidator.validateUpdateArgs(this.model, normalizedArgs))).with("updateMany", () => this.runUpdateMany(this.inputValidator.validateUpdateManyArgs(this.model, normalizedArgs))).with("updateManyAndReturn", () => this.runUpdateManyAndReturn(this.inputValidator.validateUpdateManyAndReturnArgs(this.model, normalizedArgs))).with("upsert", () => this.runUpsert(this.inputValidator.validateUpsertArgs(this.model, normalizedArgs))).exhaustive();
5070
5241
  }
5071
5242
  async runUpdate(args) {
5072
5243
  const readBackResult = await this.safeTransaction(async (tx) => {
@@ -5146,7 +5317,7 @@ var UpdateOperationHandler = class extends BaseOperationHandler {
5146
5317
  import { invariant as invariant10 } from "@zenstackhq/common-helpers";
5147
5318
  import Decimal3 from "decimal.js";
5148
5319
  import stableStringify from "json-stable-stringify";
5149
- import { match as match15, P as P2 } from "ts-pattern";
5320
+ import { match as match16, P as P2 } from "ts-pattern";
5150
5321
  import { z } from "zod";
5151
5322
 
5152
5323
  // src/utils/zod-utils.ts
@@ -5262,7 +5433,7 @@ var InputValidator = class {
5262
5433
  if (this.schema.typeDefs && type in this.schema.typeDefs) {
5263
5434
  return this.makeTypeDefSchema(type);
5264
5435
  } else {
5265
- return match15(type).with("String", () => z.string()).with("Int", () => z.number().int()).with("Float", () => z.number()).with("Boolean", () => z.boolean()).with("BigInt", () => z.union([
5436
+ return match16(type).with("String", () => z.string()).with("Int", () => z.number().int()).with("Float", () => z.number()).with("Boolean", () => z.boolean()).with("BigInt", () => z.union([
5266
5437
  z.number().int(),
5267
5438
  z.bigint()
5268
5439
  ])).with("Decimal", () => z.union([
@@ -5423,7 +5594,7 @@ var InputValidator = class {
5423
5594
  if (this.schema.typeDefs && type in this.schema.typeDefs) {
5424
5595
  return this.makeTypeDefFilterSchema(type, optional);
5425
5596
  }
5426
- return match15(type).with("String", () => this.makeStringFilterSchema(optional, withAggregations)).with(P2.union("Int", "Float", "Decimal", "BigInt"), (type2) => this.makeNumberFilterSchema(this.makePrimitiveSchema(type2), optional, withAggregations)).with("Boolean", () => this.makeBooleanFilterSchema(optional, withAggregations)).with("DateTime", () => this.makeDateTimeFilterSchema(optional, withAggregations)).with("Bytes", () => this.makeBytesFilterSchema(optional, withAggregations)).with("Json", () => z.any()).with("Unsupported", () => z.never()).exhaustive();
5597
+ return match16(type).with("String", () => this.makeStringFilterSchema(optional, withAggregations)).with(P2.union("Int", "Float", "Decimal", "BigInt"), (type2) => this.makeNumberFilterSchema(this.makePrimitiveSchema(type2), optional, withAggregations)).with("Boolean", () => this.makeBooleanFilterSchema(optional, withAggregations)).with("DateTime", () => this.makeDateTimeFilterSchema(optional, withAggregations)).with("Bytes", () => this.makeBytesFilterSchema(optional, withAggregations)).with("Json", () => z.any()).with("Unsupported", () => z.never()).exhaustive();
5427
5598
  }
5428
5599
  makeTypeDefFilterSchema(_type, _optional) {
5429
5600
  return z.never();
@@ -6369,11 +6540,11 @@ __name(performanceNow, "performanceNow");
6369
6540
  // src/client/executor/zenstack-query-executor.ts
6370
6541
  import { invariant as invariant12 } from "@zenstackhq/common-helpers";
6371
6542
  import { AndNode as AndNode2, DefaultQueryExecutor, DeleteQueryNode as DeleteQueryNode2, InsertQueryNode as InsertQueryNode2, ReturningNode as ReturningNode2, SelectionNode as SelectionNode4, SingleConnectionProvider, TableNode as TableNode6, UpdateQueryNode as UpdateQueryNode2, WhereNode as WhereNode3 } from "kysely";
6372
- import { match as match16 } from "ts-pattern";
6543
+ import { match as match17 } from "ts-pattern";
6373
6544
 
6374
6545
  // src/client/executor/name-mapper.ts
6375
6546
  import { invariant as invariant11 } from "@zenstackhq/common-helpers";
6376
- import { AliasNode as AliasNode5, ColumnNode as ColumnNode4, FromNode as FromNode3, IdentifierNode as IdentifierNode3, OperationNodeTransformer as OperationNodeTransformer2, ReferenceNode as ReferenceNode4, SelectAllNode, SelectionNode as SelectionNode3, TableNode as TableNode5 } from "kysely";
6547
+ import { AliasNode as AliasNode5, ColumnNode as ColumnNode4, FromNode as FromNode3, IdentifierNode as IdentifierNode3, OperationNodeTransformer as OperationNodeTransformer2, ReferenceNode as ReferenceNode5, SelectAllNode as SelectAllNode2, SelectionNode as SelectionNode3, TableNode as TableNode5 } from "kysely";
6377
6548
  var QueryNameMapper = class extends OperationNodeTransformer2 {
6378
6549
  static {
6379
6550
  __name(this, "QueryNameMapper");
@@ -6455,7 +6626,7 @@ var QueryNameMapper = class extends OperationNodeTransformer2 {
6455
6626
  mappedTableName = this.mapTableName(scope.model);
6456
6627
  }
6457
6628
  }
6458
- return ReferenceNode4.create(ColumnNode4.create(mappedFieldName), mappedTableName ? TableNode5.create(mappedTableName) : void 0);
6629
+ return ReferenceNode5.create(ColumnNode4.create(mappedFieldName), mappedTableName ? TableNode5.create(mappedTableName) : void 0);
6459
6630
  } else {
6460
6631
  return super.transformReference(node);
6461
6632
  }
@@ -6516,14 +6687,14 @@ var QueryNameMapper = class extends OperationNodeTransformer2 {
6516
6687
  processSelectQuerySelections(node) {
6517
6688
  const selections = [];
6518
6689
  for (const selection of node.selections ?? []) {
6519
- if (SelectAllNode.is(selection.selection)) {
6690
+ if (SelectAllNode2.is(selection.selection)) {
6520
6691
  const scope = this.scopes[this.scopes.length - 1];
6521
6692
  if (scope?.model && !scope.namesMapped) {
6522
6693
  selections.push(...this.createSelectAllFields(scope.model, scope.alias));
6523
6694
  } else {
6524
6695
  selections.push(super.transformSelection(selection));
6525
6696
  }
6526
- } else if (ReferenceNode4.is(selection.selection) || ColumnNode4.is(selection.selection)) {
6697
+ } else if (ReferenceNode5.is(selection.selection) || ColumnNode4.is(selection.selection)) {
6527
6698
  const transformed = this.transformNode(selection.selection);
6528
6699
  if (AliasNode5.is(transformed)) {
6529
6700
  selections.push(SelectionNode3.create(transformed));
@@ -6665,7 +6836,7 @@ var QueryNameMapper = class extends OperationNodeTransformer2 {
6665
6836
  const modelDef = requireModel(this.schema, model);
6666
6837
  return this.getModelFields(modelDef).map((fieldDef) => {
6667
6838
  const columnName = this.mapFieldName(model, fieldDef.name);
6668
- const columnRef = ReferenceNode4.create(ColumnNode4.create(columnName), alias && IdentifierNode3.is(alias) ? TableNode5.create(alias.name) : void 0);
6839
+ const columnRef = ReferenceNode5.create(ColumnNode4.create(columnName), alias && IdentifierNode3.is(alias) ? TableNode5.create(alias.name) : void 0);
6669
6840
  if (columnName !== fieldDef.name) {
6670
6841
  const aliased = AliasNode5.create(columnRef, IdentifierNode3.create(fieldDef.name));
6671
6842
  return SelectionNode3.create(aliased);
@@ -6680,7 +6851,7 @@ var QueryNameMapper = class extends OperationNodeTransformer2 {
6680
6851
  processSelections(selections) {
6681
6852
  const result = [];
6682
6853
  selections.forEach((selection) => {
6683
- if (SelectAllNode.is(selection.selection)) {
6854
+ if (SelectAllNode2.is(selection.selection)) {
6684
6855
  const processed = this.processSelectAll(selection.selection);
6685
6856
  if (Array.isArray(processed)) {
6686
6857
  result.push(...processed.map((s) => SelectionNode3.create(s)));
@@ -6710,7 +6881,7 @@ var QueryNameMapper = class extends OperationNodeTransformer2 {
6710
6881
  const modelDef = requireModel(this.schema, scope.model);
6711
6882
  return this.getModelFields(modelDef).map((fieldDef) => {
6712
6883
  const columnName = this.mapFieldName(modelDef.name, fieldDef.name);
6713
- const columnRef = ReferenceNode4.create(ColumnNode4.create(columnName));
6884
+ const columnRef = ReferenceNode5.create(ColumnNode4.create(columnName));
6714
6885
  return columnName !== fieldDef.name ? this.wrapAlias(columnRef, IdentifierNode3.create(fieldDef.name)) : columnRef;
6715
6886
  });
6716
6887
  }
@@ -6737,13 +6908,37 @@ var ZenStackQueryExecutor = class _ZenStackQueryExecutor extends DefaultQueryExe
6737
6908
  get options() {
6738
6909
  return this.client.$options;
6739
6910
  }
6740
- async executeQuery(compiledQuery, queryId) {
6911
+ executeQuery(compiledQuery, queryId) {
6741
6912
  const queryParams = compiledQuery.$raw ? compiledQuery.parameters : void 0;
6742
- const result = await this.proceedQueryWithKyselyInterceptors(compiledQuery.query, queryParams, queryId.queryId);
6743
- return result.result;
6913
+ return this.provideConnection(async (connection) => {
6914
+ let startedTx = false;
6915
+ try {
6916
+ if (this.isMutationNode(compiledQuery.query) && !this.driver.isTransactionConnection(connection)) {
6917
+ await this.driver.beginTransaction(connection, {
6918
+ isolationLevel: TransactionIsolationLevel.RepeatableRead
6919
+ });
6920
+ startedTx = true;
6921
+ }
6922
+ const result = await this.proceedQueryWithKyselyInterceptors(connection, compiledQuery.query, queryParams, queryId.queryId);
6923
+ if (startedTx) {
6924
+ await this.driver.commitTransaction(connection);
6925
+ }
6926
+ return result;
6927
+ } catch (err) {
6928
+ if (startedTx) {
6929
+ await this.driver.rollbackTransaction(connection);
6930
+ }
6931
+ if (err instanceof ZenStackError) {
6932
+ throw err;
6933
+ } else {
6934
+ const message = `Failed to execute query: ${err}, sql: ${compiledQuery?.sql}`;
6935
+ throw new QueryError(message, err);
6936
+ }
6937
+ }
6938
+ });
6744
6939
  }
6745
- async proceedQueryWithKyselyInterceptors(queryNode, parameters, queryId) {
6746
- let proceed = /* @__PURE__ */ __name((q) => this.proceedQuery(q, parameters, queryId), "proceed");
6940
+ async proceedQueryWithKyselyInterceptors(connection, queryNode, parameters, queryId) {
6941
+ let proceed = /* @__PURE__ */ __name((q) => this.proceedQuery(connection, q, parameters, queryId), "proceed");
6747
6942
  const hooks = [];
6748
6943
  for (const plugin of this.client.$options.plugins ?? []) {
6749
6944
  if (plugin.onKyselyQuery) {
@@ -6753,19 +6948,14 @@ var ZenStackQueryExecutor = class _ZenStackQueryExecutor extends DefaultQueryExe
6753
6948
  for (const hook of hooks) {
6754
6949
  const _proceed = proceed;
6755
6950
  proceed = /* @__PURE__ */ __name(async (query) => {
6756
- const _p = /* @__PURE__ */ __name(async (q) => {
6757
- const r = await _proceed(q);
6758
- return r.result;
6759
- }, "_p");
6951
+ const _p = /* @__PURE__ */ __name((q) => _proceed(q), "_p");
6760
6952
  const hookResult = await hook({
6761
6953
  client: this.client,
6762
6954
  schema: this.client.$schema,
6763
6955
  query,
6764
6956
  proceed: _p
6765
6957
  });
6766
- return {
6767
- result: hookResult
6768
- };
6958
+ return hookResult;
6769
6959
  }, "proceed");
6770
6960
  }
6771
6961
  const result = await proceed(queryNode);
@@ -6773,7 +6963,7 @@ var ZenStackQueryExecutor = class _ZenStackQueryExecutor extends DefaultQueryExe
6773
6963
  }
6774
6964
  getMutationInfo(queryNode) {
6775
6965
  const model = this.getMutationModel(queryNode);
6776
- const { action, where } = match16(queryNode).when(InsertQueryNode2.is, () => ({
6966
+ const { action, where } = match17(queryNode).when(InsertQueryNode2.is, () => ({
6777
6967
  action: "create",
6778
6968
  where: void 0
6779
6969
  })).when(UpdateQueryNode2.is, (node) => ({
@@ -6789,85 +6979,54 @@ var ZenStackQueryExecutor = class _ZenStackQueryExecutor extends DefaultQueryExe
6789
6979
  where
6790
6980
  };
6791
6981
  }
6792
- async proceedQuery(query, parameters, queryId) {
6982
+ async proceedQuery(connection, query, parameters, queryId) {
6793
6983
  let compiled;
6794
- try {
6795
- return await this.provideConnection(async (connection) => {
6796
- if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) {
6797
- const finalQuery2 = this.nameMapper.transformNode(query);
6798
- compiled = this.compileQuery(finalQuery2);
6799
- if (parameters) {
6800
- compiled = {
6801
- ...compiled,
6802
- parameters
6803
- };
6804
- }
6805
- const result = await connection.executeQuery(compiled);
6806
- return {
6807
- result
6808
- };
6809
- }
6810
- if ((InsertQueryNode2.is(query) || UpdateQueryNode2.is(query)) && this.hasEntityMutationPluginsWithAfterMutationHooks) {
6811
- query = {
6812
- ...query,
6813
- returning: ReturningNode2.create([
6814
- SelectionNode4.createSelectAll()
6815
- ])
6816
- };
6817
- }
6818
- const finalQuery = this.nameMapper.transformNode(query);
6819
- compiled = this.compileQuery(finalQuery);
6820
- if (parameters) {
6821
- compiled = {
6822
- ...compiled,
6823
- parameters
6824
- };
6825
- }
6826
- const currentlyInTx = this.driver.isTransactionConnection(connection);
6827
- const connectionClient = this.createClientForConnection(connection, currentlyInTx);
6828
- const mutationInfo = this.getMutationInfo(finalQuery);
6829
- let beforeMutationEntities;
6830
- const loadBeforeMutationEntities = /* @__PURE__ */ __name(async () => {
6831
- if (beforeMutationEntities === void 0 && (UpdateQueryNode2.is(query) || DeleteQueryNode2.is(query))) {
6832
- beforeMutationEntities = await this.loadEntities(mutationInfo.model, mutationInfo.where, connection);
6833
- }
6834
- return beforeMutationEntities;
6835
- }, "loadBeforeMutationEntities");
6836
- await this.callBeforeMutationHooks(finalQuery, mutationInfo, loadBeforeMutationEntities, connectionClient, queryId);
6837
- const shouldCreateTx = this.hasPluginRequestingAfterMutationWithinTransaction && !this.driver.isTransactionConnection(connection);
6838
- if (!shouldCreateTx) {
6839
- const result = await connection.executeQuery(compiled);
6840
- if (!this.driver.isTransactionConnection(connection)) {
6841
- await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, "all", queryId);
6842
- } else {
6843
- await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, "inTx", queryId);
6844
- this.driver.registerTransactionCommitCallback(connection, () => this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, "outTx", queryId));
6845
- }
6846
- return {
6847
- result
6848
- };
6849
- } else {
6850
- await this.driver.beginTransaction(connection, {
6851
- isolationLevel: TransactionIsolationLevel.ReadCommitted
6852
- });
6853
- try {
6854
- const result = await connection.executeQuery(compiled);
6855
- await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, "inTx", queryId);
6856
- await this.driver.commitTransaction(connection);
6857
- await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, "outTx", queryId);
6858
- return {
6859
- result
6860
- };
6861
- } catch (err) {
6862
- await this.driver.rollbackTransaction(connection);
6863
- throw err;
6864
- }
6865
- }
6866
- });
6867
- } catch (err) {
6868
- const message = `Failed to execute query: ${err}, sql: ${compiled?.sql}`;
6869
- throw new QueryError(message, err);
6984
+ if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) {
6985
+ const finalQuery2 = this.nameMapper.transformNode(query);
6986
+ compiled = this.compileQuery(finalQuery2);
6987
+ if (parameters) {
6988
+ compiled = {
6989
+ ...compiled,
6990
+ parameters
6991
+ };
6992
+ }
6993
+ return connection.executeQuery(compiled);
6870
6994
  }
6995
+ if ((InsertQueryNode2.is(query) || UpdateQueryNode2.is(query)) && this.hasEntityMutationPluginsWithAfterMutationHooks) {
6996
+ query = {
6997
+ ...query,
6998
+ returning: ReturningNode2.create([
6999
+ SelectionNode4.createSelectAll()
7000
+ ])
7001
+ };
7002
+ }
7003
+ const finalQuery = this.nameMapper.transformNode(query);
7004
+ compiled = this.compileQuery(finalQuery);
7005
+ if (parameters) {
7006
+ compiled = {
7007
+ ...compiled,
7008
+ parameters
7009
+ };
7010
+ }
7011
+ const currentlyInTx = this.driver.isTransactionConnection(connection);
7012
+ const connectionClient = this.createClientForConnection(connection, currentlyInTx);
7013
+ const mutationInfo = this.getMutationInfo(finalQuery);
7014
+ let beforeMutationEntities;
7015
+ const loadBeforeMutationEntities = /* @__PURE__ */ __name(async () => {
7016
+ if (beforeMutationEntities === void 0 && (UpdateQueryNode2.is(query) || DeleteQueryNode2.is(query))) {
7017
+ beforeMutationEntities = await this.loadEntities(mutationInfo.model, mutationInfo.where, connection);
7018
+ }
7019
+ return beforeMutationEntities;
7020
+ }, "loadBeforeMutationEntities");
7021
+ await this.callBeforeMutationHooks(finalQuery, mutationInfo, loadBeforeMutationEntities, connectionClient, queryId);
7022
+ const result = await connection.executeQuery(compiled);
7023
+ if (!this.driver.isTransactionConnection(connection)) {
7024
+ await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, "all", queryId);
7025
+ } else {
7026
+ await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, "inTx", queryId);
7027
+ this.driver.registerTransactionCommitCallback(connection, () => this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, "outTx", queryId));
7028
+ }
7029
+ return result;
6871
7030
  }
6872
7031
  createClientForConnection(connection, inTx) {
6873
7032
  const innerExecutor = this.withConnectionProvider(new SingleConnectionProvider(connection));
@@ -6884,9 +7043,6 @@ var ZenStackQueryExecutor = class _ZenStackQueryExecutor extends DefaultQueryExe
6884
7043
  get hasEntityMutationPluginsWithAfterMutationHooks() {
6885
7044
  return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation?.afterEntityMutation);
6886
7045
  }
6887
- get hasPluginRequestingAfterMutationWithinTransaction() {
6888
- return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation?.runAfterMutationWithinTransaction);
6889
- }
6890
7046
  isMutationNode(queryNode) {
6891
7047
  return InsertQueryNode2.is(queryNode) || UpdateQueryNode2.is(queryNode) || DeleteQueryNode2.is(queryNode);
6892
7048
  }
@@ -6917,7 +7073,7 @@ var ZenStackQueryExecutor = class _ZenStackQueryExecutor extends DefaultQueryExe
6917
7073
  return newExecutor;
6918
7074
  }
6919
7075
  getMutationModel(queryNode) {
6920
- return match16(queryNode).when(InsertQueryNode2.is, (node) => {
7076
+ return match17(queryNode).when(InsertQueryNode2.is, (node) => {
6921
7077
  invariant12(node.into, "InsertQueryNode must have an into clause");
6922
7078
  return node.into.table.identifier.name;
6923
7079
  }).when(UpdateQueryNode2.is, (node) => {
@@ -7028,7 +7184,7 @@ __export(functions_exports, {
7028
7184
  });
7029
7185
  import { invariant as invariant13, lowerCaseFirst, upperCaseFirst } from "@zenstackhq/common-helpers";
7030
7186
  import { sql as sql8, ValueNode as ValueNode5 } from "kysely";
7031
- import { match as match17 } from "ts-pattern";
7187
+ import { match as match18 } from "ts-pattern";
7032
7188
  var contains = /* @__PURE__ */ __name((eb, args, context) => textMatch(eb, args, context, "contains"), "contains");
7033
7189
  var search = /* @__PURE__ */ __name((_eb, _args) => {
7034
7190
  throw new Error(`"search" function is not implemented yet`);
@@ -7065,7 +7221,7 @@ var textMatch = /* @__PURE__ */ __name((eb, args, { dialect }, method) => {
7065
7221
  } else {
7066
7222
  op = "like";
7067
7223
  }
7068
- searchExpr = match17(method).with("contains", () => eb.fn("CONCAT", [
7224
+ searchExpr = match18(method).with("contains", () => eb.fn("CONCAT", [
7069
7225
  sql8.lit("%"),
7070
7226
  sql8`CAST(${searchExpr} as text)`,
7071
7227
  sql8.lit("%")
@@ -7137,7 +7293,7 @@ var currentOperation = /* @__PURE__ */ __name((_eb, args, { operation }) => {
7137
7293
  function processCasing(casing, result, model) {
7138
7294
  const opNode = casing.toOperationNode();
7139
7295
  invariant13(ValueNode5.is(opNode) && typeof opNode.value === "string", '"casting" parameter must be a string value');
7140
- result = match17(opNode.value).with("original", () => model).with("upper", () => result.toUpperCase()).with("lower", () => result.toLowerCase()).with("capitalize", () => upperCaseFirst(result)).with("uncapitalize", () => lowerCaseFirst(result)).otherwise(() => {
7296
+ result = match18(opNode.value).with("original", () => model).with("upper", () => result.toUpperCase()).with("lower", () => result.toLowerCase()).with("capitalize", () => upperCaseFirst(result)).with("uncapitalize", () => lowerCaseFirst(result)).otherwise(() => {
7141
7297
  throw new Error(`Invalid casing value: ${opNode.value}. Must be "original", "upper", "lower", "capitalize", or "uncapitalize".`);
7142
7298
  });
7143
7299
  return result;
@@ -7157,7 +7313,7 @@ __name(readBoolean, "readBoolean");
7157
7313
  import { invariant as invariant14 } from "@zenstackhq/common-helpers";
7158
7314
  import { sql as sql9 } from "kysely";
7159
7315
  import toposort from "toposort";
7160
- import { match as match18 } from "ts-pattern";
7316
+ import { match as match19 } from "ts-pattern";
7161
7317
  var SchemaDbPusher = class {
7162
7318
  static {
7163
7319
  __name(this, "SchemaDbPusher");
@@ -7323,7 +7479,7 @@ var SchemaDbPusher = class {
7323
7479
  return "jsonb";
7324
7480
  }
7325
7481
  const type = fieldDef.type;
7326
- const result = match18(type).with("String", () => "text").with("Boolean", () => "boolean").with("Int", () => "integer").with("Float", () => "real").with("BigInt", () => "bigint").with("Decimal", () => "decimal").with("DateTime", () => "timestamp").with("Bytes", () => this.schema.provider.type === "postgresql" ? "bytea" : "blob").with("Json", () => "jsonb").otherwise(() => {
7482
+ const result = match19(type).with("String", () => "text").with("Boolean", () => "boolean").with("Int", () => "integer").with("Float", () => "real").with("BigInt", () => "bigint").with("Decimal", () => "decimal").with("DateTime", () => "timestamp").with("Bytes", () => this.schema.provider.type === "postgresql" ? "bytea" : "blob").with("Json", () => "jsonb").otherwise(() => {
7327
7483
  throw new Error(`Unsupported field type: ${type}`);
7328
7484
  });
7329
7485
  if (fieldDef.array) {
@@ -7357,7 +7513,7 @@ var SchemaDbPusher = class {
7357
7513
  return table;
7358
7514
  }
7359
7515
  mapCascadeAction(action) {
7360
- return match18(action).with("SetNull", () => "set null").with("Cascade", () => "cascade").with("Restrict", () => "restrict").with("NoAction", () => "no action").with("SetDefault", () => "set default").exhaustive();
7516
+ return match19(action).with("SetNull", () => "set null").with("Cascade", () => "cascade").with("Restrict", () => "restrict").with("NoAction", () => "no action").with("SetDefault", () => "set default").exhaustive();
7361
7517
  }
7362
7518
  };
7363
7519
 
@@ -7856,6 +8012,7 @@ export {
7856
8012
  NotFoundError,
7857
8013
  QueryError,
7858
8014
  ZenStackClient,
8015
+ ZenStackError,
7859
8016
  definePlugin,
7860
8017
  sql11 as sql
7861
8018
  };