nmn 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
nmn/nnx/nmn.py CHANGED
@@ -4,26 +4,18 @@ import typing as tp
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
7
- import numpy as np
8
7
  from jax import lax
9
- import opt_einsum
10
8
 
11
- from flax.core.frozen_dict import FrozenDict
12
9
  from flax import nnx
13
- from flax.nnx import rnglib, variablelib
14
- from flax.nnx.module import Module, first_from
10
+ from flax.nnx import rnglib
11
+ from flax.nnx.module import Module
15
12
  from flax.nnx.nn import dtypes, initializers
16
13
  from flax.typing import (
17
14
  Dtype,
18
- Shape,
19
15
  Initializer,
20
16
  PrecisionLike,
21
17
  DotGeneralT,
22
- ConvGeneralDilatedT,
23
- PaddingLike,
24
- LaxPadding,
25
18
  PromoteDtypeFn,
26
- EinsumT,
27
19
  )
28
20
 
29
21
  Array = jax.Array
@@ -60,21 +52,26 @@ class YatNMN(Module):
60
52
  in_features: the number of input features.
61
53
  out_features: the number of output features.
62
54
  use_bias: whether to add a bias to the output (default: True).
55
+ use_alpha: whether to use alpha scaling (default: True).
56
+ use_dropconnect: whether to use DropConnect (default: False).
63
57
  dtype: the dtype of the computation (default: infer from input and params).
64
58
  param_dtype: the dtype passed to parameter initializers (default: float32).
65
59
  precision: numerical precision of the computation see ``jax.lax.Precision``
66
60
  for details.
67
61
  kernel_init: initializer function for the weight matrix.
68
62
  bias_init: initializer function for the bias.
63
+ alpha_init: initializer function for the alpha.
69
64
  dot_general: dot product function.
70
65
  promote_dtype: function to promote the dtype of the arrays to the desired
71
66
  dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
72
67
  and a ``dtype`` keyword argument, and return a tuple of arrays with the
73
68
  promoted dtype.
69
+ epsilon: A small float added to the denominator to prevent division by zero.
70
+ drop_rate: dropout rate for DropConnect (default: 0.0).
74
71
  rngs: rng key.
75
72
  """
76
73
 
77
- __data__ = ('kernel', 'bias')
74
+ __data__ = ('kernel', 'bias', 'alpha', 'dropconnect_key')
78
75
 
79
76
  def __init__(
80
77
  self,
@@ -83,6 +80,7 @@ class YatNMN(Module):
83
80
  *,
84
81
  use_bias: bool = True,
85
82
  use_alpha: bool = True,
83
+ use_dropconnect: bool = False,
86
84
  dtype: tp.Optional[Dtype] = None,
87
85
  param_dtype: Dtype = jnp.float32,
88
86
  precision: PrecisionLike = None,
@@ -91,8 +89,9 @@ class YatNMN(Module):
91
89
  alpha_init: Initializer = default_alpha_init,
92
90
  dot_general: DotGeneralT = lax.dot_general,
93
91
  promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
94
- rngs: rnglib.Rngs,
95
92
  epsilon: float = 1e-5,
93
+ drop_rate: float = 0.0,
94
+ rngs: rnglib.Rngs,
96
95
  ):
97
96
 
98
97
  kernel_key = rngs.params()
@@ -117,6 +116,7 @@ class YatNMN(Module):
117
116
  self.out_features = out_features
118
117
  self.use_bias = use_bias
119
118
  self.use_alpha = use_alpha
119
+ self.use_dropconnect = use_dropconnect
120
120
  self.dtype = dtype
121
121
  self.param_dtype = param_dtype
122
122
  self.precision = precision
@@ -125,12 +125,19 @@ class YatNMN(Module):
125
125
  self.dot_general = dot_general
126
126
  self.promote_dtype = promote_dtype
127
127
  self.epsilon = epsilon
128
+ self.drop_rate = drop_rate
129
+
130
+ if use_dropconnect:
131
+ self.dropconnect_key = rngs.params()
132
+ else:
133
+ self.dropconnect_key = None
128
134
 
129
- def __call__(self, inputs: Array) -> Array:
135
+ def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
130
136
  """Applies a linear transformation to the inputs along the last dimension.
131
137
 
132
138
  Args:
133
139
  inputs: The nd-array to be transformed.
140
+ deterministic: If true, DropConnect is not applied (e.g., during inference).
134
141
 
135
142
  Returns:
136
143
  The transformed input.
@@ -139,6 +146,11 @@ class YatNMN(Module):
139
146
  bias = self.bias.value if self.bias is not None else None
140
147
  alpha = self.alpha.value if self.alpha is not None else None
141
148
 
149
+ if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
150
+ keep_prob = 1.0 - self.drop_rate
151
+ mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel.shape)
152
+ kernel = (kernel * mask) / keep_prob
153
+
142
154
  inputs, kernel, bias, alpha = self.promote_dtype(
143
155
  (inputs, kernel, bias, alpha), dtype=self.dtype
144
156
  )
