threadcheck 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- threadcheck/__init__.py +13 -0
- threadcheck/__main__.py +3 -0
- threadcheck/_version.py +1 -0
- threadcheck/cli.py +89 -0
- threadcheck/dynamic/__init__.py +0 -0
- threadcheck/dynamic/__main__.py +38 -0
- threadcheck/dynamic/clock.py +31 -0
- threadcheck/dynamic/hook.py +97 -0
- threadcheck/dynamic/tracker.py +191 -0
- threadcheck/dynamic/transform.py +192 -0
- threadcheck/pytest_plugin.py +60 -0
- threadcheck/reporting/__init__.py +0 -0
- threadcheck/reporting/formatter.py +33 -0
- threadcheck/reporting/sarif.py +100 -0
- threadcheck/reporting/types.py +3 -0
- threadcheck/static/__init__.py +0 -0
- threadcheck/static/analyzer.py +104 -0
- threadcheck/static/lock_tracker.py +42 -0
- threadcheck/static/models.py +48 -0
- threadcheck/static/visitors.py +324 -0
- threadcheck-0.0.1.dist-info/METADATA +248 -0
- threadcheck-0.0.1.dist-info/RECORD +25 -0
- threadcheck-0.0.1.dist-info/WHEEL +4 -0
- threadcheck-0.0.1.dist-info/entry_points.txt +5 -0
- threadcheck-0.0.1.dist-info/licenses/LICENSE +21 -0
threadcheck/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
threadcheck — Data-race detector for multi-threaded Python.
|
|
3
|
+
|
|
4
|
+
Supports both AST-based static analysis and runtime dynamic
|
|
5
|
+
detection via bytecode instrumentation.
|
|
6
|
+
|
|
7
|
+
Targets Python 3.14+ free-threading builds.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from ._version import __version__
|
|
11
|
+
|
|
12
|
+
from .static.analyzer import analyze_path, analyze_file
|
|
13
|
+
from .static.models import RaceWarning, Severity, WarningCategory
|
threadcheck/__main__.py
ADDED
threadcheck/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.0.1"
|
threadcheck/cli.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import sys
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from ._version import __version__
|
|
7
|
+
from .static.analyzer import analyze_path
|
|
8
|
+
from .reporting.formatter import format_report
|
|
9
|
+
from .reporting.sarif import format_sarif
|
|
10
|
+
from .dynamic.__main__ import run_script
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def main():
|
|
14
|
+
parser = argparse.ArgumentParser(
|
|
15
|
+
prog="threadcheck",
|
|
16
|
+
description="Data Race Detector for Python",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
parser.add_argument(
|
|
20
|
+
"--version", action="version", version=f"threadcheck {__version__}"
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
24
|
+
|
|
25
|
+
scan = sub.add_parser("scan", help="Static analysis for data races")
|
|
26
|
+
scan.add_argument("path", help="File or directory to scan")
|
|
27
|
+
fmt = scan.add_mutually_exclusive_group()
|
|
28
|
+
fmt.add_argument("--json", action="store_true", help="Output in JSON format")
|
|
29
|
+
fmt.add_argument("--sarif", action="store_true", help="Output in SARIF v2.1.0 format")
|
|
30
|
+
scan.add_argument("-o", "--output", help="Write output to file (default: stdout)")
|
|
31
|
+
|
|
32
|
+
run = sub.add_parser("run", help="Dynamic race detection (Phase 3)")
|
|
33
|
+
run.add_argument("script", help="Python script to execute")
|
|
34
|
+
|
|
35
|
+
compat = sub.add_parser("check-compat", help="Check free-threading compatibility (Phase 7)")
|
|
36
|
+
compat.add_argument("path", nargs="?", default=".", help="Project path")
|
|
37
|
+
|
|
38
|
+
args = parser.parse_args()
|
|
39
|
+
|
|
40
|
+
if args.command == "scan":
|
|
41
|
+
_do_scan(args)
|
|
42
|
+
elif args.command == "run":
|
|
43
|
+
run_script(args.script)
|
|
44
|
+
elif args.command == "check-compat":
|
|
45
|
+
print("Not implemented: free-threading compatibility check (Phase 7)", file=sys.stderr)
|
|
46
|
+
sys.exit(1)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _do_scan(args):
|
|
50
|
+
path = Path(args.path).resolve()
|
|
51
|
+
if not path.exists():
|
|
52
|
+
print(f"Path does not exist: {path}", file=sys.stderr)
|
|
53
|
+
sys.exit(1)
|
|
54
|
+
|
|
55
|
+
print(f"threadcheck scan -- analysing {path}")
|
|
56
|
+
print()
|
|
57
|
+
|
|
58
|
+
warnings = analyze_path(str(path))
|
|
59
|
+
|
|
60
|
+
total = len(warnings)
|
|
61
|
+
errors = sum(1 for w in warnings if w.severity.value == "error")
|
|
62
|
+
warns = sum(1 for w in warnings if w.severity.value == "warning")
|
|
63
|
+
infos = sum(1 for w in warnings if w.severity.value == "info")
|
|
64
|
+
|
|
65
|
+
if args.json:
|
|
66
|
+
output = json.dumps(
|
|
67
|
+
[w.to_dict() for w in warnings], indent=2, ensure_ascii=False
|
|
68
|
+
)
|
|
69
|
+
_write_output(args.output, output)
|
|
70
|
+
elif args.sarif:
|
|
71
|
+
output = format_sarif(warnings)
|
|
72
|
+
_write_output(args.output, output)
|
|
73
|
+
else:
|
|
74
|
+
text = format_report(warnings)
|
|
75
|
+
_write_output(args.output, text)
|
|
76
|
+
|
|
77
|
+
print()
|
|
78
|
+
print(f"Total: {total} issue(s) ({errors} error(s), {warns} warning(s), {infos} info)")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _write_output(path_arg: str | None, content: str):
|
|
82
|
+
if path_arg:
|
|
83
|
+
Path(path_arg).write_text(content, encoding="utf-8")
|
|
84
|
+
else:
|
|
85
|
+
print(content)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
if __name__ == "__main__":
|
|
89
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import sys
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from .tracker import ThreadCheckTracker
|
|
6
|
+
from .transform import TrackInjector
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def run_script(script_path: str):
|
|
10
|
+
path = Path(script_path).resolve()
|
|
11
|
+
if not path.exists():
|
|
12
|
+
print(f"File not found: {path}", file=sys.stderr)
|
|
13
|
+
sys.exit(1)
|
|
14
|
+
|
|
15
|
+
source = path.read_text(encoding="utf-8")
|
|
16
|
+
filename = str(path)
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
tree = ast.parse(source, filename=filename)
|
|
20
|
+
except SyntaxError as e:
|
|
21
|
+
print(f"Syntax error: {e}", file=sys.stderr)
|
|
22
|
+
sys.exit(1)
|
|
23
|
+
|
|
24
|
+
TrackInjector(filename=filename).transform(tree)
|
|
25
|
+
ast.fix_missing_locations(tree)
|
|
26
|
+
|
|
27
|
+
code = compile(tree, filename, "exec")
|
|
28
|
+
|
|
29
|
+
ThreadCheckTracker.start()
|
|
30
|
+
try:
|
|
31
|
+
exec(code, {"_threadcheck_tracker": ThreadCheckTracker, "__file__": filename})
|
|
32
|
+
except SystemExit:
|
|
33
|
+
pass
|
|
34
|
+
finally:
|
|
35
|
+
ThreadCheckTracker.stop()
|
|
36
|
+
|
|
37
|
+
print(ThreadCheckTracker.format_races())
|
|
38
|
+
ThreadCheckTracker.reset()
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class VectorClock:
|
|
5
|
+
def __init__(self):
|
|
6
|
+
self._clock: dict[int, int] = defaultdict(int)
|
|
7
|
+
|
|
8
|
+
def tick(self, thread_id: int) -> int:
|
|
9
|
+
self._clock[thread_id] += 1
|
|
10
|
+
return self._clock[thread_id]
|
|
11
|
+
|
|
12
|
+
def merge(self, other: "VectorClock"):
|
|
13
|
+
for k, v in other._clock.items():
|
|
14
|
+
self._clock[k] = max(self._clock.get(k, 0), v)
|
|
15
|
+
|
|
16
|
+
def conflicts_with(self, other: "VectorClock") -> bool:
|
|
17
|
+
return not (self._leq(other) or other._leq(self))
|
|
18
|
+
|
|
19
|
+
def _leq(self, other: "VectorClock") -> bool:
|
|
20
|
+
for k, v in self._clock.items():
|
|
21
|
+
if v > other._clock.get(k, 0):
|
|
22
|
+
return False
|
|
23
|
+
return True
|
|
24
|
+
|
|
25
|
+
def copy(self) -> "VectorClock":
|
|
26
|
+
vc = VectorClock()
|
|
27
|
+
vc._clock = self._clock.copy()
|
|
28
|
+
return vc
|
|
29
|
+
|
|
30
|
+
def __repr__(self) -> str:
|
|
31
|
+
return f"VectorClock({dict(self._clock)})"
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import ast
|
|
3
|
+
import builtins
|
|
4
|
+
import importlib.util
|
|
5
|
+
import importlib.abc
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .transform import TrackInjector
|
|
9
|
+
from .tracker import ThreadCheckTracker
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ThreadCheckLoader(importlib.abc.Loader):
|
|
13
|
+
def __init__(self, tracker=None):
|
|
14
|
+
self.tracker = tracker or ThreadCheckTracker
|
|
15
|
+
|
|
16
|
+
def create_module(self, spec):
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
def exec_module(self, module):
|
|
20
|
+
spec = module.__spec__
|
|
21
|
+
source = self._get_source(spec, module)
|
|
22
|
+
if source is None:
|
|
23
|
+
raise ImportError(f"cannot load source for {spec.name}")
|
|
24
|
+
|
|
25
|
+
module.__file__ = spec.origin
|
|
26
|
+
|
|
27
|
+
tree = ast.parse(source, filename=spec.origin)
|
|
28
|
+
TrackInjector(filename=str(spec.origin)).transform(tree)
|
|
29
|
+
ast.fix_missing_locations(tree)
|
|
30
|
+
|
|
31
|
+
code = compile(tree, spec.origin, "exec")
|
|
32
|
+
globals_dict = module.__dict__
|
|
33
|
+
globals_dict["_threadcheck_tracker"] = self.tracker
|
|
34
|
+
builtins._threadcheck_tracker = self.tracker
|
|
35
|
+
exec(code, globals_dict)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _get_source(spec, module=None):
|
|
39
|
+
for candidate in (spec.origin, getattr(module, "__file__", None)):
|
|
40
|
+
if candidate and Path(candidate).suffix == ".py":
|
|
41
|
+
try:
|
|
42
|
+
return Path(candidate).read_text(encoding="utf-8")
|
|
43
|
+
except Exception:
|
|
44
|
+
pass
|
|
45
|
+
if hasattr(spec.loader, "get_source"):
|
|
46
|
+
try:
|
|
47
|
+
return spec.loader.get_source(spec.name)
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ThreadCheckFinder(importlib.abc.MetaPathFinder):
|
|
54
|
+
def __init__(self, tracker=None, include_paths=None):
|
|
55
|
+
self.tracker = tracker or ThreadCheckTracker
|
|
56
|
+
self._include_paths = (
|
|
57
|
+
[Path(p).resolve() for p in include_paths] if include_paths else []
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def _should_instrument(self, filepath: Path) -> bool:
|
|
61
|
+
if not self._include_paths:
|
|
62
|
+
return True
|
|
63
|
+
resolved = filepath.resolve()
|
|
64
|
+
return any(_is_under(resolved, inc) for inc in self._include_paths)
|
|
65
|
+
|
|
66
|
+
def find_spec(self, fullname, path, target=None):
|
|
67
|
+
for entry in (path or sys.path):
|
|
68
|
+
if entry == "":
|
|
69
|
+
entry = "."
|
|
70
|
+
base = Path(entry) / f"{fullname.replace('.', '/')}.py"
|
|
71
|
+
if base.exists() and self._should_instrument(base):
|
|
72
|
+
spec = importlib.util.spec_from_file_location(
|
|
73
|
+
fullname,
|
|
74
|
+
str(base),
|
|
75
|
+
loader=ThreadCheckLoader(self.tracker),
|
|
76
|
+
)
|
|
77
|
+
return spec
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def install_hook(tracker=None, include_paths=None):
|
|
82
|
+
hook = ThreadCheckFinder(tracker, include_paths)
|
|
83
|
+
sys.meta_path.insert(0, hook)
|
|
84
|
+
return hook
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def uninstall_hook(hook):
|
|
88
|
+
if hook in sys.meta_path:
|
|
89
|
+
sys.meta_path.remove(hook)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _is_under(child: Path, parent: Path) -> bool:
|
|
93
|
+
try:
|
|
94
|
+
child.relative_to(parent)
|
|
95
|
+
return True
|
|
96
|
+
except ValueError:
|
|
97
|
+
return False
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import threading
|
|
4
|
+
from collections import Counter, defaultdict
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
from .clock import VectorClock
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class AccessRecord:
|
|
12
|
+
var_name: str
|
|
13
|
+
operation: str
|
|
14
|
+
thread_id: int
|
|
15
|
+
clock: VectorClock = field(default_factory=VectorClock)
|
|
16
|
+
location: tuple = ("", 0)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ThreadCheckTracker:
|
|
20
|
+
_lock = threading.Lock()
|
|
21
|
+
_access_log: dict[str, list[AccessRecord]] = defaultdict(list)
|
|
22
|
+
_thread_clocks: dict[int, VectorClock] = {}
|
|
23
|
+
_lock_clocks: dict[str, VectorClock] = {}
|
|
24
|
+
_active = False
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def start(cls):
|
|
28
|
+
cls._active = True
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def stop(cls):
|
|
32
|
+
cls._active = False
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def _get_clock(cls) -> VectorClock:
|
|
36
|
+
tid = threading.get_ident()
|
|
37
|
+
if tid not in cls._thread_clocks:
|
|
38
|
+
with cls._lock:
|
|
39
|
+
if tid not in cls._thread_clocks:
|
|
40
|
+
cls._thread_clocks[tid] = VectorClock()
|
|
41
|
+
return cls._thread_clocks[tid]
|
|
42
|
+
|
|
43
|
+
_diag_count = 0
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def write_before(cls, var_name: str, file: str = "", line: int = 0):
|
|
47
|
+
if not cls._active:
|
|
48
|
+
return
|
|
49
|
+
clock = cls._get_clock()
|
|
50
|
+
tid = threading.get_ident()
|
|
51
|
+
clock.tick(tid)
|
|
52
|
+
record = AccessRecord(
|
|
53
|
+
var_name=var_name,
|
|
54
|
+
operation="write",
|
|
55
|
+
thread_id=tid,
|
|
56
|
+
clock=clock.copy(),
|
|
57
|
+
location=(file, line),
|
|
58
|
+
)
|
|
59
|
+
with cls._lock:
|
|
60
|
+
cls._access_log[var_name].append(record)
|
|
61
|
+
if cls._diag_count < 10:
|
|
62
|
+
cls._diag_count += 1
|
|
63
|
+
ct = threading.current_thread()
|
|
64
|
+
print(
|
|
65
|
+
f"[TC_DIAG] write_before tid={tid} ct_name={ct.name} ct_ident={ct.ident} var={var_name}",
|
|
66
|
+
file=sys.stderr, flush=True,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def read_before(cls, var_name: str, file: str = "", line: int = 0):
|
|
71
|
+
if not cls._active:
|
|
72
|
+
return
|
|
73
|
+
clock = cls._get_clock()
|
|
74
|
+
tid = threading.get_ident()
|
|
75
|
+
clock.tick(tid)
|
|
76
|
+
record = AccessRecord(
|
|
77
|
+
var_name=var_name,
|
|
78
|
+
operation="read",
|
|
79
|
+
thread_id=tid,
|
|
80
|
+
clock=clock.copy(),
|
|
81
|
+
location=(file, line),
|
|
82
|
+
)
|
|
83
|
+
with cls._lock:
|
|
84
|
+
cls._access_log[var_name].append(record)
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def lock_acquire(cls, lock_name: str, file: str = "", line: int = 0):
|
|
88
|
+
if not cls._active:
|
|
89
|
+
return
|
|
90
|
+
tid = threading.get_ident()
|
|
91
|
+
clock = cls._get_clock()
|
|
92
|
+
with cls._lock:
|
|
93
|
+
if lock_name in cls._lock_clocks:
|
|
94
|
+
clock.merge(cls._lock_clocks[lock_name])
|
|
95
|
+
clock.tick(tid)
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def lock_release(cls, lock_name: str, file: str = "", line: int = 0):
|
|
99
|
+
if not cls._active:
|
|
100
|
+
return
|
|
101
|
+
clock = cls._get_clock()
|
|
102
|
+
tid = threading.get_ident()
|
|
103
|
+
with cls._lock:
|
|
104
|
+
cls._lock_clocks[lock_name] = clock.copy()
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def reset(cls):
|
|
108
|
+
with cls._lock:
|
|
109
|
+
cls._access_log.clear()
|
|
110
|
+
cls._thread_clocks.clear()
|
|
111
|
+
cls._lock_clocks.clear()
|
|
112
|
+
cls._active = False
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def reset_logs(cls):
|
|
116
|
+
with cls._lock:
|
|
117
|
+
cls._access_log.clear()
|
|
118
|
+
cls._thread_clocks.clear()
|
|
119
|
+
cls._lock_clocks.clear()
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def _race_key(cls, r1: AccessRecord, r2: AccessRecord) -> tuple:
|
|
123
|
+
tid1, tid2 = sorted([r1.thread_id, r2.thread_id])
|
|
124
|
+
loc1, loc2 = sorted([r1.location, r2.location])
|
|
125
|
+
return (r1.var_name, tid1, tid2, loc1, loc2)
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def detect_races(cls) -> list[tuple[str, AccessRecord, AccessRecord]]:
|
|
129
|
+
raw: list[tuple[str, AccessRecord, AccessRecord]] = []
|
|
130
|
+
with cls._lock:
|
|
131
|
+
for var_name, records in cls._access_log.items():
|
|
132
|
+
for i, r1 in enumerate(records):
|
|
133
|
+
for r2 in records[i + 1 :]:
|
|
134
|
+
if r1.thread_id != r2.thread_id:
|
|
135
|
+
if r1.operation == "write" or r2.operation == "write":
|
|
136
|
+
if r1.clock.conflicts_with(r2.clock):
|
|
137
|
+
raw.append((var_name, r1, r2))
|
|
138
|
+
|
|
139
|
+
seen: set[tuple] = set()
|
|
140
|
+
races: list[tuple[str, AccessRecord, AccessRecord]] = []
|
|
141
|
+
for entry in raw:
|
|
142
|
+
_, r1, r2 = entry
|
|
143
|
+
key = cls._race_key(r1, r2)
|
|
144
|
+
if key not in seen:
|
|
145
|
+
seen.add(key)
|
|
146
|
+
races.append(entry)
|
|
147
|
+
return races
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def format_races(cls) -> str:
|
|
151
|
+
races = cls.detect_races()
|
|
152
|
+
if not races:
|
|
153
|
+
return "No data races detected"
|
|
154
|
+
|
|
155
|
+
overlap = Counter()
|
|
156
|
+
with cls._lock:
|
|
157
|
+
for var_name, records in cls._access_log.items():
|
|
158
|
+
for i, r1 in enumerate(records):
|
|
159
|
+
for r2 in records[i + 1 :]:
|
|
160
|
+
if r1.thread_id != r2.thread_id:
|
|
161
|
+
if r1.operation == "write" or r2.operation == "write":
|
|
162
|
+
if r1.clock.conflicts_with(r2.clock):
|
|
163
|
+
key = cls._race_key(r1, r2)
|
|
164
|
+
overlap[key] += 1
|
|
165
|
+
|
|
166
|
+
lines = ["Data races detected:", ""]
|
|
167
|
+
for var_name, r1, r2 in races:
|
|
168
|
+
f1, l1 = r1.location
|
|
169
|
+
f2, l2 = r2.location
|
|
170
|
+
key = cls._race_key(r1, r2)
|
|
171
|
+
count = overlap.get(key, 0)
|
|
172
|
+
lines.append(f" [!] `{var_name}`")
|
|
173
|
+
lines.append(
|
|
174
|
+
f" Thread-{r1.thread_id} ({r1.operation})"
|
|
175
|
+
f" at {f1}:{l1}"
|
|
176
|
+
)
|
|
177
|
+
lines.append(
|
|
178
|
+
f" Thread-{r2.thread_id} ({r2.operation})"
|
|
179
|
+
f" at {f2}:{l2}"
|
|
180
|
+
)
|
|
181
|
+
if count > 1:
|
|
182
|
+
lines.append(f" ({count} overlapping accesses)")
|
|
183
|
+
lines.append("")
|
|
184
|
+
|
|
185
|
+
total_unique = len(races)
|
|
186
|
+
total_overlap = sum(overlap.values())
|
|
187
|
+
lines.append(
|
|
188
|
+
f"Summary: {total_unique} unique race pair(s), "
|
|
189
|
+
f"{total_overlap} total overlapping access(es)"
|
|
190
|
+
)
|
|
191
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
_LOCK_NAMES = frozenset({"Lock", "RLock", "Semaphore", "BoundedSemaphore"})
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
_TRACKER_IMPORT = ast.parse(
|
|
8
|
+
"from threadcheck.dynamic.tracker import ThreadCheckTracker as _threadcheck_tracker"
|
|
9
|
+
).body[0]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TrackInjector:
|
|
13
|
+
def __init__(self, filename: str = "<unknown>"):
|
|
14
|
+
self.filename = filename
|
|
15
|
+
|
|
16
|
+
def transform(self, tree: ast.Module) -> ast.Module:
|
|
17
|
+
tree.body.insert(0, _TRACKER_IMPORT)
|
|
18
|
+
scopes = {}
|
|
19
|
+
self._collect_scopes(tree, scopes)
|
|
20
|
+
self._inject(tree, scopes)
|
|
21
|
+
return tree
|
|
22
|
+
|
|
23
|
+
def _collect_scopes(self, node, scopes):
|
|
24
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
25
|
+
func_id = id(node)
|
|
26
|
+
info = {"globals": set(), "nonlocals": set()}
|
|
27
|
+
for child in ast.walk(node):
|
|
28
|
+
if isinstance(child, ast.Global):
|
|
29
|
+
info["globals"].update(child.names)
|
|
30
|
+
elif isinstance(child, ast.Nonlocal):
|
|
31
|
+
info["nonlocals"].update(child.names)
|
|
32
|
+
scopes[func_id] = info
|
|
33
|
+
for child in ast.iter_child_nodes(node):
|
|
34
|
+
self._collect_scopes(child, scopes)
|
|
35
|
+
else:
|
|
36
|
+
for child in ast.iter_child_nodes(node):
|
|
37
|
+
self._collect_scopes(child, scopes)
|
|
38
|
+
|
|
39
|
+
def _inject(self, node, scopes, func_id=None):
|
|
40
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
41
|
+
func_id = id(node)
|
|
42
|
+
|
|
43
|
+
for field in ("body", "orelse", "finalbody"):
|
|
44
|
+
old = getattr(node, field, None)
|
|
45
|
+
if isinstance(old, list):
|
|
46
|
+
setattr(node, field, self._transform_list(old, scopes, func_id))
|
|
47
|
+
|
|
48
|
+
for handler in getattr(node, "handlers", []):
|
|
49
|
+
handler.body = self._transform_list(handler.body, scopes, func_id)
|
|
50
|
+
|
|
51
|
+
for child in ast.iter_child_nodes(node):
|
|
52
|
+
self._inject(child, scopes, func_id)
|
|
53
|
+
|
|
54
|
+
def _transform_list(self, stmts, scopes, func_id):
|
|
55
|
+
if func_id is None or func_id not in scopes:
|
|
56
|
+
return stmts
|
|
57
|
+
|
|
58
|
+
info = scopes[func_id]
|
|
59
|
+
shared = info["globals"] | info["nonlocals"]
|
|
60
|
+
|
|
61
|
+
new: list[ast.stmt] = []
|
|
62
|
+
for stmt in stmts:
|
|
63
|
+
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
|
64
|
+
new.append(stmt)
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
if isinstance(stmt, ast.Assign):
|
|
68
|
+
targets = [
|
|
69
|
+
t
|
|
70
|
+
for t in stmt.targets
|
|
71
|
+
if isinstance(t, ast.Name) and t.id in shared
|
|
72
|
+
]
|
|
73
|
+
for t in targets:
|
|
74
|
+
new.append(_make_write_before(t.id, self.filename, stmt.lineno))
|
|
75
|
+
new.append(stmt)
|
|
76
|
+
|
|
77
|
+
elif isinstance(stmt, ast.AugAssign):
|
|
78
|
+
if isinstance(stmt.target, ast.Name) and stmt.target.id in shared:
|
|
79
|
+
new.append(
|
|
80
|
+
_make_write_before(
|
|
81
|
+
stmt.target.id, self.filename, stmt.lineno
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
new.append(stmt)
|
|
85
|
+
|
|
86
|
+
elif isinstance(stmt, ast.Delete):
|
|
87
|
+
targets = [
|
|
88
|
+
t
|
|
89
|
+
for t in stmt.targets
|
|
90
|
+
if isinstance(t, ast.Name) and t.id in shared
|
|
91
|
+
]
|
|
92
|
+
for t in targets:
|
|
93
|
+
new.append(
|
|
94
|
+
_make_write_before(t.id, self.filename, stmt.lineno)
|
|
95
|
+
)
|
|
96
|
+
new.append(stmt)
|
|
97
|
+
|
|
98
|
+
elif isinstance(stmt, ast.With):
|
|
99
|
+
lock_name = _resolve_lock_name(stmt)
|
|
100
|
+
new.append(stmt)
|
|
101
|
+
if lock_name:
|
|
102
|
+
stmt.body.insert(
|
|
103
|
+
0,
|
|
104
|
+
_make_lock_acquire(lock_name, self.filename, stmt.lineno),
|
|
105
|
+
)
|
|
106
|
+
stmt.body.append(
|
|
107
|
+
_make_lock_release(lock_name, self.filename, stmt.lineno),
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
new.append(stmt)
|
|
111
|
+
|
|
112
|
+
return new
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _make_write_before(var_name: str, filename: str, lineno: int) -> ast.Expr:
|
|
116
|
+
return ast.Expr(
|
|
117
|
+
value=ast.Call(
|
|
118
|
+
func=ast.Attribute(
|
|
119
|
+
value=ast.Name(id="_threadcheck_tracker", ctx=ast.Load()),
|
|
120
|
+
attr="write_before",
|
|
121
|
+
ctx=ast.Load(),
|
|
122
|
+
),
|
|
123
|
+
args=[
|
|
124
|
+
ast.Constant(value=var_name),
|
|
125
|
+
ast.Constant(value=filename),
|
|
126
|
+
ast.Constant(value=lineno),
|
|
127
|
+
],
|
|
128
|
+
keywords=[],
|
|
129
|
+
),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _make_lock_acquire(lock_name: str, filename: str, lineno: int) -> ast.Expr:
|
|
134
|
+
return ast.Expr(
|
|
135
|
+
value=ast.Call(
|
|
136
|
+
func=ast.Attribute(
|
|
137
|
+
value=ast.Name(id="_threadcheck_tracker", ctx=ast.Load()),
|
|
138
|
+
attr="lock_acquire",
|
|
139
|
+
ctx=ast.Load(),
|
|
140
|
+
),
|
|
141
|
+
args=[
|
|
142
|
+
ast.Constant(value=lock_name),
|
|
143
|
+
ast.Constant(value=filename),
|
|
144
|
+
ast.Constant(value=lineno),
|
|
145
|
+
],
|
|
146
|
+
keywords=[],
|
|
147
|
+
),
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _make_lock_release(lock_name: str, filename: str, lineno: int) -> ast.Expr:
|
|
152
|
+
return ast.Expr(
|
|
153
|
+
value=ast.Call(
|
|
154
|
+
func=ast.Attribute(
|
|
155
|
+
value=ast.Name(id="_threadcheck_tracker", ctx=ast.Load()),
|
|
156
|
+
attr="lock_release",
|
|
157
|
+
ctx=ast.Load(),
|
|
158
|
+
),
|
|
159
|
+
args=[
|
|
160
|
+
ast.Constant(value=lock_name),
|
|
161
|
+
ast.Constant(value=filename),
|
|
162
|
+
ast.Constant(value=lineno),
|
|
163
|
+
],
|
|
164
|
+
keywords=[],
|
|
165
|
+
),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _resolve_lock_name(with_stmt: ast.With) -> str | None:
|
|
170
|
+
for item in with_stmt.items:
|
|
171
|
+
expr = item.context_expr
|
|
172
|
+
if isinstance(expr, ast.Name):
|
|
173
|
+
return expr.id
|
|
174
|
+
if isinstance(expr, ast.Call):
|
|
175
|
+
if isinstance(expr.func, ast.Name) and expr.func.id in _LOCK_NAMES:
|
|
176
|
+
return ast.unparse(expr)
|
|
177
|
+
if isinstance(expr.func, ast.Attribute) and expr.func.attr in _LOCK_NAMES:
|
|
178
|
+
return ast.unparse(expr)
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def transform_source(source: str, filename: str = "<unknown>") -> str:
|
|
183
|
+
tree = ast.parse(source, filename=filename)
|
|
184
|
+
TrackInjector(filename=filename).transform(tree)
|
|
185
|
+
return ast.unparse(tree)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def transform_and_compile(source: str, filename: str = "<unknown>") -> str:
|
|
189
|
+
tree = ast.parse(source, filename=filename)
|
|
190
|
+
TrackInjector(filename=filename).transform(tree)
|
|
191
|
+
ast.fix_missing_locations(tree)
|
|
192
|
+
return compile(tree, filename, "exec")
|