@openrewrite/rewrite 8.66.1 → 8.66.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. package/dist/java/tree.d.ts +10 -1
  2. package/dist/java/tree.d.ts.map +1 -1
  3. package/dist/java/tree.js +21 -5
  4. package/dist/java/tree.js.map +1 -1
  5. package/dist/java/type-visitor.d.ts +1 -1
  6. package/dist/java/type-visitor.d.ts.map +1 -1
  7. package/dist/java/visitor.d.ts +2 -2
  8. package/dist/java/visitor.d.ts.map +1 -1
  9. package/dist/java/visitor.js +8 -2
  10. package/dist/java/visitor.js.map +1 -1
  11. package/dist/javascript/assertions.d.ts +6 -0
  12. package/dist/javascript/assertions.d.ts.map +1 -1
  13. package/dist/javascript/assertions.js +14 -6
  14. package/dist/javascript/assertions.js.map +1 -1
  15. package/dist/javascript/comparator.d.ts +154 -7
  16. package/dist/javascript/comparator.d.ts.map +1 -1
  17. package/dist/javascript/comparator.js +623 -180
  18. package/dist/javascript/comparator.js.map +1 -1
  19. package/dist/javascript/format.d.ts +5 -3
  20. package/dist/javascript/format.d.ts.map +1 -1
  21. package/dist/javascript/format.js +85 -43
  22. package/dist/javascript/format.js.map +1 -1
  23. package/dist/javascript/index.d.ts +1 -0
  24. package/dist/javascript/index.d.ts.map +1 -1
  25. package/dist/javascript/index.js +1 -0
  26. package/dist/javascript/index.js.map +1 -1
  27. package/dist/javascript/parser.d.ts +2 -1
  28. package/dist/javascript/parser.d.ts.map +1 -1
  29. package/dist/javascript/parser.js +39 -30
  30. package/dist/javascript/parser.js.map +1 -1
  31. package/dist/javascript/templating/capture.d.ts +81 -14
  32. package/dist/javascript/templating/capture.d.ts.map +1 -1
  33. package/dist/javascript/templating/capture.js +98 -8
  34. package/dist/javascript/templating/capture.js.map +1 -1
  35. package/dist/javascript/templating/comparator.d.ts +125 -15
  36. package/dist/javascript/templating/comparator.d.ts.map +1 -1
  37. package/dist/javascript/templating/comparator.js +946 -118
  38. package/dist/javascript/templating/comparator.js.map +1 -1
  39. package/dist/javascript/templating/engine.d.ts +58 -25
  40. package/dist/javascript/templating/engine.d.ts.map +1 -1
  41. package/dist/javascript/templating/engine.js +527 -94
  42. package/dist/javascript/templating/engine.js.map +1 -1
  43. package/dist/javascript/templating/index.d.ts +3 -3
  44. package/dist/javascript/templating/index.d.ts.map +1 -1
  45. package/dist/javascript/templating/index.js +3 -1
  46. package/dist/javascript/templating/index.js.map +1 -1
  47. package/dist/javascript/templating/pattern.d.ts +121 -16
  48. package/dist/javascript/templating/pattern.d.ts.map +1 -1
  49. package/dist/javascript/templating/pattern.js +528 -257
  50. package/dist/javascript/templating/pattern.js.map +1 -1
  51. package/dist/javascript/templating/placeholder-replacement.d.ts +30 -5
  52. package/dist/javascript/templating/placeholder-replacement.d.ts.map +1 -1
  53. package/dist/javascript/templating/placeholder-replacement.js +183 -81
  54. package/dist/javascript/templating/placeholder-replacement.js.map +1 -1
  55. package/dist/javascript/templating/rewrite.d.ts +56 -11
  56. package/dist/javascript/templating/rewrite.d.ts.map +1 -1
  57. package/dist/javascript/templating/rewrite.js +143 -16
  58. package/dist/javascript/templating/rewrite.js.map +1 -1
  59. package/dist/javascript/templating/template.d.ts +31 -5
  60. package/dist/javascript/templating/template.d.ts.map +1 -1
  61. package/dist/javascript/templating/template.js +89 -15
  62. package/dist/javascript/templating/template.js.map +1 -1
  63. package/dist/javascript/templating/types.d.ts +359 -12
  64. package/dist/javascript/templating/types.d.ts.map +1 -1
  65. package/dist/javascript/templating/utils.d.ts +52 -35
  66. package/dist/javascript/templating/utils.d.ts.map +1 -1
  67. package/dist/javascript/templating/utils.js +107 -109
  68. package/dist/javascript/templating/utils.js.map +1 -1
  69. package/dist/javascript/type-mapping.d.ts.map +1 -1
  70. package/dist/javascript/type-mapping.js +21 -11
  71. package/dist/javascript/type-mapping.js.map +1 -1
  72. package/dist/json/rpc.js +2 -2
  73. package/dist/json/rpc.js.map +1 -1
  74. package/dist/recipe/order-imports.js.map +1 -1
  75. package/dist/test/rewrite-test.d.ts.map +1 -1
  76. package/dist/test/rewrite-test.js +10 -6
  77. package/dist/test/rewrite-test.js.map +1 -1
  78. package/dist/version.txt +1 -1
  79. package/dist/visitor.d.ts +4 -4
  80. package/dist/visitor.d.ts.map +1 -1
  81. package/dist/visitor.js +8 -3
  82. package/dist/visitor.js.map +1 -1
  83. package/package.json +4 -2
  84. package/src/java/tree.ts +10 -3
  85. package/src/java/type-visitor.ts +1 -1
  86. package/src/java/visitor.ts +11 -5
  87. package/src/javascript/assertions.ts +9 -3
  88. package/src/javascript/comparator.ts +676 -185
  89. package/src/javascript/format.ts +72 -34
  90. package/src/javascript/index.ts +1 -0
  91. package/src/javascript/parser.ts +51 -31
  92. package/src/javascript/templating/capture.ts +107 -15
  93. package/src/javascript/templating/comparator.ts +1087 -134
  94. package/src/javascript/templating/engine.ts +601 -103
  95. package/src/javascript/templating/index.ts +9 -2
  96. package/src/javascript/templating/pattern.ts +655 -281
  97. package/src/javascript/templating/placeholder-replacement.ts +183 -80
  98. package/src/javascript/templating/rewrite.ts +152 -18
  99. package/src/javascript/templating/template.ts +110 -22
  100. package/src/javascript/templating/types.ts +386 -12
  101. package/src/javascript/templating/utils.ts +116 -102
  102. package/src/javascript/type-mapping.ts +20 -11
  103. package/src/json/rpc.ts +2 -2
  104. package/src/recipe/order-imports.ts +1 -1
  105. package/src/test/rewrite-test.ts +12 -7
  106. package/src/visitor.ts +14 -6
