argot-engine 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {argot_engine-0.2.0 → argot_engine-0.2.2}/PKG-INFO +1 -1
  2. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/check.py +115 -39
  3. argot_engine-0.2.2/argot/tests/test_check.py +92 -0
  4. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/PKG-INFO +1 -1
  5. {argot_engine-0.2.0 → argot_engine-0.2.2}/pyproject.toml +1 -1
  6. argot_engine-0.2.0/argot/tests/test_check.py +0 -43
  7. {argot_engine-0.2.0 → argot_engine-0.2.2}/README.md +0 -0
  8. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/__init__.py +0 -0
  9. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/__main__.py +0 -0
  10. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/dataset.py +0 -0
  11. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/explain.py +0 -0
  12. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/extract.py +0 -0
  13. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/fetch.py +0 -0
  14. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/git_walk.py +0 -0
  15. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/__init__.py +0 -0
  16. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/encoder.py +0 -0
  17. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/model.py +0 -0
  18. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/predictor.py +0 -0
  19. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/jepa/sigreg.py +0 -0
  20. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/__init__.py +0 -0
  21. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_explain.py +0 -0
  22. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_extract_smoke.py +0 -0
  23. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_fetch.py +0 -0
  24. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_git_walk.py +0 -0
  25. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_jepa.py +0 -0
  26. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_tokenize.py +0 -0
  27. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_train_smoke.py +0 -0
  28. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tests/test_validate.py +0 -0
  29. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/tokenize.py +0 -0
  30. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/train.py +0 -0
  31. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot/validate.py +0 -0
  32. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/SOURCES.txt +0 -0
  33. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/dependency_links.txt +0 -0
  34. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/entry_points.txt +0 -0
  35. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/requires.txt +0 -0
  36. {argot_engine-0.2.0 → argot_engine-0.2.2}/argot_engine.egg-info/top_level.txt +0 -0
  37. {argot_engine-0.2.0 → argot_engine-0.2.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: argot-engine
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Requires-Python: >=3.11
5
5
  Requires-Dist: pygit2==1.19.2
6
6
  Requires-Dist: scikit-learn>=1.5.0
@@ -2,13 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  import argparse
4
4
  import sys
5
+ from collections.abc import Iterator
5
6
  from pathlib import Path
7
+ from typing import Any
6
8
 
7
9
  import joblib # type: ignore[import-untyped]
8
10
  import pygit2
9
11
  import torch
10
12
 
11
- from argot.git_walk import walk_commits
13
+ from argot.git_walk import SUPPORTED_EXTENSIONS, _extension, walk_commits
12
14
  from argot.jepa.encoder import TokenEncoder
13
15
  from argot.jepa.model import JEPAArgot
14
16
  from argot.jepa.predictor import ArgotPredictor
@@ -37,42 +39,43 @@ def _resolve_shas(repo: pygit2.Repository, ref: str) -> set[str]:
37
39
  return shas
38
40
 
39
41
 
40
- def main() -> None:
41
- parser = argparse.ArgumentParser(description="Check code surprise with argot JEPA model")
42
- parser.add_argument("repo_path")
43
- parser.add_argument("ref")
44
- parser.add_argument("--model", default=".argot/model.pkl")
45
- parser.add_argument("--threshold", type=float, default=0.5)
46
- args = parser.parse_args()
47
-
48
- model_path = Path(args.model)
49
- if not model_path.exists():
50
- print(f"error: model not found at {model_path}", file=sys.stderr)
51
- sys.exit(2)
52
-
53
- bundle = joblib.load(model_path)
54
- vectorizer = bundle["vectorizer"]
55
- embed_dim: int = bundle["embed_dim"]
56
- input_dim: int = bundle["input_dim"]
57
-
58
- encoder = TokenEncoder(input_dim, embed_dim)
59
- encoder.load_state_dict(bundle["encoder_state"])
60
- predictor = ArgotPredictor(embed_dim=embed_dim)
61
- predictor.load_state_dict(bundle["predictor_state"])
62
- model = JEPAArgot(encoder, predictor)
63
- model.eval()
64
-
65
- repo = pygit2.Repository(args.repo_path)
66
- shas = _resolve_shas(repo, args.ref)
67
- if not shas:
68
- print("No commits found in range", file=sys.stderr)
69
- sys.exit(0)
70
-
42
+ def _workdir_patches(
43
+ repo_path: str,
44
+ ) -> Iterator[tuple[str, bytes, list[pygit2.DiffHunk]]]:
45
+ """Yield (file_path, content, hunks) for uncommitted changes vs HEAD."""
46
+ repo = pygit2.Repository(repo_path)
47
+ head_oid = repo.revparse_single("HEAD").id
48
+ diff = repo.diff(a=head_oid)
49
+ diff.find_similar()
50
+ workdir = Path(repo.workdir)
51
+ for patch in diff:
52
+ if patch is None:
53
+ continue
54
+ file_path = patch.delta.new_file.path
55
+ if _extension(file_path) not in SUPPORTED_EXTENSIONS:
56
+ continue
57
+ hunks = list(patch.hunks)
58
+ if not hunks:
59
+ continue
60
+ full_path = workdir / file_path
61
+ if not full_path.exists():
62
+ continue
63
+ yield file_path, full_path.read_bytes(), hunks
64
+
65
+
66
+ def _score_patches(
67
+ patches: Iterator[tuple[str, bytes, list[pygit2.DiffHunk]]],
68
+ vectorizer: Any,
69
+ model: JEPAArgot,
70
+ label: str,
71
+ ) -> tuple[list[tuple[float, str, int, str]], int]:
72
+ """Score hunk patches; returns (results, total_hunk_count)."""
73
+ context_lines = 50
71
74
  results: list[tuple[float, str, int, str]] = []
75
+ hunk_count = 0
72
76
 
73
- context_lines = 50
74
77
  with torch.no_grad():
75
- for commit, file_path, post_blob, hunks in walk_commits(args.repo_path, shas):
78
+ for file_path, post_blob, hunks in patches:
76
79
  lang = language_for_path(file_path)
77
80
  if lang is None:
78
81
  continue
@@ -82,6 +85,7 @@ def main() -> None:
82
85
  continue
83
86
 
84
87
  for hunk in hunks:
88
+ hunk_count += 1
85
89
  hunk_start = hunk.new_start - 1
86
90
  hunk_end = hunk_start + hunk.new_lines
87
91
  if hunk_start < 0 or hunk_end > len(source_lines):
@@ -102,17 +106,89 @@ def main() -> None:
102
106
  )
