matnets 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. matnets-1.1.0/.github/workflows/docs.yml +63 -0
  2. matnets-1.1.0/.github/workflows/pypi.yaml +30 -0
  3. matnets-1.1.0/.github/workflows/testing.yml +82 -0
  4. matnets-1.1.0/.gitignore +15 -0
  5. matnets-1.1.0/CHANGELOG.md +6 -0
  6. matnets-1.1.0/CONTRIBUTING.md +20 -0
  7. matnets-1.1.0/LICENSE +21 -0
  8. matnets-1.1.0/PKG-INFO +54 -0
  9. matnets-1.1.0/README.md +22 -0
  10. matnets-1.1.0/docs/api.md +164 -0
  11. matnets-1.1.0/docs/concepts.md +78 -0
  12. matnets-1.1.0/docs/development.md +42 -0
  13. matnets-1.1.0/docs/examples.md +52 -0
  14. matnets-1.1.0/docs/getting-started.md +59 -0
  15. matnets-1.1.0/docs/index.md +29 -0
  16. matnets-1.1.0/docs/stylesheets/extra.css +89 -0
  17. matnets-1.1.0/examples/1_linear_regression.py +60 -0
  18. matnets-1.1.0/examples/2_mlp_3_hidden_layers.py +72 -0
  19. matnets-1.1.0/examples/3_mlp_7_hidden_layers.py +72 -0
  20. matnets-1.1.0/examples/4_cnn1d.py +78 -0
  21. matnets-1.1.0/examples/5_cnn2d.py +77 -0
  22. matnets-1.1.0/examples/6_lstm.py +73 -0
  23. matnets-1.1.0/examples/7_rnn.py +77 -0
  24. matnets-1.1.0/examples/8_transformer.py +75 -0
  25. matnets-1.1.0/examples/9_kaggle_cnn_mnist copy.ipynb +842 -0
  26. matnets-1.1.0/examples/9_kaggle_cnn_mnist.ipynb +766 -0
  27. matnets-1.1.0/examples/README.md +19 -0
  28. matnets-1.1.0/mkdocs.yml +26 -0
  29. matnets-1.1.0/pyproject.toml +65 -0
  30. matnets-1.1.0/src/matnets/__init__.py +9 -0
  31. matnets-1.1.0/src/matnets/_dense.py +76 -0
  32. matnets-1.1.0/src/matnets/_params.py +66 -0
  33. matnets-1.1.0/src/matnets/lax/__init__.py +16 -0
  34. matnets-1.1.0/src/matnets/lax/attention.py +62 -0
  35. matnets-1.1.0/src/matnets/lax/conv.py +153 -0
  36. matnets-1.1.0/src/matnets/nn/__init__.py +5 -0
  37. matnets-1.1.0/src/matnets/nn/recurrent.py +60 -0
  38. matnets-1.1.0/src/matnets/py.typed +1 -0
  39. matnets-1.1.0/tests/test_primitives.py +318 -0