@@ -13,13 +13,12 @@
13
13
  * See the License for the specific language governing permissions and
14
14
  * limitations under the License.
15
15
  */
16
+ import {Cursor} from '../..';
16
17
  import {J} from '../../java';
17
18
  import {JS} from '../index';
18
- import {JavaScriptParser} from '../parser';
19
- import {DependencyWorkspace} from '../dependency-workspace';
20
- import {Marker} from '../../markers';
19
+ import {Marker, Markers} from '../../markers';
21
20
  import {randomId} from '../../uuid';
22
- import {VariadicOptions, Capture, Any} from './types';
21
+ import {ConstraintFunction, VariadicOptions} from './types';
23
22
 
24
23
  /**
25
24
  * Internal storage value type for pattern match captures.
@@ -36,84 +35,86 @@ export type CaptureStorageValue = J | J.RightPadded<J> | J[] | J.RightPadded<J>[
36
35
  export const WRAPPERS_MAP_SYMBOL = Symbol('wrappersMap');
37
36
 
38
37
  /**
39
- * Cache for compiled templates and patterns.
40
- * Stores parsed ASTs to avoid expensive re-parsing and dependency resolution.
38
+ * Shared wrapper function name used by both patterns and templates.
39
+ * Using the same name allows cache sharing when pattern and template code is identical.
41
40
  */
