argot-engine 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- argot_engine-0.1.4/PKG-INFO +12 -0
- argot_engine-0.1.4/README.md +3 -0
- argot_engine-0.1.4/argot/__init__.py +0 -0
- argot_engine-0.1.4/argot/__main__.py +3 -0
- argot_engine-0.1.4/argot/check.py +122 -0
- argot_engine-0.1.4/argot/dataset.py +29 -0
- argot_engine-0.1.4/argot/explain.py +167 -0
- argot_engine-0.1.4/argot/extract.py +121 -0
- argot_engine-0.1.4/argot/fetch.py +87 -0
- argot_engine-0.1.4/argot/git_walk.py +109 -0
- argot_engine-0.1.4/argot/jepa/__init__.py +0 -0
- argot_engine-0.1.4/argot/jepa/encoder.py +22 -0
- argot_engine-0.1.4/argot/jepa/model.py +60 -0
- argot_engine-0.1.4/argot/jepa/predictor.py +153 -0
- argot_engine-0.1.4/argot/jepa/sigreg.py +33 -0
- argot_engine-0.1.4/argot/tests/__init__.py +0 -0
- argot_engine-0.1.4/argot/tests/test_check.py +43 -0
- argot_engine-0.1.4/argot/tests/test_explain.py +54 -0
- argot_engine-0.1.4/argot/tests/test_extract_smoke.py +38 -0
- argot_engine-0.1.4/argot/tests/test_fetch.py +44 -0
- argot_engine-0.1.4/argot/tests/test_git_walk.py +48 -0
- argot_engine-0.1.4/argot/tests/test_jepa.py +70 -0
- argot_engine-0.1.4/argot/tests/test_tokenize.py +41 -0
- argot_engine-0.1.4/argot/tests/test_train_smoke.py +73 -0
- argot_engine-0.1.4/argot/tests/test_validate.py +87 -0
- argot_engine-0.1.4/argot/tokenize.py +82 -0
- argot_engine-0.1.4/argot/train.py +121 -0
- argot_engine-0.1.4/argot/validate.py +193 -0
- argot_engine-0.1.4/argot_engine.egg-info/PKG-INFO +12 -0
- argot_engine-0.1.4/argot_engine.egg-info/SOURCES.txt +34 -0
- argot_engine-0.1.4/argot_engine.egg-info/dependency_links.txt +1 -0
- argot_engine-0.1.4/argot_engine.egg-info/entry_points.txt +7 -0
- argot_engine-0.1.4/argot_engine.egg-info/requires.txt +8 -0
- argot_engine-0.1.4/argot_engine.egg-info/top_level.txt +1 -0
- argot_engine-0.1.4/pyproject.toml +61 -0
- argot_engine-0.1.4/setup.cfg +4 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: argot-engine
|
|
3
|
+
Version: 0.1.4
|
|
4
|
+
Requires-Python: >=3.11
|
|
5
|
+
Requires-Dist: pygit2==1.19.2
|
|
6
|
+
Requires-Dist: scikit-learn>=1.5.0
|
|
7
|
+
Requires-Dist: torch>=2.0.0
|
|
8
|
+
Requires-Dist: einops>=0.6.0
|
|
9
|
+
Requires-Dist: tree-sitter==0.23.2
|
|
10
|
+
Requires-Dist: tree-sitter-typescript==0.23.2
|
|
11
|
+
Requires-Dist: tree-sitter-javascript==0.23.1
|
|
12
|
+
Requires-Dist: tree-sitter-python==0.23.6
|
|
File without changes
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import joblib # type: ignore[import-untyped]
|
|
8
|
+
import pygit2
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from argot.git_walk import walk_commits
|
|
12
|
+
from argot.jepa.encoder import TokenEncoder
|
|
13
|
+
from argot.jepa.model import JEPAArgot
|
|
14
|
+
from argot.jepa.predictor import ArgotPredictor
|
|
15
|
+
from argot.tokenize import language_for_path, tokenize_lines
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _resolve_shas(repo: pygit2.Repository, ref: str) -> set[str]:
|
|
19
|
+
"""Parse a git range (A..B or bare ref) into a set of commit SHAs."""
|
|
20
|
+
if ".." in ref:
|
|
21
|
+
start_ref, end_ref = ref.split("..", 1)
|
|
22
|
+
else:
|
|
23
|
+
start_ref = ref + "^"
|
|
24
|
+
end_ref = ref
|
|
25
|
+
|
|
26
|
+
end_oid = repo.revparse_single(end_ref).id
|
|
27
|
+
try:
|
|
28
|
+
start_oid = repo.revparse_single(start_ref).id
|
|
29
|
+
except (pygit2.GitError, KeyError):
|
|
30
|
+
start_oid = None
|
|
31
|
+
|
|
32
|
+
shas: set[str] = set()
|
|
33
|
+
for commit in repo.walk(end_oid, pygit2.enums.SortMode.TOPOLOGICAL):
|
|
34
|
+
if start_oid is not None and commit.id == start_oid:
|
|
35
|
+
break
|
|
36
|
+
shas.add(str(commit.id))
|
|
37
|
+
return shas
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def main() -> None:
|
|
41
|
+
parser = argparse.ArgumentParser(description="Check code surprise with argot JEPA model")
|
|
42
|
+
parser.add_argument("repo_path")
|
|
43
|
+
parser.add_argument("ref")
|
|
44
|
+
parser.add_argument("--model", default=".argot/model.pkl")
|
|
45
|
+
parser.add_argument("--threshold", type=float, default=0.5)
|
|
46
|
+
args = parser.parse_args()
|
|
47
|
+
|
|
48
|
+
model_path = Path(args.model)
|
|
49
|
+
if not model_path.exists():
|
|
50
|
+
print(f"error: model not found at {model_path}", file=sys.stderr)
|
|
51
|
+
sys.exit(2)
|
|
52
|
+
|
|
53
|
+
bundle = joblib.load(model_path)
|
|
54
|
+
vectorizer = bundle["vectorizer"]
|
|
55
|
+
embed_dim: int = bundle["embed_dim"]
|
|
56
|
+
input_dim: int = bundle["input_dim"]
|
|
57
|
+
|
|
58
|
+
encoder = TokenEncoder(input_dim, embed_dim)
|
|
59
|
+
encoder.load_state_dict(bundle["encoder_state"])
|
|
60
|
+
predictor = ArgotPredictor(embed_dim=embed_dim)
|
|
61
|
+
predictor.load_state_dict(bundle["predictor_state"])
|
|
62
|
+
model = JEPAArgot(encoder, predictor)
|
|
63
|
+
model.eval()
|
|
64
|
+
|
|
65
|
+
repo = pygit2.Repository(args.repo_path)
|
|
66
|
+
shas = _resolve_shas(repo, args.ref)
|
|
67
|
+
if not shas:
|
|
68
|
+
print("No commits found in range", file=sys.stderr)
|
|
69
|
+
sys.exit(0)
|
|
70
|
+
|
|
71
|
+
results: list[tuple[float, str, int, str]] = []
|
|
72
|
+
|
|
73
|
+
context_lines = 50
|
|
74
|
+
with torch.no_grad():
|
|
75
|
+
for commit, file_path, post_blob, hunks in walk_commits(args.repo_path, shas):
|
|
76
|
+
lang = language_for_path(file_path)
|
|
77
|
+
if lang is None:
|
|
78
|
+
continue
|
|
79
|
+
try:
|
|
80
|
+
source_lines = post_blob.decode("utf-8", errors="replace").splitlines()
|
|
81
|
+
except Exception:
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
for hunk in hunks:
|
|
85
|
+
hunk_start = hunk.new_start - 1
|
|
86
|
+
hunk_end = hunk_start + hunk.new_lines
|
|
87
|
+
if hunk_start < 0 or hunk_end > len(source_lines):
|
|
88
|
+
continue
|
|
89
|
+
|
|
90
|
+
before_start = max(0, hunk_start - context_lines)
|
|
91
|
+
ctx_tokens = tokenize_lines(source_lines, lang, before_start, hunk_start)
|
|
92
|
+
hunk_tokens = tokenize_lines(source_lines, lang, hunk_start, hunk_end)
|
|
93
|
+
|
|
94
|
+
ctx_text = " ".join(t.text for t in ctx_tokens)
|
|
95
|
+
hunk_text = " ".join(t.text for t in hunk_tokens)
|
|
96
|
+
|
|
97
|
+
ctx_vec = torch.tensor(
|
|
98
|
+
vectorizer.transform([ctx_text]).toarray(), dtype=torch.float32
|
|
99
|
+
)
|
|
100
|
+
hunk_vec = torch.tensor(
|
|
101
|
+
vectorizer.transform([hunk_text]).toarray(), dtype=torch.float32
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
score = model.surprise(ctx_vec, hunk_vec).item()
|
|
105
|
+
results.append((score, file_path, hunk.new_start, str(commit.id)[:8]))
|
|
106
|
+
|
|
107
|
+
if not results:
|
|
108
|
+
print("No hunks found in range")
|
|
109
|
+
sys.exit(0)
|
|
110
|
+
|
|
111
|
+
results.sort(key=lambda r: r[0], reverse=True)
|
|
112
|
+
|
|
113
|
+
print(f"{'SURPRISE':>9} {'FILE':<48} {'LINE':>5} COMMIT")
|
|
114
|
+
for score, fp, line, sha in results:
|
|
115
|
+
print(f"{score:>9.4f} {fp:<48} {line:>5} {sha}")
|
|
116
|
+
|
|
117
|
+
if any(s > args.threshold for s, *_ in results):
|
|
118
|
+
sys.exit(1)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
if __name__ == "__main__":
|
|
122
|
+
main()
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Wire format for the dataset JSONL. Every hunk emitted conforms to this schema.
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
Language = Literal["typescript", "javascript", "python"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True, slots=True)
|
|
11
|
+
class Token:
|
|
12
|
+
text: str
|
|
13
|
+
node_type: str # tree-sitter node kind, e.g. "function_declaration"
|
|
14
|
+
start_line: int
|
|
15
|
+
end_line: int
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True, slots=True)
|
|
19
|
+
class HunkRecord:
|
|
20
|
+
commit_sha: str
|
|
21
|
+
file_path: str
|
|
22
|
+
language: Language
|
|
23
|
+
hunk_start_line: int
|
|
24
|
+
hunk_end_line: int
|
|
25
|
+
context_before: list[Token] # up to 50 lines before, tokenized
|
|
26
|
+
hunk_tokens: list[Token]
|
|
27
|
+
context_after: list[Token] # up to 50 lines after
|
|
28
|
+
parent_sha: str | None
|
|
29
|
+
author_date_iso: str
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import sys
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import joblib # type: ignore[import-untyped]
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pygit2
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from argot.check import _resolve_shas
|
|
15
|
+
from argot.git_walk import walk_commits
|
|
16
|
+
from argot.jepa.encoder import TokenEncoder
|
|
17
|
+
from argot.jepa.model import JEPAArgot
|
|
18
|
+
from argot.jepa.predictor import ArgotPredictor
|
|
19
|
+
from argot.tokenize import language_for_path, tokenize_lines
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def percentile_rank(value: float, distribution: list[float]) -> float:
|
|
23
|
+
arr = np.array(distribution)
|
|
24
|
+
return float(np.mean(arr < value) * 100)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def select_style_examples(records: list[dict[str, Any]], *, n: int = 5) -> list[dict[str, Any]]:
|
|
28
|
+
"""Pick n lowest-surprise records, one per file where possible."""
|
|
29
|
+
sorted_records = sorted(records, key=lambda r: r["_score"])
|
|
30
|
+
seen_files: set[str] = set()
|
|
31
|
+
diverse: list[dict[str, Any]] = []
|
|
32
|
+
remainder: list[dict[str, Any]] = []
|
|
33
|
+
|
|
34
|
+
for r in sorted_records:
|
|
35
|
+
fp = r.get("file_path", "")
|
|
36
|
+
if fp not in seen_files:
|
|
37
|
+
seen_files.add(fp)
|
|
38
|
+
diverse.append(r)
|
|
39
|
+
else:
|
|
40
|
+
remainder.append(r)
|
|
41
|
+
if len(diverse) >= n:
|
|
42
|
+
break
|
|
43
|
+
|
|
44
|
+
result = diverse[:n]
|
|
45
|
+
if len(result) < n:
|
|
46
|
+
result += remainder[: n - len(result)]
|
|
47
|
+
return result
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _score_dataset(
|
|
51
|
+
model: JEPAArgot,
|
|
52
|
+
vectorizer: Any,
|
|
53
|
+
records: list[dict[str, Any]],
|
|
54
|
+
) -> list[dict[str, Any]]:
|
|
55
|
+
ctx_texts = [" ".join(t["text"] for t in r["context_before"]) for r in records]
|
|
56
|
+
hunk_texts = [" ".join(t["text"] for t in r["hunk_tokens"]) for r in records]
|
|
57
|
+
ctx_x = torch.tensor(vectorizer.transform(ctx_texts).toarray(), dtype=torch.float32)
|
|
58
|
+
hunk_x = torch.tensor(vectorizer.transform(hunk_texts).toarray(), dtype=torch.float32)
|
|
59
|
+
model.eval()
|
|
60
|
+
scored = []
|
|
61
|
+
with torch.no_grad():
|
|
62
|
+
for i, record in enumerate(records):
|
|
63
|
+
score = model.surprise(ctx_x[i : i + 1], hunk_x[i : i + 1]).item()
|
|
64
|
+
scored.append({**record, "_score": score})
|
|
65
|
+
return scored
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def main() -> None:
|
|
69
|
+
parser = argparse.ArgumentParser(description="Explain style anomalies in a git ref")
|
|
70
|
+
parser.add_argument("repo_path")
|
|
71
|
+
parser.add_argument("ref")
|
|
72
|
+
parser.add_argument("--model", default=".argot/model.pkl")
|
|
73
|
+
parser.add_argument("--dataset", default=".argot/dataset.jsonl")
|
|
74
|
+
parser.add_argument("--threshold-percentile", type=float, default=75.0)
|
|
75
|
+
parser.add_argument("--examples", type=int, default=5)
|
|
76
|
+
args = parser.parse_args()
|
|
77
|
+
|
|
78
|
+
model_path = Path(args.model)
|
|
79
|
+
if not model_path.exists():
|
|
80
|
+
print(f"error: model not found at {model_path}", file=sys.stderr)
|
|
81
|
+
sys.exit(2)
|
|
82
|
+
|
|
83
|
+
dataset_path = Path(args.dataset)
|
|
84
|
+
if not dataset_path.exists():
|
|
85
|
+
print(f"error: dataset not found at {dataset_path}", file=sys.stderr)
|
|
86
|
+
sys.exit(2)
|
|
87
|
+
|
|
88
|
+
bundle = joblib.load(model_path)
|
|
89
|
+
vectorizer = bundle["vectorizer"]
|
|
90
|
+
input_dim: int = bundle["input_dim"]
|
|
91
|
+
embed_dim: int = bundle["embed_dim"]
|
|
92
|
+
|
|
93
|
+
encoder = TokenEncoder(input_dim, embed_dim)
|
|
94
|
+
encoder.load_state_dict(bundle["encoder_state"])
|
|
95
|
+
predictor = ArgotPredictor(embed_dim=embed_dim)
|
|
96
|
+
predictor.load_state_dict(bundle["predictor_state"])
|
|
97
|
+
model = JEPAArgot(encoder, predictor)
|
|
98
|
+
model.eval()
|
|
99
|
+
|
|
100
|
+
raw_dataset = [
|
|
101
|
+
json.loads(line) for line in dataset_path.read_text().splitlines() if line.strip()
|
|
102
|
+
]
|
|
103
|
+
scored_dataset = _score_dataset(model, vectorizer, raw_dataset)
|
|
104
|
+
distribution = [r["_score"] for r in scored_dataset]
|
|
105
|
+
style_examples = select_style_examples(scored_dataset, n=args.examples)
|
|
106
|
+
example_texts = [" ".join(t["text"] for t in r["hunk_tokens"]) for r in style_examples]
|
|
107
|
+
|
|
108
|
+
repo = pygit2.Repository(args.repo_path)
|
|
109
|
+
shas = _resolve_shas(repo, args.ref)
|
|
110
|
+
if not shas:
|
|
111
|
+
sys.exit(0)
|
|
112
|
+
|
|
113
|
+
context_lines = 50
|
|
114
|
+
with torch.no_grad():
|
|
115
|
+
for commit, file_path, post_blob, hunks in walk_commits(args.repo_path, shas):
|
|
116
|
+
lang = language_for_path(file_path)
|
|
117
|
+
if lang is None:
|
|
118
|
+
continue
|
|
119
|
+
try:
|
|
120
|
+
source_lines = post_blob.decode("utf-8", errors="replace").splitlines()
|
|
121
|
+
except Exception:
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
for hunk in hunks:
|
|
125
|
+
hunk_start = hunk.new_start - 1
|
|
126
|
+
hunk_end = hunk_start + hunk.new_lines
|
|
127
|
+
if hunk_start < 0 or hunk_end > len(source_lines):
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
before_start = max(0, hunk_start - context_lines)
|
|
131
|
+
ctx_tokens = tokenize_lines(source_lines, lang, before_start, hunk_start)
|
|
132
|
+
hunk_tokens = tokenize_lines(source_lines, lang, hunk_start, hunk_end)
|
|
133
|
+
|
|
134
|
+
ctx_text = " ".join(t.text for t in ctx_tokens)
|
|
135
|
+
hunk_text = " ".join(t.text for t in hunk_tokens)
|
|
136
|
+
|
|
137
|
+
ctx_vec = torch.tensor(
|
|
138
|
+
vectorizer.transform([ctx_text]).toarray(), dtype=torch.float32
|
|
139
|
+
)
|
|
140
|
+
hunk_vec = torch.tensor(
|
|
141
|
+
vectorizer.transform([hunk_text]).toarray(), dtype=torch.float32
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
score = model.surprise(ctx_vec, hunk_vec).item()
|
|
145
|
+
pct = percentile_rank(score, distribution)
|
|
146
|
+
|
|
147
|
+
if pct < args.threshold_percentile:
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
print(
|
|
151
|
+
json.dumps(
|
|
152
|
+
{
|
|
153
|
+
"file_path": file_path,
|
|
154
|
+
"line": hunk.new_start,
|
|
155
|
+
"commit": str(commit.id)[:8],
|
|
156
|
+
"surprise": round(score, 4),
|
|
157
|
+
"percentile": round(pct, 1),
|
|
158
|
+
"hunk_text": hunk_text,
|
|
159
|
+
"context_text": ctx_text,
|
|
160
|
+
"style_examples": example_texts,
|
|
161
|
+
}
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
if __name__ == "__main__":
|
|
167
|
+
main()
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import sys
|
|
6
|
+
from dataclasses import asdict
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import pygit2
|
|
10
|
+
|
|
11
|
+
from argot.dataset import HunkRecord
|
|
12
|
+
from argot.git_walk import walk_repo
|
|
13
|
+
from argot.tokenize import language_for_path
|
|
14
|
+
|
|
15
|
+
CONTEXT_LINES = 50
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _extract_context(
|
|
19
|
+
source_lines: list[str],
|
|
20
|
+
hunk_start: int,
|
|
21
|
+
hunk_end: int,
|
|
22
|
+
) -> tuple[list[str], list[str], list[str]]:
|
|
23
|
+
before_start = max(0, hunk_start - CONTEXT_LINES)
|
|
24
|
+
after_end = min(len(source_lines), hunk_end + CONTEXT_LINES)
|
|
25
|
+
return (
|
|
26
|
+
source_lines[before_start:hunk_start],
|
|
27
|
+
source_lines[hunk_start:hunk_end],
|
|
28
|
+
source_lines[hunk_end:after_end],
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def main() -> None:
|
|
33
|
+
parser = argparse.ArgumentParser(description="Extract dataset from git history")
|
|
34
|
+
parser.add_argument("repo_path", help="Path to git repository")
|
|
35
|
+
parser.add_argument("--out", default=".argot/dataset.jsonl", help="Output JSONL path")
|
|
36
|
+
parser.add_argument("--limit", type=int, default=None, help="Max number of records to emit")
|
|
37
|
+
args = parser.parse_args()
|
|
38
|
+
|
|
39
|
+
repo_path = args.repo_path
|
|
40
|
+
out_path = Path(args.out)
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
pygit2.Repository(repo_path)
|
|
44
|
+
except pygit2.GitError:
|
|
45
|
+
print(f"error: repository not found at {repo_path!r}", file=sys.stderr)
|
|
46
|
+
sys.exit(2)
|
|
47
|
+
|
|
48
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
|
|
50
|
+
count = 0
|
|
51
|
+
|
|
52
|
+
with open(out_path, "w") as fh:
|
|
53
|
+
for commit, file_path, post_blob, hunks in walk_repo(repo_path):
|
|
54
|
+
lang = language_for_path(file_path)
|
|
55
|
+
if lang is None:
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
source_lines = post_blob.decode("utf-8", errors="replace").splitlines()
|
|
60
|
+
except Exception:
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
parent_sha = str(commit.parents[0].id) if commit.parents else None
|
|
64
|
+
author_date_iso = commit.author.time
|
|
65
|
+
|
|
66
|
+
for hunk in hunks:
|
|
67
|
+
hunk_start = hunk.new_start - 1 # convert to 0-indexed
|
|
68
|
+
hunk_end = hunk_start + hunk.new_lines
|
|
69
|
+
|
|
70
|
+
if hunk_start < 0 or hunk_end > len(source_lines):
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
ctx_before_lines, hunk_lines, ctx_after_lines = _extract_context(
|
|
74
|
+
source_lines, hunk_start, hunk_end
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
before_start_abs = max(0, hunk_start - CONTEXT_LINES)
|
|
78
|
+
after_start_abs = hunk_end
|
|
79
|
+
|
|
80
|
+
from argot.tokenize import tokenize_lines
|
|
81
|
+
|
|
82
|
+
context_before = tokenize_lines(source_lines, lang, before_start_abs, hunk_start)
|
|
83
|
+
hunk_tokens = tokenize_lines(source_lines, lang, hunk_start, hunk_end)
|
|
84
|
+
context_after = tokenize_lines(
|
|
85
|
+
source_lines,
|
|
86
|
+
lang,
|
|
87
|
+
after_start_abs,
|
|
88
|
+
min(len(source_lines), after_start_abs + CONTEXT_LINES),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
record = HunkRecord(
|
|
92
|
+
commit_sha=str(commit.id),
|
|
93
|
+
file_path=file_path,
|
|
94
|
+
language=lang,
|
|
95
|
+
hunk_start_line=hunk_start,
|
|
96
|
+
hunk_end_line=hunk_end,
|
|
97
|
+
context_before=context_before,
|
|
98
|
+
hunk_tokens=hunk_tokens,
|
|
99
|
+
context_after=context_after,
|
|
100
|
+
parent_sha=parent_sha,
|
|
101
|
+
author_date_iso=str(author_date_iso),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
fh.write(json.dumps(asdict(record)))
|
|
105
|
+
fh.write("\n")
|
|
106
|
+
count += 1
|
|
107
|
+
|
|
108
|
+
if args.limit is not None and count >= args.limit:
|
|
109
|
+
print(f"Reached limit of {args.limit} records", file=sys.stderr)
|
|
110
|
+
print(f"Wrote {count} records to {out_path}")
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
if count == 0:
|
|
114
|
+
print("error: no hunks found — repository may have no history", file=sys.stderr)
|
|
115
|
+
sys.exit(2)
|
|
116
|
+
|
|
117
|
+
print(f"Wrote {count} records to {out_path}")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
if __name__ == "__main__":
|
|
121
|
+
main()
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import subprocess
|
|
4
|
+
import sys
|
|
5
|
+
import tomllib
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from argot.extract import main as extract_main
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class RepoConfig:
|
|
14
|
+
name: str
|
|
15
|
+
url: str
|
|
16
|
+
limit: int
|
|
17
|
+
label: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_repos_config(config_path: Path) -> list[RepoConfig]:
|
|
21
|
+
data = tomllib.loads(config_path.read_text())
|
|
22
|
+
return [RepoConfig(**r) for r in data["repos"]]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def merge_jsonl(sources: list[tuple[Path, str]], dest: Path) -> int:
|
|
26
|
+
import json
|
|
27
|
+
|
|
28
|
+
count = 0
|
|
29
|
+
with dest.open("w") as fh:
|
|
30
|
+
for src, repo_name in sources:
|
|
31
|
+
for line in src.read_text().splitlines():
|
|
32
|
+
if line.strip():
|
|
33
|
+
record = json.loads(line)
|
|
34
|
+
record["_repo"] = repo_name
|
|
35
|
+
fh.write(json.dumps(record) + "\n")
|
|
36
|
+
count += 1
|
|
37
|
+
return count
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _clone_or_update(url: str, dest: Path, depth: int = 2000) -> None:
|
|
41
|
+
if (dest / ".git").exists():
|
|
42
|
+
print(f" already cloned, skipping: {dest}")
|
|
43
|
+
return
|
|
44
|
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
print(f" cloning {url} → {dest}")
|
|
46
|
+
subprocess.run(
|
|
47
|
+
["git", "clone", "--depth", str(depth), url, str(dest)],
|
|
48
|
+
check=True,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def main() -> None:
|
|
53
|
+
project_root = Path(__file__).parents[2] # engine/argot/fetch.py → project root
|
|
54
|
+
config_path = project_root / "training" / "repos.toml"
|
|
55
|
+
|
|
56
|
+
if not config_path.exists():
|
|
57
|
+
print(f"error: config not found at {config_path}", file=sys.stderr)
|
|
58
|
+
sys.exit(2)
|
|
59
|
+
|
|
60
|
+
repos = load_repos_config(config_path)
|
|
61
|
+
repos_dir = project_root / ".argot" / "repos"
|
|
62
|
+
part_files: list[tuple[Path, str]] = []
|
|
63
|
+
|
|
64
|
+
for repo in repos:
|
|
65
|
+
repo_dir = repos_dir / repo.name
|
|
66
|
+
part_out = repos_dir / f"{repo.name}.jsonl"
|
|
67
|
+
print(f"\n[{repo.name}] label={repo.label} limit={repo.limit}")
|
|
68
|
+
_clone_or_update(repo.url, repo_dir)
|
|
69
|
+
|
|
70
|
+
sys.argv = [
|
|
71
|
+
"extract",
|
|
72
|
+
str(repo_dir),
|
|
73
|
+
"--out",
|
|
74
|
+
str(part_out),
|
|
75
|
+
"--limit",
|
|
76
|
+
str(repo.limit),
|
|
77
|
+
]
|
|
78
|
+
extract_main()
|
|
79
|
+
part_files.append((part_out, repo.name))
|
|
80
|
+
|
|
81
|
+
out = project_root / ".argot" / "training.jsonl"
|
|
82
|
+
total = merge_jsonl(part_files, out)
|
|
83
|
+
print(f"\nMerged {total} records → {out}")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
if __name__ == "__main__":
|
|
87
|
+
main()
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import pygit2
|
|
7
|
+
|
|
8
|
+
SUPPORTED_EXTENSIONS = frozenset({".ts", ".tsx", ".js", ".jsx", ".py"})
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _extension(path: str) -> str:
|
|
12
|
+
return Path(path).suffix.lower()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def walk_repo(repo_path: str) -> Iterator[tuple[pygit2.Commit, str, bytes, list[pygit2.DiffHunk]]]:
|
|
16
|
+
"""Yield (commit, file_path, post_blob_content, hunks) for each changed supported file.
|
|
17
|
+
|
|
18
|
+
Skips merge commits. Yields nothing if the repo has no commits.
|
|
19
|
+
"""
|
|
20
|
+
repo = pygit2.Repository(repo_path)
|
|
21
|
+
if repo.is_empty:
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
start_oid = repo.head.target
|
|
26
|
+
except pygit2.GitError:
|
|
27
|
+
branch_refs = [r for r in repo.references if r.startswith("refs/heads/")]
|
|
28
|
+
if not branch_refs:
|
|
29
|
+
return
|
|
30
|
+
start_oid = repo.references[branch_refs[0]].target
|
|
31
|
+
|
|
32
|
+
for commit in repo.walk(start_oid, pygit2.enums.SortMode.TOPOLOGICAL):
|
|
33
|
+
if len(commit.parents) != 1:
|
|
34
|
+
# Skip merge commits and root commits
|
|
35
|
+
continue
|
|
36
|
+
|
|
37
|
+
parent = commit.parents[0]
|
|
38
|
+
diff = parent.tree.diff_to_tree(commit.tree)
|
|
39
|
+
diff.find_similar()
|
|
40
|
+
|
|
41
|
+
for patch in diff:
|
|
42
|
+
if patch is None:
|
|
43
|
+
continue
|
|
44
|
+
file_path = patch.delta.new_file.path
|
|
45
|
+
if _extension(file_path) not in SUPPORTED_EXTENSIONS:
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
hunks = list(patch.hunks)
|
|
49
|
+
if not hunks:
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
obj = commit.tree / file_path
|
|
54
|
+
if not isinstance(obj, pygit2.Blob):
|
|
55
|
+
continue
|
|
56
|
+
post_blob_content = obj.data
|
|
57
|
+
except KeyError:
|
|
58
|
+
# File deleted in this commit
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
yield commit, file_path, post_blob_content, hunks
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def walk_commits(
|
|
65
|
+
repo_path: str, shas: set[str]
|
|
66
|
+
) -> Iterator[tuple[pygit2.Commit, str, bytes, list[pygit2.DiffHunk]]]:
|
|
67
|
+
"""Yield (commit, file_path, post_blob_content, hunks) for commits in shas."""
|
|
68
|
+
repo = pygit2.Repository(repo_path)
|
|
69
|
+
if repo.is_empty:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
start_oid = repo.head.target
|
|
74
|
+
except pygit2.GitError:
|
|
75
|
+
branch_refs = [r for r in repo.references if r.startswith("refs/heads/")]
|
|
76
|
+
if not branch_refs:
|
|
77
|
+
return
|
|
78
|
+
start_oid = repo.references[branch_refs[0]].target
|
|
79
|
+
|
|
80
|
+
for commit in repo.walk(start_oid, pygit2.enums.SortMode.TOPOLOGICAL):
|
|
81
|
+
if str(commit.id) not in shas:
|
|
82
|
+
continue
|
|
83
|
+
if len(commit.parents) != 1:
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
parent = commit.parents[0]
|
|
87
|
+
diff = parent.tree.diff_to_tree(commit.tree)
|
|
88
|
+
diff.find_similar()
|
|
89
|
+
|
|
90
|
+
for patch in diff:
|
|
91
|
+
if patch is None:
|
|
92
|
+
continue
|
|
93
|
+
file_path = patch.delta.new_file.path
|
|
94
|
+
if _extension(file_path) not in SUPPORTED_EXTENSIONS:
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
hunks = list(patch.hunks)
|
|
98
|
+
if not hunks:
|
|
99
|
+
continue
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
obj = commit.tree / file_path
|
|
103
|
+
if not isinstance(obj, pygit2.Blob):
|
|
104
|
+
continue
|
|
105
|
+
post_blob_content = obj.data
|
|
106
|
+
except KeyError:
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
yield commit, file_path, post_blob_content, hunks
|
|
File without changes
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TokenEncoder(nn.Module):
|
|
10
|
+
"""MLP encoder from TF-IDF vectors to latent embeddings."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, input_dim: int = 5000, embed_dim: int = 192) -> None:
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.net = nn.Sequential(
|
|
15
|
+
nn.Linear(input_dim, 512),
|
|
16
|
+
nn.GELU(),
|
|
17
|
+
nn.Linear(512, embed_dim),
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
21
|
+
"""x: (B, input_dim) → (B, embed_dim)"""
|
|
22
|
+
return cast(torch.Tensor, self.net(x))
|