nmn 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nmn/nnx/nmn.py CHANGED
@@ -4,26 +4,18 @@ import typing as tp
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
- import numpy as np
8
7
  from jax import lax
9
- import opt_einsum
10
8
 
11
- from flax.core.frozen_dict import FrozenDict
12
9
  from flax import nnx
13
- from flax.nnx import rnglib, variablelib
14
- from flax.nnx.module import Module, first_from
10
+ from flax.nnx import rnglib
11
+ from flax.nnx.module import Module
15
12
  from flax.nnx.nn import dtypes, initializers
16
13
  from flax.typing import (
17
14
  Dtype,
18
- Shape,
19
15
  Initializer,
20
16
  PrecisionLike,
21
17
  DotGeneralT,
22
- ConvGeneralDilatedT,
23
- PaddingLike,
24
- LaxPadding,
25
18
  PromoteDtypeFn,
26
- EinsumT,
27
19
  )
28
20
 
29
21
  Array = jax.Array
@@ -60,21 +52,26 @@ class YatNMN(Module):
60
52
  in_features: the number of input features.
61
53
  out_features: the number of output features.
62
54
  use_bias: whether to add a bias to the output (default: True).
55
+ use_alpha: whether to use alpha scaling (default: True).
56
+ use_dropconnect: whether to use DropConnect (default: False).
63
57
  dtype: the dtype of the computation (default: infer from input and params).
64
58
  param_dtype: the dtype passed to parameter initializers (default: float32).
65
59
  precision: numerical precision of the computation see ``jax.lax.Precision``
66
60
  for details.
67
61
  kernel_init: initializer function for the weight matrix.
68
62
  bias_init: initializer function for the bias.
63
+ alpha_init: initializer function for the alpha.
69
64
  dot_general: dot product function.
70
65
  promote_dtype: function to promote the dtype of the arrays to the desired
71
66
  dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
72
67
  and a ``dtype`` keyword argument, and return a tuple of arrays with the
73
68
  promoted dtype.
69
+ epsilon: A small float added to the denominator to prevent division by zero.
70
+ drop_rate: dropout rate for DropConnect (default: 0.0).
74
71
  rngs: rng key.
75
72
  """
76
73
 
77
- __data__ = ('kernel', 'bias')
74
+ __data__ = ('kernel', 'bias', 'alpha', 'dropconnect_key')
78
75
 
79
76
  def __init__(
80
77
  self,
@@ -83,6 +80,7 @@ class YatNMN(Module):
83
80
  *,
84
81
  use_bias: bool = True,
85
82
  use_alpha: bool = True,
83
+ use_dropconnect: bool = False,
86
84
  dtype: tp.Optional[Dtype] = None,
87
85
  param_dtype: Dtype = jnp.float32,
88
86
  precision: PrecisionLike = None,
@@ -91,8 +89,9 @@ class YatNMN(Module):
91
89
  alpha_init: Initializer = default_alpha_init,
92
90
  dot_general: DotGeneralT = lax.dot_general,
93
91
  promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
94
- rngs: rnglib.Rngs,
95
92
  epsilon: float = 1e-5,
93
+ drop_rate: float = 0.0,
94
+ rngs: rnglib.Rngs,
96
95
  ):
97
96
 
98
97
  kernel_key = rngs.params()
@@ -117,6 +116,7 @@ class YatNMN(Module):
117
116
  self.out_features = out_features
118
117
  self.use_bias = use_bias
119
118
  self.use_alpha = use_alpha
119
+ self.use_dropconnect = use_dropconnect
120
120
  self.dtype = dtype
121
121
  self.param_dtype = param_dtype
122
122
  self.precision = precision
@@ -125,12 +125,19 @@ class YatNMN(Module):
125
125
  self.dot_general = dot_general
126
126
  self.promote_dtype = promote_dtype
127
127
  self.epsilon = epsilon
128
+ self.drop_rate = drop_rate
129
+
130
+ if use_dropconnect:
131
+ self.dropconnect_key = rngs.params()
132
+ else:
133
+ self.dropconnect_key = None
128
134
 
129
- def __call__(self, inputs: Array) -> Array:
135
+ def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
130
136
  """Applies a linear transformation to the inputs along the last dimension.
131
137
 
132
138
  Args:
133
139
  inputs: The nd-array to be transformed.
