ast-pattern-engine 1.0.0__tar.gz → 1.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/.github/workflows/ci.yml +35 -10
  2. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/PKG-INFO +1 -1
  3. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/pyproject.toml +5 -1
  4. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/core.py +10 -2
  5. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/engine.py +18 -6
  6. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/nodes/basic.py +65 -13
  7. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/nodes/sequences.py +5 -1
  8. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/visitors.py +2 -1
  9. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_any_of.py +8 -18
  10. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_collect.py +1 -0
  11. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_filter.py +4 -3
  12. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_one_of.py +4 -3
  13. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_pattern_group.py +1 -0
  14. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_repetition.py +1 -0
  15. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/test_engine.py +1 -0
  16. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/visitors/test_bottom_up_pattern_transformer.py +2 -1
  17. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/visitors/test_pattern_finder.py +4 -1
  18. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/visitors/test_pattern_transformer.py +14 -10
  19. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/visitors/test_single_occurrence_finder.py +2 -2
  20. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/.gitignore +0 -0
  21. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/LICENSE +0 -0
  22. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/README.md +0 -0
  23. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/examples/dict_get_rewrite.py +0 -0
  24. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/__init__.py +0 -0
  25. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/nodes/__init__.py +0 -0
  26. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/plumbing.py +0 -0
  27. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/py.typed +0 -0
  28. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/src/ast_pattern_engine/templates.py +0 -0
  29. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/__init__.py +0 -0
  30. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_all_of.py +0 -0
  31. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_bind.py +0 -0
  32. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_contains.py +0 -0
  33. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_not.py +0 -0
  34. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_optional.py +0 -0
  35. {ast_pattern_engine-1.0.0 → ast_pattern_engine-1.0.1}/tests/patterns/test_templates.py +0 -0
@@ -7,37 +7,62 @@ on:
7
7
  branches: [ "main" ]
8
8
 
9
9
  jobs:
10
+ format:
11
+ runs-on: ubuntu-latest
12
+ if: github.event_name == 'push'
13
+ permissions:
14
+ contents: write
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+
18
+ - name: Install uv
19
+ uses: astral-sh/setup-uv@v5
20
+ with:
21
+ enable-cache: true
22
+
23
+ - name: Auto-format with Ruff
24
+ run: uv run ruff format
25
+
26
+ - name: Auto-fix lint with Ruff
27
+ run: uv run ruff check --fix
28
+
29
+ - name: Commit formatting changes
30
+ uses: stefanzweifel/git-auto-commit-action@v5
31
+ with:
32
+ commit_message: "style: auto-format with ruff"
33
+
10
34
  test:
11
35
  runs-on: ubuntu-latest
36
+ needs: format
37
+ if: always()
12
38
  strategy:
13
39
  matrix:
14
40
  python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
15
41
 
16
42
  steps:
17
43
  - uses: actions/checkout@v4
18
-
44
+ with:
45
+ ref: ${{ github.ref }}
46
+
19
47
  - name: Set up Python ${{ matrix.python-version }}
20
48
  uses: actions/setup-python@v5
21
49
  with:
22
50
  python-version: ${{ matrix.python-version }}
23
-
51
+
24
52
  - name: Install uv
25
53
  uses: astral-sh/setup-uv@v5
26
54
  with:
27
55
  enable-cache: true
28
-
56
+
29
57
  - name: Install dependencies
30
- run: uv sync --all-extras
31
-
32
- - name: Check formatting with Ruff
33
- run: uv run ruff format --check
34
-
58
+ run: uv sync --all-extras --all-groups
59
+
35
60
  - name: Lint with Ruff
36
61
  run: uv run ruff check
37
-
62
+
38
63
  - name: Run tests with pytest
39
64
  run: uv run pytest --cov=src --cov-report=xml
40
-
65
+
41
66
  - name: Upload coverage reports to Codecov
42
67
  uses: codecov/codecov-action@v4
43
68
  env:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ast-pattern-engine
3
- Version: 1.0.0
3
+ Version: 1.0.1
4
4
  Summary: A library for regex-inspired fine-grained AST pattern matching and replacing
