threadcheck 0.0.1.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- threadcheck/__init__.py +14 -0
- threadcheck/__main__.py +3 -0
- threadcheck/_tid.py +5 -0
- threadcheck/_version.py +1 -0
- threadcheck/cli.py +213 -0
- threadcheck/compat/__init__.py +4 -0
- threadcheck/compat/checker.py +95 -0
- threadcheck/compat/models.py +40 -0
- threadcheck/dynamic/__init__.py +0 -0
- threadcheck/dynamic/__main__.py +47 -0
- threadcheck/dynamic/clock.py +31 -0
- threadcheck/dynamic/hook.py +97 -0
- threadcheck/dynamic/tracker.py +177 -0
- threadcheck/dynamic/transform.py +192 -0
- threadcheck/pytest_plugin.py +60 -0
- threadcheck/reporting/__init__.py +10 -0
- threadcheck/reporting/formatter.py +174 -0
- threadcheck/reporting/sarif.py +100 -0
- threadcheck/reporting/types.py +8 -0
- threadcheck/static/__init__.py +0 -0
- threadcheck/static/analyzer.py +104 -0
- threadcheck/static/lock_tracker.py +42 -0
- threadcheck/static/models.py +48 -0
- threadcheck/static/visitors.py +324 -0
- threadcheck-0.0.1.1.dist-info/METADATA +248 -0
- threadcheck-0.0.1.1.dist-info/RECORD +29 -0
- threadcheck-0.0.1.1.dist-info/WHEEL +4 -0
- threadcheck-0.0.1.1.dist-info/entry_points.txt +5 -0
- threadcheck-0.0.1.1.dist-info/licenses/LICENSE +21 -0
threadcheck/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
threadcheck — Data-race detector for multi-threaded Python.
|
|
3
|
+
|
|
4
|
+
Supports both AST-based static analysis and runtime dynamic
|
|
5
|
+
detection via bytecode instrumentation.
|
|
6
|
+
|
|
7
|
+
Targets Python 3.14+ free-threading builds.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from ._version import __version__
|
|
11
|
+
|
|
12
|
+
from .static.analyzer import analyze_path, analyze_file
|
|
13
|
+
from .static.models import RaceWarning, Severity, WarningCategory
|
|
14
|
+
from .compat import check_compat, FTCompatResult, CompatStatus
|
threadcheck/__main__.py
ADDED
threadcheck/_tid.py
ADDED
threadcheck/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.0.1.1"
|
threadcheck/cli.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json as _json
|
|
3
|
+
import re
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from ._version import __version__
|
|
8
|
+
from .static.analyzer import analyze_path
|
|
9
|
+
from .reporting.formatter import format_report, format_warnings_json
|
|
10
|
+
from .reporting.sarif import format_sarif
|
|
11
|
+
from .dynamic.__main__ import run_script
|
|
12
|
+
from .compat import check_compat
|
|
13
|
+
from .compat.models import CompatStatus
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main():
|
|
17
|
+
parser = argparse.ArgumentParser(
|
|
18
|
+
prog="threadcheck",
|
|
19
|
+
description="Data Race Detector for Python",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
parser.add_argument(
|
|
23
|
+
"--version", action="version", version=f"threadcheck {__version__}"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
27
|
+
|
|
28
|
+
scan = sub.add_parser("scan", help="Static analysis for data races")
|
|
29
|
+
scan.add_argument("path", help="File or directory to scan")
|
|
30
|
+
fmt = scan.add_mutually_exclusive_group()
|
|
31
|
+
fmt.add_argument("--json", action="store_true", help="Output in JSON format")
|
|
32
|
+
fmt.add_argument("--sarif", action="store_true", help="Output in SARIF v2.1.0 format")
|
|
33
|
+
scan.add_argument("-o", "--output", help="Write output to file (default: stdout)")
|
|
34
|
+
|
|
35
|
+
run = sub.add_parser("run", help="Dynamic race detection")
|
|
36
|
+
run.add_argument("script", help="Python script to execute")
|
|
37
|
+
run_fmt = run.add_mutually_exclusive_group()
|
|
38
|
+
run_fmt.add_argument("--json", action="store_true", help="Output in JSON format")
|
|
39
|
+
run.add_argument("-o", "--output", help="Write output to file (default: stdout)")
|
|
40
|
+
|
|
41
|
+
compat = sub.add_parser("compat", help="Check free-threading compatibility")
|
|
42
|
+
compat.add_argument("path", nargs="?", default=".", help="Project path")
|
|
43
|
+
compat.add_argument("--json", action="store_true", help="Output in JSON format")
|
|
44
|
+
|
|
45
|
+
args = parser.parse_args()
|
|
46
|
+
|
|
47
|
+
if args.command == "scan":
|
|
48
|
+
_do_scan(args)
|
|
49
|
+
elif args.command == "run":
|
|
50
|
+
_do_run(args)
|
|
51
|
+
elif args.command == "compat":
|
|
52
|
+
_do_compat(args)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _detect_format(output_path: str | None) -> str:
|
|
56
|
+
if output_path:
|
|
57
|
+
ext = Path(output_path).suffix.lower()
|
|
58
|
+
if ext == ".json":
|
|
59
|
+
return "json"
|
|
60
|
+
if ext == ".sarif":
|
|
61
|
+
return "sarif"
|
|
62
|
+
return "text"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _do_scan(args):
|
|
66
|
+
path = Path(args.path).resolve()
|
|
67
|
+
if not path.exists():
|
|
68
|
+
print(f"Path does not exist: {path}", file=sys.stderr)
|
|
69
|
+
sys.exit(1)
|
|
70
|
+
|
|
71
|
+
print(f"threadcheck scan -- analysing {path}")
|
|
72
|
+
print()
|
|
73
|
+
|
|
74
|
+
warnings = analyze_path(str(path))
|
|
75
|
+
|
|
76
|
+
fmt = _detect_format(args.output)
|
|
77
|
+
if args.json:
|
|
78
|
+
fmt = "json"
|
|
79
|
+
elif args.sarif:
|
|
80
|
+
fmt = "sarif"
|
|
81
|
+
|
|
82
|
+
if fmt == "json":
|
|
83
|
+
output = format_warnings_json(warnings)
|
|
84
|
+
_write_output(args.output, output)
|
|
85
|
+
elif fmt == "sarif":
|
|
86
|
+
output = format_sarif(warnings)
|
|
87
|
+
_write_output(args.output, output)
|
|
88
|
+
else:
|
|
89
|
+
text = format_report(warnings)
|
|
90
|
+
_write_output(args.output, text)
|
|
91
|
+
|
|
92
|
+
total = len(warnings)
|
|
93
|
+
errors = sum(1 for w in warnings if w.severity.value == "error")
|
|
94
|
+
warns = sum(1 for w in warnings if w.severity.value == "warning")
|
|
95
|
+
infos = sum(1 for w in warnings if w.severity.value == "info")
|
|
96
|
+
print()
|
|
97
|
+
print(f"Total: {total} issue(s) ({errors} error(s), {warns} warning(s), {infos} info)")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _do_run(args):
|
|
101
|
+
fmt = "json" if args.json else _detect_format(args.output)
|
|
102
|
+
run_script(args.script, output_format=fmt)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _do_compat(args):
|
|
106
|
+
path = Path(args.path).resolve()
|
|
107
|
+
names: list[str] | None = None
|
|
108
|
+
toml = path / "pyproject.toml"
|
|
109
|
+
req = path / "requirements.txt"
|
|
110
|
+
if path.is_dir() and toml.is_file():
|
|
111
|
+
names = _read_deps_from_pyproject(toml)
|
|
112
|
+
elif path.is_file() and path.suffix == ".txt":
|
|
113
|
+
names = _read_deps_from_requirements(path)
|
|
114
|
+
elif path.is_file() and path.name == "pyproject.toml":
|
|
115
|
+
names = _read_deps_from_pyproject(path)
|
|
116
|
+
|
|
117
|
+
results = check_compat(names)
|
|
118
|
+
|
|
119
|
+
if args.json:
|
|
120
|
+
obj = [r.to_dict() for r in results]
|
|
121
|
+
print(_json.dumps(obj, indent=2, ensure_ascii=False))
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
print(f"threadcheck compat - Free-threading compatibility check")
|
|
125
|
+
print(f"Python {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
|
|
126
|
+
print()
|
|
127
|
+
|
|
128
|
+
can_emoji = _can_print_emoji()
|
|
129
|
+
for r in results:
|
|
130
|
+
if r.status.value == "compatible":
|
|
131
|
+
icon = _icon("✅", "[OK]", can_emoji)
|
|
132
|
+
elif r.status.value == "needs_verification":
|
|
133
|
+
icon = _icon("⚠️", "[??]", can_emoji)
|
|
134
|
+
else:
|
|
135
|
+
icon = _icon("❌", "[--]", can_emoji)
|
|
136
|
+
print(f" {icon} {r.name:<20} {r.reason}")
|
|
137
|
+
|
|
138
|
+
print()
|
|
139
|
+
total = len(results)
|
|
140
|
+
compat_count = sum(1 for r in results if r.status == CompatStatus.COMPATIBLE)
|
|
141
|
+
needs_v = sum(1 for r in results if r.status == CompatStatus.NEEDS_VERIFICATION)
|
|
142
|
+
not_inst = sum(1 for r in results if r.status == CompatStatus.NOT_INSTALLED)
|
|
143
|
+
print(f"Total: {total} package(s) - {compat_count} compatible, {needs_v} need verification, {not_inst} not installed")
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _read_deps_from_pyproject(path: Path) -> list[str]:
|
|
147
|
+
try:
|
|
148
|
+
import tomllib
|
|
149
|
+
except ImportError:
|
|
150
|
+
return []
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
data = tomllib.loads(path.read_text(encoding="utf-8"))
|
|
154
|
+
except Exception:
|
|
155
|
+
return []
|
|
156
|
+
|
|
157
|
+
deps: list[str] = []
|
|
158
|
+
for key in ("dependencies", "optional-dependencies"):
|
|
159
|
+
section = data.get("project", {}).get(key, {})
|
|
160
|
+
if isinstance(section, dict):
|
|
161
|
+
for group in section.values():
|
|
162
|
+
deps.extend(_extract_names(group))
|
|
163
|
+
elif isinstance(section, list):
|
|
164
|
+
deps.extend(_extract_names(section))
|
|
165
|
+
return sorted(set(deps))
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _read_deps_from_requirements(path: Path) -> list[str]:
|
|
169
|
+
names: list[str] = []
|
|
170
|
+
for line in path.read_text(encoding="utf-8").splitlines():
|
|
171
|
+
line = line.strip()
|
|
172
|
+
if not line or line.startswith(("#", "-", "git+", "http")):
|
|
173
|
+
continue
|
|
174
|
+
name = re.split(r"[<>=!~@;]", line, maxsplit=1)[0].strip()
|
|
175
|
+
if name:
|
|
176
|
+
names.append(name)
|
|
177
|
+
return sorted(set(names))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _extract_names(entries: list) -> list[str]:
|
|
181
|
+
import re as _re
|
|
182
|
+
|
|
183
|
+
names: list[str] = []
|
|
184
|
+
for entry in entries:
|
|
185
|
+
if not isinstance(entry, str):
|
|
186
|
+
continue
|
|
187
|
+
name = _re.split(r"[<>=!~@;]", entry, maxsplit=1)[0].strip()
|
|
188
|
+
if name:
|
|
189
|
+
names.append(name)
|
|
190
|
+
return names
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _can_print_emoji() -> bool:
|
|
194
|
+
try:
|
|
195
|
+
"\u2705".encode(sys.stdout.encoding)
|
|
196
|
+
return True
|
|
197
|
+
except (UnicodeEncodeError, UnicodeDecodeError, AttributeError):
|
|
198
|
+
return False
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _icon(emoji: str, fallback: str, can_emoji: bool) -> str:
|
|
202
|
+
return emoji if can_emoji else fallback
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _write_output(path_arg: str | None, content: str):
|
|
206
|
+
if path_arg:
|
|
207
|
+
Path(path_arg).write_text(content, encoding="utf-8")
|
|
208
|
+
else:
|
|
209
|
+
print(content)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
if __name__ == "__main__":
|
|
213
|
+
main()
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Sequence
|
|
7
|
+
|
|
8
|
+
from .models import CExtInfo, CompatStatus, FTCompatResult
|
|
9
|
+
|
|
10
|
+
_C_EXT_RE = re.compile(r"\.(pyd|so)$", re.IGNORECASE)
|
|
11
|
+
_FT_TAG_RE = re.compile(r"\.(cp\d+t|cpython-\d+t)-")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _check_single(name: str) -> FTCompatResult:
|
|
15
|
+
try:
|
|
16
|
+
from importlib.metadata import PackageNotFoundError, distribution
|
|
17
|
+
except ImportError:
|
|
18
|
+
return FTCompatResult(
|
|
19
|
+
name=name,
|
|
20
|
+
status=CompatStatus.NOT_INSTALLED,
|
|
21
|
+
reason="importlib.metadata not available",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
dist = distribution(name)
|
|
26
|
+
except PackageNotFoundError:
|
|
27
|
+
return FTCompatResult(
|
|
28
|
+
name=name,
|
|
29
|
+
status=CompatStatus.NOT_INSTALLED,
|
|
30
|
+
reason="not installed",
|
|
31
|
+
)
|
|
32
|
+
except Exception as exc:
|
|
33
|
+
return FTCompatResult(
|
|
34
|
+
name=name,
|
|
35
|
+
status=CompatStatus.NOT_INSTALLED,
|
|
36
|
+
reason=str(exc),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
c_exts: list[CExtInfo] = []
|
|
40
|
+
try:
|
|
41
|
+
files = dist.files or []
|
|
42
|
+
except Exception:
|
|
43
|
+
files = []
|
|
44
|
+
|
|
45
|
+
for f in files:
|
|
46
|
+
fname = str(f)
|
|
47
|
+
if not _C_EXT_RE.search(fname):
|
|
48
|
+
continue
|
|
49
|
+
c_exts.append(
|
|
50
|
+
CExtInfo(filename=fname, has_ft_tag=bool(_FT_TAG_RE.search(fname)))
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if not c_exts:
|
|
54
|
+
return FTCompatResult(
|
|
55
|
+
name=name,
|
|
56
|
+
status=CompatStatus.COMPATIBLE,
|
|
57
|
+
reason="pure Python, no C extensions",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
ft_missing = [e for e in c_exts if not e.has_ft_tag]
|
|
61
|
+
if ft_missing:
|
|
62
|
+
return FTCompatResult(
|
|
63
|
+
name=name,
|
|
64
|
+
status=CompatStatus.NEEDS_VERIFICATION,
|
|
65
|
+
c_exts=c_exts,
|
|
66
|
+
reason=f"{len(ft_missing)} C extension(s) not compiled for free-threading ABI",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return FTCompatResult(
|
|
70
|
+
name=name,
|
|
71
|
+
status=CompatStatus.COMPATIBLE,
|
|
72
|
+
c_exts=c_exts,
|
|
73
|
+
reason="all C extensions compiled for free-threading ABI",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def check_compat(
|
|
78
|
+
names: Sequence[str] | None = None,
|
|
79
|
+
) -> list[FTCompatResult]:
|
|
80
|
+
if names is not None:
|
|
81
|
+
return [_check_single(n) for n in names]
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
from importlib.metadata import distributions
|
|
85
|
+
except ImportError:
|
|
86
|
+
return []
|
|
87
|
+
|
|
88
|
+
results: list[FTCompatResult] = []
|
|
89
|
+
for dist in distributions():
|
|
90
|
+
name = dist.metadata.get("Name", "")
|
|
91
|
+
if name:
|
|
92
|
+
results.append(_check_single(name))
|
|
93
|
+
|
|
94
|
+
results.sort(key=lambda r: r.name.lower())
|
|
95
|
+
return results
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CompatStatus(Enum):
|
|
9
|
+
COMPATIBLE = "compatible"
|
|
10
|
+
NEEDS_VERIFICATION = "needs_verification"
|
|
11
|
+
NOT_INSTALLED = "not_installed"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class CExtInfo:
|
|
16
|
+
filename: str
|
|
17
|
+
has_ft_tag: bool
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def tag(self) -> str:
|
|
21
|
+
match = re.search(r"\.(cpython-\d+t?|cp\d+t?)-", self.filename)
|
|
22
|
+
if match:
|
|
23
|
+
return match.group(1)
|
|
24
|
+
return ""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class FTCompatResult:
|
|
29
|
+
name: str
|
|
30
|
+
status: CompatStatus
|
|
31
|
+
c_exts: list[CExtInfo] = field(default_factory=list)
|
|
32
|
+
reason: str = ""
|
|
33
|
+
|
|
34
|
+
def to_dict(self) -> dict:
|
|
35
|
+
return {
|
|
36
|
+
"name": self.name,
|
|
37
|
+
"status": self.status.value,
|
|
38
|
+
"c_exts": [{"filename": e.filename, "has_ft_tag": e.has_ft_tag} for e in self.c_exts],
|
|
39
|
+
"reason": self.reason,
|
|
40
|
+
}
|
|
File without changes
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import sys
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from .tracker import ThreadCheckTracker
|
|
6
|
+
from .transform import TrackInjector
|
|
7
|
+
from ..reporting.formatter import format_dynamic_races, format_dynamic_json
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def run_script(script_path: str, output_format: str = "text"):
|
|
11
|
+
path = Path(script_path).resolve()
|
|
12
|
+
if not path.exists():
|
|
13
|
+
print(f"File not found: {path}", file=sys.stderr)
|
|
14
|
+
sys.exit(1)
|
|
15
|
+
|
|
16
|
+
source = path.read_text(encoding="utf-8")
|
|
17
|
+
filename = str(path)
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
tree = ast.parse(source, filename=filename)
|
|
21
|
+
except SyntaxError as e:
|
|
22
|
+
print(f"Syntax error: {e}", file=sys.stderr)
|
|
23
|
+
sys.exit(1)
|
|
24
|
+
|
|
25
|
+
TrackInjector(filename=filename).transform(tree)
|
|
26
|
+
ast.fix_missing_locations(tree)
|
|
27
|
+
|
|
28
|
+
code = compile(tree, filename, "exec")
|
|
29
|
+
|
|
30
|
+
ThreadCheckTracker.start()
|
|
31
|
+
try:
|
|
32
|
+
exec(code, {"_threadcheck_tracker": ThreadCheckTracker, "__file__": filename})
|
|
33
|
+
except SystemExit:
|
|
34
|
+
pass
|
|
35
|
+
finally:
|
|
36
|
+
ThreadCheckTracker.stop()
|
|
37
|
+
|
|
38
|
+
races = ThreadCheckTracker.detect_races()
|
|
39
|
+
with ThreadCheckTracker._lock:
|
|
40
|
+
alog = dict(ThreadCheckTracker._access_log)
|
|
41
|
+
|
|
42
|
+
if output_format == "json":
|
|
43
|
+
print(format_dynamic_json(races))
|
|
44
|
+
else:
|
|
45
|
+
print(format_dynamic_races(races, alog))
|
|
46
|
+
|
|
47
|
+
ThreadCheckTracker.reset()
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class VectorClock:
|
|
5
|
+
def __init__(self):
|
|
6
|
+
self._clock: dict[int, int] = defaultdict(int)
|
|
7
|
+
|
|
8
|
+
def tick(self, thread_id: int) -> int:
|
|
9
|
+
self._clock[thread_id] += 1
|
|
10
|
+
return self._clock[thread_id]
|
|
11
|
+
|
|
12
|
+
def merge(self, other: "VectorClock"):
|
|
13
|
+
for k, v in other._clock.items():
|
|
14
|
+
self._clock[k] = max(self._clock.get(k, 0), v)
|
|
15
|
+
|
|
16
|
+
def conflicts_with(self, other: "VectorClock") -> bool:
|
|
17
|
+
return not (self._leq(other) or other._leq(self))
|
|
18
|
+
|
|
19
|
+
def _leq(self, other: "VectorClock") -> bool:
|
|
20
|
+
for k, v in self._clock.items():
|
|
21
|
+
if v > other._clock.get(k, 0):
|
|
22
|
+
return False
|
|
23
|
+
return True
|
|
24
|
+
|
|
25
|
+
def copy(self) -> "VectorClock":
|
|
26
|
+
vc = VectorClock()
|
|
27
|
+
vc._clock = self._clock.copy()
|
|
28
|
+
return vc
|
|
29
|
+
|
|
30
|
+
def __repr__(self) -> str:
|
|
31
|
+
return f"VectorClock({dict(self._clock)})"
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import ast
|
|
3
|
+
import builtins
|
|
4
|
+
import importlib.util
|
|
5
|
+
import importlib.abc
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .transform import TrackInjector
|
|
9
|
+
from .tracker import ThreadCheckTracker
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ThreadCheckLoader(importlib.abc.Loader):
|
|
13
|
+
def __init__(self, tracker=None):
|
|
14
|
+
self.tracker = tracker or ThreadCheckTracker
|
|
15
|
+
|
|
16
|
+
def create_module(self, spec):
|
|
17
|
+
return None
|
|
18
|
+
|
|
19
|
+
def exec_module(self, module):
|
|
20
|
+
spec = module.__spec__
|
|
21
|
+
source = self._get_source(spec, module)
|
|
22
|
+
if source is None:
|
|
23
|
+
raise ImportError(f"cannot load source for {spec.name}")
|
|
24
|
+
|
|
25
|
+
module.__file__ = spec.origin
|
|
26
|
+
|
|
27
|
+
tree = ast.parse(source, filename=spec.origin)
|
|
28
|
+
TrackInjector(filename=str(spec.origin)).transform(tree)
|
|
29
|
+
ast.fix_missing_locations(tree)
|
|
30
|
+
|
|
31
|
+
code = compile(tree, spec.origin, "exec")
|
|
32
|
+
globals_dict = module.__dict__
|
|
33
|
+
globals_dict["_threadcheck_tracker"] = self.tracker
|
|
34
|
+
builtins._threadcheck_tracker = self.tracker
|
|
35
|
+
exec(code, globals_dict)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _get_source(spec, module=None):
|
|
39
|
+
for candidate in (spec.origin, getattr(module, "__file__", None)):
|
|
40
|
+
if candidate and Path(candidate).suffix == ".py":
|
|
41
|
+
try:
|
|
42
|
+
return Path(candidate).read_text(encoding="utf-8")
|
|
43
|
+
except Exception:
|
|
44
|
+
pass
|
|
45
|
+
if hasattr(spec.loader, "get_source"):
|
|
46
|
+
try:
|
|
47
|
+
return spec.loader.get_source(spec.name)
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ThreadCheckFinder(importlib.abc.MetaPathFinder):
|
|
54
|
+
def __init__(self, tracker=None, include_paths=None):
|
|
55
|
+
self.tracker = tracker or ThreadCheckTracker
|
|
56
|
+
self._include_paths = (
|
|
57
|
+
[Path(p).resolve() for p in include_paths] if include_paths else []
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def _should_instrument(self, filepath: Path) -> bool:
|
|
61
|
+
if not self._include_paths:
|
|
62
|
+
return True
|
|
63
|
+
resolved = filepath.resolve()
|
|
64
|
+
return any(_is_under(resolved, inc) for inc in self._include_paths)
|
|
65
|
+
|
|
66
|
+
def find_spec(self, fullname, path, target=None):
|
|
67
|
+
for entry in (path or sys.path):
|
|
68
|
+
if entry == "":
|
|
69
|
+
entry = "."
|
|
70
|
+
base = Path(entry) / f"{fullname.replace('.', '/')}.py"
|
|
71
|
+
if base.exists() and self._should_instrument(base):
|
|
72
|
+
spec = importlib.util.spec_from_file_location(
|
|
73
|
+
fullname,
|
|
74
|
+
str(base),
|
|
75
|
+
loader=ThreadCheckLoader(self.tracker),
|
|
76
|
+
)
|
|
77
|
+
return spec
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def install_hook(tracker=None, include_paths=None):
|
|
82
|
+
hook = ThreadCheckFinder(tracker, include_paths)
|
|
83
|
+
sys.meta_path.insert(0, hook)
|
|
84
|
+
return hook
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def uninstall_hook(hook):
|
|
88
|
+
if hook in sys.meta_path:
|
|
89
|
+
sys.meta_path.remove(hook)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _is_under(child: Path, parent: Path) -> bool:
|
|
93
|
+
try:
|
|
94
|
+
child.relative_to(parent)
|
|
95
|
+
return True
|
|
96
|
+
except ValueError:
|
|
97
|
+
return False
|