140
+ deterministic: If true, DropConnect is not applied (e.g., during inference).
134
141
 
135
142
  Returns:
136
143
  The transformed input.
@@ -139,6 +146,11 @@ class YatNMN(Module):
139
146
  bias = self.bias.value if self.bias is not None else None
140
147
  alpha = self.alpha.value if self.alpha is not None else None
141
148
 
149
+ if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
150
+ keep_prob = 1.0 - self.drop_rate
151
+ mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel.shape)
152
+ kernel = (kernel * mask) / keep_prob
153
+
142
154
  inputs, kernel, bias, alpha = self.promote_dtype(
143
155
  (inputs, kernel, bias, alpha), dtype=self.dtype
144
156
  )
@@ -166,5 +178,4 @@ class YatNMN(Module):
166
178
  scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
167
179
  y = y * scale
168
180
 
169
-
170
181
  return y
nmn/nnx/yatconv.py CHANGED
@@ -110,6 +110,8 @@ class YatConv(Module):
110
110
  feature_group_count: integer, default 1. If specified divides the input
111
111
  features into groups.
112
112
  use_bias: whether to add a bias to the output (default: True).
113
+ use_alpha: whether to use alpha scaling (default: True).
114
+ use_dropconnect: whether to use DropConnect (default: False).
113
115
  mask: Optional mask for the weights during masked convolution. The mask must
114
116
  be the same shape as the convolution weight matrix.
115
117
  dtype: the dtype of the computation (default: infer from input and params).
@@ -123,10 +125,11 @@ class YatConv(Module):
123
125
  and a ``dtype`` keyword argument, and return a tuple of arrays with the
124
126
  promoted dtype.
125
127
  epsilon: A small float added to the denominator to prevent division by zero.
128
+ drop_rate: dropout rate for DropConnect (default: 0.0).
126
129
  rngs: rng key.