@@ -166,5 +178,4 @@ class YatNMN(Module):
166
178
  scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
167
179
  y = y * scale
168
180
 
169
-
170
181
  return y
@@ -0,0 +1,9 @@
1
+ from .softermax import softermax
2
+ from .softer_sigmoid import softer_sigmoid
3
+ from .soft_tanh import soft_tanh
4
+
5
+ __all__ = [
6
+ "softermax",
7
+ "softer_sigmoid",
8
+ "soft_tanh",
9
+ ]
@@ -0,0 +1,29 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+
4
+ def soft_tanh(
5
+ x: Array,
6
+ n: float = 1.0,
7
+ ) -> Array:
8
+ """
9
+ Maps a non-negative score to the range [-1, 1) using the soft-tanh function.
10
+
11
+ The soft-tanh function is defined as:
12
+ .. math::
13
+ \\text{soft-tanh}_n(x) = \\frac{x^n - 1}{1 + x^n}
14
+
15
+ The power `n` again controls the transition sharpness: higher `n` makes the
16
+ function approach -1 more quickly for large `x`.
17
+
18
+ Args:
19
+ x (Array): A JAX array of non-negative scores (x >= 0).
20
+ n (float, optional): The power to raise the score to. Defaults to 1.0.
21
+
22
+ Returns:
23
+ Array: The mapped scores in the range [-1, 1).
24
+ """
25
+ if n <= 0:
26
+ raise ValueError("Power 'n' must be positive.")
27
+
28
+ x_n = jnp.power(x, n)
29
+ return (x_n - 1.0) / (1.0 + x_n)
@@ -0,0 +1,29 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+
4
+ def softer_sigmoid(
5
+ x: Array,
6
+ n: float = 1.0,
7
+ ) -> Array:
8
+ """
9
+ Squashes a non-negative score into the range [0, 1) using the soft-sigmoid function.
10
+
11
+ The soft-sigmoid function is defined as:
12
+ .. math::
13
+ \\text{soft-sigmoid}_n(x) = \\frac{x^n}{1 + x^n}
14
+
15
+ The power `n` modulates the softness: higher `n` makes the function approach
16
+ zero faster for large `x`, while `n < 1` makes the decay slower.
17
+
18
+ Args:
19
+ x (Array): A JAX array of non-negative scores (x >= 0).
20
+ n (float, optional): The power to raise the score to. Defaults to 1.0.
21
+
22
+ Returns:
23
+ Array: The squashed scores in the range [0, 1).
24
+ """
25
+ if n <= 0:
26
+ raise ValueError("Power 'n' must be positive.")
27
+
28
+ x_n = jnp.power(x, n)
29
+ return x_n / (1.0 + x_n)
@@ -0,0 +1,38 @@
1
+ import jax.numpy as jnp
2
+ from jax import Array
3
+ from typing import Optional
4
+
5
+ def softermax(
6
+ x: Array,
7
+ n: float = 1.0,
8
+ epsilon: float = 1e-12,
9
+ axis: Optional[int] = -1,
10
+ ) -> Array:
11
+ """
12
+ Normalizes a set of non-negative scores using the Softermax function.
13
+
14
+ The Softermax function is defined as:
15
+ .. math::
16
+ \\text{softermax}_n(x_k, \\{x_i\\}) = \\frac{x_k^n}{\\epsilon + \\sum_i x_i^n}
17
+
18
+ The power `n` controls the sharpness of the distribution: `n=1` recovers
19
+ the original Softermax, while `n > 1` makes the distribution harder (more
20
+ peaked), and `0 < n < 1` makes it softer.
21
+
22
+ Args:
23
+ x (Array): A JAX array of non-negative scores.
24
+ n (float, optional): The power to raise each score to. Defaults to 1.0.
25
+ epsilon (float, optional): A small constant for numerical stability.
26
+ Defaults to 1e-12.
27
+ axis (Optional[int], optional): The axis to perform the sum over.
28
+ Defaults to -1.
29
+
30
+ Returns:
31
+ Array: The normalized scores.
32
+ """
33
+ if n <= 0:
34
+ raise ValueError("Power 'n' must be positive.")
35
+
36
+ x_n = jnp.power(x, n)
37
+ sum_x_n = jnp.sum(x_n, axis=axis, keepdims=True)
38
+ return x_n / (epsilon + sum_x_n)
nmn/nnx/yatattention.py CHANGED
@@ -26,7 +26,8 @@ from flax.typing import (
26
26
  DotGeneralT,
27
27
  )
28
28
 
29
-
29
+ from nmn.nnx.nmn import YatNMN
30
+ from jax import Array
30
31
 