@@ -0,0 +1,63 @@
1
+ name: Docs
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - "docs/**"
9
+ - "mkdocs.yml"
10
+ - "pyproject.toml"
11
+ - ".github/workflows/docs.yml"
12
+ workflow_dispatch:
13
+
14
+ permissions:
15
+ contents: read
16
+ pages: write
17
+ id-token: write
18
+
19
+ concurrency:
20
+ group: pages
21
+ cancel-in-progress: false
22
+
23
+ jobs:
24
+ build:
25
+ name: Build MkDocs site
26
+ runs-on: ubuntu-latest
27
+
28
+ steps:
29
+ - name: Check out repository
30
+ uses: actions/checkout@v4
31
+
32
+ - name: Set up Python
33
+ uses: actions/setup-python@v5
34
+ with:
35
+ python-version: "3.12"
36
+ cache: pip
37
+
38
+ - name: Configure GitHub Pages
39
+ uses: actions/configure-pages@v5
40
+
41
+ - name: Install docs dependencies
42
+ run: python -m pip install -e ".[docs]"
43
+
44
+ - name: Build site
45
+ run: mkdocs build --strict
46
+
47
+ - name: Upload Pages artifact
48
+ uses: actions/upload-pages-artifact@v3
49
+ with:
50
+ path: site
51
+
52
+ deploy:
53
+ name: Deploy to GitHub Pages
54
+ needs: build
55
+ runs-on: ubuntu-latest
56
+ environment:
57
+ name: github-pages
58
+ url: "${{ steps.deployment.outputs.page_url }}"
59
+
60
+ steps:
61
+ - name: Deploy site
62
+ id: deployment
63
+ uses: actions/deploy-pages@v4
@@ -0,0 +1,30 @@
1
+ # pypi.yaml
2
+ # Configuration for PyPI publishing (example for GitHub Actions)
3
+ # Adjust as needed for your CI/CD provider and secrets management
4
+
5
+ name: Publish Python 🐍 distribution 📦 to PyPI
6
+
7
+ on:
8
+ push:
9
+ tags:
10
+ - 'v*.*.*'
11
+
12
+ jobs:
13
+ build-and-publish:
14
+ runs-on: ubuntu-latest
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+ - name: Set up Python
18
+ uses: actions/setup-python@v5
19
+ with:
20
+ python-version: '3.11'
21
+ - name: Install build dependencies
22
+ run: |
23
+ python -m pip install --upgrade pip
24
+ pip install build
25
+ - name: Build package
26
+ run: python -m build
27
+ - name: Publish package to PyPI
28
+ uses: pypa/gh-action-pypi-publish@release/v1
29
+ with:
30
+ password: ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,82 @@
1
+ name: CI
2
+
3
+ on:
4
+ pull_request:
5
+ paths-ignore:
6
+ - "docs/**"
7
+ - "mkdocs.yml"
8
+ - ".github/workflows/docs.yml"
9
+ - "*.md"
10
+ push:
11
+ branches:
12
+ - main
13
+ paths-ignore:
14
+ - "docs/**"
15
+ - "mkdocs.yml"
16
+ - ".github/workflows/docs.yml"
17
+ - "*.md"
18
+
19
+ jobs:
20
+ test:
21
+ name: Test Python ${{ matrix.python-version }}
22
+ runs-on: ubuntu-latest
23
+
24
+ strategy:
25
+ fail-fast: false
26
+ matrix:
27
+ python-version:
28
+ - "3.11"
29
+ - "3.12"
30
+ - "3.13"
31
+ - "3.14"
32
+
33
+ steps:
34
+ - name: Check out repository
35
+ uses: actions/checkout@v4
36
+
37
+ - name: Set up Python
38
+ uses: actions/setup-python@v5
39
+ with:
40
+ python-version: ${{ matrix.python-version }}
41
+ cache: pip
42
+
43
+ - name: Install package
44
+ run: python -m pip install -e ".[dev,docs]"
45
+
46
+ - name: Lint
47
+ run: ruff check src/ tests/
48
+
49
+ - name: Type check
50
+ run: mypy src
51
+
52
+ - name: Test
53
+ run: pytest
54
+
55
+ - name: Build docs
56
+ run: mkdocs build --strict
57
+
58
+ package:
59
+ name: Build package
60
+ runs-on: ubuntu-latest
61
+
62
+ steps:
63
+ - name: Check out repository
64
+ uses: actions/checkout@v4
65
+
66
+ - name: Set up Python
67
+ uses: actions/setup-python@v5
68
+ with:
69
+ python-version: "3.13"
70
+ cache: pip
71
+
72
+ - name: Install build tools
73
+ run: python -m pip install build
74
+
75
+ - name: Build distribution
76
+ run: python -m build
77
+
78
+ - name: Upload distribution artifact
79
+ uses: actions/upload-artifact@v4
80
+ with:
81
+ name: python-package
82
+ path: dist/*
@@ -0,0 +1,15 @@
1
+ .vscode/
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.egg-info/
5
+ .pytest_cache/
6
+ .ruff_cache/
7
+ .mypy_cache/
8
+ .venv/
9
+ venv/
10
+ build/
11
+ dist/
12
+ site/
13
+ .coverage
14
+ htmlcov/
15
+ runs/
@@ -0,0 +1,6 @@
1
+ # Changelog
2
+
3
+ ## 0.1.0
4
+
5
+ - Initial project scaffold.
6
+ - Added matrix-neuron, layer, and sequential network primitives.
@@ -0,0 +1,20 @@
1
+ # Contributing
2
+
3
+ Thanks for helping MATNETS grow.
4
+
5
+ ## Local Setup
6
+
7
+ ```bash
8
+ python -m pip install -e ".[dev,docs]"
9
+ ```
10
+
11
+ ## Checks
12
+
13
+ ```bash
14
+ pytest
15
+ ruff check .
16
+ mypy src
17
+ ```
18
+
19
+ Keep changes focused, add tests for behavior changes, and update docs when the
20
+ public API changes.
matnets-1.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 MATNETS contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
matnets-1.1.0/PKG-INFO ADDED
@@ -0,0 +1,54 @@
1
+ Metadata-Version: 2.4
2
+ Name: matnets
3
+ Version: 1.1.0
4
+ Summary: A small experimental neural network library where neurons are represented as matrices.
5
+ Project-URL: Homepage, https://github.com/dsainvg/MATNETS
6
+ Project-URL: Documentation, https://github.com/dsainvg/MATNETS/tree/main/docs
7
+ Project-URL: Issues, https://github.com/dsainvg/MATNETS/issues
8
+ Author: dsainvg
9
+ License-Expression: MIT
10
+ License-File: LICENSE
11
+ Keywords: jax,machine-learning,matrix,neural-networks
12
+ Classifier: Development Status :: 2 - Pre-Alpha
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Programming Language :: Python :: 3.14
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Requires-Python: >=3.11
23
+ Requires-Dist: jax>=0.4
24
+ Provides-Extra: dev
25
+ Requires-Dist: mypy>=1.10; extra == 'dev'
26
+ Requires-Dist: pytest>=8.0; extra == 'dev'
27
+ Requires-Dist: ruff>=0.5; extra == 'dev'
28
+ Provides-Extra: docs
29
+ Requires-Dist: mkdocs-material>=9.5; extra == 'docs'
30
+ Requires-Dist: mkdocs>=1.6; extra == 'docs'
31
+ Description-Content-Type: text/markdown
32
+
33
+ # MATNETS
34
+
35
+ MATNETS is a small JAX library for matrix-neuron neural network experiments.
36
+ Each neuron carries an `n x n` matrix instead of a scalar.
37
+
38
+ The user documentation lives in [`docs/`](docs/index.md):
39
+
40
+ - [`docs/index.md`](docs/index.md): overview
41
+ - [`docs/getting-started.md`](docs/getting-started.md): install and first model
42
+ - [`docs/concepts.md`](docs/concepts.md): matrix-neuron shapes and JAX transforms
43
+ - [`docs/api.md`](docs/api.md): API guide
44
+ - [`docs/examples.md`](docs/examples.md): runnable examples
45
+ - [`docs/development.md`](docs/development.md): local development commands
46
+
47
+ Quick check:
48
+
49
+ ```powershell
50
+ .\.venv\Scripts\python.exe examples\five_hidden_net.py
51
+ .\.venv\Scripts\python.exe -m pytest
52
+ ```
53
+
54
+ MIT license.
@@ -0,0 +1,22 @@
1
+ # MATNETS
2
+
3
+ MATNETS is a small JAX library for matrix-neuron neural network experiments.
4
+ Each neuron carries an `n x n` matrix instead of a scalar.
5
+
6
+ The user documentation lives in [`docs/`](docs/index.md):
7
+
8
+ - [`docs/index.md`](docs/index.md): overview
9
+ - [`docs/getting-started.md`](docs/getting-started.md): install and first model
10
+ - [`docs/concepts.md`](docs/concepts.md): matrix-neuron shapes and JAX transforms
11
+ - [`docs/api.md`](docs/api.md): API guide
12
+ - [`docs/examples.md`](docs/examples.md): runnable examples
13
+ - [`docs/development.md`](docs/development.md): local development commands
14
+
15
+ Quick check:
16
+
17
+ ```powershell
18
+ .\.venv\Scripts\python.exe examples\five_hidden_net.py
19
+ .\.venv\Scripts\python.exe -m pytest
20
+ ```
21
+
22
+ MIT license.
@@ -0,0 +1,164 @@
1
+ # API Guide
2
+
3
+ ## `matnets.MatrixParams`
4
+
5
+ `MatrixParams` stores the weights and bias for matrix primitives:
6
+
7
+ ```python
8
+ from matnets import MatrixParams
9
+ ```
10
+
11
+ Dense parameters use:
12
+
13
+ ```text
14
+ W: (q, p, n, n)
15
+ B: (q, n, n)
16
+ ```
17
+
18
+ `MatrixParams` is registered as a JAX pytree, so it works with `jax.jit`,
19
+ `jax.vmap`, `jax.grad`, and nested dictionaries/lists of parameters.
20
+
21
+ ## `matnets.init`
22
+
23
+ ```python
24
+ params = matnets.init(key, p=2, q=3, n=4)
25
+ ```
26
+
27
+ Creates:
28
+
29
+ ```text
30
+ params.W: (3, 2, 4, 4)
31
+ params.B: (3, 4, 4)
32
+ ```
33
+
34
+ Weights use Glorot-uniform initialization. Bias starts at zero.
35
+
36
+ ## `matnets.dense`
37
+
38
+ ```python
39
+ y = matnets.dense(params, x)
40
+ ```
41
+
42
+ Expected shapes:
43
+
44
+ ```text
45
+ params.W: (q, p, n, n)
46
+ params.B: (q, n, n)
47
+ x: (p, n, n)
48
+ y: (q, n, n)
49
+ ```
50
+
51
+ With activation:
52
+
53
+ ```python
54
+ y = matnets.dense(params, x, activation=jax.nn.relu)
55
+ ```
56
+
57
+ The core operation is:
58
+
59
+ ```python
60
+ jnp.einsum("qpak,pkc->qac", params.W, x) + params.B
61
+ ```
62
+
63
+ ## `matnets.lax.matrix_conv1d`
64
+
65
+ ```python
66
+ from matnets.lax import matrix_conv1d
67
+
68
+ y = matrix_conv1d(params, x, stride=1, padding="VALID")
69
+ ```
70
+
71
+ Expected shapes:
72
+
73
+ ```text
74
+ params.W: (q, p, r, n, n)
75
+ params.B: (q, n, n)
76
+ x: (t, p, n, n)
77
+ y: (t_out, q, n, n)
78
+ ```
79
+
80
+ `r` is the 1D kernel size.
81
+
82
+ ## `matnets.lax.matrix_conv2d`
83
+
84
+ ```python
85
+ from matnets.lax import matrix_conv2d
86
+
87
+ y = matrix_conv2d(params, x, stride=(1, 1), padding="SAME")
88
+ ```
89
+
90
+ Expected shapes:
91
+
92
+ ```text
93
+ params.W: (q, p, h, w, n, n)
94
+ params.B: (q, n, n)
95
+ x: (height, width, p, n, n)
96
+ y: (height_out, width_out, q, n, n)
97
+ ```
98
+
99
+ ## `matnets.lax.matrix_attention`
100
+
101
+ ```python
102
+ from matnets.lax import matrix_attention
103
+
104
+ out = matrix_attention(None, Q, K, V)
105
+ ```
106
+
107
+ Expected token shapes:
108
+
109
+ ```text
110
+ Q: (tokens_q, p, n, n)
111
+ K: (tokens_k, p, n, n)
112
+ V: (tokens_k, p, n, n)
113
+ out: (tokens_q, p, n, n)
114
+ ```
115
+
116
+ By default the score is a scaled Frobenius inner product. You can pass a custom
117
+ `score_fn` that receives one query token and one key token and returns a scalar.
118
+
119
+ If `params` is not `None`, each aggregated output token is projected through
120
+ `matnets.dense(params, token)`.
121
+
122
+ ## `matnets.nn`
123
+
124
+ `matnets.nn` contains recurrent wiring patterns built from `dense`.
125
+
126
+ ```python
127
+ from matnets.nn import rnn_step, lstm_step, gru_step
128
+ ```
129
+
130
+ These functions are intended to be used with `jax.lax.scan`.
131
+
132
+ ### RNN
133
+
134
+ ```python
135
+ carry, outputs = jax.lax.scan(
136
+ lambda h, x_t: rnn_step(params, h, x_t),
137
+ h0,
138
+ sequence,
139
+ )
140
+ ```
141
+
142
+ ### LSTM
143
+
144
+ ```python
145
+ carry, outputs = jax.lax.scan(
146
+ lambda carry, x_t: lstm_step(params, carry, x_t),
147
+ (h0, c0),
148
+ sequence,
149
+ )
150
+ ```
151
+
152
+ LSTM params must contain keys `"i"`, `"f"`, `"g"`, and `"o"`.
153
+
154
+ ### GRU
155
+
156
+ ```python
157
+ carry, outputs = jax.lax.scan(
158
+ lambda h, x_t: gru_step(params, h, x_t),
159
+ h0,
160
+ sequence,
161
+ )
162
+ ```
163
+
164
+ GRU params must contain keys `"z"`, `"r"`, and `"n"`.
@@ -0,0 +1,78 @@
1
+ # Concepts
2
+
3
+ ## Matrix-Neurons
4
+
5
+ A traditional dense layer usually maps vectors:
6
+
7
+ ```text
8
+ x: (p)
9
+ W: (q, p)
10
+ y: (q)
11
+ ```
12
+
13
+ MATNETS maps stacks of square matrices:
14
+
15
+ ```text
16
+ x: (p, n, n)
17
+ W: (q, p, n, n)
18
+ B: (q, n, n)
19
+ y: (q, n, n)
20
+ ```
21
+
22
+ `p` is the input neuron count. `q` is the output neuron count. `n` is the
23
+ matrix size inside each neuron.
24
+
25
+ ## Dense Primitive
26
+
27
+ The core operation is:
28
+
29
+ ```python
30
+ jnp.einsum("qpak,pkc->qac", W, x) + B
31
+ ```
32
+
33
+ Under the square-matrix contract:
34
+
35
+ ```text
36
+ a == n
37
+ k == n
38
+ c == n
39
+ ```
40
+
41
+ so the output is always `(q, n, n)`.
42
+
43
+ ## Bias
44
+
45
+ The bias is a full matrix:
46
+
47
+ ```text
48
+ B: (q, n, n)
49
+ ```
50
+
51
+ Each output matrix-neuron gets its own complete matrix bias.
52
+
53
+ ## JAX Transforms
54
+
55
+ MATNETS functions are ordinary JAX functions. You can transform them with:
56
+
57
+ ```python
58
+ jax.jit(forward)
59
+ jax.vmap(forward, in_axes=(None, 0))
60
+ jax.grad(loss)
61
+ jax.lax.scan(step, carry, sequence)
62
+ ```
63
+
64
+ The main parallel work is the dense einsum. `vmap` adds batch or token axes
65
+ around it. `scan` handles recurrence over time while each step still uses
66
+ compiled dense contractions.
67
+
68
+ ## Recurrent State
69
+
70
+ RNN, LSTM, and GRU hidden states are stacks of matrices:
71
+
72
+ ```text
73
+ H: (hidden_neurons, n, n)
74
+ C: (hidden_neurons, n, n)
75
+ ```
76
+
77
+ Gates are also matrix-valued, so an LSTM forget gate has one value per matrix
78
+ entry, not just one scalar per neuron.
@@ -0,0 +1,42 @@
1
+ # Development
2
+
3
+ ## Test
4
+
5
+ ```powershell
6
+ .\.venv\Scripts\python.exe -m pytest
7
+ ```
8
+
9
+ ## Run Examples
10
+
11
+ ```powershell
12
+ .\.venv\Scripts\python.exe examples\basic_forward.py
13
+ .\.venv\Scripts\python.exe examples\five_hidden_net.py
14
+ .\.venv\Scripts\python.exe examples\matrix_architectures.py
15
+ ```
16
+
17
+ ## Syntax Check
18
+
19
+ ```powershell
20
+ .\.venv\Scripts\python.exe -m compileall -q src examples tests
21
+ ```
22
+
23
+ ## Build Documentation
24
+
25
+ ```powershell
26
+ .\.venv\Scripts\python.exe -m pip install -e ".[docs]"
27
+ .\.venv\Scripts\python.exe -m mkdocs build --strict
28
+ ```
29
+
30
+ ## Package Layout
31
+
32
+ ```text
33
+ src/matnets/
34
+ _params.py MatrixParams and init()
35
+ _dense.py core dense matrix primitive
36
+ lax/conv.py convolution-like matrix primitives
37
+ lax/attention.py matrix attention primitive
38
+ nn/recurrent.py RNN, LSTM, and GRU step patterns
39
+ tests/ test suite
40
+ examples/ runnable examples
41
+ docs/ user documentation
42
+ ```
@@ -0,0 +1,52 @@
1
+ # Examples
2
+
3
+ ## Basic Forward Pass
4
+
5
+ ```powershell
6
+ .\.venv\Scripts\python.exe examples\basic_forward.py
7
+ ```
8
+
9
+ This creates one dense layer and applies it to an input shaped `(2, 2, 2)`.
10
+
11
+ ## Five Hidden Layers
12
+
13
+ ```powershell
14
+ .\.venv\Scripts\python.exe examples\five_hidden_net.py
15
+ ```
16
+
17
+ This example defines a small class:
18
+
19
+ ```python
20
+ model = FiveHiddenNet(key, input_neurons=3, hidden_neurons=4, n=2)
21
+ y = jax.jit(model.forward)(model.params, x)
22
+ ```
23
+
24
+ The output shape is `(1, 2, 2)`.
25
+
26
+ ## Architecture Walkthrough
27
+
28
+ ```powershell
29
+ .\.venv\Scripts\python.exe examples\matrix_architectures.py
30
+ ```
31
+
32
+ This file checks shape flow through:
33
+
34
+ - MLP
35
+ - batched MLP with `jax.vmap`
36
+ - gradients with `jax.grad`
37
+ - RNN with `jax.lax.scan`
38
+ - LSTM with `jax.lax.scan`
39
+ - Frobenius attention
40
+ - residual block
41
+
42
+ ## Where Parallelization Happens
43
+
44
+ The dense operation:
45
+
46
+ ```python
47
+ jnp.einsum("qpak,pkc->qac", W, x)
48
+ ```
49
+
50
+ is the main kernel. JAX can compile it with `jit`, map it over batches or token
51
+ sequences with `vmap`, differentiate it with `grad`, and call it repeatedly
52
+ inside `lax.scan`.