127
130
  """
128
131
 
129
- __data__ = ('kernel', 'bias', 'mask')
132
+ __data__ = ('kernel', 'bias', 'mask', 'dropconnect_key')
130
133
 
131
134
  def __init__(
132
135
  self,
@@ -142,6 +145,7 @@ class YatConv(Module):
142
145
 
143
146
  use_bias: bool = True,
144
147
  use_alpha: bool = True,
148
+ use_dropconnect: bool = False,
145
149
  kernel_init: Initializer = default_kernel_init,
146
150
  bias_init: Initializer = default_bias_init,
147
151
  alpha_init: Initializer = default_alpha_init,
@@ -153,6 +157,7 @@ class YatConv(Module):
153
157
  conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
154
158
  promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
155
159
  epsilon: float = 1e-5,
160
+ drop_rate: float = 0.0,
156
161
  rngs: rnglib.Rngs,
157
162
  ):
158
163
  if isinstance(kernel_size, int):
@@ -185,6 +190,7 @@ class YatConv(Module):
185
190
  self.feature_group_count = feature_group_count
186
191
  self.use_bias = use_bias
187
192
  self.use_alpha = use_alpha
193
+ self.use_dropconnect = use_dropconnect
188
194
 
189
195
  self.mask = mask
190
196
  self.dtype = dtype
@@ -195,6 +201,7 @@ class YatConv(Module):
195
201
  self.conv_general_dilated = conv_general_dilated
196
202
  self.promote_dtype = promote_dtype
197
203
  self.epsilon = epsilon
204
+ self.drop_rate = drop_rate
198
205
 
199
206
  if use_alpha:
200
207
  alpha_key = rngs.params()
@@ -202,8 +209,12 @@ class YatConv(Module):
202
209
  else:
203
210
  self.alpha = None
204
211
 
212
+ if use_dropconnect:
213
+ self.dropconnect_key = rngs.params()
214
+ else:
215
+ self.dropconnect_key = None
205
216
 
206
- def __call__(self, inputs: Array) -> Array:
217
+ def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
207
218
  assert isinstance(self.kernel_size, tuple)
208
219
 
209
220
  def maybe_broadcast(
@@ -261,6 +272,12 @@ class YatConv(Module):
261
272
 
262
273
  kernel_val = self.kernel.value
263
274
 
275
+ # Apply DropConnect if enabled and not in deterministic mode
276
+ if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
277
+ keep_prob = 1.0 - self.drop_rate
278
+ mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel_val.shape)
279
+ kernel_val = (kernel_val * mask) / keep_prob
280
+
264
281
  current_mask = self.mask
265
282
  if current_mask is not None:
266
283
  if current_mask.shape != self.kernel_shape:
@@ -0,0 +1,176 @@
1
+ Metadata-Version: 2.4
2
+ Name: nmn
3
+ Version: 0.1.5
4
+ Summary: a neuron that matter
5
+ Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
+ Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
7
+ Author-email: Taha Bouhsine <yat@mlnomads.com>
8
+ License-File: LICENSE
9
+ Classifier: License :: OSI Approved :: GNU Affero General Public License v3
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+
15
+ # nmn
16
+ Not the neurons we want, but the neurons we need
17
+
18
+ [![PyPI version](https://img.shields.io/pypi/v/nmn.svg)](https://pypi.org/project/nmn/)
19
+ [![Downloads](https://static.pepy.tech/badge/nmn)](https://pepy.tech/project/nmn)
20
+ [![Downloads/month](https://static.pepy.tech/badge/nmn/month)](https://pepy.tech/project/nmn)
21
+ [![GitHub stars](https://img.shields.io/github/stars/mlnomadpy/nmn?style=social)](https://github.com/mlnomadpy/nmn)
22
+ [![GitHub forks](https://img.shields.io/github/forks/mlnomadpy/nmn?style=social)](https://github.com/mlnomadpy/nmn)
23
+ [![GitHub issues](https://img.shields.io/github/issues/mlnomadpy/nmn)](https://github.com/mlnomadpy/nmn/issues)
24
+ [![PyPI - License](https://img.shields.io/pypi/l/nmn)](https://pypi.org/project/nmn/)
25
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/nmn)](https://pypi.org/project/nmn/)
26
+
27
+ ## Features
28
+
29
+ * **Activation-Free Non-linearity:** Learns complex, non-linear relationships without separate activation functions.
30
+ * **Multiple Frameworks:** Supports Flax (Linen & NNX), Keras, PyTorch, and TensorFlow.
31
+ * **Yat-Product & Yat-Conv:** Implements novel Yat-Product and Yat-Conv operations.
32
+ * **Inspired by Research:** Based on the principles from "Deep Learning 2.0/2.1: Artificial Neurons that Matter".
33
+
34
+ ## Overview
35
+
36
+ **nmn** provides neural network layers for multiple frameworks (Flax, NNX, Keras, PyTorch, TensorFlow) that do not require activation functions to learn non-linearity. The main goal is to enable deep learning architectures where the layer itself is inherently non-linear, inspired by the papers:
37
+
38
+ > Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
39
+ >
40
+ > Deep Learning 2.1: Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks
41
+
42
+ ## Math
43
+
44
+ Yat-Product:
45
+ $$
46
+ ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
47
+ $$
48
+
49
+ **Explanation:**
50
+ - $\mathbf{w}$ is the weight vector, $\mathbf{x}$ is the input vector.
51
+ - $\langle \mathbf{w}, \mathbf{x} \rangle$ is the dot product between $\mathbf{w}$ and $\mathbf{x}$.
52
+ - $\|\mathbf{w} - \mathbf{x}\|^2$ is the squared Euclidean distance between $\mathbf{w}$ and $\mathbf{x}$.
53
+ - $\epsilon$ is a small constant for numerical stability.
54
+ - $\theta$ is the angle between $\mathbf{w}$ and $\mathbf{x}$.
55
+
56
+ This operation:
57
+ - **Numerator:** Squares the similarity (dot product) between $\mathbf{w}$ and $\mathbf{x}$, emphasizing strong alignments.
58
+ - **Denominator:** Penalizes large distances, so the response is high only when $\mathbf{w}$ and $\mathbf{x}$ are both similar in direction and close in space.
59
+ - **No activation needed:** The non-linearity is built into the operation itself, allowing the layer to learn complex, non-linear relationships without a separate activation function.
60
+ - **Geometric view:** The output is maximized when $\mathbf{w}$ and $\mathbf{x}$ are both large in norm, closely aligned (small $\theta$), and close together in Euclidean space.
61
+
62
+ Yat-Conv:
63
+ $$
64
+ ⵟ^*(\mathbf{W}, \mathbf{X}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon}
65
+ = \frac{\left(\sum_{i,j} w_{ij} x_{ij}\right)^2}{\sum_{i,j} (w_{ij} - x_{ij})^2 + \epsilon}
66
+ $$
67
+
68
+ Where:
69
+ - $\mathbf{W}$ and $\mathbf{X}$ are local patches (e.g., kernel and input patch in convolution)
70
+ - $w_{ij}$ and $x_{ij}$ are elements of the kernel and input patch, respectively
71
+ - $\epsilon$ is a small constant for numerical stability
72
+
73
+ This generalizes the Yat-product to convolutional (patch-wise) operations.
74
+
75
+
76
+ ## Supported Frameworks & API
77
+
78
+ The `YatNMN` layer (for dense operations) and `YatConv` (for convolutional operations) are the core components. Below is a summary of their availability and features per framework:
79
+
80
+ | Framework | `YatNMN` Path | `YatConv` Path | Core Layer | DropConnect | Ternary Network | Recurrent Layer |
81
+ |----------------|-------------------------------|-------------------------------|------------|-------------|-----------------|-----------------|
82
+ | **Flax (Linen)** | `src/nmn/linen/nmn.py` | (Available) | ✅ | | | 🚧 |
83
+ | **Flax (NNX)** | `src/nmn/nnx/nmn.py` | `src/nmn/nnx/yatconv.py` | ✅ | ✅ | 🚧 | 🚧 |
84
+ | **Keras** | `src/nmn/keras/nmn.py` | (Available) | ✅ | | | 🚧 |
85
+ | **PyTorch** | `src/nmn/torch/nmn.py` | (Available) | ✅ | | | 🚧 |
86
+ | **TensorFlow** | `src/nmn/tf/nmn.py` | (Available) | ✅ | | | 🚧 |
87
+
88
+ *Legend: ✅ Implemented, 🚧 To be implemented / In Progress, (Available) - Assumed available if NMN is, specific path might vary or be part of the NMN module.*
89
+
90
+ ## Installation
91
+
92
+ ```bash
93
+ pip install nmn
94
+ ```
95
+
96
+ ## Usage Example (Flax NNX)
97
+
98
+ ```python
99
+ import jax
100
+ import jax.numpy as jnp
101
+ from flax import nnx
102
+ from nmn.nnx.nmn import YatNMN
103
+ from nmn.nnx.yatconv import YatConv
104
+
105
+ # Example YatNMN (Dense Layer)
106
+ model_key, param_key, drop_key, input_key = jax.random.split(jax.random.key(0), 4)
107
+ in_features, out_features = 3, 4
108
+ layer = YatNMN(in_features=in_features, out_features=out_features, rngs=nnx.Rngs(params=param_key, dropout=drop_key))
109
+ dummy_input = jax.random.normal(input_key, (2, in_features)) # Batch size 2
110
+ output = layer(dummy_input)
111
+ print("YatNMN Output Shape:", output.shape)
112
+
113
+ # Example YatConv (Convolutional Layer)
114
+ conv_key, conv_param_key, conv_input_key = jax.random.split(jax.random.key(1), 3)
115
+ in_channels, out_channels = 3, 8
116
+ kernel_size = (3, 3)
117
+ conv_layer = YatConv(
118
+ in_features=in_channels,
119
+ out_features=out_channels,
120
+ kernel_size=kernel_size,
121
+ rngs=nnx.Rngs(params=conv_param_key)
122
+ )
123
+ dummy_conv_input = jax.random.normal(conv_input_key, (1, 28, 28, in_channels)) # Batch 1, 28x28 image, in_channels
124
+ conv_output = conv_layer(dummy_conv_input)
125
+ print("YatConv Output Shape:", conv_output.shape)
126
+
127
+ ```
128
+ *Note: Examples for other frameworks (Keras, PyTorch, TensorFlow, Flax Linen) can be found in their respective `nmn.<framework>` modules and upcoming documentation.*
129
+
130
+ ## Roadmap
131
+
132
+ - [ ] Implement recurrent layers (`YatRNN`, `YatLSTM`, `YatGRU`) for all supported frameworks.
133
+ - [ ] Develop Ternary Network versions of Yat layers for NNX.
134
+ - [ ] Add more comprehensive examples and benchmark scripts for various tasks (vision, language).
135
+ - [ ] Publish detailed documentation and API references.
136
+ - [ ] Conduct and publish thorough performance benchmarks against traditional layers.
137
+
138
+ ## Contributing
139
+
140
+ Contributions are welcome! If you'd like to contribute, please feel free to:
141
+ - Open an issue on the [Bug Tracker](https://github.com/mlnomadpy/nmn/issues) to report bugs or suggest features.
142
+ - Submit a pull request with your improvements.
143
+ - Help expand the documentation or add more examples.
144
+
145
+ ## License
146
+
147
+ This project is licensed under the **GNU Affero General Public License v3**. See the [LICENSE](LICENSE) file for details.
148
+
149
+ ## Citation
150
+
151
+ If you use `nmn` in your research, please consider citing the original papers that inspired this work:
152
+
153
+ > Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
154
+ >
155
+ > Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks
156
+
157
+ A BibTeX entry will be provided once the accompanying paper for this library is published.
158
+
159
+ ## Citing
160
+
161
+ If you use this work, please cite the paper:
162
+
163
+ ```bibtex
164
+ @article{taha2024dl2,
165
+ author = {Taha Bouhsine},
166
+ title = {Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality},
167
+ }
168
+ ```
169
+
170
+
171
+ ```bibtex
172
+ @article{taha2025dl2,
173
+ author = {Taha Bouhsine},
174
+ title = {Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks},
175
+ }
176
+ ```
@@ -1,14 +1,14 @@
1
1
  nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
2
2
  nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
3
3
  nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
4
- nmn/nnx/nmn.py,sha256=gWe8EL-aUm7be03M9O5R3XdBb92EpBEFsylrY6BA60c,4871
4
+ nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
5
5
  nmn/nnx/yatattention.py,sha256=chjtUKJtaR7ROPnNqkicbvMs7hzZKE0fIo_8cTNiju8,26601
6
- nmn/nnx/yatconv.py,sha256=xUH9NBY1fIDZeTA9GdgmqR_DJiQJgwU2uDrgxqirKmU,12308
6
+ nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
7
7
  nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
8
8
  nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
9
9
  nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
10
10
  nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
11
- nmn-0.1.4.dist-info/METADATA,sha256=k28p055Dr6WWVQcb01uinFRiT5R-CAvdKz33fqZ85g4,5032
12
- nmn-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- nmn-0.1.4.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
14
- nmn-0.1.4.dist-info/RECORD,,
11
+ nmn-0.1.5.dist-info/METADATA,sha256=7gvXle6Hgdgyj_tJk1DGdkOh03BOsfSks-ZHPOIEwHQ,8800
12
+ nmn-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ nmn-0.1.5.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
14
+ nmn-0.1.5.dist-info/RECORD,,
@@ -1,119 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: nmn
3
- Version: 0.1.4
4
- Summary: a neuron that matter
5
- Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
- Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
7
- Author-email: Taha Bouhsine <yat@mlnomads.com>
8
- License-File: LICENSE
9
- Classifier: License :: OSI Approved :: GNU Affero General Public License v3
10
- Classifier: Operating System :: OS Independent
11
- Classifier: Programming Language :: Python :: 3
12
- Requires-Python: >=3.8
13
- Description-Content-Type: text/markdown
14
-
15
- # nmn
16
- Not the neurons we want, but the neurons we need
17
-
18
- [![PyPI version](https://img.shields.io/pypi/v/nmn.svg)](https://pypi.org/project/nmn/)
19
- [![Downloads](https://static.pepy.tech/badge/nmn)](https://pepy.tech/project/nmn)
20
- [![Downloads/month](https://static.pepy.tech/badge/nmn/month)](https://pepy.tech/project/nmn)
21
- [![GitHub stars](https://img.shields.io/github/stars/mlnomadpy/nmn?style=social)](https://github.com/mlnomadpy/nmn)
22
- [![GitHub forks](https://img.shields.io/github/forks/mlnomadpy/nmn?style=social)](https://github.com/mlnomadpy/nmn)
23
- [![GitHub issues](https://img.shields.io/github/issues/mlnomadpy/nmn)](https://github.com/mlnomadpy/nmn/issues)
24
- [![PyPI - License](https://img.shields.io/pypi/l/nmn)](https://pypi.org/project/nmn/)
25
- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/nmn)](https://pypi.org/project/nmn/)
26
-
27
- ## Overview
28
-
29
- **nmn** provides neural network layers for multiple frameworks (Flax, NNX, Keras, PyTorch, TensorFlow) that do not require activation functions to learn non-linearity. The main goal is to enable deep learning architectures where the layer itself is inherently non-linear, inspired by the paper:
30
-
31
- > Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
32
-
33
- ## Math
34
-
35
- Yat-Product:
36
- $$
37
- ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
38
- $$
39
-
40
- **Explanation:**
41
- - $\mathbf{w}$ is the weight vector, $\mathbf{x}$ is the input vector.
42
- - $\langle \mathbf{w}, \mathbf{x} \rangle$ is the dot product between $\mathbf{w}$ and $\mathbf{x}$.
43
- - $\|\mathbf{w} - \mathbf{x}\|^2$ is the squared Euclidean distance between $\mathbf{w}$ and $\mathbf{x}$.
44
- - $\epsilon$ is a small constant for numerical stability.
45
- - $\theta$ is the angle between $\mathbf{w}$ and $\mathbf{x}$.
46
-
47
- This operation:
48
- - **Numerator:** Squares the similarity (dot product) between $\mathbf{w}$ and $\mathbf{x}$, emphasizing strong alignments.
49
- - **Denominator:** Penalizes large distances, so the response is high only when $\mathbf{w}$ and $\mathbf{x}$ are both similar in direction and close in space.
50
- - **No activation needed:** The non-linearity is built into the operation itself, allowing the layer to learn complex, non-linear relationships without a separate activation function.
51
- - **Geometric view:** The output is maximized when $\mathbf{w}$ and $\mathbf{x}$ are both large in norm, closely aligned (small $\theta$), and close together in Euclidean space.
52
-
53
- Yat-Conv:
54
- $$
55
- ⵟ^*(\mathbf{W}, \mathbf{X}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon}
56
- = \frac{\left(\sum_{i,j} w_{ij} x_{ij}\right)^2}{\sum_{i,j} (w_{ij} - x_{ij})^2 + \epsilon}
57
- $$
58
-
59
- Where:
60
- - $\mathbf{W}$ and $\mathbf{X}$ are local patches (e.g., kernel and input patch in convolution)
61
- - $w_{ij}$ and $x_{ij}$ are elements of the kernel and input patch, respectively
62
- - $\epsilon$ is a small constant for numerical stability
63
-
64
- This generalizes the Yat-product to convolutional (patch-wise) operations.
65
-
66
-
67
- ## Supported Frameworks & Tasks
68
-
69
- ### Flax (JAX)
70
- - `YatNMN` layer implemented in `src/nmn/linen/nmn.py`
71
- - **Tasks:**
72
- - [x] Core layer implementation
73
- - [ ] Recurrent layer (to be implemented)
74
-
75
- ### NNX (Flax NNX)
76
- - `YatNMN` layer implemented in `src/nmn/nnx/nmn.py`
77
- - **Tasks:**
78
- - [x] Core layer implementation
79
- - [ ] Recurrent layer (to be implemented)
80
-
81
- ### Keras
82
- - `YatNMN` layer implemented in `src/nmn/keras/nmn.py`
83
- - **Tasks:**
84
- - [x] Core layer implementation
85
- - [ ] Recurrent layer (to be implemented)
86
-
87
- ### PyTorch
88
- - `YatNMN` layer implemented in `src/nmn/torch/nmn.py`
89
- - **Tasks:**
90
- - [x] Core layer implementation
91
- - [ ] Recurrent layer (to be implemented)
92
-
93
- ### TensorFlow
94
- - `YatNMN` layer implemented in `src/nmn/tf/nmn.py`
95
- - **Tasks:**
96
- - [x] Core layer implementation
97
- - [ ] Recurrent layer (to be implemented)
98
-
99
- ## Installation
100
-
101
- ```bash
102
- pip install nmn
103
- ```
104
-
105
- ## Usage Example (Flax)
106
-
107
- ```python
108
- from nmn.nnx.nmn import YatNMN
109
- from nmn.nnx.yatconv import YatConv
110
- # ... use as a Flax module ...
111
- ```
112
-
113
- ## Roadmap
114
- - [ ] Implement recurrent layers for all frameworks
115
- - [ ] Add more examples and benchmarks
116
- - [ ] Improve documentation and API consistency
117
-
118
- ## License
119
- GNU Affero General Public License v3
File without changes