nmn 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nmn/nnx/TODO CHANGED
@@ -1,2 +1,3 @@
1
1
  - add support to masked kernels
2
- - explain attention [directed graph]
2
+ - explain attention [directed graph]
3
+ - add support to optax softermax
nmn/nnx/conv_utils.py ADDED
@@ -0,0 +1,44 @@
1
+ import typing as tp
2
+ import jax.numpy as jnp
3
+ from jax import lax
4
+
5
+ from flax.nnx.nn import initializers
6
+ from flax.typing import PaddingLike, LaxPadding
7
+
8
+
9
+ # Default initializers
10
+ default_kernel_init = initializers.lecun_normal()
11
+ default_bias_init = initializers.zeros_init()
12
+ default_alpha_init = initializers.ones_init()
13
+
14
+ # Helper functions
15
+ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
16
+ """ "Canonicalizes conv padding to a jax.lax supported format."""
17
+ if isinstance(padding, str):
18
+ return padding
19
+ if isinstance(padding, int):
20
+ return [(padding, padding)] * rank
21
+ if isinstance(padding, tp.Sequence) and len(padding) == rank:
22
+ new_pad = []
23
+ for p in padding:
24
+ if isinstance(p, int):
25
+ new_pad.append((p, p))
26
+ elif isinstance(p, tuple) and len(p) == 2:
27
+ new_pad.append(p)
28
+ else:
29
+ break
30
+ if len(new_pad) == rank:
31
+ return new_pad
32
+ raise ValueError(
33
+ f'Invalid padding format: {padding}, should be str, int,'
34
+ f' or a sequence of len {rank} where each element is an'
35
+ ' int or pair of ints.'
36
+ )
37
+
38
+ def _conv_dimension_numbers(input_shape):
39
+ """Computes the dimension numbers based on the input shape."""
40
+ ndim = len(input_shape)
41
+ lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
42
+ rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
43
+ out_spec = lhs_spec
44
+ return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
@@ -6,7 +6,7 @@ import jax.numpy as jnp
6
6
  from jax import Array
7
7
 
8
8
 
