zenstack 1.0.16 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. package/README.md +93 -24
  2. package/bin/post-install.js +1 -1
  3. package/cli/actions/generate.d.ts +13 -0
  4. package/cli/actions/generate.js +71 -0
  5. package/cli/actions/generate.js.map +1 -0
  6. package/cli/actions/index.d.ts +3 -0
  7. package/cli/actions/index.js +20 -0
  8. package/cli/actions/index.js.map +1 -0
  9. package/cli/actions/info.d.ts +4 -0
  10. package/cli/actions/info.js +63 -0
  11. package/cli/actions/info.js.map +1 -0
  12. package/cli/actions/init.d.ts +12 -0
  13. package/cli/actions/init.js +83 -0
  14. package/cli/actions/init.js.map +1 -0
  15. package/cli/cli-util.d.ts +14 -11
  16. package/cli/cli-util.js +150 -82
  17. package/cli/cli-util.js.map +1 -1
  18. package/cli/config.d.ts +10 -0
  19. package/cli/config.js +62 -0
  20. package/cli/config.js.map +1 -0
  21. package/cli/index.d.ts +4 -12
  22. package/cli/index.js +36 -29
  23. package/cli/index.js.map +1 -1
  24. package/cli/plugin-runner.d.ts +13 -3
  25. package/cli/plugin-runner.js +165 -59
  26. package/cli/plugin-runner.js.map +1 -1
  27. package/constants.js +1 -1
  28. package/language-server/constants.d.ts +7 -0
  29. package/language-server/constants.js +8 -1
  30. package/language-server/constants.js.map +1 -1
  31. package/language-server/utils.d.ts +1 -14
  32. package/language-server/utils.js +2 -38
  33. package/language-server/utils.js.map +1 -1
  34. package/language-server/validator/attribute-application-validator.d.ts +15 -0
  35. package/language-server/validator/attribute-application-validator.js +246 -0
  36. package/language-server/validator/attribute-application-validator.js.map +1 -0
  37. package/language-server/validator/attribute-validator.d.ts +1 -1
  38. package/language-server/validator/attribute-validator.js +4 -1
  39. package/language-server/validator/attribute-validator.js.map +1 -1
  40. package/language-server/validator/datamodel-validator.d.ts +7 -0
  41. package/language-server/validator/datamodel-validator.js +84 -33
  42. package/language-server/validator/datamodel-validator.js.map +1 -1
  43. package/language-server/validator/enum-validator.js +3 -6
  44. package/language-server/validator/enum-validator.js.map +1 -1
  45. package/language-server/validator/expression-validator.d.ts +1 -1
  46. package/language-server/validator/expression-validator.js +108 -3
  47. package/language-server/validator/expression-validator.js.map +1 -1
  48. package/language-server/validator/function-decl-validator.d.ts +9 -0
  49. package/language-server/validator/function-decl-validator.js +13 -0
  50. package/language-server/validator/function-decl-validator.js.map +1 -0
  51. package/language-server/validator/function-invocation-validator.d.ts +11 -0
  52. package/language-server/validator/function-invocation-validator.js +135 -0
  53. package/language-server/validator/function-invocation-validator.js.map +1 -0
  54. package/language-server/validator/schema-validator.d.ts +4 -1
  55. package/language-server/validator/schema-validator.js +28 -7
  56. package/language-server/validator/schema-validator.js.map +1 -1
  57. package/language-server/validator/utils.d.ts +3 -4
  58. package/language-server/validator/utils.js +20 -123
  59. package/language-server/validator/utils.js.map +1 -1
  60. package/language-server/validator/zmodel-validator.d.ts +5 -1
  61. package/language-server/validator/zmodel-validator.js +18 -4
  62. package/language-server/validator/zmodel-validator.js.map +1 -1
  63. package/language-server/zmodel-code-action.d.ts +2 -1
  64. package/language-server/zmodel-code-action.js +46 -21
  65. package/language-server/zmodel-code-action.js.map +1 -1
  66. package/language-server/zmodel-definition.d.ts +7 -0
  67. package/language-server/zmodel-definition.js +31 -0
  68. package/language-server/zmodel-definition.js.map +1 -0
  69. package/language-server/zmodel-formatter.js +2 -2
  70. package/language-server/zmodel-formatter.js.map +1 -1
  71. package/language-server/zmodel-linker.d.ts +3 -0
  72. package/language-server/zmodel-linker.js +122 -41
  73. package/language-server/zmodel-linker.js.map +1 -1
  74. package/language-server/zmodel-module.js +4 -1
  75. package/language-server/zmodel-module.js.map +1 -1
  76. package/language-server/zmodel-scope.d.ts +7 -1
  77. package/language-server/zmodel-scope.js +57 -1
  78. package/language-server/zmodel-scope.js.map +1 -1
  79. package/language-server/zmodel-workspace-manager.d.ts +5 -1
  80. package/language-server/zmodel-workspace-manager.js +101 -0
  81. package/language-server/zmodel-workspace-manager.js.map +1 -1
  82. package/package.json +27 -20
  83. package/plugins/access-policy/expression-writer.d.ts +7 -0
  84. package/plugins/access-policy/expression-writer.js +325 -106
  85. package/plugins/access-policy/expression-writer.js.map +1 -1
  86. package/plugins/access-policy/index.d.ts +3 -3
  87. package/plugins/access-policy/index.js +3 -5
  88. package/plugins/access-policy/index.js.map +1 -1
  89. package/plugins/access-policy/policy-guard-generator.d.ts +10 -3
  90. package/plugins/access-policy/policy-guard-generator.js +406 -121
  91. package/plugins/access-policy/policy-guard-generator.js.map +1 -1
  92. package/plugins/model-meta/index.d.ts +3 -3
  93. package/plugins/model-meta/index.js +110 -46
  94. package/plugins/model-meta/index.js.map +1 -1
  95. package/plugins/plugin-utils.d.ts +8 -7
  96. package/plugins/plugin-utils.js +55 -21
  97. package/plugins/plugin-utils.js.map +1 -1
  98. package/plugins/prisma/index.d.ts +3 -3
  99. package/plugins/prisma/index.js +3 -5
  100. package/plugins/prisma/index.js.map +1 -1
  101. package/plugins/prisma/prisma-builder.d.ts +7 -14
  102. package/plugins/prisma/prisma-builder.js +29 -34
  103. package/plugins/prisma/prisma-builder.js.map +1 -1
  104. package/plugins/prisma/schema-generator.d.ts +7 -3
  105. package/plugins/prisma/schema-generator.js +146 -102
  106. package/plugins/prisma/schema-generator.js.map +1 -1
  107. package/plugins/prisma/zmodel-code-generator.d.ts +3 -1
  108. package/plugins/prisma/zmodel-code-generator.js +12 -2
  109. package/plugins/prisma/zmodel-code-generator.js.map +1 -1
  110. package/plugins/zod/generator.d.ts +4 -0
  111. package/plugins/zod/generator.js +298 -0
  112. package/plugins/zod/generator.js.map +1 -0
  113. package/plugins/zod/index.d.ts +4 -0
  114. package/plugins/zod/index.js +24 -0
  115. package/plugins/zod/index.js.map +1 -0
  116. package/plugins/zod/transformer.d.ts +68 -0
  117. package/plugins/zod/transformer.js +554 -0
  118. package/plugins/zod/transformer.js.map +1 -0
  119. package/plugins/zod/types.d.ts +25 -0
  120. package/plugins/zod/types.js.map +1 -0
  121. package/plugins/zod/utils/removeDir.d.ts +1 -0
  122. package/plugins/zod/utils/removeDir.js +30 -0
  123. package/plugins/zod/utils/removeDir.js.map +1 -0
  124. package/plugins/zod/utils/schema-gen.d.ts +3 -0
  125. package/plugins/zod/utils/schema-gen.js +188 -0
  126. package/plugins/zod/utils/schema-gen.js.map +1 -0
  127. package/res/starter.zmodel +6 -8
  128. package/res/stdlib.zmodel +238 -74
  129. package/telemetry.d.ts +2 -1
  130. package/telemetry.js +21 -11
  131. package/telemetry.js.map +1 -1
  132. package/utils/ast-utils.d.ts +12 -15
  133. package/utils/ast-utils.js +117 -66
  134. package/utils/ast-utils.js.map +1 -1
  135. package/utils/pkg-utils.d.ts +2 -2
  136. package/utils/pkg-utils.js +34 -16
  137. package/utils/pkg-utils.js.map +1 -1
  138. package/utils/typescript-expression-transformer.d.ts +54 -0
  139. package/utils/typescript-expression-transformer.js +326 -0
  140. package/utils/typescript-expression-transformer.js.map +1 -0
  141. package/utils/version-utils.js +7 -1
  142. package/utils/version-utils.js.map +1 -1
  143. package/plugins/access-policy/typescript-expression-transformer.d.ts +0 -26
  144. package/plugins/access-policy/typescript-expression-transformer.js +0 -111
  145. package/plugins/access-policy/typescript-expression-transformer.js.map +0 -1
  146. package/plugins/access-policy/utils.d.ts +0 -5
  147. package/plugins/access-policy/utils.js +0 -14
  148. package/plugins/access-policy/utils.js.map +0 -1
  149. package/plugins/access-policy/zod-schema-generator.d.ts +0 -12
  150. package/plugins/access-policy/zod-schema-generator.js +0 -158
  151. package/plugins/access-policy/zod-schema-generator.js.map +0 -1
  152. package/types.d.ts +0 -12
  153. package/types.js.map +0 -1
  154. /package/{types.js → plugins/zod/types.js} +0 -0
