python-somax 0.0.1__tar.gz → 1.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. python_somax-1.0.1/PKG-INFO +202 -0
  2. python_somax-1.0.1/README.md +162 -0
  3. python_somax-1.0.1/pyproject.toml +93 -0
  4. python_somax-1.0.1/src/python_somax.egg-info/PKG-INFO +202 -0
  5. python_somax-1.0.1/src/python_somax.egg-info/SOURCES.txt +53 -0
  6. python_somax-1.0.1/src/python_somax.egg-info/requires.txt +18 -0
  7. python_somax-1.0.1/src/somax/__init__.py +58 -0
  8. python_somax-1.0.1/src/somax/assembler.py +151 -0
  9. python_somax-1.0.1/src/somax/curvature/__init__.py +14 -0
  10. python_somax-1.0.1/src/somax/curvature/base.py +56 -0
  11. python_somax-1.0.1/src/somax/curvature/ggn_ce.py +183 -0
  12. python_somax-1.0.1/src/somax/curvature/ggn_mse.py +104 -0
  13. python_somax-1.0.1/src/somax/curvature/hessian.py +63 -0
  14. python_somax-1.0.1/src/somax/curvature/with_estimators.py +45 -0
  15. python_somax-1.0.1/src/somax/damping/__init__.py +12 -0
  16. python_somax-1.0.1/src/somax/damping/base.py +33 -0
  17. python_somax-1.0.1/src/somax/damping/constant.py +28 -0
  18. python_somax-1.0.1/src/somax/damping/trust_region.py +115 -0
  19. python_somax-1.0.1/src/somax/estimators/__init__.py +18 -0
  20. python_somax-1.0.1/src/somax/estimators/base.py +53 -0
  21. python_somax-1.0.1/src/somax/estimators/gnb_ce.py +67 -0
  22. python_somax-1.0.1/src/somax/estimators/hutchinson.py +121 -0
  23. python_somax-1.0.1/src/somax/executor.py +528 -0
  24. python_somax-1.0.1/src/somax/methods/__init__.py +30 -0
  25. python_somax-1.0.1/src/somax/methods/adahessian.py +83 -0
  26. python_somax-1.0.1/src/somax/methods/direct_methods.py +151 -0
  27. python_somax-1.0.1/src/somax/methods/egn.py +233 -0
  28. python_somax-1.0.1/src/somax/methods/newton_cg.py +100 -0
  29. python_somax-1.0.1/src/somax/methods/sgn.py +194 -0
  30. python_somax-1.0.1/src/somax/methods/sophia_g.py +74 -0
  31. python_somax-1.0.1/src/somax/methods/sophia_h.py +76 -0
  32. python_somax-1.0.1/src/somax/metrics.py +79 -0
  33. python_somax-1.0.1/src/somax/optax.py +61 -0
  34. python_somax-1.0.1/src/somax/planner.py +197 -0
  35. python_somax-1.0.1/src/somax/preconditioners/__init__.py +15 -0
  36. python_somax-1.0.1/src/somax/preconditioners/base.py +42 -0
  37. python_somax-1.0.1/src/somax/preconditioners/diag_direct.py +86 -0
  38. python_somax-1.0.1/src/somax/preconditioners/diag_ema.py +183 -0
  39. python_somax-1.0.1/src/somax/preconditioners/identity.py +35 -0
  40. python_somax-1.0.1/src/somax/presets.py +60 -0
  41. python_somax-1.0.1/src/somax/solvers/__init__.py +16 -0
  42. python_somax-1.0.1/src/somax/solvers/base.py +49 -0
  43. python_somax-1.0.1/src/somax/solvers/cg.py +554 -0
  44. python_somax-1.0.1/src/somax/solvers/direct.py +55 -0
  45. python_somax-1.0.1/src/somax/solvers/identity.py +36 -0
  46. python_somax-1.0.1/src/somax/solvers/row_cg.py +89 -0
  47. python_somax-1.0.1/src/somax/solvers/row_cholesky.py +78 -0
  48. python_somax-1.0.1/src/somax/solvers/row_common.py +71 -0
  49. python_somax-1.0.1/src/somax/specs.py +246 -0
  50. python_somax-1.0.1/src/somax/types.py +53 -0
  51. python_somax-1.0.1/src/somax/utils.py +197 -0
  52. python_somax-0.0.1/PKG-INFO +0 -131
  53. python_somax-0.0.1/README.md +0 -112
  54. python_somax-0.0.1/setup.py +0 -40
  55. python_somax-0.0.1/src/python_somax.egg-info/PKG-INFO +0 -131
  56. python_somax-0.0.1/src/python_somax.egg-info/SOURCES.txt +0 -30
  57. python_somax-0.0.1/src/python_somax.egg-info/requires.txt +0 -1
  58. python_somax-0.0.1/src/somax/__init__.py +0 -19
  59. python_somax-0.0.1/src/somax/diagonal/__init__.py +0 -0
  60. python_somax-0.0.1/src/somax/diagonal/adahessian.py +0 -192
  61. python_somax-0.0.1/src/somax/diagonal/sophia_g.py +0 -208
  62. python_somax-0.0.1/src/somax/diagonal/sophia_h.py +0 -191
  63. python_somax-0.0.1/src/somax/gn/__init__.py +0 -0
  64. python_somax-0.0.1/src/somax/gn/egn.py +0 -567
  65. python_somax-0.0.1/src/somax/gn/sgn.py +0 -237
  66. python_somax-0.0.1/src/somax/hf/__init__.py +0 -0
  67. python_somax-0.0.1/src/somax/hf/newton_cg.py +0 -339
  68. python_somax-0.0.1/src/somax/ng/__init__.py +0 -0
  69. python_somax-0.0.1/src/somax/ng/swm_ng.py +0 -411
  70. python_somax-0.0.1/src/somax/qn/__init__.py +0 -0
  71. python_somax-0.0.1/src/somax/qn/sqn.py +0 -311
  72. python_somax-0.0.1/tests/test_adahessian.py +0 -134
  73. python_somax-0.0.1/tests/test_egn.py +0 -306
  74. python_somax-0.0.1/tests/test_newton_cg.py +0 -254
  75. python_somax-0.0.1/tests/test_sgn.py +0 -273
  76. python_somax-0.0.1/tests/test_sophia_g.py +0 -72
  77. python_somax-0.0.1/tests/test_sophia_h.py +0 -130
  78. python_somax-0.0.1/tests/test_sqn.py +0 -126
  79. python_somax-0.0.1/tests/test_swm_ng.py +0 -138
  80. {python_somax-0.0.1 → python_somax-1.0.1}/LICENSE +0 -0
  81. {python_somax-0.0.1 → python_somax-1.0.1}/setup.cfg +0 -0
  82. {python_somax-0.0.1 → python_somax-1.0.1}/src/python_somax.egg-info/dependency_links.txt +0 -0
  83. {python_somax-0.0.1 → python_somax-1.0.1}/src/python_somax.egg-info/top_level.txt +0 -0