9
- @partial(jax.jit, static_argnames=("n", "axis"))
9
+ @partial(jax.jit, static_argnames=("n", "axis", "epsilon"))
10
10
  def softermax(
11
11
  x: Array,
12
12
  n: float = 1.0,
nmn/nnx/yatconv.py CHANGED
@@ -18,46 +18,16 @@ from flax.typing import (
18
18
  LaxPadding,
19
19
  PromoteDtypeFn,
20
20
  )
21
+ from nmn.nnx.conv_utils import (
22
+ canonicalize_padding,
23
+ _conv_dimension_numbers,
24
+ default_kernel_init,
25
+ default_bias_init,
26
+ default_alpha_init,
27
+ )
21
28
 
22
29
  Array = jax.Array
23
30
 
24
- # Default initializers
25
- default_kernel_init = initializers.lecun_normal()
26
- default_bias_init = initializers.zeros_init()
27
- default_alpha_init = initializers.ones_init()
28
-
29
- # Helper functions
30
- def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
31
- """ "Canonicalizes conv padding to a jax.lax supported format."""
32
- if isinstance(padding, str):
33
- return padding
34
- if isinstance(padding, int):
35
- return [(padding, padding)] * rank
36
- if isinstance(padding, tp.Sequence) and len(padding) == rank:
37
- new_pad = []
38
- for p in padding:
39
- if isinstance(p, int):
40
- new_pad.append((p, p))
41
- elif isinstance(p, tuple) and len(p) == 2:
42
- new_pad.append(p)
43
- else:
44
- break
45
- if len(new_pad) == rank:
46
- return new_pad
47
- raise ValueError(
48
- f'Invalid padding format: {padding}, should be str, int,'
49
- f' or a sequence of len {rank} where each element is an'
50
- ' int or pair of ints.'
51
- )
52
-
53
- def _conv_dimension_numbers(input_shape):
54
- """Computes the dimension numbers based on the input shape."""
55
- ndim = len(input_shape)
56
- lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
57
- rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
58
- out_spec = lhs_spec
59
- return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
60
-
61
31
  class YatConv(Module):
62
32
  """Yat Convolution Module wrapping ``lax.conv_general_dilated``.
63
33
 
@@ -0,0 +1,295 @@
1
+ import typing as tp
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ from jax import lax
7
+
8
+ from flax import nnx
9
+ from flax.nnx.module import Module
10
+ from flax.nnx import rnglib
11
+ from flax.nnx.nn import dtypes, initializers
12
+ from flax.typing import (
13
+ Dtype,
14
+ Initializer,
15
+ PrecisionLike,
16
+ PaddingLike,
17
+ LaxPadding,
18
+ PromoteDtypeFn,
19
+ )
20
+ from nmn.nnx.conv_utils import (
21
+ canonicalize_padding,
22
+ _conv_dimension_numbers,
23
+ default_kernel_init,
24
+ default_bias_init,
25
+ default_alpha_init,
26
+ )
27
+
28
+ Array = jax.Array
29
+
30
+
31
+ class YatConvTranspose(Module):
32
+ """Yat Transposed Convolution Module wrapping ``lax.conv_transpose``.
33
+
34
+ Args:
35
+ in_features: int or tuple with number of input features.
36
+ out_features: int or tuple with number of output features.
37
+ kernel_size: shape of the convolutional kernel. For 1D convolution,
38
+ the kernel size can be passed as an integer, which will be interpreted
39
+ as a tuple of the single integer. For all other cases, it must be a
40
+ sequence of integers.
41
+ strides: an integer or a sequence of ``n`` integers, representing the
42
+ inter-window strides (default: 1).
43
+ padding: either the string ``'SAME'``, the string ``'VALID'``, the string
44
+ ``'CIRCULAR'`` (periodic boundary conditions), or a sequence of ``n``
45
+ ``(low, high)`` integer pairs that give the padding to apply before and after each
46
+ spatial dimension. A single int is interpeted as applying the same padding
47
+ in all dims and passign a single int in a sequence causes the same padding
48
+ to be used on both sides.
49
+ kernel_dilation: an integer or a sequence of ``n`` integers, giving the
50
+ dilation factor to apply in each spatial dimension of the convolution
51
+ kernel (default: 1). Convolution with kernel dilation
52
+ is also known as 'atrous convolution'.
53
+ use_bias: whether to add a bias to the output (default: True).
54
+ use_alpha: whether to use alpha scaling (default: True).
55
+ use_dropconnect: whether to use DropConnect (default: False).
56
+ mask: Optional mask for the weights during masked convolution. The mask must
57
+ be the same shape as the convolution weight matrix.
58
+ dtype: the dtype of the computation (default: infer from input and params).
59
+ param_dtype: the dtype passed to parameter initializers (default: float32).
60
+ precision: numerical precision of the computation see ``jax.lax.Precision``
61
+ for details.
62
+ kernel_init: initializer for the convolutional kernel.
63
+ bias_init: initializer for the bias.
64
+ transpose_kernel: if ``True`` flips spatial axes and swaps the input/output
65
+ channel axes of the kernel.
66
+ promote_dtype: function to promote the dtype of the arrays to the desired
67
+ dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
68
+ and a ``dtype`` keyword argument, and return a tuple of arrays with the
69
+ promoted dtype.
70
+ epsilon: A small float added to the denominator to prevent division by zero.
71
+ drop_rate: dropout rate for DropConnect (default: 0.0).
72
+ rngs: rng key.
73
+ """
74
+
75
+ __data__ = ('kernel', 'bias', 'mask', 'dropconnect_key')
76
+
77
+ def __init__(
78
+ self,
79
+ in_features: int,
80
+ out_features: int,
81
+ kernel_size: int | tp.Sequence[int],
82
+ strides: int | tp.Sequence[int] | None = None,
83
+ *,
84
+ padding: PaddingLike = 'SAME',
85
+ kernel_dilation: int | tp.Sequence[int] | None = None,
86
+ use_bias: bool = True,
87
+ use_alpha: bool = True,
88
+ use_dropconnect: bool = False,
89
+ mask: Array | None = None,
90
+ dtype: Dtype | None = None,
91
+ param_dtype: Dtype = jnp.float32,
92
+ precision: PrecisionLike | None = None,
93
+ kernel_init: Initializer = default_kernel_init,
94
+ bias_init: Initializer = default_bias_init,
95
+ alpha_init: Initializer = default_alpha_init,
96
+ transpose_kernel: bool = False,
97
+ promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
98
+ epsilon: float = 1e-5,
99
+ drop_rate: float = 0.0,
100
+ rngs: rnglib.Rngs,
101
+ ):
102
+ if isinstance(kernel_size, int):
103
+ kernel_size = (kernel_size,)
104
+ else:
105
+ kernel_size = tuple(kernel_size)
106
+
107
+ self.kernel_size = kernel_size
108
+ self.in_features = in_features
109
+ self.out_features = out_features
110
+ self.strides = strides
111
+ self.padding = padding
112
+ self.kernel_dilation = kernel_dilation
113
+ self.use_bias = use_bias
114
+ self.use_alpha = use_alpha
115
+ self.use_dropconnect = use_dropconnect
116
+ self.mask = mask
117
+ self.dtype = dtype
118
+ self.param_dtype = param_dtype
119
+ self.precision = precision
120
+ self.kernel_init = kernel_init
121
+ self.bias_init = bias_init
122
+ self.alpha_init = alpha_init
123
+ self.transpose_kernel = transpose_kernel
124
+ self.promote_dtype = promote_dtype
125
+ self.epsilon = epsilon
126
+ self.drop_rate = drop_rate
127
+
128
+ if self.transpose_kernel:
129
+ kernel_shape = kernel_size + (self.out_features, in_features)
130
+ else:
131
+ kernel_shape = kernel_size + (in_features, self.out_features)
132
+
133
+ self.kernel_shape = kernel_shape
134
+ self.kernel = nnx.Param(
135
+ self.kernel_init(rngs.params(), kernel_shape, self.param_dtype)
136
+ )
137
+
138
+ self.bias: nnx.Param | None
139
+ if self.use_bias:
140
+ self.bias = nnx.Param(
141
+ self.bias_init(rngs.params(), (self.out_features,), self.param_dtype)
142
+ )
143
+ else:
144
+ self.bias = None
145
+
146
+ if use_alpha:
147
+ alpha_key = rngs.params()
148
+ self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
149
+ else:
150
+ self.alpha = None
151
+
152
+ if use_dropconnect:
153
+ self.dropconnect_key = rngs.params()
154
+ else:
155
+ self.dropconnect_key = None
156
+
157
+ def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
158
+ assert isinstance(self.kernel_size, tuple)
159
+
160
+ def maybe_broadcast(
161
+ x: tp.Optional[tp.Union[int, tp.Sequence[int]]],
162
+ ) -> tuple[int, ...]:
163
+ if x is None:
164
+ x = 1
165
+ if isinstance(x, int):
166
+ return (x,) * len(self.kernel_size)
167
+ return tuple(x)
168
+
169
+ num_batch_dimensions = inputs.ndim - (len(self.kernel_size) + 1)
170
+ if num_batch_dimensions != 1:
171
+ input_batch_shape = inputs.shape[:num_batch_dimensions]
172
+ total_batch_size = int(np.prod(input_batch_shape))
173
+ flat_input_shape = (total_batch_size,) + inputs.shape[
174
+ num_batch_dimensions:
175
+ ]
176
+ inputs_flat = jnp.reshape(inputs, flat_input_shape)
177
+ else:
178
+ inputs_flat = inputs
179
+ input_batch_shape = ()
180
+
181
+ strides = maybe_broadcast(self.strides)
182
+ kernel_dilation = maybe_broadcast(self.kernel_dilation)
183
+
184
+ padding_lax = canonicalize_padding(self.padding, len(self.kernel_size))
185
+ if padding_lax == 'CIRCULAR':
186
+ padding_lax = 'VALID'
187
+
188
+ kernel_val = self.kernel.value
189
+
190
+ if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
191
+ keep_prob = 1.0 - self.drop_rate
192
+ mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel_val.shape)
193
+ kernel_val = (kernel_val * mask) / keep_prob
194
+
195
+ current_mask = self.mask
196
+ if current_mask is not None:
197
+ if current_mask.shape != self.kernel_shape:
198
+ raise ValueError(
199
+ 'Mask needs to have the same shape as weights. '
200
+ f'Shapes are: {current_mask.shape}, {self.kernel_shape}'
201
+ )
202
+ kernel_val *= current_mask
203
+
204
+ bias_val = self.bias.value if self.bias is not None else None
205
+ alpha = self.alpha.value if self.alpha is not None else None
206
+
207
+ inputs_promoted, kernel_promoted, bias_promoted = self.promote_dtype(
208
+ (inputs_flat, kernel_val, bias_val), dtype=self.dtype
209
+ )
210
+ inputs_flat = inputs_promoted
211
+ kernel_val = kernel_promoted
212
+ bias_val = bias_promoted
213
+
214
+ dot_prod_map = lax.conv_transpose(
215
+ inputs_flat,
216
+ kernel_val,
217
+ strides,
218
+ padding_lax,
219
+ rhs_dilation=kernel_dilation,
220
+ transpose_kernel=self.transpose_kernel,
221
+ precision=self.precision,
222
+ )
223
+
224
+ inputs_flat_squared = inputs_flat**2
225
+ if self.transpose_kernel:
226
+ patch_kernel_in_features = self.out_features
227
+ else:
228
+ patch_kernel_in_features = self.in_features
229
+
230
+ kernel_for_patch_sq_sum_shape = self.kernel_size + (patch_kernel_in_features, 1)
231
+ kernel_for_patch_sq_sum = jnp.ones(kernel_for_patch_sq_sum_shape, dtype=kernel_val.dtype)
232
+
233
+ patch_sq_sum_map_raw = lax.conv_transpose(
234
+ inputs_flat_squared,
235
+ kernel_for_patch_sq_sum,
236
+ strides,
237
+ padding_lax,
238
+ rhs_dilation=kernel_dilation,
239
+ transpose_kernel=self.transpose_kernel,
240
+ precision=self.precision,
241
+ )
242
+
243
+ if self.out_features > 1:
244
+ patch_sq_sum_map = jnp.repeat(patch_sq_sum_map_raw, self.out_features, axis=-1)
245
+ else:
246
+ patch_sq_sum_map = patch_sq_sum_map_raw
247
+
248
+ if self.transpose_kernel:
249
+ reduce_axes_for_kernel_sq = tuple(range(len(self.kernel_size))) + (len(self.kernel_size) + 1,)
250
+ else:
251
+ reduce_axes_for_kernel_sq = tuple(range(len(self.kernel_size))) + (len(self.kernel_size),)
252
+
253
+ kernel_sq_sum_per_filter = jnp.sum(kernel_val**2, axis=reduce_axes_for_kernel_sq)
254
+
255
+ distance_sq_map = patch_sq_sum_map + kernel_sq_sum_per_filter - 2 * dot_prod_map
256
+ y = dot_prod_map**2 / (distance_sq_map + self.epsilon)
257
+
258
+ if self.use_bias and bias_val is not None:
259
+ bias_reshape_dims = (1,) * (y.ndim - 1) + (-1,)
260
+ y += jnp.reshape(bias_val, bias_reshape_dims)
261
+
262
+ assert self.use_alpha == (alpha is not None)
263
+ if alpha is not None:
264
+ scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
265
+ y = y * scale
266
+
267
+ if self.padding == 'CIRCULAR':
268
+ scaled_x_dims = [
269
+ x_dim * stride
270
+ for x_dim, stride in zip(jnp.shape(inputs_flat)[1:-1], strides)
271
+ ]
272
+ size_diffs = [
273
+ -(y_dim - x_dim) % (2 * x_dim)
274
+ for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims)
275
+ ]
276
+ if self.transpose_kernel:
277
+ total_pad = [
278
+ (size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs
279
+ ]
280
+ else:
281
+ total_pad = [
282
+ ((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
283
+ ]
284
+ y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)])
285
+ for i in range(1, y.ndim - 1):
286
+ y = y.reshape(
287
+ y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :]
288
+ )
289
+ y = y.sum(axis=i)
290
+
291
+ if num_batch_dimensions != 1:
292
+ output_shape = input_batch_shape + y.shape[1:]
293
+ y = jnp.reshape(y, output_shape)
294
+
295
+ return y
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nmn
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: a neuron that matter
5
5
  Project-URL: Homepage, https://github.com/mlnomadpy/nmn
6
6
  Project-URL: Bug Tracker, https://github.com/mlnomadpy/my_package/issues
@@ -1,21 +1,23 @@
1
1
  nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
2
2
  nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
3
3
  nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
4
- nmn/nnx/TODO,sha256=U1WV51Eqij5igMjWLcbCjAZPONwIoPUQsMFKYHC6C8g,68
4
+ nmn/nnx/TODO,sha256=mr9z3yaqz9t_yiZCrfq7LEzlx9Gy-qTG9Q--JO-Apa8,101
5
+ nmn/nnx/conv_utils.py,sha256=7OiLx9mrjRCAxBOfiHo7uUU66g5AEhfPUXCWO7HI6O8,1424
5
6
  nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
6
7
  nmn/nnx/yatattention.py,sha256=qEWiG_FIgr-TslYCbm2pcBi1myXJLC84nT6k1tMQcr4,25001
7
- nmn/nnx/yatconv.py,sha256=EOAAWfuv5QA-QTru-JyYKYNoGqxcklu7ph9a-CtmYsA,13123
8
+ nmn/nnx/yatconv.py,sha256=nIIYT4QZErnneNM6WCC8EO-pYgJYIVQn8DiDsWdL10Y,12012
9
+ nmn/nnx/yatconv_transpose.py,sha256=PdfHzCZF_ENl7y11Hu7exoplasMVOxUjr0Lmu6MbJqc,10310
8
10
  nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
9
11
  nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
10
12
  nmn/nnx/loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
13
  nmn/nnx/squashers/__init__.py,sha256=zXYPa3yzqMXxkIPvNHiaV6pcZRDOdVrzaVdYVDGALTY,180
12
14
  nmn/nnx/squashers/soft_tanh.py,sha256=WSJkxD6L9WU1eqPwsK2AW4V6OJbw5pSWYjKwkiWtLdo,812
13
15
  nmn/nnx/squashers/softer_sigmoid.py,sha256=vE6IWorZdBb2cww6fskARnwzdjTcWB2kKohuaJWVGNs,845
14
- nmn/nnx/squashers/softermax.py,sha256=ggg0mHMFyk7b5xs31o-inNvWDzEvghD6YO3mtPlnkW4,1318
16
+ nmn/nnx/squashers/softermax.py,sha256=ZSOVs1I57mc25SW2ZA75k46asoxs1IqWaB9wUG7jb3s,1329
15
17
  nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
16
18
  nmn/torch/conv.py,sha256=g5YxStk1p85WkvfecqbzRZaWaAJahOSArpMcqxWAWKc,83413
17
19
  nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
18
- nmn-0.1.10.dist-info/METADATA,sha256=o-wLjeO-n2h56-cvw-AqrRiio5UFaerm58w03XkdHQY,8801
19
- nmn-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
20
- nmn-0.1.10.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
21
- nmn-0.1.10.dist-info/RECORD,,
20
+ nmn-0.1.12.dist-info/METADATA,sha256=W4oIKDKSFpZmg9o0oKvRPuO4zXbDweKKM4kATqC9zlI,8801
21
+ nmn-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
22
+ nmn-0.1.12.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
23
+ nmn-0.1.12.dist-info/RECORD,,
File without changes