aetherlm 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. aetherlm-0.1.0/PKG-INFO +201 -0
  2. aetherlm-0.1.0/README.md +159 -0
  3. aetherlm-0.1.0/aetherlm/__init__.py +55 -0
  4. aetherlm-0.1.0/aetherlm/checkpoint/__init__.py +3 -0
  5. aetherlm-0.1.0/aetherlm/checkpoint/manager.py +123 -0
  6. aetherlm-0.1.0/aetherlm/cli/__init__.py +3 -0
  7. aetherlm-0.1.0/aetherlm/cli/main.py +274 -0
  8. aetherlm-0.1.0/aetherlm/core/__init__.py +5 -0
  9. aetherlm-0.1.0/aetherlm/core/config.py +202 -0
  10. aetherlm-0.1.0/aetherlm/core/engine.py +366 -0
  11. aetherlm-0.1.0/aetherlm/core/sharding.py +68 -0
  12. aetherlm-0.1.0/aetherlm/data/__init__.py +6 -0
  13. aetherlm-0.1.0/aetherlm/data/causal.py +123 -0
  14. aetherlm-0.1.0/aetherlm/data/contrastive.py +150 -0
  15. aetherlm-0.1.0/aetherlm/data/mlm.py +164 -0
  16. aetherlm-0.1.0/aetherlm/data/pipeline.py +63 -0
  17. aetherlm-0.1.0/aetherlm/eval/__init__.py +5 -0
  18. aetherlm-0.1.0/aetherlm/eval/mteb.py +181 -0
  19. aetherlm-0.1.0/aetherlm/eval/runner.py +416 -0
  20. aetherlm-0.1.0/aetherlm/eval/tasks.py +164 -0
  21. aetherlm-0.1.0/aetherlm/losses/__init__.py +5 -0
  22. aetherlm-0.1.0/aetherlm/losses/causal.py +43 -0
  23. aetherlm-0.1.0/aetherlm/losses/contrastive.py +73 -0
  24. aetherlm-0.1.0/aetherlm/losses/mlm.py +53 -0
  25. aetherlm-0.1.0/aetherlm/metrics/__init__.py +3 -0
  26. aetherlm-0.1.0/aetherlm/metrics/tracker.py +209 -0
  27. aetherlm-0.1.0/aetherlm/models/__init__.py +6 -0
  28. aetherlm-0.1.0/aetherlm/models/base.py +157 -0
  29. aetherlm-0.1.0/aetherlm/models/bert.py +121 -0
  30. aetherlm-0.1.0/aetherlm/models/causal.py +195 -0
  31. aetherlm-0.1.0/aetherlm/models/transformer.py +128 -0
  32. aetherlm-0.1.0/aetherlm/optim/__init__.py +5 -0
  33. aetherlm-0.1.0/aetherlm/optim/optimizers.py +108 -0
  34. aetherlm-0.1.0/aetherlm/optim/schedules.py +101 -0
  35. aetherlm-0.1.0/aetherlm/optim/switching.py +57 -0
  36. aetherlm-0.1.0/aetherlm/tpu/__init__.py +4 -0
  37. aetherlm-0.1.0/aetherlm/tpu/precision.py +41 -0
  38. aetherlm-0.1.0/aetherlm/tpu/topology.py +123 -0
  39. aetherlm-0.1.0/aetherlm/training/__init__.py +7 -0
  40. aetherlm-0.1.0/aetherlm/training/steps.py +119 -0
  41. aetherlm-0.1.0/aetherlm.egg-info/PKG-INFO +201 -0
  42. aetherlm-0.1.0/aetherlm.egg-info/SOURCES.txt +46 -0
  43. aetherlm-0.1.0/aetherlm.egg-info/dependency_links.txt +1 -0
  44. aetherlm-0.1.0/aetherlm.egg-info/entry_points.txt +2 -0
  45. aetherlm-0.1.0/aetherlm.egg-info/requires.txt +26 -0
  46. aetherlm-0.1.0/aetherlm.egg-info/top_level.txt +1 -0
  47. aetherlm-0.1.0/pyproject.toml +67 -0
  48. aetherlm-0.1.0/setup.cfg +4 -0