@@ -0,0 +1,202 @@
1
+ Metadata-Version: 2.4
2
+ Name: python-somax
3
+ Version: 1.0.1
4
+ Summary: Composable Second-Order Optimization for JAX and Optax.
5
+ Author: Nick Korbit
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://github.com/cor3bit/somax
8
+ Project-URL: Repository, https://github.com/cor3bit/somax
9
+ Project-URL: Issues, https://github.com/cor3bit/somax/issues
10
+ Project-URL: Releases, https://github.com/cor3bit/somax/releases
11
+ Project-URL: Paper, https://arxiv.org/abs/2603.25976
12
+ Keywords: jax,optax,optimization,second-order,gauss-newton,hessian,machine-learning
13
+ Classifier: Development Status :: 3 - Alpha
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
20
+ Requires-Python: >=3.11
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: jax>=0.4.20
24
+ Requires-Dist: optax>=0.1.7
25
+ Requires-Dist: flax>=0.7.5
26
+ Requires-Dist: numpy>=1.23
27
+ Provides-Extra: dev
28
+ Requires-Dist: build>=1.2; extra == "dev"
29
+ Requires-Dist: twine>=5; extra == "dev"
30
+ Requires-Dist: pytest>=8; extra == "dev"
31
+ Requires-Dist: pytest-cov>=5; extra == "dev"
32
+ Requires-Dist: ruff>=0.5; extra == "dev"
33
+ Requires-Dist: pyright>=1.1; extra == "dev"
34
+ Requires-Dist: pre-commit>=3; extra == "dev"
35
+ Requires-Dist: chex>=0.1.8; extra == "dev"
36
+ Provides-Extra: docs
37
+ Requires-Dist: mkdocs>=1.5; extra == "docs"
38
+ Requires-Dist: mkdocs-material>=9; extra == "docs"
39
+ Dynamic: license-file
40
+
41
+ <h1 align="center">Somax</h1>
42
+
43
+ <p align="center">
44
+ <img src="assets/somax_logo_mini.png" alt="Somax logo" width="250px"/>
45
+ </p>
46
+
47
+ <p align="center">
48
+ Composable Second-Order Optimization for JAX and Optax.
49
+ </p>
50
+
51
+ <p align="center">
52
+ A small research-engineering library for curvature-aware training:
53
+ modular, matrix-free, and explicit about the moving parts.
54
+ </p>
55
+
56
+ ---
57
+
58
+ Somax is a JAX-native library for building and running second-order optimization methods from explicit components.
59
+
60
+ Rather than treating an optimizer as a monolithic object, Somax factors a step into swappable pieces:
61
+ - curvature operator
62
+ - solver
63
+ - damping policy
64
+ - optional preconditioner
65
+ - update transform
66
+ - optional telemetry and control signals
67
+
68
+ That decomposition is the point.
69
+
70
+ Somax is built for users who want a clean second-order stack in JAX without hiding the execution model.
71
+ It aims to make curvature-aware training easier to inspect, compare, and extend.
72
+
73
+
74
+ > The catfish in the logo is a small nod to *som*, the Belarusian word for catfish.
75
+ > A quiet bottom-dweller, but not a first-order creature.
76
+
77
+
78
+
79
+ ## Why Somax
80
+
81
+ - **Composable**: build methods from curvature, solver, damping, preconditioner, and update components.
82
+ - **Optax-native**: computed directions are fed through Optax-style update transforms.
83
+ - **Planned execution**: a method is assembled once, planned once, and then executed as a stable step pipeline.
84
+ - **JAX-first**: intended for `jit`-compiled training loops and explicit control over execution.
85
+ - **Multiple solve lanes**: diagonal, parameter-space, and row-space paths are first-class parts of the stack.
86
+ - **Research-friendly**: easy to inspect, compare, ablate, and extend.
87
+
88
+
89
+
90
+
91
+ ## Installation
92
+
93
+ Install JAX for your backend first:
94
+
95
+ - JAX installation guide: https://docs.jax.dev/en/latest/installation.html
96
+
97
+ Then install Somax:
98
+
99
+ ```bash
100
+ pip install python-somax
101
+ ```
102
+
103
+ For local development:
104
+
105
+ ```bash
106
+ git clone https://github.com/cor3bit/somax.git
107
+ cd somax
108
+ pip install -e ".[dev]"
109
+ ```
110
+
111
+ Optional:
112
+ - install `lineax` only if you want to use CG backends with `backend="lineax"`.
113
+
114
+
115
+
116
+ ## Quickstart
117
+
118
+ ```python
119
+ import jax
120
+ import jax.numpy as jnp
121
+ import somax
122
+
123
+
124
+ def predict_fn(params, x):
125
+ h = jnp.tanh(x @ params["W1"] + params["b1"])
126
+ return h @ params["W2"] + params["b2"]
127
+
128
+
129
+ rng = jax.random.PRNGKey(0)
130
+ k1, k2, k3, k4 = jax.random.split(rng, 4)
131
+
132
+ params = {
133
+ "W1": 0.1 * jax.random.normal(k1, (16, 32)),
134
+ "b1": jnp.zeros((32,)),
135
+ "W2": 0.1 * jax.random.normal(k2, (32, 10)),
136
+ "b2": jnp.zeros((10,)),
137
+ }
138
+
139
+ batch = {
140
+ "x": jax.random.normal(k3, (64, 16)),
141
+ "y": jax.random.randint(k4, (64,), 0, 10),
142
+ }
143
+
144
+ method = somax.sgn_ce(
145
+ predict_fn=predict_fn,
146
+ lam0=1e-2,
147
+ tol=1e-4,
148
+ maxiter=20,
149
+ learning_rate=1e-1,
150
+ )
151
+
152
+ state = method.init(params)
153
+
154
+ @jax.jit
155
+ def train_step(params, state, rng):
156
+ params, state, info = method.step(params, batch, state, rng)
157
+ return params, state, info
158
+
159
+ for step in range(10):
160
+ params, state, info = train_step(params, state, jax.random.fold_in(rng, step))
161
+ ```
162
+
163
+
164
+
165
+
166
+ ## Citation
167
+
168
+ If Somax is useful in your academic work, please cite:
169
+
170
+ **Second-Order, First-Class: A Composable Stack for Curvature-Aware Training**
171
+ Mikalai Korbit and Mario Zanon
172
+ https://arxiv.org/abs/2603.25976
173
+
174
+
175
+ ```bibtex
176
+ @article{korbit2026second,
177
+ title={Second-Order, First-Class: A Composable Stack for Curvature-Aware Training},
178
+ author={Korbit, Mikalai and Zanon, Mario},
179
+ journal={arXiv preprint arXiv:2603.25976},
180
+ year={2026}
181
+ }
182
+ ```
183
+
184
+
185
+ ## Related projects
186
+
187
+ **Optimization in JAX**
188
+ [Optax](https://github.com/google-deepmind/optax): first-order gradient (e.g., SGD, Adam) optimisers.
189
+ [JAXopt](https://github.com/google/jaxopt): deterministic second-order methods (e.g., Gauss-Newton, Levenberg
190
+ Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
191
+
192
+ **Awesome Projects**
193
+ [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of various JAX projects.
194
+ [Awesome SOMs](https://github.com/cor3bit/awesome-soms): a list
195
+ of resources for second-order optimization methods in machine learning.
196
+
197
+
198
+
199
+
200
+ ## License
201
+
202
+ Apache-2.0
@@ -0,0 +1,162 @@
1
+ <h1 align="center">Somax</h1>
2
+
3
+ <p align="center">
4
+ <img src="assets/somax_logo_mini.png" alt="Somax logo" width="250px"/>
5
+ </p>
6
+
7
+ <p align="center">
8
+ Composable Second-Order Optimization for JAX and Optax.
9
+ </p>
10
+
11
+ <p align="center">
12
+ A small research-engineering library for curvature-aware training:
13
+ modular, matrix-free, and explicit about the moving parts.
14
+ </p>
15
+
16
+ ---
17
+
18
+ Somax is a JAX-native library for building and running second-order optimization methods from explicit components.
19
+
20
+ Rather than treating an optimizer as a monolithic object, Somax factors a step into swappable pieces:
21
+ - curvature operator
22
+ - solver
23
+ - damping policy
24
+ - optional preconditioner
25
+ - update transform
26
+ - optional telemetry and control signals
27
+
28
+ That decomposition is the point.
29
+
30
+ Somax is built for users who want a clean second-order stack in JAX without hiding the execution model.
31
+ It aims to make curvature-aware training easier to inspect, compare, and extend.
32
+
33
+
34
+ > The catfish in the logo is a small nod to *som*, the Belarusian word for catfish.
35
+ > A quiet bottom-dweller, but not a first-order creature.
36
+
37
+
38
+
39
+ ## Why Somax
40
+
41
+ - **Composable**: build methods from curvature, solver, damping, preconditioner, and update components.
42
+ - **Optax-native**: computed directions are fed through Optax-style update transforms.
43
+ - **Planned execution**: a method is assembled once, planned once, and then executed as a stable step pipeline.
44
+ - **JAX-first**: intended for `jit`-compiled training loops and explicit control over execution.
45
+ - **Multiple solve lanes**: diagonal, parameter-space, and row-space paths are first-class parts of the stack.
46
+ - **Research-friendly**: easy to inspect, compare, ablate, and extend.
47
+
48
+
49
+
50
+
51
+ ## Installation
52
+
53
+ Install JAX for your backend first:
54
+
55
+ - JAX installation guide: https://docs.jax.dev/en/latest/installation.html
56
+
57
+ Then install Somax:
58
+
59
+ ```bash
60
+ pip install python-somax
61
+ ```
62
+
63
+ For local development:
64
+
65
+ ```bash
66
+ git clone https://github.com/cor3bit/somax.git
67
+ cd somax
68
+ pip install -e ".[dev]"
69
+ ```
70
+
71
+ Optional:
72
+ - install `lineax` only if you want to use CG backends with `backend="lineax"`.
73
+
74
+
75
+
76
+ ## Quickstart
77
+
78
+ ```python
79
+ import jax
80
+ import jax.numpy as jnp
81
+ import somax
82
+
83
+
84
+ def predict_fn(params, x):
85
+ h = jnp.tanh(x @ params["W1"] + params["b1"])
86
+ return h @ params["W2"] + params["b2"]
87
+
88
+
89
+ rng = jax.random.PRNGKey(0)
90
+ k1, k2, k3, k4 = jax.random.split(rng, 4)
91
+
92
+ params = {
93
+ "W1": 0.1 * jax.random.normal(k1, (16, 32)),
94
+ "b1": jnp.zeros((32,)),
95
+ "W2": 0.1 * jax.random.normal(k2, (32, 10)),
96
+ "b2": jnp.zeros((10,)),
97
+ }
98
+
99
+ batch = {
100
+ "x": jax.random.normal(k3, (64, 16)),
101
+ "y": jax.random.randint(k4, (64,), 0, 10),
102
+ }
103
+
104
+ method = somax.sgn_ce(
105
+ predict_fn=predict_fn,
106
+ lam0=1e-2,
107
+ tol=1e-4,
108
+ maxiter=20,
109
+ learning_rate=1e-1,
110
+ )
111
+
112
+ state = method.init(params)
113
+
114
+ @jax.jit
115
+ def train_step(params, state, rng):
116
+ params, state, info = method.step(params, batch, state, rng)
117
+ return params, state, info
118
+
119
+ for step in range(10):
120
+ params, state, info = train_step(params, state, jax.random.fold_in(rng, step))
121
+ ```
122
+
123
+
124
+
125
+
126
+ ## Citation
127
+
128
+ If Somax is useful in your academic work, please cite:
129
+
130
+ **Second-Order, First-Class: A Composable Stack for Curvature-Aware Training**
131
+ Mikalai Korbit and Mario Zanon
132
+ https://arxiv.org/abs/2603.25976
133
+
134
+
135
+ ```bibtex
136
+ @article{korbit2026second,
137
+ title={Second-Order, First-Class: A Composable Stack for Curvature-Aware Training},
138
+ author={Korbit, Mikalai and Zanon, Mario},
139
+ journal={arXiv preprint arXiv:2603.25976},
140
+ year={2026}
141
+ }
142
+ ```
143
+
144
+
145
+ ## Related projects
146
+
147
+ **Optimization in JAX**
148
+ [Optax](https://github.com/google-deepmind/optax): first-order gradient (e.g., SGD, Adam) optimisers.
149
+ [JAXopt](https://github.com/google/jaxopt): deterministic second-order methods (e.g., Gauss-Newton, Levenberg
150
+ Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
151
+
152
+ **Awesome Projects**
153
+ [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of various JAX projects.
154
+ [Awesome SOMs](https://github.com/cor3bit/awesome-soms): a list
155
+ of resources for second-order optimization methods in machine learning.
156
+
157
+
158
+
159
+
160
+ ## License
161
+
162
+ Apache-2.0
@@ -0,0 +1,93 @@
1
+ [build-system]
2
+ requires = ["setuptools>=77"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "python-somax"
7
+ version = "1.0.1"
8
+ description = "Composable Second-Order Optimization for JAX and Optax."
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ license = "Apache-2.0"
12
+ license-files = ["LICENSE"]
13
+ authors = [
14
+ { name = "Nick Korbit" }
15
+ ]
16
+ dependencies = [
17
+ "jax>=0.4.20",
18
+ "optax>=0.1.7",
19
+ "flax>=0.7.5",
20
+ "numpy>=1.23",
21
+ ]
22
+ keywords = [
23
+ "jax",
24
+ "optax",
25
+ "optimization",
26
+ "second-order",
27
+ "gauss-newton",
28
+ "hessian",
29
+ "machine-learning",
30
+ ]
31
+ classifiers = [
32
+ "Development Status :: 3 - Alpha",
33
+ "Intended Audience :: Science/Research",
34
+ "Programming Language :: Python :: 3",
35
+ "Programming Language :: Python :: 3.11",
36
+ "Programming Language :: Python :: 3.12",
37
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
38
+ "Topic :: Scientific/Engineering :: Mathematics",
39
+ ]
40
+
41
+ [project.urls]
42
+ Homepage = "https://github.com/cor3bit/somax"
43
+ Repository = "https://github.com/cor3bit/somax"
44
+ Issues = "https://github.com/cor3bit/somax/issues"
45
+ Releases = "https://github.com/cor3bit/somax/releases"
46
+ #Documentation = "https://github.com/cor3bit/somax/tree/main/docs"
47
+ Paper = "https://arxiv.org/abs/2603.25976"
48
+
49
+ [project.optional-dependencies]
50
+ dev = [
51
+ "build>=1.2",
52
+ "twine>=5",
53
+ "pytest>=8",
54
+ "pytest-cov>=5",
55
+ "ruff>=0.5",
56
+ "pyright>=1.1",
57
+ "pre-commit>=3",
58
+ "chex>=0.1.8",
59
+ ]
60
+ docs = [
61
+ "mkdocs>=1.5",
62
+ "mkdocs-material>=9",
63
+ ]
64
+
65
+ [tool.setuptools.packages.find]
66
+ where = ["src"]
67
+
68
+ [tool.ruff]
69
+ line-length = 100
70
+ target-version = "py311"
71
+
72
+ [tool.ruff.lint]
73
+ select = ["E", "F", "I", "N", "W", "UP", "B", "SIM"]
74
+ ignore = ["E501", "N803", "N806"]
75
+
76
+ [tool.ruff.format]
77
+ quote-style = "double"
78
+ indent-style = "space"
79
+ skip-magic-trailing-comma = false
80
+ line-ending = "lf"
81
+
82
+ [tool.pyright]
83
+ include = ["src"]
84
+ exclude = ["**/node_modules", "**/__pycache__"]
85
+ reportMissingImports = true
86
+ reportMissingTypeStubs = false
87
+ pythonVersion = "3.11"
88
+ typeCheckingMode = "basic"
89
+
90
+ [tool.pytest.ini_options]
91
+ testpaths = ["tests"]
92
+ addopts = "-q --strict-markers"
93
+ xfail_strict = true
@@ -0,0 +1,202 @@
1
+ Metadata-Version: 2.4
2
+ Name: python-somax
3
+ Version: 1.0.1
4
+ Summary: Composable Second-Order Optimization for JAX and Optax.
5
+ Author: Nick Korbit
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://github.com/cor3bit/somax
8
+ Project-URL: Repository, https://github.com/cor3bit/somax
9
+ Project-URL: Issues, https://github.com/cor3bit/somax/issues
10
+ Project-URL: Releases, https://github.com/cor3bit/somax/releases
11
+ Project-URL: Paper, https://arxiv.org/abs/2603.25976
12
+ Keywords: jax,optax,optimization,second-order,gauss-newton,hessian,machine-learning
13
+ Classifier: Development Status :: 3 - Alpha
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
20
+ Requires-Python: >=3.11
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: jax>=0.4.20
24
+ Requires-Dist: optax>=0.1.7
25
+ Requires-Dist: flax>=0.7.5
26
+ Requires-Dist: numpy>=1.23
27
+ Provides-Extra: dev
28
+ Requires-Dist: build>=1.2; extra == "dev"
29
+ Requires-Dist: twine>=5; extra == "dev"
30
+ Requires-Dist: pytest>=8; extra == "dev"
31
+ Requires-Dist: pytest-cov>=5; extra == "dev"
32
+ Requires-Dist: ruff>=0.5; extra == "dev"
33
+ Requires-Dist: pyright>=1.1; extra == "dev"
34
+ Requires-Dist: pre-commit>=3; extra == "dev"
35
+ Requires-Dist: chex>=0.1.8; extra == "dev"
36
+ Provides-Extra: docs
37
+ Requires-Dist: mkdocs>=1.5; extra == "docs"
38
+ Requires-Dist: mkdocs-material>=9; extra == "docs"
39
+ Dynamic: license-file
40
+
41
+ <h1 align="center">Somax</h1>
42
+
43
+ <p align="center">
44
+ <img src="assets/somax_logo_mini.png" alt="Somax logo" width="250px"/>
45
+ </p>
46
+
47
+ <p align="center">
48
+ Composable Second-Order Optimization for JAX and Optax.
49
+ </p>
50
+
51
+ <p align="center">
52
+ A small research-engineering library for curvature-aware training:
53
+ modular, matrix-free, and explicit about the moving parts.
54
+ </p>
55
+
56
+ ---
57
+
58
+ Somax is a JAX-native library for building and running second-order optimization methods from explicit components.
59
+
60
+ Rather than treating an optimizer as a monolithic object, Somax factors a step into swappable pieces:
61
+ - curvature operator
62
+ - solver
63
+ - damping policy
64
+ - optional preconditioner
65
+ - update transform
66
+ - optional telemetry and control signals
67
+
68
+ That decomposition is the point.
69
+
70
+ Somax is built for users who want a clean second-order stack in JAX without hiding the execution model.
71
+ It aims to make curvature-aware training easier to inspect, compare, and extend.
72
+
73
+
74
+ > The catfish in the logo is a small nod to *som*, the Belarusian word for catfish.
75
+ > A quiet bottom-dweller, but not a first-order creature.
76
+
77
+
78
+
79
+ ## Why Somax
80
+
81
+ - **Composable**: build methods from curvature, solver, damping, preconditioner, and update components.
82
+ - **Optax-native**: computed directions are fed through Optax-style update transforms.
83
+ - **Planned execution**: a method is assembled once, planned once, and then executed as a stable step pipeline.
84
+ - **JAX-first**: intended for `jit`-compiled training loops and explicit control over execution.
85
+ - **Multiple solve lanes**: diagonal, parameter-space, and row-space paths are first-class parts of the stack.
86
+ - **Research-friendly**: easy to inspect, compare, ablate, and extend.
87
+
88
+
89
+
90
+
91
+ ## Installation
92
+
93
+ Install JAX for your backend first:
94
+
95
+ - JAX installation guide: https://docs.jax.dev/en/latest/installation.html
96
+
97
+ Then install Somax:
98
+
99
+ ```bash
100
+ pip install python-somax
101
+ ```
102
+
103
+ For local development:
104
+
105
+ ```bash
106
+ git clone https://github.com/cor3bit/somax.git
107
+ cd somax
108
+ pip install -e ".[dev]"
109
+ ```
110
+
111
+ Optional:
112
+ - install `lineax` only if you want to use CG backends with `backend="lineax"`.
113
+
114
+
115
+
116
+ ## Quickstart
117
+
118
+ ```python
119
+ import jax
120
+ import jax.numpy as jnp
121
+ import somax
122
+
123
+
124
+ def predict_fn(params, x):
125
+ h = jnp.tanh(x @ params["W1"] + params["b1"])
126
+ return h @ params["W2"] + params["b2"]
127
+
128
+
129
+ rng = jax.random.PRNGKey(0)
130
+ k1, k2, k3, k4 = jax.random.split(rng, 4)
131
+
132
+ params = {
133
+ "W1": 0.1 * jax.random.normal(k1, (16, 32)),
134
+ "b1": jnp.zeros((32,)),
135
+ "W2": 0.1 * jax.random.normal(k2, (32, 10)),
136
+ "b2": jnp.zeros((10,)),
137
+ }
138
+
139
+ batch = {
140
+ "x": jax.random.normal(k3, (64, 16)),
141
+ "y": jax.random.randint(k4, (64,), 0, 10),
142
+ }
143
+
144
+ method = somax.sgn_ce(
145
+ predict_fn=predict_fn,
146
+ lam0=1e-2,
147
+ tol=1e-4,
148
+ maxiter=20,
149
+ learning_rate=1e-1,
150
+ )
151
+
152
+ state = method.init(params)
153
+
154
+ @jax.jit
155
+ def train_step(params, state, rng):
156
+ params, state, info = method.step(params, batch, state, rng)
157
+ return params, state, info
158
+
159
+ for step in range(10):
160
+ params, state, info = train_step(params, state, jax.random.fold_in(rng, step))
161
+ ```
162
+
163
+
164
+
165
+
166
+ ## Citation
167
+
168
+ If Somax is useful in your academic work, please cite:
169
+
170
+ **Second-Order, First-Class: A Composable Stack for Curvature-Aware Training**
171
+ Mikalai Korbit and Mario Zanon
172
+ https://arxiv.org/abs/2603.25976
173
+
174
+
175
+ ```bibtex
176
+ @article{korbit2026second,
177
+ title={Second-Order, First-Class: A Composable Stack for Curvature-Aware Training},
178
+ author={Korbit, Mikalai and Zanon, Mario},
179
+ journal={arXiv preprint arXiv:2603.25976},
180
+ year={2026}
181
+ }
182
+ ```
183
+
184
+
185
+ ## Related projects
186
+
187
+ **Optimization in JAX**
188
+ [Optax](https://github.com/google-deepmind/optax): first-order gradient (e.g., SGD, Adam) optimisers.
189
+ [JAXopt](https://github.com/google/jaxopt): deterministic second-order methods (e.g., Gauss-Newton, Levenberg
190
+ Marquardt), stochastic first-order methods PolyakSGD, ArmijoSGD.
191
+
192
+ **Awesome Projects**
193
+ [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of various JAX projects.
194
+ [Awesome SOMs](https://github.com/cor3bit/awesome-soms): a list
195
+ of resources for second-order optimization methods in machine learning.
196
+
197
+
198
+
199
+
200
+ ## License
201
+
202
+ Apache-2.0
@@ -0,0 +1,53 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/python_somax.egg-info/PKG-INFO
5
+ src/python_somax.egg-info/SOURCES.txt
6
+ src/python_somax.egg-info/dependency_links.txt
7
+ src/python_somax.egg-info/requires.txt
8
+ src/python_somax.egg-info/top_level.txt
9
+ src/somax/__init__.py
10
+ src/somax/assembler.py
11
+ src/somax/executor.py
12
+ src/somax/metrics.py
13
+ src/somax/optax.py
14
+ src/somax/planner.py
15
+ src/somax/presets.py
16
+ src/somax/specs.py
17
+ src/somax/types.py
18
+ src/somax/utils.py
19
+ src/somax/curvature/__init__.py
20
+ src/somax/curvature/base.py
21
+ src/somax/curvature/ggn_ce.py
22
+ src/somax/curvature/ggn_mse.py
23
+ src/somax/curvature/hessian.py
24
+ src/somax/curvature/with_estimators.py
25
+ src/somax/damping/__init__.py
26
+ src/somax/damping/base.py
27
+ src/somax/damping/constant.py
28
+ src/somax/damping/trust_region.py
29
+ src/somax/estimators/__init__.py
30
+ src/somax/estimators/base.py
31
+ src/somax/estimators/gnb_ce.py
32
+ src/somax/estimators/hutchinson.py
33
+ src/somax/methods/__init__.py
34
+ src/somax/methods/adahessian.py
35
+ src/somax/methods/direct_methods.py
36
+ src/somax/methods/egn.py
37
+ src/somax/methods/newton_cg.py
38
+ src/somax/methods/sgn.py
39
+ src/somax/methods/sophia_g.py
40
+ src/somax/methods/sophia_h.py
41
+ src/somax/preconditioners/__init__.py
42
+ src/somax/preconditioners/base.py
43
+ src/somax/preconditioners/diag_direct.py
44
+ src/somax/preconditioners/diag_ema.py
45
+ src/somax/preconditioners/identity.py
46
+ src/somax/solvers/__init__.py
47
+ src/somax/solvers/base.py
48
+ src/somax/solvers/cg.py
49
+ src/somax/solvers/direct.py
50
+ src/somax/solvers/identity.py
51
+ src/somax/solvers/row_cg.py
52
+ src/somax/solvers/row_cholesky.py
53
+ src/somax/solvers/row_common.py