nmn 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/nmn.py +25 -14
- nmn/nnx/yatconv.py +19 -2
- nmn-0.1.5.dist-info/METADATA +176 -0
- {nmn-0.1.4.dist-info → nmn-0.1.5.dist-info}/RECORD +6 -6
- nmn-0.1.4.dist-info/METADATA +0 -119
- {nmn-0.1.4.dist-info → nmn-0.1.5.dist-info}/WHEEL +0 -0
- {nmn-0.1.4.dist-info → nmn-0.1.5.dist-info}/licenses/LICENSE +0 -0
nmn/nnx/nmn.py
CHANGED
@@ -4,26 +4,18 @@ import typing as tp
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
-
import numpy as np
|
8
7
|
from jax import lax
|
9
|
-
import opt_einsum
|
10
8
|
|
11
|
-
from flax.core.frozen_dict import FrozenDict
|
12
9
|
from flax import nnx
|
13
|
-
from flax.nnx import rnglib
|
14
|
-
from flax.nnx.module import Module
|
10
|
+
from flax.nnx import rnglib
|
11
|
+
from flax.nnx.module import Module
|
15
12
|
from flax.nnx.nn import dtypes, initializers
|
16
13
|
from flax.typing import (
|
17
14
|
Dtype,
|
18
|
-
Shape,
|
19
15
|
Initializer,
|
20
16
|
PrecisionLike,
|
21
17
|
DotGeneralT,
|
22
|
-
ConvGeneralDilatedT,
|
23
|
-
PaddingLike,
|
24
|
-
LaxPadding,
|
25
18
|
PromoteDtypeFn,
|
26
|
-
EinsumT,
|
27
19
|
)
|
28
20
|
|
29
21
|
Array = jax.Array
|
@@ -60,21 +52,26 @@ class YatNMN(Module):
|
|
60
52
|
in_features: the number of input features.
|
61
53
|
out_features: the number of output features.
|
62
54
|
use_bias: whether to add a bias to the output (default: True).
|
55
|
+
use_alpha: whether to use alpha scaling (default: True).
|
56
|
+
use_dropconnect: whether to use DropConnect (default: False).
|
63
57
|
dtype: the dtype of the computation (default: infer from input and params).
|
64
58
|
param_dtype: the dtype passed to parameter initializers (default: float32).
|
65
59
|
precision: numerical precision of the computation see ``jax.lax.Precision``
|
66
60
|
for details.
|
67
61
|
kernel_init: initializer function for the weight matrix.
|
68
62
|
bias_init: initializer function for the bias.
|
63
|
+
alpha_init: initializer function for the alpha.
|
69
64
|
dot_general: dot product function.
|
70
65
|
promote_dtype: function to promote the dtype of the arrays to the desired
|
71
66
|
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
|
72
67
|
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
73
68
|
promoted dtype.
|
69
|
+
epsilon: A small float added to the denominator to prevent division by zero.
|
70
|
+
drop_rate: dropout rate for DropConnect (default: 0.0).
|
74
71
|
rngs: rng key.
|
75
72
|
"""
|
76
73
|
|
77
|
-
__data__ = ('kernel', 'bias')
|
74
|
+
__data__ = ('kernel', 'bias', 'alpha', 'dropconnect_key')
|
78
75
|
|
79
76
|
def __init__(
|
80
77
|
self,
|
@@ -83,6 +80,7 @@ class YatNMN(Module):
|
|
83
80
|
*,
|
84
81
|
use_bias: bool = True,
|
85
82
|
use_alpha: bool = True,
|
83
|
+
use_dropconnect: bool = False,
|
86
84
|
dtype: tp.Optional[Dtype] = None,
|
87
85
|
param_dtype: Dtype = jnp.float32,
|
88
86
|
precision: PrecisionLike = None,
|
@@ -91,8 +89,9 @@ class YatNMN(Module):
|
|
91
89
|
alpha_init: Initializer = default_alpha_init,
|
92
90
|
dot_general: DotGeneralT = lax.dot_general,
|
93
91
|
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
94
|
-
rngs: rnglib.Rngs,
|
95
92
|
epsilon: float = 1e-5,
|
93
|
+
drop_rate: float = 0.0,
|
94
|
+
rngs: rnglib.Rngs,
|
96
95
|
):
|
97
96
|
|
98
97
|
kernel_key = rngs.params()
|
@@ -117,6 +116,7 @@ class YatNMN(Module):
|
|
117
116
|
self.out_features = out_features
|
118
117
|
self.use_bias = use_bias
|
119
118
|
self.use_alpha = use_alpha
|
119
|
+
self.use_dropconnect = use_dropconnect
|
120
120
|
self.dtype = dtype
|
121
121
|
self.param_dtype = param_dtype
|
122
122
|
self.precision = precision
|
@@ -125,12 +125,19 @@ class YatNMN(Module):
|
|
125
125
|
self.dot_general = dot_general
|
126
126
|
self.promote_dtype = promote_dtype
|
127
127
|
self.epsilon = epsilon
|
128
|
+
self.drop_rate = drop_rate
|
129
|
+
|
130
|
+
if use_dropconnect:
|
131
|
+
self.dropconnect_key = rngs.params()
|
132
|
+
else:
|
133
|
+
self.dropconnect_key = None
|
128
134
|
|
129
|
-
def __call__(self, inputs: Array) -> Array:
|
135
|
+
def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
|
130
136
|
"""Applies a linear transformation to the inputs along the last dimension.
|
131
137
|
|
132
138
|
Args:
|
133
139
|
inputs: The nd-array to be transformed.
|
140
|
+
deterministic: If true, DropConnect is not applied (e.g., during inference).
|
134
141
|
|
135
142
|
Returns:
|
136
143
|
The transformed input.
|
@@ -139,6 +146,11 @@ class YatNMN(Module):
|
|
139
146
|
bias = self.bias.value if self.bias is not None else None
|
140
147
|
alpha = self.alpha.value if self.alpha is not None else None
|
141
148
|
|
149
|
+
if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
|
150
|
+
keep_prob = 1.0 - self.drop_rate
|
151
|
+
mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel.shape)
|
152
|
+
kernel = (kernel * mask) / keep_prob
|
153
|
+
|
142
154
|
inputs, kernel, bias, alpha = self.promote_dtype(
|
143
155
|
(inputs, kernel, bias, alpha), dtype=self.dtype
|
144
156
|
)
|
@@ -166,5 +178,4 @@ class YatNMN(Module):
|
|
166
178
|
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
167
179
|
y = y * scale
|
168
180
|
|
169
|
-
|
170
181
|
return y
|
nmn/nnx/yatconv.py
CHANGED
@@ -110,6 +110,8 @@ class YatConv(Module):
|
|
110
110
|
feature_group_count: integer, default 1. If specified divides the input
|
111
111
|
features into groups.
|
112
112
|
use_bias: whether to add a bias to the output (default: True).
|
113
|
+
use_alpha: whether to use alpha scaling (default: True).
|
114
|
+
use_dropconnect: whether to use DropConnect (default: False).
|
113
115
|
mask: Optional mask for the weights during masked convolution. The mask must
|
114
116
|
be the same shape as the convolution weight matrix.
|
115
117
|
dtype: the dtype of the computation (default: infer from input and params).
|
@@ -123,10 +125,11 @@ class YatConv(Module):
|
|
123
125
|
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
124
126
|
promoted dtype.
|
125
127
|
epsilon: A small float added to the denominator to prevent division by zero.
|
128
|
+
drop_rate: dropout rate for DropConnect (default: 0.0).
|
126
129
|
rngs: rng key.
|
127
130
|
"""
|
128
131
|
|
129
|
-
__data__ = ('kernel', 'bias', 'mask')
|
132
|
+
__data__ = ('kernel', 'bias', 'mask', 'dropconnect_key')
|
130
133
|
|
131
134
|
def __init__(
|
132
135
|
self,
|
@@ -142,6 +145,7 @@ class YatConv(Module):
|
|
142
145
|
|
143
146
|
use_bias: bool = True,
|
144
147
|
use_alpha: bool = True,
|
148
|
+
use_dropconnect: bool = False,
|
145
149
|
kernel_init: Initializer = default_kernel_init,
|
146
150
|
bias_init: Initializer = default_bias_init,
|
147
151
|
alpha_init: Initializer = default_alpha_init,
|
@@ -153,6 +157,7 @@ class YatConv(Module):
|
|
153
157
|
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
|
154
158
|
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
155
159
|
epsilon: float = 1e-5,
|
160
|
+
drop_rate: float = 0.0,
|
156
161
|
rngs: rnglib.Rngs,
|
157
162
|
):
|
158
163
|
if isinstance(kernel_size, int):
|
@@ -185,6 +190,7 @@ class YatConv(Module):
|
|
185
190
|
self.feature_group_count = feature_group_count
|
186
191
|
self.use_bias = use_bias
|
187
192
|
self.use_alpha = use_alpha
|
193
|
+
self.use_dropconnect = use_dropconnect
|
188
194
|
|
189
195
|
self.mask = mask
|
190
196
|
self.dtype = dtype
|
@@ -195,6 +201,7 @@ class YatConv(Module):
|
|
195
201
|
self.conv_general_dilated = conv_general_dilated
|
196
202
|
self.promote_dtype = promote_dtype
|
197
203
|
self.epsilon = epsilon
|
204
|
+
self.drop_rate = drop_rate
|
198
205
|
|
199
206
|
if use_alpha:
|
200
207
|
alpha_key = rngs.params()
|
@@ -202,8 +209,12 @@ class YatConv(Module):
|
|
202
209
|
else:
|
203
210
|
self.alpha = None
|
204
211
|
|
212
|
+
if use_dropconnect:
|
213
|
+
self.dropconnect_key = rngs.params()
|
214
|
+
else:
|
215
|
+
self.dropconnect_key = None
|
205
216
|
|
206
|
-
def __call__(self, inputs: Array) -> Array:
|
217
|
+
def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
|
207
218
|
assert isinstance(self.kernel_size, tuple)
|
208
219
|
|
209
220
|
def maybe_broadcast(
|
@@ -261,6 +272,12 @@ class YatConv(Module):
|
|
261
272
|
|
262
273
|
kernel_val = self.kernel.value
|
263
274
|
|
275
|
+
# Apply DropConnect if enabled and not in deterministic mode
|
276
|
+
if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
|
277
|
+
keep_prob = 1.0 - self.drop_rate
|
278
|
+
mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel_val.shape)
|
279
|
+
kernel_val = (kernel_val * mask) / keep_prob
|
280
|
+
|
264
281
|
current_mask = self.mask
|
265
282
|
if current_mask is not None:
|
266
283
|
if current_mask.shape != self.kernel_shape:
|
@@ -0,0 +1,176 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: nmn
|
3
|
+
Version: 0.1.5
|
4
|
+
Summary: a neuron that matter
|
5
|
+
Project-URL: Homepage, https://github.com/mlnomadpy/nmn
|
6
|
+
Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
|
7
|
+
Author-email: Taha Bouhsine <yat@mlnomads.com>
|
8
|
+
License-File: LICENSE
|
9
|
+
Classifier: License :: OSI Approved :: GNU Affero General Public License v3
|
10
|
+
Classifier: Operating System :: OS Independent
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
12
|
+
Requires-Python: >=3.8
|
13
|
+
Description-Content-Type: text/markdown
|
14
|
+
|
15
|
+
# nmn
|
16
|
+
Not the neurons we want, but the neurons we need
|
17
|
+
|
18
|
+
[](https://pypi.org/project/nmn/)
|
19
|
+
[](https://pepy.tech/project/nmn)
|
20
|
+
[](https://pepy.tech/project/nmn)
|
21
|
+
[](https://github.com/mlnomadpy/nmn)
|
22
|
+
[](https://github.com/mlnomadpy/nmn)
|
23
|
+
[](https://github.com/mlnomadpy/nmn/issues)
|
24
|
+
[](https://pypi.org/project/nmn/)
|
25
|
+
[](https://pypi.org/project/nmn/)
|
26
|
+
|
27
|
+
## Features
|
28
|
+
|
29
|
+
* **Activation-Free Non-linearity:** Learns complex, non-linear relationships without separate activation functions.
|
30
|
+
* **Multiple Frameworks:** Supports Flax (Linen & NNX), Keras, PyTorch, and TensorFlow.
|
31
|
+
* **Yat-Product & Yat-Conv:** Implements novel Yat-Product and Yat-Conv operations.
|
32
|
+
* **Inspired by Research:** Based on the principles from "Deep Learning 2.0/2.1: Artificial Neurons that Matter".
|
33
|
+
|
34
|
+
## Overview
|
35
|
+
|
36
|
+
**nmn** provides neural network layers for multiple frameworks (Flax, NNX, Keras, PyTorch, TensorFlow) that do not require activation functions to learn non-linearity. The main goal is to enable deep learning architectures where the layer itself is inherently non-linear, inspired by the papers:
|
37
|
+
|
38
|
+
> Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
|
39
|
+
>
|
40
|
+
> Deep Learning 2.1: Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks
|
41
|
+
|
42
|
+
## Math
|
43
|
+
|
44
|
+
Yat-Product:
|
45
|
+
$$
|
46
|
+
ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
|
47
|
+
$$
|
48
|
+
|
49
|
+
**Explanation:**
|
50
|
+
- $\mathbf{w}$ is the weight vector, $\mathbf{x}$ is the input vector.
|
51
|
+
- $\langle \mathbf{w}, \mathbf{x} \rangle$ is the dot product between $\mathbf{w}$ and $\mathbf{x}$.
|
52
|
+
- $\|\mathbf{w} - \mathbf{x}\|^2$ is the squared Euclidean distance between $\mathbf{w}$ and $\mathbf{x}$.
|
53
|
+
- $\epsilon$ is a small constant for numerical stability.
|
54
|
+
- $\theta$ is the angle between $\mathbf{w}$ and $\mathbf{x}$.
|
55
|
+
|
56
|
+
This operation:
|
57
|
+
- **Numerator:** Squares the similarity (dot product) between $\mathbf{w}$ and $\mathbf{x}$, emphasizing strong alignments.
|
58
|
+
- **Denominator:** Penalizes large distances, so the response is high only when $\mathbf{w}$ and $\mathbf{x}$ are both similar in direction and close in space.
|
59
|
+
- **No activation needed:** The non-linearity is built into the operation itself, allowing the layer to learn complex, non-linear relationships without a separate activation function.
|
60
|
+
- **Geometric view:** The output is maximized when $\mathbf{w}$ and $\mathbf{x}$ are both large in norm, closely aligned (small $\theta$), and close together in Euclidean space.
|
61
|
+
|
62
|
+
Yat-Conv:
|
63
|
+
$$
|
64
|
+
ⵟ^*(\mathbf{W}, \mathbf{X}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon}
|
65
|
+
= \frac{\left(\sum_{i,j} w_{ij} x_{ij}\right)^2}{\sum_{i,j} (w_{ij} - x_{ij})^2 + \epsilon}
|
66
|
+
$$
|
67
|
+
|
68
|
+
Where:
|
69
|
+
- $\mathbf{W}$ and $\mathbf{X}$ are local patches (e.g., kernel and input patch in convolution)
|
70
|
+
- $w_{ij}$ and $x_{ij}$ are elements of the kernel and input patch, respectively
|
71
|
+
- $\epsilon$ is a small constant for numerical stability
|
72
|
+
|
73
|
+
This generalizes the Yat-product to convolutional (patch-wise) operations.
|
74
|
+
|
75
|
+
|
76
|
+
## Supported Frameworks & API
|
77
|
+
|
78
|
+
The `YatNMN` layer (for dense operations) and `YatConv` (for convolutional operations) are the core components. Below is a summary of their availability and features per framework:
|
79
|
+
|
80
|
+
| Framework | `YatNMN` Path | `YatConv` Path | Core Layer | DropConnect | Ternary Network | Recurrent Layer |
|
81
|
+
|----------------|-------------------------------|-------------------------------|------------|-------------|-----------------|-----------------|
|
82
|
+
| **Flax (Linen)** | `src/nmn/linen/nmn.py` | (Available) | ✅ | | | 🚧 |
|
83
|
+
| **Flax (NNX)** | `src/nmn/nnx/nmn.py` | `src/nmn/nnx/yatconv.py` | ✅ | ✅ | 🚧 | 🚧 |
|
84
|
+
| **Keras** | `src/nmn/keras/nmn.py` | (Available) | ✅ | | | 🚧 |
|
85
|
+
| **PyTorch** | `src/nmn/torch/nmn.py` | (Available) | ✅ | | | 🚧 |
|
86
|
+
| **TensorFlow** | `src/nmn/tf/nmn.py` | (Available) | ✅ | | | 🚧 |
|
87
|
+
|
88
|
+
*Legend: ✅ Implemented, 🚧 To be implemented / In Progress, (Available) - Assumed available if NMN is, specific path might vary or be part of the NMN module.*
|
89
|
+
|
90
|
+
## Installation
|
91
|
+
|
92
|
+
```bash
|
93
|
+
pip install nmn
|
94
|
+
```
|
95
|
+
|
96
|
+
## Usage Example (Flax NNX)
|
97
|
+
|
98
|
+
```python
|
99
|
+
import jax
|
100
|
+
import jax.numpy as jnp
|
101
|
+
from flax import nnx
|
102
|
+
from nmn.nnx.nmn import YatNMN
|
103
|
+
from nmn.nnx.yatconv import YatConv
|
104
|
+
|
105
|
+
# Example YatNMN (Dense Layer)
|
106
|
+
model_key, param_key, drop_key, input_key = jax.random.split(jax.random.key(0), 4)
|
107
|
+
in_features, out_features = 3, 4
|
108
|
+
layer = YatNMN(in_features=in_features, out_features=out_features, rngs=nnx.Rngs(params=param_key, dropout=drop_key))
|
109
|
+
dummy_input = jax.random.normal(input_key, (2, in_features)) # Batch size 2
|
110
|
+
output = layer(dummy_input)
|
111
|
+
print("YatNMN Output Shape:", output.shape)
|
112
|
+
|
113
|
+
# Example YatConv (Convolutional Layer)
|
114
|
+
conv_key, conv_param_key, conv_input_key = jax.random.split(jax.random.key(1), 3)
|
115
|
+
in_channels, out_channels = 3, 8
|
116
|
+
kernel_size = (3, 3)
|
117
|
+
conv_layer = YatConv(
|
118
|
+
in_features=in_channels,
|
119
|
+
out_features=out_channels,
|
120
|
+
kernel_size=kernel_size,
|
121
|
+
rngs=nnx.Rngs(params=conv_param_key)
|
122
|
+
)
|
123
|
+
dummy_conv_input = jax.random.normal(conv_input_key, (1, 28, 28, in_channels)) # Batch 1, 28x28 image, in_channels
|
124
|
+
conv_output = conv_layer(dummy_conv_input)
|
125
|
+
print("YatConv Output Shape:", conv_output.shape)
|
126
|
+
|
127
|
+
```
|
128
|
+
*Note: Examples for other frameworks (Keras, PyTorch, TensorFlow, Flax Linen) can be found in their respective `nmn.<framework>` modules and upcoming documentation.*
|
129
|
+
|
130
|
+
## Roadmap
|
131
|
+
|
132
|
+
- [ ] Implement recurrent layers (`YatRNN`, `YatLSTM`, `YatGRU`) for all supported frameworks.
|
133
|
+
- [ ] Develop Ternary Network versions of Yat layers for NNX.
|
134
|
+
- [ ] Add more comprehensive examples and benchmark scripts for various tasks (vision, language).
|
135
|
+
- [ ] Publish detailed documentation and API references.
|
136
|
+
- [ ] Conduct and publish thorough performance benchmarks against traditional layers.
|
137
|
+
|
138
|
+
## Contributing
|
139
|
+
|
140
|
+
Contributions are welcome! If you'd like to contribute, please feel free to:
|
141
|
+
- Open an issue on the [Bug Tracker](https://github.com/mlnomadpy/nmn/issues) to report bugs or suggest features.
|
142
|
+
- Submit a pull request with your improvements.
|
143
|
+
- Help expand the documentation or add more examples.
|
144
|
+
|
145
|
+
## License
|
146
|
+
|
147
|
+
This project is licensed under the **GNU Affero General Public License v3**. See the [LICENSE](LICENSE) file for details.
|
148
|
+
|
149
|
+
## Citation
|
150
|
+
|
151
|
+
If you use `nmn` in your research, please consider citing the original papers that inspired this work:
|
152
|
+
|
153
|
+
> Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
|
154
|
+
>
|
155
|
+
> Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks
|
156
|
+
|
157
|
+
A BibTeX entry will be provided once the accompanying paper for this library is published.
|
158
|
+
|
159
|
+
## Citing
|
160
|
+
|
161
|
+
If you use this work, please cite the paper:
|
162
|
+
|
163
|
+
```bibtex
|
164
|
+
@article{taha2024dl2,
|
165
|
+
author = {Taha Bouhsine},
|
166
|
+
title = {Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality},
|
167
|
+
}
|
168
|
+
```
|
169
|
+
|
170
|
+
|
171
|
+
```bibtex
|
172
|
+
@article{taha2025dl2,
|
173
|
+
author = {Taha Bouhsine},
|
174
|
+
title = {Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks},
|
175
|
+
}
|
176
|
+
```
|
@@ -1,14 +1,14 @@
|
|
1
1
|
nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
|
2
2
|
nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
|
3
3
|
nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
|
4
|
-
nmn/nnx/nmn.py,sha256=
|
4
|
+
nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
|
5
5
|
nmn/nnx/yatattention.py,sha256=chjtUKJtaR7ROPnNqkicbvMs7hzZKE0fIo_8cTNiju8,26601
|
6
|
-
nmn/nnx/yatconv.py,sha256=
|
6
|
+
nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
|
7
7
|
nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
|
8
8
|
nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
|
9
9
|
nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
|
10
10
|
nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
|
11
|
-
nmn-0.1.
|
12
|
-
nmn-0.1.
|
13
|
-
nmn-0.1.
|
14
|
-
nmn-0.1.
|
11
|
+
nmn-0.1.5.dist-info/METADATA,sha256=7gvXle6Hgdgyj_tJk1DGdkOh03BOsfSks-ZHPOIEwHQ,8800
|
12
|
+
nmn-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
13
|
+
nmn-0.1.5.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
|
14
|
+
nmn-0.1.5.dist-info/RECORD,,
|
nmn-0.1.4.dist-info/METADATA
DELETED
@@ -1,119 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.4
|
2
|
-
Name: nmn
|
3
|
-
Version: 0.1.4
|
4
|
-
Summary: a neuron that matter
|
5
|
-
Project-URL: Homepage, https://github.com/mlnomadpy/nmn
|
6
|
-
Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
|
7
|
-
Author-email: Taha Bouhsine <yat@mlnomads.com>
|
8
|
-
License-File: LICENSE
|
9
|
-
Classifier: License :: OSI Approved :: GNU Affero General Public License v3
|
10
|
-
Classifier: Operating System :: OS Independent
|
11
|
-
Classifier: Programming Language :: Python :: 3
|
12
|
-
Requires-Python: >=3.8
|
13
|
-
Description-Content-Type: text/markdown
|
14
|
-
|
15
|
-
# nmn
|
16
|
-
Not the neurons we want, but the neurons we need
|
17
|
-
|
18
|
-
[](https://pypi.org/project/nmn/)
|
19
|
-
[](https://pepy.tech/project/nmn)
|
20
|
-
[](https://pepy.tech/project/nmn)
|
21
|
-
[](https://github.com/mlnomadpy/nmn)
|
22
|
-
[](https://github.com/mlnomadpy/nmn)
|
23
|
-
[](https://github.com/mlnomadpy/nmn/issues)
|
24
|
-
[](https://pypi.org/project/nmn/)
|
25
|
-
[](https://pypi.org/project/nmn/)
|
26
|
-
|
27
|
-
## Overview
|
28
|
-
|
29
|
-
**nmn** provides neural network layers for multiple frameworks (Flax, NNX, Keras, PyTorch, TensorFlow) that do not require activation functions to learn non-linearity. The main goal is to enable deep learning architectures where the layer itself is inherently non-linear, inspired by the paper:
|
30
|
-
|
31
|
-
> Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
|
32
|
-
|
33
|
-
## Math
|
34
|
-
|
35
|
-
Yat-Product:
|
36
|
-
$$
|
37
|
-
ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
|
38
|
-
$$
|
39
|
-
|
40
|
-
**Explanation:**
|
41
|
-
- $\mathbf{w}$ is the weight vector, $\mathbf{x}$ is the input vector.
|
42
|
-
- $\langle \mathbf{w}, \mathbf{x} \rangle$ is the dot product between $\mathbf{w}$ and $\mathbf{x}$.
|
43
|
-
- $\|\mathbf{w} - \mathbf{x}\|^2$ is the squared Euclidean distance between $\mathbf{w}$ and $\mathbf{x}$.
|
44
|
-
- $\epsilon$ is a small constant for numerical stability.
|
45
|
-
- $\theta$ is the angle between $\mathbf{w}$ and $\mathbf{x}$.
|
46
|
-
|
47
|
-
This operation:
|
48
|
-
- **Numerator:** Squares the similarity (dot product) between $\mathbf{w}$ and $\mathbf{x}$, emphasizing strong alignments.
|
49
|
-
- **Denominator:** Penalizes large distances, so the response is high only when $\mathbf{w}$ and $\mathbf{x}$ are both similar in direction and close in space.
|
50
|
-
- **No activation needed:** The non-linearity is built into the operation itself, allowing the layer to learn complex, non-linear relationships without a separate activation function.
|
51
|
-
- **Geometric view:** The output is maximized when $\mathbf{w}$ and $\mathbf{x}$ are both large in norm, closely aligned (small $\theta$), and close together in Euclidean space.
|
52
|
-
|
53
|
-
Yat-Conv:
|
54
|
-
$$
|
55
|
-
ⵟ^*(\mathbf{W}, \mathbf{X}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon}
|
56
|
-
= \frac{\left(\sum_{i,j} w_{ij} x_{ij}\right)^2}{\sum_{i,j} (w_{ij} - x_{ij})^2 + \epsilon}
|
57
|
-
$$
|
58
|
-
|
59
|
-
Where:
|
60
|
-
- $\mathbf{W}$ and $\mathbf{X}$ are local patches (e.g., kernel and input patch in convolution)
|
61
|
-
- $w_{ij}$ and $x_{ij}$ are elements of the kernel and input patch, respectively
|
62
|
-
- $\epsilon$ is a small constant for numerical stability
|
63
|
-
|
64
|
-
This generalizes the Yat-product to convolutional (patch-wise) operations.
|
65
|
-
|
66
|
-
|
67
|
-
## Supported Frameworks & Tasks
|
68
|
-
|
69
|
-
### Flax (JAX)
|
70
|
-
- `YatNMN` layer implemented in `src/nmn/linen/nmn.py`
|
71
|
-
- **Tasks:**
|
72
|
-
- [x] Core layer implementation
|
73
|
-
- [ ] Recurrent layer (to be implemented)
|
74
|
-
|
75
|
-
### NNX (Flax NNX)
|
76
|
-
- `YatNMN` layer implemented in `src/nmn/nnx/nmn.py`
|
77
|
-
- **Tasks:**
|
78
|
-
- [x] Core layer implementation
|
79
|
-
- [ ] Recurrent layer (to be implemented)
|
80
|
-
|
81
|
-
### Keras
|
82
|
-
- `YatNMN` layer implemented in `src/nmn/keras/nmn.py`
|
83
|
-
- **Tasks:**
|
84
|
-
- [x] Core layer implementation
|
85
|
-
- [ ] Recurrent layer (to be implemented)
|
86
|
-
|
87
|
-
### PyTorch
|
88
|
-
- `YatNMN` layer implemented in `src/nmn/torch/nmn.py`
|
89
|
-
- **Tasks:**
|
90
|
-
- [x] Core layer implementation
|
91
|
-
- [ ] Recurrent layer (to be implemented)
|
92
|
-
|
93
|
-
### TensorFlow
|
94
|
-
- `YatNMN` layer implemented in `src/nmn/tf/nmn.py`
|
95
|
-
- **Tasks:**
|
96
|
-
- [x] Core layer implementation
|
97
|
-
- [ ] Recurrent layer (to be implemented)
|
98
|
-
|
99
|
-
## Installation
|
100
|
-
|
101
|
-
```bash
|
102
|
-
pip install nmn
|
103
|
-
```
|
104
|
-
|
105
|
-
## Usage Example (Flax)
|
106
|
-
|
107
|
-
```python
|
108
|
-
from nmn.nnx.nmn import YatNMN
|
109
|
-
from nmn.nnx.yatconv import YatConv
|
110
|
-
# ... use as a Flax module ...
|
111
|
-
```
|
112
|
-
|
113
|
-
## Roadmap
|
114
|
-
- [ ] Implement recurrent layers for all frameworks
|
115
|
-
- [ ] Add more examples and benchmarks
|
116
|
-
- [ ] Improve documentation and API consistency
|
117
|
-
|
118
|
-
## License
|
119
|
-
GNU Affero General Public License v3
|
File without changes
|
File without changes
|