python-somax 0.0.1__tar.gz → 1.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- python_somax-1.0.1/PKG-INFO +202 -0
- python_somax-1.0.1/README.md +162 -0
- python_somax-1.0.1/pyproject.toml +93 -0
- python_somax-1.0.1/src/python_somax.egg-info/PKG-INFO +202 -0
- python_somax-1.0.1/src/python_somax.egg-info/SOURCES.txt +53 -0
- python_somax-1.0.1/src/python_somax.egg-info/requires.txt +18 -0
- python_somax-1.0.1/src/somax/__init__.py +58 -0
- python_somax-1.0.1/src/somax/assembler.py +151 -0
- python_somax-1.0.1/src/somax/curvature/__init__.py +14 -0
- python_somax-1.0.1/src/somax/curvature/base.py +56 -0
- python_somax-1.0.1/src/somax/curvature/ggn_ce.py +183 -0
- python_somax-1.0.1/src/somax/curvature/ggn_mse.py +104 -0
- python_somax-1.0.1/src/somax/curvature/hessian.py +63 -0
- python_somax-1.0.1/src/somax/curvature/with_estimators.py +45 -0
- python_somax-1.0.1/src/somax/damping/__init__.py +12 -0
- python_somax-1.0.1/src/somax/damping/base.py +33 -0
- python_somax-1.0.1/src/somax/damping/constant.py +28 -0
- python_somax-1.0.1/src/somax/damping/trust_region.py +115 -0
- python_somax-1.0.1/src/somax/estimators/__init__.py +18 -0
- python_somax-1.0.1/src/somax/estimators/base.py +53 -0
- python_somax-1.0.1/src/somax/estimators/gnb_ce.py +67 -0
- python_somax-1.0.1/src/somax/estimators/hutchinson.py +121 -0
- python_somax-1.0.1/src/somax/executor.py +528 -0
- python_somax-1.0.1/src/somax/methods/__init__.py +30 -0
- python_somax-1.0.1/src/somax/methods/adahessian.py +83 -0
- python_somax-1.0.1/src/somax/methods/direct_methods.py +151 -0
- python_somax-1.0.1/src/somax/methods/egn.py +233 -0
- python_somax-1.0.1/src/somax/methods/newton_cg.py +100 -0
- python_somax-1.0.1/src/somax/methods/sgn.py +194 -0
- python_somax-1.0.1/src/somax/methods/sophia_g.py +74 -0
- python_somax-1.0.1/src/somax/methods/sophia_h.py +76 -0
- python_somax-1.0.1/src/somax/metrics.py +79 -0
- python_somax-1.0.1/src/somax/optax.py +61 -0
- python_somax-1.0.1/src/somax/planner.py +197 -0
- python_somax-1.0.1/src/somax/preconditioners/__init__.py +15 -0
- python_somax-1.0.1/src/somax/preconditioners/base.py +42 -0
- python_somax-1.0.1/src/somax/preconditioners/diag_direct.py +86 -0
- python_somax-1.0.1/src/somax/preconditioners/diag_ema.py +183 -0
- python_somax-1.0.1/src/somax/preconditioners/identity.py +35 -0
- python_somax-1.0.1/src/somax/presets.py +60 -0
- python_somax-1.0.1/src/somax/solvers/__init__.py +16 -0
- python_somax-1.0.1/src/somax/solvers/base.py +49 -0
- python_somax-1.0.1/src/somax/solvers/cg.py +554 -0
- python_somax-1.0.1/src/somax/solvers/direct.py +55 -0
- python_somax-1.0.1/src/somax/solvers/identity.py +36 -0
- python_somax-1.0.1/src/somax/solvers/row_cg.py +89 -0
- python_somax-1.0.1/src/somax/solvers/row_cholesky.py +78 -0
- python_somax-1.0.1/src/somax/solvers/row_common.py +71 -0
- python_somax-1.0.1/src/somax/specs.py +246 -0
- python_somax-1.0.1/src/somax/types.py +53 -0
- python_somax-1.0.1/src/somax/utils.py +197 -0
- python_somax-0.0.1/PKG-INFO +0 -131
- python_somax-0.0.1/README.md +0 -112
- python_somax-0.0.1/setup.py +0 -40
- python_somax-0.0.1/src/python_somax.egg-info/PKG-INFO +0 -131
- python_somax-0.0.1/src/python_somax.egg-info/SOURCES.txt +0 -30
- python_somax-0.0.1/src/python_somax.egg-info/requires.txt +0 -1
- python_somax-0.0.1/src/somax/__init__.py +0 -19
- python_somax-0.0.1/src/somax/diagonal/__init__.py +0 -0
- python_somax-0.0.1/src/somax/diagonal/adahessian.py +0 -192
- python_somax-0.0.1/src/somax/diagonal/sophia_g.py +0 -208
- python_somax-0.0.1/src/somax/diagonal/sophia_h.py +0 -191
- python_somax-0.0.1/src/somax/gn/__init__.py +0 -0
- python_somax-0.0.1/src/somax/gn/egn.py +0 -567
- python_somax-0.0.1/src/somax/gn/sgn.py +0 -237
- python_somax-0.0.1/src/somax/hf/__init__.py +0 -0
- python_somax-0.0.1/src/somax/hf/newton_cg.py +0 -339
- python_somax-0.0.1/src/somax/ng/__init__.py +0 -0
- python_somax-0.0.1/src/somax/ng/swm_ng.py +0 -411
- python_somax-0.0.1/src/somax/qn/__init__.py +0 -0
- python_somax-0.0.1/src/somax/qn/sqn.py +0 -311
- python_somax-0.0.1/tests/test_adahessian.py +0 -134
- python_somax-0.0.1/tests/test_egn.py +0 -306
- python_somax-0.0.1/tests/test_newton_cg.py +0 -254
- python_somax-0.0.1/tests/test_sgn.py +0 -273
- python_somax-0.0.1/tests/test_sophia_g.py +0 -72
- python_somax-0.0.1/tests/test_sophia_h.py +0 -130
- python_somax-0.0.1/tests/test_sqn.py +0 -126
- python_somax-0.0.1/tests/test_swm_ng.py +0 -138
- {python_somax-0.0.1 → python_somax-1.0.1}/LICENSE +0 -0
- {python_somax-0.0.1 → python_somax-1.0.1}/setup.cfg +0 -0
- {python_somax-0.0.1 → python_somax-1.0.1}/src/python_somax.egg-info/dependency_links.txt +0 -0
- {python_somax-0.0.1 → python_somax-1.0.1}/src/python_somax.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: python-somax
|
|
3
|
+
Version: 1.0.1
|
|
4
|
+
Summary: Composable Second-Order Optimization for JAX and Optax.
|
|
5
|
+
Author: Nick Korbit
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/cor3bit/somax
|
|
8
|
+
Project-URL: Repository, https://github.com/cor3bit/somax
|
|
9
|
+
Project-URL: Issues, https://github.com/cor3bit/somax/issues
|
|
10
|
+
Project-URL: Releases, https://github.com/cor3bit/somax/releases
|
|
11
|
+
Project-URL: Paper, https://arxiv.org/abs/2603.25976
|
|
12
|
+
Keywords: jax,optax,optimization,second-order,gauss-newton,hessian,machine-learning
|
|
13
|
+
Classifier: Development Status :: 3 - Alpha
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
20
|
+
Requires-Python: >=3.11
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: jax>=0.4.20
|
|
24
|
+
Requires-Dist: optax>=0.1.7
|
|
25
|
+
Requires-Dist: flax>=0.7.5
|
|
26
|
+
Requires-Dist: numpy>=1.23
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: build>=1.2; extra == "dev"
|
|
29
|
+
Requires-Dist: twine>=5; extra == "dev"
|
|
30
|
+
Requires-Dist: pytest>=8; extra == "dev"
|
|
31
|
+
Requires-Dist: pytest-cov>=5; extra == "dev"
|
|
32
|
+
Requires-Dist: ruff>=0.5; extra == "dev"
|
|
33
|
+
Requires-Dist: pyright>=1.1; extra == "dev"
|
|
34
|
+
Requires-Dist: pre-commit>=3; extra == "dev"
|
|
35
|
+
Requires-Dist: chex>=0.1.8; extra == "dev"
|
|
36
|
+
Provides-Extra: docs
|
|
37
|
+
Requires-Dist: mkdocs>=1.5; extra == "docs"
|
|
38
|
+
Requires-Dist: mkdocs-material>=9; extra == "docs"
|
|
39
|
+
Dynamic: license-file
|
|
40
|
+
|
|
41
|
+
<h1 align="center">Somax</h1>
|
|
42
|
+
|
|
43
|
+
<p align="center">
|
|
44
|
+
<img src="assets/somax_logo_mini.png" alt="Somax logo" width="250px"/>
|
|
45
|
+
</p>
|
|
46
|
+
|
|
47
|
+
<p align="center">
|
|
48
|
+
Composable Second-Order Optimization for JAX and Optax.
|
|
49
|
+
</p>
|
|
50
|
+
|
|
51
|
+
<p align="center">
|
|
52
|
+
A small research-engineering library for curvature-aware training:
|
|
53
|
+
modular, matrix-free, and explicit about the moving parts.
|
|
54
|
+
</p>
|
|
55
|
+
|
|
56
|
+
---
|
|
57
|
+
|
|
58
|
+
Somax is a JAX-native library for building and running second-order optimization methods from explicit components.
|
|
59
|
+
|
|
60
|
+
Rather than treating an optimizer as a monolithic object, Somax factors a step into swappable pieces:
|
|
61
|
+
- curvature operator
|
|
62
|
+
- solver
|
|
63
|
+
- damping policy
|
|
64
|
+
- optional preconditioner
|
|
65
|
+
- update transform
|
|
66
|
+
- optional telemetry and control signals
|
|
67
|
+
|
|
68
|
+
That decomposition is the point.
|
|
69
|
+
|
|
70
|
+
Somax is built for users who want a clean second-order stack in JAX without hiding the execution model.
|
|
71
|
+
It aims to make curvature-aware training easier to inspect, compare, and extend.
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
> The catfish in the logo is a small nod to *som*, the Belarusian word for catfish.
|
|
75
|
+
> A quiet bottom-dweller, but not a first-order creature.
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
## Why Somax
|
|
80
|
+
|
|
81
|
+
- **Composable**: build methods from curvature, solver, damping, preconditioner, and update components.
|
|
82
|
+
- **Optax-native**: computed directions are fed through Optax-style update transforms.
|
|
83
|
+
- **Planned execution**: a method is assembled once, planned once, and then executed as a stable step pipeline.
|
|
84
|
+
- **JAX-first**: intended for `jit`-compiled training loops and explicit control over execution.
|
|
85
|
+
- **Multiple solve lanes**: diagonal, parameter-space, and row-space paths are first-class parts of the stack.
|
|
86
|
+
- **Research-friendly**: easy to inspect, compare, ablate, and extend.
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
## Installation
|
|
92
|
+
|
|
93
|
+
Install JAX for your backend first:
|
|
94
|
+
|
|
95
|
+
- JAX installation guide: https://docs.jax.dev/en/latest/installation.html
|
|
96
|
+
|
|
97
|
+
Then install Somax:
|
|
98
|
+
|
|
99
|
+
```bash
|
|
100
|
+
pip install python-somax
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
For local development:
|
|
104
|
+
|
|
105
|
+
```bash
|
|
106
|
+
git clone https://github.com/cor3bit/somax.git
|
|
107
|
+
cd somax
|
|
108
|
+
pip install -e ".[dev]"
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
Optional:
|
|
112
|
+
- install `lineax` only if you want to use CG backends with `backend="lineax"`.
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
## Quickstart
|
|
117
|
+
|
|
118
|
+
```python
|
|
119
|
+
import jax
|
|
120
|
+
import jax.numpy as jnp
|
|
121
|
+
import somax
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def predict_fn(params, x):
|
|
125
|
+
h = jnp.tanh(x @ params["W1"] + params["b1"])
|
|
126
|
+
return h @ params["W2"] + params["b2"]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
rng = jax.random.PRNGKey(0)
|
|
130
|
+
k1, k2, k3, k4 = jax.random.split(rng, 4)
|
|
131
|
+
|
|
132
|
+
params = {
|
|
133
|
+
"W1": 0.1 * jax.random.normal(k1, (16, 32)),
|
|
134
|
+
"b1": jnp.zeros((32,)),
|
|
135
|
+
"W2": 0.1 * jax.random.normal(k2, (32, 10)),
|
|
136
|
+
"b2": jnp.zeros((10,)),
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
batch = {
|
|
140
|
+
"x": jax.random.normal(k3, (64, 16)),
|
|
141
|
+
"y": jax.random.randint(k4, (64,), 0, 10),
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
method = somax.sgn_ce(
|
|
145
|
+
predict_fn=predict_fn,
|
|
146
|
+
lam0=1e-2,
|
|
147
|
+
tol=1e-4,
|
|
148
|
+
maxiter=20,
|
|
149
|
+
learning_rate=1e-1,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
state = method.init(params)
|
|
153
|
+
|
|
154
|
+
@jax.jit
|
|
155
|
+
def train_step(params, state, rng):
|
|
156
|
+
params, state, info = method.step(params, batch, state, rng)
|
|
157
|
+
return params, state, info
|
|
158
|
+
|
|
159
|
+
for step in range(10):
|
|
160
|
+
params, state, info = train_step(params, state, jax.random.fold_in(rng, step))
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
## Citation
|
|
167
|
+
|
|
168
|
+
If Somax is useful in your academic work, please cite:
|
|
169
|
+
|
|
170
|
+
**Second-Order, First-Class: A Composable Stack for Curvature-Aware Training**
|
|
171
|
+
Mikalai Korbit and Mario Zanon
|
|
172
|
+
https://arxiv.org/abs/2603.25976
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
```bibtex
|
|
176
|
+
@article{korbit2026second,
|
|
177
|
+
title={Second-Order, First-Class: A Composable Stack for Curvature-Aware Training},
|
|
178
|
+
author={Korbit, Mikalai and Zanon, Mario},
|
|
179
|
+
journal={arXiv preprint arXiv:2603.25976},
|
|
180
|
+
year={2026}
|
|
181
|
+
}
|
|
182
|
+
```
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
## Related projects
|
|
186
|
+
|
|
187
|
+
**Optimization in JAX**
|
|
188
|
+
[Optax](https://github.com/google-deepmind/optax): first-order gradient (e.g., SGD, Adam) optimisers.
|
|
189
|
+
[JAXopt](https://github.com/google/jaxopt): deterministic second-order methods (e.g., Gauss-Newton, Levenberg
|
|
190
|
+
Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
|
|
191
|
+
|
|
192
|
+
**Awesome Projects**
|
|
193
|
+
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of various JAX projects.
|
|
194
|
+
[Awesome SOMs](https://github.com/cor3bit/awesome-soms): a list
|
|
195
|
+
of resources for second-order optimization methods in machine learning.
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
## License
|
|
201
|
+
|
|
202
|
+
Apache-2.0
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
<h1 align="center">Somax</h1>
|
|
2
|
+
|
|
3
|
+
<p align="center">
|
|
4
|
+
<img src="assets/somax_logo_mini.png" alt="Somax logo" width="250px"/>
|
|
5
|
+
</p>
|
|
6
|
+
|
|
7
|
+
<p align="center">
|
|
8
|
+
Composable Second-Order Optimization for JAX and Optax.
|
|
9
|
+
</p>
|
|
10
|
+
|
|
11
|
+
<p align="center">
|
|
12
|
+
A small research-engineering library for curvature-aware training:
|
|
13
|
+
modular, matrix-free, and explicit about the moving parts.
|
|
14
|
+
</p>
|
|
15
|
+
|
|
16
|
+
---
|
|
17
|
+
|
|
18
|
+
Somax is a JAX-native library for building and running second-order optimization methods from explicit components.
|
|
19
|
+
|
|
20
|
+
Rather than treating an optimizer as a monolithic object, Somax factors a step into swappable pieces:
|
|
21
|
+
- curvature operator
|
|
22
|
+
- solver
|
|
23
|
+
- damping policy
|
|
24
|
+
- optional preconditioner
|
|
25
|
+
- update transform
|
|
26
|
+
- optional telemetry and control signals
|
|
27
|
+
|
|
28
|
+
That decomposition is the point.
|
|
29
|
+
|
|
30
|
+
Somax is built for users who want a clean second-order stack in JAX without hiding the execution model.
|
|
31
|
+
It aims to make curvature-aware training easier to inspect, compare, and extend.
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
> The catfish in the logo is a small nod to *som*, the Belarusian word for catfish.
|
|
35
|
+
> A quiet bottom-dweller, but not a first-order creature.
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
## Why Somax
|
|
40
|
+
|
|
41
|
+
- **Composable**: build methods from curvature, solver, damping, preconditioner, and update components.
|
|
42
|
+
- **Optax-native**: computed directions are fed through Optax-style update transforms.
|
|
43
|
+
- **Planned execution**: a method is assembled once, planned once, and then executed as a stable step pipeline.
|
|
44
|
+
- **JAX-first**: intended for `jit`-compiled training loops and explicit control over execution.
|
|
45
|
+
- **Multiple solve lanes**: diagonal, parameter-space, and row-space paths are first-class parts of the stack.
|
|
46
|
+
- **Research-friendly**: easy to inspect, compare, ablate, and extend.
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
## Installation
|
|
52
|
+
|
|
53
|
+
Install JAX for your backend first:
|
|
54
|
+
|
|
55
|
+
- JAX installation guide: https://docs.jax.dev/en/latest/installation.html
|
|
56
|
+
|
|
57
|
+
Then install Somax:
|
|
58
|
+
|
|
59
|
+
```bash
|
|
60
|
+
pip install python-somax
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
For local development:
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
git clone https://github.com/cor3bit/somax.git
|
|
67
|
+
cd somax
|
|
68
|
+
pip install -e ".[dev]"
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
Optional:
|
|
72
|
+
- install `lineax` only if you want to use CG backends with `backend="lineax"`.
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
## Quickstart
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
import jax
|
|
80
|
+
import jax.numpy as jnp
|
|
81
|
+
import somax
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def predict_fn(params, x):
|
|
85
|
+
h = jnp.tanh(x @ params["W1"] + params["b1"])
|
|
86
|
+
return h @ params["W2"] + params["b2"]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
rng = jax.random.PRNGKey(0)
|
|
90
|
+
k1, k2, k3, k4 = jax.random.split(rng, 4)
|
|
91
|
+
|
|
92
|
+
params = {
|
|
93
|
+
"W1": 0.1 * jax.random.normal(k1, (16, 32)),
|
|
94
|
+
"b1": jnp.zeros((32,)),
|
|
95
|
+
"W2": 0.1 * jax.random.normal(k2, (32, 10)),
|
|
96
|
+
"b2": jnp.zeros((10,)),
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
batch = {
|
|
100
|
+
"x": jax.random.normal(k3, (64, 16)),
|
|
101
|
+
"y": jax.random.randint(k4, (64,), 0, 10),
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
method = somax.sgn_ce(
|
|
105
|
+
predict_fn=predict_fn,
|
|
106
|
+
lam0=1e-2,
|
|
107
|
+
tol=1e-4,
|
|
108
|
+
maxiter=20,
|
|
109
|
+
learning_rate=1e-1,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
state = method.init(params)
|
|
113
|
+
|
|
114
|
+
@jax.jit
|
|
115
|
+
def train_step(params, state, rng):
|
|
116
|
+
params, state, info = method.step(params, batch, state, rng)
|
|
117
|
+
return params, state, info
|
|
118
|
+
|
|
119
|
+
for step in range(10):
|
|
120
|
+
params, state, info = train_step(params, state, jax.random.fold_in(rng, step))
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
## Citation
|
|
127
|
+
|
|
128
|
+
If Somax is useful in your academic work, please cite:
|
|
129
|
+
|
|
130
|
+
**Second-Order, First-Class: A Composable Stack for Curvature-Aware Training**
|
|
131
|
+
Mikalai Korbit and Mario Zanon
|
|
132
|
+
https://arxiv.org/abs/2603.25976
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
```bibtex
|
|
136
|
+
@article{korbit2026second,
|
|
137
|
+
title={Second-Order, First-Class: A Composable Stack for Curvature-Aware Training},
|
|
138
|
+
author={Korbit, Mikalai and Zanon, Mario},
|
|
139
|
+
journal={arXiv preprint arXiv:2603.25976},
|
|
140
|
+
year={2026}
|
|
141
|
+
}
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
## Related projects
|
|
146
|
+
|
|
147
|
+
**Optimization in JAX**
|
|
148
|
+
[Optax](https://github.com/google-deepmind/optax): first-order gradient (e.g., SGD, Adam) optimisers.
|
|
149
|
+
[JAXopt](https://github.com/google/jaxopt): deterministic second-order methods (e.g., Gauss-Newton, Levenberg
|
|
150
|
+
Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
|
|
151
|
+
|
|
152
|
+
**Awesome Projects**
|
|
153
|
+
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of various JAX projects.
|
|
154
|
+
[Awesome SOMs](https://github.com/cor3bit/awesome-soms): a list
|
|
155
|
+
of resources for second-order optimization methods in machine learning.
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
## License
|
|
161
|
+
|
|
162
|
+
Apache-2.0
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=77"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "python-somax"
|
|
7
|
+
version = "1.0.1"
|
|
8
|
+
description = "Composable Second-Order Optimization for JAX and Optax."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
license = "Apache-2.0"
|
|
12
|
+
license-files = ["LICENSE"]
|
|
13
|
+
authors = [
|
|
14
|
+
{ name = "Nick Korbit" }
|
|
15
|
+
]
|
|
16
|
+
dependencies = [
|
|
17
|
+
"jax>=0.4.20",
|
|
18
|
+
"optax>=0.1.7",
|
|
19
|
+
"flax>=0.7.5",
|
|
20
|
+
"numpy>=1.23",
|
|
21
|
+
]
|
|
22
|
+
keywords = [
|
|
23
|
+
"jax",
|
|
24
|
+
"optax",
|
|
25
|
+
"optimization",
|
|
26
|
+
"second-order",
|
|
27
|
+
"gauss-newton",
|
|
28
|
+
"hessian",
|
|
29
|
+
"machine-learning",
|
|
30
|
+
]
|
|
31
|
+
classifiers = [
|
|
32
|
+
"Development Status :: 3 - Alpha",
|
|
33
|
+
"Intended Audience :: Science/Research",
|
|
34
|
+
"Programming Language :: Python :: 3",
|
|
35
|
+
"Programming Language :: Python :: 3.11",
|
|
36
|
+
"Programming Language :: Python :: 3.12",
|
|
37
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
38
|
+
"Topic :: Scientific/Engineering :: Mathematics",
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
[project.urls]
|
|
42
|
+
Homepage = "https://github.com/cor3bit/somax"
|
|
43
|
+
Repository = "https://github.com/cor3bit/somax"
|
|
44
|
+
Issues = "https://github.com/cor3bit/somax/issues"
|
|
45
|
+
Releases = "https://github.com/cor3bit/somax/releases"
|
|
46
|
+
#Documentation = "https://github.com/cor3bit/somax/tree/main/docs"
|
|
47
|
+
Paper = "https://arxiv.org/abs/2603.25976"
|
|
48
|
+
|
|
49
|
+
[project.optional-dependencies]
|
|
50
|
+
dev = [
|
|
51
|
+
"build>=1.2",
|
|
52
|
+
"twine>=5",
|
|
53
|
+
"pytest>=8",
|
|
54
|
+
"pytest-cov>=5",
|
|
55
|
+
"ruff>=0.5",
|
|
56
|
+
"pyright>=1.1",
|
|
57
|
+
"pre-commit>=3",
|
|
58
|
+
"chex>=0.1.8",
|
|
59
|
+
]
|
|
60
|
+
docs = [
|
|
61
|
+
"mkdocs>=1.5",
|
|
62
|
+
"mkdocs-material>=9",
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
[tool.setuptools.packages.find]
|
|
66
|
+
where = ["src"]
|
|
67
|
+
|
|
68
|
+
[tool.ruff]
|
|
69
|
+
line-length = 100
|
|
70
|
+
target-version = "py311"
|
|
71
|
+
|
|
72
|
+
[tool.ruff.lint]
|
|
73
|
+
select = ["E", "F", "I", "N", "W", "UP", "B", "SIM"]
|
|
74
|
+
ignore = ["E501", "N803", "N806"]
|
|
75
|
+
|
|
76
|
+
[tool.ruff.format]
|
|
77
|
+
quote-style = "double"
|
|
78
|
+
indent-style = "space"
|
|
79
|
+
skip-magic-trailing-comma = false
|
|
80
|
+
line-ending = "lf"
|
|
81
|
+
|
|
82
|
+
[tool.pyright]
|
|
83
|
+
include = ["src"]
|
|
84
|
+
exclude = ["**/node_modules", "**/__pycache__"]
|
|
85
|
+
reportMissingImports = true
|
|
86
|
+
reportMissingTypeStubs = false
|
|
87
|
+
pythonVersion = "3.11"
|
|
88
|
+
typeCheckingMode = "basic"
|
|
89
|
+
|
|
90
|
+
[tool.pytest.ini_options]
|
|
91
|
+
testpaths = ["tests"]
|
|
92
|
+
addopts = "-q --strict-markers"
|
|
93
|
+
xfail_strict = true
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: python-somax
|
|
3
|
+
Version: 1.0.1
|
|
4
|
+
Summary: Composable Second-Order Optimization for JAX and Optax.
|
|
5
|
+
Author: Nick Korbit
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/cor3bit/somax
|
|
8
|
+
Project-URL: Repository, https://github.com/cor3bit/somax
|
|
9
|
+
Project-URL: Issues, https://github.com/cor3bit/somax/issues
|
|
10
|
+
Project-URL: Releases, https://github.com/cor3bit/somax/releases
|
|
11
|
+
Project-URL: Paper, https://arxiv.org/abs/2603.25976
|
|
12
|
+
Keywords: jax,optax,optimization,second-order,gauss-newton,hessian,machine-learning
|
|
13
|
+
Classifier: Development Status :: 3 - Alpha
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
20
|
+
Requires-Python: >=3.11
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: jax>=0.4.20
|
|
24
|
+
Requires-Dist: optax>=0.1.7
|
|
25
|
+
Requires-Dist: flax>=0.7.5
|
|
26
|
+
Requires-Dist: numpy>=1.23
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: build>=1.2; extra == "dev"
|
|
29
|
+
Requires-Dist: twine>=5; extra == "dev"
|
|
30
|
+
Requires-Dist: pytest>=8; extra == "dev"
|
|
31
|
+
Requires-Dist: pytest-cov>=5; extra == "dev"
|
|
32
|
+
Requires-Dist: ruff>=0.5; extra == "dev"
|
|
33
|
+
Requires-Dist: pyright>=1.1; extra == "dev"
|
|
34
|
+
Requires-Dist: pre-commit>=3; extra == "dev"
|
|
35
|
+
Requires-Dist: chex>=0.1.8; extra == "dev"
|
|
36
|
+
Provides-Extra: docs
|
|
37
|
+
Requires-Dist: mkdocs>=1.5; extra == "docs"
|
|
38
|
+
Requires-Dist: mkdocs-material>=9; extra == "docs"
|
|
39
|
+
Dynamic: license-file
|
|
40
|
+
|
|
41
|
+
<h1 align="center">Somax</h1>
|
|
42
|
+
|
|
43
|
+
<p align="center">
|
|
44
|
+
<img src="assets/somax_logo_mini.png" alt="Somax logo" width="250px"/>
|
|
45
|
+
</p>
|
|
46
|
+
|
|
47
|
+
<p align="center">
|
|
48
|
+
Composable Second-Order Optimization for JAX and Optax.
|
|
49
|
+
</p>
|
|
50
|
+
|
|
51
|
+
<p align="center">
|
|
52
|
+
A small research-engineering library for curvature-aware training:
|
|
53
|
+
modular, matrix-free, and explicit about the moving parts.
|
|
54
|
+
</p>
|
|
55
|
+
|
|
56
|
+
---
|
|
57
|
+
|
|
58
|
+
Somax is a JAX-native library for building and running second-order optimization methods from explicit components.
|
|
59
|
+
|
|
60
|
+
Rather than treating an optimizer as a monolithic object, Somax factors a step into swappable pieces:
|
|
61
|
+
- curvature operator
|
|
62
|
+
- solver
|
|
63
|
+
- damping policy
|
|
64
|
+
- optional preconditioner
|
|
65
|
+
- update transform
|
|
66
|
+
- optional telemetry and control signals
|
|
67
|
+
|
|
68
|
+
That decomposition is the point.
|
|
69
|
+
|
|
70
|
+
Somax is built for users who want a clean second-order stack in JAX without hiding the execution model.
|
|
71
|
+
It aims to make curvature-aware training easier to inspect, compare, and extend.
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
> The catfish in the logo is a small nod to *som*, the Belarusian word for catfish.
|
|
75
|
+
> A quiet bottom-dweller, but not a first-order creature.
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
## Why Somax
|
|
80
|
+
|
|
81
|
+
- **Composable**: build methods from curvature, solver, damping, preconditioner, and update components.
|
|
82
|
+
- **Optax-native**: computed directions are fed through Optax-style update transforms.
|
|
83
|
+
- **Planned execution**: a method is assembled once, planned once, and then executed as a stable step pipeline.
|
|
84
|
+
- **JAX-first**: intended for `jit`-compiled training loops and explicit control over execution.
|
|
85
|
+
- **Multiple solve lanes**: diagonal, parameter-space, and row-space paths are first-class parts of the stack.
|
|
86
|
+
- **Research-friendly**: easy to inspect, compare, ablate, and extend.
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
## Installation
|
|
92
|
+
|
|
93
|
+
Install JAX for your backend first:
|
|
94
|
+
|
|
95
|
+
- JAX installation guide: https://docs.jax.dev/en/latest/installation.html
|
|
96
|
+
|
|
97
|
+
Then install Somax:
|
|
98
|
+
|
|
99
|
+
```bash
|
|
100
|
+
pip install python-somax
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
For local development:
|
|
104
|
+
|
|
105
|
+
```bash
|
|
106
|
+
git clone https://github.com/cor3bit/somax.git
|
|
107
|
+
cd somax
|
|
108
|
+
pip install -e ".[dev]"
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
Optional:
|
|
112
|
+
- install `lineax` only if you want to use CG backends with `backend="lineax"`.
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
## Quickstart
|
|
117
|
+
|
|
118
|
+
```python
|
|
119
|
+
import jax
|
|
120
|
+
import jax.numpy as jnp
|
|
121
|
+
import somax
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def predict_fn(params, x):
|
|
125
|
+
h = jnp.tanh(x @ params["W1"] + params["b1"])
|
|
126
|
+
return h @ params["W2"] + params["b2"]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
rng = jax.random.PRNGKey(0)
|
|
130
|
+
k1, k2, k3, k4 = jax.random.split(rng, 4)
|
|
131
|
+
|
|
132
|
+
params = {
|
|
133
|
+
"W1": 0.1 * jax.random.normal(k1, (16, 32)),
|
|
134
|
+
"b1": jnp.zeros((32,)),
|
|
135
|
+
"W2": 0.1 * jax.random.normal(k2, (32, 10)),
|
|
136
|
+
"b2": jnp.zeros((10,)),
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
batch = {
|
|
140
|
+
"x": jax.random.normal(k3, (64, 16)),
|
|
141
|
+
"y": jax.random.randint(k4, (64,), 0, 10),
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
method = somax.sgn_ce(
|
|
145
|
+
predict_fn=predict_fn,
|
|
146
|
+
lam0=1e-2,
|
|
147
|
+
tol=1e-4,
|
|
148
|
+
maxiter=20,
|
|
149
|
+
learning_rate=1e-1,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
state = method.init(params)
|
|
153
|
+
|
|
154
|
+
@jax.jit
|
|
155
|
+
def train_step(params, state, rng):
|
|
156
|
+
params, state, info = method.step(params, batch, state, rng)
|
|
157
|
+
return params, state, info
|
|
158
|
+
|
|
159
|
+
for step in range(10):
|
|
160
|
+
params, state, info = train_step(params, state, jax.random.fold_in(rng, step))
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
## Citation
|
|
167
|
+
|
|
168
|
+
If Somax is useful in your academic work, please cite:
|
|
169
|
+
|
|
170
|
+
**Second-Order, First-Class: A Composable Stack for Curvature-Aware Training**
|
|
171
|
+
Mikalai Korbit and Mario Zanon
|
|
172
|
+
https://arxiv.org/abs/2603.25976
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
```bibtex
|
|
176
|
+
@article{korbit2026second,
|
|
177
|
+
title={Second-Order, First-Class: A Composable Stack for Curvature-Aware Training},
|
|
178
|
+
author={Korbit, Mikalai and Zanon, Mario},
|
|
179
|
+
journal={arXiv preprint arXiv:2603.25976},
|
|
180
|
+
year={2026}
|
|
181
|
+
}
|
|
182
|
+
```
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
## Related projects
|
|
186
|
+
|
|
187
|
+
**Optimization in JAX**
|
|
188
|
+
[Optax](https://github.com/google-deepmind/optax): first-order gradient (e.g., SGD, Adam) optimisers.
|
|
189
|
+
[JAXopt](https://github.com/google/jaxopt): deterministic second-order methods (e.g., Gauss-Newton, Levenberg
|
|
190
|
+
Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
|
|
191
|
+
|
|
192
|
+
**Awesome Projects**
|
|
193
|
+
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of various JAX projects.
|
|
194
|
+
[Awesome SOMs](https://github.com/cor3bit/awesome-soms): a list
|
|
195
|
+
of resources for second-order optimization methods in machine learning.
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
## License
|
|
201
|
+
|
|
202
|
+
Apache-2.0
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/python_somax.egg-info/PKG-INFO
|
|
5
|
+
src/python_somax.egg-info/SOURCES.txt
|
|
6
|
+
src/python_somax.egg-info/dependency_links.txt
|
|
7
|
+
src/python_somax.egg-info/requires.txt
|
|
8
|
+
src/python_somax.egg-info/top_level.txt
|
|
9
|
+
src/somax/__init__.py
|
|
10
|
+
src/somax/assembler.py
|
|
11
|
+
src/somax/executor.py
|
|
12
|
+
src/somax/metrics.py
|
|
13
|
+
src/somax/optax.py
|
|
14
|
+
src/somax/planner.py
|
|
15
|
+
src/somax/presets.py
|
|
16
|
+
src/somax/specs.py
|
|
17
|
+
src/somax/types.py
|
|
18
|
+
src/somax/utils.py
|
|
19
|
+
src/somax/curvature/__init__.py
|
|
20
|
+
src/somax/curvature/base.py
|
|
21
|
+
src/somax/curvature/ggn_ce.py
|
|
22
|
+
src/somax/curvature/ggn_mse.py
|
|
23
|
+
src/somax/curvature/hessian.py
|
|
24
|
+
src/somax/curvature/with_estimators.py
|
|
25
|
+
src/somax/damping/__init__.py
|
|
26
|
+
src/somax/damping/base.py
|
|
27
|
+
src/somax/damping/constant.py
|
|
28
|
+
src/somax/damping/trust_region.py
|
|
29
|
+
src/somax/estimators/__init__.py
|
|
30
|
+
src/somax/estimators/base.py
|
|
31
|
+
src/somax/estimators/gnb_ce.py
|
|
32
|
+
src/somax/estimators/hutchinson.py
|
|
33
|
+
src/somax/methods/__init__.py
|
|
34
|
+
src/somax/methods/adahessian.py
|
|
35
|
+
src/somax/methods/direct_methods.py
|
|
36
|
+
src/somax/methods/egn.py
|
|
37
|
+
src/somax/methods/newton_cg.py
|
|
38
|
+
src/somax/methods/sgn.py
|
|
39
|
+
src/somax/methods/sophia_g.py
|
|
40
|
+
src/somax/methods/sophia_h.py
|
|
41
|
+
src/somax/preconditioners/__init__.py
|
|
42
|
+
src/somax/preconditioners/base.py
|
|
43
|
+
src/somax/preconditioners/diag_direct.py
|
|
44
|
+
src/somax/preconditioners/diag_ema.py
|
|
45
|
+
src/somax/preconditioners/identity.py
|
|
46
|
+
src/somax/solvers/__init__.py
|
|
47
|
+
src/somax/solvers/base.py
|
|
48
|
+
src/somax/solvers/cg.py
|
|
49
|
+
src/somax/solvers/direct.py
|
|
50
|
+
src/somax/solvers/identity.py
|
|
51
|
+
src/somax/solvers/row_cg.py
|
|
52
|
+
src/somax/solvers/row_cholesky.py
|
|
53
|
+
src/somax/solvers/row_common.py
|