omnius 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. package/README.md +4959 -0
  2. package/dist/index.d.ts +6 -0
  3. package/dist/index.js +630665 -0
  4. package/dist/launcher.cjs +78 -0
  5. package/dist/postinstall-daemon.cjs +776 -0
  6. package/dist/preinstall.cjs +92 -0
  7. package/dist/scripts/autoresearch-prepare.py +459 -0
  8. package/dist/scripts/autoresearch-train.py +661 -0
  9. package/dist/scripts/crawlee-scraper.py +358 -0
  10. package/dist/scripts/live-nemotron.py +478 -0
  11. package/dist/scripts/live-whisper.py +242 -0
  12. package/dist/scripts/ocr-advanced.py +571 -0
  13. package/dist/scripts/start-moondream.py +112 -0
  14. package/dist/scripts/tor/UPSTREAM-README.md +148 -0
  15. package/dist/scripts/tor/destroy_tor.sh +29 -0
  16. package/dist/scripts/tor/tor_setup.sh +163 -0
  17. package/dist/scripts/transcribe-file.py +63 -0
  18. package/dist/scripts/web_scrape.py +1295 -0
  19. package/npm-shrinkwrap.json +7412 -0
  20. package/package.json +142 -0
  21. package/prompts/agentic/system-large.md +569 -0
  22. package/prompts/agentic/system-medium.md +211 -0
  23. package/prompts/agentic/system-small.md +114 -0
  24. package/prompts/compaction/context-compaction.md +44 -0
  25. package/prompts/personality/level-1-minimal.md +3 -0
  26. package/prompts/personality/level-2-concise.md +3 -0
  27. package/prompts/personality/level-4-explanatory.md +3 -0
  28. package/prompts/personality/level-5-thorough.md +3 -0
  29. package/prompts/personality/level-autist.md +3 -0
  30. package/prompts/personality/level-stark.md +3 -0
  31. package/prompts/runners/dispatcher.md +24 -0
  32. package/prompts/runners/editor.md +44 -0
  33. package/prompts/runners/evaluator.md +30 -0
  34. package/prompts/runners/merge-summary.md +9 -0
  35. package/prompts/runners/normalizer.md +23 -0
  36. package/prompts/runners/planner.md +33 -0
  37. package/prompts/runners/scout.md +39 -0
  38. package/prompts/runners/verifier.md +36 -0
  39. package/prompts/skill-builder/seed-analysis.md +30 -0
  40. package/prompts/skill-builder/skill-expansion.md +76 -0
  41. package/prompts/skill-builder/skill-validation.md +31 -0
  42. package/prompts/templates/analysis.md +14 -0
  43. package/prompts/templates/code-review.md +16 -0
  44. package/prompts/templates/code.md +13 -0
  45. package/prompts/templates/document.md +13 -0
  46. package/prompts/templates/error-diagnosis.md +14 -0
  47. package/prompts/templates/general.md +9 -0
  48. package/prompts/templates/plan.md +15 -0
  49. package/prompts/templates/system.md +16 -0
  50. package/prompts/tui/dmn-gather.md +128 -0
  51. package/prompts/tui/dream-consolidate.md +48 -0
  52. package/prompts/tui/dream-lucid-eval.md +17 -0
  53. package/prompts/tui/dream-lucid-implement.md +14 -0
  54. package/prompts/tui/dream-stages.md +19 -0
  55. package/prompts/tui/emotion-behavioral.md +2 -0
  56. package/prompts/tui/emotion-center.md +12 -0
  57. package/voices/personaplex/OverBarn.pt +0 -0
  58. package/voices/personaplex/clone-voice.py +384 -0
  59. package/voices/personaplex/dequant-loader.py +174 -0
  60. package/voices/personaplex/quantize-weights.py +167 -0
