marin-levanter 0.99__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. marin_levanter-0.99/PKG-INFO +291 -0
  2. marin_levanter-0.99/README.md +207 -0
  3. marin_levanter-0.99/pyproject.toml +215 -0
  4. marin_levanter-0.99/src/levanter/__init__.py +50 -0
  5. marin_levanter-0.99/src/levanter/_debug_logging.py +26 -0
  6. marin_levanter-0.99/src/levanter/analysis/__init__.py +17 -0
  7. marin_levanter-0.99/src/levanter/analysis/entropy.py +222 -0
  8. marin_levanter-0.99/src/levanter/analysis/model_perplexity.py +427 -0
  9. marin_levanter-0.99/src/levanter/analysis/perplexity_gap.py +940 -0
  10. marin_levanter-0.99/src/levanter/analysis/tree_stats.py +102 -0
  11. marin_levanter-0.99/src/levanter/analysis/visualization.py +339 -0
  12. marin_levanter-0.99/src/levanter/callbacks/__init__.py +238 -0
  13. marin_levanter-0.99/src/levanter/callbacks/_core.py +119 -0
  14. marin_levanter-0.99/src/levanter/callbacks/_metrics.py +158 -0
  15. marin_levanter-0.99/src/levanter/callbacks/profiler.py +94 -0
  16. marin_levanter-0.99/src/levanter/callbacks/state_adapter.py +68 -0
  17. marin_levanter-0.99/src/levanter/callbacks/tensorstore_callbacks.py +140 -0
  18. marin_levanter-0.99/src/levanter/callbacks/watch.py +194 -0
  19. marin_levanter-0.99/src/levanter/checkpoint.py +1113 -0
  20. marin_levanter-0.99/src/levanter/compat/__init__.py +2 -0
  21. marin_levanter-0.99/src/levanter/compat/fsspec_safetensor.py +289 -0
  22. marin_levanter-0.99/src/levanter/compat/hf_checkpoints.py +1558 -0
  23. marin_levanter-0.99/src/levanter/config.py +169 -0
  24. marin_levanter-0.99/src/levanter/data/__init__.py +30 -0
  25. marin_levanter-0.99/src/levanter/data/_preprocessor.py +277 -0
  26. marin_levanter-0.99/src/levanter/data/_prp.py +271 -0
  27. marin_levanter-0.99/src/levanter/data/audio.py +558 -0
  28. marin_levanter-0.99/src/levanter/data/dataset.py +442 -0
  29. marin_levanter-0.99/src/levanter/data/loader.py +643 -0
  30. marin_levanter-0.99/src/levanter/data/mixture.py +536 -0
  31. marin_levanter-0.99/src/levanter/data/packing.py +628 -0
  32. marin_levanter-0.99/src/levanter/data/passthrough_tokenizer.py +100 -0
  33. marin_levanter-0.99/src/levanter/data/permutation.py +290 -0
  34. marin_levanter-0.99/src/levanter/data/sharded_datasource.py +581 -0
  35. marin_levanter-0.99/src/levanter/data/text/__init__.py +96 -0
  36. marin_levanter-0.99/src/levanter/data/text/_batch_tokenizer.py +191 -0
  37. marin_levanter-0.99/src/levanter/data/text/cache.py +63 -0
  38. marin_levanter-0.99/src/levanter/data/text/datasets.py +999 -0
  39. marin_levanter-0.99/src/levanter/data/text/examples.py +226 -0
  40. marin_levanter-0.99/src/levanter/data/text/formats.py +282 -0
  41. marin_levanter-0.99/src/levanter/data/text/preference.py +316 -0
  42. marin_levanter-0.99/src/levanter/data/utils.py +20 -0
  43. marin_levanter-0.99/src/levanter/distributed.py +262 -0
  44. marin_levanter-0.99/src/levanter/eval.py +574 -0
  45. marin_levanter-0.99/src/levanter/eval_harness.py +1757 -0
  46. marin_levanter-0.99/src/levanter/grad_accum.py +187 -0
  47. marin_levanter-0.99/src/levanter/grug/__init__.py +8 -0
  48. marin_levanter-0.99/src/levanter/grug/attention.py +407 -0
  49. marin_levanter-0.99/src/levanter/grug/grug_moe.py +700 -0
  50. marin_levanter-0.99/src/levanter/grug/loss.py +176 -0
  51. marin_levanter-0.99/src/levanter/grug/sharding.py +29 -0
  52. marin_levanter-0.99/src/levanter/inference/engine.py +1382 -0
  53. marin_levanter-0.99/src/levanter/inference/jit_scheduler.py +1214 -0
  54. marin_levanter-0.99/src/levanter/inference/openai.py +823 -0
  55. marin_levanter-0.99/src/levanter/inference/openai_protocol.py +87 -0
  56. marin_levanter-0.99/src/levanter/inference/page_table.py +94 -0
  57. marin_levanter-0.99/src/levanter/inference/utils.py +140 -0
  58. marin_levanter-0.99/src/levanter/infra/__init__.py +2 -0
  59. marin_levanter-0.99/src/levanter/infra/cli_helpers.py +176 -0
  60. marin_levanter-0.99/src/levanter/infra/docker.py +245 -0
  61. marin_levanter-0.99/src/levanter/infra/tpus.py +302 -0
  62. marin_levanter-0.99/src/levanter/kernels/pallas/__init__.py +10 -0
  63. marin_levanter-0.99/src/levanter/kernels/pallas/autotune_cache_utils.py +70 -0
  64. marin_levanter-0.99/src/levanter/kernels/pallas/autotune_utils.py +168 -0
  65. marin_levanter-0.99/src/levanter/kernels/pallas/cost_estimate_utils.py +36 -0
  66. marin_levanter-0.99/src/levanter/kernels/pallas/fused_cross_entropy_loss/__init__.py +28 -0
  67. marin_levanter-0.99/src/levanter/kernels/pallas/fused_cross_entropy_loss/api.py +750 -0
  68. marin_levanter-0.99/src/levanter/kernels/pallas/fused_cross_entropy_loss/config.py +22 -0
  69. marin_levanter-0.99/src/levanter/kernels/pallas/fused_cross_entropy_loss/pallas_gpu.py +879 -0
  70. marin_levanter-0.99/src/levanter/kernels/pallas/fused_cross_entropy_loss/pallas_tpu.py +785 -0
  71. marin_levanter-0.99/src/levanter/kernels/pallas/fused_cross_entropy_loss/reference.py +143 -0
  72. marin_levanter-0.99/src/levanter/kernels/pallas/fused_cross_entropy_loss/tuned_block_sizes.py +832 -0
  73. marin_levanter-0.99/src/levanter/kernels/pallas/fused_cross_entropy_loss/xla.py +501 -0
  74. marin_levanter-0.99/src/levanter/kernels/pallas/mamba3/__init__.py +66 -0
  75. marin_levanter-0.99/src/levanter/kernels/pallas/mamba3/api.py +1123 -0
  76. marin_levanter-0.99/src/levanter/kernels/pallas/mamba3/config.py +55 -0
  77. marin_levanter-0.99/src/levanter/kernels/pallas/mamba3/reference.py +599 -0
  78. marin_levanter-0.99/src/levanter/kernels/pallas/mamba3/xla.py +370 -0
  79. marin_levanter-0.99/src/levanter/kernels/pallas/ssd/__init__.py +36 -0
  80. marin_levanter-0.99/src/levanter/kernels/pallas/ssd/api.py +252 -0
  81. marin_levanter-0.99/src/levanter/kernels/pallas/ssd/config.py +18 -0
  82. marin_levanter-0.99/src/levanter/kernels/pallas/ssd/reference.py +271 -0
  83. marin_levanter-0.99/src/levanter/kernels/pallas/ssd/xla.py +252 -0
  84. marin_levanter-0.99/src/levanter/kernels/pallas/template_kernel.py +142 -0
  85. marin_levanter-0.99/src/levanter/layers/__init__.py +28 -0
  86. marin_levanter-0.99/src/levanter/layers/attention.py +2604 -0
  87. marin_levanter-0.99/src/levanter/layers/gated_deltanet.py +902 -0
  88. marin_levanter-0.99/src/levanter/layers/kv_cache.py +165 -0
  89. marin_levanter-0.99/src/levanter/layers/normalization.py +48 -0
  90. marin_levanter-0.99/src/levanter/layers/rotary.py +292 -0
  91. marin_levanter-0.99/src/levanter/layers/sampler.py +81 -0
  92. marin_levanter-0.99/src/levanter/lora.py +530 -0
  93. marin_levanter-0.99/src/levanter/main/eval_lm.py +214 -0
  94. marin_levanter-0.99/src/levanter/main/export_hf_to_lm.py +122 -0
  95. marin_levanter-0.99/src/levanter/main/export_lm_to_hf.py +95 -0
  96. marin_levanter-0.99/src/levanter/main/inference_repl.py +570 -0
  97. marin_levanter-0.99/src/levanter/main/lora_lm.py +175 -0
  98. marin_levanter-0.99/src/levanter/main/perplexity_gap.py +500 -0
  99. marin_levanter-0.99/src/levanter/main/sample_lm.py +219 -0
  100. marin_levanter-0.99/src/levanter/main/sft.py +2 -0
  101. marin_levanter-0.99/src/levanter/main/train_asr.py +210 -0
  102. marin_levanter-0.99/src/levanter/main/train_dpo.py +510 -0
  103. marin_levanter-0.99/src/levanter/main/train_lm.py +328 -0
  104. marin_levanter-0.99/src/levanter/main/viz_logprobs.py +189 -0
  105. marin_levanter-0.99/src/levanter/metrics.py +204 -0
  106. marin_levanter-0.99/src/levanter/models/__init__.py +2 -0
  107. marin_levanter-0.99/src/levanter/models/apertus.py +485 -0
  108. marin_levanter-0.99/src/levanter/models/asr_model.py +125 -0
  109. marin_levanter-0.99/src/levanter/models/flash_attention.py +466 -0
  110. marin_levanter-0.99/src/levanter/models/gemma.py +997 -0
  111. marin_levanter-0.99/src/levanter/models/gpt2.py +368 -0
  112. marin_levanter-0.99/src/levanter/models/gpt2_hyena.py +196 -0
  113. marin_levanter-0.99/src/levanter/models/hyena.py +556 -0
  114. marin_levanter-0.99/src/levanter/models/linear.py +22 -0
  115. marin_levanter-0.99/src/levanter/models/llama.py +671 -0
  116. marin_levanter-0.99/src/levanter/models/lm_model.py +297 -0
  117. marin_levanter-0.99/src/levanter/models/loss.py +295 -0
  118. marin_levanter-0.99/src/levanter/models/mistral.py +236 -0
  119. marin_levanter-0.99/src/levanter/models/mixtral.py +677 -0
  120. marin_levanter-0.99/src/levanter/models/olmo.py +572 -0
  121. marin_levanter-0.99/src/levanter/models/olmo3.py +419 -0
  122. marin_levanter-0.99/src/levanter/models/qwen.py +394 -0
  123. marin_levanter-0.99/src/levanter/models/whisper.py +518 -0
  124. marin_levanter-0.99/src/levanter/optim/__init__.py +63 -0
  125. marin_levanter-0.99/src/levanter/optim/adam_mini.py +176 -0
  126. marin_levanter-0.99/src/levanter/optim/adamh.py +177 -0
  127. marin_levanter-0.99/src/levanter/optim/adopt.py +131 -0
  128. marin_levanter-0.99/src/levanter/optim/cautious.py +130 -0
  129. marin_levanter-0.99/src/levanter/optim/clip_update_norm.py +115 -0
  130. marin_levanter-0.99/src/levanter/optim/config.py +585 -0
  131. marin_levanter-0.99/src/levanter/optim/grugmuon.py +351 -0
  132. marin_levanter-0.99/src/levanter/optim/kron.py +1482 -0
  133. marin_levanter-0.99/src/levanter/optim/mars.py +127 -0
  134. marin_levanter-0.99/src/levanter/optim/model_averaging.py +115 -0
  135. marin_levanter-0.99/src/levanter/optim/muon.py +182 -0
  136. marin_levanter-0.99/src/levanter/optim/muonh.py +203 -0
  137. marin_levanter-0.99/src/levanter/optim/namo.py +557 -0
  138. marin_levanter-0.99/src/levanter/optim/rmsprop.py +136 -0
  139. marin_levanter-0.99/src/levanter/optim/scion.py +164 -0
  140. marin_levanter-0.99/src/levanter/optim/skipstep.py +201 -0
  141. marin_levanter-0.99/src/levanter/optim/soap.py +1039 -0
  142. marin_levanter-0.99/src/levanter/optim/util.py +300 -0
  143. marin_levanter-0.99/src/levanter/schedule.py +133 -0
  144. marin_levanter-0.99/src/levanter/shapes.py +35 -0
  145. marin_levanter-0.99/src/levanter/store/__init__.py +9 -0
  146. marin_levanter-0.99/src/levanter/store/cache.py +724 -0
  147. marin_levanter-0.99/src/levanter/store/jagged_array.py +694 -0
  148. marin_levanter-0.99/src/levanter/store/tree_store.py +201 -0
  149. marin_levanter-0.99/src/levanter/tensorstore_serialization.py +470 -0
  150. marin_levanter-0.99/src/levanter/tokenizers.py +1007 -0
  151. marin_levanter-0.99/src/levanter/tracker/__init__.py +42 -0
  152. marin_levanter-0.99/src/levanter/tracker/background.py +222 -0
  153. marin_levanter-0.99/src/levanter/tracker/helpers.py +90 -0
  154. marin_levanter-0.99/src/levanter/tracker/histogram.py +228 -0
  155. marin_levanter-0.99/src/levanter/tracker/json_file.py +59 -0
  156. marin_levanter-0.99/src/levanter/tracker/json_logger.py +140 -0
  157. marin_levanter-0.99/src/levanter/tracker/tensorboard.py +152 -0
  158. marin_levanter-0.99/src/levanter/tracker/tracker.py +182 -0
  159. marin_levanter-0.99/src/levanter/tracker/tracker_fns.py +321 -0
  160. marin_levanter-0.99/src/levanter/tracker/trackio.py +163 -0
  161. marin_levanter-0.99/src/levanter/tracker/wandb.py +454 -0
  162. marin_levanter-0.99/src/levanter/trainer.py +1133 -0
  163. marin_levanter-0.99/src/levanter/trainer_state.py +284 -0
  164. marin_levanter-0.99/src/levanter/utils/__init__.py +2 -0
  165. marin_levanter-0.99/src/levanter/utils/activation.py +72 -0
  166. marin_levanter-0.99/src/levanter/utils/background_iterable.py +171 -0
  167. marin_levanter-0.99/src/levanter/utils/cloud_utils.py +173 -0
  168. marin_levanter-0.99/src/levanter/utils/datetime_utils.py +44 -0
  169. marin_levanter-0.99/src/levanter/utils/flop_utils.py +37 -0
  170. marin_levanter-0.99/src/levanter/utils/fsspec_utils.py +68 -0
  171. marin_levanter-0.99/src/levanter/utils/hf_utils.py +39 -0
  172. marin_levanter-0.99/src/levanter/utils/index.py +49 -0
  173. marin_levanter-0.99/src/levanter/utils/jax_utils.py +621 -0
  174. marin_levanter-0.99/src/levanter/utils/logging.py +104 -0
  175. marin_levanter-0.99/src/levanter/utils/mesh.py +194 -0
  176. marin_levanter-0.99/src/levanter/utils/py_utils.py +182 -0
  177. marin_levanter-0.99/src/levanter/utils/stat_utils.py +36 -0
  178. marin_levanter-0.99/src/levanter/utils/thread_utils.py +77 -0
  179. marin_levanter-0.99/src/levanter/utils/token_init.py +81 -0
  180. marin_levanter-0.99/src/levanter/utils/tree_utils.py +202 -0
  181. marin_levanter-0.99/src/levanter/utils/types.py +57 -0