@@ -13,64 +13,66 @@ var __importDefault = (this && this.__importDefault) || function (mod) {
13
13
  };
14
14
  Object.defineProperty(exports, "__esModule", { value: true });
15
15
  const ast_1 = require("@zenstackhq/language/ast");
16
+ const runtime_1 = require("@zenstackhq/runtime");
16
17
  const sdk_1 = require("@zenstackhq/sdk");
17
- const change_case_1 = require("change-case");
18
18
  const langium_1 = require("langium");
19
+ const lower_case_first_1 = require("lower-case-first");
19
20
  const path_1 = __importDefault(require("path"));
20
21
  const ts_morph_1 = require("ts-morph");
21
22
  const _1 = require(".");
22
- const utils_1 = require("../../language-server/utils");
23
23
  const ast_utils_1 = require("../../utils/ast-utils");
24
+ const typescript_expression_transformer_1 = require("../../utils/typescript-expression-transformer");
24
25
  const plugin_utils_1 = require("../plugin-utils");
25
26
  const expression_writer_1 = require("./expression-writer");
26
- const utils_2 = require("./utils");
27
- const zod_schema_generator_1 = require("./zod-schema-generator");
28
27
  /**
29
28
  * Generates source file that contains Prisma query guard objects used for injecting database queries
30
29
  */
