nmn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nmn/nnx/yatconv.py ADDED
@@ -0,0 +1,320 @@
1
+ import typing as tp
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numpy as np
6
+ from jax import lax
7
+
8
+ from flax import nnx
9
+ from flax.nnx.module import Module
10
+ from flax.nnx import rnglib
11
+ from flax.nnx.nn import dtypes, initializers
12
+ from flax.typing import (
13
+ Dtype,
14
+ Initializer,
15
+ PrecisionLike,
16
+ ConvGeneralDilatedT,
17
+ PaddingLike,
18
+ LaxPadding,
19
+ PromoteDtypeFn,
20
+ )
21
+
22
+ Array = jax.Array
23
+
24
+ # Default initializers
25
+ default_kernel_init = initializers.lecun_normal()
26
+ default_bias_init = initializers.zeros_init()
27
+
28
+ # Helper functions
29
+ def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
30
+ """ "Canonicalizes conv padding to a jax.lax supported format."""
31
+ if isinstance(padding, str):
32
+ return padding
33
+ if isinstance(padding, int):
34
+ return [(padding, padding)] * rank
35
+ if isinstance(padding, tp.Sequence) and len(padding) == rank:
36
+ new_pad = []
37
+ for p in padding:
38
+ if isinstance(p, int):
39
+ new_pad.append((p, p))
40
+ elif isinstance(p, tuple) and len(p) == 2:
41
+ new_pad.append(p)
42
+ else:
43
+ break
44
+ if len(new_pad) == rank:
45
+ return new_pad
46
+ raise ValueError(
47
+ f'Invalid padding format: {padding}, should be str, int,'
48
+ f' or a sequence of len {rank} where each element is an'
49
+ ' int or pair of ints.'
50
+ )
51
+
52
+ def _conv_dimension_numbers(input_shape):
53
+ """Computes the dimension numbers based on the input shape."""
54
+ ndim = len(input_shape)
55
+ lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
56
+ rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
57
+ out_spec = lhs_spec
58
+ return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
59
+
60
+ class YatConv(Module):
61
+ """Yat Convolution Module wrapping ``lax.conv_general_dilated``.
62
+
63
+ Example usage::
64
+
65
+ >>> from flax.nnx import conv # Assuming this file is flax/nnx/conv.py
66
+ >>> import jax, jax.numpy as jnp
67
+ >>> from flax.nnx import rnglib, state
68
+
69
+ >>> rngs = rnglib.Rngs(0)
70
+ >>> x = jnp.ones((1, 8, 3))
71
+
72
+ >>> # valid padding
73
+ >>> layer = conv.Conv(in_features=3, out_features=4, kernel_size=(3,),
74
+ ... padding='VALID', rngs=rngs)
75
+ >>> s = state(layer)
76
+ >>> print(s['kernel'].value.shape)
77
+ (3, 3, 4)
78
+ >>> print(s['bias'].value.shape)
79
+ (4,)
80
+ >>> out = layer(x)
81
+ >>> print(out.shape)
82
+ (1, 6, 4)
83
+
84
+ Args:
85
+ in_features: int or tuple with number of input features.
86
+ out_features: int or tuple with number of output features.
87
+ kernel_size: shape of the convolutional kernel. For 1D convolution,
88
+ the kernel size can be passed as an integer, which will be interpreted
89
+ as a tuple of the single integer. For all other cases, it must be a
90
+ sequence of integers.
91
+ strides: an integer or a sequence of ``n`` integers, representing the
92
+ inter-window strides (default: 1).
93
+ padding: either the string ``'SAME'``, the string ``'VALID'``, the string
94
+ ``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'`
95
+ (reflection across the padding boundary), or a sequence of ``n``
96
+ ``(low, high)`` integer pairs that give the padding to apply before and after each
97
+ spatial dimension. A single int is interpeted as applying the same padding
98
+ in all dims and passign a single int in a sequence causes the same padding
99
+ to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will
100
+ left-pad the convolution axis, resulting in same-sized output.
101
+ input_dilation: an integer or a sequence of ``n`` integers, giving the
102
+ dilation factor to apply in each spatial dimension of ``inputs``
103
+ (default: 1). Convolution with input dilation ``d`` is equivalent to
104
+ transposed convolution with stride ``d``.
105
+ kernel_dilation: an integer or a sequence of ``n`` integers, giving the
106
+ dilation factor to apply in each spatial dimension of the convolution
107
+ kernel (default: 1). Convolution with kernel dilation
108
+ is also known as 'atrous convolution'.
109
+ feature_group_count: integer, default 1. If specified divides the input
110
+ features into groups.
111
+ use_bias: whether to add a bias to the output (default: True).
112
+ mask: Optional mask for the weights during masked convolution. The mask must
113
+ be the same shape as the convolution weight matrix.
114
+ dtype: the dtype of the computation (default: infer from input and params).
115
+ param_dtype: the dtype passed to parameter initializers (default: float32).
116
+ precision: numerical precision of the computation see ``jax.lax.Precision``
117
+ for details.
118
+ kernel_init: initializer for the convolutional kernel.
119
+ bias_init: initializer for the bias.
120
+ promote_dtype: function to promote the dtype of the arrays to the desired
121
+ dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
122
+ and a ``dtype`` keyword argument, and return a tuple of arrays with the
123
+ promoted dtype.
124
+ epsilon: A small float added to the denominator to prevent division by zero.
125
+ rngs: rng key.
126
+ """
127
+
128
+ __data__ = ('kernel', 'bias', 'mask')
129
+
130
+ def __init__(
131
+ self,
132
+ in_features: int,
133
+ out_features: int,
134
+ kernel_size: int | tp.Sequence[int],
135
+ strides: tp.Union[None, int, tp.Sequence[int]] = 1,
136
+ *,
137
+ padding: PaddingLike = 'SAME',
138
+ input_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
139
+ kernel_dilation: tp.Union[None, int, tp.Sequence[int]] = 1,
140
+ feature_group_count: int = 1,
141
+ use_bias: bool = True,
142
+ mask: tp.Optional[Array] = None,
143
+ dtype: tp.Optional[Dtype] = None,
144
+ param_dtype: Dtype = jnp.float32,
145
+ precision: PrecisionLike = None,
146
+ kernel_init: Initializer = default_kernel_init,
147
+ bias_init: Initializer = default_bias_init,
148
+ conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
149
+ promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
150
+ epsilon: float = 1e-5,
151
+ rngs: rnglib.Rngs,
152
+ ):
153
+ if isinstance(kernel_size, int):
154
+ kernel_size = (kernel_size,)
155
+ else:
156
+ kernel_size = tuple(kernel_size)
157
+
158
+ self.kernel_shape = kernel_size + (
159
+ in_features // feature_group_count,
160
+ out_features,
161
+ )
162
+ kernel_key = rngs.params()
163
+ self.kernel = nnx.Param(kernel_init(kernel_key, self.kernel_shape, param_dtype))
164
+
165
+ self.bias: nnx.Param[jax.Array] | None
166
+ if use_bias:
167
+ bias_shape = (out_features,)
168
+ bias_key = rngs.params()
169
+ self.bias = nnx.Param(bias_init(bias_key, bias_shape, param_dtype))
170
+ else:
171
+ self.bias = None
172
+
173
+ self.in_features = in_features
174
+ self.out_features = out_features
175
+ self.kernel_size = kernel_size
176
+ self.strides = strides
177
+ self.padding = padding
178
+ self.input_dilation = input_dilation
179
+ self.kernel_dilation = kernel_dilation
180
+ self.feature_group_count = feature_group_count
181
+ self.use_bias = use_bias
182
+ self.mask = mask
183
+ self.dtype = dtype
184
+ self.param_dtype = param_dtype
185
+ self.precision = precision
186
+ self.kernel_init = kernel_init
187
+ self.bias_init = bias_init
188
+ self.conv_general_dilated = conv_general_dilated
189
+ self.promote_dtype = promote_dtype
190
+ self.epsilon = epsilon
191
+
192
+ def __call__(self, inputs: Array) -> Array:
193
+ assert isinstance(self.kernel_size, tuple)
194
+
195
+ def maybe_broadcast(
196
+ x: tp.Optional[tp.Union[int, tp.Sequence[int]]],
197
+ ) -> tuple[int, ...]:
198
+ if x is None:
199
+ x = 1
200
+ if isinstance(x, int):
201
+ return (x,) * len(self.kernel_size)
202
+ return tuple(x)
203
+
204
+ num_batch_dimensions = inputs.ndim - (len(self.kernel_size) + 1)
205
+ if num_batch_dimensions != 1:
206
+ input_batch_shape = inputs.shape[:num_batch_dimensions]
207
+ total_batch_size = int(np.prod(input_batch_shape))
208
+ flat_input_shape = (total_batch_size,) + inputs.shape[
209
+ num_batch_dimensions:
210
+ ]
211
+ inputs_flat = jnp.reshape(inputs, flat_input_shape)
212
+ else:
213
+ inputs_flat = inputs
214
+ input_batch_shape = ()
215
+
216
+ strides = maybe_broadcast(self.strides)
217
+ input_dilation = maybe_broadcast(self.input_dilation)
218
+ kernel_dilation = maybe_broadcast(self.kernel_dilation)
219
+
220
+ padding_lax = canonicalize_padding(self.padding, len(self.kernel_size))
221
+ if padding_lax in ('CIRCULAR', 'REFLECT'):
222
+ assert isinstance(padding_lax, str)
223
+ kernel_size_dilated = [
224
+ (k - 1) * d + 1 for k, d in zip(self.kernel_size, kernel_dilation)
225
+ ]
226
+ zero_pad: tp.List[tuple[int, int]] = [(0, 0)]
227
+ pads = (
228
+ zero_pad
229
+ + [((k - 1) // 2, k // 2) for k in kernel_size_dilated]
230
+ + [(0, 0)]
231
+ )
232
+ padding_mode = {'CIRCULAR': 'wrap', 'REFLECT': 'reflect'}[padding_lax]
233
+ inputs_flat = jnp.pad(inputs_flat, pads, mode=padding_mode)
234
+ padding_lax = 'VALID'
235
+ elif padding_lax == 'CAUSAL':
236
+ if len(self.kernel_size) != 1:
237
+ raise ValueError(
238
+ 'Causal padding is only implemented for 1D convolutions.'
239
+ )
240
+ left_pad = kernel_dilation[0] * (self.kernel_size[0] - 1)
241
+ pads = [(0, 0), (left_pad, 0), (0, 0)]
242
+ inputs_flat = jnp.pad(inputs_flat, pads)
243
+ padding_lax = 'VALID'
244
+
245
+ dimension_numbers = _conv_dimension_numbers(inputs_flat.shape)
246
+ assert self.in_features % self.feature_group_count == 0
247
+
248
+ kernel_val = self.kernel.value
249
+
250
+ current_mask = self.mask
251
+ if current_mask is not None:
252
+ if current_mask.shape != self.kernel_shape:
253
+ raise ValueError(
254
+ 'Mask needs to have the same shape as weights. '
255
+ f'Shapes are: {current_mask.shape}, {self.kernel_shape}'
256
+ )
257
+ kernel_val *= current_mask
258
+
259
+ bias_val = self.bias.value if self.bias is not None else None
260
+
261
+ inputs_promoted, kernel_promoted, bias_promoted = self.promote_dtype(
262
+ (inputs_flat, kernel_val, bias_val), dtype=self.dtype
263
+ )
264
+ inputs_flat = inputs_promoted
265
+ kernel_val = kernel_promoted
266
+ bias_val = bias_promoted
267
+
268
+ dot_prod_map = self.conv_general_dilated(
269
+ inputs_flat,
270
+ kernel_val,
271
+ strides,
272
+ padding_lax,
273
+ lhs_dilation=input_dilation,
274
+ rhs_dilation=kernel_dilation,
275
+ dimension_numbers=dimension_numbers,
276
+ feature_group_count=self.feature_group_count,
277
+ precision=self.precision,
278
+ )
279
+
280
+ inputs_flat_squared = inputs_flat**2
281
+ kernel_in_channels_for_sum_sq = self.kernel_shape[-2]
282
+ kernel_for_patch_sq_sum_shape = self.kernel_size + (kernel_in_channels_for_sum_sq, 1)
283
+ kernel_for_patch_sq_sum = jnp.ones(kernel_for_patch_sq_sum_shape, dtype=kernel_val.dtype)
284
+
285
+ patch_sq_sum_map_raw = self.conv_general_dilated(
286
+ inputs_flat_squared,
287
+ kernel_for_patch_sq_sum,
288
+ strides,
289
+ padding_lax,
290
+ lhs_dilation=input_dilation,
291
+ rhs_dilation=kernel_dilation,
292
+ dimension_numbers=dimension_numbers,
293
+ feature_group_count=self.feature_group_count,
294
+ precision=self.precision,
295
+ )
296
+
297
+ if self.feature_group_count > 1:
298
+ num_out_channels_per_group = self.out_features // self.feature_group_count
299
+ if num_out_channels_per_group == 0 :
300
+ raise ValueError(
301
+ "out_features must be a multiple of feature_group_count and greater or equal."
302
+ )
303
+ patch_sq_sum_map = jnp.repeat(patch_sq_sum_map_raw, num_out_channels_per_group, axis=-1)
304
+ else:
305
+ patch_sq_sum_map = patch_sq_sum_map_raw
306
+
307
+ reduce_axes_for_kernel_sq = tuple(range(kernel_val.ndim - 1))
308
+ kernel_sq_sum_per_filter = jnp.sum(kernel_val**2, axis=reduce_axes_for_kernel_sq)
309
+
310
+ distance_sq_map = patch_sq_sum_map + kernel_sq_sum_per_filter - 2 * dot_prod_map
311
+ y = dot_prod_map**2 / (distance_sq_map + self.epsilon)
312
+
313
+ if self.use_bias and bias_val is not None:
314
+ bias_reshape_dims = (1,) * (y.ndim - 1) + (-1,)
315
+ y += jnp.reshape(bias_val, bias_reshape_dims)
316
+
317
+ if num_batch_dimensions != 1:
318
+ output_shape = input_batch_shape + y.shape[1:]
319
+ y = jnp.reshape(y, output_shape)
320
+ return y
nmn/tf/nmn.py ADDED
@@ -0,0 +1,179 @@
1
+ import tensorflow as tf
2
+ import math
3
+ from typing import Optional, Any, Tuple, Union, List, Callable
4
+ import numpy as np
5
+
6
+ def create_orthogonal_matrix(shape: Tuple[int, ...], dtype: tf.DType = tf.float32) -> tf.Tensor:
7
+ """Creates an orthogonal matrix using QR decomposition."""
8
+ num_rows, num_cols = shape
9
+ random_matrix = tf.random.normal([num_rows, num_cols], dtype=dtype)
10
+ q, r = tf.linalg.qr(random_matrix)
11
+ # Make it uniform
12
+ d = tf.linalg.diag_part(r)
13
+ ph = tf.cast(tf.sign(d), dtype)
14
+ q *= ph[None, :]
15
+
16
+ if num_rows < num_cols:
17
+ q = tf.transpose(q)
18
+ return q
19
+
20
+ class YatNMN(tf.Module):
21
+ """A custom transformation applied over the last dimension of the input using squared Euclidean distance.
22
+
23
+ Args:
24
+ features: The number of output features.
25
+ use_bias: Whether to add a bias to the output (default: True).
26
+ dtype: The dtype of the computation (default: tf.float32).
27
+ epsilon: Small constant added to avoid division by zero (default: 1e-6).
28
+ return_weights: Whether to return the weight matrix along with output (default: False).
29
+ name: Name of the module (default: None).
30
+ """
31
+ def __init__(
32
+ self,
33
+ features: int,
34
+ use_bias: bool = True,
35
+ dtype: tf.DType = tf.float32,
36
+ epsilon: float = 1e-6,
37
+ return_weights: bool = False,
38
+ name: Optional[str] = None
39
+ ):
40
+ super().__init__(name=name)
41
+ self.features = features
42
+ self.use_bias = use_bias
43
+ self.dtype = dtype
44
+ self.epsilon = epsilon
45
+ self.return_weights = return_weights
46
+
47
+ # Variables will be created in build
48
+ self.is_built = False
49
+ self.input_dim = None
50
+ self.kernel = None
51
+ self.bias = None
52
+ self.alpha = None
53
+
54
+ @tf.Module.with_name_scope
55
+ def build(self, input_shape: Union[List[int], tf.TensorShape]) -> None:
56
+ """Builds the layer weights based on input shape.
57
+
58
+ Args:
59
+ input_shape: Shape of the input tensor.
60
+ """
61
+ if self.is_built:
62
+ return
63
+
64
+ last_dim = int(input_shape[-1])
65
+ self.input_dim = last_dim
66
+
67
+ # Initialize kernel using orthogonal initialization
68
+ kernel_shape = (self.features, last_dim)
69
+ initial_kernel = create_orthogonal_matrix(kernel_shape, dtype=self.dtype)
70
+ self.kernel = tf.Variable(
71
+ initial_kernel,
72
+ trainable=True,
73
+ name='kernel',
74
+ dtype=self.dtype
75
+ )
76
+
77
+ # Initialize alpha to ones
78
+ self.alpha = tf.Variable(
79
+ tf.ones([1], dtype=self.dtype),
80
+ trainable=True,
81
+ name='alpha'
82
+ )
83
+
84
+ # Initialize bias if needed
85
+ if self.use_bias:
86
+ self.bias = tf.Variable(
87
+ tf.zeros([self.features], dtype=self.dtype),
88
+ trainable=True,
89
+ name='bias'
90
+ )
91
+
92
+ self.is_built = True
93
+
94
+ def _maybe_build(self, inputs: tf.Tensor) -> None:
95
+ """Builds the layer if it hasn't been built yet."""
96
+ if not self.is_built:
97
+ self.build(inputs.shape)
98
+ elif self.input_dim != inputs.shape[-1]:
99
+ raise ValueError(f'Input shape changed: expected last dimension '
100
+ f'{self.input_dim}, got {inputs.shape[-1]}')
101
+
102
+ @tf.Module.with_name_scope
103
+ def __call__(self, inputs: tf.Tensor) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
104
+ """Forward pass of the layer.
105
+
106
+ Args:
107
+ inputs: Input tensor.
108
+
109
+ Returns:
110
+ Output tensor or tuple of (output tensor, kernel weights) if return_weights is True.
111
+ """
112
+ # Ensure inputs are tensor
113
+ inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
114
+
115
+ # Build if necessary
116
+ self._maybe_build(inputs)
117
+
118
+ # Compute dot product between input and transposed kernel
119
+ y = tf.matmul(inputs, tf.transpose(self.kernel))
120
+
121
+ # Compute squared Euclidean distances
122
+ inputs_squared_sum = tf.reduce_sum(tf.square(inputs), axis=-1, keepdims=True)
123
+ kernel_squared_sum = tf.reduce_sum(tf.square(self.kernel), axis=-1)
124
+
125
+ # Reshape kernel_squared_sum for broadcasting
126
+ kernel_squared_sum = tf.reshape(
127
+ kernel_squared_sum,
128
+ [1] * (len(inputs.shape) - 1) + [self.features]
129
+ )
130
+
131
+ distances = inputs_squared_sum + kernel_squared_sum - 2 * y
132
+
133
+ # Apply the transformation
134
+ y = tf.square(y) / (distances + self.epsilon)
135
+
136
+ # Apply scaling factor
137
+ scale = tf.pow(
138
+ tf.cast(
139
+ tf.sqrt(float(self.features)) / tf.math.log(1. + float(self.features)),
140
+ self.dtype
141
+ ),
142
+ self.alpha
143
+ )
144
+ y = y * scale
145
+
146
+ # Add bias if used
147
+ if self.use_bias:
148
+ # Reshape bias for proper broadcasting
149
+ bias_shape = [1] * (len(y.shape) - 1) + [-1]
150
+ y = y + tf.reshape(self.bias, bias_shape)
151
+
152
+ if self.return_weights:
153
+ return y, self.kernel
154
+ return y
155
+
156
+ def get_weights(self) -> List[tf.Tensor]:
157
+ """Returns the current weights of the layer."""
158
+ weights = [self.kernel, self.alpha]
159
+ if self.use_bias:
160
+ weights.append(self.bias)
161
+ return weights
162
+
163
+ def set_weights(self, weights: List[tf.Tensor]) -> None:
164
+ """Sets the weights of the layer.
165
+
166
+ Args:
167
+ weights: List of tensors with shapes matching the layer's variables.
168
+ """
169
+ if not self.is_built:
170
+ raise ValueError("Layer must be built before weights can be set.")
171
+
172
+ expected_num = 3 if self.use_bias else 2
173
+ if len(weights) != expected_num:
174
+ raise ValueError(f"Expected {expected_num} weight tensors, got {len(weights)}")
175
+
176
+ self.kernel.assign(weights[0])
177
+ self.alpha.assign(weights[1])
178
+ if self.use_bias:
179
+ self.bias.assign(weights[2])
nmn/torch/nmn.py ADDED
@@ -0,0 +1,144 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class YatNMN(nn.Module):
7
+ """
8
+ A PyTorch implementation of the Yat neuron with squared Euclidean distance transformation.
9
+
10
+ Attributes:
11
+ in_features (int): Size of each input sample
12
+ out_features (int): Size of each output sample
13
+ use_bias (bool): Whether to add a bias to the output
14
+ dtype (torch.dtype): Data type for computation
15
+ epsilon (float): Small constant to avoid division by zero
16
+ kernel_init (callable): Initializer for the weight matrix
17
+ bias_init (callable): Initializer for the bias
18
+ alpha_init (callable): Initializer for the scaling parameter
19
+ """
20
+ def __init__(
21
+ self,
22
+ in_features: int,
23
+ out_features: int,
24
+ bias: bool = True,
25
+ alpha: bool = True,
26
+ dtype: torch.dtype = torch.float32,
27
+ epsilon: float = 1e-4, # 1/epsilon is the maximum score per neuron, setting it low increase the precision but the scores explode
28
+ kernel_init: callable = None,
29
+ bias_init: callable = None,
30
+ alpha_init: callable = None
31
+ ):
32
+ super().__init__()
33
+
34
+ # Store attributes
35
+ self.in_features = in_features
36
+ self.out_features = out_features
37
+ self.dtype = dtype
38
+ self.epsilon = epsilon
39
+ # Weight initialization
40
+ if kernel_init is None:
41
+ kernel_init = nn.init.xavier_normal_
42
+
43
+ # Create weight parameter
44
+ self.weight = nn.Parameter(torch.empty(
45
+ (out_features, in_features),
46
+ dtype=dtype
47
+ ))
48
+
49
+ # Alpha scaling parameter
50
+ if alpha:
51
+ self.alpha = nn.Parameter(torch.ones(
52
+ (1,),
53
+ dtype=dtype
54
+ ))
55
+ else:
56
+ self.register_parameter('alpha', None)
57
+
58
+ # Bias parameter
59
+ if bias:
60
+ self.bias = nn.Parameter(torch.empty(
61
+ (out_features,),
62
+ dtype=dtype
63
+ ))
64
+ else:
65
+ self.register_parameter('bias', None)
66
+
67
+ # Initialize parameters
68
+ self.reset_parameters(kernel_init, bias_init, alpha_init)
69
+
70
+ def reset_parameters(
71
+ self,
72
+ kernel_init: callable = None,
73
+ bias_init: callable = None,
74
+ alpha_init: callable = None
75
+ ):
76
+ """
77
+ Initialize network parameters with specified or default initializers.
78
+ """
79
+ # Kernel (weight) initialization
80
+ if kernel_init is None:
81
+ kernel_init = nn.init.orthogonal_
82
+ kernel_init(self.weight)
83
+
84
+ # Bias initialization
85
+ if self.bias is not None:
86
+ if bias_init is None:
87
+ # Default: uniform initialization
88
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
89
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
90
+ nn.init.uniform_(self.bias, -bound, bound)
91
+ else:
92
+ bias_init(self.bias)
93
+
94
+ # Alpha initialization (default to 1.0)
95
+ if self.alpha is not None:
96
+ if alpha_init is None:
97
+ self.alpha.data.fill_(1.0)
98
+ else:
99
+ alpha_init(self.alpha)
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ """
103
+ Forward pass with squared Euclidean distance transformation.
104
+
105
+ Args:
106
+ x (torch.Tensor): Input tensor
107
+
108
+ Returns:
109
+ torch.Tensor: Transformed output
110
+ """
111
+ # Ensure input and weight are in the same dtype
112
+ x = x.to(self.dtype)
113
+
114
+ # Compute dot product
115
+ y = torch.matmul(x, self.weight.t())
116
+
117
+ # Compute squared distances
118
+ inputs_squared_sum = torch.sum(x**2, dim=-1, keepdim=True)
119
+ kernel_squared_sum = torch.sum(self.weight**2, dim=-1)
120
+ distances = inputs_squared_sum + kernel_squared_sum - 2 * y
121
+
122
+ # Apply squared Euclidean distance transformation
123
+ y = y ** 2 / (distances + self.epsilon)
124
+
125
+ # Add bias if used
126
+ if self.bias is not None:
127
+ y += self.bias
128
+
129
+ # Dynamic scaling
130
+ if self.alpha is not None:
131
+ scale = (math.sqrt(self.out_features) / math.log(1 + self.out_features)) ** self.alpha
132
+ y = y * scale
133
+
134
+
135
+ return y
136
+
137
+ def extra_repr(self) -> str:
138
+ """
139
+ Extra representation of the module for print formatting.
140
+ """
141
+ return (f"in_features={self.in_features}, "
142
+ f"out_features={self.out_features}, "
143
+ f"bias={self.bias}, "
144
+ f"alpha={self.alpha}")