@zenstackhq/plugin-policy 3.3.0-beta.2 → 3.3.0-beta.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +268 -121
- package/dist/index.cjs.map +1 -1
- package/dist/index.js +269 -122
- package/dist/index.js.map +1 -1
- package/package.json +6 -6
- package/dist/index.mjs +0 -1798
- package/dist/index.mjs.map +0 -1
package/dist/index.js
CHANGED
|
@@ -10,7 +10,7 @@ import { ExpressionWrapper as ExpressionWrapper2, ValueNode as ValueNode4 } from
|
|
|
10
10
|
import { invariant as invariant3 } from "@zenstackhq/common-helpers";
|
|
11
11
|
import { getCrudDialect as getCrudDialect2, QueryUtils as QueryUtils2, RejectedByPolicyReason, SchemaUtils as SchemaUtils2 } from "@zenstackhq/orm";
|
|
12
12
|
import { ExpressionUtils as ExpressionUtils4 } from "@zenstackhq/orm/schema";
|
|
13
|
-
import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode3, DeleteQueryNode, expressionBuilder as expressionBuilder2, ExpressionWrapper, FromNode as FromNode2,
|
|
13
|
+
import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode3, DeleteQueryNode, expressionBuilder as expressionBuilder2, ExpressionWrapper, FromNode as FromNode2, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, ReferenceNode as ReferenceNode3, ReturningNode, SelectAllNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql, TableNode as TableNode3, UpdateQueryNode, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
|
|
14
14
|
import { match as match3 } from "ts-pattern";
|
|
15
15
|
|
|
16
16
|
// src/column-collector.ts
|
|
@@ -41,14 +41,14 @@ import { match as match2 } from "ts-pattern";
|
|
|
41
41
|
|
|
42
42
|
// src/expression-evaluator.ts
|
|
43
43
|
import { invariant } from "@zenstackhq/common-helpers";
|
|
44
|
-
import { match } from "ts-pattern";
|
|
45
44
|
import { ExpressionUtils } from "@zenstackhq/orm/schema";
|
|
45
|
+
import { match } from "ts-pattern";
|
|
46
46
|
var ExpressionEvaluator = class {
|
|
47
47
|
static {
|
|
48
48
|
__name(this, "ExpressionEvaluator");
|
|
49
49
|
}
|
|
50
50
|
evaluate(expression, context) {
|
|
51
|
-
const result = match(expression).when(ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(ExpressionUtils.isThis, () => context.thisValue).when(ExpressionUtils.isNull, () => null).exhaustive();
|
|
51
|
+
const result = match(expression).when(ExpressionUtils.isArray, (expr2) => this.evaluateArray(expr2, context)).when(ExpressionUtils.isBinary, (expr2) => this.evaluateBinary(expr2, context)).when(ExpressionUtils.isField, (expr2) => this.evaluateField(expr2, context)).when(ExpressionUtils.isLiteral, (expr2) => this.evaluateLiteral(expr2)).when(ExpressionUtils.isMember, (expr2) => this.evaluateMember(expr2, context)).when(ExpressionUtils.isUnary, (expr2) => this.evaluateUnary(expr2, context)).when(ExpressionUtils.isCall, (expr2) => this.evaluateCall(expr2, context)).when(ExpressionUtils.isBinding, (expr2) => this.evaluateBinding(expr2, context)).when(ExpressionUtils.isThis, () => context.thisValue).when(ExpressionUtils.isNull, () => null).exhaustive();
|
|
52
52
|
return result ?? null;
|
|
53
53
|
}
|
|
54
54
|
evaluateCall(expr2, context) {
|
|
@@ -72,6 +72,9 @@ var ExpressionEvaluator = class {
|
|
|
72
72
|
return expr2.value;
|
|
73
73
|
}
|
|
74
74
|
evaluateField(expr2, context) {
|
|
75
|
+
if (context.bindingScope && expr2.field in context.bindingScope) {
|
|
76
|
+
return context.bindingScope[expr2.field];
|
|
77
|
+
}
|
|
75
78
|
return context.thisValue?.[expr2.field];
|
|
76
79
|
}
|
|
77
80
|
evaluateArray(expr2, context) {
|
|
@@ -105,15 +108,33 @@ var ExpressionEvaluator = class {
|
|
|
105
108
|
invariant(Array.isArray(left), "expected array");
|
|
106
109
|
return match(op).with("?", () => left.some((item) => this.evaluate(expr2.right, {
|
|
107
110
|
...context,
|
|
108
|
-
thisValue: item
|
|
111
|
+
thisValue: item,
|
|
112
|
+
bindingScope: expr2.binding ? {
|
|
113
|
+
...context.bindingScope ?? {},
|
|
114
|
+
[expr2.binding]: item
|
|
115
|
+
} : context.bindingScope
|
|
109
116
|
}))).with("!", () => left.every((item) => this.evaluate(expr2.right, {
|
|
110
117
|
...context,
|
|
111
|
-
thisValue: item
|
|
118
|
+
thisValue: item,
|
|
119
|
+
bindingScope: expr2.binding ? {
|
|
120
|
+
...context.bindingScope ?? {},
|
|
121
|
+
[expr2.binding]: item
|
|
122
|
+
} : context.bindingScope
|
|
112
123
|
}))).with("^", () => !left.some((item) => this.evaluate(expr2.right, {
|
|
113
124
|
...context,
|
|
114
|
-
thisValue: item
|
|
125
|
+
thisValue: item,
|
|
126
|
+
bindingScope: expr2.binding ? {
|
|
127
|
+
...context.bindingScope ?? {},
|
|
128
|
+
[expr2.binding]: item
|
|
129
|
+
} : context.bindingScope
|
|
115
130
|
}))).exhaustive();
|
|
116
131
|
}
|
|
132
|
+
evaluateBinding(expr2, context) {
|
|
133
|
+
if (!context.bindingScope || !(expr2.name in context.bindingScope)) {
|
|
134
|
+
throw new Error(`Unresolved binding: ${expr2.name}`);
|
|
135
|
+
}
|
|
136
|
+
return context.bindingScope[expr2.name];
|
|
137
|
+
}
|
|
117
138
|
};
|
|
118
139
|
|
|
119
140
|
// src/types.ts
|
|
@@ -128,11 +149,11 @@ import { ORMError, ORMErrorReason } from "@zenstackhq/orm";
|
|
|
128
149
|
import { ExpressionUtils as ExpressionUtils2 } from "@zenstackhq/orm/schema";
|
|
129
150
|
import { AliasNode, AndNode, BinaryOperationNode, ColumnNode, FunctionNode, OperatorNode, OrNode, ParensNode, ReferenceNode, TableNode, UnaryOperationNode, ValueNode } from "kysely";
|
|
130
151
|
function trueNode(dialect) {
|
|
131
|
-
return ValueNode.createImmediate(dialect.
|
|
152
|
+
return ValueNode.createImmediate(dialect.transformInput(true, "Boolean", false));
|
|
132
153
|
}
|
|
133
154
|
__name(trueNode, "trueNode");
|
|
134
155
|
function falseNode(dialect) {
|
|
135
|
-
return ValueNode.createImmediate(dialect.
|
|
156
|
+
return ValueNode.createImmediate(dialect.transformInput(false, "Boolean", false));
|
|
136
157
|
}
|
|
137
158
|
__name(falseNode, "falseNode");
|
|
138
159
|
function isTrueNode(node) {
|
|
@@ -426,7 +447,8 @@ var ExpressionTransformer = class {
|
|
|
426
447
|
const evaluator = new ExpressionEvaluator();
|
|
427
448
|
const receiver = evaluator.evaluate(expr2.left, {
|
|
428
449
|
thisValue: context.contextValue,
|
|
429
|
-
auth: this.auth
|
|
450
|
+
auth: this.auth,
|
|
451
|
+
bindingScope: this.getEvaluationBindingScope(context.bindingScope)
|
|
430
452
|
});
|
|
431
453
|
const baseType = this.isAuthMember(expr2.left) ? this.authType : context.modelOrType;
|
|
432
454
|
const memberType = this.getMemberType(baseType, expr2.left);
|
|
@@ -442,18 +464,31 @@ var ExpressionTransformer = class {
|
|
|
442
464
|
invariant2(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr2.left)}`);
|
|
443
465
|
newContextModel = fieldDef.type;
|
|
444
466
|
} else {
|
|
445
|
-
invariant2(ExpressionUtils3.isMember(expr2.left) && ExpressionUtils3.isField(expr2.left.receiver), "left operand must be member access with field receiver");
|
|
446
|
-
|
|
447
|
-
|
|
467
|
+
invariant2(ExpressionUtils3.isMember(expr2.left) && (ExpressionUtils3.isField(expr2.left.receiver) || ExpressionUtils3.isBinding(expr2.left.receiver)), "left operand must be member access with field receiver");
|
|
468
|
+
if (ExpressionUtils3.isField(expr2.left.receiver)) {
|
|
469
|
+
const fieldDef2 = QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
|
|
470
|
+
newContextModel = fieldDef2.type;
|
|
471
|
+
} else {
|
|
472
|
+
const binding = this.requireBindingScope(expr2.left.receiver, context);
|
|
473
|
+
newContextModel = binding.type;
|
|
474
|
+
}
|
|
448
475
|
for (const member of expr2.left.members) {
|
|
449
476
|
const memberDef = QueryUtils.requireField(this.schema, newContextModel, member);
|
|
450
477
|
newContextModel = memberDef.type;
|
|
451
478
|
}
|
|
452
479
|
}
|
|
480
|
+
const bindingScope = expr2.binding ? {
|
|
481
|
+
...context.bindingScope ?? {},
|
|
482
|
+
[expr2.binding]: {
|
|
483
|
+
type: newContextModel,
|
|
484
|
+
alias: newContextModel
|
|
485
|
+
}
|
|
486
|
+
} : context.bindingScope;
|
|
453
487
|
let predicateFilter = this.transform(expr2.right, {
|
|
454
488
|
...context,
|
|
455
489
|
modelOrType: newContextModel,
|
|
456
|
-
alias: void 0
|
|
490
|
+
alias: void 0,
|
|
491
|
+
bindingScope
|
|
457
492
|
});
|
|
458
493
|
if (expr2.op === "!") {
|
|
459
494
|
predicateFilter = logicalNot(this.dialect, predicateFilter);
|
|
@@ -480,18 +515,30 @@ var ExpressionTransformer = class {
|
|
|
480
515
|
if (!visitor.find(expr2.right)) {
|
|
481
516
|
const value = new ExpressionEvaluator().evaluate(expr2, {
|
|
482
517
|
auth: this.auth,
|
|
483
|
-
thisValue: context.contextValue
|
|
518
|
+
thisValue: context.contextValue,
|
|
519
|
+
bindingScope: this.getEvaluationBindingScope(context.bindingScope)
|
|
484
520
|
});
|
|
485
521
|
return this.transformValue(value, "Boolean");
|
|
486
522
|
} else {
|
|
487
523
|
invariant2(Array.isArray(receiver), "array value is expected");
|
|
488
|
-
const components = receiver.map((item) =>
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
524
|
+
const components = receiver.map((item) => {
|
|
525
|
+
const bindingScope = expr2.binding ? {
|
|
526
|
+
...context.bindingScope ?? {},
|
|
527
|
+
[expr2.binding]: {
|
|
528
|
+
type: context.modelOrType,
|
|
529
|
+
alias: context.thisAlias ?? context.modelOrType,
|
|
530
|
+
value: item
|
|
531
|
+
}
|
|
532
|
+
} : context.bindingScope;
|
|
533
|
+
return this.transform(expr2.right, {
|
|
534
|
+
operation: context.operation,
|
|
535
|
+
thisType: context.thisType,
|
|
536
|
+
thisAlias: context.thisAlias,
|
|
537
|
+
modelOrType: context.modelOrType,
|
|
538
|
+
contextValue: item,
|
|
539
|
+
bindingScope
|
|
540
|
+
});
|
|
541
|
+
});
|
|
495
542
|
return match2(expr2.op).with("?", () => disjunction(this.dialect, components)).with("!", () => conjunction(this.dialect, components)).with("^", () => logicalNot(this.dialect, disjunction(this.dialect, components))).exhaustive();
|
|
496
543
|
}
|
|
497
544
|
}
|
|
@@ -558,7 +605,7 @@ var ExpressionTransformer = class {
|
|
|
558
605
|
} else if (value === false) {
|
|
559
606
|
return falseNode(this.dialect);
|
|
560
607
|
} else {
|
|
561
|
-
const transformed = this.dialect.
|
|
608
|
+
const transformed = this.dialect.transformInput(value, type, false) ?? null;
|
|
562
609
|
if (!Array.isArray(transformed)) {
|
|
563
610
|
return ValueNode2.createImmediate(transformed);
|
|
564
611
|
} else {
|
|
@@ -621,6 +668,12 @@ var ExpressionTransformer = class {
|
|
|
621
668
|
throw createUnsupportedError(`Unsupported argument expression: ${arg.kind}`);
|
|
622
669
|
}
|
|
623
670
|
_member(expr2, context) {
|
|
671
|
+
if (ExpressionUtils3.isBinding(expr2.receiver)) {
|
|
672
|
+
const scope = this.requireBindingScope(expr2.receiver, context);
|
|
673
|
+
if (scope.value !== void 0) {
|
|
674
|
+
return this.valueMemberAccess(scope.value, expr2, scope.type);
|
|
675
|
+
}
|
|
676
|
+
}
|
|
624
677
|
if (this.isAuthCall(expr2.receiver)) {
|
|
625
678
|
return this.valueMemberAccess(this.auth, expr2, this.authType);
|
|
626
679
|
}
|
|
@@ -629,9 +682,10 @@ var ExpressionTransformer = class {
|
|
|
629
682
|
invariant2(expr2.members.length === 1, "before() can only be followed by a scalar field access");
|
|
630
683
|
return ReferenceNode2.create(ColumnNode2.create(expr2.members[0]), TableNode2.create("$before"));
|
|
631
684
|
}
|
|
632
|
-
invariant2(ExpressionUtils3.isField(expr2.receiver) || ExpressionUtils3.isThis(expr2.receiver), 'expect receiver to be field expression or "this"');
|
|
685
|
+
invariant2(ExpressionUtils3.isField(expr2.receiver) || ExpressionUtils3.isThis(expr2.receiver) || ExpressionUtils3.isBinding(expr2.receiver), 'expect receiver to be field expression, collection predicate binding, or "this"');
|
|
633
686
|
let members = expr2.members;
|
|
634
687
|
let receiver;
|
|
688
|
+
let startType;
|
|
635
689
|
const { memberFilter, memberSelect, ...restContext } = context;
|
|
636
690
|
if (ExpressionUtils3.isThis(expr2.receiver)) {
|
|
637
691
|
if (expr2.members.length === 1) {
|
|
@@ -646,17 +700,40 @@ var ExpressionTransformer = class {
|
|
|
646
700
|
const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.thisType, expr2.members[0]);
|
|
647
701
|
receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, restContext);
|
|
648
702
|
members = expr2.members.slice(1);
|
|
703
|
+
startType = firstMemberFieldDef.type;
|
|
704
|
+
}
|
|
705
|
+
} else if (ExpressionUtils3.isBinding(expr2.receiver)) {
|
|
706
|
+
if (expr2.members.length === 1) {
|
|
707
|
+
const bindingScope = this.requireBindingScope(expr2.receiver, context);
|
|
708
|
+
return this._field(ExpressionUtils3.field(expr2.members[0]), {
|
|
709
|
+
...context,
|
|
710
|
+
modelOrType: bindingScope.type,
|
|
711
|
+
alias: bindingScope.alias,
|
|
712
|
+
thisType: context.thisType,
|
|
713
|
+
contextValue: void 0
|
|
714
|
+
});
|
|
715
|
+
} else {
|
|
716
|
+
const bindingScope = this.requireBindingScope(expr2.receiver, context);
|
|
717
|
+
const firstMemberFieldDef = QueryUtils.requireField(this.schema, bindingScope.type, expr2.members[0]);
|
|
718
|
+
receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, {
|
|
719
|
+
...restContext,
|
|
720
|
+
modelOrType: bindingScope.type,
|
|
721
|
+
alias: bindingScope.alias
|
|
722
|
+
});
|
|
723
|
+
members = expr2.members.slice(1);
|
|
724
|
+
startType = firstMemberFieldDef.type;
|
|
649
725
|
}
|
|
650
726
|
} else {
|
|
651
727
|
receiver = this.transform(expr2.receiver, restContext);
|
|
652
728
|
}
|
|
653
729
|
invariant2(SelectQueryNode.is(receiver), "expected receiver to be select query");
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
730
|
+
if (startType === void 0) {
|
|
731
|
+
if (ExpressionUtils3.isField(expr2.receiver)) {
|
|
732
|
+
const receiverField = QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
|
|
733
|
+
startType = receiverField.type;
|
|
734
|
+
} else {
|
|
735
|
+
startType = context.thisType;
|
|
736
|
+
}
|
|
660
737
|
}
|
|
661
738
|
const memberFields = [];
|
|
662
739
|
let currType = startType;
|
|
@@ -707,6 +784,11 @@ var ExpressionTransformer = class {
|
|
|
707
784
|
]
|
|
708
785
|
};
|
|
709
786
|
}
|
|
787
|
+
requireBindingScope(expr2, context) {
|
|
788
|
+
const binding = context.bindingScope?.[expr2.name];
|
|
789
|
+
invariant2(binding, `binding not found: ${expr2.name}`);
|
|
790
|
+
return binding;
|
|
791
|
+
}
|
|
710
792
|
valueMemberAccess(receiver, expr2, receiverType) {
|
|
711
793
|
if (!receiver) {
|
|
712
794
|
return ValueNode2.createImmediate(null);
|
|
@@ -772,6 +854,19 @@ var ExpressionTransformer = class {
|
|
|
772
854
|
}
|
|
773
855
|
return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel);
|
|
774
856
|
}
|
|
857
|
+
// convert transformer's binding scope to equivalent expression evaluator binding scope
|
|
858
|
+
getEvaluationBindingScope(scope) {
|
|
859
|
+
if (!scope) {
|
|
860
|
+
return void 0;
|
|
861
|
+
}
|
|
862
|
+
const result = {};
|
|
863
|
+
for (const [key, value] of Object.entries(scope)) {
|
|
864
|
+
if (value.value !== void 0) {
|
|
865
|
+
result[key] = value.value;
|
|
866
|
+
}
|
|
867
|
+
}
|
|
868
|
+
return Object.keys(result).length > 0 ? result : void 0;
|
|
869
|
+
}
|
|
775
870
|
buildDelegateBaseFieldSelect(model, modelAlias, field, baseModel) {
|
|
776
871
|
const idFields = QueryUtils.requireIdFields(this.client.$schema, model);
|
|
777
872
|
return {
|
|
@@ -896,6 +991,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
896
991
|
}
|
|
897
992
|
client;
|
|
898
993
|
dialect;
|
|
994
|
+
eb = expressionBuilder2();
|
|
899
995
|
constructor(client) {
|
|
900
996
|
super(), this.client = client;
|
|
901
997
|
this.dialect = getCrudDialect2(this.client.$schema, this.client.$options);
|
|
@@ -919,52 +1015,21 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
919
1015
|
if (UpdateQueryNode.is(node)) {
|
|
920
1016
|
await this.preUpdateCheck(mutationModel, node, proceed);
|
|
921
1017
|
}
|
|
922
|
-
const
|
|
1018
|
+
const needsPostUpdateCheck = UpdateQueryNode.is(node) && this.hasPostUpdatePolicies(mutationModel);
|
|
923
1019
|
let beforeUpdateInfo;
|
|
924
|
-
if (
|
|
925
|
-
beforeUpdateInfo = await this.loadBeforeUpdateEntities(
|
|
1020
|
+
if (needsPostUpdateCheck) {
|
|
1021
|
+
beforeUpdateInfo = await this.loadBeforeUpdateEntities(
|
|
1022
|
+
mutationModel,
|
|
1023
|
+
node.where,
|
|
1024
|
+
proceed,
|
|
1025
|
+
// force load pre-update entities if dialect doesn't support returning,
|
|
1026
|
+
// so we can rely on pre-update ids to read back updated entities
|
|
1027
|
+
!this.dialect.supportsReturning
|
|
1028
|
+
);
|
|
926
1029
|
}
|
|
927
1030
|
const result = await proceed(this.transformNode(node));
|
|
928
|
-
if (
|
|
929
|
-
|
|
930
|
-
invariant3(beforeUpdateInfo.rows.length === result.rows.length);
|
|
931
|
-
const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
|
|
932
|
-
for (const postRow of result.rows) {
|
|
933
|
-
const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
|
|
934
|
-
if (!beforeRow) {
|
|
935
|
-
throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match by id. If you have post-update policies on a model, updating id fields is not supported.");
|
|
936
|
-
}
|
|
937
|
-
}
|
|
938
|
-
}
|
|
939
|
-
const idConditions = this.buildIdConditions(mutationModel, result.rows);
|
|
940
|
-
const postUpdateFilter = this.buildPolicyFilter(mutationModel, void 0, "post-update");
|
|
941
|
-
const eb = expressionBuilder2();
|
|
942
|
-
const beforeUpdateTable = beforeUpdateInfo ? {
|
|
943
|
-
kind: "SelectQueryNode",
|
|
944
|
-
from: FromNode2.create([
|
|
945
|
-
ParensNode2.create(ValuesNode.create(beforeUpdateInfo.rows.map((r) => PrimitiveValueListNode.create(beforeUpdateInfo.fields.map((f) => r[f])))))
|
|
946
|
-
]),
|
|
947
|
-
selections: beforeUpdateInfo.fields.map((name, index) => {
|
|
948
|
-
const def = QueryUtils2.requireField(this.client.$schema, mutationModel, name);
|
|
949
|
-
const castedColumnRef = sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
|
|
950
|
-
return SelectionNode2.create(castedColumnRef.toOperationNode());
|
|
951
|
-
})
|
|
952
|
-
} : void 0;
|
|
953
|
-
const postUpdateQuery = eb.selectFrom(mutationModel).select(() => [
|
|
954
|
-
eb(eb.fn("COUNT", [
|
|
955
|
-
eb.lit(1)
|
|
956
|
-
]), "=", result.rows.length).as("$condition")
|
|
957
|
-
]).where(() => new ExpressionWrapper(conjunction(this.dialect, [
|
|
958
|
-
idConditions,
|
|
959
|
-
postUpdateFilter
|
|
960
|
-
]))).$if(!!beforeUpdateInfo, (qb) => qb.leftJoin(() => new ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
|
|
961
|
-
const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
|
|
962
|
-
return idFields.reduce((acc, f) => acc.onRef(`${mutationModel}.${f}`, "=", `$before.${f}`), join);
|
|
963
|
-
}));
|
|
964
|
-
const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
|
|
965
|
-
if (!postUpdateResult.rows[0]?.$condition) {
|
|
966
|
-
throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
|
|
967
|
-
}
|
|
1031
|
+
if ((result.numAffectedRows ?? 0) > 0 && needsPostUpdateCheck) {
|
|
1032
|
+
await this.postUpdateCheck(mutationModel, beforeUpdateInfo, result, proceed);
|
|
968
1033
|
}
|
|
969
1034
|
if (!node.returning || this.onlyReturningId(node)) {
|
|
970
1035
|
return this.postProcessMutationResult(result, node);
|
|
@@ -1003,12 +1068,69 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1003
1068
|
modelLevelFilter,
|
|
1004
1069
|
node.where?.where ?? trueNode(this.dialect)
|
|
1005
1070
|
]);
|
|
1006
|
-
const preUpdateCheckQuery =
|
|
1071
|
+
const preUpdateCheckQuery = this.eb.selectFrom(mutationModel).select((eb) => eb.fn.coalesce(eb.fn.sum(this.dialect.castInt(new ExpressionWrapper(logicalNot(this.dialect, fieldLevelFilter)))), eb.lit(0)).as("$filteredCount")).where(() => new ExpressionWrapper(updateFilter));
|
|
1007
1072
|
const preUpdateResult = await proceed(preUpdateCheckQuery.toOperationNode());
|
|
1008
1073
|
if (preUpdateResult.rows[0].$filteredCount > 0) {
|
|
1009
1074
|
throw createRejectedByPolicyError(mutationModel, RejectedByPolicyReason.NO_ACCESS, "some rows cannot be updated due to field policies");
|
|
1010
1075
|
}
|
|
1011
1076
|
}
|
|
1077
|
+
async postUpdateCheck(model, beforeUpdateInfo, updateResult, proceed) {
|
|
1078
|
+
let postUpdateRows;
|
|
1079
|
+
if (this.dialect.supportsReturning) {
|
|
1080
|
+
postUpdateRows = updateResult.rows;
|
|
1081
|
+
} else {
|
|
1082
|
+
invariant3(beforeUpdateInfo, "beforeUpdateInfo must be defined for dialects not supporting returning");
|
|
1083
|
+
const idConditions2 = this.buildIdConditions(model, beforeUpdateInfo.rows);
|
|
1084
|
+
const idFields = QueryUtils2.requireIdFields(this.client.$schema, model);
|
|
1085
|
+
const postUpdateQuery2 = {
|
|
1086
|
+
kind: "SelectQueryNode",
|
|
1087
|
+
from: FromNode2.create([
|
|
1088
|
+
TableNode3.create(model)
|
|
1089
|
+
]),
|
|
1090
|
+
where: WhereNode2.create(idConditions2),
|
|
1091
|
+
selections: idFields.map((field) => SelectionNode2.create(ColumnNode3.create(field)))
|
|
1092
|
+
};
|
|
1093
|
+
const postUpdateQueryResult = await proceed(postUpdateQuery2);
|
|
1094
|
+
postUpdateRows = postUpdateQueryResult.rows;
|
|
1095
|
+
}
|
|
1096
|
+
if (beforeUpdateInfo) {
|
|
1097
|
+
if (beforeUpdateInfo.rows.length !== postUpdateRows.length) {
|
|
1098
|
+
throw createRejectedByPolicyError(model, RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
|
|
1099
|
+
}
|
|
1100
|
+
const idFields = QueryUtils2.requireIdFields(this.client.$schema, model);
|
|
1101
|
+
for (const postRow of postUpdateRows) {
|
|
1102
|
+
const beforeRow = beforeUpdateInfo.rows.find((r) => idFields.every((f) => r[f] === postRow[f]));
|
|
1103
|
+
if (!beforeRow) {
|
|
1104
|
+
throw createRejectedByPolicyError(model, RejectedByPolicyReason.OTHER, "Before-update and after-update rows do not match. If you have post-update policies on a model, updating id fields is not supported.");
|
|
1105
|
+
}
|
|
1106
|
+
}
|
|
1107
|
+
}
|
|
1108
|
+
const idConditions = this.buildIdConditions(model, postUpdateRows);
|
|
1109
|
+
const postUpdateFilter = this.buildPolicyFilter(model, void 0, "post-update");
|
|
1110
|
+
const eb = expressionBuilder2();
|
|
1111
|
+
const needsBeforeUpdateJoin = !!beforeUpdateInfo?.fields;
|
|
1112
|
+
let beforeUpdateTable = void 0;
|
|
1113
|
+
if (needsBeforeUpdateJoin) {
|
|
1114
|
+
const fieldDefs = beforeUpdateInfo.fields.map((name) => QueryUtils2.requireField(this.client.$schema, model, name));
|
|
1115
|
+
const rows = beforeUpdateInfo.rows.map((r) => beforeUpdateInfo.fields.map((f) => r[f]));
|
|
1116
|
+
beforeUpdateTable = this.dialect.buildValuesTableSelect(fieldDefs, rows).toOperationNode();
|
|
1117
|
+
}
|
|
1118
|
+
const postUpdateQuery = eb.selectFrom(model).select(() => [
|
|
1119
|
+
eb(eb.fn("COUNT", [
|
|
1120
|
+
eb.lit(1)
|
|
1121
|
+
]), "=", Number(updateResult.numAffectedRows ?? 0)).as("$condition")
|
|
1122
|
+
]).where(() => new ExpressionWrapper(conjunction(this.dialect, [
|
|
1123
|
+
idConditions,
|
|
1124
|
+
postUpdateFilter
|
|
1125
|
+
]))).$if(needsBeforeUpdateJoin, (qb) => qb.leftJoin(() => new ExpressionWrapper(beforeUpdateTable).as("$before"), (join) => {
|
|
1126
|
+
const idFields = QueryUtils2.requireIdFields(this.client.$schema, model);
|
|
1127
|
+
return idFields.reduce((acc, f) => acc.onRef(`${model}.${f}`, "=", `$before.${f}`), join);
|
|
1128
|
+
}));
|
|
1129
|
+
const postUpdateResult = await proceed(postUpdateQuery.toOperationNode());
|
|
1130
|
+
if (!postUpdateResult.rows[0]?.$condition) {
|
|
1131
|
+
throw createRejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS, "some or all updated rows failed to pass post-update policy check");
|
|
1132
|
+
}
|
|
1133
|
+
}
|
|
1012
1134
|
// #endregion
|
|
1013
1135
|
// #region Transformations
|
|
1014
1136
|
transformSelectQuery(node) {
|
|
@@ -1076,6 +1198,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1076
1198
|
};
|
|
1077
1199
|
}
|
|
1078
1200
|
transformInsertQuery(node) {
|
|
1201
|
+
let processedNode = node;
|
|
1079
1202
|
let onConflict = node.onConflict;
|
|
1080
1203
|
if (onConflict?.updates) {
|
|
1081
1204
|
const { mutationModel, alias } = this.getMutationModel(node);
|
|
@@ -1094,11 +1217,36 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1094
1217
|
updateWhere: WhereNode2.create(filter)
|
|
1095
1218
|
};
|
|
1096
1219
|
}
|
|
1220
|
+
processedNode = {
|
|
1221
|
+
...node,
|
|
1222
|
+
onConflict
|
|
1223
|
+
};
|
|
1224
|
+
}
|
|
1225
|
+
let onDuplicateKey = node.onDuplicateKey;
|
|
1226
|
+
if (onDuplicateKey?.updates) {
|
|
1227
|
+
const { mutationModel } = this.getMutationModel(node);
|
|
1228
|
+
const filterWithTableRef = this.buildPolicyFilter(mutationModel, void 0, "update");
|
|
1229
|
+
const filter = this.stripTableReferences(filterWithTableRef, mutationModel);
|
|
1230
|
+
const wrappedUpdates = onDuplicateKey.updates.map((update) => {
|
|
1231
|
+
const columnName = ColumnNode3.is(update.column) ? update.column.column.name : void 0;
|
|
1232
|
+
if (!columnName) {
|
|
1233
|
+
return update;
|
|
1234
|
+
}
|
|
1235
|
+
const wrappedValue = sql`IF(${new ExpressionWrapper(filter)}, ${new ExpressionWrapper(update.value)}, ${sql.ref(columnName)})`.toOperationNode();
|
|
1236
|
+
return {
|
|
1237
|
+
...update,
|
|
1238
|
+
value: wrappedValue
|
|
1239
|
+
};
|
|
1240
|
+
});
|
|
1241
|
+
onDuplicateKey = {
|
|
1242
|
+
...onDuplicateKey,
|
|
1243
|
+
updates: wrappedUpdates
|
|
1244
|
+
};
|
|
1245
|
+
processedNode = {
|
|
1246
|
+
...processedNode,
|
|
1247
|
+
onDuplicateKey
|
|
1248
|
+
};
|
|
1097
1249
|
}
|
|
1098
|
-
const processedNode = onConflict ? {
|
|
1099
|
-
...node,
|
|
1100
|
-
onConflict
|
|
1101
|
-
} : node;
|
|
1102
1250
|
const result = super.transformInsertQuery(processedNode);
|
|
1103
1251
|
let returning = result.returning;
|
|
1104
1252
|
if (returning) {
|
|
@@ -1126,7 +1274,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1126
1274
|
}
|
|
1127
1275
|
}
|
|
1128
1276
|
let returning = result.returning;
|
|
1129
|
-
if (returning || this.hasPostUpdatePolicies(mutationModel)) {
|
|
1277
|
+
if (this.dialect.supportsReturning && (returning || this.hasPostUpdatePolicies(mutationModel))) {
|
|
1130
1278
|
const idFields = QueryUtils2.requireIdFields(this.client.$schema, mutationModel);
|
|
1131
1279
|
returning = ReturningNode.create(idFields.map((f) => SelectionNode2.create(ColumnNode3.create(f))));
|
|
1132
1280
|
}
|
|
@@ -1163,9 +1311,9 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1163
1311
|
}
|
|
1164
1312
|
// #endregion
|
|
1165
1313
|
// #region post-update
|
|
1166
|
-
async loadBeforeUpdateEntities(model, where, proceed) {
|
|
1314
|
+
async loadBeforeUpdateEntities(model, where, proceed, forceLoad = false) {
|
|
1167
1315
|
const beforeUpdateAccessFields = this.getFieldsAccessForBeforeUpdatePolicies(model);
|
|
1168
|
-
if (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0) {
|
|
1316
|
+
if (!forceLoad && (!beforeUpdateAccessFields || beforeUpdateAccessFields.length === 0)) {
|
|
1169
1317
|
return void 0;
|
|
1170
1318
|
}
|
|
1171
1319
|
const policyFilter = this.buildPolicyFilter(model, model, "update");
|
|
@@ -1173,15 +1321,14 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1173
1321
|
where.where,
|
|
1174
1322
|
policyFilter
|
|
1175
1323
|
]) : policyFilter;
|
|
1324
|
+
const selections = beforeUpdateAccessFields ?? QueryUtils2.requireIdFields(this.client.$schema, model);
|
|
1176
1325
|
const query = {
|
|
1177
1326
|
kind: "SelectQueryNode",
|
|
1178
1327
|
from: FromNode2.create([
|
|
1179
1328
|
TableNode3.create(model)
|
|
1180
1329
|
]),
|
|
1181
1330
|
where: WhereNode2.create(combinedFilter),
|
|
1182
|
-
selections:
|
|
1183
|
-
...beforeUpdateAccessFields.map((f) => SelectionNode2.create(ColumnNode3.create(f)))
|
|
1184
|
-
]
|
|
1331
|
+
selections: selections.map((f) => SelectionNode2.create(ColumnNode3.create(f)))
|
|
1185
1332
|
};
|
|
1186
1333
|
const result = await proceed(query);
|
|
1187
1334
|
return {
|
|
@@ -1361,43 +1508,24 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1361
1508
|
}
|
|
1362
1509
|
}
|
|
1363
1510
|
async enforcePreCreatePolicyForOne(model, fields, values, proceed) {
|
|
1364
|
-
const allFields =
|
|
1511
|
+
const allFields = QueryUtils2.getModelFields(this.client.$schema, model, {
|
|
1512
|
+
inherited: true
|
|
1513
|
+
});
|
|
1365
1514
|
const allValues = [];
|
|
1366
|
-
for (const
|
|
1367
|
-
const index = fields.indexOf(name);
|
|
1515
|
+
for (const def of allFields) {
|
|
1516
|
+
const index = fields.indexOf(def.name);
|
|
1368
1517
|
if (index >= 0) {
|
|
1369
|
-
allValues.push(values[index]);
|
|
1518
|
+
allValues.push(new ExpressionWrapper(values[index]));
|
|
1370
1519
|
} else {
|
|
1371
|
-
allValues.push(
|
|
1520
|
+
allValues.push(this.eb.lit(null));
|
|
1372
1521
|
}
|
|
1373
1522
|
}
|
|
1374
|
-
const
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
from: FromNode2.create([
|
|
1378
|
-
AliasNode3.create(ParensNode2.create(ValuesNode.create([
|
|
1379
|
-
ValueListNode2.create(allValues)
|
|
1380
|
-
])), IdentifierNode2.create("$t"))
|
|
1381
|
-
]),
|
|
1382
|
-
selections: allFields.map(([name, def], index) => {
|
|
1383
|
-
const castedColumnRef = sql`CAST(${eb.ref(`column${index + 1}`)} as ${sql.raw(this.dialect.getFieldSqlType(def))})`.as(name);
|
|
1384
|
-
return SelectionNode2.create(castedColumnRef.toOperationNode());
|
|
1385
|
-
})
|
|
1386
|
-
};
|
|
1523
|
+
const valuesTable = this.dialect.buildValuesTableSelect(allFields, [
|
|
1524
|
+
allValues
|
|
1525
|
+
]);
|
|
1387
1526
|
const filter = this.buildPolicyFilter(model, void 0, "create");
|
|
1388
|
-
const preCreateCheck =
|
|
1389
|
-
|
|
1390
|
-
from: FromNode2.create([
|
|
1391
|
-
AliasNode3.create(constTable, IdentifierNode2.create(model))
|
|
1392
|
-
]),
|
|
1393
|
-
selections: [
|
|
1394
|
-
SelectionNode2.create(AliasNode3.create(BinaryOperationNode3.create(FunctionNode3.create("COUNT", [
|
|
1395
|
-
ValueNode3.createImmediate(1)
|
|
1396
|
-
]), OperatorNode3.create(">"), ValueNode3.createImmediate(0)), IdentifierNode2.create("$condition")))
|
|
1397
|
-
],
|
|
1398
|
-
where: WhereNode2.create(filter)
|
|
1399
|
-
};
|
|
1400
|
-
const result = await proceed(preCreateCheck);
|
|
1527
|
+
const preCreateCheck = this.eb.selectFrom(valuesTable.as(model)).select(this.eb(this.eb.fn.count(this.eb.lit(1)), ">", 0).as("$condition")).where(() => new ExpressionWrapper(filter));
|
|
1528
|
+
const result = await proceed(preCreateCheck.toOperationNode());
|
|
1401
1529
|
if (!result.rows[0]?.$condition) {
|
|
1402
1530
|
throw createRejectedByPolicyError(model, RejectedByPolicyReason.NO_ACCESS);
|
|
1403
1531
|
}
|
|
@@ -1422,18 +1550,18 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1422
1550
|
const fieldDef = QueryUtils2.requireField(this.client.$schema, model, fields[i]);
|
|
1423
1551
|
invariant3(item.kind === "ValueNode", "expecting a ValueNode");
|
|
1424
1552
|
result.push({
|
|
1425
|
-
node: ValueNode3.create(this.dialect.
|
|
1553
|
+
node: ValueNode3.create(this.dialect.transformInput(item.value, fieldDef.type, !!fieldDef.array)),
|
|
1426
1554
|
raw: item.value
|
|
1427
1555
|
});
|
|
1428
1556
|
} else {
|
|
1429
1557
|
let value = item;
|
|
1430
1558
|
if (!isImplicitManyToManyJoinTable) {
|
|
1431
1559
|
const fieldDef = QueryUtils2.requireField(this.client.$schema, model, fields[i]);
|
|
1432
|
-
value = this.dialect.
|
|
1560
|
+
value = this.dialect.transformInput(item, fieldDef.type, !!fieldDef.array);
|
|
1433
1561
|
}
|
|
1434
1562
|
if (Array.isArray(value)) {
|
|
1435
1563
|
result.push({
|
|
1436
|
-
node:
|
|
1564
|
+
node: this.dialect.buildArrayLiteralSQL(value).toOperationNode(),
|
|
1437
1565
|
raw: value
|
|
1438
1566
|
});
|
|
1439
1567
|
} else {
|
|
@@ -1681,11 +1809,10 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1681
1809
|
return void 0;
|
|
1682
1810
|
}
|
|
1683
1811
|
const checkForOperation = operation === "read" ? "read" : "update";
|
|
1684
|
-
const eb = expressionBuilder2();
|
|
1685
1812
|
const joinTable = alias ?? tableName;
|
|
1686
|
-
const aQuery = eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
|
|
1687
|
-
const bQuery = eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
|
|
1688
|
-
return eb.and([
|
|
1813
|
+
const aQuery = this.eb.selectFrom(m2m.firstModel).whereRef(`${m2m.firstModel}.${m2m.firstIdField}`, "=", `${joinTable}.A`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.firstModel, void 0, checkForOperation)).as("$conditionA"));
|
|
1814
|
+
const bQuery = this.eb.selectFrom(m2m.secondModel).whereRef(`${m2m.secondModel}.${m2m.secondIdField}`, "=", `${joinTable}.B`).select(() => new ExpressionWrapper(this.buildPolicyFilter(m2m.secondModel, void 0, checkForOperation)).as("$conditionB"));
|
|
1815
|
+
return this.eb.and([
|
|
1689
1816
|
aQuery,
|
|
1690
1817
|
bQuery
|
|
1691
1818
|
]).toOperationNode();
|
|
@@ -1716,6 +1843,26 @@ var PolicyHandler = class extends OperationNodeTransformer {
|
|
|
1716
1843
|
};
|
|
1717
1844
|
}
|
|
1718
1845
|
}
|
|
1846
|
+
// strips table references from an OperationNode
|
|
1847
|
+
stripTableReferences(node, modelName) {
|
|
1848
|
+
return new TableReferenceStripper().strip(node, modelName);
|
|
1849
|
+
}
|
|
1850
|
+
};
|
|
1851
|
+
var TableReferenceStripper = class TableReferenceStripper2 extends OperationNodeTransformer {
|
|
1852
|
+
static {
|
|
1853
|
+
__name(this, "TableReferenceStripper");
|
|
1854
|
+
}
|
|
1855
|
+
tableName = "";
|
|
1856
|
+
strip(node, tableName) {
|
|
1857
|
+
this.tableName = tableName;
|
|
1858
|
+
return this.transformNode(node);
|
|
1859
|
+
}
|
|
1860
|
+
transformReference(node) {
|
|
1861
|
+
if (ColumnNode3.is(node.column) && node.table?.table.identifier.name === this.tableName) {
|
|
1862
|
+
return ReferenceNode3.create(this.transformNode(node.column));
|
|
1863
|
+
}
|
|
1864
|
+
return super.transformReference(node);
|
|
1865
|
+
}
|
|
1719
1866
|
};
|
|
1720
1867
|
|
|
1721
1868
|
// src/functions.ts
|
|
@@ -1764,7 +1911,7 @@ var check = /* @__PURE__ */ __name((eb, args, { client, model, modelAlias, opera
|
|
|
1764
1911
|
const policyHandler = new PolicyHandler(client);
|
|
1765
1912
|
const op = arg2Node ? arg2Node.value : operation;
|
|
1766
1913
|
const policyCondition = policyHandler.buildPolicyFilter(relationModel, void 0, op);
|
|
1767
|
-
const result = eb.selectFrom(relationModel).where(joinCondition).select(new ExpressionWrapper2(policyCondition).as("$condition"));
|
|
1914
|
+
const result = eb.selectFrom(eb.selectFrom(relationModel).where(joinCondition).select(new ExpressionWrapper2(policyCondition).as("$condition")).as("$sub")).selectAll();
|
|
1768
1915
|
return result;
|
|
1769
1916
|
}, "check");
|
|
1770
1917
|
|