nmn 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/examples/language/mingpt.py +1650 -0
- nmn/nnx/examples/vision/cnn_cifar.py +1769 -0
- nmn/nnx/nmn.py +26 -15
- nmn/nnx/yatattention.py +764 -0
- nmn/nnx/yatconv.py +41 -4
- nmn/torch/nmn.py +2 -1
- nmn-0.1.5.dist-info/METADATA +176 -0
- nmn-0.1.5.dist-info/RECORD +14 -0
- nmn-0.1.3.dist-info/METADATA +0 -119
- nmn-0.1.3.dist-info/RECORD +0 -11
- {nmn-0.1.3.dist-info → nmn-0.1.5.dist-info}/WHEEL +0 -0
- {nmn-0.1.3.dist-info → nmn-0.1.5.dist-info}/licenses/LICENSE +0 -0
nmn/nnx/nmn.py
CHANGED
@@ -4,26 +4,18 @@ import typing as tp
|
|
4
4
|
|
5
5
|
import jax
|
6
6
|
import jax.numpy as jnp
|
7
|
-
import numpy as np
|
8
7
|
from jax import lax
|
9
|
-
import opt_einsum
|
10
8
|
|
11
|
-
from flax.core.frozen_dict import FrozenDict
|
12
9
|
from flax import nnx
|
13
|
-
from flax.nnx import rnglib
|
14
|
-
from flax.nnx.module import Module
|
10
|
+
from flax.nnx import rnglib
|
11
|
+
from flax.nnx.module import Module
|
15
12
|
from flax.nnx.nn import dtypes, initializers
|
16
13
|
from flax.typing import (
|
17
14
|
Dtype,
|
18
|
-
Shape,
|
19
15
|
Initializer,
|
20
16
|
PrecisionLike,
|
21
17
|
DotGeneralT,
|
22
|
-
ConvGeneralDilatedT,
|
23
|
-
PaddingLike,
|
24
|
-
LaxPadding,
|
25
18
|
PromoteDtypeFn,
|
26
|
-
EinsumT,
|
27
19
|
)
|
28
20
|
|
29
21
|
Array = jax.Array
|
@@ -60,21 +52,26 @@ class YatNMN(Module):
|
|
60
52
|
in_features: the number of input features.
|
61
53
|
out_features: the number of output features.
|
62
54
|
use_bias: whether to add a bias to the output (default: True).
|
55
|
+
use_alpha: whether to use alpha scaling (default: True).
|
56
|
+
use_dropconnect: whether to use DropConnect (default: False).
|
63
57
|
dtype: the dtype of the computation (default: infer from input and params).
|
64
58
|
param_dtype: the dtype passed to parameter initializers (default: float32).
|
65
59
|
precision: numerical precision of the computation see ``jax.lax.Precision``
|
66
60
|
for details.
|
67
61
|
kernel_init: initializer function for the weight matrix.
|
68
62
|
bias_init: initializer function for the bias.
|
63
|
+
alpha_init: initializer function for the alpha.
|
69
64
|
dot_general: dot product function.
|
70
65
|
promote_dtype: function to promote the dtype of the arrays to the desired
|
71
66
|
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
|
72
67
|
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
73
68
|
promoted dtype.
|
69
|
+
epsilon: A small float added to the denominator to prevent division by zero.
|
70
|
+
drop_rate: dropout rate for DropConnect (default: 0.0).
|
74
71
|
rngs: rng key.
|
75
72
|
"""
|
76
73
|
|
77
|
-
__data__ = ('kernel', 'bias')
|
74
|
+
__data__ = ('kernel', 'bias', 'alpha', 'dropconnect_key')
|
78
75
|
|
79
76
|
def __init__(
|
80
77
|
self,
|
@@ -83,6 +80,7 @@ class YatNMN(Module):
|
|
83
80
|
*,
|
84
81
|
use_bias: bool = True,
|
85
82
|
use_alpha: bool = True,
|
83
|
+
use_dropconnect: bool = False,
|
86
84
|
dtype: tp.Optional[Dtype] = None,
|
87
85
|
param_dtype: Dtype = jnp.float32,
|
88
86
|
precision: PrecisionLike = None,
|
@@ -91,8 +89,9 @@ class YatNMN(Module):
|
|
91
89
|
alpha_init: Initializer = default_alpha_init,
|
92
90
|
dot_general: DotGeneralT = lax.dot_general,
|
93
91
|
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
94
|
-
rngs: rnglib.Rngs,
|
95
92
|
epsilon: float = 1e-5,
|
93
|
+
drop_rate: float = 0.0,
|
94
|
+
rngs: rnglib.Rngs,
|
96
95
|
):
|
97
96
|
|
98
97
|
kernel_key = rngs.params()
|
@@ -109,7 +108,7 @@ class YatNMN(Module):
|
|
109
108
|
self.alpha: nnx.Param[jax.Array] | None
|
110
109
|
if use_alpha:
|
111
110
|
alpha_key = rngs.params()
|
112
|
-
self.alpha = nnx.Param(alpha_init(
|
111
|
+
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
113
112
|
else:
|
114
113
|
self.alpha = None
|
115
114
|
|
@@ -117,6 +116,7 @@ class YatNMN(Module):
|
|
117
116
|
self.out_features = out_features
|
118
117
|
self.use_bias = use_bias
|
119
118
|
self.use_alpha = use_alpha
|
119
|
+
self.use_dropconnect = use_dropconnect
|
120
120
|
self.dtype = dtype
|
121
121
|
self.param_dtype = param_dtype
|
122
122
|
self.precision = precision
|
@@ -125,12 +125,19 @@ class YatNMN(Module):
|
|
125
125
|
self.dot_general = dot_general
|
126
126
|
self.promote_dtype = promote_dtype
|
127
127
|
self.epsilon = epsilon
|
128
|
+
self.drop_rate = drop_rate
|
129
|
+
|
130
|
+
if use_dropconnect:
|
131
|
+
self.dropconnect_key = rngs.params()
|
132
|
+
else:
|
133
|
+
self.dropconnect_key = None
|
128
134
|
|
129
|
-
def __call__(self, inputs: Array) -> Array:
|
135
|
+
def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
|
130
136
|
"""Applies a linear transformation to the inputs along the last dimension.
|
131
137
|
|
132
138
|
Args:
|
133
139
|
inputs: The nd-array to be transformed.
|
140
|
+
deterministic: If true, DropConnect is not applied (e.g., during inference).
|
134
141
|
|
135
142
|
Returns:
|
136
143
|
The transformed input.
|
@@ -139,6 +146,11 @@ class YatNMN(Module):
|
|
139
146
|
bias = self.bias.value if self.bias is not None else None
|
140
147
|
alpha = self.alpha.value if self.alpha is not None else None
|
141
148
|
|
149
|
+
if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
|
150
|
+
keep_prob = 1.0 - self.drop_rate
|
151
|
+
mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel.shape)
|
152
|
+
kernel = (kernel * mask) / keep_prob
|
153
|
+
|
142
154
|
inputs, kernel, bias, alpha = self.promote_dtype(
|
143
155
|
(inputs, kernel, bias, alpha), dtype=self.dtype
|
144
156
|
)
|
@@ -166,5 +178,4 @@ class YatNMN(Module):
|
|
166
178
|
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
167
179
|
y = y * scale
|
168
180
|
|
169
|
-
|
170
181
|
return y
|