rollfast 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,277 @@
1
+ Metadata-Version: 2.3
2
+ Name: rollfast
3
+ Version: 0.1.0
4
+ Summary: JAX implementation of experimental optimizers and schedulers.
5
+ Keywords: jax,optax,optimizer,psgd,deep-learning,second-order-optimization,preconditioning
6
+ Author: clementpoiret
7
+ Author-email: clementpoiret <clement@linux.com>
8
+ License: MIT
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
18
+ Requires-Dist: jax>=0.6.2
19
+ Requires-Dist: optax>=0.2.0
20
+ Requires-Python: >=3.11
21
+ Project-URL: Homepage, https://github.com/clementpoiret/rollfast
22
+ Project-URL: Repository, https://github.com/clementpoiret/rollfast
23
+ Project-URL: Issues, https://github.com/clementpoiret/rollfast/issues
24
+ Description-Content-Type: text/markdown
25
+
26
+ # rollfast: Advanced Optimization Primitives in JAX
27
+
28
+ `rollfast` is a high-performance optimization library for JAX, designed to
29
+ implement cutting-edge optimizers that go beyond standard Euclidean gradient
30
+ descent. It provides production-ready implementations of optimizers like
31
+ **PSGD** (Preconditioned Stochastic Gradient Descent) and **PRISM** (Anisotropic
32
+ Spectral Shaping), along with a robust **Schedule-Free** wrapper.
33
+
34
+ Built on top of the [Optax](https://github.com/google-deepmind/optax) ecosystem,
35
+ `rollfast` prioritizes memory efficiency (via scanned layers and Kronecker
36
+ factorizations), multi-gpu compatibility, mixed-precision trainings and
37
+ scalability for large models.
38
+
39
+ ## Algorithms
40
+
41
+ ### 1. PRISM (Anisotropic Spectral Shaping)
42
+
43
+ PRISM allows for structured optimization by applying anisotropic spectral
44
+ shaping to parameter updates. Unlike standard adaptive methods (Adam) that
45
+ operate element-wise, or full-matrix second-order methods (Shampoo/PSGD) that
46
+ approximate the Hessian, PRISM optimizes the singular value distribution of
47
+ weight matrices directly.
48
+
49
+ - **Mechanism**: Decomposes updates using Newton-Schulz iterations to
50
+ approximate SVD, applying "innovation" updates to the singular vectors while
51
+ damping singular values.
52
+ - **Partitioning**: Automatically partitions parameters. High-rank tensors
53
+ (Linear/Conv weights) are optimized via PRISM; vectors (biases, layernorms) are
54
+ optimized via AdamW.
55
+ - **Reference**: *PRISM: Structured Optimization via Anisotropic Spectral
56
+ Shaping* (Yang, 2026).
57
+
58
+ ### 2. PSGD Kron (Lie Group Preconditioning)
59
+
60
+ PSGD reformulates preconditioner estimation as a strongly convex optimization
61
+ problem on Lie groups. It updates the preconditioner $Q$ (where $P = Q^T Q$)
62
+ using multiplicative updates that avoid explicit matrix inversion.
63
+
64
+ - **Mechanism**: Maintains a Kronecker-factored preconditioner updated via the
65
+ triangular or orthogonal group.
66
+ - **Reference**: *Stochastic Hessian Fittings with Lie Groups* (Li, 2024).
67
+
68
+ ### 3. Schedule-Free Optimization
69
+
70
+ A wrapper that eliminates the need for complex learning rate schedules by
71
+ maintaining two sequences of parameters: a primary sequence $z$ (stepped via the
72
+ base optimizer) and an averaged sequence $x$ (used for evaluation).
73
+
74
+ - **Features**: Supports "Practical" and "Schedulet" weighting modes for
75
+ theoretically grounded averaging.
76
+ - **Reference**: *The Road Less Scheduled* (Defazio et al., 2024).
77
+
78
+ ______________________________________________________________________
79
+
80
+ ## Installation
81
+
82
+ ```bash
83
+ pip install rollfast
84
+ ```
85
+
86
+ ## Usage
87
+
88
+ ### 1. PRISM (Standard)
89
+
90
+ PRISM automatically handles parameter partitioning. You simply provide the
91
+ learning rate and structural hyperparameters.
92
+
93
+ ```python
94
+ import jax
95
+ import jax.numpy as jnp
96
+ from rollfast import prism
97
+
98
+ # Define parameters
99
+ params = {
100
+ 'linear': {'w': jnp.zeros((128, 128)), 'b': jnp.zeros((128,))},
101
+ }
102
+
103
+ # Initialize PRISM
104
+ # 'w' will be optimized by PRISM (Spectral Shaping)
105
+ # 'b' will be optimized by AdamW
106
+ optimizer = prism(
107
+ learning_rate=1e-3,
108
+ ns_iters=5, # Newton-Schulz iterations for orthogonalization
109
+ gamma=1.0, # Innovation damping
110
+ weight_decay=0.01
111
+ )
112
+
113
+ opt_state = optimizer.init(params)
114
+ ```
115
+
116
+ ### 2. Schedule-Free PRISM
117
+
118
+ The `schedule_free_prism` function wraps the PRISM optimizer with the
119
+ Schedule-Free logic and the WSD (Warmup-Stable-Decay) scheduler for the internal
120
+ step size.
121
+
122
+ ```python
123
+ from rollfast.optim import schedule_free_prism
124
+
125
+ optimizer = schedule_free_prism(
126
+ learning_rate=1.0, # Peak LR for internal steps
127
+ total_steps=10000, # Required for WSD schedule generation
128
+ warmup_fraction=0.1,
129
+ weighting_mode="schedulet",
130
+ sf_b1=0.9, # Schedule-Free interpolation (beta)
131
+ gamma=0.8, # PRISM specific arg
132
+ )
133
+
134
+ # Note: In Schedule-Free, you must compute gradients at the averaged location 'x'
135
+ # but apply updates to the state 'z'.
136
+ ```
137
+
138
+ ### 3. PSGD Kron
139
+
140
+ The classic Kronecker-factored PSGD optimizer.
141
+
142
+ ```python
143
+ from rollfast.optim import kron
144
+
145
+ optimizer = kron(
146
+ learning_rate=1e-3,
147
+ b1=0.9,
148
+ preconditioner_lr=0.1,
149
+ preconditioner_mode='Q0.5EQ1.5', # Procrustes-regularized update
150
+ whiten_grad=True
151
+ )
152
+ ```
153
+
154
+ ### Advanced: Scanned Layers (Memory Efficiency)
155
+
156
+ For deep architectures (e.g., Transformers) implemented via `jax.lax.scan`,
157
+ `rollfast` supports explicit handling of scanned layers to prevent unrolling
158
+ computation graphs.
159
+
160
+ ```python
161
+ import jax
162
+ from rollfast.optim import kron
163
+
164
+ # Boolean pytree mask where True indicates a scanned parameter
165
+ scanned_layers_mask = ...
166
+
167
+ optimizer = kron(
168
+ learning_rate=3e-4,
169
+ scanned_layers=scanned_layers_mask,
170
+ lax_map_scanned_layers=True, # Use lax.map for preconditioner updates
171
+ lax_map_batch_size=8
172
+ )
173
+ ```
174
+
175
+ ______________________________________________________________________
176
+
177
+ ## Configuration
178
+
179
+ ### Stability & Clipping Parameters
180
+
181
+ These parameters ensure robustness against gradient spikes and numerical
182
+ instability, critical for training at scale.
183
+
184
+ | Parameter | Default | Description |
185
+ | :---------------------------- | :------------ | :--------------------------------------------------------------------------------------------------------------------------------------------------------- |
186
+ | `raw_global_grad_clip` | `None` | If set, computes the global L2 norm of gradients *before* the optimizer step. If the norm exceeds this threshold, the update is either clipped or skipped. |
187
+ | `permissive_spike_protection` | `True` | Controls behavior when `raw_global_grad_clip` is triggered. `True` clips the gradient and proceeds; `False` strictly skips the update (zeroing the step). |
188
+ | `grad_clip_max_amps` | `(2.0, 10.0)` | Post-processing clipping. Clips individual tensors by RMS (`2.0`) and absolute value (`10.0`) to prevent heavy tails in the update distribution. |
189
+
190
+ ### Schedule-Free Hyperparameters
191
+
192
+ When using `schedule_free_*` optimizers, these arguments control the underlying
193
+ WSD (Warmup-Stable-Decay) schedule and the iterate averaging.
194
+
195
+ | Parameter | Default | Description |
196
+ | :---------------- | :---------- | :---------------------------------------------------------------------------------------------------------------- |
197
+ | `warmup_fraction` | `0.1` | Fraction of `total_steps` used for linear warmup. |
198
+ | `decay_fraction` | `0.1` | Fraction of `total_steps` used for linear decay (cooldown) at the end of training. |
199
+ | `weighting_mode` | `SCHEDULET` | Strategy for $c_t$ calculation: `THEORETICAL` ($1/t$), `PRACTICAL` ($\\gamma_t^2$), or `SCHEDULET` ($\\gamma_t$). |
200
+
201
+ ### PRISM Specifics
202
+
203
+ | Parameter | Default | Description |
204
+ | :------------------- | :------ | :------------------------------------------------------------------------------------------ |
205
+ | `ns_iters` | `5` | Newton-Schulz iterations. Higher values provide better orthogonality but cost more compute. |
206
+ | `gamma` | `1.0` | Damping coefficient for the innovation term. Controls the "anisotropy" of spectral shaping. |
207
+ | `shape_nesterov` | `True` | If True, shapes Nesterov momentum; otherwise shapes raw momentum. |
208
+ | `adam_learning_rate` | `None` | Optional override for the Adam branch learning rate. Defaults to `learning_rate` if None. |
209
+
210
+ ### PSGD Specifics
211
+
212
+ | Parameter | Default | Description |
213
+ | :-------------------------- | :------ | :---------------------------------------------------------------------------------------------------------------------------------------------------------- |
214
+ | `track_lipschitz` | `True` | Enables adaptive step sizes for the preconditioner $Q$ by tracking the Lipschitz constant of the gradient. |
215
+ | `max_skew_triangular` | `1.0` | Threshold for diagonal approximation. If a dimension's aspect ratio squared exceeds this relative to total numel, it is treated as diagonal to save memory. |
216
+ | `preconditioner_init_scale` | `None` | Initial scale for $Q$. If `None`, it is estimated on the first step using gradient statistics. |
217
+
218
+ #### Preconditioner Modes
219
+
220
+ The geometry of the preconditioner update $dQ$ is controlled via
221
+ `preconditioner_mode`.
222
+
223
+ | Mode | Formula | Description |
224
+ | :---------- | :---------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------- |
225
+ | `Q0.5EQ1.5` | $dQ = Q^{0.5} \\mathcal{E} Q^{1.5}$ | **Recommended**. Uses an online orthogonal Procrustes solver to keep $Q$ approximately SPD. Numerically stable for low precision. |
226
+ | `EQ` | $dQ = \\mathcal{E} Q$ | The original triangular update. Requires triangular solves. Only mode compatible with triangular $Q$. |
227
+ | `QUAD` | Quadratic Form | Ensures $Q$ remains symmetric positive definite via quadratic form updates. |
228
+ | `NS` | Newton-Schulz | Iteratively projects $Q$ onto the SPD manifold using Newton-Schulz iterations. Exact but more expensive. |
229
+ | `EXP` | Matrix Exponential | Geodesic update on the SPD manifold. Uses matrix exponential. |
230
+ | `TAYLOR2` | Taylor Expansion | Second-order Taylor approximation of the matrix exponential update. |
231
+ | `HYPER` | Hyperbolic | Multiplicative hyperbolic update. |
232
+
233
+ ______________________________________________________________________
234
+
235
+ ## Citations
236
+
237
+ If you use `rollfast` in your research, please cite the relevant papers for the algorithms you utilize.
238
+
239
+ **PRISM:**
240
+
241
+ ```bibtex
242
+ @misc{2602.03096,
243
+ Author = {Yujie Yang},
244
+ Title = {PRISM: Structured Optimization via Anisotropic Spectral Shaping},
245
+ Year = {2026},
246
+ Eprint = {arXiv:2602.03096},
247
+ }
248
+ ```
249
+
250
+ **Schedule-Free:**
251
+
252
+ ```bibtex
253
+ @misc{2405.15682,
254
+ Author = {Aaron Defazio and Xingyu Alice Yang and Harsh Mehta and Konstantin Mishchenko and Ahmed Khaled and Ashok Cutkosky},
255
+ Title = {The Road Less Scheduled},
256
+ Year = {2024},
257
+ Eprint = {arXiv:2405.15682},
258
+ }
259
+
260
+ @misc{2511.07767,
261
+ Author = {Yuen-Man Pun and Matthew Buchholz and Robert M. Gower},
262
+ Title = {Schedulers for Schedule-free: Theoretically inspired hyperparameters},
263
+ Year = {2025},
264
+ Eprint = {arXiv:2511.07767},
265
+ }
266
+ ```
267
+
268
+ **PSGD:**
269
+
270
+ ```bibtex
271
+ @article{li2024stochastic,
272
+ title={Stochastic Hessian Fittings with Lie Groups},
273
+ author={Li, Xi-Lin},
274
+ journal={arXiv preprint arXiv:2402.11858},
275
+ year={2024}
276
+ }
277
+ ```
@@ -0,0 +1,252 @@
1
+ # rollfast: Advanced Optimization Primitives in JAX
2
+
3
+ `rollfast` is a high-performance optimization library for JAX, designed to
4
+ implement cutting-edge optimizers that go beyond standard Euclidean gradient
5
+ descent. It provides production-ready implementations of optimizers like
6
+ **PSGD** (Preconditioned Stochastic Gradient Descent) and **PRISM** (Anisotropic
7
+ Spectral Shaping), along with a robust **Schedule-Free** wrapper.
8
+
9
+ Built on top of the [Optax](https://github.com/google-deepmind/optax) ecosystem,
10
+ `rollfast` prioritizes memory efficiency (via scanned layers and Kronecker
11
+ factorizations), multi-gpu compatibility, mixed-precision trainings and
12
+ scalability for large models.
13
+
14
+ ## Algorithms
15
+
16
+ ### 1. PRISM (Anisotropic Spectral Shaping)
17
+
18
+ PRISM allows for structured optimization by applying anisotropic spectral
19
+ shaping to parameter updates. Unlike standard adaptive methods (Adam) that
20
+ operate element-wise, or full-matrix second-order methods (Shampoo/PSGD) that
21
+ approximate the Hessian, PRISM optimizes the singular value distribution of
22
+ weight matrices directly.
23
+
24
+ - **Mechanism**: Decomposes updates using Newton-Schulz iterations to
25
+ approximate SVD, applying "innovation" updates to the singular vectors while
26
+ damping singular values.
27
+ - **Partitioning**: Automatically partitions parameters. High-rank tensors
28
+ (Linear/Conv weights) are optimized via PRISM; vectors (biases, layernorms) are
29
+ optimized via AdamW.
30
+ - **Reference**: *PRISM: Structured Optimization via Anisotropic Spectral
31
+ Shaping* (Yang, 2026).
32
+
33
+ ### 2. PSGD Kron (Lie Group Preconditioning)
34
+
35
+ PSGD reformulates preconditioner estimation as a strongly convex optimization
36
+ problem on Lie groups. It updates the preconditioner $Q$ (where $P = Q^T Q$)
37
+ using multiplicative updates that avoid explicit matrix inversion.
38
+
39
+ - **Mechanism**: Maintains a Kronecker-factored preconditioner updated via the
40
+ triangular or orthogonal group.
41
+ - **Reference**: *Stochastic Hessian Fittings with Lie Groups* (Li, 2024).
42
+
43
+ ### 3. Schedule-Free Optimization
44
+
45
+ A wrapper that eliminates the need for complex learning rate schedules by
46
+ maintaining two sequences of parameters: a primary sequence $z$ (stepped via the
47
+ base optimizer) and an averaged sequence $x$ (used for evaluation).
48
+
49
+ - **Features**: Supports "Practical" and "Schedulet" weighting modes for
50
+ theoretically grounded averaging.
51
+ - **Reference**: *The Road Less Scheduled* (Defazio et al., 2024).
52
+
53
+ ______________________________________________________________________
54
+
55
+ ## Installation
56
+
57
+ ```bash
58
+ pip install rollfast
59
+ ```
60
+
61
+ ## Usage
62
+
63
+ ### 1. PRISM (Standard)
64
+
65
+ PRISM automatically handles parameter partitioning. You simply provide the
66
+ learning rate and structural hyperparameters.
67
+
68
+ ```python
69
+ import jax
70
+ import jax.numpy as jnp
71
+ from rollfast import prism
72
+
73
+ # Define parameters
74
+ params = {
75
+ 'linear': {'w': jnp.zeros((128, 128)), 'b': jnp.zeros((128,))},
76
+ }
77
+
78
+ # Initialize PRISM
79
+ # 'w' will be optimized by PRISM (Spectral Shaping)
80
+ # 'b' will be optimized by AdamW
81
+ optimizer = prism(
82
+ learning_rate=1e-3,
83
+ ns_iters=5, # Newton-Schulz iterations for orthogonalization
84
+ gamma=1.0, # Innovation damping
85
+ weight_decay=0.01
86
+ )
87
+
88
+ opt_state = optimizer.init(params)
89
+ ```
90
+
91
+ ### 2. Schedule-Free PRISM
92
+
93
+ The `schedule_free_prism` function wraps the PRISM optimizer with the
94
+ Schedule-Free logic and the WSD (Warmup-Stable-Decay) scheduler for the internal
95
+ step size.
96
+
97
+ ```python
98
+ from rollfast.optim import schedule_free_prism
99
+
100
+ optimizer = schedule_free_prism(
101
+ learning_rate=1.0, # Peak LR for internal steps
102
+ total_steps=10000, # Required for WSD schedule generation
103
+ warmup_fraction=0.1,
104
+ weighting_mode="schedulet",
105
+ sf_b1=0.9, # Schedule-Free interpolation (beta)
106
+ gamma=0.8, # PRISM specific arg
107
+ )
108
+
109
+ # Note: In Schedule-Free, you must compute gradients at the averaged location 'x'
110
+ # but apply updates to the state 'z'.
111
+ ```
112
+
113
+ ### 3. PSGD Kron
114
+
115
+ The classic Kronecker-factored PSGD optimizer.
116
+
117
+ ```python
118
+ from rollfast.optim import kron
119
+
120
+ optimizer = kron(
121
+ learning_rate=1e-3,
122
+ b1=0.9,
123
+ preconditioner_lr=0.1,
124
+ preconditioner_mode='Q0.5EQ1.5', # Procrustes-regularized update
125
+ whiten_grad=True
126
+ )
127
+ ```
128
+
129
+ ### Advanced: Scanned Layers (Memory Efficiency)
130
+
131
+ For deep architectures (e.g., Transformers) implemented via `jax.lax.scan`,
132
+ `rollfast` supports explicit handling of scanned layers to prevent unrolling
133
+ computation graphs.
134
+
135
+ ```python
136
+ import jax
137
+ from rollfast.optim import kron
138
+
139
+ # Boolean pytree mask where True indicates a scanned parameter
140
+ scanned_layers_mask = ...
141
+
142
+ optimizer = kron(
143
+ learning_rate=3e-4,
144
+ scanned_layers=scanned_layers_mask,
145
+ lax_map_scanned_layers=True, # Use lax.map for preconditioner updates
146
+ lax_map_batch_size=8
147
+ )
148
+ ```
149
+
150
+ ______________________________________________________________________
151
+
152
+ ## Configuration
153
+
154
+ ### Stability & Clipping Parameters
155
+
156
+ These parameters ensure robustness against gradient spikes and numerical
157
+ instability, critical for training at scale.
158
+
159
+ | Parameter | Default | Description |
160
+ | :---------------------------- | :------------ | :--------------------------------------------------------------------------------------------------------------------------------------------------------- |
161
+ | `raw_global_grad_clip` | `None` | If set, computes the global L2 norm of gradients *before* the optimizer step. If the norm exceeds this threshold, the update is either clipped or skipped. |
162
+ | `permissive_spike_protection` | `True` | Controls behavior when `raw_global_grad_clip` is triggered. `True` clips the gradient and proceeds; `False` strictly skips the update (zeroing the step). |
163
+ | `grad_clip_max_amps` | `(2.0, 10.0)` | Post-processing clipping. Clips individual tensors by RMS (`2.0`) and absolute value (`10.0`) to prevent heavy tails in the update distribution. |
164
+
165
+ ### Schedule-Free Hyperparameters
166
+
167
+ When using `schedule_free_*` optimizers, these arguments control the underlying
168
+ WSD (Warmup-Stable-Decay) schedule and the iterate averaging.
169
+
170
+ | Parameter | Default | Description |
171
+ | :---------------- | :---------- | :---------------------------------------------------------------------------------------------------------------- |
172
+ | `warmup_fraction` | `0.1` | Fraction of `total_steps` used for linear warmup. |
173
+ | `decay_fraction` | `0.1` | Fraction of `total_steps` used for linear decay (cooldown) at the end of training. |
174
+ | `weighting_mode` | `SCHEDULET` | Strategy for $c_t$ calculation: `THEORETICAL` ($1/t$), `PRACTICAL` ($\\gamma_t^2$), or `SCHEDULET` ($\\gamma_t$). |
175
+
176
+ ### PRISM Specifics
177
+
178
+ | Parameter | Default | Description |
179
+ | :------------------- | :------ | :------------------------------------------------------------------------------------------ |
180
+ | `ns_iters` | `5` | Newton-Schulz iterations. Higher values provide better orthogonality but cost more compute. |
181
+ | `gamma` | `1.0` | Damping coefficient for the innovation term. Controls the "anisotropy" of spectral shaping. |
182
+ | `shape_nesterov` | `True` | If True, shapes Nesterov momentum; otherwise shapes raw momentum. |
183
+ | `adam_learning_rate` | `None` | Optional override for the Adam branch learning rate. Defaults to `learning_rate` if None. |
184
+
185
+ ### PSGD Specifics
186
+
187
+ | Parameter | Default | Description |
188
+ | :-------------------------- | :------ | :---------------------------------------------------------------------------------------------------------------------------------------------------------- |
189
+ | `track_lipschitz` | `True` | Enables adaptive step sizes for the preconditioner $Q$ by tracking the Lipschitz constant of the gradient. |
190
+ | `max_skew_triangular` | `1.0` | Threshold for diagonal approximation. If a dimension's aspect ratio squared exceeds this relative to total numel, it is treated as diagonal to save memory. |
191
+ | `preconditioner_init_scale` | `None` | Initial scale for $Q$. If `None`, it is estimated on the first step using gradient statistics. |
192
+
193
+ #### Preconditioner Modes
194
+
195
+ The geometry of the preconditioner update $dQ$ is controlled via
196
+ `preconditioner_mode`.
197
+
198
+ | Mode | Formula | Description |
199
+ | :---------- | :---------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------- |
200
+ | `Q0.5EQ1.5` | $dQ = Q^{0.5} \\mathcal{E} Q^{1.5}$ | **Recommended**. Uses an online orthogonal Procrustes solver to keep $Q$ approximately SPD. Numerically stable for low precision. |
201
+ | `EQ` | $dQ = \\mathcal{E} Q$ | The original triangular update. Requires triangular solves. Only mode compatible with triangular $Q$. |
202
+ | `QUAD` | Quadratic Form | Ensures $Q$ remains symmetric positive definite via quadratic form updates. |
203
+ | `NS` | Newton-Schulz | Iteratively projects $Q$ onto the SPD manifold using Newton-Schulz iterations. Exact but more expensive. |
204
+ | `EXP` | Matrix Exponential | Geodesic update on the SPD manifold. Uses matrix exponential. |
205
+ | `TAYLOR2` | Taylor Expansion | Second-order Taylor approximation of the matrix exponential update. |
206
+ | `HYPER` | Hyperbolic | Multiplicative hyperbolic update. |
207
+
208
+ ______________________________________________________________________
209
+
210
+ ## Citations
211
+
212
+ If you use `rollfast` in your research, please cite the relevant papers for the algorithms you utilize.
213
+
214
+ **PRISM:**
215
+
216
+ ```bibtex
217
+ @misc{2602.03096,
218
+ Author = {Yujie Yang},
219
+ Title = {PRISM: Structured Optimization via Anisotropic Spectral Shaping},
220
+ Year = {2026},
221
+ Eprint = {arXiv:2602.03096},
222
+ }
223
+ ```
224
+
225
+ **Schedule-Free:**
226
+
227
+ ```bibtex
228
+ @misc{2405.15682,
229
+ Author = {Aaron Defazio and Xingyu Alice Yang and Harsh Mehta and Konstantin Mishchenko and Ahmed Khaled and Ashok Cutkosky},
230
+ Title = {The Road Less Scheduled},
231
+ Year = {2024},
232
+ Eprint = {arXiv:2405.15682},
233
+ }
234
+
235
+ @misc{2511.07767,
236
+ Author = {Yuen-Man Pun and Matthew Buchholz and Robert M. Gower},
237
+ Title = {Schedulers for Schedule-free: Theoretically inspired hyperparameters},
238
+ Year = {2025},
239
+ Eprint = {arXiv:2511.07767},
240
+ }
241
+ ```
242
+
243
+ **PSGD:**
244
+
245
+ ```bibtex
246
+ @article{li2024stochastic,
247
+ title={Stochastic Hessian Fittings with Lie Groups},
248
+ author={Li, Xi-Lin},
249
+ journal={arXiv preprint arXiv:2402.11858},
250
+ year={2024}
251
+ }
252
+ ```
@@ -0,0 +1,43 @@
1
+ [project]
2
+ name = "rollfast"
3
+ version = "0.1.0"
4
+ description = "JAX implementation of experimental optimizers and schedulers."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ license = { text = "MIT" }
8
+ authors = [
9
+ { name = "clementpoiret", email = "clement@linux.com" }
10
+ ]
11
+ keywords = [
12
+ "jax",
13
+ "optax",
14
+ "optimizer",
15
+ "psgd",
16
+ "deep-learning",
17
+ "second-order-optimization",
18
+ "preconditioning"
19
+ ]
20
+ classifiers = [
21
+ "Development Status :: 4 - Beta",
22
+ "Intended Audience :: Science/Research",
23
+ "Intended Audience :: Developers",
24
+ "License :: OSI Approved :: MIT License",
25
+ "Programming Language :: Python :: 3",
26
+ "Programming Language :: Python :: 3.11",
27
+ "Programming Language :: Python :: 3.12",
28
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
29
+ "Topic :: Scientific/Engineering :: Mathematics",
30
+ ]
31
+ dependencies = [
32
+ "jax>=0.6.2",
33
+ "optax>=0.2.0",
34
+ ]
35
+
36
+ [project.urls]
37
+ Homepage = "https://github.com/clementpoiret/rollfast"
38
+ Repository = "https://github.com/clementpoiret/rollfast"
39
+ Issues = "https://github.com/clementpoiret/rollfast/issues"
40
+
41
+ [build-system]
42
+ requires = ["uv_build>=0.9.7,<0.10.0"]
43
+ build-backend = "uv_build"
@@ -0,0 +1,10 @@
1
+ from .optim.prism import prism as prism
2
+ from .optim.psgd import kron as kron
3
+ from .schedules.schedulefree import (
4
+ schedule_free_eval_params as schedule_free_eval_params,
5
+ schedule_free_kron as schedule_free_kron,
6
+ schedule_free_prism as schedule_free_prism,
7
+ )
8
+ from .schedules.wsd import wsd_schedule as wsd_schedule
9
+
10
+ __version__ = "0.1.0"