aetherlm 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherlm-0.1.0/PKG-INFO +201 -0
- aetherlm-0.1.0/README.md +159 -0
- aetherlm-0.1.0/aetherlm/__init__.py +55 -0
- aetherlm-0.1.0/aetherlm/checkpoint/__init__.py +3 -0
- aetherlm-0.1.0/aetherlm/checkpoint/manager.py +123 -0
- aetherlm-0.1.0/aetherlm/cli/__init__.py +3 -0
- aetherlm-0.1.0/aetherlm/cli/main.py +274 -0
- aetherlm-0.1.0/aetherlm/core/__init__.py +5 -0
- aetherlm-0.1.0/aetherlm/core/config.py +202 -0
- aetherlm-0.1.0/aetherlm/core/engine.py +366 -0
- aetherlm-0.1.0/aetherlm/core/sharding.py +68 -0
- aetherlm-0.1.0/aetherlm/data/__init__.py +6 -0
- aetherlm-0.1.0/aetherlm/data/causal.py +123 -0
- aetherlm-0.1.0/aetherlm/data/contrastive.py +150 -0
- aetherlm-0.1.0/aetherlm/data/mlm.py +164 -0
- aetherlm-0.1.0/aetherlm/data/pipeline.py +63 -0
- aetherlm-0.1.0/aetherlm/eval/__init__.py +5 -0
- aetherlm-0.1.0/aetherlm/eval/mteb.py +181 -0
- aetherlm-0.1.0/aetherlm/eval/runner.py +416 -0
- aetherlm-0.1.0/aetherlm/eval/tasks.py +164 -0
- aetherlm-0.1.0/aetherlm/losses/__init__.py +5 -0
- aetherlm-0.1.0/aetherlm/losses/causal.py +43 -0
- aetherlm-0.1.0/aetherlm/losses/contrastive.py +73 -0
- aetherlm-0.1.0/aetherlm/losses/mlm.py +53 -0
- aetherlm-0.1.0/aetherlm/metrics/__init__.py +3 -0
- aetherlm-0.1.0/aetherlm/metrics/tracker.py +209 -0
- aetherlm-0.1.0/aetherlm/models/__init__.py +6 -0
- aetherlm-0.1.0/aetherlm/models/base.py +157 -0
- aetherlm-0.1.0/aetherlm/models/bert.py +121 -0
- aetherlm-0.1.0/aetherlm/models/causal.py +195 -0
- aetherlm-0.1.0/aetherlm/models/transformer.py +128 -0
- aetherlm-0.1.0/aetherlm/optim/__init__.py +5 -0
- aetherlm-0.1.0/aetherlm/optim/optimizers.py +108 -0
- aetherlm-0.1.0/aetherlm/optim/schedules.py +101 -0
- aetherlm-0.1.0/aetherlm/optim/switching.py +57 -0
- aetherlm-0.1.0/aetherlm/tpu/__init__.py +4 -0
- aetherlm-0.1.0/aetherlm/tpu/precision.py +41 -0
- aetherlm-0.1.0/aetherlm/tpu/topology.py +123 -0
- aetherlm-0.1.0/aetherlm/training/__init__.py +7 -0
- aetherlm-0.1.0/aetherlm/training/steps.py +119 -0
- aetherlm-0.1.0/aetherlm.egg-info/PKG-INFO +201 -0
- aetherlm-0.1.0/aetherlm.egg-info/SOURCES.txt +46 -0
- aetherlm-0.1.0/aetherlm.egg-info/dependency_links.txt +1 -0
- aetherlm-0.1.0/aetherlm.egg-info/entry_points.txt +2 -0
- aetherlm-0.1.0/aetherlm.egg-info/requires.txt +26 -0
- aetherlm-0.1.0/aetherlm.egg-info/top_level.txt +1 -0
- aetherlm-0.1.0/pyproject.toml +67 -0
- aetherlm-0.1.0/setup.cfg +4 -0
aetherlm-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: aetherlm
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.
|
|
5
|
+
Author: aetherBERT contributors
|
|
6
|
+
License: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/aetherBERT/aetherlm
|
|
8
|
+
Project-URL: Repository, https://github.com/aetherBERT/aetherlm
|
|
9
|
+
Keywords: jax,flax,tpu,deep-learning,distributed-training,bert,transformers
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Requires-Python: >=3.10
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
Requires-Dist: jax[tpu]>=0.4.20
|
|
21
|
+
Requires-Dist: flax>=0.10.0
|
|
22
|
+
Requires-Dist: optax>=0.2.0
|
|
23
|
+
Requires-Dist: orbax-checkpoint>=0.6.0
|
|
24
|
+
Requires-Dist: transformers>=4.36.0
|
|
25
|
+
Requires-Dist: datasets>=2.16.0
|
|
26
|
+
Requires-Dist: wandb>=0.16.0
|
|
27
|
+
Requires-Dist: numpy>=1.24.0
|
|
28
|
+
Requires-Dist: pyyaml>=6.0
|
|
29
|
+
Provides-Extra: eval
|
|
30
|
+
Requires-Dist: mteb>=1.12.0; extra == "eval"
|
|
31
|
+
Requires-Dist: torch>=2.0.0; extra == "eval"
|
|
32
|
+
Requires-Dist: scikit-learn>=1.3.0; extra == "eval"
|
|
33
|
+
Requires-Dist: scipy>=1.11.0; extra == "eval"
|
|
34
|
+
Provides-Extra: nmn
|
|
35
|
+
Requires-Dist: nmn; extra == "nmn"
|
|
36
|
+
Provides-Extra: dev
|
|
37
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
38
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
39
|
+
Requires-Dist: jupyter>=1.0.0; extra == "dev"
|
|
40
|
+
Provides-Extra: all
|
|
41
|
+
Requires-Dist: aetherlm[dev,eval,nmn]; extra == "all"
|
|
42
|
+
|
|
43
|
+
# AetherLM
|
|
44
|
+
|
|
45
|
+
A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.
|
|
46
|
+
|
|
47
|
+
AetherLM abstracts away the complexity of distributed training on TPU pods, providing a simple `initialize()` API for pretraining and fine-tuning transformer models.
|
|
48
|
+
|
|
49
|
+
## Features
|
|
50
|
+
|
|
51
|
+
- **One-call initialization** — `aetherlm.initialize(config)` sets up mesh, model, optimizer, and metrics
|
|
52
|
+
- **TPU-native** — Built on JAX with automatic TPU topology detection and optimal sharding
|
|
53
|
+
- **Multiple training modes** — MLM pretraining, contrastive learning, causal LM
|
|
54
|
+
- **Built-in models** — `AetherBERT` (bidirectional) and `AetherCausalLM` (autoregressive with generation)
|
|
55
|
+
- **YAML/dict configuration** — DeepSpeed-style config system
|
|
56
|
+
- **MTEB evaluation** — Integrated embedding evaluation with 56 tasks
|
|
57
|
+
- **Checkpoint management** — Automatic saving, rotation, and restore with Orbax
|
|
58
|
+
|
|
59
|
+
## Quick Start
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
import aetherlm
|
|
63
|
+
|
|
64
|
+
# Configure and initialize (auto-detects TPU)
|
|
65
|
+
engine = aetherlm.initialize(config={
|
|
66
|
+
"model": {"model_type": "bert", "embed_dim": 768, "num_layers": 12},
|
|
67
|
+
"training": {"mode": "mlm", "batch_size": 32, "learning_rate": 1e-4},
|
|
68
|
+
"tpu": {"precision": "bf16"},
|
|
69
|
+
})
|
|
70
|
+
|
|
71
|
+
# Load data
|
|
72
|
+
from aetherlm.data import load_mlm_datasets
|
|
73
|
+
train_data, val_data = load_mlm_datasets(maxlen=512, mask_prob=0.15, vocab_size=50265)
|
|
74
|
+
|
|
75
|
+
# Train (handles sharding, logging, checkpointing)
|
|
76
|
+
engine.train(train_data, val_data)
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
## Installation
|
|
80
|
+
|
|
81
|
+
```bash
|
|
82
|
+
# From PyPI
|
|
83
|
+
pip install aetherlm
|
|
84
|
+
|
|
85
|
+
# With eval support (MTEB, sklearn, scipy)
|
|
86
|
+
pip install aetherlm[eval]
|
|
87
|
+
|
|
88
|
+
# From source
|
|
89
|
+
pip install -e ".[all]"
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
## Training Modes
|
|
93
|
+
|
|
94
|
+
### MLM Pretraining (BERT-style)
|
|
95
|
+
|
|
96
|
+
```bash
|
|
97
|
+
aetherlm --mode mlm --config configs/default.yaml
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
### Causal Language Modeling (GPT-style)
|
|
101
|
+
|
|
102
|
+
```bash
|
|
103
|
+
aetherlm --mode causal --config configs/causal_small.yaml
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
### Contrastive Learning (requires checkpoint)
|
|
107
|
+
|
|
108
|
+
```bash
|
|
109
|
+
aetherlm --mode contrastive --checkpoint ./checkpoints/step_10000
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
### Evaluation
|
|
113
|
+
|
|
114
|
+
```bash
|
|
115
|
+
# Quick MTEB eval (3 tasks, ~1 min)
|
|
116
|
+
aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset quick
|
|
117
|
+
|
|
118
|
+
# Full leaderboard (56 tasks)
|
|
119
|
+
aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset leaderboard
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
## Configuration
|
|
123
|
+
|
|
124
|
+
Aether uses a dataclass-based config system. Create configs from YAML, JSON, or Python dicts:
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
from aetherlm import AetherConfig
|
|
128
|
+
|
|
129
|
+
# From YAML
|
|
130
|
+
config = AetherConfig.from_yaml("configs/default.yaml")
|
|
131
|
+
|
|
132
|
+
# From dict
|
|
133
|
+
config = AetherConfig.from_dict({
|
|
134
|
+
"model": {"embed_dim": 768, "num_layers": 12},
|
|
135
|
+
"training": {"mode": "mlm", "batch_size": 32},
|
|
136
|
+
})
|
|
137
|
+
|
|
138
|
+
# Save for reproducibility
|
|
139
|
+
config.to_yaml("my_experiment.yaml")
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
See `configs/` for example configurations.
|
|
143
|
+
|
|
144
|
+
## Project Structure
|
|
145
|
+
|
|
146
|
+
```
|
|
147
|
+
aetherlm/
|
|
148
|
+
__init__.py # Top-level API: initialize(), AetherConfig, models
|
|
149
|
+
core/
|
|
150
|
+
config.py # Dataclass configuration system
|
|
151
|
+
engine.py # Training engine (the heart of the library)
|
|
152
|
+
sharding.py # Automatic mesh sharding
|
|
153
|
+
models/
|
|
154
|
+
base.py # Abstract model interface + utilities
|
|
155
|
+
transformer.py # Transformer blocks, embeddings
|
|
156
|
+
bert.py # AetherBERT (bidirectional MLM)
|
|
157
|
+
causal.py # AetherCausalLM (autoregressive + generation)
|
|
158
|
+
optim/
|
|
159
|
+
optimizers.py # Optimizer factory (AdamW, Adam, SGD)
|
|
160
|
+
schedules.py # LR schedules (warmup-cosine, linear, constant)
|
|
161
|
+
switching.py # Plateau detection + optimizer switching
|
|
162
|
+
data/
|
|
163
|
+
pipeline.py # Tokenizer caching, batch iterators
|
|
164
|
+
mlm.py # MLM masking and dataset processing
|
|
165
|
+
contrastive.py # Contrastive pair creation (self-supervised + AllNLI)
|
|
166
|
+
causal.py # Causal LM next-token prediction format
|
|
167
|
+
losses/
|
|
168
|
+
mlm.py # Efficient gather-based MLM loss
|
|
169
|
+
contrastive.py # SimCLR-style contrastive loss
|
|
170
|
+
causal.py # Causal LM cross-entropy loss
|
|
171
|
+
training/
|
|
172
|
+
steps.py # JIT-compiled train/eval steps for all modes
|
|
173
|
+
checkpoint/
|
|
174
|
+
manager.py # Orbax checkpoint save/load/rotation
|
|
175
|
+
metrics/
|
|
176
|
+
tracker.py # Throughput, loss, ETA, WandB logging
|
|
177
|
+
eval/
|
|
178
|
+
tasks.py # MTEB task lists and presets
|
|
179
|
+
mteb.py # MTEB EncoderProtocol wrapper
|
|
180
|
+
runner.py # Evaluation orchestrators
|
|
181
|
+
tpu/
|
|
182
|
+
topology.py # TPU detection, mesh creation
|
|
183
|
+
precision.py # bfloat16/mixed precision config
|
|
184
|
+
cli/
|
|
185
|
+
main.py # CLI entry point
|
|
186
|
+
configs/ # Example YAML configurations
|
|
187
|
+
notebooks/ # Tutorial notebooks
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
## Notebooks
|
|
191
|
+
|
|
192
|
+
| Notebook | Description |
|
|
193
|
+
|----------|-------------|
|
|
194
|
+
| `01_quickstart.ipynb` | Core features: models, config, initialize |
|
|
195
|
+
| `02_mlm_pretraining.ipynb` | Full MLM pretraining pipeline |
|
|
196
|
+
| `03_causal_lm.ipynb` | Causal LM training + text generation |
|
|
197
|
+
| `04_evaluation.ipynb` | MTEB and custom evaluation |
|
|
198
|
+
|
|
199
|
+
## License
|
|
200
|
+
|
|
201
|
+
Apache 2.0
|
aetherlm-0.1.0/README.md
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# AetherLM
|
|
2
|
+
|
|
3
|
+
A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.
|
|
4
|
+
|
|
5
|
+
AetherLM abstracts away the complexity of distributed training on TPU pods, providing a simple `initialize()` API for pretraining and fine-tuning transformer models.
|
|
6
|
+
|
|
7
|
+
## Features
|
|
8
|
+
|
|
9
|
+
- **One-call initialization** — `aetherlm.initialize(config)` sets up mesh, model, optimizer, and metrics
|
|
10
|
+
- **TPU-native** — Built on JAX with automatic TPU topology detection and optimal sharding
|
|
11
|
+
- **Multiple training modes** — MLM pretraining, contrastive learning, causal LM
|
|
12
|
+
- **Built-in models** — `AetherBERT` (bidirectional) and `AetherCausalLM` (autoregressive with generation)
|
|
13
|
+
- **YAML/dict configuration** — DeepSpeed-style config system
|
|
14
|
+
- **MTEB evaluation** — Integrated embedding evaluation with 56 tasks
|
|
15
|
+
- **Checkpoint management** — Automatic saving, rotation, and restore with Orbax
|
|
16
|
+
|
|
17
|
+
## Quick Start
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
import aetherlm
|
|
21
|
+
|
|
22
|
+
# Configure and initialize (auto-detects TPU)
|
|
23
|
+
engine = aetherlm.initialize(config={
|
|
24
|
+
"model": {"model_type": "bert", "embed_dim": 768, "num_layers": 12},
|
|
25
|
+
"training": {"mode": "mlm", "batch_size": 32, "learning_rate": 1e-4},
|
|
26
|
+
"tpu": {"precision": "bf16"},
|
|
27
|
+
})
|
|
28
|
+
|
|
29
|
+
# Load data
|
|
30
|
+
from aetherlm.data import load_mlm_datasets
|
|
31
|
+
train_data, val_data = load_mlm_datasets(maxlen=512, mask_prob=0.15, vocab_size=50265)
|
|
32
|
+
|
|
33
|
+
# Train (handles sharding, logging, checkpointing)
|
|
34
|
+
engine.train(train_data, val_data)
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Installation
|
|
38
|
+
|
|
39
|
+
```bash
|
|
40
|
+
# From PyPI
|
|
41
|
+
pip install aetherlm
|
|
42
|
+
|
|
43
|
+
# With eval support (MTEB, sklearn, scipy)
|
|
44
|
+
pip install aetherlm[eval]
|
|
45
|
+
|
|
46
|
+
# From source
|
|
47
|
+
pip install -e ".[all]"
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Training Modes
|
|
51
|
+
|
|
52
|
+
### MLM Pretraining (BERT-style)
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
aetherlm --mode mlm --config configs/default.yaml
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
### Causal Language Modeling (GPT-style)
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
aetherlm --mode causal --config configs/causal_small.yaml
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
### Contrastive Learning (requires checkpoint)
|
|
65
|
+
|
|
66
|
+
```bash
|
|
67
|
+
aetherlm --mode contrastive --checkpoint ./checkpoints/step_10000
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
### Evaluation
|
|
71
|
+
|
|
72
|
+
```bash
|
|
73
|
+
# Quick MTEB eval (3 tasks, ~1 min)
|
|
74
|
+
aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset quick
|
|
75
|
+
|
|
76
|
+
# Full leaderboard (56 tasks)
|
|
77
|
+
aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset leaderboard
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## Configuration
|
|
81
|
+
|
|
82
|
+
Aether uses a dataclass-based config system. Create configs from YAML, JSON, or Python dicts:
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
from aetherlm import AetherConfig
|
|
86
|
+
|
|
87
|
+
# From YAML
|
|
88
|
+
config = AetherConfig.from_yaml("configs/default.yaml")
|
|
89
|
+
|
|
90
|
+
# From dict
|
|
91
|
+
config = AetherConfig.from_dict({
|
|
92
|
+
"model": {"embed_dim": 768, "num_layers": 12},
|
|
93
|
+
"training": {"mode": "mlm", "batch_size": 32},
|
|
94
|
+
})
|
|
95
|
+
|
|
96
|
+
# Save for reproducibility
|
|
97
|
+
config.to_yaml("my_experiment.yaml")
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
See `configs/` for example configurations.
|
|
101
|
+
|
|
102
|
+
## Project Structure
|
|
103
|
+
|
|
104
|
+
```
|
|
105
|
+
aetherlm/
|
|
106
|
+
__init__.py # Top-level API: initialize(), AetherConfig, models
|
|
107
|
+
core/
|
|
108
|
+
config.py # Dataclass configuration system
|
|
109
|
+
engine.py # Training engine (the heart of the library)
|
|
110
|
+
sharding.py # Automatic mesh sharding
|
|
111
|
+
models/
|
|
112
|
+
base.py # Abstract model interface + utilities
|
|
113
|
+
transformer.py # Transformer blocks, embeddings
|
|
114
|
+
bert.py # AetherBERT (bidirectional MLM)
|
|
115
|
+
causal.py # AetherCausalLM (autoregressive + generation)
|
|
116
|
+
optim/
|
|
117
|
+
optimizers.py # Optimizer factory (AdamW, Adam, SGD)
|
|
118
|
+
schedules.py # LR schedules (warmup-cosine, linear, constant)
|
|
119
|
+
switching.py # Plateau detection + optimizer switching
|
|
120
|
+
data/
|
|
121
|
+
pipeline.py # Tokenizer caching, batch iterators
|
|
122
|
+
mlm.py # MLM masking and dataset processing
|
|
123
|
+
contrastive.py # Contrastive pair creation (self-supervised + AllNLI)
|
|
124
|
+
causal.py # Causal LM next-token prediction format
|
|
125
|
+
losses/
|
|
126
|
+
mlm.py # Efficient gather-based MLM loss
|
|
127
|
+
contrastive.py # SimCLR-style contrastive loss
|
|
128
|
+
causal.py # Causal LM cross-entropy loss
|
|
129
|
+
training/
|
|
130
|
+
steps.py # JIT-compiled train/eval steps for all modes
|
|
131
|
+
checkpoint/
|
|
132
|
+
manager.py # Orbax checkpoint save/load/rotation
|
|
133
|
+
metrics/
|
|
134
|
+
tracker.py # Throughput, loss, ETA, WandB logging
|
|
135
|
+
eval/
|
|
136
|
+
tasks.py # MTEB task lists and presets
|
|
137
|
+
mteb.py # MTEB EncoderProtocol wrapper
|
|
138
|
+
runner.py # Evaluation orchestrators
|
|
139
|
+
tpu/
|
|
140
|
+
topology.py # TPU detection, mesh creation
|
|
141
|
+
precision.py # bfloat16/mixed precision config
|
|
142
|
+
cli/
|
|
143
|
+
main.py # CLI entry point
|
|
144
|
+
configs/ # Example YAML configurations
|
|
145
|
+
notebooks/ # Tutorial notebooks
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
## Notebooks
|
|
149
|
+
|
|
150
|
+
| Notebook | Description |
|
|
151
|
+
|----------|-------------|
|
|
152
|
+
| `01_quickstart.ipynb` | Core features: models, config, initialize |
|
|
153
|
+
| `02_mlm_pretraining.ipynb` | Full MLM pretraining pipeline |
|
|
154
|
+
| `03_causal_lm.ipynb` | Causal LM training + text generation |
|
|
155
|
+
| `04_evaluation.ipynb` | MTEB and custom evaluation |
|
|
156
|
+
|
|
157
|
+
## License
|
|
158
|
+
|
|
159
|
+
Apache 2.0
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Aether - A DeepSpeed-like training library for Google TPUs
|
|
3
|
+
|
|
4
|
+
Built on JAX/Flax for high-performance distributed training
|
|
5
|
+
of transformer models on TPU pods.
|
|
6
|
+
|
|
7
|
+
Quick Start:
|
|
8
|
+
import aetherlm
|
|
9
|
+
|
|
10
|
+
engine = aetherlm.initialize(config={
|
|
11
|
+
"model": {"model_type": "bert", "embed_dim": 768},
|
|
12
|
+
"training": {"mode": "mlm", "batch_size": 32},
|
|
13
|
+
})
|
|
14
|
+
|
|
15
|
+
for batch in train_data:
|
|
16
|
+
loss = engine.train_step(batch)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
__version__ = "0.1.0"
|
|
20
|
+
|
|
21
|
+
from aetherlm.core.config import (
|
|
22
|
+
AetherConfig,
|
|
23
|
+
ModelConfig,
|
|
24
|
+
TrainingConfig,
|
|
25
|
+
OptimizerConfig,
|
|
26
|
+
ContrastiveConfig,
|
|
27
|
+
CausalConfig,
|
|
28
|
+
CheckpointConfig,
|
|
29
|
+
LoggingConfig,
|
|
30
|
+
EvalConfig,
|
|
31
|
+
TPUConfig,
|
|
32
|
+
)
|
|
33
|
+
from aetherlm.core.engine import AetherEngine, initialize
|
|
34
|
+
from aetherlm.models.bert import AetherBERT
|
|
35
|
+
from aetherlm.models.causal import AetherCausalLM
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
# Core
|
|
39
|
+
"initialize",
|
|
40
|
+
"AetherEngine",
|
|
41
|
+
"AetherConfig",
|
|
42
|
+
# Config
|
|
43
|
+
"ModelConfig",
|
|
44
|
+
"TrainingConfig",
|
|
45
|
+
"OptimizerConfig",
|
|
46
|
+
"ContrastiveConfig",
|
|
47
|
+
"CausalConfig",
|
|
48
|
+
"CheckpointConfig",
|
|
49
|
+
"LoggingConfig",
|
|
50
|
+
"EvalConfig",
|
|
51
|
+
"TPUConfig",
|
|
52
|
+
# Models
|
|
53
|
+
"AetherBERT",
|
|
54
|
+
"AetherCausalLM",
|
|
55
|
+
]
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Checkpoint Management
|
|
3
|
+
|
|
4
|
+
Save and load model checkpoints using Orbax, with support for
|
|
5
|
+
checkpoint rotation (keeping only the N most recent).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import flax.nnx as nnx
|
|
12
|
+
import orbax.checkpoint as orbax
|
|
13
|
+
|
|
14
|
+
from aetherlm.core.config import CheckpointConfig
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CheckpointManager:
|
|
18
|
+
"""
|
|
19
|
+
Manages model checkpointing with Orbax.
|
|
20
|
+
|
|
21
|
+
Features:
|
|
22
|
+
- Save/load model parameters
|
|
23
|
+
- Checkpoint rotation (max_to_keep)
|
|
24
|
+
- Automatic directory creation
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: Optional[CheckpointConfig] = None):
|
|
28
|
+
self.config = config or CheckpointConfig()
|
|
29
|
+
self.checkpoint_dir = os.path.abspath(self.config.dir)
|
|
30
|
+
self.max_to_keep = self.config.max_to_keep
|
|
31
|
+
self._saved_checkpoints = []
|
|
32
|
+
|
|
33
|
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
34
|
+
|
|
35
|
+
def save(self, model: nnx.Module, path: Optional[str] = None, step: Optional[int] = None):
|
|
36
|
+
"""
|
|
37
|
+
Save a model checkpoint.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model: NNX model to save.
|
|
41
|
+
path: Full path to save to. If None, auto-generates from step.
|
|
42
|
+
step: Training step number (used for auto-naming).
|
|
43
|
+
"""
|
|
44
|
+
if path is None:
|
|
45
|
+
if step is not None:
|
|
46
|
+
path = os.path.join(self.checkpoint_dir, f"step_{step}")
|
|
47
|
+
else:
|
|
48
|
+
path = os.path.join(self.checkpoint_dir, "latest")
|
|
49
|
+
|
|
50
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
51
|
+
|
|
52
|
+
checkpointer = orbax.PyTreeCheckpointer()
|
|
53
|
+
_, param_state, _ = nnx.split(model, nnx.Param, ...)
|
|
54
|
+
checkpointer.save(path, item=param_state)
|
|
55
|
+
checkpointer.close()
|
|
56
|
+
|
|
57
|
+
self._saved_checkpoints.append(path)
|
|
58
|
+
print(f"Checkpoint saved at {path}")
|
|
59
|
+
|
|
60
|
+
# Rotate old checkpoints
|
|
61
|
+
self._rotate()
|
|
62
|
+
|
|
63
|
+
def load(self, model: nnx.Module, path: str) -> nnx.Module:
|
|
64
|
+
"""
|
|
65
|
+
Load model parameters from a checkpoint.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
model: NNX model (used as template for parameter shapes).
|
|
69
|
+
path: Path to checkpoint directory.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Model with restored parameters.
|
|
73
|
+
"""
|
|
74
|
+
print(f"Loading model from {path}...")
|
|
75
|
+
|
|
76
|
+
_, params_template, _ = nnx.split(model, nnx.Param, ...)
|
|
77
|
+
checkpointer = orbax.PyTreeCheckpointer()
|
|
78
|
+
restored_params = checkpointer.restore(path, item=params_template)
|
|
79
|
+
|
|
80
|
+
if restored_params is None:
|
|
81
|
+
raise ValueError(f"Could not restore parameters from {path}")
|
|
82
|
+
|
|
83
|
+
nnx.update(model, restored_params)
|
|
84
|
+
print("Model loaded successfully.")
|
|
85
|
+
return model
|
|
86
|
+
|
|
87
|
+
def _rotate(self):
|
|
88
|
+
"""Remove old checkpoints beyond max_to_keep."""
|
|
89
|
+
if self.max_to_keep <= 0:
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
while len(self._saved_checkpoints) > self.max_to_keep:
|
|
93
|
+
old_path = self._saved_checkpoints.pop(0)
|
|
94
|
+
if os.path.exists(old_path):
|
|
95
|
+
import shutil
|
|
96
|
+
try:
|
|
97
|
+
shutil.rmtree(old_path)
|
|
98
|
+
print(f"Removed old checkpoint: {old_path}")
|
|
99
|
+
except Exception as e:
|
|
100
|
+
print(f"Warning: could not remove {old_path}: {e}")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def save_checkpoint(model: nnx.Module, path: str):
|
|
104
|
+
"""Convenience function to save a model checkpoint."""
|
|
105
|
+
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
|
|
106
|
+
checkpointer = orbax.PyTreeCheckpointer()
|
|
107
|
+
_, param_state, _ = nnx.split(model, nnx.Param, ...)
|
|
108
|
+
checkpointer.save(path, item=param_state)
|
|
109
|
+
checkpointer.close()
|
|
110
|
+
print(f"Checkpoint saved at {path}")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def load_checkpoint(model: nnx.Module, path: str) -> nnx.Module:
|
|
114
|
+
"""Convenience function to load a model checkpoint."""
|
|
115
|
+
print(f"Loading model from {path}...")
|
|
116
|
+
_, params_template, _ = nnx.split(model, nnx.Param, ...)
|
|
117
|
+
checkpointer = orbax.PyTreeCheckpointer()
|
|
118
|
+
restored_params = checkpointer.restore(path, item=params_template)
|
|
119
|
+
if restored_params is None:
|
|
120
|
+
raise ValueError(f"Could not restore parameters from {path}")
|
|
121
|
+
nnx.update(model, restored_params)
|
|
122
|
+
print("Model loaded successfully.")
|
|
123
|
+
return model
|