nmn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/__init__.py +3 -0
- nmn/keras/nmn.py +153 -0
- nmn/linen/nmn.py +112 -0
- nmn/nnx/nmn.py +170 -0
- nmn/nnx/yatconv.py +320 -0
- nmn/tf/nmn.py +179 -0
- nmn/torch/nmn.py +144 -0
- nmn-0.1.0.dist-info/METADATA +76 -0
- nmn-0.1.0.dist-info/RECORD +11 -0
- nmn-0.1.0.dist-info/WHEEL +4 -0
- nmn-0.1.0.dist-info/licenses/LICENSE +661 -0
nmn/nnx/yatconv.py
ADDED
@@ -0,0 +1,320 @@
|
|
1
|
+
import typing as tp
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
import numpy as np
|
6
|
+
from jax import lax
|
7
|
+
|
8
|
+
from flax import nnx
|
9
|
+
from flax.nnx.module import Module
|
10
|
+
from flax.nnx import rnglib
|
11
|
+
from flax.nnx.nn import dtypes, initializers
|
12
|
+
from flax.typing import (
|
13
|
+
Dtype,
|
14
|
+
Initializer,
|
15
|
+
PrecisionLike,
|
16
|
+
ConvGeneralDilatedT,
|
17
|
+
PaddingLike,
|
18
|
+
LaxPadding,
|
19
|
+
PromoteDtypeFn,
|
20
|
+
)
|
21
|
+
|
22
|
+
Array = jax.Array
|
23
|
+
|
24
|
+
# Default initializers
|
25
|
+
default_kernel_init = initializers.lecun_normal()
|
26
|
+
default_bias_init = initializers.zeros_init()
|
27
|
+
|
28
|
+
# Helper functions
|
29
|
+
def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
|
30
|
+
""" "Canonicalizes conv padding to a jax.lax supported format."""
|
31
|
+
if isinstance(padding, str):
|
32
|
+
return padding
|
33
|
+
if isinstance(padding, int):
|
34
|
+
return [(padding, padding)] * rank
|
35
|
+
if isinstance(padding, tp.Sequence) and len(padding) == rank:
|
36
|
+
new_pad = []
|
37
|
+
for p in padding:
|
38
|
+
if isinstance(p, int):
|
39
|
+
new_pad.append((p, p))
|
40
|
+
elif isinstance(p, tuple) and len(p) == 2:
|
41
|
+
new_pad.append(p)
|
42
|
+
else:
|
43
|
+
break
|
44
|
+
if len(new_pad) == rank:
|
45
|
+
return new_pad
|
46
|
+
raise ValueError(
|
47
|
+
f'Invalid padding format: {padding}, should be str, int,'
|
48
|
+
f' or a sequence of len {rank} where each element is an'
|
49
|
+
' int or pair of ints.'
|
50
|
+
)
|
51
|
+
|
52
|
+
def _conv_dimension_numbers(input_shape):
|
53
|
+
"""Computes the dimension numbers based on the input shape."""
|
54
|
+
ndim = len(input_shape)
|
55
|
+
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
|
56
|
+
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
|
57
|
+
out_spec = lhs_spec
|
58
|
+
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
|
59
|
+
|
60
|
+
class YatConv(Module):
|
61
|
+
"""Yat Convolution Module wrapping ``lax.conv_general_dilated``.
|
62
|
+
|
63
|
+
Example usage::
|
64
|
+
|
65
|
+
>>> from flax.nnx import conv # Assuming this file is flax/nnx/conv.py
|
66
|
+
>>> import jax, jax.numpy as jnp
|
67
|
+
>>> from flax.nnx import rnglib, state
|
68
|
+
|
69
|
+
>>> rngs = rnglib.Rngs(0)
|
70
|
+
>>> x = jnp.ones((1, 8, 3))
|
71
|
+
|
72
|
+
>>> # valid padding
|
73
|
+
>>> layer = conv.Conv(in_features=3, out_features=4, kernel_size=(3,),
|
74
|
+
... padding='VALID', rngs=rngs)
|
75
|
+
>>> s = state(layer)
|
76
|
+
>>> print(s['kernel'].value.shape)
|
77
|
+
(3, 3, 4)
|
78
|
+
>>> print(s['bias'].value.shape)
|
79
|
+
(4,)
|
80
|
+
>>> out = layer(x)
|
81
|
+
>>> print(out.shape)
|
82
|
+
(1, 6, 4)
|
83
|
+
|
84
|
+
Args:
|
85
|
+
in_features: int or tuple with number of input features.
|
86
|
+
out_features: int or tuple with number of output features.
|
87
|
+
kernel_size: shape of the convolutional kernel. For 1D convolution,
|
88
|
+
the kernel size can be passed as an integer, which will be interpreted
|
89
|
+
as a tuple of the single integer. For all other cases, it must be a
|
90
|
+
sequence of integers.
|
91
|
+
strides: an integer or a sequence of ``n`` integers, representing the
|
92
|
+
inter-window strides (default: 1).
|
93
|
+
padding: either the string ``'SAME'``, the string ``'VALID'``, the string
|
94
|
+
``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'`
|
95
|
+
(reflection across the padding boundary), or a sequence of ``n``
|
96
|
+
``(low, high)`` integer pairs that give the padding to apply before and after each
|
97
|
+
spatial dimension. A single int is interpeted as applying the same padding
|
98
|
+
in all dims and passign a single int in a sequence causes the same padding
|
99
|
+
to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will
|
100
|
+
left-pad the convolution axis, resulting in same-sized output.
|
101
|
+
input_dilation: an integer or a sequence of ``n`` integers, giving the
|
102
|
+
dilation factor to apply in each spatial dimension of ``inputs``
|
103
|
+
(default: 1). Convolution with input dilation ``d`` is equivalent to
|
104
|
+
transposed convolution with stride ``d``.
|
105
|
+
kernel_dilation: an integer or a sequence of ``n`` integers, giving the
|
106
|
+
dilation factor to apply in each spatial dimension of the convolution
|
107
|
+
kernel (default: 1). Convolution with kernel dilation
|
108
|
+
is also known as 'atrous convolution'.
|
109
|
+
feature_group_count: integer, default 1. If specified divides the input
|
110
|
+
features into groups.
|
111
|
+
use_bias: whether to add a bias to the output (default: True).
|
112
|
+
mask: Optional mask for the weights during masked convolution. The mask must
|
113
|
+
be the same shape as the convolution weight matrix.
|
114
|
+
dtype: the dtype of the computation (default: infer from input and params).
|
115
|
+
param_dtype: the dtype passed to parameter initializers (default: float32).
|
116
|
+
precision: numerical precision of the computation see ``jax.lax.Precision``
|
117
|
+
for details.
|
118
|
+
kernel_init: initializer for the convolutional kernel.
|
119
|
+
bias_init: initializer for the bias.
|
120
|
+
promote_dtype: function to promote the dtype of the arrays to the desired
|
121
|
+
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
|
122
|
+
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
123
|
+
promoted dtype.
|
124
|
+
epsilon: A small float added to the denominator to prevent division by zero.
|
125
|
+
rngs: rng key.
|
126
|
+
"""
|
127
|
+
|
128
|
+
__data__ = ('kernel', 'bias', 'mask')
|
129
|
+
|
130
|
+
def __init__(
|
131
|
+
self,
|
132
|
+
in_features: int,
|
133
|
+
out_features: int,
|
134
|
+
kernel_size: int | tp.Sequence[int],
|
135
|
+
strides: tp.Union[None, int, tp.Sequence[int]] = 1,
|
136
|
+
*,
|
137
|
+
padding: PaddingLike = 'SAME',
|
138
|
+
input_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
|
139
|
+
kernel_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
|
140
|
+
feature_group_count: int = 1,
|
141
|
+
use_bias: bool = True,
|
142
|
+
mask: tp.Optional[Array] = None,
|
143
|
+
dtype: tp.Optional[Dtype] = None,
|
144
|
+
param_dtype: Dtype = jnp.float32,
|
145
|
+
precision: PrecisionLike = None,
|
146
|
+
kernel_init: Initializer = default_kernel_init,
|
147
|
+
bias_init: Initializer = default_bias_init,
|
148
|
+
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
|
149
|
+
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
150
|
+
epsilon: float = 1e-5,
|
151
|
+
rngs: rnglib.Rngs,
|
152
|
+
):
|
153
|
+
if isinstance(kernel_size, int):
|
154
|
+
kernel_size = (kernel_size,)
|
155
|
+
else:
|
156
|
+
kernel_size = tuple(kernel_size)
|
157
|
+
|
158
|
+
self.kernel_shape = kernel_size + (
|
159
|
+
in_features // feature_group_count,
|
160
|
+
out_features,
|
161
|
+
)
|
162
|
+
kernel_key = rngs.params()
|
163
|
+
self.kernel = nnx.Param(kernel_init(kernel_key, self.kernel_shape, param_dtype))
|
164
|
+
|
165
|
+
self.bias: nnx.Param[jax.Array] | None
|
166
|
+
if use_bias:
|
167
|
+
bias_shape = (out_features,)
|
168
|
+
bias_key = rngs.params()
|
169
|
+
self.bias = nnx.Param(bias_init(bias_key, bias_shape, param_dtype))
|
170
|
+
else:
|
171
|
+
self.bias = None
|
172
|
+
|
173
|
+
self.in_features = in_features
|
174
|
+
self.out_features = out_features
|
175
|
+
self.kernel_size = kernel_size
|
176
|
+
self.strides = strides
|
177
|
+
self.padding = padding
|
178
|
+
self.input_dilation = input_dilation
|
179
|
+
self.kernel_dilation = kernel_dilation
|
180
|
+
self.feature_group_count = feature_group_count
|
181
|
+
self.use_bias = use_bias
|
182
|
+
self.mask = mask
|
183
|
+
self.dtype = dtype
|
184
|
+
self.param_dtype = param_dtype
|
185
|
+
self.precision = precision
|
186
|
+
self.kernel_init = kernel_init
|
187
|
+
self.bias_init = bias_init
|
188
|
+
self.conv_general_dilated = conv_general_dilated
|
189
|
+
self.promote_dtype = promote_dtype
|
190
|
+
self.epsilon = epsilon
|
191
|
+
|
192
|
+
def __call__(self, inputs: Array) -> Array:
|
193
|
+
assert isinstance(self.kernel_size, tuple)
|
194
|
+
|
195
|
+
def maybe_broadcast(
|
196
|
+
x: tp.Optional[tp.Union[int, tp.Sequence[int]]],
|
197
|
+
) -> tuple[int, ...]:
|
198
|
+
if x is None:
|
199
|
+
x = 1
|
200
|
+
if isinstance(x, int):
|
201
|
+
return (x,) * len(self.kernel_size)
|
202
|
+
return tuple(x)
|
203
|
+
|
204
|
+
num_batch_dimensions = inputs.ndim - (len(self.kernel_size) + 1)
|
205
|
+
if num_batch_dimensions != 1:
|
206
|
+
input_batch_shape = inputs.shape[:num_batch_dimensions]
|
207
|
+
total_batch_size = int(np.prod(input_batch_shape))
|
208
|
+
flat_input_shape = (total_batch_size,) + inputs.shape[
|
209
|
+
num_batch_dimensions:
|
210
|
+
]
|
211
|
+
inputs_flat = jnp.reshape(inputs, flat_input_shape)
|
212
|
+
else:
|
213
|
+
inputs_flat = inputs
|
214
|
+
input_batch_shape = ()
|
215
|
+
|
216
|
+
strides = maybe_broadcast(self.strides)
|
217
|
+
input_dilation = maybe_broadcast(self.input_dilation)
|
218
|
+
kernel_dilation = maybe_broadcast(self.kernel_dilation)
|
219
|
+
|
220
|
+
padding_lax = canonicalize_padding(self.padding, len(self.kernel_size))
|
221
|
+
if padding_lax in ('CIRCULAR', 'REFLECT'):
|
222
|
+
assert isinstance(padding_lax, str)
|
223
|
+
kernel_size_dilated = [
|
224
|
+
(k - 1) * d + 1 for k, d in zip(self.kernel_size, kernel_dilation)
|
225
|
+
]
|
226
|
+
zero_pad: tp.List[tuple[int, int]] = [(0, 0)]
|
227
|
+
pads = (
|
228
|
+
zero_pad
|
229
|
+
+ [((k - 1) // 2, k // 2) for k in kernel_size_dilated]
|
230
|
+
+ [(0, 0)]
|
231
|
+
)
|
232
|
+
padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax]
|
233
|
+
inputs_flat = jnp.pad(inputs_flat, pads, mode=padding_mode)
|
234
|
+
padding_lax = 'VALID'
|
235
|
+
elif padding_lax == 'CAUSAL':
|
236
|
+
if len(self.kernel_size) != 1:
|
237
|
+
raise ValueError(
|
238
|
+
'Causal padding is only implemented for 1D convolutions.'
|
239
|
+
)
|
240
|
+
left_pad = kernel_dilation[0] * (self.kernel_size[0] - 1)
|
241
|
+
pads = [(0, 0), (left_pad, 0), (0, 0)]
|
242
|
+
inputs_flat = jnp.pad(inputs_flat, pads)
|
243
|
+
padding_lax = 'VALID'
|
244
|
+
|
245
|
+
dimension_numbers = _conv_dimension_numbers(inputs_flat.shape)
|
246
|
+
assert self.in_features % self.feature_group_count == 0
|
247
|
+
|
248
|
+
kernel_val = self.kernel.value
|
249
|
+
|
250
|
+
current_mask = self.mask
|
251
|
+
if current_mask is not None:
|
252
|
+
if current_mask.shape != self.kernel_shape:
|
253
|
+
raise ValueError(
|
254
|
+
'Mask needs to have the same shape as weights. '
|
255
|
+
f'Shapes are: {current_mask.shape}, {self.kernel_shape}'
|
256
|
+
)
|
257
|
+
kernel_val *= current_mask
|
258
|
+
|
259
|
+
bias_val = self.bias.value if self.bias is not None else None
|
260
|
+
|
261
|
+
inputs_promoted, kernel_promoted, bias_promoted = self.promote_dtype(
|
262
|
+
(inputs_flat, kernel_val, bias_val), dtype=self.dtype
|
263
|
+
)
|
264
|
+
inputs_flat = inputs_promoted
|
265
|
+
kernel_val = kernel_promoted
|
266
|
+
bias_val = bias_promoted
|
267
|
+
|
268
|
+
dot_prod_map = self.conv_general_dilated(
|
269
|
+
inputs_flat,
|
270
|
+
kernel_val,
|
271
|
+
strides,
|
272
|
+
padding_lax,
|
273
|
+
lhs_dilation=input_dilation,
|
274
|
+
rhs_dilation=kernel_dilation,
|
275
|
+
dimension_numbers=dimension_numbers,
|
276
|
+
feature_group_count=self.feature_group_count,
|
277
|
+
precision=self.precision,
|
278
|
+
)
|
279
|
+
|
280
|
+
inputs_flat_squared = inputs_flat**2
|
281
|
+
kernel_in_channels_for_sum_sq = self.kernel_shape[-2]
|
282
|
+
kernel_for_patch_sq_sum_shape = self.kernel_size + (kernel_in_channels_for_sum_sq, 1)
|
283
|
+
kernel_for_patch_sq_sum = jnp.ones(kernel_for_patch_sq_sum_shape, dtype=kernel_val.dtype)
|
284
|
+
|
285
|
+
patch_sq_sum_map_raw = self.conv_general_dilated(
|
286
|
+
inputs_flat_squared,
|
287
|
+
kernel_for_patch_sq_sum,
|
288
|
+
strides,
|
289
|
+
padding_lax,
|
290
|
+
lhs_dilation=input_dilation,
|
291
|
+
rhs_dilation=kernel_dilation,
|
292
|
+
dimension_numbers=dimension_numbers,
|
293
|
+
feature_group_count=self.feature_group_count,
|
294
|
+
precision=self.precision,
|
295
|
+
)
|
296
|
+
|
297
|
+
if self.feature_group_count > 1:
|
298
|
+
num_out_channels_per_group = self.out_features // self.feature_group_count
|
299
|
+
if num_out_channels_per_group == 0 :
|
300
|
+
raise ValueError(
|
301
|
+
"out_features must be a multiple of feature_group_count and greater or equal."
|
302
|
+
)
|
303
|
+
patch_sq_sum_map = jnp.repeat(patch_sq_sum_map_raw, num_out_channels_per_group, axis=-1)
|
304
|
+
else:
|
305
|
+
patch_sq_sum_map = patch_sq_sum_map_raw
|
306
|
+
|
307
|
+
reduce_axes_for_kernel_sq = tuple(range(kernel_val.ndim - 1))
|
308
|
+
kernel_sq_sum_per_filter = jnp.sum(kernel_val**2, axis=reduce_axes_for_kernel_sq)
|
309
|
+
|
310
|
+
distance_sq_map = patch_sq_sum_map + kernel_sq_sum_per_filter - 2 * dot_prod_map
|
311
|
+
y = dot_prod_map**2 / (distance_sq_map + self.epsilon)
|
312
|
+
|
313
|
+
if self.use_bias and bias_val is not None:
|
314
|
+
bias_reshape_dims = (1,) * (y.ndim - 1) + (-1,)
|
315
|
+
y += jnp.reshape(bias_val, bias_reshape_dims)
|
316
|
+
|
317
|
+
if num_batch_dimensions != 1:
|
318
|
+
output_shape = input_batch_shape + y.shape[1:]
|
319
|
+
y = jnp.reshape(y, output_shape)
|
320
|
+
return y
|
nmn/tf/nmn.py
ADDED
@@ -0,0 +1,179 @@
|
|
1
|
+
import tensorflow as tf
|
2
|
+
import math
|
3
|
+
from typing import Optional, Any, Tuple, Union, List, Callable
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
def create_orthogonal_matrix(shape: Tuple[int, ...], dtype: tf.DType = tf.float32) -> tf.Tensor:
|
7
|
+
"""Creates an orthogonal matrix using QR decomposition."""
|
8
|
+
num_rows, num_cols = shape
|
9
|
+
random_matrix = tf.random.normal([num_rows, num_cols], dtype=dtype)
|
10
|
+
q, r = tf.linalg.qr(random_matrix)
|
11
|
+
# Make it uniform
|
12
|
+
d = tf.linalg.diag_part(r)
|
13
|
+
ph = tf.cast(tf.sign(d), dtype)
|
14
|
+
q *= ph[None, :]
|
15
|
+
|
16
|
+
if num_rows < num_cols:
|
17
|
+
q = tf.transpose(q)
|
18
|
+
return q
|
19
|
+
|
20
|
+
class YatNMN(tf.Module):
|
21
|
+
"""A custom transformation applied over the last dimension of the input using squared Euclidean distance.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
features: The number of output features.
|
25
|
+
use_bias: Whether to add a bias to the output (default: True).
|
26
|
+
dtype: The dtype of the computation (default: tf.float32).
|
27
|
+
epsilon: Small constant added to avoid division by zero (default: 1e-6).
|
28
|
+
return_weights: Whether to return the weight matrix along with output (default: False).
|
29
|
+
name: Name of the module (default: None).
|
30
|
+
"""
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
features: int,
|
34
|
+
use_bias: bool = True,
|
35
|
+
dtype: tf.DType = tf.float32,
|
36
|
+
epsilon: float = 1e-6,
|
37
|
+
return_weights: bool = False,
|
38
|
+
name: Optional[str] = None
|
39
|
+
):
|
40
|
+
super().__init__(name=name)
|
41
|
+
self.features = features
|
42
|
+
self.use_bias = use_bias
|
43
|
+
self.dtype = dtype
|
44
|
+
self.epsilon = epsilon
|
45
|
+
self.return_weights = return_weights
|
46
|
+
|
47
|
+
# Variables will be created in build
|
48
|
+
self.is_built = False
|
49
|
+
self.input_dim = None
|
50
|
+
self.kernel = None
|
51
|
+
self.bias = None
|
52
|
+
self.alpha = None
|
53
|
+
|
54
|
+
@tf.Module.with_name_scope
|
55
|
+
def build(self, input_shape: Union[List[int], tf.TensorShape]) -> None:
|
56
|
+
"""Builds the layer weights based on input shape.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
input_shape: Shape of the input tensor.
|
60
|
+
"""
|
61
|
+
if self.is_built:
|
62
|
+
return
|
63
|
+
|
64
|
+
last_dim = int(input_shape[-1])
|
65
|
+
self.input_dim = last_dim
|
66
|
+
|
67
|
+
# Initialize kernel using orthogonal initialization
|
68
|
+
kernel_shape = (self.features, last_dim)
|
69
|
+
initial_kernel = create_orthogonal_matrix(kernel_shape, dtype=self.dtype)
|
70
|
+
self.kernel = tf.Variable(
|
71
|
+
initial_kernel,
|
72
|
+
trainable=True,
|
73
|
+
name='kernel',
|
74
|
+
dtype=self.dtype
|
75
|
+
)
|
76
|
+
|
77
|
+
# Initialize alpha to ones
|
78
|
+
self.alpha = tf.Variable(
|
79
|
+
tf.ones([1], dtype=self.dtype),
|
80
|
+
trainable=True,
|
81
|
+
name='alpha'
|
82
|
+
)
|
83
|
+
|
84
|
+
# Initialize bias if needed
|
85
|
+
if self.use_bias:
|
86
|
+
self.bias = tf.Variable(
|
87
|
+
tf.zeros([self.features], dtype=self.dtype),
|
88
|
+
trainable=True,
|
89
|
+
name='bias'
|
90
|
+
)
|
91
|
+
|
92
|
+
self.is_built = True
|
93
|
+
|
94
|
+
def _maybe_build(self, inputs: tf.Tensor) -> None:
|
95
|
+
"""Builds the layer if it hasn't been built yet."""
|
96
|
+
if not self.is_built:
|
97
|
+
self.build(inputs.shape)
|
98
|
+
elif self.input_dim != inputs.shape[-1]:
|
99
|
+
raise ValueError(f'Input shape changed: expected last dimension '
|
100
|
+
f'{self.input_dim}, got {inputs.shape[-1]}')
|
101
|
+
|
102
|
+
@tf.Module.with_name_scope
|
103
|
+
def __call__(self, inputs: tf.Tensor) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
|
104
|
+
"""Forward pass of the layer.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
inputs: Input tensor.
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
Output tensor or tuple of (output tensor, kernel weights) if return_weights is True.
|
111
|
+
"""
|
112
|
+
# Ensure inputs are tensor
|
113
|
+
inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
|
114
|
+
|
115
|
+
# Build if necessary
|
116
|
+
self._maybe_build(inputs)
|
117
|
+
|
118
|
+
# Compute dot product between input and transposed kernel
|
119
|
+
y = tf.matmul(inputs, tf.transpose(self.kernel))
|
120
|
+
|
121
|
+
# Compute squared Euclidean distances
|
122
|
+
inputs_squared_sum = tf.reduce_sum(tf.square(inputs), axis=-1, keepdims=True)
|
123
|
+
kernel_squared_sum = tf.reduce_sum(tf.square(self.kernel), axis=-1)
|
124
|
+
|
125
|
+
# Reshape kernel_squared_sum for broadcasting
|
126
|
+
kernel_squared_sum = tf.reshape(
|
127
|
+
kernel_squared_sum,
|
128
|
+
[1] * (len(inputs.shape) - 1) + [self.features]
|
129
|
+
)
|
130
|
+
|
131
|
+
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
132
|
+
|
133
|
+
# Apply the transformation
|
134
|
+
y = tf.square(y) / (distances + self.epsilon)
|
135
|
+
|
136
|
+
# Apply scaling factor
|
137
|
+
scale = tf.pow(
|
138
|
+
tf.cast(
|
139
|
+
tf.sqrt(float(self.features)) / tf.math.log(1. + float(self.features)),
|
140
|
+
self.dtype
|
141
|
+
),
|
142
|
+
self.alpha
|
143
|
+
)
|
144
|
+
y = y * scale
|
145
|
+
|
146
|
+
# Add bias if used
|
147
|
+
if self.use_bias:
|
148
|
+
# Reshape bias for proper broadcasting
|
149
|
+
bias_shape = [1] * (len(y.shape) - 1) + [-1]
|
150
|
+
y = y + tf.reshape(self.bias, bias_shape)
|
151
|
+
|
152
|
+
if self.return_weights:
|
153
|
+
return y, self.kernel
|
154
|
+
return y
|
155
|
+
|
156
|
+
def get_weights(self) -> List[tf.Tensor]:
|
157
|
+
"""Returns the current weights of the layer."""
|
158
|
+
weights = [self.kernel, self.alpha]
|
159
|
+
if self.use_bias:
|
160
|
+
weights.append(self.bias)
|
161
|
+
return weights
|
162
|
+
|
163
|
+
def set_weights(self, weights: List[tf.Tensor]) -> None:
|
164
|
+
"""Sets the weights of the layer.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
weights: List of tensors with shapes matching the layer's variables.
|
168
|
+
"""
|
169
|
+
if not self.is_built:
|
170
|
+
raise ValueError("Layer must be built before weights can be set.")
|
171
|
+
|
172
|
+
expected_num = 3 if self.use_bias else 2
|
173
|
+
if len(weights) != expected_num:
|
174
|
+
raise ValueError(f"Expected {expected_num} weight tensors, got {len(weights)}")
|
175
|
+
|
176
|
+
self.kernel.assign(weights[0])
|
177
|
+
self.alpha.assign(weights[1])
|
178
|
+
if self.use_bias:
|
179
|
+
self.bias.assign(weights[2])
|
nmn/torch/nmn.py
ADDED
@@ -0,0 +1,144 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import math
|
5
|
+
|
6
|
+
class YatNMN(nn.Module):
|
7
|
+
"""
|
8
|
+
A PyTorch implementation of the Yat neuron with squared Euclidean distance transformation.
|
9
|
+
|
10
|
+
Attributes:
|
11
|
+
in_features (int): Size of each input sample
|
12
|
+
out_features (int): Size of each output sample
|
13
|
+
use_bias (bool): Whether to add a bias to the output
|
14
|
+
dtype (torch.dtype): Data type for computation
|
15
|
+
epsilon (float): Small constant to avoid division by zero
|
16
|
+
kernel_init (callable): Initializer for the weight matrix
|
17
|
+
bias_init (callable): Initializer for the bias
|
18
|
+
alpha_init (callable): Initializer for the scaling parameter
|
19
|
+
"""
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
in_features: int,
|
23
|
+
out_features: int,
|
24
|
+
bias: bool = True,
|
25
|
+
alpha: bool = True,
|
26
|
+
dtype: torch.dtype = torch.float32,
|
27
|
+
epsilon: float = 1e-4, # 1/epsilon is the maximum score per neuron, setting it low increase the precision but the scores explode
|
28
|
+
kernel_init: callable = None,
|
29
|
+
bias_init: callable = None,
|
30
|
+
alpha_init: callable = None
|
31
|
+
):
|
32
|
+
super().__init__()
|
33
|
+
|
34
|
+
# Store attributes
|
35
|
+
self.in_features = in_features
|
36
|
+
self.out_features = out_features
|
37
|
+
self.dtype = dtype
|
38
|
+
self.epsilon = epsilon
|
39
|
+
# Weight initialization
|
40
|
+
if kernel_init is None:
|
41
|
+
kernel_init = nn.init.xavier_normal_
|
42
|
+
|
43
|
+
# Create weight parameter
|
44
|
+
self.weight = nn.Parameter(torch.empty(
|
45
|
+
(out_features, in_features),
|
46
|
+
dtype=dtype
|
47
|
+
))
|
48
|
+
|
49
|
+
# Alpha scaling parameter
|
50
|
+
if alpha:
|
51
|
+
self.alpha = nn.Parameter(torch.ones(
|
52
|
+
(1,),
|
53
|
+
dtype=dtype
|
54
|
+
))
|
55
|
+
else:
|
56
|
+
self.register_parameter('alpha', None)
|
57
|
+
|
58
|
+
# Bias parameter
|
59
|
+
if bias:
|
60
|
+
self.bias = nn.Parameter(torch.empty(
|
61
|
+
(out_features,),
|
62
|
+
dtype=dtype
|
63
|
+
))
|
64
|
+
else:
|
65
|
+
self.register_parameter('bias', None)
|
66
|
+
|
67
|
+
# Initialize parameters
|
68
|
+
self.reset_parameters(kernel_init, bias_init, alpha_init)
|
69
|
+
|
70
|
+
def reset_parameters(
|
71
|
+
self,
|
72
|
+
kernel_init: callable = None,
|
73
|
+
bias_init: callable = None,
|
74
|
+
alpha_init: callable = None
|
75
|
+
):
|
76
|
+
"""
|
77
|
+
Initialize network parameters with specified or default initializers.
|
78
|
+
"""
|
79
|
+
# Kernel (weight) initialization
|
80
|
+
if kernel_init is None:
|
81
|
+
kernel_init = nn.init.orthogonal_
|
82
|
+
kernel_init(self.weight)
|
83
|
+
|
84
|
+
# Bias initialization
|
85
|
+
if self.bias is not None:
|
86
|
+
if bias_init is None:
|
87
|
+
# Default: uniform initialization
|
88
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
89
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
90
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
91
|
+
else:
|
92
|
+
bias_init(self.bias)
|
93
|
+
|
94
|
+
# Alpha initialization (default to 1.0)
|
95
|
+
if self.alpha is not None:
|
96
|
+
if alpha_init is None:
|
97
|
+
self.alpha.data.fill_(1.0)
|
98
|
+
else:
|
99
|
+
alpha_init(self.alpha)
|
100
|
+
|
101
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
102
|
+
"""
|
103
|
+
Forward pass with squared Euclidean distance transformation.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
x (torch.Tensor): Input tensor
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
torch.Tensor: Transformed output
|
110
|
+
"""
|
111
|
+
# Ensure input and weight are in the same dtype
|
112
|
+
x = x.to(self.dtype)
|
113
|
+
|
114
|
+
# Compute dot product
|
115
|
+
y = torch.matmul(x, self.weight.t())
|
116
|
+
|
117
|
+
# Compute squared distances
|
118
|
+
inputs_squared_sum = torch.sum(x**2, dim=-1, keepdim=True)
|
119
|
+
kernel_squared_sum = torch.sum(self.weight**2, dim=-1)
|
120
|
+
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
121
|
+
|
122
|
+
# Apply squared Euclidean distance transformation
|
123
|
+
y = y ** 2 / (distances + self.epsilon)
|
124
|
+
|
125
|
+
# Add bias if used
|
126
|
+
if self.bias is not None:
|
127
|
+
y += self.bias
|
128
|
+
|
129
|
+
# Dynamic scaling
|
130
|
+
if self.alpha is not None:
|
131
|
+
scale = (math.sqrt(self.out_features) / math.log(1 + self.out_features)) ** self.alpha
|
132
|
+
y = y * scale
|
133
|
+
|
134
|
+
|
135
|
+
return y
|
136
|
+
|
137
|
+
def extra_repr(self) -> str:
|
138
|
+
"""
|
139
|
+
Extra representation of the module for print formatting.
|
140
|
+
"""
|
141
|
+
return (f"in_features={self.in_features}, "
|
142
|
+
f"out_features={self.out_features}, "
|
143
|
+
f"bias={self.bias}, "
|
144
|
+
f"alpha={self.alpha}")
|