flake8-stepdown 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ """flake8-stepdown: enforce top-down function ordering in Python."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from flake8_stepdown.core.ordering import order_module
flake8_stepdown/cli.py ADDED
@@ -0,0 +1,168 @@
1
+ """CLI entry point for flake8-stepdown."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import sys
7
+ from fnmatch import fnmatch
8
+ from pathlib import Path
9
+
10
+ from flake8_stepdown.core.ordering import order_module
11
+ from flake8_stepdown.reporter import format_diff, format_violations
12
+
13
+ EXIT_OK = 0
14
+ EXIT_VIOLATIONS = 1
15
+ EXIT_ERROR = 2
16
+
17
+
18
+ def main(argv: list[str] | None = None) -> int:
19
+ """CLI entry point.
20
+
21
+ Returns:
22
+ Exit code: 0 (clean), 1 (violations/changes), 2 (error).
23
+
24
+ """
25
+ parser = _build_parser()
26
+ args = parser.parse_args(argv)
27
+
28
+ # Handle stdin
29
+ if args.stdin_filename:
30
+ source = sys.stdin.read()
31
+ code, output = _process_source(source, args.stdin_filename, args)
32
+ if output:
33
+ _write_output(output)
34
+ return code
35
+
36
+ if not args.files:
37
+ sys.stderr.write("Error: no files specified\n")
38
+ return EXIT_ERROR
39
+
40
+ filepaths = _resolve_paths(args.files, args.exclude)
41
+ if not filepaths:
42
+ return EXIT_OK
43
+
44
+ exit_code = EXIT_OK
45
+ for filepath in filepaths:
46
+ code = _process_file(filepath, args)
47
+ if code == EXIT_ERROR:
48
+ return EXIT_ERROR
49
+ exit_code = max(exit_code, code)
50
+
51
+ return exit_code
52
+
53
+
54
+ def _build_parser() -> argparse.ArgumentParser:
55
+ """Build the argument parser."""
56
+ parser = argparse.ArgumentParser(
57
+ prog="stepdown",
58
+ description="Enforce top-down function ordering in Python",
59
+ )
60
+ subparsers = parser.add_subparsers(dest="command", required=True)
61
+
62
+ common = argparse.ArgumentParser(add_help=False)
63
+ common.add_argument("files", nargs="*", help="Files or directories to check")
64
+ common.add_argument("--exclude", action="append", default=[], help="Glob patterns to exclude")
65
+ common.add_argument(
66
+ "-v",
67
+ "--verbose",
68
+ action="store_true",
69
+ help="Show debug info (mutual recursion info on stderr)",
70
+ )
71
+ common.add_argument(
72
+ "--stdin-filename",
73
+ help="Read from stdin, use this filename for output",
74
+ )
75
+
76
+ check_parser = subparsers.add_parser("check", parents=[common], help="Report violations")
77
+ check_parser.add_argument(
78
+ "--format",
79
+ dest="fmt",
80
+ choices=["text", "json"],
81
+ default="text",
82
+ help="Output format",
83
+ )
84
+
85
+ subparsers.add_parser("diff", parents=[common], help="Show unified diff")
86
+ subparsers.add_parser("fix", parents=[common], help="Rewrite files in place")
87
+
88
+ return parser
89
+
90
+
91
+ def _resolve_paths(paths: list[str], exclude: list[str]) -> list[str]:
92
+ """Expand directories to .py files and apply exclude patterns."""
93
+ resolved: list[str] = []
94
+ for entry in paths:
95
+ p = Path(entry)
96
+ if p.is_dir():
97
+ for py_file in sorted(p.rglob("*.py")):
98
+ filepath = str(py_file)
99
+ if not any(fnmatch(filepath, pat) for pat in exclude):
100
+ resolved.append(filepath)
101
+ elif not any(fnmatch(entry, pat) for pat in exclude):
102
+ resolved.append(entry)
103
+ return resolved
104
+
105
+
106
+ def _process_file(filepath: str, args: argparse.Namespace) -> int:
107
+ """Process a single file and handle output."""
108
+ path = Path(filepath)
109
+ if not path.exists():
110
+ sys.stderr.write(f"Error: {filepath} not found\n")
111
+ return EXIT_ERROR
112
+
113
+ try:
114
+ source = path.read_text()
115
+ except (OSError, UnicodeDecodeError) as e:
116
+ sys.stderr.write(f"Error reading {filepath}: {e}\n")
117
+ return EXIT_ERROR
118
+
119
+ code, output = _process_source(source, filepath, args)
120
+
121
+ if args.command == "fix" and code == EXIT_VIOLATIONS and output:
122
+ path.write_text(output)
123
+ elif output:
124
+ _write_output(output)
125
+
126
+ return code
127
+
128
+
129
+ def _process_source(
130
+ source: str,
131
+ filename: str,
132
+ args: argparse.Namespace,
133
+ ) -> tuple[int, str]:
134
+ """Process a single source file and return (exit_code, output)."""
135
+ compute_rewrite = args.command != "check"
136
+ result = order_module(source, compute_rewrite=compute_rewrite)
137
+
138
+ if args.verbose and result.mutual_recursion_groups:
139
+ for group in result.mutual_recursion_groups:
140
+ sys.stderr.write(
141
+ f"{filename}: mutual recursion between {', '.join(group)}; original order preserved\n"
142
+ )
143
+
144
+ if args.command == "check":
145
+ output = format_violations(result.violations, filename=filename, fmt=args.fmt)
146
+ return (EXIT_VIOLATIONS if result.violations else EXIT_OK), output
147
+
148
+ if args.command == "diff":
149
+ if result.reordered_source is not None:
150
+ output = format_diff(source, result.reordered_source, filename=filename)
151
+ return EXIT_VIOLATIONS, output
152
+ return EXIT_OK, ""
153
+
154
+ # fix command
155
+ if result.reordered_source is not None:
156
+ return EXIT_VIOLATIONS, result.reordered_source
157
+ return EXIT_OK, ""
158
+
159
+
160
+ def _write_output(output: str) -> None:
161
+ """Write output to stdout with trailing newline if needed."""
162
+ sys.stdout.write(output)
163
+ if not output.endswith("\n"):
164
+ sys.stdout.write("\n")
165
+
166
+
167
+ if __name__ == "__main__":
168
+ sys.exit(main())
@@ -0,0 +1 @@
1
+ """Core analysis modules for flake8-stepdown."""
@@ -0,0 +1,148 @@
1
+ """Extract bindings (defined names) from module-level statements."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import libcst as cst
8
+ import libcst.matchers as m
9
+
10
+ from flake8_stepdown.types import Statement
11
+
12
+ if TYPE_CHECKING:
13
+ from collections.abc import Mapping
14
+
15
+
16
+ def extract_bindings(
17
+ statements: list[cst.CSTNode],
18
+ positions: Mapping[cst.CSTNode, cst.metadata.CodeRange],
19
+ ) -> list[Statement]:
20
+ """Extract bindings from module-level statements.
21
+
22
+ Args:
23
+ statements: The module-level CST nodes to analyze (functions, classes,
24
+ and assignments from the reorderable zone between preamble and postamble).
25
+ positions: Position mapping from MetadataWrapper.resolve(PositionProvider).
26
+
27
+ Groups consecutive @overload stubs with their implementation into a single Statement.
28
+ Returns Statement objects with empty refs (to be populated by references module).
29
+
30
+ """
31
+ result: list[Statement] = []
32
+ i = 0
33
+ nodes = list(statements)
34
+
35
+ while i < len(nodes):
36
+ node = nodes[i]
37
+
38
+ # Check for @overload grouping
39
+ if isinstance(node, cst.FunctionDef) and _has_overload_decorator(node):
40
+ func_name = node.name.value
41
+ group_nodes: list[cst.CSTNode] = [node]
42
+
43
+ # Collect consecutive same-name functions
44
+ j = i + 1
45
+ while j < len(nodes):
46
+ next_node = nodes[j]
47
+ if isinstance(next_node, cst.FunctionDef) and next_node.name.value == func_name:
48
+ group_nodes.append(next_node)
49
+ if not _has_overload_decorator(next_node):
50
+ j += 1
51
+ break
52
+ j += 1
53
+ else:
54
+ break
55
+
56
+ # Merge if stubs + implementation (>1 node and last is not overload)
57
+ last_node = group_nodes[-1]
58
+ if (
59
+ len(group_nodes) > 1
60
+ and isinstance(last_node, cst.FunctionDef)
61
+ and not _has_overload_decorator(last_node)
62
+ ):
63
+ first_pos = positions.get(group_nodes[0])
64
+ last_pos = positions.get(last_node)
65
+ result.append(
66
+ Statement(
67
+ node=last_node,
68
+ start_line=first_pos.start.line if first_pos else 0,
69
+ end_line=last_pos.end.line if last_pos else 0,
70
+ bindings=frozenset({func_name}),
71
+ immediate_refs=frozenset(),
72
+ deferred_refs=frozenset(),
73
+ is_overload_group=True,
74
+ ),
75
+ )
76
+ i = j
77
+ continue
78
+
79
+ # Not a complete overload group — fall through to normal handling
80
+
81
+ # Normal statement
82
+ pos = positions.get(node)
83
+ start_line = pos.start.line if pos else 0
84
+ end_line = pos.end.line if pos else 0
85
+
86
+ bindings = (
87
+ _extract_binding_names(node) if isinstance(node, cst.BaseStatement) else frozenset()
88
+ )
89
+ result.append(
90
+ Statement(
91
+ node=node,
92
+ start_line=start_line,
93
+ end_line=end_line,
94
+ bindings=bindings,
95
+ immediate_refs=frozenset(),
96
+ deferred_refs=frozenset(),
97
+ is_overload_group=False,
98
+ ),
99
+ )
100
+ i += 1
101
+
102
+ return result
103
+
104
+
105
+ def _extract_binding_names(node: cst.BaseStatement) -> frozenset[str]:
106
+ """Extract the names defined by a single statement."""
107
+ if isinstance(node, cst.FunctionDef):
108
+ return frozenset({node.name.value})
109
+
110
+ if isinstance(node, cst.ClassDef):
111
+ return frozenset({node.name.value})
112
+
113
+ if isinstance(node, cst.SimpleStatementLine):
114
+ names: set[str] = set()
115
+ for stmt in node.body:
116
+ if isinstance(stmt, cst.Assign):
117
+ for target in stmt.targets:
118
+ names |= _collect_names(target.target)
119
+ elif isinstance(stmt, cst.AnnAssign) and stmt.value is not None:
120
+ names |= _collect_names(stmt.target)
121
+ return frozenset(names)
122
+
123
+ return frozenset()
124
+
125
+
126
+ def _collect_names(target: cst.BaseExpression) -> set[str]:
127
+ """Recursively collect all Name identifiers from an assignment target."""
128
+ if isinstance(target, cst.Name):
129
+ return {target.value}
130
+ if isinstance(target, cst.Tuple):
131
+ names: set[str] = set()
132
+ for element in target.elements:
133
+ names |= _collect_names(element.value)
134
+ return names
135
+ if isinstance(target, cst.StarredElement):
136
+ return _collect_names(target.value)
137
+ return set()
138
+
139
+
140
+ def _has_overload_decorator(node: cst.FunctionDef) -> bool:
141
+ """Check if a FunctionDef has @typing.overload or @overload."""
142
+ for decorator in node.decorators:
143
+ dec = decorator.decorator
144
+ if m.matches(dec, m.Name("overload")):
145
+ return True
146
+ if m.matches(dec, m.Attribute(value=m.Name("typing"), attr=m.Name("overload"))):
147
+ return True
148
+ return False
@@ -0,0 +1,156 @@
1
+ """Dependency graph construction, topological sort, and SCC detection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import heapq
6
+ from typing import TYPE_CHECKING
7
+
8
+ from flake8_stepdown.core.parser import is_docstring, is_simple_assignment
9
+
10
+ if TYPE_CHECKING:
11
+ from flake8_stepdown.types import Statement
12
+
13
+
14
+ def build_normalized_graph(statements: list[Statement]) -> dict[int, set[int]]:
15
+ """Build a normalized dependency graph where edge A->B means "A must appear before B".
16
+
17
+ - Deferred ref: A calls B -> edge A->B (caller before callee)
18
+ - Immediate ref: A uses @B -> edge B->A (dependency before dependent)
19
+ """
20
+ # Build name -> index mapping
21
+ name_to_idx: dict[str, int] = {}
22
+ for idx, stmt in enumerate(statements):
23
+ for name in stmt.bindings:
24
+ name_to_idx[name] = idx
25
+
26
+ graph: dict[int, set[int]] = {i: set() for i in range(len(statements))}
27
+
28
+ for idx, stmt in enumerate(statements):
29
+ # Deferred refs: caller before callee -> edge idx -> target
30
+ for ref in stmt.deferred_refs:
31
+ if ref in name_to_idx:
32
+ target = name_to_idx[ref]
33
+ if target != idx:
34
+ graph[idx].add(target)
35
+
36
+ # Immediate refs: dependency before dependent -> edge target -> idx
37
+ for ref in stmt.immediate_refs:
38
+ if ref in name_to_idx:
39
+ target = name_to_idx[ref]
40
+ if target != idx:
41
+ graph[target].add(idx)
42
+
43
+ return graph
44
+
45
+
46
+ def topological_sort(graph: dict[int, set[int]], num_nodes: int) -> list[int] | None:
47
+ """Kahn's topological sort with min-heap stability tie-breaking.
48
+
49
+ Returns ordered list of node indices, or None if a cycle is detected.
50
+ """
51
+ # Compute in-degrees
52
+ in_degree = [0] * num_nodes
53
+ for successors in graph.values():
54
+ for s in successors:
55
+ in_degree[s] += 1
56
+
57
+ # Initialize min-heap with zero in-degree nodes (keyed by original index for stability)
58
+ heap: list[int] = [i for i in range(num_nodes) if in_degree[i] == 0]
59
+ heapq.heapify(heap)
60
+
61
+ result: list[int] = []
62
+ while heap:
63
+ node = heapq.heappop(heap)
64
+ result.append(node)
65
+ for successor in graph.get(node, set()):
66
+ in_degree[successor] -= 1
67
+ if in_degree[successor] == 0:
68
+ heapq.heappush(heap, successor)
69
+
70
+ if len(result) != num_nodes:
71
+ return None # Cycle detected
72
+
73
+ return result
74
+
75
+
76
+ def find_sccs(graph: dict[int, set[int]], num_nodes: int) -> list[list[int]]: # noqa: C901
77
+ """Find strongly connected components with size > 1 using Tarjan's algorithm."""
78
+ index_counter = [0]
79
+ stack: list[int] = []
80
+ on_stack = [False] * num_nodes
81
+ indices = [-1] * num_nodes
82
+ lowlinks = [-1] * num_nodes
83
+ result: list[list[int]] = []
84
+
85
+ def strongconnect(v: int) -> None:
86
+ indices[v] = index_counter[0]
87
+ lowlinks[v] = index_counter[0]
88
+ index_counter[0] += 1
89
+ stack.append(v)
90
+ on_stack[v] = True
91
+
92
+ for w in graph.get(v, set()):
93
+ if indices[w] == -1:
94
+ strongconnect(w)
95
+ lowlinks[v] = min(lowlinks[v], lowlinks[w])
96
+ elif on_stack[w]:
97
+ lowlinks[v] = min(lowlinks[v], indices[w])
98
+
99
+ if lowlinks[v] == indices[v]:
100
+ scc: list[int] = []
101
+ while True:
102
+ w = stack.pop()
103
+ on_stack[w] = False
104
+ scc.append(w)
105
+ if w == v:
106
+ break
107
+ if len(scc) > 1:
108
+ result.append(scc)
109
+
110
+ for v in range(num_nodes):
111
+ if indices[v] == -1:
112
+ strongconnect(v)
113
+
114
+ return result
115
+
116
+
117
+ def attach_no_binding_stmts(statements: list[Statement]) -> list[list[Statement]]:
118
+ """Group statements so that those with no bindings attach to their neighbor.
119
+
120
+ A statement with no bindings attaches to the next statement with bindings.
121
+ If it's the last statement, it attaches to the preceding one.
122
+
123
+ Returns a list of groups, where each group is one or more statements
124
+ that move together.
125
+ """
126
+ if not statements:
127
+ return []
128
+
129
+ groups: list[list[Statement]] = []
130
+ pending: list[Statement] = []
131
+
132
+ for stmt in statements:
133
+ if stmt.bindings:
134
+ # This statement has bindings — flush pending no-binding stmts as prefix
135
+ groups.append([*pending, stmt])
136
+ pending = []
137
+ elif (
138
+ groups
139
+ and not pending
140
+ and is_docstring(stmt.node)
141
+ and is_simple_assignment(groups[-1][-1].node)
142
+ ):
143
+ # Docstring immediately after a constant — attach to the constant's group
144
+ groups[-1].append(stmt)
145
+ else:
146
+ pending.append(stmt)
147
+
148
+ # Handle trailing no-binding statements: attach to last group
149
+ if pending:
150
+ if groups:
151
+ groups[-1].extend(pending)
152
+ else:
153
+ # All statements have no bindings — just return them as one group
154
+ groups.append(pending)
155
+
156
+ return groups
@@ -0,0 +1,184 @@
1
+ """Orchestrator: parse -> segment -> bindings -> refs -> graph -> sort -> violations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import libcst as cst
6
+
7
+ from flake8_stepdown.core.bindings import extract_bindings
8
+ from flake8_stepdown.core.graph import (
9
+ attach_no_binding_stmts,
10
+ build_normalized_graph,
11
+ find_sccs,
12
+ topological_sort,
13
+ )
14
+ from flake8_stepdown.core.parser import parse_source, segment
15
+ from flake8_stepdown.core.references import detect_future_annotations, extract_refs
16
+ from flake8_stepdown.rewriter import rewrite
17
+ from flake8_stepdown.types import OrderingResult, Statement, Violation
18
+
19
+ _EMPTY_RESULT = OrderingResult(violations=[], reordered_source=None, mutual_recursion_groups=[])
20
+
21
+
22
+ def order_module(source: str, *, compute_rewrite: bool = True) -> OrderingResult:
23
+ """Analyze and determine the correct ordering for a Python module.
24
+
25
+ Args:
26
+ source: Python source code.
27
+ compute_rewrite: Whether to compute the reordered source (default True).
28
+ Set to False when only violations are needed (e.g. flake8 plugin, check command).
29
+
30
+ Returns:
31
+ OrderingResult with violations and optionally reordered source.
32
+
33
+ """
34
+ if not source.strip():
35
+ return _EMPTY_RESULT
36
+
37
+ module = parse_source(source)
38
+ wrapper = cst.metadata.MetadataWrapper(module)
39
+ positions = wrapper.resolve(cst.metadata.PositionProvider)
40
+ seg = segment(wrapper.module)
41
+
42
+ if not seg.interstitials:
43
+ return _EMPTY_RESULT
44
+
45
+ # Extract bindings
46
+ statements = extract_bindings(seg.interstitials, positions)
47
+
48
+ # Extract references
49
+ has_future = detect_future_annotations(seg.preamble)
50
+ statements = extract_refs(statements, has_future_annotations=has_future)
51
+
52
+ # Attach no-binding statements
53
+ groups = attach_no_binding_stmts(statements)
54
+
55
+ # Build merged statements for graph (one per group)
56
+ merged: list[Statement] = []
57
+ for group in groups:
58
+ # Merge bindings and refs from all statements in the group
59
+ all_bindings: frozenset[str] = frozenset().union(*(s.bindings for s in group))
60
+ all_immediate: frozenset[str] = frozenset().union(*(s.immediate_refs for s in group))
61
+ all_deferred: frozenset[str] = frozenset().union(*(s.deferred_refs for s in group))
62
+ merged.append(
63
+ Statement(
64
+ node=group[0].node,
65
+ start_line=group[0].start_line,
66
+ end_line=group[-1].end_line,
67
+ bindings=all_bindings,
68
+ immediate_refs=all_immediate,
69
+ deferred_refs=all_deferred,
70
+ is_overload_group=group[0].is_overload_group,
71
+ ),
72
+ )
73
+
74
+ # Build graph and sort
75
+ graph = build_normalized_graph(merged)
76
+ num_nodes = len(merged)
77
+
78
+ # Detect SCCs
79
+ sccs = find_sccs(graph, num_nodes)
80
+
81
+ # For SCCs: remove internal edges and preserve original order
82
+ for scc in sccs:
83
+ scc_set = set(scc)
84
+ for node in scc:
85
+ graph[node] = {s for s in graph[node] if s not in scc_set}
86
+
87
+ # Topological sort
88
+ new_order = topological_sort(graph, num_nodes)
89
+
90
+ if new_order is None:
91
+ # Remaining cycles after SCC removal — shouldn't happen but handle gracefully
92
+ new_order = list(range(num_nodes))
93
+
94
+ # Check if order changed
95
+ changed = new_order != list(range(num_nodes))
96
+
97
+ # Generate violations and mutual recursion info
98
+ violations = _generate_violations(merged, new_order)
99
+ mutual_recursion_groups = _extract_mutual_recursion_groups(merged, sccs)
100
+
101
+ # Rewrite source if order changed and rewrite requested
102
+ reordered_source = None
103
+ if changed and compute_rewrite:
104
+ # Expand group order back to individual statement order
105
+ expanded_order: list[int] = []
106
+ offsets = []
107
+ offset = 0
108
+ for group in groups:
109
+ offsets.append(offset)
110
+ offset += len(group)
111
+
112
+ for group_idx in new_order:
113
+ group = groups[group_idx]
114
+ base = offsets[group_idx]
115
+ expanded_order.extend(base + j for j in range(len(group)))
116
+
117
+ all_nodes = [s.node for group in groups for s in group]
118
+
119
+ reordered_source = rewrite(
120
+ seg.module,
121
+ seg.preamble,
122
+ all_nodes,
123
+ seg.postamble,
124
+ expanded_order,
125
+ )
126
+
127
+ return OrderingResult(
128
+ violations=violations,
129
+ reordered_source=reordered_source,
130
+ mutual_recursion_groups=mutual_recursion_groups,
131
+ )
132
+
133
+
134
+ def _generate_violations(
135
+ statements: list[Statement],
136
+ new_order: list[int],
137
+ ) -> list[Violation]:
138
+ """Generate TDP001 violations from ordering differences."""
139
+ violations: list[Violation] = []
140
+
141
+ # Map from original index to new position
142
+ new_position = {orig_idx: new_pos for new_pos, orig_idx in enumerate(new_order)}
143
+
144
+ for orig_idx, stmt in enumerate(statements):
145
+ new_pos = new_position[orig_idx]
146
+ if new_pos != orig_idx:
147
+ # Find what it should come before
148
+ name = next(iter(stmt.bindings), "<unnamed>")
149
+ # Find the first statement that this one should precede
150
+ for other_idx in new_order:
151
+ if other_idx == orig_idx:
152
+ break
153
+ other = statements[other_idx]
154
+ other_name = next(iter(other.bindings), "<unnamed>")
155
+ if new_position[orig_idx] < new_position[other_idx]:
156
+ continue
157
+ violations.append(
158
+ Violation(
159
+ code="TDP001",
160
+ lineno=stmt.start_line,
161
+ col_offset=0,
162
+ name=name,
163
+ message=f"{name} should appear after {other_name}",
164
+ dependency=other_name,
165
+ ),
166
+ )
167
+ break
168
+
169
+ return violations
170
+
171
+
172
+ def _extract_mutual_recursion_groups(
173
+ statements: list[Statement],
174
+ sccs: list[list[int]],
175
+ ) -> list[list[str]]:
176
+ """Extract mutual recursion groups from SCCs as lists of function names."""
177
+ groups: list[list[str]] = []
178
+ for scc in sccs:
179
+ names = sorted(
180
+ {n for idx in scc for n in statements[idx].bindings},
181
+ )
182
+ if names:
183
+ groups.append(names)
184
+ return groups