nmn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nmn/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """Neural-Matter Network (NMN) - beyond blinded neurons."""
2
+
3
+ __version__ = "0.1.0"
nmn/keras/nmn.py ADDED
@@ -0,0 +1,153 @@
1
+ from keras.src import activations, constraints, initializers, regularizers
2
+ from keras.src.api_export import keras_export
3
+ from keras.src.layers.input_spec import InputSpec
4
+ from keras.src.layers.layer import Layer
5
+ from keras.src import ops
6
+ import math
7
+
8
+ @keras_export("keras.layers.YatDense")
9
+ class YatNMN(Layer):
10
+ """A YAT densely-connected NN layer.
11
+
12
+ This layer implements the operation:
13
+ `output = scale * (dot(input, kernel)^2 / (squared_euclidean_distance + epsilon))`
14
+ where:
15
+ - `scale` is a dynamic scaling factor based on output dimension
16
+ - `squared_euclidean_distance` is computed between input and kernel
17
+ - `epsilon` is a small constant to prevent division by zero
18
+
19
+ Args:
20
+ units: Positive integer, dimensionality of the output space.
21
+ use_bias: Boolean, whether the layer uses a bias vector.
22
+ epsilon: Float, small constant added to denominator for numerical stability.
23
+ kernel_initializer: Initializer for the `kernel` weights matrix.
24
+ bias_initializer: Initializer for the bias vector.
25
+ kernel_regularizer: Regularizer function applied to the `kernel` weights matrix.
26
+ bias_regularizer: Regularizer function applied to the bias vector.
27
+ activity_regularizer: Regularizer function applied to the output.
28
+ kernel_constraint: Constraint function applied to the `kernel` weights matrix.
29
+ bias_constraint: Constraint function applied to the bias vector.
30
+
31
+ Input shape:
32
+ N-D tensor with shape: `(batch_size, ..., input_dim)`.
33
+ The most common situation would be a 2D input with shape
34
+ `(batch_size, input_dim)`.
35
+
36
+ Output shape:
37
+ N-D tensor with shape: `(batch_size, ..., units)`.
38
+ For instance, for a 2D input with shape `(batch_size, input_dim)`,
39
+ the output would have shape `(batch_size, units)`.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ units,
45
+ activation=None,
46
+ use_bias=True,
47
+ epsilon=1e-5,
48
+ kernel_initializer="orthogonal",
49
+ bias_initializer="zeros",
50
+ kernel_regularizer=None,
51
+ bias_regularizer=None,
52
+ activity_regularizer=None,
53
+ kernel_constraint=None,
54
+ bias_constraint=None,
55
+ **kwargs,
56
+ ):
57
+ super().__init__(activity_regularizer=activity_regularizer, **kwargs)
58
+ self.units = units
59
+ self.activation = activations.get(activation)
60
+ self.use_bias = use_bias
61
+ self.epsilon = epsilon
62
+
63
+ self.kernel_initializer = initializers.get(kernel_initializer)
64
+ self.bias_initializer = initializers.get(bias_initializer)
65
+ self.kernel_regularizer = regularizers.get(kernel_regularizer)
66
+ self.bias_regularizer = regularizers.get(bias_regularizer)
67
+ self.kernel_constraint = constraints.get(kernel_constraint)
68
+ self.bias_constraint = constraints.get(bias_constraint)
69
+
70
+ self.input_spec = InputSpec(min_ndim=2)
71
+ self.supports_masking = True
72
+
73
+ def build(self, input_shape):
74
+ input_dim = input_shape[-1]
75
+
76
+ self.kernel = self.add_weight(
77
+ name="kernel",
78
+ shape=(input_dim, self.units),
79
+ initializer=self.kernel_initializer,
80
+ regularizer=self.kernel_regularizer,
81
+ constraint=self.kernel_constraint,
82
+ trainable=True,
83
+ )
84
+
85
+ if self.use_bias:
86
+ self.bias = self.add_weight(
87
+ name="bias",
88
+ shape=(self.units,),
89
+ initializer=self.bias_initializer,
90
+ regularizer=self.bias_regularizer,
91
+ constraint=self.bias_constraint,
92
+ trainable=True,
93
+ )
94
+ else:
95
+ self.bias = None
96
+
97
+ # Add alpha parameter for dynamic scaling
98
+ self.alpha = self.add_weight(
99
+ name="alpha",
100
+ shape=(1,),
101
+ initializer="ones",
102
+ trainable=True,
103
+ )
104
+
105
+ self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
106
+ self.built = True
107
+
108
+ def call(self, inputs):
109
+ # Compute dot product
110
+ dot_product = ops.matmul(inputs, self.kernel)
111
+
112
+ # Compute squared distances
113
+ inputs_squared_sum = ops.sum(inputs ** 2, axis=-1, keepdims=True)
114
+ kernel_squared_sum = ops.sum(self.kernel ** 2, axis=0)
115
+ distances = inputs_squared_sum + kernel_squared_sum - 2 * dot_product
116
+
117
+ # Compute inverse square attention
118
+ outputs = dot_product ** 2 / (distances + self.epsilon)
119
+ if self.use_bias:
120
+ outputs = ops.add(outputs, self.bias)
121
+
122
+ # Apply dynamic scaling
123
+ scale = (ops.sqrt(ops.cast(self.units, self.compute_dtype)) /
124
+ ops.log1p(ops.cast(self.units, self.compute_dtype))) ** self.alpha
125
+ outputs = outputs * scale
126
+
127
+
128
+ if self.activation is not None:
129
+ outputs = self.activation(outputs)
130
+
131
+ return outputs
132
+
133
+ def compute_output_shape(self, input_shape):
134
+ output_shape = list(input_shape)
135
+ output_shape[-1] = self.units
136
+ return tuple(output_shape)
137
+
138
+ def get_config(self):
139
+ config = super().get_config()
140
+ config.update({
141
+ "units": self.units,
142
+ "activation": activations.serialize(self.activation),
143
+ "use_bias": self.use_bias,
144
+ "epsilon": self.epsilon,
145
+ "kernel_initializer": initializers.serialize(self.kernel_initializer),
146
+ "bias_initializer": initializers.serialize(self.bias_initializer),
147
+ "kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
148
+ "bias_regularizer": regularizers.serialize(self.bias_regularizer),
149
+ "activity_regularizer": regularizers.serialize(self.activity_regularizer),
150
+ "kernel_constraint": constraints.serialize(self.kernel_constraint),
151
+ "bias_constraint": constraints.serialize(self.bias_constraint),
152
+ })
153
+ return config
nmn/linen/nmn.py ADDED
@@ -0,0 +1,112 @@
1
+ import jax.numpy as jnp
2
+ from flax.linen.dtypes import promote_dtype
3
+ from flax.linen.module import Module, compact
4
+ from flax.typing import (
5
+ PRNGKey as PRNGKey,
6
+ Shape as Shape,
7
+ DotGeneralT,
8
+ )
9
+
10
+ from typing import (
11
+ Any,
12
+ )
13
+ import jax.numpy as jnp
14
+ import jax.lax as lax
15
+ from flax.linen import Module, compact
16
+ from flax import linen as nn
17
+ from flax.linen.initializers import zeros_init, lecun_normal
18
+ from typing import Any, Optional
19
+
20
+ class YatNMN(Module):
21
+ """A custom transformation applied over the last dimension of the input using squared Euclidean distance.
22
+
23
+ Attributes:
24
+ features: the number of output features.
25
+ use_bias: whether to add a bias to the output (default: True).
26
+ dtype: the dtype of the computation (default: infer from input and params).
27
+ param_dtype: the dtype passed to parameter initializers (default: float32).
28
+ precision: numerical precision of the computation see ``jax.lax.Precision`` for details.
29
+ kernel_init: initializer function for the weight matrix.
30
+ bias_init: initializer function for the bias.
31
+ epsilon: small constant added to avoid division by zero (default: 1e-6).
32
+ """
33
+ features: int
34
+ use_bias: bool = True
35
+ use_alpha: bool = True
36
+ dtype: Optional[Any] = None
37
+ param_dtype: Any = jnp.float32
38
+ precision: Any = None
39
+ kernel_init: Any = nn.initializers.orthogonal()
40
+ bias_init: Any = zeros_init()
41
+
42
+ alpha_init: Any = lambda key, shape, dtype: jnp.ones(shape, dtype) # Initialize alpha to 1.0
43
+ epsilon: float = 1e-6
44
+ dot_general: DotGeneralT | None = None
45
+ dot_general_cls: Any = None
46
+ return_weights: bool = False
47
+
48
+ @compact
49
+ def __call__(self, inputs: Any) -> Any:
50
+ """Applies a transformation to the inputs along the last dimension using squared Euclidean distance.
51
+
52
+ Args:
53
+ inputs: The nd-array to be transformed.
54
+
55
+ Returns:
56
+ The transformed input.
57
+ """
58
+ kernel = self.param(
59
+ 'kernel',
60
+ self.kernel_init,
61
+ (self.features, jnp.shape(inputs)[-1]),
62
+ self.param_dtype,
63
+ )
64
+ if self.use_alpha:
65
+ alpha = self.param(
66
+ 'alpha',
67
+ self.alpha_init,
68
+ (1,), # Single scalar parameter
69
+ self.param_dtype,
70
+ )
71
+ else:
72
+ alpha = None
73
+
74
+ if self.use_bias:
75
+ bias = self.param(
76
+ 'bias', self.bias_init, (self.features,), self.param_dtype
77
+ )
78
+ else:
79
+ bias = None
80
+
81
+ inputs, kernel, bias, alpha = promote_dtype(inputs, kernel, bias, alpha, dtype=self.dtype)
82
+
83
+ # Compute dot product between input and kernel
84
+ if self.dot_general_cls is not None:
85
+ dot_general = self.dot_general_cls()
86
+ elif self.dot_general is not None:
87
+ dot_general = self.dot_general
88
+ else:
89
+ dot_general = lax.dot_general
90
+ y = dot_general(
91
+ inputs,
92
+ jnp.transpose(kernel),
93
+ (((inputs.ndim - 1,), (0,)), ((), ())),
94
+ precision=self.precision,
95
+ )
96
+ inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
97
+ kernel_squared_sum = jnp.sum(kernel**2, axis=-1)
98
+ distances = inputs_squared_sum + kernel_squared_sum - 2 * y
99
+
100
+ # # Element-wise operation
101
+ y = y ** 2 / (distances + self.epsilon)
102
+ if bias is not None:
103
+ y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
104
+
105
+ if alpha is not None:
106
+ scale = (jnp.sqrt(self.features) / jnp.log(1 + self.features)) ** alpha
107
+ y = y * scale
108
+
109
+ # Normalize y
110
+ if self.return_weights:
111
+ return y, kernel
112
+ return y
nmn/nnx/nmn.py ADDED
@@ -0,0 +1,170 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as tp
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ from jax import lax
9
+ import opt_einsum
10
+
11
+ from flax.core.frozen_dict import FrozenDict
12
+ from flax import nnx
13
+ from flax.nnx import rnglib, variablelib
14
+ from flax.nnx.module import Module, first_from
15
+ from flax.nnx.nn import dtypes, initializers
16
+ from flax.typing import (
17
+ Dtype,
18
+ Shape,
19
+ Initializer,
20
+ PrecisionLike,
21
+ DotGeneralT,
22
+ ConvGeneralDilatedT,
23
+ PaddingLike,
24
+ LaxPadding,
25
+ PromoteDtypeFn,
26
+ EinsumT,
27
+ )
28
+
29
+ Array = jax.Array
30
+ Axis = int
31
+ Size = int
32
+
33
+
34
+ default_kernel_init = initializers.lecun_normal()
35
+ default_bias_init = initializers.zeros_init()
36
+ default_alpha_init = initializers.ones_init()
37
+
38
+ class YatNMN(Module):
39
+ """A linear transformation applied over the last dimension of the input.
40
+
41
+ Example usage::
42
+
43
+ >>> from flax import nnx
44
+ >>> import jax, jax.numpy as jnp
45
+
46
+ >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
47
+ >>> jax.tree.map(jnp.shape, nnx.state(layer))
48
+ State({
49
+ 'bias': VariableState(
50
+ type=Param,
51
+ value=(4,)
52
+ ),
53
+ 'kernel': VariableState(
54
+ type=Param,
55
+ value=(3, 4)
56
+ )
57
+ })
58
+
59
+ Args:
60
+ in_features: the number of input features.
61
+ out_features: the number of output features.
62
+ use_bias: whether to add a bias to the output (default: True).
63
+ dtype: the dtype of the computation (default: infer from input and params).
64
+ param_dtype: the dtype passed to parameter initializers (default: float32).
65
+ precision: numerical precision of the computation see ``jax.lax.Precision``
66
+ for details.
67
+ kernel_init: initializer function for the weight matrix.
68
+ bias_init: initializer function for the bias.
69
+ dot_general: dot product function.
70
+ promote_dtype: function to promote the dtype of the arrays to the desired
71
+ dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
72
+ and a ``dtype`` keyword argument, and return a tuple of arrays with the
73
+ promoted dtype.
74
+ rngs: rng key.
75
+ """
76
+
77
+ __data__ = ('kernel', 'bias')
78
+
79
+ def __init__(
80
+ self,
81
+ in_features: int,
82
+ out_features: int,
83
+ *,
84
+ use_bias: bool = True,
85
+ use_alpha: bool = True,
86
+ dtype: tp.Optional[Dtype] = None,
87
+ param_dtype: Dtype = jnp.float32,
88
+ precision: PrecisionLike = None,
89
+ kernel_init: Initializer = default_kernel_init,
90
+ bias_init: Initializer = default_bias_init,
91
+ alpha_init: Initializer = default_alpha_init,
92
+ dot_general: DotGeneralT = lax.dot_general,
93
+ promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
94
+ rngs: rnglib.Rngs,
95
+ epsilon: float = 1e-5,
96
+ ):
97
+
98
+ kernel_key = rngs.params()
99
+ self.kernel = nnx.Param(
100
+ kernel_init(kernel_key, (in_features, out_features), param_dtype)
101
+ )
102
+ self.bias: nnx.Param[jax.Array] | None
103
+ if use_bias:
104
+ bias_key = rngs.params()
105
+ self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
106
+ else:
107
+ self.bias = None
108
+
109
+ self.alpha: nnx.Param[jax.Array] | None
110
+ if use_alpha:
111
+ alpha_key = rngs.params()
112
+ self.alpha = nnx.Param(alpha_init(bias_key, (1,), param_dtype))
113
+ else:
114
+ self.alpha = None
115
+
116
+ self.in_features = in_features
117
+ self.out_features = out_features
118
+ self.use_bias = use_bias
119
+ self.use_alpha = use_alpha
120
+ self.dtype = dtype
121
+ self.param_dtype = param_dtype
122
+ self.precision = precision
123
+ self.kernel_init = kernel_init
124
+ self.bias_init = bias_init
125
+ self.dot_general = dot_general
126
+ self.promote_dtype = promote_dtype
127
+ self.epsilon = epsilon
128
+
129
+ def __call__(self, inputs: Array) -> Array:
130
+ """Applies a linear transformation to the inputs along the last dimension.
131
+
132
+ Args:
133
+ inputs: The nd-array to be transformed.
134
+
135
+ Returns:
136
+ The transformed input.
137
+ """
138
+ kernel = self.kernel.value
139
+ bias = self.bias.value if self.bias is not None else None
140
+ alpha = self.alpha.value if self.alpha is not None else None
141
+
142
+ inputs, kernel, bias, alpha = self.promote_dtype(
143
+ (inputs, kernel, bias, alpha), dtype=self.dtype
144
+ )
145
+ y = self.dot_general(
146
+ inputs,
147
+ kernel,
148
+ (((inputs.ndim - 1,), (0,)), ((), ())),
149
+ precision=self.precision,
150
+ )
151
+
152
+ assert self.use_bias == (bias is not None)
153
+ assert self.use_alpha == (alpha is not None)
154
+
155
+ inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
156
+ kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True) # Change axis to 0 and keepdims to True
157
+ distances = inputs_squared_sum + kernel_squared_sum - 2 * y
158
+
159
+ # # Element-wise operation
160
+ y = y ** 2 / (distances + self.epsilon)
161
+
162
+ if bias is not None:
163
+ y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
164
+
165
+ if alpha is not None:
166
+ scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
167
+ y = y * scale
168
+
169
+
170
+ return y