sarj-python-lint 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,32 @@
1
+ node_modules/
2
+ dist/
3
+ build/
4
+ *.tgz
5
+ .npm/
6
+ *.tsbuildinfo
7
+ .turbo/
8
+ .tsup/
9
+ .vitest-cache/
10
+ .eslintcache
11
+ coverage/
12
+
13
+ __pycache__/
14
+ *.py[co]
15
+ .pytest_cache/
16
+ .ruff_cache/
17
+ .mypy_cache/
18
+ .basedpyright/
19
+ .coverage
20
+ .coverage.*
21
+ htmlcov/
22
+ .tox/
23
+ .venv/
24
+ venv/
25
+ *.egg-info/
26
+
27
+ .DS_Store
28
+ Thumbs.db
29
+ .idea/
30
+ .vscode/
31
+ .zed/
32
+ *.swp
@@ -0,0 +1,52 @@
1
+ Metadata-Version: 2.4
2
+ Name: sarj-python-lint
3
+ Version: 0.2.0
4
+ Summary: Custom Python lint rules — AST-based, pre-commit-friendly, hypermodern defaults
5
+ Project-URL: Homepage, https://github.com/sarj-ai/linting/tree/main/packages/python
6
+ Project-URL: Repository, https://github.com/sarj-ai/linting
7
+ Project-URL: Issues, https://github.com/sarj-ai/linting/issues
8
+ Author: sarj-ai
9
+ License: MIT
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Classifier: Topic :: Software Development :: Quality Assurance
16
+ Requires-Python: >=3.13
17
+ Description-Content-Type: text/markdown
18
+
19
+ # sarj-python-lint
20
+
21
+ Custom Python lint rules via stdlib `ast`. Designed for pre-commit. For SQL rules see [`sarj-sql-lint`](../sql/).
22
+
23
+ ```bash
24
+ uv tool install sarj-python-lint
25
+ ```
26
+
27
+ ## Pre-commit
28
+
29
+ ```yaml
30
+ - repo: https://github.com/sarj-ai/linting
31
+ rev: python-v0.2.0
32
+ hooks:
33
+ - id: sarj-no-sequential-await
34
+ - id: sarj-inefficient-string-concat-in-loop
35
+ - id: sarj-prefer-discriminated-union
36
+ - id: sarj-prefer-str-enum
37
+ ```
38
+
39
+ ## CLI
40
+
41
+ ```bash
42
+ sarj-python-lint check --rule no-sequential-await path/to/file.py
43
+ sarj-python-lint list-rules
44
+ ```
45
+
46
+ Diagnostic format is `path:line:col: CODE message` — Ruff-compatible.
47
+
48
+ ## Suppression
49
+
50
+ Inline `# sarj-noqa: SARJ00X — <reason>` on the offending line.
51
+
52
+ Each rule's source under `src/sarj_python_lint/rules/` carries its own `description` and diagnostic message.
@@ -0,0 +1,34 @@
1
+ # sarj-python-lint
2
+
3
+ Custom Python lint rules via stdlib `ast`. Designed for pre-commit. For SQL rules see [`sarj-sql-lint`](../sql/).
4
+
5
+ ```bash
6
+ uv tool install sarj-python-lint
7
+ ```
8
+
9
+ ## Pre-commit
10
+
11
+ ```yaml
12
+ - repo: https://github.com/sarj-ai/linting
13
+ rev: python-v0.2.0
14
+ hooks:
15
+ - id: sarj-no-sequential-await
16
+ - id: sarj-inefficient-string-concat-in-loop
17
+ - id: sarj-prefer-discriminated-union
18
+ - id: sarj-prefer-str-enum
19
+ ```
20
+
21
+ ## CLI
22
+
23
+ ```bash
24
+ sarj-python-lint check --rule no-sequential-await path/to/file.py
25
+ sarj-python-lint list-rules
26
+ ```
27
+
28
+ Diagnostic format is `path:line:col: CODE message` — Ruff-compatible.
29
+
30
+ ## Suppression
31
+
32
+ Inline `# sarj-noqa: SARJ00X — <reason>` on the offending line.
33
+
34
+ Each rule's source under `src/sarj_python_lint/rules/` carries its own `description` and diagnostic message.
@@ -0,0 +1,55 @@
1
+ [project]
2
+ name = "sarj-python-lint"
3
+ version = "0.2.0"
4
+ description = "Custom Python lint rules — AST-based, pre-commit-friendly, hypermodern defaults"
5
+ readme = "README.md"
6
+ authors = [{ name = "sarj-ai" }]
7
+ license = { text = "MIT" }
8
+ requires-python = ">=3.13"
9
+ classifiers = [
10
+ "Development Status :: 4 - Beta",
11
+ "Intended Audience :: Developers",
12
+ "License :: OSI Approved :: MIT License",
13
+ "Programming Language :: Python :: 3",
14
+ "Programming Language :: Python :: 3.13",
15
+ "Topic :: Software Development :: Quality Assurance",
16
+ ]
17
+ dependencies = []
18
+
19
+ [project.scripts]
20
+ sarj-python-lint = "sarj_python_lint.__main__:main"
21
+
22
+ [project.urls]
23
+ Homepage = "https://github.com/sarj-ai/linting/tree/main/packages/python"
24
+ Repository = "https://github.com/sarj-ai/linting"
25
+ Issues = "https://github.com/sarj-ai/linting/issues"
26
+
27
+ [dependency-groups]
28
+ dev = [
29
+ "pytest>=8.0",
30
+ "pytest-benchmark>=4.0",
31
+ "ruff>=0.6",
32
+ "basedpyright>=1.20",
33
+ ]
34
+
35
+ [build-system]
36
+ requires = ["hatchling"]
37
+ build-backend = "hatchling.build"
38
+
39
+ [tool.hatch.build.targets.wheel]
40
+ packages = ["src/sarj_python_lint"]
41
+
42
+ [tool.hatch.build.targets.sdist]
43
+ include = [
44
+ "src",
45
+ "README.md",
46
+ "pyproject.toml",
47
+ ]
48
+ exclude = [
49
+ "tests/",
50
+ "**/__pycache__/",
51
+ "**/*.pyc",
52
+ ]
53
+
54
+ [tool.pytest.ini_options]
55
+ testpaths = ["tests"]
@@ -0,0 +1,3 @@
1
+ """sarj-python-lint — custom Python + SQL lint rules."""
2
+
3
+ __version__ = "0.1.4"
@@ -0,0 +1,99 @@
1
+ """CLI: sarj-python-lint check --rule <id> [--rule <id2>] <files>"""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ from sarj_python_lint import __version__
9
+ from sarj_python_lint.rule_base import Diagnostic, is_suppressed
10
+ from sarj_python_lint.rules import REGISTRY
11
+
12
+
13
+ SKIP_DIR_NAMES = {
14
+ "node_modules", ".venv", "venv", ".git", "dist", "build", ".next",
15
+ "coverage", "__pycache__", ".pytest_cache", ".ruff_cache", ".mypy_cache",
16
+ ".turbo", ".yarn", ".pnpm-store",
17
+ }
18
+
19
+
20
+ def _expand_paths(paths: list[Path]) -> list[Path]:
21
+ out: list[Path] = []
22
+ for p in paths:
23
+ if not p.exists():
24
+ continue
25
+ if p.is_file():
26
+ out.append(p)
27
+ continue
28
+ for child in p.rglob("*.py"):
29
+ if not child.is_file():
30
+ continue
31
+ if any(part in SKIP_DIR_NAMES for part in child.parts):
32
+ continue
33
+ try:
34
+ if child.stat().st_size > 500_000:
35
+ continue
36
+ except OSError:
37
+ continue
38
+ out.append(child)
39
+ return out
40
+
41
+
42
+ def _check(rule_ids: list[str], paths: list[Path]) -> list[Diagnostic]:
43
+ unknown = [rid for rid in rule_ids if rid not in REGISTRY]
44
+ if unknown:
45
+ sys.stderr.write(f"unknown rule(s): {', '.join(unknown)}\n")
46
+ sys.stderr.write(f"available: {', '.join(sorted(REGISTRY))}\n")
47
+ raise SystemExit(2)
48
+ rules = [REGISTRY[rid]() for rid in rule_ids]
49
+ expanded = _expand_paths(paths)
50
+ diags: list[Diagnostic] = []
51
+ for p in expanded:
52
+ try:
53
+ source = p.read_text(encoding="utf-8", errors="replace")
54
+ except OSError:
55
+ continue
56
+ source_lines = source.splitlines()
57
+ for rule in rules:
58
+ for d in rule.check(p, source):
59
+ if is_suppressed(source_lines, d.line, d.code):
60
+ continue
61
+ diags.append(d)
62
+ return diags
63
+
64
+
65
+ def main(argv: list[str] | None = None) -> int:
66
+ parser = argparse.ArgumentParser(
67
+ prog="sarj-python-lint",
68
+ description="Custom Python + SQL lint rules.",
69
+ )
70
+ parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}")
71
+ sub = parser.add_subparsers(dest="cmd", required=True)
72
+
73
+ check_p = sub.add_parser("check", help="Run rules over files.")
74
+ check_p.add_argument(
75
+ "--rule",
76
+ action="append",
77
+ required=True,
78
+ help="Rule ID (repeat for multiple).",
79
+ )
80
+ check_p.add_argument("files", nargs="+", type=Path)
81
+
82
+ sub.add_parser("list-rules", help="List available rule IDs.")
83
+
84
+ args = parser.parse_args(argv)
85
+
86
+ if args.cmd == "list-rules":
87
+ for rid, cls in sorted(REGISTRY.items()):
88
+ inst = cls()
89
+ sys.stdout.write(f"{inst.code:8} {rid:40} {inst.description}\n")
90
+ return 0
91
+
92
+ diags = _check(args.rule, args.files)
93
+ for d in diags:
94
+ sys.stdout.write(d.format() + "\n")
95
+ return 1 if diags else 0
96
+
97
+
98
+ if __name__ == "__main__":
99
+ sys.exit(main())
File without changes
@@ -0,0 +1,70 @@
1
+ """Base types for sarj-python-lint rules."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+
10
+ # Suppression syntax. Two forms supported:
11
+ # # sarj-noqa: SARJ001 — reason
12
+ # # sarj-noqa: SARJ001, SARJ002 — reason
13
+ # We deliberately do NOT use `# noqa` because ruff aggressively cleans
14
+ # unrecognized noqa codes (RUF100/RUF102) even with `external` set, which
15
+ # silently breaks suppressions across runs. Distinct prefix = no conflict.
16
+ _SARJ_NOQA_RE = re.compile(
17
+ r"#\s*sarj-noqa(?::\s*([A-Za-z0-9_, ]+))?",
18
+ re.IGNORECASE,
19
+ )
20
+
21
+
22
+ def is_suppressed(source_lines: list[str], line: int, code: str) -> bool:
23
+ """Return True if the diagnostic's line carries a `# sarj-noqa[: CODE]` comment.
24
+
25
+ `line` is 1-based to match Diagnostic.line.
26
+ """
27
+ if line < 1 or line > len(source_lines):
28
+ return False
29
+ text = source_lines[line - 1]
30
+ m = _SARJ_NOQA_RE.search(text)
31
+ if not m:
32
+ return False
33
+ codes_str = m.group(1)
34
+ if not codes_str:
35
+ # Bare `# sarj-noqa` suppresses every SARJ code on the line
36
+ return True
37
+ codes = {c.strip().upper() for c in codes_str.split(",") if c.strip()}
38
+ return code.upper() in codes
39
+
40
+
41
+ @dataclass(frozen=True, slots=True)
42
+ class Diagnostic:
43
+ """A single lint finding."""
44
+
45
+ path: Path
46
+ line: int
47
+ col: int
48
+ code: str
49
+ message: str
50
+
51
+ def format(self) -> str:
52
+ """Ruff-compatible: `path:line:col: CODE message`."""
53
+ return f"{self.path}:{self.line}:{self.col}: {self.code} {self.message}"
54
+
55
+
56
+ class Rule(ABC):
57
+ """Base class for a single lint rule.
58
+
59
+ Subclasses set `id` (kebab-case) and `code` (e.g. SARJ001) as class
60
+ attributes and implement `check(path, source) -> list[Diagnostic]`.
61
+ """
62
+
63
+ id: str
64
+ code: str
65
+ description: str
66
+
67
+ @abstractmethod
68
+ def check(self, path: Path, source: str) -> list[Diagnostic]:
69
+ """Inspect the given source. Return zero or more diagnostics."""
70
+ raise NotImplementedError
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+
3
+ from sarj_python_lint.rule_base import Rule
4
+ from sarj_python_lint.rules.inefficient_string_concat_in_loop import (
5
+ InefficientStringConcatInLoop,
6
+ )
7
+ from sarj_python_lint.rules.no_sequential_await import NoSequentialAwait
8
+ from sarj_python_lint.rules.prefer_discriminated_union import PreferDiscriminatedUnion
9
+ from sarj_python_lint.rules.prefer_str_enum import PreferStrEnum
10
+
11
+
12
+ REGISTRY: dict[str, type[Rule]] = {
13
+ NoSequentialAwait.id: NoSequentialAwait,
14
+ InefficientStringConcatInLoop.id: InefficientStringConcatInLoop,
15
+ PreferDiscriminatedUnion.id: PreferDiscriminatedUnion,
16
+ PreferStrEnum.id: PreferStrEnum,
17
+ }
18
+
19
+ __all__ = ["REGISTRY"]
@@ -0,0 +1,74 @@
1
+ """SARJ002: detect `s += "..."` inside loops.
2
+
3
+ String concatenation with `+=` inside a loop is O(n²) in CPython because
4
+ strings are immutable — each `+=` allocates a new string and copies the
5
+ previous one. Append to a list and `"".join(parts)` at the end for O(n).
6
+
7
+ References:
8
+ - https://docs.python.org/3/library/stdtypes.html#str.join
9
+ - https://wiki.python.org/moin/PythonSpeed/PerformanceTips
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import ast
15
+ from pathlib import Path
16
+
17
+ from sarj_python_lint.rule_base import Diagnostic, Rule
18
+
19
+
20
+ class InefficientStringConcatInLoop(Rule):
21
+ """O(n²) string concatenation in a loop."""
22
+
23
+ id = "inefficient-string-concat-in-loop"
24
+ code = "SARJ002"
25
+ description = "`s += '...'` in a loop is O(n²); append to a list and join."
26
+
27
+ def check(self, path: Path, source: str) -> list[Diagnostic]:
28
+ try:
29
+ tree = ast.parse(source, filename=str(path))
30
+ except SyntaxError:
31
+ return []
32
+ diags: list[Diagnostic] = []
33
+ for loop in ast.walk(tree):
34
+ if not isinstance(loop, (ast.For, ast.While)):
35
+ continue
36
+ for node in ast.walk(loop):
37
+ if not isinstance(node, ast.AugAssign):
38
+ continue
39
+ if not isinstance(node.op, ast.Add):
40
+ continue
41
+ # Heuristic: the RHS is a string-like value
42
+ if not _looks_like_string(node.value):
43
+ continue
44
+ diags.append(
45
+ Diagnostic(
46
+ path=path,
47
+ line=node.lineno,
48
+ col=node.col_offset + 1,
49
+ code=self.code,
50
+ message=(
51
+ "`+=` string concat in a loop is O(n²). "
52
+ "Append to a list and `''.join(...)`."
53
+ ),
54
+ )
55
+ )
56
+ return diags
57
+
58
+
59
+ def _looks_like_string(node: ast.AST) -> bool:
60
+ """Heuristic for 'this RHS is probably a string at runtime'."""
61
+ if isinstance(node, ast.Constant) and isinstance(node.value, str):
62
+ return True
63
+ if isinstance(node, ast.JoinedStr): # f-string
64
+ return True
65
+ if isinstance(node, ast.Call):
66
+ # str(...) / repr(...) / format / strftime — usually string
67
+ if isinstance(node.func, ast.Name) and node.func.id in {"str", "repr", "format"}:
68
+ return True
69
+ if isinstance(node.func, ast.Attribute):
70
+ return node.func.attr in {"format", "strftime", "join"}
71
+ if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
72
+ # `+` propagates string-ness if either side is a string
73
+ return _looks_like_string(node.left) or _looks_like_string(node.right)
74
+ return False
@@ -0,0 +1,71 @@
1
+ """SARJ001: detect `for x in xs: await f(x)` patterns.
2
+
3
+ Sequential `await` in a for-loop serializes I/O that could be parallelized
4
+ with `asyncio.gather([f(x) for x in xs])`. The performance gap is often 10-100x
5
+ for network-bound work (HTTP, DB queries, LLM calls).
6
+
7
+ References:
8
+ - https://docs.python.org/3/library/asyncio-task.html#running-tasks-concurrently
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import ast
14
+ from pathlib import Path
15
+
16
+ from sarj_python_lint.rule_base import Diagnostic, Rule
17
+
18
+
19
+ class NoSequentialAwait(Rule):
20
+ """Sequential await calls in a loop that could be parallelized."""
21
+
22
+ id = "no-sequential-await"
23
+ code = "SARJ001"
24
+ description = "Sequential `await` in a for-loop — prefer asyncio.gather."
25
+
26
+ def check(self, path: Path, source: str) -> list[Diagnostic]:
27
+ try:
28
+ tree = ast.parse(source, filename=str(path))
29
+ except SyntaxError:
30
+ return []
31
+ diags: list[Diagnostic] = []
32
+ for node in ast.walk(tree):
33
+ if not isinstance(node, ast.For):
34
+ continue
35
+ # `async for` is fine — that's the parallel-iteration construct
36
+ for child in ast.walk(node):
37
+ # Skip nested For loops to avoid double-reporting the outer one
38
+ if isinstance(child, ast.Await) and _enclosing_loop(child, tree) is node:
39
+ diags.append(
40
+ Diagnostic(
41
+ path=path,
42
+ line=child.lineno,
43
+ col=child.col_offset + 1,
44
+ code=self.code,
45
+ message=(
46
+ "Sequential `await` inside `for` — prefer "
47
+ "`asyncio.gather([f(x) for x in xs])`."
48
+ ),
49
+ )
50
+ )
51
+ break # one diag per for-loop
52
+ return diags
53
+
54
+
55
+ def _enclosing_loop(node: ast.AST, root: ast.AST) -> ast.AST | None:
56
+ """Walk parents until we find the nearest enclosing `ast.For`."""
57
+ # Build a child→parent map lazily on the first call. ast doesn't track
58
+ # parents, so we walk the whole tree once per check.
59
+ parent: dict[int, ast.AST] = {}
60
+ for child in ast.walk(root):
61
+ for grandchild in ast.iter_child_nodes(child):
62
+ parent[id(grandchild)] = child
63
+ cur: ast.AST | None = node
64
+ while cur is not None:
65
+ cur = parent.get(id(cur))
66
+ if isinstance(cur, ast.For):
67
+ return cur
68
+ if isinstance(cur, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)):
69
+ # Don't escape function boundaries
70
+ return None
71
+ return None
@@ -0,0 +1,148 @@
1
+ """SARJ005: flag BaseModel with `success: bool` + Optional fields.
2
+
3
+ The anti-pattern:
4
+
5
+ class Result(BaseModel):
6
+ success: bool
7
+ data: Optional[Data] = None
8
+ error: Optional[str] = None
9
+
10
+ allows illegal states (success=True with data=None, or success=False with
11
+ data set). Use a discriminated union:
12
+
13
+ class Success(BaseModel): data: Data
14
+ class Failure(BaseModel): error: str
15
+ Result = Union[Success, Failure]
16
+
17
+ References:
18
+ - https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
19
+ - https://en.wikipedia.org/wiki/Tagged_union
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import ast
25
+ from pathlib import Path
26
+
27
+ from sarj_python_lint.rule_base import Diagnostic, Rule
28
+
29
+ STATUS_FIELDS = {"success", "ok", "is_success", "succeeded", "successful", "failed", "failure"}
30
+ IGNORED_OPTIONAL_FIELDS = {
31
+ "metadata",
32
+ "meta",
33
+ "debug",
34
+ "debug_logs",
35
+ "extra",
36
+ "log",
37
+ "logs",
38
+ "traceback",
39
+ "request_id",
40
+ "trace_id",
41
+ }
42
+
43
+
44
+ class PreferDiscriminatedUnion(Rule):
45
+ """Pydantic BaseModel with success:bool — prefer a discriminated union."""
46
+
47
+ id = "prefer-discriminated-union"
48
+ code = "SARJ005"
49
+ description = "BaseModel with `success: bool` + Optional siblings — use a discriminated union."
50
+
51
+ def check(self, path: Path, source: str) -> list[Diagnostic]:
52
+ try:
53
+ tree = ast.parse(source, filename=str(path))
54
+ except SyntaxError:
55
+ return []
56
+ diags: list[Diagnostic] = []
57
+ for node in ast.walk(tree):
58
+ if not isinstance(node, ast.ClassDef):
59
+ continue
60
+ if not _inherits_basemodel(node):
61
+ continue
62
+ has_status_bool = False
63
+ optional_fields: list[str] = []
64
+ for stmt in node.body:
65
+ if not isinstance(stmt, ast.AnnAssign):
66
+ continue
67
+ if not isinstance(stmt.target, ast.Name):
68
+ continue
69
+ name = stmt.target.id
70
+ ann_text = ast.unparse(stmt.annotation) if stmt.annotation else ""
71
+ if name in STATUS_FIELDS and "bool" in ann_text:
72
+ has_status_bool = True
73
+ if _is_optional(stmt.annotation):
74
+ if name not in IGNORED_OPTIONAL_FIELDS:
75
+ optional_fields.append(name)
76
+ if has_status_bool and len(optional_fields) >= 2:
77
+ diags.append(
78
+ Diagnostic(
79
+ path=path,
80
+ line=node.lineno,
81
+ col=node.col_offset + 1,
82
+ code=self.code,
83
+ message=(
84
+ f"`{node.name}` has a bool status field plus "
85
+ f"Optional fields ({', '.join(optional_fields)}). "
86
+ "Model as `Union[Success, Failure]` to make illegal "
87
+ "states unrepresentable."
88
+ ),
89
+ )
90
+ )
91
+ return diags
92
+
93
+
94
+ def _inherits_basemodel(node: ast.ClassDef) -> bool:
95
+ for base in node.bases:
96
+ if isinstance(base, ast.Name) and base.id == "BaseModel":
97
+ return True
98
+ if isinstance(base, ast.Attribute) and base.attr == "BaseModel":
99
+ return True
100
+ return False
101
+
102
+
103
+ def _get_name_flat(node: ast.AST) -> str:
104
+ if isinstance(node, ast.Name):
105
+ return node.id
106
+ if isinstance(node, ast.Attribute):
107
+ val = _get_name_flat(node.value)
108
+ if val:
109
+ return f"{val}.{node.attr}"
110
+ return ""
111
+
112
+
113
+ def _is_optional(node: ast.AST | None) -> bool:
114
+ """Detect if an annotation represents an Optional type or Union with None."""
115
+ if node is None:
116
+ return False
117
+
118
+ # If it's a string literal (forward ref), parse it and check the inner AST
119
+ if isinstance(node, ast.Constant) and isinstance(node.value, str):
120
+ try:
121
+ parsed = ast.parse(node.value, mode="eval")
122
+ return _is_optional(parsed.body)
123
+ except SyntaxError:
124
+ pass
125
+
126
+ if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
127
+ return _is_optional(node.left) or _is_optional(node.right)
128
+
129
+ if isinstance(node, ast.Subscript):
130
+ name = _get_name_flat(node.value)
131
+ if name == "Optional" or name.endswith(".Optional"):
132
+ return True
133
+ if name == "Union" or name.endswith(".Union"):
134
+ slice_node = node.slice
135
+ # Handle Python < 3.9 Index wrapper safely
136
+ if type(slice_node).__name__ == "Index":
137
+ slice_node = getattr(slice_node, "value", slice_node)
138
+ if isinstance(slice_node, ast.Tuple):
139
+ return any(_is_optional(elt) for elt in slice_node.elts)
140
+ return _is_optional(slice_node)
141
+
142
+ if isinstance(node, ast.Constant) and node.value is None:
143
+ return True
144
+
145
+ if isinstance(node, ast.Name) and node.id == "None":
146
+ return True
147
+
148
+ return False
@@ -0,0 +1,107 @@
1
+ """SARJ006: Pydantic field with raw `str` annotation that looks like a choice/enum field.
2
+
3
+ `Literal["a", "b", "c"]` is acceptable — that's a proper closed set. This rule
4
+ only flags **raw `str`** annotations on fields whose name (`*_status`, `*_state`,
5
+ `*_type`, `*_kind`) or sibling class attribute (`choices`, `states`, `STATUSES`,
6
+ `values`, `allowed`) strongly suggests a closed enumeration is intended.
7
+
8
+ Replace with:
9
+ class Status(StrEnum):
10
+ ACTIVE = "active"
11
+ INACTIVE = "inactive"
12
+
13
+ References:
14
+ - https://docs.python.org/3/library/enum.html#enum.StrEnum
15
+ - https://docs.pydantic.dev/latest/concepts/types/#enums
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ from pathlib import Path
22
+
23
+ from sarj_python_lint.rule_base import Diagnostic, Rule
24
+
25
+
26
+ class PreferStrEnum(Rule):
27
+ """Choice-shaped str field — prefer StrEnum or Literal."""
28
+
29
+ id = "prefer-str-enum"
30
+ code = "SARJ006"
31
+ description = "Pydantic str field with choice-like default — prefer StrEnum."
32
+
33
+ def check(self, path: Path, source: str) -> list[Diagnostic]:
34
+ try:
35
+ tree = ast.parse(source, filename=str(path))
36
+ except SyntaxError:
37
+ return []
38
+ diags: list[Diagnostic] = []
39
+ for cls in ast.walk(tree):
40
+ if not isinstance(cls, ast.ClassDef):
41
+ continue
42
+ # Skip enum classes themselves
43
+ if any(_base_name(b) in {"Enum", "StrEnum", "IntEnum"} for b in cls.bases):
44
+ continue
45
+ # Find string-list class attrs that look like a choices set.
46
+ choices_attrs: set[str] = set()
47
+ for stmt in cls.body:
48
+ if isinstance(stmt, (ast.Assign, ast.AnnAssign)):
49
+ target = (
50
+ stmt.targets[0]
51
+ if isinstance(stmt, ast.Assign) and stmt.targets
52
+ else getattr(stmt, "target", None)
53
+ )
54
+ if not isinstance(target, ast.Name):
55
+ continue
56
+ val = getattr(stmt, "value", None)
57
+ if _is_string_collection(val) and target.id.lower() in {
58
+ "choices",
59
+ "states",
60
+ "statuses",
61
+ "values",
62
+ "allowed",
63
+ }:
64
+ choices_attrs.add(target.id)
65
+ # Flag bare-str AnnAssigns
66
+ for stmt in cls.body:
67
+ if not isinstance(stmt, ast.AnnAssign):
68
+ continue
69
+ if not isinstance(stmt.target, ast.Name):
70
+ continue
71
+ ann_text = ast.unparse(stmt.annotation) if stmt.annotation else ""
72
+ if ann_text.strip() != "str":
73
+ continue # Literal[...] etc. is fine per user L234
74
+ # Heuristic: there's a nearby choices list OR the field name
75
+ # ends with `_status` / `_state` / `_type`.
76
+ name = stmt.target.id
77
+ if choices_attrs or name.endswith(("_status", "_state", "_type", "_kind")):
78
+ diags.append(
79
+ Diagnostic(
80
+ path=path,
81
+ line=stmt.lineno,
82
+ col=stmt.col_offset + 1,
83
+ code=self.code,
84
+ message=(
85
+ f"`{name}: str` looks like a choice field — "
86
+ "prefer `StrEnum`. (`Literal[...]` is also acceptable.)"
87
+ ),
88
+ )
89
+ )
90
+ return diags
91
+
92
+
93
+ def _base_name(base: ast.AST) -> str | None:
94
+ if isinstance(base, ast.Name):
95
+ return base.id
96
+ if isinstance(base, ast.Attribute):
97
+ return base.attr
98
+ return None
99
+
100
+
101
+ def _is_string_collection(node: ast.AST | None) -> bool:
102
+ if not isinstance(node, (ast.List, ast.Tuple, ast.Set)):
103
+ return False
104
+ return all(
105
+ isinstance(elt, ast.Constant) and isinstance(elt.value, str)
106
+ for elt in node.elts
107
+ )