prompture 0.0.50__py3-none-any.whl → 0.0.51.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,302 @@
1
+ """AST visitors for extracting code features.
2
+
3
+ Provides an AST node visitor that extracts security-relevant features
4
+ from Python source code.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import ast
10
+ from dataclasses import dataclass, field
11
+
12
+
13
+ @dataclass
14
+ class CodeFeatures:
15
+ """Features extracted from Python code via AST analysis.
16
+
17
+ Attributes:
18
+ imports: Set of module names imported (import x, from x import y).
19
+ file_operations: List of (operation, args) tuples for file access.
20
+ network_calls: List of network-related function calls detected.
21
+ system_calls: List of system/subprocess calls detected.
22
+ exec_eval_usage: List of exec/eval/compile calls detected.
23
+ dangerous_builtins: Set of dangerous builtin names used.
24
+ attribute_accesses: Set of attribute access patterns (e.g., "os.system").
25
+ function_calls: Set of all function call names.
26
+ has_global_statements: Whether the code uses global statements.
27
+ has_nonlocal_statements: Whether the code uses nonlocal statements.
28
+ class_definitions: Set of class names defined.
29
+ async_operations: Whether the code contains async/await.
30
+ """
31
+
32
+ imports: set[str] = field(default_factory=set)
33
+ file_operations: list[tuple[str, list[str]]] = field(default_factory=list)
34
+ network_calls: list[tuple[str, list[str]]] = field(default_factory=list)
35
+ system_calls: list[tuple[str, list[str]]] = field(default_factory=list)
36
+ exec_eval_usage: list[tuple[str, list[str]]] = field(default_factory=list)
37
+ dangerous_builtins: set[str] = field(default_factory=set)
38
+ attribute_accesses: set[str] = field(default_factory=set)
39
+ function_calls: set[str] = field(default_factory=set)
40
+ has_global_statements: bool = False
41
+ has_nonlocal_statements: bool = False
42
+ class_definitions: set[str] = field(default_factory=set)
43
+ async_operations: bool = False
44
+
45
+
46
+ # Builtins that can be dangerous when executed dynamically
47
+ DANGEROUS_BUILTINS = frozenset(
48
+ {
49
+ "eval",
50
+ "exec",
51
+ "compile",
52
+ "__import__",
53
+ "open",
54
+ "input",
55
+ "breakpoint",
56
+ "memoryview",
57
+ "vars",
58
+ "dir",
59
+ "globals",
60
+ "locals",
61
+ "getattr",
62
+ "setattr",
63
+ "delattr",
64
+ "hasattr",
65
+ }
66
+ )
67
+
68
+ # File operation function names
69
+ FILE_OPERATIONS = frozenset(
70
+ {
71
+ "open",
72
+ "read",
73
+ "write",
74
+ "close",
75
+ "seek",
76
+ "tell",
77
+ "readline",
78
+ "readlines",
79
+ "writelines",
80
+ "flush",
81
+ "truncate",
82
+ }
83
+ )
84
+
85
+ # Network-related module prefixes and function names
86
+ NETWORK_MODULES = frozenset(
87
+ {
88
+ "socket",
89
+ "urllib",
90
+ "http",
91
+ "requests",
92
+ "httpx",
93
+ "aiohttp",
94
+ "ftplib",
95
+ "smtplib",
96
+ "poplib",
97
+ "imaplib",
98
+ "telnetlib",
99
+ "ssl",
100
+ "websocket",
101
+ "websockets",
102
+ }
103
+ )
104
+
105
+ # System call patterns
106
+ SYSTEM_CALL_PATTERNS = frozenset(
107
+ {
108
+ "os.system",
109
+ "os.popen",
110
+ "os.spawn",
111
+ "os.spawnl",
112
+ "os.spawnle",
113
+ "os.spawnlp",
114
+ "os.spawnlpe",
115
+ "os.spawnv",
116
+ "os.spawnve",
117
+ "os.spawnvp",
118
+ "os.spawnvpe",
119
+ "os.exec",
120
+ "os.execl",
121
+ "os.execle",
122
+ "os.execlp",
123
+ "os.execlpe",
124
+ "os.execv",
125
+ "os.execve",
126
+ "os.execvp",
127
+ "os.execvpe",
128
+ "os.fork",
129
+ "os.forkpty",
130
+ "os.kill",
131
+ "os.killpg",
132
+ "subprocess.run",
133
+ "subprocess.call",
134
+ "subprocess.check_call",
135
+ "subprocess.check_output",
136
+ "subprocess.Popen",
137
+ "subprocess.getoutput",
138
+ "subprocess.getstatusoutput",
139
+ "pty.spawn",
140
+ "pty.fork",
141
+ }
142
+ )
143
+
144
+
145
+ class FeatureExtractor(ast.NodeVisitor):
146
+ """AST visitor that extracts security-relevant features from Python code.
147
+
148
+ Usage::
149
+
150
+ extractor = FeatureExtractor()
151
+ extractor.visit(ast.parse(source_code))
152
+ features = extractor.features
153
+ """
154
+
155
+ def __init__(self) -> None:
156
+ self.features = CodeFeatures()
157
+ self._current_module_context: list[str] = []
158
+
159
+ def visit_Import(self, node: ast.Import) -> None:
160
+ """Handle import statements."""
161
+ for alias in node.names:
162
+ self.features.imports.add(alias.name.split(".")[0])
163
+ self.generic_visit(node)
164
+
165
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
166
+ """Handle from ... import statements."""
167
+ if node.module:
168
+ self.features.imports.add(node.module.split(".")[0])
169
+ self.generic_visit(node)
170
+
171
+ def visit_Call(self, node: ast.Call) -> None:
172
+ """Handle function calls."""
173
+ call_name = self._get_call_name(node)
174
+ if call_name:
175
+ self.features.function_calls.add(call_name)
176
+ self._classify_call(call_name, node)
177
+ self.generic_visit(node)
178
+
179
+ def visit_Attribute(self, node: ast.Attribute) -> None:
180
+ """Handle attribute access."""
181
+ attr_chain = self._get_attribute_chain(node)
182
+ if attr_chain:
183
+ self.features.attribute_accesses.add(attr_chain)
184
+ self.generic_visit(node)
185
+
186
+ def visit_Name(self, node: ast.Name) -> None:
187
+ """Handle name references (potential builtin usage)."""
188
+ if node.id in DANGEROUS_BUILTINS:
189
+ self.features.dangerous_builtins.add(node.id)
190
+ self.generic_visit(node)
191
+
192
+ def visit_Global(self, node: ast.Global) -> None:
193
+ """Handle global statements."""
194
+ self.features.has_global_statements = True
195
+ self.generic_visit(node)
196
+
197
+ def visit_Nonlocal(self, node: ast.Nonlocal) -> None:
198
+ """Handle nonlocal statements."""
199
+ self.features.has_nonlocal_statements = True
200
+ self.generic_visit(node)
201
+
202
+ def visit_ClassDef(self, node: ast.ClassDef) -> None:
203
+ """Handle class definitions."""
204
+ self.features.class_definitions.add(node.name)
205
+ self.generic_visit(node)
206
+
207
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
208
+ """Handle async function definitions."""
209
+ self.features.async_operations = True
210
+ self.generic_visit(node)
211
+
212
+ def visit_Await(self, node: ast.Await) -> None:
213
+ """Handle await expressions."""
214
+ self.features.async_operations = True
215
+ self.generic_visit(node)
216
+
217
+ def visit_AsyncFor(self, node: ast.AsyncFor) -> None:
218
+ """Handle async for loops."""
219
+ self.features.async_operations = True
220
+ self.generic_visit(node)
221
+
222
+ def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
223
+ """Handle async with statements."""
224
+ self.features.async_operations = True
225
+ self.generic_visit(node)
226
+
227
+ def _get_call_name(self, node: ast.Call) -> str | None:
228
+ """Extract the full call name from a Call node."""
229
+ if isinstance(node.func, ast.Name):
230
+ return node.func.id
231
+ elif isinstance(node.func, ast.Attribute):
232
+ return self._get_attribute_chain(node.func)
233
+ return None
234
+
235
+ def _get_attribute_chain(self, node: ast.Attribute) -> str | None:
236
+ """Build the full attribute chain (e.g., 'os.path.join')."""
237
+ parts: list[str] = [node.attr]
238
+ current = node.value
239
+ while isinstance(current, ast.Attribute):
240
+ parts.append(current.attr)
241
+ current = current.value
242
+ if isinstance(current, ast.Name):
243
+ parts.append(current.id)
244
+ return ".".join(reversed(parts))
245
+ return None
246
+
247
+ def _get_call_args_as_strings(self, node: ast.Call) -> list[str]:
248
+ """Extract string representations of call arguments."""
249
+ args: list[str] = []
250
+ for arg in node.args:
251
+ if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
252
+ args.append(arg.value)
253
+ else:
254
+ args.append(ast.dump(arg))
255
+ return args
256
+
257
+ def _classify_call(self, call_name: str, node: ast.Call) -> None:
258
+ """Classify a function call into categories."""
259
+ args = self._get_call_args_as_strings(node)
260
+
261
+ # Check for exec/eval
262
+ if call_name in ("exec", "eval", "compile"):
263
+ self.features.exec_eval_usage.append((call_name, args))
264
+
265
+ # Check for file operations
266
+ if (
267
+ call_name == "open"
268
+ or call_name.endswith(".open")
269
+ or any(call_name.endswith(f".{op}") for op in FILE_OPERATIONS)
270
+ ):
271
+ self.features.file_operations.append((call_name, args))
272
+
273
+ # Check for system calls
274
+ if call_name in SYSTEM_CALL_PATTERNS or any(
275
+ call_name.startswith(pattern.rsplit(".", 1)[0])
276
+ for pattern in SYSTEM_CALL_PATTERNS
277
+ if call_name.endswith(pattern.rsplit(".", 1)[-1])
278
+ ):
279
+ self.features.system_calls.append((call_name, args))
280
+
281
+ # Check for network calls
282
+ module_prefix = call_name.split(".")[0] if "." in call_name else ""
283
+ if module_prefix in NETWORK_MODULES:
284
+ self.features.network_calls.append((call_name, args))
285
+
286
+
287
+ def extract_features(source: str) -> CodeFeatures:
288
+ """Extract security-relevant features from Python source code.
289
+
290
+ Args:
291
+ source: Python source code as a string.
292
+
293
+ Returns:
294
+ CodeFeatures dataclass with extracted features.
295
+
296
+ Raises:
297
+ SyntaxError: If the source code is not valid Python.
298
+ """
299
+ tree = ast.parse(source)
300
+ extractor = FeatureExtractor()
301
+ extractor.visit(tree)
302
+ return extractor.features
@@ -0,0 +1,219 @@
1
+ """Risk scoring for Python code analysis.
2
+
3
+ Calculates risk levels based on detected code features.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import enum
9
+ from dataclasses import dataclass, field
10
+
11
+ from .ast_visitors import CodeFeatures
12
+
13
+ # Always-blocked imports that represent severe security risks
14
+ CRITICAL_IMPORTS = frozenset(
15
+ {
16
+ "ctypes",
17
+ "multiprocessing",
18
+ "threading",
19
+ "_thread",
20
+ "gc",
21
+ "sys",
22
+ "builtins",
23
+ "importlib",
24
+ "pkgutil",
25
+ "code",
26
+ "codeop",
27
+ "rlcompleter",
28
+ "pdb",
29
+ "bdb",
30
+ "trace",
31
+ "traceback",
32
+ "linecache",
33
+ "inspect",
34
+ "dis",
35
+ "pickletools",
36
+ "formatter",
37
+ "msilib",
38
+ "winreg",
39
+ "_winapi",
40
+ "posix",
41
+ "posixpath",
42
+ "nt",
43
+ "ntpath",
44
+ "_posixsubprocess",
45
+ }
46
+ )
47
+
48
+ # High-risk imports
49
+ HIGH_RISK_IMPORTS = frozenset(
50
+ {
51
+ "os",
52
+ "subprocess",
53
+ "shutil",
54
+ "pathlib",
55
+ "pickle",
56
+ "shelve",
57
+ "marshal",
58
+ "socket",
59
+ "ssl",
60
+ "asyncio",
61
+ "signal",
62
+ "pty",
63
+ "tty",
64
+ "termios",
65
+ "resource",
66
+ "syslog",
67
+ "tempfile",
68
+ "glob",
69
+ "fnmatch",
70
+ }
71
+ )
72
+
73
+ # Medium-risk imports
74
+ MEDIUM_RISK_IMPORTS = frozenset(
75
+ {
76
+ "urllib",
77
+ "http",
78
+ "email",
79
+ "mailbox",
80
+ "mimetypes",
81
+ "base64",
82
+ "binascii",
83
+ "quopri",
84
+ "uu",
85
+ "html",
86
+ "xml",
87
+ "configparser",
88
+ "logging",
89
+ "warnings",
90
+ "contextlib",
91
+ "abc",
92
+ "atexit",
93
+ "weakref",
94
+ "copy",
95
+ "pprint",
96
+ "reprlib",
97
+ }
98
+ )
99
+
100
+
101
+ class RiskLevel(str, enum.Enum):
102
+ """Risk level classification for code analysis."""
103
+
104
+ LOW = "low"
105
+ MEDIUM = "medium"
106
+ HIGH = "high"
107
+ CRITICAL = "critical"
108
+
109
+
110
+ @dataclass
111
+ class RiskAssessment:
112
+ """Detailed risk assessment for analyzed code.
113
+
114
+ Attributes:
115
+ level: Overall risk level.
116
+ score: Numeric risk score (0-100).
117
+ reasons: List of reasons contributing to the risk.
118
+ blocked_imports: Imports that would be blocked.
119
+ warnings: Non-blocking warnings about the code.
120
+ """
121
+
122
+ level: RiskLevel
123
+ score: int
124
+ reasons: list[str] = field(default_factory=list)
125
+ blocked_imports: set[str] = field(default_factory=set)
126
+ warnings: list[str] = field(default_factory=list)
127
+
128
+
129
+ def calculate_risk(features: CodeFeatures) -> RiskAssessment:
130
+ """Calculate risk score and level from extracted code features.
131
+
132
+ Args:
133
+ features: CodeFeatures from AST analysis.
134
+
135
+ Returns:
136
+ RiskAssessment with level, score, reasons, and warnings.
137
+ """
138
+ score = 0
139
+ reasons: list[str] = []
140
+ warnings: list[str] = []
141
+ blocked_imports: set[str] = set()
142
+
143
+ # Check critical imports (instant critical)
144
+ critical_found = features.imports & CRITICAL_IMPORTS
145
+ if critical_found:
146
+ score += 100
147
+ blocked_imports.update(critical_found)
148
+ reasons.append(f"Critical imports detected: {', '.join(sorted(critical_found))}")
149
+
150
+ # Check high-risk imports
151
+ high_risk_found = features.imports & HIGH_RISK_IMPORTS
152
+ if high_risk_found:
153
+ score += 40
154
+ reasons.append(f"High-risk imports: {', '.join(sorted(high_risk_found))}")
155
+
156
+ # Check medium-risk imports
157
+ medium_risk_found = features.imports & MEDIUM_RISK_IMPORTS
158
+ if medium_risk_found:
159
+ score += 15
160
+ warnings.append(f"Medium-risk imports: {', '.join(sorted(medium_risk_found))}")
161
+
162
+ # Check exec/eval usage (critical)
163
+ if features.exec_eval_usage:
164
+ score += 80
165
+ calls = [call[0] for call in features.exec_eval_usage]
166
+ reasons.append(f"Dynamic code execution: {', '.join(calls)}")
167
+
168
+ # Check system calls (critical)
169
+ if features.system_calls:
170
+ score += 80
171
+ calls = [call[0] for call in features.system_calls]
172
+ reasons.append(f"System calls detected: {', '.join(calls)}")
173
+
174
+ # Check network calls (high)
175
+ if features.network_calls:
176
+ score += 35
177
+ calls = [call[0] for call in features.network_calls]
178
+ reasons.append(f"Network operations: {', '.join(calls)}")
179
+
180
+ # Check file operations (medium-high depending on context)
181
+ if features.file_operations:
182
+ score += 25
183
+ calls = [call[0] for call in features.file_operations]
184
+ warnings.append(f"File operations: {', '.join(calls)}")
185
+
186
+ # Check dangerous builtins
187
+ if features.dangerous_builtins:
188
+ score += 20
189
+ warnings.append(f"Dangerous builtins: {', '.join(sorted(features.dangerous_builtins))}")
190
+
191
+ # Check for global/nonlocal (can be used to escape sandbox)
192
+ if features.has_global_statements:
193
+ score += 10
194
+ warnings.append("Uses global statements")
195
+
196
+ if features.has_nonlocal_statements:
197
+ score += 5
198
+ warnings.append("Uses nonlocal statements")
199
+
200
+ # Cap score at 100
201
+ score = min(score, 100)
202
+
203
+ # Determine level
204
+ if score >= 70:
205
+ level = RiskLevel.CRITICAL
206
+ elif score >= 40:
207
+ level = RiskLevel.HIGH
208
+ elif score >= 15:
209
+ level = RiskLevel.MEDIUM
210
+ else:
211
+ level = RiskLevel.LOW
212
+
213
+ return RiskAssessment(
214
+ level=level,
215
+ score=score,
216
+ reasons=reasons,
217
+ blocked_imports=blocked_imports,
218
+ warnings=warnings,
219
+ )