sarj-python-lint 0.6.0__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/PKG-INFO +1 -1
  2. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/pyproject.toml +1 -1
  3. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/__main__.py +15 -6
  4. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_aggregation_in_store_query.py +3 -3
  5. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_fat_try_blocks.py +30 -12
  6. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_query_with_many_joins.py +3 -3
  7. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_select_star.py +3 -3
  8. sarj_python_lint-0.7.0/src/sarj_python_lint/rules/no_sequential_await.py +196 -0
  9. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/store_insert_requires_on_conflict.py +3 -3
  10. sarj_python_lint-0.6.0/src/sarj_python_lint/rules/no_sequential_await.py +0 -94
  11. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/.gitignore +0 -0
  12. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/README.md +0 -0
  13. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/__init__.py +0 -0
  14. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/py.typed +0 -0
  15. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rule_base.py +0 -0
  16. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/__init__.py +0 -0
  17. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/_logging.py +0 -0
  18. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/_registry.py +0 -0
  19. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/inefficient_string_concat_in_loop.py +0 -0
  20. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_comment_cruft.py +0 -0
  21. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_fstring_in_log.py +0 -0
  22. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_isinstance_union_chain.py +0 -0
  23. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_secret_in_log.py +0 -0
  24. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_sentinel_return_on_except.py +0 -0
  25. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/no_unreachable_after_terminal.py +0 -0
  26. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/prefer_class_row.py +0 -0
  27. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/prefer_constant_time_secret_compare.py +0 -0
  28. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/prefer_discriminated_union.py +0 -0
  29. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/prefer_str_enum.py +0 -0
  30. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/prefer_struct_over_namedtuple.py +0 -0
  31. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/prefer_timedelta_for_durations.py +0 -0
  32. {sarj_python_lint-0.6.0 → sarj_python_lint-0.7.0}/src/sarj_python_lint/rules/pydantic_at_boundaries.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sarj-python-lint
3
- Version: 0.6.0
3
+ Version: 0.7.0
4
4
  Summary: Custom Python lint rules — AST-based, pre-commit-friendly, hypermodern defaults
5
5
  Project-URL: Homepage, https://github.com/sarj-ai/standards/tree/main/packages/python
6
6
  Project-URL: Repository, https://github.com/sarj-ai/standards
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sarj-python-lint"
3
- version = "0.6.0"
3
+ version = "0.7.0"
4
4
  description = "Custom Python lint rules — AST-based, pre-commit-friendly, hypermodern defaults"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "sarj-ai" }]
@@ -66,6 +66,18 @@ def _check(rule_ids: list[str], paths: list[Path]) -> list[Diagnostic]:
66
66
  return diags
67
67
 
68
68
 
69
+ class _Args(argparse.Namespace):
70
+ cmd: str | None
71
+ rule: list[str]
72
+ files: list[Path]
73
+
74
+ def __init__(self) -> None:
75
+ super().__init__()
76
+ self.cmd = None
77
+ self.rule = []
78
+ self.files = []
79
+
80
+
69
81
  def main(argv: list[str] | None = None) -> int:
