nmn 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/TODO +2 -1
- nmn/nnx/conv_utils.py +44 -0
- nmn/nnx/squashers/softermax.py +1 -1
- nmn/nnx/yatconv.py +7 -37
- nmn/nnx/yatconv_transpose.py +295 -0
- {nmn-0.1.10.dist-info → nmn-0.1.12.dist-info}/METADATA +1 -1
- {nmn-0.1.10.dist-info → nmn-0.1.12.dist-info}/RECORD +9 -7
- {nmn-0.1.10.dist-info → nmn-0.1.12.dist-info}/WHEEL +0 -0
- {nmn-0.1.10.dist-info → nmn-0.1.12.dist-info}/licenses/LICENSE +0 -0
nmn/nnx/TODO
CHANGED
nmn/nnx/conv_utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
1
|
+
import typing as tp
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from jax import lax
|
4
|
+
|
5
|
+
from flax.nnx.nn import initializers
|
6
|
+
from flax.typing import PaddingLike, LaxPadding
|
7
|
+
|
8
|
+
|
9
|
+
# Default initializers
|
10
|
+
default_kernel_init = initializers.lecun_normal()
|
11
|
+
default_bias_init = initializers.zeros_init()
|
12
|
+
default_alpha_init = initializers.ones_init()
|
13
|
+
|
14
|
+
# Helper functions
|
15
|
+
def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
|
16
|
+
""" "Canonicalizes conv padding to a jax.lax supported format."""
|
17
|
+
if isinstance(padding, str):
|
18
|
+
return padding
|
19
|
+
if isinstance(padding, int):
|
20
|
+
return [(padding, padding)] * rank
|
21
|
+
if isinstance(padding, tp.Sequence) and len(padding) == rank:
|
22
|
+
new_pad = []
|
23
|
+
for p in padding:
|
24
|
+
if isinstance(p, int):
|
25
|
+
new_pad.append((p, p))
|
26
|
+
elif isinstance(p, tuple) and len(p) == 2:
|
27
|
+
new_pad.append(p)
|
28
|
+
else:
|
29
|
+
break
|
30
|
+
if len(new_pad) == rank:
|
31
|
+
return new_pad
|
32
|
+
raise ValueError(
|
33
|
+
f'Invalid padding format: {padding}, should be str, int,'
|
34
|
+
f' or a sequence of len {rank} where each element is an'
|
35
|
+
' int or pair of ints.'
|
36
|
+
)
|
37
|
+
|
38
|
+
def _conv_dimension_numbers(input_shape):
|
39
|
+
"""Computes the dimension numbers based on the input shape."""
|
40
|
+
ndim = len(input_shape)
|
41
|
+
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
|
42
|
+
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
|
43
|
+
out_spec = lhs_spec
|
44
|
+
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
|
nmn/nnx/squashers/softermax.py
CHANGED
nmn/nnx/yatconv.py
CHANGED
@@ -18,46 +18,16 @@ from flax.typing import (
|
|
18
18
|
LaxPadding,
|
19
19
|
PromoteDtypeFn,
|
20
20
|
)
|
21
|
+
from nmn.nnx.conv_utils import (
|
22
|
+
canonicalize_padding,
|
23
|
+
_conv_dimension_numbers,
|
24
|
+
default_kernel_init,
|
25
|
+
default_bias_init,
|
26
|
+
default_alpha_init,
|
27
|
+
)
|
21
28
|
|
22
29
|
Array = jax.Array
|
23
30
|
|
24
|
-
# Default initializers
|
25
|
-
default_kernel_init = initializers.lecun_normal()
|
26
|
-
default_bias_init = initializers.zeros_init()
|
27
|
-
default_alpha_init = initializers.ones_init()
|
28
|
-
|
29
|
-
# Helper functions
|
30
|
-
def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
|
31
|
-
""" "Canonicalizes conv padding to a jax.lax supported format."""
|
32
|
-
if isinstance(padding, str):
|
33
|
-
return padding
|
34
|
-
if isinstance(padding, int):
|
35
|
-
return [(padding, padding)] * rank
|
36
|
-
if isinstance(padding, tp.Sequence) and len(padding) == rank:
|
37
|
-
new_pad = []
|
38
|
-
for p in padding:
|
39
|
-
if isinstance(p, int):
|
40
|
-
new_pad.append((p, p))
|
41
|
-
elif isinstance(p, tuple) and len(p) == 2:
|
42
|
-
new_pad.append(p)
|
43
|
-
else:
|
44
|
-
break
|
45
|
-
if len(new_pad) == rank:
|
46
|
-
return new_pad
|
47
|
-
raise ValueError(
|
48
|
-
f'Invalid padding format: {padding}, should be str, int,'
|
49
|
-
f' or a sequence of len {rank} where each element is an'
|
50
|
-
' int or pair of ints.'
|
51
|
-
)
|
52
|
-
|
53
|
-
def _conv_dimension_numbers(input_shape):
|
54
|
-
"""Computes the dimension numbers based on the input shape."""
|
55
|
-
ndim = len(input_shape)
|
56
|
-
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
|
57
|
-
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
|
58
|
-
out_spec = lhs_spec
|
59
|
-
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
|
60
|
-
|
61
31
|
class YatConv(Module):
|
62
32
|
"""Yat Convolution Module wrapping ``lax.conv_general_dilated``.
|
63
33
|
|
@@ -0,0 +1,295 @@
|
|
1
|
+
import typing as tp
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import numpy as np
|
6
|
+
from jax import lax
|
7
|
+
|
8
|
+
from flax import nnx
|
9
|
+
from flax.nnx.module import Module
|
10
|
+
from flax.nnx import rnglib
|
11
|
+
from flax.nnx.nn import dtypes, initializers
|
12
|
+
from flax.typing import (
|
13
|
+
Dtype,
|
14
|
+
Initializer,
|
15
|
+
PrecisionLike,
|
16
|
+
PaddingLike,
|
17
|
+
LaxPadding,
|
18
|
+
PromoteDtypeFn,
|
19
|
+
)
|
20
|
+
from nmn.nnx.conv_utils import (
|
21
|
+
canonicalize_padding,
|
22
|
+
_conv_dimension_numbers,
|
23
|
+
default_kernel_init,
|
24
|
+
default_bias_init,
|
25
|
+
default_alpha_init,
|
26
|
+
)
|
27
|
+
|
28
|
+
Array = jax.Array
|
29
|
+
|
30
|
+
|
31
|
+
class YatConvTranspose(Module):
|
32
|
+
"""Yat Transposed Convolution Module wrapping ``lax.conv_transpose``.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
in_features: int or tuple with number of input features.
|
36
|
+
out_features: int or tuple with number of output features.
|
37
|
+
kernel_size: shape of the convolutional kernel. For 1D convolution,
|
38
|
+
the kernel size can be passed as an integer, which will be interpreted
|
39
|
+
as a tuple of the single integer. For all other cases, it must be a
|
40
|
+
sequence of integers.
|
41
|
+
strides: an integer or a sequence of ``n`` integers, representing the
|
42
|
+
inter-window strides (default: 1).
|
43
|
+
padding: either the string ``'SAME'``, the string ``'VALID'``, the string
|
44
|
+
``'CIRCULAR'`` (periodic boundary conditions), or a sequence of ``n``
|
45
|
+
``(low, high)`` integer pairs that give the padding to apply before and after each
|
46
|
+
spatial dimension. A single int is interpeted as applying the same padding
|
47
|
+
in all dims and passign a single int in a sequence causes the same padding
|
48
|
+
to be used on both sides.
|
49
|
+
kernel_dilation: an integer or a sequence of ``n`` integers, giving the
|
50
|
+
dilation factor to apply in each spatial dimension of the convolution
|
51
|
+
kernel (default: 1). Convolution with kernel dilation
|
52
|
+
is also known as 'atrous convolution'.
|
53
|
+
use_bias: whether to add a bias to the output (default: True).
|
54
|
+
use_alpha: whether to use alpha scaling (default: True).
|
55
|
+
use_dropconnect: whether to use DropConnect (default: False).
|
56
|
+
mask: Optional mask for the weights during masked convolution. The mask must
|
57
|
+
be the same shape as the convolution weight matrix.
|
58
|
+
dtype: the dtype of the computation (default: infer from input and params).
|
59
|
+
param_dtype: the dtype passed to parameter initializers (default: float32).
|
60
|
+
precision: numerical precision of the computation see ``jax.lax.Precision``
|
61
|
+
for details.
|
62
|
+
kernel_init: initializer for the convolutional kernel.
|
63
|
+
bias_init: initializer for the bias.
|
64
|
+
transpose_kernel: if ``True`` flips spatial axes and swaps the input/output
|
65
|
+
channel axes of the kernel.
|
66
|
+
promote_dtype: function to promote the dtype of the arrays to the desired
|
67
|
+
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
|
68
|
+
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
69
|
+
promoted dtype.
|
70
|
+
epsilon: A small float added to the denominator to prevent division by zero.
|
71
|
+
drop_rate: dropout rate for DropConnect (default: 0.0).
|
72
|
+
rngs: rng key.
|
73
|
+
"""
|
74
|
+
|
75
|
+
__data__ = ('kernel', 'bias', 'mask', 'dropconnect_key')
|
76
|
+
|
77
|
+
def __init__(
|
78
|
+
self,
|
79
|
+
in_features: int,
|
80
|
+
out_features: int,
|
81
|
+
kernel_size: int | tp.Sequence[int],
|
82
|
+
strides: int | tp.Sequence[int] | None = None,
|
83
|
+
*,
|
84
|
+
padding: PaddingLike = 'SAME',
|
85
|
+
kernel_dilation: int | tp.Sequence[int] | None = None,
|
86
|
+
use_bias: bool = True,
|
87
|
+
use_alpha: bool = True,
|
88
|
+
use_dropconnect: bool = False,
|
89
|
+
mask: Array | None = None,
|
90
|
+
dtype: Dtype | None = None,
|
91
|
+
param_dtype: Dtype = jnp.float32,
|
92
|
+
precision: PrecisionLike | None = None,
|
93
|
+
kernel_init: Initializer = default_kernel_init,
|
94
|
+
bias_init: Initializer = default_bias_init,
|
95
|
+
alpha_init: Initializer = default_alpha_init,
|
96
|
+
transpose_kernel: bool = False,
|
97
|
+
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
98
|
+
epsilon: float = 1e-5,
|
99
|
+
drop_rate: float = 0.0,
|
100
|
+
rngs: rnglib.Rngs,
|
101
|
+
):
|
102
|
+
if isinstance(kernel_size, int):
|
103
|
+
kernel_size = (kernel_size,)
|
104
|
+
else:
|
105
|
+
kernel_size = tuple(kernel_size)
|
106
|
+
|
107
|
+
self.kernel_size = kernel_size
|
108
|
+
self.in_features = in_features
|
109
|
+
self.out_features = out_features
|
110
|
+
self.strides = strides
|
111
|
+
self.padding = padding
|
112
|
+
self.kernel_dilation = kernel_dilation
|
113
|
+
self.use_bias = use_bias
|
114
|
+
self.use_alpha = use_alpha
|
115
|
+
self.use_dropconnect = use_dropconnect
|
116
|
+
self.mask = mask
|
117
|
+
self.dtype = dtype
|
118
|
+
self.param_dtype = param_dtype
|
119
|
+
self.precision = precision
|
120
|
+
self.kernel_init = kernel_init
|
121
|
+
self.bias_init = bias_init
|
122
|
+
self.alpha_init = alpha_init
|
123
|
+
self.transpose_kernel = transpose_kernel
|
124
|
+
self.promote_dtype = promote_dtype
|
125
|
+
self.epsilon = epsilon
|
126
|
+
self.drop_rate = drop_rate
|
127
|
+
|
128
|
+
if self.transpose_kernel:
|
129
|
+
kernel_shape = kernel_size + (self.out_features, in_features)
|
130
|
+
else:
|
131
|
+
kernel_shape = kernel_size + (in_features, self.out_features)
|
132
|
+
|
133
|
+
self.kernel_shape = kernel_shape
|
134
|
+
self.kernel = nnx.Param(
|
135
|
+
self.kernel_init(rngs.params(), kernel_shape, self.param_dtype)
|
136
|
+
)
|
137
|
+
|
138
|
+
self.bias: nnx.Param | None
|
139
|
+
if self.use_bias:
|
140
|
+
self.bias = nnx.Param(
|
141
|
+
self.bias_init(rngs.params(), (self.out_features,), self.param_dtype)
|
142
|
+
)
|
143
|
+
else:
|
144
|
+
self.bias = None
|
145
|
+
|
146
|
+
if use_alpha:
|
147
|
+
alpha_key = rngs.params()
|
148
|
+
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
149
|
+
else:
|
150
|
+
self.alpha = None
|
151
|
+
|
152
|
+
if use_dropconnect:
|
153
|
+
self.dropconnect_key = rngs.params()
|
154
|
+
else:
|
155
|
+
self.dropconnect_key = None
|
156
|
+
|
157
|
+
def __call__(self, inputs: Array, *, deterministic: bool = False) -> Array:
|
158
|
+
assert isinstance(self.kernel_size, tuple)
|
159
|
+
|
160
|
+
def maybe_broadcast(
|
161
|
+
x: tp.Optional[tp.Union[int, tp.Sequence[int]]],
|
162
|
+
) -> tuple[int, ...]:
|
163
|
+
if x is None:
|
164
|
+
x = 1
|
165
|
+
if isinstance(x, int):
|
166
|
+
return (x,) * len(self.kernel_size)
|
167
|
+
return tuple(x)
|
168
|
+
|
169
|
+
num_batch_dimensions = inputs.ndim - (len(self.kernel_size) + 1)
|
170
|
+
if num_batch_dimensions != 1:
|
171
|
+
input_batch_shape = inputs.shape[:num_batch_dimensions]
|
172
|
+
total_batch_size = int(np.prod(input_batch_shape))
|
173
|
+
flat_input_shape = (total_batch_size,) + inputs.shape[
|
174
|
+
num_batch_dimensions:
|
175
|
+
]
|
176
|
+
inputs_flat = jnp.reshape(inputs, flat_input_shape)
|
177
|
+
else:
|
178
|
+
inputs_flat = inputs
|
179
|
+
input_batch_shape = ()
|
180
|
+
|
181
|
+
strides = maybe_broadcast(self.strides)
|
182
|
+
kernel_dilation = maybe_broadcast(self.kernel_dilation)
|
183
|
+
|
184
|
+
padding_lax = canonicalize_padding(self.padding, len(self.kernel_size))
|
185
|
+
if padding_lax == 'CIRCULAR':
|
186
|
+
padding_lax = 'VALID'
|
187
|
+
|
188
|
+
kernel_val = self.kernel.value
|
189
|
+
|
190
|
+
if self.use_dropconnect and not deterministic and self.drop_rate > 0.0:
|
191
|
+
keep_prob = 1.0 - self.drop_rate
|
192
|
+
mask = jax.random.bernoulli(self.dropconnect_key, p=keep_prob, shape=kernel_val.shape)
|
193
|
+
kernel_val = (kernel_val * mask) / keep_prob
|
194
|
+
|
195
|
+
current_mask = self.mask
|
196
|
+
if current_mask is not None:
|
197
|
+
if current_mask.shape != self.kernel_shape:
|
198
|
+
raise ValueError(
|
199
|
+
'Mask needs to have the same shape as weights. '
|
200
|
+
f'Shapes are: {current_mask.shape}, {self.kernel_shape}'
|
201
|
+
)
|
202
|
+
kernel_val *= current_mask
|
203
|
+
|
204
|
+
bias_val = self.bias.value if self.bias is not None else None
|
205
|
+
alpha = self.alpha.value if self.alpha is not None else None
|
206
|
+
|
207
|
+
inputs_promoted, kernel_promoted, bias_promoted = self.promote_dtype(
|
208
|
+
(inputs_flat, kernel_val, bias_val), dtype=self.dtype
|
209
|
+
)
|
210
|
+
inputs_flat = inputs_promoted
|
211
|
+
kernel_val = kernel_promoted
|
212
|
+
bias_val = bias_promoted
|
213
|
+
|
214
|
+
dot_prod_map = lax.conv_transpose(
|
215
|
+
inputs_flat,
|
216
|
+
kernel_val,
|
217
|
+
strides,
|
218
|
+
padding_lax,
|
219
|
+
rhs_dilation=kernel_dilation,
|
220
|
+
transpose_kernel=self.transpose_kernel,
|
221
|
+
precision=self.precision,
|
222
|
+
)
|
223
|
+
|
224
|
+
inputs_flat_squared = inputs_flat**2
|
225
|
+
if self.transpose_kernel:
|
226
|
+
patch_kernel_in_features = self.out_features
|
227
|
+
else:
|
228
|
+
patch_kernel_in_features = self.in_features
|
229
|
+
|
230
|
+
kernel_for_patch_sq_sum_shape = self.kernel_size + (patch_kernel_in_features, 1)
|
231
|
+
kernel_for_patch_sq_sum = jnp.ones(kernel_for_patch_sq_sum_shape, dtype=kernel_val.dtype)
|
232
|
+
|
233
|
+
patch_sq_sum_map_raw = lax.conv_transpose(
|
234
|
+
inputs_flat_squared,
|
235
|
+
kernel_for_patch_sq_sum,
|
236
|
+
strides,
|
237
|
+
padding_lax,
|
238
|
+
rhs_dilation=kernel_dilation,
|
239
|
+
transpose_kernel=self.transpose_kernel,
|
240
|
+
precision=self.precision,
|
241
|
+
)
|
242
|
+
|
243
|
+
if self.out_features > 1:
|
244
|
+
patch_sq_sum_map = jnp.repeat(patch_sq_sum_map_raw, self.out_features, axis=-1)
|
245
|
+
else:
|
246
|
+
patch_sq_sum_map = patch_sq_sum_map_raw
|
247
|
+
|
248
|
+
if self.transpose_kernel:
|
249
|
+
reduce_axes_for_kernel_sq = tuple(range(len(self.kernel_size))) + (len(self.kernel_size) + 1,)
|
250
|
+
else:
|
251
|
+
reduce_axes_for_kernel_sq = tuple(range(len(self.kernel_size))) + (len(self.kernel_size),)
|
252
|
+
|
253
|
+
kernel_sq_sum_per_filter = jnp.sum(kernel_val**2, axis=reduce_axes_for_kernel_sq)
|
254
|
+
|
255
|
+
distance_sq_map = patch_sq_sum_map + kernel_sq_sum_per_filter - 2 * dot_prod_map
|
256
|
+
y = dot_prod_map**2 / (distance_sq_map + self.epsilon)
|
257
|
+
|
258
|
+
if self.use_bias and bias_val is not None:
|
259
|
+
bias_reshape_dims = (1,) * (y.ndim - 1) + (-1,)
|
260
|
+
y += jnp.reshape(bias_val, bias_reshape_dims)
|
261
|
+
|
262
|
+
assert self.use_alpha == (alpha is not None)
|
263
|
+
if alpha is not None:
|
264
|
+
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
265
|
+
y = y * scale
|
266
|
+
|
267
|
+
if self.padding == 'CIRCULAR':
|
268
|
+
scaled_x_dims = [
|
269
|
+
x_dim * stride
|
270
|
+
for x_dim, stride in zip(jnp.shape(inputs_flat)[1:-1], strides)
|
271
|
+
]
|
272
|
+
size_diffs = [
|
273
|
+
-(y_dim - x_dim) % (2 * x_dim)
|
274
|
+
for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims)
|
275
|
+
]
|
276
|
+
if self.transpose_kernel:
|
277
|
+
total_pad = [
|
278
|
+
(size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs
|
279
|
+
]
|
280
|
+
else:
|
281
|
+
total_pad = [
|
282
|
+
((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs
|
283
|
+
]
|
284
|
+
y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)])
|
285
|
+
for i in range(1, y.ndim - 1):
|
286
|
+
y = y.reshape(
|
287
|
+
y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :]
|
288
|
+
)
|
289
|
+
y = y.sum(axis=i)
|
290
|
+
|
291
|
+
if num_batch_dimensions != 1:
|
292
|
+
output_shape = input_batch_shape + y.shape[1:]
|
293
|
+
y = jnp.reshape(y, output_shape)
|
294
|
+
|
295
|
+
return y
|
@@ -1,21 +1,23 @@
|
|
1
1
|
nmn/__init__.py,sha256=F_5o-lCggdEdWfR1l1YC_jfR01mJmveugwUndoRx8n8,83
|
2
2
|
nmn/keras/nmn.py,sha256=E7V7kyFB09PfMG1Da_TA2FirOiTCeAXYp3JWACV8h_c,5908
|
3
3
|
nmn/linen/nmn.py,sha256=j4v6Z793wliE0xEAITde7jXu9Qras9u75NqdOSPSM4Q,3722
|
4
|
-
nmn/nnx/TODO,sha256=
|
4
|
+
nmn/nnx/TODO,sha256=mr9z3yaqz9t_yiZCrfq7LEzlx9Gy-qTG9Q--JO-Apa8,101
|
5
|
+
nmn/nnx/conv_utils.py,sha256=7OiLx9mrjRCAxBOfiHo7uUU66g5AEhfPUXCWO7HI6O8,1424
|
5
6
|
nmn/nnx/nmn.py,sha256=tPNUtF8Lmv_B1TgMoVXfMQ9x0IPGKjSyAP6HnZ-YBsM,5651
|
6
7
|
nmn/nnx/yatattention.py,sha256=qEWiG_FIgr-TslYCbm2pcBi1myXJLC84nT6k1tMQcr4,25001
|
7
|
-
nmn/nnx/yatconv.py,sha256=
|
8
|
+
nmn/nnx/yatconv.py,sha256=nIIYT4QZErnneNM6WCC8EO-pYgJYIVQn8DiDsWdL10Y,12012
|
9
|
+
nmn/nnx/yatconv_transpose.py,sha256=PdfHzCZF_ENl7y11Hu7exoplasMVOxUjr0Lmu6MbJqc,10310
|
8
10
|
nmn/nnx/examples/language/mingpt.py,sha256=RveY3NwriTGPBdj8HNKDNtnXMaH0pgux8554m4Bhho4,61080
|
9
11
|
nmn/nnx/examples/vision/cnn_cifar.py,sha256=UcK52-SCwuE2hl2BkpEbyg7N3Jwvvz8iFxiqhI7B9ew,73961
|
10
12
|
nmn/nnx/loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
13
|
nmn/nnx/squashers/__init__.py,sha256=zXYPa3yzqMXxkIPvNHiaV6pcZRDOdVrzaVdYVDGALTY,180
|
12
14
|
nmn/nnx/squashers/soft_tanh.py,sha256=WSJkxD6L9WU1eqPwsK2AW4V6OJbw5pSWYjKwkiWtLdo,812
|
13
15
|
nmn/nnx/squashers/softer_sigmoid.py,sha256=vE6IWorZdBb2cww6fskARnwzdjTcWB2kKohuaJWVGNs,845
|
14
|
-
nmn/nnx/squashers/softermax.py,sha256=
|
16
|
+
nmn/nnx/squashers/softermax.py,sha256=ZSOVs1I57mc25SW2ZA75k46asoxs1IqWaB9wUG7jb3s,1329
|
15
17
|
nmn/tf/nmn.py,sha256=A-K65z9_aN62tAy12b0553nXxrzOofK1umGMRGJYjqw,6036
|
16
18
|
nmn/torch/conv.py,sha256=g5YxStk1p85WkvfecqbzRZaWaAJahOSArpMcqxWAWKc,83413
|
17
19
|
nmn/torch/nmn.py,sha256=8K0S3nwpGprT7apbCqpaYpKpxq8F8g8EL8PHIezgMCY,4658
|
18
|
-
nmn-0.1.
|
19
|
-
nmn-0.1.
|
20
|
-
nmn-0.1.
|
21
|
-
nmn-0.1.
|
20
|
+
nmn-0.1.12.dist-info/METADATA,sha256=W4oIKDKSFpZmg9o0oKvRPuO4zXbDweKKM4kATqC9zlI,8801
|
21
|
+
nmn-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
22
|
+
nmn-0.1.12.dist-info/licenses/LICENSE,sha256=kbZSd5WewnN2PSjvAC6DprP7pXx6NUNsnltmU2Mz1yA,34519
|
23
|
+
nmn-0.1.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|