5
5
  Project-URL: Homepage, https://github.com/80sVectorz/ast_pattern_engine
6
6
  Project-URL: Repository, https://github.com/80sVectorz/ast_pattern_engine
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ast-pattern-engine"
3
- version = "1.0.0"
3
+ version = "1.0.1"
4
4
  description = "A library for regex-inspired fine-grained AST pattern matching and replacing"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -40,4 +40,8 @@ pythonpath = ["src"]
40
40
  [dependency-groups]
41
41
  dev = [
42
42
  "pytest-cov>=7.1.0",
43
+ "ruff>=0.11.0",
43
44
  ]
45
+
46
+ [tool.ruff.lint]
47
+ ignore = ["F841"] # Ignore unused variable assignments
@@ -8,7 +8,11 @@ class Pattern(ast.AST):
8
8
 
9
9
  # public API
10
10
  def match_node(
11
- self, node: object, bindings: dict[str, object] | None = None, *, _force_list: bool = False
11
+ self,
12
+ node: object,
13
+ bindings: dict[str, object] | None = None,
14
+ *,
15
+ _force_list: bool = False,
12
16
  ):
13
17
  """Match *node* and return updated *bindings* or *None*."""
14
18
  raise NotImplementedError
@@ -21,7 +25,11 @@ class Pattern(ast.AST):
21
25
 
22
26
  class SequencePattern(Pattern):
23
27
  def match_node(
24
- self, node: object, bindings: dict[str, object] | None = None, *, _force_list: bool = False
28
+ self,
29
+ node: object,
30
+ bindings: dict[str, object] | None = None,
31
+ *,
32
+ _force_list: bool = False,
25
33
  ):
26
34
  # Matching is handled by engine._match_sequence
