matnets 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matnets/__init__.py +9 -0
- matnets/_dense.py +76 -0
- matnets/_params.py +66 -0
- matnets/lax/__init__.py +16 -0
- matnets/lax/attention.py +62 -0
- matnets/lax/conv.py +153 -0
- matnets/nn/__init__.py +5 -0
- matnets/nn/recurrent.py +60 -0
- matnets/py.typed +1 -0
- matnets-1.1.0.dist-info/METADATA +54 -0
- matnets-1.1.0.dist-info/RECORD +13 -0
- matnets-1.1.0.dist-info/WHEEL +4 -0
- matnets-1.1.0.dist-info/licenses/LICENSE +21 -0
matnets/__init__.py
ADDED
matnets/_dense.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Core dense matrix primitive."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
from jax import Array
|
|
10
|
+
|
|
11
|
+
from matnets._params import MatrixParams
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def identity(x: Array) -> Array:
|
|
15
|
+
return x
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@jax.custom_jvp
|
|
19
|
+
def _dense_linear(W: Array, B: Array, x: Array) -> Array:
|
|
20
|
+
q, p, a, k = W.shape
|
|
21
|
+
c = x.shape[-1]
|
|
22
|
+
W_flat = jnp.reshape(jnp.transpose(W, (0, 2, 1, 3)), (q * a, p * k))
|
|
23
|
+
x_flat = jnp.reshape(x, (p * k, c))
|
|
24
|
+
out = jnp.matmul(W_flat, x_flat)
|
|
25
|
+
return jnp.reshape(out, (q, a, c)) + B
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@_dense_linear.defjvp
|
|
29
|
+
def _dense_linear_jvp(
|
|
30
|
+
primals: tuple[Array, Array, Array],
|
|
31
|
+
tangents: tuple[Array, Array, Array],
|
|
32
|
+
) -> tuple[Array, Array]:
|
|
33
|
+
W, B, x = primals
|
|
34
|
+
dW, dB, dx = tangents
|
|
35
|
+
q, p, a, k = W.shape
|
|
36
|
+
c = x.shape[-1]
|
|
37
|
+
|
|
38
|
+
W_flat = jnp.reshape(jnp.transpose(W, (0, 2, 1, 3)), (q * a, p * k))
|
|
39
|
+
dW_flat = jnp.reshape(jnp.transpose(dW, (0, 2, 1, 3)), (q * a, p * k))
|
|
40
|
+
x_flat = jnp.reshape(x, (p * k, c))
|
|
41
|
+
dx_flat = jnp.reshape(dx, (p * k, c))
|
|
42
|
+
|
|
43
|
+
y_flat = jnp.matmul(W_flat, x_flat)
|
|
44
|
+
y = jnp.reshape(y_flat, (q, a, c)) + B
|
|
45
|
+
|
|
46
|
+
dy_flat = jnp.matmul(dW_flat, x_flat) + jnp.matmul(W_flat, dx_flat)
|
|
47
|
+
dy = jnp.reshape(dy_flat, (q, a, c)) + dB
|
|
48
|
+
|
|
49
|
+
return y, dy
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def dense(
|
|
53
|
+
params: MatrixParams,
|
|
54
|
+
x: Array,
|
|
55
|
+
activation: Callable[[Array], Array] = identity,
|
|
56
|
+
) -> Array:
|
|
57
|
+
"""Apply the core square-matrix contraction ``qpak,pkc -> qac``."""
|
|
58
|
+
|
|
59
|
+
W = params.W
|
|
60
|
+
x = jnp.asarray(x)
|
|
61
|
+
if W.ndim != 4:
|
|
62
|
+
msg = "dense expects weights shaped (q, p, n, n)"
|
|
63
|
+
raise ValueError(msg)
|
|
64
|
+
|
|
65
|
+
q, p, n, k = W.shape
|
|
66
|
+
if n != k:
|
|
67
|
+
msg = "dense weights must map square matrices: last two axes must match"
|
|
68
|
+
raise ValueError(msg)
|
|
69
|
+
if x.shape != (p, n, n):
|
|
70
|
+
msg = f"dense expects input shaped ({p}, {n}, {n})"
|
|
71
|
+
raise ValueError(msg)
|
|
72
|
+
if params.B.shape != (q, n, n):
|
|
73
|
+
msg = f"dense bias must be shaped ({q}, {n}, {n})"
|
|
74
|
+
raise ValueError(msg)
|
|
75
|
+
|
|
76
|
+
return activation(_dense_linear(W, params.B, x))
|
matnets/_params.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Parameter containers and initializers for matrix primitives."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from math import sqrt
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from jax import Array
|
|
11
|
+
from jax.typing import DTypeLike
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@jax.tree_util.register_pytree_node_class
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class MatrixParams:
|
|
17
|
+
"""Weights and bias for matrix contractions.
|
|
18
|
+
|
|
19
|
+
Dense weights use ``(q, p, n, n)`` and consume inputs shaped ``(p, n, n)``.
|
|
20
|
+
The bias is shaped ``(q, n, n)`` so each output unit has a full matrix bias.
|
|
21
|
+
Structural primitives may use higher-rank weights while keeping the same
|
|
22
|
+
small, pytree-friendly container.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
W: Array
|
|
26
|
+
B: Array
|
|
27
|
+
|
|
28
|
+
def tree_flatten(self) -> tuple[tuple[Array, Array], None]:
|
|
29
|
+
return (self.W, self.B), None
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def tree_unflatten(
|
|
33
|
+
cls,
|
|
34
|
+
aux_data: None,
|
|
35
|
+
children: tuple[Array, Array],
|
|
36
|
+
) -> MatrixParams:
|
|
37
|
+
del aux_data
|
|
38
|
+
return cls(*children)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def init(
|
|
42
|
+
key: Array,
|
|
43
|
+
p: int,
|
|
44
|
+
q: int,
|
|
45
|
+
n: int,
|
|
46
|
+
*,
|
|
47
|
+
dtype: DTypeLike = jnp.float32,
|
|
48
|
+
) -> MatrixParams:
|
|
49
|
+
"""Initialize dense matrix parameters with Glorot-uniform weights."""
|
|
50
|
+
|
|
51
|
+
if p <= 0 or q <= 0 or n <= 0:
|
|
52
|
+
msg = "p, q, and n must all be positive"
|
|
53
|
+
raise ValueError(msg)
|
|
54
|
+
|
|
55
|
+
fan_in = p * n
|
|
56
|
+
fan_out = q * n
|
|
57
|
+
limit = sqrt(6.0 / (fan_in + fan_out))
|
|
58
|
+
W = jax.random.uniform(
|
|
59
|
+
key,
|
|
60
|
+
shape=(q, p, n, n),
|
|
61
|
+
minval=-limit,
|
|
62
|
+
maxval=limit,
|
|
63
|
+
dtype=dtype,
|
|
64
|
+
)
|
|
65
|
+
B = jnp.zeros((q, n, n), dtype=dtype)
|
|
66
|
+
return MatrixParams(W=W, B=B)
|
matnets/lax/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""JAX-level structural primitives."""
|
|
2
|
+
|
|
3
|
+
from matnets.lax.attention import (
|
|
4
|
+
frobenius_score,
|
|
5
|
+
matrix_attention,
|
|
6
|
+
scaled_frobenius_score,
|
|
7
|
+
)
|
|
8
|
+
from matnets.lax.conv import matrix_conv1d, matrix_conv2d
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"frobenius_score",
|
|
12
|
+
"matrix_attention",
|
|
13
|
+
"matrix_conv1d",
|
|
14
|
+
"matrix_conv2d",
|
|
15
|
+
"scaled_frobenius_score",
|
|
16
|
+
]
|
matnets/lax/attention.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Attention primitive for matrix-valued tokens."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from math import sqrt
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from jax import Array
|
|
11
|
+
|
|
12
|
+
from matnets._dense import dense
|
|
13
|
+
from matnets._params import MatrixParams
|
|
14
|
+
|
|
15
|
+
ScoreFn = Callable[[Array, Array], Array]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def frobenius_score(Q: Array, K: Array) -> Array:
|
|
19
|
+
"""Scalar Frobenius inner product between two matrix-valued tokens."""
|
|
20
|
+
|
|
21
|
+
return jnp.einsum("pkc,pkc->", Q, K)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def scaled_frobenius_score(Q: Array, K: Array) -> Array:
|
|
25
|
+
scale = sqrt(Q.shape[0] * Q.shape[1] * Q.shape[2])
|
|
26
|
+
return frobenius_score(Q, K) / scale
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def matrix_attention(
|
|
30
|
+
params: MatrixParams | None,
|
|
31
|
+
Q: Array,
|
|
32
|
+
K: Array,
|
|
33
|
+
V: Array,
|
|
34
|
+
score_fn: ScoreFn = scaled_frobenius_score,
|
|
35
|
+
) -> Array:
|
|
36
|
+
"""Score, normalize, aggregate, and optionally project matrix tokens.
|
|
37
|
+
|
|
38
|
+
``Q``, ``K``, and ``V`` are token sequences shaped ``(t, p, n, n)``.
|
|
39
|
+
``score_fn`` defines the pairwise semantics and must return a scalar.
|
|
40
|
+
Passing ``params=None`` returns the aggregated context directly; otherwise
|
|
41
|
+
each context token is projected through ``dense(params, token)``.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
Q = jnp.asarray(Q)
|
|
45
|
+
K = jnp.asarray(K)
|
|
46
|
+
V = jnp.asarray(V)
|
|
47
|
+
if Q.ndim != 4 or K.ndim != 4 or V.ndim != 4:
|
|
48
|
+
msg = "matrix_attention expects Q, K, and V shaped (t, p, n, n)"
|
|
49
|
+
raise ValueError(msg)
|
|
50
|
+
if Q.shape[1:] != K.shape[1:] or Q.shape[1:] != V.shape[1:]:
|
|
51
|
+
msg = "Q, K, and V must share neuron and matrix axes"
|
|
52
|
+
raise ValueError(msg)
|
|
53
|
+
if Q.shape[-2] != Q.shape[-1]:
|
|
54
|
+
msg = "matrix_attention expects square matrix tokens"
|
|
55
|
+
raise ValueError(msg)
|
|
56
|
+
|
|
57
|
+
scores = jax.vmap(lambda q: jax.vmap(lambda k: score_fn(q, k))(K))(Q)
|
|
58
|
+
weights = jax.nn.softmax(scores, axis=-1)
|
|
59
|
+
context = jnp.einsum("ij,jpkc->ipkc", weights, V)
|
|
60
|
+
if params is None:
|
|
61
|
+
return context
|
|
62
|
+
return jax.vmap(lambda token: dense(params, token))(context)
|
matnets/lax/conv.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Convolution-like structural matrix primitives."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
from jax import Array
|
|
10
|
+
|
|
11
|
+
from matnets._params import MatrixParams
|
|
12
|
+
|
|
13
|
+
Padding = Literal["VALID", "SAME"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _check_bias_shape(params: MatrixParams, q: int, n: int) -> None:
|
|
17
|
+
if params.B.shape != (q, n, n):
|
|
18
|
+
msg = f"bias must be shaped ({q}, {n}, {n})"
|
|
19
|
+
raise ValueError(msg)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _check_square_matrix_kernel(W: Array, *, rank: int, shape_name: str) -> None:
|
|
23
|
+
if W.ndim != rank:
|
|
24
|
+
msg = f"{shape_name} expects weights shaped with rank {rank}"
|
|
25
|
+
raise ValueError(msg)
|
|
26
|
+
if W.shape[-2] != W.shape[-1]:
|
|
27
|
+
msg = (
|
|
28
|
+
"matrix convolution weights must map square matrices: "
|
|
29
|
+
"last two axes must match"
|
|
30
|
+
)
|
|
31
|
+
raise ValueError(msg)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _positive_stride(stride: int) -> int:
|
|
35
|
+
if stride <= 0:
|
|
36
|
+
msg = "stride must be positive"
|
|
37
|
+
raise ValueError(msg)
|
|
38
|
+
return stride
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _same_padding(size: int) -> tuple[int, int]:
|
|
42
|
+
left = (size - 1) // 2
|
|
43
|
+
right = size - 1 - left
|
|
44
|
+
return left, right
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def matrix_conv1d(
|
|
48
|
+
params: MatrixParams,
|
|
49
|
+
x: Array,
|
|
50
|
+
*,
|
|
51
|
+
stride: int = 1,
|
|
52
|
+
padding: Padding = "VALID",
|
|
53
|
+
) -> Array:
|
|
54
|
+
"""Slide matrix kernels across a sequence.
|
|
55
|
+
|
|
56
|
+
``params.W`` is shaped ``(q, p, r, n, n)`` and ``x`` is shaped
|
|
57
|
+
``(t, p, n, n)``. The output is ``(t_out, q, n, n)``.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
stride = _positive_stride(stride)
|
|
61
|
+
W = params.W
|
|
62
|
+
_check_square_matrix_kernel(W, rank=5, shape_name="matrix_conv1d")
|
|
63
|
+
|
|
64
|
+
q, p, kernel, n, _ = W.shape
|
|
65
|
+
x = jnp.asarray(x)
|
|
66
|
+
if x.ndim != 4:
|
|
67
|
+
msg = "matrix_conv1d expects inputs shaped (t, p, n, n)"
|
|
68
|
+
raise ValueError(msg)
|
|
69
|
+
if x.shape[1:] != (p, n, n):
|
|
70
|
+
msg = "input axes must match weight axes: x must be shaped (t, p, n, n)"
|
|
71
|
+
raise ValueError(msg)
|
|
72
|
+
_check_bias_shape(params, q, n)
|
|
73
|
+
|
|
74
|
+
if padding == "SAME":
|
|
75
|
+
pad_left, pad_right = _same_padding(kernel)
|
|
76
|
+
elif padding == "VALID":
|
|
77
|
+
pad_left, pad_right = 0, 0
|
|
78
|
+
else:
|
|
79
|
+
msg = "padding must be 'VALID' or 'SAME'"
|
|
80
|
+
raise ValueError(msg)
|
|
81
|
+
|
|
82
|
+
x_reshaped = jnp.reshape(x, (x.shape[0], p * n, n))
|
|
83
|
+
|
|
84
|
+
W_trans = jnp.transpose(W, (0, 3, 1, 4, 2))
|
|
85
|
+
W_reshaped = jnp.reshape(W_trans, (q * n, p * n, kernel))
|
|
86
|
+
|
|
87
|
+
out_conv = jax.lax.conv_general_dilated(
|
|
88
|
+
x_reshaped,
|
|
89
|
+
W_reshaped,
|
|
90
|
+
window_strides=(stride,),
|
|
91
|
+
padding=((pad_left, pad_right),),
|
|
92
|
+
dimension_numbers=("WCN", "OIW", "WCN"),
|
|
93
|
+
)
|
|
94
|
+
return jnp.reshape(out_conv, (out_conv.shape[0], q, n, n)) + params.B
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def matrix_conv2d(
|
|
98
|
+
params: MatrixParams,
|
|
99
|
+
x: Array,
|
|
100
|
+
*,
|
|
101
|
+
stride: int | tuple[int, int] = 1,
|
|
102
|
+
padding: Padding = "VALID",
|
|
103
|
+
) -> Array:
|
|
104
|
+
"""Slide matrix kernels across a 2D grid.
|
|
105
|
+
|
|
106
|
+
``params.W`` is shaped ``(q, p, h, w, n, n)`` and ``x`` is shaped
|
|
107
|
+
``(y, x, p, n, n)``. The output is ``(y_out, x_out, q, n, n)``.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
if isinstance(stride, int):
|
|
111
|
+
stride_y = stride_x = _positive_stride(stride)
|
|
112
|
+
else:
|
|
113
|
+
stride_y = _positive_stride(stride[0])
|
|
114
|
+
stride_x = _positive_stride(stride[1])
|
|
115
|
+
|
|
116
|
+
W = params.W
|
|
117
|
+
_check_square_matrix_kernel(W, rank=6, shape_name="matrix_conv2d")
|
|
118
|
+
|
|
119
|
+
q, p, kernel_y, kernel_x, n, _ = W.shape
|
|
120
|
+
x = jnp.asarray(x)
|
|
121
|
+
if x.ndim != 5:
|
|
122
|
+
msg = "matrix_conv2d expects inputs shaped (y, x, p, n, n)"
|
|
123
|
+
raise ValueError(msg)
|
|
124
|
+
if x.shape[2:] != (p, n, n):
|
|
125
|
+
msg = "input axes must match weight axes: x must be shaped (y, x, p, n, n)"
|
|
126
|
+
raise ValueError(msg)
|
|
127
|
+
_check_bias_shape(params, q, n)
|
|
128
|
+
|
|
129
|
+
if padding == "SAME":
|
|
130
|
+
pad_top, pad_bottom = _same_padding(kernel_y)
|
|
131
|
+
pad_left, pad_right = _same_padding(kernel_x)
|
|
132
|
+
elif padding == "VALID":
|
|
133
|
+
pad_top = pad_bottom = pad_left = pad_right = 0
|
|
134
|
+
else:
|
|
135
|
+
msg = "padding must be 'VALID' or 'SAME'"
|
|
136
|
+
raise ValueError(msg)
|
|
137
|
+
|
|
138
|
+
x_reshaped = jnp.reshape(x, (x.shape[0], x.shape[1], p * n, n))
|
|
139
|
+
|
|
140
|
+
W_trans = jnp.transpose(W, (0, 4, 1, 5, 2, 3))
|
|
141
|
+
W_reshaped = jnp.reshape(W_trans, (q * n, p * n, kernel_y, kernel_x))
|
|
142
|
+
|
|
143
|
+
out_conv = jax.lax.conv_general_dilated(
|
|
144
|
+
x_reshaped,
|
|
145
|
+
W_reshaped,
|
|
146
|
+
window_strides=(stride_y, stride_x),
|
|
147
|
+
padding=((pad_top, pad_bottom), (pad_left, pad_right)),
|
|
148
|
+
dimension_numbers=("HWCN", "OIHW", "HWCN"),
|
|
149
|
+
)
|
|
150
|
+
return (
|
|
151
|
+
jnp.reshape(out_conv, (out_conv.shape[0], out_conv.shape[1], q, n, n))
|
|
152
|
+
+ params.B
|
|
153
|
+
)
|
matnets/nn/__init__.py
ADDED
matnets/nn/recurrent.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Recurrent patterns built from the dense matrix primitive."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Mapping
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
from jax import Array
|
|
10
|
+
|
|
11
|
+
from matnets._dense import dense
|
|
12
|
+
from matnets._params import MatrixParams
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def rnn_step(
|
|
16
|
+
params: MatrixParams,
|
|
17
|
+
carry: Array,
|
|
18
|
+
x: Array,
|
|
19
|
+
*,
|
|
20
|
+
activation: Callable[[Array], Array] = jnp.tanh,
|
|
21
|
+
) -> tuple[Array, Array]:
|
|
22
|
+
"""Simple RNN step suitable for use inside ``jax.lax.scan``."""
|
|
23
|
+
|
|
24
|
+
combined = jnp.concatenate([carry, x], axis=0)
|
|
25
|
+
next_carry = dense(params, combined, activation)
|
|
26
|
+
return next_carry, next_carry
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def lstm_step(
|
|
30
|
+
params: Mapping[str, MatrixParams],
|
|
31
|
+
carry: tuple[Array, Array],
|
|
32
|
+
x: Array,
|
|
33
|
+
) -> tuple[tuple[Array, Array], Array]:
|
|
34
|
+
"""LSTM step using ``i``, ``f``, ``g``, and ``o`` dense gate params."""
|
|
35
|
+
|
|
36
|
+
h, c = carry
|
|
37
|
+
combined = jnp.concatenate([h, x], axis=0)
|
|
38
|
+
i = dense(params["i"], combined, jax.nn.sigmoid)
|
|
39
|
+
f = dense(params["f"], combined, jax.nn.sigmoid)
|
|
40
|
+
g = dense(params["g"], combined, jnp.tanh)
|
|
41
|
+
o = dense(params["o"], combined, jax.nn.sigmoid)
|
|
42
|
+
next_c = f * c + i * g
|
|
43
|
+
next_h = o * jnp.tanh(next_c)
|
|
44
|
+
return (next_h, next_c), next_h
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def gru_step(
|
|
48
|
+
params: Mapping[str, MatrixParams],
|
|
49
|
+
carry: Array,
|
|
50
|
+
x: Array,
|
|
51
|
+
) -> tuple[Array, Array]:
|
|
52
|
+
"""GRU step using ``z``, ``r``, and ``n`` dense gate params."""
|
|
53
|
+
|
|
54
|
+
combined = jnp.concatenate([carry, x], axis=0)
|
|
55
|
+
z = dense(params["z"], combined, jax.nn.sigmoid)
|
|
56
|
+
r = dense(params["r"], combined, jax.nn.sigmoid)
|
|
57
|
+
candidate_input = jnp.concatenate([r * carry, x], axis=0)
|
|
58
|
+
n = dense(params["n"], candidate_input, jnp.tanh)
|
|
59
|
+
next_carry = (1.0 - z) * n + z * carry
|
|
60
|
+
return next_carry, next_carry
|
matnets/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: matnets
|
|
3
|
+
Version: 1.1.0
|
|
4
|
+
Summary: A small experimental neural network library where neurons are represented as matrices.
|
|
5
|
+
Project-URL: Homepage, https://github.com/dsainvg/MATNETS
|
|
6
|
+
Project-URL: Documentation, https://github.com/dsainvg/MATNETS/tree/main/docs
|
|
7
|
+
Project-URL: Issues, https://github.com/dsainvg/MATNETS/issues
|
|
8
|
+
Author: dsainvg
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: jax,machine-learning,matrix,neural-networks
|
|
12
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Requires-Python: >=3.11
|
|
23
|
+
Requires-Dist: jax>=0.4
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: mypy>=1.10; extra == 'dev'
|
|
26
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
27
|
+
Requires-Dist: ruff>=0.5; extra == 'dev'
|
|
28
|
+
Provides-Extra: docs
|
|
29
|
+
Requires-Dist: mkdocs-material>=9.5; extra == 'docs'
|
|
30
|
+
Requires-Dist: mkdocs>=1.6; extra == 'docs'
|
|
31
|
+
Description-Content-Type: text/markdown
|
|
32
|
+
|
|
33
|
+
# MATNETS
|
|
34
|
+
|
|
35
|
+
MATNETS is a small JAX library for matrix-neuron neural network experiments.
|
|
36
|
+
Each neuron carries an `n x n` matrix instead of a scalar.
|
|
37
|
+
|
|
38
|
+
The user documentation lives in [`docs/`](docs/index.md):
|
|
39
|
+
|
|
40
|
+
- [`docs/index.md`](docs/index.md): overview
|
|
41
|
+
- [`docs/getting-started.md`](docs/getting-started.md): install and first model
|
|
42
|
+
- [`docs/concepts.md`](docs/concepts.md): matrix-neuron shapes and JAX transforms
|
|
43
|
+
- [`docs/api.md`](docs/api.md): API guide
|
|
44
|
+
- [`docs/examples.md`](docs/examples.md): runnable examples
|
|
45
|
+
- [`docs/development.md`](docs/development.md): local development commands
|
|
46
|
+
|
|
47
|
+
Quick check:
|
|
48
|
+
|
|
49
|
+
```powershell
|
|
50
|
+
.\.venv\Scripts\python.exe examples\five_hidden_net.py
|
|
51
|
+
.\.venv\Scripts\python.exe -m pytest
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
MIT license.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
matnets/__init__.py,sha256=s90bc4_V_TVyLgnmf1JbD-MiqWEGEZ9rSAUMMsHVdQM,246
|
|
2
|
+
matnets/_dense.py,sha256=00Riy9Nkqm4U7qo4wAS-xCiNeBnotNrvvs0z6kNRYTQ,2061
|
|
3
|
+
matnets/_params.py,sha256=wKdDyW-St-VxMyhSzgAGuDG5muhYTCTYa_nLrF3CY0s,1610
|
|
4
|
+
matnets/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
5
|
+
matnets/lax/__init__.py,sha256=7zw_HdUXRqUmfFY9pYITNQEWQRSy25DZq00FjEFD2pw,341
|
|
6
|
+
matnets/lax/attention.py,sha256=ImdhAusgzdv_6BUNscePcqWHPEV2IXgb7BczYrFWd10,1979
|
|
7
|
+
matnets/lax/conv.py,sha256=wP5Ly-hHpFes9jXepPcJrzG_xli7sMnb3Ak6CTT4vz4,4484
|
|
8
|
+
matnets/nn/__init__.py,sha256=KFp0s_hxBsOm59vGp4yGYg0zccaA0GwRi3KA6qcrBg0,176
|
|
9
|
+
matnets/nn/recurrent.py,sha256=Vtk0BSnc-OEUWx1uUyjmIdn2GRqtAihCSDxZaFQ9grs,1767
|
|
10
|
+
matnets-1.1.0.dist-info/METADATA,sha256=f90gfRdkQQR7KG74rcGLXcXk5psn2XLfD8QnLGkp3Cw,2061
|
|
11
|
+
matnets-1.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
12
|
+
matnets-1.1.0.dist-info/licenses/LICENSE,sha256=Y2Z4Zzs_BrNYRx8nWVtax4k14klf8SELDd77VK0wui8,1077
|
|
13
|
+
matnets-1.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 MATNETS contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|