103
107
 
104
108
  score = model.surprise(ctx_vec, hunk_vec).item()
105
- results.append((score, file_path, hunk.new_start, str(commit.id)[:8]))
109
+ results.append((score, file_path, hunk.new_start, label))
110
+
111
+ return results, hunk_count
112
+
113
+
114
+ def main() -> None:
115
+ parser = argparse.ArgumentParser(description="Check code surprise with argot JEPA model")
116
+ parser.add_argument("repo_path")
117
+ parser.add_argument("ref", nargs="?", default="")
118
+ parser.add_argument("--model", default=".argot/model.pkl")
119
+ parser.add_argument("--threshold", type=float, default=0.5)
120
+ args = parser.parse_args()
121
+
122
+ model_path = Path(args.model)
123
+ if not model_path.exists():
124
+ print(f"error: model not found at {model_path}", file=sys.stderr)
125
+ sys.exit(2)
126
+
127
+ bundle = joblib.load(model_path)
128
+ vectorizer = bundle["vectorizer"]
129
+ embed_dim: int = bundle["embed_dim"]
130
+ input_dim: int = bundle["input_dim"]
131
+
132
+ encoder = TokenEncoder(input_dim, embed_dim)
133
+ encoder.load_state_dict(bundle["encoder_state"])
134
+ predictor = ArgotPredictor(embed_dim=embed_dim)
135
+ predictor.load_state_dict(bundle["predictor_state"])
136
+ model = JEPAArgot(encoder, predictor)
137
+ model.eval()
138
+
139
+ if args.ref == "":
140
+ patches: Iterator[tuple[str, bytes, list[pygit2.DiffHunk]]] = _workdir_patches(
141
+ args.repo_path
142
+ )
143
+ context_label = "workdir"
144
+ commit_info = "working tree"
145
+ else:
146
+ repo = pygit2.Repository(args.repo_path)
147
+ shas = _resolve_shas(repo, args.ref)
148
+ if not shas:
149
+ print("No commits found in range", file=sys.stderr)
150
+ sys.exit(0)
151
+
152
+ def _committed_patches() -> Iterator[tuple[str, bytes, list[pygit2.DiffHunk]]]:
153
+ for _commit, file_path, post_blob, hunks in walk_commits(args.repo_path, shas):
154
+ yield file_path, post_blob, hunks
155
+
156
+ patches = _committed_patches()
157
+ context_label = args.ref
158
+ commit_info = f"{len(shas)} commit(s)"
159
+
160
+ results, hunk_count = _score_patches(patches, vectorizer, model, context_label)
106
161
 
