zshshellcheck 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zshcheck/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """ZshCheck - A static analysis tool for zsh shell scripts."""
2
+
3
+ __version__ = "0.1.0"
zshcheck/analyzer.py ADDED
@@ -0,0 +1,269 @@
1
+ """Analyzer module for zshcheck.
2
+
3
+ This module provides the main analysis engine that orchestrates parsing,
4
+ running checks, and collecting diagnostics.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import re
10
+ from pathlib import Path
11
+
12
+ from tree_sitter import Node
13
+
14
+ from zshcheck.checks.base import AnalysisContext, CheckRegistry, get_registry
15
+ from zshcheck.diagnostics import Diagnostic, Position, Range, Severity
16
+ from zshcheck.parser import ZshParser
17
+
18
+ NON_ASCII_PATTERN = re.compile(r"[^\x00-\x7F]")
19
+
20
+
21
+ def _check_non_ascii(source: str) -> list[Diagnostic]:
22
+ """Check for non-ASCII characters (including emojis) that may cause parsing issues.
23
+
24
+ Args:
25
+ source: The source code to check.
26
+
27
+ Returns:
28
+ List of diagnostics for non-ASCII characters found.
29
+ """
30
+ diagnostics: list[Diagnostic] = []
31
+ matches = list(NON_ASCII_PATTERN.finditer(source))
32
+
33
+ if not matches:
34
+ return diagnostics
35
+
36
+ chars = [m.group() for m in matches[:5]]
37
+ chars_str = " ".join(f"'{c}'" for c in chars)
38
+
39
+ diagnostic = Diagnostic(
40
+ code="ZC9004",
41
+ severity=Severity.INFO,
42
+ message=(
43
+ f"Non-ASCII characters detected ({len(matches)} found). "
44
+ "tree-sitter-zsh grammar may not parse these correctly. "
45
+ f"Found: {chars_str}"
46
+ ),
47
+ range=Range(Position(1, 1), Position(1, 1)),
48
+ )
49
+ diagnostics.append(diagnostic)
50
+ return diagnostics
51
+
52
+
53
+ def apply_fixes(source: str, fixes: list[Diagnostic]) -> str:
54
+ """Apply fixes to source code.
55
+
56
+ Args:
57
+ source: Original source code.
58
+ fixes: List of diagnostics with fixes to apply.
59
+
60
+ Returns:
61
+ Source code with fixes applied.
62
+ """
63
+ if not fixes:
64
+ return source
65
+
66
+ result = source
67
+ ordered_fixes = sorted(
68
+ [f for f in fixes if f.fixable],
69
+ key=lambda d: (d.range.start.line, d.range.start.column),
70
+ reverse=True,
71
+ )
72
+
73
+ for diagnostic in ordered_fixes:
74
+ if diagnostic.fix is None:
75
+ continue
76
+ for replacement in diagnostic.fix.replacements:
77
+ start_line = replacement.range.start.line - 1
78
+ start_col = replacement.range.start.column - 1
79
+ end_line = replacement.range.end.line - 1
80
+ end_col = replacement.range.end.column
81
+
82
+ lines = result.splitlines(keepends=True)
83
+ if start_line < 0 or start_line >= len(lines):
84
+ continue
85
+ if end_line < 0 or end_line >= len(lines):
86
+ continue
87
+
88
+ if start_line == end_line:
89
+ line = lines[start_line]
90
+ lines[start_line] = line[:start_col] + replacement.text + line[end_col:]
91
+ else:
92
+ first_line = lines[start_line]
93
+ lines[start_line] = first_line[:start_col] + replacement.text + "\n"
94
+ del lines[start_line + 1 : end_line + 1]
95
+
96
+ result = "".join(lines)
97
+
98
+ return result
99
+
100
+
101
+ class Analyzer:
102
+ """Main analysis engine for zsh shell scripts."""
103
+
104
+ def __init__(self, registry: CheckRegistry | None = None) -> None:
105
+ """Initialize the analyzer.
106
+
107
+ Args:
108
+ registry: Check registry to use (creates default if None).
109
+ """
110
+ self._parser = ZshParser()
111
+ self._registry = registry or get_registry()
112
+
113
+ def analyze_string(
114
+ self,
115
+ source: str,
116
+ filename: str | None = None,
117
+ include: list[str] | None = None,
118
+ exclude: list[str] | None = None,
119
+ ) -> list[Diagnostic]:
120
+ """Analyze a zsh script from a string.
121
+
122
+ Args:
123
+ source: The zsh script source code.
124
+ filename: Optional filename for context.
125
+ include: Optional list of check codes to run.
126
+ exclude: Optional list of check codes to skip.
127
+
128
+ Returns:
129
+ List of diagnostics found.
130
+ """
131
+ parse_result = self._parser.parse(source)
132
+
133
+ # Collect parse errors first
134
+ all_diagnostics: list[Diagnostic] = list(parse_result.diagnostics)
135
+
136
+ # Check for non-ASCII characters
137
+ all_diagnostics.extend(_check_non_ascii(source))
138
+
139
+ if not parse_result.success or parse_result.root_node is None:
140
+ return all_diagnostics
141
+
142
+ # Create analysis context
143
+ context = AnalysisContext(
144
+ source=source,
145
+ filename=filename,
146
+ )
147
+
148
+ # Run all checks on the AST
149
+ check_diagnostics = self._run_checks(
150
+ parse_result.root_node,
151
+ context,
152
+ include=include,
153
+ exclude=exclude,
154
+ )
155
+
156
+ all_diagnostics.extend(check_diagnostics)
157
+
158
+ # Sort by line number, then column
159
+ all_diagnostics.sort()
160
+
161
+ return all_diagnostics
162
+
163
+ def analyze_file(
164
+ self,
165
+ path: str | Path,
166
+ include: list[str] | None = None,
167
+ exclude: list[str] | None = None,
168
+ ) -> list[Diagnostic]:
169
+ """Analyze a zsh script from a file.
170
+
171
+ Args:
172
+ path: Path to the zsh script file.
173
+ include: Optional list of check codes to run.
174
+ exclude: Optional list of check codes to skip.
175
+
176
+ Returns:
177
+ List of diagnostics found.
178
+ """
179
+ file_path = Path(path)
180
+ parse_result = self._parser.parse_file(file_path)
181
+
182
+ # Collect parse errors first
183
+ all_diagnostics: list[Diagnostic] = list(parse_result.diagnostics)
184
+
185
+ # Check for non-ASCII characters
186
+ all_diagnostics.extend(_check_non_ascii(parse_result.source))
187
+
188
+ if not parse_result.success or parse_result.root_node is None:
189
+ return all_diagnostics
190
+
191
+ # Create analysis context
192
+ context = AnalysisContext(
193
+ source=parse_result.source,
194
+ filename=str(file_path),
195
+ )
196
+
197
+ # Run all checks on the AST
198
+ check_diagnostics = self._run_checks(
199
+ parse_result.root_node,
200
+ context,
201
+ include=include,
202
+ exclude=exclude,
203
+ )
204
+
205
+ all_diagnostics.extend(check_diagnostics)
206
+
207
+ # Sort by line number, then column
208
+ all_diagnostics.sort()
209
+
210
+ return all_diagnostics
211
+
212
+ def _run_checks(
213
+ self,
214
+ root_node: Node,
215
+ context: AnalysisContext,
216
+ include: list[str] | None = None,
217
+ exclude: list[str] | None = None,
218
+ ) -> list[Diagnostic]:
219
+ """Run all checks against the AST.
220
+
221
+ Args:
222
+ root_node: Root node of the AST.
223
+ context: Analysis context.
224
+ include: Optional list of check codes to run.
225
+ exclude: Optional list of check codes to skip.
226
+
227
+ Returns:
228
+ List of diagnostics found.
229
+ """
230
+ diagnostics: list[Diagnostic] = []
231
+ exclude_set = set(exclude or [])
232
+
233
+ def visit_node(node: Node, depth: int) -> None:
234
+ # Run checks that are not excluded
235
+ for check in self._registry.checks:
236
+ if check.code in exclude_set:
237
+ continue
238
+ if include is not None and check.code not in include:
239
+ continue
240
+
241
+ if diagnostic := check.check(node, context):
242
+ diagnostics.append(diagnostic)
243
+
244
+ # Visit children
245
+ for child in node.children:
246
+ visit_node(child, depth + 1)
247
+
248
+ visit_node(root_node, 0)
249
+ return diagnostics
250
+
251
+
252
+ def create_default_analyzer() -> Analyzer:
253
+ """Create an analyzer with all default checks registered."""
254
+ from zshcheck.checks.commands import DeprecatedCommandCheck
255
+ from zshcheck.checks.quoting import UnquotedVariableCheck
256
+ from zshcheck.checks.style import DoubleBracketCheck
257
+ from zshcheck.checks.variables import UnusedVariableCheck
258
+
259
+ registry = CheckRegistry()
260
+ registry.register_all(
261
+ [
262
+ UnquotedVariableCheck(),
263
+ UnusedVariableCheck(),
264
+ DeprecatedCommandCheck(),
265
+ DoubleBracketCheck(),
266
+ ]
267
+ )
268
+
269
+ return Analyzer(registry)
@@ -0,0 +1 @@
1
+ """Checks package for zshcheck."""
@@ -0,0 +1,225 @@
1
+ """Base check module for zshcheck.
2
+
3
+ This module defines the abstract base class that all checks must implement,
4
+ along with the check registry for managing and running checks.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass, field
11
+
12
+ from tree_sitter import Node
13
+
14
+ from zshcheck.diagnostics import Diagnostic, Severity
15
+
16
+
17
+ @dataclass
18
+ class AnalysisContext:
19
+ """Context information available during analysis.
20
+
21
+ This object maintains state about the current script being analyzed,
22
+ including variable scopes, function definitions, and shell options.
23
+
24
+ Attributes:
25
+ source: The full source code being analyzed.
26
+ filename: Name of the file being analyzed (if available).
27
+ variables: Dictionary of declared variables and their scopes.
28
+ functions: Set of defined function names.
29
+ aliases: Dictionary of defined aliases.
30
+ shell_options: Set of enabled shell options.
31
+ """
32
+
33
+ source: str
34
+ filename: str | None = None
35
+ variables: dict[str, str] = field(default_factory=dict) # name -> scope
36
+ functions: set[str] = field(default_factory=set)
37
+ aliases: dict[str, str] = field(default_factory=dict)
38
+ shell_options: set[str] = field(default_factory=set)
39
+
40
+ def is_variable_defined(self, name: str) -> bool:
41
+ """Check if a variable is defined in the current context."""
42
+ return name in self.variables
43
+
44
+ def is_function_defined(self, name: str) -> bool:
45
+ """Check if a function is defined."""
46
+ return name in self.functions
47
+
48
+
49
+ class BaseCheck(ABC):
50
+ """Abstract base class for all zshcheck checks.
51
+
52
+ Each check must implement the following properties and methods:
53
+ - code: Unique check identifier (e.g., "ZC1001")
54
+ - description: Human-readable description of what the check looks for
55
+ - severity: Default severity level for this check
56
+ - check: The actual check logic
57
+ """
58
+
59
+ @property
60
+ @abstractmethod
61
+ def code(self) -> str:
62
+ """Unique check code (e.g., "ZC1001").
63
+
64
+ Codes should follow this convention:
65
+ - ZC1xxx: Quoting and word splitting issues
66
+ - ZC2xxx: Variable tracking issues
67
+ - ZC3xxx: Zsh-specific issues
68
+ - ZC4xxx: Command usage issues
69
+ - ZC5xxx: Style suggestions
70
+ - ZC9xxx: Internal/parsing errors
71
+ """
72
+ pass
73
+
74
+ @property
75
+ @abstractmethod
76
+ def description(self) -> str:
77
+ """Short human-readable description of what this check looks for."""
78
+ pass
79
+
80
+ @property
81
+ @abstractmethod
82
+ def severity(self) -> Severity:
83
+ """Default severity level for diagnostics from this check."""
84
+ pass
85
+
86
+ @abstractmethod
87
+ def check(self, node: Node, context: AnalysisContext) -> Diagnostic | None:
88
+ """Analyze a single AST node and return a diagnostic if an issue is found.
89
+
90
+ Args:
91
+ node: The AST node to analyze.
92
+ context: Analysis context with script state information.
93
+
94
+ Returns:
95
+ A Diagnostic if an issue is found, None otherwise.
96
+ """
97
+ pass
98
+
99
+ def __repr__(self) -> str:
100
+ return f"{self.__class__.__name__}(code={self.code!r})"
101
+
102
+
103
+ class CheckRegistry:
104
+ """Registry for managing and running all checks.
105
+
106
+ This class maintains a collection of all available checks and provides
107
+ methods to run them against AST nodes.
108
+ """
109
+
110
+ def __init__(self) -> None:
111
+ """Initialize an empty check registry."""
112
+ self._checks: list[BaseCheck] = []
113
+
114
+ def register(self, check: BaseCheck) -> None:
115
+ """Register a check with the registry.
116
+
117
+ Args:
118
+ check: The check instance to register.
119
+ """
120
+ # Validate code uniqueness
121
+ for existing in self._checks:
122
+ if existing.code == check.code:
123
+ raise ValueError(f"Check with code {check.code!r} already registered")
124
+ self._checks.append(check)
125
+
126
+ def register_all(self, checks: list[BaseCheck]) -> None:
127
+ """Register multiple checks at once.
128
+
129
+ Args:
130
+ checks: List of check instances to register.
131
+ """
132
+ for check in checks:
133
+ self.register(check)
134
+
135
+ def get_check(self, code: str) -> BaseCheck | None:
136
+ """Get a check by its code.
137
+
138
+ Args:
139
+ code: The check code to look up.
140
+
141
+ Returns:
142
+ The check instance if found, None otherwise.
143
+ """
144
+ for check in self._checks:
145
+ if check.code == code:
146
+ return check
147
+ return None
148
+
149
+ @property
150
+ def checks(self) -> list[BaseCheck]:
151
+ """Return all registered checks."""
152
+ return self._checks.copy()
153
+
154
+ @property
155
+ def codes(self) -> list[str]:
156
+ """Return codes of all registered checks."""
157
+ return [check.code for check in self._checks]
158
+
159
+ def run_check(self, code: str, node: Node, context: AnalysisContext) -> Diagnostic | None:
160
+ """Run a specific check against a node.
161
+
162
+ Args:
163
+ code: The code of the check to run.
164
+ node: The AST node to check.
165
+ context: Analysis context.
166
+
167
+ Returns:
168
+ The diagnostic if found, None otherwise.
169
+ """
170
+ check = self.get_check(code)
171
+ if check is None:
172
+ raise ValueError(f"Unknown check code: {code!r}")
173
+ return check.check(node, context)
174
+
175
+ def run_all(
176
+ self,
177
+ node: Node,
178
+ context: AnalysisContext,
179
+ include: list[str] | None = None,
180
+ exclude: list[str] | None = None,
181
+ ) -> list[Diagnostic]:
182
+ """Run all applicable checks against a node.
183
+
184
+ Args:
185
+ node: The AST node to check.
186
+ context: Analysis context.
187
+ include: Optional list of check codes to run (if None, run all).
188
+ exclude: Optional list of check codes to skip.
189
+
190
+ Returns:
191
+ List of diagnostics found by any check.
192
+ """
193
+ results: list[Diagnostic] = []
194
+ exclude_set = set(exclude or [])
195
+
196
+ for check in self._checks:
197
+ # Skip if not in include list (when specified)
198
+ if include is not None and check.code not in include:
199
+ continue
200
+ # Skip if in exclude list
201
+ if check.code in exclude_set:
202
+ continue
203
+
204
+ if diagnostic := check.check(node, context):
205
+ results.append(diagnostic)
206
+
207
+ return results
208
+
209
+
210
+ # Global registry instance
211
+ _registry: CheckRegistry | None = None
212
+
213
+
214
+ def get_registry() -> CheckRegistry:
215
+ """Get the global check registry, creating it if necessary."""
216
+ global _registry
217
+ if _registry is None:
218
+ _registry = CheckRegistry()
219
+ return _registry
220
+
221
+
222
+ def reset_registry() -> None:
223
+ """Reset the global registry (mainly for testing)."""
224
+ global _registry
225
+ _registry = CheckRegistry()
@@ -0,0 +1,179 @@
1
+ """Command usage checks for zshcheck.
2
+
3
+ This module contains checks for deprecated commands, unsafe command usage,
4
+ and command-specific issues.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from tree_sitter import Node
10
+
11
+ from zshcheck.checks.base import AnalysisContext, BaseCheck
12
+ from zshcheck.diagnostics import Diagnostic, Severity
13
+ from zshcheck.parser import get_node_text, node_to_range
14
+
15
+
16
+ class DeprecatedCommandCheck(BaseCheck):
17
+ """Check for deprecated or obsolete commands."""
18
+
19
+ # Map of deprecated commands to their replacements
20
+ DEPRECATED_COMMANDS: dict[str, tuple[str, str]] = {
21
+ "which": ("command -v or type", "'which' is non-standard and not portable"),
22
+ "whereis": ("command -v or type", "'whereis' is not portable"),
23
+ "finger": ("other user lookup methods", "'finger' is outdated and often unavailable"),
24
+ "mail": ("sendmail or modern mail clients", "'mail' command is deprecated"),
25
+ "uudecode": ("base64", "'uudecode' is obsolete, use base64"),
26
+ "uuencode": ("base64", "'uuencode' is obsolete, use base64"),
27
+ }
28
+
29
+ @property
30
+ def code(self) -> str:
31
+ return "ZC4001"
32
+
33
+ @property
34
+ def description(self) -> str:
35
+ return "Deprecated or obsolete command used"
36
+
37
+ @property
38
+ def severity(self) -> Severity:
39
+ return Severity.WARNING
40
+
41
+ def check(self, node: Node, context: AnalysisContext) -> Diagnostic | None:
42
+ # Look for command names
43
+ if node.type not in ("command_name", "command"):
44
+ return None
45
+
46
+ # Get command text
47
+ cmd_text = get_node_text(node, context.source)
48
+
49
+ # Strip any path prefix to get the base command
50
+ base_cmd = cmd_text.split("/")[-1]
51
+
52
+ if base_cmd in self.DEPRECATED_COMMANDS:
53
+ replacement, reason = self.DEPRECATED_COMMANDS[base_cmd]
54
+ node_range = node_to_range(node)
55
+ source_line = self._get_source_line(node, context.source)
56
+
57
+ return Diagnostic(
58
+ code=self.code,
59
+ severity=self.severity,
60
+ message=f"{reason}. Consider using '{replacement}' instead.",
61
+ range=node_range,
62
+ source=source_line,
63
+ )
64
+
65
+ return None
66
+
67
+ def _get_source_line(self, node: Node, source: str) -> str | None:
68
+ """Get the source line containing this node."""
69
+ lines = source.split("\n")
70
+ line_idx = node.start_point.row
71
+ if 0 <= line_idx < len(lines):
72
+ return lines[line_idx]
73
+ return None
74
+
75
+
76
+ class BacktickCheck(BaseCheck):
77
+ """Check for use of backticks instead of $()."""
78
+
79
+ @property
80
+ def code(self) -> str:
81
+ return "ZC4002"
82
+
83
+ @property
84
+ def description(self) -> str:
85
+ return "Use $() instead of backticks for command substitution"
86
+
87
+ @property
88
+ def severity(self) -> Severity:
89
+ return Severity.STYLE
90
+
91
+ def check(self, node: Node, context: AnalysisContext) -> Diagnostic | None:
92
+ # Look for backtick command substitutions
93
+ # In tree-sitter-bash, these might be marked as "command_substitution"
94
+ # but we'd need to check the actual source text
95
+ if node.type != "command_substitution":
96
+ return None
97
+
98
+ # Check if it starts with backtick
99
+ node_text = get_node_text(node, context.source)
100
+ if node_text.startswith("`") and node_text.endswith("`"):
101
+ node_range = node_to_range(node)
102
+ source_line = self._get_source_line(node, context.source)
103
+
104
+ # Extract the inner command
105
+ inner = node_text[1:-1]
106
+ replacement = f"$({inner})"
107
+
108
+ from zshcheck.diagnostics import Fix, Replacement
109
+
110
+ fix = Fix(
111
+ message=f"Replace with {replacement}",
112
+ replacements=[Replacement(range=node_range, text=replacement)],
113
+ )
114
+
115
+ return Diagnostic(
116
+ code=self.code,
117
+ severity=self.severity,
118
+ message="Backticks are deprecated. Use $() for command substitution instead.",
119
+ range=node_range,
120
+ fix=fix,
121
+ source=source_line,
122
+ )
123
+
124
+ return None
125
+
126
+ def _get_source_line(self, node: Node, source: str) -> str | None:
127
+ """Get the source line containing this node."""
128
+ lines = source.split("\n")
129
+ line_idx = node.start_point.row
130
+ if 0 <= line_idx < len(lines):
131
+ return lines[line_idx]
132
+ return None
133
+
134
+
135
+ class EchoWithEscapesCheck(BaseCheck):
136
+ """Check for potentially problematic echo with escape sequences."""
137
+
138
+ @property
139
+ def code(self) -> str:
140
+ return "ZC4003"
141
+
142
+ @property
143
+ def description(self) -> str:
144
+ return "echo with escape sequences may behave differently across systems"
145
+
146
+ @property
147
+ def severity(self) -> Severity:
148
+ return Severity.WARNING
149
+
150
+ def check(self, node: Node, context: AnalysisContext) -> Diagnostic | None:
151
+ # Look for echo commands with -e flag
152
+ if node.type != "command":
153
+ return None
154
+
155
+ # Get command text
156
+ node_text = get_node_text(node, context.source)
157
+
158
+ # Check if it's echo with -e flag
159
+ if node_text.startswith("echo -e") or node_text.startswith("echo -E"):
160
+ node_range = node_to_range(node)
161
+ source_line = self._get_source_line(node, context.source)
162
+
163
+ return Diagnostic(
164
+ code=self.code,
165
+ severity=self.severity,
166
+ message="echo with -e/-E flags is non-portable. Consider using printf instead.",
167
+ range=node_range,
168
+ source=source_line,
169
+ )
170
+
171
+ return None
172
+
173
+ def _get_source_line(self, node: Node, source: str) -> str | None:
174
+ """Get the source line containing this node."""
175
+ lines = source.split("\n")
176
+ line_idx = node.start_point.row
177
+ if 0 <= line_idx < len(lines):
178
+ return lines[line_idx]
179
+ return None