argot-engine 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {argot_engine-0.2.0 → argot_engine-0.2.2}/PKG-INFO +1 -1
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/check.py +115 -39
- argot_engine-0.2.2/argot/tests/test_check.py +92 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/PKG-INFO +1 -1
- {argot_engine-0.2.0 → argot_engine-0.2.2}/pyproject.toml +1 -1
- argot_engine-0.2.0/argot/tests/test_check.py +0 -43
- {argot_engine-0.2.0 → argot_engine-0.2.2}/README.md +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/__init__.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/__main__.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/dataset.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/explain.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/extract.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/fetch.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/git_walk.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/__init__.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/encoder.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/model.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/predictor.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/sigreg.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/__init__.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_explain.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_extract_smoke.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_fetch.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_git_walk.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_jepa.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_tokenize.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_train_smoke.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_validate.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tokenize.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/train.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/validate.py +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/SOURCES.txt +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/dependency_links.txt +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/entry_points.txt +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/requires.txt +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/top_level.txt +0 -0
- {argot_engine-0.2.0 → argot_engine-0.2.2}/setup.cfg +0 -0
|
@@ -2,13 +2,15 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
4
|
import sys
|
|
5
|
+
from collections.abc import Iterator
|
|
5
6
|
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
6
8
|
|
|
7
9
|
import joblib # type: ignore[import-untyped]
|
|
8
10
|
import pygit2
|
|
9
11
|
import torch
|
|
10
12
|
|
|
11
|
-
from argot.git_walk import walk_commits
|
|
13
|
+
from argot.git_walk import SUPPORTED_EXTENSIONS, _extension, walk_commits
|
|
12
14
|
from argot.jepa.encoder import TokenEncoder
|
|
13
15
|
from argot.jepa.model import JEPAArgot
|
|
14
16
|
from argot.jepa.predictor import ArgotPredictor
|
|
@@ -37,42 +39,43 @@ def _resolve_shas(repo: pygit2.Repository, ref: str) -> set[str]:
|
|
|
37
39
|
return shas
|
|
38
40
|
|
|
39
41
|
|
|
40
|
-
def
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
42
|
+
def _workdir_patches(
|
|
43
|
+
repo_path: str,
|
|
44
|
+
) -> Iterator[tuple[str, bytes, list[pygit2.DiffHunk]]]:
|
|
45
|
+
"""Yield (file_path, content, hunks) for uncommitted changes vs HEAD."""
|
|
46
|
+
repo = pygit2.Repository(repo_path)
|
|
47
|
+
head_oid = repo.revparse_single("HEAD").id
|
|
48
|
+
diff = repo.diff(a=head_oid)
|
|
49
|
+
diff.find_similar()
|
|
50
|
+
workdir = Path(repo.workdir)
|
|
51
|
+
for patch in diff:
|
|
52
|
+
if patch is None:
|
|
53
|
+
continue
|
|
54
|
+
file_path = patch.delta.new_file.path
|
|
55
|
+
if _extension(file_path) not in SUPPORTED_EXTENSIONS:
|
|
56
|
+
continue
|
|
57
|
+
hunks = list(patch.hunks)
|
|
58
|
+
if not hunks:
|
|
59
|
+
continue
|
|
60
|
+
full_path = workdir / file_path
|
|
61
|
+
if not full_path.exists():
|
|
62
|
+
continue
|
|
63
|
+
yield file_path, full_path.read_bytes(), hunks
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _score_patches(
|
|
67
|
+
patches: Iterator[tuple[str, bytes, list[pygit2.DiffHunk]]],
|
|
68
|
+
vectorizer: Any,
|
|
69
|
+
model: JEPAArgot,
|
|
70
|
+
label: str,
|
|
71
|
+
) -> tuple[list[tuple[float, str, int, str]], int]:
|
|
72
|
+
"""Score hunk patches; returns (results, total_hunk_count)."""
|
|
73
|
+
context_lines = 50
|
|
71
74
|
results: list[tuple[float, str, int, str]] = []
|
|
75
|
+
hunk_count = 0
|
|
72
76
|
|
|
73
|
-
context_lines = 50
|
|
74
77
|
with torch.no_grad():
|
|
75
|
-
for
|
|
78
|
+
for file_path, post_blob, hunks in patches:
|
|
76
79
|
lang = language_for_path(file_path)
|
|
77
80
|
if lang is None:
|
|
78
81
|
continue
|
|
@@ -82,6 +85,7 @@ def main() -> None:
|
|
|
82
85
|
continue
|
|
83
86
|
|
|
84
87
|
for hunk in hunks:
|
|
88
|
+
hunk_count += 1
|
|
85
89
|
hunk_start = hunk.new_start - 1
|
|
86
90
|
hunk_end = hunk_start + hunk.new_lines
|
|
87
91
|
if hunk_start < 0 or hunk_end > len(source_lines):
|
|
@@ -102,17 +106,89 @@ def main() -> None:
|
|
|
102
106
|
)
|
|
103
107
|
|
|
104
108
|
score = model.surprise(ctx_vec, hunk_vec).item()
|
|
105
|
-
results.append((score, file_path, hunk.new_start,
|
|
109
|
+
results.append((score, file_path, hunk.new_start, label))
|
|
110
|
+
|
|
111
|
+
return results, hunk_count
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def main() -> None:
|
|
115
|
+
parser = argparse.ArgumentParser(description="Check code surprise with argot JEPA model")
|
|
116
|
+
parser.add_argument("repo_path")
|
|
117
|
+
parser.add_argument("ref", nargs="?", default="")
|
|
118
|
+
parser.add_argument("--model", default=".argot/model.pkl")
|
|
119
|
+
parser.add_argument("--threshold", type=float, default=0.5)
|
|
120
|
+
args = parser.parse_args()
|
|
121
|
+
|
|
122
|
+
model_path = Path(args.model)
|
|
123
|
+
if not model_path.exists():
|
|
124
|
+
print(f"error: model not found at {model_path}", file=sys.stderr)
|
|
125
|
+
sys.exit(2)
|
|
126
|
+
|
|
127
|
+
bundle = joblib.load(model_path)
|
|
128
|
+
vectorizer = bundle["vectorizer"]
|
|
129
|
+
embed_dim: int = bundle["embed_dim"]
|
|
130
|
+
input_dim: int = bundle["input_dim"]
|
|
131
|
+
|
|
132
|
+
encoder = TokenEncoder(input_dim, embed_dim)
|
|
133
|
+
encoder.load_state_dict(bundle["encoder_state"])
|
|
134
|
+
predictor = ArgotPredictor(embed_dim=embed_dim)
|
|
135
|
+
predictor.load_state_dict(bundle["predictor_state"])
|
|
136
|
+
model = JEPAArgot(encoder, predictor)
|
|
137
|
+
model.eval()
|
|
138
|
+
|
|
139
|
+
if args.ref == "":
|
|
140
|
+
patches: Iterator[tuple[str, bytes, list[pygit2.DiffHunk]]] = _workdir_patches(
|
|
141
|
+
args.repo_path
|
|
142
|
+
)
|
|
143
|
+
context_label = "workdir"
|
|
144
|
+
commit_info = "working tree"
|
|
145
|
+
else:
|
|
146
|
+
repo = pygit2.Repository(args.repo_path)
|
|
147
|
+
shas = _resolve_shas(repo, args.ref)
|
|
148
|
+
if not shas:
|
|
149
|
+
print("No commits found in range", file=sys.stderr)
|
|
150
|
+
sys.exit(0)
|
|
151
|
+
|
|
152
|
+
def _committed_patches() -> Iterator[tuple[str, bytes, list[pygit2.DiffHunk]]]:
|
|
153
|
+
for _commit, file_path, post_blob, hunks in walk_commits(args.repo_path, shas):
|
|
154
|
+
yield file_path, post_blob, hunks
|
|
155
|
+
|
|
156
|
+
patches = _committed_patches()
|
|
157
|
+
context_label = args.ref
|
|
158
|
+
commit_info = f"{len(shas)} commit(s)"
|
|
159
|
+
|
|
160
|
+
results, hunk_count = _score_patches(patches, vectorizer, model, context_label)
|
|
106
161
|
|
|
107
162
|
if not results:
|
|
108
|
-
|
|
163
|
+
if hunk_count == 0:
|
|
164
|
+
exts = " ".join(sorted(SUPPORTED_EXTENSIONS))
|
|
165
|
+
print(
|
|
166
|
+
f"No changes to supported files found ({commit_info} scanned).\n"
|
|
167
|
+
f"Supported extensions: {exts}"
|
|
168
|
+
)
|
|
169
|
+
if args.ref != "":
|
|
170
|
+
print("Try a wider range, e.g.: argot check HEAD~20..HEAD")
|
|
171
|
+
else:
|
|
172
|
+
print(
|
|
173
|
+
f"All {hunk_count} hunk(s) scored below threshold {args.threshold:.2f}"
|
|
174
|
+
" — looks clean."
|
|
175
|
+
)
|
|
109
176
|
sys.exit(0)
|
|
110
177
|
|
|
111
178
|
results.sort(key=lambda r: r[0], reverse=True)
|
|
112
179
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
180
|
+
t = args.threshold
|
|
181
|
+
print(f"{'SURPRISE':>9} {'TAG':<10} {'FILE':<48} {'LINE':>5} REF")
|
|
182
|
+
for score, fp, line, ref in results:
|
|
183
|
+
if score <= t:
|
|
184
|
+
tag = "ok"
|
|
185
|
+
elif score <= t + 0.3:
|
|
186
|
+
tag = "unusual"
|
|
187
|
+
elif score <= t + 0.6:
|
|
188
|
+
tag = "suspicious"
|
|
189
|
+
else:
|
|
190
|
+
tag = "foreign"
|
|
191
|
+
print(f"{score:>9.4f} {tag:<10} {fp:<48} {line:>5} {ref}")
|
|
116
192
|
|
|
117
193
|
if any(s > args.threshold for s, *_ in results):
|
|
118
194
|
sys.exit(1)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
import pygit2
|
|
7
|
+
|
|
8
|
+
from argot.check import _resolve_shas, _workdir_patches
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _make_repo(tmp_path: Path, files: dict[str, str]) -> pygit2.Repository:
|
|
12
|
+
"""Create a repo with a single commit containing the given files."""
|
|
13
|
+
repo = pygit2.init_repository(str(tmp_path))
|
|
14
|
+
sig = pygit2.Signature("Test", "test@example.com")
|
|
15
|
+
for name, content in files.items():
|
|
16
|
+
(tmp_path / name).write_text(content)
|
|
17
|
+
repo.index.add(name)
|
|
18
|
+
repo.index.write()
|
|
19
|
+
tree = repo.index.write_tree()
|
|
20
|
+
repo.create_commit("refs/heads/main", sig, sig, "init", tree, [])
|
|
21
|
+
repo.set_head("refs/heads/main")
|
|
22
|
+
return repo
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _make_two_commit_repo(tmp_path: Path) -> pygit2.Repository:
|
|
26
|
+
repo = pygit2.init_repository(str(tmp_path))
|
|
27
|
+
sig = pygit2.Signature("Test", "test@example.com")
|
|
28
|
+
f = tmp_path / "main.py"
|
|
29
|
+
f.write_text("x = 1\n")
|
|
30
|
+
repo.index.add("main.py")
|
|
31
|
+
repo.index.write()
|
|
32
|
+
tree1 = repo.index.write_tree()
|
|
33
|
+
c1 = repo.create_commit("refs/heads/main", sig, sig, "first", tree1, [])
|
|
34
|
+
f.write_text("x = 2\n")
|
|
35
|
+
repo.index.add("main.py")
|
|
36
|
+
repo.index.write()
|
|
37
|
+
tree2 = repo.index.write_tree()
|
|
38
|
+
repo.create_commit("refs/heads/main", sig, sig, "second", tree2, [c1])
|
|
39
|
+
return repo
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_resolve_shas_range(tmp_path: Path) -> None:
|
|
43
|
+
repo = _make_two_commit_repo(tmp_path)
|
|
44
|
+
head_oid = repo.references["refs/heads/main"].target
|
|
45
|
+
commit = cast(pygit2.Commit, repo.get(head_oid))
|
|
46
|
+
parent_oid = commit.parents[0].id
|
|
47
|
+
|
|
48
|
+
shas = _resolve_shas(repo, f"{parent_oid}..refs/heads/main")
|
|
49
|
+
assert str(head_oid) in shas
|
|
50
|
+
assert str(parent_oid) not in shas
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_resolve_shas_bare_ref(tmp_path: Path) -> None:
|
|
54
|
+
repo = _make_two_commit_repo(tmp_path)
|
|
55
|
+
head_oid = str(repo.references["refs/heads/main"].target)
|
|
56
|
+
shas = _resolve_shas(repo, "refs/heads/main")
|
|
57
|
+
assert head_oid in shas
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_workdir_patches_detects_modification(tmp_path: Path) -> None:
|
|
61
|
+
_make_repo(tmp_path, {"main.py": "x = 1\n"})
|
|
62
|
+
(tmp_path / "main.py").write_text("x = 1\ny = 2\n")
|
|
63
|
+
|
|
64
|
+
patches = list(_workdir_patches(str(tmp_path)))
|
|
65
|
+
assert len(patches) == 1
|
|
66
|
+
file_path, content, hunks = patches[0]
|
|
67
|
+
assert file_path == "main.py"
|
|
68
|
+
assert b"y = 2" in content
|
|
69
|
+
assert len(hunks) > 0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_workdir_patches_ignores_unsupported_extension(tmp_path: Path) -> None:
|
|
73
|
+
_make_repo(tmp_path, {"config.json": "{}\n"})
|
|
74
|
+
(tmp_path / "config.json").write_text('{"key": "value"}\n')
|
|
75
|
+
|
|
76
|
+
patches = list(_workdir_patches(str(tmp_path)))
|
|
77
|
+
assert patches == []
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_workdir_patches_ignores_deleted_files(tmp_path: Path) -> None:
|
|
81
|
+
_make_repo(tmp_path, {"main.py": "x = 1\n"})
|
|
82
|
+
(tmp_path / "main.py").unlink()
|
|
83
|
+
|
|
84
|
+
patches = list(_workdir_patches(str(tmp_path)))
|
|
85
|
+
assert patches == []
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_workdir_patches_no_changes(tmp_path: Path) -> None:
|
|
89
|
+
_make_repo(tmp_path, {"main.py": "x = 1\n"})
|
|
90
|
+
|
|
91
|
+
patches = list(_workdir_patches(str(tmp_path)))
|
|
92
|
+
assert patches == []
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
from typing import cast
|
|
5
|
-
|
|
6
|
-
import pygit2
|
|
7
|
-
|
|
8
|
-
from argot.check import _resolve_shas
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _make_two_commit_repo(tmp_path: Path) -> pygit2.Repository:
|
|
12
|
-
repo = pygit2.init_repository(str(tmp_path))
|
|
13
|
-
sig = pygit2.Signature("Test", "test@example.com")
|
|
14
|
-
f = tmp_path / "main.py"
|
|
15
|
-
f.write_text("x = 1\n")
|
|
16
|
-
repo.index.add("main.py")
|
|
17
|
-
repo.index.write()
|
|
18
|
-
tree1 = repo.index.write_tree()
|
|
19
|
-
c1 = repo.create_commit("refs/heads/main", sig, sig, "first", tree1, [])
|
|
20
|
-
f.write_text("x = 2\n")
|
|
21
|
-
repo.index.add("main.py")
|
|
22
|
-
repo.index.write()
|
|
23
|
-
tree2 = repo.index.write_tree()
|
|
24
|
-
repo.create_commit("refs/heads/main", sig, sig, "second", tree2, [c1])
|
|
25
|
-
return repo
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def test_resolve_shas_range(tmp_path: Path) -> None:
|
|
29
|
-
repo = _make_two_commit_repo(tmp_path)
|
|
30
|
-
head_oid = repo.references["refs/heads/main"].target
|
|
31
|
-
commit = cast(pygit2.Commit, repo.get(head_oid))
|
|
32
|
-
parent_oid = commit.parents[0].id
|
|
33
|
-
|
|
34
|
-
shas = _resolve_shas(repo, f"{parent_oid}..refs/heads/main")
|
|
35
|
-
assert str(head_oid) in shas
|
|
36
|
-
assert str(parent_oid) not in shas
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def test_resolve_shas_bare_ref(tmp_path: Path) -> None:
|
|
40
|
-
repo = _make_two_commit_repo(tmp_path)
|
|
41
|
-
head_oid = str(repo.references["refs/heads/main"].target)
|
|
42
|
-
shas = _resolve_shas(repo, "refs/heads/main")
|
|
43
|
-
assert head_oid in shas
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|