107
162
  if not results:
108
- print("No hunks found in range")
163
+ if hunk_count == 0:
164
+ exts = " ".join(sorted(SUPPORTED_EXTENSIONS))
165
+ print(
166
+ f"No changes to supported files found ({commit_info} scanned).\n"
167
+ f"Supported extensions: {exts}"
168
+ )
169
+ if args.ref != "":
170
+ print("Try a wider range, e.g.: argot check HEAD~20..HEAD")
171
+ else:
172
+ print(
173
+ f"All {hunk_count} hunk(s) scored below threshold {args.threshold:.2f}"
174
+ " — looks clean."
175
+ )
109
176
  sys.exit(0)
110
177
 
111
178
  results.sort(key=lambda r: r[0], reverse=True)
112
179
 
113
- print(f"{'SURPRISE':>9} {'FILE':<48} {'LINE':>5} COMMIT")
114
- for score, fp, line, sha in results:
115
- print(f"{score:>9.4f} {fp:<48} {line:>5} {sha}")
180
+ t = args.threshold
181
+ print(f"{'SURPRISE':>9} {'TAG':<10} {'FILE':<48} {'LINE':>5} REF")
182
+ for score, fp, line, ref in results:
183
+ if score <= t:
184
+ tag = "ok"
185
+ elif score <= t + 0.3:
186
+ tag = "unusual"
187
+ elif score <= t + 0.6:
188
+ tag = "suspicious"
189
+ else:
190
+ tag = "foreign"
191
+ print(f"{score:>9.4f} {tag:<10} {fp:<48} {line:>5} {ref}")
116
192
 
117
193
  if any(s > args.threshold for s, *_ in results):