@@ -0,0 +1,291 @@
1
+ Metadata-Version: 2.3
2
+ Name: marin-levanter
3
+ Version: 0.99
4
+ Summary: Scalable Training for Foundation Models with Named Tensors and JAX
5
+ Author: David Hall, Jason Wang, Ahmed Ahmed, Ivan Zhou, Will Held, Virginia Adams
6
+ Author-email: David Hall <dlwh@cs.stanford.edu>, Jason Wang <jsywang@cs.stanford.edu>, Ivan Zhou <ivanz@stanford.edu>
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: Apache Software License
9
+ Classifier: Operating System :: POSIX :: Linux
10
+ Classifier: Operating System :: MacOS :: MacOS X
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Science/Research
13
+ Requires-Dist: marin-haliax==0.99
14
+ Requires-Dist: equinox>=0.11.7,!=0.12.0
15
+ Requires-Dist: jax>=0.9.2,<0.11
16
+ Requires-Dist: marin-fray==0.99
17
+ Requires-Dist: marin-rigging==0.99
18
+ Requires-Dist: marin-zephyr==0.99
19
+ Requires-Dist: einops
20
+ Requires-Dist: jaxtyping>=0.2.34
21
+ Requires-Dist: tokenizers>=0.15.2
22
+ Requires-Dist: kitoken>=0.10.2
23
+ Requires-Dist: transformers>=4.57.1,<5.0
24
+ Requires-Dist: chex>=0.1.86
25
+ Requires-Dist: optax>=0.1.9,<0.2.7
26
+ Requires-Dist: wandb>0.24.0
27
+ Requires-Dist: draccus>=0.11.5
28
+ Requires-Dist: pyarrow>=11.0.0
29
+ Requires-Dist: zstandard>=0.18.0
30
+ Requires-Dist: datasets>=3.1.0,<5.0
31
+ Requires-Dist: gcsfs>=2024.2,<2027
32
+ Requires-Dist: braceexpand>=0.1.7
33
+ Requires-Dist: chex>=0.1.86
34
+ Requires-Dist: jmp>=0.0.3
35
+ Requires-Dist: fsspec[http]>=2024.2,<2027
36
+ Requires-Dist: tensorstore>=0.1.73,<0.1.82
37
+ Requires-Dist: pytimeparse>=1.1.8
38
+ Requires-Dist: humanfriendly==10.0
39
+ Requires-Dist: safetensors[numpy]>=0.4.2,<0.7.0
40
+ Requires-Dist: tblib>=1.7.0,<4.0.0
41
+ Requires-Dist: dataclasses-json~=0.6.4
42
+ Requires-Dist: pydantic<3
43
+ Requires-Dist: filelock~=3.13
44
+ Requires-Dist: async-lru~=2.0
45
+ Requires-Dist: tqdm-loggable>=0.2
46
+ Requires-Dist: deepdiff>=8.6.2
47
+ Requires-Dist: lenses
48
+ Requires-Dist: jinja2
49
+ Requires-Dist: protobuf>=6
50
+ Requires-Dist: immutabledict
51
+ Requires-Dist: google-api-python-client>=2.175.0 ; extra == 'gcp'
52
+ Requires-Dist: google-cloud-storage ; extra == 'gcp'
53
+ Requires-Dist: google-cloud-storage-transfer ; extra == 'gcp'
54
+ Requires-Dist: google-auth ; extra == 'gcp'
55
+ Requires-Dist: jax[cuda13]==0.10.0 ; extra == 'gpu'
56
+ Requires-Dist: nvidia-cublas>=13.2.0.9 ; sys_platform == 'linux' and extra == 'gpu'
57
+ Requires-Dist: nvidia-nccl-cu13>=2.28.3 ; sys_platform == 'linux' and extra == 'gpu'
58
+ Requires-Dist: tokamax ; extra == 'kernels'
59
+ Requires-Dist: kitoken>=0.10.2 ; extra == 'kitoken'
60
+ Requires-Dist: xprof ; extra == 'profiling'
61
+ Requires-Dist: tensorboard>=2.16 ; extra == 'profiling'
62
+ Requires-Dist: tensorboardx>=2.6 ; extra == 'profiling'
63
+ Requires-Dist: fastapi>=0.100.0 ; extra == 'serve'
64
+ Requires-Dist: uvicorn[standard]>=0.23.0 ; extra == 'serve'
65
+ Requires-Dist: openai>=1.0.0 ; extra == 'serve'
66
+ Requires-Dist: torch>=2.7.0 ; extra == 'torch-test'
67
+ Requires-Dist: peft>=0.12.0 ; extra == 'torch-test'
68
+ Requires-Dist: jax==0.9.2 ; extra == 'tpu'
69
+ Requires-Dist: jaxlib==0.9.2 ; extra == 'tpu'
70
+ Requires-Dist: libtpu==0.0.38 ; extra == 'tpu'
71
+ Requires-Python: >=3.11
72
+ Project-URL: Bug Tracker, https://github.com/stanford-crfm/levanter/issues
73
+ Project-URL: Homepage, https://github.com/stanford-crfm/levanter
74
+ Provides-Extra: gcp
75
+ Provides-Extra: gpu
76
+ Provides-Extra: kernels
77
+ Provides-Extra: kitoken
78
+ Provides-Extra: lm-eval
79
+ Provides-Extra: profiling
80
+ Provides-Extra: serve
81
+ Provides-Extra: torch-test
82
+ Provides-Extra: tpu
83
+ Description-Content-Type: text/markdown
84
+
85
+ # Levanter
86
+
87
+ Levanter is developed and released from the
88
+ [marin-community/marin](https://github.com/marin-community/marin) monorepo
89
+ (`lib/levanter`) and published to PyPI as **`marin-levanter`**. This README
90
+ documents the package as shipped from that monorepo.
91
+
92
+ <a href="https://levanter.readthedocs.io/en/latest/?badge=latest">
93
+ <img alt="Documentation Status" src="https://readthedocs.org/projects/levanter/badge/?version=latest">
94
+ </a>
95
+ <a href="https://pypi.org/project/marin-levanter/">
96
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/marin-levanter?color=blue" />
97
+ </a>
98
+
99
+
100
+ <!--levanter-intro-start-->
101
+ > *You could not prevent a thunderstorm, but you could use the electricity; you could not direct the wind, but you could trim your sail so as to propel your vessel as you pleased, no matter which way the wind blew.* <br/>
102
+ > — Cora L. V. Hatch
103
+
104
+ Levanter is a framework for training large language models (LLMs) and other foundation models that strives for legibility, scalability, and reproducibility:
105
+
106
+ 1. **Legible**: Levanter uses our named tensor library [Haliax](https://github.com/stanford-crfm/haliax) to write easy-to-follow, composable deep learning code, while still being high performance.
107
+ 2. **Scalable**: Levanter scales to large models, and to be able to train on a variety of hardware, including GPUs and TPUs.
108
+ 3. **Reproducible**: Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.
109
+
110
+ We built Levanter with [JAX](https://github.com/jax-ml/jax), [Equinox](https://github.com/patrick-kidger/equinox), and [Haliax](https://github.com/stanford-crfm/haliax).
111
+
112
+ ## Documentation
113
+
114
+ Levanter's documentation is available at [levanter.readthedocs.io](https://levanter.readthedocs.io/en/latest/).
115
+ Haliax's documentation is available at [haliax.readthedocs.io](https://haliax.readthedocs.io/en/latest/).
116
+
117
+ ## Features
118
+
119
+ * **Distributed Training**: We support distributed training on TPUs and GPUs, including FSDP and tensor parallelism.
120
+ * **Compatibility**: Levanter supports importing and exporting models to/from the Hugging Face ecosystem, including tokenizers, datasets, and models via [SafeTensors](https://github.com/huggingface/safetensors).
121
+ * **Performance**: Levanter's performance rivals commercially-backed frameworks like MosaicML's Composer or Google's MaxText.
122
+ * **Resilience**: Levanter supports fast, distributed checkpointing and fast resume from checkpoints with no data seek, making Levanter robust to preemption and hardware failure.
123
+ * **Cached On-Demand Data Preprocessing**: We preprocess corpora online, but we cache the results of preprocessing so
124
+ that resumes are much faster and so that subsequent runs are even faster. As soon as the first part of the cache is complete, Levanter will start training.
125
+ * **Logging**: Levanter logs a rich and detailed set of metrics covering loss and performance. Levanter also supports a few different logging backends, including [WandB](https://wandb.ai/site) and [TensorBoard](https://www.tensorflow.org/tensorboard). (Adding a new logging backend is easy!) Levanter even exposes the ability
126
+ to log inside of JAX `jit`-ted functions.
127
+ * **Reproducibility**: On TPU, Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.
128
+ * **Distributed Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now.
129
+ * **Optimization**: We support [Optax](https://github.com/deepmind/optax) for optimization with AdamW, as well as newer optimizers like Muon, SOAP, and more.
130
+ * **Flexible**: Levanter supports tuning data mixtures without having to retokenize or shuffle data.
131
+
132
+ <!--levanter-intro-end-->
133
+
134
+ Levanter was created by [Stanford's Center for Research on Foundation Models (CRFM)](https://crfm.stanford.edu/)'s research engineering team.
135
+ You can also find us in the #levanter channel on the unofficial [Jax LLM Discord](https://discord.gg/CKazXcbbBm)
136
+
137
+ ## Getting Started
138
+
139
+ Here is a small set of examples to get you started. For more information about the various configuration options,
140
+ please see the [Getting Started](./docs/Getting-Started-Training.md) guide or the [In-Depth Configuration Guide](docs/reference/Configuration.md).
141
+ You can also use `--help` or poke around other configs to see all the options available to you.
142
+
143
+
144
+ ### Installing Levanter
145
+
146
+ <!--levanter-installation-start-->
147
+
148
+ After [installing JAX](https://github.com/google/jax/blob/main/README.md#installation) with the appropriate configuration
149
+ for your platform, install Levanter from PyPI:
150
+
151
+ ```bash
152
+ pip install marin-levanter
153
+ wandb login # optional, we use wandb for logging
154
+ ```
155
+
156
+ For development, clone the marin monorepo and use `uv sync` to install Levanter
157
+ alongside its sibling packages (Haliax, Iris, etc.) in editable form:
158
+
159
+ ```bash
160
+ git clone https://github.com/marin-community/marin.git
161
+ cd marin
162
+ uv sync
163
+ ```
164
+
165
+ <!--levanter-installation-end-->
166
+
167
+ Please refer to the [Installation Guide](docs/Installation.md) for more information on how to install Levanter.
168
+
169
+ If you're using a TPU, more complete documentation for setting that up is available [here](docs/Getting-Started-TPU-VM.md). GPU support is still in-progress; documentation is available [here](docs/Getting-Started-GPU.md).
170
+
171
+ <!--levanter-user-guide-start-->
172
+
173
+ ### Training a GPT2-nano
174
+
175
+ As a kind of hello world, here's how you can train a GPT-2 "nano"-sized model on a small dataset.
176
+
177
+ ```bash
178
+ python -m levanter.main.train_lm --config_path config/gpt2_nano.yaml
179
+
180
+ # alternatively, if you didn't use -e and are in a different directory
181
+ python -m levanter.main.train_lm --config_path gpt2_nano
182
+ ```
183
+
184
+ This will train a GPT2-nano model on the [WikiText-103](https://huggingface.co/datasets/Salesforce/wikitext) dataset.
185
+
186
+ ### Training a Llama-small on your own data
187
+
188
+ You can also change the dataset by changing the `dataset` field in the config file.
189
+ If your dataset is a [Hugging Face dataset](https://huggingface.co/docs/datasets/loading_datasets.html), you can use the `data.id` field to specify it:
190
+
191
+ ```bash
192
+ python -m levanter.main.train_lm --config_path config/llama_small_fast.yaml --data.id openwebtext
193
+
194
+ # optionally, you may specify a tokenizer and/or a cache directory, which may be local or on gcs
195
+ python -m levanter.main.train_lm --config_path config/llama_small_fast.yaml --data.id openwebtext --data.tokenizer "NousResearch/Llama-2-7b-hf" --data.cache_dir "gs://path/to/cache/dir"
196
+ ```
197
+
198
+ If instead your data is a list of URLs, you can use the `data.train_urls` and `data.validation_urls` fields to specify them.
199
+ Data URLS can be local files, gcs files, or http(s) URLs, or anything that [fsspec](https://filesystem-spec.readthedocs.io/en/latest/) supports.
200
+ Levanter (really, fsspec) will automatically uncompress `.gz` and `.zstd` files, and probably other formats too.
201
+
202
+ ```bash
203
+ python -m levanter.main.train_lm --config_path config/llama_small_fast.yaml --data.train_urls ["https://path/to/train/data_*.jsonl.gz"] --data.validation_urls ["https://path/to/val/data_*.jsonl.gz"]
204
+ ```
205
+
206
+ ### Customizing a Config File
207
+
208
+ You can modify the config file to change the model, the dataset, the training parameters, and more. Here's
209
+ the `llama_small_fast.yaml` file:
210
+
211
+ ```yaml
212
+ data:
213
+ train_urls:
214
+ - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
215
+ validation_urls:
216
+ - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
217
+ cache_dir: "gs://pubmed-mosaic/tokenized/openwebtext/"
218
+ model:
219
+ type: llama
220
+ hidden_dim: 768
221
+ intermediate_dim: 2048
222
+ num_heads: 12
223
+ num_kv_heads: 12
224
+ num_layers: 12
225
+ seq_len: 1024
226
+ gradient_checkpointing: true
227
+ trainer:
228
+ tracker:
229
+ type: wandb
230
+ project: "levanter"
231
+ tags: [ "openwebtext", "llama" ]
232
+
233
+ mp: p=f32,c=bfloat16
234
+ mesh:
235
+ axes: {data: -1, replica: 1, model: 1} # inherited defaults; override if you need TP
236
+ per_device_parallelism: 4
237
+
238
+ train_batch_size: 512
239
+ optimizer:
240
+ learning_rate: 6E-4
241
+ weight_decay: 0.1
242
+ min_lr_ratio: 0.1
243
+ ```
244
+
245
+ ### Other Architectures
246
+
247
+ Currently, we support the following architectures:
248
+
249
+ * GPT-2
250
+ * [LLama](https://ai.meta.com/llama/), including Llama 1, 2 and 3
251
+ * [Gemma](https://ai.google.dev/gemma), including Gemma 1, 2 and Gemma 3.
252
+ * [Qwen2](https://huggingface.co/Qwen/Qwen2.5-7B)
253
+ * [Qwen3](https://huggingface.co/Qwen/Qwen3-8B)
254
+ * [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
255
+ * [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
256
+ * [Olmo2](https://huggingface.co/allenai/Olmo-2-1124-7B)
257
+
258
+ We plan to add more in the future.
259
+
260
+ For speech, we currently only support [Whisper](https://huggingface.co/openai/whisper-large-v3).
261
+
262
+ #### Continued Pretraining with Llama
263
+
264
+ Here's an example of how to continue pretraining a Llama 1 or Llama 2 model on the OpenWebText dataset:
265
+
266
+ ```bash
267
+ python -m levanter.main.train_lm --config_path config/llama2_7b_continued.yaml
268
+ ```
269
+
270
+
271
+ ## Distributed and Cloud Training
272
+
273
+ ### Training on a TPU Cloud VM
274
+
275
+ Please see the [TPU Getting Started](docs/Getting-Started-TPU-VM.md) guide for more information on how to set up a TPU Cloud VM and run Levanter there.
276
+
277
+ ### Training with CUDA
278
+
279
+ Please see the [CUDA Getting Started](docs/Getting-Started-GPU.md) guide for more information on how to set up a CUDA environment and run Levanter there.
280
+
281
+ <!--levanter-user-guide-end-->
282
+
283
+ ## Contributing
284
+
285
+ We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for
286
+ more information. Issues and pull requests are tracked at
287
+ [marin-community/marin](https://github.com/marin-community/marin/issues).
288
+
289
+ ## License
290
+
291
+ Levanter is licensed under the Apache License, Version 2.0. See [LICENSE](LICENSE) for the full license text.
@@ -0,0 +1,207 @@
1
+ # Levanter
2
+
3
+ Levanter is developed and released from the
4
+ [marin-community/marin](https://github.com/marin-community/marin) monorepo
5
+ (`lib/levanter`) and published to PyPI as **`marin-levanter`**. This README
6
+ documents the package as shipped from that monorepo.
7
+
8
+ <a href="https://levanter.readthedocs.io/en/latest/?badge=latest">
9
+ <img alt="Documentation Status" src="https://readthedocs.org/projects/levanter/badge/?version=latest">
10
+ </a>
11
+ <a href="https://pypi.org/project/marin-levanter/">
12
+ <img alt="PyPI" src="https://img.shields.io/pypi/v/marin-levanter?color=blue" />
13
+ </a>
14
+
15
+
16
+ <!--levanter-intro-start-->
17
+ > *You could not prevent a thunderstorm, but you could use the electricity; you could not direct the wind, but you could trim your sail so as to propel your vessel as you pleased, no matter which way the wind blew.* <br/>
18
+ > — Cora L. V. Hatch
19
+
20
+ Levanter is a framework for training large language models (LLMs) and other foundation models that strives for legibility, scalability, and reproducibility:
21
+
22
+ 1. **Legible**: Levanter uses our named tensor library [Haliax](https://github.com/stanford-crfm/haliax) to write easy-to-follow, composable deep learning code, while still being high performance.
23
+ 2. **Scalable**: Levanter scales to large models, and to be able to train on a variety of hardware, including GPUs and TPUs.
24
+ 3. **Reproducible**: Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.
25
+
26
+ We built Levanter with [JAX](https://github.com/jax-ml/jax), [Equinox](https://github.com/patrick-kidger/equinox), and [Haliax](https://github.com/stanford-crfm/haliax).
27
+
28
+ ## Documentation
29
+
30
+ Levanter's documentation is available at [levanter.readthedocs.io](https://levanter.readthedocs.io/en/latest/).
31
+ Haliax's documentation is available at [haliax.readthedocs.io](https://haliax.readthedocs.io/en/latest/).
32
+
33
+ ## Features
34
+
35
+ * **Distributed Training**: We support distributed training on TPUs and GPUs, including FSDP and tensor parallelism.
36
+ * **Compatibility**: Levanter supports importing and exporting models to/from the Hugging Face ecosystem, including tokenizers, datasets, and models via [SafeTensors](https://github.com/huggingface/safetensors).
37
+ * **Performance**: Levanter's performance rivals commercially-backed frameworks like MosaicML's Composer or Google's MaxText.
38
+ * **Resilience**: Levanter supports fast, distributed checkpointing and fast resume from checkpoints with no data seek, making Levanter robust to preemption and hardware failure.
39
+ * **Cached On-Demand Data Preprocessing**: We preprocess corpora online, but we cache the results of preprocessing so
40
+ that resumes are much faster and so that subsequent runs are even faster. As soon as the first part of the cache is complete, Levanter will start training.
41
+ * **Logging**: Levanter logs a rich and detailed set of metrics covering loss and performance. Levanter also supports a few different logging backends, including [WandB](https://wandb.ai/site) and [TensorBoard](https://www.tensorflow.org/tensorboard). (Adding a new logging backend is easy!) Levanter even exposes the ability
42
+ to log inside of JAX `jit`-ted functions.
43
+ * **Reproducibility**: On TPU, Levanter is bitwise deterministic, meaning that the same configuration will always produce the same results, even in the face of preemption and resumption.
44
+ * **Distributed Checkpointing**: Distributed checkpointing is supported via Google's [TensorStore](https://google.github.io/tensorstore/) library. Training can even be resumed on a different number of hosts, though this breaks reproducibility for now.
45
+ * **Optimization**: We support [Optax](https://github.com/deepmind/optax) for optimization with AdamW, as well as newer optimizers like Muon, SOAP, and more.
46
+ * **Flexible**: Levanter supports tuning data mixtures without having to retokenize or shuffle data.
47
+
48
+ <!--levanter-intro-end-->
49
+
50
+ Levanter was created by [Stanford's Center for Research on Foundation Models (CRFM)](https://crfm.stanford.edu/)'s research engineering team.
51
+ You can also find us in the #levanter channel on the unofficial [Jax LLM Discord](https://discord.gg/CKazXcbbBm)
52
+
53
+ ## Getting Started
54
+
55
+ Here is a small set of examples to get you started. For more information about the various configuration options,
56
+ please see the [Getting Started](./docs/Getting-Started-Training.md) guide or the [In-Depth Configuration Guide](docs/reference/Configuration.md).
57
+ You can also use `--help` or poke around other configs to see all the options available to you.
58
+
59
+
60
+ ### Installing Levanter
61
+
62
+ <!--levanter-installation-start-->
63
+
64
+ After [installing JAX](https://github.com/google/jax/blob/main/README.md#installation) with the appropriate configuration
65
+ for your platform, install Levanter from PyPI:
66
+
67
+ ```bash
68
+ pip install marin-levanter
69
+ wandb login # optional, we use wandb for logging
70
+ ```
71
+
72
+ For development, clone the marin monorepo and use `uv sync` to install Levanter
73
+ alongside its sibling packages (Haliax, Iris, etc.) in editable form:
74
+
75
+ ```bash
76
+ git clone https://github.com/marin-community/marin.git
77
+ cd marin
78
+ uv sync
79
+ ```
80
+
81
+ <!--levanter-installation-end-->
82
+
83
+ Please refer to the [Installation Guide](docs/Installation.md) for more information on how to install Levanter.
84
+
85
+ If you're using a TPU, more complete documentation for setting that up is available [here](docs/Getting-Started-TPU-VM.md). GPU support is still in-progress; documentation is available [here](docs/Getting-Started-GPU.md).
86
+
87
+ <!--levanter-user-guide-start-->
88
+
89
+ ### Training a GPT2-nano
90
+
91
+ As a kind of hello world, here's how you can train a GPT-2 "nano"-sized model on a small dataset.
92
+
93
+ ```bash
94
+ python -m levanter.main.train_lm --config_path config/gpt2_nano.yaml
95
+
96
+ # alternatively, if you didn't use -e and are in a different directory
97
+ python -m levanter.main.train_lm --config_path gpt2_nano
98
+ ```
99
+
100
+ This will train a GPT2-nano model on the [WikiText-103](https://huggingface.co/datasets/Salesforce/wikitext) dataset.
101
+
102
+ ### Training a Llama-small on your own data
103
+
104
+ You can also change the dataset by changing the `dataset` field in the config file.
105
+ If your dataset is a [Hugging Face dataset](https://huggingface.co/docs/datasets/loading_datasets.html), you can use the `data.id` field to specify it:
106
+
107
+ ```bash
108
+ python -m levanter.main.train_lm --config_path config/llama_small_fast.yaml --data.id openwebtext
109
+
110
+ # optionally, you may specify a tokenizer and/or a cache directory, which may be local or on gcs
111
+ python -m levanter.main.train_lm --config_path config/llama_small_fast.yaml --data.id openwebtext --data.tokenizer "NousResearch/Llama-2-7b-hf" --data.cache_dir "gs://path/to/cache/dir"
112
+ ```
113
+
114
+ If instead your data is a list of URLs, you can use the `data.train_urls` and `data.validation_urls` fields to specify them.
115
+ Data URLS can be local files, gcs files, or http(s) URLs, or anything that [fsspec](https://filesystem-spec.readthedocs.io/en/latest/) supports.
116
+ Levanter (really, fsspec) will automatically uncompress `.gz` and `.zstd` files, and probably other formats too.
117
+
118
+ ```bash
119
+ python -m levanter.main.train_lm --config_path config/llama_small_fast.yaml --data.train_urls ["https://path/to/train/data_*.jsonl.gz"] --data.validation_urls ["https://path/to/val/data_*.jsonl.gz"]
120
+ ```
121
+
122
+ ### Customizing a Config File
123
+
124
+ You can modify the config file to change the model, the dataset, the training parameters, and more. Here's
125
+ the `llama_small_fast.yaml` file:
126
+
127
+ ```yaml
128
+ data:
129
+ train_urls:
130
+ - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
131
+ validation_urls:
132
+ - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
133
+ cache_dir: "gs://pubmed-mosaic/tokenized/openwebtext/"
134
+ model:
135
+ type: llama
136
+ hidden_dim: 768
137
+ intermediate_dim: 2048
138
+ num_heads: 12
139
+ num_kv_heads: 12
140
+ num_layers: 12
141
+ seq_len: 1024
142
+ gradient_checkpointing: true
143
+ trainer:
144
+ tracker:
145
+ type: wandb
146
+ project: "levanter"
147
+ tags: [ "openwebtext", "llama" ]
148
+
149
+ mp: p=f32,c=bfloat16
150
+ mesh:
151
+ axes: {data: -1, replica: 1, model: 1} # inherited defaults; override if you need TP
152
+ per_device_parallelism: 4
153
+
154
+ train_batch_size: 512
155
+ optimizer:
156
+ learning_rate: 6E-4
157
+ weight_decay: 0.1
158
+ min_lr_ratio: 0.1
159
+ ```
160
+
161
+ ### Other Architectures
162
+
163
+ Currently, we support the following architectures:
164
+
165
+ * GPT-2
166
+ * [LLama](https://ai.meta.com/llama/), including Llama 1, 2 and 3
167
+ * [Gemma](https://ai.google.dev/gemma), including Gemma 1, 2 and Gemma 3.
168
+ * [Qwen2](https://huggingface.co/Qwen/Qwen2.5-7B)
169
+ * [Qwen3](https://huggingface.co/Qwen/Qwen3-8B)
170
+ * [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
171
+ * [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
172
+ * [Olmo2](https://huggingface.co/allenai/Olmo-2-1124-7B)
173
+
174
+ We plan to add more in the future.
175
+
176
+ For speech, we currently only support [Whisper](https://huggingface.co/openai/whisper-large-v3).
177
+
178
+ #### Continued Pretraining with Llama
179
+
180
+ Here's an example of how to continue pretraining a Llama 1 or Llama 2 model on the OpenWebText dataset:
181
+
182
+ ```bash
183
+ python -m levanter.main.train_lm --config_path config/llama2_7b_continued.yaml
184
+ ```
185
+
186
+
187
+ ## Distributed and Cloud Training
188
+
189
+ ### Training on a TPU Cloud VM
190
+
191
+ Please see the [TPU Getting Started](docs/Getting-Started-TPU-VM.md) guide for more information on how to set up a TPU Cloud VM and run Levanter there.
192
+
193
+ ### Training with CUDA
194
+
195
+ Please see the [CUDA Getting Started](docs/Getting-Started-GPU.md) guide for more information on how to set up a CUDA environment and run Levanter there.
196
+
197
+ <!--levanter-user-guide-end-->
198
+
199
+ ## Contributing
200
+
201
+ We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for
202
+ more information. Issues and pull requests are tracked at
203
+ [marin-community/marin](https://github.com/marin-community/marin/issues).
204
+
205
+ ## License
206
+
207
+ Levanter is licensed under the Apache License, Version 2.0. See [LICENSE](LICENSE) for the full license text.