@zenstackhq/plugin-policy 3.0.0-beta.25 → 3.0.0-beta.27

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -8,7 +8,7 @@ import { ExpressionWrapper as ExpressionWrapper2, ValueNode as ValueNode4 } from
8
8
 
9
9
  // src/policy-handler.ts
10
10
  import { invariant as invariant3 } from "@zenstackhq/common-helpers";
11
- import { getCrudDialect as getCrudDialect2, QueryUtils as QueryUtils2, RejectedByPolicyReason, SchemaUtils } from "@zenstackhq/orm";
11
+ import { getCrudDialect as getCrudDialect2, QueryUtils as QueryUtils2, RejectedByPolicyReason, SchemaUtils as SchemaUtils2 } from "@zenstackhq/orm";
12
12
  import { ExpressionUtils as ExpressionUtils4 } from "@zenstackhq/orm/schema";
13
13
  import { AliasNode as AliasNode3, BinaryOperationNode as BinaryOperationNode3, ColumnNode as ColumnNode2, DeleteQueryNode, expressionBuilder as expressionBuilder2, ExpressionWrapper, FromNode as FromNode2, FunctionNode as FunctionNode3, IdentifierNode as IdentifierNode2, InsertQueryNode, OperationNodeTransformer, OperatorNode as OperatorNode3, ParensNode as ParensNode2, PrimitiveValueListNode, RawNode, ReferenceNode as ReferenceNode3, ReturningNode, SelectAllNode, SelectionNode as SelectionNode2, SelectQueryNode as SelectQueryNode2, sql, TableNode as TableNode3, UpdateQueryNode, ValueListNode as ValueListNode2, ValueNode as ValueNode3, ValuesNode, WhereNode as WhereNode2 } from "kysely";
14
14
  import { match as match3 } from "ts-pattern";
