nested-learning 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nested_learning/__init__.py +12 -0
- nested_learning/__main__.py +12 -0
- nested_learning/assoc_memory.py +23 -0
- nested_learning/backbones.py +147 -0
- nested_learning/capabilities.py +104 -0
- nested_learning/cli.py +253 -0
- nested_learning/cms.py +92 -0
- nested_learning/config_utils.py +50 -0
- nested_learning/configs/ablations/cms_sparse.yaml +46 -0
- nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
- nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
- nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
- nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
- nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
- nested_learning/configs/data/continual_segments_sample.yaml +9 -0
- nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
- nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
- nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
- nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
- nested_learning/configs/deepspeed/zero3.json +25 -0
- nested_learning/configs/hope/mid.yaml +118 -0
- nested_learning/configs/hope/mid_fsdp.yaml +47 -0
- nested_learning/configs/hope/pilot.yaml +2 -0
- nested_learning/configs/hope/pilot_attention.yaml +9 -0
- nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
- nested_learning/configs/hope/pilot_transformer.yaml +9 -0
- nested_learning/configs/hope/target.yaml +145 -0
- nested_learning/configs/hope/target_fsdp.yaml +47 -0
- nested_learning/configs/mid_smoke.yaml +99 -0
- nested_learning/configs/mid_stage2.yaml +110 -0
- nested_learning/configs/mid_stage2_smoke.yaml +102 -0
- nested_learning/configs/mid_titan_baseline.yaml +92 -0
- nested_learning/configs/pilot.yaml +127 -0
- nested_learning/configs/pilot_paper_faithful.yaml +42 -0
- nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
- nested_learning/configs/pilot_smoke.yaml +80 -0
- nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
- nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
- nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
- nested_learning/continual_classification.py +136 -0
- nested_learning/continual_streaming.py +283 -0
- nested_learning/data.py +153 -0
- nested_learning/device.py +21 -0
- nested_learning/eval_state.py +72 -0
- nested_learning/fast_state.py +108 -0
- nested_learning/functional.py +69 -0
- nested_learning/hope/__init__.py +0 -0
- nested_learning/hope/block.py +1973 -0
- nested_learning/hope/self_mod.py +40 -0
- nested_learning/instrumentation.py +38 -0
- nested_learning/levels.py +94 -0
- nested_learning/logging_utils.py +64 -0
- nested_learning/memorize.py +382 -0
- nested_learning/model.py +604 -0
- nested_learning/optim/__init__.py +0 -0
- nested_learning/optim/deep.py +102 -0
- nested_learning/optim/factory.py +13 -0
- nested_learning/optim/m3.py +121 -0
- nested_learning/optim/manager.py +151 -0
- nested_learning/titan/__init__.py +0 -0
- nested_learning/titan/memory.py +88 -0
- nested_learning/titan/model.py +412 -0
- nested_learning/titan/self_modifying.py +724 -0
- nested_learning/tokenizer.py +28 -0
- nested_learning/tokenizer_coverage.py +77 -0
- nested_learning/training.py +1600 -0
- nested_learning/transformer.py +104 -0
- nested_learning-0.2.0.dist-info/METADATA +390 -0
- nested_learning-0.2.0.dist-info/RECORD +76 -0
- nested_learning-0.2.0.dist-info/WHEEL +4 -0
- nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
- nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Sequence
|
|
5
|
+
|
|
6
|
+
import sentencepiece as spm
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SentencePieceTokenizer:
|
|
11
|
+
def __init__(self, model_path: str | Path):
|
|
12
|
+
self.processor = spm.SentencePieceProcessor(model_file=str(model_path))
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def vocab_size(self) -> int:
|
|
16
|
+
return self.processor.vocab_size()
|
|
17
|
+
|
|
18
|
+
def encode(self, text: str, add_bos: bool = False, add_eos: bool = True) -> torch.Tensor:
|
|
19
|
+
tokens: list[int] = []
|
|
20
|
+
if add_bos:
|
|
21
|
+
tokens.append(self.processor.bos_id())
|
|
22
|
+
tokens.extend(self.processor.encode(text))
|
|
23
|
+
if add_eos:
|
|
24
|
+
tokens.append(self.processor.eos_id())
|
|
25
|
+
return torch.tensor(tokens, dtype=torch.long)
|
|
26
|
+
|
|
27
|
+
def batch_encode(self, texts: Sequence[str]) -> list[torch.Tensor]:
|
|
28
|
+
return [self.encode(text) for text in texts]
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import Counter
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict
|
|
6
|
+
|
|
7
|
+
from .tokenizer import SentencePieceTokenizer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def compute_tokenizer_coverage_stats(
|
|
11
|
+
tokenizer_path: Path,
|
|
12
|
+
sample_file: Path,
|
|
13
|
+
max_lines: int = 10_000,
|
|
14
|
+
) -> Dict[str, object]:
|
|
15
|
+
"""
|
|
16
|
+
Compute tokenizer coverage statistics on a representative text sample.
|
|
17
|
+
|
|
18
|
+
Returns a JSON-serialisable dictionary; shared by both the coverage CLI and
|
|
19
|
+
the regression guard so they cannot drift apart silently.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
tokenizer = SentencePieceTokenizer(tokenizer_path)
|
|
23
|
+
total_words = 0
|
|
24
|
+
total_tokens = 0
|
|
25
|
+
total_chars = 0
|
|
26
|
+
processed_lines = 0
|
|
27
|
+
word_token_lengths: list[int] = []
|
|
28
|
+
piece_lengths: Counter[int] = Counter()
|
|
29
|
+
|
|
30
|
+
with sample_file.open("r", encoding="utf-8") as handle:
|
|
31
|
+
for idx, line in enumerate(handle):
|
|
32
|
+
if idx >= max_lines:
|
|
33
|
+
break
|
|
34
|
+
stripped = line.strip()
|
|
35
|
+
if not stripped:
|
|
36
|
+
continue
|
|
37
|
+
processed_lines += 1
|
|
38
|
+
total_chars += len(stripped)
|
|
39
|
+
words = stripped.split()
|
|
40
|
+
if not words:
|
|
41
|
+
continue
|
|
42
|
+
total_words += len(words)
|
|
43
|
+
encoded = tokenizer.encode(stripped, add_bos=False, add_eos=False)
|
|
44
|
+
ids = encoded.tolist()
|
|
45
|
+
total_tokens += len(ids)
|
|
46
|
+
for word in words:
|
|
47
|
+
word_tokens = tokenizer.encode(word, add_bos=False, add_eos=False).tolist()
|
|
48
|
+
if not word_tokens:
|
|
49
|
+
continue
|
|
50
|
+
word_token_lengths.append(len(word_tokens))
|
|
51
|
+
for token_id in ids:
|
|
52
|
+
piece = tokenizer.processor.id_to_piece(token_id)
|
|
53
|
+
piece_lengths[len(piece)] += 1
|
|
54
|
+
|
|
55
|
+
if total_words == 0 or not word_token_lengths:
|
|
56
|
+
raise ValueError("Sample produced no words; double-check the sample_file path.")
|
|
57
|
+
|
|
58
|
+
avg_tokens_per_word = total_tokens / total_words if total_words else 0.0
|
|
59
|
+
pct_single_token = sum(1 for length in word_token_lengths if length == 1) / len(
|
|
60
|
+
word_token_lengths
|
|
61
|
+
)
|
|
62
|
+
pct_two_or_less = sum(1 for length in word_token_lengths if length <= 2) / len(
|
|
63
|
+
word_token_lengths
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return {
|
|
67
|
+
"tokenizer": str(tokenizer_path),
|
|
68
|
+
"sample_file": str(sample_file),
|
|
69
|
+
"lines_processed": processed_lines,
|
|
70
|
+
"total_words": total_words,
|
|
71
|
+
"total_tokens": total_tokens,
|
|
72
|
+
"avg_tokens_per_word": avg_tokens_per_word,
|
|
73
|
+
"pct_single_token_words": pct_single_token,
|
|
74
|
+
"pct_two_or_less_tokens_words": pct_two_or_less,
|
|
75
|
+
"avg_chars_per_word": total_chars / total_words,
|
|
76
|
+
"piece_length_histogram": dict(piece_lengths.most_common(20)),
|
|
77
|
+
}
|