31
30
  class PolicyGenerator {
32
- generate(model, options) {
31
+ generate(model, options, globalOptions) {
33
32
  return __awaiter(this, void 0, void 0, function* () {
34
- const output = options.output ? options.output : (0, plugin_utils_1.getDefaultOutputFolder)();
33
+ let output = options.output ? options.output : (0, plugin_utils_1.getDefaultOutputFolder)(globalOptions);
35
34
  if (!output) {
36
- console.error(`Unable to determine output path, not running plugin ${_1.name}`);
37
- return;
35
+ throw new sdk_1.PluginError(options.name, `Unable to determine output path, not running plugin`);
38
36
  }
39
- const project = new ts_morph_1.Project();
37
+ output = (0, sdk_1.resolvePath)(output, options);
38
+ const project = (0, sdk_1.createProject)();
40
39
  const sf = project.createSourceFile(path_1.default.join(output, 'policy.ts'), undefined, { overwrite: true });
40
+ sf.addStatements('/* eslint-disable */');
41
41
  sf.addImportDeclaration({
42
- namedImports: [{ name: 'QueryContext' }],
43
- moduleSpecifier: `${plugin_utils_1.RUNTIME_PACKAGE}`,
44
- isTypeOnly: true,
45
- });
46
- sf.addImportDeclaration({
47
- namedImports: [{ name: 'z' }],
48
- moduleSpecifier: 'zod',
42
+ namedImports: [
43
+ { name: 'type QueryContext' },
44
+ { name: 'type DbOperations' },
45
+ { name: 'hasAllFields' },
46
+ { name: 'allFieldsEqual' },
47
+ { name: 'type PolicyDef' },
48
+ ],
49
+ moduleSpecifier: `${sdk_1.RUNTIME_PACKAGE}`,
49
50
  });
50
51
  // import enums
51
- for (const e of model.declarations.filter((d) => (0, ast_1.isEnum)(d))) {
52
+ const prismaImport = (0, sdk_1.getPrismaClientImportSpec)(model, output);
53
+ for (const e of model.declarations.filter((d) => (0, ast_1.isEnum)(d) && this.isEnumReferenced(model, d))) {
52
54
  sf.addImportDeclaration({
53
55
  namedImports: [{ name: e.name }],
54
- moduleSpecifier: '@prisma/client',
56
+ moduleSpecifier: prismaImport,
55
57
  });
56
58
  }
57
- const models = model.declarations.filter((d) => (0, ast_1.isDataModel)(d));
59
+ const models = (0, sdk_1.getDataModels)(model);
58
60
  const policyMap = {};
59
61
  for (const model of models) {
60
62
  policyMap[model.name] = yield this.generateQueryGuardForModel(model, sf);
61
63
  }
62
- const zodGenerator = new zod_schema_generator_1.ZodSchemaGenerator();
63
64
  sf.addVariableStatement({
64
65
  declarationKind: ts_morph_1.VariableDeclarationKind.Const,
65
66
  declarations: [
66
67
  {
67
68
  name: 'policy',
69
+ type: 'PolicyDef',
68
70
  initializer: (writer) => {
69
71
  writer.block(() => {
70
72
  writer.write('guard:');
71
73
  writer.inlineBlock(() => {
72
74
  for (const [model, map] of Object.entries(policyMap)) {
73
- writer.write(`${(0, change_case_1.camelCase)(model)}:`);
75
+ writer.write(`${(0, lower_case_first_1.lowerCaseFirst)(model)}:`);
74
76
  writer.inlineBlock(() => {
75
77
  for (const [op, func] of Object.entries(map)) {
76
78
  if (typeof func === 'object') {
@@ -85,21 +87,57 @@ class PolicyGenerator {
85
87
  }
86
88
  });
87
89
  writer.writeLine(',');
88
- writer.write('schema:');
89
- zodGenerator.generate(writer, models);
90
+ writer.write('validation:');
91
+ writer.inlineBlock(() => {
92
+ for (const model of models) {
93
+ writer.write(`${(0, lower_case_first_1.lowerCaseFirst)(model.name)}:`);
94
+ writer.inlineBlock(() => {
95
+ writer.write(`hasValidation: ${(0, sdk_1.hasValidationAttributes)(model)}`);
96
+ });
97
+ writer.writeLine(',');
98
+ }
99
+ });
90
100
  });
91
101
  },
92
102
  },
93
103
  ],
94
104
  });
95
105
  sf.addStatements('export default policy');
96
- sf.formatText();
97
- yield project.save();
98
- yield project.emit();
106
+ let shouldCompile = true;
107
+ if (typeof options.compile === 'boolean') {
108
+ // explicit override
109
+ shouldCompile = options.compile;
110
+ }
111
+ else if (globalOptions) {
112
+ shouldCompile = globalOptions.compile;
113
+ }
114
+ if (!shouldCompile || options.preserveTsFiles === true) {
115
+ // save ts files
116
+ yield (0, sdk_1.saveProject)(project);
117
+ }
118
+ if (shouldCompile) {
119
+ yield (0, sdk_1.emitProject)(project);
120
+ }
121
+ });
122
+ }
123
+ isEnumReferenced(model, decl) {
124
+ return (0, langium_1.streamAllContents)(model).some((node) => {
125
+ var _a, _b;
126
+ if ((0, ast_1.isDataModelField)(node) && ((_a = node.type.reference) === null || _a === void 0 ? void 0 : _a.ref) === decl) {
127
+ // referenced as field type
128
+ return true;
129
+ }
130
+ if ((0, sdk_1.isEnumFieldReference)(node) && ((_b = node.target.ref) === null || _b === void 0 ? void 0 : _b.$container) === decl) {
131
+ // enum field is referenced
132
+ return true;
133
+ }
134
+ return false;
99
135
  });
100
136
  }
101
- getPolicyExpressions(model, kind, operation) {
102
- const attrs = model.attributes.filter((attr) => { var _a; return ((_a = attr.decl.ref) === null || _a === void 0 ? void 0 : _a.name) === `@@${kind}`; });
137
+ getPolicyExpressions(target, kind, operation) {
138
+ const attributes = target.attributes;
139
+ const attrName = (0, ast_1.isDataModel)(target) ? `@@${kind}` : `@${kind}`;
140
+ const attrs = attributes.filter((attr) => { var _a; return ((_a = attr.decl.ref) === null || _a === void 0 ? void 0 : _a.name) === attrName; });
103
141
  const checkOperation = operation === 'postUpdate' ? 'update' : operation;
104
142
  let result = attrs
105
143
  .filter((attr) => {
@@ -150,8 +188,8 @@ class PolicyGenerator {
150
188
  }
151
189
  hasFutureReference(expr) {
152
190
  var _a;
153
- for (const node of (0, langium_1.streamAllContents)(expr)) {
154
- if ((0, ast_1.isInvocationExpr)(node) && ((_a = node.function.ref) === null || _a === void 0 ? void 0 : _a.name) === 'future' && (0, utils_1.isFromStdlib)(node.function.ref)) {
191
+ for (const node of (0, langium_1.streamAst)(expr)) {
192
+ if ((0, ast_1.isInvocationExpr)(node) && ((_a = node.function.ref) === null || _a === void 0 ? void 0 : _a.name) === 'future' && (0, sdk_1.isFromStdlib)(node.function.ref)) {
155
193
  return true;
156
194
  }
157
195
  }
@@ -161,10 +199,13 @@ class PolicyGenerator {
161
199
  return __awaiter(this, void 0, void 0, function* () {
162
200
  const result = {};
163
201
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
164
- const policies = (0, ast_utils_1.analyzePolicies)(model);
202
+ const policies = (0, sdk_1.analyzePolicies)(model);
165
203
  for (const kind of plugin_utils_1.ALL_OPERATION_KINDS) {
166
204
  if (policies[kind] === true || policies[kind] === false) {
167
205
  result[kind] = policies[kind];
206
+ if (kind === 'create') {
207
+ result[kind + '_input'] = policies[kind];
208
+ }
168
209
  continue;
169
210
  }
170
211
  const denies = this.getPolicyExpressions(model, 'deny', kind);
@@ -186,23 +227,156 @@ class PolicyGenerator {
186
227
  result[kind] = true;
187
228
  continue;
188
229
  }
189
- const func = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies);
230
+ const guardFunc = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies);
190
231
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
191
- result[kind] = func.getName();
232
+ result[kind] = guardFunc.getName();
192
233
  if (kind === 'postUpdate') {
193
- const preValueSelect = this.generatePreValueSelect(model, allows, denies);
234
+ const preValueSelect = this.generateSelectForRules(allows, denies);
194
235
  if (preValueSelect) {
195
- result['preValueSelect'] = preValueSelect;
236
+ result[runtime_1.PRE_UPDATE_VALUE_SELECTOR] = preValueSelect;
196
237
  }
197
238
  }
239
+ if (kind === 'create' && this.canCheckCreateBasedOnInput(model, allows, denies)) {
240
+ const inputCheckFunc = this.generateInputCheckFunction(sourceFile, model, kind, allows, denies);
241
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
242
+ result[kind + '_input'] = inputCheckFunc.getName();
243
+ }
198
244
  }
245
+ // generate field read checkers
246
+ this.generateReadFieldsGuards(model, sourceFile, result);
247
+ // generate field update guards
248
+ this.generateUpdateFieldsGuards(model, sourceFile, result);
199
249
  return result;
200
250
  });
201
251
  }
202
- // generates an object that can be used as the 'select' argument when fetching pre-update
203
- // entity value
204
- generatePreValueSelect(model, allows, denies) {
205
- var _a;
252
+ generateReadFieldsGuards(model, sourceFile, result) {
253
+ const allFieldsAllows = [];
254
+ const allFieldsDenies = [];
255
+ for (const field of model.fields) {
256
+ const allows = this.getPolicyExpressions(field, 'allow', 'read');
257
+ const denies = this.getPolicyExpressions(field, 'deny', 'read');
258
+ if (denies.length === 0 && allows.length === 0) {
259
+ continue;
260
+ }
261
+ allFieldsAllows.push(...allows);
262
+ allFieldsDenies.push(...denies);
263
+ const guardFunc = this.generateReadFieldGuardFunction(sourceFile, field, allows, denies);
264
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
265
+ result[`${runtime_1.FIELD_LEVEL_READ_CHECKER_PREFIX}${field.name}`] = guardFunc.getName();
266
+ }
267
+ if (allFieldsAllows.length > 0 || allFieldsDenies.length > 0) {
268
+ result[runtime_1.HAS_FIELD_LEVEL_POLICY_FLAG] = true;
269
+ const readFieldCheckSelect = this.generateSelectForRules(allFieldsAllows, allFieldsDenies);
270
+ if (readFieldCheckSelect) {
271
+ result[runtime_1.FIELD_LEVEL_READ_CHECKER_SELECTOR] = readFieldCheckSelect;
272
+ }
273
+ }
274
+ }
275
+ generateReadFieldGuardFunction(sourceFile, field, allows, denies) {
276
+ const statements = [];
277
+ this.generateNormalizedAuthRef(field.$container, allows, denies, statements);
278
+ // compile rules down to typescript expressions
279
+ statements.push((writer) => {
280
+ const transformer = new typescript_expression_transformer_1.TypeScriptExpressionTransformer({
281
+ context: sdk_1.ExpressionContext.AccessPolicy,
282
+ fieldReferenceContext: 'input',
283
+ });
284
+ const denyStmt = denies.length > 0
285
+ ? '!(' +
286
+ denies
287
+ .map((deny) => {
288
+ return transformer.transform(deny);
289
+ })
290
+ .join(' || ') +
291
+ ')'
292
+ : undefined;
293
+ const allowStmt = allows.length > 0
294
+ ? '(' +
295
+ allows
296
+ .map((allow) => {
297
+ return transformer.transform(allow);
298
+ })
299
+ .join(' || ') +
300
+ ')'
301
+ : undefined;
302
+ let expr;
303
+ if (denyStmt && allowStmt) {
304
+ expr = `${denyStmt} && ${allowStmt}`;
305
+ }
306
+ else if (denyStmt) {
307
+ expr = denyStmt;
308
+ }
309
+ else if (allowStmt) {
310
+ expr = allowStmt;
311
+ }
312
+ else {
313
+ throw new Error('should not happen');
314
+ }
315
+ writer.write('return ' + expr);
316
+ });
317
+ const func = sourceFile.addFunction({
318
+ name: `${field.$container.name}$${field.name}_read`,
319
+ returnType: 'boolean',
320
+ parameters: [
321
+ {
322
+ name: 'input',
323
+ type: 'any',
324
+ },
325
+ {
326
+ name: 'context',
327
+ type: 'QueryContext',
328
+ },
329
+ ],
330
+ statements,
331
+ });
332
+ return func;
333
+ }
334
+ generateUpdateFieldsGuards(model, sourceFile, result) {
335
+ for (const field of model.fields) {
336
+ const allows = this.getPolicyExpressions(field, 'allow', 'update');
337
+ const denies = this.getPolicyExpressions(field, 'deny', 'update');
338
+ if (denies.length === 0 && allows.length === 0) {
339
+ continue;
340
+ }
341
+ const guardFunc = this.generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field);
342
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
343
+ result[`${runtime_1.FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field.name}`] = guardFunc.getName();
344
+ }
345
+ }
346
+ canCheckCreateBasedOnInput(model, allows, denies) {
347
+ return [...allows, ...denies].every((rule) => {
348
+ return (0, langium_1.streamAst)(rule).every((expr) => {
349
+ var _a;
350
+ if ((0, ast_1.isThisExpr)(expr)) {
351
+ return false;
352
+ }
353
+ if ((0, ast_1.isReferenceExpr)(expr)) {
354
+ if ((0, ast_1.isDataModel)((_a = expr.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl)) {
355
+ // if policy rules uses relation fields,
356
+ // we can't check based on create input
357
+ return false;
358
+ }
359
+ if ((0, ast_1.isDataModelField)(expr.target.ref) &&
360
+ expr.target.ref.$container === model &&
361
+ (0, sdk_1.hasAttribute)(expr.target.ref, '@default')) {
362
+ // reference to field of current model
363
+ // if it has default value, we can't check
364
+ // based on create input
365
+ return false;
366
+ }
367
+ if ((0, ast_1.isDataModelField)(expr.target.ref) && (0, sdk_1.isForeignKeyField)(expr.target.ref)) {
368
+ // reference to foreign key field
369
+ // we can't check based on create input
370
+ return false;
371
+ }
372
+ }
373
+ return true;
374
+ });
375
+ });
376
+ }
377
+ // generates a "select" object that contains (recursively) fields referenced by the
378
+ // given policy rules
379
+ generateSelectForRules(allows, denies) {
206
380
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
207
381
  const result = {};
208
382
  const addPath = (path) => {
@@ -219,6 +393,8 @@ class PolicyGenerator {
219
393
  }
220
394
  });
221
395
  };
396
+ // visit a reference or member access expression to build a
397
+ // selection path
222
398
  const visit = (node) => {
223
399
  if ((0, ast_1.isReferenceExpr)(node)) {
224
400
  const target = (0, sdk_1.resolved)(node.target);
@@ -228,7 +404,7 @@ class PolicyGenerator {
228
404
  }
229
405
  }
230
406
  else if ((0, ast_1.isMemberAccessExpr)(node)) {
231
- if ((0, utils_2.isFutureExpr)(node.operand)) {
407
+ if ((0, sdk_1.isFutureExpr)(node.operand)) {
232
408
  // future().field is not subject to pre-update select
233
409
  return undefined;
234
410
  }
@@ -240,109 +416,218 @@ class PolicyGenerator {
240
416
  }
241
417
  return undefined;
242
418
  };
243
- for (const rule of [...allows, ...denies]) {
244
- for (const expr of (0, langium_1.streamAllContents)(rule).filter((node) => (0, ast_1.isExpression)(node))) {
245
- // only care about member access and reference expressions
246
- if (!(0, ast_1.isMemberAccessExpr)(expr) && !(0, ast_1.isReferenceExpr)(expr)) {
247
- continue;
248
- }
249
- if (expr.$container.$type === ast_1.MemberAccessExpr) {
250
- // only visit top-level member access
251
- continue;
252
- }
419
+ // collect selection paths from the given expression
420
+ const collectReferencePaths = (expr) => {
421
+ var _a, _b, _c;
422
+ if ((0, ast_1.isThisExpr)(expr) && !(0, ast_1.isMemberAccessExpr)(expr.$container)) {
423
+ // a standalone `this` expression, include all id fields
424
+ const model = (_a = expr.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl;
425
+ const idFields = (0, ast_utils_1.getIdFields)(model);
426
+ return idFields.map((field) => [field.name]);
427
+ }
428
+ if ((0, ast_1.isMemberAccessExpr)(expr) || (0, ast_1.isReferenceExpr)(expr)) {
253
429
  const path = visit(expr);
254
430
  if (path) {
255
- if ((0, ast_1.isDataModel)((_a = expr.$resolvedType) === null || _a === void 0 ? void 0 : _a.decl)) {
256
- // member selection ended at a data model field, include its 'id'
257
- path.push('id');
431
+ if ((0, ast_1.isDataModel)((_b = expr.$resolvedType) === null || _b === void 0 ? void 0 : _b.decl)) {
432
+ // member selection ended at a data model field, include its id fields
433
+ const idFields = (0, ast_utils_1.getIdFields)((_c = expr.$resolvedType) === null || _c === void 0 ? void 0 : _c.decl);
434
+ return idFields.map((field) => [...path, field.name]);
435
+ }
436
+ else {
437
+ return [path];
258
438
  }
259
- addPath(path);
260
439
  }
440
+ else {
441
+ return [];
442
+ }
443
+ }
444
+ else if ((0, ast_utils_1.isCollectionPredicate)(expr)) {
445
+ const path = visit(expr.left);
446
+ if (path) {
447
+ // recurse into RHS
448
+ const rhs = collectReferencePaths(expr.right);
449
+ // combine path of LHS and RHS
450
+ return rhs.map((r) => [...path, ...r]);
451
+ }
452
+ else {
453
+ return [];
454
+ }
455
+ }
456
+ else {
457
+ // recurse
458
+ const children = (0, langium_1.streamContents)(expr)
459
+ .filter((child) => (0, ast_1.isExpression)(child))
460
+ .toArray();
461
+ return children.flatMap((child) => collectReferencePaths(child));
261
462
  }
463
+ };
464
+ for (const rule of [...allows, ...denies]) {
465
+ const paths = collectReferencePaths(rule);
466
+ paths.forEach((p) => addPath(p));
262
467
  }
263
- return Object.keys(result).length === 0 ? null : result;
468
+ return Object.keys(result).length === 0 ? undefined : result;
264
469
  }
265
- generateQueryGuardFunction(sourceFile, model, kind, allows, denies) {
266
- const func = sourceFile
267
- .addFunction({
268
- name: model.name + '_' + kind,
470
+ generateQueryGuardFunction(sourceFile, model, kind, allows, denies, forField) {
471
+ const statements = [];
472
+ this.generateNormalizedAuthRef(model, allows, denies, statements);
473
+ const hasFieldAccess = [...denies, ...allows].some((rule) => (0, langium_1.streamAst)(rule).some((child) =>
474
+ // this.???
475
+ (0, ast_1.isThisExpr)(child) ||
476
+ // future().???
477
+ (0, sdk_1.isFutureExpr)(child) ||
478
+ // field reference
479
+ ((0, ast_1.isReferenceExpr)(child) && (0, ast_1.isDataModelField)(child.target.ref))));
480
+ if (!hasFieldAccess) {
481
+ // none of the rules reference model fields, we can compile down to a plain boolean
482
+ // function in this case (so we can skip doing SQL queries when validating)
483
+ statements.push((writer) => {
484
+ const transformer = new typescript_expression_transformer_1.TypeScriptExpressionTransformer({
485
+ context: sdk_1.ExpressionContext.AccessPolicy,
486
+ isPostGuard: kind === 'postUpdate',
487
+ });
488
+ try {
489
+ denies.forEach((rule) => {
490
+ writer.write(`if (${transformer.transform(rule, false)}) { return ${expression_writer_1.FALSE}; }`);
491
+ });
492
+ allows.forEach((rule) => {
493
+ writer.write(`if (${transformer.transform(rule, false)}) { return ${expression_writer_1.TRUE}; }`);
494
+ });
495
+ }
496
+ catch (err) {
497
+ if (err instanceof typescript_expression_transformer_1.TypeScriptExpressionTransformerError) {
498
+ throw new sdk_1.PluginError(_1.name, err.message);
499
+ }
500
+ else {
501
+ throw err;
502
+ }
503
+ }
504
+ writer.write(`return ${expression_writer_1.FALSE};`);
505
+ });
506
+ }
507
+ else {
508
+ statements.push((writer) => {
509
+ writer.write('return ');
510
+ const exprWriter = new expression_writer_1.ExpressionWriter(writer, kind === 'postUpdate');
511
+ const writeDenies = () => {
512
+ writer.conditionalWrite(denies.length > 1, '{ AND: [');
513
+ denies.forEach((expr, i) => {
514
+ writer.inlineBlock(() => {
515
+ writer.write('NOT: ');
516
+ exprWriter.write(expr);
517
+ });
518
+ writer.conditionalWrite(i !== denies.length - 1, ',');
519
+ });
520
+ writer.conditionalWrite(denies.length > 1, ']}');
521
+ };
522
+ const writeAllows = () => {
523
+ writer.conditionalWrite(allows.length > 1, '{ OR: [');
524
+ allows.forEach((expr, i) => {
525
+ exprWriter.write(expr);
526
+ writer.conditionalWrite(i !== allows.length - 1, ',');
527
+ });
528
+ writer.conditionalWrite(allows.length > 1, ']}');
529
+ };
530
+ if (allows.length > 0 && denies.length > 0) {
531
+ writer.write('{ AND: [');
532
+ writeDenies();
533
+ writer.write(',');
534
+ writeAllows();
535
+ writer.write(']}');
536
+ }
537
+ else if (denies.length > 0) {
538
+ writeDenies();
539
+ }
540
+ else if (allows.length > 0) {
541
+ writeAllows();
542
+ }
543
+ else {
544
+ // disallow any operation
545
+ writer.write(`{ OR: [] }`);
546
+ }
547
+ writer.write(';');
548
+ });
549
+ }
550
+ const func = sourceFile.addFunction({
551
+ name: `${model.name}${forField ? '$' + forField.name : ''}_${kind}`,
269
552
  returnType: 'any',
270
553
  parameters: [
271
554
  {
272
555
  name: 'context',
273
556
  type: 'QueryContext',
274
557
  },
558
+ {
559
+ // for generating field references used by field comparison in the same model
560
+ name: 'db',
561
+ type: 'Record<string, DbOperations>',
562
+ },
275
563
  ],
276
- })
277
- .addBody();
278
- // check if any allow or deny rule contains 'auth()' invocation
279
- let hasAuthRef = false;
280
- for (const node of [...denies, ...allows]) {
281
- for (const child of (0, langium_1.streamAllContents)(node)) {
282
- if ((0, ast_1.isInvocationExpr)(child) && (0, sdk_1.resolved)(child.function).name === 'auth') {
283
- hasAuthRef = true;
284
- break;
285
- }
286
- }
287
- if (hasAuthRef) {
288
- break;
564
+ statements,
565
+ });
566
+ return func;
567
+ }
568
+ generateInputCheckFunction(sourceFile, model, kind, allows, denies) {
569
+ const statements = [];
570
+ this.generateNormalizedAuthRef(model, allows, denies, statements);
571
+ statements.push((writer) => {
572
+ if (allows.length === 0) {
573
+ writer.write('return false;');
574
+ return;
289
575
  }
290
- }
576
+ const transformer = new typescript_expression_transformer_1.TypeScriptExpressionTransformer({
577
+ context: sdk_1.ExpressionContext.AccessPolicy,
578
+ fieldReferenceContext: 'input',
579
+ });
580
+ let expr = denies.length > 0
581
+ ? '!(' +
582
+ denies
583
+ .map((deny) => {
584
+ return transformer.transform(deny);
585
+ })
586
+ .join(' || ') +
587
+ ')'
588
+ : undefined;
589
+ const allowStmt = allows
590
+ .map((allow) => {
591
+ return transformer.transform(allow);
592
+ })
593
+ .join(' || ');
594
+ expr = expr ? `${expr} && (${allowStmt})` : allowStmt;
595
+ writer.write('return ' + expr);
596
+ });
597
+ const func = sourceFile.addFunction({
598
+ name: model.name + '_' + kind + '_input',
599
+ returnType: 'boolean',
600
+ parameters: [
601
+ {
602
+ name: 'input',
603
+ type: 'any',
604
+ },
605
+ {
606
+ name: 'context',
607
+ type: 'QueryContext',
608
+ },
609
+ ],
610
+ statements,
611
+ });
612
+ return func;
613
+ }
614
+ generateNormalizedAuthRef(model, allows, denies, statements) {
615
+ // check if any allow or deny rule contains 'auth()' invocation
616
+ const hasAuthRef = [...allows, ...denies].some((rule) => (0, langium_1.streamAst)(rule).some((child) => (0, ast_utils_1.isAuthInvocation)(child)));
291
617
  if (hasAuthRef) {
292
618
  const userModel = model.$container.declarations.find((decl) => (0, ast_1.isDataModel)(decl) && decl.name === 'User');
293
619
  if (!userModel) {
294
- throw new sdk_1.PluginError('User model not found');
620
+ throw new sdk_1.PluginError(_1.name, 'User model not found');
295
621
  }
296
- const userIdField = (0, ast_utils_1.getIdField)(userModel);
297
- if (!userIdField) {
298
- throw new sdk_1.PluginError('User model does not have an id field');
622
+ const userIdFields = (0, ast_utils_1.getIdFields)(userModel);
623
+ if (!userIdFields || userIdFields.length === 0) {
624
+ throw new sdk_1.PluginError(_1.name, 'User model does not have an id field');
299
625
  }
300
626
  // normalize user to null to avoid accidentally use undefined in filter
301
- func.addStatements(`const user = context.user ?? null;`);
627
+ statements.push(`const user = hasAllFields(context.user, [${userIdFields
628
+ .map((f) => "'" + f.name + "'")
629
+ .join(', ')}]) ? context.user as any : null;`);
302
630
  }
303
- // r = <guard object>;
304
- func.addStatements((writer) => {
305
- writer.write('return ');
306
- const exprWriter = new expression_writer_1.ExpressionWriter(writer, kind === 'postUpdate');
307
- const writeDenies = () => {
308
- writer.conditionalWrite(denies.length > 1, '{ AND: [');
309
- denies.forEach((expr, i) => {
310
- writer.inlineBlock(() => {
311
- writer.write('NOT: ');
312
- exprWriter.write(expr);
313
- });
314
- writer.conditionalWrite(i !== denies.length - 1, ',');
315
- });
316
- writer.conditionalWrite(denies.length > 1, ']}');
317
- };
318
- const writeAllows = () => {
319
- writer.conditionalWrite(allows.length > 1, '{ OR: [');
320
- allows.forEach((expr, i) => {
321
- exprWriter.write(expr);
322
- writer.conditionalWrite(i !== allows.length - 1, ',');
323
- });
324
- writer.conditionalWrite(allows.length > 1, ']}');
325
- };
326
- if (allows.length > 0 && denies.length > 0) {
327
- writer.write('{ AND: [');
328
- writeDenies();
329
- writer.write(',');
330
- writeAllows();
331
- writer.write(']}');
332
- }
333
- else if (denies.length > 0) {
334
- writeDenies();
335
- }
336
- else if (allows.length > 0) {
337
- writeAllows();
338
- }
339
- else {
340
- // disallow any operation
341
- writer.write(`{ ${sdk_1.GUARD_FIELD_NAME}: false }`);
342
- }
343
- writer.write(';');
344
- });
345
- return func;
346
631
  }
347
632
  }
348
633
  exports.default = PolicyGenerator;