matnets 1.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matnets-1.1.0/.github/workflows/docs.yml +63 -0
- matnets-1.1.0/.github/workflows/pypi.yaml +30 -0
- matnets-1.1.0/.github/workflows/testing.yml +82 -0
- matnets-1.1.0/.gitignore +15 -0
- matnets-1.1.0/CHANGELOG.md +6 -0
- matnets-1.1.0/CONTRIBUTING.md +20 -0
- matnets-1.1.0/LICENSE +21 -0
- matnets-1.1.0/PKG-INFO +54 -0
- matnets-1.1.0/README.md +22 -0
- matnets-1.1.0/docs/api.md +164 -0
- matnets-1.1.0/docs/concepts.md +78 -0
- matnets-1.1.0/docs/development.md +42 -0
- matnets-1.1.0/docs/examples.md +52 -0
- matnets-1.1.0/docs/getting-started.md +59 -0
- matnets-1.1.0/docs/index.md +29 -0
- matnets-1.1.0/docs/stylesheets/extra.css +89 -0
- matnets-1.1.0/examples/1_linear_regression.py +60 -0
- matnets-1.1.0/examples/2_mlp_3_hidden_layers.py +72 -0
- matnets-1.1.0/examples/3_mlp_7_hidden_layers.py +72 -0
- matnets-1.1.0/examples/4_cnn1d.py +78 -0
- matnets-1.1.0/examples/5_cnn2d.py +77 -0
- matnets-1.1.0/examples/6_lstm.py +73 -0
- matnets-1.1.0/examples/7_rnn.py +77 -0
- matnets-1.1.0/examples/8_transformer.py +75 -0
- matnets-1.1.0/examples/9_kaggle_cnn_mnist copy.ipynb +842 -0
- matnets-1.1.0/examples/9_kaggle_cnn_mnist.ipynb +766 -0
- matnets-1.1.0/examples/README.md +19 -0
- matnets-1.1.0/mkdocs.yml +26 -0
- matnets-1.1.0/pyproject.toml +65 -0
- matnets-1.1.0/src/matnets/__init__.py +9 -0
- matnets-1.1.0/src/matnets/_dense.py +76 -0
- matnets-1.1.0/src/matnets/_params.py +66 -0
- matnets-1.1.0/src/matnets/lax/__init__.py +16 -0
- matnets-1.1.0/src/matnets/lax/attention.py +62 -0
- matnets-1.1.0/src/matnets/lax/conv.py +153 -0
- matnets-1.1.0/src/matnets/nn/__init__.py +5 -0
- matnets-1.1.0/src/matnets/nn/recurrent.py +60 -0
- matnets-1.1.0/src/matnets/py.typed +1 -0
- matnets-1.1.0/tests/test_primitives.py +318 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
name: Docs
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches:
|
|
6
|
+
- main
|
|
7
|
+
paths:
|
|
8
|
+
- "docs/**"
|
|
9
|
+
- "mkdocs.yml"
|
|
10
|
+
- "pyproject.toml"
|
|
11
|
+
- ".github/workflows/docs.yml"
|
|
12
|
+
workflow_dispatch:
|
|
13
|
+
|
|
14
|
+
permissions:
|
|
15
|
+
contents: read
|
|
16
|
+
pages: write
|
|
17
|
+
id-token: write
|
|
18
|
+
|
|
19
|
+
concurrency:
|
|
20
|
+
group: pages
|
|
21
|
+
cancel-in-progress: false
|
|
22
|
+
|
|
23
|
+
jobs:
|
|
24
|
+
build:
|
|
25
|
+
name: Build MkDocs site
|
|
26
|
+
runs-on: ubuntu-latest
|
|
27
|
+
|
|
28
|
+
steps:
|
|
29
|
+
- name: Check out repository
|
|
30
|
+
uses: actions/checkout@v4
|
|
31
|
+
|
|
32
|
+
- name: Set up Python
|
|
33
|
+
uses: actions/setup-python@v5
|
|
34
|
+
with:
|
|
35
|
+
python-version: "3.12"
|
|
36
|
+
cache: pip
|
|
37
|
+
|
|
38
|
+
- name: Configure GitHub Pages
|
|
39
|
+
uses: actions/configure-pages@v5
|
|
40
|
+
|
|
41
|
+
- name: Install docs dependencies
|
|
42
|
+
run: python -m pip install -e ".[docs]"
|
|
43
|
+
|
|
44
|
+
- name: Build site
|
|
45
|
+
run: mkdocs build --strict
|
|
46
|
+
|
|
47
|
+
- name: Upload Pages artifact
|
|
48
|
+
uses: actions/upload-pages-artifact@v3
|
|
49
|
+
with:
|
|
50
|
+
path: site
|
|
51
|
+
|
|
52
|
+
deploy:
|
|
53
|
+
name: Deploy to GitHub Pages
|
|
54
|
+
needs: build
|
|
55
|
+
runs-on: ubuntu-latest
|
|
56
|
+
environment:
|
|
57
|
+
name: github-pages
|
|
58
|
+
url: "${{ steps.deployment.outputs.page_url }}"
|
|
59
|
+
|
|
60
|
+
steps:
|
|
61
|
+
- name: Deploy site
|
|
62
|
+
id: deployment
|
|
63
|
+
uses: actions/deploy-pages@v4
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# pypi.yaml
|
|
2
|
+
# Configuration for PyPI publishing (example for GitHub Actions)
|
|
3
|
+
# Adjust as needed for your CI/CD provider and secrets management
|
|
4
|
+
|
|
5
|
+
name: Publish Python 🐍 distribution 📦 to PyPI
|
|
6
|
+
|
|
7
|
+
on:
|
|
8
|
+
push:
|
|
9
|
+
tags:
|
|
10
|
+
- 'v*.*.*'
|
|
11
|
+
|
|
12
|
+
jobs:
|
|
13
|
+
build-and-publish:
|
|
14
|
+
runs-on: ubuntu-latest
|
|
15
|
+
steps:
|
|
16
|
+
- uses: actions/checkout@v4
|
|
17
|
+
- name: Set up Python
|
|
18
|
+
uses: actions/setup-python@v5
|
|
19
|
+
with:
|
|
20
|
+
python-version: '3.11'
|
|
21
|
+
- name: Install build dependencies
|
|
22
|
+
run: |
|
|
23
|
+
python -m pip install --upgrade pip
|
|
24
|
+
pip install build
|
|
25
|
+
- name: Build package
|
|
26
|
+
run: python -m build
|
|
27
|
+
- name: Publish package to PyPI
|
|
28
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
29
|
+
with:
|
|
30
|
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
pull_request:
|
|
5
|
+
paths-ignore:
|
|
6
|
+
- "docs/**"
|
|
7
|
+
- "mkdocs.yml"
|
|
8
|
+
- ".github/workflows/docs.yml"
|
|
9
|
+
- "*.md"
|
|
10
|
+
push:
|
|
11
|
+
branches:
|
|
12
|
+
- main
|
|
13
|
+
paths-ignore:
|
|
14
|
+
- "docs/**"
|
|
15
|
+
- "mkdocs.yml"
|
|
16
|
+
- ".github/workflows/docs.yml"
|
|
17
|
+
- "*.md"
|
|
18
|
+
|
|
19
|
+
jobs:
|
|
20
|
+
test:
|
|
21
|
+
name: Test Python ${{ matrix.python-version }}
|
|
22
|
+
runs-on: ubuntu-latest
|
|
23
|
+
|
|
24
|
+
strategy:
|
|
25
|
+
fail-fast: false
|
|
26
|
+
matrix:
|
|
27
|
+
python-version:
|
|
28
|
+
- "3.11"
|
|
29
|
+
- "3.12"
|
|
30
|
+
- "3.13"
|
|
31
|
+
- "3.14"
|
|
32
|
+
|
|
33
|
+
steps:
|
|
34
|
+
- name: Check out repository
|
|
35
|
+
uses: actions/checkout@v4
|
|
36
|
+
|
|
37
|
+
- name: Set up Python
|
|
38
|
+
uses: actions/setup-python@v5
|
|
39
|
+
with:
|
|
40
|
+
python-version: ${{ matrix.python-version }}
|
|
41
|
+
cache: pip
|
|
42
|
+
|
|
43
|
+
- name: Install package
|
|
44
|
+
run: python -m pip install -e ".[dev,docs]"
|
|
45
|
+
|
|
46
|
+
- name: Lint
|
|
47
|
+
run: ruff check src/ tests/
|
|
48
|
+
|
|
49
|
+
- name: Type check
|
|
50
|
+
run: mypy src
|
|
51
|
+
|
|
52
|
+
- name: Test
|
|
53
|
+
run: pytest
|
|
54
|
+
|
|
55
|
+
- name: Build docs
|
|
56
|
+
run: mkdocs build --strict
|
|
57
|
+
|
|
58
|
+
package:
|
|
59
|
+
name: Build package
|
|
60
|
+
runs-on: ubuntu-latest
|
|
61
|
+
|
|
62
|
+
steps:
|
|
63
|
+
- name: Check out repository
|
|
64
|
+
uses: actions/checkout@v4
|
|
65
|
+
|
|
66
|
+
- name: Set up Python
|
|
67
|
+
uses: actions/setup-python@v5
|
|
68
|
+
with:
|
|
69
|
+
python-version: "3.13"
|
|
70
|
+
cache: pip
|
|
71
|
+
|
|
72
|
+
- name: Install build tools
|
|
73
|
+
run: python -m pip install build
|
|
74
|
+
|
|
75
|
+
- name: Build distribution
|
|
76
|
+
run: python -m build
|
|
77
|
+
|
|
78
|
+
- name: Upload distribution artifact
|
|
79
|
+
uses: actions/upload-artifact@v4
|
|
80
|
+
with:
|
|
81
|
+
name: python-package
|
|
82
|
+
path: dist/*
|
matnets-1.1.0/.gitignore
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Contributing
|
|
2
|
+
|
|
3
|
+
Thanks for helping MATNETS grow.
|
|
4
|
+
|
|
5
|
+
## Local Setup
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
python -m pip install -e ".[dev,docs]"
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Checks
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
pytest
|
|
15
|
+
ruff check .
|
|
16
|
+
mypy src
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
Keep changes focused, add tests for behavior changes, and update docs when the
|
|
20
|
+
public API changes.
|
matnets-1.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 MATNETS contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
matnets-1.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: matnets
|
|
3
|
+
Version: 1.1.0
|
|
4
|
+
Summary: A small experimental neural network library where neurons are represented as matrices.
|
|
5
|
+
Project-URL: Homepage, https://github.com/dsainvg/MATNETS
|
|
6
|
+
Project-URL: Documentation, https://github.com/dsainvg/MATNETS/tree/main/docs
|
|
7
|
+
Project-URL: Issues, https://github.com/dsainvg/MATNETS/issues
|
|
8
|
+
Author: dsainvg
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: jax,machine-learning,matrix,neural-networks
|
|
12
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Requires-Python: >=3.11
|
|
23
|
+
Requires-Dist: jax>=0.4
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: mypy>=1.10; extra == 'dev'
|
|
26
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
27
|
+
Requires-Dist: ruff>=0.5; extra == 'dev'
|
|
28
|
+
Provides-Extra: docs
|
|
29
|
+
Requires-Dist: mkdocs-material>=9.5; extra == 'docs'
|
|
30
|
+
Requires-Dist: mkdocs>=1.6; extra == 'docs'
|
|
31
|
+
Description-Content-Type: text/markdown
|
|
32
|
+
|
|
33
|
+
# MATNETS
|
|
34
|
+
|
|
35
|
+
MATNETS is a small JAX library for matrix-neuron neural network experiments.
|
|
36
|
+
Each neuron carries an `n x n` matrix instead of a scalar.
|
|
37
|
+
|
|
38
|
+
The user documentation lives in [`docs/`](docs/index.md):
|
|
39
|
+
|
|
40
|
+
- [`docs/index.md`](docs/index.md): overview
|
|
41
|
+
- [`docs/getting-started.md`](docs/getting-started.md): install and first model
|
|
42
|
+
- [`docs/concepts.md`](docs/concepts.md): matrix-neuron shapes and JAX transforms
|
|
43
|
+
- [`docs/api.md`](docs/api.md): API guide
|
|
44
|
+
- [`docs/examples.md`](docs/examples.md): runnable examples
|
|
45
|
+
- [`docs/development.md`](docs/development.md): local development commands
|
|
46
|
+
|
|
47
|
+
Quick check:
|
|
48
|
+
|
|
49
|
+
```powershell
|
|
50
|
+
.\.venv\Scripts\python.exe examples\five_hidden_net.py
|
|
51
|
+
.\.venv\Scripts\python.exe -m pytest
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
MIT license.
|
matnets-1.1.0/README.md
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# MATNETS
|
|
2
|
+
|
|
3
|
+
MATNETS is a small JAX library for matrix-neuron neural network experiments.
|
|
4
|
+
Each neuron carries an `n x n` matrix instead of a scalar.
|
|
5
|
+
|
|
6
|
+
The user documentation lives in [`docs/`](docs/index.md):
|
|
7
|
+
|
|
8
|
+
- [`docs/index.md`](docs/index.md): overview
|
|
9
|
+
- [`docs/getting-started.md`](docs/getting-started.md): install and first model
|
|
10
|
+
- [`docs/concepts.md`](docs/concepts.md): matrix-neuron shapes and JAX transforms
|
|
11
|
+
- [`docs/api.md`](docs/api.md): API guide
|
|
12
|
+
- [`docs/examples.md`](docs/examples.md): runnable examples
|
|
13
|
+
- [`docs/development.md`](docs/development.md): local development commands
|
|
14
|
+
|
|
15
|
+
Quick check:
|
|
16
|
+
|
|
17
|
+
```powershell
|
|
18
|
+
.\.venv\Scripts\python.exe examples\five_hidden_net.py
|
|
19
|
+
.\.venv\Scripts\python.exe -m pytest
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
MIT license.
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# API Guide
|
|
2
|
+
|
|
3
|
+
## `matnets.MatrixParams`
|
|
4
|
+
|
|
5
|
+
`MatrixParams` stores the weights and bias for matrix primitives:
|
|
6
|
+
|
|
7
|
+
```python
|
|
8
|
+
from matnets import MatrixParams
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
Dense parameters use:
|
|
12
|
+
|
|
13
|
+
```text
|
|
14
|
+
W: (q, p, n, n)
|
|
15
|
+
B: (q, n, n)
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
`MatrixParams` is registered as a JAX pytree, so it works with `jax.jit`,
|
|
19
|
+
`jax.vmap`, `jax.grad`, and nested dictionaries/lists of parameters.
|
|
20
|
+
|
|
21
|
+
## `matnets.init`
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
params = matnets.init(key, p=2, q=3, n=4)
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
Creates:
|
|
28
|
+
|
|
29
|
+
```text
|
|
30
|
+
params.W: (3, 2, 4, 4)
|
|
31
|
+
params.B: (3, 4, 4)
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
Weights use Glorot-uniform initialization. Bias starts at zero.
|
|
35
|
+
|
|
36
|
+
## `matnets.dense`
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
y = matnets.dense(params, x)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
Expected shapes:
|
|
43
|
+
|
|
44
|
+
```text
|
|
45
|
+
params.W: (q, p, n, n)
|
|
46
|
+
params.B: (q, n, n)
|
|
47
|
+
x: (p, n, n)
|
|
48
|
+
y: (q, n, n)
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
With activation:
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
y = matnets.dense(params, x, activation=jax.nn.relu)
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
The core operation is:
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
jnp.einsum("qpak,pkc->qac", params.W, x) + params.B
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## `matnets.lax.matrix_conv1d`
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
from matnets.lax import matrix_conv1d
|
|
67
|
+
|
|
68
|
+
y = matrix_conv1d(params, x, stride=1, padding="VALID")
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
Expected shapes:
|
|
72
|
+
|
|
73
|
+
```text
|
|
74
|
+
params.W: (q, p, r, n, n)
|
|
75
|
+
params.B: (q, n, n)
|
|
76
|
+
x: (t, p, n, n)
|
|
77
|
+
y: (t_out, q, n, n)
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
`r` is the 1D kernel size.
|
|
81
|
+
|
|
82
|
+
## `matnets.lax.matrix_conv2d`
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
from matnets.lax import matrix_conv2d
|
|
86
|
+
|
|
87
|
+
y = matrix_conv2d(params, x, stride=(1, 1), padding="SAME")
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
Expected shapes:
|
|
91
|
+
|
|
92
|
+
```text
|
|
93
|
+
params.W: (q, p, h, w, n, n)
|
|
94
|
+
params.B: (q, n, n)
|
|
95
|
+
x: (height, width, p, n, n)
|
|
96
|
+
y: (height_out, width_out, q, n, n)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
## `matnets.lax.matrix_attention`
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
from matnets.lax import matrix_attention
|
|
103
|
+
|
|
104
|
+
out = matrix_attention(None, Q, K, V)
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
Expected token shapes:
|
|
108
|
+
|
|
109
|
+
```text
|
|
110
|
+
Q: (tokens_q, p, n, n)
|
|
111
|
+
K: (tokens_k, p, n, n)
|
|
112
|
+
V: (tokens_k, p, n, n)
|
|
113
|
+
out: (tokens_q, p, n, n)
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
By default the score is a scaled Frobenius inner product. You can pass a custom
|
|
117
|
+
`score_fn` that receives one query token and one key token and returns a scalar.
|
|
118
|
+
|
|
119
|
+
If `params` is not `None`, each aggregated output token is projected through
|
|
120
|
+
`matnets.dense(params, token)`.
|
|
121
|
+
|
|
122
|
+
## `matnets.nn`
|
|
123
|
+
|
|
124
|
+
`matnets.nn` contains recurrent wiring patterns built from `dense`.
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
from matnets.nn import rnn_step, lstm_step, gru_step
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
These functions are intended to be used with `jax.lax.scan`.
|
|
131
|
+
|
|
132
|
+
### RNN
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
carry, outputs = jax.lax.scan(
|
|
136
|
+
lambda h, x_t: rnn_step(params, h, x_t),
|
|
137
|
+
h0,
|
|
138
|
+
sequence,
|
|
139
|
+
)
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
### LSTM
|
|
143
|
+
|
|
144
|
+
```python
|
|
145
|
+
carry, outputs = jax.lax.scan(
|
|
146
|
+
lambda carry, x_t: lstm_step(params, carry, x_t),
|
|
147
|
+
(h0, c0),
|
|
148
|
+
sequence,
|
|
149
|
+
)
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
LSTM params must contain keys `"i"`, `"f"`, `"g"`, and `"o"`.
|
|
153
|
+
|
|
154
|
+
### GRU
|
|
155
|
+
|
|
156
|
+
```python
|
|
157
|
+
carry, outputs = jax.lax.scan(
|
|
158
|
+
lambda h, x_t: gru_step(params, h, x_t),
|
|
159
|
+
h0,
|
|
160
|
+
sequence,
|
|
161
|
+
)
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
GRU params must contain keys `"z"`, `"r"`, and `"n"`.
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Concepts
|
|
2
|
+
|
|
3
|
+
## Matrix-Neurons
|
|
4
|
+
|
|
5
|
+
A traditional dense layer usually maps vectors:
|
|
6
|
+
|
|
7
|
+
```text
|
|
8
|
+
x: (p)
|
|
9
|
+
W: (q, p)
|
|
10
|
+
y: (q)
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
MATNETS maps stacks of square matrices:
|
|
14
|
+
|
|
15
|
+
```text
|
|
16
|
+
x: (p, n, n)
|
|
17
|
+
W: (q, p, n, n)
|
|
18
|
+
B: (q, n, n)
|
|
19
|
+
y: (q, n, n)
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
`p` is the input neuron count. `q` is the output neuron count. `n` is the
|
|
23
|
+
matrix size inside each neuron.
|
|
24
|
+
|
|
25
|
+
## Dense Primitive
|
|
26
|
+
|
|
27
|
+
The core operation is:
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
jnp.einsum("qpak,pkc->qac", W, x) + B
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
Under the square-matrix contract:
|
|
34
|
+
|
|
35
|
+
```text
|
|
36
|
+
a == n
|
|
37
|
+
k == n
|
|
38
|
+
c == n
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
so the output is always `(q, n, n)`.
|
|
42
|
+
|
|
43
|
+
## Bias
|
|
44
|
+
|
|
45
|
+
The bias is a full matrix:
|
|
46
|
+
|
|
47
|
+
```text
|
|
48
|
+
B: (q, n, n)
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
Each output matrix-neuron gets its own complete matrix bias.
|
|
52
|
+
|
|
53
|
+
## JAX Transforms
|
|
54
|
+
|
|
55
|
+
MATNETS functions are ordinary JAX functions. You can transform them with:
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
jax.jit(forward)
|
|
59
|
+
jax.vmap(forward, in_axes=(None, 0))
|
|
60
|
+
jax.grad(loss)
|
|
61
|
+
jax.lax.scan(step, carry, sequence)
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
The main parallel work is the dense einsum. `vmap` adds batch or token axes
|
|
65
|
+
around it. `scan` handles recurrence over time while each step still uses
|
|
66
|
+
compiled dense contractions.
|
|
67
|
+
|
|
68
|
+
## Recurrent State
|
|
69
|
+
|
|
70
|
+
RNN, LSTM, and GRU hidden states are stacks of matrices:
|
|
71
|
+
|
|
72
|
+
```text
|
|
73
|
+
H: (hidden_neurons, n, n)
|
|
74
|
+
C: (hidden_neurons, n, n)
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
Gates are also matrix-valued, so an LSTM forget gate has one value per matrix
|
|
78
|
+
entry, not just one scalar per neuron.
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Development
|
|
2
|
+
|
|
3
|
+
## Test
|
|
4
|
+
|
|
5
|
+
```powershell
|
|
6
|
+
.\.venv\Scripts\python.exe -m pytest
|
|
7
|
+
```
|
|
8
|
+
|
|
9
|
+
## Run Examples
|
|
10
|
+
|
|
11
|
+
```powershell
|
|
12
|
+
.\.venv\Scripts\python.exe examples\basic_forward.py
|
|
13
|
+
.\.venv\Scripts\python.exe examples\five_hidden_net.py
|
|
14
|
+
.\.venv\Scripts\python.exe examples\matrix_architectures.py
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## Syntax Check
|
|
18
|
+
|
|
19
|
+
```powershell
|
|
20
|
+
.\.venv\Scripts\python.exe -m compileall -q src examples tests
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## Build Documentation
|
|
24
|
+
|
|
25
|
+
```powershell
|
|
26
|
+
.\.venv\Scripts\python.exe -m pip install -e ".[docs]"
|
|
27
|
+
.\.venv\Scripts\python.exe -m mkdocs build --strict
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## Package Layout
|
|
31
|
+
|
|
32
|
+
```text
|
|
33
|
+
src/matnets/
|
|
34
|
+
_params.py MatrixParams and init()
|
|
35
|
+
_dense.py core dense matrix primitive
|
|
36
|
+
lax/conv.py convolution-like matrix primitives
|
|
37
|
+
lax/attention.py matrix attention primitive
|
|
38
|
+
nn/recurrent.py RNN, LSTM, and GRU step patterns
|
|
39
|
+
tests/ test suite
|
|
40
|
+
examples/ runnable examples
|
|
41
|
+
docs/ user documentation
|
|
42
|
+
```
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Examples
|
|
2
|
+
|
|
3
|
+
## Basic Forward Pass
|
|
4
|
+
|
|
5
|
+
```powershell
|
|
6
|
+
.\.venv\Scripts\python.exe examples\basic_forward.py
|
|
7
|
+
```
|
|
8
|
+
|
|
9
|
+
This creates one dense layer and applies it to an input shaped `(2, 2, 2)`.
|
|
10
|
+
|
|
11
|
+
## Five Hidden Layers
|
|
12
|
+
|
|
13
|
+
```powershell
|
|
14
|
+
.\.venv\Scripts\python.exe examples\five_hidden_net.py
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
This example defines a small class:
|
|
18
|
+
|
|
19
|
+
```python
|
|
20
|
+
model = FiveHiddenNet(key, input_neurons=3, hidden_neurons=4, n=2)
|
|
21
|
+
y = jax.jit(model.forward)(model.params, x)
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
The output shape is `(1, 2, 2)`.
|
|
25
|
+
|
|
26
|
+
## Architecture Walkthrough
|
|
27
|
+
|
|
28
|
+
```powershell
|
|
29
|
+
.\.venv\Scripts\python.exe examples\matrix_architectures.py
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
This file checks shape flow through:
|
|
33
|
+
|
|
34
|
+
- MLP
|
|
35
|
+
- batched MLP with `jax.vmap`
|
|
36
|
+
- gradients with `jax.grad`
|
|
37
|
+
- RNN with `jax.lax.scan`
|
|
38
|
+
- LSTM with `jax.lax.scan`
|
|
39
|
+
- Frobenius attention
|
|
40
|
+
- residual block
|
|
41
|
+
|
|
42
|
+
## Where Parallelization Happens
|
|
43
|
+
|
|
44
|
+
The dense operation:
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
jnp.einsum("qpak,pkc->qac", W, x)
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
is the main kernel. JAX can compile it with `jit`, map it over batches or token
|
|
51
|
+
sequences with `vmap`, differentiate it with `grad`, and call it repeatedly
|
|
52
|
+
inside `lax.scan`.
|