ast-pattern-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+ import ast
3
+ from typing import Any
4
+
5
+ from ast_pattern_engine.plumbing import Message
6
+
7
+ class Pattern(ast.AST):
8
+ """Base class for AST matching patterns."""
9
+
10
+ _buffered_messages: list[Message] | None = None
11
+ _forces_list: bool = False # overridden by repetition wrappers
12
+ _parent: Pattern | None = None
13
+
14
+ # ---------------------------------- public API ----------------------------------
15
+ def match(self, node: object, bindings: dict[str, object] | None = None):
16
+ """Match *node* and return updated *bindings* or *None*."""
17
+ raise NotImplementedError
18
+
19
+ def handle_message(self, message: Message) -> None:
20
+ self.bubble_message(message)
21
+
22
+ def bubble_message(self, message: Message) -> None:
23
+ if self._parent is not None:
24
+ self._parent.handle_message(message)
25
+ else:
26
+ if self._buffered_messages is None:
27
+ self._buffered_messages = []
28
+ self._buffered_messages.append(message)
29
+
30
+ # ------------------------------- parent navigation -------------------------------
31
+ def _on_parent_init(self) -> None:
32
+ if self._buffered_messages:
33
+ for m in self._buffered_messages:
34
+ self.bubble_message(m)
35
+
36
+ def _set_parent(self, parent: Pattern) -> None:
37
+ self._parent = parent
38
+ self._on_parent_init()
39
+
40
+ def _ancestor_forces_list(self) -> bool:
41
+ cur = self._parent
42
+ while cur is not None:
43
+ if getattr(cur, 'expected_bindings', False):
44
+ return getattr(cur, '_forces_list', False)
45
+ cur = cur._parent # type: ignore[attr-defined]
46
+ return False
47
+
48
+ # ----------------------------------- helpers ------------------------------------
49
+ @staticmethod
50
+ def _to_list(val: Any) -> list[Any]:
51
+ return val if isinstance(val, list) else [val]
@@ -0,0 +1,59 @@
1
+ from os import SEEK_CUR
2
+ from typing import Sequence, Any
3
+
4
+ from ast_pattern_engine.core import Pattern
5
+
6
+ from ast_pattern_engine.nodes.sequences import PatternGroup, Repetition
7
+
8
+ def _match_patterns(pattern_nodes: Sequence[Pattern], nodes: list[Any], pos: int, bindings: dict[str, Any]):
9
+ if not pattern_nodes:
10
+ return [(bindings, pos)]
11
+ first, *rest = pattern_nodes
12
+ out: list[tuple[dict[str, Any], int]] = []
13
+
14
+ if isinstance(first, PatternGroup):
15
+ res = _match_patterns(first.pattern, nodes, pos, {})
16
+ if res:
17
+ new_bindings = dict(bindings)
18
+ if first.key is not None:
19
+ bindings[first.key] = res[-1][0]
20
+ return [(new_bindings, res[-1][1])]
21
+ return out
22
+
23
+ if isinstance(first, Repetition):
24
+ new_bindings = dict(bindings)
25
+ n_reps = 0
26
+ while (
27
+ n_reps < (first.max_matches or len(nodes))
28
+ and pos < len(nodes)
29
+ ):
30
+ res = _match_patterns([first.pattern], nodes, pos, dict(new_bindings))
31
+ if not res:
32
+ break
33
+ new_bindings, pos = res[-1]
34
+ n_reps += 1
35
+
36
+ if n_reps >= first.min_matches:
37
+ return [(new_bindings, pos)]
38
+
39
+ # ---------- single ----------
40
+ if pos < len(nodes):
41
+ res = first.match(nodes[pos], dict(bindings))
42
+ if res is not None:
43
+ out.extend(_match_patterns(rest, nodes, pos + 1, res))
44
+
45
+ return out
46
+
47
+ def match_sequence(patterns: Sequence[Pattern], nodes: list[Any]):
48
+ """Return list of binding dicts for non-overlapping matches in `nodes`."""
49
+ results: list[dict[str, Any]] = []
50
+ i = 0
51
+ while i < len(nodes):
52
+ m = _match_patterns(patterns, nodes, i, {})
53
+ if not m:
54
+ i += 1
55
+ continue
56
+ b, new_pos = m[0]
57
+ results.append(b)
58
+ i = new_pos
59
+ return results
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+ import ast
3
+ from typing import Any, Callable
4
+
5
+ from ast_pattern_engine.core import Pattern
6
+ from ast_pattern_engine.src.ast_pattern_engine.plumbing import (
7
+ AnnounceBinding, Message, t_expected_bindings
8
+ )
9
+ from ast_pattern_engine.engine import _match_patterns
10
+
11
+ class Bind(Pattern):
12
+ """Bind the current node to `name`."""
13
+
14
+ key: str
15
+
16
+ def __init__(self, name: str):
17
+ self.key = name
18
+
19
+ def _on_parent_init(self) -> None:
20
+ super()._on_parent_init()
21
+ self.bubble_message(AnnounceBinding(self.key, self))
22
+
23
+ def match(self, node: Any, bindings: dict[str, Any] | None = None):
24
+ bindings = bindings or {}
25
+ force = self._ancestor_forces_list()
26
+ if self.key in bindings:
27
+ if not force:
28
+ return None
29
+ bindings[self.key] = self._to_list(bindings[self.key]) + [node]
30
+ else:
31
+ bindings[self.key] = [node] if force else node
32
+ return bindings
33
+
34
+ class WildCard(Pattern):
35
+ """Matches any node"""
36
+
37
+ def __init__(self): ...
38
+
39
+ def match(self, node: Any, bindings: dict[str, Any] | None = None):
40
+ bindings = bindings or {}
41
+ return bindings
42
+
43
+ class NodePattern(Pattern):
44
+ """Match an AST node of *node_type* with constraints on its fields."""
45
+
46
+ def __init__(self, node_type: type[ast.AST], **field_patterns: Pattern | Any):
47
+ self.node_type = node_type
48
+ self.field_patterns = field_patterns
49
+ for pat in field_patterns.values():
50
+ if isinstance(pat, Pattern):
51
+ pat._set_parent(self)
52
+
53
+ def match(self, node: Any, bindings: dict[str, Any] | None = None):
54
+ bindings = bindings or {}
55
+ if not isinstance(node, self.node_type):
56
+ return None
57
+ merged = dict(bindings)
58
+ for field, pat in self.field_patterns.items():
59
+ val = getattr(node, field, None)
60
+ if isinstance(pat, Pattern):
61
+ if val is None:
62
+ return None
63
+ # Match list-valued field
64
+ if isinstance(val, list) and not isinstance(pat, Bind):
65
+ res = _match_patterns([pat], val, 0, {})
66
+ # if not res or res[0][1] != len(val):
67
+ # return None
68
+ sub_bind = res[-1][0]
69
+ else:
70
+ sub_bind = pat.match(val, {})
71
+ if sub_bind is None:
72
+ return None
73
+ # merge sub bindings
74
+ for k, v in sub_bind.items():
75
+ if k in merged:
76
+ if not self._ancestor_forces_list():
77
+ return None
78
+ merged[k] = self._to_list(merged[k]) + self._to_list(v)
79
+ else:
80
+ merged[k] = self._to_list(v) if self._ancestor_forces_list() else v
81
+ else:
82
+ if val != pat:
83
+ return None
84
+ return merged
85
+
86
+ class Collect(Pattern):
87
+ """Collect *node* under *key* and merge sub-bindings into current scope."""
88
+
89
+ expected_bindings: t_expected_bindings
90
+
91
+ def __init__(self, pattern: Pattern, key: str):
92
+ self.pattern = pattern
93
+ self.key = key
94
+ self.expected_bindings = []
95
+ pattern._set_parent(self)
96
+
97
+ def _on_parent_init(self) -> None:
98
+ super()._on_parent_init()
99
+ if self._ancestor_forces_list():
100
+ if self.expected_bindings:
101
+ self.bubble_message(AnnounceBinding([self.key, self.expected_bindings], self))
102
+ else:
103
+ self.bubble_message(AnnounceBinding(self.key, self))
104
+ else:
105
+ for binding in [*self.expected_bindings,self.key]:
106
+ self.bubble_message(message=AnnounceBinding(binding, self))
107
+
108
+ def match(self, node: Any, bindings: dict[str, Any] | None = None) -> None | dict[str, Any]:
109
+ bindings = bindings or {}
110
+ inner = self.pattern.match(node, {})
111
+ if inner is None:
112
+ return None
113
+ force = self._ancestor_forces_list()
114
+ merged = dict(bindings)
115
+
116
+ if force:
117
+ if self.expected_bindings:
118
+ # Inside a repetition wrapper, do not merge inner bindings;
119
+ # instead, append the inner-dict itself to the list under key.
120
+ if self.key in merged:
121
+ merged[self.key].append(inner)
122
+ else:
123
+ merged[self.key] = [inner]
124
+ return merged
125
+ else:
126
+ if self.key in merged:
127
+ merged[self.key].append(node)
128
+ else:
129
+ merged[self.key] = [node]
130
+ return merged
131
+
132
+ # Outside repetition - store node and merge inner bindings
133
+ if self.key in merged:
134
+ return None # scalar expected, duplicate found
135
+ merged[self.key] = node
136
+ for k, v in inner.items():
137
+ if k in merged:
138
+ return None # scalar expected, duplicate found
139
+ merged[k] = v
140
+ return merged
141
+
142
+ def handle_message(self, message: Message) -> None:
143
+ match message:
144
+ case AnnounceBinding():
145
+ self.expected_bindings.append(message.expected_bindings)
146
+ # print(self.key, self.expected_bindings)
147
+
148
+ class Filter(Pattern):
149
+ """Match nodes where `predicate(node)` returns `True` and optionally bind `node` to `name`."""
150
+
151
+ def __init__(self, predicate: Callable[[Any], bool], key: str | None = None):
152
+ self.predicate = predicate
153
+ self.key = key
154
+
155
+ def _on_parent_init(self) -> None:
156
+ if self.key:
157
+ self.bubble_message(AnnounceBinding(self.key, self))
158
+
159
+ def match(self, node: Any, bindings: dict[str, Any] | None = None):
160
+ bindings = bindings or {}
161
+ if not self.predicate(node):
162
+ return None
163
+ if self.key is None:
164
+ return bindings
165
+ force = self._ancestor_forces_list()
166
+ if self.key in bindings:
167
+ if not force:
168
+ return None
169
+ bindings[self.key] = self._to_list(bindings[self.key]) + [node]
170
+ else:
171
+ bindings[self.key] = [node] if force else node
172
+ return bindings
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+ from typing import Sequence
3
+
4
+ from ast_pattern_engine.core import Pattern
5
+ from ast_pattern_engine.src.ast_pattern_engine.plumbing import (
6
+ AnnounceBinding, Message, t_expected_bindings
7
+ )
8
+
9
+ class Repetition(Pattern):
10
+ """Matches a single pattern to an AST node sequence.
11
+ Also supports specifying min and max match count threshold"""
12
+
13
+ _forces_list = True
14
+ expected_bindings: list[str | t_expected_bindings]
15
+
16
+ def __init__(
17
+ self,
18
+ pattern: Pattern,
19
+ min_matches: int = 1,
20
+ max_matches: int | None = None,
21
+ ):
22
+ """
23
+ :param pattern: The AST pattern
24
+ :type pattern: Pattern
25
+ :param min_matches: min number of matches required, defaults to 1
26
+ :type min_matches: int, optional
27
+ :param max_matches: max number of allowed matches, defaults to None
28
+ :type max_matches: int | None, optional
29
+ """
30
+
31
+ self.expected_bindings = []
32
+ self.pattern = pattern
33
+
34
+ self.min_matches = min_matches
35
+ self.max_matches = max_matches
36
+
37
+ pattern._set_parent(self)
38
+
39
+ def match(self, node: object, bindings: dict[str, object] | None = None):
40
+ # Matching is handled by engine._match_sequence
41
+ raise NotImplementedError(
42
+ "Repetition node does not support matching single AST node."""
43
+ )
44
+
45
+ def handle_message(self, message: Message) -> None:
46
+ match message:
47
+ case AnnounceBinding():
48
+ self.expected_bindings.append(message.expected_bindings)
49
+
50
+ super().handle_message(message)
51
+
52
+ class PatternGroup(Pattern):
53
+ """Matches a pattern group to an AST node sequence"""
54
+
55
+ def __init__(self, pattern: Sequence[Pattern], key: str | None = None) -> None:
56
+ self.pattern = pattern
57
+ self.key = key
58
+
59
+ def _on_parent_init(self) -> None:
60
+ super()._on_parent_init()
61
+ if self.key is not None:
62
+ self.bubble_message(AnnounceBinding(self.key, self))
63
+
64
+ def match(self, node: object, bindings: dict[str, object] | None = None):
65
+ # Matching is handled by engine._match_sequence
66
+ raise NotImplementedError(
67
+ "PatternGroup node does not support matching single AST node."
68
+ )
@@ -0,0 +1,16 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from ast_pattern_engine.core import Pattern
7
+
8
+ class Message:
9
+ ...
10
+
11
+ type t_expected_bindings = list[str | t_expected_bindings]
12
+
13
+ @dataclass
14
+ class AnnounceBinding(Message):
15
+ expected_bindings: str | t_expected_bindings
16
+ node: Pattern
@@ -0,0 +1,543 @@
1
+ import ast
2
+ from typing import Sequence, Callable, Any
3
+
4
+ from ast_pattern_engine.core import Pattern
5
+ from ast_pattern_engine.src.ast_pattern_engine.engine import _match_patterns
6
+
7
+ type ReplaceResult = ast.AST | list[ast.AST] | None
8
+
9
+ class PatternTransformer(ast.NodeTransformer):
10
+ """
11
+ Walk the tree, find non-overlapping matches of *pattern* and run *actions*.
12
+ In addition, every successful set of bindings is appended to
13
+ `self.matches` (like PatternFinder).
14
+
15
+ actions = {collect_key: handler | None}
16
+
17
+ • handler(bindings) → list[ast.AST]
18
+ - `bindings` is the FULL match dict
19
+ - the returned list replaces the *anchor* node (see below)
20
+
21
+ • None
22
+ - delete every node collected under *collect_key*
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ pattern: Sequence[Pattern],
28
+ actions: dict[str, Callable[[dict[str, Any]], list[ast.AST]] | None],
29
+ ):
30
+ super().__init__()
31
+ self.pattern = pattern
32
+ self.actions = actions
33
+ self.matches: list[dict[str, Any]] = []
34
+
35
+ def _record(self, bindings: dict[str, Any]) -> None:
36
+ """Internal helper so we have just one place that appends."""
37
+ self.matches.append(bindings)
38
+
39
+ def _normalize_replace_for_nonlist(self, rep: ReplaceResult, field: str) -> ast.AST:
40
+ """
41
+ Normalize a replacement for a non-list field.
42
+
43
+ - ast.AST -> ast.AST
44
+ - [ast.AST] with len==1 -> element
45
+ - None or list with len!=1 -> error
46
+ """
47
+ if isinstance(rep, ast.AST):
48
+ return rep
49
+ if isinstance(rep, list):
50
+ if len(rep) == 1 and isinstance(rep[0], ast.AST):
51
+ return rep[0]
52
+ raise ValueError(f'Cannot replace non-list field {field} with {len(rep)} nodes')
53
+ raise ValueError(f'Cannot delete non-list field {field}')
54
+
55
+
56
+ # ------------- helpers to interpret collected values ------------------
57
+
58
+ @staticmethod
59
+ def _as_nodes(value: Any) -> list[ast.AST]:
60
+ """
61
+ Extract AST nodes from arbitrarily nested values.
62
+
63
+ The value may be:
64
+ - ast.AST: returned as a single-element list
65
+ - list: flattened recursively
66
+ - dict: values scanned recursively (no special key required)
67
+ - other: ignored
68
+
69
+ Returns:
70
+ list[ast.AST]: All AST nodes found, in encounter order.
71
+ """
72
+ out: list[ast.AST] = []
73
+ def rec(v: Any) -> None:
74
+ if isinstance(v, ast.AST):
75
+ out.append(v)
76
+ elif isinstance(v, list):
77
+ for it in v:
78
+ rec(it)
79
+ elif isinstance(v, dict):
80
+ for it in v.values():
81
+ rec(it)
82
+ rec(value)
83
+ return out
84
+
85
+ def _replace_within(self, parent: ast.AST, target: ast.AST, repl: list[ast.AST]) -> bool:
86
+ """
87
+ Replace `target` somewhere inside `parent` with `repl`.
88
+ Returns True if a replacement was performed.
89
+ """
90
+ def walk(node: ast.AST) -> bool:
91
+ for field, val in ast.iter_fields(node):
92
+ if val is target:
93
+ if len(repl) != 1:
94
+ raise ValueError(f'Cannot replace non-list field {field} with {len(repl)} nodes')
95
+ setattr(node, field, repl[0])
96
+ return True
97
+ if isinstance(val, list):
98
+ for i, elem in enumerate(val):
99
+ if elem is target:
100
+ val[i:i+1] = repl
101
+ return True
102
+ if isinstance(val, ast.AST) and walk(val):
103
+ return True
104
+ if isinstance(val, list):
105
+ for elem in val:
106
+ if isinstance(elem, ast.AST) and walk(elem):
107
+ return True
108
+ return False
109
+ return walk(parent)
110
+
111
+ def _delete_within(self, parent: ast.AST, target: ast.AST) -> bool:
112
+ """
113
+ Delete `target` if it appears in a list field somewhere inside `parent`.
114
+ Returns True if a deletion was performed.
115
+ """
116
+ def walk(node: ast.AST) -> bool:
117
+ for _, val in ast.iter_fields(node):
118
+ if isinstance(val, list):
119
+ i = 0
120
+ while i < len(val):
121
+ elem = val[i]
122
+ if elem is target:
123
+ del val[i]
124
+ return True
125
+ if isinstance(elem, ast.AST) and walk(elem):
126
+ return True
127
+ i += 1
128
+ elif isinstance(val, ast.AST) and walk(val):
129
+ return True
130
+ return False
131
+ return walk(parent)
132
+
133
+ def _contains(self, root: ast.AST, target: ast.AST) -> bool:
134
+ """Return True if target occurs anywhere in root's subtree."""
135
+ if root is target:
136
+ return True
137
+ for _, v in ast.iter_fields(root):
138
+ if isinstance(v, ast.AST):
139
+ if self._contains(v, target):
140
+ return True
141
+ elif isinstance(v, list):
142
+ for it in v:
143
+ if isinstance(it, ast.AST) and self._contains(it, target):
144
+ return True
145
+ return False
146
+
147
+ def _find_owner_in_span(self, span: list[ast.AST], anchor: ast.AST) -> ast.AST | None:
148
+ """Return the first node in span whose subtree contains anchor."""
149
+ for n in span:
150
+ if self._contains(n, anchor):
151
+ return n
152
+ return None
153
+
154
+
155
+ # ---------------- plan replacements / removals for a list -------------
156
+
157
+
158
+ def _plan(self, seq: list[ast.AST]) -> tuple[dict[int, list[ast.AST]], set[int]]:
159
+ repl: dict[int, list[ast.AST]] = {}
160
+ remove: set[int] = set()
161
+ i = 0
162
+ while i < len(seq):
163
+ mtch = _match_patterns(self.pattern, seq, i, {})
164
+ if not mtch:
165
+ i += 1
166
+ continue
167
+
168
+ bindings, new_pos = mtch[0]
169
+ self._record(bindings)
170
+ span_nodes = seq[i:new_pos]
171
+
172
+ for key, action in self.actions.items():
173
+ if key not in bindings:
174
+ continue
175
+
176
+ collected = self._as_nodes(bindings[key])
177
+ if not collected:
178
+ continue
179
+
180
+ list_anchors = [n for n in collected if n in span_nodes]
181
+
182
+ if action is None:
183
+ # Remove only list elements; ignore nested nodes
184
+ if list_anchors:
185
+ remove.update(id(n) for n in list_anchors)
186
+ else:
187
+ # Best-effort nested delete, but soft-fail if not in a list field
188
+ owner = self._find_owner_in_span(span_nodes, collected[0]) or span_nodes[0]
189
+ for n in collected:
190
+ self._delete_within(owner, n)
191
+ continue
192
+
193
+ # Replacement
194
+ r = action(bindings) or []
195
+ if not isinstance(r, list):
196
+ raise TypeError('Handler must return list[ast.AST].')
197
+
198
+ if list_anchors:
199
+ # Replace the first list element anchor and remove the rest
200
+ repl[id(list_anchors[0])] = r
201
+ remove.update(id(n) for n in list_anchors[1:])
202
+ else:
203
+ # Replace nested child within its owner
204
+ owner = self._find_owner_in_span(span_nodes, collected[0]) or span_nodes[0]
205
+ # Try each collected node until one is found in owner
206
+ replaced = False
207
+ for n in collected:
208
+ if self._replace_within(owner, n, r):
209
+ replaced = True
210
+ break
211
+ if not replaced:
212
+ raise ValueError('Could not locate collected child in matched subtree for in-place replacement.')
213
+
214
+ i = new_pos
215
+
216
+ return repl, remove
217
+
218
+ def generic_visit(self, node: ast.AST) -> ast.AST:
219
+ """Transform children, then apply sequential matching in list fields."""
220
+ for field, old_value in ast.iter_fields(node):
221
+ if isinstance(old_value, list):
222
+ # 1) visit children
223
+ children: list[ast.AST] = []
224
+ for v in old_value:
225
+ if isinstance(v, ast.AST):
226
+ nv = self.visit(v)
227
+ if nv is None:
228
+ continue
229
+ if not isinstance(nv, ast.AST):
230
+ raise TypeError('List fields must contain AST nodes.')
231
+ children.append(nv)
232
+ else:
233
+ children.append(v)
234
+
235
+ # 2) plan and splice sequential replacements/removals
236
+ if children and all(isinstance(c, ast.AST) for c in children):
237
+ replace, remove = self._plan(children)
238
+ new_children: list[ast.AST] = []
239
+ for ch in children:
240
+ rid = id(ch)
241
+ if rid in replace:
242
+ new_children.extend(replace[rid])
243
+ elif rid not in remove:
244
+ new_children.append(ch)
245
+ setattr(node, field, new_children)
246
+ else:
247
+ setattr(node, field, children)
248
+
249
+ elif isinstance(old_value, ast.AST):
250
+ visited = self.visit(old_value)
251
+ if isinstance(visited, ast.AST):
252
+ rep = self._maybe_replace(visited)
253
+ if rep is not None:
254
+ new_node = self._normalize_replace_for_nonlist(rep, field)
255
+ setattr(node, field, new_node)
256
+ else:
257
+ # Deleting or expanding is not valid for non-list fields.
258
+ setattr(node, field, old_value)
259
+
260
+ return node
261
+
262
+ def _maybe_replace(self, node: ast.AST) -> ast.AST | list[ast.AST] | None:
263
+ """
264
+ Try to match `self.pattern` against just `node` and apply actions.
265
+
266
+ Returns
267
+ -------
268
+ ast.AST | list[ast.AST] | None:
269
+ - ast.AST: replace this node with a single node
270
+ - list[ast.AST]: splice multiple nodes (only valid in list fields)
271
+ - None: delete this node
272
+ """
273
+ res = _match_patterns(self.pattern, [node], 0, {})
274
+ if not res:
275
+ return None
276
+
277
+ bindings, _ = res[0]
278
+ self._record(bindings)
279
+
280
+ did_inplace = False
281
+
282
+ for key, action in self.actions.items():
283
+ if key not in bindings:
284
+ continue
285
+
286
+ collected = self._as_nodes(bindings[key])
287
+
288
+ # Case 1: the action targets THIS node -> replace/delete the node itself
289
+ if node in collected:
290
+ if action is None:
291
+ return None
292
+ repl = action(bindings) or []
293
+ if not isinstance(repl, list):
294
+ raise TypeError('Handler must return list[ast.AST].')
295
+ if len(repl) == 0:
296
+ return None
297
+ if len(repl) == 1:
298
+ return repl[0]
299
+ return repl # multi-element expansion only valid in list fields
300
+
301
+ # Case 2: the action targets nested anchors -> mutate in place
302
+ if action is None:
303
+ for n in collected:
304
+ did_inplace |= self._delete_within(node, n)
305
+ else:
306
+ repl = action(bindings) or []
307
+ if not isinstance(repl, list):
308
+ raise TypeError('Handler must return list[ast.AST].')
309
+ for n in collected:
310
+ did_inplace |= self._replace_within(node, n, repl)
311
+
312
+ # If we only did in-place nested edits, keep this node (signal "no top-level replace")
313
+ return None
314
+
315
+
316
+ # class BottomUpPatternTransformer(PatternTransformer):
317
+ # """
318
+ # Post-order transformer that applies actions to the current node only
319
+ # after all of its children have been transformed.
320
+
321
+ # Key difference from PatternTransformer:
322
+ # - We disable the parent class's child-level _maybe_replace(...) to avoid
323
+ # re-matching children on the way back up. Sequence-level planning in list
324
+ # fields still runs as usual.
325
+ # """
326
+
327
+ # # --- IMPORTANT: neutralize parent child-level matching ---
328
+ # def _maybe_replace(self, node: ast.AST) -> ast.AST | list[ast.AST] | None:
329
+ # # Parent generic_visit calls this for visited children; returning None
330
+ # # prevents a second match on those children.
331
+ # return None
332
+
333
+ # def generic_visit(self, node: ast.AST) -> ast.AST:
334
+ # # 1) Transform children with parent's machinery (list-field sequence
335
+ # # planning still works), but _maybe_replace is now neutralized.
336
+ # node = super().generic_visit(node)
337
+
338
+ # # 2) Bottom-up: match and act on the fully-processed *current* node.
339
+ # res = _match_patterns(self.pattern, [node], 0, {})
340
+ # if not res:
341
+ # return node
342
+
343
+ # bindings, _ = res[0]
344
+ # self._record(bindings)
345
+
346
+ # delete_self = False
347
+ # replace_with: ast.AST | None = None
348
+
349
+ # for key, action in self.actions.items():
350
+ # if key not in bindings:
351
+ # continue
352
+
353
+ # collected = self._as_nodes(bindings[key])
354
+
355
+ # # Target THIS node directly
356
+ # if node in collected:
357
+ # if action is None:
358
+ # delete_self = True
359
+ # continue
360
+ # repl = action(bindings) or []
361
+ # if not isinstance(repl, list):
362
+ # raise TypeError('Handler must return list[ast.AST].')
363
+ # if len(repl) == 0:
364
+ # delete_self = True
365
+ # continue
366
+ # if len(repl) != 1:
367
+ # raise ValueError('Cannot replace a non-list node with multiple nodes.')
368
+ # replace_with = repl[0]
369
+ # continue
370
+
371
+ # # Target nested anchors inside this node
372
+ # if action is None:
373
+ # for n in collected:
374
+ # self._delete_within(node, n)
375
+ # else:
376
+ # repl = action(bindings) or []
377
+ # if not isinstance(repl, list):
378
+ # raise TypeError('Handler must return list[ast.AST].')
379
+ # for n in collected:
380
+ # self._replace_within(node, n, repl)
381
+
382
+ # if replace_with is not None:
383
+ # return replace_with
384
+ # if delete_self:
385
+ # # Effective only when parent holds this node inside a list field.
386
+ # return None
387
+
388
+ # return node
389
+
390
+
391
+ class BottomUpPatternTransformer(ast.NodeTransformer):
392
+ """
393
+ Like PatternTransformer, but applies pattern matching *after* visiting all children.
394
+ This enables bottom-up rewriting (i.e., transforming from the leaves upward).
395
+ """
396
+ def __init__(
397
+ self,
398
+ pattern: Sequence[Pattern],
399
+ actions: dict[str, Callable[[dict[str, Any]], list[ast.AST]] | None],
400
+ ):
401
+ super().__init__()
402
+ self.pattern = pattern
403
+ self.actions = actions
404
+ self.matches: list[dict[str, Any]] = []
405
+
406
+ def _record(self, bindings: dict[str, Any]) -> None:
407
+ self.matches.append(bindings)
408
+
409
+ @staticmethod
410
+ def _as_nodes(value: Any) -> list[ast.AST]:
411
+ """
412
+ Extract AST nodes from arbitrarily nested values.
413
+
414
+ Returns:
415
+ list[ast.AST]: All AST nodes found within value.
416
+ """
417
+ out: list[ast.AST] = []
418
+ def rec(v: Any) -> None:
419
+ if isinstance(v, ast.AST):
420
+ out.append(v)
421
+ elif isinstance(v, list):
422
+ for it in v:
423
+ rec(it)
424
+ elif isinstance(v, dict):
425
+ for it in v.values():
426
+ rec(it)
427
+ rec(value)
428
+ return out
429
+
430
+ def visit(self, node: ast.AST) -> ast.AST | list[ast.AST] | None:
431
+ # First, transform children recursively (bottom-up)
432
+ for field, value in list(ast.iter_fields(node)):
433
+ if isinstance(value, list):
434
+ new_list = []
435
+ for item in value:
436
+ if isinstance(item, ast.AST):
437
+ new_item = self.visit(item)
438
+ if new_item is None:
439
+ continue
440
+ elif isinstance(new_item, list):
441
+ new_list.extend(new_item)
442
+ else:
443
+ new_list.append(new_item)
444
+ else:
445
+ new_list.append(item)
446
+ setattr(node, field, new_list)
447
+ elif isinstance(value, ast.AST):
448
+ new_value = self.visit(value)
449
+ setattr(node, field, new_value)
450
+
451
+ # Then apply pattern matching to this node
452
+ res = _match_patterns(self.pattern, [node], 0, {})
453
+ if res:
454
+ bindings, _ = res[0]
455
+ self._record(bindings)
456
+ for key, action in self.actions.items():
457
+ if key in bindings:
458
+ if action is not None:
459
+ replacements = action(bindings)
460
+ if len(replacements) == 0: return None
461
+ return replacements[0] if len(replacements) == 1 else replacements
462
+ else:
463
+ return None # delete the matched node
464
+
465
+ return node
466
+
467
+
468
+ class PatternFinder(ast.NodeVisitor):
469
+ """Collect bindings for every occurrence of *pattern* in an AST."""
470
+
471
+ def __init__(self, pattern: Sequence[Pattern]):
472
+ super().__init__()
473
+ self.visited: set[int] = set()
474
+ self.pattern = pattern
475
+ self.matches: list[dict[str, Any]] = []
476
+
477
+ def generic_visit(self, node: ast.AST):
478
+ if id(node) in self.visited: return
479
+
480
+ self.visited.add(id(node))
481
+ res = _match_patterns(self.pattern, [node], 0, {})
482
+ if res:
483
+ self.matches.append(res[0][0])
484
+ for _, val in ast.iter_fields(node):
485
+ if isinstance(val, list):
486
+ if len(self.pattern) > 1:
487
+ self._scan_list(val)
488
+ for elem in val:
489
+ if isinstance(elem, ast.AST):
490
+ self.visit(elem)
491
+ elif isinstance(val, ast.AST):
492
+ self.visit(val)
493
+
494
+ def _scan_list(self, seq: list[ast.AST]):
495
+ i = 0
496
+ while i < len(seq):
497
+ res = _match_patterns(self.pattern, seq, i, {})
498
+ if res:
499
+ binds, new_pos = res[0]
500
+ self.matches.append(binds)
501
+ i = new_pos
502
+ else:
503
+ i += 1
504
+
505
+ class SingleOccurrenceFinder(ast.NodeVisitor):
506
+ """
507
+ Quickly checks whether a single match of the given pattern sequence exists in an AST.
508
+ Returns True on the first match, short-circuiting the traversal.
509
+ """
510
+
511
+ match_node: ast.AST | None
512
+
513
+ def __init__(self, pattern: Sequence[Pattern]):
514
+ super().__init__()
515
+ self.match_node = None
516
+ self.pattern = pattern
517
+ self.found = False
518
+
519
+ def visit(self, node: ast.AST):
520
+ if self.found:
521
+ return # short-circuit: we've already found a match
522
+
523
+ res = _match_patterns(self.pattern, [node], 0, {})
524
+ if res:
525
+ self.found = True
526
+ self.match_node = node
527
+ return
528
+
529
+ # Continue traversal
530
+ for _, val in ast.iter_fields(node):
531
+ if isinstance(val, list):
532
+ for item in val:
533
+ if isinstance(item, ast.AST):
534
+ self.visit(item)
535
+ if self.found:
536
+ return
537
+ elif isinstance(val, ast.AST):
538
+ self.visit(val)
539
+ if self.found:
540
+ return
541
+
542
+ def found_match(self) -> bool:
543
+ return self.found
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.4
2
+ Name: ast-pattern-engine
3
+ Version: 0.1.0
4
+ Summary: A library for regex level fine grained AST pattern matching and replacing
5
+ Author-email: 80sVectorz <66908776+80sVectorz@users.noreply.github.com>
6
+ Requires-Python: >=3.13
7
+ Provides-Extra: dev
8
+ Requires-Dist: pytest; extra == 'dev'
@@ -0,0 +1,10 @@
1
+ ast_pattern_engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ast_pattern_engine/core.py,sha256=oO3GeboCh44EJXRl3I75rhsV19TpprC1o_yA-GlOFkQ,1926
3
+ ast_pattern_engine/engine.py,sha256=lToTzRMmknLosuOITNnZ7L9NSyIVlNzukxc_5MUhH6M,1938
4
+ ast_pattern_engine/plumbing.py,sha256=1YQ7Pu1FVKH3X7mGa6QZOg6RH3O2ukRMeKOL5smLXNQ,378
5
+ ast_pattern_engine/visitors.py,sha256=bAwTTLqlwhNH_olTf4NiFvlCL_fcnTvZD8wQNDrm6b0,20794
6
+ ast_pattern_engine/nodes/basic.py,sha256=nriLPIaP4Ea8UOedIj-n8L1hOjWgpnQahdqo_V0RIF8,6427
7
+ ast_pattern_engine/nodes/sequences.py,sha256=RaFQTLlB9gmJf6L_ftxyMSjqRnheAtIpFH1QigwisGw,2320
8
+ ast_pattern_engine-0.1.0.dist-info/METADATA,sha256=g1AE08ArsM9FG4dUKbJUMCTbKjdV0h7muuXLxFH0xAU,299
9
+ ast_pattern_engine-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ ast_pattern_engine-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any