flake8-stepdown 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flake8_stepdown/__init__.py +5 -0
- flake8_stepdown/cli.py +168 -0
- flake8_stepdown/core/__init__.py +1 -0
- flake8_stepdown/core/bindings.py +148 -0
- flake8_stepdown/core/graph.py +156 -0
- flake8_stepdown/core/ordering.py +184 -0
- flake8_stepdown/core/parser.py +102 -0
- flake8_stepdown/core/references.py +260 -0
- flake8_stepdown/flake8_plugin.py +42 -0
- flake8_stepdown/py.typed +0 -0
- flake8_stepdown/reporter.py +75 -0
- flake8_stepdown/rewriter.py +99 -0
- flake8_stepdown/types.py +78 -0
- flake8_stepdown-0.1.0.dist-info/METADATA +126 -0
- flake8_stepdown-0.1.0.dist-info/RECORD +17 -0
- flake8_stepdown-0.1.0.dist-info/WHEEL +4 -0
- flake8_stepdown-0.1.0.dist-info/entry_points.txt +5 -0
flake8_stepdown/cli.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""CLI entry point for flake8-stepdown."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import sys
|
|
7
|
+
from fnmatch import fnmatch
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from flake8_stepdown.core.ordering import order_module
|
|
11
|
+
from flake8_stepdown.reporter import format_diff, format_violations
|
|
12
|
+
|
|
13
|
+
EXIT_OK = 0
|
|
14
|
+
EXIT_VIOLATIONS = 1
|
|
15
|
+
EXIT_ERROR = 2
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def main(argv: list[str] | None = None) -> int:
|
|
19
|
+
"""CLI entry point.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Exit code: 0 (clean), 1 (violations/changes), 2 (error).
|
|
23
|
+
|
|
24
|
+
"""
|
|
25
|
+
parser = _build_parser()
|
|
26
|
+
args = parser.parse_args(argv)
|
|
27
|
+
|
|
28
|
+
# Handle stdin
|
|
29
|
+
if args.stdin_filename:
|
|
30
|
+
source = sys.stdin.read()
|
|
31
|
+
code, output = _process_source(source, args.stdin_filename, args)
|
|
32
|
+
if output:
|
|
33
|
+
_write_output(output)
|
|
34
|
+
return code
|
|
35
|
+
|
|
36
|
+
if not args.files:
|
|
37
|
+
sys.stderr.write("Error: no files specified\n")
|
|
38
|
+
return EXIT_ERROR
|
|
39
|
+
|
|
40
|
+
filepaths = _resolve_paths(args.files, args.exclude)
|
|
41
|
+
if not filepaths:
|
|
42
|
+
return EXIT_OK
|
|
43
|
+
|
|
44
|
+
exit_code = EXIT_OK
|
|
45
|
+
for filepath in filepaths:
|
|
46
|
+
code = _process_file(filepath, args)
|
|
47
|
+
if code == EXIT_ERROR:
|
|
48
|
+
return EXIT_ERROR
|
|
49
|
+
exit_code = max(exit_code, code)
|
|
50
|
+
|
|
51
|
+
return exit_code
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _build_parser() -> argparse.ArgumentParser:
|
|
55
|
+
"""Build the argument parser."""
|
|
56
|
+
parser = argparse.ArgumentParser(
|
|
57
|
+
prog="stepdown",
|
|
58
|
+
description="Enforce top-down function ordering in Python",
|
|
59
|
+
)
|
|
60
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
61
|
+
|
|
62
|
+
common = argparse.ArgumentParser(add_help=False)
|
|
63
|
+
common.add_argument("files", nargs="*", help="Files or directories to check")
|
|
64
|
+
common.add_argument("--exclude", action="append", default=[], help="Glob patterns to exclude")
|
|
65
|
+
common.add_argument(
|
|
66
|
+
"-v",
|
|
67
|
+
"--verbose",
|
|
68
|
+
action="store_true",
|
|
69
|
+
help="Show debug info (mutual recursion info on stderr)",
|
|
70
|
+
)
|
|
71
|
+
common.add_argument(
|
|
72
|
+
"--stdin-filename",
|
|
73
|
+
help="Read from stdin, use this filename for output",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
check_parser = subparsers.add_parser("check", parents=[common], help="Report violations")
|
|
77
|
+
check_parser.add_argument(
|
|
78
|
+
"--format",
|
|
79
|
+
dest="fmt",
|
|
80
|
+
choices=["text", "json"],
|
|
81
|
+
default="text",
|
|
82
|
+
help="Output format",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
subparsers.add_parser("diff", parents=[common], help="Show unified diff")
|
|
86
|
+
subparsers.add_parser("fix", parents=[common], help="Rewrite files in place")
|
|
87
|
+
|
|
88
|
+
return parser
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _resolve_paths(paths: list[str], exclude: list[str]) -> list[str]:
|
|
92
|
+
"""Expand directories to .py files and apply exclude patterns."""
|
|
93
|
+
resolved: list[str] = []
|
|
94
|
+
for entry in paths:
|
|
95
|
+
p = Path(entry)
|
|
96
|
+
if p.is_dir():
|
|
97
|
+
for py_file in sorted(p.rglob("*.py")):
|
|
98
|
+
filepath = str(py_file)
|
|
99
|
+
if not any(fnmatch(filepath, pat) for pat in exclude):
|
|
100
|
+
resolved.append(filepath)
|
|
101
|
+
elif not any(fnmatch(entry, pat) for pat in exclude):
|
|
102
|
+
resolved.append(entry)
|
|
103
|
+
return resolved
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _process_file(filepath: str, args: argparse.Namespace) -> int:
|
|
107
|
+
"""Process a single file and handle output."""
|
|
108
|
+
path = Path(filepath)
|
|
109
|
+
if not path.exists():
|
|
110
|
+
sys.stderr.write(f"Error: {filepath} not found\n")
|
|
111
|
+
return EXIT_ERROR
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
source = path.read_text()
|
|
115
|
+
except (OSError, UnicodeDecodeError) as e:
|
|
116
|
+
sys.stderr.write(f"Error reading {filepath}: {e}\n")
|
|
117
|
+
return EXIT_ERROR
|
|
118
|
+
|
|
119
|
+
code, output = _process_source(source, filepath, args)
|
|
120
|
+
|
|
121
|
+
if args.command == "fix" and code == EXIT_VIOLATIONS and output:
|
|
122
|
+
path.write_text(output)
|
|
123
|
+
elif output:
|
|
124
|
+
_write_output(output)
|
|
125
|
+
|
|
126
|
+
return code
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _process_source(
|
|
130
|
+
source: str,
|
|
131
|
+
filename: str,
|
|
132
|
+
args: argparse.Namespace,
|
|
133
|
+
) -> tuple[int, str]:
|
|
134
|
+
"""Process a single source file and return (exit_code, output)."""
|
|
135
|
+
compute_rewrite = args.command != "check"
|
|
136
|
+
result = order_module(source, compute_rewrite=compute_rewrite)
|
|
137
|
+
|
|
138
|
+
if args.verbose and result.mutual_recursion_groups:
|
|
139
|
+
for group in result.mutual_recursion_groups:
|
|
140
|
+
sys.stderr.write(
|
|
141
|
+
f"{filename}: mutual recursion between {', '.join(group)}; original order preserved\n"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if args.command == "check":
|
|
145
|
+
output = format_violations(result.violations, filename=filename, fmt=args.fmt)
|
|
146
|
+
return (EXIT_VIOLATIONS if result.violations else EXIT_OK), output
|
|
147
|
+
|
|
148
|
+
if args.command == "diff":
|
|
149
|
+
if result.reordered_source is not None:
|
|
150
|
+
output = format_diff(source, result.reordered_source, filename=filename)
|
|
151
|
+
return EXIT_VIOLATIONS, output
|
|
152
|
+
return EXIT_OK, ""
|
|
153
|
+
|
|
154
|
+
# fix command
|
|
155
|
+
if result.reordered_source is not None:
|
|
156
|
+
return EXIT_VIOLATIONS, result.reordered_source
|
|
157
|
+
return EXIT_OK, ""
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _write_output(output: str) -> None:
|
|
161
|
+
"""Write output to stdout with trailing newline if needed."""
|
|
162
|
+
sys.stdout.write(output)
|
|
163
|
+
if not output.endswith("\n"):
|
|
164
|
+
sys.stdout.write("\n")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
if __name__ == "__main__":
|
|
168
|
+
sys.exit(main())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Core analysis modules for flake8-stepdown."""
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""Extract bindings (defined names) from module-level statements."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import libcst as cst
|
|
8
|
+
import libcst.matchers as m
|
|
9
|
+
|
|
10
|
+
from flake8_stepdown.types import Statement
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import Mapping
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def extract_bindings(
|
|
17
|
+
statements: list[cst.CSTNode],
|
|
18
|
+
positions: Mapping[cst.CSTNode, cst.metadata.CodeRange],
|
|
19
|
+
) -> list[Statement]:
|
|
20
|
+
"""Extract bindings from module-level statements.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
statements: The module-level CST nodes to analyze (functions, classes,
|
|
24
|
+
and assignments from the reorderable zone between preamble and postamble).
|
|
25
|
+
positions: Position mapping from MetadataWrapper.resolve(PositionProvider).
|
|
26
|
+
|
|
27
|
+
Groups consecutive @overload stubs with their implementation into a single Statement.
|
|
28
|
+
Returns Statement objects with empty refs (to be populated by references module).
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
result: list[Statement] = []
|
|
32
|
+
i = 0
|
|
33
|
+
nodes = list(statements)
|
|
34
|
+
|
|
35
|
+
while i < len(nodes):
|
|
36
|
+
node = nodes[i]
|
|
37
|
+
|
|
38
|
+
# Check for @overload grouping
|
|
39
|
+
if isinstance(node, cst.FunctionDef) and _has_overload_decorator(node):
|
|
40
|
+
func_name = node.name.value
|
|
41
|
+
group_nodes: list[cst.CSTNode] = [node]
|
|
42
|
+
|
|
43
|
+
# Collect consecutive same-name functions
|
|
44
|
+
j = i + 1
|
|
45
|
+
while j < len(nodes):
|
|
46
|
+
next_node = nodes[j]
|
|
47
|
+
if isinstance(next_node, cst.FunctionDef) and next_node.name.value == func_name:
|
|
48
|
+
group_nodes.append(next_node)
|
|
49
|
+
if not _has_overload_decorator(next_node):
|
|
50
|
+
j += 1
|
|
51
|
+
break
|
|
52
|
+
j += 1
|
|
53
|
+
else:
|
|
54
|
+
break
|
|
55
|
+
|
|
56
|
+
# Merge if stubs + implementation (>1 node and last is not overload)
|
|
57
|
+
last_node = group_nodes[-1]
|
|
58
|
+
if (
|
|
59
|
+
len(group_nodes) > 1
|
|
60
|
+
and isinstance(last_node, cst.FunctionDef)
|
|
61
|
+
and not _has_overload_decorator(last_node)
|
|
62
|
+
):
|
|
63
|
+
first_pos = positions.get(group_nodes[0])
|
|
64
|
+
last_pos = positions.get(last_node)
|
|
65
|
+
result.append(
|
|
66
|
+
Statement(
|
|
67
|
+
node=last_node,
|
|
68
|
+
start_line=first_pos.start.line if first_pos else 0,
|
|
69
|
+
end_line=last_pos.end.line if last_pos else 0,
|
|
70
|
+
bindings=frozenset({func_name}),
|
|
71
|
+
immediate_refs=frozenset(),
|
|
72
|
+
deferred_refs=frozenset(),
|
|
73
|
+
is_overload_group=True,
|
|
74
|
+
),
|
|
75
|
+
)
|
|
76
|
+
i = j
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
# Not a complete overload group — fall through to normal handling
|
|
80
|
+
|
|
81
|
+
# Normal statement
|
|
82
|
+
pos = positions.get(node)
|
|
83
|
+
start_line = pos.start.line if pos else 0
|
|
84
|
+
end_line = pos.end.line if pos else 0
|
|
85
|
+
|
|
86
|
+
bindings = (
|
|
87
|
+
_extract_binding_names(node) if isinstance(node, cst.BaseStatement) else frozenset()
|
|
88
|
+
)
|
|
89
|
+
result.append(
|
|
90
|
+
Statement(
|
|
91
|
+
node=node,
|
|
92
|
+
start_line=start_line,
|
|
93
|
+
end_line=end_line,
|
|
94
|
+
bindings=bindings,
|
|
95
|
+
immediate_refs=frozenset(),
|
|
96
|
+
deferred_refs=frozenset(),
|
|
97
|
+
is_overload_group=False,
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
i += 1
|
|
101
|
+
|
|
102
|
+
return result
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _extract_binding_names(node: cst.BaseStatement) -> frozenset[str]:
|
|
106
|
+
"""Extract the names defined by a single statement."""
|
|
107
|
+
if isinstance(node, cst.FunctionDef):
|
|
108
|
+
return frozenset({node.name.value})
|
|
109
|
+
|
|
110
|
+
if isinstance(node, cst.ClassDef):
|
|
111
|
+
return frozenset({node.name.value})
|
|
112
|
+
|
|
113
|
+
if isinstance(node, cst.SimpleStatementLine):
|
|
114
|
+
names: set[str] = set()
|
|
115
|
+
for stmt in node.body:
|
|
116
|
+
if isinstance(stmt, cst.Assign):
|
|
117
|
+
for target in stmt.targets:
|
|
118
|
+
names |= _collect_names(target.target)
|
|
119
|
+
elif isinstance(stmt, cst.AnnAssign) and stmt.value is not None:
|
|
120
|
+
names |= _collect_names(stmt.target)
|
|
121
|
+
return frozenset(names)
|
|
122
|
+
|
|
123
|
+
return frozenset()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _collect_names(target: cst.BaseExpression) -> set[str]:
|
|
127
|
+
"""Recursively collect all Name identifiers from an assignment target."""
|
|
128
|
+
if isinstance(target, cst.Name):
|
|
129
|
+
return {target.value}
|
|
130
|
+
if isinstance(target, cst.Tuple):
|
|
131
|
+
names: set[str] = set()
|
|
132
|
+
for element in target.elements:
|
|
133
|
+
names |= _collect_names(element.value)
|
|
134
|
+
return names
|
|
135
|
+
if isinstance(target, cst.StarredElement):
|
|
136
|
+
return _collect_names(target.value)
|
|
137
|
+
return set()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _has_overload_decorator(node: cst.FunctionDef) -> bool:
|
|
141
|
+
"""Check if a FunctionDef has @typing.overload or @overload."""
|
|
142
|
+
for decorator in node.decorators:
|
|
143
|
+
dec = decorator.decorator
|
|
144
|
+
if m.matches(dec, m.Name("overload")):
|
|
145
|
+
return True
|
|
146
|
+
if m.matches(dec, m.Attribute(value=m.Name("typing"), attr=m.Name("overload"))):
|
|
147
|
+
return True
|
|
148
|
+
return False
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Dependency graph construction, topological sort, and SCC detection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import heapq
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from flake8_stepdown.core.parser import is_docstring, is_simple_assignment
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from flake8_stepdown.types import Statement
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def build_normalized_graph(statements: list[Statement]) -> dict[int, set[int]]:
|
|
15
|
+
"""Build a normalized dependency graph where edge A->B means "A must appear before B".
|
|
16
|
+
|
|
17
|
+
- Deferred ref: A calls B -> edge A->B (caller before callee)
|
|
18
|
+
- Immediate ref: A uses @B -> edge B->A (dependency before dependent)
|
|
19
|
+
"""
|
|
20
|
+
# Build name -> index mapping
|
|
21
|
+
name_to_idx: dict[str, int] = {}
|
|
22
|
+
for idx, stmt in enumerate(statements):
|
|
23
|
+
for name in stmt.bindings:
|
|
24
|
+
name_to_idx[name] = idx
|
|
25
|
+
|
|
26
|
+
graph: dict[int, set[int]] = {i: set() for i in range(len(statements))}
|
|
27
|
+
|
|
28
|
+
for idx, stmt in enumerate(statements):
|
|
29
|
+
# Deferred refs: caller before callee -> edge idx -> target
|
|
30
|
+
for ref in stmt.deferred_refs:
|
|
31
|
+
if ref in name_to_idx:
|
|
32
|
+
target = name_to_idx[ref]
|
|
33
|
+
if target != idx:
|
|
34
|
+
graph[idx].add(target)
|
|
35
|
+
|
|
36
|
+
# Immediate refs: dependency before dependent -> edge target -> idx
|
|
37
|
+
for ref in stmt.immediate_refs:
|
|
38
|
+
if ref in name_to_idx:
|
|
39
|
+
target = name_to_idx[ref]
|
|
40
|
+
if target != idx:
|
|
41
|
+
graph[target].add(idx)
|
|
42
|
+
|
|
43
|
+
return graph
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def topological_sort(graph: dict[int, set[int]], num_nodes: int) -> list[int] | None:
|
|
47
|
+
"""Kahn's topological sort with min-heap stability tie-breaking.
|
|
48
|
+
|
|
49
|
+
Returns ordered list of node indices, or None if a cycle is detected.
|
|
50
|
+
"""
|
|
51
|
+
# Compute in-degrees
|
|
52
|
+
in_degree = [0] * num_nodes
|
|
53
|
+
for successors in graph.values():
|
|
54
|
+
for s in successors:
|
|
55
|
+
in_degree[s] += 1
|
|
56
|
+
|
|
57
|
+
# Initialize min-heap with zero in-degree nodes (keyed by original index for stability)
|
|
58
|
+
heap: list[int] = [i for i in range(num_nodes) if in_degree[i] == 0]
|
|
59
|
+
heapq.heapify(heap)
|
|
60
|
+
|
|
61
|
+
result: list[int] = []
|
|
62
|
+
while heap:
|
|
63
|
+
node = heapq.heappop(heap)
|
|
64
|
+
result.append(node)
|
|
65
|
+
for successor in graph.get(node, set()):
|
|
66
|
+
in_degree[successor] -= 1
|
|
67
|
+
if in_degree[successor] == 0:
|
|
68
|
+
heapq.heappush(heap, successor)
|
|
69
|
+
|
|
70
|
+
if len(result) != num_nodes:
|
|
71
|
+
return None # Cycle detected
|
|
72
|
+
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def find_sccs(graph: dict[int, set[int]], num_nodes: int) -> list[list[int]]: # noqa: C901
|
|
77
|
+
"""Find strongly connected components with size > 1 using Tarjan's algorithm."""
|
|
78
|
+
index_counter = [0]
|
|
79
|
+
stack: list[int] = []
|
|
80
|
+
on_stack = [False] * num_nodes
|
|
81
|
+
indices = [-1] * num_nodes
|
|
82
|
+
lowlinks = [-1] * num_nodes
|
|
83
|
+
result: list[list[int]] = []
|
|
84
|
+
|
|
85
|
+
def strongconnect(v: int) -> None:
|
|
86
|
+
indices[v] = index_counter[0]
|
|
87
|
+
lowlinks[v] = index_counter[0]
|
|
88
|
+
index_counter[0] += 1
|
|
89
|
+
stack.append(v)
|
|
90
|
+
on_stack[v] = True
|
|
91
|
+
|
|
92
|
+
for w in graph.get(v, set()):
|
|
93
|
+
if indices[w] == -1:
|
|
94
|
+
strongconnect(w)
|
|
95
|
+
lowlinks[v] = min(lowlinks[v], lowlinks[w])
|
|
96
|
+
elif on_stack[w]:
|
|
97
|
+
lowlinks[v] = min(lowlinks[v], indices[w])
|
|
98
|
+
|
|
99
|
+
if lowlinks[v] == indices[v]:
|
|
100
|
+
scc: list[int] = []
|
|
101
|
+
while True:
|
|
102
|
+
w = stack.pop()
|
|
103
|
+
on_stack[w] = False
|
|
104
|
+
scc.append(w)
|
|
105
|
+
if w == v:
|
|
106
|
+
break
|
|
107
|
+
if len(scc) > 1:
|
|
108
|
+
result.append(scc)
|
|
109
|
+
|
|
110
|
+
for v in range(num_nodes):
|
|
111
|
+
if indices[v] == -1:
|
|
112
|
+
strongconnect(v)
|
|
113
|
+
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def attach_no_binding_stmts(statements: list[Statement]) -> list[list[Statement]]:
|
|
118
|
+
"""Group statements so that those with no bindings attach to their neighbor.
|
|
119
|
+
|
|
120
|
+
A statement with no bindings attaches to the next statement with bindings.
|
|
121
|
+
If it's the last statement, it attaches to the preceding one.
|
|
122
|
+
|
|
123
|
+
Returns a list of groups, where each group is one or more statements
|
|
124
|
+
that move together.
|
|
125
|
+
"""
|
|
126
|
+
if not statements:
|
|
127
|
+
return []
|
|
128
|
+
|
|
129
|
+
groups: list[list[Statement]] = []
|
|
130
|
+
pending: list[Statement] = []
|
|
131
|
+
|
|
132
|
+
for stmt in statements:
|
|
133
|
+
if stmt.bindings:
|
|
134
|
+
# This statement has bindings — flush pending no-binding stmts as prefix
|
|
135
|
+
groups.append([*pending, stmt])
|
|
136
|
+
pending = []
|
|
137
|
+
elif (
|
|
138
|
+
groups
|
|
139
|
+
and not pending
|
|
140
|
+
and is_docstring(stmt.node)
|
|
141
|
+
and is_simple_assignment(groups[-1][-1].node)
|
|
142
|
+
):
|
|
143
|
+
# Docstring immediately after a constant — attach to the constant's group
|
|
144
|
+
groups[-1].append(stmt)
|
|
145
|
+
else:
|
|
146
|
+
pending.append(stmt)
|
|
147
|
+
|
|
148
|
+
# Handle trailing no-binding statements: attach to last group
|
|
149
|
+
if pending:
|
|
150
|
+
if groups:
|
|
151
|
+
groups[-1].extend(pending)
|
|
152
|
+
else:
|
|
153
|
+
# All statements have no bindings — just return them as one group
|
|
154
|
+
groups.append(pending)
|
|
155
|
+
|
|
156
|
+
return groups
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Orchestrator: parse -> segment -> bindings -> refs -> graph -> sort -> violations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import libcst as cst
|
|
6
|
+
|
|
7
|
+
from flake8_stepdown.core.bindings import extract_bindings
|
|
8
|
+
from flake8_stepdown.core.graph import (
|
|
9
|
+
attach_no_binding_stmts,
|
|
10
|
+
build_normalized_graph,
|
|
11
|
+
find_sccs,
|
|
12
|
+
topological_sort,
|
|
13
|
+
)
|
|
14
|
+
from flake8_stepdown.core.parser import parse_source, segment
|
|
15
|
+
from flake8_stepdown.core.references import detect_future_annotations, extract_refs
|
|
16
|
+
from flake8_stepdown.rewriter import rewrite
|
|
17
|
+
from flake8_stepdown.types import OrderingResult, Statement, Violation
|
|
18
|
+
|
|
19
|
+
_EMPTY_RESULT = OrderingResult(violations=[], reordered_source=None, mutual_recursion_groups=[])
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def order_module(source: str, *, compute_rewrite: bool = True) -> OrderingResult:
|
|
23
|
+
"""Analyze and determine the correct ordering for a Python module.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
source: Python source code.
|
|
27
|
+
compute_rewrite: Whether to compute the reordered source (default True).
|
|
28
|
+
Set to False when only violations are needed (e.g. flake8 plugin, check command).
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
OrderingResult with violations and optionally reordered source.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
if not source.strip():
|
|
35
|
+
return _EMPTY_RESULT
|
|
36
|
+
|
|
37
|
+
module = parse_source(source)
|
|
38
|
+
wrapper = cst.metadata.MetadataWrapper(module)
|
|
39
|
+
positions = wrapper.resolve(cst.metadata.PositionProvider)
|
|
40
|
+
seg = segment(wrapper.module)
|
|
41
|
+
|
|
42
|
+
if not seg.interstitials:
|
|
43
|
+
return _EMPTY_RESULT
|
|
44
|
+
|
|
45
|
+
# Extract bindings
|
|
46
|
+
statements = extract_bindings(seg.interstitials, positions)
|
|
47
|
+
|
|
48
|
+
# Extract references
|
|
49
|
+
has_future = detect_future_annotations(seg.preamble)
|
|
50
|
+
statements = extract_refs(statements, has_future_annotations=has_future)
|
|
51
|
+
|
|
52
|
+
# Attach no-binding statements
|
|
53
|
+
groups = attach_no_binding_stmts(statements)
|
|
54
|
+
|
|
55
|
+
# Build merged statements for graph (one per group)
|
|
56
|
+
merged: list[Statement] = []
|
|
57
|
+
for group in groups:
|
|
58
|
+
# Merge bindings and refs from all statements in the group
|
|
59
|
+
all_bindings: frozenset[str] = frozenset().union(*(s.bindings for s in group))
|
|
60
|
+
all_immediate: frozenset[str] = frozenset().union(*(s.immediate_refs for s in group))
|
|
61
|
+
all_deferred: frozenset[str] = frozenset().union(*(s.deferred_refs for s in group))
|
|
62
|
+
merged.append(
|
|
63
|
+
Statement(
|
|
64
|
+
node=group[0].node,
|
|
65
|
+
start_line=group[0].start_line,
|
|
66
|
+
end_line=group[-1].end_line,
|
|
67
|
+
bindings=all_bindings,
|
|
68
|
+
immediate_refs=all_immediate,
|
|
69
|
+
deferred_refs=all_deferred,
|
|
70
|
+
is_overload_group=group[0].is_overload_group,
|
|
71
|
+
),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Build graph and sort
|
|
75
|
+
graph = build_normalized_graph(merged)
|
|
76
|
+
num_nodes = len(merged)
|
|
77
|
+
|
|
78
|
+
# Detect SCCs
|
|
79
|
+
sccs = find_sccs(graph, num_nodes)
|
|
80
|
+
|
|
81
|
+
# For SCCs: remove internal edges and preserve original order
|
|
82
|
+
for scc in sccs:
|
|
83
|
+
scc_set = set(scc)
|
|
84
|
+
for node in scc:
|
|
85
|
+
graph[node] = {s for s in graph[node] if s not in scc_set}
|
|
86
|
+
|
|
87
|
+
# Topological sort
|
|
88
|
+
new_order = topological_sort(graph, num_nodes)
|
|
89
|
+
|
|
90
|
+
if new_order is None:
|
|
91
|
+
# Remaining cycles after SCC removal — shouldn't happen but handle gracefully
|
|
92
|
+
new_order = list(range(num_nodes))
|
|
93
|
+
|
|
94
|
+
# Check if order changed
|
|
95
|
+
changed = new_order != list(range(num_nodes))
|
|
96
|
+
|
|
97
|
+
# Generate violations and mutual recursion info
|
|
98
|
+
violations = _generate_violations(merged, new_order)
|
|
99
|
+
mutual_recursion_groups = _extract_mutual_recursion_groups(merged, sccs)
|
|
100
|
+
|
|
101
|
+
# Rewrite source if order changed and rewrite requested
|
|
102
|
+
reordered_source = None
|
|
103
|
+
if changed and compute_rewrite:
|
|
104
|
+
# Expand group order back to individual statement order
|
|
105
|
+
expanded_order: list[int] = []
|
|
106
|
+
offsets = []
|
|
107
|
+
offset = 0
|
|
108
|
+
for group in groups:
|
|
109
|
+
offsets.append(offset)
|
|
110
|
+
offset += len(group)
|
|
111
|
+
|
|
112
|
+
for group_idx in new_order:
|
|
113
|
+
group = groups[group_idx]
|
|
114
|
+
base = offsets[group_idx]
|
|
115
|
+
expanded_order.extend(base + j for j in range(len(group)))
|
|
116
|
+
|
|
117
|
+
all_nodes = [s.node for group in groups for s in group]
|
|
118
|
+
|
|
119
|
+
reordered_source = rewrite(
|
|
120
|
+
seg.module,
|
|
121
|
+
seg.preamble,
|
|
122
|
+
all_nodes,
|
|
123
|
+
seg.postamble,
|
|
124
|
+
expanded_order,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
return OrderingResult(
|
|
128
|
+
violations=violations,
|
|
129
|
+
reordered_source=reordered_source,
|
|
130
|
+
mutual_recursion_groups=mutual_recursion_groups,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _generate_violations(
|
|
135
|
+
statements: list[Statement],
|
|
136
|
+
new_order: list[int],
|
|
137
|
+
) -> list[Violation]:
|
|
138
|
+
"""Generate TDP001 violations from ordering differences."""
|
|
139
|
+
violations: list[Violation] = []
|
|
140
|
+
|
|
141
|
+
# Map from original index to new position
|
|
142
|
+
new_position = {orig_idx: new_pos for new_pos, orig_idx in enumerate(new_order)}
|
|
143
|
+
|
|
144
|
+
for orig_idx, stmt in enumerate(statements):
|
|
145
|
+
new_pos = new_position[orig_idx]
|
|
146
|
+
if new_pos != orig_idx:
|
|
147
|
+
# Find what it should come before
|
|
148
|
+
name = next(iter(stmt.bindings), "<unnamed>")
|
|
149
|
+
# Find the first statement that this one should precede
|
|
150
|
+
for other_idx in new_order:
|
|
151
|
+
if other_idx == orig_idx:
|
|
152
|
+
break
|
|
153
|
+
other = statements[other_idx]
|
|
154
|
+
other_name = next(iter(other.bindings), "<unnamed>")
|
|
155
|
+
if new_position[orig_idx] < new_position[other_idx]:
|
|
156
|
+
continue
|
|
157
|
+
violations.append(
|
|
158
|
+
Violation(
|
|
159
|
+
code="TDP001",
|
|
160
|
+
lineno=stmt.start_line,
|
|
161
|
+
col_offset=0,
|
|
162
|
+
name=name,
|
|
163
|
+
message=f"{name} should appear after {other_name}",
|
|
164
|
+
dependency=other_name,
|
|
165
|
+
),
|
|
166
|
+
)
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
return violations
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _extract_mutual_recursion_groups(
|
|
173
|
+
statements: list[Statement],
|
|
174
|
+
sccs: list[list[int]],
|
|
175
|
+
) -> list[list[str]]:
|
|
176
|
+
"""Extract mutual recursion groups from SCCs as lists of function names."""
|
|
177
|
+
groups: list[list[str]] = []
|
|
178
|
+
for scc in sccs:
|
|
179
|
+
names = sorted(
|
|
180
|
+
{n for idx in scc for n in statements[idx].bindings},
|
|
181
|
+
)
|
|
182
|
+
if names:
|
|
183
|
+
groups.append(names)
|
|
184
|
+
return groups
|