70
82
  parser = argparse.ArgumentParser(
71
83
  prog="sarj-python-lint",
@@ -85,18 +97,15 @@ def main(argv: list[str] | None = None) -> int:
85
97
 
86
98
  sub.add_parser("list-rules", help="List available rule IDs.")
87
99
 
88
- args = parser.parse_args(argv)
89
- cmd: str | None = args.cmd
100
+ args = parser.parse_args(argv, namespace=_Args())
90
101
 
91
- if cmd == "list-rules":
102
+ if args.cmd == "list-rules":
92
103
  for rid, cls in sorted(REGISTRY.items()):
93
104
  inst = cls()
94
105
  sys.stdout.write(f"{inst.code:8} {rid:40} {inst.description}\n")
95
106
  return 0
96
107
 
97
- rule_ids: list[str] = args.rule
98
- files: list[Path] = args.files
99
- diags = _check(rule_ids, files)
108
+ diags = _check(args.rule, args.files)
100
109
  for d in diags:
101
110
  sys.stdout.write(d.format() + "\n")
102
111
  return 1 if diags else 0
@@ -82,9 +82,9 @@ def _strip_sql_comments(text: str) -> str:
82
82
  class NoAggregationInStoreQuery(Rule):
83
83
  """DISTINCT / GROUP BY / COUNT in a store query — aggregate in ClickHouse."""
84
84
 
85
- id = "no-aggregation-in-store-query"
86
- code = "SARJ020"
87
- description = (
85
+ id: str = "no-aggregation-in-store-query"
86
+ code: str = "SARJ020"
87
+ description: str = (
88
88
  "DISTINCT / GROUP BY / COUNT in a Postgres store query — push heavy "
89
89
  "aggregation to the columnar mirror (ClickHouse / BigQuery)."
90
90
  )
@@ -1,4 +1,4 @@
1
- """SARJ007: `try` block whose body has more than 3 top-level statements.
1
+ """SARJ007: `try` block with more than 3 top-level statements that can raise.
2
2
 
3
3
  A fat `try` body obscures which statement is actually expected to raise and
4
4
  widens the blast radius of the `except` handlers: unrelated failures get
@@ -6,13 +6,20 @@ caught (and often swallowed or mis-reported) by handlers written for a
6
6
  different operation. Keep the `try` skinny — isolate the throwing
7
7
  statement(s) and move the non-throwing setup and follow-up work outside.
8
8
 
9
- Only the top-level statements of the `try` body are counted; statements
10
- nested inside an `if` / `with` / loop within the body count as the single
11
- compound statement that contains them. Nested `try` blocks are checked
12
- independently. `try*` (PEP 654 except-groups) is held to the same limit.
13
-
14
- This is a direct Python port of the org's ESLint restriction
15
- `TryStatement > BlockStatement[body.length > 3]` in eslint.strict.mjs.
9
+ Two refinements keep the count aligned with that intent and avoid the
10
+ false-positive patterns that dominated real-world suppressions:
11
+
12
+ * Only top-level statements that *can raise* are counted a statement counts
13
+ toward the limit only if its subtree contains a call or `await`. Pure
14
+ assignments / name-rebinds (`self.x = y`, `a = b.c`) don't obscure a throwing
15
+ statement and are free. Statements nested inside an `if` / `with` / loop
16
+ count as the single compound statement that contains them. Nested `try`
17
+ blocks are checked independently. `try*` (PEP 654) is held to the same limit.
18
+ * `try` blocks that carry an `else` or `finally` clause are exempt. Those
19
+ clauses are a deliberate success/cleanup contract that couples the body to
20
+ the handler (a `finally` that tears down a resource, an `else`/`finally` that
21
+ reads a status the body set) — statements can't be freely hoisted out without
22
+ changing semantics, so the length check is counterproductive there.
16
23
 
17
24
  Instead of:
18
25
  try:
@@ -52,12 +59,18 @@ if TYPE_CHECKING:
52
59
  _MAX_TRY_BODY_STATEMENTS = 3
53
60
 
54
61
 
62
+ def _can_raise(stmt: ast.stmt) -> bool:
63
+ """True if the statement's subtree contains a call or `await` — i.e. it can
64
+ plausibly raise. Pure assignments / rebinds with no call do not count."""
65
+ return any(isinstance(n, (ast.Call, ast.Await)) for n in ast.walk(stmt))
66
+
67
+
55
68
  class NoFatTryBlocks(Rule):
56
- """Try body longer than 3 statements — isolate the throwing statement(s)."""
69
+ """Try body with too many throwing statements — isolate the one that raises."""
57
70
 
58
71
  id: str = "no-fat-try-blocks"
59
72
  code: str = "SARJ007"
60
- description: str = "Try block body exceeds 3 statements — keep try blocks skinny."
73
+ description: str = "Try block has too many throwing statements — keep try blocks skinny."
61
74
 
62
75
  @override
63
76
  def check(self, path: Path, source: str) -> list[Diagnostic]:
@@ -69,7 +82,12 @@ class NoFatTryBlocks(Rule):
69
82
  for node in ast.walk(tree):
70
83
  if not isinstance(node, (ast.Try, ast.TryStar)):
71
84
  continue
72
- if len(node.body) <= _MAX_TRY_BODY_STATEMENTS:
85
+ # An `else`/`finally` clause is a deliberate success/cleanup contract
86
+ # that couples the body to the handler — don't fight it on length.
87
+ if node.orelse or node.finalbody:
88
+ continue
89
+ throwing = sum(_can_raise(stmt) for stmt in node.body)
90
+ if throwing <= _MAX_TRY_BODY_STATEMENTS:
73
91
  continue
74
92
  diags.append(
75
93
  Diagnostic(
@@ -78,7 +96,7 @@ class NoFatTryBlocks(Rule):
78
96
  col=node.col_offset + 1,
79
97
  code=self.code,
80
98
  message=(
81
- f"try block has {len(node.body)} statements "
99
+ f"try block has {throwing} statements that can raise "
82
100
  f"(max {_MAX_TRY_BODY_STATEMENTS}) — try blocks should "
83
101
  "isolate the throwing statement(s); move non-throwing "
84
102
  "work outside the try."
@@ -61,9 +61,9 @@ def _strip_sql_comments(text: str) -> str:
61
61
  class NoQueryWithManyJoins(Rule):
62
62
  """A SQL query with 3+ JOINs is too entangled — split it or denormalize."""
63
63
 
64
- id = "no-query-with-many-joins"
65
- code = "SARJ019"
66
- description = (
64
+ id: str = "no-query-with-many-joins"
65
+ code: str = "SARJ019"
66
+ description: str = (
67
67
  "SQL query with 3 or more JOINs — split the query or denormalize instead of fanning across many tables."
68
68
  )
69
69
 
@@ -67,9 +67,9 @@ def _has_real_select_star(sql: str) -> bool:
67
67
  class NoSelectStar(Rule):
68
68
  """`SELECT *` in a store query — list the columns explicitly."""
69
69
 
70
- id = "no-select-star"
71
- code = "SARJ021"
72
- description = (
70
+ id: str = "no-select-star"
71
+ code: str = "SARJ021"
72
+ description: str = (
73
73
  "SELECT * in a store query — name the columns; * over-fetches and breaks "
74
74
  "class_row mapping when the schema changes."
75
75
  )
@@ -0,0 +1,196 @@
1
+ """SARJ001: detect the `for x in xs: await f(x)` gather antipattern.
2
+
3
+ Sequential `await` in a for-loop serializes I/O that could be parallelized
4
+ with `asyncio.gather([f(x) for x in xs])`. The performance gap is often 10-100x
5
+ for network-bound work (HTTP, DB queries, LLM calls).
6
+
7
+ Deliberately narrow, to flag the textbook antipattern and almost nothing else —
8
+ an over-broad version drowned real signal under suppressions. The rule fires
9
+ only for:
10
+
11
+ * a `for` loop whose body is **straight-line** (no `if`/`try`/`with`/`return`/
12
+ `break`/`continue`/`raise`/nested loop — those signal conditional or ordered
13
+ logic, not a parallel map) and awaits a call that **uses the loop variable**
14
+ (so each iteration is a distinct, independent call); or
15
+ * a comprehension / generator expression with an `await` in its element or a
16
+ per-element `if` (those have no ordered side effects).
17
+
18
+ It does NOT fire for: `while` loops (pagination, polling, queue drains — length
19
+ unknown, inherently sequential), a loop's once-evaluated iterable
20
+ (`for x in await fetch()`), `async for`, test modules (intentional ordering),
21
+ or a `for` body containing control flow. Those were the false-positive sources.
22
+
23
+ References:
24
+ - https://docs.python.org/3/library/asyncio-task.html#running-tasks-concurrently
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import ast
30
+ from typing import TYPE_CHECKING, override
31
+
32
+ from sarj_python_lint.rule_base import Diagnostic, Rule
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from pathlib import Path
37
+
38
+
39
+ def _is_test_path(path: Path) -> bool:
40
+ name = path.name
41
+ if name == "conftest.py" or name.startswith("test_") or name.endswith("_test.py"):
42
+ return True
43
+ return any(part in {"tests", "test"} for part in path.parts)
44
+
45
+
46
+ class NoSequentialAwait(Rule):
47
+ """Sequential await calls in a loop that could be parallelized."""
48
+
49
+ id: str = "no-sequential-await"
50
+ code: str = "SARJ001"
51
+ description: str = "Sequential `await` in a for-loop — prefer asyncio.gather."
52
+
53
+ @override
54
+ def check(self, path: Path, source: str) -> list[Diagnostic]:
55
+ if _is_test_path(path):
56
+ return []
57
+ try:
58
+ tree = ast.parse(source, filename=str(path))
59
+ except SyntaxError:
60
+ return []
61
+ visitor = _SequentialAwaitVisitor()
62
+ visitor.visit(tree)
63
+ diags = [
64
+ Diagnostic(
65
+ path=path,
66
+ line=node.lineno,
67
+ col=node.col_offset + 1,
68
+ code=self.code,
69
+ message=("Sequential `await` inside `for` — prefer `asyncio.gather([f(x) for x in xs])`."),
70
+ )
71
+ for node in visitor.hits
72
+ ]
73
+ diags.sort(key=lambda d: (d.line, d.col))
74
+ return diags
75
+
76
+
77
+ # A loop's *iterable* is evaluated once in the enclosing scope, NOT per element:
78
+ # `for x in await fetch()` / `{x for x in await fetch()}` await once. Iterables
79
+ # are visited *before* the loop is pushed, so an await there attributes to an
80
+ # enclosing loop (if any), not this one.
81
+ _SCOPES = (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)
82
+
83
+ # Top-level body statements that signal conditional or ordered logic rather than
84
+ # a straight-line parallel map. A `for` whose body contains any of these is not
85
+ # treated as the gather antipattern.
86
+ _CONTROL_FLOW = (
87
+ ast.If,
88
+ ast.For,
89
+ ast.AsyncFor,
90
+ ast.While,
91
+ ast.With,
92
+ ast.AsyncWith,
93
+ ast.Try,
94
+ ast.Match,
95
+ ast.Return,
96
+ ast.Break,
97
+ ast.Continue,
98
+ ast.Raise,
99
+ )
100
+
101
+
102
+ def _names(node: ast.AST) -> set[str]:
103
+ return {n.id for n in ast.walk(node) if isinstance(n, ast.Name)}
104
+
105
+
106
+ def _is_gather_antipattern(node: ast.For) -> bool:
107
+ """True for `for x in xs: <straight-line body awaiting a call that uses x>`."""
108
+ if any(isinstance(stmt, _CONTROL_FLOW) for stmt in node.body):
109
+ return False
110
+ targets = _names(node.target)
111
+ for stmt in node.body:
112
+ for inner in ast.walk(stmt):
113
+ if isinstance(inner, ast.Await) and _names(inner) & targets:
114
+ return True
115
+ return False
116
+
117
+
118
+ class _SequentialAwaitVisitor(ast.NodeVisitor):
119
+ """Single O(n) pass: flag the first per-iteration `await` of each loop.
120
+
121
+ Maintains a stack of enclosing loops within the current function. The stack
122
+ resets at function boundaries so a loop in an outer function never claims an
123
+ `await` in a nested one. Each loop is flagged at most once. A loop's
124
+ once-evaluated iterable is excluded (see module comment).
125
+ """
126
+
127
+ def __init__(self) -> None:
128
+ super().__init__()
129
+ self._loops: list[ast.AST] = []
130
+ self._flagged: set[int] = set()
131
+ self.hits: list[ast.Await] = []
132
+
133
+ def _flag_if_in_loop(self, node: ast.Await) -> None:
134
+ if self._loops:
135
+ loop = self._loops[-1]
136
+ if id(loop) not in self._flagged:
137
+ self._flagged.add(id(loop))
138
+ self.hits.append(node)
139
+
140
+ def visit_For(self, node: ast.For) -> None:
141
+ # `<iter>` runs once in the enclosing scope; visit it before entering.
142
+ self.visit(node.iter)
143
+ # Only a straight-line per-element-await body is the gather antipattern;
144
+ # control-flow bodies (conditional/ordered) are not pushed, so awaits in
145
+ # them are not flagged for this loop.
146
+ antipattern = _is_gather_antipattern(node)
147
+ if antipattern:
148
+ self._loops.append(node)
149
+ self.visit(node.target)
150
+ for stmt in (*node.body, *node.orelse):
151
+ self.visit(stmt)
152
+ if antipattern:
153
+ self._loops.pop()
154
+
155
+ def _visit_comprehension(self, node: ast.AST, elements: tuple[ast.expr, ...]) -> None:
156
+ gens: list[ast.comprehension] = node.generators # pyright: ignore[reportAttributeAccessIssue]
157
+ # Outermost iterable is evaluated once in the enclosing scope.
158
+ self.visit(gens[0].iter)
159
+ self._loops.append(node)
160
+ for elt in elements:
161
+ self.visit(elt)
162
+ self.visit(gens[0].target)
163
+ for cond in gens[0].ifs:
164
+ self.visit(cond)
165
+ # Later generators iterate per element of the preceding one.
166
+ for gen in gens[1:]:
167
+ self.visit(gen.iter)
168
+ self.visit(gen.target)
169
+ for cond in gen.ifs:
170
+ self.visit(cond)
171
+ self._loops.pop()
172
+
173
+ def visit_ListComp(self, node: ast.ListComp) -> None:
174
+ self._visit_comprehension(node, (node.elt,))
175
+
176
+ def visit_SetComp(self, node: ast.SetComp) -> None:
177
+ self._visit_comprehension(node, (node.elt,))
178
+
179
+ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None:
180
+ self._visit_comprehension(node, (node.elt,))
181
+
182
+ def visit_DictComp(self, node: ast.DictComp) -> None:
183
+ self._visit_comprehension(node, (node.key, node.value))
184
+
185
+ @override
186
+ def generic_visit(self, node: ast.AST) -> None:
187
+ if isinstance(node, _SCOPES):
188
+ saved = self._loops
189
+ self._loops = []
190
+ super().generic_visit(node)
191
+ self._loops = saved
192
+ elif isinstance(node, ast.Await):
193
+ self._flag_if_in_loop(node)
194
+ super().generic_visit(node)
195
+ else:
196
+ super().generic_visit(node)
@@ -61,9 +61,9 @@ def _strip_sql_comments(text: str) -> str:
61
61
  class StoreInsertRequiresOnConflict(Rule):
62
62
  """Embedded INSERT in store code without ON CONFLICT — store writes must be upserts."""
63
63
 
64
- id = "store-insert-requires-on-conflict"
65
- code = "SARJ018"
66
- description = "Embedded SQL INSERT in store code without ON CONFLICT — store writes must be idempotent upserts."
64
+ id: str = "store-insert-requires-on-conflict"
65
+ code: str = "SARJ018"
66
+ description: str = "Embedded SQL INSERT in store code without ON CONFLICT — store writes must be idempotent upserts."
67
67
 
68
68
  @override
69
69
  def check(self, path: Path, source: str) -> list[Diagnostic]:
@@ -1,94 +0,0 @@
1
- """SARJ001: detect `for x in xs: await f(x)` patterns.
2
-
3
- Sequential `await` in a for-loop serializes I/O that could be parallelized
4
- with `asyncio.gather([f(x) for x in xs])`. The performance gap is often 10-100x
5
- for network-bound work (HTTP, DB queries, LLM calls).
6
-
7
- References:
8
- - https://docs.python.org/3/library/asyncio-task.html#running-tasks-concurrently
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- import ast
14
- from typing import TYPE_CHECKING, override
15
-
16
- from sarj_python_lint.rule_base import Diagnostic, Rule
17
-
18
-
19
- if TYPE_CHECKING:
20
- from pathlib import Path
21
-
22
-
23
- class NoSequentialAwait(Rule):
24
- """Sequential await calls in a loop that could be parallelized."""
25
-
26
- id: str = "no-sequential-await"
27
- code: str = "SARJ001"
28
- description: str = "Sequential `await` in a for-loop — prefer asyncio.gather."
29
-
30
- @override
31
- def check(self, path: Path, source: str) -> list[Diagnostic]:
32
- try:
33
- tree = ast.parse(source, filename=str(path))
34
- except SyntaxError:
35
- return []
36
- visitor = _SequentialAwaitVisitor()
37
- visitor.visit(tree)
38
- diags = [
39
- Diagnostic(
40
- path=path,
41
- line=node.lineno,
42
- col=node.col_offset + 1,
43
- code=self.code,
44
- message=(
45
- "Sequential `await` inside `for` — prefer "
46
- "`asyncio.gather([f(x) for x in xs])`."
47
- ),
48
- )
49
- for node in visitor.hits
50
- ]
51
- diags.sort(key=lambda d: (d.line, d.col))
52
- return diags
53
-
54
-
55
- # Loop-like constructs whose body runs once per element: `await` inside one of
56
- # them serializes the iterations. `async for` is deliberately absent — it is the
57
- # parallel-iteration construct, not the antipattern.
58
- _LOOPS = (ast.For, ast.While, ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)
59
- _SCOPES = (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)
60
-
61
-
62
- class _SequentialAwaitVisitor(ast.NodeVisitor):
63
- """Single O(n) pass: flag the first `await` of each enclosing loop.
64
-
65
- Maintains a stack of enclosing loops within the current function. The stack
66
- resets at function boundaries so a loop in an outer function never claims an
67
- `await` in a nested one. Each loop is flagged at most once.
68
- """
69
-
70
- def __init__(self) -> None:
71
- self._loops: list[ast.AST] = []
72
- self._flagged: set[int] = set()
73
- self.hits: list[ast.Await] = []
74
-
75
- @override
76
- def generic_visit(self, node: ast.AST) -> None:
77
- if isinstance(node, _SCOPES):
78
- saved = self._loops
79
- self._loops = []
80
- super().generic_visit(node)
81
- self._loops = saved
82
- elif isinstance(node, _LOOPS):
83
- self._loops.append(node)
84
- super().generic_visit(node)
85
- self._loops.pop()
86
- elif isinstance(node, ast.Await):
87
- if self._loops:
88
- loop = self._loops[-1]
89
- if id(loop) not in self._flagged:
90
- self._flagged.add(id(loop))
91
- self.hits.append(node)
92
- super().generic_visit(node)
93
- else:
94
- super().generic_visit(node)