@@ -0,0 +1,201 @@
1
+ Metadata-Version: 2.4
2
+ Name: aetherlm
3
+ Version: 0.1.0
4
+ Summary: A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.
5
+ Author: aetherBERT contributors
6
+ License: Apache-2.0
7
+ Project-URL: Homepage, https://github.com/aetherBERT/aetherlm
8
+ Project-URL: Repository, https://github.com/aetherBERT/aetherlm
9
+ Keywords: jax,flax,tpu,deep-learning,distributed-training,bert,transformers
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: License :: OSI Approved :: Apache Software License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Requires-Python: >=3.10
19
+ Description-Content-Type: text/markdown
20
+ Requires-Dist: jax[tpu]>=0.4.20
21
+ Requires-Dist: flax>=0.10.0
22
+ Requires-Dist: optax>=0.2.0
23
+ Requires-Dist: orbax-checkpoint>=0.6.0
24
+ Requires-Dist: transformers>=4.36.0
25
+ Requires-Dist: datasets>=2.16.0
26
+ Requires-Dist: wandb>=0.16.0
27
+ Requires-Dist: numpy>=1.24.0
28
+ Requires-Dist: pyyaml>=6.0
29
+ Provides-Extra: eval
30
+ Requires-Dist: mteb>=1.12.0; extra == "eval"
31
+ Requires-Dist: torch>=2.0.0; extra == "eval"
32
+ Requires-Dist: scikit-learn>=1.3.0; extra == "eval"
33
+ Requires-Dist: scipy>=1.11.0; extra == "eval"
34
+ Provides-Extra: nmn
35
+ Requires-Dist: nmn; extra == "nmn"
36
+ Provides-Extra: dev
37
+ Requires-Dist: pytest>=7.0; extra == "dev"
38
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
39
+ Requires-Dist: jupyter>=1.0.0; extra == "dev"
40
+ Provides-Extra: all
41
+ Requires-Dist: aetherlm[dev,eval,nmn]; extra == "all"
42
+
43
+ # AetherLM
44
+
45
+ A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.
46
+
47
+ AetherLM abstracts away the complexity of distributed training on TPU pods, providing a simple `initialize()` API for pretraining and fine-tuning transformer models.
48
+
49
+ ## Features
50
+
51
+ - **One-call initialization** — `aetherlm.initialize(config)` sets up mesh, model, optimizer, and metrics
52
+ - **TPU-native** — Built on JAX with automatic TPU topology detection and optimal sharding
53
+ - **Multiple training modes** — MLM pretraining, contrastive learning, causal LM
54
+ - **Built-in models** — `AetherBERT` (bidirectional) and `AetherCausalLM` (autoregressive with generation)
55
+ - **YAML/dict configuration** — DeepSpeed-style config system
56
+ - **MTEB evaluation** — Integrated embedding evaluation with 56 tasks
57
+ - **Checkpoint management** — Automatic saving, rotation, and restore with Orbax
58
+
59
+ ## Quick Start
60
+
61
+ ```python
62
+ import aetherlm
63
+
64
+ # Configure and initialize (auto-detects TPU)
65
+ engine = aetherlm.initialize(config={
66
+ "model": {"model_type": "bert", "embed_dim": 768, "num_layers": 12},
67
+ "training": {"mode": "mlm", "batch_size": 32, "learning_rate": 1e-4},
68
+ "tpu": {"precision": "bf16"},
69
+ })
70
+
71
+ # Load data
72
+ from aetherlm.data import load_mlm_datasets
73
+ train_data, val_data = load_mlm_datasets(maxlen=512, mask_prob=0.15, vocab_size=50265)
74
+
75
+ # Train (handles sharding, logging, checkpointing)
76
+ engine.train(train_data, val_data)
77
+ ```
78
+
79
+ ## Installation
80
+
81
+ ```bash
82
+ # From PyPI
83
+ pip install aetherlm
84
+
85
+ # With eval support (MTEB, sklearn, scipy)
86
+ pip install aetherlm[eval]
87
+
88
+ # From source
89
+ pip install -e ".[all]"
90
+ ```
91
+
92
+ ## Training Modes
93
+
94
+ ### MLM Pretraining (BERT-style)
95
+
96
+ ```bash
97
+ aetherlm --mode mlm --config configs/default.yaml
98
+ ```
99
+
100
+ ### Causal Language Modeling (GPT-style)
101
+
102
+ ```bash
103
+ aetherlm --mode causal --config configs/causal_small.yaml
104
+ ```
105
+
106
+ ### Contrastive Learning (requires checkpoint)
107
+
108
+ ```bash
109
+ aetherlm --mode contrastive --checkpoint ./checkpoints/step_10000
110
+ ```
111
+
112
+ ### Evaluation
113
+
114
+ ```bash
115
+ # Quick MTEB eval (3 tasks, ~1 min)
116
+ aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset quick
117
+
118
+ # Full leaderboard (56 tasks)
119
+ aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset leaderboard
120
+ ```
121
+
122
+ ## Configuration
123
+
124
+ Aether uses a dataclass-based config system. Create configs from YAML, JSON, or Python dicts:
125
+
126
+ ```python
127
+ from aetherlm import AetherConfig
128
+
129
+ # From YAML
130
+ config = AetherConfig.from_yaml("configs/default.yaml")
131
+
132
+ # From dict
133
+ config = AetherConfig.from_dict({
134
+ "model": {"embed_dim": 768, "num_layers": 12},
135
+ "training": {"mode": "mlm", "batch_size": 32},
136
+ })
137
+
138
+ # Save for reproducibility
139
+ config.to_yaml("my_experiment.yaml")
140
+ ```
141
+
142
+ See `configs/` for example configurations.
143
+
144
+ ## Project Structure
145
+
146
+ ```
147
+ aetherlm/
148
+ __init__.py # Top-level API: initialize(), AetherConfig, models
149
+ core/
150
+ config.py # Dataclass configuration system
151
+ engine.py # Training engine (the heart of the library)
152
+ sharding.py # Automatic mesh sharding
153
+ models/
154
+ base.py # Abstract model interface + utilities
155
+ transformer.py # Transformer blocks, embeddings
156
+ bert.py # AetherBERT (bidirectional MLM)
157
+ causal.py # AetherCausalLM (autoregressive + generation)
158
+ optim/
159
+ optimizers.py # Optimizer factory (AdamW, Adam, SGD)
160
+ schedules.py # LR schedules (warmup-cosine, linear, constant)
161
+ switching.py # Plateau detection + optimizer switching
162
+ data/
163
+ pipeline.py # Tokenizer caching, batch iterators
164
+ mlm.py # MLM masking and dataset processing
165
+ contrastive.py # Contrastive pair creation (self-supervised + AllNLI)
166
+ causal.py # Causal LM next-token prediction format
167
+ losses/
168
+ mlm.py # Efficient gather-based MLM loss
169
+ contrastive.py # SimCLR-style contrastive loss
170
+ causal.py # Causal LM cross-entropy loss
171
+ training/
172
+ steps.py # JIT-compiled train/eval steps for all modes
173
+ checkpoint/
174
+ manager.py # Orbax checkpoint save/load/rotation
175
+ metrics/
176
+ tracker.py # Throughput, loss, ETA, WandB logging
177
+ eval/
178
+ tasks.py # MTEB task lists and presets
179
+ mteb.py # MTEB EncoderProtocol wrapper
180
+ runner.py # Evaluation orchestrators
181
+ tpu/
182
+ topology.py # TPU detection, mesh creation
183
+ precision.py # bfloat16/mixed precision config
184
+ cli/
185
+ main.py # CLI entry point
186
+ configs/ # Example YAML configurations
187
+ notebooks/ # Tutorial notebooks
188
+ ```
189
+
190
+ ## Notebooks
191
+
192
+ | Notebook | Description |
193
+ |----------|-------------|
194
+ | `01_quickstart.ipynb` | Core features: models, config, initialize |
195
+ | `02_mlm_pretraining.ipynb` | Full MLM pretraining pipeline |
196
+ | `03_causal_lm.ipynb` | Causal LM training + text generation |
197
+ | `04_evaluation.ipynb` | MTEB and custom evaluation |
198
+
199
+ ## License
200
+
201
+ Apache 2.0
@@ -0,0 +1,159 @@
1
+ # AetherLM
2
+
3
+ A DeepSpeed-like training library for Google TPUs, built on JAX/Flax.
4
+
5
+ AetherLM abstracts away the complexity of distributed training on TPU pods, providing a simple `initialize()` API for pretraining and fine-tuning transformer models.
6
+
7
+ ## Features
8
+
9
+ - **One-call initialization** — `aetherlm.initialize(config)` sets up mesh, model, optimizer, and metrics
10
+ - **TPU-native** — Built on JAX with automatic TPU topology detection and optimal sharding
11
+ - **Multiple training modes** — MLM pretraining, contrastive learning, causal LM
12
+ - **Built-in models** — `AetherBERT` (bidirectional) and `AetherCausalLM` (autoregressive with generation)
13
+ - **YAML/dict configuration** — DeepSpeed-style config system
14
+ - **MTEB evaluation** — Integrated embedding evaluation with 56 tasks
15
+ - **Checkpoint management** — Automatic saving, rotation, and restore with Orbax
16
+
17
+ ## Quick Start
18
+
19
+ ```python
20
+ import aetherlm
21
+
22
+ # Configure and initialize (auto-detects TPU)
23
+ engine = aetherlm.initialize(config={
24
+ "model": {"model_type": "bert", "embed_dim": 768, "num_layers": 12},
25
+ "training": {"mode": "mlm", "batch_size": 32, "learning_rate": 1e-4},
26
+ "tpu": {"precision": "bf16"},
27
+ })
28
+
29
+ # Load data
30
+ from aetherlm.data import load_mlm_datasets
31
+ train_data, val_data = load_mlm_datasets(maxlen=512, mask_prob=0.15, vocab_size=50265)
32
+
33
+ # Train (handles sharding, logging, checkpointing)
34
+ engine.train(train_data, val_data)
35
+ ```
36
+
37
+ ## Installation
38
+
39
+ ```bash
40
+ # From PyPI
41
+ pip install aetherlm
42
+
43
+ # With eval support (MTEB, sklearn, scipy)
44
+ pip install aetherlm[eval]
45
+
46
+ # From source
47
+ pip install -e ".[all]"
48
+ ```
49
+
50
+ ## Training Modes
51
+
52
+ ### MLM Pretraining (BERT-style)
53
+
54
+ ```bash
55
+ aetherlm --mode mlm --config configs/default.yaml
56
+ ```
57
+
58
+ ### Causal Language Modeling (GPT-style)
59
+
60
+ ```bash
61
+ aetherlm --mode causal --config configs/causal_small.yaml
62
+ ```
63
+
64
+ ### Contrastive Learning (requires checkpoint)
65
+
66
+ ```bash
67
+ aetherlm --mode contrastive --checkpoint ./checkpoints/step_10000
68
+ ```
69
+
70
+ ### Evaluation
71
+
72
+ ```bash
73
+ # Quick MTEB eval (3 tasks, ~1 min)
74
+ aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset quick
75
+
76
+ # Full leaderboard (56 tasks)
77
+ aetherlm --mode eval --checkpoint ./checkpoints/step_10000 --mteb_preset leaderboard
78
+ ```
79
+
80
+ ## Configuration
81
+
82
+ Aether uses a dataclass-based config system. Create configs from YAML, JSON, or Python dicts:
83
+
84
+ ```python
85
+ from aetherlm import AetherConfig
86
+
87
+ # From YAML
88
+ config = AetherConfig.from_yaml("configs/default.yaml")
89
+
90
+ # From dict
91
+ config = AetherConfig.from_dict({
92
+ "model": {"embed_dim": 768, "num_layers": 12},
93
+ "training": {"mode": "mlm", "batch_size": 32},
94
+ })
95
+
96
+ # Save for reproducibility
97
+ config.to_yaml("my_experiment.yaml")
98
+ ```
99
+
100
+ See `configs/` for example configurations.
101
+
102
+ ## Project Structure
103
+
104
+ ```
105
+ aetherlm/
106
+ __init__.py # Top-level API: initialize(), AetherConfig, models
107
+ core/
108
+ config.py # Dataclass configuration system
109
+ engine.py # Training engine (the heart of the library)
110
+ sharding.py # Automatic mesh sharding
111
+ models/
112
+ base.py # Abstract model interface + utilities
113
+ transformer.py # Transformer blocks, embeddings
114
+ bert.py # AetherBERT (bidirectional MLM)
115
+ causal.py # AetherCausalLM (autoregressive + generation)
116
+ optim/
117
+ optimizers.py # Optimizer factory (AdamW, Adam, SGD)
118
+ schedules.py # LR schedules (warmup-cosine, linear, constant)
119
+ switching.py # Plateau detection + optimizer switching
120
+ data/
121
+ pipeline.py # Tokenizer caching, batch iterators
122
+ mlm.py # MLM masking and dataset processing
123
+ contrastive.py # Contrastive pair creation (self-supervised + AllNLI)
124
+ causal.py # Causal LM next-token prediction format
125
+ losses/
126
+ mlm.py # Efficient gather-based MLM loss
127
+ contrastive.py # SimCLR-style contrastive loss
128
+ causal.py # Causal LM cross-entropy loss
129
+ training/
130
+ steps.py # JIT-compiled train/eval steps for all modes
131
+ checkpoint/
132
+ manager.py # Orbax checkpoint save/load/rotation
133
+ metrics/
134
+ tracker.py # Throughput, loss, ETA, WandB logging
135
+ eval/
136
+ tasks.py # MTEB task lists and presets
137
+ mteb.py # MTEB EncoderProtocol wrapper
138
+ runner.py # Evaluation orchestrators
139
+ tpu/
140
+ topology.py # TPU detection, mesh creation
141
+ precision.py # bfloat16/mixed precision config
142
+ cli/
143
+ main.py # CLI entry point
144
+ configs/ # Example YAML configurations
145
+ notebooks/ # Tutorial notebooks
146
+ ```
147
+
148
+ ## Notebooks
149
+
150
+ | Notebook | Description |
151
+ |----------|-------------|
152
+ | `01_quickstart.ipynb` | Core features: models, config, initialize |
153
+ | `02_mlm_pretraining.ipynb` | Full MLM pretraining pipeline |
154
+ | `03_causal_lm.ipynb` | Causal LM training + text generation |
155
+ | `04_evaluation.ipynb` | MTEB and custom evaluation |
156
+
157
+ ## License
158
+
159
+ Apache 2.0
@@ -0,0 +1,55 @@
1
+ """
2
+ Aether - A DeepSpeed-like training library for Google TPUs
3
+
4
+ Built on JAX/Flax for high-performance distributed training
5
+ of transformer models on TPU pods.
6
+
7
+ Quick Start:
8
+ import aetherlm
9
+
10
+ engine = aetherlm.initialize(config={
11
+ "model": {"model_type": "bert", "embed_dim": 768},
12
+ "training": {"mode": "mlm", "batch_size": 32},
13
+ })
14
+
15
+ for batch in train_data:
16
+ loss = engine.train_step(batch)
17
+ """
18
+
19
+ __version__ = "0.1.0"
20
+
21
+ from aetherlm.core.config import (
22
+ AetherConfig,
23
+ ModelConfig,
24
+ TrainingConfig,
25
+ OptimizerConfig,
26
+ ContrastiveConfig,
27
+ CausalConfig,
28
+ CheckpointConfig,
29
+ LoggingConfig,
30
+ EvalConfig,
31
+ TPUConfig,
32
+ )
33
+ from aetherlm.core.engine import AetherEngine, initialize
34
+ from aetherlm.models.bert import AetherBERT
35
+ from aetherlm.models.causal import AetherCausalLM
36
+
37
+ __all__ = [
38
+ # Core
39
+ "initialize",
40
+ "AetherEngine",
41
+ "AetherConfig",
42
+ # Config
43
+ "ModelConfig",
44
+ "TrainingConfig",
45
+ "OptimizerConfig",
46
+ "ContrastiveConfig",
47
+ "CausalConfig",
48
+ "CheckpointConfig",
49
+ "LoggingConfig",
50
+ "EvalConfig",
51
+ "TPUConfig",
52
+ # Models
53
+ "AetherBERT",
54
+ "AetherCausalLM",
55
+ ]
@@ -0,0 +1,3 @@
1
+ """Aether checkpoint management with Orbax."""
2
+
3
+ from aetherlm.checkpoint.manager import CheckpointManager, save_checkpoint, load_checkpoint
@@ -0,0 +1,123 @@
1
+ """
2
+ Checkpoint Management
3
+
4
+ Save and load model checkpoints using Orbax, with support for
5
+ checkpoint rotation (keeping only the N most recent).
6
+ """
7
+
8
+ import os
9
+ from typing import Optional
10
+
11
+ import flax.nnx as nnx
12
+ import orbax.checkpoint as orbax
13
+
14
+ from aetherlm.core.config import CheckpointConfig
15
+
16
+
17
+ class CheckpointManager:
18
+ """
19
+ Manages model checkpointing with Orbax.
20
+
21
+ Features:
22
+ - Save/load model parameters
23
+ - Checkpoint rotation (max_to_keep)
24
+ - Automatic directory creation
25
+ """
26
+
27
+ def __init__(self, config: Optional[CheckpointConfig] = None):
28
+ self.config = config or CheckpointConfig()
29
+ self.checkpoint_dir = os.path.abspath(self.config.dir)
30
+ self.max_to_keep = self.config.max_to_keep
31
+ self._saved_checkpoints = []
32
+
33
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
34
+
35
+ def save(self, model: nnx.Module, path: Optional[str] = None, step: Optional[int] = None):
36
+ """
37
+ Save a model checkpoint.
38
+
39
+ Args:
40
+ model: NNX model to save.
41
+ path: Full path to save to. If None, auto-generates from step.
42
+ step: Training step number (used for auto-naming).
43
+ """
44
+ if path is None:
45
+ if step is not None:
46
+ path = os.path.join(self.checkpoint_dir, f"step_{step}")
47
+ else:
48
+ path = os.path.join(self.checkpoint_dir, "latest")
49
+
50
+ os.makedirs(os.path.dirname(path), exist_ok=True)
51
+
52
+ checkpointer = orbax.PyTreeCheckpointer()
53
+ _, param_state, _ = nnx.split(model, nnx.Param, ...)
54
+ checkpointer.save(path, item=param_state)
55
+ checkpointer.close()
56
+
57
+ self._saved_checkpoints.append(path)
58
+ print(f"Checkpoint saved at {path}")
59
+
60
+ # Rotate old checkpoints
61
+ self._rotate()
62
+
63
+ def load(self, model: nnx.Module, path: str) -> nnx.Module:
64
+ """
65
+ Load model parameters from a checkpoint.
66
+
67
+ Args:
68
+ model: NNX model (used as template for parameter shapes).
69
+ path: Path to checkpoint directory.
70
+
71
+ Returns:
72
+ Model with restored parameters.
73
+ """
74
+ print(f"Loading model from {path}...")
75
+
76
+ _, params_template, _ = nnx.split(model, nnx.Param, ...)
77
+ checkpointer = orbax.PyTreeCheckpointer()
78
+ restored_params = checkpointer.restore(path, item=params_template)
79
+
80
+ if restored_params is None:
81
+ raise ValueError(f"Could not restore parameters from {path}")
82
+
83
+ nnx.update(model, restored_params)
84
+ print("Model loaded successfully.")
85
+ return model
86
+
87
+ def _rotate(self):
88
+ """Remove old checkpoints beyond max_to_keep."""
89
+ if self.max_to_keep <= 0:
90
+ return
91
+
92
+ while len(self._saved_checkpoints) > self.max_to_keep:
93
+ old_path = self._saved_checkpoints.pop(0)
94
+ if os.path.exists(old_path):
95
+ import shutil
96
+ try:
97
+ shutil.rmtree(old_path)
98
+ print(f"Removed old checkpoint: {old_path}")
99
+ except Exception as e:
100
+ print(f"Warning: could not remove {old_path}: {e}")
101
+
102
+
103
+ def save_checkpoint(model: nnx.Module, path: str):
104
+ """Convenience function to save a model checkpoint."""
105
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
106
+ checkpointer = orbax.PyTreeCheckpointer()
107
+ _, param_state, _ = nnx.split(model, nnx.Param, ...)
108
+ checkpointer.save(path, item=param_state)
109
+ checkpointer.close()
110
+ print(f"Checkpoint saved at {path}")
111
+
112
+
113
+ def load_checkpoint(model: nnx.Module, path: str) -> nnx.Module:
114
+ """Convenience function to load a model checkpoint."""
115
+ print(f"Loading model from {path}...")
116
+ _, params_template, _ = nnx.split(model, nnx.Param, ...)
117
+ checkpointer = orbax.PyTreeCheckpointer()
118
+ restored_params = checkpointer.restore(path, item=params_template)
119
+ if restored_params is None:
120
+ raise ValueError(f"Could not restore parameters from {path}")
121
+ nnx.update(model, restored_params)
122
+ print("Model loaded successfully.")
123
+ return model
@@ -0,0 +1,3 @@
1
+ """Aether CLI entry point."""
2
+
3
+ from aetherlm.cli.main import main