threadcheck 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,13 @@
1
+ """
2
+ threadcheck — Data-race detector for multi-threaded Python.
3
+
4
+ Supports both AST-based static analysis and runtime dynamic
5
+ detection via bytecode instrumentation.
6
+
7
+ Targets Python 3.14+ free-threading builds.
8
+ """
9
+
10
+ from ._version import __version__
11
+
12
+ from .static.analyzer import analyze_path, analyze_file
13
+ from .static.models import RaceWarning, Severity, WarningCategory
@@ -0,0 +1,3 @@
1
+ from .cli import main
2
+
3
+ main()
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1"
threadcheck/cli.py ADDED
@@ -0,0 +1,89 @@
1
+ import argparse
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ from ._version import __version__
7
+ from .static.analyzer import analyze_path
8
+ from .reporting.formatter import format_report
9
+ from .reporting.sarif import format_sarif
10
+ from .dynamic.__main__ import run_script
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser(
15
+ prog="threadcheck",
16
+ description="Data Race Detector for Python",
17
+ )
18
+
19
+ parser.add_argument(
20
+ "--version", action="version", version=f"threadcheck {__version__}"
21
+ )
22
+
23
+ sub = parser.add_subparsers(dest="command", required=True)
24
+
25
+ scan = sub.add_parser("scan", help="Static analysis for data races")
26
+ scan.add_argument("path", help="File or directory to scan")
27
+ fmt = scan.add_mutually_exclusive_group()
28
+ fmt.add_argument("--json", action="store_true", help="Output in JSON format")
29
+ fmt.add_argument("--sarif", action="store_true", help="Output in SARIF v2.1.0 format")
30
+ scan.add_argument("-o", "--output", help="Write output to file (default: stdout)")
31
+
32
+ run = sub.add_parser("run", help="Dynamic race detection (Phase 3)")
33
+ run.add_argument("script", help="Python script to execute")
34
+
35
+ compat = sub.add_parser("check-compat", help="Check free-threading compatibility (Phase 7)")
36
+ compat.add_argument("path", nargs="?", default=".", help="Project path")
37
+
38
+ args = parser.parse_args()
39
+
40
+ if args.command == "scan":
41
+ _do_scan(args)
42
+ elif args.command == "run":
43
+ run_script(args.script)
44
+ elif args.command == "check-compat":
45
+ print("Not implemented: free-threading compatibility check (Phase 7)", file=sys.stderr)
46
+ sys.exit(1)
47
+
48
+
49
+ def _do_scan(args):
50
+ path = Path(args.path).resolve()
51
+ if not path.exists():
52
+ print(f"Path does not exist: {path}", file=sys.stderr)
53
+ sys.exit(1)
54
+
55
+ print(f"threadcheck scan -- analysing {path}")
56
+ print()
57
+
58
+ warnings = analyze_path(str(path))
59
+
60
+ total = len(warnings)
61
+ errors = sum(1 for w in warnings if w.severity.value == "error")
62
+ warns = sum(1 for w in warnings if w.severity.value == "warning")
63
+ infos = sum(1 for w in warnings if w.severity.value == "info")
64
+
65
+ if args.json:
66
+ output = json.dumps(
67
+ [w.to_dict() for w in warnings], indent=2, ensure_ascii=False
68
+ )
69
+ _write_output(args.output, output)
70
+ elif args.sarif:
71
+ output = format_sarif(warnings)
72
+ _write_output(args.output, output)
73
+ else:
74
+ text = format_report(warnings)
75
+ _write_output(args.output, text)
76
+
77
+ print()
78
+ print(f"Total: {total} issue(s) ({errors} error(s), {warns} warning(s), {infos} info)")
79
+
80
+
81
+ def _write_output(path_arg: str | None, content: str):
82
+ if path_arg:
83
+ Path(path_arg).write_text(content, encoding="utf-8")
84
+ else:
85
+ print(content)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()
File without changes
@@ -0,0 +1,38 @@
1
+ import ast
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ from .tracker import ThreadCheckTracker
6
+ from .transform import TrackInjector
7
+
8
+
9
+ def run_script(script_path: str):
10
+ path = Path(script_path).resolve()
11
+ if not path.exists():
12
+ print(f"File not found: {path}", file=sys.stderr)
13
+ sys.exit(1)
14
+
15
+ source = path.read_text(encoding="utf-8")
16
+ filename = str(path)
17
+
18
+ try:
19
+ tree = ast.parse(source, filename=filename)
20
+ except SyntaxError as e:
21
+ print(f"Syntax error: {e}", file=sys.stderr)
22
+ sys.exit(1)
23
+
24
+ TrackInjector(filename=filename).transform(tree)
25
+ ast.fix_missing_locations(tree)
26
+
27
+ code = compile(tree, filename, "exec")
28
+
29
+ ThreadCheckTracker.start()
30
+ try:
31
+ exec(code, {"_threadcheck_tracker": ThreadCheckTracker, "__file__": filename})
32
+ except SystemExit:
33
+ pass
34
+ finally:
35
+ ThreadCheckTracker.stop()
36
+
37
+ print(ThreadCheckTracker.format_races())
38
+ ThreadCheckTracker.reset()
@@ -0,0 +1,31 @@
1
+ from collections import defaultdict
2
+
3
+
4
+ class VectorClock:
5
+ def __init__(self):
6
+ self._clock: dict[int, int] = defaultdict(int)
7
+
8
+ def tick(self, thread_id: int) -> int:
9
+ self._clock[thread_id] += 1
10
+ return self._clock[thread_id]
11
+
12
+ def merge(self, other: "VectorClock"):
13
+ for k, v in other._clock.items():
14
+ self._clock[k] = max(self._clock.get(k, 0), v)
15
+
16
+ def conflicts_with(self, other: "VectorClock") -> bool:
17
+ return not (self._leq(other) or other._leq(self))
18
+
19
+ def _leq(self, other: "VectorClock") -> bool:
20
+ for k, v in self._clock.items():
21
+ if v > other._clock.get(k, 0):
22
+ return False
23
+ return True
24
+
25
+ def copy(self) -> "VectorClock":
26
+ vc = VectorClock()
27
+ vc._clock = self._clock.copy()
28
+ return vc
29
+
30
+ def __repr__(self) -> str:
31
+ return f"VectorClock({dict(self._clock)})"
@@ -0,0 +1,97 @@
1
+ import sys
2
+ import ast
3
+ import builtins
4
+ import importlib.util
5
+ import importlib.abc
6
+ from pathlib import Path
7
+
8
+ from .transform import TrackInjector
9
+ from .tracker import ThreadCheckTracker
10
+
11
+
12
+ class ThreadCheckLoader(importlib.abc.Loader):
13
+ def __init__(self, tracker=None):
14
+ self.tracker = tracker or ThreadCheckTracker
15
+
16
+ def create_module(self, spec):
17
+ return None
18
+
19
+ def exec_module(self, module):
20
+ spec = module.__spec__
21
+ source = self._get_source(spec, module)
22
+ if source is None:
23
+ raise ImportError(f"cannot load source for {spec.name}")
24
+
25
+ module.__file__ = spec.origin
26
+
27
+ tree = ast.parse(source, filename=spec.origin)
28
+ TrackInjector(filename=str(spec.origin)).transform(tree)
29
+ ast.fix_missing_locations(tree)
30
+
31
+ code = compile(tree, spec.origin, "exec")
32
+ globals_dict = module.__dict__
33
+ globals_dict["_threadcheck_tracker"] = self.tracker
34
+ builtins._threadcheck_tracker = self.tracker
35
+ exec(code, globals_dict)
36
+
37
+ @staticmethod
38
+ def _get_source(spec, module=None):
39
+ for candidate in (spec.origin, getattr(module, "__file__", None)):
40
+ if candidate and Path(candidate).suffix == ".py":
41
+ try:
42
+ return Path(candidate).read_text(encoding="utf-8")
43
+ except Exception:
44
+ pass
45
+ if hasattr(spec.loader, "get_source"):
46
+ try:
47
+ return spec.loader.get_source(spec.name)
48
+ except Exception:
49
+ pass
50
+ return None
51
+
52
+
53
+ class ThreadCheckFinder(importlib.abc.MetaPathFinder):
54
+ def __init__(self, tracker=None, include_paths=None):
55
+ self.tracker = tracker or ThreadCheckTracker
56
+ self._include_paths = (
57
+ [Path(p).resolve() for p in include_paths] if include_paths else []
58
+ )
59
+
60
+ def _should_instrument(self, filepath: Path) -> bool:
61
+ if not self._include_paths:
62
+ return True
63
+ resolved = filepath.resolve()
64
+ return any(_is_under(resolved, inc) for inc in self._include_paths)
65
+
66
+ def find_spec(self, fullname, path, target=None):
67
+ for entry in (path or sys.path):
68
+ if entry == "":
69
+ entry = "."
70
+ base = Path(entry) / f"{fullname.replace('.', '/')}.py"
71
+ if base.exists() and self._should_instrument(base):
72
+ spec = importlib.util.spec_from_file_location(
73
+ fullname,
74
+ str(base),
75
+ loader=ThreadCheckLoader(self.tracker),
76
+ )
77
+ return spec
78
+ return None
79
+
80
+
81
+ def install_hook(tracker=None, include_paths=None):
82
+ hook = ThreadCheckFinder(tracker, include_paths)
83
+ sys.meta_path.insert(0, hook)
84
+ return hook
85
+
86
+
87
+ def uninstall_hook(hook):
88
+ if hook in sys.meta_path:
89
+ sys.meta_path.remove(hook)
90
+
91
+
92
+ def _is_under(child: Path, parent: Path) -> bool:
93
+ try:
94
+ child.relative_to(parent)
95
+ return True
96
+ except ValueError:
97
+ return False
@@ -0,0 +1,191 @@
1
+ import os
2
+ import sys
3
+ import threading
4
+ from collections import Counter, defaultdict
5
+ from dataclasses import dataclass, field
6
+
7
+ from .clock import VectorClock
8
+
9
+
10
+ @dataclass
11
+ class AccessRecord:
12
+ var_name: str
13
+ operation: str
14
+ thread_id: int
15
+ clock: VectorClock = field(default_factory=VectorClock)
16
+ location: tuple = ("", 0)
17
+
18
+
19
+ class ThreadCheckTracker:
20
+ _lock = threading.Lock()
21
+ _access_log: dict[str, list[AccessRecord]] = defaultdict(list)
22
+ _thread_clocks: dict[int, VectorClock] = {}
23
+ _lock_clocks: dict[str, VectorClock] = {}
24
+ _active = False
25
+
26
+ @classmethod
27
+ def start(cls):
28
+ cls._active = True
29
+
30
+ @classmethod
31
+ def stop(cls):
32
+ cls._active = False
33
+
34
+ @classmethod
35
+ def _get_clock(cls) -> VectorClock:
36
+ tid = threading.get_ident()
37
+ if tid not in cls._thread_clocks:
38
+ with cls._lock:
39
+ if tid not in cls._thread_clocks:
40
+ cls._thread_clocks[tid] = VectorClock()
41
+ return cls._thread_clocks[tid]
42
+
43
+ _diag_count = 0
44
+
45
+ @classmethod
46
+ def write_before(cls, var_name: str, file: str = "", line: int = 0):
47
+ if not cls._active:
48
+ return
49
+ clock = cls._get_clock()
50
+ tid = threading.get_ident()
51
+ clock.tick(tid)
52
+ record = AccessRecord(
53
+ var_name=var_name,
54
+ operation="write",
55
+ thread_id=tid,
56
+ clock=clock.copy(),
57
+ location=(file, line),
58
+ )
59
+ with cls._lock:
60
+ cls._access_log[var_name].append(record)
61
+ if cls._diag_count < 10:
62
+ cls._diag_count += 1
63
+ ct = threading.current_thread()
64
+ print(
65
+ f"[TC_DIAG] write_before tid={tid} ct_name={ct.name} ct_ident={ct.ident} var={var_name}",
66
+ file=sys.stderr, flush=True,
67
+ )
68
+
69
+ @classmethod
70
+ def read_before(cls, var_name: str, file: str = "", line: int = 0):
71
+ if not cls._active:
72
+ return
73
+ clock = cls._get_clock()
74
+ tid = threading.get_ident()
75
+ clock.tick(tid)
76
+ record = AccessRecord(
77
+ var_name=var_name,
78
+ operation="read",
79
+ thread_id=tid,
80
+ clock=clock.copy(),
81
+ location=(file, line),
82
+ )
83
+ with cls._lock:
84
+ cls._access_log[var_name].append(record)
85
+
86
+ @classmethod
87
+ def lock_acquire(cls, lock_name: str, file: str = "", line: int = 0):
88
+ if not cls._active:
89
+ return
90
+ tid = threading.get_ident()
91
+ clock = cls._get_clock()
92
+ with cls._lock:
93
+ if lock_name in cls._lock_clocks:
94
+ clock.merge(cls._lock_clocks[lock_name])
95
+ clock.tick(tid)
96
+
97
+ @classmethod
98
+ def lock_release(cls, lock_name: str, file: str = "", line: int = 0):
99
+ if not cls._active:
100
+ return
101
+ clock = cls._get_clock()
102
+ tid = threading.get_ident()
103
+ with cls._lock:
104
+ cls._lock_clocks[lock_name] = clock.copy()
105
+
106
+ @classmethod
107
+ def reset(cls):
108
+ with cls._lock:
109
+ cls._access_log.clear()
110
+ cls._thread_clocks.clear()
111
+ cls._lock_clocks.clear()
112
+ cls._active = False
113
+
114
+ @classmethod
115
+ def reset_logs(cls):
116
+ with cls._lock:
117
+ cls._access_log.clear()
118
+ cls._thread_clocks.clear()
119
+ cls._lock_clocks.clear()
120
+
121
+ @classmethod
122
+ def _race_key(cls, r1: AccessRecord, r2: AccessRecord) -> tuple:
123
+ tid1, tid2 = sorted([r1.thread_id, r2.thread_id])
124
+ loc1, loc2 = sorted([r1.location, r2.location])
125
+ return (r1.var_name, tid1, tid2, loc1, loc2)
126
+
127
+ @classmethod
128
+ def detect_races(cls) -> list[tuple[str, AccessRecord, AccessRecord]]:
129
+ raw: list[tuple[str, AccessRecord, AccessRecord]] = []
130
+ with cls._lock:
131
+ for var_name, records in cls._access_log.items():
132
+ for i, r1 in enumerate(records):
133
+ for r2 in records[i + 1 :]:
134
+ if r1.thread_id != r2.thread_id:
135
+ if r1.operation == "write" or r2.operation == "write":
136
+ if r1.clock.conflicts_with(r2.clock):
137
+ raw.append((var_name, r1, r2))
138
+
139
+ seen: set[tuple] = set()
140
+ races: list[tuple[str, AccessRecord, AccessRecord]] = []
141
+ for entry in raw:
142
+ _, r1, r2 = entry
143
+ key = cls._race_key(r1, r2)
144
+ if key not in seen:
145
+ seen.add(key)
146
+ races.append(entry)
147
+ return races
148
+
149
+ @classmethod
150
+ def format_races(cls) -> str:
151
+ races = cls.detect_races()
152
+ if not races:
153
+ return "No data races detected"
154
+
155
+ overlap = Counter()
156
+ with cls._lock:
157
+ for var_name, records in cls._access_log.items():
158
+ for i, r1 in enumerate(records):
159
+ for r2 in records[i + 1 :]:
160
+ if r1.thread_id != r2.thread_id:
161
+ if r1.operation == "write" or r2.operation == "write":
162
+ if r1.clock.conflicts_with(r2.clock):
163
+ key = cls._race_key(r1, r2)
164
+ overlap[key] += 1
165
+
166
+ lines = ["Data races detected:", ""]
167
+ for var_name, r1, r2 in races:
168
+ f1, l1 = r1.location
169
+ f2, l2 = r2.location
170
+ key = cls._race_key(r1, r2)
171
+ count = overlap.get(key, 0)
172
+ lines.append(f" [!] `{var_name}`")
173
+ lines.append(
174
+ f" Thread-{r1.thread_id} ({r1.operation})"
175
+ f" at {f1}:{l1}"
176
+ )
177
+ lines.append(
178
+ f" Thread-{r2.thread_id} ({r2.operation})"
179
+ f" at {f2}:{l2}"
180
+ )
181
+ if count > 1:
182
+ lines.append(f" ({count} overlapping accesses)")
183
+ lines.append("")
184
+
185
+ total_unique = len(races)
186
+ total_overlap = sum(overlap.values())
187
+ lines.append(
188
+ f"Summary: {total_unique} unique race pair(s), "
189
+ f"{total_overlap} total overlapping access(es)"
190
+ )
191
+ return "\n".join(lines)
@@ -0,0 +1,192 @@
1
+ import ast
2
+
3
+
4
+ _LOCK_NAMES = frozenset({"Lock", "RLock", "Semaphore", "BoundedSemaphore"})
5
+
6
+
7
+ _TRACKER_IMPORT = ast.parse(
8
+ "from threadcheck.dynamic.tracker import ThreadCheckTracker as _threadcheck_tracker"
9
+ ).body[0]
10
+
11
+
12
+ class TrackInjector:
13
+ def __init__(self, filename: str = "<unknown>"):
14
+ self.filename = filename
15
+
16
+ def transform(self, tree: ast.Module) -> ast.Module:
17
+ tree.body.insert(0, _TRACKER_IMPORT)
18
+ scopes = {}
19
+ self._collect_scopes(tree, scopes)
20
+ self._inject(tree, scopes)
21
+ return tree
22
+
23
+ def _collect_scopes(self, node, scopes):
24
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
25
+ func_id = id(node)
26
+ info = {"globals": set(), "nonlocals": set()}
27
+ for child in ast.walk(node):
28
+ if isinstance(child, ast.Global):
29
+ info["globals"].update(child.names)
30
+ elif isinstance(child, ast.Nonlocal):
31
+ info["nonlocals"].update(child.names)
32
+ scopes[func_id] = info
33
+ for child in ast.iter_child_nodes(node):
34
+ self._collect_scopes(child, scopes)
35
+ else:
36
+ for child in ast.iter_child_nodes(node):
37
+ self._collect_scopes(child, scopes)
38
+
39
+ def _inject(self, node, scopes, func_id=None):
40
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
41
+ func_id = id(node)
42
+
43
+ for field in ("body", "orelse", "finalbody"):
44
+ old = getattr(node, field, None)
45
+ if isinstance(old, list):
46
+ setattr(node, field, self._transform_list(old, scopes, func_id))
47
+
48
+ for handler in getattr(node, "handlers", []):
49
+ handler.body = self._transform_list(handler.body, scopes, func_id)
50
+
51
+ for child in ast.iter_child_nodes(node):
52
+ self._inject(child, scopes, func_id)
53
+
54
+ def _transform_list(self, stmts, scopes, func_id):
55
+ if func_id is None or func_id not in scopes:
56
+ return stmts
57
+
58
+ info = scopes[func_id]
59
+ shared = info["globals"] | info["nonlocals"]
60
+
61
+ new: list[ast.stmt] = []
62
+ for stmt in stmts:
63
+ if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
64
+ new.append(stmt)
65
+ continue
66
+
67
+ if isinstance(stmt, ast.Assign):
68
+ targets = [
69
+ t
70
+ for t in stmt.targets
71
+ if isinstance(t, ast.Name) and t.id in shared
72
+ ]
73
+ for t in targets:
74
+ new.append(_make_write_before(t.id, self.filename, stmt.lineno))
75
+ new.append(stmt)
76
+
77
+ elif isinstance(stmt, ast.AugAssign):
78
+ if isinstance(stmt.target, ast.Name) and stmt.target.id in shared:
79
+ new.append(
80
+ _make_write_before(
81
+ stmt.target.id, self.filename, stmt.lineno
82
+ )
83
+ )
84
+ new.append(stmt)
85
+
86
+ elif isinstance(stmt, ast.Delete):
87
+ targets = [
88
+ t
89
+ for t in stmt.targets
90
+ if isinstance(t, ast.Name) and t.id in shared
91
+ ]
92
+ for t in targets:
93
+ new.append(
94
+ _make_write_before(t.id, self.filename, stmt.lineno)
95
+ )
96
+ new.append(stmt)
97
+
98
+ elif isinstance(stmt, ast.With):
99
+ lock_name = _resolve_lock_name(stmt)
100
+ new.append(stmt)
101
+ if lock_name:
102
+ stmt.body.insert(
103
+ 0,
104
+ _make_lock_acquire(lock_name, self.filename, stmt.lineno),
105
+ )
106
+ stmt.body.append(
107
+ _make_lock_release(lock_name, self.filename, stmt.lineno),
108
+ )
109
+ else:
110
+ new.append(stmt)
111
+
112
+ return new
113
+
114
+
115
+ def _make_write_before(var_name: str, filename: str, lineno: int) -> ast.Expr:
116
+ return ast.Expr(
117
+ value=ast.Call(
118
+ func=ast.Attribute(
119
+ value=ast.Name(id="_threadcheck_tracker", ctx=ast.Load()),
120
+ attr="write_before",
121
+ ctx=ast.Load(),
122
+ ),
123
+ args=[
124
+ ast.Constant(value=var_name),
125
+ ast.Constant(value=filename),
126
+ ast.Constant(value=lineno),
127
+ ],
128
+ keywords=[],
129
+ ),
130
+ )
131
+
132
+
133
+ def _make_lock_acquire(lock_name: str, filename: str, lineno: int) -> ast.Expr:
134
+ return ast.Expr(
135
+ value=ast.Call(
136
+ func=ast.Attribute(
137
+ value=ast.Name(id="_threadcheck_tracker", ctx=ast.Load()),
138
+ attr="lock_acquire",
139
+ ctx=ast.Load(),
140
+ ),
141
+ args=[
142
+ ast.Constant(value=lock_name),
143
+ ast.Constant(value=filename),
144
+ ast.Constant(value=lineno),
145
+ ],
146
+ keywords=[],
147
+ ),
148
+ )
149
+
150
+
151
+ def _make_lock_release(lock_name: str, filename: str, lineno: int) -> ast.Expr:
152
+ return ast.Expr(
153
+ value=ast.Call(
154
+ func=ast.Attribute(
155
+ value=ast.Name(id="_threadcheck_tracker", ctx=ast.Load()),
156
+ attr="lock_release",
157
+ ctx=ast.Load(),
158
+ ),
159
+ args=[
160
+ ast.Constant(value=lock_name),
161
+ ast.Constant(value=filename),
162
+ ast.Constant(value=lineno),
163
+ ],
164
+ keywords=[],
165
+ ),
166
+ )
167
+
168
+
169
+ def _resolve_lock_name(with_stmt: ast.With) -> str | None:
170
+ for item in with_stmt.items:
171
+ expr = item.context_expr
172
+ if isinstance(expr, ast.Name):
173
+ return expr.id
174
+ if isinstance(expr, ast.Call):
175
+ if isinstance(expr.func, ast.Name) and expr.func.id in _LOCK_NAMES:
176
+ return ast.unparse(expr)
177
+ if isinstance(expr.func, ast.Attribute) and expr.func.attr in _LOCK_NAMES:
178
+ return ast.unparse(expr)
179
+ return None
180
+
181
+
182
+ def transform_source(source: str, filename: str = "<unknown>") -> str:
183
+ tree = ast.parse(source, filename=filename)
184
+ TrackInjector(filename=filename).transform(tree)
185
+ return ast.unparse(tree)
186
+
187
+
188
+ def transform_and_compile(source: str, filename: str = "<unknown>") -> str:
189
+ tree = ast.parse(source, filename=filename)
190
+ TrackInjector(filename=filename).transform(tree)
191
+ ast.fix_missing_locations(tree)
192
+ return compile(tree, filename, "exec")