nested-learning 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nested_learning/__init__.py +12 -0
  2. nested_learning/__main__.py +12 -0
  3. nested_learning/assoc_memory.py +23 -0
  4. nested_learning/backbones.py +147 -0
  5. nested_learning/capabilities.py +104 -0
  6. nested_learning/cli.py +253 -0
  7. nested_learning/cms.py +92 -0
  8. nested_learning/config_utils.py +50 -0
  9. nested_learning/configs/ablations/cms_sparse.yaml +46 -0
  10. nested_learning/configs/ablations/selfmod_chunked_8_64.yaml +24 -0
  11. nested_learning/configs/ablations/selfmod_momentum_off.yaml +23 -0
  12. nested_learning/configs/ablations/selfmod_momentum_on.yaml +23 -0
  13. nested_learning/configs/ablations/selfmod_no_alpha.yaml +23 -0
  14. nested_learning/configs/ablations/selfmod_no_cms.yaml +23 -0
  15. nested_learning/configs/ablations/selfmod_rank1_precond_off.yaml +23 -0
  16. nested_learning/configs/data/continual_segments_sample.yaml +9 -0
  17. nested_learning/configs/data/fineweb_edu_longdoc_filtered_sample.yaml +14 -0
  18. nested_learning/configs/data/fineweb_edu_mixture_full.yaml +14 -0
  19. nested_learning/configs/data/fineweb_edu_mixture_sample.yaml +14 -0
  20. nested_learning/configs/data/refinedweb_mixture.yaml +48 -0
  21. nested_learning/configs/data/refinedweb_mixture_filtered.yaml +48 -0
  22. nested_learning/configs/data/refinedweb_mixture_full.yaml +48 -0
  23. nested_learning/configs/data/refinedweb_mixture_sample.yaml +51 -0
  24. nested_learning/configs/deepspeed/zero3.json +25 -0
  25. nested_learning/configs/hope/mid.yaml +118 -0
  26. nested_learning/configs/hope/mid_fsdp.yaml +47 -0
  27. nested_learning/configs/hope/pilot.yaml +2 -0
  28. nested_learning/configs/hope/pilot_attention.yaml +9 -0
  29. nested_learning/configs/hope/pilot_selfmod.yaml +20 -0
  30. nested_learning/configs/hope/pilot_transformer.yaml +9 -0
  31. nested_learning/configs/hope/target.yaml +145 -0
  32. nested_learning/configs/hope/target_fsdp.yaml +47 -0
  33. nested_learning/configs/mid_smoke.yaml +99 -0
  34. nested_learning/configs/mid_stage2.yaml +110 -0
  35. nested_learning/configs/mid_stage2_smoke.yaml +102 -0
  36. nested_learning/configs/mid_titan_baseline.yaml +92 -0
  37. nested_learning/configs/pilot.yaml +127 -0
  38. nested_learning/configs/pilot_paper_faithful.yaml +42 -0
  39. nested_learning/configs/pilot_selfmod_paper_faithful.yaml +18 -0
  40. nested_learning/configs/pilot_smoke.yaml +80 -0
  41. nested_learning/configs/resolved/cms_sparse_eval.yaml +105 -0
  42. nested_learning/configs/resolved/phase2_pilot_attention_eval.yaml +49 -0
  43. nested_learning/configs/resolved/phase2_pilot_transformer_eval.yaml +49 -0
  44. nested_learning/continual_classification.py +136 -0
  45. nested_learning/continual_streaming.py +283 -0
  46. nested_learning/data.py +153 -0
  47. nested_learning/device.py +21 -0
  48. nested_learning/eval_state.py +72 -0
  49. nested_learning/fast_state.py +108 -0
  50. nested_learning/functional.py +69 -0
  51. nested_learning/hope/__init__.py +0 -0
  52. nested_learning/hope/block.py +1973 -0
  53. nested_learning/hope/self_mod.py +40 -0
  54. nested_learning/instrumentation.py +38 -0
  55. nested_learning/levels.py +94 -0
  56. nested_learning/logging_utils.py +64 -0
  57. nested_learning/memorize.py +382 -0
  58. nested_learning/model.py +604 -0
  59. nested_learning/optim/__init__.py +0 -0
  60. nested_learning/optim/deep.py +102 -0
  61. nested_learning/optim/factory.py +13 -0
  62. nested_learning/optim/m3.py +121 -0
  63. nested_learning/optim/manager.py +151 -0
  64. nested_learning/titan/__init__.py +0 -0
  65. nested_learning/titan/memory.py +88 -0
  66. nested_learning/titan/model.py +412 -0
  67. nested_learning/titan/self_modifying.py +724 -0
  68. nested_learning/tokenizer.py +28 -0
  69. nested_learning/tokenizer_coverage.py +77 -0
  70. nested_learning/training.py +1600 -0
  71. nested_learning/transformer.py +104 -0
  72. nested_learning-0.2.0.dist-info/METADATA +390 -0
  73. nested_learning-0.2.0.dist-info/RECORD +76 -0
  74. nested_learning-0.2.0.dist-info/WHEEL +4 -0
  75. nested_learning-0.2.0.dist-info/entry_points.txt +2 -0
  76. nested_learning-0.2.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Sequence
