argot-engine 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
argot/__init__.py ADDED
File without changes
argot/__main__.py ADDED
@@ -0,0 +1,3 @@
1
+ from argot.extract import main
2
+
3
+ main()
argot/check.py ADDED
@@ -0,0 +1,122 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import joblib # type: ignore[import-untyped]
8
+ import pygit2
9
+ import torch
10
+
11
+ from argot.git_walk import walk_commits
12
+ from argot.jepa.encoder import TokenEncoder
13
+ from argot.jepa.model import JEPAArgot
14
+ from argot.jepa.predictor import ArgotPredictor
15
+ from argot.tokenize import language_for_path, tokenize_lines
16
+
17
+
18
+ def _resolve_shas(repo: pygit2.Repository, ref: str) -> set[str]:
19
+ """Parse a git range (A..B or bare ref) into a set of commit SHAs."""
20
+ if ".." in ref:
21
+ start_ref, end_ref = ref.split("..", 1)
22
+ else:
23
+ start_ref = ref + "^"
24
+ end_ref = ref
25
+
26
+ end_oid = repo.revparse_single(end_ref).id
27
+ try:
28
+ start_oid = repo.revparse_single(start_ref).id
29
+ except (pygit2.GitError, KeyError):
30
+ start_oid = None
31
+
32
+ shas: set[str] = set()
33
+ for commit in repo.walk(end_oid, pygit2.enums.SortMode.TOPOLOGICAL):
34
+ if start_oid is not None and commit.id == start_oid:
35
+ break
36
+ shas.add(str(commit.id))
37
+ return shas
38
+
39
+
40
+ def main() -> None:
41
+ parser = argparse.ArgumentParser(description="Check code surprise with argot JEPA model")
42
+ parser.add_argument("repo_path")
43
+ parser.add_argument("ref")
44
+ parser.add_argument("--model", default=".argot/model.pkl")
45
+ parser.add_argument("--threshold", type=float, default=0.5)
46
+ args = parser.parse_args()
47
+
48
+ model_path = Path(args.model)
49
+ if not model_path.exists():
50
+ print(f"error: model not found at {model_path}", file=sys.stderr)
51
+ sys.exit(2)
52
+
53
+ bundle = joblib.load(model_path)
54
+ vectorizer = bundle["vectorizer"]
55
+ embed_dim: int = bundle["embed_dim"]
56
+ input_dim: int = bundle["input_dim"]
57
+
58
+ encoder = TokenEncoder(input_dim, embed_dim)
59
+ encoder.load_state_dict(bundle["encoder_state"])
60
+ predictor = ArgotPredictor(embed_dim=embed_dim)
61
+ predictor.load_state_dict(bundle["predictor_state"])
62
+ model = JEPAArgot(encoder, predictor)
63
+ model.eval()
64
+
65
+ repo = pygit2.Repository(args.repo_path)
66
+ shas = _resolve_shas(repo, args.ref)
67
+ if not shas:
68
+ print("No commits found in range", file=sys.stderr)
69
+ sys.exit(0)
70
+
71
+ results: list[tuple[float, str, int, str]] = []
72
+
73
+ context_lines = 50
74
+ with torch.no_grad():
75
+ for commit, file_path, post_blob, hunks in walk_commits(args.repo_path, shas):
76
+ lang = language_for_path(file_path)
77
+ if lang is None:
78
+ continue
79
+ try:
80
+ source_lines = post_blob.decode("utf-8", errors="replace").splitlines()
81
+ except Exception:
82
+ continue
83
+
84
+ for hunk in hunks:
85
+ hunk_start = hunk.new_start - 1
86
+ hunk_end = hunk_start + hunk.new_lines
87
+ if hunk_start < 0 or hunk_end > len(source_lines):
88
+ continue
89
+
90
+ before_start = max(0, hunk_start - context_lines)
91
+ ctx_tokens = tokenize_lines(source_lines, lang, before_start, hunk_start)
92
+ hunk_tokens = tokenize_lines(source_lines, lang, hunk_start, hunk_end)
93
+
94
+ ctx_text = " ".join(t.text for t in ctx_tokens)
95
+ hunk_text = " ".join(t.text for t in hunk_tokens)
96
+
97
+ ctx_vec = torch.tensor(
98
+ vectorizer.transform([ctx_text]).toarray(), dtype=torch.float32
99
+ )
100
+ hunk_vec = torch.tensor(
101
+ vectorizer.transform([hunk_text]).toarray(), dtype=torch.float32
102
+ )
103
+
104
+ score = model.surprise(ctx_vec, hunk_vec).item()
105
+ results.append((score, file_path, hunk.new_start, str(commit.id)[:8]))
106
+
107
+ if not results:
108
+ print("No hunks found in range")
109
+ sys.exit(0)
110
+
111
+ results.sort(key=lambda r: r[0], reverse=True)
112
+
113
+ print(f"{'SURPRISE':>9} {'FILE':<48} {'LINE':>5} COMMIT")
114
+ for score, fp, line, sha in results:
115
+ print(f"{score:>9.4f} {fp:<48} {line:>5} {sha}")
116
+
117
+ if any(s > args.threshold for s, *_ in results):
118
+ sys.exit(1)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main()
argot/dataset.py ADDED
@@ -0,0 +1,29 @@
1
+ # Wire format for the dataset JSONL. Every hunk emitted conforms to this schema.
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Literal
6
+
7
+ Language = Literal["typescript", "javascript", "python"]
8
+
9
+
10
+ @dataclass(frozen=True, slots=True)
11
+ class Token:
12
+ text: str
13
+ node_type: str # tree-sitter node kind, e.g. "function_declaration"
14
+ start_line: int
15
+ end_line: int
16
+
17
+
18
+ @dataclass(frozen=True, slots=True)
19
+ class HunkRecord:
20
+ commit_sha: str
21
+ file_path: str
22
+ language: Language
23
+ hunk_start_line: int
24
+ hunk_end_line: int
25
+ context_before: list[Token] # up to 50 lines before, tokenized
26
+ hunk_tokens: list[Token]
27
+ context_after: list[Token] # up to 50 lines after
28
+ parent_sha: str | None
29
+ author_date_iso: str
argot/explain.py ADDED
@@ -0,0 +1,167 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import joblib # type: ignore[import-untyped]
10
+ import numpy as np
11
+ import pygit2
12
+ import torch
13
+
14
+ from argot.check import _resolve_shas
15
+ from argot.git_walk import walk_commits
16
+ from argot.jepa.encoder import TokenEncoder
17
+ from argot.jepa.model import JEPAArgot
18
+ from argot.jepa.predictor import ArgotPredictor
19
+ from argot.tokenize import language_for_path, tokenize_lines
20
+
21
+
22
+ def percentile_rank(value: float, distribution: list[float]) -> float:
23
+ arr = np.array(distribution)
24
+ return float(np.mean(arr < value) * 100)
25
+
26
+
27
+ def select_style_examples(records: list[dict[str, Any]], *, n: int = 5) -> list[dict[str, Any]]:
28
+ """Pick n lowest-surprise records, one per file where possible."""
29
+ sorted_records = sorted(records, key=lambda r: r["_score"])
30
+ seen_files: set[str] = set()
31
+ diverse: list[dict[str, Any]] = []
32
+ remainder: list[dict[str, Any]] = []
33
+
34
+ for r in sorted_records:
35
+ fp = r.get("file_path", "")
36
+ if fp not in seen_files:
37
+ seen_files.add(fp)
38
+ diverse.append(r)
39
+ else:
40
+ remainder.append(r)
41
+ if len(diverse) >= n:
42
+ break
43
+
44
+ result = diverse[:n]
45
+ if len(result) < n:
46
+ result += remainder[: n - len(result)]
47
+ return result
48
+
49
+
50
+ def _score_dataset(
51
+ model: JEPAArgot,
52
+ vectorizer: Any,
53
+ records: list[dict[str, Any]],
54
+ ) -> list[dict[str, Any]]:
55
+ ctx_texts = [" ".join(t["text"] for t in r["context_before"]) for r in records]
56
+ hunk_texts = [" ".join(t["text"] for t in r["hunk_tokens"]) for r in records]
57
+ ctx_x = torch.tensor(vectorizer.transform(ctx_texts).toarray(), dtype=torch.float32)
58
+ hunk_x = torch.tensor(vectorizer.transform(hunk_texts).toarray(), dtype=torch.float32)
59
+ model.eval()
60
+ scored = []
61
+ with torch.no_grad():
62
+ for i, record in enumerate(records):
63
+ score = model.surprise(ctx_x[i : i + 1], hunk_x[i : i + 1]).item()
64
+ scored.append({**record, "_score": score})
65
+ return scored
66
+
67
+
68
+ def main() -> None:
69
+ parser = argparse.ArgumentParser(description="Explain style anomalies in a git ref")
70
+ parser.add_argument("repo_path")
71
+ parser.add_argument("ref")
72
+ parser.add_argument("--model", default=".argot/model.pkl")
73
+ parser.add_argument("--dataset", default=".argot/dataset.jsonl")
74
+ parser.add_argument("--threshold-percentile", type=float, default=75.0)
75
+ parser.add_argument("--examples", type=int, default=5)
76
+ args = parser.parse_args()
77
+
78
+ model_path = Path(args.model)
79
+ if not model_path.exists():
80
+ print(f"error: model not found at {model_path}", file=sys.stderr)
81
+ sys.exit(2)
82
+
83
+ dataset_path = Path(args.dataset)
84
+ if not dataset_path.exists():
85
+ print(f"error: dataset not found at {dataset_path}", file=sys.stderr)
86
+ sys.exit(2)
87
+
88
+ bundle = joblib.load(model_path)
89
+ vectorizer = bundle["vectorizer"]
90
+ input_dim: int = bundle["input_dim"]
91
+ embed_dim: int = bundle["embed_dim"]
92
+
93
+ encoder = TokenEncoder(input_dim, embed_dim)
94
+ encoder.load_state_dict(bundle["encoder_state"])
95
+ predictor = ArgotPredictor(embed_dim=embed_dim)
96
+ predictor.load_state_dict(bundle["predictor_state"])
97
+ model = JEPAArgot(encoder, predictor)
98
+ model.eval()
99
+
100
+ raw_dataset = [
101
+ json.loads(line) for line in dataset_path.read_text().splitlines() if line.strip()
102
+ ]
103
+ scored_dataset = _score_dataset(model, vectorizer, raw_dataset)
104
+ distribution = [r["_score"] for r in scored_dataset]
105
+ style_examples = select_style_examples(scored_dataset, n=args.examples)
106
+ example_texts = [" ".join(t["text"] for t in r["hunk_tokens"]) for r in style_examples]
107
+
108
+ repo = pygit2.Repository(args.repo_path)
109
+ shas = _resolve_shas(repo, args.ref)
110
+ if not shas:
111
+ sys.exit(0)
112
+
113
+ context_lines = 50
114
+ with torch.no_grad():
115
+ for commit, file_path, post_blob, hunks in walk_commits(args.repo_path, shas):
116
+ lang = language_for_path(file_path)
117
+ if lang is None:
118
+ continue
119
+ try:
120
+ source_lines = post_blob.decode("utf-8", errors="replace").splitlines()
121
+ except Exception:
122
+ continue
123
+
124
+ for hunk in hunks:
125
+ hunk_start = hunk.new_start - 1
126
+ hunk_end = hunk_start + hunk.new_lines
127
+ if hunk_start < 0 or hunk_end > len(source_lines):
128
+ continue
129
+
130
+ before_start = max(0, hunk_start - context_lines)
131
+ ctx_tokens = tokenize_lines(source_lines, lang, before_start, hunk_start)
132
+ hunk_tokens = tokenize_lines(source_lines, lang, hunk_start, hunk_end)
133
+
134
+ ctx_text = " ".join(t.text for t in ctx_tokens)
135
+ hunk_text = " ".join(t.text for t in hunk_tokens)
136
+
137
+ ctx_vec = torch.tensor(
138
+ vectorizer.transform([ctx_text]).toarray(), dtype=torch.float32
139
+ )
140
+ hunk_vec = torch.tensor(
141
+ vectorizer.transform([hunk_text]).toarray(), dtype=torch.float32
142
+ )
143
+
144
+ score = model.surprise(ctx_vec, hunk_vec).item()
145
+ pct = percentile_rank(score, distribution)
146
+
147
+ if pct < args.threshold_percentile:
148
+ continue
149
+
150
+ print(
151
+ json.dumps(
152
+ {
153
+ "file_path": file_path,
154
+ "line": hunk.new_start,
155
+ "commit": str(commit.id)[:8],
156
+ "surprise": round(score, 4),
157
+ "percentile": round(pct, 1),
158
+ "hunk_text": hunk_text,
159
+ "context_text": ctx_text,
160
+ "style_examples": example_texts,
161
+ }
162
+ )
163
+ )
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main()
argot/extract.py ADDED
@@ -0,0 +1,121 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import sys
6
+ from dataclasses import asdict
7
+ from pathlib import Path
8
+
9
+ import pygit2
10
+
11
+ from argot.dataset import HunkRecord
12
+ from argot.git_walk import walk_repo
13
+ from argot.tokenize import language_for_path
14
+
15
+ CONTEXT_LINES = 50
16
+
17
+
18
+ def _extract_context(
19
+ source_lines: list[str],
20
+ hunk_start: int,
21
+ hunk_end: int,
22
+ ) -> tuple[list[str], list[str], list[str]]:
23
+ before_start = max(0, hunk_start - CONTEXT_LINES)
24
+ after_end = min(len(source_lines), hunk_end + CONTEXT_LINES)
25
+ return (
26
+ source_lines[before_start:hunk_start],
27
+ source_lines[hunk_start:hunk_end],
28
+ source_lines[hunk_end:after_end],
29
+ )
30
+
31
+
32
+ def main() -> None:
33
+ parser = argparse.ArgumentParser(description="Extract dataset from git history")
34
+ parser.add_argument("repo_path", help="Path to git repository")
35
+ parser.add_argument("--out", default=".argot/dataset.jsonl", help="Output JSONL path")
36
+ parser.add_argument("--limit", type=int, default=None, help="Max number of records to emit")
37
+ args = parser.parse_args()
38
+
39
+ repo_path = args.repo_path
40
+ out_path = Path(args.out)
41
+
42
+ try:
43
+ pygit2.Repository(repo_path)
44
+ except pygit2.GitError:
45
+ print(f"error: repository not found at {repo_path!r}", file=sys.stderr)
46
+ sys.exit(2)
47
+
48
+ out_path.parent.mkdir(parents=True, exist_ok=True)
49
+
50
+ count = 0
51
+
52
+ with open(out_path, "w") as fh:
53
+ for commit, file_path, post_blob, hunks in walk_repo(repo_path):
54
+ lang = language_for_path(file_path)
55
+ if lang is None:
56
+ continue
57
+
58
+ try:
59
+ source_lines = post_blob.decode("utf-8", errors="replace").splitlines()
60
+ except Exception:
61
+ continue
62
+
63
+ parent_sha = str(commit.parents[0].id) if commit.parents else None
64
+ author_date_iso = commit.author.time
65
+
66
+ for hunk in hunks:
67
+ hunk_start = hunk.new_start - 1 # convert to 0-indexed
68
+ hunk_end = hunk_start + hunk.new_lines
69
+
70
+ if hunk_start < 0 or hunk_end > len(source_lines):
71
+ continue
72
+
73
+ ctx_before_lines, hunk_lines, ctx_after_lines = _extract_context(
74
+ source_lines, hunk_start, hunk_end
75
+ )
76
+
77
+ before_start_abs = max(0, hunk_start - CONTEXT_LINES)
78
+ after_start_abs = hunk_end
79
+
80
+ from argot.tokenize import tokenize_lines
81
+
82
+ context_before = tokenize_lines(source_lines, lang, before_start_abs, hunk_start)
83
+ hunk_tokens = tokenize_lines(source_lines, lang, hunk_start, hunk_end)
84
+ context_after = tokenize_lines(
85
+ source_lines,
86
+ lang,
87
+ after_start_abs,
88
+ min(len(source_lines), after_start_abs + CONTEXT_LINES),
89
+ )
90
+
91
+ record = HunkRecord(
92
+ commit_sha=str(commit.id),
93
+ file_path=file_path,
94
+ language=lang,
95
+ hunk_start_line=hunk_start,
96
+ hunk_end_line=hunk_end,
97
+ context_before=context_before,
98
+ hunk_tokens=hunk_tokens,
99
+ context_after=context_after,
100
+ parent_sha=parent_sha,
101
+ author_date_iso=str(author_date_iso),
102
+ )
103
+
104
+ fh.write(json.dumps(asdict(record)))
105
+ fh.write("\n")
106
+ count += 1
107
+
108
+ if args.limit is not None and count >= args.limit:
109
+ print(f"Reached limit of {args.limit} records", file=sys.stderr)
110
+ print(f"Wrote {count} records to {out_path}")
111
+ return
112
+
113
+ if count == 0:
114
+ print("error: no hunks found — repository may have no history", file=sys.stderr)
115
+ sys.exit(2)
116
+
117
+ print(f"Wrote {count} records to {out_path}")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
argot/fetch.py ADDED
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ import subprocess
4
+ import sys
5
+ import tomllib
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+
9
+ from argot.extract import main as extract_main
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class RepoConfig:
14
+ name: str
15
+ url: str
16
+ limit: int
17
+ label: str
18
+
19
+
20
+ def load_repos_config(config_path: Path) -> list[RepoConfig]:
21
+ data = tomllib.loads(config_path.read_text())
22
+ return [RepoConfig(**r) for r in data["repos"]]
23
+
24
+
25
+ def merge_jsonl(sources: list[tuple[Path, str]], dest: Path) -> int:
26
+ import json
27
+
28
+ count = 0
29
+ with dest.open("w") as fh:
30
+ for src, repo_name in sources:
31
+ for line in src.read_text().splitlines():
32
+ if line.strip():
33
+ record = json.loads(line)
34
+ record["_repo"] = repo_name
35
+ fh.write(json.dumps(record) + "\n")
36
+ count += 1
37
+ return count
38
+
39
+
40
+ def _clone_or_update(url: str, dest: Path, depth: int = 2000) -> None:
41
+ if (dest / ".git").exists():
42
+ print(f" already cloned, skipping: {dest}")
43
+ return
44
+ dest.parent.mkdir(parents=True, exist_ok=True)
45
+ print(f" cloning {url} → {dest}")
46
+ subprocess.run(
47
+ ["git", "clone", "--depth", str(depth), url, str(dest)],
48
+ check=True,
49
+ )
50
+
51
+
52
+ def main() -> None:
53
+ project_root = Path(__file__).parents[2] # engine/argot/fetch.py → project root
54
+ config_path = project_root / "training" / "repos.toml"
55
+
56
+ if not config_path.exists():
57
+ print(f"error: config not found at {config_path}", file=sys.stderr)
58
+ sys.exit(2)
59
+
60
+ repos = load_repos_config(config_path)
61
+ repos_dir = project_root / ".argot" / "repos"
62
+ part_files: list[tuple[Path, str]] = []
63
+
64
+ for repo in repos:
65
+ repo_dir = repos_dir / repo.name
66
+ part_out = repos_dir / f"{repo.name}.jsonl"
67
+ print(f"\n[{repo.name}] label={repo.label} limit={repo.limit}")
68
+ _clone_or_update(repo.url, repo_dir)
69
+
70
+ sys.argv = [
71
+ "extract",
72
+ str(repo_dir),
73
+ "--out",
74
+ str(part_out),
75
+ "--limit",
76
+ str(repo.limit),
77
+ ]
78
+ extract_main()
79
+ part_files.append((part_out, repo.name))
80
+
81
+ out = project_root / ".argot" / "training.jsonl"
82
+ total = merge_jsonl(part_files, out)
83
+ print(f"\nMerged {total} records → {out}")
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()
argot/git_walk.py ADDED
@@ -0,0 +1,109 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterator
4
+ from pathlib import Path
5
+
6
+ import pygit2
7
+
8
+ SUPPORTED_EXTENSIONS = frozenset({".ts", ".tsx", ".js", ".jsx", ".py"})
9
+
10
+
11
+ def _extension(path: str) -> str:
12
+ return Path(path).suffix.lower()
13
+
14
+
15
+ def walk_repo(repo_path: str) -> Iterator[tuple[pygit2.Commit, str, bytes, list[pygit2.DiffHunk]]]:
16
+ """Yield (commit, file_path, post_blob_content, hunks) for each changed supported file.
17
+
18
+ Skips merge commits. Yields nothing if the repo has no commits.
19
+ """
20
+ repo = pygit2.Repository(repo_path)
21
+ if repo.is_empty:
22
+ return
23
+
24
+ try:
25
+ start_oid = repo.head.target
26
+ except pygit2.GitError:
27
+ branch_refs = [r for r in repo.references if r.startswith("refs/heads/")]
28
+ if not branch_refs:
29
+ return
30
+ start_oid = repo.references[branch_refs[0]].target
31
+
32
+ for commit in repo.walk(start_oid, pygit2.enums.SortMode.TOPOLOGICAL):
33
+ if len(commit.parents) != 1:
34
+ # Skip merge commits and root commits
35
+ continue
36
+
37
+ parent = commit.parents[0]
38
+ diff = parent.tree.diff_to_tree(commit.tree)
39
+ diff.find_similar()
40
+
41
+ for patch in diff:
42
+ if patch is None:
43
+ continue
44
+ file_path = patch.delta.new_file.path
45
+ if _extension(file_path) not in SUPPORTED_EXTENSIONS:
46
+ continue
47
+
48
+ hunks = list(patch.hunks)
49
+ if not hunks:
50
+ continue
51
+
52
+ try:
53
+ obj = commit.tree / file_path
54
+ if not isinstance(obj, pygit2.Blob):
55
+ continue
56
+ post_blob_content = obj.data
57
+ except KeyError:
58
+ # File deleted in this commit
59
+ continue
60
+
61
+ yield commit, file_path, post_blob_content, hunks
62
+
63
+
64
+ def walk_commits(
65
+ repo_path: str, shas: set[str]
66
+ ) -> Iterator[tuple[pygit2.Commit, str, bytes, list[pygit2.DiffHunk]]]:
67
+ """Yield (commit, file_path, post_blob_content, hunks) for commits in shas."""
68
+ repo = pygit2.Repository(repo_path)
69
+ if repo.is_empty:
70
+ return
71
+
72
+ try:
73
+ start_oid = repo.head.target
74
+ except pygit2.GitError:
75
+ branch_refs = [r for r in repo.references if r.startswith("refs/heads/")]
76
+ if not branch_refs:
77
+ return
78
+ start_oid = repo.references[branch_refs[0]].target
79
+
80
+ for commit in repo.walk(start_oid, pygit2.enums.SortMode.TOPOLOGICAL):
81
+ if str(commit.id) not in shas:
82
+ continue
83
+ if len(commit.parents) != 1:
84
+ continue
85
+
86
+ parent = commit.parents[0]
87
+ diff = parent.tree.diff_to_tree(commit.tree)
88
+ diff.find_similar()
89
+
90
+ for patch in diff:
91
+ if patch is None:
92
+ continue
93
+ file_path = patch.delta.new_file.path
94
+ if _extension(file_path) not in SUPPORTED_EXTENSIONS:
95
+ continue
96
+
97
+ hunks = list(patch.hunks)
98
+ if not hunks:
99
+ continue
100
+
101
+ try:
102
+ obj = commit.tree / file_path
103
+ if not isinstance(obj, pygit2.Blob):
104
+ continue
105
+ post_blob_content = obj.data
106
+ except KeyError:
107
+ continue
108
+
109
+ yield commit, file_path, post_blob_content, hunks
argot/jepa/__init__.py ADDED
File without changes
argot/jepa/encoder.py ADDED
@@ -0,0 +1,22 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+
9
+ class TokenEncoder(nn.Module):
10
+ """MLP encoder from TF-IDF vectors to latent embeddings."""
11
+
12
+ def __init__(self, input_dim: int = 5000, embed_dim: int = 192) -> None:
13
+ super().__init__()
14
+ self.net = nn.Sequential(
15
+ nn.Linear(input_dim, 512),
16
+ nn.GELU(),
17
+ nn.Linear(512, embed_dim),
18
+ )
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ """x: (B, input_dim) → (B, embed_dim)"""
22
+ return cast(torch.Tensor, self.net(x))