nmn 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/loss/__init__.py +0 -0
- nmn/nnx/nmn.py +25 -14
- nmn/nnx/squashers/__init__.py +9 -0
- nmn/nnx/squashers/soft_tanh.py +29 -0
- nmn/nnx/squashers/softer_sigmoid.py +29 -0
- nmn/nnx/squashers/softermax.py +38 -0
- nmn/nnx/yatattention.py +36 -110
- nmn/nnx/yatconv.py +19 -2
- nmn-0.1.6.dist-info/METADATA +176 -0
- nmn-0.1.6.dist-info/RECORD +19 -0
- nmn-0.1.4.dist-info/METADATA +0 -119
- nmn-0.1.4.dist-info/RECORD +0 -14
- {nmn-0.1.4.dist-info → nmn-0.1.6.dist-info}/WHEEL +0 -0
- {nmn-0.1.4.dist-info → nmn-0.1.6.dist-info}/licenses/LICENSE +0 -0
nmn/nnx/loss/__init__.py
ADDED
File without changes
|
nmn/nnx/nmn.py
CHANGED
@@ -4,26 +4,18 @@ import typing as tp
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
-
import numpy as np
|
8
7
|
from jax import lax
|
9
|
-
import opt_einsum
|
10
8
|
|
11
|
-
from flax.core.frozen_dict import FrozenDict
|
12
9
|
from flax import nnx
|
13
|
-
from flax.nnx import rnglib
|
14
|
-
from flax.nnx.module import Module
|
10
|
+
from flax.nnx import rnglib
|
11
|
+
from flax.nnx.module import Module
|
15
12
|
from flax.nnx.nn import dtypes, initializers
|
16
13
|
from flax.typing import (
|
17
14
|
Dtype,
|
18
|
-
Shape,
|
19
15
|
Initializer,
|
20
16
|
PrecisionLike,
|
21
17
|
DotGeneralT,
|
22
|
-
ConvGeneralDilatedT,
|
23
|
-
PaddingLike,
|
24
|
-
LaxPadding,
|
25
18
|
PromoteDtypeFn,
|
26
|
-
EinsumT,
|
27
19
|
)
|
28
20
|
|
29
21
|
Array = jax.Array
|
@@ -60,21 +52,26 @@ class YatNMN(Module):
|
|
60
52
|
in_features: the number of input features.
|
61
53
|
out_features: the number of output features.
|
62
54
|
use_bias: whether to add a bias to the output (default: True).
|
55
|
+
use_alpha: whether to use alpha scaling (default: True).
|
56
|
+
use_dropconnect: whether to use DropConnect (default: False).
|
63
57
|
dtype: the dtype of the computation (default: infer from input and params).
|
64
58
|
param_dtype: the dtype passed to parameter initializers (default: float32).
|
65
59
|
precision: numerical precision of the computation see ``jax.lax.Precision``
|
66
60
|
for details.
|
67
61
|
kernel_init: initializer function for the weight matrix.
|
68
62
|
bias_init: initializer function for the bias.
|
63
|
+
alpha_init: initializer function for the alpha.
|
69
64
|
dot_general: dot product function.
|
70
65
|
promote_dtype: function to promote the dtype of the arrays to the desired
|
71
66
|
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
|
72
67
|
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
73
68
|
promoted dtype.
|
69
|
+
epsilon: A small float added to the denominator to prevent division by zero.
|
70
|
+
drop_rate: dropout rate for DropConnect (default: 0.0).
|
74
71
|
rngs: rng key.
|
75
72
|
"""
|
76
73
|
|
77
|
-
__data__ = ('kernel', 'bias')
|
74
|
+
__data__ = ('kernel', 'bias', 'alpha', 'dropconnect_key')
|
78
75
|
|
79
76
|
def __init__(
|
80
77
|
self,
|
@@ -83,6 +80,7 @@ class YatNMN(Module):
|
|
83
80
|
*,
|
84
81
|
use_bias: bool = True,
|
85
82
|
use_alpha: bool = True,
|
83
|
+
use_dropconnect: bool = False,
|
86
84
|
dtype: tp.Optional[Dtype] = None,
|
87
85
|
param_dtype: Dtype = jnp.float32,
|
88
86
|
precision: PrecisionLike = None,
|
@@ -91,8 +89,9 @@ class YatNMN(Module):
|
|
91
89
|
alpha_init: Initializer = default_alpha_init,
|
92
90
|
dot_general: DotGeneralT = lax.dot_general,
|
93
91
|
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
94
|
-
rngs: rnglib.Rngs,
|
95
92
|
epsilon: float = 1e-5,
|
93
|
+
drop_rate: float = 0.0,
|
94
|
+
rngs: rnglib.Rngs,
|
96
95
|
):
|
97
96
|
|
98
97
|
kernel_key = rngs.params()
|
@@ -117,6 +116,7 @@ class YatNMN(Module):
|
|
117
116
|
self.out_features = out_features
|
118
117
|
self.use_bias = use_bias
|
119
118
|
self.use_alpha = use_alpha
|
119
|
+
self.use_dropconnect = use_dropconnect
|
120
120
|
self.dtype = dtype
|
121
121
|
self.param_dtype = param_dtype
|
122
122
|
self.precision = precision
|
@@ -125,12 +125,19 @@ class YatNMN(Module):
|
|
125
125
|
self.dot_general = dot_general
|
126
126
|
self.promote_dtype = promote_dtype
|
127
127
|
self.epsilon = epsilon
|
128
|
+
self.drop_rate = drop_rate
|
129
|
+
|
130
|
+
if use_dropconnect:
|
131
|
+
self.dropconnect_key = rngs.params()
|
132
|
+
else:
|
133
|
+
self.dropconnect_key = None
|
128
134
|
|
129
|
-
def __call__(self, inputs: Array) -> Array:
|
135
|
+
def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
|
130
136
|
"""Applies a linear transformation to the inputs along the last dimension.
|
131
137
|
|
132
138
|
Args:
|
133
139
|
inputs: The nd-array to be transformed.
|
140
|
+
deterministic: If true, DropConnect is not applied (e.g., during inference).
|
134
141
|
|
135
142
|
Returns:
|
136
143
|
The transformed input.
|
@@ -139,6 +146,11 @@ class YatNMN(Module):
|
|
139
146
|
bias = self.bias.value if self.bias is not None else None
|
140
147
|
alpha = self.alpha.value if self.alpha is not None else None
|
141
148
|
|
149
|
+
if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
|
150
|
+
keep_prob = 1.0 - self.drop_rate
|
151
|
+
mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel.shape)
|
152
|
+
kernel = (kernel * mask) / keep_prob
|
153
|
+
|
142
154
|
inputs, kernel, bias, alpha = self.promote_dtype(
|
143
155
|
(inputs, kernel, bias, alpha), dtype=self.dtype
|
144
156
|
)
|
@@ -166,5 +178,4 @@ class YatNMN(Module):
|
|
166
178
|
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
167
179
|
y = y * scale
|
168
180
|
|
169
|
-
|
170
181
|
return y
|
@@ -0,0 +1,29 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jax import Array
|
3
|
+
|
4
|
+
def soft_tanh(
|
5
|
+
x: Array,
|
6
|
+
n: float = 1.0,
|
7
|
+
) -> Array:
|
8
|
+
"""
|
9
|
+
Maps a non-negative score to the range [-1, 1) using the soft-tanh function.
|
10
|
+
|
11
|
+
The soft-tanh function is defined as:
|
12
|
+
.. math::
|
13
|
+
\\text{soft-tanh}_n(x) = \\frac{x^n - 1}{1 + x^n}
|
14
|
+
|
15
|
+
The power `n` again controls the transition sharpness: higher `n` makes the
|
16
|
+
function approach -1 more quickly for large `x`.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
x (Array): A JAX array of non-negative scores (x >= 0).
|
20
|
+
n (float, optional): The power to raise the score to. Defaults to 1.0.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
Array: The mapped scores in the range [-1, 1).
|
24
|
+
"""
|
25
|
+
if n <= 0:
|
26
|
+
raise ValueError("Power 'n' must be positive.")
|
27
|
+
|
28
|
+
x_n = jnp.power(x, n)
|
29
|
+
return (x_n - 1.0) / (1.0 + x_n)
|
@@ -0,0 +1,29 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jax import Array
|
3
|
+
|
4
|
+
def softer_sigmoid(
|
5
|
+
x: Array,
|
6
|
+
n: float = 1.0,
|
7
|
+
) -> Array:
|
8
|
+
"""
|
9
|
+
Squashes a non-negative score into the range [0, 1) using the soft-sigmoid function.
|
10
|
+
|
11
|
+
The soft-sigmoid function is defined as:
|
12
|
+
.. math::
|
13
|
+
\\text{soft-sigmoid}_n(x) = \\frac{x^n}{1 + x^n}
|
14
|
+
|
15
|
+
The power `n` modulates the softness: higher `n` makes the function approach
|
16
|
+
zero faster for large `x`, while `n < 1` makes the decay slower.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
x (Array): A JAX array of non-negative scores (x >= 0).
|
20
|
+
n (float, optional): The power to raise the score to. Defaults to 1.0.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
Array: The squashed scores in the range [0, 1).
|
24
|
+
"""
|
25
|
+
if n <= 0:
|
26
|
+
raise ValueError("Power 'n' must be positive.")
|
27
|
+
|
28
|
+
x_n = jnp.power(x, n)
|
29
|
+
return x_n / (1.0 + x_n)
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from jax import Array
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
def softermax(
|
6
|
+
x: Array,
|
7
|
+
n: float = 1.0,
|
8
|
+
epsilon: float = 1e-12,
|
9
|
+
axis: Optional[int] = -1,
|
10
|
+
) -> Array:
|
11
|
+
"""
|
12
|
+
Normalizes a set of non-negative scores using the Softermax function.
|
13
|
+
|
14
|
+
The Softermax function is defined as:
|
15
|
+
.. math::
|
16
|
+
\\text{softermax}_n(x_k, \\{x_i\\}) = \\frac{x_k^n}{\\epsilon + \\sum_i x_i^n}
|
17
|
+
|
18
|
+
The power `n` controls the sharpness of the distribution: `n=1` recovers
|
19
|
+
the original Softermax, while `n > 1` makes the distribution harder (more
|
20
|
+
peaked), and `0 < n < 1` makes it softer.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
x (Array): A JAX array of non-negative scores.
|
24
|
+
n (float, optional): The power to raise each score to. Defaults to 1.0.
|
25
|
+
epsilon (float, optional): A small constant for numerical stability.
|
26
|
+
Defaults to 1e-12.
|
27
|
+
axis (Optional[int], optional): The axis to perform the sum over.
|
28
|
+
Defaults to -1.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
Array: The normalized scores.
|
32
|
+
"""
|
33
|
+
if n <= 0:
|
34
|
+
raise ValueError("Power 'n' must be positive.")
|
35
|
+
|
36
|
+
x_n = jnp.power(x, n)
|
37
|
+
sum_x_n = jnp.sum(x_n, axis=axis, keepdims=True)
|
38
|
+
return x_n / (epsilon + sum_x_n)
|
nmn/nnx/yatattention.py
CHANGED
@@ -26,7 +26,8 @@ from flax.typing import (
|
|
26
26
|
DotGeneralT,
|
27
27
|
)
|
28
28
|
|
29
|
-
|
29
|
+
from nmn.nnx.nmn import YatNMN
|
30
|
+
from jax import Array
|
30
31
|
|
31
32
|
def yat_attention_weights(
|
32
33
|
query: Array,
|
@@ -153,91 +154,6 @@ def yat_attention(
|
|
153
154
|
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
|
154
155
|
)
|
155
156
|
|
156
|
-
Array = jax.Array
|
157
|
-
|
158
|
-
# Add YatNMN class implementation
|
159
|
-
default_bias_init = initializers.zeros_init()
|
160
|
-
default_alpha_init = initializers.ones_init()
|
161
|
-
|
162
|
-
class YatNMN(Module):
|
163
|
-
"""A linear transformation with custom distance-based computation."""
|
164
|
-
|
165
|
-
def __init__(
|
166
|
-
self,
|
167
|
-
in_features: int,
|
168
|
-
out_features: int,
|
169
|
-
*,
|
170
|
-
use_bias: bool = True,
|
171
|
-
use_alpha: bool = True,
|
172
|
-
dtype: Optional[Dtype] = None,
|
173
|
-
param_dtype: Dtype = jnp.float32,
|
174
|
-
precision: PrecisionLike = None,
|
175
|
-
kernel_init: Initializer = default_kernel_init,
|
176
|
-
bias_init: Initializer = default_bias_init,
|
177
|
-
alpha_init: Initializer = default_alpha_init,
|
178
|
-
dot_general: DotGeneralT = lax.dot_general,
|
179
|
-
rngs: rnglib.Rngs,
|
180
|
-
epsilon: float = 1e-5,
|
181
|
-
):
|
182
|
-
|
183
|
-
kernel_key = rngs.params()
|
184
|
-
self.kernel = nnx.Param(
|
185
|
-
kernel_init(kernel_key, (in_features, out_features), param_dtype)
|
186
|
-
)
|
187
|
-
self.bias: nnx.Param[jax.Array] | None
|
188
|
-
if use_bias:
|
189
|
-
bias_key = rngs.params()
|
190
|
-
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
|
191
|
-
else:
|
192
|
-
self.bias = None
|
193
|
-
|
194
|
-
self.alpha: nnx.Param[jax.Array] | None
|
195
|
-
if use_alpha:
|
196
|
-
alpha_key = rngs.params()
|
197
|
-
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
198
|
-
else:
|
199
|
-
self.alpha = None
|
200
|
-
|
201
|
-
self.in_features = in_features
|
202
|
-
self.out_features = out_features
|
203
|
-
self.use_bias = use_bias
|
204
|
-
self.use_alpha = use_alpha
|
205
|
-
self.dtype = dtype
|
206
|
-
self.param_dtype = param_dtype
|
207
|
-
self.precision = precision
|
208
|
-
self.kernel_init = kernel_init
|
209
|
-
self.bias_init = bias_init
|
210
|
-
self.dot_general = dot_general
|
211
|
-
self.epsilon = epsilon
|
212
|
-
|
213
|
-
def __call__(self, inputs: Array) -> Array:
|
214
|
-
"""Applies YatNMN transformation to inputs."""
|
215
|
-
kernel = self.kernel.value
|
216
|
-
bias = self.bias.value if self.bias is not None else None
|
217
|
-
alpha = self.alpha.value if self.alpha is not None else None
|
218
|
-
|
219
|
-
y = self.dot_general(
|
220
|
-
inputs,
|
221
|
-
kernel,
|
222
|
-
(((inputs.ndim - 1,), (0,)), ((), ())),
|
223
|
-
precision=self.precision,
|
224
|
-
)
|
225
|
-
|
226
|
-
inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
|
227
|
-
kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
|
228
|
-
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
229
|
-
|
230
|
-
# Element-wise operation
|
231
|
-
y = y ** 2 / (distances + self.epsilon)
|
232
|
-
|
233
|
-
if bias is not None:
|
234
|
-
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
235
|
-
|
236
|
-
if alpha is not None:
|
237
|
-
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
238
|
-
y = y * scale
|
239
|
-
|
240
|
-
return y
|
241
157
|
|
242
158
|
|
243
159
|
def dot_product_attention_weights(
|
@@ -435,6 +351,10 @@ class MultiHeadAttention(Module):
|
|
435
351
|
attention_fn: Callable[..., Array] = yat_attention,
|
436
352
|
decode: bool | None = None,
|
437
353
|
normalize_qk: bool = False,
|
354
|
+
use_alpha: bool = True,
|
355
|
+
alpha_init: Initializer = initializers.ones_init(),
|
356
|
+
use_dropconnect: bool = False,
|
357
|
+
dropconnect_rate: float = 0.0,
|
438
358
|
# Deprecated, will be removed.
|
439
359
|
qkv_dot_general: DotGeneralT | None = None,
|
440
360
|
out_dot_general: DotGeneralT | None = None,
|
@@ -470,6 +390,10 @@ class MultiHeadAttention(Module):
|
|
470
390
|
self.qkv_dot_general_cls = qkv_dot_general_cls
|
471
391
|
self.out_dot_general_cls = out_dot_general_cls
|
472
392
|
self.epsilon = epsilon
|
393
|
+
self.use_alpha = use_alpha
|
394
|
+
self.alpha_init = alpha_init
|
395
|
+
self.use_dropconnect = use_dropconnect
|
396
|
+
self.dropconnect_rate = dropconnect_rate
|
473
397
|
|
474
398
|
if self.qkv_features % self.num_heads != 0:
|
475
399
|
raise ValueError(
|
@@ -491,6 +415,10 @@ class MultiHeadAttention(Module):
|
|
491
415
|
use_bias=self.use_bias,
|
492
416
|
precision=self.precision,
|
493
417
|
epsilon=self.epsilon,
|
418
|
+
use_alpha=self.use_alpha,
|
419
|
+
alpha_init=self.alpha_init,
|
420
|
+
use_dropconnect=self.use_dropconnect,
|
421
|
+
drop_rate=self.dropconnect_rate,
|
494
422
|
)
|
495
423
|
|
496
424
|
# project inputs_q to multi-headed q/k/v
|
@@ -590,10 +518,23 @@ class MultiHeadAttention(Module):
|
|
590
518
|
f'but module expects {self.in_features}.'
|
591
519
|
)
|
592
520
|
|
521
|
+
is_deterministic: bool = False
|
522
|
+
if self.dropout_rate > 0.0 or (
|
523
|
+
self.use_dropconnect and self.dropconnect_rate > 0.0
|
524
|
+
):
|
525
|
+
is_deterministic = first_from(
|
526
|
+
deterministic,
|
527
|
+
self.deterministic,
|
528
|
+
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
|
529
|
+
as either a __call__ argument, class attribute, or nnx.flag.""",
|
530
|
+
)
|
531
|
+
else:
|
532
|
+
is_deterministic = True
|
533
|
+
|
593
534
|
# Apply YatNMN transformations and reshape to multi-head format
|
594
|
-
query =
|
595
|
-
key =
|
596
|
-
value =
|
535
|
+
query = self.query(inputs_q, deterministic=is_deterministic)
|
536
|
+
key = self.key(inputs_k, deterministic=is_deterministic)
|
537
|
+
value = self.value(inputs_v, deterministic=is_deterministic)
|
597
538
|
|
598
539
|
# Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
|
599
540
|
query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
|
@@ -660,26 +601,11 @@ class MultiHeadAttention(Module):
|
|
660
601
|
),
|
661
602
|
)
|
662
603
|
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
self.deterministic,
|
669
|
-
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
|
670
|
-
as either a __call__ argument, class attribute, or nnx.flag.""",
|
671
|
-
)
|
672
|
-
if not deterministic:
|
673
|
-
if rngs is None:
|
674
|
-
raise ValueError(
|
675
|
-
"'rngs' must be provided if 'dropout_rng' is not given."
|
676
|
-
)
|
677
|
-
dropout_rng = rngs.dropout()
|
678
|
-
else:
|
679
|
-
dropout_rng = None
|
680
|
-
else:
|
681
|
-
deterministic = True
|
682
|
-
dropout_rng = None
|
604
|
+
dropout_rng = None
|
605
|
+
if self.dropout_rate > 0.0 and not is_deterministic:
|
606
|
+
if rngs is None:
|
607
|
+
raise ValueError("'rngs' must be provided for dropout.")
|
608
|
+
dropout_rng = rngs.dropout()
|
683
609
|
|
684
610
|
# apply attention with epsilon parameter for YatNMN
|
685
611
|
x = self.attention_fn(
|
@@ -690,7 +616,7 @@ class MultiHeadAttention(Module):
|
|
690
616
|
dropout_rng=dropout_rng,
|
691
617
|
dropout_rate=self.dropout_rate,
|
692
618
|
broadcast_dropout=self.broadcast_dropout,
|
693
|
-
deterministic=
|
619
|
+
deterministic=is_deterministic,
|
694
620
|
dtype=self.dtype,
|
695
621
|
precision=self.precision,
|
696
622
|
module=self if sow_weights else None,
|
nmn/nnx/yatconv.py
CHANGED
@@ -110,6 +110,8 @@ class YatConv(Module):
|
|
110
110
|
feature_group_count: integer, default 1. If specified divides the input
|
111
111
|
features into groups.
|
112
112
|
use_bias: whether to add a bias to the output (default: True).
|
113
|
+
use_alpha: whether to use alpha scaling (default: True).
|
114
|
+
use_dropconnect: whether to use DropConnect (default: False).
|
113
115
|
mask: Optional mask for the weights during masked convolution. The mask must
|
114
116
|
be the same shape as the convolution weight matrix.
|
115
117
|
dtype: the dtype of the computation (default: infer from input and params).
|
@@ -123,10 +125,11 @@ class YatConv(Module):
|
|
123
125
|
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
124
126
|
promoted dtype.
|
125
127
|
epsilon: A small float added to the denominator to prevent division by zero.
|
128
|
+
drop_rate: dropout rate for DropConnect (default: 0.0).
|
126
129
|
rngs: rng key.
|
127
130
|
"""
|
128
131
|
|
129
|
-
__data__ = ('kernel', 'bias', 'mask')
|
132
|
+
__data__ = ('kernel', 'bias', 'mask', 'dropconnect_key')
|
130
133
|
|
131
134
|
def __init__(
|
132
135
|
self,
|
@@ -142,6 +145,7 @@ class YatConv(Module):
|
|
142
145
|
|
143
146
|
use_bias: bool = True,
|
144
147
|
use_alpha: bool = True,
|
148
|
+
use_dropconnect: bool = False,
|
145
149
|
kernel_init: Initializer = default_kernel_init,
|
146
150
|
bias_init: Initializer = default_bias_init,
|
147
151
|
alpha_init: Initializer = default_alpha_init,
|
@@ -153,6 +157,7 @@ class YatConv(Module):
|
|
153
157
|
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
|
154
158
|
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
155
159
|
epsilon: float = 1e-5,
|
160
|
+
drop_rate: float = 0.0,
|
156
161
|
rngs: rnglib.Rngs,
|
157
162
|
):
|
158
163
|
if isinstance(kernel_size, int):
|
@@ -185,6 +190,7 @@ class YatConv(Module):
|
|
185
190
|
self.feature_group_count = feature_group_count
|
186
191
|
self.use_bias = use_bias
|
187
192
|
self.use_alpha = use_alpha
|
193
|
+
self.use_dropconnect = use_dropconnect
|
188
194
|
|
189
195
|
self.mask = mask
|
190
196
|
self.dtype = dtype
|
@@ -195,6 +201,7 @@ class YatConv(Module):
|
|
195
201
|
self.conv_general_dilated = conv_general_dilated
|
196
202
|
self.promote_dtype = promote_dtype
|
197
203
|
self.epsilon = epsilon
|
204
|
+
self.drop_rate = drop_rate
|
198
205
|
|
199
206
|
if use_alpha:
|
200
207
|
alpha_key = rngs.params()
|
@@ -202,8 +209,12 @@ class YatConv(Module):
|
|
202
209
|
else:
|
203
210
|
self.alpha = None
|
204
211
|
|
212
|
+
if use_dropconnect:
|
213
|
+
self.dropconnect_key = rngs.params()
|
214
|
+
else:
|
215
|
+
self.dropconnect_key = None
|
205
216
|
|
206
|
-
def __call__(self, inputs: Array) -> Array:
|
217
|
+
def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
|
207
218
|
assert isinstance(self.kernel_size, tuple)
|
208
219
|
|
209
220
|
def maybe_broadcast(
|
@@ -261,6 +272,12 @@ class YatConv(Module):
|
|
261
272
|
|
262
273
|
kernel_val = self.kernel.value
|
263
274
|
|
275
|
+
# Apply DropConnect if enabled and not in deterministic mode
|
276
|
+
if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
|
277
|
+
keep_prob = 1.0 - self.drop_rate
|
278
|
+
mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel_val.shape)
|
279
|
+
kernel_val = (kernel_val * mask) / keep_prob
|
280
|
+
|
264
281
|
current_mask = self.mask
|
265
282
|
if current_mask is not None:
|
266
283
|
if current_mask.shape != self.kernel_shape:
|
@@ -0,0 +1,176 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: nmn
|
3
|
+
Version: 0.1.6
|
4
|
+
Summary: a neuron that matter
|
5
|
+
Project-URL: Homepage, https://github.com/mlnomadpy/nmn
|
6
|
+
Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
|
7
|
+
Author-email: Taha Bouhsine <yat@mlnomads.com>
|
8
|
+
License-File: LICENSE
|
9
|
+
Classifier: License :: OSI Approved :: GNU Affero General Public License v3
|
10
|
+
Classifier: Operating System :: OS Independent
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
12
|
+
Requires-Python: >=3.8
|
13
|
+
Description-Content-Type: text/markdown
|
14
|
+
|
15
|
+
# nmn
|
16
|
+
Not the neurons we want, but the neurons we need
|
17
|
+
|
18
|
+
[](https://pypi.org/project/nmn/)
|
19
|
+
[](https://pepy.tech/project/nmn)
|
20
|
+
[](https://pepy.tech/project/nmn)
|
21
|
+
[](https://github.com/mlnomadpy/nmn)
|
22
|
+
[](https://github.com/mlnomadpy/nmn)
|
23
|
+
[](https://github.com/mlnomadpy/nmn/issues)
|
24
|
+
[](https://pypi.org/project/nmn/)
|
25
|
+
[](https://pypi.org/project/nmn/)
|
26
|
+
|
27
|
+
## Features
|
28
|
+
|
29
|
+
* **Activation-Free Non-linearity:** Learns complex, non-linear relationships without separate activation functions.
|
30
|
+
* **Multiple Frameworks:** Supports Flax (Linen & NNX), Keras, PyTorch, and TensorFlow.
|
31
|
+
* **Yat-Product & Yat-Conv:** Implements novel Yat-Product and Yat-Conv operations.
|
32
|
+
* **Inspired by Research:** Based on the principles from "Deep Learning 2.0/2.1: Artificial Neurons that Matter".
|
33
|
+
|
34
|
+
## Overview
|
35
|
+
|
36
|
+
**nmn** provides neural network layers for multiple frameworks (Flax, NNX, Keras, PyTorch, TensorFlow) that do not require activation functions to learn non-linearity. The main goal is to enable deep learning architectures where the layer itself is inherently non-linear, inspired by the papers:
|
37
|
+
|
38
|
+
> Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
|
39
|
+
>
|
40
|
+
> Deep Learning 2.1: Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks
|
41
|
+
|
42
|
+
## Math
|
43
|
+
|
44
|
+
Yat-Product:
|
45
|
+
$$
|
46
|
+
ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
|
47
|
+
$$
|
48
|
+
|
49
|
+
**Explanation:**
|
50
|
+
- $\mathbf{w}$ is the weight vector, $\mathbf{x}$ is the input vector.
|
51
|
+
- $\langle \mathbf{w}, \mathbf{x} \rangle$ is the dot product between $\mathbf{w}$ and $\mathbf{x}$.
|
52
|
+
- $\|\mathbf{w} - \mathbf{x}\|^2$ is the squared Euclidean distance between $\mathbf{w}$ and $\mathbf{x}$.
|
53
|
+
- $\epsilon$ is a small constant for numerical stability.
|
54
|
+
- $\theta$ is the angle between $\mathbf{w}$ and $\mathbf{x}$.
|
55
|
+
|
56
|
+
This operation:
|
57
|
+
- **Numerator:** Squares the similarity (dot product) between $\mathbf{w}$ and $\mathbf{x}$, emphasizing strong alignments.
|
58
|
+
- **Denominator:** Penalizes large distances, so the response is high only when $\mathbf{w}$ and $\mathbf{x}$ are both similar in direction and close in space.
|
59
|
+
- **No activation needed:** The non-linearity is built into the operation itself, allowing the layer to learn complex, non-linear relationships without a separate activation function.
|
60
|
+
- **Geometric view:** The output is maximized when $\mathbf{w}$ and $\mathbf{x}$ are both large in norm, closely aligned (small $\theta$), and close together in Euclidean space.
|
61
|
+
|
62
|
+
Yat-Conv:
|
63
|
+
$$
|
64
|
+
ⵟ^*(\mathbf{W}, \mathbf{X}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon}
|
65
|
+
= \frac{\left(\sum_{i,j} w_{ij} x_{ij}\right)^2}{\sum_{i,j} (w_{ij} - x_{ij})^2 + \epsilon}
|
66
|
+
$$
|
67
|
+
|
68
|
+
Where:
|
69
|
+
- $\mathbf{W}$ and $\mathbf{X}$ are local patches (e.g., kernel and input patch in convolution)
|
70
|
+
- $w_{ij}$ and $x_{ij}$ are elements of the kernel and input patch, respectively
|
71
|
+
- $\epsilon$ is a small constant for numerical stability
|
72
|
+
|
73
|
+
This generalizes the Yat-product to convolutional (patch-wise) operations.
|
74
|
+
|
75
|
+
|
76
|
+
## Supported Frameworks & API
|
77
|
+
|
78
|
+
The `YatNMN` layer (for dense operations) and `YatConv` (for convolutional operations) are the core components. Below is a summary of their availability and features per framework:
|
79
|
+
|
80
|
+
| Framework | `YatNMN` Path | `YatConv` Path | Core Layer | DropConnect | Ternary Network | Recurrent Layer |
|
81
|
+
|----------------|-------------------------------|-------------------------------|------------|-------------|-----------------|-----------------|
|
82
|
+
| **Flax (Linen)** | `src/nmn/linen/nmn.py` | (Available) | ✅ | | | 🚧 |
|
83
|
+
| **Flax (NNX)** | `src/nmn/nnx/nmn.py` | `src/nmn/nnx/yatconv.py` | ✅ | ✅ | 🚧 | 🚧 |
|
84
|
+
| **Keras** | `src/nmn/keras/nmn.py` | (Available) | ✅ | | | 🚧 |
|
85
|
+
| **PyTorch** | `src/nmn/torch/nmn.py` | (Available) | ✅ | | | 🚧 |
|
86
|
+
| **TensorFlow** | `src/nmn/tf/nmn.py` | (Available) | ✅ | | | 🚧 |
|
87
|
+
|
88
|
+
*Legend: ✅ Implemented, 🚧 To be implemented / In Progress, (Available) - Assumed available if NMN is, specific path might vary or be part of the NMN module.*
|
89
|
+
|
90
|
+
## Installation
|
91
|
+
|
92
|
+
```bash
|
93
|
+
pip install nmn
|
94
|
+
```
|
95
|
+
|
96
|
+
## Usage Example (Flax NNX)
|
97
|
+
|
98
|
+
```python
|
99
|
+
import jax
|
100
|
+
import jax.numpy as jnp
|
101
|
+
from flax import nnx
|
102
|
+
from nmn.nnx.nmn import YatNMN
|
103
|
+
from nmn.nnx.yatconv import YatConv
|
104
|
+
|
105
|
+
# Example YatNMN (Dense Layer)
|
106
|
+
model_key, param_key, drop_key, input_key = jax.random.split(jax.random.key(0), 4)
|
107
|
+
in_features, out_features = 3, 4
|
108
|
+
layer = YatNMN(in_features=in_features, out_features=out_features, rngs=nnx.Rngs(params=param_key, dropout=drop_key))
|
109
|
+
dummy_input = jax.random.normal(input_key, (2, in_features)) # Batch size 2
|
110
|
+
output = layer(dummy_input)
|
111
|
+
print("YatNMN Output Shape:", output.shape)
|
112
|
+
|
113
|
+
# Example YatConv (Convolutional Layer)
|
114
|
+
conv_key, conv_param_key, conv_input_key = jax.random.split(jax.random.key(1), 3)
|
115
|
+
in_channels, out_channels = 3, 8
|
116
|
+
kernel_size = (3, 3)
|
117
|
+
conv_layer = YatConv(
|
118
|
+
in_features=in_channels,
|
119
|
+
out_features=out_channels,
|
120
|
+
kernel_size=kernel_size,
|
121
|
+
rngs=nnx.Rngs(params=conv_param_key)
|
122
|
+
)
|
123
|
+
dummy_conv_input = jax.random.normal(conv_input_key, (1, 28, 28, in_channels)) # Batch 1, 28x28 image, in_channels
|
124
|
+
conv_output = conv_layer(dummy_conv_input)
|
125
|
+
print("YatConv Output Shape:", conv_output.shape)
|
126
|
+
|
127
|
+
```
|
128
|
+
*Note: Examples for other frameworks (Keras, PyTorch, TensorFlow, Flax Linen) can be found in their respective `nmn.<framework>` modules and upcoming documentation.*
|
129
|
+
|
130
|
+
## Roadmap
|
131
|
+
|
132
|
+
- [ ] Implement recurrent layers (`YatRNN`, `YatLSTM`, `YatGRU`) for all supported frameworks.
|
133
|
+
- [ ] Develop Ternary Network versions of Yat layers for NNX.
|
134
|
+
- [ ] Add more comprehensive examples and benchmark scripts for various tasks (vision, language).
|
135
|
+
- [ ] Publish detailed documentation and API references.
|
136
|
+
- [ ] Conduct and publish thorough performance benchmarks against traditional layers.
|
137
|
+
|
138
|
+
## Contributing
|
139
|
+
|
140
|
+
Contributions are welcome! If you'd like to contribute, please feel free to:
|
141
|
+
- Open an issue on the [Bug Tracker](https://github.com/mlnomadpy/nmn/issues) to report bugs or suggest features.
|
142
|
+
- Submit a pull request with your improvements.
|
143
|
+
- Help expand the documentation or add more examples.
|
144
|
+
|
145
|
+
## License
|
146
|
+
|
147
|
+
This project is licensed under the **GNU Affero General Public License v3**. See the [LICENSE](LICENSE) file for details.
|
148
|
+
|
149
|
+
## Citation
|
150
|
+
|
151
|
+
If you use `nmn` in your research, please consider citing the original papers that inspired this work:
|
152
|
+
|
153
|
+
> Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
|
154
|
+
>
|
155
|
+
> Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks
|
156
|
+
|
157
|
+
A BibTeX entry will be provided once the accompanying paper for this library is published.
|
158
|
+
|
159
|
+
## Citing
|
160
|
+
|
161
|
+
If you use this work, please cite the paper:
|
162
|
+
|
163
|
+
```bibtex
|
164
|
+
@article{taha2024dl2,
|
165
|
+
author = {Taha Bouhsine},
|
166
|
+
title = {Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality},
|
167
|
+
}
|
168
|
+
```
|
169
|
+
|
170
|
+
|
171
|
+
```bibtex
|
172
|
+
@article{taha2025dl2,
|
173
|
+
author = {Taha Bouhsine},
|
174
|
+
title = {Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks},
|
175
|
+
}
|
176
|
+
```
|
@@ -0,0 +1,19 @@
|
|
1
|
+
nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
|
2
|
+
nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
|
3
|
+
nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
|
4
|
+
nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
|
5
|
+
nmn/nnx/yatattention.py,sha256=i6XfCGHISyb2P6KrgYFnhhdzqSTWAyshFhy1XEeuEWc,24642
|
6
|
+
nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
|
7
|
+
nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
|
8
|
+
nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
|
9
|
+
nmn/nnx/loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
+
nmn/nnx/squashers/__init__.py,sha256=zXYPa3yzqMXxkIPvNHiaV6pcZRDOdVrzaVdYVDGALTY,180
|
11
|
+
nmn/nnx/squashers/soft_tanh.py,sha256=WSJkxD6L9WU1eqPwsK2AW4V6OJbw5pSWYjKwkiWtLdo,812
|
12
|
+
nmn/nnx/squashers/softer_sigmoid.py,sha256=vE6IWorZdBb2cww6fskARnwzdjTcWB2kKohuaJWVGNs,845
|
13
|
+
nmn/nnx/squashers/softermax.py,sha256=NfxEDbogLUysyTvtVCTpDt27PplYvKRQLTZbYCL-Wfg,1226
|
14
|
+
nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
|
15
|
+
nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
|
16
|
+
nmn-0.1.6.dist-info/METADATA,sha256=Y9MByC16wz1MGYVZRmZA0wJQATB6Kj6w6TOL5lPzl0Q,8800
|
17
|
+
nmn-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
18
|
+
nmn-0.1.6.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
|
19
|
+
nmn-0.1.6.dist-info/RECORD,,
|
nmn-0.1.4.dist-info/METADATA
DELETED
@@ -1,119 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.4
|
2
|
-
Name: nmn
|
3
|
-
Version: 0.1.4
|
4
|
-
Summary: a neuron that matter
|
5
|
-
Project-URL: Homepage, https://github.com/mlnomadpy/nmn
|
6
|
-
Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
|
7
|
-
Author-email: Taha Bouhsine <yat@mlnomads.com>
|
8
|
-
License-File: LICENSE
|
9
|
-
Classifier: License :: OSI Approved :: GNU Affero General Public License v3
|
10
|
-
Classifier: Operating System :: OS Independent
|
11
|
-
Classifier: Programming Language :: Python :: 3
|
12
|
-
Requires-Python: >=3.8
|
13
|
-
Description-Content-Type: text/markdown
|
14
|
-
|
15
|
-
# nmn
|
16
|
-
Not the neurons we want, but the neurons we need
|
17
|
-
|
18
|
-
[](https://pypi.org/project/nmn/)
|
19
|
-
[](https://pepy.tech/project/nmn)
|
20
|
-
[](https://pepy.tech/project/nmn)
|
21
|
-
[](https://github.com/mlnomadpy/nmn)
|
22
|
-
[](https://github.com/mlnomadpy/nmn)
|
23
|
-
[](https://github.com/mlnomadpy/nmn/issues)
|
24
|
-
[](https://pypi.org/project/nmn/)
|
25
|
-
[](https://pypi.org/project/nmn/)
|
26
|
-
|
27
|
-
## Overview
|
28
|
-
|
29
|
-
**nmn** provides neural network layers for multiple frameworks (Flax, NNX, Keras, PyTorch, TensorFlow) that do not require activation functions to learn non-linearity. The main goal is to enable deep learning architectures where the layer itself is inherently non-linear, inspired by the paper:
|
30
|
-
|
31
|
-
> Deep Learning 2.0: Artificial Neurons that Matter: Reject Correlation - Embrace Orthogonality
|
32
|
-
|
33
|
-
## Math
|
34
|
-
|
35
|
-
Yat-Product:
|
36
|
-
$$
|
37
|
-
ⵟ(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{\|\mathbf{w}\|^2 - 2\mathbf{w}^\top\mathbf{x} + \|\mathbf{x}\|^2 + \epsilon} = \frac{ \|\mathbf{x}\|^2 \|\mathbf{w}\|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}.
|
38
|
-
$$
|
39
|
-
|
40
|
-
**Explanation:**
|
41
|
-
- $\mathbf{w}$ is the weight vector, $\mathbf{x}$ is the input vector.
|
42
|
-
- $\langle \mathbf{w}, \mathbf{x} \rangle$ is the dot product between $\mathbf{w}$ and $\mathbf{x}$.
|
43
|
-
- $\|\mathbf{w} - \mathbf{x}\|^2$ is the squared Euclidean distance between $\mathbf{w}$ and $\mathbf{x}$.
|
44
|
-
- $\epsilon$ is a small constant for numerical stability.
|
45
|
-
- $\theta$ is the angle between $\mathbf{w}$ and $\mathbf{x}$.
|
46
|
-
|
47
|
-
This operation:
|
48
|
-
- **Numerator:** Squares the similarity (dot product) between $\mathbf{w}$ and $\mathbf{x}$, emphasizing strong alignments.
|
49
|
-
- **Denominator:** Penalizes large distances, so the response is high only when $\mathbf{w}$ and $\mathbf{x}$ are both similar in direction and close in space.
|
50
|
-
- **No activation needed:** The non-linearity is built into the operation itself, allowing the layer to learn complex, non-linear relationships without a separate activation function.
|
51
|
-
- **Geometric view:** The output is maximized when $\mathbf{w}$ and $\mathbf{x}$ are both large in norm, closely aligned (small $\theta$), and close together in Euclidean space.
|
52
|
-
|
53
|
-
Yat-Conv:
|
54
|
-
$$
|
55
|
-
ⵟ^*(\mathbf{W}, \mathbf{X}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{\|\mathbf{w} - \mathbf{x}\|^2 + \epsilon}
|
56
|
-
= \frac{\left(\sum_{i,j} w_{ij} x_{ij}\right)^2}{\sum_{i,j} (w_{ij} - x_{ij})^2 + \epsilon}
|
57
|
-
$$
|
58
|
-
|
59
|
-
Where:
|
60
|
-
- $\mathbf{W}$ and $\mathbf{X}$ are local patches (e.g., kernel and input patch in convolution)
|
61
|
-
- $w_{ij}$ and $x_{ij}$ are elements of the kernel and input patch, respectively
|
62
|
-
- $\epsilon$ is a small constant for numerical stability
|
63
|
-
|
64
|
-
This generalizes the Yat-product to convolutional (patch-wise) operations.
|
65
|
-
|
66
|
-
|
67
|
-
## Supported Frameworks & Tasks
|
68
|
-
|
69
|
-
### Flax (JAX)
|
70
|
-
- `YatNMN` layer implemented in `src/nmn/linen/nmn.py`
|
71
|
-
- **Tasks:**
|
72
|
-
- [x] Core layer implementation
|
73
|
-
- [ ] Recurrent layer (to be implemented)
|
74
|
-
|
75
|
-
### NNX (Flax NNX)
|
76
|
-
- `YatNMN` layer implemented in `src/nmn/nnx/nmn.py`
|
77
|
-
- **Tasks:**
|
78
|
-
- [x] Core layer implementation
|
79
|
-
- [ ] Recurrent layer (to be implemented)
|
80
|
-
|
81
|
-
### Keras
|
82
|
-
- `YatNMN` layer implemented in `src/nmn/keras/nmn.py`
|
83
|
-
- **Tasks:**
|
84
|
-
- [x] Core layer implementation
|
85
|
-
- [ ] Recurrent layer (to be implemented)
|
86
|
-
|
87
|
-
### PyTorch
|
88
|
-
- `YatNMN` layer implemented in `src/nmn/torch/nmn.py`
|
89
|
-
- **Tasks:**
|
90
|
-
- [x] Core layer implementation
|
91
|
-
- [ ] Recurrent layer (to be implemented)
|
92
|
-
|
93
|
-
### TensorFlow
|
94
|
-
- `YatNMN` layer implemented in `src/nmn/tf/nmn.py`
|
95
|
-
- **Tasks:**
|
96
|
-
- [x] Core layer implementation
|
97
|
-
- [ ] Recurrent layer (to be implemented)
|
98
|
-
|
99
|
-
## Installation
|
100
|
-
|
101
|
-
```bash
|
102
|
-
pip install nmn
|
103
|
-
```
|
104
|
-
|
105
|
-
## Usage Example (Flax)
|
106
|
-
|
107
|
-
```python
|
108
|
-
from nmn.nnx.nmn import YatNMN
|
109
|
-
from nmn.nnx.yatconv import YatConv
|
110
|
-
# ... use as a Flax module ...
|
111
|
-
```
|
112
|
-
|
113
|
-
## Roadmap
|
114
|
-
- [ ] Implement recurrent layers for all frameworks
|
115
|
-
- [ ] Add more examples and benchmarks
|
116
|
-
- [ ] Improve documentation and API consistency
|
117
|
-
|
118
|
-
## License
|
119
|
-
GNU Affero General Public License v3
|
nmn-0.1.4.dist-info/RECORD
DELETED
@@ -1,14 +0,0 @@
|
|
1
|
-
nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
|
2
|
-
nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
|
3
|
-
nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
|
4
|
-
nmn/nnx/nmn.py,sha256=gWe8EL-aUm7be03M9O5R3XdBb92EpBEFsylrY6BA60c,4871
|
5
|
-
nmn/nnx/yatattention.py,sha256=chjtUKJtaR7ROPnNqkicbvMs7hzZKE0fIo_8cTNiju8,26601
|
6
|
-
nmn/nnx/yatconv.py,sha256=xUH9NBY1fIDZeTA9GdgmqR_DJiQJgwU2uDrgxqirKmU,12308
|
7
|
-
nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
|
8
|
-
nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
|
9
|
-
nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
|
10
|
-
nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
|
11
|
-
nmn-0.1.4.dist-info/METADATA,sha256=k28p055Dr6WWVQcb01uinFRiT5R-CAvdKz33fqZ85g4,5032
|
12
|
-
nmn-0.1.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
13
|
-
nmn-0.1.4.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
|
14
|
-
nmn-0.1.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|