codeshield-ai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codeshield/__init__.py +62 -0
- codeshield/api_server.py +438 -0
- codeshield/cli.py +48 -0
- codeshield/contextvault/__init__.py +1 -0
- codeshield/contextvault/capture.py +174 -0
- codeshield/contextvault/restore.py +115 -0
- codeshield/mcp/__init__.py +1 -0
- codeshield/mcp/hooks.py +65 -0
- codeshield/mcp/server.py +319 -0
- codeshield/styleforge/__init__.py +1 -0
- codeshield/styleforge/corrector.py +298 -0
- codeshield/trustgate/__init__.py +1 -0
- codeshield/trustgate/checker.py +384 -0
- codeshield/trustgate/sandbox.py +101 -0
- codeshield/utils/__init__.py +9 -0
- codeshield/utils/daytona.py +233 -0
- codeshield/utils/leanmcp.py +258 -0
- codeshield/utils/llm.py +423 -0
- codeshield/utils/metrics.py +543 -0
- codeshield/utils/token_optimizer.py +605 -0
- codeshield_ai-0.1.0.dist-info/METADATA +565 -0
- codeshield_ai-0.1.0.dist-info/RECORD +24 -0
- codeshield_ai-0.1.0.dist-info/WHEEL +4 -0
- codeshield_ai-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
"""
|
|
2
|
+
StyleForge Corrector - Convention enforcement
|
|
3
|
+
|
|
4
|
+
Enforces YOUR codebase conventions:
|
|
5
|
+
- Naming pattern detection
|
|
6
|
+
- Variable name correction
|
|
7
|
+
- Function name matching
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import ast
|
|
11
|
+
import re
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import Optional
|
|
15
|
+
from collections import Counter
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class NamingConvention:
|
|
20
|
+
"""Detected naming convention"""
|
|
21
|
+
pattern: str # "snake_case", "camelCase", "PascalCase", "SCREAMING_SNAKE"
|
|
22
|
+
confidence: float
|
|
23
|
+
examples: list[str] = field(default_factory=list)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class StyleIssue:
|
|
28
|
+
"""A style issue found in code"""
|
|
29
|
+
message: str
|
|
30
|
+
original: str
|
|
31
|
+
suggested: str
|
|
32
|
+
line: Optional[int] = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class StyleCheckResult:
|
|
37
|
+
"""Result of style checking"""
|
|
38
|
+
matches_convention: bool
|
|
39
|
+
issues: list[StyleIssue] = field(default_factory=list)
|
|
40
|
+
conventions_detected: dict = field(default_factory=dict)
|
|
41
|
+
corrected_code: Optional[str] = None
|
|
42
|
+
|
|
43
|
+
def to_dict(self) -> dict:
|
|
44
|
+
return {
|
|
45
|
+
"matches_convention": self.matches_convention,
|
|
46
|
+
"issues": [
|
|
47
|
+
{
|
|
48
|
+
"message": i.message,
|
|
49
|
+
"original": i.original,
|
|
50
|
+
"suggested": i.suggested,
|
|
51
|
+
"line": i.line,
|
|
52
|
+
}
|
|
53
|
+
for i in self.issues
|
|
54
|
+
],
|
|
55
|
+
"conventions_detected": self.conventions_detected,
|
|
56
|
+
"has_corrections": self.corrected_code is not None,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def detect_naming_pattern(name: str) -> str:
|
|
61
|
+
"""Detect the naming pattern of a single name"""
|
|
62
|
+
if name.isupper() and '_' in name:
|
|
63
|
+
return "SCREAMING_SNAKE"
|
|
64
|
+
elif '_' in name and name.islower():
|
|
65
|
+
return "snake_case"
|
|
66
|
+
elif name[0].isupper() and not '_' in name:
|
|
67
|
+
return "PascalCase"
|
|
68
|
+
elif name[0].islower() and any(c.isupper() for c in name):
|
|
69
|
+
return "camelCase"
|
|
70
|
+
elif name.islower():
|
|
71
|
+
return "snake_case"
|
|
72
|
+
else:
|
|
73
|
+
return "mixed"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def convert_to_snake_case(name: str) -> str:
|
|
77
|
+
"""Convert name to snake_case"""
|
|
78
|
+
# Handle camelCase and PascalCase
|
|
79
|
+
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
|
80
|
+
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def convert_to_camel_case(name: str) -> str:
|
|
84
|
+
"""Convert name to camelCase"""
|
|
85
|
+
if '_' in name:
|
|
86
|
+
components = name.split('_')
|
|
87
|
+
return components[0].lower() + ''.join(x.title() for x in components[1:])
|
|
88
|
+
return name
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class CodebaseAnalyzer(ast.NodeVisitor):
|
|
92
|
+
"""Analyzes codebase to extract naming conventions"""
|
|
93
|
+
|
|
94
|
+
def __init__(self):
|
|
95
|
+
self.variable_names: list[str] = []
|
|
96
|
+
self.function_names: list[str] = []
|
|
97
|
+
self.class_names: list[str] = []
|
|
98
|
+
self.constant_names: list[str] = []
|
|
99
|
+
|
|
100
|
+
def visit_Name(self, node: ast.Name):
|
|
101
|
+
if isinstance(node.ctx, ast.Store):
|
|
102
|
+
if node.id.isupper():
|
|
103
|
+
self.constant_names.append(node.id)
|
|
104
|
+
else:
|
|
105
|
+
self.variable_names.append(node.id)
|
|
106
|
+
self.generic_visit(node)
|
|
107
|
+
|
|
108
|
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
|
109
|
+
self.function_names.append(node.name)
|
|
110
|
+
self.generic_visit(node)
|
|
111
|
+
|
|
112
|
+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
|
|
113
|
+
self.function_names.append(node.name)
|
|
114
|
+
self.generic_visit(node)
|
|
115
|
+
|
|
116
|
+
def visit_ClassDef(self, node: ast.ClassDef):
|
|
117
|
+
self.class_names.append(node.name)
|
|
118
|
+
self.generic_visit(node)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def analyze_codebase(path: str) -> dict[str, NamingConvention]:
|
|
122
|
+
"""Analyze codebase and extract naming conventions"""
|
|
123
|
+
path = Path(path)
|
|
124
|
+
|
|
125
|
+
if not path.exists():
|
|
126
|
+
return {}
|
|
127
|
+
|
|
128
|
+
analyzer = CodebaseAnalyzer()
|
|
129
|
+
|
|
130
|
+
# Find all Python files
|
|
131
|
+
py_files = list(path.rglob("*.py")) if path.is_dir() else [path]
|
|
132
|
+
|
|
133
|
+
for py_file in py_files[:50]: # Limit for performance
|
|
134
|
+
try:
|
|
135
|
+
code = py_file.read_text(encoding='utf-8')
|
|
136
|
+
tree = ast.parse(code)
|
|
137
|
+
analyzer.visit(tree)
|
|
138
|
+
except (SyntaxError, UnicodeDecodeError):
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
# Detect patterns
|
|
142
|
+
conventions = {}
|
|
143
|
+
|
|
144
|
+
# Variable naming
|
|
145
|
+
if analyzer.variable_names:
|
|
146
|
+
patterns = [detect_naming_pattern(n) for n in analyzer.variable_names if len(n) > 1]
|
|
147
|
+
pattern_counts = Counter(patterns)
|
|
148
|
+
if pattern_counts:
|
|
149
|
+
most_common = pattern_counts.most_common(1)[0]
|
|
150
|
+
conventions["variables"] = NamingConvention(
|
|
151
|
+
pattern=most_common[0],
|
|
152
|
+
confidence=most_common[1] / len(patterns),
|
|
153
|
+
examples=analyzer.variable_names[:5]
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Function naming
|
|
157
|
+
if analyzer.function_names:
|
|
158
|
+
patterns = [detect_naming_pattern(n) for n in analyzer.function_names if not n.startswith('_')]
|
|
159
|
+
pattern_counts = Counter(patterns)
|
|
160
|
+
if pattern_counts:
|
|
161
|
+
most_common = pattern_counts.most_common(1)[0]
|
|
162
|
+
conventions["functions"] = NamingConvention(
|
|
163
|
+
pattern=most_common[0],
|
|
164
|
+
confidence=most_common[1] / len(patterns),
|
|
165
|
+
examples=analyzer.function_names[:5]
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return conventions
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def build_name_registry(path: str) -> dict[str, set[str]]:
|
|
172
|
+
"""Build registry of all names in codebase"""
|
|
173
|
+
path = Path(path)
|
|
174
|
+
|
|
175
|
+
if not path.exists():
|
|
176
|
+
return {}
|
|
177
|
+
|
|
178
|
+
analyzer = CodebaseAnalyzer()
|
|
179
|
+
|
|
180
|
+
py_files = list(path.rglob("*.py")) if path.is_dir() else [path]
|
|
181
|
+
|
|
182
|
+
for py_file in py_files[:50]:
|
|
183
|
+
try:
|
|
184
|
+
code = py_file.read_text(encoding='utf-8')
|
|
185
|
+
tree = ast.parse(code)
|
|
186
|
+
analyzer.visit(tree)
|
|
187
|
+
except (SyntaxError, UnicodeDecodeError):
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
return {
|
|
191
|
+
"variables": set(analyzer.variable_names),
|
|
192
|
+
"functions": set(analyzer.function_names),
|
|
193
|
+
"classes": set(analyzer.class_names),
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def check_style(code: str, codebase_path: str = ".") -> StyleCheckResult:
|
|
198
|
+
"""
|
|
199
|
+
Check code against codebase conventions.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
code: Code to check
|
|
203
|
+
codebase_path: Path to codebase for convention extraction
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
StyleCheckResult with issues and suggestions
|
|
207
|
+
"""
|
|
208
|
+
issues: list[StyleIssue] = []
|
|
209
|
+
|
|
210
|
+
# Analyze codebase conventions
|
|
211
|
+
conventions = analyze_codebase(codebase_path)
|
|
212
|
+
registry = build_name_registry(codebase_path)
|
|
213
|
+
|
|
214
|
+
# Parse code
|
|
215
|
+
try:
|
|
216
|
+
tree = ast.parse(code)
|
|
217
|
+
except SyntaxError:
|
|
218
|
+
return StyleCheckResult(
|
|
219
|
+
matches_convention=False,
|
|
220
|
+
issues=[StyleIssue(
|
|
221
|
+
message="Cannot parse code",
|
|
222
|
+
original="",
|
|
223
|
+
suggested="",
|
|
224
|
+
)],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Analyze code names
|
|
228
|
+
code_analyzer = CodebaseAnalyzer()
|
|
229
|
+
code_analyzer.visit(tree)
|
|
230
|
+
|
|
231
|
+
# Check variable naming
|
|
232
|
+
var_convention = conventions.get("variables")
|
|
233
|
+
if var_convention and var_convention.pattern == "snake_case":
|
|
234
|
+
for var in code_analyzer.variable_names:
|
|
235
|
+
pattern = detect_naming_pattern(var)
|
|
236
|
+
if pattern == "camelCase":
|
|
237
|
+
suggested = convert_to_snake_case(var)
|
|
238
|
+
issues.append(StyleIssue(
|
|
239
|
+
message=f"Variable '{var}' uses camelCase, codebase uses snake_case",
|
|
240
|
+
original=var,
|
|
241
|
+
suggested=suggested,
|
|
242
|
+
))
|
|
243
|
+
|
|
244
|
+
# Check function naming
|
|
245
|
+
func_convention = conventions.get("functions")
|
|
246
|
+
if func_convention and func_convention.pattern == "snake_case":
|
|
247
|
+
for func in code_analyzer.function_names:
|
|
248
|
+
pattern = detect_naming_pattern(func)
|
|
249
|
+
if pattern == "camelCase":
|
|
250
|
+
suggested = convert_to_snake_case(func)
|
|
251
|
+
issues.append(StyleIssue(
|
|
252
|
+
message=f"Function '{func}' uses camelCase, codebase uses snake_case",
|
|
253
|
+
original=func,
|
|
254
|
+
suggested=suggested,
|
|
255
|
+
))
|
|
256
|
+
|
|
257
|
+
# Check for similar existing names (typo detection)
|
|
258
|
+
all_existing_names = set()
|
|
259
|
+
for names in registry.values():
|
|
260
|
+
all_existing_names.update(names)
|
|
261
|
+
|
|
262
|
+
for var in code_analyzer.variable_names:
|
|
263
|
+
normalized = var.lower().replace('_', '')
|
|
264
|
+
for existing in all_existing_names:
|
|
265
|
+
existing_normalized = existing.lower().replace('_', '')
|
|
266
|
+
# Check for slight variations
|
|
267
|
+
if normalized != existing_normalized and len(normalized) > 3:
|
|
268
|
+
if normalized[:-1] == existing_normalized or normalized == existing_normalized[:-1]:
|
|
269
|
+
issues.append(StyleIssue(
|
|
270
|
+
message=f"'{var}' might be a typo of existing '{existing}'",
|
|
271
|
+
original=var,
|
|
272
|
+
suggested=existing,
|
|
273
|
+
))
|
|
274
|
+
|
|
275
|
+
# Apply corrections
|
|
276
|
+
corrected_code = code
|
|
277
|
+
if issues:
|
|
278
|
+
for issue in issues:
|
|
279
|
+
if issue.original and issue.suggested:
|
|
280
|
+
# Simple replacement (word boundary aware)
|
|
281
|
+
pattern = r'\b' + re.escape(issue.original) + r'\b'
|
|
282
|
+
corrected_code = re.sub(pattern, issue.suggested, corrected_code)
|
|
283
|
+
|
|
284
|
+
# Build conventions dict for response
|
|
285
|
+
conv_dict = {}
|
|
286
|
+
for name, conv in conventions.items():
|
|
287
|
+
conv_dict[name] = {
|
|
288
|
+
"pattern": conv.pattern,
|
|
289
|
+
"confidence": conv.confidence,
|
|
290
|
+
"examples": conv.examples[:3],
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
return StyleCheckResult(
|
|
294
|
+
matches_convention=len(issues) == 0,
|
|
295
|
+
issues=issues,
|
|
296
|
+
conventions_detected=conv_dict,
|
|
297
|
+
corrected_code=corrected_code if corrected_code != code else None,
|
|
298
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""TrustGate - Code verification module"""
|
|
@@ -0,0 +1,384 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TrustGate Checker - Core verification engine
|
|
3
|
+
|
|
4
|
+
Verifies code before you see it:
|
|
5
|
+
- Syntax checking
|
|
6
|
+
- Missing import detection
|
|
7
|
+
- Type checking (basic)
|
|
8
|
+
- Execution testing
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import ast
|
|
12
|
+
import sys
|
|
13
|
+
import re
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import Optional
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
# Standard library modules that are commonly used
|
|
19
|
+
STDLIB_MODULES = {
|
|
20
|
+
'os', 'sys', 'json', 're', 'math', 'random', 'datetime', 'time',
|
|
21
|
+
'collections', 'itertools', 'functools', 'typing', 'pathlib',
|
|
22
|
+
'subprocess', 'threading', 'multiprocessing', 'asyncio',
|
|
23
|
+
'urllib', 'http', 'socket', 'email', 'html', 'xml',
|
|
24
|
+
'sqlite3', 'csv', 'configparser', 'logging', 'unittest',
|
|
25
|
+
'hashlib', 'hmac', 'secrets', 'base64', 'pickle', 'copy',
|
|
26
|
+
'io', 'tempfile', 'shutil', 'glob', 'fnmatch',
|
|
27
|
+
'argparse', 'getopt', 'textwrap', 'string',
|
|
28
|
+
'dataclasses', 'enum', 'abc', 'contextlib',
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
# Common third-party modules and their pip names
|
|
32
|
+
COMMON_PACKAGES = {
|
|
33
|
+
'requests': 'requests',
|
|
34
|
+
'numpy': 'numpy',
|
|
35
|
+
'pandas': 'pandas',
|
|
36
|
+
'flask': 'flask',
|
|
37
|
+
'django': 'django',
|
|
38
|
+
'fastapi': 'fastapi',
|
|
39
|
+
'httpx': 'httpx',
|
|
40
|
+
'aiohttp': 'aiohttp',
|
|
41
|
+
'pydantic': 'pydantic',
|
|
42
|
+
'sqlalchemy': 'sqlalchemy',
|
|
43
|
+
'pytest': 'pytest',
|
|
44
|
+
'rich': 'rich',
|
|
45
|
+
'click': 'click',
|
|
46
|
+
'typer': 'typer',
|
|
47
|
+
'beautifulsoup4': 'bs4',
|
|
48
|
+
'bs4': 'beautifulsoup4',
|
|
49
|
+
'PIL': 'pillow',
|
|
50
|
+
'cv2': 'opencv-python',
|
|
51
|
+
'sklearn': 'scikit-learn',
|
|
52
|
+
'torch': 'torch',
|
|
53
|
+
'tensorflow': 'tensorflow',
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class VerificationIssue:
|
|
59
|
+
"""A single verification issue"""
|
|
60
|
+
severity: str # "error", "warning", "info"
|
|
61
|
+
message: str
|
|
62
|
+
line: Optional[int] = None
|
|
63
|
+
column: Optional[int] = None
|
|
64
|
+
fix_available: bool = False
|
|
65
|
+
fix_description: Optional[str] = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class VerificationResult:
|
|
70
|
+
"""Result of code verification"""
|
|
71
|
+
is_valid: bool
|
|
72
|
+
issues: list[VerificationIssue] = field(default_factory=list)
|
|
73
|
+
fixed_code: Optional[str] = None
|
|
74
|
+
confidence_score: float = 0.0
|
|
75
|
+
|
|
76
|
+
def to_dict(self) -> dict:
|
|
77
|
+
return {
|
|
78
|
+
"is_valid": self.is_valid,
|
|
79
|
+
"issues": [
|
|
80
|
+
{
|
|
81
|
+
"type": i.severity,
|
|
82
|
+
"severity": i.severity,
|
|
83
|
+
"message": i.message,
|
|
84
|
+
"line": i.line,
|
|
85
|
+
"fix_available": i.fix_available,
|
|
86
|
+
"fix_description": i.fix_description,
|
|
87
|
+
}
|
|
88
|
+
for i in self.issues
|
|
89
|
+
],
|
|
90
|
+
"confidence_score": self.confidence_score,
|
|
91
|
+
"has_fixes": self.fixed_code is not None,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ImportVisitor(ast.NodeVisitor):
|
|
96
|
+
"""Visits AST to find all imports"""
|
|
97
|
+
|
|
98
|
+
def __init__(self):
|
|
99
|
+
self.imports: set[str] = set()
|
|
100
|
+
self.from_imports: dict[str, set[str]] = {}
|
|
101
|
+
|
|
102
|
+
def visit_Import(self, node: ast.Import):
|
|
103
|
+
for alias in node.names:
|
|
104
|
+
self.imports.add(alias.name.split('.')[0])
|
|
105
|
+
self.generic_visit(node)
|
|
106
|
+
|
|
107
|
+
def visit_ImportFrom(self, node: ast.ImportFrom):
|
|
108
|
+
if node.module:
|
|
109
|
+
base_module = node.module.split('.')[0]
|
|
110
|
+
self.imports.add(base_module)
|
|
111
|
+
if base_module not in self.from_imports:
|
|
112
|
+
self.from_imports[base_module] = set()
|
|
113
|
+
for alias in node.names:
|
|
114
|
+
self.from_imports[base_module].add(alias.name)
|
|
115
|
+
self.generic_visit(node)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class NameVisitor(ast.NodeVisitor):
|
|
119
|
+
"""Visits AST to find all used names"""
|
|
120
|
+
|
|
121
|
+
def __init__(self):
|
|
122
|
+
self.used_names: set[str] = set()
|
|
123
|
+
self.defined_names: set[str] = set()
|
|
124
|
+
self.function_calls: set[str] = set()
|
|
125
|
+
|
|
126
|
+
def visit_Name(self, node: ast.Name):
|
|
127
|
+
if isinstance(node.ctx, ast.Load):
|
|
128
|
+
self.used_names.add(node.id)
|
|
129
|
+
elif isinstance(node.ctx, ast.Store):
|
|
130
|
+
self.defined_names.add(node.id)
|
|
131
|
+
self.generic_visit(node)
|
|
132
|
+
|
|
133
|
+
def visit_Call(self, node: ast.Call):
|
|
134
|
+
if isinstance(node.func, ast.Attribute):
|
|
135
|
+
# module.function() calls
|
|
136
|
+
if isinstance(node.func.value, ast.Name):
|
|
137
|
+
self.function_calls.add(f"{node.func.value.id}.{node.func.attr}")
|
|
138
|
+
elif isinstance(node.func, ast.Name):
|
|
139
|
+
self.function_calls.add(node.func.id)
|
|
140
|
+
self.generic_visit(node)
|
|
141
|
+
|
|
142
|
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
|
143
|
+
self.defined_names.add(node.name)
|
|
144
|
+
self.generic_visit(node)
|
|
145
|
+
|
|
146
|
+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
|
|
147
|
+
self.defined_names.add(node.name)
|
|
148
|
+
self.generic_visit(node)
|
|
149
|
+
|
|
150
|
+
def visit_ClassDef(self, node: ast.ClassDef):
|
|
151
|
+
self.defined_names.add(node.name)
|
|
152
|
+
self.generic_visit(node)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def check_syntax(code: str) -> tuple[bool, Optional[VerificationIssue]]:
|
|
156
|
+
"""Check if code has valid Python syntax"""
|
|
157
|
+
try:
|
|
158
|
+
ast.parse(code)
|
|
159
|
+
return True, None
|
|
160
|
+
except SyntaxError as e:
|
|
161
|
+
return False, VerificationIssue(
|
|
162
|
+
severity="error",
|
|
163
|
+
message=f"Syntax error: {e.msg}",
|
|
164
|
+
line=e.lineno,
|
|
165
|
+
column=e.offset,
|
|
166
|
+
fix_available=False,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def detect_missing_imports(code: str) -> list[VerificationIssue]:
|
|
171
|
+
"""Detect missing imports in code"""
|
|
172
|
+
issues = []
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
tree = ast.parse(code)
|
|
176
|
+
except SyntaxError:
|
|
177
|
+
return issues # Can't parse, syntax check will catch this
|
|
178
|
+
|
|
179
|
+
# Get all imports
|
|
180
|
+
import_visitor = ImportVisitor()
|
|
181
|
+
import_visitor.visit(tree)
|
|
182
|
+
imported = import_visitor.imports
|
|
183
|
+
|
|
184
|
+
# Get all used names
|
|
185
|
+
name_visitor = NameVisitor()
|
|
186
|
+
name_visitor.visit(tree)
|
|
187
|
+
|
|
188
|
+
# Check function calls for module usage
|
|
189
|
+
for call in name_visitor.function_calls:
|
|
190
|
+
if '.' in call:
|
|
191
|
+
module = call.split('.')[0]
|
|
192
|
+
if module not in imported and module not in name_visitor.defined_names:
|
|
193
|
+
# Check if it's a known module
|
|
194
|
+
if module in STDLIB_MODULES:
|
|
195
|
+
issues.append(VerificationIssue(
|
|
196
|
+
severity="error",
|
|
197
|
+
message=f"Missing import: {module}",
|
|
198
|
+
fix_available=True,
|
|
199
|
+
fix_description=f"Add 'import {module}'",
|
|
200
|
+
))
|
|
201
|
+
elif module in COMMON_PACKAGES:
|
|
202
|
+
issues.append(VerificationIssue(
|
|
203
|
+
severity="error",
|
|
204
|
+
message=f"Missing import: {module} (pip install {COMMON_PACKAGES[module]})",
|
|
205
|
+
fix_available=True,
|
|
206
|
+
fix_description=f"Add 'import {module}'",
|
|
207
|
+
))
|
|
208
|
+
|
|
209
|
+
return issues
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def detect_undefined_names(code: str) -> list[VerificationIssue]:
|
|
213
|
+
"""Detect potentially undefined variable names"""
|
|
214
|
+
issues = []
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
tree = ast.parse(code)
|
|
218
|
+
except SyntaxError:
|
|
219
|
+
return issues
|
|
220
|
+
|
|
221
|
+
# Get imports
|
|
222
|
+
import_visitor = ImportVisitor()
|
|
223
|
+
import_visitor.visit(tree)
|
|
224
|
+
|
|
225
|
+
# Get names
|
|
226
|
+
name_visitor = NameVisitor()
|
|
227
|
+
name_visitor.visit(tree)
|
|
228
|
+
|
|
229
|
+
# Built-in names that don't need imports
|
|
230
|
+
builtins = set(dir(__builtins__)) if isinstance(__builtins__, dict) else set(dir(__builtins__))
|
|
231
|
+
builtins.update({'True', 'False', 'None', 'self', 'cls'})
|
|
232
|
+
|
|
233
|
+
# Find undefined names
|
|
234
|
+
all_defined = (
|
|
235
|
+
name_visitor.defined_names |
|
|
236
|
+
import_visitor.imports |
|
|
237
|
+
builtins
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
for name in name_visitor.used_names:
|
|
241
|
+
if name not in all_defined and not name.startswith('_'):
|
|
242
|
+
# Check if it might be a module
|
|
243
|
+
if name in STDLIB_MODULES or name in COMMON_PACKAGES:
|
|
244
|
+
issues.append(VerificationIssue(
|
|
245
|
+
severity="warning",
|
|
246
|
+
message=f"'{name}' used but not imported",
|
|
247
|
+
fix_available=True,
|
|
248
|
+
fix_description=f"Add 'import {name}'",
|
|
249
|
+
))
|
|
250
|
+
|
|
251
|
+
return issues
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def auto_fix_imports(code: str, issues: list[VerificationIssue]) -> str:
|
|
255
|
+
"""Auto-fix missing imports"""
|
|
256
|
+
imports_to_add = []
|
|
257
|
+
|
|
258
|
+
for issue in issues:
|
|
259
|
+
if issue.fix_available and "Missing import:" in issue.message:
|
|
260
|
+
# Extract module name
|
|
261
|
+
match = re.search(r"Missing import: (\w+)", issue.message)
|
|
262
|
+
if match:
|
|
263
|
+
module = match.group(1)
|
|
264
|
+
imports_to_add.append(f"import {module}")
|
|
265
|
+
|
|
266
|
+
if not imports_to_add:
|
|
267
|
+
return code
|
|
268
|
+
|
|
269
|
+
# Add imports at the top (after any existing imports or docstrings)
|
|
270
|
+
lines = code.split('\n')
|
|
271
|
+
insert_pos = 0
|
|
272
|
+
|
|
273
|
+
# Skip docstrings
|
|
274
|
+
in_docstring = False
|
|
275
|
+
for i, line in enumerate(lines):
|
|
276
|
+
stripped = line.strip()
|
|
277
|
+
if stripped.startswith('"""') or stripped.startswith("'''"):
|
|
278
|
+
if in_docstring:
|
|
279
|
+
in_docstring = False
|
|
280
|
+
insert_pos = i + 1
|
|
281
|
+
else:
|
|
282
|
+
in_docstring = True
|
|
283
|
+
elif not in_docstring and (stripped.startswith('import ') or stripped.startswith('from ')):
|
|
284
|
+
insert_pos = i + 1
|
|
285
|
+
elif not in_docstring and stripped and not stripped.startswith('#'):
|
|
286
|
+
break
|
|
287
|
+
|
|
288
|
+
# Insert imports
|
|
289
|
+
for imp in imports_to_add:
|
|
290
|
+
lines.insert(insert_pos, imp)
|
|
291
|
+
insert_pos += 1
|
|
292
|
+
|
|
293
|
+
return '\n'.join(lines)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def verify_code(code: str, auto_fix: bool = True) -> VerificationResult:
|
|
297
|
+
"""
|
|
298
|
+
Main verification function.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
code: Python code to verify
|
|
302
|
+
auto_fix: Whether to attempt auto-fixes
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
VerificationResult with issues and optionally fixed code
|
|
306
|
+
"""
|
|
307
|
+
issues: list[VerificationIssue] = []
|
|
308
|
+
|
|
309
|
+
# 1. Syntax check
|
|
310
|
+
syntax_ok, syntax_issue = check_syntax(code)
|
|
311
|
+
if not syntax_ok and syntax_issue:
|
|
312
|
+
issues.append(syntax_issue)
|
|
313
|
+
return VerificationResult(
|
|
314
|
+
is_valid=False,
|
|
315
|
+
issues=issues,
|
|
316
|
+
confidence_score=0.0,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# 2. Missing imports
|
|
320
|
+
import_issues = detect_missing_imports(code)
|
|
321
|
+
issues.extend(import_issues)
|
|
322
|
+
|
|
323
|
+
# 3. Undefined names
|
|
324
|
+
undefined_issues = detect_undefined_names(code)
|
|
325
|
+
issues.extend(undefined_issues)
|
|
326
|
+
|
|
327
|
+
# Calculate confidence score
|
|
328
|
+
error_count = sum(1 for i in issues if i.severity == "error")
|
|
329
|
+
warning_count = sum(1 for i in issues if i.severity == "warning")
|
|
330
|
+
confidence = max(0.0, 1.0 - (error_count * 0.2) - (warning_count * 0.05))
|
|
331
|
+
|
|
332
|
+
# Auto-fix if requested
|
|
333
|
+
fixed_code = None
|
|
334
|
+
if auto_fix and issues:
|
|
335
|
+
fixed_code = auto_fix_imports(code, issues)
|
|
336
|
+
if fixed_code != code:
|
|
337
|
+
# Re-verify fixed code
|
|
338
|
+
recheck_syntax, _ = check_syntax(fixed_code)
|
|
339
|
+
if recheck_syntax:
|
|
340
|
+
confidence = min(confidence + 0.1, 1.0)
|
|
341
|
+
|
|
342
|
+
is_valid = error_count == 0
|
|
343
|
+
|
|
344
|
+
return VerificationResult(
|
|
345
|
+
is_valid=is_valid,
|
|
346
|
+
issues=issues,
|
|
347
|
+
fixed_code=fixed_code,
|
|
348
|
+
confidence_score=confidence,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def verify_file(file_path: str) -> VerificationResult:
|
|
353
|
+
"""Verify a Python file"""
|
|
354
|
+
path = Path(file_path)
|
|
355
|
+
if not path.exists():
|
|
356
|
+
return VerificationResult(
|
|
357
|
+
is_valid=False,
|
|
358
|
+
issues=[VerificationIssue(
|
|
359
|
+
severity="error",
|
|
360
|
+
message=f"File not found: {file_path}",
|
|
361
|
+
)],
|
|
362
|
+
confidence_score=0.0,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
code = path.read_text(encoding='utf-8')
|
|
366
|
+
return verify_code(code)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
# Quick test
|
|
370
|
+
if __name__ == "__main__":
|
|
371
|
+
test_code = '''
|
|
372
|
+
def fetch_data():
|
|
373
|
+
response = requests.get("https://api.example.com")
|
|
374
|
+
return json.loads(response.text)
|
|
375
|
+
'''
|
|
376
|
+
|
|
377
|
+
result = verify_code(test_code)
|
|
378
|
+
print(f"Valid: {result.is_valid}")
|
|
379
|
+
print(f"Confidence: {result.confidence_score:.0%}")
|
|
380
|
+
for issue in result.issues:
|
|
381
|
+
print(f" [{issue.severity}] {issue.message}")
|
|
382
|
+
if result.fixed_code:
|
|
383
|
+
print("\nFixed code:")
|
|
384
|
+
print(result.fixed_code)
|