@@ -34,7 +34,7 @@ var ColumnCollector = class extends KyselyUtils.DefaultOperationNodeVisitor {
34
34
 
35
35
  // src/expression-transformer.ts
36
36
  import { invariant as invariant2 } from "@zenstackhq/common-helpers";
37
- import { getCrudDialect, QueryUtils } from "@zenstackhq/orm";
37
+ import { getCrudDialect, QueryUtils, SchemaUtils } from "@zenstackhq/orm";
38
38
  import { ExpressionUtils as ExpressionUtils3 } from "@zenstackhq/orm/schema";
39
39
  import { AliasNode as AliasNode2, BinaryOperationNode as BinaryOperationNode2, ColumnNode, expressionBuilder, FromNode, FunctionNode as FunctionNode2, IdentifierNode, OperatorNode as OperatorNode2, ReferenceNode as ReferenceNode2, SelectionNode, SelectQueryNode, TableNode as TableNode2, ValueListNode, ValueNode as ValueNode2, WhereNode } from "kysely";
40
40
  import { match as match2 } from "ts-pattern";
@@ -83,6 +83,12 @@ var ExpressionEvaluator = class {
83
83
  }
84
84
  const left = this.evaluate(expr2.left, context);
85
85
  const right = this.evaluate(expr2.right, context);
86
+ if (![
87
+ "==",
88
+ "!="
89
+ ].includes(expr2.op) && (left === null || right === null)) {
90
+ return null;
91
+ }
86
92
  return match(expr2.op).with("==", () => left === right).with("!=", () => left !== right).with(">", () => left > right).with(">=", () => left >= right).with("<", () => left < right).with("<=", () => left <= right).with("&&", () => left && right).with("||", () => left || right).with("in", () => {
87
93
  const _right = right ?? [];
88
94
  invariant(Array.isArray(_right), 'expected array for "in" operator');
@@ -93,8 +99,8 @@ var ExpressionEvaluator = class {
93
99
  const op = expr2.op;
94
100
  invariant(op === "?" || op === "!" || op === "^", 'expected "?" or "!" or "^" operator');
95
101
  const left = this.evaluate(expr2.left, context);
96
- if (!left) {
97
- return false;
102
+ if (left === null || left === void 0) {
103
+ return null;
98
104
  }
99
105
  invariant(Array.isArray(left), "expected array");
100
106
  return match(op).with("?", () => left.some((item) => this.evaluate(expr2.right, {
@@ -110,6 +116,13 @@ var ExpressionEvaluator = class {
110
116
  }
111
117
  };
112
118
 
119
+ // src/types.ts
120
+ var CollectionPredicateOperator = [
121
+ "?",
122
+ "!",
123
+ "^"
124
+ ];
125
+
113
126
  // src/utils.ts
114
127
  import { ORMError, ORMErrorReason } from "@zenstackhq/orm";
115
128
  import { ExpressionUtils as ExpressionUtils2 } from "@zenstackhq/orm/schema";
@@ -286,7 +299,11 @@ var ExpressionTransformer = class {
286
299
  return ValueListNode.create(expr2.items.map((item) => this.transform(item, context)));
287
300
  }
288
301
  _field(expr2, context) {
289
- const fieldDef = QueryUtils.requireField(this.schema, context.model, expr2.field);
302
+ if (context.contextValue) {
303
+ const fieldDef2 = QueryUtils.requireField(this.schema, context.modelOrType, expr2.field);
304
+ return this.transformValue(context.contextValue[expr2.field], fieldDef2.type);
305
+ }
306
+ const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, expr2.field);
290
307
  if (!fieldDef.relation) {
291
308
  return this.createColumnRef(expr2.field, context);
292
309
  } else {
@@ -360,30 +377,39 @@ var ExpressionTransformer = class {
360
377
  }
361
378
  }
362
379
  transformNullCheck(expr2, operator) {
363
- invariant2(operator === "==" || operator === "!=", 'operator must be "==" or "!=" for null comparison');
364
- if (ValueNode2.is(expr2)) {
365
- if (expr2.value === null) {
366
- return operator === "==" ? trueNode(this.dialect) : falseNode(this.dialect);
380
+ if (operator === "==" || operator === "!=") {
381
+ if (ValueNode2.is(expr2)) {
382
+ if (expr2.value === null) {
383
+ return operator === "==" ? trueNode(this.dialect) : falseNode(this.dialect);
384
+ } else {
385
+ return operator === "==" ? falseNode(this.dialect) : trueNode(this.dialect);
386
+ }
367
387
  } else {
368
- return operator === "==" ? falseNode(this.dialect) : trueNode(this.dialect);
388
+ return operator === "==" ? BinaryOperationNode2.create(expr2, OperatorNode2.create("is"), ValueNode2.createImmediate(null)) : BinaryOperationNode2.create(expr2, OperatorNode2.create("is not"), ValueNode2.createImmediate(null));
369
389
  }
370
390
  } else {
371
- return operator === "==" ? BinaryOperationNode2.create(expr2, OperatorNode2.create("is"), ValueNode2.createImmediate(null)) : BinaryOperationNode2.create(expr2, OperatorNode2.create("is not"), ValueNode2.createImmediate(null));
391
+ return ValueNode2.createImmediate(null);
372
392
  }
373
393
  }
374
394
  normalizeBinaryOperationOperands(expr2, context) {
395
+ if (context.contextValue) {
396
+ return {
397
+ normalizedLeft: expr2.left,
398
+ normalizedRight: expr2.right
399
+ };
400
+ }
375
401
  let normalizedLeft = expr2.left;
376
- if (this.isRelationField(expr2.left, context.model)) {
402
+ if (this.isRelationField(expr2.left, context.modelOrType)) {
377
403
  invariant2(ExpressionUtils3.isNull(expr2.right), "only null comparison is supported for relation field");
378
- const leftRelDef = this.getFieldDefFromFieldRef(expr2.left, context.model);
404
+ const leftRelDef = this.getFieldDefFromFieldRef(expr2.left, context.modelOrType);
379
405
  invariant2(leftRelDef, "failed to get relation field definition");
380
406
  const idFields = QueryUtils.requireIdFields(this.schema, leftRelDef.type);
381
407
  normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]);
382
408
  }
383
409
  let normalizedRight = expr2.right;
384
- if (this.isRelationField(expr2.right, context.model)) {
410
+ if (this.isRelationField(expr2.right, context.modelOrType)) {
385
411
  invariant2(ExpressionUtils3.isNull(expr2.left), "only null comparison is supported for relation field");
386
- const rightRelDef = this.getFieldDefFromFieldRef(expr2.right, context.model);
412
+ const rightRelDef = this.getFieldDefFromFieldRef(expr2.right, context.modelOrType);
387
413
  invariant2(rightRelDef, "failed to get relation field definition");
388
414
  const idFields = QueryUtils.requireIdFields(this.schema, rightRelDef.type);
389
415
  normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]);
@@ -394,22 +420,30 @@ var ExpressionTransformer = class {
394
420
  };
395
421
  }
396
422
  transformCollectionPredicate(expr2, context) {
397
- invariant2(expr2.op === "?" || expr2.op === "!" || expr2.op === "^", 'expected "?" or "!" or "^" operator');
398
- if (this.isAuthCall(expr2.left) || this.isAuthMember(expr2.left)) {
399
- const value = new ExpressionEvaluator().evaluate(expr2, {
423
+ this.ensureCollectionPredicateOperator(expr2.op);
424
+ if (this.isAuthMember(expr2.left) || context.contextValue) {
425
+ invariant2(ExpressionUtils3.isMember(expr2.left) || ExpressionUtils3.isField(expr2.left), "expected member or field expression");
426
+ const evaluator = new ExpressionEvaluator();
427
+ const receiver = evaluator.evaluate(expr2.left, {
428
+ thisValue: context.contextValue,
400
429
  auth: this.auth
401
430
  });
402
- return this.transformValue(value, "Boolean");
431
+ const baseType = this.isAuthMember(expr2.left) ? this.authType : context.modelOrType;
432
+ const memberType = this.getMemberType(baseType, expr2.left);
433
+ return this.transformValueCollectionPredicate(receiver, expr2, {
434
+ ...context,
435
+ modelOrType: memberType
436
+ });
403
437
  }
404
438
  invariant2(ExpressionUtils3.isField(expr2.left) || ExpressionUtils3.isMember(expr2.left), "left operand must be field or member access");
405
439
  let newContextModel;
406
- const fieldDef = this.getFieldDefFromFieldRef(expr2.left, context.model);
440
+ const fieldDef = this.getFieldDefFromFieldRef(expr2.left, context.modelOrType);
407
441
  if (fieldDef) {
408
442
  invariant2(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr2.left)}`);
409
443
  newContextModel = fieldDef.type;
410
444
  } else {
411
445
  invariant2(ExpressionUtils3.isMember(expr2.left) && ExpressionUtils3.isField(expr2.left.receiver), "left operand must be member access with field receiver");
412
- const fieldDef2 = QueryUtils.requireField(this.schema, context.model, expr2.left.receiver.field);
446
+ const fieldDef2 = QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
413
447
  newContextModel = fieldDef2.type;
414
448
  for (const member of expr2.left.members) {
415
449
  const memberDef = QueryUtils.requireField(this.schema, newContextModel, member);
@@ -418,7 +452,7 @@ var ExpressionTransformer = class {
418
452
  }
419
453
  let predicateFilter = this.transform(expr2.right, {
420
454
  ...context,
421
- model: newContextModel,
455
+ modelOrType: newContextModel,
422
456
  alias: void 0
423
457
  });
424
458
  if (expr2.op === "!") {
@@ -434,9 +468,49 @@ var ExpressionTransformer = class {
434
468
  memberFilter: predicateFilter
435
469
  });
436
470
  }
471
+ ensureCollectionPredicateOperator(op) {
472
+ invariant2(CollectionPredicateOperator.includes(op), 'expected "?" or "!" or "^" operator');
473
+ }
474
+ transformValueCollectionPredicate(receiver, expr2, context) {
475
+ if (!receiver) {
476
+ return ValueNode2.createImmediate(null);
477
+ }
478
+ this.ensureCollectionPredicateOperator(expr2.op);
479
+ const visitor = new SchemaUtils.MatchingExpressionVisitor((e) => ExpressionUtils3.isThis(e));
480
+ if (!visitor.find(expr2.right)) {
481
+ const value = new ExpressionEvaluator().evaluate(expr2, {
482
+ auth: this.auth,
483
+ thisValue: context.contextValue
484
+ });
485
+ return this.transformValue(value, "Boolean");
486
+ } else {
487
+ invariant2(Array.isArray(receiver), "array value is expected");
488
+ const components = receiver.map((item) => this.transform(expr2.right, {
489
+ operation: context.operation,
490
+ thisType: context.thisType,
491
+ thisAlias: context.thisAlias,
492
+ modelOrType: context.modelOrType,
493
+ contextValue: item
494
+ }));
495
+ return match2(expr2.op).with("?", () => disjunction(this.dialect, components)).with("!", () => conjunction(this.dialect, components)).with("^", () => logicalNot(this.dialect, disjunction(this.dialect, components))).exhaustive();
496
+ }
497
+ }
498
+ getMemberType(receiverType, expr2) {
499
+ if (ExpressionUtils3.isField(expr2)) {
500
+ const fieldDef = QueryUtils.requireField(this.schema, receiverType, expr2.field);
501
+ return fieldDef.type;
502
+ } else {
503
+ let currType = receiverType;
504
+ for (const member of expr2.members) {
505
+ const fieldDef = QueryUtils.requireField(this.schema, currType, member);
506
+ currType = fieldDef.type;
507
+ }
508
+ return currType;
509
+ }
510
+ }
437
511
  transformAuthBinary(expr2, context) {
438
512
  if (expr2.op !== "==" && expr2.op !== "!=") {
439
- throw createUnsupportedError(`Unsupported operator for \`auth()\` in policy of model "${context.model}": ${expr2.op}`);
513
+ throw createUnsupportedError(`Unsupported operator for \`auth()\` in policy of model "${context.modelOrType}": ${expr2.op}`);
440
514
  }
441
515
  let authExpr;
442
516
  let other;
@@ -452,7 +526,7 @@ var ExpressionTransformer = class {
452
526
  } else {
453
527
  const authModel = QueryUtils.getModel(this.schema, this.authType);
454
528
  if (!authModel) {
455
- throw createUnsupportedError(`Unsupported use of \`auth()\` in policy of model "${context.model}", comparing with \`auth()\` is only possible when auth type is a model`);
529
+ throw createUnsupportedError(`Unsupported use of \`auth()\` in policy of model "${context.modelOrType}", comparing with \`auth()\` is only possible when auth type is a model`);
456
530
  }
457
531
  const idFields = Object.values(authModel.fields).filter((f) => f.id).map((f) => f.name);
458
532
  invariant2(idFields.length > 0, "auth type model must have at least one id field");
@@ -484,7 +558,12 @@ var ExpressionTransformer = class {
484
558
  } else if (value === false) {
485
559
  return falseNode(this.dialect);
486
560
  } else {
487
- return ValueNode2.create(this.dialect.transformPrimitive(value, type, false) ?? null);
561
+ const transformed = this.dialect.transformPrimitive(value, type, false) ?? null;
562
+ if (!Array.isArray(transformed)) {
563
+ return ValueNode2.createImmediate(transformed);
564
+ } else {
565
+ return ValueNode2.create(transformed);
566
+ }
488
567
  }
489
568
  }
490
569
  _unary(expr2, context) {
@@ -508,8 +587,8 @@ var ExpressionTransformer = class {
508
587
  return func(eb, (expr2.args ?? []).map((arg) => this.transformCallArg(eb, arg, context)), {
509
588
  client: this.client,
510
589
  dialect: this.dialect,
511
- model: context.model,
512
- modelAlias: context.alias ?? context.model,
590
+ model: context.modelOrType,
591
+ modelAlias: context.alias ?? context.modelOrType,
513
592
  operation: context.operation
514
593
  });
515
594
  }
@@ -556,9 +635,15 @@ var ExpressionTransformer = class {
556
635
  const { memberFilter, memberSelect, ...restContext } = context;
557
636
  if (ExpressionUtils3.isThis(expr2.receiver)) {
558
637
  if (expr2.members.length === 1) {
559
- return this._field(ExpressionUtils3.field(expr2.members[0]), context);
638
+ return this._field(ExpressionUtils3.field(expr2.members[0]), {
639
+ ...context,
640
+ alias: context.thisAlias,
641
+ modelOrType: context.thisType,
642
+ thisType: context.thisType,
643
+ contextValue: void 0
644
+ });
560
645
  } else {
561
- const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.model, expr2.members[0]);
646
+ const firstMemberFieldDef = QueryUtils.requireField(this.schema, context.thisType, expr2.members[0]);
562
647
  receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, restContext);
563
648
  members = expr2.members.slice(1);
564
649
  }
@@ -568,10 +653,10 @@ var ExpressionTransformer = class {
568
653
  invariant2(SelectQueryNode.is(receiver), "expected receiver to be select query");
569
654
  let startType;
570
655
  if (ExpressionUtils3.isField(expr2.receiver)) {
571
- const receiverField = QueryUtils.requireField(this.schema, context.model, expr2.receiver.field);
656
+ const receiverField = QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
572
657
  startType = receiverField.type;
573
658
  } else {
574
- startType = context.model;
659
+ startType = context.thisType;
575
660
  }
576
661
  const memberFields = [];
577
662
  let currType = startType;
@@ -590,7 +675,7 @@ var ExpressionTransformer = class {
590
675
  if (fieldDef.relation) {
591
676
  const relation = this.transformRelationAccess(member, fieldDef.type, {
592
677
  ...restContext,
593
- model: fromModel,
678
+ modelOrType: fromModel,
594
679
  alias: void 0
595
680
  });
596
681
  if (currNode) {
@@ -626,20 +711,29 @@ var ExpressionTransformer = class {
626
711
  if (!receiver) {
627
712
  return ValueNode2.createImmediate(null);
628
713
  }
629
- if (expr2.members.length !== 1) {
630
- throw new Error(`Only single member access is supported`);
714
+ invariant2(expr2.members.length > 0, "member expression must have at least one member");
715
+ let curr = receiver;
716
+ let currType = receiverType;
717
+ for (let i = 0; i < expr2.members.length; i++) {
718
+ const field = expr2.members[i];
719
+ curr = curr?.[field];
720
+ if (curr === void 0) {
721
+ curr = ValueNode2.createImmediate(null);
722
+ break;
723
+ }
724
+ currType = QueryUtils.requireField(this.schema, currType, field).type;
725
+ if (i === expr2.members.length - 1) {
726
+ curr = this.transformValue(curr, currType);
727
+ }
631
728
  }
632
- const field = expr2.members[0];
633
- const fieldDef = QueryUtils.requireField(this.schema, receiverType, field);
634
- const fieldValue = receiver[field] ?? null;
635
- return this.transformValue(fieldValue, fieldDef.type);
729
+ return curr;
636
730
  }
637
731
  transformRelationAccess(field, relationModel, context) {
638
- const m2m = QueryUtils.getManyToManyRelation(this.schema, context.model, field);
732
+ const m2m = QueryUtils.getManyToManyRelation(this.schema, context.modelOrType, field);
639
733
  if (m2m) {
640
734
  return this.transformManyToManyRelationAccess(m2m, context);
641
735
  }
642
- const fromModel = context.model;
736
+ const fromModel = context.modelOrType;
643
737
  const relationFieldDef = QueryUtils.requireField(this.schema, fromModel, field);
644
738
  const { keyPairs, ownedByModel } = QueryUtils.getRelationForeignKeyFieldPairs(this.schema, fromModel, field);
645
739
  let condition;
@@ -664,19 +758,19 @@ var ExpressionTransformer = class {
664
758
  }
665
759
  transformManyToManyRelationAccess(m2m, context) {
666
760
  const eb = expressionBuilder();
667
- const relationQuery = eb.selectFrom(m2m.otherModel).innerJoin(m2m.joinTable, (join) => join.onRef(`${m2m.otherModel}.${m2m.otherPKName}`, "=", `${m2m.joinTable}.${m2m.otherFkName}`).onRef(`${m2m.joinTable}.${m2m.parentFkName}`, "=", `${context.alias ?? context.model}.${m2m.parentPKName}`));
761
+ const relationQuery = eb.selectFrom(m2m.otherModel).innerJoin(m2m.joinTable, (join) => join.onRef(`${m2m.otherModel}.${m2m.otherPKName}`, "=", `${m2m.joinTable}.${m2m.otherFkName}`).onRef(`${m2m.joinTable}.${m2m.parentFkName}`, "=", `${context.alias ?? context.modelOrType}.${m2m.parentPKName}`));
668
762
  return relationQuery.toOperationNode();
669
763
  }
670
764
  createColumnRef(column, context) {
671
- const tableName = context.alias ?? context.model;
765
+ const tableName = context.alias ?? context.modelOrType;
672
766
  if (context.operation === "create") {
673
767
  return ReferenceNode2.create(ColumnNode.create(column), TableNode2.create(tableName));
674
768
  }
675
- const fieldDef = QueryUtils.requireField(this.schema, context.model, column);
676
- if (!fieldDef.originModel || fieldDef.originModel === context.model) {
769
+ const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, column);
770
+ if (!fieldDef.originModel || fieldDef.originModel === context.modelOrType) {
677
771
  return ReferenceNode2.create(ColumnNode.create(column), TableNode2.create(tableName));
678
772
  }
679
- return this.buildDelegateBaseFieldSelect(context.model, tableName, column, fieldDef.originModel);
773
+ return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel);
680
774
  }
681
775
  buildDelegateBaseFieldSelect(model, modelAlias, field, baseModel) {
682
776
  const idFields = QueryUtils.requireIdFields(this.client.$schema, model);
@@ -718,9 +812,9 @@ var ExpressionTransformer = class {
718
812
  }
719
813
  getFieldDefFromFieldRef(expr2, model) {
720
814
  if (ExpressionUtils3.isField(expr2)) {
721
- return QueryUtils.requireField(this.schema, model, expr2.field);
815
+ return QueryUtils.getField(this.schema, model, expr2.field);
722
816
  } else if (ExpressionUtils3.isMember(expr2) && expr2.members.length === 1 && ExpressionUtils3.isThis(expr2.receiver)) {
723
- return QueryUtils.requireField(this.schema, model, expr2.members[0]);
817
+ return QueryUtils.getField(this.schema, model, expr2.members[0]);
724
818
  } else {
725
819
  return void 0;
726
820
  }
@@ -938,7 +1032,7 @@ var PolicyHandler = class extends OperationNodeTransformer {
938
1032
  return void 0;
939
1033
  }
940
1034
  const fields = /* @__PURE__ */ new Set();
941
- const fieldCollector = new class extends SchemaUtils.ExpressionVisitor {
1035
+ const fieldCollector = new class extends SchemaUtils2.ExpressionVisitor {
942
1036
  visitMember(e) {
943
1037
  if (isBeforeInvocation(e.receiver)) {
944
1038
  invariant3(e.members.length === 1, "before() can only be followed by a scalar field access");
@@ -1386,7 +1480,9 @@ var PolicyHandler = class extends OperationNodeTransformer {
1386
1480
  }
1387
1481
  compilePolicyCondition(model, alias, operation, policy) {
1388
1482
  return new ExpressionTransformer(this.client).transform(policy.condition, {
1389
- model,
1483
+ modelOrType: model,
1484
+ thisType: model,
1485
+ thisAlias: alias,
1390
1486
  alias,
1391
1487
  operation
1392
1488
  });