5
+
6
+ import sentencepiece as spm
7
+ import torch
8
+
9
+
10
+ class SentencePieceTokenizer:
11
+ def __init__(self, model_path: str | Path):
12
+ self.processor = spm.SentencePieceProcessor(model_file=str(model_path))
13
+
14
+ @property
15
+ def vocab_size(self) -> int:
16
+ return self.processor.vocab_size()
17
+
18
+ def encode(self, text: str, add_bos: bool = False, add_eos: bool = True) -> torch.Tensor:
19
+ tokens: list[int] = []
20
+ if add_bos:
21
+ tokens.append(self.processor.bos_id())
22
+ tokens.extend(self.processor.encode(text))
23
+ if add_eos:
24
+ tokens.append(self.processor.eos_id())
25
+ return torch.tensor(tokens, dtype=torch.long)
26
+
27
+ def batch_encode(self, texts: Sequence[str]) -> list[torch.Tensor]:
28
+ return [self.encode(text) for text in texts]
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import Counter
4
+ from pathlib import Path
5
+ from typing import Dict
6
+
7
+ from .tokenizer import SentencePieceTokenizer
8
+
9
+
10
+ def compute_tokenizer_coverage_stats(
11
+ tokenizer_path: Path,
12
+ sample_file: Path,
13
+ max_lines: int = 10_000,
14
+ ) -> Dict[str, object]:
15
+ """
16
+ Compute tokenizer coverage statistics on a representative text sample.
17
+
18
+ Returns a JSON-serialisable dictionary; shared by both the coverage CLI and
19
+ the regression guard so they cannot drift apart silently.
20
+ """
21
+
22
+ tokenizer = SentencePieceTokenizer(tokenizer_path)
23
+ total_words = 0
24
+ total_tokens = 0
25
+ total_chars = 0
26
+ processed_lines = 0
27
+ word_token_lengths: list[int] = []
28
+ piece_lengths: Counter[int] = Counter()
29
+
30
+ with sample_file.open("r", encoding="utf-8") as handle:
31
+ for idx, line in enumerate(handle):
32
+ if idx >= max_lines:
33
+ break
34
+ stripped = line.strip()
35
+ if not stripped:
36
+ continue
37
+ processed_lines += 1
38
+ total_chars += len(stripped)
39
+ words = stripped.split()
40
+ if not words:
41
+ continue
42
+ total_words += len(words)
43
+ encoded = tokenizer.encode(stripped, add_bos=False, add_eos=False)
44
+ ids = encoded.tolist()
45
+ total_tokens += len(ids)
46
+ for word in words:
47
+ word_tokens = tokenizer.encode(word, add_bos=False, add_eos=False).tolist()
48
+ if not word_tokens:
49
+ continue
50
+ word_token_lengths.append(len(word_tokens))
51
+ for token_id in ids:
52
+ piece = tokenizer.processor.id_to_piece(token_id)
53
+ piece_lengths[len(piece)] += 1
54
+
55
+ if total_words == 0 or not word_token_lengths:
56
+ raise ValueError("Sample produced no words; double-check the sample_file path.")
57
+
58
+ avg_tokens_per_word = total_tokens / total_words if total_words else 0.0
59
+ pct_single_token = sum(1 for length in word_token_lengths if length == 1) / len(
60
+ word_token_lengths
61
+ )
62
+ pct_two_or_less = sum(1 for length in word_token_lengths if length <= 2) / len(
63
+ word_token_lengths
64
+ )
65
+
66
+ return {
67
+ "tokenizer": str(tokenizer_path),
68
+ "sample_file": str(sample_file),
69
+ "lines_processed": processed_lines,
70
+ "total_words": total_words,
71
+ "total_tokens": total_tokens,
72
+ "avg_tokens_per_word": avg_tokens_per_word,
73
+ "pct_single_token_words": pct_single_token,
74
+ "pct_two_or_less_tokens_words": pct_two_or_less,
75
+ "avg_chars_per_word": total_chars / total_words,
76
+ "piece_length_histogram": dict(piece_lengths.most_common(20)),
77
+ }