@@ -0,0 +1,92 @@
1
+ #!/usr/bin/env node
2
+ /* eslint-disable */
3
+ /**
4
+ * preinstall — runs BEFORE npm replaces files on disk. Gracefully stops
5
+ * the running OA daemon so the install window doesn't strand it with
6
+ * a half-replaced binary that crashes on next relaunch.
7
+ *
8
+ * The postinstall hook restarts the daemon after install completes with
9
+ * the new code in place.
10
+ *
11
+ * Opt-out: OA_SKIP_DAEMON_INSTALL=1 (matches postinstall semantics).
12
+ */
13
+
14
+ "use strict";
15
+
16
+ if (process.env.OA_SKIP_DAEMON_INSTALL === "1") {
17
+ process.exit(0);
18
+ }
19
+
20
+ var os = require("os");
21
+ var path = require("path");
22
+ var fs = require("fs");
23
+ var cp = require("child_process");
24
+
25
+ var IS_WIN = os.platform() === "win32";
26
+ var IS_LINUX = os.platform() === "linux";
27
+ var IS_MAC = os.platform() === "darwin";
28
+ var HOME = os.homedir();
29
+ var SERVICE_LABEL = "open-agents-daemon";
30
+ var LAUNCHD_LABEL = "ai.open-agents.daemon";
31
+ var WIN_TASK_NAME = "OpenAgentsDaemon";
32
+
33
+ function runQuiet(cmd) {
34
+ try { cp.execSync(cmd, { stdio: "pipe", timeout: 8000 }); return true; } catch (e) { return false; }
35
+ }
36
+
37
+ function log(msg) { process.stdout.write(" [preinstall] " + msg + "\n"); }
38
+
39
+ function stopServiceManager() {
40
+ // Stop via the registered service manager if one exists. This is
41
+ // graceful (sends SIGTERM, waits for exit) AND prevents the manager
42
+ // from auto-restarting the daemon mid-install.
43
+ try {
44
+ if (IS_LINUX) {
45
+ // Pause + stop. Don't disable — we want it to come back after
46
+ // postinstall re-enables/restarts it.
47
+ runQuiet("systemctl --user stop " + SERVICE_LABEL + ".service");
48
+ } else if (IS_MAC) {
49
+ var plist = path.join(HOME, "Library", "LaunchAgents", LAUNCHD_LABEL + ".plist");
50
+ if (fs.existsSync(plist)) runQuiet("launchctl unload " + plist);
51
+ } else if (IS_WIN) {
52
+ runQuiet('schtasks /End /TN "' + WIN_TASK_NAME + '"');
53
+ }
54
+ } catch (e) { /* best-effort */ }
55
+ }
56
+
57
+ function killPidFile(pidFile) {
58
+ try {
59
+ if (!fs.existsSync(pidFile)) return false;
60
+ var n = parseInt(fs.readFileSync(pidFile, "utf8").trim(), 10);
61
+ if (!n || n <= 0) return false;
62
+ try { process.kill(n, "SIGTERM"); log("SIGTERM " + pidFile + " (pid " + n + ")"); return true; }
63
+ catch (e) { /* dead */ }
64
+ } catch (e) {}
65
+ return false;
66
+ }
67
+
68
+ stopServiceManager();
69
+ killPidFile(path.join(HOME, ".open-agents", "daemon.pid"));
70
+
71
+ // Final: lsof-based sweep for any process still holding 11435.
72
+ try {
73
+ var port = parseInt(process.env.OA_PORT || "11435", 10);
74
+ var out = "";
75
+ try {
76
+ out = cp.execSync("lsof -ti :" + port + " 2>/dev/null || true", {
77
+ encoding: "utf8", timeout: 3000,
78
+ }).trim();
79
+ } catch (e) {}
80
+ if (out) {
81
+ out.split(/\s+/).forEach(function (s) {
82
+ var n = parseInt(s, 10);
83
+ if (Number.isFinite(n) && n > 0 && n !== process.pid) {
84
+ try { process.kill(n, "SIGTERM"); log("SIGTERM port-holder pid " + n); } catch (e) {}
85
+ }
86
+ });
87
+ }
88
+ } catch (e) {}
89
+
90
+ // 1.5s grace so SIGTERM handlers can flush state.
91
+ setTimeout(function () { process.exit(0); }, 1500);
92
+
@@ -0,0 +1,459 @@
1
+ """
2
+ One-time data preparation for autoresearch experiments.
3
+ Downloads data shards and trains a BPE tokenizer.
4
+
5
+ Usage:
6
+ uv run prepare.py # recommended (uses pyproject.toml deps)
7
+ python prepare.py # auto-installs missing deps into venv
8
+ python prepare.py --num-shards 8 # download only 8 shards (for testing)
9
+
10
+ Data and tokenizer are stored in ~/.cache/autoresearch/.
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import time
16
+ import math
17
+ import argparse
18
+ import pickle
19
+ import subprocess
20
+ from multiprocessing import Pool
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Auto-bootstrap: if running outside uv, ensure deps are installed
24
+ # ---------------------------------------------------------------------------
25
+
26
+ _REQUIRED_PACKAGES = {
27
+ "requests": "requests",
28
+ "pyarrow": "pyarrow",
29
+ "rustbpe": "rustbpe",
30
+ "tiktoken": "tiktoken",
31
+ "torch": "torch",
32
+ }
33
+
34
+ def _bootstrap_deps():
35
+ """Auto-install missing dependencies into a local venv if not using uv."""
36
+ missing = []
37
+ for module_name, pip_name in _REQUIRED_PACKAGES.items():
38
+ try:
39
+ __import__(module_name)
40
+ except ImportError:
41
+ missing.append(pip_name)
42
+
43
+ if not missing:
44
+ return # All deps available
45
+
46
+ print(f"Missing packages: {', '.join(missing)}")
47
+ print("Auto-installing into local venv...")
48
+
49
+ script_dir = os.path.dirname(os.path.abspath(__file__))
50
+ venv_dir = os.path.join(script_dir, ".venv")
51
+
52
+ # Create venv if needed
53
+ if not os.path.exists(venv_dir):
54
+ print(f"Creating venv at {venv_dir}...")
55
+ subprocess.check_call([sys.executable, "-m", "venv", venv_dir])
56
+
57
+ # Determine pip path
58
+ pip_path = os.path.join(venv_dir, "bin", "pip")
59
+ if not os.path.exists(pip_path):
60
+ pip_path = os.path.join(venv_dir, "Scripts", "pip.exe") # Windows
61
+
62
+ # Install missing packages
63
+ # For torch, use the CUDA 12.8 index if available
64
+ torch_pkgs = [p for p in missing if p == "torch"]
65
+ other_pkgs = [p for p in missing if p != "torch"]
66
+
67
+ if other_pkgs:
68
+ print(f"Installing: {', '.join(other_pkgs)}")
69
+ subprocess.check_call([pip_path, "install", "--quiet"] + other_pkgs)
70
+
71
+ if torch_pkgs:
72
+ print("Installing torch (CUDA 12.8)...")
73
+ subprocess.check_call([
74
+ pip_path, "install", "--quiet", "torch",
75
+ "--index-url", "https://download.pytorch.org/whl/cu128",
76
+ ])
77
+
78
+ # Re-exec with the venv's Python
79
+ venv_python = os.path.join(venv_dir, "bin", "python")
80
+ if not os.path.exists(venv_python):
81
+ venv_python = os.path.join(venv_dir, "Scripts", "python.exe")
82
+
83
+ print(f"Re-launching with venv Python: {venv_python}")
84
+ os.execv(venv_python, [venv_python] + sys.argv)
85
+
86
+ # Only bootstrap if not already in a venv/uv-managed environment
87
+ if not (hasattr(sys, "real_prefix") or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix)):
88
+ _bootstrap_deps()
89
+
90
+ import requests
91
+ import pyarrow.parquet as pq
92
+ import rustbpe
93
+ import tiktoken
94
+ import torch
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # Constants (fixed, do not modify)
98
+ # ---------------------------------------------------------------------------
99
+
100
+ MAX_SEQ_LEN = 2048 # context length
101
+ TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
102
+ EVAL_TOKENS = 40 * 524288 # number of tokens for val eval
103
+
104
+ # ---------------------------------------------------------------------------
105
+ # Configuration
106
+ # ---------------------------------------------------------------------------
107
+
108
+ CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch")
109
+ DATA_DIR = os.path.join(CACHE_DIR, "data")
110
+ TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
111
+ BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
112
+ MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
113
+ VAL_SHARD = MAX_SHARD # pinned validation shard (shard_06542)
114
+ VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet"
115
+ VOCAB_SIZE = 8192
116
+
117
+ # BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3})
118
+ SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
119
+
120
+ SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)]
121
+ BOS_TOKEN = "<|reserved_0|>"
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # Data download
125
+ # ---------------------------------------------------------------------------
126
+
127
+ def download_single_shard(index):
128
+ """Download one parquet shard with retries. Returns True on success."""
129
+ filename = f"shard_{index:05d}.parquet"
130
+ filepath = os.path.join(DATA_DIR, filename)
131
+ if os.path.exists(filepath):
132
+ return True
133
+
134
+ url = f"{BASE_URL}/{filename}"
135
+ max_attempts = 5
136
+ for attempt in range(1, max_attempts + 1):
137
+ try:
138
+ response = requests.get(url, stream=True, timeout=30)
139
+ response.raise_for_status()
140
+ temp_path = filepath + ".tmp"
141
+ with open(temp_path, "wb") as f:
142
+ for chunk in response.iter_content(chunk_size=1024 * 1024):
143
+ if chunk:
144
+ f.write(chunk)
145
+ os.rename(temp_path, filepath)
146
+ print(f" Downloaded {filename}")
147
+ return True
148
+ except (requests.RequestException, IOError) as e:
149
+ print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
150
+ for path in [filepath + ".tmp", filepath]:
151
+ if os.path.exists(path):
152
+ try:
153
+ os.remove(path)
154
+ except OSError:
155
+ pass
156
+ if attempt < max_attempts:
157
+ time.sleep(2 ** attempt)
158
+ return False
159
+
160
+
161
+ def download_data(num_shards, download_workers=8):
162
+ """Download training shards + pinned validation shard."""
163
+ os.makedirs(DATA_DIR, exist_ok=True)
164
+ num_train = min(num_shards, MAX_SHARD)
165
+ ids = list(range(num_train))
166
+ if VAL_SHARD not in ids:
167
+ ids.append(VAL_SHARD)
168
+
169
+ # Count what's already downloaded
170
+ existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet")))
171
+ if existing == len(ids):
172
+ print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}")
173
+ return
174
+
175
+ needed = len(ids) - existing
176
+ print(f"Data: downloading {needed} shards ({existing} already exist)...")
177
+
178
+ workers = max(1, min(download_workers, needed))
179
+ with Pool(processes=workers) as pool:
180
+ results = pool.map(download_single_shard, ids)
181
+
182
+ ok = sum(1 for r in results if r)
183
+ print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}")
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # Tokenizer training
187
+ # ---------------------------------------------------------------------------
188
+
189
+ def list_parquet_files():
190
+ """Return sorted list of parquet file paths in the data directory."""
191
+ files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp"))
192
+ return [os.path.join(DATA_DIR, f) for f in files]
193
+
194
+
195
+ def text_iterator(max_chars=1_000_000_000, doc_cap=10_000):
196
+ """Yield documents from training split (all shards except pinned val shard)."""
197
+ parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)]
198
+ nchars = 0
199
+ for filepath in parquet_paths:
200
+ pf = pq.ParquetFile(filepath)
201
+ for rg_idx in range(pf.num_row_groups):
202
+ rg = pf.read_row_group(rg_idx)
203
+ for text in rg.column("text").to_pylist():
204
+ doc = text[:doc_cap] if len(text) > doc_cap else text
205
+ nchars += len(doc)
206
+ yield doc
207
+ if nchars >= max_chars:
208
+ return
209
+
210
+
211
+ def train_tokenizer():
212
+ """Train BPE tokenizer using rustbpe, save as tiktoken pickle."""
213
+ tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl")
214
+ token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
215
+
216
+ if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path):
217
+ print(f"Tokenizer: already trained at {TOKENIZER_DIR}")
218
+ return
219
+
220
+ os.makedirs(TOKENIZER_DIR, exist_ok=True)
221
+
222
+ parquet_files = list_parquet_files()
223
+ if len(parquet_files) < 2:
224
+ print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.")
225
+ sys.exit(1)
226
+
227
+ # --- Train with rustbpe ---
228
+ print("Tokenizer: training BPE tokenizer...")
229
+ t0 = time.time()
230
+
231
+ tokenizer = rustbpe.Tokenizer()
232
+ vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
233
+ tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
234
+
235
+ # Build tiktoken encoding from trained merges
236
+ pattern = tokenizer.get_pattern()
237
+ mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()}
238
+ tokens_offset = len(mergeable_ranks)
239
+ special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
240
+ enc = tiktoken.Encoding(
241
+ name="rustbpe",
242
+ pat_str=pattern,
243
+ mergeable_ranks=mergeable_ranks,
244
+ special_tokens=special_tokens,
245
+ )
246
+
247
+ # Save tokenizer
248
+ with open(tokenizer_pkl, "wb") as f:
249
+ pickle.dump(enc, f)
250
+
251
+ t1 = time.time()
252
+ print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}")
253
+
254
+ # --- Build token_bytes lookup for BPB evaluation ---
255
+ print("Tokenizer: building token_bytes lookup...")
256
+ special_set = set(SPECIAL_TOKENS)
257
+ token_bytes_list = []
258
+ for token_id in range(enc.n_vocab):
259
+ token_str = enc.decode([token_id])
260
+ if token_str in special_set:
261
+ token_bytes_list.append(0)
262
+ else:
263
+ token_bytes_list.append(len(token_str.encode("utf-8")))
264
+ token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
265
+ torch.save(token_bytes_tensor, token_bytes_path)
266
+ print(f"Tokenizer: saved token_bytes to {token_bytes_path}")
267
+
268
+ # Sanity check
269
+ test = "Hello world! Numbers: 123. Unicode: 你好"
270
+ encoded = enc.encode_ordinary(test)
271
+ decoded = enc.decode(encoded)
272
+ assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}"
273
+ print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})")
274
+
275
+ # ---------------------------------------------------------------------------
276
+ # Runtime utilities (imported by train.py)
277
+ # ---------------------------------------------------------------------------
278
+
279
+ class Tokenizer:
280
+ """Minimal tokenizer wrapper. Training is handled above."""
281
+
282
+ def __init__(self, enc):
283
+ self.enc = enc
284
+ self.bos_token_id = enc.encode_single_token(BOS_TOKEN)
285
+
286
+ @classmethod
287
+ def from_directory(cls, tokenizer_dir=TOKENIZER_DIR):
288
+ with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f:
289
+ enc = pickle.load(f)
290
+ return cls(enc)
291
+
292
+ def get_vocab_size(self):
293
+ return self.enc.n_vocab
294
+
295
+ def get_bos_token_id(self):
296
+ return self.bos_token_id
297
+
298
+ def encode(self, text, prepend=None, num_threads=8):
299
+ if prepend is not None:
300
+ prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
301
+ if isinstance(text, str):
302
+ ids = self.enc.encode_ordinary(text)
303
+ if prepend is not None:
304
+ ids.insert(0, prepend_id)
305
+ elif isinstance(text, list):
306
+ ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
307
+ if prepend is not None:
308
+ for row in ids:
309
+ row.insert(0, prepend_id)
310
+ else:
311
+ raise ValueError(f"Invalid input type: {type(text)}")
312
+ return ids
313
+
314
+ def decode(self, ids):
315
+ return self.enc.decode(ids)
316
+
317
+
318
+ def get_token_bytes(device="cpu"):
319
+ path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
320
+ with open(path, "rb") as f:
321
+ return torch.load(f, map_location=device)
322
+
323
+
324
+ def _document_batches(split, tokenizer_batch_size=128):
325
+ """Infinite iterator over document batches from parquet files."""
326
+ parquet_paths = list_parquet_files()
327
+ assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first."
328
+ val_path = os.path.join(DATA_DIR, VAL_FILENAME)
329
+ if split == "train":
330
+ parquet_paths = [p for p in parquet_paths if p != val_path]
331
+ assert len(parquet_paths) > 0, "No training shards found."
332
+ else:
333
+ parquet_paths = [val_path]
334
+ epoch = 1
335
+ while True:
336
+ for filepath in parquet_paths:
337
+ pf = pq.ParquetFile(filepath)
338
+ for rg_idx in range(pf.num_row_groups):
339
+ rg = pf.read_row_group(rg_idx)
340
+ batch = rg.column('text').to_pylist()
341
+ for i in range(0, len(batch), tokenizer_batch_size):
342
+ yield batch[i:i+tokenizer_batch_size], epoch
343
+ epoch += 1
344
+
345
+
346
+ def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
347
+ """
348
+ BOS-aligned dataloader with best-fit packing.
349
+ Every row starts with BOS. Documents packed using best-fit to minimize cropping.
350
+ When no document fits remaining space, crops shortest doc to fill exactly.
351
+ 100% utilization (no padding).
352
+ """
353
+ assert split in ["train", "val"]
354
+ row_capacity = T + 1
355
+ batches = _document_batches(split)
356
+ bos_token = tokenizer.get_bos_token_id()
357
+ doc_buffer = []
358
+ epoch = 1
359
+
360
+ def refill_buffer():
361
+ nonlocal epoch
362
+ doc_batch, epoch = next(batches)
363
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
364
+ doc_buffer.extend(token_lists)
365
+
366
+ # Pre-allocate buffers: [inputs (B*T) | targets (B*T)]
367
+ row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
368
+ cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
369
+ gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda")
370
+ cpu_inputs = cpu_buffer[:B * T].view(B, T)
371
+ cpu_targets = cpu_buffer[B * T:].view(B, T)
372
+ inputs = gpu_buffer[:B * T].view(B, T)
373
+ targets = gpu_buffer[B * T:].view(B, T)
374
+
375
+ while True:
376
+ for row_idx in range(B):
377
+ pos = 0
378
+ while pos < row_capacity:
379
+ while len(doc_buffer) < buffer_size:
380
+ refill_buffer()
381
+
382
+ remaining = row_capacity - pos
383
+
384
+ # Find largest doc that fits entirely
385
+ best_idx = -1
386
+ best_len = 0
387
+ for i, doc in enumerate(doc_buffer):
388
+ doc_len = len(doc)
389
+ if doc_len <= remaining and doc_len > best_len:
390
+ best_idx = i
391
+ best_len = doc_len
392
+
393
+ if best_idx >= 0:
394
+ doc = doc_buffer.pop(best_idx)
395
+ row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
396
+ pos += len(doc)
397
+ else:
398
+ # No doc fits — crop shortest to fill remaining
399
+ shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
400
+ doc = doc_buffer.pop(shortest_idx)
401
+ row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
402
+ pos += remaining
403
+
404
+ cpu_inputs.copy_(row_buffer[:, :-1])
405
+ cpu_targets.copy_(row_buffer[:, 1:])
406
+ gpu_buffer.copy_(cpu_buffer, non_blocking=True)
407
+ yield inputs, targets, epoch
408
+
409
+ # ---------------------------------------------------------------------------
410
+ # Evaluation (DO NOT CHANGE — this is the fixed metric)
411
+ # ---------------------------------------------------------------------------
412
+
413
+ @torch.no_grad()
414
+ def evaluate_bpb(model, tokenizer, batch_size):
415
+ """
416
+ Bits per byte (BPB): vocab size-independent evaluation metric.
417
+ Sums per-token cross-entropy (in nats), sums target byte lengths,
418
+ then converts nats/byte to bits/byte. Special tokens (byte length 0)
419
+ are excluded from both sums.
420
+ Uses fixed MAX_SEQ_LEN so results are comparable across configs.
421
+ """
422
+ token_bytes = get_token_bytes(device="cuda")
423
+ val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
424
+ steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
425
+ total_nats = 0.0
426
+ total_bytes = 0
427
+ for _ in range(steps):
428
+ x, y, _ = next(val_loader)
429
+ loss_flat = model(x, y, reduction='none').view(-1)
430
+ y_flat = y.view(-1)
431
+ nbytes = token_bytes[y_flat]
432
+ mask = nbytes > 0
433
+ total_nats += (loss_flat * mask).sum().item()
434
+ total_bytes += nbytes.sum().item()
435
+ return total_nats / (math.log(2) * total_bytes)
436
+
437
+ # ---------------------------------------------------------------------------
438
+ # Main
439
+ # ---------------------------------------------------------------------------
440
+
441
+ if __name__ == "__main__":
442
+ parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch")
443
+ parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.")
444
+ parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers")
445
+ args = parser.parse_args()
446
+
447
+ num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards
448
+
449
+ print(f"Cache directory: {CACHE_DIR}")
450
+ print()
451
+
452
+ # Step 1: Download data
453
+ download_data(num_shards, download_workers=args.download_workers)
454
+ print()
455
+
456
+ # Step 2: Train tokenizer
457
+ train_tokenizer()
458
+ print()
459
+ print("Done! Ready to train.")