sarj-python-lint 0.6.0__tar.gz → 0.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/PKG-INFO +1 -2
  2. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/README.md +0 -1
  3. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/pyproject.toml +1 -1
  4. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/__main__.py +15 -6
  5. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/_registry.py +0 -2
  6. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_aggregation_in_store_query.py +3 -3
  7. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_fat_try_blocks.py +30 -12
  8. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_query_with_many_joins.py +3 -3
  9. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_select_star.py +3 -3
  10. sarj_python_lint-0.8.0/src/sarj_python_lint/rules/no_sequential_await.py +196 -0
  11. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/store_insert_requires_on_conflict.py +3 -3
  12. sarj_python_lint-0.6.0/src/sarj_python_lint/rules/no_sequential_await.py +0 -94
  13. sarj_python_lint-0.6.0/src/sarj_python_lint/rules/prefer_discriminated_union.py +0 -391
  14. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/.gitignore +0 -0
  15. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/__init__.py +0 -0
  16. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/py.typed +0 -0
  17. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rule_base.py +0 -0
  18. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/__init__.py +0 -0
  19. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/_logging.py +0 -0
  20. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/inefficient_string_concat_in_loop.py +0 -0
  21. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_comment_cruft.py +0 -0
  22. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_fstring_in_log.py +0 -0
  23. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_isinstance_union_chain.py +0 -0
  24. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_secret_in_log.py +0 -0
  25. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_sentinel_return_on_except.py +0 -0
  26. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/no_unreachable_after_terminal.py +0 -0
  27. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/prefer_class_row.py +0 -0
  28. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/prefer_constant_time_secret_compare.py +0 -0
  29. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/prefer_str_enum.py +0 -0
  30. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/prefer_struct_over_namedtuple.py +0 -0
  31. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/prefer_timedelta_for_durations.py +0 -0
  32. {sarj_python_lint-0.6.0 → sarj_python_lint-0.8.0}/src/sarj_python_lint/rules/pydantic_at_boundaries.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sarj-python-lint
3
- Version: 0.6.0
3
+ Version: 0.8.0
4
4
  Summary: Custom Python lint rules — AST-based, pre-commit-friendly, hypermodern defaults
5
5
  Project-URL: Homepage, https://github.com/sarj-ai/standards/tree/main/packages/python
6
6
  Project-URL: Repository, https://github.com/sarj-ai/standards
@@ -32,7 +32,6 @@ uv tool install sarj-python-lint
32
32
  hooks:
33
33
  - id: sarj-no-sequential-await
34
34
  - id: sarj-inefficient-string-concat-in-loop
35
- - id: sarj-prefer-discriminated-union
36
35
  - id: sarj-prefer-str-enum
37
36
  - id: sarj-no-fat-try-blocks
38
37
  - id: sarj-pydantic-at-boundaries
@@ -14,7 +14,6 @@ uv tool install sarj-python-lint
14
14
  hooks:
15
15
  - id: sarj-no-sequential-await
16
16
  - id: sarj-inefficient-string-concat-in-loop
17
- - id: sarj-prefer-discriminated-union
18
17
  - id: sarj-prefer-str-enum
19
18
  - id: sarj-no-fat-try-blocks
20
19
  - id: sarj-pydantic-at-boundaries
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sarj-python-lint"
3
- version = "0.6.0"
3
+ version = "0.8.0"
4
4
  description = "Custom Python lint rules — AST-based, pre-commit-friendly, hypermodern defaults"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "sarj-ai" }]
@@ -66,6 +66,18 @@ def _check(rule_ids: list[str], paths: list[Path]) -> list[Diagnostic]:
66
66
  return diags
67
67
 
68
68
 
69
+ class _Args(argparse.Namespace):
70
+ cmd: str | None
71
+ rule: list[str]
72
+ files: list[Path]
73
+
74
+ def __init__(self) -> None:
75
+ super().__init__()
76
+ self.cmd = None
77
+ self.rule = []
78
+ self.files = []
79
+
80
+
69
81
  def main(argv: list[str] | None = None) -> int:
