threadcheck 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,60 @@
1
+ import pytest
2
+
3
+ from .dynamic.hook import install_hook, uninstall_hook
4
+ from .dynamic.tracker import ThreadCheckTracker
5
+
6
+ _hook_instance = None
7
+
8
+
9
+ def pytest_addoption(parser):
10
+ group = parser.getgroup("threadcheck")
11
+ group.addoption(
12
+ "--threadcheck",
13
+ action="store_true",
14
+ default=False,
15
+ help="Run dynamic race detection via AST instrumentation",
16
+ )
17
+
18
+
19
+ def pytest_configure(config):
20
+ global _hook_instance
21
+ if config.getoption("--threadcheck"):
22
+ _hook_instance = install_hook(include_paths=[config.rootpath])
23
+ ThreadCheckTracker.start()
24
+ import sys
25
+ print(f"[threadcheck] hook installed, rootpath={config.rootpath}", flush=True)
26
+ print(f"[threadcheck] include_paths={[str(p) for p in _hook_instance._include_paths]}", flush=True)
27
+
28
+
29
+ def pytest_unconfigure(config):
30
+ global _hook_instance
31
+ if config.getoption("--threadcheck"):
32
+ ThreadCheckTracker.stop()
33
+ if _hook_instance is not None:
34
+ uninstall_hook(_hook_instance)
35
+ _hook_instance = None
36
+
37
+
38
+ @pytest.hookimpl(hookwrapper=True)
39
+ def pytest_runtest_call(item):
40
+ if item.config.getoption("--threadcheck"):
41
+ ThreadCheckTracker.reset_logs()
42
+ yield
43
+ if item.config.getoption("--threadcheck"):
44
+ races = ThreadCheckTracker.detect_races()
45
+ if races:
46
+ item._threadcheck_race_report = ThreadCheckTracker.format_races()
47
+ ThreadCheckTracker.reset_logs()
48
+
49
+
50
+ @pytest.hookimpl(tryfirst=True, hookwrapper=True)
51
+ def pytest_runtest_makereport(item, call):
52
+ if call.when == "call" and not call.excinfo:
53
+ report = getattr(item, "_threadcheck_race_report", None)
54
+ if report:
55
+ from _pytest._code import ExceptionInfo
56
+ try:
57
+ raise pytest.fail.Exception(report)
58
+ except pytest.fail.Exception:
59
+ call.excinfo = ExceptionInfo.from_current()
60
+ yield
File without changes
@@ -0,0 +1,33 @@
1
+ from ..static.models import RaceWarning, Severity, Confidence
2
+
3
+ _CONFIDENCE_TAG = {
4
+ Confidence.HIGH: "[HIGH]",
5
+ Confidence.MEDIUM: "[MED]",
6
+ Confidence.LOW: "[LOW]",
7
+ }
8
+
9
+ _SEVERITY_TAG = {
10
+ Severity.ERROR: "[ERROR]",
11
+ Severity.WARNING: "[WARNING]",
12
+ Severity.INFO: "[INFO]",
13
+ }
14
+
15
+
16
+ def format_report(warnings: list[RaceWarning]) -> str:
17
+ if not warnings:
18
+ return "No data-race issues detected"
19
+
20
+ lines: list[str] = []
21
+
22
+ for w in warnings:
23
+ sev = _SEVERITY_TAG.get(w.severity, "[?]")
24
+ conf = _CONFIDENCE_TAG.get(w.confidence, "")
25
+ lines.append(
26
+ f"{sev} {conf} [{w.category.value}] {w.file}:{w.line}:{w.col}"
27
+ )
28
+ lines.append(f" {w.message}")
29
+ if w.suggestion:
30
+ lines.append(f" suggestion: {w.suggestion}")
31
+ lines.append("")
32
+
33
+ return "\n".join(lines)
@@ -0,0 +1,100 @@
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from ..static.models import RaceWarning, Severity, WarningCategory, Confidence
5
+ from .._version import __version__
6
+
7
+ _SCHEMA = "https://raw.githubusercontent.com/oasis-tcs/sarif-spec/master/Schemata/sarif-schema-2.1.0.json"
8
+ _TOOL_INFO_URI = "https://github.com/ChidcGithub/Threadcheck"
9
+
10
+ _SEVERITY_TO_LEVEL = {
11
+ Severity.ERROR: "error",
12
+ Severity.WARNING: "warning",
13
+ Severity.INFO: "note",
14
+ }
15
+
16
+ _CATEGORY_LABELS = {
17
+ WarningCategory.UNSAFE_GLOBAL: ("Global variable modified without lock", "Detects modifications of `global` variables inside functions without lock protection."),
18
+ WarningCategory.UNSAFE_NONLOCAL: ("Nonlocal variable modified without lock", "Detects modifications of `nonlocal` variables inside nested functions without lock protection."),
19
+ WarningCategory.UNPROTECTED_ACCESS: ("Unprotected shared access", "Detects shared variable access without synchronization."),
20
+ WarningCategory.THREAD_USAGE: ("Thread creation detected", "Reports sites where `threading.Thread` objects are created."),
21
+ WarningCategory.SHARED_MUTABLE: ("Module-level mutable object modified", "Detects modification of module-level mutable objects (lists, dicts, sets) inside functions."),
22
+ WarningCategory.CLASS_ATTRIBUTE: ("Class attribute modified without lock", "Detects unsafe modification of instance attributes (`self.x`) without lock protection."),
23
+ }
24
+
25
+
26
+ def format_sarif(
27
+ warnings: list[RaceWarning],
28
+ tool_version: str = __version__,
29
+ ) -> str:
30
+ rules = _build_rules(warnings)
31
+ results = [_warning_to_result(w) for w in warnings]
32
+
33
+ doc = {
34
+ "$schema": _SCHEMA,
35
+ "version": "2.1.0",
36
+ "runs": [
37
+ {
38
+ "tool": {
39
+ "driver": {
40
+ "name": "threadcheck",
41
+ "version": tool_version,
42
+ "informationUri": _TOOL_INFO_URI,
43
+ "rules": rules,
44
+ }
45
+ },
46
+ "results": results,
47
+ }
48
+ ],
49
+ }
50
+
51
+ return json.dumps(doc, indent=2, ensure_ascii=False)
52
+
53
+
54
+ def _build_rules(warnings: list[RaceWarning]) -> list[dict]:
55
+ seen: set[str] = set()
56
+ rules: list[dict] = []
57
+ for w in warnings:
58
+ rid = w.category.value
59
+ if rid in seen:
60
+ continue
61
+ seen.add(rid)
62
+ label, desc = _CATEGORY_LABELS.get(
63
+ w.category, (rid.replace("_", " ").title(), "")
64
+ )
65
+ level = _SEVERITY_TO_LEVEL.get(w.severity, "warning")
66
+ rules.append(
67
+ {
68
+ "id": rid,
69
+ "shortDescription": {"text": label},
70
+ "fullDescription": {"text": desc},
71
+ "defaultConfiguration": {"level": level},
72
+ "properties": {
73
+ "category": "Concurrency",
74
+ "confidence": w.confidence.value,
75
+ },
76
+ }
77
+ )
78
+ return rules
79
+
80
+
81
+ def _warning_to_result(w: RaceWarning) -> dict:
82
+ return {
83
+ "ruleId": w.category.value,
84
+ "level": _SEVERITY_TO_LEVEL.get(w.severity, "warning"),
85
+ "message": {"text": w.message},
86
+ "locations": [
87
+ {
88
+ "physicalLocation": {
89
+ "artifactLocation": {
90
+ "uri": w.file.resolve().as_uri(),
91
+ },
92
+ "region": {
93
+ "startLine": w.line,
94
+ "startColumn": w.col,
95
+ },
96
+ }
97
+ }
98
+ ],
99
+ "properties": {},
100
+ }
@@ -0,0 +1,3 @@
1
+ from ..static.models import Severity, WarningCategory
2
+
3
+ __all__ = ["Severity", "WarningCategory"]
File without changes
@@ -0,0 +1,104 @@
1
+ import ast
2
+ from pathlib import Path
3
+
4
+ from .models import RaceWarning
5
+ from .lock_tracker import LockTracker
6
+ from .visitors import (
7
+ GlobalVisitor,
8
+ NonlocalVisitor,
9
+ ThreadVisitor,
10
+ SharedMutableVisitor,
11
+ ClassAttributeVisitor,
12
+ )
13
+
14
+ _SKIP_DIRS = frozenset({
15
+ ".git", "__pycache__", ".venv", "venv", "env", ".env",
16
+ ".mypy_cache", ".pytest_cache", "node_modules", ".tox",
17
+ "dist", "build", ".eggs", "site-packages",
18
+ })
19
+
20
+
21
+ class AnalysisContext:
22
+ def __init__(self, filepath: Path, tree: ast.Module):
23
+ self.filepath = filepath
24
+ self._thread_targets: set[str] = set()
25
+ self._has_thread = False
26
+ self._find_thread_targets(tree)
27
+ self.lock_tracker = LockTracker()
28
+ self.lock_tracker.visit(tree)
29
+
30
+ def is_protected(self, line: int) -> bool:
31
+ return self.lock_tracker.is_protected_by_lock(line)
32
+
33
+ def is_thread_target(self, func_name: str) -> bool:
34
+ return func_name in self._thread_targets
35
+
36
+ def has_any_thread(self) -> bool:
37
+ return self._has_thread
38
+
39
+ def _find_thread_targets(self, tree):
40
+ for node in ast.walk(tree):
41
+ if isinstance(node, ast.Call):
42
+ if isinstance(node.func, ast.Attribute) and node.func.attr == "Thread":
43
+ self._has_thread = True
44
+ for kw in node.keywords:
45
+ if kw.arg == "target":
46
+ if isinstance(kw.value, ast.Name):
47
+ self._thread_targets.add(kw.value.id)
48
+ elif isinstance(kw.value, ast.Attribute):
49
+ self._thread_targets.add(kw.value.attr)
50
+ attr_name = getattr(node.func, "attr", None)
51
+ if attr_name in ("submit", "map"):
52
+ self._has_thread = True
53
+
54
+
55
+ def analyze_file(filepath: Path) -> list[RaceWarning]:
56
+ try:
57
+ source = filepath.read_text(encoding="utf-8")
58
+ except Exception:
59
+ return []
60
+
61
+ try:
62
+ tree = ast.parse(source, filename=str(filepath))
63
+ except SyntaxError:
64
+ return []
65
+
66
+ context = AnalysisContext(filepath, tree)
67
+
68
+ all_warnings: list[RaceWarning] = []
69
+
70
+ for visitor_cls in (
71
+ GlobalVisitor,
72
+ NonlocalVisitor,
73
+ ThreadVisitor,
74
+ SharedMutableVisitor,
75
+ ClassAttributeVisitor,
76
+ ):
77
+ visitor = visitor_cls(filepath, context)
78
+ visitor.visit(tree)
79
+ all_warnings.extend(visitor.warnings)
80
+
81
+ return all_warnings
82
+
83
+
84
+ def analyze_path(path: str) -> list[RaceWarning]:
85
+ p = Path(path).resolve()
86
+ all_warnings: list[RaceWarning] = []
87
+
88
+ if p.is_file():
89
+ if p.suffix == ".py":
90
+ all_warnings.extend(analyze_file(p))
91
+ elif p.is_dir():
92
+ for py_file in sorted(p.rglob("*.py")):
93
+ if _should_skip(py_file):
94
+ continue
95
+ all_warnings.extend(analyze_file(py_file))
96
+
97
+ return all_warnings
98
+
99
+
100
+ def _should_skip(path: Path) -> bool:
101
+ for part in path.parts:
102
+ if part in _SKIP_DIRS or part.startswith("."):
103
+ return True
104
+ return False
@@ -0,0 +1,42 @@
1
+ import ast
2
+
3
+ _LOCK_CLASSES = frozenset({"Lock", "RLock", "Semaphore", "BoundedSemaphore"})
4
+
5
+
6
+ class LockTracker(ast.NodeVisitor):
7
+ def __init__(self):
8
+ self.protected_regions: list[ast.With] = []
9
+ self.lock_exprs: set[str] = set()
10
+
11
+ def visit_With(self, node):
12
+ for item in node.items:
13
+ expr_name = ast.unparse(item.context_expr)
14
+ if expr_name in self.lock_exprs or self._is_lock_creation(item.context_expr):
15
+ self.lock_exprs.add(expr_name)
16
+ self.protected_regions.append(node)
17
+ self.generic_visit(node)
18
+
19
+ def visit_Assign(self, node):
20
+ for target in node.targets:
21
+ if isinstance(node.value, ast.Call) and self._is_lock_creation(node.value):
22
+ self.lock_exprs.add(ast.unparse(target))
23
+ self.generic_visit(node)
24
+
25
+ def is_protected_by_lock(self, line: int) -> bool:
26
+ for region in self.protected_regions:
27
+ start = getattr(region, "lineno", 0)
28
+ end = getattr(region, "end_lineno", start)
29
+ if start <= line <= end:
30
+ return True
31
+ return False
32
+
33
+ @staticmethod
34
+ def _is_lock_creation(node) -> bool:
35
+ if not isinstance(node, ast.Call):
36
+ return False
37
+ func = node.func
38
+ if isinstance(func, ast.Name):
39
+ return func.id in _LOCK_CLASSES
40
+ if isinstance(func, ast.Attribute):
41
+ return func.attr in _LOCK_CLASSES
42
+ return False
@@ -0,0 +1,48 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from pathlib import Path
4
+
5
+
6
+ class Severity(Enum):
7
+ ERROR = "error"
8
+ WARNING = "warning"
9
+ INFO = "info"
10
+
11
+
12
+ class WarningCategory(Enum):
13
+ UNSAFE_GLOBAL = "unsafe_global"
14
+ UNSAFE_NONLOCAL = "unsafe_nonlocal"
15
+ UNPROTECTED_ACCESS = "unprotected_access"
16
+ THREAD_USAGE = "thread_usage"
17
+ SHARED_MUTABLE = "shared_mutable"
18
+ CLASS_ATTRIBUTE = "class_attribute"
19
+
20
+
21
+ class Confidence(Enum):
22
+ HIGH = "high"
23
+ MEDIUM = "medium"
24
+ LOW = "low"
25
+
26
+
27
+ @dataclass
28
+ class RaceWarning:
29
+ file: Path
30
+ line: int
31
+ col: int
32
+ severity: Severity
33
+ category: WarningCategory
34
+ message: str
35
+ suggestion: str | None = None
36
+ confidence: Confidence = Confidence.MEDIUM
37
+
38
+ def to_dict(self) -> dict:
39
+ return {
40
+ "file": str(self.file),
41
+ "line": self.line,
42
+ "col": self.col,
43
+ "severity": self.severity.value,
44
+ "category": self.category.value,
45
+ "message": self.message,
46
+ "suggestion": self.suggestion,
47
+ "confidence": self.confidence.value,
48
+ }