@zenstackhq/plugin-policy 3.0.0-beta.25 → 3.0.0-beta.26

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -107,6 +107,12 @@ var ExpressionEvaluator = class {
107
107
  }
108
108
  const left = this.evaluate(expr2.left, context);
109
109
  const right = this.evaluate(expr2.right, context);
110
+ if (![
111
+ "==",
112
+ "!="
113
+ ].includes(expr2.op) && (left === null || right === null)) {
114
+ return null;
115
+ }
110
116
  return (0, import_ts_pattern.match)(expr2.op).with("==", () => left === right).with("!=", () => left !== right).with(">", () => left > right).with(">=", () => left >= right).with("<", () => left < right).with("<=", () => left <= right).with("&&", () => left && right).with("||", () => left || right).with("in", () => {
111
117
  const _right = right ?? [];
112
118
  (0, import_common_helpers.invariant)(Array.isArray(_right), 'expected array for "in" operator');
@@ -117,8 +123,8 @@ var ExpressionEvaluator = class {
117
123
  const op = expr2.op;
118
124
  (0, import_common_helpers.invariant)(op === "?" || op === "!" || op === "^", 'expected "?" or "!" or "^" operator');
119
125
  const left = this.evaluate(expr2.left, context);
120
- if (!left) {
121
- return false;
126
+ if (left === null || left === void 0) {
127
+ return null;
122
128
  }
123
129
  (0, import_common_helpers.invariant)(Array.isArray(left), "expected array");
124
130
  return (0, import_ts_pattern.match)(op).with("?", () => left.some((item) => this.evaluate(expr2.right, {
@@ -134,6 +140,13 @@ var ExpressionEvaluator = class {
134
140
  }
135
141
  };
136
142
 
143
+ // src/types.ts
144
+ var CollectionPredicateOperator = [
145
+ "?",
146
+ "!",
147
+ "^"
148
+ ];
149
+
137
150
  // src/utils.ts
138
151
  var import_orm2 = require("@zenstackhq/orm");
139
152
  var import_schema2 = require("@zenstackhq/orm/schema");
@@ -310,7 +323,11 @@ var ExpressionTransformer = class {
310
323
  return import_kysely2.ValueListNode.create(expr2.items.map((item) => this.transform(item, context)));
311
324
  }
312
325
  _field(expr2, context) {
313
- const fieldDef = import_orm3.QueryUtils.requireField(this.schema, context.model, expr2.field);
326
+ if (context.contextValue) {
327
+ const fieldDef2 = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.field);
328
+ return this.transformValue(context.contextValue[expr2.field], fieldDef2.type);
329
+ }
330
+ const fieldDef = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.field);
314
331
  if (!fieldDef.relation) {
315
332
  return this.createColumnRef(expr2.field, context);
316
333
  } else {
@@ -384,30 +401,39 @@ var ExpressionTransformer = class {
384
401
  }
385
402
  }
386
403
  transformNullCheck(expr2, operator) {
387
- (0, import_common_helpers2.invariant)(operator === "==" || operator === "!=", 'operator must be "==" or "!=" for null comparison');
388
- if (import_kysely2.ValueNode.is(expr2)) {
389
- if (expr2.value === null) {
390
- return operator === "==" ? trueNode(this.dialect) : falseNode(this.dialect);
404
+ if (operator === "==" || operator === "!=") {
405
+ if (import_kysely2.ValueNode.is(expr2)) {
406
+ if (expr2.value === null) {
407
+ return operator === "==" ? trueNode(this.dialect) : falseNode(this.dialect);
408
+ } else {
409
+ return operator === "==" ? falseNode(this.dialect) : trueNode(this.dialect);
410
+ }
391
411
  } else {
392
- return operator === "==" ? falseNode(this.dialect) : trueNode(this.dialect);
412
+ return operator === "==" ? import_kysely2.BinaryOperationNode.create(expr2, import_kysely2.OperatorNode.create("is"), import_kysely2.ValueNode.createImmediate(null)) : import_kysely2.BinaryOperationNode.create(expr2, import_kysely2.OperatorNode.create("is not"), import_kysely2.ValueNode.createImmediate(null));
393
413
  }
394
414
  } else {
395
- return operator === "==" ? import_kysely2.BinaryOperationNode.create(expr2, import_kysely2.OperatorNode.create("is"), import_kysely2.ValueNode.createImmediate(null)) : import_kysely2.BinaryOperationNode.create(expr2, import_kysely2.OperatorNode.create("is not"), import_kysely2.ValueNode.createImmediate(null));
415
+ return import_kysely2.ValueNode.createImmediate(null);
396
416
  }
397
417
  }
398
418
  normalizeBinaryOperationOperands(expr2, context) {
419
+ if (context.contextValue) {
420
+ return {
421
+ normalizedLeft: expr2.left,
422
+ normalizedRight: expr2.right
423
+ };
424
+ }
399
425
  let normalizedLeft = expr2.left;
400
- if (this.isRelationField(expr2.left, context.model)) {
426
+ if (this.isRelationField(expr2.left, context.modelOrType)) {
401
427
  (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isNull(expr2.right), "only null comparison is supported for relation field");
402
- const leftRelDef = this.getFieldDefFromFieldRef(expr2.left, context.model);
428
+ const leftRelDef = this.getFieldDefFromFieldRef(expr2.left, context.modelOrType);
403
429
  (0, import_common_helpers2.invariant)(leftRelDef, "failed to get relation field definition");
404
430
  const idFields = import_orm3.QueryUtils.requireIdFields(this.schema, leftRelDef.type);
405
431
  normalizedLeft = this.makeOrAppendMember(normalizedLeft, idFields[0]);
406
432
  }
407
433
  let normalizedRight = expr2.right;
408
- if (this.isRelationField(expr2.right, context.model)) {
434
+ if (this.isRelationField(expr2.right, context.modelOrType)) {
409
435
  (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isNull(expr2.left), "only null comparison is supported for relation field");
410
- const rightRelDef = this.getFieldDefFromFieldRef(expr2.right, context.model);
436
+ const rightRelDef = this.getFieldDefFromFieldRef(expr2.right, context.modelOrType);
411
437
  (0, import_common_helpers2.invariant)(rightRelDef, "failed to get relation field definition");
412
438
  const idFields = import_orm3.QueryUtils.requireIdFields(this.schema, rightRelDef.type);
413
439
  normalizedRight = this.makeOrAppendMember(normalizedRight, idFields[0]);
@@ -418,22 +444,30 @@ var ExpressionTransformer = class {
418
444
  };
419
445
  }
420
446
  transformCollectionPredicate(expr2, context) {
421
- (0, import_common_helpers2.invariant)(expr2.op === "?" || expr2.op === "!" || expr2.op === "^", 'expected "?" or "!" or "^" operator');
422
- if (this.isAuthCall(expr2.left) || this.isAuthMember(expr2.left)) {
423
- const value = new ExpressionEvaluator().evaluate(expr2, {
447
+ this.ensureCollectionPredicateOperator(expr2.op);
448
+ if (this.isAuthMember(expr2.left) || context.contextValue) {
449
+ (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isMember(expr2.left) || import_schema3.ExpressionUtils.isField(expr2.left), "expected member or field expression");
450
+ const evaluator = new ExpressionEvaluator();
451
+ const receiver = evaluator.evaluate(expr2.left, {
452
+ thisValue: context.contextValue,
424
453
  auth: this.auth
425
454
  });
426
- return this.transformValue(value, "Boolean");
455
+ const baseType = this.isAuthMember(expr2.left) ? this.authType : context.modelOrType;
456
+ const memberType = this.getMemberType(baseType, expr2.left);
457
+ return this.transformValueCollectionPredicate(receiver, expr2, {
458
+ ...context,
459
+ modelOrType: memberType
460
+ });
427
461
  }
428
462
  (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isField(expr2.left) || import_schema3.ExpressionUtils.isMember(expr2.left), "left operand must be field or member access");
429
463
  let newContextModel;
430
- const fieldDef = this.getFieldDefFromFieldRef(expr2.left, context.model);
464
+ const fieldDef = this.getFieldDefFromFieldRef(expr2.left, context.modelOrType);
431
465
  if (fieldDef) {
432
466
  (0, import_common_helpers2.invariant)(fieldDef.relation, `field is not a relation: ${JSON.stringify(expr2.left)}`);
433
467
  newContextModel = fieldDef.type;
434
468
  } else {
435
469
  (0, import_common_helpers2.invariant)(import_schema3.ExpressionUtils.isMember(expr2.left) && import_schema3.ExpressionUtils.isField(expr2.left.receiver), "left operand must be member access with field receiver");
436
- const fieldDef2 = import_orm3.QueryUtils.requireField(this.schema, context.model, expr2.left.receiver.field);
470
+ const fieldDef2 = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.left.receiver.field);
437
471
  newContextModel = fieldDef2.type;
438
472
  for (const member of expr2.left.members) {
439
473
  const memberDef = import_orm3.QueryUtils.requireField(this.schema, newContextModel, member);
@@ -442,7 +476,7 @@ var ExpressionTransformer = class {
442
476
  }
443
477
  let predicateFilter = this.transform(expr2.right, {
444
478
  ...context,
445
- model: newContextModel,
479
+ modelOrType: newContextModel,
446
480
  alias: void 0
447
481
  });
448
482
  if (expr2.op === "!") {
@@ -458,9 +492,49 @@ var ExpressionTransformer = class {
458
492
  memberFilter: predicateFilter
459
493
  });
460
494
  }
495
+ ensureCollectionPredicateOperator(op) {
496
+ (0, import_common_helpers2.invariant)(CollectionPredicateOperator.includes(op), 'expected "?" or "!" or "^" operator');
497
+ }
498
+ transformValueCollectionPredicate(receiver, expr2, context) {
499
+ if (!receiver) {
500
+ return import_kysely2.ValueNode.createImmediate(null);
501
+ }
502
+ this.ensureCollectionPredicateOperator(expr2.op);
503
+ const visitor = new import_orm3.SchemaUtils.MatchingExpressionVisitor((e) => import_schema3.ExpressionUtils.isThis(e));
504
+ if (!visitor.find(expr2.right)) {
505
+ const value = new ExpressionEvaluator().evaluate(expr2, {
506
+ auth: this.auth,
507
+ thisValue: context.contextValue
508
+ });
509
+ return this.transformValue(value, "Boolean");
510
+ } else {
511
+ (0, import_common_helpers2.invariant)(Array.isArray(receiver), "array value is expected");
512
+ const components = receiver.map((item) => this.transform(expr2.right, {
513
+ operation: context.operation,
514
+ thisType: context.thisType,
515
+ thisAlias: context.thisAlias,
516
+ modelOrType: context.modelOrType,
517
+ contextValue: item
518
+ }));
519
+ return (0, import_ts_pattern2.match)(expr2.op).with("?", () => disjunction(this.dialect, components)).with("!", () => conjunction(this.dialect, components)).with("^", () => logicalNot(this.dialect, disjunction(this.dialect, components))).exhaustive();
520
+ }
521
+ }
522
+ getMemberType(receiverType, expr2) {
523
+ if (import_schema3.ExpressionUtils.isField(expr2)) {
524
+ const fieldDef = import_orm3.QueryUtils.requireField(this.schema, receiverType, expr2.field);
525
+ return fieldDef.type;
526
+ } else {
527
+ let currType = receiverType;
528
+ for (const member of expr2.members) {
529
+ const fieldDef = import_orm3.QueryUtils.requireField(this.schema, currType, member);
530
+ currType = fieldDef.type;
531
+ }
532
+ return currType;
533
+ }
534
+ }
461
535
  transformAuthBinary(expr2, context) {
462
536
  if (expr2.op !== "==" && expr2.op !== "!=") {
463
- throw createUnsupportedError(`Unsupported operator for \`auth()\` in policy of model "${context.model}": ${expr2.op}`);
537
+ throw createUnsupportedError(`Unsupported operator for \`auth()\` in policy of model "${context.modelOrType}": ${expr2.op}`);
464
538
  }
465
539
  let authExpr;
466
540
  let other;
@@ -476,7 +550,7 @@ var ExpressionTransformer = class {
476
550
  } else {
477
551
  const authModel = import_orm3.QueryUtils.getModel(this.schema, this.authType);
478
552
  if (!authModel) {
479
- throw createUnsupportedError(`Unsupported use of \`auth()\` in policy of model "${context.model}", comparing with \`auth()\` is only possible when auth type is a model`);
553
+ throw createUnsupportedError(`Unsupported use of \`auth()\` in policy of model "${context.modelOrType}", comparing with \`auth()\` is only possible when auth type is a model`);
480
554
  }
481
555
  const idFields = Object.values(authModel.fields).filter((f) => f.id).map((f) => f.name);
482
556
  (0, import_common_helpers2.invariant)(idFields.length > 0, "auth type model must have at least one id field");
@@ -508,7 +582,12 @@ var ExpressionTransformer = class {
508
582
  } else if (value === false) {
509
583
  return falseNode(this.dialect);
510
584
  } else {
511
- return import_kysely2.ValueNode.create(this.dialect.transformPrimitive(value, type, false) ?? null);
585
+ const transformed = this.dialect.transformPrimitive(value, type, false) ?? null;
586
+ if (!Array.isArray(transformed)) {
587
+ return import_kysely2.ValueNode.createImmediate(transformed);
588
+ } else {
589
+ return import_kysely2.ValueNode.create(transformed);
590
+ }
512
591
  }
513
592
  }
514
593
  _unary(expr2, context) {
@@ -532,8 +611,8 @@ var ExpressionTransformer = class {
532
611
  return func(eb, (expr2.args ?? []).map((arg) => this.transformCallArg(eb, arg, context)), {
533
612
  client: this.client,
534
613
  dialect: this.dialect,
535
- model: context.model,
536
- modelAlias: context.alias ?? context.model,
614
+ model: context.modelOrType,
615
+ modelAlias: context.alias ?? context.modelOrType,
537
616
  operation: context.operation
538
617
  });
539
618
  }
@@ -580,9 +659,15 @@ var ExpressionTransformer = class {
580
659
  const { memberFilter, memberSelect, ...restContext } = context;
581
660
  if (import_schema3.ExpressionUtils.isThis(expr2.receiver)) {
582
661
  if (expr2.members.length === 1) {
583
- return this._field(import_schema3.ExpressionUtils.field(expr2.members[0]), context);
662
+ return this._field(import_schema3.ExpressionUtils.field(expr2.members[0]), {
663
+ ...context,
664
+ alias: context.thisAlias,
665
+ modelOrType: context.thisType,
666
+ thisType: context.thisType,
667
+ contextValue: void 0
668
+ });
584
669
  } else {
585
- const firstMemberFieldDef = import_orm3.QueryUtils.requireField(this.schema, context.model, expr2.members[0]);
670
+ const firstMemberFieldDef = import_orm3.QueryUtils.requireField(this.schema, context.thisType, expr2.members[0]);
586
671
  receiver = this.transformRelationAccess(expr2.members[0], firstMemberFieldDef.type, restContext);
587
672
  members = expr2.members.slice(1);
588
673
  }
@@ -592,10 +677,10 @@ var ExpressionTransformer = class {
592
677
  (0, import_common_helpers2.invariant)(import_kysely2.SelectQueryNode.is(receiver), "expected receiver to be select query");
593
678
  let startType;
594
679
  if (import_schema3.ExpressionUtils.isField(expr2.receiver)) {
595
- const receiverField = import_orm3.QueryUtils.requireField(this.schema, context.model, expr2.receiver.field);
680
+ const receiverField = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, expr2.receiver.field);
596
681
  startType = receiverField.type;
597
682
  } else {
598
- startType = context.model;
683
+ startType = context.thisType;
599
684
  }
600
685
  const memberFields = [];
601
686
  let currType = startType;
@@ -614,7 +699,7 @@ var ExpressionTransformer = class {
614
699
  if (fieldDef.relation) {
615
700
  const relation = this.transformRelationAccess(member, fieldDef.type, {
616
701
  ...restContext,
617
- model: fromModel,
702
+ modelOrType: fromModel,
618
703
  alias: void 0
619
704
  });
620
705
  if (currNode) {
@@ -650,20 +735,29 @@ var ExpressionTransformer = class {
650
735
  if (!receiver) {
651
736
  return import_kysely2.ValueNode.createImmediate(null);
652
737
  }
653
- if (expr2.members.length !== 1) {
654
- throw new Error(`Only single member access is supported`);
738
+ (0, import_common_helpers2.invariant)(expr2.members.length > 0, "member expression must have at least one member");
739
+ let curr = receiver;
740
+ let currType = receiverType;
741
+ for (let i = 0; i < expr2.members.length; i++) {
742
+ const field = expr2.members[i];
743
+ curr = curr?.[field];
744
+ if (curr === void 0) {
745
+ curr = import_kysely2.ValueNode.createImmediate(null);
746
+ break;
747
+ }
748
+ currType = import_orm3.QueryUtils.requireField(this.schema, currType, field).type;
749
+ if (i === expr2.members.length - 1) {
750
+ curr = this.transformValue(curr, currType);
751
+ }
655
752
  }
656
- const field = expr2.members[0];
657
- const fieldDef = import_orm3.QueryUtils.requireField(this.schema, receiverType, field);
658
- const fieldValue = receiver[field] ?? null;
659
- return this.transformValue(fieldValue, fieldDef.type);
753
+ return curr;
660
754
  }
661
755
  transformRelationAccess(field, relationModel, context) {
662
- const m2m = import_orm3.QueryUtils.getManyToManyRelation(this.schema, context.model, field);
756
+ const m2m = import_orm3.QueryUtils.getManyToManyRelation(this.schema, context.modelOrType, field);
663
757
  if (m2m) {
664
758
  return this.transformManyToManyRelationAccess(m2m, context);
665
759
  }
666
- const fromModel = context.model;
760
+ const fromModel = context.modelOrType;
667
761
  const relationFieldDef = import_orm3.QueryUtils.requireField(this.schema, fromModel, field);
668
762
  const { keyPairs, ownedByModel } = import_orm3.QueryUtils.getRelationForeignKeyFieldPairs(this.schema, fromModel, field);
669
763
  let condition;
@@ -688,19 +782,19 @@ var ExpressionTransformer = class {
688
782
  }
689
783
  transformManyToManyRelationAccess(m2m, context) {
690
784
  const eb = (0, import_kysely2.expressionBuilder)();
691
- const relationQuery = eb.selectFrom(m2m.otherModel).innerJoin(m2m.joinTable, (join) => join.onRef(`${m2m.otherModel}.${m2m.otherPKName}`, "=", `${m2m.joinTable}.${m2m.otherFkName}`).onRef(`${m2m.joinTable}.${m2m.parentFkName}`, "=", `${context.alias ?? context.model}.${m2m.parentPKName}`));
785
+ const relationQuery = eb.selectFrom(m2m.otherModel).innerJoin(m2m.joinTable, (join) => join.onRef(`${m2m.otherModel}.${m2m.otherPKName}`, "=", `${m2m.joinTable}.${m2m.otherFkName}`).onRef(`${m2m.joinTable}.${m2m.parentFkName}`, "=", `${context.alias ?? context.modelOrType}.${m2m.parentPKName}`));
692
786
  return relationQuery.toOperationNode();
693
787
  }
694
788
  createColumnRef(column, context) {
695
- const tableName = context.alias ?? context.model;
789
+ const tableName = context.alias ?? context.modelOrType;
696
790
  if (context.operation === "create") {
697
791
  return import_kysely2.ReferenceNode.create(import_kysely2.ColumnNode.create(column), import_kysely2.TableNode.create(tableName));
698
792
  }
699
- const fieldDef = import_orm3.QueryUtils.requireField(this.schema, context.model, column);
700
- if (!fieldDef.originModel || fieldDef.originModel === context.model) {
793
+ const fieldDef = import_orm3.QueryUtils.requireField(this.schema, context.modelOrType, column);
794
+ if (!fieldDef.originModel || fieldDef.originModel === context.modelOrType) {
701
795
  return import_kysely2.ReferenceNode.create(import_kysely2.ColumnNode.create(column), import_kysely2.TableNode.create(tableName));
702
796
  }
703
- return this.buildDelegateBaseFieldSelect(context.model, tableName, column, fieldDef.originModel);
797
+ return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel);
704
798
  }
705
799
  buildDelegateBaseFieldSelect(model, modelAlias, field, baseModel) {
706
800
  const idFields = import_orm3.QueryUtils.requireIdFields(this.client.$schema, model);
@@ -742,9 +836,9 @@ var ExpressionTransformer = class {
742
836
  }
743
837
  getFieldDefFromFieldRef(expr2, model) {
744
838
  if (import_schema3.ExpressionUtils.isField(expr2)) {
745
- return import_orm3.QueryUtils.requireField(this.schema, model, expr2.field);
839
+ return import_orm3.QueryUtils.getField(this.schema, model, expr2.field);
746
840
  } else if (import_schema3.ExpressionUtils.isMember(expr2) && expr2.members.length === 1 && import_schema3.ExpressionUtils.isThis(expr2.receiver)) {
747
- return import_orm3.QueryUtils.requireField(this.schema, model, expr2.members[0]);
841
+ return import_orm3.QueryUtils.getField(this.schema, model, expr2.members[0]);
748
842
  } else {
749
843
  return void 0;
750
844
  }
@@ -1410,7 +1504,9 @@ var PolicyHandler = class extends import_kysely3.OperationNodeTransformer {
1410
1504
  }
1411
1505
  compilePolicyCondition(model, alias, operation, policy) {
1412
1506
  return new ExpressionTransformer(this.client).transform(policy.condition, {
1413
- model,
1507
+ modelOrType: model,
1508
+ thisType: model,
1509
+ thisAlias: alias,
1414
1510
  alias,
1415
1511
  operation
1416
1512
  });