31
32
  def yat_attention_weights(
32
33
  query: Array,
@@ -153,91 +154,6 @@ def yat_attention(
153
154
  '...hqk,...khd->...qhd', attn_weights, value, precision=precision
154
155
  )
155
156
 
156
- Array = jax.Array
157
-
158
- # Add YatNMN class implementation
159
- default_bias_init = initializers.zeros_init()
160
- default_alpha_init = initializers.ones_init()
161
-
162
- class YatNMN(Module):
163
- """A linear transformation with custom distance-based computation."""
164
-
165
- def __init__(
166
- self,
167
- in_features: int,
168
- out_features: int,
169
- *,
170
- use_bias: bool = True,
171
- use_alpha: bool = True,
172
- dtype: Optional[Dtype] = None,
173
- param_dtype: Dtype = jnp.float32,
174
- precision: PrecisionLike = None,
175
- kernel_init: Initializer = default_kernel_init,
176
- bias_init: Initializer = default_bias_init,
177
- alpha_init: Initializer = default_alpha_init,
178
- dot_general: DotGeneralT = lax.dot_general,
179
- rngs: rnglib.Rngs,
180
- epsilon: float = 1e-5,
181
- ):
182
-
183
- kernel_key = rngs.params()
184
- self.kernel = nnx.Param(
185
- kernel_init(kernel_key, (in_features, out_features), param_dtype)
186
- )
187
- self.bias: nnx.Param[jax.Array] | None
188
- if use_bias:
189
- bias_key = rngs.params()
190
- self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
191
- else:
192
- self.bias = None
193
-
194
- self.alpha: nnx.Param[jax.Array] | None
195
- if use_alpha:
196
- alpha_key = rngs.params()
197
- self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
198
- else:
199
- self.alpha = None
200
-
201
- self.in_features = in_features
202
- self.out_features = out_features
203
- self.use_bias = use_bias
204
- self.use_alpha = use_alpha
205
- self.dtype = dtype
206
- self.param_dtype = param_dtype
207
- self.precision = precision
208
- self.kernel_init = kernel_init
209
- self.bias_init = bias_init
210
- self.dot_general = dot_general
211
- self.epsilon = epsilon
212
-
213
- def __call__(self, inputs: Array) -> Array:
214
- """Applies YatNMN transformation to inputs."""
215
- kernel = self.kernel.value
216
- bias = self.bias.value if self.bias is not None else None
217
- alpha = self.alpha.value if self.alpha is not None else None
218
-
219
- y = self.dot_general(
220
- inputs,
221
- kernel,
222
- (((inputs.ndim - 1,), (0,)), ((), ())),
223
- precision=self.precision,
224
- )
225
-
226
- inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
227
- kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
228
- distances = inputs_squared_sum + kernel_squared_sum - 2 * y
229
-
230
- # Element-wise operation
231
- y = y ** 2 / (distances + self.epsilon)
232
-
233
- if bias is not None:
234
- y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
235
-
236
- if alpha is not None:
237
- scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
238
- y = y * scale
239
-
240
- return y
241
157
 
242
158
 
243
159
  def dot_product_attention_weights(
@@ -435,6 +351,10 @@ class MultiHeadAttention(Module):
435
351
  attention_fn: Callable[..., Array] = yat_attention,
436
352
  decode: bool | None = None,
437
353
  normalize_qk: bool = False,
354
+ use_alpha: bool = True,
355
+ alpha_init: Initializer = initializers.ones_init(),
356
+ use_dropconnect: bool = False,
357
+ dropconnect_rate: float = 0.0,
438
358
  # Deprecated, will be removed.
439
359
  qkv_dot_general: DotGeneralT | None = None,
440
360
  out_dot_general: DotGeneralT | None = None,
@@ -470,6 +390,10 @@ class MultiHeadAttention(Module):
470
390
  self.qkv_dot_general_cls = qkv_dot_general_cls
471
391
  self.out_dot_general_cls = out_dot_general_cls
472
392
  self.epsilon = epsilon
393
+ self.use_alpha = use_alpha
394
+ self.alpha_init = alpha_init
395
+ self.use_dropconnect = use_dropconnect
396
+ self.dropconnect_rate = dropconnect_rate
473
397
 
474
398
  if self.qkv_features % self.num_heads != 0:
475
399
  raise ValueError(
@@ -491,6 +415,10 @@ class MultiHeadAttention(Module):
491
415
  use_bias=self.use_bias,
492
416
  precision=self.precision,
493
417
  epsilon=self.epsilon,
418
+ use_alpha=self.use_alpha,
419
+ alpha_init=self.alpha_init,
420
+ use_dropconnect=self.use_dropconnect,
421
+ drop_rate=self.dropconnect_rate,
494
422
  )
495
423
 
496
424
  # project inputs_q to multi-headed q/k/v
@@ -590,10 +518,23 @@ class MultiHeadAttention(Module):
590
518
  f'but module expects {self.in_features}.'
591
519
  )
592
520
 
521
+ is_deterministic: bool = False
522
+ if self.dropout_rate > 0.0 or (
523
+ self.use_dropconnect and self.dropconnect_rate > 0.0
524
+ ):
525
+ is_deterministic = first_from(
526
+ deterministic,
527
+ self.deterministic,
528
+ error_msg="""No `deterministic` argument was provided to MultiHeadAttention
529
+ as either a __call__ argument, class attribute, or nnx.flag.""",
530
+ )
531
+ else:
532
+ is_deterministic = True
533
+
593
534
  # Apply YatNMN transformations and reshape to multi-head format
594
- query = squash(self.query(inputs_q))
595
- key = squash(self.key(inputs_k))
596
- value = squash(self.value(inputs_v))
535
+ query = self.query(inputs_q, deterministic=is_deterministic)
536
+ key = self.key(inputs_k, deterministic=is_deterministic)
537
+ value = self.value(inputs_v, deterministic=is_deterministic)
597
538
 
598
539
  # Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
599
540
  query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
@@ -660,26 +601,11 @@ class MultiHeadAttention(Module):
660
601
  ),
