rollfast 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rollfast-0.1.0/PKG-INFO +277 -0
- rollfast-0.1.0/README.md +252 -0
- rollfast-0.1.0/pyproject.toml +43 -0
- rollfast-0.1.0/src/rollfast/__init__.py +10 -0
- rollfast-0.1.0/src/rollfast/optim/__init__.py +1 -0
- rollfast-0.1.0/src/rollfast/optim/prism.py +745 -0
- rollfast-0.1.0/src/rollfast/optim/psgd.py +1304 -0
- rollfast-0.1.0/src/rollfast/schedules/__init__.py +1 -0
- rollfast-0.1.0/src/rollfast/schedules/schedulefree.py +518 -0
- rollfast-0.1.0/src/rollfast/schedules/wsd.py +55 -0
- rollfast-0.1.0/src/rollfast/utils.py +31 -0
rollfast-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: rollfast
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: JAX implementation of experimental optimizers and schedulers.
|
|
5
|
+
Keywords: jax,optax,optimizer,psgd,deep-learning,second-order-optimization,preconditioning
|
|
6
|
+
Author: clementpoiret
|
|
7
|
+
Author-email: clementpoiret <clement@linux.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: Intended Audience :: Developers
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
18
|
+
Requires-Dist: jax>=0.6.2
|
|
19
|
+
Requires-Dist: optax>=0.2.0
|
|
20
|
+
Requires-Python: >=3.11
|
|
21
|
+
Project-URL: Homepage, https://github.com/clementpoiret/rollfast
|
|
22
|
+
Project-URL: Repository, https://github.com/clementpoiret/rollfast
|
|
23
|
+
Project-URL: Issues, https://github.com/clementpoiret/rollfast/issues
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
|
|
26
|
+
# rollfast: Advanced Optimization Primitives in JAX
|
|
27
|
+
|
|
28
|
+
`rollfast` is a high-performance optimization library for JAX, designed to
|
|
29
|
+
implement cutting-edge optimizers that go beyond standard Euclidean gradient
|
|
30
|
+
descent. It provides production-ready implementations of optimizers like
|
|
31
|
+
**PSGD** (Preconditioned Stochastic Gradient Descent) and **PRISM** (Anisotropic
|
|
32
|
+
Spectral Shaping), along with a robust **Schedule-Free** wrapper.
|
|
33
|
+
|
|
34
|
+
Built on top of the [Optax](https://github.com/google-deepmind/optax) ecosystem,
|
|
35
|
+
`rollfast` prioritizes memory efficiency (via scanned layers and Kronecker
|
|
36
|
+
factorizations), multi-gpu compatibility, mixed-precision trainings and
|
|
37
|
+
scalability for large models.
|
|
38
|
+
|
|
39
|
+
## Algorithms
|
|
40
|
+
|
|
41
|
+
### 1. PRISM (Anisotropic Spectral Shaping)
|
|
42
|
+
|
|
43
|
+
PRISM allows for structured optimization by applying anisotropic spectral
|
|
44
|
+
shaping to parameter updates. Unlike standard adaptive methods (Adam) that
|
|
45
|
+
operate element-wise, or full-matrix second-order methods (Shampoo/PSGD) that
|
|
46
|
+
approximate the Hessian, PRISM optimizes the singular value distribution of
|
|
47
|
+
weight matrices directly.
|
|
48
|
+
|
|
49
|
+
- **Mechanism**: Decomposes updates using Newton-Schulz iterations to
|
|
50
|
+
approximate SVD, applying "innovation" updates to the singular vectors while
|
|
51
|
+
damping singular values.
|
|
52
|
+
- **Partitioning**: Automatically partitions parameters. High-rank tensors
|
|
53
|
+
(Linear/Conv weights) are optimized via PRISM; vectors (biases, layernorms) are
|
|
54
|
+
optimized via AdamW.
|
|
55
|
+
- **Reference**: *PRISM: Structured Optimization via Anisotropic Spectral
|
|
56
|
+
Shaping* (Yang, 2026).
|
|
57
|
+
|
|
58
|
+
### 2. PSGD Kron (Lie Group Preconditioning)
|
|
59
|
+
|
|
60
|
+
PSGD reformulates preconditioner estimation as a strongly convex optimization
|
|
61
|
+
problem on Lie groups. It updates the preconditioner $Q$ (where $P = Q^T Q$)
|
|
62
|
+
using multiplicative updates that avoid explicit matrix inversion.
|
|
63
|
+
|
|
64
|
+
- **Mechanism**: Maintains a Kronecker-factored preconditioner updated via the
|
|
65
|
+
triangular or orthogonal group.
|
|
66
|
+
- **Reference**: *Stochastic Hessian Fittings with Lie Groups* (Li, 2024).
|
|
67
|
+
|
|
68
|
+
### 3. Schedule-Free Optimization
|
|
69
|
+
|
|
70
|
+
A wrapper that eliminates the need for complex learning rate schedules by
|
|
71
|
+
maintaining two sequences of parameters: a primary sequence $z$ (stepped via the
|
|
72
|
+
base optimizer) and an averaged sequence $x$ (used for evaluation).
|
|
73
|
+
|
|
74
|
+
- **Features**: Supports "Practical" and "Schedulet" weighting modes for
|
|
75
|
+
theoretically grounded averaging.
|
|
76
|
+
- **Reference**: *The Road Less Scheduled* (Defazio et al., 2024).
|
|
77
|
+
|
|
78
|
+
______________________________________________________________________
|
|
79
|
+
|
|
80
|
+
## Installation
|
|
81
|
+
|
|
82
|
+
```bash
|
|
83
|
+
pip install rollfast
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Usage
|
|
87
|
+
|
|
88
|
+
### 1. PRISM (Standard)
|
|
89
|
+
|
|
90
|
+
PRISM automatically handles parameter partitioning. You simply provide the
|
|
91
|
+
learning rate and structural hyperparameters.
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
import jax
|
|
95
|
+
import jax.numpy as jnp
|
|
96
|
+
from rollfast import prism
|
|
97
|
+
|
|
98
|
+
# Define parameters
|
|
99
|
+
params = {
|
|
100
|
+
'linear': {'w': jnp.zeros((128, 128)), 'b': jnp.zeros((128,))},
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
# Initialize PRISM
|
|
104
|
+
# 'w' will be optimized by PRISM (Spectral Shaping)
|
|
105
|
+
# 'b' will be optimized by AdamW
|
|
106
|
+
optimizer = prism(
|
|
107
|
+
learning_rate=1e-3,
|
|
108
|
+
ns_iters=5, # Newton-Schulz iterations for orthogonalization
|
|
109
|
+
gamma=1.0, # Innovation damping
|
|
110
|
+
weight_decay=0.01
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
opt_state = optimizer.init(params)
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
### 2. Schedule-Free PRISM
|
|
117
|
+
|
|
118
|
+
The `schedule_free_prism` function wraps the PRISM optimizer with the
|
|
119
|
+
Schedule-Free logic and the WSD (Warmup-Stable-Decay) scheduler for the internal
|
|
120
|
+
step size.
|
|
121
|
+
|
|
122
|
+
```python
|
|
123
|
+
from rollfast.optim import schedule_free_prism
|
|
124
|
+
|
|
125
|
+
optimizer = schedule_free_prism(
|
|
126
|
+
learning_rate=1.0, # Peak LR for internal steps
|
|
127
|
+
total_steps=10000, # Required for WSD schedule generation
|
|
128
|
+
warmup_fraction=0.1,
|
|
129
|
+
weighting_mode="schedulet",
|
|
130
|
+
sf_b1=0.9, # Schedule-Free interpolation (beta)
|
|
131
|
+
gamma=0.8, # PRISM specific arg
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Note: In Schedule-Free, you must compute gradients at the averaged location 'x'
|
|
135
|
+
# but apply updates to the state 'z'.
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
### 3. PSGD Kron
|
|
139
|
+
|
|
140
|
+
The classic Kronecker-factored PSGD optimizer.
|
|
141
|
+
|
|
142
|
+
```python
|
|
143
|
+
from rollfast.optim import kron
|
|
144
|
+
|
|
145
|
+
optimizer = kron(
|
|
146
|
+
learning_rate=1e-3,
|
|
147
|
+
b1=0.9,
|
|
148
|
+
preconditioner_lr=0.1,
|
|
149
|
+
preconditioner_mode='Q0.5EQ1.5', # Procrustes-regularized update
|
|
150
|
+
whiten_grad=True
|
|
151
|
+
)
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
### Advanced: Scanned Layers (Memory Efficiency)
|
|
155
|
+
|
|
156
|
+
For deep architectures (e.g., Transformers) implemented via `jax.lax.scan`,
|
|
157
|
+
`rollfast` supports explicit handling of scanned layers to prevent unrolling
|
|
158
|
+
computation graphs.
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
import jax
|
|
162
|
+
from rollfast.optim import kron
|
|
163
|
+
|
|
164
|
+
# Boolean pytree mask where True indicates a scanned parameter
|
|
165
|
+
scanned_layers_mask = ...
|
|
166
|
+
|
|
167
|
+
optimizer = kron(
|
|
168
|
+
learning_rate=3e-4,
|
|
169
|
+
scanned_layers=scanned_layers_mask,
|
|
170
|
+
lax_map_scanned_layers=True, # Use lax.map for preconditioner updates
|
|
171
|
+
lax_map_batch_size=8
|
|
172
|
+
)
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
______________________________________________________________________
|
|
176
|
+
|
|
177
|
+
## Configuration
|
|
178
|
+
|
|
179
|
+
### Stability & Clipping Parameters
|
|
180
|
+
|
|
181
|
+
These parameters ensure robustness against gradient spikes and numerical
|
|
182
|
+
instability, critical for training at scale.
|
|
183
|
+
|
|
184
|
+
| Parameter | Default | Description |
|
|
185
|
+
| :---------------------------- | :------------ | :--------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
186
|
+
| `raw_global_grad_clip` | `None` | If set, computes the global L2 norm of gradients *before* the optimizer step. If the norm exceeds this threshold, the update is either clipped or skipped. |
|
|
187
|
+
| `permissive_spike_protection` | `True` | Controls behavior when `raw_global_grad_clip` is triggered. `True` clips the gradient and proceeds; `False` strictly skips the update (zeroing the step). |
|
|
188
|
+
| `grad_clip_max_amps` | `(2.0, 10.0)` | Post-processing clipping. Clips individual tensors by RMS (`2.0`) and absolute value (`10.0`) to prevent heavy tails in the update distribution. |
|
|
189
|
+
|
|
190
|
+
### Schedule-Free Hyperparameters
|
|
191
|
+
|
|
192
|
+
When using `schedule_free_*` optimizers, these arguments control the underlying
|
|
193
|
+
WSD (Warmup-Stable-Decay) schedule and the iterate averaging.
|
|
194
|
+
|
|
195
|
+
| Parameter | Default | Description |
|
|
196
|
+
| :---------------- | :---------- | :---------------------------------------------------------------------------------------------------------------- |
|
|
197
|
+
| `warmup_fraction` | `0.1` | Fraction of `total_steps` used for linear warmup. |
|
|
198
|
+
| `decay_fraction` | `0.1` | Fraction of `total_steps` used for linear decay (cooldown) at the end of training. |
|
|
199
|
+
| `weighting_mode` | `SCHEDULET` | Strategy for $c_t$ calculation: `THEORETICAL` ($1/t$), `PRACTICAL` ($\\gamma_t^2$), or `SCHEDULET` ($\\gamma_t$). |
|
|
200
|
+
|
|
201
|
+
### PRISM Specifics
|
|
202
|
+
|
|
203
|
+
| Parameter | Default | Description |
|
|
204
|
+
| :------------------- | :------ | :------------------------------------------------------------------------------------------ |
|
|
205
|
+
| `ns_iters` | `5` | Newton-Schulz iterations. Higher values provide better orthogonality but cost more compute. |
|
|
206
|
+
| `gamma` | `1.0` | Damping coefficient for the innovation term. Controls the "anisotropy" of spectral shaping. |
|
|
207
|
+
| `shape_nesterov` | `True` | If True, shapes Nesterov momentum; otherwise shapes raw momentum. |
|
|
208
|
+
| `adam_learning_rate` | `None` | Optional override for the Adam branch learning rate. Defaults to `learning_rate` if None. |
|
|
209
|
+
|
|
210
|
+
### PSGD Specifics
|
|
211
|
+
|
|
212
|
+
| Parameter | Default | Description |
|
|
213
|
+
| :-------------------------- | :------ | :---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
214
|
+
| `track_lipschitz` | `True` | Enables adaptive step sizes for the preconditioner $Q$ by tracking the Lipschitz constant of the gradient. |
|
|
215
|
+
| `max_skew_triangular` | `1.0` | Threshold for diagonal approximation. If a dimension's aspect ratio squared exceeds this relative to total numel, it is treated as diagonal to save memory. |
|
|
216
|
+
| `preconditioner_init_scale` | `None` | Initial scale for $Q$. If `None`, it is estimated on the first step using gradient statistics. |
|
|
217
|
+
|
|
218
|
+
#### Preconditioner Modes
|
|
219
|
+
|
|
220
|
+
The geometry of the preconditioner update $dQ$ is controlled via
|
|
221
|
+
`preconditioner_mode`.
|
|
222
|
+
|
|
223
|
+
| Mode | Formula | Description |
|
|
224
|
+
| :---------- | :---------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------- |
|
|
225
|
+
| `Q0.5EQ1.5` | $dQ = Q^{0.5} \\mathcal{E} Q^{1.5}$ | **Recommended**. Uses an online orthogonal Procrustes solver to keep $Q$ approximately SPD. Numerically stable for low precision. |
|
|
226
|
+
| `EQ` | $dQ = \\mathcal{E} Q$ | The original triangular update. Requires triangular solves. Only mode compatible with triangular $Q$. |
|
|
227
|
+
| `QUAD` | Quadratic Form | Ensures $Q$ remains symmetric positive definite via quadratic form updates. |
|
|
228
|
+
| `NS` | Newton-Schulz | Iteratively projects $Q$ onto the SPD manifold using Newton-Schulz iterations. Exact but more expensive. |
|
|
229
|
+
| `EXP` | Matrix Exponential | Geodesic update on the SPD manifold. Uses matrix exponential. |
|
|
230
|
+
| `TAYLOR2` | Taylor Expansion | Second-order Taylor approximation of the matrix exponential update. |
|
|
231
|
+
| `HYPER` | Hyperbolic | Multiplicative hyperbolic update. |
|
|
232
|
+
|
|
233
|
+
______________________________________________________________________
|
|
234
|
+
|
|
235
|
+
## Citations
|
|
236
|
+
|
|
237
|
+
If you use `rollfast` in your research, please cite the relevant papers for the algorithms you utilize.
|
|
238
|
+
|
|
239
|
+
**PRISM:**
|
|
240
|
+
|
|
241
|
+
```bibtex
|
|
242
|
+
@misc{2602.03096,
|
|
243
|
+
Author = {Yujie Yang},
|
|
244
|
+
Title = {PRISM: Structured Optimization via Anisotropic Spectral Shaping},
|
|
245
|
+
Year = {2026},
|
|
246
|
+
Eprint = {arXiv:2602.03096},
|
|
247
|
+
}
|
|
248
|
+
```
|
|
249
|
+
|
|
250
|
+
**Schedule-Free:**
|
|
251
|
+
|
|
252
|
+
```bibtex
|
|
253
|
+
@misc{2405.15682,
|
|
254
|
+
Author = {Aaron Defazio and Xingyu Alice Yang and Harsh Mehta and Konstantin Mishchenko and Ahmed Khaled and Ashok Cutkosky},
|
|
255
|
+
Title = {The Road Less Scheduled},
|
|
256
|
+
Year = {2024},
|
|
257
|
+
Eprint = {arXiv:2405.15682},
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
@misc{2511.07767,
|
|
261
|
+
Author = {Yuen-Man Pun and Matthew Buchholz and Robert M. Gower},
|
|
262
|
+
Title = {Schedulers for Schedule-free: Theoretically inspired hyperparameters},
|
|
263
|
+
Year = {2025},
|
|
264
|
+
Eprint = {arXiv:2511.07767},
|
|
265
|
+
}
|
|
266
|
+
```
|
|
267
|
+
|
|
268
|
+
**PSGD:**
|
|
269
|
+
|
|
270
|
+
```bibtex
|
|
271
|
+
@article{li2024stochastic,
|
|
272
|
+
title={Stochastic Hessian Fittings with Lie Groups},
|
|
273
|
+
author={Li, Xi-Lin},
|
|
274
|
+
journal={arXiv preprint arXiv:2402.11858},
|
|
275
|
+
year={2024}
|
|
276
|
+
}
|
|
277
|
+
```
|
rollfast-0.1.0/README.md
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# rollfast: Advanced Optimization Primitives in JAX
|
|
2
|
+
|
|
3
|
+
`rollfast` is a high-performance optimization library for JAX, designed to
|
|
4
|
+
implement cutting-edge optimizers that go beyond standard Euclidean gradient
|
|
5
|
+
descent. It provides production-ready implementations of optimizers like
|
|
6
|
+
**PSGD** (Preconditioned Stochastic Gradient Descent) and **PRISM** (Anisotropic
|
|
7
|
+
Spectral Shaping), along with a robust **Schedule-Free** wrapper.
|
|
8
|
+
|
|
9
|
+
Built on top of the [Optax](https://github.com/google-deepmind/optax) ecosystem,
|
|
10
|
+
`rollfast` prioritizes memory efficiency (via scanned layers and Kronecker
|
|
11
|
+
factorizations), multi-gpu compatibility, mixed-precision trainings and
|
|
12
|
+
scalability for large models.
|
|
13
|
+
|
|
14
|
+
## Algorithms
|
|
15
|
+
|
|
16
|
+
### 1. PRISM (Anisotropic Spectral Shaping)
|
|
17
|
+
|
|
18
|
+
PRISM allows for structured optimization by applying anisotropic spectral
|
|
19
|
+
shaping to parameter updates. Unlike standard adaptive methods (Adam) that
|
|
20
|
+
operate element-wise, or full-matrix second-order methods (Shampoo/PSGD) that
|
|
21
|
+
approximate the Hessian, PRISM optimizes the singular value distribution of
|
|
22
|
+
weight matrices directly.
|
|
23
|
+
|
|
24
|
+
- **Mechanism**: Decomposes updates using Newton-Schulz iterations to
|
|
25
|
+
approximate SVD, applying "innovation" updates to the singular vectors while
|
|
26
|
+
damping singular values.
|
|
27
|
+
- **Partitioning**: Automatically partitions parameters. High-rank tensors
|
|
28
|
+
(Linear/Conv weights) are optimized via PRISM; vectors (biases, layernorms) are
|
|
29
|
+
optimized via AdamW.
|
|
30
|
+
- **Reference**: *PRISM: Structured Optimization via Anisotropic Spectral
|
|
31
|
+
Shaping* (Yang, 2026).
|
|
32
|
+
|
|
33
|
+
### 2. PSGD Kron (Lie Group Preconditioning)
|
|
34
|
+
|
|
35
|
+
PSGD reformulates preconditioner estimation as a strongly convex optimization
|
|
36
|
+
problem on Lie groups. It updates the preconditioner $Q$ (where $P = Q^T Q$)
|
|
37
|
+
using multiplicative updates that avoid explicit matrix inversion.
|
|
38
|
+
|
|
39
|
+
- **Mechanism**: Maintains a Kronecker-factored preconditioner updated via the
|
|
40
|
+
triangular or orthogonal group.
|
|
41
|
+
- **Reference**: *Stochastic Hessian Fittings with Lie Groups* (Li, 2024).
|
|
42
|
+
|
|
43
|
+
### 3. Schedule-Free Optimization
|
|
44
|
+
|
|
45
|
+
A wrapper that eliminates the need for complex learning rate schedules by
|
|
46
|
+
maintaining two sequences of parameters: a primary sequence $z$ (stepped via the
|
|
47
|
+
base optimizer) and an averaged sequence $x$ (used for evaluation).
|
|
48
|
+
|
|
49
|
+
- **Features**: Supports "Practical" and "Schedulet" weighting modes for
|
|
50
|
+
theoretically grounded averaging.
|
|
51
|
+
- **Reference**: *The Road Less Scheduled* (Defazio et al., 2024).
|
|
52
|
+
|
|
53
|
+
______________________________________________________________________
|
|
54
|
+
|
|
55
|
+
## Installation
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
pip install rollfast
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Usage
|
|
62
|
+
|
|
63
|
+
### 1. PRISM (Standard)
|
|
64
|
+
|
|
65
|
+
PRISM automatically handles parameter partitioning. You simply provide the
|
|
66
|
+
learning rate and structural hyperparameters.
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
import jax
|
|
70
|
+
import jax.numpy as jnp
|
|
71
|
+
from rollfast import prism
|
|
72
|
+
|
|
73
|
+
# Define parameters
|
|
74
|
+
params = {
|
|
75
|
+
'linear': {'w': jnp.zeros((128, 128)), 'b': jnp.zeros((128,))},
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
# Initialize PRISM
|
|
79
|
+
# 'w' will be optimized by PRISM (Spectral Shaping)
|
|
80
|
+
# 'b' will be optimized by AdamW
|
|
81
|
+
optimizer = prism(
|
|
82
|
+
learning_rate=1e-3,
|
|
83
|
+
ns_iters=5, # Newton-Schulz iterations for orthogonalization
|
|
84
|
+
gamma=1.0, # Innovation damping
|
|
85
|
+
weight_decay=0.01
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
opt_state = optimizer.init(params)
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
### 2. Schedule-Free PRISM
|
|
92
|
+
|
|
93
|
+
The `schedule_free_prism` function wraps the PRISM optimizer with the
|
|
94
|
+
Schedule-Free logic and the WSD (Warmup-Stable-Decay) scheduler for the internal
|
|
95
|
+
step size.
|
|
96
|
+
|
|
97
|
+
```python
|
|
98
|
+
from rollfast.optim import schedule_free_prism
|
|
99
|
+
|
|
100
|
+
optimizer = schedule_free_prism(
|
|
101
|
+
learning_rate=1.0, # Peak LR for internal steps
|
|
102
|
+
total_steps=10000, # Required for WSD schedule generation
|
|
103
|
+
warmup_fraction=0.1,
|
|
104
|
+
weighting_mode="schedulet",
|
|
105
|
+
sf_b1=0.9, # Schedule-Free interpolation (beta)
|
|
106
|
+
gamma=0.8, # PRISM specific arg
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Note: In Schedule-Free, you must compute gradients at the averaged location 'x'
|
|
110
|
+
# but apply updates to the state 'z'.
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
### 3. PSGD Kron
|
|
114
|
+
|
|
115
|
+
The classic Kronecker-factored PSGD optimizer.
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
from rollfast.optim import kron
|
|
119
|
+
|
|
120
|
+
optimizer = kron(
|
|
121
|
+
learning_rate=1e-3,
|
|
122
|
+
b1=0.9,
|
|
123
|
+
preconditioner_lr=0.1,
|
|
124
|
+
preconditioner_mode='Q0.5EQ1.5', # Procrustes-regularized update
|
|
125
|
+
whiten_grad=True
|
|
126
|
+
)
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
### Advanced: Scanned Layers (Memory Efficiency)
|
|
130
|
+
|
|
131
|
+
For deep architectures (e.g., Transformers) implemented via `jax.lax.scan`,
|
|
132
|
+
`rollfast` supports explicit handling of scanned layers to prevent unrolling
|
|
133
|
+
computation graphs.
|
|
134
|
+
|
|
135
|
+
```python
|
|
136
|
+
import jax
|
|
137
|
+
from rollfast.optim import kron
|
|
138
|
+
|
|
139
|
+
# Boolean pytree mask where True indicates a scanned parameter
|
|
140
|
+
scanned_layers_mask = ...
|
|
141
|
+
|
|
142
|
+
optimizer = kron(
|
|
143
|
+
learning_rate=3e-4,
|
|
144
|
+
scanned_layers=scanned_layers_mask,
|
|
145
|
+
lax_map_scanned_layers=True, # Use lax.map for preconditioner updates
|
|
146
|
+
lax_map_batch_size=8
|
|
147
|
+
)
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
______________________________________________________________________
|
|
151
|
+
|
|
152
|
+
## Configuration
|
|
153
|
+
|
|
154
|
+
### Stability & Clipping Parameters
|
|
155
|
+
|
|
156
|
+
These parameters ensure robustness against gradient spikes and numerical
|
|
157
|
+
instability, critical for training at scale.
|
|
158
|
+
|
|
159
|
+
| Parameter | Default | Description |
|
|
160
|
+
| :---------------------------- | :------------ | :--------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
161
|
+
| `raw_global_grad_clip` | `None` | If set, computes the global L2 norm of gradients *before* the optimizer step. If the norm exceeds this threshold, the update is either clipped or skipped. |
|
|
162
|
+
| `permissive_spike_protection` | `True` | Controls behavior when `raw_global_grad_clip` is triggered. `True` clips the gradient and proceeds; `False` strictly skips the update (zeroing the step). |
|
|
163
|
+
| `grad_clip_max_amps` | `(2.0, 10.0)` | Post-processing clipping. Clips individual tensors by RMS (`2.0`) and absolute value (`10.0`) to prevent heavy tails in the update distribution. |
|
|
164
|
+
|
|
165
|
+
### Schedule-Free Hyperparameters
|
|
166
|
+
|
|
167
|
+
When using `schedule_free_*` optimizers, these arguments control the underlying
|
|
168
|
+
WSD (Warmup-Stable-Decay) schedule and the iterate averaging.
|
|
169
|
+
|
|
170
|
+
| Parameter | Default | Description |
|
|
171
|
+
| :---------------- | :---------- | :---------------------------------------------------------------------------------------------------------------- |
|
|
172
|
+
| `warmup_fraction` | `0.1` | Fraction of `total_steps` used for linear warmup. |
|
|
173
|
+
| `decay_fraction` | `0.1` | Fraction of `total_steps` used for linear decay (cooldown) at the end of training. |
|
|
174
|
+
| `weighting_mode` | `SCHEDULET` | Strategy for $c_t$ calculation: `THEORETICAL` ($1/t$), `PRACTICAL` ($\\gamma_t^2$), or `SCHEDULET` ($\\gamma_t$). |
|
|
175
|
+
|
|
176
|
+
### PRISM Specifics
|
|
177
|
+
|
|
178
|
+
| Parameter | Default | Description |
|
|
179
|
+
| :------------------- | :------ | :------------------------------------------------------------------------------------------ |
|
|
180
|
+
| `ns_iters` | `5` | Newton-Schulz iterations. Higher values provide better orthogonality but cost more compute. |
|
|
181
|
+
| `gamma` | `1.0` | Damping coefficient for the innovation term. Controls the "anisotropy" of spectral shaping. |
|
|
182
|
+
| `shape_nesterov` | `True` | If True, shapes Nesterov momentum; otherwise shapes raw momentum. |
|
|
183
|
+
| `adam_learning_rate` | `None` | Optional override for the Adam branch learning rate. Defaults to `learning_rate` if None. |
|
|
184
|
+
|
|
185
|
+
### PSGD Specifics
|
|
186
|
+
|
|
187
|
+
| Parameter | Default | Description |
|
|
188
|
+
| :-------------------------- | :------ | :---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
189
|
+
| `track_lipschitz` | `True` | Enables adaptive step sizes for the preconditioner $Q$ by tracking the Lipschitz constant of the gradient. |
|
|
190
|
+
| `max_skew_triangular` | `1.0` | Threshold for diagonal approximation. If a dimension's aspect ratio squared exceeds this relative to total numel, it is treated as diagonal to save memory. |
|
|
191
|
+
| `preconditioner_init_scale` | `None` | Initial scale for $Q$. If `None`, it is estimated on the first step using gradient statistics. |
|
|
192
|
+
|
|
193
|
+
#### Preconditioner Modes
|
|
194
|
+
|
|
195
|
+
The geometry of the preconditioner update $dQ$ is controlled via
|
|
196
|
+
`preconditioner_mode`.
|
|
197
|
+
|
|
198
|
+
| Mode | Formula | Description |
|
|
199
|
+
| :---------- | :---------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------- |
|
|
200
|
+
| `Q0.5EQ1.5` | $dQ = Q^{0.5} \\mathcal{E} Q^{1.5}$ | **Recommended**. Uses an online orthogonal Procrustes solver to keep $Q$ approximately SPD. Numerically stable for low precision. |
|
|
201
|
+
| `EQ` | $dQ = \\mathcal{E} Q$ | The original triangular update. Requires triangular solves. Only mode compatible with triangular $Q$. |
|
|
202
|
+
| `QUAD` | Quadratic Form | Ensures $Q$ remains symmetric positive definite via quadratic form updates. |
|
|
203
|
+
| `NS` | Newton-Schulz | Iteratively projects $Q$ onto the SPD manifold using Newton-Schulz iterations. Exact but more expensive. |
|
|
204
|
+
| `EXP` | Matrix Exponential | Geodesic update on the SPD manifold. Uses matrix exponential. |
|
|
205
|
+
| `TAYLOR2` | Taylor Expansion | Second-order Taylor approximation of the matrix exponential update. |
|
|
206
|
+
| `HYPER` | Hyperbolic | Multiplicative hyperbolic update. |
|
|
207
|
+
|
|
208
|
+
______________________________________________________________________
|
|
209
|
+
|
|
210
|
+
## Citations
|
|
211
|
+
|
|
212
|
+
If you use `rollfast` in your research, please cite the relevant papers for the algorithms you utilize.
|
|
213
|
+
|
|
214
|
+
**PRISM:**
|
|
215
|
+
|
|
216
|
+
```bibtex
|
|
217
|
+
@misc{2602.03096,
|
|
218
|
+
Author = {Yujie Yang},
|
|
219
|
+
Title = {PRISM: Structured Optimization via Anisotropic Spectral Shaping},
|
|
220
|
+
Year = {2026},
|
|
221
|
+
Eprint = {arXiv:2602.03096},
|
|
222
|
+
}
|
|
223
|
+
```
|
|
224
|
+
|
|
225
|
+
**Schedule-Free:**
|
|
226
|
+
|
|
227
|
+
```bibtex
|
|
228
|
+
@misc{2405.15682,
|
|
229
|
+
Author = {Aaron Defazio and Xingyu Alice Yang and Harsh Mehta and Konstantin Mishchenko and Ahmed Khaled and Ashok Cutkosky},
|
|
230
|
+
Title = {The Road Less Scheduled},
|
|
231
|
+
Year = {2024},
|
|
232
|
+
Eprint = {arXiv:2405.15682},
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
@misc{2511.07767,
|
|
236
|
+
Author = {Yuen-Man Pun and Matthew Buchholz and Robert M. Gower},
|
|
237
|
+
Title = {Schedulers for Schedule-free: Theoretically inspired hyperparameters},
|
|
238
|
+
Year = {2025},
|
|
239
|
+
Eprint = {arXiv:2511.07767},
|
|
240
|
+
}
|
|
241
|
+
```
|
|
242
|
+
|
|
243
|
+
**PSGD:**
|
|
244
|
+
|
|
245
|
+
```bibtex
|
|
246
|
+
@article{li2024stochastic,
|
|
247
|
+
title={Stochastic Hessian Fittings with Lie Groups},
|
|
248
|
+
author={Li, Xi-Lin},
|
|
249
|
+
journal={arXiv preprint arXiv:2402.11858},
|
|
250
|
+
year={2024}
|
|
251
|
+
}
|
|
252
|
+
```
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "rollfast"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "JAX implementation of experimental optimizers and schedulers."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.11"
|
|
7
|
+
license = { text = "MIT" }
|
|
8
|
+
authors = [
|
|
9
|
+
{ name = "clementpoiret", email = "clement@linux.com" }
|
|
10
|
+
]
|
|
11
|
+
keywords = [
|
|
12
|
+
"jax",
|
|
13
|
+
"optax",
|
|
14
|
+
"optimizer",
|
|
15
|
+
"psgd",
|
|
16
|
+
"deep-learning",
|
|
17
|
+
"second-order-optimization",
|
|
18
|
+
"preconditioning"
|
|
19
|
+
]
|
|
20
|
+
classifiers = [
|
|
21
|
+
"Development Status :: 4 - Beta",
|
|
22
|
+
"Intended Audience :: Science/Research",
|
|
23
|
+
"Intended Audience :: Developers",
|
|
24
|
+
"License :: OSI Approved :: MIT License",
|
|
25
|
+
"Programming Language :: Python :: 3",
|
|
26
|
+
"Programming Language :: Python :: 3.11",
|
|
27
|
+
"Programming Language :: Python :: 3.12",
|
|
28
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
29
|
+
"Topic :: Scientific/Engineering :: Mathematics",
|
|
30
|
+
]
|
|
31
|
+
dependencies = [
|
|
32
|
+
"jax>=0.6.2",
|
|
33
|
+
"optax>=0.2.0",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
[project.urls]
|
|
37
|
+
Homepage = "https://github.com/clementpoiret/rollfast"
|
|
38
|
+
Repository = "https://github.com/clementpoiret/rollfast"
|
|
39
|
+
Issues = "https://github.com/clementpoiret/rollfast/issues"
|
|
40
|
+
|
|
41
|
+
[build-system]
|
|
42
|
+
requires = ["uv_build>=0.9.7,<0.10.0"]
|
|
43
|
+
build-backend = "uv_build"
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .optim.prism import prism as prism
|
|
2
|
+
from .optim.psgd import kron as kron
|
|
3
|
+
from .schedules.schedulefree import (
|
|
4
|
+
schedule_free_eval_params as schedule_free_eval_params,
|
|
5
|
+
schedule_free_kron as schedule_free_kron,
|
|
6
|
+
schedule_free_prism as schedule_free_prism,
|
|
7
|
+
)
|
|
8
|
+
from .schedules.wsd import wsd_schedule as wsd_schedule
|
|
9
|
+
|
|
10
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|