42
- export class TemplateCache {
43
- private cache = new Map<string, JS.CompilationUnit>();
41
+ export const WRAPPER_FUNCTION_NAME = '__WRAPPER__';
44
42
 
45
- /**
46
- * Generates a cache key from template string, captures, and options.
47
- */
48
- private generateKey(
49
- templateString: string,
50
- captures: (Capture | Any<any>)[],
51
- contextStatements: string[],
52
- dependencies: Record<string, string>
53
- ): string {
54
- // Use the actual template string (with placeholders) as the primary key
55
- const templateKey = templateString;
56
-
57
- // Capture names
58
- const capturesKey = captures.map(c => c.getName()).join(',');
59
-
60
- // Context statements
61
- const contextKey = contextStatements.join(';');
62
-
63
- // Dependencies
64
- const depsKey = JSON.stringify(dependencies || {});
43
+ /**
44
+ * Simple LRU (Least Recently Used) cache implementation using Map's insertion order.
45
+ * JavaScript Map maintains insertion order, so the first entry is the oldest.
46
+ *
47
+ * Used by both Pattern and Template caching to provide bounded memory usage.
48
+ */
49
+ export class LRUCache<K, V> {
50
+ private cache = new Map<K, V>();
65
51
 
66
- return `${templateKey}::${capturesKey}::${contextKey}::${depsKey}`;
52
+ constructor(private maxSize: number) {
67
53
  }
68
54
 
69
- /**
70
- * Gets a cached compilation unit or creates and caches a new one.
71
- */
72
- async getOrParse(
73
- templateString: string,
74
- captures: (Capture | Any<any>)[],
75
- contextStatements: string[],
76
- dependencies: Record<string, string>
77
- ): Promise<JS.CompilationUnit> {
78
- const key = this.generateKey(templateString, captures, contextStatements, dependencies);
79
-
80
- let cu = this.cache.get(key);
81
- if (cu) {
82
- return cu;
83
- }
84
-
85
- // Create workspace if dependencies are provided
86
- // DependencyWorkspace has its own cache, so multiple templates with
87
- // the same dependencies will automatically share the same workspace
88
- let workspaceDir: string | undefined;
89
- if (dependencies && Object.keys(dependencies).length > 0) {
90
- workspaceDir = await DependencyWorkspace.getOrCreateWorkspace(dependencies);
55
+ get(key: K): V | undefined {
56
+ const value = this.cache.get(key);
57
+ if (value !== undefined) {
58
+ // Move to end (most recently used)
59
+ this.cache.delete(key);
60
+ this.cache.set(key, value);
91
61
  }
62
+ return value;
63
+ }
92
64
 
93
- // Prepend context statements for type attribution context
94
- const fullTemplateString = contextStatements.length > 0
95
- ? contextStatements.join('\n') + '\n' + templateString
96
- : templateString;
65
+ set(key: K, value: V): void {
66
+ // Remove if exists (to update position)
67
+ this.cache.delete(key);
97
68
 
98
- // Parse and cache (workspace only needed during parsing)
99
- const parser = new JavaScriptParser({relativeTo: workspaceDir});
100
- const parseGenerator = parser.parse({text: fullTemplateString, sourcePath: 'template.ts'});
101
- cu = (await parseGenerator.next()).value as JS.CompilationUnit;
69
+ // Add to end
70
+ this.cache.set(key, value);
102
71
 
103
- this.cache.set(key, cu);
104
- return cu;
72
+ // Evict oldest if over capacity
73
+ if (this.cache.size > this.maxSize) {
74
+ const iterator = this.cache.keys();
75
+ const firstEntry = iterator.next();
76
+ if (!firstEntry.done) {
77
+ this.cache.delete(firstEntry.value);
78
+ }
79
+ }
105
80
  }
106
81
 
107
- /**
108
- * Clears the cache.
109
- */
110
82
  clear(): void {
111
83
  this.cache.clear();
112
84
  }
113
85
  }
114
86
 
115
- // Global cache instance
116
- export const templateCache = new TemplateCache();
87
+ /**
88
+ * Shared global LRU cache for both pattern and template ASTs.
89
+ * When pattern and template code is identical, they share the same cached AST.
90
+ * This mirrors JavaTemplate's unified approach in the Java implementation.
91
+ * Bounded to 100 entries using LRU eviction.
92
+ */
93
+ export const globalAstCache = new LRUCache<string, J>(100);
94
+
95
+ /**
96
+ * Generates a cache key for template/pattern processing.
97
+ * Used by both Pattern and Template for consistent cache key generation.
98
+ *
99
+ * @param templateParts The template string parts
100
+ * @param itemsKey String representing the captures/parameters (comma-separated)
101
+ * @param contextStatements Context declarations
102
+ * @param dependencies NPM dependencies
103
+ * @returns A cache key string
104
+ */
105
+ export function generateCacheKey(
106
+ templateParts: string[] | TemplateStringsArray,
107
+ itemsKey: string,
108
+ contextStatements: string[],
109
+ dependencies: Record<string, string>
110
+ ): string {
111
+ return [
112
+ Array.from(templateParts).join('|'),
113
+ itemsKey,
114
+ contextStatements.join(';'),
115
+ JSON.stringify(dependencies)
116
+ ].join('::');
117
+ }
117
118
 
118
119
  /**
119
120
  * Marker that stores capture metadata on pattern AST nodes.
@@ -125,7 +126,8 @@ export class CaptureMarker implements Marker {
125
126
 
126
127
  constructor(
127
128
  public readonly captureName: string,
128
- public readonly variadicOptions?: VariadicOptions
129
+ public readonly variadicOptions?: VariadicOptions,
130
+ public readonly constraint?: ConstraintFunction<any>
129
131
  ) {
130
132
  }
131
133
  }
@@ -154,30 +156,13 @@ export class PlaceholderUtils {
154
156
  return false;
155
157
  }
156
158
 
157
- /**
158
- * Gets the capture name from a node with a CaptureMarker.
159
- *
160
- * @param node The node to extract capture name from
161
- * @returns The capture name, or null if not a capture
162
- */
163
- static getCaptureName(node: J): string | undefined {
164
- // Check for CaptureMarker
165
- for (const marker of node.markers.markers) {
166
- if (marker instanceof CaptureMarker) {
167
- return marker.captureName;
168
- }
169
- }
170
-
171
- return undefined;
172
- }
173
-
174
159
  /**
175
160
  * Gets the CaptureMarker from a node, if present.
176
161
  *
177
162
  * @param node The node to check
178
163
  * @returns The CaptureMarker or undefined
179
164
  */
180
- static getCaptureMarker(node: J): CaptureMarker | undefined {
165
+ static getCaptureMarker(node: { markers: Markers }): CaptureMarker | undefined {
181
166
  for (const marker of node.markers.markers) {
182
167
  if (marker instanceof CaptureMarker) {
183
168
  return marker;
@@ -235,7 +220,7 @@ export class PlaceholderUtils {
235
220
  * @param node The node to check
236
221
  * @returns true if the node has a variadic CaptureMarker, false otherwise
237
222
  */
238
- static isVariadicCapture(node: J): boolean {
223
+ static isVariadicCapture(node: { markers: Markers }): boolean {
239
224
  for (const marker of node.markers.markers) {
240
225
  if (marker instanceof CaptureMarker && marker.variadicOptions) {
241
226
  return true;
@@ -250,7 +235,7 @@ export class PlaceholderUtils {
250
235
  * @param node The node to extract variadic options from
251
236
  * @returns The VariadicOptions, or undefined if not a variadic capture
252
237
  */
253
- static getVariadicOptions(node: J): VariadicOptions | undefined {
238
+ static getVariadicOptions(node: { markers: Markers }): VariadicOptions | undefined {
254
239
  for (const marker of node.markers.markers) {
255
240
  if (marker instanceof CaptureMarker) {
256
241
  return marker.variadicOptions;
@@ -260,25 +245,54 @@ export class PlaceholderUtils {
260
245
  }
261
246
 
262
247
  /**
263
- * Checks if a statement is an ExpressionStatement wrapping a capture identifier.
264
- * When a capture placeholder appears in statement position, the parser wraps it as
265
- * an ExpressionStatement. This method unwraps it to get the identifier.
248
+ * Extracts the relevant AST node from a wrapper function.
249
+ * Used by both pattern and template processors to intelligently extract
250
+ * code from `function __WRAPPER__() { code }` wrappers.
266
251
  *
267
- * @param stmt The statement to check
268
- * @returns The unwrapped capture identifier, or the original statement if not wrapped
252
+ * @param lastStatement The last statement from the compilation unit
253
+ * @param contextName Context name for error messages (e.g., 'Pattern', 'Template')
254
+ * @returns The extracted AST node
269
255
  */
270
- static unwrapStatementCapture(stmt: J): J {
271
- // Check if it's an ExpressionStatement containing a capture identifier
272
- if (stmt.kind === JS.Kind.ExpressionStatement) {
273
- const exprStmt = stmt as JS.ExpressionStatement;
274
- if (exprStmt.expression?.kind === J.Kind.Identifier) {
275
- const identifier = exprStmt.expression as J.Identifier;
276
- // Check if this is a capture placeholder
277
- if (identifier.simpleName?.startsWith(this.CAPTURE_PREFIX)) {
278
- return identifier;
256
+ static extractFromWrapper(lastStatement: J, contextName: string): J {
257
+ let extracted: J;
258
+
259
+ // Since we always wrap in function __WRAPPER__() { code }, look for it
260
+ if (lastStatement.kind === J.Kind.MethodDeclaration) {
261
+ const method = lastStatement as J.MethodDeclaration;
262
+ if (method.name?.simpleName === WRAPPER_FUNCTION_NAME && method.body) {
263
+ const body = method.body;
264
+
265
+ // Intelligently extract based on what's in the function body
266
+ if (body.statements.length === 0) {
267
+ throw new Error(`${contextName} function body is empty`);
268
+ } else if (body.statements.length === 1) {
269
+ const stmt = body.statements[0].element;
270
+
271
+ // Single expression statement → extract the expression
272
+ if (stmt.kind === JS.Kind.ExpressionStatement) {
273
+ extracted = (stmt as JS.ExpressionStatement).expression;
274
+ }
275
+ // Single block statement → keep the block
276
+ else if (stmt.kind === J.Kind.Block) {
277
+ extracted = stmt;
278
+ }
279
+ // Other single statement → keep it
280
+ else {
281
+ extracted = stmt;
282
+ }
283
+ } else {
284
+ // Multiple statements → keep the block
285
+ extracted = body;
279
286
  }
287
+ } else {
288
+ // Not our wrapper function
289
+ extracted = lastStatement;
280
290
  }
291
+ } else {
292
+ // Shouldn't happen with our wrapping strategy, but handle it
293
+ extracted = lastStatement;
281
294
  }
282
- return stmt;
295
+
296
+ return extracted;
283
297
  }
284
298
  }
@@ -402,19 +402,28 @@ export class JavaScriptTypeMapping {
402
402
  // If getAliasedSymbol returns something different, it's an import
403
403
  if (aliasedSymbol && aliasedSymbol !== symbol) {
404
404
  // This is definitely an imported symbol
405
- // Now find the import declaration to get the module specifier
406
- if (symbol.declarations && symbol.declarations.length > 0) {
407
- let importNode: ts.Node = symbol.declarations[0];
405
+ const aliasedParentSymbol = (aliasedSymbol as any).parent as ts.Symbol | undefined;
408
406
 
409
- // Traverse up to find the ImportDeclaration
410
- while (importNode && !ts.isImportDeclaration(importNode)) {
411
- importNode = importNode.parent;
412
- }
407
+ if (aliasedParentSymbol && aliasedParentSymbol.declarations?.[0] &&
408
+ ts.isModuleDeclaration(aliasedParentSymbol.declarations[0]) &&
409
+ ts.isIdentifier(aliasedParentSymbol.declarations[0].name)) {
410
+ // For namespace imports, use the namespace symbol's `name` as the module specifier (e.g. `React` instead of `react`)
411
+ moduleSpecifier = aliasedParentSymbol.name;
412
+ } else {
413
+ // Now find the import declaration to get the module specifier
414
+ if (symbol.declarations && symbol.declarations.length > 0) {
415
+ let importNode: ts.Node = symbol.declarations[0];
413
416
 
414
- if (importNode && ts.isImportDeclaration(importNode)) {
415
- const importDeclNode = importNode as ts.ImportDeclaration;
416
- if (ts.isStringLiteral(importDeclNode.moduleSpecifier)) {
417
- moduleSpecifier = importDeclNode.moduleSpecifier.text;
417
+ // Traverse up to find the ImportDeclaration
418
+ while (importNode && !ts.isImportDeclaration(importNode)) {
419
+ importNode = importNode.parent;
420
+ }
421
+
422
+ if (importNode && ts.isImportDeclaration(importNode)) {
423
+ const importDeclNode = importNode as ts.ImportDeclaration;
424
+ if (ts.isStringLiteral(importDeclNode.moduleSpecifier)) {
425
+ moduleSpecifier = importDeclNode.moduleSpecifier.text;
426
+ }
418
427
  }
419
428
  }
420
429
  }
package/src/json/rpc.ts CHANGED
@@ -158,7 +158,7 @@ class JsonReceiver extends JsonVisitor<RpcReceiveQueue> {
158
158
  }
159
159
 
160
160
  public async visitSpace(space: Json.Space, q: RpcReceiveQueue): Promise<Json.Space> {
161
- return produceAsync<Json.Space>(space, async draft => {
161
+ return (await produceAsync<Json.Space>(space, async draft => {
162
162
  draft.comments = await q.receiveListDefined(space.comments, async c => {
163
163
  return await produceAsync(c, async draft => {
164
164
  draft.multiline = await q.receive(c.multiline);
@@ -168,7 +168,7 @@ class JsonReceiver extends JsonVisitor<RpcReceiveQueue> {
168
168
  })
169
169
  });
170
170
  draft.whitespace = await q.receive(space.whitespace);
171
- });
171
+ }))!;
172
172
  }
173
173
 
174
174
  public async visitRightPadded<T extends Json>(right: Json.RightPadded<T>, p: RpcReceiveQueue): Promise<Json.RightPadded<T> | undefined> {
@@ -63,7 +63,7 @@ export class OrderImports extends Recipe {
63
63
  const cuWithImportsSorted = await produceAsync(cu, async draft => {
64
64
  draft.statements = [...sortedImports, ...restStatements];
65
65
  });
66
- return produce(cuWithImportsSorted, draft => {
66
+ return produce(cuWithImportsSorted!, draft => {
67
67
  for (let i = 0; i < importCount; i++) {
68
68
  draft.statements[i].element.prefix.whitespace = i > 0 ? "\n" : "";
69
69
  }
@@ -175,7 +175,7 @@ export class RecipeSpec {
175
175
  (spec.after as (actual: string) => string)(actualAfter) : spec.after as string;
176
176
  expect(actualAfter).toEqual(afterSource);
177
177
  if (spec.afterRecipe) {
178
- await spec.afterRecipe(actualAfter);
178
+ await spec.afterRecipe(after);
179
179
  }
180
180
  }
181
181
 
@@ -279,9 +279,14 @@ function dedent(s: string): string {
279
279
  const str = start > 0 || end < s.length ? s.slice(start, end) : s;
280
280
  const lines = str.split('\n');
281
281
 
282
- // Find minimum indentation (avoid regex for performance)
282
+ // If we removed a leading newline, consider all lines for minIndent
283
+ // Otherwise, skip the first line (it's on the same line as the opening quote)
284
+ const startLine = start > 0 ? 0 : 1;
285
+
286
+ // Find minimum indentation
283
287
  let minIndent = Infinity;
284
- for (const line of lines) {
288
+ for (let i = startLine; i < lines.length; i++) {
289
+ const line = lines[i];
285
290
  let indent = 0;
286
291
  for (let j = 0; j < line.length; j++) {
287
292
  const ch = line.charCodeAt(j);
@@ -297,12 +302,12 @@ function dedent(s: string): string {
297
302
 
298
303
  // If all lines are empty or no indentation
299
304
  if (minIndent === Infinity || minIndent === 0) {
300
- return lines.map(line => line.trim() || '').join('\n');
305
+ return lines.join('\n');
301
306
  }
302
307
 
303
- // Remove common indentation from each line
304
- return lines.map(line =>
305
- line.length >= minIndent ? line.slice(minIndent) : ''
308
+ // Remove common indentation from lines (skip first line only if we didn't remove leading newline)
309
+ return lines.map((line, i) =>
310
+ (i === 0 && startLine === 1) ? line : (line.length >= minIndent ? line.slice(minIndent) : '')
306
311
  ).join('\n');
307
312
  }
308
313
 
package/src/visitor.ts CHANGED
@@ -15,7 +15,7 @@
15
15
  */
16
16
  import {emptyMarkers, Marker, Markers} from "./markers";
17
17
  import {Cursor, isSourceFile, rootCursor, SourceFile, Tree} from "./tree";
18
- import {createDraft, Draft, finishDraft, Objectish} from "immer";
18
+ import {createDraft, Draft, finishDraft, nothing, Objectish} from "immer";
19
19
  import {mapAsync} from "./util";
20
20
 
21
21
  /* Not exported beyond the internal immer module */
@@ -23,15 +23,23 @@ export type ValidImmerRecipeReturnType<State> =
23
23
  | State
24
24
  | void
25
25
  | undefined
26
+ | typeof nothing
26
27
 
27
28
  export async function produceAsync<Base extends Objectish>(
28
29
  before: Promise<Base> | Base,
29
30
  recipe: (draft: Draft<Base>) => ValidImmerRecipeReturnType<Draft<Base>> |
30
31
  PromiseLike<ValidImmerRecipeReturnType<Draft<Base>>>
31
- ): Promise<Base> {
32
+ ): Promise<Base | undefined> {
32
33
  const b: Base = await before;
33
34
  const draft = createDraft(b);
34
- await recipe(draft);
35
+ const result = await recipe(draft);
36
+
37
+ // If recipe explicitly returned Immer's nothing, return undefined
38
+ if (result === nothing) {
39
+ return undefined;
40
+ }
41
+
42
+ // Otherwise, return the finished draft (void/undefined means use draft)
35
43
  return finishDraft(draft) as Base;
36
44
  }
37
45
 
@@ -128,9 +136,9 @@ export abstract class TreeVisitor<T extends Tree, P> {
128
136
  } else if ((markers.markers?.length || 0) === 0) {
129
137
  return markers;
130
138
  }
131
- return produceAsync<Markers>(markers, async (draft) => {
139
+ return (await produceAsync<Markers>(markers, async (draft) => {
132
140
  draft.markers = await mapAsync(markers.markers, m => this.visitMarker(m, p))
133
- });
141
+ }))!;
134
142
  }
135
143
 
136
144
  protected async visitMarker<M extends Marker>(marker: M, p: P): Promise<M> {
@@ -143,7 +151,7 @@ export abstract class TreeVisitor<T extends Tree, P> {
143
151
  recipe?:
144
152
  ((draft: Draft<T>) => ValidImmerRecipeReturnType<Draft<T>>) |
145
153
  ((draft: Draft<T>) => Promise<ValidImmerRecipeReturnType<Draft<T>>>)
146
- ): Promise<T> {
154
+ ): Promise<T | undefined> {
147
155
  return produceAsync(before, async draft => {
148
156
  draft.markers = await this.visitMarkers(before.markers, p);
149
157
  if (recipe) {