70
82
  parser = argparse.ArgumentParser(
71
83
  prog="sarj-python-lint",
@@ -85,18 +97,15 @@ def main(argv: list[str] | None = None) -> int:
85
97
 
86
98
  sub.add_parser("list-rules", help="List available rule IDs.")
87
99
 
88
- args = parser.parse_args(argv)
89
- cmd: str | None = args.cmd
100
+ args = parser.parse_args(argv, namespace=_Args())
90
101
 
91
- if cmd == "list-rules":
102
+ if args.cmd == "list-rules":
92
103
  for rid, cls in sorted(REGISTRY.items()):
93
104
  inst = cls()
94
105
  sys.stdout.write(f"{inst.code:8} {rid:40} {inst.description}\n")
95
106
  return 0
96
107
 
97
- rule_ids: list[str] = args.rule
98
- files: list[Path] = args.files
99
- diags = _check(rule_ids, files)
108
+ diags = _check(args.rule, args.files)
100
109
  for d in diags:
101
110
  sys.stdout.write(d.format() + "\n")
102
111
  return 1 if diags else 0
@@ -24,7 +24,6 @@ from sarj_python_lint.rules.prefer_class_row import PreferClassRow
24
24
  from sarj_python_lint.rules.prefer_constant_time_secret_compare import (
25
25
  PreferConstantTimeSecretCompare,
26
26
  )
27
- from sarj_python_lint.rules.prefer_discriminated_union import PreferDiscriminatedUnion
28
27
  from sarj_python_lint.rules.prefer_str_enum import PreferStrEnum
29
28
  from sarj_python_lint.rules.prefer_struct_over_namedtuple import (
30
29
  PreferStructOverNamedtuple,
@@ -45,7 +44,6 @@ if TYPE_CHECKING:
45
44
  REGISTRY: dict[str, type[Rule]] = {
46
45
  NoSequentialAwait.id: NoSequentialAwait,
47
46
  InefficientStringConcatInLoop.id: InefficientStringConcatInLoop,
48
- PreferDiscriminatedUnion.id: PreferDiscriminatedUnion,
49
47
  PreferClassRow.id: PreferClassRow,
50
48
  PreferStrEnum.id: PreferStrEnum,
51
49
  NoFatTryBlocks.id: NoFatTryBlocks,
@@ -82,9 +82,9 @@ def _strip_sql_comments(text: str) -> str:
82
82
  class NoAggregationInStoreQuery(Rule):
83
83
  """DISTINCT / GROUP BY / COUNT in a store query — aggregate in ClickHouse."""
84
84
 
85
- id = "no-aggregation-in-store-query"
86
- code = "SARJ020"
87
- description = (
85
+ id: str = "no-aggregation-in-store-query"
86
+ code: str = "SARJ020"
87
+ description: str = (
88
88
  "DISTINCT / GROUP BY / COUNT in a Postgres store query — push heavy "
89
89
  "aggregation to the columnar mirror (ClickHouse / BigQuery)."
90
90
  )
@@ -1,4 +1,4 @@
1
- """SARJ007: `try` block whose body has more than 3 top-level statements.
1
+ """SARJ007: `try` block with more than 3 top-level statements that can raise.
2
2
 
3
3
  A fat `try` body obscures which statement is actually expected to raise and
4
4
  widens the blast radius of the `except` handlers: unrelated failures get
@@ -6,13 +6,20 @@ caught (and often swallowed or mis-reported) by handlers written for a
6
6
  different operation. Keep the `try` skinny — isolate the throwing
7
7
  statement(s) and move the non-throwing setup and follow-up work outside.
8
8
 
9
- Only the top-level statements of the `try` body are counted; statements
10
- nested inside an `if` / `with` / loop within the body count as the single
11
- compound statement that contains them. Nested `try` blocks are checked
12
- independently. `try*` (PEP 654 except-groups) is held to the same limit.
13
-
14
- This is a direct Python port of the org's ESLint restriction
15
- `TryStatement > BlockStatement[body.length > 3]` in eslint.strict.mjs.
9
+ Two refinements keep the count aligned with that intent and avoid the
10
+ false-positive patterns that dominated real-world suppressions:
11
+
12
+ * Only top-level statements that *can raise* are counted a statement counts
13
+ toward the limit only if its subtree contains a call or `await`. Pure
14
+ assignments / name-rebinds (`self.x = y`, `a = b.c`) don't obscure a throwing
15
+ statement and are free. Statements nested inside an `if` / `with` / loop
16
+ count as the single compound statement that contains them. Nested `try`
17
+ blocks are checked independently. `try*` (PEP 654) is held to the same limit.
18
+ * `try` blocks that carry an `else` or `finally` clause are exempt. Those
19
+ clauses are a deliberate success/cleanup contract that couples the body to
20
+ the handler (a `finally` that tears down a resource, an `else`/`finally` that
21
+ reads a status the body set) — statements can't be freely hoisted out without
22
+ changing semantics, so the length check is counterproductive there.
16
23
 
17
24
  Instead of:
18
25
  try:
@@ -52,12 +59,18 @@ if TYPE_CHECKING:
52
59
  _MAX_TRY_BODY_STATEMENTS = 3
53
60
 
54
61
 
62
+ def _can_raise(stmt: ast.stmt) -> bool:
63
+ """True if the statement's subtree contains a call or `await` — i.e. it can
64
+ plausibly raise. Pure assignments / rebinds with no call do not count."""
65
+ return any(isinstance(n, (ast.Call, ast.Await)) for n in ast.walk(stmt))
66
+
67
+
55
68
  class NoFatTryBlocks(Rule):
56
- """Try body longer than 3 statements — isolate the throwing statement(s)."""
69
+ """Try body with too many throwing statements — isolate the one that raises."""
57
70
 
58
71
  id: str = "no-fat-try-blocks"
59
72
  code: str = "SARJ007"
60
- description: str = "Try block body exceeds 3 statements — keep try blocks skinny."
73
+ description: str = "Try block has too many throwing statements — keep try blocks skinny."
61
74
 
62
75
  @override
63
76
  def check(self, path: Path, source: str) -> list[Diagnostic]:
@@ -69,7 +82,12 @@ class NoFatTryBlocks(Rule):
69
82
  for node in ast.walk(tree):
70
83
  if not isinstance(node, (ast.Try, ast.TryStar)):
71
84
  continue
72
- if len(node.body) <= _MAX_TRY_BODY_STATEMENTS:
85
+ # An `else`/`finally` clause is a deliberate success/cleanup contract
86
+ # that couples the body to the handler — don't fight it on length.
87
+ if node.orelse or node.finalbody:
88
+ continue
89
+ throwing = sum(_can_raise(stmt) for stmt in node.body)
90
+ if throwing <= _MAX_TRY_BODY_STATEMENTS:
73
91
  continue
74
92
  diags.append(
75
93
  Diagnostic(
@@ -78,7 +96,7 @@ class NoFatTryBlocks(Rule):
78
96
  col=node.col_offset + 1,
79
97
  code=self.code,
80
98
  message=(
81
- f"try block has {len(node.body)} statements "
99
+ f"try block has {throwing} statements that can raise "
82
100
  f"(max {_MAX_TRY_BODY_STATEMENTS}) — try blocks should "
83
101
  "isolate the throwing statement(s); move non-throwing "
84
102
  "work outside the try."
@@ -61,9 +61,9 @@ def _strip_sql_comments(text: str) -> str:
61
61
  class NoQueryWithManyJoins(Rule):
62
62
  """A SQL query with 3+ JOINs is too entangled — split it or denormalize."""
63
63
 
64
- id = "no-query-with-many-joins"
65
- code = "SARJ019"
66
- description = (
64
+ id: str = "no-query-with-many-joins"
65
+ code: str = "SARJ019"
66
+ description: str = (
67
67
  "SQL query with 3 or more JOINs — split the query or denormalize instead of fanning across many tables."
68
68
  )
69
69
 
@@ -67,9 +67,9 @@ def _has_real_select_star(sql: str) -> bool:
67
67
  class NoSelectStar(Rule):
68
68
  """`SELECT *` in a store query — list the columns explicitly."""
69
69
 
70
- id = "no-select-star"
71
- code = "SARJ021"
72
- description = (
70
+ id: str = "no-select-star"
71
+ code: str = "SARJ021"
72
+ description: str = (
73
73
  "SELECT * in a store query — name the columns; * over-fetches and breaks "
74
74
  "class_row mapping when the schema changes."
75
75
  )
@@ -0,0 +1,196 @@
1
+ """SARJ001: detect the `for x in xs: await f(x)` gather antipattern.
2
+
3
+ Sequential `await` in a for-loop serializes I/O that could be parallelized
4
+ with `asyncio.gather([f(x) for x in xs])`. The performance gap is often 10-100x
5
+ for network-bound work (HTTP, DB queries, LLM calls).
6
+
7
+ Deliberately narrow, to flag the textbook antipattern and almost nothing else —
8
+ an over-broad version drowned real signal under suppressions. The rule fires
9
+ only for:
10
+
11
+ * a `for` loop whose body is **straight-line** (no `if`/`try`/`with`/`return`/
12
+ `break`/`continue`/`raise`/nested loop — those signal conditional or ordered
13
+ logic, not a parallel map) and awaits a call that **uses the loop variable**
14
+ (so each iteration is a distinct, independent call); or
15
+ * a comprehension / generator expression with an `await` in its element or a
16
+ per-element `if` (those have no ordered side effects).
17
+
18
+ It does NOT fire for: `while` loops (pagination, polling, queue drains — length
19
+ unknown, inherently sequential), a loop's once-evaluated iterable
20
+ (`for x in await fetch()`), `async for`, test modules (intentional ordering),
21
+ or a `for` body containing control flow. Those were the false-positive sources.
22
+
23
+ References:
24
+ - https://docs.python.org/3/library/asyncio-task.html#running-tasks-concurrently
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import ast
30
+ from typing import TYPE_CHECKING, override
31
+
32
+ from sarj_python_lint.rule_base import Diagnostic, Rule
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from pathlib import Path
37
+
38
+
39
+ def _is_test_path(path: Path) -> bool:
40
+ name = path.name
41
+ if name == "conftest.py" or name.startswith("test_") or name.endswith("_test.py"):
42
+ return True
43
+ return any(part in {"tests", "test"} for part in path.parts)
44
+
45
+
46
+ class NoSequentialAwait(Rule):
47
+ """Sequential await calls in a loop that could be parallelized."""
48
+
49
+ id: str = "no-sequential-await"
50
+ code: str = "SARJ001"
51
+ description: str = "Sequential `await` in a for-loop — prefer asyncio.gather."
52
+
53
+ @override
54
+ def check(self, path: Path, source: str) -> list[Diagnostic]:
55
+ if _is_test_path(path):
56
+ return []
57
+ try:
58
+ tree = ast.parse(source, filename=str(path))
59
+ except SyntaxError:
60
+ return []
61
+ visitor = _SequentialAwaitVisitor()
62
+ visitor.visit(tree)
63
+ diags = [
64
+ Diagnostic(
65
+ path=path,
66
+ line=node.lineno,
67
+ col=node.col_offset + 1,
68
+ code=self.code,
69
+ message=("Sequential `await` inside `for` — prefer `asyncio.gather([f(x) for x in xs])`."),
70
+ )
71
+ for node in visitor.hits
72
+ ]
73
+ diags.sort(key=lambda d: (d.line, d.col))
74
+ return diags
75
+
76
+
77
+ # A loop's *iterable* is evaluated once in the enclosing scope, NOT per element:
78
+ # `for x in await fetch()` / `{x for x in await fetch()}` await once. Iterables
79
+ # are visited *before* the loop is pushed, so an await there attributes to an
80
+ # enclosing loop (if any), not this one.
81
+ _SCOPES = (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)
82
+
83
+ # Top-level body statements that signal conditional or ordered logic rather than
84
+ # a straight-line parallel map. A `for` whose body contains any of these is not
85
+ # treated as the gather antipattern.
86
+ _CONTROL_FLOW = (
87
+ ast.If,
88
+ ast.For,
89
+ ast.AsyncFor,
90
+ ast.While,
91
+ ast.With,
92
+ ast.AsyncWith,
93
+ ast.Try,
94
+ ast.Match,
95
+ ast.Return,
96
+ ast.Break,
97
+ ast.Continue,
98
+ ast.Raise,
99
+ )
100
+
101
+
102
+ def _names(node: ast.AST) -> set[str]:
103
+ return {n.id for n in ast.walk(node) if isinstance(n, ast.Name)}
104
+
105
+
106
+ def _is_gather_antipattern(node: ast.For) -> bool:
107
+ """True for `for x in xs: <straight-line body awaiting a call that uses x>`."""
108
+ if any(isinstance(stmt, _CONTROL_FLOW) for stmt in node.body):
109
+ return False
110
+ targets = _names(node.target)
111
+ for stmt in node.body:
112
+ for inner in ast.walk(stmt):
113
+ if isinstance(inner, ast.Await) and _names(inner) & targets:
114
+ return True
115
+ return False
116
+
117
+
118
+ class _SequentialAwaitVisitor(ast.NodeVisitor):
119
+ """Single O(n) pass: flag the first per-iteration `await` of each loop.
120
+
121
+ Maintains a stack of enclosing loops within the current function. The stack
122
+ resets at function boundaries so a loop in an outer function never claims an
123
+ `await` in a nested one. Each loop is flagged at most once. A loop's
124
+ once-evaluated iterable is excluded (see module comment).
125
+ """
126
+
127
+ def __init__(self) -> None:
128
+ super().__init__()
129
+ self._loops: list[ast.AST] = []
130
+ self._flagged: set[int] = set()
131
+ self.hits: list[ast.Await] = []
132
+
133
+ def _flag_if_in_loop(self, node: ast.Await) -> None:
134
+ if self._loops:
135
+ loop = self._loops[-1]
136
+ if id(loop) not in self._flagged:
137
+ self._flagged.add(id(loop))
138
+ self.hits.append(node)
139
+
140
+ def visit_For(self, node: ast.For) -> None:
141
+ # `<iter>` runs once in the enclosing scope; visit it before entering.
142
+ self.visit(node.iter)
143
+ # Only a straight-line per-element-await body is the gather antipattern;
144
+ # control-flow bodies (conditional/ordered) are not pushed, so awaits in
145
+ # them are not flagged for this loop.
146
+ antipattern = _is_gather_antipattern(node)
147
+ if antipattern:
148
+ self._loops.append(node)
149
+ self.visit(node.target)
150
+ for stmt in (*node.body, *node.orelse):
151
+ self.visit(stmt)
152
+ if antipattern:
153
+ self._loops.pop()
154
+
155
+ def _visit_comprehension(self, node: ast.AST, elements: tuple[ast.expr, ...]) -> None:
156
+ gens: list[ast.comprehension] = node.generators # pyright: ignore[reportAttributeAccessIssue]
157
+ # Outermost iterable is evaluated once in the enclosing scope.
158
+ self.visit(gens[0].iter)
159
+ self._loops.append(node)
160
+ for elt in elements:
161
+ self.visit(elt)
162
+ self.visit(gens[0].target)
163
+ for cond in gens[0].ifs:
164
+ self.visit(cond)
165
+ # Later generators iterate per element of the preceding one.
166
+ for gen in gens[1:]:
167
+ self.visit(gen.iter)
168
+ self.visit(gen.target)
169
+ for cond in gen.ifs:
170
+ self.visit(cond)
171
+ self._loops.pop()
172
+
173
+ def visit_ListComp(self, node: ast.ListComp) -> None:
174
+ self._visit_comprehension(node, (node.elt,))
175
+
176
+ def visit_SetComp(self, node: ast.SetComp) -> None:
177
+ self._visit_comprehension(node, (node.elt,))
178
+
179
+ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None:
180
+ self._visit_comprehension(node, (node.elt,))
181
+
182
+ def visit_DictComp(self, node: ast.DictComp) -> None:
183
+ self._visit_comprehension(node, (node.key, node.value))
184
+
185
+ @override
186
+ def generic_visit(self, node: ast.AST) -> None:
187
+ if isinstance(node, _SCOPES):
188
+ saved = self._loops
189
+ self._loops = []
190
+ super().generic_visit(node)
191
+ self._loops = saved
192
+ elif isinstance(node, ast.Await):
193
+ self._flag_if_in_loop(node)
194
+ super().generic_visit(node)
195
+ else:
196
+ super().generic_visit(node)
@@ -61,9 +61,9 @@ def _strip_sql_comments(text: str) -> str:
61
61
  class StoreInsertRequiresOnConflict(Rule):
62
62
  """Embedded INSERT in store code without ON CONFLICT — store writes must be upserts."""
63
63
 
64
- id = "store-insert-requires-on-conflict"
65
- code = "SARJ018"
66
- description = "Embedded SQL INSERT in store code without ON CONFLICT — store writes must be idempotent upserts."
64
+ id: str = "store-insert-requires-on-conflict"
65
+ code: str = "SARJ018"
66
+ description: str = "Embedded SQL INSERT in store code without ON CONFLICT — store writes must be idempotent upserts."
67
67
 
68
68
  @override
69
69
  def check(self, path: Path, source: str) -> list[Diagnostic]:
@@ -1,94 +0,0 @@
1
- """SARJ001: detect `for x in xs: await f(x)` patterns.
2
-
3
- Sequential `await` in a for-loop serializes I/O that could be parallelized
4
- with `asyncio.gather([f(x) for x in xs])`. The performance gap is often 10-100x
5
- for network-bound work (HTTP, DB queries, LLM calls).
6
-
7
- References:
8
- - https://docs.python.org/3/library/asyncio-task.html#running-tasks-concurrently
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- import ast
14
- from typing import TYPE_CHECKING, override
15
-
16
- from sarj_python_lint.rule_base import Diagnostic, Rule
17
-
18
-
19
- if TYPE_CHECKING:
20
- from pathlib import Path
21
-
22
-
23
- class NoSequentialAwait(Rule):
24
- """Sequential await calls in a loop that could be parallelized."""
25
-
26
- id: str = "no-sequential-await"
27
- code: str = "SARJ001"
28
- description: str = "Sequential `await` in a for-loop — prefer asyncio.gather."
29
-
30
- @override
31
- def check(self, path: Path, source: str) -> list[Diagnostic]:
32
- try:
33
- tree = ast.parse(source, filename=str(path))
34
- except SyntaxError:
35
- return []
36
- visitor = _SequentialAwaitVisitor()
37
- visitor.visit(tree)
38
- diags = [
39
- Diagnostic(
40
- path=path,
41
- line=node.lineno,
42
- col=node.col_offset + 1,
43
- code=self.code,
44
- message=(
45
- "Sequential `await` inside `for` — prefer "
46
- "`asyncio.gather([f(x) for x in xs])`."
47
- ),
48
- )
49
- for node in visitor.hits
50
- ]
51
- diags.sort(key=lambda d: (d.line, d.col))
52
- return diags
53
-
54
-
55
- # Loop-like constructs whose body runs once per element: `await` inside one of
56
- # them serializes the iterations. `async for` is deliberately absent — it is the
57
- # parallel-iteration construct, not the antipattern.
58
- _LOOPS = (ast.For, ast.While, ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)
59
- _SCOPES = (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)
60
-
61
-
62
- class _SequentialAwaitVisitor(ast.NodeVisitor):
63
- """Single O(n) pass: flag the first `await` of each enclosing loop.
64
-
65
- Maintains a stack of enclosing loops within the current function. The stack
66
- resets at function boundaries so a loop in an outer function never claims an
67
- `await` in a nested one. Each loop is flagged at most once.
68
- """
69
-
70
- def __init__(self) -> None:
71
- self._loops: list[ast.AST] = []
72
- self._flagged: set[int] = set()
73
- self.hits: list[ast.Await] = []
74
-
75
- @override
76
- def generic_visit(self, node: ast.AST) -> None:
77
- if isinstance(node, _SCOPES):
78
- saved = self._loops
79
- self._loops = []
80
- super().generic_visit(node)
81
- self._loops = saved
82
- elif isinstance(node, _LOOPS):
83
- self._loops.append(node)
84
- super().generic_visit(node)
85
- self._loops.pop()
86
- elif isinstance(node, ast.Await):
87
- if self._loops:
88
- loop = self._loops[-1]
89
- if id(loop) not in self._flagged:
90
- self._flagged.add(id(loop))
91
- self.hits.append(node)
92
- super().generic_visit(node)
93
- else:
94
- super().generic_visit(node)
@@ -1,391 +0,0 @@
1
- """SARJ005: flag poor-man's-result shapes — prefer a discriminated union.
2
-
3
- Three triggers:
4
-
5
- 1. **success-bool model** — a pydantic BaseModel with a bool status field plus
6
- Optional siblings:
7
-
8
- class Result(BaseModel):
9
- success: bool
10
- data: Optional[Data] = None
11
- error: Optional[str] = None
12
-
13
- allows illegal states (success=True with data=None, or success=False with
14
- data set). Use a discriminated union:
15
-
16
- class Success(BaseModel): data: Data
17
- class Failure(BaseModel): error: str
18
- Result = Union[Success, Failure]
19
-
20
- 2. **bool-tuple result** — a function whose return annotation is a two-element
21
- `tuple[bool, X]` / `tuple[X, bool]` (also `Tuple[...]` and `X | None`
22
- payloads): the classic `(ok, value)` poor-man's-result. Model
23
- success/failure as a discriminated union (e.g. `Ok[T] | Err`) instead of a
24
- bool-tuple — the bool and the payload can disagree.
25
-
26
- 3. **nullable cluster with a discriminator** — a pydantic BaseModel or
27
- dataclass with 3+ `X | None` / `Optional[X]` fields AND a str / StrEnum /
28
- Literal field named like a discriminator (`status`, `state`, `type`,
29
- `kind`, `result`, `outcome`):
30
-
31
- class Call(BaseModel):
32
- status: str
33
- started_at: datetime | None = None
34
- ended_at: datetime | None = None
35
- error: str | None = None
36
-
37
- Split into per-state models in a discriminated union (the CallState
38
- pattern: `PendingCall | ActiveCall | CompletedCall | FailedCall`) so each
39
- state carries exactly the fields that are valid for it.
40
-
41
- Query/filter inputs and PATCH-style partial-update DTOs legitimately hold
42
- many optional fields, so class names matching those conventions
43
- (`*Input` / `*Params` / `*Filter` / `*Query` / `Update*` / `Patch*` /
44
- `Upsert*`) are excluded from this trigger.
45
-
46
- A single-value `Literal` tag (e.g. `type: Literal["complete"]`) marks a model
47
- that is already an arm of a discriminated union, so it is excluded too — a
48
- multi-value `Literal[...]` is still treated as a poor-man's discriminator.
49
-
50
- References:
51
- - https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
52
- - https://en.wikipedia.org/wiki/Tagged_union
53
- """
54
-
55
- from __future__ import annotations
56
-
57
- import ast
58
- from typing import TYPE_CHECKING, override
59
-
60
- from sarj_python_lint.rule_base import Diagnostic, Rule
61
-
62
-
63
- if TYPE_CHECKING:
64
- from pathlib import Path
65
-
66
-
67
- STATUS_FIELDS = {"success", "ok", "is_success", "succeeded", "successful", "failed", "failure"}
68
- IGNORED_OPTIONAL_FIELDS = {
69
- "metadata",
70
- "meta",
71
- "debug",
72
- "debug_logs",
73
- "extra",
74
- "log",
75
- "logs",
76
- "traceback",
77
- "request_id",
78
- "trace_id",
79
- }
80
- DISCRIMINATOR_FIELD_NAMES = {"status", "state", "type", "kind", "result", "outcome"}
81
- NULLABLE_CLUSTER_THRESHOLD = 3
82
- # A bool status field plus this many Optional siblings trips the original trigger.
83
- OPTIONAL_SIBLINGS_THRESHOLD = 2
84
- # An (ok, value) bool-tuple has exactly two elements.
85
- _BOOL_TUPLE_LEN = 2
86
- # Query/filter inputs and partial-update DTOs are all-optional by design.
87
- DTO_CLASS_NAME_SUFFIXES = ("Input", "Params", "Filter", "Query")
88
- DTO_CLASS_NAME_PREFIXES = ("Update", "Patch", "Upsert")
89
-
90
-
91
- class PreferDiscriminatedUnion(Rule):
92
- """Bool-status models, bool-tuple results, status+Optionals — prefer a discriminated union."""
93
-
94
- id: str = "prefer-discriminated-union"
95
- code: str = "SARJ005"
96
- description: str = (
97
- "success:bool + Optionals, tuple[bool, X] results, or status + nullable "
98
- "cluster — use a discriminated union."
99
- )
100
-
101
- @override
102
- def check(self, path: Path, source: str) -> list[Diagnostic]:
103
- try:
104
- tree = ast.parse(source, filename=str(path))
105
- except SyntaxError:
106
- return []
107
- diags: list[Diagnostic] = []
108
- str_enum_names = _collect_str_enum_names(tree)
109
- for node in ast.walk(tree):
110
- if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
111
- diag = self._check_bool_tuple_return(path, node)
112
- if diag is not None:
113
- diags.append(diag)
114
- continue
115
- if not isinstance(node, ast.ClassDef):
116
- continue
117
- diag = self._check_class(path, node, str_enum_names)
118
- if diag is not None:
119
- diags.append(diag)
120
- return diags
121
-
122
- def _check_bool_tuple_return(
123
- self, path: Path, node: ast.FunctionDef | ast.AsyncFunctionDef
124
- ) -> Diagnostic | None:
125
- if not _is_bool_tuple(node.returns):
126
- return None
127
- returns_text = ast.unparse(node.returns) if node.returns else ""
128
- return Diagnostic(
129
- path=path,
130
- line=node.lineno,
131
- col=node.col_offset + 1,
132
- code=self.code,
133
- message=(
134
- f"`{node.name}` returns `{returns_text}` — a (ok, value) bool-tuple. "
135
- "Model success/failure as a discriminated union "
136
- "(e.g. `Ok[T] | Err`), not a bool-tuple."
137
- ),
138
- )
139
-
140
- def _check_class(
141
- self, path: Path, node: ast.ClassDef, str_enum_names: set[str]
142
- ) -> Diagnostic | None:
143
- is_model = _inherits_basemodel(node)
144
- is_dc = _is_dataclass(node)
145
- if not (is_model or is_dc):
146
- return None
147
- has_status_bool = False
148
- has_literal_tag = False
149
- optional_fields: list[str] = []
150
- discriminator_fields: list[str] = []
151
- for stmt in node.body:
152
- if not isinstance(stmt, ast.AnnAssign):
153
- continue
154
- if not isinstance(stmt.target, ast.Name):
155
- continue
156
- name = stmt.target.id
157
- if name in STATUS_FIELDS and _is_bool_annotation(stmt.annotation):
158
- has_status_bool = True
159
- if name in DISCRIMINATOR_FIELD_NAMES and _is_discriminator_type(
160
- stmt.annotation, str_enum_names
161
- ):
162
- discriminator_fields.append(name)
163
- if _is_single_value_literal(stmt.annotation):
164
- has_literal_tag = True
165
- if _is_optional(stmt.annotation) and name not in IGNORED_OPTIONAL_FIELDS:
166
- optional_fields.append(name)
167
- # Original trigger: bool status field + Optional siblings (BaseModel only).
168
- if is_model and has_status_bool and len(optional_fields) >= OPTIONAL_SIBLINGS_THRESHOLD:
169
- return Diagnostic(
170
- path=path,
171
- line=node.lineno,
172
- col=node.col_offset + 1,
173
- code=self.code,
174
- message=(
175
- f"`{node.name}` has a bool status field plus "
176
- f"Optional fields ({', '.join(optional_fields)}). "
177
- "Model as `Union[Success, Failure]` to make illegal "
178
- "states unrepresentable."
179
- ),
180
- )
181
- # Nullable-cluster trigger: discriminator-ish field + 3 or more nullables.
182
- # A single-value `Literal` tag (e.g. `type: Literal["complete"]`) marks a
183
- # model that is already a discriminated-union arm, not a poor-man's result.
184
- if (
185
- discriminator_fields
186
- and len(optional_fields) >= NULLABLE_CLUSTER_THRESHOLD
187
- and not _is_dto_class_name(node.name)
188
- and not has_literal_tag
189
- ):
190
- return Diagnostic(
191
- path=path,
192
- line=node.lineno,
193
- col=node.col_offset + 1,
194
- code=self.code,
195
- message=(
196
- f"`{node.name}` has a discriminator-ish field "
197
- f"(`{discriminator_fields[0]}`) plus {len(optional_fields)} nullable "
198
- f"fields ({', '.join(optional_fields)}). Split into per-state models "
199
- "in a discriminated union (the CallState pattern: "
200
- "`PendingCall | ActiveCall | CompletedCall | FailedCall`)."
201
- ),
202
- )
203
- return None
204
-
205
-
206
- def _is_dto_class_name(name: str) -> bool:
207
- """Query/filter input and partial-update DTO names are all-optional by design."""
208
- return name.endswith(DTO_CLASS_NAME_SUFFIXES) or name.startswith(DTO_CLASS_NAME_PREFIXES)
209
-
210
-
211
- def _is_single_value_literal(node: ast.AST | None) -> bool:
212
- """Detect a single-constant `Literal[X]` annotation.
213
-
214
- A one-element `Literal` (e.g. `type: Literal["complete"]`) is the canonical
215
- tag of a discriminated-union arm, so a model carrying one is already modelled
216
- correctly. A multi-value `Literal[...]` is still a poor-man's discriminator.
217
- """
218
- if node is None:
219
- return False
220
- if isinstance(node, ast.Constant) and isinstance(node.value, str):
221
- try:
222
- parsed = ast.parse(node.value, mode="eval")
223
- except SyntaxError:
224
- return False
225
- return _is_single_value_literal(parsed.body)
226
- if not isinstance(node, ast.Subscript):
227
- return False
228
- if _get_name_flat(node.value).rsplit(".", 1)[-1] != "Literal":
229
- return False
230
- slice_node = node.slice
231
- if isinstance(slice_node, ast.Tuple):
232
- return len(slice_node.elts) == 1
233
- return True
234
-
235
-
236
- def _inherits_basemodel(node: ast.ClassDef) -> bool:
237
- for base in node.bases:
238
- if isinstance(base, ast.Name) and base.id == "BaseModel":
239
- return True
240
- if isinstance(base, ast.Attribute) and base.attr == "BaseModel":
241
- return True
242
- return False
243
-
244
-
245
- def _is_dataclass(node: ast.ClassDef) -> bool:
246
- """Detect `@dataclass`, `@dataclasses.dataclass`, and called forms."""
247
- for deco in node.decorator_list:
248
- target = deco.func if isinstance(deco, ast.Call) else deco
249
- name = _get_name_flat(target)
250
- if name == "dataclass" or name.endswith(".dataclass"):
251
- return True
252
- return False
253
-
254
-
255
- def _collect_str_enum_names(tree: ast.Module) -> set[str]:
256
- """Names of classes in this module that look like string enums.
257
-
258
- Matches `class X(StrEnum)`, `class X(enum.StrEnum)`, and the
259
- pre-3.11 `class X(str, Enum)` spelling.
260
- """
261
- names: set[str] = set()
262
- for node in ast.walk(tree):
263
- if not isinstance(node, ast.ClassDef):
264
- continue
265
- base_names = {_get_name_flat(base).rsplit(".", 1)[-1] for base in node.bases}
266
- if "StrEnum" in base_names or {"str", "Enum"} <= base_names:
267
- names.add(node.name)
268
- return names
269
-
270
-
271
- def _is_bool_tuple(node: ast.AST | None) -> bool:
272
- """Detect a two-element `tuple[bool, X]` / `tuple[X, bool]` annotation."""
273
- if node is None:
274
- return False
275
- if isinstance(node, ast.Constant) and isinstance(node.value, str):
276
- try:
277
- parsed = ast.parse(node.value, mode="eval")
278
- except SyntaxError:
279
- return False
280
- return _is_bool_tuple(parsed.body)
281
- if not isinstance(node, ast.Subscript):
282
- return False
283
- name = _get_name_flat(node.value).rsplit(".", 1)[-1]
284
- if name not in {"tuple", "Tuple"}:
285
- return False
286
- slice_node = node.slice
287
- if not isinstance(slice_node, ast.Tuple) or len(slice_node.elts) != _BOOL_TUPLE_LEN:
288
- return False
289
- elts = slice_node.elts
290
- # `tuple[bool, ...]` is a homogeneous variadic tuple, not an (ok, value) pair.
291
- if any(isinstance(e, ast.Constant) and e.value is Ellipsis for e in elts):
292
- return False
293
- return any(_is_bool(e) for e in elts)
294
-
295
-
296
- def _is_bool(node: ast.AST) -> bool:
297
- if isinstance(node, ast.Name):
298
- return node.id == "bool"
299
- if isinstance(node, ast.Attribute):
300
- return node.attr == "bool"
301
- return False
302
-
303
-
304
- def _is_bool_annotation(node: ast.AST | None) -> bool:
305
- """True if the annotation is `bool` (optionally unioned, e.g. `bool | None`).
306
-
307
- A parsed-node check, so `success: BoolishFlag` no longer trips the substring
308
- `"bool" in ast.unparse(...)` heuristic.
309
- """
310
- if node is None:
311
- return False
312
- if _is_bool(node):
313
- return True
314
- if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
315
- return _is_bool_annotation(node.left) or _is_bool_annotation(node.right)
316
- if isinstance(node, ast.Constant) and isinstance(node.value, str):
317
- try:
318
- parsed = ast.parse(node.value, mode="eval")
319
- except SyntaxError:
320
- return False
321
- return _is_bool_annotation(parsed.body)
322
- return False
323
-
324
-
325
- def _is_discriminator_type(node: ast.AST | None, str_enum_names: set[str]) -> bool:
326
- """Detect a str / StrEnum / Literal annotation (optionally unioned with None)."""
327
- if node is None:
328
- return False
329
- if isinstance(node, ast.Constant) and isinstance(node.value, str):
330
- try:
331
- parsed = ast.parse(node.value, mode="eval")
332
- except SyntaxError:
333
- return False
334
- return _is_discriminator_type(parsed.body, str_enum_names)
335
- if isinstance(node, ast.Name):
336
- return node.id == "str" or node.id in str_enum_names
337
- if isinstance(node, ast.Attribute):
338
- return node.attr in str_enum_names
339
- if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
340
- return _is_discriminator_type(node.left, str_enum_names) or _is_discriminator_type(
341
- node.right, str_enum_names
342
- )
343
- if isinstance(node, ast.Subscript):
344
- name = _get_name_flat(node.value).rsplit(".", 1)[-1]
345
- if name == "Literal":
346
- return True
347
- if name == "Optional":
348
- return _is_discriminator_type(node.slice, str_enum_names)
349
- return False
350
-
351
-
352
- def _get_name_flat(node: ast.AST) -> str:
353
- if isinstance(node, ast.Name):
354
- return node.id
355
- if isinstance(node, ast.Attribute):
356
- val = _get_name_flat(node.value)
357
- if val:
358
- return f"{val}.{node.attr}"
359
- return ""
360
-
361
-
362
- def _is_optional(node: ast.AST | None) -> bool:
363
- """Detect if an annotation represents an Optional type or Union with None."""
364
- if node is None:
365
- return False
366
-
367
- # If it's a string literal (forward ref), parse it and check the inner AST
368
- if isinstance(node, ast.Constant) and isinstance(node.value, str):
369
- try:
370
- parsed = ast.parse(node.value, mode="eval")
371
- return _is_optional(parsed.body)
372
- except SyntaxError:
373
- pass
374
-
375
- if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
376
- return _is_optional(node.left) or _is_optional(node.right)
377
-
378
- if isinstance(node, ast.Subscript):
379
- name = _get_name_flat(node.value)
380
- if name == "Optional" or name.endswith(".Optional"):
381
- return True
382
- if name == "Union" or name.endswith(".Union"):
383
- slice_node = node.slice
384
- if isinstance(slice_node, ast.Tuple):
385
- return any(_is_optional(elt) for elt in slice_node.elts)
386
- return _is_optional(slice_node)
387
-
388
- if isinstance(node, ast.Constant) and node.value is None:
389
- return True
390
-
391
- return isinstance(node, ast.Name) and node.id == "None"