27
35
  raise NotImplementedError(
@@ -26,7 +26,9 @@ def _match_patterns(
26
26
 
27
27
  match first:
28
28
  case PatternGroup(pattern=sub_pattern, key=key):
29
- res = _match_patterns(sub_pattern, nodes, pos, dict(bindings), _force_list=_force_list)
29
+ res = _match_patterns(
30
+ sub_pattern, nodes, pos, dict(bindings), _force_list=_force_list
31
+ )
30
32
  if res:
31
33
  new_bindings = res[-1][0]
32
34
  if key is not None:
@@ -40,7 +42,9 @@ def _match_patterns(
40
42
  new_bindings = dict(bindings)
41
43
  n_reps = 0
42
44
  while n_reps < (max_matches or len(nodes)) and pos < len(nodes):
43
- res = _match_patterns([sub_pattern], nodes, pos, dict(new_bindings), _force_list=True)
45
+ res = _match_patterns(
46
+ [sub_pattern], nodes, pos, dict(new_bindings), _force_list=True
47
+ )
44
48
  if not res:
45
49
  break
46
50
  new_bindings, pos = res[-1]
@@ -54,7 +58,9 @@ def _match_patterns(
54
58
  new_pos = pos
55
59
  n_matches = 0
56
60
  for pattern in sub_patterns:
57
- res = _match_patterns([pattern], nodes, pos, dict(bindings), _force_list=_force_list)
61
+ res = _match_patterns(
62
+ [pattern], nodes, pos, dict(bindings), _force_list=_force_list
63
+ )
58
64
  if res:
59
65
  n_matches += 1
60
66
  if n_matches == 1:
@@ -74,7 +80,9 @@ def _match_patterns(
74
80
  out.append((new_bindings, pos))
75
81
 
76
82
  case Optional(pattern=sub_pattern, key=key):
77
- res = _match_patterns([sub_pattern], nodes, pos, dict(bindings), _force_list=_force_list)
83
+ res = _match_patterns(
84
+ [sub_pattern], nodes, pos, dict(bindings), _force_list=_force_list
85
+ )
78
86
  if res:
79
87
  new_bindings, new_pos = res[-1]
80
88
  if key is not None:
@@ -88,13 +96,17 @@ def _match_patterns(
88
96
  case _:
89
97
  # single node pattern
90
98
  if pos < len(nodes):
91
- res = first.match_node(nodes[pos], dict(bindings), _force_list=_force_list)
99
+ res = first.match_node(
100
+ nodes[pos], dict(bindings), _force_list=_force_list
101
+ )
92
102
  if res is not None:
93
103
  out.append((res, pos + 1))
94
104
 
95
105
  # Match remaining patterns
96
106
  if out and remaining:
97
- rem_res = _match_patterns(remaining, nodes, out[-1][1], out[-1][0], _force_list=_force_list)
107
+ rem_res = _match_patterns(
108
+ remaining, nodes, out[-1][1], out[-1][0], _force_list=_force_list
109
+ )
98
110
  if not rem_res:
99
111
  return []
100
112
  out.extend(rem_res)
@@ -25,7 +25,13 @@ class Bind(Pattern):
25
25
  def __init__(self, key: str):
26
26
  self.key = key
27
27
 
28
- def match_node(self, node: Any, bindings: dict[str, Any] | None = None, *, _force_list: bool = False):
28
+ def match_node(
29
+ self,
30
+ node: Any,
31
+ bindings: dict[str, Any] | None = None,
32
+ *,
33
+ _force_list: bool = False,
34
+ ):
29
35
  bindings = bindings or {}
30
36
  if self.key in bindings:
31
37
  if not _force_list:
@@ -41,7 +47,13 @@ class WildCard(Pattern):
41
47
 
42
48
  def __init__(self): ...
43
49
 
44
- def match_node(self, node: Any, bindings: dict[str, Any] | None = None, *, _force_list: bool = False):
50
+ def match_node(
51
+ self,
52
+ node: Any,
53
+ bindings: dict[str, Any] | None = None,
54
+ *,
55
+ _force_list: bool = False,
56
+ ):
45
57
  bindings = bindings or {}
46
58
  return bindings
47
59
 
@@ -58,7 +70,13 @@ class NodePattern(Pattern):
58
70
  self.node_type = node_type
59
71
  self.field_patterns = field_patterns
60
72
 
61
- def match_node(self, node: Any, bindings: dict[str, Any] | None = None, *, _force_list: bool = False):
73
+ def match_node(
74
+ self,
75
+ node: Any,
76
+ bindings: dict[str, Any] | None = None,
77
+ *,
78
+ _force_list: bool = False,
79
+ ):
62
80
  bindings = bindings or {}
63
81
  if not isinstance(node, self.node_type):
64
82
  return None
@@ -85,9 +103,7 @@ class NodePattern(Pattern):
85
103
  return None
86
104
  merged[k] = self._to_list(merged[k]) + self._to_list(v)
87
105
  else:
88
- merged[k] = (
89
- self._to_list(v) if _force_list else v
90
- )
106
+ merged[k] = self._to_list(v) if _force_list else v
91
107
  else:
92
108
  if val != pat:
93
109
  return None
@@ -113,7 +129,11 @@ class Collect(Pattern):
113
129
  self.key = key
114
130
 
115
131
  def match_node(
116
- self, node: Any, bindings: dict[str, Any] | None = None, *, _force_list: bool = False
132
+ self,
133
+ node: Any,
134
+ bindings: dict[str, Any] | None = None,
135
+ *,
136
+ _force_list: bool = False,
117
137
  ) -> None | dict[str, Any]:
118
138
  bindings = bindings or {}
119
139
  # Collect is a binding boundary — inner patterns always see _force_list=False
@@ -167,7 +187,13 @@ class Filter(Pattern):
167
187
  self.predicate = predicate
168
188
  self.key = key
169
189
 
170
- def match_node(self, node: Any, bindings: dict[str, Any] | None = None, *, _force_list: bool = False):
190
+ def match_node(
191
+ self,
192
+ node: Any,
193
+ bindings: dict[str, Any] | None = None,
194
+ *,
195
+ _force_list: bool = False,
196
+ ):
171
197
  bindings = bindings or {}
172
198
  if not self.predicate(node):
173
199
  return None
@@ -192,7 +218,13 @@ class Not(Pattern):
192
218
  def __init__(self, pattern: Pattern):
193
219
  self.pattern = pattern
194
220
 
195
- def match_node(self, node: Any, bindings: dict[str, Any] | None = None, *, _force_list: bool = False):
221
+ def match_node(
222
+ self,
223
+ node: Any,
224
+ bindings: dict[str, Any] | None = None,
225
+ *,
226
+ _force_list: bool = False,
227
+ ):
196
228
  bindings = bindings or {}
197
229
 
198
230
  # Use _match_patterns so that SequencePatterns (like OneOf) don't
@@ -214,7 +246,13 @@ class Contains(Pattern):
214
246
  def __init__(self, pattern: Sequence[Pattern]):
215
247
  self.pattern = list(pattern)
216
248
 
217
- def match_node(self, node: Any, bindings: dict[str, Any] | None = None, *, _force_list: bool = False):
249
+ def match_node(
250
+ self,
251
+ node: Any,
252
+ bindings: dict[str, Any] | None = None,
253
+ *,
254
+ _force_list: bool = False,
255
+ ):
218
256
  bindings = bindings or {}
219
257
  finder = SingleOccurrenceFinder(self.pattern)
220
258
  finder.visit(node)
@@ -245,12 +283,20 @@ class AllOf(Pattern):
245
283
  def __init__(self, patterns: Sequence[Pattern]):
246
284
  self.patterns = list(patterns)
247
285
 
248
- def match_node(self, node: Any, bindings: dict[str, Any] | None = None, *, _force_list: bool = False):
286
+ def match_node(
287
+ self,
288
+ node: Any,
289
+ bindings: dict[str, Any] | None = None,
290
+ *,
291
+ _force_list: bool = False,
292
+ ):
249
293
  bindings = bindings or {}
250
294
  new_bindings = dict(bindings)
251
295
 
252
296
  for pattern in self.patterns:
253
- new_bindings = pattern.match_node(node, new_bindings, _force_list=_force_list)
297
+ new_bindings = pattern.match_node(
298
+ node, new_bindings, _force_list=_force_list
299
+ )
254
300
  if new_bindings is None:
255
301
  return None
256
302
  return new_bindings
@@ -269,7 +315,13 @@ class AnyOf(Pattern):
269
315
  def __init__(self, patterns: Sequence[Pattern]):
270
316
  self.patterns = list(patterns)
271
317
 
272
- def match_node(self, node: Any, bindings: dict[str, Any] | None = None, *, _force_list: bool = False):
318
+ def match_node(
319
+ self,
320
+ node: Any,
321
+ bindings: dict[str, Any] | None = None,
322
+ *,
323
+ _force_list: bool = False,
324
+ ):
273
325
  bindings = bindings or {}
274
326
  merged = dict(bindings)
275
327
  matched_any = False
@@ -6,7 +6,11 @@ from ast_pattern_engine.core import Pattern
6
6
 
7
7
  class SequencePattern(Pattern):
8
8
  def match_node(
9
- self, node: object, bindings: dict[str, object] | None = None, *, _force_list: bool = False
9
+ self,
10
+ node: object,
11
+ bindings: dict[str, object] | None = None,
12
+ *,
13
+ _force_list: bool = False,
10
14
  ):
11
15
  # Matching is handled by engine._match_sequence
12
16
  raise NotImplementedError(
@@ -1,3 +1,4 @@
1
+ from typing import TypeAlias
1
2
  import ast
2
3
  from typing import Any
3
4
  from collections.abc import Sequence, Callable
@@ -5,7 +6,7 @@ from collections.abc import Sequence, Callable
5
6
  from ast_pattern_engine.core import Pattern
6
7
  from ast_pattern_engine.engine import _match_patterns
7
8
 
8
- type ReplaceResult = ast.AST | list[ast.AST] | None
9
+ ReplaceResult: TypeAlias = ast.AST | list[ast.AST] | None
9
10
 
10
11
 
11
12
  class PatternTransformer(ast.NodeTransformer):
@@ -1,32 +1,22 @@
1
1
  import ast
2
2
  from ast_pattern_engine.nodes.basic import AnyOf, NodePattern, Collect
3
3
 
4
+
4
5
  def test_any_of_matches_and_conflicts():
5
6
  node = ast.parse("1").body[0].value
6
-
7
+
7
8
  # First matches
8
- pattern1 = AnyOf([
9
- NodePattern(ast.Constant),
10
- NodePattern(ast.Name)
11
- ])
9
+ pattern1 = AnyOf([NodePattern(ast.Constant), NodePattern(ast.Name)])
12
10
  assert pattern1.match_node(node, {"existing": 1}) == {"existing": 1}
13
-
11
+
14
12
  # Second matches
15
- pattern2 = AnyOf([
16
- NodePattern(ast.Name),
17
- NodePattern(ast.Constant)
18
- ])
13
+ pattern2 = AnyOf([NodePattern(ast.Name), NodePattern(ast.Constant)])
19
14
  assert pattern2.match_node(node, {}) == {}
20
-
15
+
21
16
  # None matches
22
- pattern3 = AnyOf([
23
- NodePattern(ast.Name),
24
- NodePattern(ast.Assign)
25
- ])
17
+ pattern3 = AnyOf([NodePattern(ast.Name), NodePattern(ast.Assign)])
26
18
  assert pattern3.match_node(node, {}) is None
27
19
 
28
20
  # Match produces conflicting key
29
- pattern4 = AnyOf([
30
- Collect(NodePattern(ast.Constant), "c")
31
- ])
21
+ pattern4 = AnyOf([Collect(NodePattern(ast.Constant), "c")])
32
22
  assert pattern4.match_node(node, {"c": "conflict"}) is None
@@ -1,6 +1,7 @@
1
1
  import ast
2
2
  from ast_pattern_engine.nodes.basic import Collect, NodePattern
3
3
 
4
+
4
5
  def test_collect_match_single_constant():
5
6
  node = ast.parse("1").body[0]
6
7
  assert isinstance(node, ast.Expr)
@@ -3,6 +3,7 @@ from ast_pattern_engine.nodes.basic import Filter
3
3
  from ast_pattern_engine.nodes.sequences import Repetition
4
4
  from ast_pattern_engine.engine import match_sequence
5
5
 
6
+
6
7
  def test_filter_no_key():
7
8
  node = ast.parse("1").body[0].value
8
9
  # Matches, no key bound
@@ -17,14 +18,14 @@ def test_filter_no_key():
17
18
  def test_filter_with_key_and_conflicts():
18
19
  node = ast.parse("1").body[0].value
19
20
  pattern = Filter(lambda x: isinstance(x, ast.Constant), key="my_filter")
20
-
21
+
21
22
  # Matches and binds
22
23
  res = pattern.match_node(node, {})
23
24
  assert res == {"my_filter": node}
24
-
25
+
25
26
  # Conflict on duplicate key without force
26
27
  assert pattern.match_node(node, {"my_filter": "existing"}) is None
27
-
28
+
28
29
  # When inside a Repetition, ancestor forces list
29
30
  rep_pattern = Repetition(pattern)
30
31
  # Match node directly against the filter while simulating the Repetition context
@@ -3,6 +3,7 @@ from ast_pattern_engine.nodes.basic import Collect, WildCard, NodePattern
3
3
  from ast_pattern_engine.nodes.sequences import OneOf
4
4
  from ast_pattern_engine.engine import _match_patterns, match_sequence
5
5
 
6
+
6
7
  def test_one_of_non_strict_returns_first_match():
7
8
  nodes = [ast.parse(src).body[0] for src in ("1", "2")]
8
9
  pattern = [
@@ -49,9 +50,9 @@ def test_one_of_strict_matches_exactly_one_pattern():
49
50
  ]
50
51
 
51
52
  result = _match_patterns(pattern, nodes, 0, {})
52
- assert (
53
- len(result) == 0
54
- ), "Expected strict OneOf not to match because multiple sub patterns match"
53
+ assert len(result) == 0, (
54
+ "Expected strict OneOf not to match because multiple sub patterns match"
55
+ )
55
56
 
56
57
  # Section B: Test strict mode with exactly one matching pattern for each line
57
58
  nodes = [ast.parse(src).body[0] for src in ("1", "x=2")]
@@ -3,6 +3,7 @@ from ast_pattern_engine.nodes.basic import Collect, WildCard, NodePattern
3
3
  from ast_pattern_engine.nodes.sequences import PatternGroup
4
4
  from ast_pattern_engine.engine import _match_patterns, match_sequence
5
5
 
6
+
6
7
  def test_pattern_group_collects_inner_bindings_under_key():
7
8
  nodes = [ast.parse(src).body[0].value for src in ("1", "2")] # type: ignore
8
9
  pattern = [
@@ -3,6 +3,7 @@ from ast_pattern_engine.nodes.basic import Collect, WildCard
3
3
  from ast_pattern_engine.nodes.sequences import Repetition
4
4
  from ast_pattern_engine.engine import _match_patterns
5
5
 
6
+
6
7
  def test_collect_inside_one_or_more_accumulates_nodes():
7
8
  nodes = [ast.parse(str(i)).body[0].value for i in range(3)] # type: ignore
8
9
  pattern = [Repetition(Collect(WildCard(), "item"))]
@@ -2,6 +2,7 @@ import ast
2
2
  from ast_pattern_engine.nodes.basic import Collect, WildCard
3
3
  from ast_pattern_engine.engine import match_sequence
4
4
 
5
+
5
6
  def test_match_sequence_returns_non_overlapping_bindings():
6
7
  nodes = [ast.parse(text).body[0] for text in ("a = 1", "b = 2", "c = 3")]
7
8
  pattern = [Collect(WildCard(), "assign")]
@@ -14,7 +14,7 @@ def _constant(value, template):
14
14
 
15
15
 
16
16
  def test_bottom_up_pattern_transformer_collapses_children_before_parent():
17
- source = "def foo():\n" " return (1 + 2) + (3 + 4)\n"
17
+ source = "def foo():\n return (1 + 2) + (3 + 4)\n"
18
18
  tree = ast.parse(source)
19
19
  pattern = [Collect(NodePattern(ast.BinOp), "expr")]
20
20
 
@@ -60,6 +60,7 @@ def test_bu_transformer_list_manipulation():
60
60
  # Return multiple nodes (expands)
61
61
  def expand(_):
62
62
  return [_parse_stmt("pass"), _parse_stmt("pass")]
63
+
63
64
  transformer3 = BottomUpPatternTransformer(pattern, {"a": expand})
64
65
  res3 = transformer3.visit(ast.parse("x = 1"))
65
66
  assert len(res3.body) == 2
@@ -17,7 +17,10 @@ def test_pattern_finder_collects_node_matches():
17
17
  def test_pattern_finder_scan_list():
18
18
  tree = ast.parse("a = 1\nb = 2\nc = 3")
19
19
  # A sequence of length 2 to trigger `_scan_list` len(self.pattern) > 1 branch
20
- pattern = [Collect(NodePattern(ast.Assign), "a"), Collect(NodePattern(ast.Assign), "b")]
20
+ pattern = [
21
+ Collect(NodePattern(ast.Assign), "a"),
22
+ Collect(NodePattern(ast.Assign), "b"),
23
+ ]
21
24
  finder = PatternFinder(pattern)
22
25
  finder.visit(tree)
23
26
  # Print the matches to debug
@@ -123,7 +123,7 @@ def test_pt_nonlist_replace_errors():
123
123
 
124
124
  # Error: handler returns non-list
125
125
  def bad_handler(_):
126
- return _constant_val(2) # type: ignore
126
+ return _constant_val(2) # type: ignore
127
127
 
128
128
  transformer2 = PatternTransformer(pattern, {"c": bad_handler})
129
129
  with pytest.raises(TypeError, match="Handler must return list"):
@@ -133,7 +133,7 @@ def test_pt_nonlist_replace_errors():
133
133
  def test_pt_plan_errors():
134
134
  tree = ast.parse("a = 1\nb = 2")
135
135
  pattern = [Collect(NodePattern(ast.Assign), "a")]
136
-
136
+
137
137
  # key not in bindings (should silently continue)
138
138
  transformer1 = PatternTransformer(pattern, {"missing_key": lambda b: []})
139
139
  transformer1.visit(ast.parse("a = 1"))
@@ -142,11 +142,12 @@ def test_pt_plan_errors():
142
142
  transformer2 = PatternTransformer(pattern, {"a": lambda b: []})
143
143
  # override matching to return empty list for testing
144
144
  b = {"a": []}
145
- transformer2.matches.append(b)
146
-
145
+ transformer2.matches.append(b)
146
+
147
147
  # handler returns non-list
148
148
  def bad_plan_handler(_):
149
149
  return "not a list"
150
+
150
151
  transformer3 = PatternTransformer(pattern, {"a": bad_plan_handler})
151
152
  with pytest.raises(TypeError, match="must return `list"):
152
153
  transformer3.visit(ast.parse("a = 1"))
@@ -166,6 +167,7 @@ def test_pt_nested_replace_and_delete():
166
167
  # Replace nested nodes
167
168
  def replace_with_99(_):
168
169
  return [_constant_val(99)]
170
+
169
171
  transformer2 = PatternTransformer(pattern, {"c": replace_with_99})
170
172
  res2 = transformer2.visit(ast.parse("x = [1, 2, 3]"))
171
173
  assert all(elt.value == 99 for elt in res2.body[0].value.elts)
@@ -174,18 +176,20 @@ def test_pt_nested_replace_and_delete():
174
176
  def test_pt_dict_as_nodes():
175
177
  tree = ast.parse("x = 1")
176
178
  pattern = [Collect(NodePattern(ast.Assign), "a")]
177
-
179
+
178
180
  # We force the binding to be a dict to trigger dict handling in _as_nodes
179
181
  class DictTransformer(PatternTransformer):
180
182
  def _plan(self, seq):
181
183
  # Intercept and mutate bindings
182
184
  res = super()._plan(seq)
183
185
  return res
184
-
186
+
185
187
  transformer = DictTransformer(pattern, {"a": lambda b: [_parse_stmt("x = 2")]})
186
188
  # Override match manually
187
- mtch = transformer._match_patterns = lambda p, s, i, b: [({"a": {"nested": s[i]}}, i+1)] if i < len(s) else []
188
-
189
+ mtch = transformer._match_patterns = lambda p, s, i, b: (
190
+ [({"a": {"nested": s[i]}}, i + 1)] if i < len(s) else []
191
+ )
192
+
189
193
  res = transformer.visit(tree)
190
194
  assert isinstance(res.body[0], ast.Assign)
191
195
  assert res.body[0].value.value == 2
@@ -195,14 +199,14 @@ def test_pt_generic_visit_list_field_error():
195
199
  tree = ast.parse("a = 1")
196
200
  pattern = [Collect(NodePattern(ast.Assign), "a")]
197
201
  transformer = PatternTransformer(pattern, {})
198
-
202
+
199
203
  # Force a child to not be an AST node
200
204
  class BadTransformer(PatternTransformer):
201
205
  def visit(self, node):
202
206
  if isinstance(node, ast.Assign):
203
207
  return "Not an AST node"
204
208
  return super().visit(node)
205
-
209
+
206
210
  bad_transformer = BadTransformer(pattern, {})
207
211
  with pytest.raises(TypeError, match="must contain AST nodes"):
208
212
  bad_transformer.visit(ast.parse("a = 1"))
@@ -29,8 +29,8 @@ def test_single_occurrence_finder_early_exit():
29
29
  tree = ast.parse("a = 1\nb = 2")
30
30
  pattern = [NodePattern(ast.Assign)]
31
31
  finder = SingleOccurrenceFinder(pattern)
32
-
32
+
33
33
  # We manually set found to True to test early exit in visit
34
34
  finder.found = True
35
35
  finder.visit(tree)
36
- assert finder.match_node is None # Didn't actually match because it early exited
36
+ assert finder.match_node is None # Didn't actually match because it early exited