codeshield-ai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,298 @@
1
+ """
2
+ StyleForge Corrector - Convention enforcement
3
+
4
+ Enforces YOUR codebase conventions:
5
+ - Naming pattern detection
6
+ - Variable name correction
7
+ - Function name matching
8
+ """
9
+
10
+ import ast
11
+ import re
12
+ from pathlib import Path
13
+ from dataclasses import dataclass, field
14
+ from typing import Optional
15
+ from collections import Counter
16
+
17
+
18
+ @dataclass
19
+ class NamingConvention:
20
+ """Detected naming convention"""
21
+ pattern: str # "snake_case", "camelCase", "PascalCase", "SCREAMING_SNAKE"
22
+ confidence: float
23
+ examples: list[str] = field(default_factory=list)
24
+
25
+
26
+ @dataclass
27
+ class StyleIssue:
28
+ """A style issue found in code"""
29
+ message: str
30
+ original: str
31
+ suggested: str
32
+ line: Optional[int] = None
33
+
34
+
35
+ @dataclass
36
+ class StyleCheckResult:
37
+ """Result of style checking"""
38
+ matches_convention: bool
39
+ issues: list[StyleIssue] = field(default_factory=list)
40
+ conventions_detected: dict = field(default_factory=dict)
41
+ corrected_code: Optional[str] = None
42
+
43
+ def to_dict(self) -> dict:
44
+ return {
45
+ "matches_convention": self.matches_convention,
46
+ "issues": [
47
+ {
48
+ "message": i.message,
49
+ "original": i.original,
50
+ "suggested": i.suggested,
51
+ "line": i.line,
52
+ }
53
+ for i in self.issues
54
+ ],
55
+ "conventions_detected": self.conventions_detected,
56
+ "has_corrections": self.corrected_code is not None,
57
+ }
58
+
59
+
60
+ def detect_naming_pattern(name: str) -> str:
61
+ """Detect the naming pattern of a single name"""
62
+ if name.isupper() and '_' in name:
63
+ return "SCREAMING_SNAKE"
64
+ elif '_' in name and name.islower():
65
+ return "snake_case"
66
+ elif name[0].isupper() and not '_' in name:
67
+ return "PascalCase"
68
+ elif name[0].islower() and any(c.isupper() for c in name):
69
+ return "camelCase"
70
+ elif name.islower():
71
+ return "snake_case"
72
+ else:
73
+ return "mixed"
74
+
75
+
76
+ def convert_to_snake_case(name: str) -> str:
77
+ """Convert name to snake_case"""
78
+ # Handle camelCase and PascalCase
79
+ s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
80
+ return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
81
+
82
+
83
+ def convert_to_camel_case(name: str) -> str:
84
+ """Convert name to camelCase"""
85
+ if '_' in name:
86
+ components = name.split('_')
87
+ return components[0].lower() + ''.join(x.title() for x in components[1:])
88
+ return name
89
+
90
+
91
+ class CodebaseAnalyzer(ast.NodeVisitor):
92
+ """Analyzes codebase to extract naming conventions"""
93
+
94
+ def __init__(self):
95
+ self.variable_names: list[str] = []
96
+ self.function_names: list[str] = []
97
+ self.class_names: list[str] = []
98
+ self.constant_names: list[str] = []
99
+
100
+ def visit_Name(self, node: ast.Name):
101
+ if isinstance(node.ctx, ast.Store):
102
+ if node.id.isupper():
103
+ self.constant_names.append(node.id)
104
+ else:
105
+ self.variable_names.append(node.id)
106
+ self.generic_visit(node)
107
+
108
+ def visit_FunctionDef(self, node: ast.FunctionDef):
109
+ self.function_names.append(node.name)
110
+ self.generic_visit(node)
111
+
112
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
113
+ self.function_names.append(node.name)
114
+ self.generic_visit(node)
115
+
116
+ def visit_ClassDef(self, node: ast.ClassDef):
117
+ self.class_names.append(node.name)
118
+ self.generic_visit(node)
119
+
120
+
121
+ def analyze_codebase(path: str) -> dict[str, NamingConvention]:
122
+ """Analyze codebase and extract naming conventions"""
123
+ path = Path(path)
124
+
125
+ if not path.exists():
126
+ return {}
127
+
128
+ analyzer = CodebaseAnalyzer()
129
+
130
+ # Find all Python files
131
+ py_files = list(path.rglob("*.py")) if path.is_dir() else [path]
132
+
133
+ for py_file in py_files[:50]: # Limit for performance
134
+ try:
135
+ code = py_file.read_text(encoding='utf-8')
136
+ tree = ast.parse(code)
137
+ analyzer.visit(tree)
138
+ except (SyntaxError, UnicodeDecodeError):
139
+ continue
140
+
141
+ # Detect patterns
142
+ conventions = {}
143
+
144
+ # Variable naming
145
+ if analyzer.variable_names:
146
+ patterns = [detect_naming_pattern(n) for n in analyzer.variable_names if len(n) > 1]
147
+ pattern_counts = Counter(patterns)
148
+ if pattern_counts:
149
+ most_common = pattern_counts.most_common(1)[0]
150
+ conventions["variables"] = NamingConvention(
151
+ pattern=most_common[0],
152
+ confidence=most_common[1] / len(patterns),
153
+ examples=analyzer.variable_names[:5]
154
+ )
155
+
156
+ # Function naming
157
+ if analyzer.function_names:
158
+ patterns = [detect_naming_pattern(n) for n in analyzer.function_names if not n.startswith('_')]
159
+ pattern_counts = Counter(patterns)
160
+ if pattern_counts:
161
+ most_common = pattern_counts.most_common(1)[0]
162
+ conventions["functions"] = NamingConvention(
163
+ pattern=most_common[0],
164
+ confidence=most_common[1] / len(patterns),
165
+ examples=analyzer.function_names[:5]
166
+ )
167
+
168
+ return conventions
169
+
170
+
171
+ def build_name_registry(path: str) -> dict[str, set[str]]:
172
+ """Build registry of all names in codebase"""
173
+ path = Path(path)
174
+
175
+ if not path.exists():
176
+ return {}
177
+
178
+ analyzer = CodebaseAnalyzer()
179
+
180
+ py_files = list(path.rglob("*.py")) if path.is_dir() else [path]
181
+
182
+ for py_file in py_files[:50]:
183
+ try:
184
+ code = py_file.read_text(encoding='utf-8')
185
+ tree = ast.parse(code)
186
+ analyzer.visit(tree)
187
+ except (SyntaxError, UnicodeDecodeError):
188
+ continue
189
+
190
+ return {
191
+ "variables": set(analyzer.variable_names),
192
+ "functions": set(analyzer.function_names),
193
+ "classes": set(analyzer.class_names),
194
+ }
195
+
196
+
197
+ def check_style(code: str, codebase_path: str = ".") -> StyleCheckResult:
198
+ """
199
+ Check code against codebase conventions.
200
+
201
+ Args:
202
+ code: Code to check
203
+ codebase_path: Path to codebase for convention extraction
204
+
205
+ Returns:
206
+ StyleCheckResult with issues and suggestions
207
+ """
208
+ issues: list[StyleIssue] = []
209
+
210
+ # Analyze codebase conventions
211
+ conventions = analyze_codebase(codebase_path)
212
+ registry = build_name_registry(codebase_path)
213
+
214
+ # Parse code
215
+ try:
216
+ tree = ast.parse(code)
217
+ except SyntaxError:
218
+ return StyleCheckResult(
219
+ matches_convention=False,
220
+ issues=[StyleIssue(
221
+ message="Cannot parse code",
222
+ original="",
223
+ suggested="",
224
+ )],
225
+ )
226
+
227
+ # Analyze code names
228
+ code_analyzer = CodebaseAnalyzer()
229
+ code_analyzer.visit(tree)
230
+
231
+ # Check variable naming
232
+ var_convention = conventions.get("variables")
233
+ if var_convention and var_convention.pattern == "snake_case":
234
+ for var in code_analyzer.variable_names:
235
+ pattern = detect_naming_pattern(var)
236
+ if pattern == "camelCase":
237
+ suggested = convert_to_snake_case(var)
238
+ issues.append(StyleIssue(
239
+ message=f"Variable '{var}' uses camelCase, codebase uses snake_case",
240
+ original=var,
241
+ suggested=suggested,
242
+ ))
243
+
244
+ # Check function naming
245
+ func_convention = conventions.get("functions")
246
+ if func_convention and func_convention.pattern == "snake_case":
247
+ for func in code_analyzer.function_names:
248
+ pattern = detect_naming_pattern(func)
249
+ if pattern == "camelCase":
250
+ suggested = convert_to_snake_case(func)
251
+ issues.append(StyleIssue(
252
+ message=f"Function '{func}' uses camelCase, codebase uses snake_case",
253
+ original=func,
254
+ suggested=suggested,
255
+ ))
256
+
257
+ # Check for similar existing names (typo detection)
258
+ all_existing_names = set()
259
+ for names in registry.values():
260
+ all_existing_names.update(names)
261
+
262
+ for var in code_analyzer.variable_names:
263
+ normalized = var.lower().replace('_', '')
264
+ for existing in all_existing_names:
265
+ existing_normalized = existing.lower().replace('_', '')
266
+ # Check for slight variations
267
+ if normalized != existing_normalized and len(normalized) > 3:
268
+ if normalized[:-1] == existing_normalized or normalized == existing_normalized[:-1]:
269
+ issues.append(StyleIssue(
270
+ message=f"'{var}' might be a typo of existing '{existing}'",
271
+ original=var,
272
+ suggested=existing,
273
+ ))
274
+
275
+ # Apply corrections
276
+ corrected_code = code
277
+ if issues:
278
+ for issue in issues:
279
+ if issue.original and issue.suggested:
280
+ # Simple replacement (word boundary aware)
281
+ pattern = r'\b' + re.escape(issue.original) + r'\b'
282
+ corrected_code = re.sub(pattern, issue.suggested, corrected_code)
283
+
284
+ # Build conventions dict for response
285
+ conv_dict = {}
286
+ for name, conv in conventions.items():
287
+ conv_dict[name] = {
288
+ "pattern": conv.pattern,
289
+ "confidence": conv.confidence,
290
+ "examples": conv.examples[:3],
291
+ }
292
+
293
+ return StyleCheckResult(
294
+ matches_convention=len(issues) == 0,
295
+ issues=issues,
296
+ conventions_detected=conv_dict,
297
+ corrected_code=corrected_code if corrected_code != code else None,
298
+ )
@@ -0,0 +1 @@
1
+ """TrustGate - Code verification module"""
@@ -0,0 +1,384 @@
1
+ """
2
+ TrustGate Checker - Core verification engine
3
+
4
+ Verifies code before you see it:
5
+ - Syntax checking
6
+ - Missing import detection
7
+ - Type checking (basic)
8
+ - Execution testing
9
+ """
10
+
11
+ import ast
12
+ import sys
13
+ import re
14
+ from dataclasses import dataclass, field
15
+ from typing import Optional
16
+ from pathlib import Path
17
+
18
+ # Standard library modules that are commonly used
19
+ STDLIB_MODULES = {
20
+ 'os', 'sys', 'json', 're', 'math', 'random', 'datetime', 'time',
21
+ 'collections', 'itertools', 'functools', 'typing', 'pathlib',
22
+ 'subprocess', 'threading', 'multiprocessing', 'asyncio',
23
+ 'urllib', 'http', 'socket', 'email', 'html', 'xml',
24
+ 'sqlite3', 'csv', 'configparser', 'logging', 'unittest',
25
+ 'hashlib', 'hmac', 'secrets', 'base64', 'pickle', 'copy',
26
+ 'io', 'tempfile', 'shutil', 'glob', 'fnmatch',
27
+ 'argparse', 'getopt', 'textwrap', 'string',
28
+ 'dataclasses', 'enum', 'abc', 'contextlib',
29
+ }
30
+
31
+ # Common third-party modules and their pip names
32
+ COMMON_PACKAGES = {
33
+ 'requests': 'requests',
34
+ 'numpy': 'numpy',
35
+ 'pandas': 'pandas',
36
+ 'flask': 'flask',
37
+ 'django': 'django',
38
+ 'fastapi': 'fastapi',
39
+ 'httpx': 'httpx',
40
+ 'aiohttp': 'aiohttp',
41
+ 'pydantic': 'pydantic',
42
+ 'sqlalchemy': 'sqlalchemy',
43
+ 'pytest': 'pytest',
44
+ 'rich': 'rich',
45
+ 'click': 'click',
46
+ 'typer': 'typer',
47
+ 'beautifulsoup4': 'bs4',
48
+ 'bs4': 'beautifulsoup4',
49
+ 'PIL': 'pillow',
50
+ 'cv2': 'opencv-python',
51
+ 'sklearn': 'scikit-learn',
52
+ 'torch': 'torch',
53
+ 'tensorflow': 'tensorflow',
54
+ }
55
+
56
+
57
+ @dataclass
58
+ class VerificationIssue:
59
+ """A single verification issue"""
60
+ severity: str # "error", "warning", "info"
61
+ message: str
62
+ line: Optional[int] = None
63
+ column: Optional[int] = None
64
+ fix_available: bool = False
65
+ fix_description: Optional[str] = None
66
+
67
+
68
+ @dataclass
69
+ class VerificationResult:
70
+ """Result of code verification"""
71
+ is_valid: bool
72
+ issues: list[VerificationIssue] = field(default_factory=list)
73
+ fixed_code: Optional[str] = None
74
+ confidence_score: float = 0.0
75
+
76
+ def to_dict(self) -> dict:
77
+ return {
78
+ "is_valid": self.is_valid,
79
+ "issues": [
80
+ {
81
+ "type": i.severity,
82
+ "severity": i.severity,
83
+ "message": i.message,
84
+ "line": i.line,
85
+ "fix_available": i.fix_available,
86
+ "fix_description": i.fix_description,
87
+ }
88
+ for i in self.issues
89
+ ],
90
+ "confidence_score": self.confidence_score,
91
+ "has_fixes": self.fixed_code is not None,
92
+ }
93
+
94
+
95
+ class ImportVisitor(ast.NodeVisitor):
96
+ """Visits AST to find all imports"""
97
+
98
+ def __init__(self):
99
+ self.imports: set[str] = set()
100
+ self.from_imports: dict[str, set[str]] = {}
101
+
102
+ def visit_Import(self, node: ast.Import):
103
+ for alias in node.names:
104
+ self.imports.add(alias.name.split('.')[0])
105
+ self.generic_visit(node)
106
+
107
+ def visit_ImportFrom(self, node: ast.ImportFrom):
108
+ if node.module:
109
+ base_module = node.module.split('.')[0]
110
+ self.imports.add(base_module)
111
+ if base_module not in self.from_imports:
112
+ self.from_imports[base_module] = set()
113
+ for alias in node.names:
114
+ self.from_imports[base_module].add(alias.name)
115
+ self.generic_visit(node)
116
+
117
+
118
+ class NameVisitor(ast.NodeVisitor):
119
+ """Visits AST to find all used names"""
120
+
121
+ def __init__(self):
122
+ self.used_names: set[str] = set()
123
+ self.defined_names: set[str] = set()
124
+ self.function_calls: set[str] = set()
125
+
126
+ def visit_Name(self, node: ast.Name):
127
+ if isinstance(node.ctx, ast.Load):
128
+ self.used_names.add(node.id)
129
+ elif isinstance(node.ctx, ast.Store):
130
+ self.defined_names.add(node.id)
131
+ self.generic_visit(node)
132
+
133
+ def visit_Call(self, node: ast.Call):
134
+ if isinstance(node.func, ast.Attribute):
135
+ # module.function() calls
136
+ if isinstance(node.func.value, ast.Name):
137
+ self.function_calls.add(f"{node.func.value.id}.{node.func.attr}")
138
+ elif isinstance(node.func, ast.Name):
139
+ self.function_calls.add(node.func.id)
140
+ self.generic_visit(node)
141
+
142
+ def visit_FunctionDef(self, node: ast.FunctionDef):
143
+ self.defined_names.add(node.name)
144
+ self.generic_visit(node)
145
+
146
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
147
+ self.defined_names.add(node.name)
148
+ self.generic_visit(node)
149
+
150
+ def visit_ClassDef(self, node: ast.ClassDef):
151
+ self.defined_names.add(node.name)
152
+ self.generic_visit(node)
153
+
154
+
155
+ def check_syntax(code: str) -> tuple[bool, Optional[VerificationIssue]]:
156
+ """Check if code has valid Python syntax"""
157
+ try:
158
+ ast.parse(code)
159
+ return True, None
160
+ except SyntaxError as e:
161
+ return False, VerificationIssue(
162
+ severity="error",
163
+ message=f"Syntax error: {e.msg}",
164
+ line=e.lineno,
165
+ column=e.offset,
166
+ fix_available=False,
167
+ )
168
+
169
+
170
+ def detect_missing_imports(code: str) -> list[VerificationIssue]:
171
+ """Detect missing imports in code"""
172
+ issues = []
173
+
174
+ try:
175
+ tree = ast.parse(code)
176
+ except SyntaxError:
177
+ return issues # Can't parse, syntax check will catch this
178
+
179
+ # Get all imports
180
+ import_visitor = ImportVisitor()
181
+ import_visitor.visit(tree)
182
+ imported = import_visitor.imports
183
+
184
+ # Get all used names
185
+ name_visitor = NameVisitor()
186
+ name_visitor.visit(tree)
187
+
188
+ # Check function calls for module usage
189
+ for call in name_visitor.function_calls:
190
+ if '.' in call:
191
+ module = call.split('.')[0]
192
+ if module not in imported and module not in name_visitor.defined_names:
193
+ # Check if it's a known module
194
+ if module in STDLIB_MODULES:
195
+ issues.append(VerificationIssue(
196
+ severity="error",
197
+ message=f"Missing import: {module}",
198
+ fix_available=True,
199
+ fix_description=f"Add 'import {module}'",
200
+ ))
201
+ elif module in COMMON_PACKAGES:
202
+ issues.append(VerificationIssue(
203
+ severity="error",
204
+ message=f"Missing import: {module} (pip install {COMMON_PACKAGES[module]})",
205
+ fix_available=True,
206
+ fix_description=f"Add 'import {module}'",
207
+ ))
208
+
209
+ return issues
210
+
211
+
212
+ def detect_undefined_names(code: str) -> list[VerificationIssue]:
213
+ """Detect potentially undefined variable names"""
214
+ issues = []
215
+
216
+ try:
217
+ tree = ast.parse(code)
218
+ except SyntaxError:
219
+ return issues
220
+
221
+ # Get imports
222
+ import_visitor = ImportVisitor()
223
+ import_visitor.visit(tree)
224
+
225
+ # Get names
226
+ name_visitor = NameVisitor()
227
+ name_visitor.visit(tree)
228
+
229
+ # Built-in names that don't need imports
230
+ builtins = set(dir(__builtins__)) if isinstance(__builtins__, dict) else set(dir(__builtins__))
231
+ builtins.update({'True', 'False', 'None', 'self', 'cls'})
232
+
233
+ # Find undefined names
234
+ all_defined = (
235
+ name_visitor.defined_names |
236
+ import_visitor.imports |
237
+ builtins
238
+ )
239
+
240
+ for name in name_visitor.used_names:
241
+ if name not in all_defined and not name.startswith('_'):
242
+ # Check if it might be a module
243
+ if name in STDLIB_MODULES or name in COMMON_PACKAGES:
244
+ issues.append(VerificationIssue(
245
+ severity="warning",
246
+ message=f"'{name}' used but not imported",
247
+ fix_available=True,
248
+ fix_description=f"Add 'import {name}'",
249
+ ))
250
+
251
+ return issues
252
+
253
+
254
+ def auto_fix_imports(code: str, issues: list[VerificationIssue]) -> str:
255
+ """Auto-fix missing imports"""
256
+ imports_to_add = []
257
+
258
+ for issue in issues:
259
+ if issue.fix_available and "Missing import:" in issue.message:
260
+ # Extract module name
261
+ match = re.search(r"Missing import: (\w+)", issue.message)
262
+ if match:
263
+ module = match.group(1)
264
+ imports_to_add.append(f"import {module}")
265
+
266
+ if not imports_to_add:
267
+ return code
268
+
269
+ # Add imports at the top (after any existing imports or docstrings)
270
+ lines = code.split('\n')
271
+ insert_pos = 0
272
+
273
+ # Skip docstrings
274
+ in_docstring = False
275
+ for i, line in enumerate(lines):
276
+ stripped = line.strip()
277
+ if stripped.startswith('"""') or stripped.startswith("'''"):
278
+ if in_docstring:
279
+ in_docstring = False
280
+ insert_pos = i + 1
281
+ else:
282
+ in_docstring = True
283
+ elif not in_docstring and (stripped.startswith('import ') or stripped.startswith('from ')):
284
+ insert_pos = i + 1
285
+ elif not in_docstring and stripped and not stripped.startswith('#'):
286
+ break
287
+
288
+ # Insert imports
289
+ for imp in imports_to_add:
290
+ lines.insert(insert_pos, imp)
291
+ insert_pos += 1
292
+
293
+ return '\n'.join(lines)
294
+
295
+
296
+ def verify_code(code: str, auto_fix: bool = True) -> VerificationResult:
297
+ """
298
+ Main verification function.
299
+
300
+ Args:
301
+ code: Python code to verify
302
+ auto_fix: Whether to attempt auto-fixes
303
+
304
+ Returns:
305
+ VerificationResult with issues and optionally fixed code
306
+ """
307
+ issues: list[VerificationIssue] = []
308
+
309
+ # 1. Syntax check
310
+ syntax_ok, syntax_issue = check_syntax(code)
311
+ if not syntax_ok and syntax_issue:
312
+ issues.append(syntax_issue)
313
+ return VerificationResult(
314
+ is_valid=False,
315
+ issues=issues,
316
+ confidence_score=0.0,
317
+ )
318
+
319
+ # 2. Missing imports
320
+ import_issues = detect_missing_imports(code)
321
+ issues.extend(import_issues)
322
+
323
+ # 3. Undefined names
324
+ undefined_issues = detect_undefined_names(code)
325
+ issues.extend(undefined_issues)
326
+
327
+ # Calculate confidence score
328
+ error_count = sum(1 for i in issues if i.severity == "error")
329
+ warning_count = sum(1 for i in issues if i.severity == "warning")
330
+ confidence = max(0.0, 1.0 - (error_count * 0.2) - (warning_count * 0.05))
331
+
332
+ # Auto-fix if requested
333
+ fixed_code = None
334
+ if auto_fix and issues:
335
+ fixed_code = auto_fix_imports(code, issues)
336
+ if fixed_code != code:
337
+ # Re-verify fixed code
338
+ recheck_syntax, _ = check_syntax(fixed_code)
339
+ if recheck_syntax:
340
+ confidence = min(confidence + 0.1, 1.0)
341
+
342
+ is_valid = error_count == 0
343
+
344
+ return VerificationResult(
345
+ is_valid=is_valid,
346
+ issues=issues,
347
+ fixed_code=fixed_code,
348
+ confidence_score=confidence,
349
+ )
350
+
351
+
352
+ def verify_file(file_path: str) -> VerificationResult:
353
+ """Verify a Python file"""
354
+ path = Path(file_path)
355
+ if not path.exists():
356
+ return VerificationResult(
357
+ is_valid=False,
358
+ issues=[VerificationIssue(
359
+ severity="error",
360
+ message=f"File not found: {file_path}",
361
+ )],
362
+ confidence_score=0.0,
363
+ )
364
+
365
+ code = path.read_text(encoding='utf-8')
366
+ return verify_code(code)
367
+
368
+
369
+ # Quick test
370
+ if __name__ == "__main__":
371
+ test_code = '''
372
+ def fetch_data():
373
+ response = requests.get("https://api.example.com")
374
+ return json.loads(response.text)
375
+ '''
376
+
377
+ result = verify_code(test_code)
378
+ print(f"Valid: {result.is_valid}")
379
+ print(f"Confidence: {result.confidence_score:.0%}")
380
+ for issue in result.issues:
381
+ print(f" [{issue.severity}] {issue.message}")
382
+ if result.fixed_code:
383
+ print("\nFixed code:")
384
+ print(result.fixed_code)