118
194
  sys.exit(1)
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import cast
5
+
6
+ import pygit2
7
+
8
+ from argot.check import _resolve_shas, _workdir_patches
9
+
10
+
11
+ def _make_repo(tmp_path: Path, files: dict[str, str]) -> pygit2.Repository:
12
+ """Create a repo with a single commit containing the given files."""
13
+ repo = pygit2.init_repository(str(tmp_path))
14
+ sig = pygit2.Signature("Test", "test@example.com")
15
+ for name, content in files.items():
16
+ (tmp_path / name).write_text(content)
17
+ repo.index.add(name)
18
+ repo.index.write()
19
+ tree = repo.index.write_tree()
20
+ repo.create_commit("refs/heads/main", sig, sig, "init", tree, [])
21
+ repo.set_head("refs/heads/main")
22
+ return repo
23
+
24
+
25
+ def _make_two_commit_repo(tmp_path: Path) -> pygit2.Repository:
26
+ repo = pygit2.init_repository(str(tmp_path))
27
+ sig = pygit2.Signature("Test", "test@example.com")
28
+ f = tmp_path / "main.py"
29
+ f.write_text("x = 1\n")
30
+ repo.index.add("main.py")
31
+ repo.index.write()
32
+ tree1 = repo.index.write_tree()
33
+ c1 = repo.create_commit("refs/heads/main", sig, sig, "first", tree1, [])
34
+ f.write_text("x = 2\n")
35
+ repo.index.add("main.py")
36
+ repo.index.write()
37
+ tree2 = repo.index.write_tree()
38
+ repo.create_commit("refs/heads/main", sig, sig, "second", tree2, [c1])
39
+ return repo
40
+
41
+
42
+ def test_resolve_shas_range(tmp_path: Path) -> None:
43
+ repo = _make_two_commit_repo(tmp_path)
44
+ head_oid = repo.references["refs/heads/main"].target
45
+ commit = cast(pygit2.Commit, repo.get(head_oid))
46
+ parent_oid = commit.parents[0].id
47
+
48
+ shas = _resolve_shas(repo, f"{parent_oid}..refs/heads/main")
49
+ assert str(head_oid) in shas
50
+ assert str(parent_oid) not in shas
51
+
52
+
53
+ def test_resolve_shas_bare_ref(tmp_path: Path) -> None:
54
+ repo = _make_two_commit_repo(tmp_path)
55
+ head_oid = str(repo.references["refs/heads/main"].target)
56
+ shas = _resolve_shas(repo, "refs/heads/main")
57
+ assert head_oid in shas
58
+
59
+
60
+ def test_workdir_patches_detects_modification(tmp_path: Path) -> None:
61
+ _make_repo(tmp_path, {"main.py": "x = 1\n"})
62
+ (tmp_path / "main.py").write_text("x = 1\ny = 2\n")
63
+
64
+ patches = list(_workdir_patches(str(tmp_path)))
65
+ assert len(patches) == 1
66
+ file_path, content, hunks = patches[0]
67
+ assert file_path == "main.py"
68
+ assert b"y = 2" in content
69
+ assert len(hunks) > 0
70
+
71
+
72
+ def test_workdir_patches_ignores_unsupported_extension(tmp_path: Path) -> None:
73
+ _make_repo(tmp_path, {"config.json": "{}\n"})
74
+ (tmp_path / "config.json").write_text('{"key": "value"}\n')
75
+
76
+ patches = list(_workdir_patches(str(tmp_path)))
77
+ assert patches == []
78
+
79
+
80
+ def test_workdir_patches_ignores_deleted_files(tmp_path: Path) -> None:
81
+ _make_repo(tmp_path, {"main.py": "x = 1\n"})
82
+ (tmp_path / "main.py").unlink()
83
+
84
+ patches = list(_workdir_patches(str(tmp_path)))
85
+ assert patches == []
86
+
87
+
88
+ def test_workdir_patches_no_changes(tmp_path: Path) -> None:
89
+ _make_repo(tmp_path, {"main.py": "x = 1\n"})
90
+
91
+ patches = list(_workdir_patches(str(tmp_path)))
92
+ assert patches == []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: argot-engine
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Requires-Python: >=3.11
5
5
  Requires-Dist: pygit2==1.19.2
6
6
  Requires-Dist: scikit-learn>=1.5.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "argot-engine"
3
- version = "0.2.0"
3
+ version = "0.2.2"
4
4
  requires-python = ">=3.11"
5
5
  dependencies = [
6
6
  "pygit2==1.19.2",
@@ -1,43 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
- from typing import cast
5
-
6
- import pygit2
7
-
8
- from argot.check import _resolve_shas
9
-
10
-
11
- def _make_two_commit_repo(tmp_path: Path) -> pygit2.Repository:
12
- repo = pygit2.init_repository(str(tmp_path))
13
- sig = pygit2.Signature("Test", "test@example.com")
14
- f = tmp_path / "main.py"
15
- f.write_text("x = 1\n")
16
- repo.index.add("main.py")
17
- repo.index.write()
18
- tree1 = repo.index.write_tree()
19
- c1 = repo.create_commit("refs/heads/main", sig, sig, "first", tree1, [])
20
- f.write_text("x = 2\n")
21
- repo.index.add("main.py")
22
- repo.index.write()
23
- tree2 = repo.index.write_tree()
24
- repo.create_commit("refs/heads/main", sig, sig, "second", tree2, [c1])
25
- return repo
26
-
27
-
28
- def test_resolve_shas_range(tmp_path: Path) -> None:
29
- repo = _make_two_commit_repo(tmp_path)
30
- head_oid = repo.references["refs/heads/main"].target
31
- commit = cast(pygit2.Commit, repo.get(head_oid))
32
- parent_oid = commit.parents[0].id
33
-
34
- shas = _resolve_shas(repo, f"{parent_oid}..refs/heads/main")
35
- assert str(head_oid) in shas
36
- assert str(parent_oid) not in shas
37
-
38
-
39
- def test_resolve_shas_bare_ref(tmp_path: Path) -> None:
40
- repo = _make_two_commit_repo(tmp_path)
41
- head_oid = str(repo.references["refs/heads/main"].target)
42
- shas = _resolve_shas(repo, "refs/heads/main")
43
- assert head_oid in shas
File without changes
File without changes