gabion 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3173 @@
1
+ #!/usr/bin/env python3
2
+ """Infer forwarding-based parameter bundles and propagate them across calls.
3
+
4
+ This script performs a two-stage analysis:
5
+ 1) Local grouping: within a function, parameters used *only* as direct
6
+ call arguments are grouped by identical forwarding signatures.
7
+ 2) Propagation: if a function f calls g, and g has local bundles, then
8
+ f's parameters passed into g's bundled positions are linked as a
9
+ candidate bundle. This is iterated to a fixed point.
10
+
11
+ The goal is to surface "dataflow grammar" candidates for config dataclasses.
12
+
13
+ It can also emit a DOT graph (see --dot) so downstream tooling can render
14
+ bundle candidates as a dependency graph.
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import ast
20
+ import json
21
+ import os
22
+ import sys
23
+ from collections import defaultdict, deque
24
+ from dataclasses import dataclass, field, replace
25
+ from pathlib import Path
26
+ from typing import Iterable, Iterator
27
+ import re
28
+
29
+ from gabion.analysis.visitors import ImportVisitor, ParentAnnotator, UseVisitor
30
+ from gabion.config import dataflow_defaults, merge_payload
31
+ from gabion.schema import SynthesisResponse
32
+ from gabion.synthesis import NamingContext, SynthesisConfig, Synthesizer
33
+ from gabion.synthesis.merge import merge_bundles
34
+ from gabion.synthesis.schedule import topological_schedule
35
+
36
+ @dataclass
37
+ class ParamUse:
38
+ direct_forward: set[tuple[str, str]]
39
+ non_forward: bool
40
+ current_aliases: set[str]
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class CallArgs:
45
+ callee: str
46
+ pos_map: dict[str, str]
47
+ kw_map: dict[str, str]
48
+ const_pos: dict[str, str]
49
+ const_kw: dict[str, str]
50
+ non_const_pos: set[str]
51
+ non_const_kw: set[str]
52
+ star_pos: list[tuple[int, str]]
53
+ star_kw: list[str]
54
+ is_test: bool
55
+
56
+
57
+ @dataclass
58
+ class SymbolTable:
59
+ imports: dict[tuple[str, str], str] = field(default_factory=dict)
60
+ internal_roots: set[str] = field(default_factory=set)
61
+ external_filter: bool = True
62
+ star_imports: dict[str, set[str]] = field(default_factory=dict)
63
+ module_exports: dict[str, set[str]] = field(default_factory=dict)
64
+ module_export_map: dict[str, dict[str, str]] = field(default_factory=dict)
65
+
66
+ def resolve(self, current_module: str, name: str) -> str | None:
67
+ if (current_module, name) in self.imports:
68
+ fqn = self.imports[(current_module, name)]
69
+ if self.external_filter:
70
+ root = fqn.split(".")[0]
71
+ if root not in self.internal_roots:
72
+ return None
73
+ return fqn
74
+ return f"{current_module}.{name}"
75
+
76
+ def resolve_star(self, current_module: str, name: str) -> str | None:
77
+ candidates = self.star_imports.get(current_module, set())
78
+ if not candidates:
79
+ return None
80
+ for module in sorted(candidates):
81
+ exports = self.module_exports.get(module)
82
+ if exports is None or name not in exports:
83
+ continue
84
+ export_map = self.module_export_map.get(module, {})
85
+ mapped = export_map.get(name)
86
+ if mapped:
87
+ if self.external_filter and mapped:
88
+ root = mapped.split(".")[0]
89
+ if root not in self.internal_roots:
90
+ continue
91
+ return mapped
92
+ if self.external_filter and module:
93
+ root = module.split(".")[0]
94
+ if root not in self.internal_roots:
95
+ continue
96
+ if module:
97
+ return f"{module}.{name}"
98
+ return name
99
+ return None
100
+
101
+
102
+ @dataclass
103
+ class AuditConfig:
104
+ project_root: Path | None = None
105
+ exclude_dirs: set[str] = field(default_factory=set)
106
+ ignore_params: set[str] = field(default_factory=set)
107
+ external_filter: bool = True
108
+ strictness: str = "high"
109
+ transparent_decorators: set[str] | None = None
110
+
111
+ def is_ignored_path(self, path: Path) -> bool:
112
+ parts = set(path.parts)
113
+ return bool(self.exclude_dirs & parts)
114
+
115
+
116
+ def _call_context(node: ast.AST, parents: dict[ast.AST, ast.AST]) -> tuple[ast.Call | None, bool]:
117
+ child = node
118
+ parent = parents.get(child)
119
+ while parent is not None:
120
+ if isinstance(parent, ast.Call):
121
+ if child in parent.args:
122
+ return parent, True
123
+ for kw in parent.keywords:
124
+ if child is kw or child is kw.value:
125
+ return parent, True
126
+ return parent, False
127
+ child = parent
128
+ parent = parents.get(child)
129
+ return None, False
130
+
131
+
132
+ @dataclass
133
+ class AnalysisResult:
134
+ groups_by_path: dict[Path, dict[str, list[set[str]]]]
135
+ param_spans_by_path: dict[Path, dict[str, dict[str, tuple[int, int, int, int]]]]
136
+ type_suggestions: list[str]
137
+ type_ambiguities: list[str]
138
+ constant_smells: list[str]
139
+ unused_arg_smells: list[str]
140
+
141
+
142
+ def _callee_name(call: ast.Call) -> str:
143
+ try:
144
+ return ast.unparse(call.func)
145
+ except Exception:
146
+ return "<call>"
147
+
148
+
149
+ def _normalize_callee(name: str, class_name: str | None) -> str:
150
+ if not class_name:
151
+ return name
152
+ if name.startswith("self.") or name.startswith("cls."):
153
+ parts = name.split(".")
154
+ if len(parts) == 2:
155
+ return f"{class_name}.{parts[1]}"
156
+ return name
157
+
158
+
159
+ def _iter_paths(paths: Iterable[str], config: AuditConfig) -> list[Path]:
160
+ out: list[Path] = []
161
+ for p in paths:
162
+ path = Path(p)
163
+ if path.is_dir():
164
+ for candidate in sorted(path.rglob("*.py")):
165
+ if config.is_ignored_path(candidate):
166
+ continue
167
+ out.append(candidate)
168
+ else:
169
+ if config.is_ignored_path(path):
170
+ continue
171
+ out.append(path)
172
+ return out
173
+
174
+
175
+ def _collect_functions(tree: ast.AST):
176
+ funcs = []
177
+ for node in ast.walk(tree):
178
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
179
+ funcs.append(node)
180
+ return funcs
181
+
182
+
183
+ def _decorator_name(node: ast.AST) -> str | None:
184
+ if isinstance(node, ast.Name):
185
+ return node.id
186
+ if isinstance(node, ast.Attribute):
187
+ parts: list[str] = []
188
+ current: ast.AST = node
189
+ while isinstance(current, ast.Attribute):
190
+ parts.append(current.attr)
191
+ current = current.value
192
+ if isinstance(current, ast.Name):
193
+ parts.append(current.id)
194
+ return ".".join(reversed(parts))
195
+ return None
196
+ if isinstance(node, ast.Call):
197
+ return _decorator_name(node.func)
198
+ return None
199
+
200
+
201
+ def _decorator_matches(name: str, allowlist: set[str]) -> bool:
202
+ if name in allowlist:
203
+ return True
204
+ if "." in name and name.split(".")[-1] in allowlist:
205
+ return True
206
+ return False
207
+
208
+
209
+ def _decorators_transparent(
210
+ fn: ast.FunctionDef | ast.AsyncFunctionDef,
211
+ transparent_decorators: set[str] | None,
212
+ ) -> bool:
213
+ if not fn.decorator_list:
214
+ return True
215
+ if not transparent_decorators:
216
+ return True
217
+ for deco in fn.decorator_list:
218
+ name = _decorator_name(deco)
219
+ if not name:
220
+ return False
221
+ if not _decorator_matches(name, transparent_decorators):
222
+ return False
223
+ return True
224
+
225
+
226
+ def _collect_local_class_bases(
227
+ tree: ast.AST, parents: dict[ast.AST, ast.AST]
228
+ ) -> dict[str, list[str]]:
229
+ class_bases: dict[str, list[str]] = {}
230
+ for node in ast.walk(tree):
231
+ if not isinstance(node, ast.ClassDef):
232
+ continue
233
+ scopes = _enclosing_class_scopes(node, parents)
234
+ qual_parts = list(scopes)
235
+ qual_parts.append(node.name)
236
+ qual = ".".join(qual_parts)
237
+ bases: list[str] = []
238
+ for base in node.bases:
239
+ base_name = _base_identifier(base)
240
+ if base_name:
241
+ bases.append(base_name)
242
+ class_bases[qual] = bases
243
+ return class_bases
244
+
245
+
246
+ def _local_class_name(base: str, class_bases: dict[str, list[str]]) -> str | None:
247
+ if base in class_bases:
248
+ return base
249
+ if "." in base:
250
+ tail = base.split(".")[-1]
251
+ if tail in class_bases:
252
+ return tail
253
+ return None
254
+
255
+
256
+ def _resolve_local_method_in_hierarchy(
257
+ class_name: str,
258
+ method: str,
259
+ *,
260
+ class_bases: dict[str, list[str]],
261
+ local_functions: set[str],
262
+ seen: set[str],
263
+ ) -> str | None:
264
+ if class_name in seen:
265
+ return None
266
+ seen.add(class_name)
267
+ candidate = f"{class_name}.{method}"
268
+ if candidate in local_functions:
269
+ return candidate
270
+ for base in class_bases.get(class_name, []):
271
+ base_name = _local_class_name(base, class_bases)
272
+ if base_name is None:
273
+ continue
274
+ resolved = _resolve_local_method_in_hierarchy(
275
+ base_name,
276
+ method,
277
+ class_bases=class_bases,
278
+ local_functions=local_functions,
279
+ seen=seen,
280
+ )
281
+ if resolved is not None:
282
+ return resolved
283
+ return None
284
+
285
+
286
+ def _param_names(
287
+ fn: ast.FunctionDef | ast.AsyncFunctionDef,
288
+ ignore_params: set[str] | None = None,
289
+ ) -> list[str]:
290
+ args = (
291
+ fn.args.posonlyargs + fn.args.args + fn.args.kwonlyargs
292
+ )
293
+ names = [a.arg for a in args]
294
+ if fn.args.vararg:
295
+ names.append(fn.args.vararg.arg)
296
+ if fn.args.kwarg:
297
+ names.append(fn.args.kwarg.arg)
298
+ if names and names[0] in {"self", "cls"}:
299
+ names = names[1:]
300
+ if ignore_params:
301
+ names = [name for name in names if name not in ignore_params]
302
+ return names
303
+
304
+
305
+ def _node_span(node: ast.AST) -> tuple[int, int, int, int] | None:
306
+ if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
307
+ return None
308
+ start_line = max(getattr(node, "lineno", 1) - 1, 0)
309
+ start_col = max(getattr(node, "col_offset", 0), 0)
310
+ end_line = max(getattr(node, "end_lineno", getattr(node, "lineno", 1)) - 1, 0)
311
+ end_col = getattr(node, "end_col_offset", start_col + 1)
312
+ if end_line == start_line and end_col <= start_col:
313
+ end_col = start_col + 1
314
+ return (start_line, start_col, end_line, end_col)
315
+
316
+
317
+ def _param_spans(
318
+ fn: ast.FunctionDef | ast.AsyncFunctionDef,
319
+ ignore_params: set[str] | None = None,
320
+ ) -> dict[str, tuple[int, int, int, int]]:
321
+ spans: dict[str, tuple[int, int, int, int]] = {}
322
+ args = fn.args.posonlyargs + fn.args.args + fn.args.kwonlyargs
323
+ names = [a.arg for a in args]
324
+ if names and names[0] in {"self", "cls"}:
325
+ args = args[1:]
326
+ names = names[1:]
327
+ for arg in args:
328
+ if ignore_params and arg.arg in ignore_params:
329
+ continue
330
+ span = _node_span(arg)
331
+ if span is not None:
332
+ spans[arg.arg] = span
333
+ if fn.args.vararg:
334
+ name = fn.args.vararg.arg
335
+ if not ignore_params or name not in ignore_params:
336
+ span = _node_span(fn.args.vararg)
337
+ if span is not None:
338
+ spans[name] = span
339
+ if fn.args.kwarg:
340
+ name = fn.args.kwarg.arg
341
+ if not ignore_params or name not in ignore_params:
342
+ span = _node_span(fn.args.kwarg)
343
+ if span is not None:
344
+ spans[name] = span
345
+ return spans
346
+
347
+
348
+ def _function_key(scope: Iterable[str], name: str) -> str:
349
+ parts = list(scope)
350
+ parts.append(name)
351
+ if not parts:
352
+ return name
353
+ return ".".join(parts)
354
+
355
+
356
+ def _enclosing_class(
357
+ node: ast.AST, parents: dict[ast.AST, ast.AST]
358
+ ) -> str | None:
359
+ current = parents.get(node)
360
+ while current is not None:
361
+ if isinstance(current, ast.ClassDef):
362
+ return current.name
363
+ current = parents.get(current)
364
+ return None
365
+
366
+
367
+ def _enclosing_scopes(
368
+ node: ast.AST, parents: dict[ast.AST, ast.AST]
369
+ ) -> list[str]:
370
+ scopes: list[str] = []
371
+ current = parents.get(node)
372
+ while current is not None:
373
+ if isinstance(current, ast.ClassDef):
374
+ scopes.append(current.name)
375
+ elif isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef)):
376
+ scopes.append(current.name)
377
+ current = parents.get(current)
378
+ return list(reversed(scopes))
379
+
380
+
381
+ def _enclosing_class_scopes(
382
+ node: ast.AST, parents: dict[ast.AST, ast.AST]
383
+ ) -> list[str]:
384
+ scopes: list[str] = []
385
+ current = parents.get(node)
386
+ while current is not None:
387
+ if isinstance(current, ast.ClassDef):
388
+ scopes.append(current.name)
389
+ current = parents.get(current)
390
+ return list(reversed(scopes))
391
+
392
+
393
+ def _enclosing_function_scopes(
394
+ node: ast.AST, parents: dict[ast.AST, ast.AST]
395
+ ) -> list[str]:
396
+ scopes: list[str] = []
397
+ current = parents.get(node)
398
+ while current is not None:
399
+ if isinstance(current, (ast.FunctionDef, ast.AsyncFunctionDef)):
400
+ scopes.append(current.name)
401
+ current = parents.get(current)
402
+ return list(reversed(scopes))
403
+
404
+
405
+ def _param_annotations(
406
+ fn: ast.FunctionDef | ast.AsyncFunctionDef,
407
+ ignore_params: set[str] | None = None,
408
+ ) -> dict[str, str | None]:
409
+ args = fn.args.posonlyargs + fn.args.args + fn.args.kwonlyargs
410
+ names = [a.arg for a in args]
411
+ annots: dict[str, str | None] = {}
412
+ for name, arg in zip(names, args):
413
+ if arg.annotation is None:
414
+ annots[name] = None
415
+ else:
416
+ try:
417
+ annots[name] = ast.unparse(arg.annotation)
418
+ except Exception:
419
+ annots[name] = None
420
+ if fn.args.vararg:
421
+ annots[fn.args.vararg.arg] = None
422
+ if fn.args.kwarg:
423
+ annots[fn.args.kwarg.arg] = None
424
+ if names and names[0] in {"self", "cls"}:
425
+ annots.pop(names[0], None)
426
+ if ignore_params:
427
+ for name in list(annots.keys()):
428
+ if name in ignore_params:
429
+ annots.pop(name, None)
430
+ return annots
431
+
432
+
433
+ def _const_repr(node: ast.AST) -> str | None:
434
+ if isinstance(node, ast.Constant):
435
+ return repr(node.value)
436
+ if isinstance(node, ast.UnaryOp) and isinstance(
437
+ node.op, (ast.USub, ast.UAdd)
438
+ ) and isinstance(node.operand, ast.Constant):
439
+ try:
440
+ return ast.unparse(node)
441
+ except Exception:
442
+ return None
443
+ if isinstance(node, ast.Attribute):
444
+ if node.attr.isupper():
445
+ try:
446
+ return ast.unparse(node)
447
+ except Exception:
448
+ return None
449
+ return None
450
+ return None
451
+
452
+
453
+ def _type_from_const_repr(value: str) -> str | None:
454
+ try:
455
+ literal = ast.literal_eval(value)
456
+ except Exception:
457
+ return None
458
+ if literal is None:
459
+ return "None"
460
+ if isinstance(literal, bool):
461
+ return "bool"
462
+ if isinstance(literal, int):
463
+ return "int"
464
+ if isinstance(literal, float):
465
+ return "float"
466
+ if isinstance(literal, complex):
467
+ return "complex"
468
+ if isinstance(literal, str):
469
+ return "str"
470
+ if isinstance(literal, bytes):
471
+ return "bytes"
472
+ if isinstance(literal, list):
473
+ return "list"
474
+ if isinstance(literal, tuple):
475
+ return "tuple"
476
+ if isinstance(literal, set):
477
+ return "set"
478
+ if isinstance(literal, dict):
479
+ return "dict"
480
+ return None
481
+
482
+
483
+ def _is_test_path(path: Path) -> bool:
484
+ if "tests" in path.parts:
485
+ return True
486
+ return path.name.startswith("test_")
487
+
488
+
489
+ def _analyze_function(
490
+ fn: ast.FunctionDef | ast.AsyncFunctionDef,
491
+ parents: dict[ast.AST, ast.AST],
492
+ *,
493
+ is_test: bool,
494
+ ignore_params: set[str] | None = None,
495
+ strictness: str = "high",
496
+ class_name: str | None = None,
497
+ ) -> tuple[dict[str, ParamUse], list[CallArgs]]:
498
+ params = _param_names(fn, ignore_params)
499
+ use_map = {p: ParamUse(set(), False, {p}) for p in params}
500
+ alias_to_param: dict[str, str] = {p: p for p in params}
501
+ call_args: list[CallArgs] = []
502
+
503
+ visitor = UseVisitor(
504
+ parents=parents,
505
+ use_map=use_map,
506
+ call_args=call_args,
507
+ alias_to_param=alias_to_param,
508
+ is_test=is_test,
509
+ strictness=strictness,
510
+ const_repr=_const_repr,
511
+ callee_name=lambda call: _normalize_callee(_callee_name(call), class_name),
512
+ call_args_factory=CallArgs,
513
+ call_context=_call_context,
514
+ )
515
+ visitor.visit(fn)
516
+ return use_map, call_args
517
+
518
+
519
+ def _unused_params(use_map: dict[str, ParamUse]) -> set[str]:
520
+ unused: set[str] = set()
521
+ for name, info in use_map.items():
522
+ if info.non_forward:
523
+ continue
524
+ if info.direct_forward:
525
+ continue
526
+ unused.add(name)
527
+ return unused
528
+
529
+
530
+ def _group_by_signature(use_map: dict[str, ParamUse]) -> list[set[str]]:
531
+ sig_map: dict[tuple[tuple[str, str], ...], list[str]] = defaultdict(list)
532
+ for name, info in use_map.items():
533
+ if info.non_forward:
534
+ continue
535
+ sig = tuple(sorted(info.direct_forward))
536
+ sig_map[sig].append(name)
537
+ groups = [set(names) for names in sig_map.values() if len(names) > 1]
538
+ return groups
539
+
540
+
541
+ def _union_groups(groups: list[set[str]]) -> list[set[str]]:
542
+ changed = True
543
+ while changed:
544
+ changed = False
545
+ out = []
546
+ while groups:
547
+ base = groups.pop()
548
+ merged = True
549
+ while merged:
550
+ merged = False
551
+ for i, other in enumerate(groups):
552
+ if base & other:
553
+ base |= other
554
+ groups.pop(i)
555
+ merged = True
556
+ changed = True
557
+ break
558
+ out.append(base)
559
+ groups = out
560
+ return groups
561
+
562
+
563
+ def _propagate_groups(
564
+ call_args: list[CallArgs],
565
+ callee_groups: dict[str, list[set[str]]],
566
+ callee_param_orders: dict[str, list[str]],
567
+ strictness: str,
568
+ opaque_callees: set[str] | None = None,
569
+ ) -> list[set[str]]:
570
+ groups: list[set[str]] = []
571
+ for call in call_args:
572
+ if opaque_callees and call.callee in opaque_callees:
573
+ continue
574
+ if call.callee not in callee_groups:
575
+ continue
576
+ callee_params = callee_param_orders[call.callee]
577
+ # Build mapping from callee param to caller param.
578
+ callee_to_caller: dict[str, str] = {}
579
+ for idx, pname in enumerate(callee_params):
580
+ key = str(idx)
581
+ if key in call.pos_map:
582
+ callee_to_caller[pname] = call.pos_map[key]
583
+ for kw, caller_name in call.kw_map.items():
584
+ callee_to_caller[kw] = caller_name
585
+ if strictness == "low":
586
+ mapped = set(callee_to_caller.keys())
587
+ remaining = [p for p in callee_params if p not in mapped]
588
+ if len(call.star_pos) == 1:
589
+ _, star_param = call.star_pos[0]
590
+ for param in remaining:
591
+ callee_to_caller.setdefault(param, star_param)
592
+ if len(call.star_kw) == 1:
593
+ star_param = call.star_kw[0]
594
+ for param in remaining:
595
+ callee_to_caller.setdefault(param, star_param)
596
+ for group in callee_groups[call.callee]:
597
+ mapped = {callee_to_caller.get(p) for p in group}
598
+ mapped.discard(None)
599
+ if len(mapped) > 1:
600
+ groups.append(set(mapped))
601
+ return groups
602
+
603
+
604
+ def analyze_file(
605
+ path: Path,
606
+ recursive: bool = True,
607
+ *,
608
+ config: AuditConfig | None = None,
609
+ ) -> tuple[dict[str, list[set[str]]], dict[str, dict[str, tuple[int, int, int, int]]]]:
610
+ if config is None:
611
+ config = AuditConfig()
612
+ tree = ast.parse(path.read_text())
613
+ parent = ParentAnnotator()
614
+ parent.visit(tree)
615
+ parents = parent.parents
616
+ is_test = _is_test_path(path)
617
+
618
+ funcs = _collect_functions(tree)
619
+ fn_param_orders: dict[str, list[str]] = {}
620
+ fn_param_spans: dict[str, dict[str, tuple[int, int, int, int]]] = {}
621
+ fn_use = {}
622
+ fn_calls = {}
623
+ fn_names: dict[str, str] = {}
624
+ fn_lexical_scopes: dict[str, tuple[str, ...]] = {}
625
+ fn_class_names: dict[str, str | None] = {}
626
+ opaque_callees: set[str] = set()
627
+ for f in funcs:
628
+ class_name = _enclosing_class(f, parents)
629
+ scopes = _enclosing_scopes(f, parents)
630
+ lexical_scopes = _enclosing_function_scopes(f, parents)
631
+ fn_key = _function_key(scopes, f.name)
632
+ if not _decorators_transparent(f, config.transparent_decorators):
633
+ opaque_callees.add(fn_key)
634
+ use_map, call_args = _analyze_function(
635
+ f,
636
+ parents,
637
+ is_test=is_test,
638
+ ignore_params=config.ignore_params,
639
+ strictness=config.strictness,
640
+ class_name=class_name,
641
+ )
642
+ fn_use[fn_key] = use_map
643
+ fn_calls[fn_key] = call_args
644
+ fn_param_orders[fn_key] = _param_names(f, config.ignore_params)
645
+ fn_param_spans[fn_key] = _param_spans(f, config.ignore_params)
646
+ fn_names[fn_key] = f.name
647
+ fn_lexical_scopes[fn_key] = tuple(lexical_scopes)
648
+ fn_class_names[fn_key] = class_name
649
+
650
+ local_by_name: dict[str, list[str]] = defaultdict(list)
651
+ for key, name in fn_names.items():
652
+ local_by_name[name].append(key)
653
+
654
+ def _resolve_local_callee(callee: str, caller_key: str) -> str | None:
655
+ if "." in callee:
656
+ return None
657
+ candidates = local_by_name.get(callee, [])
658
+ if not candidates:
659
+ return None
660
+ effective_scope = list(fn_lexical_scopes.get(caller_key, ())) + [fn_names[caller_key]]
661
+ while True:
662
+ scoped = [
663
+ key
664
+ for key in candidates
665
+ if fn_lexical_scopes.get(key, ()) == tuple(effective_scope)
666
+ and not (fn_class_names.get(key) and not fn_lexical_scopes.get(key))
667
+ ]
668
+ if len(scoped) == 1:
669
+ return scoped[0]
670
+ if len(scoped) > 1:
671
+ return None
672
+ if not effective_scope:
673
+ break
674
+ effective_scope = effective_scope[:-1]
675
+ globals_only = [
676
+ key
677
+ for key in candidates
678
+ if not fn_lexical_scopes.get(key)
679
+ and not (fn_class_names.get(key) and not fn_lexical_scopes.get(key))
680
+ ]
681
+ if len(globals_only) == 1:
682
+ return globals_only[0]
683
+ return None
684
+
685
+ for caller_key, calls in list(fn_calls.items()):
686
+ resolved_calls: list[CallArgs] = []
687
+ for call in calls:
688
+ resolved = _resolve_local_callee(call.callee, caller_key)
689
+ if resolved:
690
+ resolved_calls.append(replace(call, callee=resolved))
691
+ else:
692
+ resolved_calls.append(call)
693
+ fn_calls[caller_key] = resolved_calls
694
+
695
+ class_bases = _collect_local_class_bases(tree, parents)
696
+ if class_bases:
697
+ local_functions = set(fn_use.keys())
698
+
699
+ def _resolve_local_method(callee: str) -> str | None:
700
+ if "." not in callee:
701
+ return None
702
+ class_part, method = callee.rsplit(".", 1)
703
+ return _resolve_local_method_in_hierarchy(
704
+ class_part,
705
+ method,
706
+ class_bases=class_bases,
707
+ local_functions=local_functions,
708
+ seen=set(),
709
+ )
710
+
711
+ for caller_key, calls in list(fn_calls.items()):
712
+ resolved_calls = []
713
+ for call in calls:
714
+ if "." in call.callee:
715
+ resolved = _resolve_local_method(call.callee)
716
+ if resolved and resolved != call.callee:
717
+ resolved_calls.append(replace(call, callee=resolved))
718
+ continue
719
+ resolved_calls.append(call)
720
+ fn_calls[caller_key] = resolved_calls
721
+
722
+ groups_by_fn = {fn: _group_by_signature(use_map) for fn, use_map in fn_use.items()}
723
+
724
+ if not recursive:
725
+ return groups_by_fn, fn_param_spans
726
+
727
+ changed = True
728
+ while changed:
729
+ changed = False
730
+ for fn in fn_use:
731
+ propagated = _propagate_groups(
732
+ fn_calls[fn],
733
+ groups_by_fn,
734
+ fn_param_orders,
735
+ config.strictness,
736
+ opaque_callees,
737
+ )
738
+ if not propagated:
739
+ continue
740
+ combined = _union_groups(groups_by_fn.get(fn, []) + propagated)
741
+ if combined != groups_by_fn.get(fn, []):
742
+ groups_by_fn[fn] = combined
743
+ changed = True
744
+ return groups_by_fn, fn_param_spans
745
+
746
+
747
+ def _callee_key(name: str) -> str:
748
+ if not name:
749
+ return name
750
+ return name.split(".")[-1]
751
+
752
+
753
+ def _is_broad_type(annot: str | None) -> bool:
754
+ if annot is None:
755
+ return True
756
+ base = annot.replace("typing.", "")
757
+ return base in {"Any", "object"}
758
+
759
+
760
+ _NONE_TYPES = {"None", "NoneType", "type(None)"}
761
+
762
+
763
+ def _split_top_level(value: str, sep: str) -> list[str]:
764
+ parts: list[str] = []
765
+ buf: list[str] = []
766
+ depth = 0
767
+ for ch in value:
768
+ if ch in "[({":
769
+ depth += 1
770
+ elif ch in "])}":
771
+ depth = max(depth - 1, 0)
772
+ if ch == sep and depth == 0:
773
+ part = "".join(buf).strip()
774
+ if part:
775
+ parts.append(part)
776
+ buf = []
777
+ continue
778
+ buf.append(ch)
779
+ tail = "".join(buf).strip()
780
+ if tail:
781
+ parts.append(tail)
782
+ return parts
783
+
784
+
785
+ def _expand_type_hint(hint: str) -> set[str]:
786
+ hint = hint.strip()
787
+ if not hint:
788
+ return set()
789
+ if hint.startswith("Optional[") and hint.endswith("]"):
790
+ inner = hint[len("Optional[") : -1]
791
+ return {_strip_type(t) for t in _split_top_level(inner, ",")} | {"None"}
792
+ if hint.startswith("Union[") and hint.endswith("]"):
793
+ inner = hint[len("Union[") : -1]
794
+ return {_strip_type(t) for t in _split_top_level(inner, ",")}
795
+ if "|" in hint:
796
+ return {_strip_type(t) for t in _split_top_level(hint, "|")}
797
+ return {hint}
798
+
799
+
800
+ def _strip_type(value: str) -> str:
801
+ return value.strip()
802
+
803
+
804
+ def _combine_type_hints(types: set[str]) -> tuple[str, bool]:
805
+ normalized_sets = []
806
+ for hint in types:
807
+ expanded = _expand_type_hint(hint)
808
+ normalized_sets.append(
809
+ tuple(sorted(t for t in expanded if t not in _NONE_TYPES))
810
+ )
811
+ unique_normalized = {norm for norm in normalized_sets if norm}
812
+ expanded: set[str] = set()
813
+ for hint in types:
814
+ expanded.update(_expand_type_hint(hint))
815
+ none_types = {t for t in expanded if t in _NONE_TYPES}
816
+ expanded -= none_types
817
+ if not expanded:
818
+ return "Any", bool(types)
819
+ sorted_types = sorted(expanded)
820
+ if len(sorted_types) == 1:
821
+ base = sorted_types[0]
822
+ if none_types:
823
+ conflicted = len(unique_normalized) > 1
824
+ return f"Optional[{base}]", conflicted
825
+ return base, len(unique_normalized) > 1
826
+ union = f"Union[{', '.join(sorted_types)}]"
827
+ if none_types:
828
+ return f"Optional[{union}]", len(unique_normalized) > 1
829
+ return union, len(unique_normalized) > 1
830
+
831
+
832
+ @dataclass
833
+ class FunctionInfo:
834
+ name: str
835
+ qual: str
836
+ path: Path
837
+ params: list[str]
838
+ annots: dict[str, str | None]
839
+ calls: list[CallArgs]
840
+ unused_params: set[str]
841
+ transparent: bool = True
842
+ class_name: str | None = None
843
+ scope: tuple[str, ...] = ()
844
+ lexical_scope: tuple[str, ...] = ()
845
+
846
+
847
+ @dataclass
848
+ class ClassInfo:
849
+ qual: str
850
+ module: str
851
+ bases: list[str]
852
+ methods: set[str]
853
+
854
+
855
+ def _module_name(path: Path, project_root: Path | None = None) -> str:
856
+ rel = path.with_suffix("")
857
+ if project_root is not None:
858
+ try:
859
+ rel = rel.relative_to(project_root)
860
+ except ValueError:
861
+ pass
862
+ parts = list(rel.parts)
863
+ if parts and parts[0] == "src":
864
+ parts = parts[1:]
865
+ return ".".join(parts)
866
+
867
+
868
+ def _string_list(node: ast.AST) -> list[str] | None:
869
+ if isinstance(node, (ast.List, ast.Tuple)):
870
+ values: list[str] = []
871
+ for elt in node.elts:
872
+ if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
873
+ values.append(elt.value)
874
+ else:
875
+ return None
876
+ return values
877
+ return None
878
+
879
+
880
+ def _base_identifier(node: ast.AST) -> str | None:
881
+ if isinstance(node, ast.Name):
882
+ return node.id
883
+ if isinstance(node, ast.Attribute):
884
+ try:
885
+ return ast.unparse(node)
886
+ except Exception:
887
+ return None
888
+ if isinstance(node, ast.Subscript):
889
+ return _base_identifier(node.value)
890
+ if isinstance(node, ast.Call):
891
+ return _base_identifier(node.func)
892
+ return None
893
+
894
+
895
+ def _collect_module_exports(
896
+ tree: ast.AST,
897
+ *,
898
+ module_name: str,
899
+ import_map: dict[str, str],
900
+ ) -> tuple[set[str], dict[str, str]]:
901
+ explicit_all: list[str] | None = None
902
+ for stmt in getattr(tree, "body", []):
903
+ if isinstance(stmt, ast.Assign):
904
+ targets = stmt.targets
905
+ if any(isinstance(t, ast.Name) and t.id == "__all__" for t in targets):
906
+ values = _string_list(stmt.value)
907
+ if values is not None:
908
+ explicit_all = list(values)
909
+ elif isinstance(stmt, ast.AnnAssign):
910
+ if isinstance(stmt.target, ast.Name) and stmt.target.id == "__all__":
911
+ values = _string_list(stmt.value) if stmt.value is not None else None
912
+ if values is not None:
913
+ explicit_all = list(values)
914
+ elif isinstance(stmt, ast.AugAssign):
915
+ if (
916
+ isinstance(stmt.target, ast.Name)
917
+ and stmt.target.id == "__all__"
918
+ and isinstance(stmt.op, ast.Add)
919
+ ):
920
+ values = _string_list(stmt.value)
921
+ if values is not None:
922
+ if explicit_all is None:
923
+ explicit_all = []
924
+ explicit_all.extend(values)
925
+
926
+ local_defs: set[str] = set()
927
+ for stmt in getattr(tree, "body", []):
928
+ if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
929
+ if not stmt.name.startswith("_"):
930
+ local_defs.add(stmt.name)
931
+ elif isinstance(stmt, ast.Assign):
932
+ for target in stmt.targets:
933
+ if isinstance(target, ast.Name) and not target.id.startswith("_"):
934
+ local_defs.add(target.id)
935
+ elif isinstance(stmt, ast.AnnAssign):
936
+ if isinstance(stmt.target, ast.Name) and not stmt.target.id.startswith("_"):
937
+ local_defs.add(stmt.target.id)
938
+
939
+ if explicit_all is not None:
940
+ export_names = set(explicit_all)
941
+ else:
942
+ export_names = set(local_defs) | {
943
+ name for name in import_map.keys() if not name.startswith("_")
944
+ }
945
+ export_names = {name for name in export_names if not name.startswith("_")}
946
+
947
+ export_map: dict[str, str] = {}
948
+ for name in export_names:
949
+ if name in import_map:
950
+ export_map[name] = import_map[name]
951
+ elif name in local_defs:
952
+ export_map[name] = f"{module_name}.{name}" if module_name else name
953
+ return export_names, export_map
954
+
955
+ def _build_symbol_table(
956
+ paths: list[Path],
957
+ project_root: Path | None,
958
+ *,
959
+ external_filter: bool,
960
+ ) -> SymbolTable:
961
+ table = SymbolTable(external_filter=external_filter)
962
+ for path in paths:
963
+ try:
964
+ tree = ast.parse(path.read_text())
965
+ except Exception:
966
+ continue
967
+ module = _module_name(path, project_root)
968
+ if module:
969
+ table.internal_roots.add(module.split(".")[0])
970
+ visitor = ImportVisitor(module, table)
971
+ visitor.visit(tree)
972
+ if module:
973
+ import_map = {
974
+ local: fqn
975
+ for (mod, local), fqn in table.imports.items()
976
+ if mod == module
977
+ }
978
+ exports, export_map = _collect_module_exports(
979
+ tree,
980
+ module_name=module,
981
+ import_map=import_map,
982
+ )
983
+ table.module_exports[module] = exports
984
+ table.module_export_map[module] = export_map
985
+ return table
986
+
987
+
988
+ def _collect_class_index(
989
+ paths: list[Path],
990
+ project_root: Path | None,
991
+ ) -> dict[str, ClassInfo]:
992
+ class_index: dict[str, ClassInfo] = {}
993
+ for path in paths:
994
+ try:
995
+ tree = ast.parse(path.read_text())
996
+ except Exception:
997
+ continue
998
+ parents = ParentAnnotator()
999
+ parents.visit(tree)
1000
+ module = _module_name(path, project_root)
1001
+ for node in ast.walk(tree):
1002
+ if not isinstance(node, ast.ClassDef):
1003
+ continue
1004
+ scopes = _enclosing_class_scopes(node, parents.parents)
1005
+ qual_parts = [module] if module else []
1006
+ qual_parts.extend(scopes)
1007
+ qual_parts.append(node.name)
1008
+ qual = ".".join(qual_parts)
1009
+ bases: list[str] = []
1010
+ for base in node.bases:
1011
+ base_name = _base_identifier(base)
1012
+ if base_name:
1013
+ bases.append(base_name)
1014
+ methods: set[str] = set()
1015
+ for stmt in node.body:
1016
+ if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
1017
+ methods.add(stmt.name)
1018
+ class_index[qual] = ClassInfo(
1019
+ qual=qual,
1020
+ module=module,
1021
+ bases=bases,
1022
+ methods=methods,
1023
+ )
1024
+ return class_index
1025
+
1026
+
1027
+ def _resolve_class_candidates(
1028
+ base: str,
1029
+ *,
1030
+ module: str,
1031
+ symbol_table: SymbolTable | None,
1032
+ class_index: dict[str, ClassInfo],
1033
+ ) -> list[str]:
1034
+ if not base:
1035
+ return []
1036
+ candidates: list[str] = []
1037
+ if "." in base:
1038
+ parts = base.split(".")
1039
+ head = parts[0]
1040
+ tail = ".".join(parts[1:])
1041
+ if symbol_table is not None:
1042
+ resolved_head = symbol_table.resolve(module, head)
1043
+ if resolved_head:
1044
+ candidates.append(f"{resolved_head}.{tail}")
1045
+ if module:
1046
+ candidates.append(f"{module}.{base}")
1047
+ candidates.append(base)
1048
+ else:
1049
+ if symbol_table is not None:
1050
+ resolved = symbol_table.resolve(module, base)
1051
+ if resolved:
1052
+ candidates.append(resolved)
1053
+ resolved_star = symbol_table.resolve_star(module, base)
1054
+ if resolved_star:
1055
+ candidates.append(resolved_star)
1056
+ if module:
1057
+ candidates.append(f"{module}.{base}")
1058
+ candidates.append(base)
1059
+ seen: set[str] = set()
1060
+ resolved: list[str] = []
1061
+ for candidate in candidates:
1062
+ if candidate in seen:
1063
+ continue
1064
+ seen.add(candidate)
1065
+ if candidate in class_index:
1066
+ resolved.append(candidate)
1067
+ return resolved
1068
+
1069
+
1070
+ def _resolve_method_in_hierarchy(
1071
+ class_qual: str,
1072
+ method: str,
1073
+ *,
1074
+ class_index: dict[str, ClassInfo],
1075
+ by_qual: dict[str, FunctionInfo],
1076
+ symbol_table: SymbolTable | None,
1077
+ seen: set[str],
1078
+ ) -> FunctionInfo | None:
1079
+ if class_qual in seen:
1080
+ return None
1081
+ seen.add(class_qual)
1082
+ candidate = f"{class_qual}.{method}"
1083
+ if candidate in by_qual:
1084
+ return by_qual[candidate]
1085
+ info = class_index.get(class_qual)
1086
+ if info is None:
1087
+ return None
1088
+ for base in info.bases:
1089
+ for base_qual in _resolve_class_candidates(
1090
+ base,
1091
+ module=info.module,
1092
+ symbol_table=symbol_table,
1093
+ class_index=class_index,
1094
+ ):
1095
+ resolved = _resolve_method_in_hierarchy(
1096
+ base_qual,
1097
+ method,
1098
+ class_index=class_index,
1099
+ by_qual=by_qual,
1100
+ symbol_table=symbol_table,
1101
+ seen=seen,
1102
+ )
1103
+ if resolved is not None:
1104
+ return resolved
1105
+ return None
1106
+
1107
+
1108
+ def _build_function_index(
1109
+ paths: list[Path],
1110
+ project_root: Path | None,
1111
+ ignore_params: set[str],
1112
+ strictness: str,
1113
+ transparent_decorators: set[str] | None = None,
1114
+ ) -> tuple[dict[str, list[FunctionInfo]], dict[str, FunctionInfo]]:
1115
+ by_name: dict[str, list[FunctionInfo]] = defaultdict(list)
1116
+ by_qual: dict[str, FunctionInfo] = {}
1117
+ for path in paths:
1118
+ try:
1119
+ tree = ast.parse(path.read_text())
1120
+ except Exception:
1121
+ continue
1122
+ funcs = _collect_functions(tree)
1123
+ if not funcs:
1124
+ continue
1125
+ parents = ParentAnnotator()
1126
+ parents.visit(tree)
1127
+ parent_map = parents.parents
1128
+ module = _module_name(path, project_root)
1129
+ for fn in funcs:
1130
+ class_name = _enclosing_class(fn, parent_map)
1131
+ scopes = _enclosing_scopes(fn, parent_map)
1132
+ lexical_scopes = _enclosing_function_scopes(fn, parent_map)
1133
+ use_map, call_args = _analyze_function(
1134
+ fn,
1135
+ parent_map,
1136
+ is_test=_is_test_path(path),
1137
+ ignore_params=ignore_params,
1138
+ strictness=strictness,
1139
+ class_name=class_name,
1140
+ )
1141
+ unused_params = _unused_params(use_map)
1142
+ qual_parts = [module] if module else []
1143
+ if scopes:
1144
+ qual_parts.extend(scopes)
1145
+ qual_parts.append(fn.name)
1146
+ qual = ".".join(qual_parts)
1147
+ info = FunctionInfo(
1148
+ name=fn.name,
1149
+ qual=qual,
1150
+ path=path,
1151
+ params=_param_names(fn, ignore_params),
1152
+ annots=_param_annotations(fn, ignore_params),
1153
+ calls=call_args,
1154
+ unused_params=unused_params,
1155
+ transparent=_decorators_transparent(fn, transparent_decorators),
1156
+ class_name=class_name,
1157
+ scope=tuple(scopes),
1158
+ lexical_scope=tuple(lexical_scopes),
1159
+ )
1160
+ by_name[fn.name].append(info)
1161
+ by_qual[info.qual] = info
1162
+ return by_name, by_qual
1163
+
1164
+
1165
+ def _resolve_callee(
1166
+ callee_key: str,
1167
+ caller: FunctionInfo,
1168
+ by_name: dict[str, list[FunctionInfo]],
1169
+ by_qual: dict[str, FunctionInfo],
1170
+ symbol_table: SymbolTable | None = None,
1171
+ project_root: Path | None = None,
1172
+ class_index: dict[str, ClassInfo] | None = None,
1173
+ ) -> FunctionInfo | None:
1174
+ # dataflow-bundle: by_name, caller
1175
+ if not callee_key:
1176
+ return None
1177
+ caller_module = _module_name(caller.path, project_root=project_root)
1178
+ candidates = by_name.get(_callee_key(callee_key), [])
1179
+ if "." not in callee_key:
1180
+ ambiguous = False
1181
+ effective_scope = list(caller.lexical_scope) + [caller.name]
1182
+ while True:
1183
+ scoped = [
1184
+ info
1185
+ for info in candidates
1186
+ if list(info.lexical_scope) == effective_scope
1187
+ and not (info.class_name and not info.lexical_scope)
1188
+ ]
1189
+ if len(scoped) == 1:
1190
+ return scoped[0]
1191
+ if len(scoped) > 1:
1192
+ ambiguous = True
1193
+ break
1194
+ if not effective_scope:
1195
+ break
1196
+ effective_scope = effective_scope[:-1]
1197
+ if ambiguous:
1198
+ pass
1199
+ globals_only = [
1200
+ info
1201
+ for info in candidates
1202
+ if not info.lexical_scope
1203
+ and not (info.class_name and not info.lexical_scope)
1204
+ and info.path == caller.path
1205
+ ]
1206
+ if len(globals_only) == 1:
1207
+ return globals_only[0]
1208
+ if symbol_table is not None:
1209
+ if "." not in callee_key:
1210
+ if (caller_module, callee_key) in symbol_table.imports:
1211
+ fqn = symbol_table.resolve(caller_module, callee_key)
1212
+ if fqn is None:
1213
+ return None
1214
+ if fqn in by_qual:
1215
+ return by_qual[fqn]
1216
+ resolved = symbol_table.resolve_star(caller_module, callee_key)
1217
+ if resolved is not None and resolved in by_qual:
1218
+ return by_qual[resolved]
1219
+ else:
1220
+ parts = callee_key.split(".")
1221
+ base = parts[0]
1222
+ if base in ("self", "cls"):
1223
+ method = parts[-1]
1224
+ if caller.class_name:
1225
+ candidate = f"{caller_module}.{caller.class_name}.{method}"
1226
+ if candidate in by_qual:
1227
+ return by_qual[candidate]
1228
+ elif len(parts) == 2:
1229
+ candidate = f"{caller_module}.{base}.{parts[1]}"
1230
+ if candidate in by_qual:
1231
+ return by_qual[candidate]
1232
+ if (caller_module, base) in symbol_table.imports:
1233
+ base_fqn = symbol_table.resolve(caller_module, base)
1234
+ if base_fqn is None:
1235
+ return None
1236
+ candidate = base_fqn + "." + ".".join(parts[1:])
1237
+ if candidate in by_qual:
1238
+ return by_qual[candidate]
1239
+ # Exact qualified name match.
1240
+ if callee_key in by_qual:
1241
+ return by_qual[callee_key]
1242
+ if class_index is not None and "." in callee_key:
1243
+ parts = callee_key.split(".")
1244
+ if len(parts) >= 2:
1245
+ method = parts[-1]
1246
+ class_part = ".".join(parts[:-1])
1247
+ if class_part in {"self", "cls"} and caller.class_name:
1248
+ class_candidates = _resolve_class_candidates(
1249
+ caller.class_name,
1250
+ module=caller_module,
1251
+ symbol_table=symbol_table,
1252
+ class_index=class_index,
1253
+ )
1254
+ else:
1255
+ class_candidates = _resolve_class_candidates(
1256
+ class_part,
1257
+ module=caller_module,
1258
+ symbol_table=symbol_table,
1259
+ class_index=class_index,
1260
+ )
1261
+ for class_qual in class_candidates:
1262
+ resolved = _resolve_method_in_hierarchy(
1263
+ class_qual,
1264
+ method,
1265
+ class_index=class_index,
1266
+ by_qual=by_qual,
1267
+ symbol_table=symbol_table,
1268
+ seen=set(),
1269
+ )
1270
+ if resolved is not None:
1271
+ return resolved
1272
+ return None
1273
+
1274
+
1275
+ def analyze_type_flow_repo_with_map(
1276
+ paths: list[Path],
1277
+ *,
1278
+ project_root: Path | None,
1279
+ ignore_params: set[str],
1280
+ strictness: str,
1281
+ external_filter: bool,
1282
+ transparent_decorators: set[str] | None = None,
1283
+ ) -> tuple[dict[str, dict[str, str | None]], list[str], list[str]]:
1284
+ """Repo-wide fixed-point pass for downstream type tightening."""
1285
+ by_name, by_qual = _build_function_index(
1286
+ paths,
1287
+ project_root,
1288
+ ignore_params,
1289
+ strictness,
1290
+ transparent_decorators,
1291
+ )
1292
+ symbol_table = _build_symbol_table(
1293
+ paths, project_root, external_filter=external_filter
1294
+ )
1295
+ class_index = _collect_class_index(paths, project_root)
1296
+ inferred: dict[str, dict[str, str | None]] = {}
1297
+ for infos in by_name.values():
1298
+ for info in infos:
1299
+ inferred[info.qual] = dict(info.annots)
1300
+
1301
+ def _get_annot(info: FunctionInfo, param: str) -> str | None:
1302
+ return inferred.get(info.qual, {}).get(param)
1303
+
1304
+ suggestions: set[str] = set()
1305
+ ambiguities: set[str] = set()
1306
+ changed = True
1307
+ while changed:
1308
+ changed = False
1309
+ for infos in by_name.values():
1310
+ for info in infos:
1311
+ if _is_test_path(info.path):
1312
+ continue
1313
+ downstream: dict[str, set[str]] = defaultdict(set)
1314
+ for call in info.calls:
1315
+ callee = _resolve_callee(
1316
+ call.callee,
1317
+ info,
1318
+ by_name,
1319
+ by_qual,
1320
+ symbol_table,
1321
+ project_root,
1322
+ class_index,
1323
+ )
1324
+ if callee is None:
1325
+ continue
1326
+ if not callee.transparent:
1327
+ continue
1328
+ callee_params = callee.params
1329
+ mapped_params: set[str] = set()
1330
+ callee_to_caller: dict[str, set[str]] = defaultdict(set)
1331
+ for pos_idx, param in call.pos_map.items():
1332
+ try:
1333
+ idx = int(pos_idx)
1334
+ except ValueError:
1335
+ continue
1336
+ if idx >= len(callee_params):
1337
+ continue
1338
+ callee_param = callee_params[idx]
1339
+ mapped_params.add(callee_param)
1340
+ callee_to_caller[callee_param].add(param)
1341
+ for kw_name, param in call.kw_map.items():
1342
+ if kw_name not in callee_params:
1343
+ continue
1344
+ mapped_params.add(kw_name)
1345
+ callee_to_caller[kw_name].add(param)
1346
+ if strictness == "low":
1347
+ remaining = [p for p in callee_params if p not in mapped_params]
1348
+ if len(call.star_pos) == 1:
1349
+ _, star_param = call.star_pos[0]
1350
+ for param in remaining:
1351
+ callee_to_caller[param].add(star_param)
1352
+ if len(call.star_kw) == 1:
1353
+ star_param = call.star_kw[0]
1354
+ for param in remaining:
1355
+ callee_to_caller[param].add(star_param)
1356
+ for callee_param, callers in callee_to_caller.items():
1357
+ annot = _get_annot(callee, callee_param)
1358
+ if not annot:
1359
+ continue
1360
+ for caller_param in callers:
1361
+ downstream[caller_param].add(annot)
1362
+ for param, annots in downstream.items():
1363
+ if not annots:
1364
+ continue
1365
+ if len(annots) > 1:
1366
+ ambiguities.add(
1367
+ f"{info.path.name}:{info.name}.{param} downstream types conflict: {sorted(annots)}"
1368
+ )
1369
+ continue
1370
+ downstream_annot = next(iter(annots))
1371
+ current = _get_annot(info, param)
1372
+ if _is_broad_type(current) and downstream_annot:
1373
+ if inferred[info.qual].get(param) != downstream_annot:
1374
+ inferred[info.qual][param] = downstream_annot
1375
+ changed = True
1376
+ suggestions.add(
1377
+ f"{info.path.name}:{info.name}.{param} can tighten to {downstream_annot}"
1378
+ )
1379
+ return inferred, sorted(suggestions), sorted(ambiguities)
1380
+
1381
+
1382
+ def analyze_type_flow_repo(
1383
+ paths: list[Path],
1384
+ *,
1385
+ project_root: Path | None,
1386
+ ignore_params: set[str],
1387
+ strictness: str,
1388
+ external_filter: bool,
1389
+ transparent_decorators: set[str] | None = None,
1390
+ ) -> tuple[list[str], list[str]]:
1391
+ inferred, suggestions, ambiguities = analyze_type_flow_repo_with_map(
1392
+ paths,
1393
+ project_root=project_root,
1394
+ ignore_params=ignore_params,
1395
+ strictness=strictness,
1396
+ external_filter=external_filter,
1397
+ transparent_decorators=transparent_decorators,
1398
+ )
1399
+ return suggestions, ambiguities
1400
+
1401
+
1402
+ def analyze_constant_flow_repo(
1403
+ paths: list[Path],
1404
+ *,
1405
+ project_root: Path | None,
1406
+ ignore_params: set[str],
1407
+ strictness: str,
1408
+ external_filter: bool,
1409
+ transparent_decorators: set[str] | None = None,
1410
+ ) -> list[str]:
1411
+ """Detect parameters that only receive a single constant value (non-test)."""
1412
+ by_name, by_qual = _build_function_index(
1413
+ paths,
1414
+ project_root,
1415
+ ignore_params,
1416
+ strictness,
1417
+ transparent_decorators,
1418
+ )
1419
+ symbol_table = _build_symbol_table(
1420
+ paths, project_root, external_filter=external_filter
1421
+ )
1422
+ class_index = _collect_class_index(paths, project_root)
1423
+ const_values: dict[tuple[str, str], set[str]] = defaultdict(set)
1424
+ non_const: dict[tuple[str, str], bool] = defaultdict(bool)
1425
+ call_counts: dict[tuple[str, str], int] = defaultdict(int)
1426
+
1427
+ for infos in by_name.values():
1428
+ for info in infos:
1429
+ for call in info.calls:
1430
+ if call.is_test:
1431
+ continue
1432
+ callee = _resolve_callee(
1433
+ call.callee,
1434
+ info,
1435
+ by_name,
1436
+ by_qual,
1437
+ symbol_table,
1438
+ project_root,
1439
+ class_index,
1440
+ )
1441
+ if callee is None:
1442
+ continue
1443
+ if not callee.transparent:
1444
+ continue
1445
+ callee_params = callee.params
1446
+ mapped_params = set()
1447
+ for idx_str in call.pos_map:
1448
+ try:
1449
+ idx = int(idx_str)
1450
+ except ValueError:
1451
+ continue
1452
+ if idx >= len(callee_params):
1453
+ continue
1454
+ mapped_params.add(callee_params[idx])
1455
+ for kw in call.kw_map:
1456
+ if kw in callee_params:
1457
+ mapped_params.add(kw)
1458
+ remaining = [p for p in callee_params if p not in mapped_params]
1459
+
1460
+ for idx_str, value in call.const_pos.items():
1461
+ try:
1462
+ idx = int(idx_str)
1463
+ except ValueError:
1464
+ continue
1465
+ if idx >= len(callee_params):
1466
+ continue
1467
+ key = (callee.qual, callee_params[idx])
1468
+ const_values[key].add(value)
1469
+ call_counts[key] += 1
1470
+ for idx_str in call.pos_map:
1471
+ try:
1472
+ idx = int(idx_str)
1473
+ except ValueError:
1474
+ continue
1475
+ if idx >= len(callee_params):
1476
+ continue
1477
+ key = (callee.qual, callee_params[idx])
1478
+ non_const[key] = True
1479
+ call_counts[key] += 1
1480
+ for idx_str in call.non_const_pos:
1481
+ try:
1482
+ idx = int(idx_str)
1483
+ except ValueError:
1484
+ continue
1485
+ if idx >= len(callee_params):
1486
+ continue
1487
+ key = (callee.qual, callee_params[idx])
1488
+ non_const[key] = True
1489
+ call_counts[key] += 1
1490
+ if strictness == "low":
1491
+ if len(call.star_pos) == 1:
1492
+ for param in remaining:
1493
+ key = (callee.qual, param)
1494
+ non_const[key] = True
1495
+ call_counts[key] += 1
1496
+
1497
+ for kw, value in call.const_kw.items():
1498
+ if kw not in callee_params:
1499
+ continue
1500
+ key = (callee.qual, kw)
1501
+ const_values[key].add(value)
1502
+ call_counts[key] += 1
1503
+ for kw in call.kw_map:
1504
+ if kw not in callee_params:
1505
+ continue
1506
+ key = (callee.qual, kw)
1507
+ non_const[key] = True
1508
+ call_counts[key] += 1
1509
+ for kw in call.non_const_kw:
1510
+ if kw not in callee_params:
1511
+ continue
1512
+ key = (callee.qual, kw)
1513
+ non_const[key] = True
1514
+ call_counts[key] += 1
1515
+ if strictness == "low":
1516
+ if len(call.star_kw) == 1:
1517
+ for param in remaining:
1518
+ key = (callee.qual, param)
1519
+ non_const[key] = True
1520
+ call_counts[key] += 1
1521
+
1522
+ smells: list[str] = []
1523
+ for key, values in const_values.items():
1524
+ if non_const.get(key):
1525
+ continue
1526
+ if not values:
1527
+ continue
1528
+ if len(values) == 1:
1529
+ qual, param = key
1530
+ info = by_qual.get(qual)
1531
+ path_name = info.path.name if info is not None else qual
1532
+ count = call_counts.get(key, 0)
1533
+ smells.append(
1534
+ f"{path_name}:{qual.split('.')[-1]}.{param} only observed constant {next(iter(values))} across {count} non-test call(s)"
1535
+ )
1536
+ return sorted(smells)
1537
+
1538
+
1539
+ def analyze_unused_arg_flow_repo(
1540
+ paths: list[Path],
1541
+ *,
1542
+ project_root: Path | None,
1543
+ ignore_params: set[str],
1544
+ strictness: str,
1545
+ external_filter: bool,
1546
+ transparent_decorators: set[str] | None = None,
1547
+ ) -> list[str]:
1548
+ """Detect non-constant arguments passed into unused callee parameters."""
1549
+ by_name, by_qual = _build_function_index(
1550
+ paths,
1551
+ project_root,
1552
+ ignore_params,
1553
+ strictness,
1554
+ transparent_decorators,
1555
+ )
1556
+ symbol_table = _build_symbol_table(
1557
+ paths, project_root, external_filter=external_filter
1558
+ )
1559
+ class_index = _collect_class_index(paths, project_root)
1560
+ smells: set[str] = set()
1561
+
1562
+ def _format(
1563
+ caller: FunctionInfo,
1564
+ callee: FunctionInfo,
1565
+ callee_param: str,
1566
+ arg_desc: str,
1567
+ ) -> str:
1568
+ # dataflow-bundle: callee, caller
1569
+ return (
1570
+ f"{caller.path.name}:{caller.name} passes {arg_desc} "
1571
+ f"to unused {callee.path.name}:{callee.name}.{callee_param}"
1572
+ )
1573
+
1574
+ for infos in by_name.values():
1575
+ for info in infos:
1576
+ for call in info.calls:
1577
+ if call.is_test:
1578
+ continue
1579
+ callee = _resolve_callee(
1580
+ call.callee,
1581
+ info,
1582
+ by_name,
1583
+ by_qual,
1584
+ symbol_table,
1585
+ project_root,
1586
+ class_index,
1587
+ )
1588
+ if callee is None:
1589
+ continue
1590
+ if not callee.transparent:
1591
+ continue
1592
+ if not callee.unused_params:
1593
+ continue
1594
+ callee_params = callee.params
1595
+ mapped_params = set()
1596
+ for idx_str in call.pos_map:
1597
+ try:
1598
+ idx = int(idx_str)
1599
+ except ValueError:
1600
+ continue
1601
+ if idx >= len(callee_params):
1602
+ continue
1603
+ mapped_params.add(callee_params[idx])
1604
+ for kw in call.kw_map:
1605
+ if kw in callee_params:
1606
+ mapped_params.add(kw)
1607
+ remaining = [
1608
+ (idx, name)
1609
+ for idx, name in enumerate(callee_params)
1610
+ if name not in mapped_params
1611
+ ]
1612
+
1613
+ for idx_str, caller_param in call.pos_map.items():
1614
+ try:
1615
+ idx = int(idx_str)
1616
+ except ValueError:
1617
+ continue
1618
+ if idx >= len(callee_params):
1619
+ continue
1620
+ callee_param = callee_params[idx]
1621
+ if callee_param in callee.unused_params:
1622
+ smells.add(
1623
+ _format(
1624
+ info,
1625
+ callee,
1626
+ callee_param,
1627
+ f"param {caller_param}",
1628
+ )
1629
+ )
1630
+ for idx_str in call.non_const_pos:
1631
+ try:
1632
+ idx = int(idx_str)
1633
+ except ValueError:
1634
+ continue
1635
+ if idx >= len(callee_params):
1636
+ continue
1637
+ callee_param = callee_params[idx]
1638
+ if callee_param in callee.unused_params:
1639
+ smells.add(
1640
+ _format(
1641
+ info,
1642
+ callee,
1643
+ callee_param,
1644
+ f"non-constant arg at position {idx}",
1645
+ )
1646
+ )
1647
+ for kw, caller_param in call.kw_map.items():
1648
+ if kw not in callee_params:
1649
+ continue
1650
+ if kw in callee.unused_params:
1651
+ smells.add(
1652
+ _format(
1653
+ info,
1654
+ callee,
1655
+ kw,
1656
+ f"param {caller_param}",
1657
+ )
1658
+ )
1659
+ for kw in call.non_const_kw:
1660
+ if kw not in callee_params:
1661
+ continue
1662
+ if kw in callee.unused_params:
1663
+ smells.add(
1664
+ _format(
1665
+ info,
1666
+ callee,
1667
+ kw,
1668
+ f"non-constant kw '{kw}'",
1669
+ )
1670
+ )
1671
+ if strictness == "low":
1672
+ if len(call.star_pos) == 1:
1673
+ for idx, param in remaining:
1674
+ if param in callee.unused_params:
1675
+ smells.add(
1676
+ _format(
1677
+ info,
1678
+ callee,
1679
+ param,
1680
+ f"non-constant arg at position {idx}",
1681
+ )
1682
+ )
1683
+ if len(call.star_kw) == 1:
1684
+ for _, param in remaining:
1685
+ if param in callee.unused_params:
1686
+ smells.add(
1687
+ _format(
1688
+ info,
1689
+ callee,
1690
+ param,
1691
+ f"non-constant kw '{param}'",
1692
+ )
1693
+ )
1694
+ return sorted(smells)
1695
+
1696
+
1697
+ def _iter_config_fields(path: Path) -> dict[str, set[str]]:
1698
+ """Best-effort extraction of config bundles from dataclasses."""
1699
+ try:
1700
+ tree = ast.parse(path.read_text())
1701
+ except Exception:
1702
+ return {}
1703
+ bundles: dict[str, set[str]] = {}
1704
+ for node in ast.walk(tree):
1705
+ if not isinstance(node, ast.ClassDef):
1706
+ continue
1707
+ decorators = {getattr(d, "id", None) for d in node.decorator_list}
1708
+ is_dataclass = "dataclass" in decorators
1709
+ is_config = node.name.endswith("Config")
1710
+ if not is_dataclass and not is_config:
1711
+ continue
1712
+ fields: set[str] = set()
1713
+ for stmt in node.body:
1714
+ if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
1715
+ name = stmt.target.id
1716
+ if is_config or name.endswith("_fn"):
1717
+ fields.add(name)
1718
+ elif isinstance(stmt, ast.Assign):
1719
+ for target in stmt.targets:
1720
+ if isinstance(target, ast.Name):
1721
+ if is_config or target.id.endswith("_fn"):
1722
+ fields.add(target.id)
1723
+ if fields:
1724
+ bundles[node.name] = fields
1725
+ return bundles
1726
+
1727
+
1728
+ def _collect_config_bundles(paths: list[Path]) -> dict[Path, dict[str, set[str]]]:
1729
+ bundles_by_path: dict[Path, dict[str, set[str]]] = {}
1730
+ for path in paths:
1731
+ bundles = _iter_config_fields(path)
1732
+ if bundles:
1733
+ bundles_by_path[path] = bundles
1734
+ return bundles_by_path
1735
+
1736
+
1737
+ _BUNDLE_MARKER = re.compile(r"dataflow-bundle:\s*(.*)")
1738
+
1739
+
1740
+ def _iter_documented_bundles(path: Path) -> set[tuple[str, ...]]:
1741
+ """Return bundles documented via '# dataflow-bundle: a, b' markers."""
1742
+ bundles: set[tuple[str, ...]] = set()
1743
+ try:
1744
+ text = path.read_text()
1745
+ except Exception:
1746
+ return bundles
1747
+ for line in text.splitlines():
1748
+ match = _BUNDLE_MARKER.search(line)
1749
+ if not match:
1750
+ continue
1751
+ payload = match.group(1)
1752
+ if not payload:
1753
+ continue
1754
+ parts = [p.strip() for p in re.split(r"[,\s]+", payload) if p.strip()]
1755
+ if len(parts) < 2:
1756
+ continue
1757
+ bundles.add(tuple(sorted(parts)))
1758
+ return bundles
1759
+
1760
+
1761
+ def _collect_dataclass_registry(
1762
+ paths: list[Path],
1763
+ *,
1764
+ project_root: Path | None,
1765
+ ) -> dict[str, list[str]]:
1766
+ registry: dict[str, list[str]] = {}
1767
+ for path in paths:
1768
+ try:
1769
+ tree = ast.parse(path.read_text())
1770
+ except Exception:
1771
+ continue
1772
+ module = _module_name(path, project_root)
1773
+ for node in ast.walk(tree):
1774
+ if not isinstance(node, ast.ClassDef):
1775
+ continue
1776
+ decorators = {
1777
+ ast.unparse(dec) if hasattr(ast, "unparse") else ""
1778
+ for dec in node.decorator_list
1779
+ }
1780
+ if not any("dataclass" in dec for dec in decorators):
1781
+ continue
1782
+ fields: list[str] = []
1783
+ for stmt in node.body:
1784
+ if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
1785
+ fields.append(stmt.target.id)
1786
+ elif isinstance(stmt, ast.Assign):
1787
+ for target in stmt.targets:
1788
+ if isinstance(target, ast.Name):
1789
+ fields.append(target.id)
1790
+ if not fields:
1791
+ continue
1792
+ if module:
1793
+ registry[f"{module}.{node.name}"] = fields
1794
+ else:
1795
+ registry[node.name] = fields
1796
+ return registry
1797
+
1798
+
1799
+ def _iter_dataclass_call_bundles(
1800
+ path: Path,
1801
+ *,
1802
+ project_root: Path | None = None,
1803
+ symbol_table: SymbolTable | None = None,
1804
+ dataclass_registry: dict[str, list[str]] | None = None,
1805
+ ) -> set[tuple[str, ...]]:
1806
+ """Return bundles promoted via @dataclass constructor calls."""
1807
+ bundles: set[tuple[str, ...]] = set()
1808
+ try:
1809
+ tree = ast.parse(path.read_text())
1810
+ except Exception:
1811
+ return bundles
1812
+ module = _module_name(path, project_root)
1813
+ local_dataclasses: dict[str, list[str]] = {}
1814
+ for node in ast.walk(tree):
1815
+ if not isinstance(node, ast.ClassDef):
1816
+ continue
1817
+ decorators = {
1818
+ ast.unparse(dec) if hasattr(ast, "unparse") else ""
1819
+ for dec in node.decorator_list
1820
+ }
1821
+ if any("dataclass" in dec for dec in decorators):
1822
+ fields: list[str] = []
1823
+ for stmt in node.body:
1824
+ if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
1825
+ fields.append(stmt.target.id)
1826
+ elif isinstance(stmt, ast.Assign):
1827
+ for target in stmt.targets:
1828
+ if isinstance(target, ast.Name):
1829
+ fields.append(target.id)
1830
+ if fields:
1831
+ local_dataclasses[node.name] = fields
1832
+ if dataclass_registry is None:
1833
+ dataclass_registry = {}
1834
+ for name, fields in local_dataclasses.items():
1835
+ if module:
1836
+ dataclass_registry[f"{module}.{name}"] = fields
1837
+ else:
1838
+ dataclass_registry[name] = fields
1839
+
1840
+ def _callee_name(call: ast.Call) -> str | None:
1841
+ if isinstance(call.func, ast.Name):
1842
+ return call.func.id
1843
+ if isinstance(call.func, ast.Attribute):
1844
+ return call.func.attr
1845
+ return None
1846
+
1847
+ def _resolve_fields(call: ast.Call) -> list[str] | None:
1848
+ if isinstance(call.func, ast.Name):
1849
+ name = call.func.id
1850
+ if name in local_dataclasses:
1851
+ return local_dataclasses[name]
1852
+ if module:
1853
+ candidate = f"{module}.{name}"
1854
+ if candidate in dataclass_registry:
1855
+ return dataclass_registry[candidate]
1856
+ if symbol_table is not None and module:
1857
+ resolved = symbol_table.resolve(module, name)
1858
+ if resolved in dataclass_registry:
1859
+ return dataclass_registry[resolved]
1860
+ resolved_star = symbol_table.resolve_star(module, name)
1861
+ if resolved_star in dataclass_registry:
1862
+ return dataclass_registry[resolved_star]
1863
+ if name in dataclass_registry:
1864
+ return dataclass_registry[name]
1865
+ if isinstance(call.func, ast.Attribute):
1866
+ if isinstance(call.func.value, ast.Name):
1867
+ base = call.func.value.id
1868
+ attr = call.func.attr
1869
+ if symbol_table is not None and module:
1870
+ base_fqn = symbol_table.resolve(module, base)
1871
+ if base_fqn:
1872
+ candidate = f"{base_fqn}.{attr}"
1873
+ if candidate in dataclass_registry:
1874
+ return dataclass_registry[candidate]
1875
+ base_star = symbol_table.resolve_star(module, base)
1876
+ if base_star:
1877
+ candidate = f"{base_star}.{attr}"
1878
+ if candidate in dataclass_registry:
1879
+ return dataclass_registry[candidate]
1880
+ return None
1881
+
1882
+ for node in ast.walk(tree):
1883
+ if not isinstance(node, ast.Call):
1884
+ continue
1885
+ fields = _resolve_fields(node)
1886
+ if not fields:
1887
+ continue
1888
+ names: list[str] = []
1889
+ ok = True
1890
+ for idx, arg in enumerate(node.args):
1891
+ if isinstance(arg, ast.Starred):
1892
+ ok = False
1893
+ break
1894
+ if idx < len(fields):
1895
+ names.append(fields[idx])
1896
+ else:
1897
+ ok = False
1898
+ break
1899
+ if not ok:
1900
+ continue
1901
+ for kw in node.keywords:
1902
+ if kw.arg is None:
1903
+ ok = False
1904
+ break
1905
+ names.append(kw.arg)
1906
+ if not ok or len(names) < 2:
1907
+ continue
1908
+ bundles.add(tuple(sorted(names)))
1909
+ return bundles
1910
+
1911
+
1912
+ def _emit_dot(groups_by_path: dict[Path, dict[str, list[set[str]]]]) -> str:
1913
+ lines = [
1914
+ "digraph dataflow_grammar {",
1915
+ " rankdir=LR;",
1916
+ " node [fontsize=10];",
1917
+ ]
1918
+ for path, groups in groups_by_path.items():
1919
+ file_id = str(path).replace("/", "_").replace(".", "_")
1920
+ lines.append(f" subgraph cluster_{file_id} {{")
1921
+ lines.append(f" label=\"{path}\";")
1922
+ for fn, bundles in groups.items():
1923
+ if not bundles:
1924
+ continue
1925
+ fn_id = f"fn_{file_id}_{fn}"
1926
+ lines.append(f" {fn_id} [shape=box,label=\"{fn}\"];")
1927
+ for idx, bundle in enumerate(bundles):
1928
+ bundle_id = f"b_{file_id}_{fn}_{idx}"
1929
+ label = ", ".join(sorted(bundle))
1930
+ lines.append(
1931
+ f" {bundle_id} [shape=ellipse,label=\"{label}\"];"
1932
+ )
1933
+ lines.append(f" {fn_id} -> {bundle_id};")
1934
+ lines.append(" }")
1935
+ lines.append("}")
1936
+ return "\n".join(lines)
1937
+
1938
+
1939
+ def _component_graph(groups_by_path: dict[Path, dict[str, list[set[str]]]]):
1940
+ nodes: dict[str, dict[str, str]] = {}
1941
+ adj: dict[str, set[str]] = defaultdict(set)
1942
+ bundle_map: dict[str, set[str]] = {}
1943
+ for path, groups in groups_by_path.items():
1944
+ file_id = str(path)
1945
+ for fn, bundles in groups.items():
1946
+ if not bundles:
1947
+ continue
1948
+ fn_id = f"fn::{file_id}::{fn}"
1949
+ nodes[fn_id] = {"kind": "fn", "label": f"{path.name}:{fn}"}
1950
+ for idx, bundle in enumerate(bundles):
1951
+ bundle_id = f"b::{file_id}::{fn}::{idx}"
1952
+ nodes[bundle_id] = {
1953
+ "kind": "bundle",
1954
+ "label": ", ".join(sorted(bundle)),
1955
+ }
1956
+ bundle_map[bundle_id] = bundle
1957
+ adj[fn_id].add(bundle_id)
1958
+ adj[bundle_id].add(fn_id)
1959
+ return nodes, adj, bundle_map
1960
+
1961
+
1962
+ def _connected_components(nodes: dict[str, dict[str, str]], adj: dict[str, set[str]]) -> list[list[str]]:
1963
+ seen: set[str] = set()
1964
+ comps: list[list[str]] = []
1965
+ for node in nodes:
1966
+ if node in seen:
1967
+ continue
1968
+ q: deque[str] = deque([node])
1969
+ seen.add(node)
1970
+ comp: list[str] = []
1971
+ while q:
1972
+ curr = q.popleft()
1973
+ comp.append(curr)
1974
+ for nxt in adj.get(curr, ()):
1975
+ if nxt not in seen:
1976
+ seen.add(nxt)
1977
+ q.append(nxt)
1978
+ comps.append(sorted(comp))
1979
+ return comps
1980
+
1981
+
1982
+ def _render_mermaid_component(
1983
+ nodes: dict[str, dict[str, str]],
1984
+ bundle_map: dict[str, set[str]],
1985
+ adj: dict[str, set[str]],
1986
+ component: list[str],
1987
+ config_bundles_by_path: dict[Path, dict[str, set[str]]],
1988
+ documented_bundles_by_path: dict[Path, set[tuple[str, ...]]],
1989
+ ) -> tuple[str, str]:
1990
+ # dataflow-bundle: adj, config_bundles_by_path, documented_bundles_by_path, nodes
1991
+ lines = ["```mermaid", "flowchart LR"]
1992
+ fn_nodes = [n for n in component if nodes[n]["kind"] == "fn"]
1993
+ bundle_nodes = [n for n in component if nodes[n]["kind"] == "bundle"]
1994
+ for n in fn_nodes:
1995
+ label = nodes[n]["label"].replace('"', "'")
1996
+ lines.append(f' {abs(hash(n))}["{label}"]')
1997
+ for n in bundle_nodes:
1998
+ label = nodes[n]["label"].replace('"', "'")
1999
+ lines.append(f' {abs(hash(n))}(({label}))')
2000
+ for n in component:
2001
+ for nxt in adj.get(n, ()):
2002
+ if nxt in component and nodes[n]["kind"] == "fn":
2003
+ lines.append(f" {abs(hash(n))} --> {abs(hash(nxt))}")
2004
+ lines.append(" classDef fn fill:#cfe8ff,stroke:#2b6cb0,stroke-width:1px;")
2005
+ lines.append(" classDef bundle fill:#ffe9c6,stroke:#c05621,stroke-width:1px;")
2006
+ if fn_nodes:
2007
+ lines.append(
2008
+ " class "
2009
+ + ",".join(str(abs(hash(n))) for n in fn_nodes)
2010
+ + " fn;"
2011
+ )
2012
+ if bundle_nodes:
2013
+ lines.append(
2014
+ " class "
2015
+ + ",".join(str(abs(hash(n))) for n in bundle_nodes)
2016
+ + " bundle;"
2017
+ )
2018
+ lines.append("```")
2019
+
2020
+ observed = [bundle_map[n] for n in bundle_nodes if n in bundle_map]
2021
+ bundle_counts: dict[tuple[str, ...], int] = defaultdict(int)
2022
+ for bundle in observed:
2023
+ bundle_counts[tuple(sorted(bundle))] += 1
2024
+ component_paths: set[Path] = set()
2025
+ for n in fn_nodes:
2026
+ parts = n.split("::", 2)
2027
+ if len(parts) == 3:
2028
+ component_paths.add(Path(parts[1]))
2029
+ declared_global = set()
2030
+ for bundles in config_bundles_by_path.values():
2031
+ for fields in bundles.values():
2032
+ declared_global.add(tuple(sorted(fields)))
2033
+ declared_local = set()
2034
+ documented = set()
2035
+ for path in component_paths:
2036
+ bundles = config_bundles_by_path.get(path)
2037
+ if bundles:
2038
+ for fields in bundles.values():
2039
+ declared_local.add(tuple(sorted(fields)))
2040
+ documented |= documented_bundles_by_path.get(path, set())
2041
+ observed_norm = {tuple(sorted(b)) for b in observed}
2042
+ observed_only = (
2043
+ sorted(observed_norm - declared_global)
2044
+ if declared_global
2045
+ else sorted(observed_norm)
2046
+ )
2047
+ declared_only = sorted(declared_local - observed_norm)
2048
+ documented_only = sorted(observed_norm & documented)
2049
+ def _tier(bundle: tuple[str, ...]) -> str:
2050
+ count = bundle_counts.get(bundle, 1)
2051
+ if bundle in declared_global:
2052
+ return "tier-1"
2053
+ if count > 1:
2054
+ return "tier-2"
2055
+ return "tier-3"
2056
+ summary_lines = [
2057
+ f"Functions: {len(fn_nodes)}",
2058
+ f"Observed bundles: {len(observed_norm)}",
2059
+ ]
2060
+ if not declared_local:
2061
+ summary_lines.append("Declared Config bundles: none found for this component.")
2062
+ if observed_only:
2063
+ summary_lines.append("Observed-only bundles (not declared in Configs):")
2064
+ for bundle in observed_only:
2065
+ tier = _tier(bundle)
2066
+ documented_flag = "documented" if bundle in documented else "undocumented"
2067
+ summary_lines.append(
2068
+ f" - {', '.join(bundle)} ({tier}, {documented_flag})"
2069
+ )
2070
+ if documented_only:
2071
+ summary_lines.append(
2072
+ "Documented bundles (dataflow-bundle markers or local dataclass calls):"
2073
+ )
2074
+ summary_lines.extend(f" - {', '.join(bundle)}" for bundle in documented_only)
2075
+ if declared_only:
2076
+ summary_lines.append("Declared Config bundles not observed in this component:")
2077
+ summary_lines.extend(f" - {', '.join(bundle)}" for bundle in declared_only)
2078
+ summary = "\n".join(summary_lines)
2079
+ return "\n".join(lines), summary
2080
+
2081
+
2082
+ def _emit_report(
2083
+ groups_by_path: dict[Path, dict[str, list[set[str]]]],
2084
+ max_components: int,
2085
+ *,
2086
+ type_suggestions: list[str] | None = None,
2087
+ type_ambiguities: list[str] | None = None,
2088
+ constant_smells: list[str] | None = None,
2089
+ unused_arg_smells: list[str] | None = None,
2090
+ ) -> tuple[str, list[str]]:
2091
+ nodes, adj, bundle_map = _component_graph(groups_by_path)
2092
+ components = _connected_components(nodes, adj)
2093
+ if groups_by_path:
2094
+ common = os.path.commonpath([str(p) for p in groups_by_path])
2095
+ root = Path(common)
2096
+ else:
2097
+ root = Path(".")
2098
+ file_paths = sorted(root.rglob("*.py"))
2099
+ config_bundles_by_path = _collect_config_bundles(file_paths)
2100
+ documented_bundles_by_path = {}
2101
+ symbol_table = _build_symbol_table(
2102
+ file_paths,
2103
+ root,
2104
+ external_filter=True,
2105
+ )
2106
+ dataclass_registry = _collect_dataclass_registry(
2107
+ file_paths,
2108
+ project_root=root,
2109
+ )
2110
+ for path in file_paths:
2111
+ documented = _iter_documented_bundles(path)
2112
+ promoted = _iter_dataclass_call_bundles(
2113
+ path,
2114
+ project_root=root,
2115
+ symbol_table=symbol_table,
2116
+ dataclass_registry=dataclass_registry,
2117
+ )
2118
+ documented_bundles_by_path[path] = documented | promoted
2119
+ lines = [
2120
+ "<!-- dataflow-grammar -->",
2121
+ "Dataflow grammar audit (observed forwarding bundles).",
2122
+ "",
2123
+ ]
2124
+ if not components:
2125
+ return "\n".join(lines + ["No bundle components detected."]), []
2126
+ if len(components) > max_components:
2127
+ lines.append(
2128
+ f"Showing top {max_components} components of {len(components)}."
2129
+ )
2130
+ violations: list[str] = []
2131
+ for idx, comp in enumerate(components[:max_components], start=1):
2132
+ lines.append(f"### Component {idx}")
2133
+ mermaid, summary = _render_mermaid_component(
2134
+ nodes,
2135
+ bundle_map,
2136
+ adj,
2137
+ comp,
2138
+ config_bundles_by_path,
2139
+ documented_bundles_by_path,
2140
+ )
2141
+ lines.append(mermaid)
2142
+ lines.append("")
2143
+ lines.append("Summary:")
2144
+ lines.append("```")
2145
+ lines.append(summary)
2146
+ lines.append("```")
2147
+ lines.append("")
2148
+ for line in summary.splitlines():
2149
+ if "(tier-3, undocumented)" in line:
2150
+ violations.append(line.strip())
2151
+ if "(tier-1," in line or "(tier-2," in line:
2152
+ if "undocumented" in line:
2153
+ violations.append(line.strip())
2154
+ if violations:
2155
+ lines.append("Violations:")
2156
+ lines.append("```")
2157
+ lines.extend(violations)
2158
+ lines.append("```")
2159
+ if type_suggestions or type_ambiguities:
2160
+ lines.append("Type-flow audit:")
2161
+ if type_suggestions or type_ambiguities:
2162
+ lines.append(_render_type_mermaid(type_suggestions or [], type_ambiguities or []))
2163
+ if type_suggestions:
2164
+ lines.append("Type tightening candidates:")
2165
+ lines.append("```")
2166
+ lines.extend(type_suggestions)
2167
+ lines.append("```")
2168
+ if type_ambiguities:
2169
+ lines.append("Type ambiguities (conflicting downstream expectations):")
2170
+ lines.append("```")
2171
+ lines.extend(type_ambiguities)
2172
+ lines.append("```")
2173
+ if constant_smells:
2174
+ lines.append("Constant-propagation smells (non-test call sites):")
2175
+ lines.append("```")
2176
+ lines.extend(constant_smells)
2177
+ lines.append("```")
2178
+ if unused_arg_smells:
2179
+ lines.append("Unused-argument smells (non-test call sites):")
2180
+ lines.append("```")
2181
+ lines.extend(unused_arg_smells)
2182
+ lines.append("```")
2183
+ return "\n".join(lines), violations
2184
+
2185
+
2186
+ def _infer_root(groups_by_path: dict[Path, dict[str, list[set[str]]]]) -> Path:
2187
+ if groups_by_path:
2188
+ common = os.path.commonpath([str(p) for p in groups_by_path])
2189
+ return Path(common)
2190
+ return Path(".")
2191
+
2192
+
2193
+ def _bundle_counts(
2194
+ groups_by_path: dict[Path, dict[str, list[set[str]]]]
2195
+ ) -> dict[tuple[str, ...], int]:
2196
+ counts: dict[tuple[str, ...], int] = defaultdict(int)
2197
+ for groups in groups_by_path.values():
2198
+ for bundles in groups.values():
2199
+ for bundle in bundles:
2200
+ counts[tuple(sorted(bundle))] += 1
2201
+ return counts
2202
+
2203
+
2204
+ def _collect_declared_bundles(root: Path) -> set[tuple[str, ...]]:
2205
+ declared: set[tuple[str, ...]] = set()
2206
+ file_paths = sorted(root.rglob("*.py"))
2207
+ bundles_by_path = _collect_config_bundles(file_paths)
2208
+ for bundles in bundles_by_path.values():
2209
+ for fields in bundles.values():
2210
+ declared.add(tuple(sorted(fields)))
2211
+ return declared
2212
+
2213
+
2214
+ def build_synthesis_plan(
2215
+ groups_by_path: dict[Path, dict[str, list[set[str]]]],
2216
+ *,
2217
+ project_root: Path | None = None,
2218
+ max_tier: int = 2,
2219
+ min_bundle_size: int = 2,
2220
+ allow_singletons: bool = False,
2221
+ config: AuditConfig | None = None,
2222
+ ) -> dict[str, object]:
2223
+ audit_config = config or AuditConfig(
2224
+ project_root=project_root or _infer_root(groups_by_path)
2225
+ )
2226
+ root = project_root or audit_config.project_root or _infer_root(groups_by_path)
2227
+ counts = _bundle_counts(groups_by_path)
2228
+ if not counts:
2229
+ response = SynthesisResponse(
2230
+ protocols=[],
2231
+ warnings=["No bundles observed for synthesis."],
2232
+ errors=[],
2233
+ )
2234
+ return response.model_dump()
2235
+
2236
+ declared = _collect_declared_bundles(root)
2237
+ bundle_tiers: dict[frozenset[str], int] = {}
2238
+ frequency: dict[str, int] = defaultdict(int)
2239
+ bundle_fields: set[str] = set()
2240
+ for bundle, count in counts.items():
2241
+ tier = 1 if bundle in declared else (2 if count > 1 else 3)
2242
+ bundle_tiers[frozenset(bundle)] = tier
2243
+ for field in bundle:
2244
+ frequency[field] += count
2245
+ bundle_fields.add(field)
2246
+
2247
+ merged_bundle_tiers: dict[frozenset[str], int] = {}
2248
+ original_bundles = [set(bundle) for bundle in counts]
2249
+ merged_bundles = merge_bundles(original_bundles)
2250
+ if merged_bundles:
2251
+ for merged in merged_bundles:
2252
+ members = [
2253
+ bundle
2254
+ for bundle in original_bundles
2255
+ if bundle and bundle.issubset(merged)
2256
+ ]
2257
+ if not members:
2258
+ continue
2259
+ tier = min(
2260
+ bundle_tiers[frozenset(member)] for member in members
2261
+ )
2262
+ merged_bundle_tiers[frozenset(merged)] = tier
2263
+ if merged_bundle_tiers:
2264
+ bundle_tiers = merged_bundle_tiers
2265
+
2266
+ naming_context = NamingContext(frequency=dict(frequency))
2267
+ synth_config = SynthesisConfig(
2268
+ max_tier=max_tier,
2269
+ min_bundle_size=min_bundle_size,
2270
+ allow_singletons=allow_singletons,
2271
+ )
2272
+ field_types: dict[str, str] = {}
2273
+ type_warnings: list[str] = []
2274
+ if bundle_fields:
2275
+ inferred, _, _ = analyze_type_flow_repo_with_map(
2276
+ list(groups_by_path.keys()),
2277
+ project_root=root,
2278
+ ignore_params=audit_config.ignore_params,
2279
+ strictness=audit_config.strictness,
2280
+ external_filter=audit_config.external_filter,
2281
+ transparent_decorators=audit_config.transparent_decorators,
2282
+ )
2283
+ type_sets: dict[str, set[str]] = defaultdict(set)
2284
+ for annots in inferred.values():
2285
+ for name, annot in annots.items():
2286
+ if name not in bundle_fields or not annot:
2287
+ continue
2288
+ type_sets[name].add(annot)
2289
+ by_name, by_qual = _build_function_index(
2290
+ list(groups_by_path.keys()),
2291
+ root,
2292
+ audit_config.ignore_params,
2293
+ audit_config.strictness,
2294
+ audit_config.transparent_decorators,
2295
+ )
2296
+ symbol_table = _build_symbol_table(
2297
+ list(groups_by_path.keys()),
2298
+ root,
2299
+ external_filter=audit_config.external_filter,
2300
+ )
2301
+ class_index = _collect_class_index(list(groups_by_path.keys()), root)
2302
+ for infos in by_name.values():
2303
+ for info in infos:
2304
+ for call in info.calls:
2305
+ if call.is_test:
2306
+ continue
2307
+ callee = _resolve_callee(
2308
+ call.callee,
2309
+ info,
2310
+ by_name,
2311
+ by_qual,
2312
+ symbol_table,
2313
+ root,
2314
+ class_index,
2315
+ )
2316
+ if callee is None or not callee.transparent:
2317
+ continue
2318
+ callee_params = callee.params
2319
+ for idx_str, value in call.const_pos.items():
2320
+ try:
2321
+ idx = int(idx_str)
2322
+ except ValueError:
2323
+ continue
2324
+ if idx >= len(callee_params):
2325
+ continue
2326
+ param = callee_params[idx]
2327
+ if param not in bundle_fields:
2328
+ continue
2329
+ hint = _type_from_const_repr(value)
2330
+ if hint:
2331
+ type_sets[param].add(hint)
2332
+ for kw, value in call.const_kw.items():
2333
+ if kw not in callee_params or kw not in bundle_fields:
2334
+ continue
2335
+ hint = _type_from_const_repr(value)
2336
+ if hint:
2337
+ type_sets[kw].add(hint)
2338
+ for name, types in type_sets.items():
2339
+ if not types:
2340
+ continue
2341
+ combined, conflicted = _combine_type_hints(types)
2342
+ field_types[name] = combined
2343
+ if conflicted and len(types) > 1:
2344
+ type_warnings.append(
2345
+ f"Conflicting type hints for '{name}': {sorted(types)} -> {combined}"
2346
+ )
2347
+ plan = Synthesizer(config=synth_config).plan(
2348
+ bundle_tiers=bundle_tiers,
2349
+ field_types=field_types,
2350
+ naming_context=naming_context,
2351
+ )
2352
+ response = SynthesisResponse(
2353
+ protocols=[
2354
+ {
2355
+ "name": spec.name,
2356
+ "fields": [
2357
+ {
2358
+ "name": field.name,
2359
+ "type_hint": field.type_hint,
2360
+ "source_params": sorted(field.source_params),
2361
+ }
2362
+ for field in spec.fields
2363
+ ],
2364
+ "bundle": sorted(spec.bundle),
2365
+ "tier": spec.tier,
2366
+ "rationale": spec.rationale,
2367
+ }
2368
+ for spec in plan.protocols
2369
+ ],
2370
+ warnings=plan.warnings + type_warnings,
2371
+ errors=plan.errors,
2372
+ )
2373
+ return response.model_dump()
2374
+
2375
+
2376
+ def render_synthesis_section(plan: dict[str, object]) -> str:
2377
+ protocols = plan.get("protocols", [])
2378
+ warnings = plan.get("warnings", [])
2379
+ errors = plan.get("errors", [])
2380
+ lines = ["", "## Synthesis plan (prototype)", ""]
2381
+ if not protocols:
2382
+ lines.append("No protocol candidates.")
2383
+ else:
2384
+ for spec in protocols:
2385
+ name = spec.get("name", "Bundle")
2386
+ tier = spec.get("tier", "?")
2387
+ fields = spec.get("fields", [])
2388
+ parts = []
2389
+ for field in fields:
2390
+ fname = field.get("name", "")
2391
+ type_hint = field.get("type_hint") or "Any"
2392
+ if fname:
2393
+ parts.append(f"{fname}: {type_hint}")
2394
+ field_list = ", ".join(parts) if parts else "(no fields)"
2395
+ lines.append(f"- {name} (tier {tier}): {field_list}")
2396
+ if warnings:
2397
+ lines.append("")
2398
+ lines.append("Warnings:")
2399
+ lines.append("```")
2400
+ lines.extend(str(w) for w in warnings)
2401
+ lines.append("```")
2402
+ if errors:
2403
+ lines.append("")
2404
+ lines.append("Errors:")
2405
+ lines.append("```")
2406
+ lines.extend(str(e) for e in errors)
2407
+ lines.append("```")
2408
+ return "\n".join(lines)
2409
+
2410
+
2411
+ def render_protocol_stubs(plan: dict[str, object], kind: str = "dataclass") -> str:
2412
+ protocols = plan.get("protocols", [])
2413
+ if kind not in {"dataclass", "protocol"}:
2414
+ kind = "dataclass"
2415
+ typing_names = {"Any"}
2416
+ if kind == "protocol":
2417
+ typing_names.add("Protocol")
2418
+ for spec in protocols:
2419
+ for field in spec.get("fields", []) or []:
2420
+ hint = field.get("type_hint") or "Any"
2421
+ if "Optional[" in hint:
2422
+ typing_names.add("Optional")
2423
+ if "Union[" in hint:
2424
+ typing_names.add("Union")
2425
+ typing_import = ", ".join(sorted(typing_names))
2426
+ lines = [
2427
+ "# Auto-generated by gabion dataflow audit.",
2428
+ "from __future__ import annotations",
2429
+ "",
2430
+ f"from typing import {typing_import}",
2431
+ "",
2432
+ ]
2433
+ if kind == "dataclass":
2434
+ lines.insert(3, "from dataclasses import dataclass")
2435
+ if not protocols:
2436
+ lines.append("# No protocol candidates.")
2437
+ return "\n".join(lines)
2438
+ placeholder_base = "TODO_Name_Me"
2439
+ for idx, spec in enumerate(protocols, start=1):
2440
+ name = placeholder_base if idx == 1 else f"{placeholder_base}{idx}"
2441
+ suggested = spec.get("name", "Bundle")
2442
+ tier = spec.get("tier", "?")
2443
+ bundle = spec.get("bundle", [])
2444
+ rationale = spec.get("rationale", "")
2445
+ if kind == "dataclass":
2446
+ lines.append("@dataclass")
2447
+ lines.append(f"class {name}:")
2448
+ else:
2449
+ lines.append(f"class {name}(Protocol):")
2450
+ doc_lines = [
2451
+ "TODO: Rename this Protocol.",
2452
+ f"Suggested name: {suggested}",
2453
+ f"Tier: {tier}",
2454
+ ]
2455
+ if bundle:
2456
+ doc_lines.append(f"Bundle: {', '.join(bundle)}")
2457
+ if rationale:
2458
+ doc_lines.append(f"Rationale: {rationale}")
2459
+ fields = spec.get("fields", [])
2460
+ if fields:
2461
+ field_summary = []
2462
+ for field in fields:
2463
+ fname = field.get("name") or "field"
2464
+ type_hint = field.get("type_hint") or "Any"
2465
+ field_summary.append(f"{fname}: {type_hint}")
2466
+ doc_lines.append("Fields: " + ", ".join(field_summary))
2467
+ lines.append(' """')
2468
+ for line in doc_lines:
2469
+ lines.append(f" {line}")
2470
+ lines.append(' """')
2471
+ if not fields:
2472
+ lines.append(" pass")
2473
+ else:
2474
+ for field in fields:
2475
+ fname = field.get("name") or "field"
2476
+ type_hint = field.get("type_hint") or "Any"
2477
+ lines.append(f" {fname}: {type_hint}")
2478
+ lines.append("")
2479
+ return "\n".join(lines)
2480
+
2481
+
2482
+ def build_refactor_plan(
2483
+ groups_by_path: dict[Path, dict[str, list[set[str]]]],
2484
+ paths: list[Path],
2485
+ *,
2486
+ config: AuditConfig,
2487
+ ) -> dict[str, object]:
2488
+ file_paths = _iter_paths([str(p) for p in paths], config)
2489
+ if not file_paths:
2490
+ return {"bundles": [], "warnings": ["No files available for refactor plan."]}
2491
+
2492
+ by_name, by_qual = _build_function_index(
2493
+ file_paths,
2494
+ config.project_root,
2495
+ config.ignore_params,
2496
+ config.strictness,
2497
+ config.transparent_decorators,
2498
+ )
2499
+ symbol_table = _build_symbol_table(
2500
+ file_paths, config.project_root, external_filter=config.external_filter
2501
+ )
2502
+ class_index = _collect_class_index(file_paths, config.project_root)
2503
+ info_by_path_name: dict[tuple[Path, str], FunctionInfo] = {}
2504
+ for infos in by_name.values():
2505
+ for info in infos:
2506
+ key = _function_key(info.scope, info.name)
2507
+ info_by_path_name[(info.path, key)] = info
2508
+
2509
+ bundle_map: dict[tuple[str, ...], dict[str, FunctionInfo]] = defaultdict(dict)
2510
+ for path, groups in groups_by_path.items():
2511
+ for fn, bundles in groups.items():
2512
+ for bundle in bundles:
2513
+ key = tuple(sorted(bundle))
2514
+ info = info_by_path_name.get((path, fn))
2515
+ if info is not None:
2516
+ bundle_map[key][info.qual] = info
2517
+
2518
+ plans: list[dict[str, object]] = []
2519
+ for bundle, infos in sorted(bundle_map.items(), key=lambda item: (len(item[0]), item[0])):
2520
+ if not infos:
2521
+ continue
2522
+ comp = dict(infos)
2523
+ deps: dict[str, set[str]] = {qual: set() for qual in comp}
2524
+ for info in infos.values():
2525
+ for call in info.calls:
2526
+ callee = _resolve_callee(
2527
+ call.callee,
2528
+ info,
2529
+ by_name,
2530
+ by_qual,
2531
+ symbol_table,
2532
+ config.project_root,
2533
+ class_index,
2534
+ )
2535
+ if callee is None:
2536
+ continue
2537
+ if not callee.transparent:
2538
+ continue
2539
+ if callee.qual in comp:
2540
+ deps[info.qual].add(callee.qual)
2541
+ schedule = topological_schedule(deps)
2542
+ plans.append(
2543
+ {
2544
+ "bundle": list(bundle),
2545
+ "functions": sorted(comp.keys()),
2546
+ "order": schedule.order,
2547
+ "cycles": [sorted(list(cycle)) for cycle in schedule.cycles],
2548
+ }
2549
+ )
2550
+
2551
+ warnings: list[str] = []
2552
+ if not plans:
2553
+ warnings.append("No bundle components available for refactor plan.")
2554
+ return {"bundles": plans, "warnings": warnings}
2555
+
2556
+
2557
+ def render_refactor_plan(plan: dict[str, object]) -> str:
2558
+ bundles = plan.get("bundles", [])
2559
+ warnings = plan.get("warnings", [])
2560
+ lines = ["", "## Refactoring plan (prototype)", ""]
2561
+ if not bundles:
2562
+ lines.append("No refactoring plan available.")
2563
+ else:
2564
+ for entry in bundles:
2565
+ bundle = entry.get("bundle", [])
2566
+ title = ", ".join(bundle) if bundle else "(unknown bundle)"
2567
+ lines.append(f"### Bundle: {title}")
2568
+ order = entry.get("order", [])
2569
+ if order:
2570
+ lines.append("Order (callee-first):")
2571
+ lines.append("```")
2572
+ for item in order:
2573
+ lines.append(f"- {item}")
2574
+ lines.append("```")
2575
+ cycles = entry.get("cycles", [])
2576
+ if cycles:
2577
+ lines.append("Cycles:")
2578
+ lines.append("```")
2579
+ for cycle in cycles:
2580
+ lines.append(", ".join(cycle))
2581
+ lines.append("```")
2582
+ if warnings:
2583
+ lines.append("")
2584
+ lines.append("Warnings:")
2585
+ lines.append("```")
2586
+ lines.extend(str(w) for w in warnings)
2587
+ lines.append("```")
2588
+ return "\n".join(lines)
2589
+
2590
+
2591
+ def _render_type_mermaid(
2592
+ suggestions: list[str],
2593
+ ambiguities: list[str],
2594
+ ) -> str:
2595
+ lines = ["```mermaid", "flowchart LR"]
2596
+ node_id = 0
2597
+ def _node(label: str) -> str:
2598
+ nonlocal node_id
2599
+ node_id += 1
2600
+ node = f"type_{node_id}"
2601
+ safe = label.replace('"', "'")
2602
+ lines.append(f' {node}["{safe}"]')
2603
+ return node
2604
+
2605
+ for entry in suggestions:
2606
+ # Format: file:func.param can tighten to Type
2607
+ if " can tighten to " not in entry:
2608
+ continue
2609
+ lhs, rhs = entry.split(" can tighten to ", 1)
2610
+ src = _node(lhs)
2611
+ dst = _node(rhs)
2612
+ lines.append(f" {src} --> {dst}")
2613
+ for entry in ambiguities:
2614
+ if " downstream types conflict: " not in entry:
2615
+ continue
2616
+ lhs, rhs = entry.split(" downstream types conflict: ", 1)
2617
+ src = _node(lhs)
2618
+ # rhs is a repr of list; keep as string nodes per type
2619
+ rhs = rhs.strip()
2620
+ if rhs.startswith("[") and rhs.endswith("]"):
2621
+ rhs = rhs[1:-1]
2622
+ type_names = []
2623
+ for item in rhs.split(","):
2624
+ item = item.strip()
2625
+ if not item:
2626
+ continue
2627
+ item = item.strip("'\"")
2628
+ type_names.append(item)
2629
+ for type_name in type_names:
2630
+ dst = _node(type_name)
2631
+ lines.append(f" {src} -.-> {dst}")
2632
+ lines.append("```")
2633
+ return "\n".join(lines)
2634
+
2635
+
2636
+ def _compute_violations(
2637
+ groups_by_path: dict[Path, dict[str, list[set[str]]]],
2638
+ max_components: int,
2639
+ *,
2640
+ type_suggestions: list[str] | None = None,
2641
+ type_ambiguities: list[str] | None = None,
2642
+ ) -> list[str]:
2643
+ _, violations = _emit_report(
2644
+ groups_by_path,
2645
+ max_components,
2646
+ type_suggestions=type_suggestions,
2647
+ type_ambiguities=type_ambiguities,
2648
+ constant_smells=[],
2649
+ unused_arg_smells=[],
2650
+ )
2651
+ return violations
2652
+
2653
+
2654
+ def _resolve_baseline_path(path: str | None, root: Path) -> Path | None:
2655
+ if not path:
2656
+ return None
2657
+ baseline = Path(path)
2658
+ if not baseline.is_absolute():
2659
+ baseline = root / baseline
2660
+ return baseline
2661
+
2662
+
2663
+ def _load_baseline(path: Path) -> set[str]:
2664
+ if not path.exists():
2665
+ return set()
2666
+ try:
2667
+ raw = path.read_text()
2668
+ except OSError:
2669
+ return set()
2670
+ entries: set[str] = set()
2671
+ for line in raw.splitlines():
2672
+ line = line.strip()
2673
+ if not line or line.startswith("#"):
2674
+ continue
2675
+ entries.add(line)
2676
+ return entries
2677
+
2678
+
2679
+ def _write_baseline(path: Path, violations: list[str]) -> None:
2680
+ path.parent.mkdir(parents=True, exist_ok=True)
2681
+ unique = sorted(set(violations))
2682
+ header = [
2683
+ "# gabion baseline (ratchet)",
2684
+ "# Lines list known violations to allow; new ones should fail.",
2685
+ "",
2686
+ ]
2687
+ path.write_text("\n".join(header + unique) + "\n")
2688
+
2689
+
2690
+ def _apply_baseline(
2691
+ violations: list[str], baseline: set[str]
2692
+ ) -> tuple[list[str], list[str]]:
2693
+ if not baseline:
2694
+ return violations, []
2695
+ new = [line for line in violations if line not in baseline]
2696
+ suppressed = [line for line in violations if line in baseline]
2697
+ return new, suppressed
2698
+
2699
+
2700
+ def resolve_baseline_path(path: str | None, root: Path) -> Path | None:
2701
+ return _resolve_baseline_path(path, root)
2702
+
2703
+
2704
+ def load_baseline(path: Path) -> set[str]:
2705
+ return _load_baseline(path)
2706
+
2707
+
2708
+ def write_baseline(path: Path, violations: list[str]) -> None:
2709
+ _write_baseline(path, violations)
2710
+
2711
+
2712
+ def apply_baseline(
2713
+ violations: list[str], baseline: set[str]
2714
+ ) -> tuple[list[str], list[str]]:
2715
+ return _apply_baseline(violations, baseline)
2716
+
2717
+
2718
+ def render_dot(groups_by_path: dict[Path, dict[str, list[set[str]]]]) -> str:
2719
+ return _emit_dot(groups_by_path)
2720
+
2721
+
2722
+ def render_report(
2723
+ groups_by_path: dict[Path, dict[str, list[set[str]]]],
2724
+ max_components: int,
2725
+ *,
2726
+ type_suggestions: list[str] | None = None,
2727
+ type_ambiguities: list[str] | None = None,
2728
+ constant_smells: list[str] | None = None,
2729
+ unused_arg_smells: list[str] | None = None,
2730
+ ) -> tuple[str, list[str]]:
2731
+ return _emit_report(
2732
+ groups_by_path,
2733
+ max_components,
2734
+ type_suggestions=type_suggestions,
2735
+ type_ambiguities=type_ambiguities,
2736
+ constant_smells=constant_smells,
2737
+ unused_arg_smells=unused_arg_smells,
2738
+ )
2739
+
2740
+
2741
+ def compute_violations(
2742
+ groups_by_path: dict[Path, dict[str, list[set[str]]]],
2743
+ max_components: int,
2744
+ *,
2745
+ type_suggestions: list[str] | None = None,
2746
+ type_ambiguities: list[str] | None = None,
2747
+ ) -> list[str]:
2748
+ return _compute_violations(
2749
+ groups_by_path,
2750
+ max_components,
2751
+ type_suggestions=type_suggestions,
2752
+ type_ambiguities=type_ambiguities,
2753
+ )
2754
+
2755
+
2756
+ def analyze_paths(
2757
+ paths: list[Path],
2758
+ *,
2759
+ recursive: bool,
2760
+ type_audit: bool,
2761
+ type_audit_report: bool,
2762
+ type_audit_max: int,
2763
+ include_constant_smells: bool,
2764
+ include_unused_arg_smells: bool,
2765
+ config: AuditConfig | None = None,
2766
+ ) -> AnalysisResult:
2767
+ if config is None:
2768
+ config = AuditConfig()
2769
+ file_paths = _iter_paths([str(p) for p in paths], config)
2770
+ groups_by_path: dict[Path, dict[str, list[set[str]]]] = {}
2771
+ param_spans_by_path: dict[Path, dict[str, dict[str, tuple[int, int, int, int]]]] = {}
2772
+ for path in file_paths:
2773
+ groups, spans = analyze_file(path, recursive=recursive, config=config)
2774
+ groups_by_path[path] = groups
2775
+ param_spans_by_path[path] = spans
2776
+
2777
+ type_suggestions: list[str] = []
2778
+ type_ambiguities: list[str] = []
2779
+ if type_audit or type_audit_report:
2780
+ type_suggestions, type_ambiguities = analyze_type_flow_repo(
2781
+ file_paths,
2782
+ project_root=config.project_root,
2783
+ ignore_params=config.ignore_params,
2784
+ strictness=config.strictness,
2785
+ external_filter=config.external_filter,
2786
+ transparent_decorators=config.transparent_decorators,
2787
+ )
2788
+ if type_audit_report:
2789
+ type_suggestions = type_suggestions[:type_audit_max]
2790
+ type_ambiguities = type_ambiguities[:type_audit_max]
2791
+
2792
+ constant_smells: list[str] = []
2793
+ if include_constant_smells:
2794
+ constant_smells = analyze_constant_flow_repo(
2795
+ file_paths,
2796
+ project_root=config.project_root,
2797
+ ignore_params=config.ignore_params,
2798
+ strictness=config.strictness,
2799
+ external_filter=config.external_filter,
2800
+ transparent_decorators=config.transparent_decorators,
2801
+ )
2802
+
2803
+ unused_arg_smells: list[str] = []
2804
+ if include_unused_arg_smells:
2805
+ unused_arg_smells = analyze_unused_arg_flow_repo(
2806
+ file_paths,
2807
+ project_root=config.project_root,
2808
+ ignore_params=config.ignore_params,
2809
+ strictness=config.strictness,
2810
+ external_filter=config.external_filter,
2811
+ transparent_decorators=config.transparent_decorators,
2812
+ )
2813
+
2814
+ return AnalysisResult(
2815
+ groups_by_path=groups_by_path,
2816
+ param_spans_by_path=param_spans_by_path,
2817
+ type_suggestions=type_suggestions,
2818
+ type_ambiguities=type_ambiguities,
2819
+ constant_smells=constant_smells,
2820
+ unused_arg_smells=unused_arg_smells,
2821
+ )
2822
+
2823
+
2824
+ def _build_parser() -> argparse.ArgumentParser:
2825
+ parser = argparse.ArgumentParser()
2826
+ parser.add_argument("paths", nargs="+")
2827
+ parser.add_argument("--root", default=".", help="Project root for module resolution.")
2828
+ parser.add_argument("--config", default=None, help="Path to gabion.toml.")
2829
+ parser.add_argument(
2830
+ "--exclude",
2831
+ action="append",
2832
+ default=None,
2833
+ help="Comma-separated directory names to exclude (repeatable).",
2834
+ )
2835
+ parser.add_argument(
2836
+ "--ignore-params",
2837
+ default=None,
2838
+ help="Comma-separated parameter names to ignore.",
2839
+ )
2840
+ parser.add_argument(
2841
+ "--transparent-decorators",
2842
+ default=None,
2843
+ help="Comma-separated decorator names treated as transparent.",
2844
+ )
2845
+ parser.add_argument(
2846
+ "--allow-external",
2847
+ action=argparse.BooleanOptionalAction,
2848
+ default=None,
2849
+ help="Allow resolving calls into external libraries.",
2850
+ )
2851
+ parser.add_argument(
2852
+ "--strictness",
2853
+ choices=["high", "low"],
2854
+ default=None,
2855
+ help="Wildcard forwarding strictness (default: high).",
2856
+ )
2857
+ parser.add_argument("--no-recursive", action="store_true")
2858
+ parser.add_argument("--dot", default=None, help="Write DOT graph to file or '-' for stdout.")
2859
+ parser.add_argument("--report", default=None, help="Write Markdown report (mermaid) to file.")
2860
+ parser.add_argument("--max-components", type=int, default=10, help="Max components in report.")
2861
+ parser.add_argument(
2862
+ "--type-audit",
2863
+ action="store_true",
2864
+ help="Emit type-tightening suggestions based on downstream annotations.",
2865
+ )
2866
+ parser.add_argument(
2867
+ "--type-audit-max",
2868
+ type=int,
2869
+ default=50,
2870
+ help="Max type-tightening entries to print.",
2871
+ )
2872
+ parser.add_argument(
2873
+ "--type-audit-report",
2874
+ action="store_true",
2875
+ help="Include type-flow audit summary in the markdown report.",
2876
+ )
2877
+ parser.add_argument(
2878
+ "--fail-on-type-ambiguities",
2879
+ action="store_true",
2880
+ help="Exit non-zero if type ambiguities are detected.",
2881
+ )
2882
+ parser.add_argument(
2883
+ "--fail-on-violations",
2884
+ action="store_true",
2885
+ help="Exit non-zero if undocumented/undeclared bundle violations are detected.",
2886
+ )
2887
+ parser.add_argument(
2888
+ "--baseline",
2889
+ default=None,
2890
+ help="Baseline file of violations to allow (ratchet mode).",
2891
+ )
2892
+ parser.add_argument(
2893
+ "--baseline-write",
2894
+ action="store_true",
2895
+ help="Write the current violations to the baseline file and exit zero.",
2896
+ )
2897
+ parser.add_argument(
2898
+ "--synthesis-plan",
2899
+ default=None,
2900
+ help="Write synthesis plan JSON to file or '-' for stdout.",
2901
+ )
2902
+ parser.add_argument(
2903
+ "--synthesis-report",
2904
+ action="store_true",
2905
+ help="Include synthesis plan summary in the markdown report.",
2906
+ )
2907
+ parser.add_argument(
2908
+ "--synthesis-protocols",
2909
+ default=None,
2910
+ help="Write protocol/dataclass stubs to file or '-' for stdout.",
2911
+ )
2912
+ parser.add_argument(
2913
+ "--synthesis-protocols-kind",
2914
+ choices=["dataclass", "protocol"],
2915
+ default="dataclass",
2916
+ help="Emit dataclass or typing.Protocol stubs (default: dataclass).",
2917
+ )
2918
+ parser.add_argument(
2919
+ "--refactor-plan",
2920
+ action="store_true",
2921
+ help="Include refactoring plan summary in the markdown report.",
2922
+ )
2923
+ parser.add_argument(
2924
+ "--refactor-plan-json",
2925
+ default=None,
2926
+ help="Write refactoring plan JSON to file or '-' for stdout.",
2927
+ )
2928
+ parser.add_argument(
2929
+ "--synthesis-max-tier",
2930
+ type=int,
2931
+ default=2,
2932
+ help="Max tier to include in synthesis plan.",
2933
+ )
2934
+ parser.add_argument(
2935
+ "--synthesis-min-bundle-size",
2936
+ type=int,
2937
+ default=2,
2938
+ help="Min bundle size to include in synthesis plan.",
2939
+ )
2940
+ parser.add_argument(
2941
+ "--synthesis-allow-singletons",
2942
+ action="store_true",
2943
+ help="Allow single-field bundles in synthesis plan.",
2944
+ )
2945
+ return parser
2946
+
2947
+
2948
+ def _normalize_transparent_decorators(
2949
+ value: object,
2950
+ ) -> set[str] | None:
2951
+ if value is None:
2952
+ return None
2953
+ items: list[str] = []
2954
+ if isinstance(value, str):
2955
+ items = [part.strip() for part in value.split(",") if part.strip()]
2956
+ elif isinstance(value, (list, tuple, set)):
2957
+ for item in value:
2958
+ if isinstance(item, str):
2959
+ items.extend([part.strip() for part in item.split(",") if part.strip()])
2960
+ if not items:
2961
+ return None
2962
+ return set(items)
2963
+
2964
+
2965
+ def run(argv: list[str] | None = None) -> int:
2966
+ parser = _build_parser()
2967
+ args = parser.parse_args(argv)
2968
+ if args.fail_on_type_ambiguities:
2969
+ args.type_audit = True
2970
+ exclude_dirs: list[str] | None = None
2971
+ if args.exclude is not None:
2972
+ exclude_dirs = []
2973
+ for entry in args.exclude:
2974
+ for part in entry.split(","):
2975
+ part = part.strip()
2976
+ if part:
2977
+ exclude_dirs.append(part)
2978
+ ignore_params: list[str] | None = None
2979
+ if args.ignore_params is not None:
2980
+ ignore_params = [p.strip() for p in args.ignore_params.split(",") if p.strip()]
2981
+ transparent_decorators: list[str] | None = None
2982
+ if args.transparent_decorators is not None:
2983
+ transparent_decorators = [
2984
+ p.strip() for p in args.transparent_decorators.split(",") if p.strip()
2985
+ ]
2986
+ config_path = Path(args.config) if args.config else None
2987
+ defaults = dataflow_defaults(Path(args.root), config_path)
2988
+ merged = merge_payload(
2989
+ {
2990
+ "exclude": exclude_dirs,
2991
+ "ignore_params": ignore_params,
2992
+ "allow_external": args.allow_external,
2993
+ "strictness": args.strictness,
2994
+ "baseline": args.baseline,
2995
+ "transparent_decorators": transparent_decorators,
2996
+ },
2997
+ defaults,
2998
+ )
2999
+ exclude_dirs = set(merged.get("exclude", []) or [])
3000
+ ignore_params_set = set(merged.get("ignore_params", []) or [])
3001
+ allow_external = bool(merged.get("allow_external", False))
3002
+ strictness = merged.get("strictness") or "high"
3003
+ if strictness not in {"high", "low"}:
3004
+ strictness = "high"
3005
+ transparent_decorators = _normalize_transparent_decorators(
3006
+ merged.get("transparent_decorators")
3007
+ )
3008
+ config = AuditConfig(
3009
+ project_root=Path(args.root),
3010
+ exclude_dirs=exclude_dirs,
3011
+ ignore_params=ignore_params_set,
3012
+ external_filter=not allow_external,
3013
+ strictness=strictness,
3014
+ transparent_decorators=transparent_decorators,
3015
+ )
3016
+ baseline_path = _resolve_baseline_path(merged.get("baseline"), Path(args.root))
3017
+ baseline_write = args.baseline_write
3018
+ if baseline_write and baseline_path is None:
3019
+ print("Baseline path required for --baseline-write.", file=sys.stderr)
3020
+ return 2
3021
+ paths = _iter_paths(args.paths, config)
3022
+ analysis = analyze_paths(
3023
+ paths,
3024
+ recursive=not args.no_recursive,
3025
+ type_audit=args.type_audit or args.type_audit_report,
3026
+ type_audit_report=args.type_audit_report,
3027
+ type_audit_max=args.type_audit_max,
3028
+ include_constant_smells=bool(args.report),
3029
+ include_unused_arg_smells=bool(args.report),
3030
+ config=config,
3031
+ )
3032
+ synthesis_plan: dict[str, object] | None = None
3033
+ if args.synthesis_plan or args.synthesis_report or args.synthesis_protocols:
3034
+ synthesis_plan = build_synthesis_plan(
3035
+ analysis.groups_by_path,
3036
+ project_root=config.project_root,
3037
+ max_tier=args.synthesis_max_tier,
3038
+ min_bundle_size=args.synthesis_min_bundle_size,
3039
+ allow_singletons=args.synthesis_allow_singletons,
3040
+ config=config,
3041
+ )
3042
+ if args.synthesis_plan:
3043
+ payload = json.dumps(synthesis_plan, indent=2, sort_keys=True)
3044
+ if args.synthesis_plan.strip() == "-":
3045
+ print(payload)
3046
+ else:
3047
+ Path(args.synthesis_plan).write_text(payload)
3048
+ if args.synthesis_protocols:
3049
+ stubs = render_protocol_stubs(
3050
+ synthesis_plan, kind=args.synthesis_protocols_kind
3051
+ )
3052
+ if args.synthesis_protocols.strip() == "-":
3053
+ print(stubs)
3054
+ else:
3055
+ Path(args.synthesis_protocols).write_text(stubs)
3056
+ refactor_plan: dict[str, object] | None = None
3057
+ if args.refactor_plan or args.refactor_plan_json:
3058
+ refactor_plan = build_refactor_plan(
3059
+ analysis.groups_by_path,
3060
+ paths,
3061
+ config=config,
3062
+ )
3063
+ if args.refactor_plan_json:
3064
+ payload = json.dumps(refactor_plan, indent=2, sort_keys=True)
3065
+ if args.refactor_plan_json.strip() == "-":
3066
+ print(payload)
3067
+ else:
3068
+ Path(args.refactor_plan_json).write_text(payload)
3069
+ if args.dot is not None:
3070
+ dot = _emit_dot(analysis.groups_by_path)
3071
+ if args.dot.strip() == "-":
3072
+ print(dot)
3073
+ else:
3074
+ Path(args.dot).write_text(dot)
3075
+ if args.report is None:
3076
+ return 0
3077
+ if args.type_audit:
3078
+ if analysis.type_suggestions:
3079
+ print("Type tightening candidates:")
3080
+ for line in analysis.type_suggestions[: args.type_audit_max]:
3081
+ print(f"- {line}")
3082
+ if analysis.type_ambiguities:
3083
+ print("Type ambiguities (conflicting downstream expectations):")
3084
+ for line in analysis.type_ambiguities[: args.type_audit_max]:
3085
+ print(f"- {line}")
3086
+ if args.report is None and not (
3087
+ args.synthesis_plan
3088
+ or args.synthesis_report
3089
+ or args.synthesis_protocols
3090
+ or args.refactor_plan
3091
+ or args.refactor_plan_json
3092
+ ):
3093
+ return 0
3094
+ if args.report is not None:
3095
+ report, violations = _emit_report(
3096
+ analysis.groups_by_path,
3097
+ args.max_components,
3098
+ type_suggestions=analysis.type_suggestions if args.type_audit_report else None,
3099
+ type_ambiguities=analysis.type_ambiguities if args.type_audit_report else None,
3100
+ constant_smells=analysis.constant_smells,
3101
+ unused_arg_smells=analysis.unused_arg_smells,
3102
+ )
3103
+ suppressed: list[str] = []
3104
+ new_violations = violations
3105
+ if baseline_path is not None:
3106
+ baseline_entries = _load_baseline(baseline_path)
3107
+ if baseline_write:
3108
+ _write_baseline(baseline_path, violations)
3109
+ baseline_entries = set(violations)
3110
+ new_violations = []
3111
+ else:
3112
+ new_violations, suppressed = _apply_baseline(
3113
+ violations, baseline_entries
3114
+ )
3115
+ report = (
3116
+ report
3117
+ + "\n\nBaseline/Ratchet:\n```\n"
3118
+ + f"Baseline: {baseline_path}\n"
3119
+ + f"Baseline entries: {len(baseline_entries)}\n"
3120
+ + f"Suppressed: {len(suppressed)}\n"
3121
+ + f"New violations: {len(new_violations)}\n"
3122
+ + "```\n"
3123
+ )
3124
+ if synthesis_plan and (
3125
+ args.synthesis_report or args.synthesis_plan or args.synthesis_protocols
3126
+ ):
3127
+ report = report + render_synthesis_section(synthesis_plan)
3128
+ if refactor_plan and (args.refactor_plan or args.refactor_plan_json):
3129
+ report = report + render_refactor_plan(refactor_plan)
3130
+ Path(args.report).write_text(report)
3131
+ if args.fail_on_violations and violations:
3132
+ if baseline_write:
3133
+ return 0
3134
+ if new_violations:
3135
+ return 1
3136
+ return 0
3137
+ for path, groups in analysis.groups_by_path.items():
3138
+ print(f"# {path}")
3139
+ for fn, bundles in groups.items():
3140
+ if not bundles:
3141
+ continue
3142
+ print(f"{fn}:")
3143
+ for bundle in bundles:
3144
+ print(f" bundle: {sorted(bundle)}")
3145
+ print()
3146
+ if args.fail_on_type_ambiguities and analysis.type_ambiguities:
3147
+ return 1
3148
+ if args.fail_on_violations:
3149
+ violations = _compute_violations(
3150
+ analysis.groups_by_path,
3151
+ args.max_components,
3152
+ type_suggestions=analysis.type_suggestions if args.type_audit_report else None,
3153
+ type_ambiguities=analysis.type_ambiguities if args.type_audit_report else None,
3154
+ )
3155
+ if baseline_path is not None:
3156
+ baseline_entries = _load_baseline(baseline_path)
3157
+ if baseline_write:
3158
+ _write_baseline(baseline_path, violations)
3159
+ return 0
3160
+ new_violations, _ = _apply_baseline(violations, baseline_entries)
3161
+ if new_violations:
3162
+ return 1
3163
+ elif violations:
3164
+ return 1
3165
+ return 0
3166
+
3167
+
3168
+ def main() -> None:
3169
+ raise SystemExit(run())
3170
+
3171
+
3172
+ if __name__ == "__main__":
3173
+ main()