661
602
  )
662
603
 
663
- if (
664
- self.dropout_rate > 0.0
665
- ): # Require `deterministic` only if using dropout.
666
- deterministic = first_from(
667
- deterministic,
668
- self.deterministic,
669
- error_msg="""No `deterministic` argument was provided to MultiHeadAttention
670
- as either a __call__ argument, class attribute, or nnx.flag.""",
671
- )
672
- if not deterministic:
673
- if rngs is None:
674
- raise ValueError(
675
- "'rngs' must be provided if 'dropout_rng' is not given."
676
- )
677
- dropout_rng = rngs.dropout()
678
- else:
679
- dropout_rng = None
680
- else:
681
- deterministic = True
682
- dropout_rng = None
604
+ dropout_rng = None
605
+ if self.dropout_rate > 0.0 and not is_deterministic:
606
+ if rngs is None:
607
+ raise ValueError("'rngs' must be provided for dropout.")
608
+ dropout_rng = rngs.dropout()
683
609
 
684
610
  # apply attention with epsilon parameter for YatNMN
685
611
  x = self.attention_fn(
@@ -690,7 +616,7 @@ class MultiHeadAttention(Module):
690
616
  dropout_rng=dropout_rng,
691
617
  dropout_rate=self.dropout_rate,
692
618
  broadcast_dropout=self.broadcast_dropout,
693
- deterministic=deterministic,
619
+ deterministic=is_deterministic,
694
620
  dtype=self.dtype,
695
621
  precision=self.precision,
696
622
  module=self if sow_weights else None,
nmn/nnx/yatconv.py CHANGED
@@ -110,6 +110,8 @@ class YatConv(Module):
110
110
  feature_group_count: integer, default 1. If specified divides the input
111
111
  features into groups.
112
112
  use_bias: whether to add a bias to the output (default: True).
113
+ use_alpha: whether to use alpha scaling (default: True).
114
+ use_dropconnect: whether to use DropConnect (default: False).
113
115
  mask: Optional mask for the weights during masked convolution. The mask must
114
116
  be the same shape as the convolution weight matrix.
115
117
  dtype: the dtype of the computation (default: infer from input and params).
@@ -123,10 +125,11 @@ class YatConv(Module):
123
125
  and a ``dtype`` keyword argument, and return a tuple of arrays with the
124
126
  promoted dtype.
125
127
  epsilon: A small float added to the denominator to prevent division by zero.
128
+ drop_rate: dropout rate for DropConnect (default: 0.0).
126
129
  rngs: rng key.
127
130
  """
128
131
 
129
- __data__ = ('kernel', 'bias', 'mask')
132
+ __data__ = ('kernel', 'bias', 'mask', 'dropconnect_key')
130
133
 
131
134
  def __init__(
132
135
  self,
@@ -142,6 +145,7 @@ class YatConv(Module):
142
145
 
143
146
  use_bias: bool = True,
144
147
  use_alpha: bool = True,
148
+ use_dropconnect: bool = False,
145
149
  kernel_init: Initializer = default_kernel_init,
146
150
  bias_init: Initializer = default_bias_init,
147
151
  alpha_init: Initializer = default_alpha_init,
@@ -153,6 +157,7 @@ class YatConv(Module):
153
157
  conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
154
158
  promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
155
159
  epsilon: float = 1e-5,
160
+ drop_rate: float = 0.0,
156
161
  rngs: rnglib.Rngs,
157
162
  ):
158
163
  if isinstance(kernel_size, int):
@@ -185,6 +190,7 @@ class YatConv(Module):
185
190
  self.feature_group_count = feature_group_count
186
191
  self.use_bias = use_bias
187
192
  self.use_alpha = use_alpha
193
+ self.use_dropconnect = use_dropconnect
188
194
 
189
195
  self.mask = mask
190
196
  self.dtype = dtype
@@ -195,6 +201,7 @@ class YatConv(Module):
195
201
  self.conv_general_dilated = conv_general_dilated
196
202
  self.promote_dtype = promote_dtype
197
203
  self.epsilon = epsilon
204
+ self.drop_rate = drop_rate
198
205
 
199
206
  if use_alpha:
200
207
  alpha_key = rngs.params()
@@ -202,8 +209,12 @@ class YatConv(Module):
202
209
  else:
203
210
  self.alpha = None
204
211
 
212
+ if use_dropconnect:
213
+ self.dropconnect_key = rngs.params()
214
+ else:
215
+ self.dropconnect_key = None
205
216
 
206
- def __call__(self, inputs: Array) -> Array:
217
+ def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
207
218
  assert isinstance(self.kernel_size, tuple)
208
219
 
209
220
  def maybe_broadcast(
@@ -261,6 +272,12 @@ class YatConv(Module):
261
272
 
262
273
  kernel_val = self.kernel.value
263
274
 
275
+ # Apply DropConnect if enabled and not in deterministic mode
276
+ if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
277
+ keep_prob = 1.0 - self.drop_rate
278
+ mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel_val.shape)
279
+ kernel_val = (kernel_val * mask) / keep_prob
280
+
264
281
  current_mask = self.mask
265
282
  if current_mask is not None:
266
283
  if current_mask.shape != self.kernel_shape:
@@ -0,0 +1,176 @@
1
+ Metadata-Version: 2.4
2
+ Name: nmn
3
+ Version: 0.1.6
4
+ Summary: a neuron that matter
5
+ Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
+ Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
7
+ Author-email: Taha Bouhsine <yat@mlnomads.com>
8
+ License-File: LICENSE
9
+ Classifier: License :: OSI Approved :: GNU Affero General Public License v3
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+
15
+ # nmn
16
+ Not the neurons we want, but the neurons we need
17
+
18
+ [![PyPI version](https://img.shields.io/pypi/v/nmn.svg)](https://pypi.org/project/nmn/)
19
+ [![Downloads](https://static.pepy.tech/badge/nmn)](https://pepy.tech/project/nmn)
20
+ [![Downloads/month](https://static.pepy.tech/badge/nmn/month)](https://pepy.tech/project/nmn)
21
+ [![GitHub stars](https://img.shields.io/github/stars/mlnomadpy/nmn?style=social)](https://github.com/mlnomadpy/nmn)
22
+ [![GitHub forks](https://img.shields.io/github/forks/mlnomadpy/nmn?style=social)](https://github.com/mlnomadpy/nmn)
23
+ [![GitHub issues](https://img.shields.io/github/issues/mlnomadpy/nmn)](https://github.com/mlnomadpy/nmn/issues)
24
+ [![PyPI - License](https://img.shields.io/pypi/l/nmn)](https://pypi.org/project/nmn/)
25
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/nmn)](https://pypi.org/project/nmn/)
26
+
27
+ ## Features
28
+
29
+ * **Activation-Free Non-linearity:** Learns complex, non-linear relationships without separate activation functions.
30
+ * **Multiple Frameworks:** Supports Flax (Linen & NNX), Keras, PyTorch, and TensorFlow.
31
+ * **Yat-Product & Yat-Conv:** Implements novel Yat-Product and Yat-Conv operations.
32
+ * **Inspired by Research:** Based on the principles from "Deep Learning 2.0/2.1: Artificial Neurons that Matter".
33
+
34
+ ## Overview
35
+
36
+ **nmn** provides neural network layers for multiple frameworks (Flax, NNX, Keras, PyTorch, TensorFlow) that do not require activation functions to learn non-linearity. The main goal is to enable deep learning architectures where the layer itself is inherently non-linear, inspired by the papers:
37
+
38
+ > Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
39
+ >
40
+ > Deep Learning 2.1: Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks
41
+
42
+ ## Math
43
+
44
+ Yat-Product:
45
+ $$
46
+ ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
47
+ $$
48
+
49
+ **Explanation:**
50
+ - $\mathbf{w}$ is the weight vector, $\mathbf{x}$ is the input vector.
51
+ - $\langle \mathbf{w}, \mathbf{x} \rangle$ is the dot product between $\mathbf{w}$ and $\mathbf{x}$.
52
+ - $\|\mathbf{w} - \mathbf{x}\|^2$ is the squared Euclidean distance between $\mathbf{w}$ and $\mathbf{x}$.
53
+ - $\epsilon$ is a small constant for numerical stability.
54
+ - $\theta$ is the angle between $\mathbf{w}$ and $\mathbf{x}$.
55
+
56
+ This operation:
57
+ - **Numerator:** Squares the similarity (dot product) between $\mathbf{w}$ and $\mathbf{x}$, emphasizing strong alignments.
58
+ - **Denominator:** Penalizes large distances, so the response is high only when $\mathbf{w}$ and $\mathbf{x}$ are both similar in direction and close in space.
59
+ - **No activation needed:** The non-linearity is built into the operation itself, allowing the layer to learn complex, non-linear relationships without a separate activation function.
60
+ - **Geometric view:** The output is maximized when $\mathbf{w}$ and $\mathbf{x}$ are both large in norm, closely aligned (small $\theta$), and close together in Euclidean space.
61
+
62
+ Yat-Conv:
63
+ $$
64
+ ⵟ^*(\mathbf{W}, \mathbf{X}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon}
65
+ = \frac{\left(\sum_{i,j} w_{ij} x_{ij}\right)^2}{\sum_{i,j} (w_{ij} - x_{ij})^2 + \epsilon}
66
+ $$
67
+
68
+ Where:
69
+ - $\mathbf{W}$ and $\mathbf{X}$ are local patches (e.g., kernel and input patch in convolution)
70
+ - $w_{ij}$ and $x_{ij}$ are elements of the kernel and input patch, respectively
71
+ - $\epsilon$ is a small constant for numerical stability
72
+
73
+ This generalizes the Yat-product to convolutional (patch-wise) operations.
74
+
75
+
76
+ ## Supported Frameworks & API
77
+
78
+ The `YatNMN` layer (for dense operations) and `YatConv` (for convolutional operations) are the core components. Below is a summary of their availability and features per framework:
79
+
80
+ | Framework | `YatNMN` Path | `YatConv` Path | Core Layer | DropConnect | Ternary Network | Recurrent Layer |
81
+ |----------------|-------------------------------|-------------------------------|------------|-------------|-----------------|-----------------|
82
+ | **Flax (Linen)** | `src/nmn/linen/nmn.py` | (Available) | ✅ | | | 🚧 |
83
+ | **Flax (NNX)** | `src/nmn/nnx/nmn.py` | `src/nmn/nnx/yatconv.py` | ✅ | ✅ | 🚧 | 🚧 |
84
+ | **Keras** | `src/nmn/keras/nmn.py` | (Available) | ✅ | | | 🚧 |
85
+ | **PyTorch** | `src/nmn/torch/nmn.py` | (Available) | ✅ | | | 🚧 |
86
+ | **TensorFlow** | `src/nmn/tf/nmn.py` | (Available) | ✅ | | | 🚧 |
87
+
88
+ *Legend: ✅ Implemented, 🚧 To be implemented / In Progress, (Available) - Assumed available if NMN is, specific path might vary or be part of the NMN module.*
89
+
90
+ ## Installation
91
+
92
+ ```bash
93
+ pip install nmn
94
+ ```
95
+
96
+ ## Usage Example (Flax NNX)
97
+
98
+ ```python
99
+ import jax
100
+ import jax.numpy as jnp
101
+ from flax import nnx
102
+ from nmn.nnx.nmn import YatNMN
103
+ from nmn.nnx.yatconv import YatConv
104
+
105
+ # Example YatNMN (Dense Layer)
106
+ model_key, param_key, drop_key, input_key = jax.random.split(jax.random.key(0), 4)
107
+ in_features, out_features = 3, 4
108
+ layer = YatNMN(in_features=in_features, out_features=out_features, rngs=nnx.Rngs(params=param_key, dropout=drop_key))
109
+ dummy_input = jax.random.normal(input_key, (2, in_features)) # Batch size 2
110
+ output = layer(dummy_input)
111
+ print("YatNMN Output Shape:", output.shape)
112
+
113
+ # Example YatConv (Convolutional Layer)
114
+ conv_key, conv_param_key, conv_input_key = jax.random.split(jax.random.key(1), 3)
115
+ in_channels, out_channels = 3, 8
116
+ kernel_size = (3, 3)
117
+ conv_layer = YatConv(
118
+ in_features=in_channels,
119
+ out_features=out_channels,
120
+ kernel_size=kernel_size,
121
+ rngs=nnx.Rngs(params=conv_param_key)
122
+ )
123
+ dummy_conv_input = jax.random.normal(conv_input_key, (1, 28, 28, in_channels)) # Batch 1, 28x28 image, in_channels
124
+ conv_output = conv_layer(dummy_conv_input)
125
+ print("YatConv Output Shape:", conv_output.shape)
126
+
127
+ ```
128
+ *Note: Examples for other frameworks (Keras, PyTorch, TensorFlow, Flax Linen) can be found in their respective `nmn.<framework>` modules and upcoming documentation.*
129
+
130
+ ## Roadmap
131
+
132
+ - [ ] Implement recurrent layers (`YatRNN`, `YatLSTM`, `YatGRU`) for all supported frameworks.
133
+ - [ ] Develop Ternary Network versions of Yat layers for NNX.
134
+ - [ ] Add more comprehensive examples and benchmark scripts for various tasks (vision, language).
135
+ - [ ] Publish detailed documentation and API references.
136
+ - [ ] Conduct and publish thorough performance benchmarks against traditional layers.
137
+
138
+ ## Contributing
139
+
140
+ Contributions are welcome! If you'd like to contribute, please feel free to:
141
+ - Open an issue on the [Bug Tracker](https://github.com/mlnomadpy/nmn/issues) to report bugs or suggest features.
142
+ - Submit a pull request with your improvements.
143
+ - Help expand the documentation or add more examples.
144
+
145
+ ## License
146
+
147
+ This project is licensed under the **GNU Affero General Public License v3**. See the [LICENSE](LICENSE) file for details.
148
+
149
+ ## Citation
150
+
151
+ If you use `nmn` in your research, please consider citing the original papers that inspired this work:
152
+
153
+ > Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
154
+ >
155
+ > Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks
156
+
157
+ A BibTeX entry will be provided once the accompanying paper for this library is published.
158
+
159
+ ## Citing
160
+
161
+ If you use this work, please cite the paper:
162
+
163
+ ```bibtex
164
+ @article{taha2024dl2,
165
+ author = {Taha Bouhsine},
166
+ title = {Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality},
167
+ }
168
+ ```
169
+
170
+
171
+ ```bibtex
172
+ @article{taha2025dl2,
173
+ author = {Taha Bouhsine},
174
+ title = {Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks},
175
+ }
176
+ ```
@@ -0,0 +1,19 @@
1
+ nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
2
+ nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
3
+ nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
4
+ nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
5
+ nmn/nnx/yatattention.py,sha256=i6XfCGHISyb2P6KrgYFnhhdzqSTWAyshFhy1XEeuEWc,24642
6
+ nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
7
+ nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
8
+ nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
9
+ nmn/nnx/loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ nmn/nnx/squashers/__init__.py,sha256=zXYPa3yzqMXxkIPvNHiaV6pcZRDOdVrzaVdYVDGALTY,180
11
+ nmn/nnx/squashers/soft_tanh.py,sha256=WSJkxD6L9WU1eqPwsK2AW4V6OJbw5pSWYjKwkiWtLdo,812
12
+ nmn/nnx/squashers/softer_sigmoid.py,sha256=vE6IWorZdBb2cww6fskARnwzdjTcWB2kKohuaJWVGNs,845
13
+ nmn/nnx/squashers/softermax.py,sha256=NfxEDbogLUysyTvtVCTpDt27PplYvKRQLTZbYCL-Wfg,1226
14
+ nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
15
+ nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
16
+ nmn-0.1.6.dist-info/METADATA,sha256=Y9MByC16wz1MGYVZRmZA0wJQATB6Kj6w6TOL5lPzl0Q,8800
17
+ nmn-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ nmn-0.1.6.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
19
+ nmn-0.1.6.dist-info/RECORD,,
@@ -1,119 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: nmn
3
- Version: 0.1.4
4
- Summary: a neuron that matter
5
- Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
- Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
7
- Author-email: Taha Bouhsine <yat@mlnomads.com>
8
- License-File: LICENSE
9
- Classifier: License :: OSI Approved :: GNU Affero General Public License v3
10
- Classifier: Operating System :: OS Independent
11
- Classifier: Programming Language :: Python :: 3
12
- Requires-Python: >=3.8
13
- Description-Content-Type: text/markdown
14
-
15
- # nmn
16
- Not the neurons we want, but the neurons we need
17
-
18
- [![PyPI version](https://img.shields.io/pypi/v/nmn.svg)](https://pypi.org/project/nmn/)
19
- [![Downloads](https://static.pepy.tech/badge/nmn)](https://pepy.tech/project/nmn)
20
- [![Downloads/month](https://static.pepy.tech/badge/nmn/month)](https://pepy.tech/project/nmn)
21
- [![GitHub stars](https://img.shields.io/github/stars/mlnomadpy/nmn?style=social)](https://github.com/mlnomadpy/nmn)
22
- [![GitHub forks](https://img.shields.io/github/forks/mlnomadpy/nmn?style=social)](https://github.com/mlnomadpy/nmn)
23
- [![GitHub issues](https://img.shields.io/github/issues/mlnomadpy/nmn)](https://github.com/mlnomadpy/nmn/issues)
24
- [![PyPI - License](https://img.shields.io/pypi/l/nmn)](https://pypi.org/project/nmn/)
25
- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/nmn)](https://pypi.org/project/nmn/)
26
-
27
- ## Overview
28
-
29
- **nmn** provides neural network layers for multiple frameworks (Flax, NNX, Keras, PyTorch, TensorFlow) that do not require activation functions to learn non-linearity. The main goal is to enable deep learning architectures where the layer itself is inherently non-linear, inspired by the paper:
30
-
31
- > Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
32
-
33
- ## Math
34
-
35
- Yat-Product:
36
- $$
37
- ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
38
- $$
39
-
40
- **Explanation:**
41
- - $\mathbf{w}$ is the weight vector, $\mathbf{x}$ is the input vector.
42
- - $\langle \mathbf{w}, \mathbf{x} \rangle$ is the dot product between $\mathbf{w}$ and $\mathbf{x}$.
43
- - $\|\mathbf{w} - \mathbf{x}\|^2$ is the squared Euclidean distance between $\mathbf{w}$ and $\mathbf{x}$.
44
- - $\epsilon$ is a small constant for numerical stability.
45
- - $\theta$ is the angle between $\mathbf{w}$ and $\mathbf{x}$.
46
-
47
- This operation:
48
- - **Numerator:** Squares the similarity (dot product) between $\mathbf{w}$ and $\mathbf{x}$, emphasizing strong alignments.
49
- - **Denominator:** Penalizes large distances, so the response is high only when $\mathbf{w}$ and $\mathbf{x}$ are both similar in direction and close in space.
50
- - **No activation needed:** The non-linearity is built into the operation itself, allowing the layer to learn complex, non-linear relationships without a separate activation function.
51
- - **Geometric view:** The output is maximized when $\mathbf{w}$ and $\mathbf{x}$ are both large in norm, closely aligned (small $\theta$), and close together in Euclidean space.
52
-
53
- Yat-Conv:
54
- $$
55
- ⵟ^*(\mathbf{W}, \mathbf{X}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon}
56
- = \frac{\left(\sum_{i,j} w_{ij} x_{ij}\right)^2}{\sum_{i,j} (w_{ij} - x_{ij})^2 + \epsilon}
57
- $$
58
-
59
- Where:
60
- - $\mathbf{W}$ and $\mathbf{X}$ are local patches (e.g., kernel and input patch in convolution)
61
- - $w_{ij}$ and $x_{ij}$ are elements of the kernel and input patch, respectively
62
- - $\epsilon$ is a small constant for numerical stability
63
-
64
- This generalizes the Yat-product to convolutional (patch-wise) operations.
65
-
66
-
67
- ## Supported Frameworks & Tasks
68
-
69
- ### Flax (JAX)
70
- - `YatNMN` layer implemented in `src/nmn/linen/nmn.py`
71
- - **Tasks:**
72
- - [x] Core layer implementation
73
- - [ ] Recurrent layer (to be implemented)
74
-
75
- ### NNX (Flax NNX)
76
- - `YatNMN` layer implemented in `src/nmn/nnx/nmn.py`
77
- - **Tasks:**
78
- - [x] Core layer implementation
79
- - [ ] Recurrent layer (to be implemented)
80
-
81
- ### Keras
82
- - `YatNMN` layer implemented in `src/nmn/keras/nmn.py`
83
- - **Tasks:**
84
- - [x] Core layer implementation
85
- - [ ] Recurrent layer (to be implemented)
86
-
87
- ### PyTorch
88
- - `YatNMN` layer implemented in `src/nmn/torch/nmn.py`
89
- - **Tasks:**
90
- - [x] Core layer implementation
91
- - [ ] Recurrent layer (to be implemented)
92
-
93
- ### TensorFlow
94
- - `YatNMN` layer implemented in `src/nmn/tf/nmn.py`
95
- - **Tasks:**
96
- - [x] Core layer implementation
97
- - [ ] Recurrent layer (to be implemented)
98
-
99
- ## Installation
100
-
101
- ```bash
102
- pip install nmn
103
- ```
104
-
105
- ## Usage Example (Flax)
106
-
107
- ```python
108
- from nmn.nnx.nmn import YatNMN
109
- from nmn.nnx.yatconv import YatConv
110
- # ... use as a Flax module ...
111
- ```
112
-
113
- ## Roadmap
114
- - [ ] Implement recurrent layers for all frameworks
115
- - [ ] Add more examples and benchmarks
116
- - [ ] Improve documentation and API consistency
117
-
118
- ## License
119
- GNU Affero General Public License v3
@@ -1,14 +0,0 @@
1
- nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
2
- nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
3
- nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
4
- nmn/nnx/nmn.py,sha256=gWe8EL-aUm7be03M9O5R3XdBb92EpBEFsylrY6BA60c,4871
5
- nmn/nnx/yatattention.py,sha256=chjtUKJtaR7ROPnNqkicbvMs7hzZKE0fIo_8cTNiju8,26601
6
- nmn/nnx/yatconv.py,sha256=xUH9NBY1fIDZeTA9GdgmqR_DJiQJgwU2uDrgxqirKmU,12308
7
- nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
8
- nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
9
- nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
10
- nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
11
- nmn-0.1.4.dist-info/METADATA,sha256=k28p055Dr6WWVQcb01uinFRiT5R-CAvdKz33fqZ85g4,5032
12
- nmn-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- nmn-0.1.4.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
14
- nmn-0.1.4.dist-info/RECORD,,
File without changes