nmn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/__init__.py +3 -0
- nmn/keras/nmn.py +153 -0
- nmn/linen/nmn.py +112 -0
- nmn/nnx/nmn.py +170 -0
- nmn/nnx/yatconv.py +320 -0
- nmn/tf/nmn.py +179 -0
- nmn/torch/nmn.py +144 -0
- nmn-0.1.0.dist-info/METADATA +76 -0
- nmn-0.1.0.dist-info/RECORD +11 -0
- nmn-0.1.0.dist-info/WHEEL +4 -0
- nmn-0.1.0.dist-info/licenses/LICENSE +661 -0
nmn/__init__.py
ADDED
nmn/keras/nmn.py
ADDED
@@ -0,0 +1,153 @@
|
|
1
|
+
from keras.src import activations, constraints, initializers, regularizers
|
2
|
+
from keras.src.api_export import keras_export
|
3
|
+
from keras.src.layers.input_spec import InputSpec
|
4
|
+
from keras.src.layers.layer import Layer
|
5
|
+
from keras.src import ops
|
6
|
+
import math
|
7
|
+
|
8
|
+
@keras_export("keras.layers.YatDense")
|
9
|
+
class YatNMN(Layer):
|
10
|
+
"""A YAT densely-connected NN layer.
|
11
|
+
|
12
|
+
This layer implements the operation:
|
13
|
+
`output = scale * (dot(input, kernel)^2 / (squared_euclidean_distance + epsilon))`
|
14
|
+
where:
|
15
|
+
- `scale` is a dynamic scaling factor based on output dimension
|
16
|
+
- `squared_euclidean_distance` is computed between input and kernel
|
17
|
+
- `epsilon` is a small constant to prevent division by zero
|
18
|
+
|
19
|
+
Args:
|
20
|
+
units: Positive integer, dimensionality of the output space.
|
21
|
+
use_bias: Boolean, whether the layer uses a bias vector.
|
22
|
+
epsilon: Float, small constant added to denominator for numerical stability.
|
23
|
+
kernel_initializer: Initializer for the `kernel` weights matrix.
|
24
|
+
bias_initializer: Initializer for the bias vector.
|
25
|
+
kernel_regularizer: Regularizer function applied to the `kernel` weights matrix.
|
26
|
+
bias_regularizer: Regularizer function applied to the bias vector.
|
27
|
+
activity_regularizer: Regularizer function applied to the output.
|
28
|
+
kernel_constraint: Constraint function applied to the `kernel` weights matrix.
|
29
|
+
bias_constraint: Constraint function applied to the bias vector.
|
30
|
+
|
31
|
+
Input shape:
|
32
|
+
N-D tensor with shape: `(batch_size, ..., input_dim)`.
|
33
|
+
The most common situation would be a 2D input with shape
|
34
|
+
`(batch_size, input_dim)`.
|
35
|
+
|
36
|
+
Output shape:
|
37
|
+
N-D tensor with shape: `(batch_size, ..., units)`.
|
38
|
+
For instance, for a 2D input with shape `(batch_size, input_dim)`,
|
39
|
+
the output would have shape `(batch_size, units)`.
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
units,
|
45
|
+
activation=None,
|
46
|
+
use_bias=True,
|
47
|
+
epsilon=1e-5,
|
48
|
+
kernel_initializer="orthogonal",
|
49
|
+
bias_initializer="zeros",
|
50
|
+
kernel_regularizer=None,
|
51
|
+
bias_regularizer=None,
|
52
|
+
activity_regularizer=None,
|
53
|
+
kernel_constraint=None,
|
54
|
+
bias_constraint=None,
|
55
|
+
**kwargs,
|
56
|
+
):
|
57
|
+
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
|
58
|
+
self.units = units
|
59
|
+
self.activation = activations.get(activation)
|
60
|
+
self.use_bias = use_bias
|
61
|
+
self.epsilon = epsilon
|
62
|
+
|
63
|
+
self.kernel_initializer = initializers.get(kernel_initializer)
|
64
|
+
self.bias_initializer = initializers.get(bias_initializer)
|
65
|
+
self.kernel_regularizer = regularizers.get(kernel_regularizer)
|
66
|
+
self.bias_regularizer = regularizers.get(bias_regularizer)
|
67
|
+
self.kernel_constraint = constraints.get(kernel_constraint)
|
68
|
+
self.bias_constraint = constraints.get(bias_constraint)
|
69
|
+
|
70
|
+
self.input_spec = InputSpec(min_ndim=2)
|
71
|
+
self.supports_masking = True
|
72
|
+
|
73
|
+
def build(self, input_shape):
|
74
|
+
input_dim = input_shape[-1]
|
75
|
+
|
76
|
+
self.kernel = self.add_weight(
|
77
|
+
name="kernel",
|
78
|
+
shape=(input_dim, self.units),
|
79
|
+
initializer=self.kernel_initializer,
|
80
|
+
regularizer=self.kernel_regularizer,
|
81
|
+
constraint=self.kernel_constraint,
|
82
|
+
trainable=True,
|
83
|
+
)
|
84
|
+
|
85
|
+
if self.use_bias:
|
86
|
+
self.bias = self.add_weight(
|
87
|
+
name="bias",
|
88
|
+
shape=(self.units,),
|
89
|
+
initializer=self.bias_initializer,
|
90
|
+
regularizer=self.bias_regularizer,
|
91
|
+
constraint=self.bias_constraint,
|
92
|
+
trainable=True,
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
self.bias = None
|
96
|
+
|
97
|
+
# Add alpha parameter for dynamic scaling
|
98
|
+
self.alpha = self.add_weight(
|
99
|
+
name="alpha",
|
100
|
+
shape=(1,),
|
101
|
+
initializer="ones",
|
102
|
+
trainable=True,
|
103
|
+
)
|
104
|
+
|
105
|
+
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
|
106
|
+
self.built = True
|
107
|
+
|
108
|
+
def call(self, inputs):
|
109
|
+
# Compute dot product
|
110
|
+
dot_product = ops.matmul(inputs, self.kernel)
|
111
|
+
|
112
|
+
# Compute squared distances
|
113
|
+
inputs_squared_sum = ops.sum(inputs ** 2, axis=-1, keepdims=True)
|
114
|
+
kernel_squared_sum = ops.sum(self.kernel ** 2, axis=0)
|
115
|
+
distances = inputs_squared_sum + kernel_squared_sum - 2 * dot_product
|
116
|
+
|
117
|
+
# Compute inverse square attention
|
118
|
+
outputs = dot_product ** 2 / (distances + self.epsilon)
|
119
|
+
if self.use_bias:
|
120
|
+
outputs = ops.add(outputs, self.bias)
|
121
|
+
|
122
|
+
# Apply dynamic scaling
|
123
|
+
scale = (ops.sqrt(ops.cast(self.units, self.compute_dtype)) /
|
124
|
+
ops.log1p(ops.cast(self.units, self.compute_dtype))) ** self.alpha
|
125
|
+
outputs = outputs * scale
|
126
|
+
|
127
|
+
|
128
|
+
if self.activation is not None:
|
129
|
+
outputs = self.activation(outputs)
|
130
|
+
|
131
|
+
return outputs
|
132
|
+
|
133
|
+
def compute_output_shape(self, input_shape):
|
134
|
+
output_shape = list(input_shape)
|
135
|
+
output_shape[-1] = self.units
|
136
|
+
return tuple(output_shape)
|
137
|
+
|
138
|
+
def get_config(self):
|
139
|
+
config = super().get_config()
|
140
|
+
config.update({
|
141
|
+
"units": self.units,
|
142
|
+
"activation": activations.serialize(self.activation),
|
143
|
+
"use_bias": self.use_bias,
|
144
|
+
"epsilon": self.epsilon,
|
145
|
+
"kernel_initializer": initializers.serialize(self.kernel_initializer),
|
146
|
+
"bias_initializer": initializers.serialize(self.bias_initializer),
|
147
|
+
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
|
148
|
+
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
|
149
|
+
"activity_regularizer": regularizers.serialize(self.activity_regularizer),
|
150
|
+
"kernel_constraint": constraints.serialize(self.kernel_constraint),
|
151
|
+
"bias_constraint": constraints.serialize(self.bias_constraint),
|
152
|
+
})
|
153
|
+
return config
|
nmn/linen/nmn.py
ADDED
@@ -0,0 +1,112 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from flax.linen.dtypes import promote_dtype
|
3
|
+
from flax.linen.module import Module, compact
|
4
|
+
from flax.typing import (
|
5
|
+
PRNGKey as PRNGKey,
|
6
|
+
Shape as Shape,
|
7
|
+
DotGeneralT,
|
8
|
+
)
|
9
|
+
|
10
|
+
from typing import (
|
11
|
+
Any,
|
12
|
+
)
|
13
|
+
import jax.numpy as jnp
|
14
|
+
import jax.lax as lax
|
15
|
+
from flax.linen import Module, compact
|
16
|
+
from flax import linen as nn
|
17
|
+
from flax.linen.initializers import zeros_init, lecun_normal
|
18
|
+
from typing import Any, Optional
|
19
|
+
|
20
|
+
class YatNMN(Module):
|
21
|
+
"""A custom transformation applied over the last dimension of the input using squared Euclidean distance.
|
22
|
+
|
23
|
+
Attributes:
|
24
|
+
features: the number of output features.
|
25
|
+
use_bias: whether to add a bias to the output (default: True).
|
26
|
+
dtype: the dtype of the computation (default: infer from input and params).
|
27
|
+
param_dtype: the dtype passed to parameter initializers (default: float32).
|
28
|
+
precision: numerical precision of the computation see ``jax.lax.Precision`` for details.
|
29
|
+
kernel_init: initializer function for the weight matrix.
|
30
|
+
bias_init: initializer function for the bias.
|
31
|
+
epsilon: small constant added to avoid division by zero (default: 1e-6).
|
32
|
+
"""
|
33
|
+
features: int
|
34
|
+
use_bias: bool = True
|
35
|
+
use_alpha: bool = True
|
36
|
+
dtype: Optional[Any] = None
|
37
|
+
param_dtype: Any = jnp.float32
|
38
|
+
precision: Any = None
|
39
|
+
kernel_init: Any = nn.initializers.orthogonal()
|
40
|
+
bias_init: Any = zeros_init()
|
41
|
+
|
42
|
+
alpha_init: Any = lambda key, shape, dtype: jnp.ones(shape, dtype) # Initialize alpha to 1.0
|
43
|
+
epsilon: float = 1e-6
|
44
|
+
dot_general: DotGeneralT | None = None
|
45
|
+
dot_general_cls: Any = None
|
46
|
+
return_weights: bool = False
|
47
|
+
|
48
|
+
@compact
|
49
|
+
def __call__(self, inputs: Any) -> Any:
|
50
|
+
"""Applies a transformation to the inputs along the last dimension using squared Euclidean distance.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
inputs: The nd-array to be transformed.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The transformed input.
|
57
|
+
"""
|
58
|
+
kernel = self.param(
|
59
|
+
'kernel',
|
60
|
+
self.kernel_init,
|
61
|
+
(self.features, jnp.shape(inputs)[-1]),
|
62
|
+
self.param_dtype,
|
63
|
+
)
|
64
|
+
if self.use_alpha:
|
65
|
+
alpha = self.param(
|
66
|
+
'alpha',
|
67
|
+
self.alpha_init,
|
68
|
+
(1,), # Single scalar parameter
|
69
|
+
self.param_dtype,
|
70
|
+
)
|
71
|
+
else:
|
72
|
+
alpha = None
|
73
|
+
|
74
|
+
if self.use_bias:
|
75
|
+
bias = self.param(
|
76
|
+
'bias', self.bias_init, (self.features,), self.param_dtype
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
bias = None
|
80
|
+
|
81
|
+
inputs, kernel, bias, alpha = promote_dtype(inputs, kernel, bias, alpha, dtype=self.dtype)
|
82
|
+
|
83
|
+
# Compute dot product between input and kernel
|
84
|
+
if self.dot_general_cls is not None:
|
85
|
+
dot_general = self.dot_general_cls()
|
86
|
+
elif self.dot_general is not None:
|
87
|
+
dot_general = self.dot_general
|
88
|
+
else:
|
89
|
+
dot_general = lax.dot_general
|
90
|
+
y = dot_general(
|
91
|
+
inputs,
|
92
|
+
jnp.transpose(kernel),
|
93
|
+
(((inputs.ndim - 1,), (0,)), ((), ())),
|
94
|
+
precision=self.precision,
|
95
|
+
)
|
96
|
+
inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
|
97
|
+
kernel_squared_sum = jnp.sum(kernel**2, axis=-1)
|
98
|
+
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
99
|
+
|
100
|
+
# # Element-wise operation
|
101
|
+
y = y ** 2 / (distances + self.epsilon)
|
102
|
+
if bias is not None:
|
103
|
+
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
104
|
+
|
105
|
+
if alpha is not None:
|
106
|
+
scale = (jnp.sqrt(self.features) / jnp.log(1 + self.features)) ** alpha
|
107
|
+
y = y * scale
|
108
|
+
|
109
|
+
# Normalize y
|
110
|
+
if self.return_weights:
|
111
|
+
return y, kernel
|
112
|
+
return y
|
nmn/nnx/nmn.py
ADDED
@@ -0,0 +1,170 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import typing as tp
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import numpy as np
|
8
|
+
from jax import lax
|
9
|
+
import opt_einsum
|
10
|
+
|
11
|
+
from flax.core.frozen_dict import FrozenDict
|
12
|
+
from flax import nnx
|
13
|
+
from flax.nnx import rnglib, variablelib
|
14
|
+
from flax.nnx.module import Module, first_from
|
15
|
+
from flax.nnx.nn import dtypes, initializers
|
16
|
+
from flax.typing import (
|
17
|
+
Dtype,
|
18
|
+
Shape,
|
19
|
+
Initializer,
|
20
|
+
PrecisionLike,
|
21
|
+
DotGeneralT,
|
22
|
+
ConvGeneralDilatedT,
|
23
|
+
PaddingLike,
|
24
|
+
LaxPadding,
|
25
|
+
PromoteDtypeFn,
|
26
|
+
EinsumT,
|
27
|
+
)
|
28
|
+
|
29
|
+
Array = jax.Array
|
30
|
+
Axis = int
|
31
|
+
Size = int
|
32
|
+
|
33
|
+
|
34
|
+
default_kernel_init = initializers.lecun_normal()
|
35
|
+
default_bias_init = initializers.zeros_init()
|
36
|
+
default_alpha_init = initializers.ones_init()
|
37
|
+
|
38
|
+
class YatNMN(Module):
|
39
|
+
"""A linear transformation applied over the last dimension of the input.
|
40
|
+
|
41
|
+
Example usage::
|
42
|
+
|
43
|
+
>>> from flax import nnx
|
44
|
+
>>> import jax, jax.numpy as jnp
|
45
|
+
|
46
|
+
>>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
|
47
|
+
>>> jax.tree.map(jnp.shape, nnx.state(layer))
|
48
|
+
State({
|
49
|
+
'bias': VariableState(
|
50
|
+
type=Param,
|
51
|
+
value=(4,)
|
52
|
+
),
|
53
|
+
'kernel': VariableState(
|
54
|
+
type=Param,
|
55
|
+
value=(3, 4)
|
56
|
+
)
|
57
|
+
})
|
58
|
+
|
59
|
+
Args:
|
60
|
+
in_features: the number of input features.
|
61
|
+
out_features: the number of output features.
|
62
|
+
use_bias: whether to add a bias to the output (default: True).
|
63
|
+
dtype: the dtype of the computation (default: infer from input and params).
|
64
|
+
param_dtype: the dtype passed to parameter initializers (default: float32).
|
65
|
+
precision: numerical precision of the computation see ``jax.lax.Precision``
|
66
|
+
for details.
|
67
|
+
kernel_init: initializer function for the weight matrix.
|
68
|
+
bias_init: initializer function for the bias.
|
69
|
+
dot_general: dot product function.
|
70
|
+
promote_dtype: function to promote the dtype of the arrays to the desired
|
71
|
+
dtype. The function should accept a tuple of ``(inputs, kernel, bias)``
|
72
|
+
and a ``dtype`` keyword argument, and return a tuple of arrays with the
|
73
|
+
promoted dtype.
|
74
|
+
rngs: rng key.
|
75
|
+
"""
|
76
|
+
|
77
|
+
__data__ = ('kernel', 'bias')
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
in_features: int,
|
82
|
+
out_features: int,
|
83
|
+
*,
|
84
|
+
use_bias: bool = True,
|
85
|
+
use_alpha: bool = True,
|
86
|
+
dtype: tp.Optional[Dtype] = None,
|
87
|
+
param_dtype: Dtype = jnp.float32,
|
88
|
+
precision: PrecisionLike = None,
|
89
|
+
kernel_init: Initializer = default_kernel_init,
|
90
|
+
bias_init: Initializer = default_bias_init,
|
91
|
+
alpha_init: Initializer = default_alpha_init,
|
92
|
+
dot_general: DotGeneralT = lax.dot_general,
|
93
|
+
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
|
94
|
+
rngs: rnglib.Rngs,
|
95
|
+
epsilon: float = 1e-5,
|
96
|
+
):
|
97
|
+
|
98
|
+
kernel_key = rngs.params()
|
99
|
+
self.kernel = nnx.Param(
|
100
|
+
kernel_init(kernel_key, (in_features, out_features), param_dtype)
|
101
|
+
)
|
102
|
+
self.bias: nnx.Param[jax.Array] | None
|
103
|
+
if use_bias:
|
104
|
+
bias_key = rngs.params()
|
105
|
+
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
|
106
|
+
else:
|
107
|
+
self.bias = None
|
108
|
+
|
109
|
+
self.alpha: nnx.Param[jax.Array] | None
|
110
|
+
if use_alpha:
|
111
|
+
alpha_key = rngs.params()
|
112
|
+
self.alpha = nnx.Param(alpha_init(bias_key, (1,), param_dtype))
|
113
|
+
else:
|
114
|
+
self.alpha = None
|
115
|
+
|
116
|
+
self.in_features = in_features
|
117
|
+
self.out_features = out_features
|
118
|
+
self.use_bias = use_bias
|
119
|
+
self.use_alpha = use_alpha
|
120
|
+
self.dtype = dtype
|
121
|
+
self.param_dtype = param_dtype
|
122
|
+
self.precision = precision
|
123
|
+
self.kernel_init = kernel_init
|
124
|
+
self.bias_init = bias_init
|
125
|
+
self.dot_general = dot_general
|
126
|
+
self.promote_dtype = promote_dtype
|
127
|
+
self.epsilon = epsilon
|
128
|
+
|
129
|
+
def __call__(self, inputs: Array) -> Array:
|
130
|
+
"""Applies a linear transformation to the inputs along the last dimension.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
inputs: The nd-array to be transformed.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
The transformed input.
|
137
|
+
"""
|
138
|
+
kernel = self.kernel.value
|
139
|
+
bias = self.bias.value if self.bias is not None else None
|
140
|
+
alpha = self.alpha.value if self.alpha is not None else None
|
141
|
+
|
142
|
+
inputs, kernel, bias, alpha = self.promote_dtype(
|
143
|
+
(inputs, kernel, bias, alpha), dtype=self.dtype
|
144
|
+
)
|
145
|
+
y = self.dot_general(
|
146
|
+
inputs,
|
147
|
+
kernel,
|
148
|
+
(((inputs.ndim - 1,), (0,)), ((), ())),
|
149
|
+
precision=self.precision,
|
150
|
+
)
|
151
|
+
|
152
|
+
assert self.use_bias == (bias is not None)
|
153
|
+
assert self.use_alpha == (alpha is not None)
|
154
|
+
|
155
|
+
inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
|
156
|
+
kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True) # Change axis to 0 and keepdims to True
|
157
|
+
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
158
|
+
|
159
|
+
# # Element-wise operation
|
160
|
+
y = y ** 2 / (distances + self.epsilon)
|
161
|
+
|
162
|
+
if bias is not None:
|
163
|
+
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
164
|
+
|
165
|
+
if alpha is not None:
|
166
|
+
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
167
|
+
y = y * scale
|
168
|
+
|
169
|
+
|
170
|
+
return y
|