threadcheck 0.0.1.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,177 @@
1
+ import threading
2
+ from collections import Counter, defaultdict
3
+ from dataclasses import dataclass, field
4
+
5
+ from .clock import VectorClock
6
+ from .._tid import current_tid as _current_tid
7
+
8
+
9
+ @dataclass
10
+ class AccessRecord:
11
+ var_name: str
12
+ operation: str
13
+ thread_id: int
14
+ clock: VectorClock = field(default_factory=VectorClock)
15
+ location: tuple = ("", 0)
16
+
17
+
18
+ class ThreadCheckTracker:
19
+ _lock = threading.Lock()
20
+ _access_log: dict[str, list[AccessRecord]] = defaultdict(list)
21
+ _thread_clocks: dict[int, VectorClock] = {}
22
+ _lock_clocks: dict[str, VectorClock] = {}
23
+ _active = False
24
+
25
+ @classmethod
26
+ def start(cls):
27
+ cls._active = True
28
+
29
+ @classmethod
30
+ def stop(cls):
31
+ cls._active = False
32
+
33
+ @classmethod
34
+ def _get_clock(cls) -> tuple[VectorClock, int]:
35
+ tid = _current_tid()
36
+ if tid not in cls._thread_clocks:
37
+ with cls._lock:
38
+ if tid not in cls._thread_clocks:
39
+ cls._thread_clocks[tid] = VectorClock()
40
+ return cls._thread_clocks[tid], tid
41
+
42
+ @classmethod
43
+ def write_before(cls, var_name: str, file: str = "", line: int = 0):
44
+ if not cls._active:
45
+ return
46
+ clock, tid = cls._get_clock()
47
+ clock.tick(tid)
48
+ record = AccessRecord(
49
+ var_name=var_name,
50
+ operation="write",
51
+ thread_id=tid,
52
+ clock=clock.copy(),
53
+ location=(file, line),
54
+ )
55
+ with cls._lock:
56
+ cls._access_log[var_name].append(record)
57
+
58
+ @classmethod
59
+ def read_before(cls, var_name: str, file: str = "", line: int = 0):
60
+ if not cls._active:
61
+ return
62
+ clock, tid = cls._get_clock()
63
+ clock.tick(tid)
64
+ record = AccessRecord(
65
+ var_name=var_name,
66
+ operation="read",
67
+ thread_id=tid,
68
+ clock=clock.copy(),
69
+ location=(file, line),
70
+ )
71
+ with cls._lock:
72
+ cls._access_log[var_name].append(record)
73
+
74
+ @classmethod
75
+ def lock_acquire(cls, lock_name: str, file: str = "", line: int = 0):
76
+ if not cls._active:
77
+ return
78
+ clock, tid = cls._get_clock()
79
+ with cls._lock:
80
+ if lock_name in cls._lock_clocks:
81
+ clock.merge(cls._lock_clocks[lock_name])
82
+ clock.tick(tid)
83
+
84
+ @classmethod
85
+ def lock_release(cls, lock_name: str, file: str = "", line: int = 0):
86
+ if not cls._active:
87
+ return
88
+ clock, tid = cls._get_clock()
89
+ with cls._lock:
90
+ cls._lock_clocks[lock_name] = clock.copy()
91
+
92
+ @classmethod
93
+ def reset(cls):
94
+ with cls._lock:
95
+ cls._access_log.clear()
96
+ cls._thread_clocks.clear()
97
+ cls._lock_clocks.clear()
98
+ cls._active = False
99
+
100
+ @classmethod
101
+ def reset_logs(cls):
102
+ with cls._lock:
103
+ cls._access_log.clear()
104
+ cls._thread_clocks.clear()
105
+ cls._lock_clocks.clear()
106
+
107
+ @classmethod
108
+ def _race_key(cls, r1: AccessRecord, r2: AccessRecord) -> tuple:
109
+ tid1, tid2 = sorted([r1.thread_id, r2.thread_id])
110
+ loc1, loc2 = sorted([r1.location, r2.location])
111
+ return (r1.var_name, tid1, tid2, loc1, loc2)
112
+
113
+ @classmethod
114
+ def detect_races(cls) -> list[tuple[str, AccessRecord, AccessRecord]]:
115
+ raw: list[tuple[str, AccessRecord, AccessRecord]] = []
116
+ with cls._lock:
117
+ for var_name, records in cls._access_log.items():
118
+ for i, r1 in enumerate(records):
119
+ for r2 in records[i + 1 :]:
120
+ if r1.thread_id != r2.thread_id:
121
+ if r1.operation == "write" or r2.operation == "write":
122
+ if r1.clock.conflicts_with(r2.clock):
123
+ raw.append((var_name, r1, r2))
124
+
125
+ seen: set[tuple] = set()
126
+ races: list[tuple[str, AccessRecord, AccessRecord]] = []
127
+ for entry in raw:
128
+ _, r1, r2 = entry
129
+ key = cls._race_key(r1, r2)
130
+ if key not in seen:
131
+ seen.add(key)
132
+ races.append(entry)
133
+ return races
134
+
135
+ @classmethod
136
+ def format_races(cls) -> str:
137
+ races = cls.detect_races()
138
+ if not races:
139
+ return "No data races detected"
140
+
141
+ overlap = Counter()
142
+ with cls._lock:
143
+ for var_name, records in cls._access_log.items():
144
+ for i, r1 in enumerate(records):
145
+ for r2 in records[i + 1 :]:
146
+ if r1.thread_id != r2.thread_id:
147
+ if r1.operation == "write" or r2.operation == "write":
148
+ if r1.clock.conflicts_with(r2.clock):
149
+ key = cls._race_key(r1, r2)
150
+ overlap[key] += 1
151
+
152
+ lines = ["Data races detected:", ""]
153
+ for var_name, r1, r2 in races:
154
+ f1, l1 = r1.location
155
+ f2, l2 = r2.location
156
+ key = cls._race_key(r1, r2)
157
+ count = overlap.get(key, 0)
158
+ lines.append(f" [!] `{var_name}`")
159
+ lines.append(
160
+ f" Thread-{r1.thread_id} ({r1.operation})"
161
+ f" at {f1}:{l1}"
162
+ )
163
+ lines.append(
164
+ f" Thread-{r2.thread_id} ({r2.operation})"
165
+ f" at {f2}:{l2}"
166
+ )
167
+ if count > 1:
168
+ lines.append(f" ({count} overlapping accesses)")
169
+ lines.append("")
170
+
171
+ total_unique = len(races)
172
+ total_overlap = sum(overlap.values())
173
+ lines.append(
174
+ f"Summary: {total_unique} unique race pair(s), "
175
+ f"{total_overlap} total overlapping access(es)"
176
+ )
177
+ return "\n".join(lines)
@@ -0,0 +1,192 @@
1
+ import ast
2
+
3
+
4
+ _LOCK_NAMES = frozenset({"Lock", "RLock", "Semaphore", "BoundedSemaphore"})
5
+
6
+
7
+ _TRACKER_IMPORT = ast.parse(
8
+ "from threadcheck.dynamic.tracker import ThreadCheckTracker as _threadcheck_tracker"
9
+ ).body[0]
10
+
11
+
12
+ class TrackInjector:
13
+ def __init__(self, filename: str = "<unknown>"):
14
+ self.filename = filename
15
+
16
+ def transform(self, tree: ast.Module) -> ast.Module:
17
+ tree.body.insert(0, _TRACKER_IMPORT)
18
+ scopes = {}
19
+ self._collect_scopes(tree, scopes)
20
+ self._inject(tree, scopes)
21
+ return tree
22
+
23
+ def _collect_scopes(self, node, scopes):
24
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
25
+ func_id = id(node)
26
+ info = {"globals": set(), "nonlocals": set()}
27
+ for child in ast.walk(node):
28
+ if isinstance(child, ast.Global):
29
+ info["globals"].update(child.names)
30
+ elif isinstance(child, ast.Nonlocal):
31
+ info["nonlocals"].update(child.names)
32
+ scopes[func_id] = info
33
+ for child in ast.iter_child_nodes(node):
34
+ self._collect_scopes(child, scopes)
35
+ else:
36
+ for child in ast.iter_child_nodes(node):
37
+ self._collect_scopes(child, scopes)
38
+
39
+ def _inject(self, node, scopes, func_id=None):
40
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
41
+ func_id = id(node)
42
+
43
+ for field in ("body", "orelse", "finalbody"):
44
+ old = getattr(node, field, None)
45
+ if isinstance(old, list):
46
+ setattr(node, field, self._transform_list(old, scopes, func_id))
47
+
48
+ for handler in getattr(node, "handlers", []):
49
+ handler.body = self._transform_list(handler.body, scopes, func_id)
50
+
51
+ for child in ast.iter_child_nodes(node):
52
+ self._inject(child, scopes, func_id)
53
+
54
+ def _transform_list(self, stmts, scopes, func_id):
55
+ if func_id is None or func_id not in scopes:
56
+ return stmts
57
+
58
+ info = scopes[func_id]
59
+ shared = info["globals"] | info["nonlocals"]
60
+
61
+ new: list[ast.stmt] = []
62
+ for stmt in stmts:
63
+ if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
64
+ new.append(stmt)
65
+ continue
66
+
67
+ if isinstance(stmt, ast.Assign):
68
+ targets = [
69
+ t
70
+ for t in stmt.targets
71
+ if isinstance(t, ast.Name) and t.id in shared
72
+ ]
73
+ for t in targets:
74
+ new.append(_make_write_before(t.id, self.filename, stmt.lineno))
75
+ new.append(stmt)
76
+
77
+ elif isinstance(stmt, ast.AugAssign):
78
+ if isinstance(stmt.target, ast.Name) and stmt.target.id in shared:
79
+ new.append(
80
+ _make_write_before(
81
+ stmt.target.id, self.filename, stmt.lineno
82
+ )
83
+ )
84
+ new.append(stmt)
85
+
86
+ elif isinstance(stmt, ast.Delete):
87
+ targets = [
88
+ t
89
+ for t in stmt.targets
90
+ if isinstance(t, ast.Name) and t.id in shared
91
+ ]
92
+ for t in targets:
93
+ new.append(
94
+ _make_write_before(t.id, self.filename, stmt.lineno)
95
+ )
96
+ new.append(stmt)
97
+
98
+ elif isinstance(stmt, ast.With):
99
+ lock_name = _resolve_lock_name(stmt)
100
+ new.append(stmt)
101
+ if lock_name:
102
+ stmt.body.insert(
103
+ 0,
104
+ _make_lock_acquire(lock_name, self.filename, stmt.lineno),
105
+ )
106
+ stmt.body.append(
107
+ _make_lock_release(lock_name, self.filename, stmt.lineno),
108
+ )
109
+ else:
110
+ new.append(stmt)
111
+
112
+ return new
113
+
114
+
115
+ def _make_write_before(var_name: str, filename: str, lineno: int) -> ast.Expr:
116
+ return ast.Expr(
117
+ value=ast.Call(
118
+ func=ast.Attribute(
119
+ value=ast.Name(id="_threadcheck_tracker", ctx=ast.Load()),
120
+ attr="write_before",
121
+ ctx=ast.Load(),
122
+ ),
123
+ args=[
124
+ ast.Constant(value=var_name),
125
+ ast.Constant(value=filename),
126
+ ast.Constant(value=lineno),
127
+ ],
128
+ keywords=[],
129
+ ),
130
+ )
131
+
132
+
133
+ def _make_lock_acquire(lock_name: str, filename: str, lineno: int) -> ast.Expr:
134
+ return ast.Expr(
135
+ value=ast.Call(
136
+ func=ast.Attribute(
137
+ value=ast.Name(id="_threadcheck_tracker", ctx=ast.Load()),
138
+ attr="lock_acquire",
139
+ ctx=ast.Load(),
140
+ ),
141
+ args=[
142
+ ast.Constant(value=lock_name),
143
+ ast.Constant(value=filename),
144
+ ast.Constant(value=lineno),
145
+ ],
146
+ keywords=[],
147
+ ),
148
+ )
149
+
150
+
151
+ def _make_lock_release(lock_name: str, filename: str, lineno: int) -> ast.Expr:
152
+ return ast.Expr(
153
+ value=ast.Call(
154
+ func=ast.Attribute(
155
+ value=ast.Name(id="_threadcheck_tracker", ctx=ast.Load()),
156
+ attr="lock_release",
157
+ ctx=ast.Load(),
158
+ ),
159
+ args=[
160
+ ast.Constant(value=lock_name),
161
+ ast.Constant(value=filename),
162
+ ast.Constant(value=lineno),
163
+ ],
164
+ keywords=[],
165
+ ),
166
+ )
167
+
168
+
169
+ def _resolve_lock_name(with_stmt: ast.With) -> str | None:
170
+ for item in with_stmt.items:
171
+ expr = item.context_expr
172
+ if isinstance(expr, ast.Name):
173
+ return expr.id
174
+ if isinstance(expr, ast.Call):
175
+ if isinstance(expr.func, ast.Name) and expr.func.id in _LOCK_NAMES:
176
+ return ast.unparse(expr)
177
+ if isinstance(expr.func, ast.Attribute) and expr.func.attr in _LOCK_NAMES:
178
+ return ast.unparse(expr)
179
+ return None
180
+
181
+
182
+ def transform_source(source: str, filename: str = "<unknown>") -> str:
183
+ tree = ast.parse(source, filename=filename)
184
+ TrackInjector(filename=filename).transform(tree)
185
+ return ast.unparse(tree)
186
+
187
+
188
+ def transform_and_compile(source: str, filename: str = "<unknown>") -> str:
189
+ tree = ast.parse(source, filename=filename)
190
+ TrackInjector(filename=filename).transform(tree)
191
+ ast.fix_missing_locations(tree)
192
+ return compile(tree, filename, "exec")
@@ -0,0 +1,60 @@
1
+ import pytest
2
+
3
+ from .dynamic.hook import install_hook, uninstall_hook
4
+ from .dynamic.tracker import ThreadCheckTracker
5
+
6
+ _hook_instance = None
7
+
8
+
9
+ def pytest_addoption(parser):
10
+ group = parser.getgroup("threadcheck")
11
+ group.addoption(
12
+ "--threadcheck",
13
+ action="store_true",
14
+ default=False,
15
+ help="Run dynamic race detection via AST instrumentation",
16
+ )
17
+
18
+
19
+ def pytest_configure(config):
20
+ global _hook_instance
21
+ if config.getoption("--threadcheck"):
22
+ _hook_instance = install_hook(include_paths=[config.rootpath])
23
+ ThreadCheckTracker.start()
24
+ import sys
25
+ print(f"[threadcheck] hook installed, rootpath={config.rootpath}", flush=True)
26
+ print(f"[threadcheck] include_paths={[str(p) for p in _hook_instance._include_paths]}", flush=True)
27
+
28
+
29
+ def pytest_unconfigure(config):
30
+ global _hook_instance
31
+ if config.getoption("--threadcheck"):
32
+ ThreadCheckTracker.stop()
33
+ if _hook_instance is not None:
34
+ uninstall_hook(_hook_instance)
35
+ _hook_instance = None
36
+
37
+
38
+ @pytest.hookimpl(hookwrapper=True)
39
+ def pytest_runtest_call(item):
40
+ if item.config.getoption("--threadcheck"):
41
+ ThreadCheckTracker.reset_logs()
42
+ yield
43
+ if item.config.getoption("--threadcheck"):
44
+ races = ThreadCheckTracker.detect_races()
45
+ if races:
46
+ item._threadcheck_race_report = ThreadCheckTracker.format_races()
47
+ ThreadCheckTracker.reset_logs()
48
+
49
+
50
+ @pytest.hookimpl(tryfirst=True, hookwrapper=True)
51
+ def pytest_runtest_makereport(item, call):
52
+ if call.when == "call" and not call.excinfo:
53
+ report = getattr(item, "_threadcheck_race_report", None)
54
+ if report:
55
+ from _pytest._code import ExceptionInfo
56
+ try:
57
+ raise pytest.fail.Exception(report)
58
+ except pytest.fail.Exception:
59
+ call.excinfo = ExceptionInfo.from_current()
60
+ yield
@@ -0,0 +1,10 @@
1
+ from .formatter import format_report, format_dynamic_races, format_warnings_json, format_dynamic_json
2
+ from .sarif import format_sarif
3
+
4
+ __all__ = [
5
+ "format_report",
6
+ "format_dynamic_races",
7
+ "format_warnings_json",
8
+ "format_dynamic_json",
9
+ "format_sarif",
10
+ ]
@@ -0,0 +1,174 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import sys
6
+ from collections import Counter
7
+ from typing import Any
8
+
9
+ from ..static.models import RaceWarning, Severity, Confidence
10
+
11
+
12
+ def _use_color() -> bool:
13
+ if not sys.stdout.isatty():
14
+ return False
15
+ term = os.environ.get("TERM", "")
16
+ if "dumb" in term.lower():
17
+ return False
18
+ return True
19
+
20
+
21
+ _COLOR = _use_color()
22
+
23
+ _STYLES: dict[str, str] = {}
24
+ if _COLOR:
25
+ _STYLES["reset"] = "\033[0m"
26
+ _STYLES["bold"] = "\033[1m"
27
+ _STYLES["red"] = "\033[91m"
28
+ _STYLES["green"] = "\033[92m"
29
+ _STYLES["yellow"] = "\033[93m"
30
+ _STYLES["blue"] = "\033[94m"
31
+ _STYLES["magenta"] = "\033[95m"
32
+ _STYLES["cyan"] = "\033[96m"
33
+ _STYLES["dim"] = "\033[2m"
34
+
35
+
36
+ def _s(name: str, text: str = "") -> str:
37
+ s = _STYLES.get(name, "")
38
+ r = _STYLES.get("reset", "")
39
+ return f"{s}{text}{r}"
40
+
41
+
42
+ _SEVERITY_COLOR = {
43
+ Severity.ERROR: "red",
44
+ Severity.WARNING: "yellow",
45
+ Severity.INFO: "cyan",
46
+ }
47
+
48
+ _CONFIDENCE_TAG = {
49
+ Confidence.HIGH: "HIGH",
50
+ Confidence.MEDIUM: "MED",
51
+ Confidence.LOW: "LOW",
52
+ }
53
+
54
+ _SEVERITY_TAG = {
55
+ Severity.ERROR: "ERROR",
56
+ Severity.WARNING: "WARNING",
57
+ Severity.INFO: "INFO",
58
+ }
59
+
60
+
61
+ def format_report(warnings: list[RaceWarning]) -> str:
62
+ if not warnings:
63
+ return _s("green", "No data-race issues detected") if _COLOR else "No data-race issues detected"
64
+
65
+ lines: list[str] = []
66
+ for w in warnings:
67
+ sc = _SEVERITY_COLOR.get(w.severity, "")
68
+ lines.append(
69
+ f" {_s(sc, _SEVERITY_TAG.get(w.severity, '?'))} "
70
+ f"{_s('bold', _CONFIDENCE_TAG.get(w.confidence, ''))} "
71
+ f"[{w.category.value}] {w.file}:{w.line}:{w.col}"
72
+ )
73
+ lines.append(f" {w.message}")
74
+ if w.suggestion:
75
+ lines.append(f" {_s('dim', 'suggestion:')} {w.suggestion}")
76
+ lines.append("")
77
+
78
+ lines.append(f"{_s('dim', '---')}")
79
+ total = len(warnings)
80
+ errors = sum(1 for w in warnings if w.severity == Severity.ERROR)
81
+ warns = sum(1 for w in warnings if w.severity == Severity.WARNING)
82
+ infos = sum(1 for w in warnings if w.severity == Severity.INFO)
83
+ lines.append(f"Total: {total} issue(s) ({errors} error(s), {warns} warning(s), {infos} info(s))")
84
+ return "\n".join(lines)
85
+
86
+
87
+ def format_dynamic_races(
88
+ races: list[tuple[str, Any, Any]],
89
+ access_log: dict[str, list] | None = None,
90
+ ) -> str:
91
+ if not races:
92
+ return _s("green", "No data races detected") if _COLOR else "No data races detected"
93
+
94
+ overlap = Counter()
95
+ if access_log:
96
+ for var_name, records in access_log.items():
97
+ for i, r1 in enumerate(records):
98
+ for r2 in records[i + 1 :]:
99
+ if r1.thread_id != r2.thread_id:
100
+ if r1.operation == "write" or r2.operation == "write":
101
+ if r1.clock.conflicts_with(r2.clock):
102
+ key = _race_key(r1, r2)
103
+ overlap[key] += 1
104
+
105
+ lines: list[str] = [
106
+ _s("red", _s("bold", "Data races detected:")) if _COLOR else "Data races detected:",
107
+ "",
108
+ ]
109
+
110
+ for var_name, r1, r2 in races:
111
+ f1, l1 = r1.location
112
+ f2, l2 = r2.location
113
+ key = _race_key(r1, r2)
114
+ count = overlap.get(key, 0)
115
+ marker = _s("red", " [!]") if _COLOR else " [!]"
116
+
117
+ lines.append(f"{marker} {_s('bold', f'`{var_name}`')}")
118
+ lines.append(
119
+ f" {'├─' if count > 0 else '└─'} "
120
+ f"Thread-{r1.thread_id} ({_s('magenta', r1.operation)}) "
121
+ f"at {f1}:{l1}"
122
+ )
123
+ lines.append(
124
+ f" {'├─' if count > 1 else '└─'} "
125
+ f"Thread-{r2.thread_id} ({_s('magenta', r2.operation)}) "
126
+ f"at {f2}:{l2}"
127
+ )
128
+ lines.append(
129
+ f" └─ No happens-before relationship between accesses"
130
+ )
131
+ if count > 1:
132
+ lines.append(f" ({count} overlapping accesses)")
133
+ lines.append("")
134
+
135
+ total_unique = len(races)
136
+ total_overlap = sum(overlap.values())
137
+ summary = f"Summary: {total_unique} unique race pair(s), {total_overlap} total overlapping access(es)"
138
+ lines.append(_s("dim", summary) if _COLOR else summary)
139
+ return "\n".join(lines)
140
+
141
+
142
+ def _race_key(r1, r2) -> tuple:
143
+ tid1, tid2 = sorted([r1.thread_id, r2.thread_id])
144
+ loc1, loc2 = sorted([r1.location, r2.location])
145
+ return (r1.var_name, tid1, tid2, loc1, loc2)
146
+
147
+
148
+ def format_warnings_json(warnings: list[RaceWarning]) -> str:
149
+ return json.dumps(
150
+ [w.to_dict() for w in warnings],
151
+ indent=2,
152
+ ensure_ascii=False,
153
+ )
154
+
155
+
156
+ def format_dynamic_json(
157
+ races: list[tuple[str, Any, Any]],
158
+ ) -> str:
159
+ entries = []
160
+ for var_name, r1, r2 in races:
161
+ entries.append({
162
+ "var_name": var_name,
163
+ "thread_1": {
164
+ "id": r1.thread_id,
165
+ "operation": r1.operation,
166
+ "location": f"{r1.location[0]}:{r1.location[1]}",
167
+ },
168
+ "thread_2": {
169
+ "id": r2.thread_id,
170
+ "operation": r2.operation,
171
+ "location": f"{r2.location[0]}:{r2.location[1]}",
172
+ },
173
+ })
174
+ return json.dumps(entries, indent=2, ensure_ascii=False)