aofire-python-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,694 @@
1
+ """Source-to-sink taint analysis via call graph construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import ast
7
+ import json
8
+ import os
9
+ import re
10
+ import sys
11
+ from collections import defaultdict
12
+
13
+ from pydantic import BaseModel
14
+
15
+ _SUPPRESS_RE = re.compile(
16
+ r"#\s*taint:\s*ignore\[([A-Z]+-\d+)\]\s*--\s*(.+)",
17
+ )
18
+
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Data structures
22
+ # ---------------------------------------------------------------------------
23
+
24
+ class FunctionInfo(BaseModel):
25
+ """Metadata about a single function definition."""
26
+
27
+ module: str
28
+ name: str
29
+ lineno: int
30
+ is_source: bool = False
31
+ is_sink: bool = False
32
+ is_sanitizer: bool = False
33
+ source_cwe: str = ""
34
+ sink_cwe: str = ""
35
+
36
+
37
+ class CallEdge(BaseModel):
38
+ """A directed edge from caller to callee."""
39
+
40
+ caller: str
41
+ callee: str
42
+ lineno: int
43
+
44
+
45
+ class Suppression(BaseModel):
46
+ """A user-acknowledged taint suppression."""
47
+
48
+ function: str
49
+ cwe: str
50
+ reason: str
51
+ lineno: int
52
+
53
+
54
+ class TaintPath(BaseModel):
55
+ """A path from a taint source to a taint sink."""
56
+
57
+ source: str
58
+ sink: str
59
+ path: list[str]
60
+ cwe: str
61
+ sanitized: bool
62
+ suppressed: bool = False
63
+ suppression_reason: str = ""
64
+
65
+
66
+ class CallGraph(BaseModel):
67
+ """Aggregated call graph with function metadata."""
68
+
69
+ functions: dict[str, FunctionInfo] = {}
70
+ edges: list[CallEdge] = []
71
+ suppressions: list[Suppression] = []
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Pattern tables
76
+ # ---------------------------------------------------------------------------
77
+
78
+ SOURCE_EXACT: dict[str, str] = {
79
+ "input": "CWE-20",
80
+ "json.loads": "CWE-502",
81
+ "json.load": "CWE-502",
82
+ "open": "CWE-73",
83
+ }
84
+
85
+ SOURCE_SUFFIX: dict[str, str] = {
86
+ ".read": "CWE-73",
87
+ ".model_validate": "CWE-502",
88
+ ".model_validate_json": "CWE-502",
89
+ ".parse_args": "CWE-20",
90
+ ".query": "CWE-74",
91
+ }
92
+
93
+ SINK_EXACT: dict[str, str] = {
94
+ "eval": "CWE-94",
95
+ "exec": "CWE-94",
96
+ "os.system": "CWE-78",
97
+ "os.popen": "CWE-78",
98
+ "subprocess.run": "CWE-78",
99
+ "subprocess.call": "CWE-78",
100
+ "subprocess.Popen": "CWE-78",
101
+ "subprocess.check_output": "CWE-78",
102
+ "subprocess.check_call": "CWE-78",
103
+ "print": "CWE-200",
104
+ }
105
+
106
+ SINK_SUFFIX: dict[str, str] = {
107
+ ".write": "CWE-73",
108
+ ".query": "CWE-74",
109
+ }
110
+
111
+ SANITIZER_NAMES: set[str] = {
112
+ "frame_data",
113
+ "validate_ontology_strict",
114
+ "is_safe_bash",
115
+ "is_safe_path",
116
+ "model_validate",
117
+ }
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # AST helpers
122
+ # ---------------------------------------------------------------------------
123
+
124
+ def _resolve_call_name(node: ast.Call) -> str:
125
+ """Extract dotted name from a Call AST node."""
126
+ func = node.func
127
+ if isinstance(func, ast.Name):
128
+ return func.id
129
+ if isinstance(func, ast.Attribute):
130
+ parts = _attr_parts(func)
131
+ return ".".join(parts) if parts else ""
132
+ return ""
133
+
134
+
135
+ def _attr_parts(node: ast.Attribute) -> list[str]:
136
+ """Recursively collect dotted attribute parts."""
137
+ if isinstance(node.value, ast.Name):
138
+ return [node.value.id, node.attr]
139
+ if isinstance(node.value, ast.Attribute):
140
+ inner = _attr_parts(node.value)
141
+ return inner + [node.attr] if inner else []
142
+ return []
143
+
144
+
145
+ def _is_source_call(name: str) -> tuple[bool, str]:
146
+ """Check whether *name* matches a source pattern."""
147
+ if name in SOURCE_EXACT:
148
+ return True, SOURCE_EXACT[name]
149
+ for suffix, cwe in SOURCE_SUFFIX.items():
150
+ if name.endswith(suffix):
151
+ return True, cwe
152
+ return False, ""
153
+
154
+
155
+ def _is_sink_call(name: str) -> tuple[bool, str]:
156
+ """Check whether *name* matches a sink pattern."""
157
+ if name in SINK_EXACT:
158
+ return True, SINK_EXACT[name]
159
+ for suffix, cwe in SINK_SUFFIX.items():
160
+ if name.endswith(suffix):
161
+ return True, cwe
162
+ return False, ""
163
+
164
+
165
+ def _is_sanitizer_name(name: str) -> bool:
166
+ """Check whether the bare tail of *name* is a sanitizer."""
167
+ bare = name.rsplit(".", 1)[-1]
168
+ return bare in SANITIZER_NAMES
169
+
170
+
171
+ def _collect_calls_in_body(
172
+ body: list[ast.stmt],
173
+ ) -> list[str]:
174
+ """Walk a function body and return all call names."""
175
+ names: list[str] = []
176
+ for node in ast.walk(ast.Module(body=body, type_ignores=[])):
177
+ if isinstance(node, ast.Call):
178
+ n = _resolve_call_name(node)
179
+ if n:
180
+ names.append(n)
181
+ return names
182
+
183
+
184
+ def _classify_as_source(
185
+ calls: list[str],
186
+ ) -> tuple[bool, str]:
187
+ """Return (True, cwe) if any call in *calls* is a source."""
188
+ for c in calls:
189
+ hit, cwe = _is_source_call(c)
190
+ if hit:
191
+ return True, cwe
192
+ return False, ""
193
+
194
+
195
+ def _classify_as_sink(
196
+ calls: list[str],
197
+ ) -> tuple[bool, str]:
198
+ """Return (True, cwe) if any call in *calls* is a sink."""
199
+ for c in calls:
200
+ hit, cwe = _is_sink_call(c)
201
+ if hit:
202
+ return True, cwe
203
+ return False, ""
204
+
205
+
206
+ # ---------------------------------------------------------------------------
207
+ # AST visitors
208
+ # ---------------------------------------------------------------------------
209
+
210
+ class _ImportCollector(ast.NodeVisitor):
211
+ """Collect import aliases from a module AST."""
212
+
213
+ def __init__(self) -> None:
214
+ self.aliases: dict[str, str] = {}
215
+
216
+ def visit_Import(self, node: ast.Import) -> None:
217
+ for alias in node.names:
218
+ local = alias.asname if alias.asname else alias.name
219
+ self.aliases[local] = alias.name
220
+
221
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
222
+ mod = node.module or ""
223
+ for alias in node.names:
224
+ local = alias.asname if alias.asname else alias.name
225
+ self.aliases[local] = f"{mod}.{alias.name}" if mod else alias.name
226
+
227
+
228
+ class _FunctionVisitor(ast.NodeVisitor):
229
+ """Collect FunctionInfo entries from a module AST."""
230
+
231
+ def __init__(self, module: str) -> None:
232
+ self.module = module
233
+ self.functions: list[FunctionInfo] = []
234
+
235
+ def _visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
236
+ fqn = f"{self.module}.{node.name}"
237
+ calls = _collect_calls_in_body(node.body)
238
+ is_src, src_cwe = _classify_as_source(calls)
239
+ is_snk, snk_cwe = _classify_as_sink(calls)
240
+ self.functions.append(FunctionInfo(
241
+ module=self.module,
242
+ name=fqn,
243
+ lineno=node.lineno,
244
+ is_source=is_src,
245
+ is_sink=is_snk,
246
+ is_sanitizer=_is_sanitizer_name(node.name),
247
+ source_cwe=src_cwe,
248
+ sink_cwe=snk_cwe,
249
+ ))
250
+
251
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
252
+ self._visit_func(node)
253
+
254
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
255
+ self._visit_func(node)
256
+
257
+
258
+ class _CallVisitor(ast.NodeVisitor):
259
+ """Collect CallEdge entries from a module AST."""
260
+
261
+ def __init__(
262
+ self,
263
+ module: str,
264
+ imports: dict[str, str],
265
+ ) -> None:
266
+ self.module = module
267
+ self.imports = imports
268
+ self.edges: list[CallEdge] = []
269
+ self._current_func: str | None = None
270
+
271
+ def _visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
272
+ old = self._current_func
273
+ self._current_func = f"{self.module}.{node.name}"
274
+ self.generic_visit(node)
275
+ self._current_func = old
276
+
277
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
278
+ self._visit_func(node)
279
+
280
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
281
+ self._visit_func(node)
282
+
283
+ def visit_Call(self, node: ast.Call) -> None:
284
+ if self._current_func is not None:
285
+ raw = _resolve_call_name(node)
286
+ if raw:
287
+ callee = _resolve_callee(
288
+ raw, self.imports, self.module,
289
+ )
290
+ self.edges.append(CallEdge(
291
+ caller=self._current_func,
292
+ callee=callee,
293
+ lineno=node.lineno,
294
+ ))
295
+ self.generic_visit(node)
296
+
297
+
298
+ # ---------------------------------------------------------------------------
299
+ # Graph construction helpers
300
+ # ---------------------------------------------------------------------------
301
+
302
+ def _collect_imports(tree: ast.Module) -> dict[str, str]:
303
+ """Return {local_alias: fully_qualified_name} for a module."""
304
+ collector = _ImportCollector()
305
+ collector.visit(tree)
306
+ return collector.aliases
307
+
308
+
309
+ def _resolve_callee(
310
+ raw: str, imports: dict[str, str], module: str,
311
+ ) -> str:
312
+ """Resolve a raw call name to a fully qualified name."""
313
+ parts = raw.split(".", 1)
314
+ head = parts[0]
315
+ if head in imports:
316
+ base = imports[head]
317
+ if len(parts) > 1:
318
+ return f"{base}.{parts[1]}"
319
+ return base
320
+ if "." not in raw:
321
+ return f"{module}.{raw}"
322
+ return raw
323
+
324
+
325
+ def _module_name_from_path(path: str, root: str) -> str:
326
+ """Derive a dotted module name from a file path."""
327
+ rel = os.path.relpath(path, root)
328
+ no_ext = rel.removesuffix(".py")
329
+ return no_ext.replace(os.sep, ".")
330
+
331
+
332
+ def _collect_python_files(root: str) -> list[str]:
333
+ """Return sorted list of .py files under *root*."""
334
+ result: list[str] = []
335
+ for dirpath, _dirs, files in os.walk(root):
336
+ for f in files:
337
+ if f.endswith(".py"):
338
+ result.append(os.path.join(dirpath, f))
339
+ result.sort()
340
+ return result
341
+
342
+
343
+ def _collect_suppressions(
344
+ source: str, module_name: str,
345
+ ) -> list[Suppression]:
346
+ """Scan source for taint suppression comments.
347
+
348
+ Format: # taint: ignore[CWE-200] -- reason text
349
+ Can appear on a def line or the line immediately before.
350
+ """
351
+ lines = source.splitlines()
352
+ results: list[Suppression] = []
353
+ for i, line in enumerate(lines):
354
+ match = _SUPPRESS_RE.search(line)
355
+ if match is None:
356
+ continue
357
+ cwe = match.group(1)
358
+ reason = match.group(2).strip()
359
+ func_name = _find_func_for_suppress(lines, i)
360
+ if func_name:
361
+ fqn = f"{module_name}.{func_name}"
362
+ results.append(Suppression(
363
+ function=fqn, cwe=cwe,
364
+ reason=reason, lineno=i + 1,
365
+ ))
366
+ return results
367
+
368
+
369
+ def _find_func_for_suppress(
370
+ lines: list[str], comment_idx: int,
371
+ ) -> str:
372
+ """Find the function name for a suppression comment.
373
+
374
+ Checks the comment line itself and the next line.
375
+ """
376
+ name = _extract_func_name(lines[comment_idx])
377
+ if name:
378
+ return name
379
+ if comment_idx + 1 < len(lines):
380
+ return _extract_func_name(lines[comment_idx + 1])
381
+ return ""
382
+
383
+
384
+ def _extract_func_name(line: str) -> str:
385
+ """Extract function name from a def line."""
386
+ stripped = line.strip()
387
+ for prefix in ("async def ", "def "):
388
+ if stripped.startswith(prefix):
389
+ rest = stripped[len(prefix):]
390
+ paren = rest.find("(")
391
+ if paren > 0:
392
+ return rest[:paren]
393
+ return ""
394
+
395
+
396
+ def parse_file(
397
+ path: str, module_name: str,
398
+ ) -> tuple[
399
+ list[FunctionInfo], list[CallEdge],
400
+ list[Suppression],
401
+ ]:
402
+ """Parse a .py file. Return functions, edges, suppressions.
403
+
404
+ Returns empty results with a warning on SyntaxError.
405
+ """
406
+ with open(path) as fh:
407
+ source = fh.read()
408
+ try:
409
+ tree = ast.parse(source, filename=path)
410
+ except SyntaxError:
411
+ import warnings
412
+ warnings.warn(
413
+ f"SyntaxError in {path}, skipping",
414
+ stacklevel=2,
415
+ )
416
+ return [], [], []
417
+ imports = _collect_imports(tree)
418
+ fv = _FunctionVisitor(module_name)
419
+ fv.visit(tree)
420
+ cv = _CallVisitor(module_name, imports)
421
+ cv.visit(tree)
422
+ supps = _collect_suppressions(source, module_name)
423
+ return fv.functions, cv.edges, supps
424
+
425
+
426
+ def build_graph(root_dir: str) -> CallGraph:
427
+ """Build a CallGraph by scanning all .py files."""
428
+ graph = CallGraph()
429
+ for path in _collect_python_files(root_dir):
430
+ mod = _module_name_from_path(path, root_dir)
431
+ funcs, edges, supps = parse_file(path, mod)
432
+ for f in funcs:
433
+ graph.functions[f.name] = f
434
+ graph.edges.extend(edges)
435
+ graph.suppressions.extend(supps)
436
+ return graph
437
+
438
+
439
+ # ---------------------------------------------------------------------------
440
+ # Taint tracing
441
+ # ---------------------------------------------------------------------------
442
+
443
+ def _build_forward_adj(graph: CallGraph) -> dict[str, list[str]]:
444
+ """Build caller -> [callee] adjacency list."""
445
+ adj: dict[str, list[str]] = defaultdict(list)
446
+ for e in graph.edges:
447
+ adj[e.caller].append(e.callee)
448
+ return dict(adj)
449
+
450
+
451
+ def _find_sources(graph: CallGraph) -> list[str]:
452
+ """Return names of all source functions in the graph."""
453
+ return [
454
+ name for name, info in graph.functions.items()
455
+ if info.is_source
456
+ ]
457
+
458
+
459
+ def _find_sinks(graph: CallGraph) -> dict[str, str]:
460
+ """Return {name: cwe} for all sink functions in the graph."""
461
+ return {
462
+ name: info.sink_cwe
463
+ for name, info in graph.functions.items()
464
+ if info.is_sink
465
+ }
466
+
467
+
468
+ def _check_sink_hit(
469
+ node: str,
470
+ start: str,
471
+ path: list[str],
472
+ sinks: dict[str, str],
473
+ ) -> tuple[str, list[str], str] | None:
474
+ """Return a sink hit tuple if *node* is a sink (and not start)."""
475
+ if node in sinks and node != start:
476
+ return (node, path, sinks[node])
477
+ return None
478
+
479
+
480
+ def _enqueue_neighbors(
481
+ node: str,
482
+ path: list[str],
483
+ forward_adj: dict[str, list[str]],
484
+ visited: set[str],
485
+ queue: list[list[str]],
486
+ ) -> None:
487
+ """Append unvisited neighbor paths to *queue*."""
488
+ for nb in forward_adj.get(node, []):
489
+ if nb not in visited:
490
+ queue.append(path + [nb])
491
+
492
+
493
+ def _bfs_to_sinks(
494
+ start: str,
495
+ forward_adj: dict[str, list[str]],
496
+ sinks: dict[str, str],
497
+ graph: CallGraph,
498
+ ) -> list[tuple[str, list[str], str]]:
499
+ """Forward BFS from *start*, return paths reaching sinks."""
500
+ results: list[tuple[str, list[str], str]] = []
501
+ queue: list[list[str]] = [[start]]
502
+ visited: set[str] = set()
503
+ while queue:
504
+ path = queue.pop(0)
505
+ node = path[-1]
506
+ if node in visited:
507
+ continue
508
+ visited.add(node)
509
+ hit = _check_sink_hit(node, start, path, sinks)
510
+ if hit:
511
+ results.append(hit)
512
+ _enqueue_neighbors(node, path, forward_adj, visited, queue)
513
+ return results
514
+
515
+
516
+ def _path_has_sanitizer(
517
+ path: list[str], graph: CallGraph,
518
+ ) -> bool:
519
+ """Return True if any node on *path* is a sanitizer."""
520
+ for node in path:
521
+ info = graph.functions.get(node)
522
+ if info and info.is_sanitizer:
523
+ return True
524
+ return False
525
+
526
+
527
+ def _check_suppressed(
528
+ path: list[str], cwe: str,
529
+ suppressions: list[Suppression],
530
+ ) -> tuple[bool, str]:
531
+ """Check if any function on path has a suppression.
532
+
533
+ Returns (suppressed, reason).
534
+ """
535
+ for s in suppressions:
536
+ if s.cwe == cwe and s.function in path:
537
+ return (True, s.reason)
538
+ return (False, "")
539
+
540
+
541
+ def find_taint_paths(graph: CallGraph) -> list[TaintPath]:
542
+ """Find all source-to-sink taint paths in the graph."""
543
+ forward = _build_forward_adj(graph)
544
+ sinks = _find_sinks(graph)
545
+ sources = _find_sources(graph)
546
+ results: list[TaintPath] = []
547
+ for src in sources:
548
+ hits = _bfs_to_sinks(src, forward, sinks, graph)
549
+ for sink_name, path, cwe in hits:
550
+ sanitized = _path_has_sanitizer(path, graph)
551
+ suppressed, reason = _check_suppressed(
552
+ path, cwe, graph.suppressions,
553
+ )
554
+ results.append(TaintPath(
555
+ source=src,
556
+ sink=sink_name,
557
+ path=path,
558
+ cwe=cwe,
559
+ sanitized=sanitized,
560
+ suppressed=suppressed,
561
+ suppression_reason=reason,
562
+ ))
563
+ return results
564
+
565
+
566
+ # ---------------------------------------------------------------------------
567
+ # Reporting
568
+ # ---------------------------------------------------------------------------
569
+
570
+ def _should_skip(
571
+ tp: TaintPath, include_sanitized: bool,
572
+ ) -> bool:
573
+ """Check if a path should be skipped in output."""
574
+ if tp.suppressed:
575
+ return True
576
+ if tp.sanitized and not include_sanitized:
577
+ return True
578
+ return False
579
+
580
+
581
+ def format_text_report(
582
+ paths: list[TaintPath],
583
+ include_sanitized: bool = False,
584
+ ) -> str:
585
+ """Format taint paths as a human-readable text report."""
586
+ lines: list[str] = []
587
+ for tp in paths:
588
+ if _should_skip(tp, include_sanitized):
589
+ continue
590
+ status = " [SANITIZED]" if tp.sanitized else ""
591
+ chain = " -> ".join(tp.path)
592
+ lines.append(
593
+ f"{tp.cwe}: {tp.source} -> {tp.sink}{status}"
594
+ )
595
+ lines.append(f" path: {chain}")
596
+ if not lines:
597
+ return "No taint paths found."
598
+ return "\n".join(lines)
599
+
600
+
601
+ def _sarif_result(tp: TaintPath) -> dict[str, object]:
602
+ """Build a single SARIF result dict for a TaintPath."""
603
+ return {
604
+ "ruleId": tp.cwe,
605
+ "message": {
606
+ "text": f"Taint flow: {tp.source} -> {tp.sink}",
607
+ },
608
+ "locations": [
609
+ {
610
+ "physicalLocation": {
611
+ "artifactLocation": {"uri": tp.source},
612
+ },
613
+ },
614
+ ],
615
+ "properties": {
616
+ "sanitized": tp.sanitized,
617
+ "path": tp.path,
618
+ },
619
+ }
620
+
621
+
622
+ def format_sarif(
623
+ paths: list[TaintPath],
624
+ include_sanitized: bool = False,
625
+ ) -> dict[str, object]:
626
+ """Format taint paths as a SARIF JSON structure."""
627
+ filtered = [
628
+ p for p in paths
629
+ if not _should_skip(p, include_sanitized)
630
+ ]
631
+ results = [_sarif_result(p) for p in filtered]
632
+ return {
633
+ "version": "2.1.0",
634
+ "$schema": (
635
+ "https://raw.githubusercontent.com/oasis-tcs/"
636
+ "sarif-spec/main/sarif-2.1/"
637
+ "schema/sarif-schema-2.1.0.json"
638
+ ),
639
+ "runs": [
640
+ {
641
+ "tool": {
642
+ "driver": {
643
+ "name": "aofire-call-graph",
644
+ "version": "0.1.0",
645
+ },
646
+ },
647
+ "results": results,
648
+ },
649
+ ],
650
+ }
651
+
652
+
653
+ # ---------------------------------------------------------------------------
654
+ # CLI
655
+ # ---------------------------------------------------------------------------
656
+
657
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
658
+ """Parse command-line arguments."""
659
+ parser = argparse.ArgumentParser(
660
+ description="Source-to-sink taint analysis via call graph",
661
+ )
662
+ parser.add_argument(
663
+ "directory",
664
+ help="Root directory of Python source to analyze",
665
+ )
666
+ parser.add_argument(
667
+ "--sarif",
668
+ action="store_true",
669
+ help="Output in SARIF format",
670
+ )
671
+ parser.add_argument(
672
+ "--include-sanitized",
673
+ action="store_true",
674
+ help="Include sanitized paths in output",
675
+ )
676
+ return parser.parse_args(argv)
677
+
678
+
679
+ def main(argv: list[str] | None = None) -> int:
680
+ """Entry point for the aofire-call-graph CLI."""
681
+ args = parse_args(argv)
682
+ graph = build_graph(args.directory)
683
+ paths = find_taint_paths(graph)
684
+ if args.sarif:
685
+ sarif = format_sarif(paths, args.include_sanitized)
686
+ print(json.dumps(sarif, indent=2))
687
+ else:
688
+ report = format_text_report(paths, args.include_sanitized)
689
+ print(report)
690
+ return 0
691
+
692
+
693
+ if __name__ == "__main__":
694
+ sys.exit(main())