brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/init.py
CHANGED
@@ -1,809 +1,809 @@
|
|
1
|
-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import math
|
17
|
-
from typing import Optional, Tuple
|
18
|
-
from typing import Union, Callable, Sequence
|
19
|
-
|
20
|
-
import brainunit as u
|
21
|
-
import jax
|
22
|
-
import jax.numpy as jnp
|
23
|
-
import numpy as np
|
24
|
-
|
25
|
-
from brainstate import environ, random
|
26
|
-
from brainstate._state import State
|
27
|
-
from brainstate._utils import set_module_as
|
28
|
-
from brainstate.typing import ArrayLike, SeedOrKey
|
29
|
-
from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
|
30
|
-
|
31
|
-
__all__ = [
|
32
|
-
'param',
|
33
|
-
'calculate_init_gain',
|
34
|
-
'ZeroInit',
|
35
|
-
'Constant',
|
36
|
-
'Identity',
|
37
|
-
'Normal',
|
38
|
-
'TruncatedNormal',
|
39
|
-
'Uniform',
|
40
|
-
'VarianceScaling',
|
41
|
-
'KaimingUniform',
|
42
|
-
'KaimingNormal',
|
43
|
-
'XavierUniform',
|
44
|
-
'XavierNormal',
|
45
|
-
'LecunUniform',
|
46
|
-
'LecunNormal',
|
47
|
-
'Orthogonal',
|
48
|
-
'DeltaOrthogonal',
|
49
|
-
]
|
50
|
-
|
51
|
-
|
52
|
-
class Initializer(PrettyRepr):
|
53
|
-
"""
|
54
|
-
Base class for initializers.
|
55
|
-
"""
|
56
|
-
__module__ = 'brainstate.nn'
|
57
|
-
|
58
|
-
def __call__(self, *args, **kwargs):
|
59
|
-
raise NotImplementedError
|
60
|
-
|
61
|
-
def __pretty_repr__(self):
|
62
|
-
"""
|
63
|
-
Pretty repr for the object.
|
64
|
-
"""
|
65
|
-
yield PrettyType(type=type(self))
|
66
|
-
for name, value in vars(self).items():
|
67
|
-
if name.startswith('_'):
|
68
|
-
continue
|
69
|
-
yield PrettyAttr(name, repr(value))
|
70
|
-
|
71
|
-
|
72
|
-
def to_size(x) -> Optional[Tuple[int]]:
|
73
|
-
if isinstance(x, (tuple, list)):
|
74
|
-
return tuple(x)
|
75
|
-
if isinstance(x, (int, np.integer)):
|
76
|
-
return (x,)
|
77
|
-
if x is None:
|
78
|
-
return x
|
79
|
-
raise ValueError(f'Cannot make a size for {x}')
|
80
|
-
|
81
|
-
|
82
|
-
def _is_scalar(x):
|
83
|
-
return u.math.isscalar(x)
|
84
|
-
|
85
|
-
|
86
|
-
def are_broadcastable_shapes(shape1, shape2):
|
87
|
-
"""
|
88
|
-
Check if two shapes are broadcastable.
|
89
|
-
|
90
|
-
Parameters:
|
91
|
-
- shape1: Tuple[int], the shape of the first array.
|
92
|
-
- shape2: Tuple[int], the shape of the second array.
|
93
|
-
|
94
|
-
Returns:
|
95
|
-
- bool: True if shapes are broadcastable, False otherwise.
|
96
|
-
"""
|
97
|
-
# Reverse the shapes to compare from the last dimension
|
98
|
-
shape1_reversed = shape1[::-1]
|
99
|
-
shape2_reversed = shape2[::-1]
|
100
|
-
|
101
|
-
# Iterate over the dimensions of the shorter shape
|
102
|
-
for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
|
103
|
-
# Check if the dimensions are not equal and neither is 1
|
104
|
-
if dim1 != dim2 and 1 not in (dim1, dim2):
|
105
|
-
return False
|
106
|
-
|
107
|
-
# If all dimensions are compatible, the shapes are broadcastable
|
108
|
-
return True
|
109
|
-
|
110
|
-
|
111
|
-
def _expand_params_to_match_sizes(params, sizes):
|
112
|
-
"""
|
113
|
-
Expand the dimensions of params to match the dimensions of sizes.
|
114
|
-
|
115
|
-
Parameters:
|
116
|
-
- params: jax.Array or np.ndarray, the parameter array to be expanded.
|
117
|
-
- sizes: tuple[int] or list[int], the target shape dimensions.
|
118
|
-
|
119
|
-
Returns:
|
120
|
-
- Expanded params with dimensions matching sizes.
|
121
|
-
"""
|
122
|
-
params_dim = params.ndim
|
123
|
-
sizes_dim = len(sizes)
|
124
|
-
dim_diff = sizes_dim - params_dim
|
125
|
-
|
126
|
-
# Add new axes to params if it has fewer dimensions than sizes
|
127
|
-
for _ in range(dim_diff):
|
128
|
-
params = u.math.expand_dims(params, axis=0) # Add new axis at the last dimension
|
129
|
-
return params
|
130
|
-
|
131
|
-
|
132
|
-
@set_module_as('brainstate.nn')
|
133
|
-
def param(
|
134
|
-
parameter: Union[Callable, ArrayLike, State],
|
135
|
-
sizes: Union[int, Sequence[int]],
|
136
|
-
batch_size: Optional[int] = None,
|
137
|
-
allow_none: bool = True,
|
138
|
-
allow_scalar: bool = True,
|
139
|
-
):
|
140
|
-
"""Initialize parameters.
|
141
|
-
|
142
|
-
Parameters
|
143
|
-
----------
|
144
|
-
parameter: callable, ArrayLike, State
|
145
|
-
The initialization of the parameter.
|
146
|
-
- If it is None, the created parameter will be None.
|
147
|
-
- If it is a callable function :math:`f`, the ``f(size)`` will be returned.
|
148
|
-
- If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
|
149
|
-
- If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
|
150
|
-
sizes: int, sequence of int
|
151
|
-
The shape of the parameter.
|
152
|
-
batch_size: int
|
153
|
-
The batch size.
|
154
|
-
allow_none: bool
|
155
|
-
Whether allow the parameter is None.
|
156
|
-
allow_scalar: bool
|
157
|
-
Whether allow the parameter is a scalar value.
|
158
|
-
|
159
|
-
Returns
|
160
|
-
-------
|
161
|
-
param: ArrayType, float, int, bool, None
|
162
|
-
The initialized parameter.
|
163
|
-
|
164
|
-
See Also
|
165
|
-
--------
|
166
|
-
noise, state
|
167
|
-
"""
|
168
|
-
# Check if the parameter is None
|
169
|
-
if parameter is None:
|
170
|
-
if allow_none:
|
171
|
-
return None
|
172
|
-
else:
|
173
|
-
raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
|
174
|
-
f'Callable function, but we got None. ')
|
175
|
-
|
176
|
-
# Check if the parameter is a scalar value
|
177
|
-
if allow_scalar and _is_scalar(parameter):
|
178
|
-
return parameter
|
179
|
-
|
180
|
-
# Convert sizes to a tuple
|
181
|
-
sizes = tuple(to_size(sizes))
|
182
|
-
|
183
|
-
# Check if the parameter is a callable function
|
184
|
-
if callable(parameter):
|
185
|
-
if batch_size is not None:
|
186
|
-
sizes = (batch_size,) + sizes
|
187
|
-
return parameter(sizes)
|
188
|
-
elif isinstance(parameter, (np.ndarray, jax.Array, u.Quantity, State)):
|
189
|
-
parameter = parameter
|
190
|
-
else:
|
191
|
-
raise ValueError(f'Unknown parameter type: {type(parameter)}')
|
192
|
-
|
193
|
-
# Check if the shape of the parameter matches the given size
|
194
|
-
if not are_broadcastable_shapes(parameter.shape, sizes):
|
195
|
-
raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
|
196
|
-
|
197
|
-
# Expand the parameter to match the given batch size
|
198
|
-
param_value = parameter.value if isinstance(parameter, State) else parameter
|
199
|
-
if batch_size is not None:
|
200
|
-
if param_value.ndim <= len(sizes):
|
201
|
-
# add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
|
202
|
-
param_value = _expand_params_to_match_sizes(param_value, sizes)
|
203
|
-
param_value = u.math.repeat(
|
204
|
-
u.math.expand_dims(param_value, axis=0),
|
205
|
-
batch_size,
|
206
|
-
axis=0
|
207
|
-
)
|
208
|
-
else:
|
209
|
-
if param_value.shape[0] != batch_size:
|
210
|
-
raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
|
211
|
-
f'does not match with the given batch size {batch_size}')
|
212
|
-
return type(parameter)(param_value) if isinstance(parameter, State) else param_value
|
213
|
-
|
214
|
-
|
215
|
-
def calculate_init_gain(nonlinearity, param=None):
|
216
|
-
r"""Return the recommended gain value for the given nonlinearity function.
|
217
|
-
The values are as follows:
|
218
|
-
|
219
|
-
================= ====================================================
|
220
|
-
nonlinearity gain
|
221
|
-
================= ====================================================
|
222
|
-
Linear / Identity :math:`1`
|
223
|
-
Conv{1,2,3}D :math:`1`
|
224
|
-
Sigmoid :math:`1`
|
225
|
-
Tanh :math:`\frac{5}{3}`
|
226
|
-
ReLU :math:`\sqrt{2}`
|
227
|
-
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
228
|
-
SELU :math:`\frac{3}{4}`
|
229
|
-
================= ====================================================
|
230
|
-
|
231
|
-
.. warning::
|
232
|
-
In order to implement `Self-Normalizing Neural Networks`_ ,
|
233
|
-
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
|
234
|
-
This gives the initial weights a variance of ``1 / N``,
|
235
|
-
which is necessary to induce a stable fixed point in the forward pass.
|
236
|
-
In contrast, the default gain for ``SELU`` sacrifices the normalisation
|
237
|
-
effect for more stable gradient flow in rectangular layers.
|
238
|
-
|
239
|
-
Args:
|
240
|
-
nonlinearity: the non-linear function (`nn.functional` name)
|
241
|
-
param: optional parameter for the non-linear function
|
242
|
-
|
243
|
-
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
|
244
|
-
"""
|
245
|
-
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
246
|
-
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
247
|
-
return 1
|
248
|
-
elif nonlinearity == 'tanh':
|
249
|
-
return 5.0 / 3
|
250
|
-
elif nonlinearity == 'relu':
|
251
|
-
return math.sqrt(2.0)
|
252
|
-
elif nonlinearity == 'leaky_relu':
|
253
|
-
if param is None:
|
254
|
-
negative_slope = 0.01
|
255
|
-
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
256
|
-
# True/False are instances of int, hence check above
|
257
|
-
negative_slope = param
|
258
|
-
else:
|
259
|
-
raise ValueError("negative_slope {} not a valid number".format(param))
|
260
|
-
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
261
|
-
elif nonlinearity == 'selu':
|
262
|
-
return 3.0 / 4
|
263
|
-
else:
|
264
|
-
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
265
|
-
|
266
|
-
|
267
|
-
def _format_shape(shape):
|
268
|
-
if isinstance(shape, int):
|
269
|
-
return (shape,)
|
270
|
-
if len(shape) == 0:
|
271
|
-
raise ValueError('Please provide shape.')
|
272
|
-
if len(shape) == 1:
|
273
|
-
if isinstance(shape[0], (tuple, list)):
|
274
|
-
return shape[0]
|
275
|
-
else:
|
276
|
-
return shape
|
277
|
-
else:
|
278
|
-
return shape
|
279
|
-
|
280
|
-
|
281
|
-
def _compute_fans(shape, in_axis=-2, out_axis=-1):
|
282
|
-
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
283
|
-
fan_in = shape[in_axis] * receptive_field_size
|
284
|
-
fan_out = shape[out_axis] * receptive_field_size
|
285
|
-
return fan_in, fan_out
|
286
|
-
|
287
|
-
|
288
|
-
class Normal(Initializer):
|
289
|
-
"""Initialize weights with normal distribution.
|
290
|
-
|
291
|
-
Parameters
|
292
|
-
----------
|
293
|
-
scale : float
|
294
|
-
The gain of the derivation of the normal distribution.
|
295
|
-
|
296
|
-
"""
|
297
|
-
__module__ = 'brainstate.nn'
|
298
|
-
|
299
|
-
def __init__(
|
300
|
-
self,
|
301
|
-
mean: ArrayLike = 0.,
|
302
|
-
scale: ArrayLike = 1.,
|
303
|
-
unit: u.Unit = u.UNITLESS,
|
304
|
-
seed: SeedOrKey = None
|
305
|
-
):
|
306
|
-
super().__init__()
|
307
|
-
self.scale = scale
|
308
|
-
self.mean = mean
|
309
|
-
self.rng = random.default_rng(seed)
|
310
|
-
self.unit = unit
|
311
|
-
|
312
|
-
def __call__(self, shape, **kwargs):
|
313
|
-
shape = to_size(shape)
|
314
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
315
|
-
rng = kwargs.get('rng', self.rng)
|
316
|
-
weights = rng.normal(size=shape, loc=self.mean, scale=self.scale, dtype=dtype)
|
317
|
-
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
318
|
-
|
319
|
-
|
320
|
-
class TruncatedNormal(Initializer):
|
321
|
-
"""Initialize weights with truncated normal distribution.
|
322
|
-
|
323
|
-
Parameters
|
324
|
-
----------
|
325
|
-
loc : float, ndarray
|
326
|
-
Mean ("centre") of the distribution before truncating. Note that
|
327
|
-
the mean of the truncated distribution will not be exactly equal
|
328
|
-
to ``loc``.
|
329
|
-
scale : float
|
330
|
-
The standard deviation of the normal distribution before truncating.
|
331
|
-
lower : float, ndarray
|
332
|
-
A float or array of floats representing the lower bound for
|
333
|
-
truncation. Must be broadcast-compatible with ``upper``.
|
334
|
-
upper : float, ndarray
|
335
|
-
A float or array of floats representing the upper bound for
|
336
|
-
truncation. Must be broadcast-compatible with ``lower``.
|
337
|
-
|
338
|
-
"""
|
339
|
-
__module__ = 'brainstate.nn'
|
340
|
-
|
341
|
-
def __init__(
|
342
|
-
self,
|
343
|
-
loc: ArrayLike = 0.,
|
344
|
-
scale: ArrayLike = 1.,
|
345
|
-
unit: u.Unit = u.UNITLESS,
|
346
|
-
lower: ArrayLike = None,
|
347
|
-
upper: ArrayLike = None,
|
348
|
-
seed: SeedOrKey = None,
|
349
|
-
):
|
350
|
-
super().__init__()
|
351
|
-
assert scale > 0, '`scale` must be positive.'
|
352
|
-
self.scale = scale
|
353
|
-
self.loc = loc
|
354
|
-
self.lower = lower
|
355
|
-
self.upper = upper
|
356
|
-
self.rng = random.default_rng(seed)
|
357
|
-
self.unit = unit
|
358
|
-
|
359
|
-
def __call__(self, shape, **kwargs):
|
360
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
361
|
-
rng = kwargs.get('rng', self.rng)
|
362
|
-
weights = rng.truncated_normal(
|
363
|
-
size=shape,
|
364
|
-
scale=self.scale,
|
365
|
-
lower=self.lower,
|
366
|
-
upper=self.upper,
|
367
|
-
loc=self.loc,
|
368
|
-
dtype=dtype
|
369
|
-
)
|
370
|
-
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
371
|
-
|
372
|
-
|
373
|
-
class Gamma(Initializer):
|
374
|
-
"""Initialize weights with Gamma distribution.
|
375
|
-
|
376
|
-
Parameters
|
377
|
-
----------
|
378
|
-
shape: float, Array
|
379
|
-
Shape parameter.
|
380
|
-
scale: float, Array
|
381
|
-
The gain of the derivation of the Gamma distribution.
|
382
|
-
|
383
|
-
"""
|
384
|
-
__module__ = 'brainstate.nn'
|
385
|
-
|
386
|
-
def __init__(
|
387
|
-
self,
|
388
|
-
shape: ArrayLike,
|
389
|
-
unit: u.Unit = u.UNITLESS,
|
390
|
-
scale: ArrayLike = None,
|
391
|
-
seed: SeedOrKey = None
|
392
|
-
):
|
393
|
-
self.shape = shape
|
394
|
-
self.scale = scale
|
395
|
-
self.rng = random.default_rng(seed)
|
396
|
-
self.unit = unit
|
397
|
-
|
398
|
-
def __call__(self, shape, **kwargs):
|
399
|
-
shape = to_size(shape)
|
400
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
401
|
-
rng = kwargs.get('rng', self.rng)
|
402
|
-
weights = rng.gamma(self.shape, scale=self.scale, size=shape, dtype=dtype)
|
403
|
-
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
404
|
-
|
405
|
-
|
406
|
-
class Exponential(Initializer):
|
407
|
-
"""Initialize weights with Gamma distribution.
|
408
|
-
|
409
|
-
Parameters
|
410
|
-
----------
|
411
|
-
scale: float, Array
|
412
|
-
The gain of the derivation of the Exponential distribution.
|
413
|
-
|
414
|
-
"""
|
415
|
-
__module__ = 'brainstate.nn'
|
416
|
-
|
417
|
-
def __init__(
|
418
|
-
self,
|
419
|
-
scale: ArrayLike = None,
|
420
|
-
seed: SeedOrKey = None,
|
421
|
-
unit: u.Unit = u.UNITLESS,
|
422
|
-
):
|
423
|
-
self.scale = scale
|
424
|
-
self.rng = random.default_rng(seed)
|
425
|
-
self.unit = unit
|
426
|
-
|
427
|
-
def __call__(self, shape, **kwargs):
|
428
|
-
shape = to_size(shape)
|
429
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
430
|
-
rng = kwargs.get('rng', self.rng)
|
431
|
-
weights = rng.exponential(scale=self.scale, size=shape, dtype=dtype)
|
432
|
-
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
433
|
-
|
434
|
-
|
435
|
-
class Uniform(Initializer):
|
436
|
-
"""Initialize weights with uniform distribution.
|
437
|
-
|
438
|
-
Parameters
|
439
|
-
----------
|
440
|
-
min_val : float
|
441
|
-
The lower limit of the uniform distribution.
|
442
|
-
max_val : float
|
443
|
-
The upper limit of the uniform distribution.
|
444
|
-
"""
|
445
|
-
__module__ = 'brainstate.nn'
|
446
|
-
|
447
|
-
def __init__(
|
448
|
-
self,
|
449
|
-
min_val: ArrayLike = 0.,
|
450
|
-
max_val: ArrayLike = 1.,
|
451
|
-
seed: SeedOrKey = None,
|
452
|
-
unit: u.Unit = u.UNITLESS,
|
453
|
-
):
|
454
|
-
super(Uniform, self).__init__()
|
455
|
-
self.min_val = min_val
|
456
|
-
self.max_val = max_val
|
457
|
-
self.rng = random.default_rng(seed)
|
458
|
-
self.unit = unit
|
459
|
-
|
460
|
-
def __call__(self, shape, **kwargs):
|
461
|
-
shape = to_size(shape)
|
462
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
463
|
-
rng = kwargs.get('rng', self.rng)
|
464
|
-
weights = rng.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=dtype)
|
465
|
-
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
466
|
-
|
467
|
-
|
468
|
-
class VarianceScaling(Initializer):
|
469
|
-
__module__ = 'brainstate.nn'
|
470
|
-
|
471
|
-
def __init__(
|
472
|
-
self,
|
473
|
-
scale: ArrayLike,
|
474
|
-
mode: str,
|
475
|
-
distribution: str,
|
476
|
-
in_axis: int = -2,
|
477
|
-
out_axis: int = -1,
|
478
|
-
seed: SeedOrKey = None,
|
479
|
-
unit: u.Unit = u.UNITLESS,
|
480
|
-
):
|
481
|
-
assert mode in ['fan_in', 'fan_out', 'fan_avg']
|
482
|
-
assert distribution in ['truncated_normal', 'normal', 'uniform']
|
483
|
-
self.scale = scale
|
484
|
-
self.mode = mode
|
485
|
-
self.in_axis = in_axis
|
486
|
-
self.out_axis = out_axis
|
487
|
-
self.distribution = distribution
|
488
|
-
self.rng = random.default_rng(seed)
|
489
|
-
self.unit = unit
|
490
|
-
|
491
|
-
def __call__(self, shape, **kwargs):
|
492
|
-
shape = to_size(shape)
|
493
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
494
|
-
rng = kwargs.get('rng', self.rng)
|
495
|
-
fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
|
496
|
-
if self.mode == "fan_in":
|
497
|
-
denominator = fan_in
|
498
|
-
elif self.mode == "fan_out":
|
499
|
-
denominator = fan_out
|
500
|
-
elif self.mode == "fan_avg":
|
501
|
-
denominator = (fan_in + fan_out) / 2
|
502
|
-
else:
|
503
|
-
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
|
504
|
-
variance = (self.scale / denominator).astype(dtype)
|
505
|
-
if self.distribution == "truncated_normal":
|
506
|
-
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
|
507
|
-
res = rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
|
508
|
-
elif self.distribution == "normal":
|
509
|
-
res = rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
|
510
|
-
elif self.distribution == "uniform":
|
511
|
-
res = (rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
|
512
|
-
jnp.sqrt(3 * variance).astype(dtype))
|
513
|
-
else:
|
514
|
-
raise ValueError("invalid distribution for variance scaling initializer")
|
515
|
-
return u.maybe_decimal(u.Quantity(res, unit=self.unit))
|
516
|
-
|
517
|
-
|
518
|
-
class KaimingUniform(VarianceScaling):
|
519
|
-
__module__ = 'brainstate.nn'
|
520
|
-
|
521
|
-
def __init__(
|
522
|
-
self,
|
523
|
-
scale: float = 2.0,
|
524
|
-
mode: str = "fan_in",
|
525
|
-
distribution: str = "uniform",
|
526
|
-
in_axis: int = -2,
|
527
|
-
out_axis: int = -1,
|
528
|
-
seed: SeedOrKey = None,
|
529
|
-
unit: u.Unit = u.UNITLESS,
|
530
|
-
):
|
531
|
-
super().__init__(scale,
|
532
|
-
mode,
|
533
|
-
distribution,
|
534
|
-
in_axis=in_axis,
|
535
|
-
out_axis=out_axis,
|
536
|
-
seed=seed,
|
537
|
-
unit=unit)
|
538
|
-
|
539
|
-
|
540
|
-
class KaimingNormal(VarianceScaling):
|
541
|
-
__module__ = 'brainstate.nn'
|
542
|
-
|
543
|
-
def __init__(
|
544
|
-
self,
|
545
|
-
scale: float = 2.0,
|
546
|
-
mode: str = "fan_in",
|
547
|
-
distribution: str = "truncated_normal",
|
548
|
-
in_axis: int = -2,
|
549
|
-
out_axis: int = -1,
|
550
|
-
seed: SeedOrKey = None,
|
551
|
-
unit: u.Unit = u.UNITLESS,
|
552
|
-
):
|
553
|
-
super().__init__(scale,
|
554
|
-
mode,
|
555
|
-
distribution,
|
556
|
-
in_axis=in_axis,
|
557
|
-
out_axis=out_axis,
|
558
|
-
seed=seed,
|
559
|
-
unit=unit)
|
560
|
-
|
561
|
-
|
562
|
-
class XavierUniform(VarianceScaling):
|
563
|
-
__module__ = 'brainstate.nn'
|
564
|
-
|
565
|
-
def __init__(
|
566
|
-
self,
|
567
|
-
scale: float = 1.0,
|
568
|
-
mode: str = "fan_avg",
|
569
|
-
distribution: str = "uniform",
|
570
|
-
in_axis: int = -2,
|
571
|
-
out_axis: int = -1,
|
572
|
-
seed: SeedOrKey = None,
|
573
|
-
unit: u.Unit = u.UNITLESS,
|
574
|
-
):
|
575
|
-
super().__init__(scale,
|
576
|
-
mode,
|
577
|
-
distribution,
|
578
|
-
in_axis=in_axis,
|
579
|
-
out_axis=out_axis,
|
580
|
-
seed=seed,
|
581
|
-
unit=unit)
|
582
|
-
|
583
|
-
|
584
|
-
class XavierNormal(VarianceScaling):
|
585
|
-
__module__ = 'brainstate.nn'
|
586
|
-
|
587
|
-
def __init__(
|
588
|
-
self,
|
589
|
-
scale: float = 1.0,
|
590
|
-
mode: str = "fan_avg",
|
591
|
-
distribution: str = "truncated_normal",
|
592
|
-
in_axis: int = -2,
|
593
|
-
out_axis: int = -1,
|
594
|
-
seed: SeedOrKey = None,
|
595
|
-
unit: u.Unit = u.UNITLESS,
|
596
|
-
):
|
597
|
-
super().__init__(scale,
|
598
|
-
mode,
|
599
|
-
distribution,
|
600
|
-
in_axis=in_axis,
|
601
|
-
out_axis=out_axis,
|
602
|
-
seed=seed,
|
603
|
-
unit=unit)
|
604
|
-
|
605
|
-
|
606
|
-
class LecunUniform(VarianceScaling):
|
607
|
-
__module__ = 'brainstate.nn'
|
608
|
-
|
609
|
-
def __init__(
|
610
|
-
self,
|
611
|
-
scale: float = 1.0,
|
612
|
-
mode: str = "fan_in",
|
613
|
-
distribution: str = "uniform",
|
614
|
-
in_axis: int = -2,
|
615
|
-
out_axis: int = -1,
|
616
|
-
seed: SeedOrKey = None,
|
617
|
-
unit: u.Unit = u.UNITLESS,
|
618
|
-
):
|
619
|
-
super().__init__(scale,
|
620
|
-
mode,
|
621
|
-
distribution,
|
622
|
-
in_axis=in_axis,
|
623
|
-
out_axis=out_axis,
|
624
|
-
seed=seed,
|
625
|
-
unit=unit)
|
626
|
-
|
627
|
-
|
628
|
-
class LecunNormal(VarianceScaling):
|
629
|
-
__module__ = 'brainstate.nn'
|
630
|
-
|
631
|
-
def __init__(
|
632
|
-
self,
|
633
|
-
scale: float = 1.0,
|
634
|
-
mode: str = "fan_in",
|
635
|
-
distribution: str = "truncated_normal",
|
636
|
-
in_axis: int = -2,
|
637
|
-
out_axis: int = -1,
|
638
|
-
seed: SeedOrKey = None,
|
639
|
-
unit: u.Unit = u.UNITLESS,
|
640
|
-
):
|
641
|
-
super().__init__(scale,
|
642
|
-
mode,
|
643
|
-
distribution,
|
644
|
-
in_axis=in_axis,
|
645
|
-
out_axis=out_axis,
|
646
|
-
seed=seed,
|
647
|
-
unit=unit)
|
648
|
-
|
649
|
-
|
650
|
-
class Orthogonal(Initializer):
|
651
|
-
"""
|
652
|
-
Construct an initializer for uniformly distributed orthogonal matrices.
|
653
|
-
|
654
|
-
If the shape is not square, the matrix will have orthonormal rows or columns
|
655
|
-
depending on which side is smaller.
|
656
|
-
"""
|
657
|
-
__module__ = 'brainstate.nn'
|
658
|
-
|
659
|
-
def __init__(
|
660
|
-
self,
|
661
|
-
scale: ArrayLike = 1.,
|
662
|
-
axis: int = -1,
|
663
|
-
seed: SeedOrKey = None,
|
664
|
-
unit: u.Unit = u.UNITLESS,
|
665
|
-
):
|
666
|
-
super().__init__()
|
667
|
-
self.scale = scale
|
668
|
-
self.axis = axis
|
669
|
-
self.rng = random.default_rng(seed)
|
670
|
-
self.unit = unit
|
671
|
-
|
672
|
-
def __call__(self, shape, **kwargs):
|
673
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
674
|
-
rng = kwargs.get('rng', self.rng)
|
675
|
-
shape = to_size(shape)
|
676
|
-
n_rows = shape[self.axis]
|
677
|
-
n_cols = np.prod(shape) // n_rows
|
678
|
-
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
|
679
|
-
norm_dst = rng.normal(size=matrix_shape, dtype=dtype)
|
680
|
-
|
681
|
-
q_mat, r_mat = jnp.linalg.qr(norm_dst)
|
682
|
-
# Enforce Q is uniformly distributed
|
683
|
-
q_mat *= jnp.sign(jnp.diag(r_mat))
|
684
|
-
if n_rows < n_cols:
|
685
|
-
q_mat = q_mat.T
|
686
|
-
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
|
687
|
-
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
|
688
|
-
r = jnp.asarray(self.scale, dtype=dtype) * q_mat
|
689
|
-
return u.maybe_decimal(u.Quantity(r, unit=self.unit))
|
690
|
-
|
691
|
-
|
692
|
-
class DeltaOrthogonal(Initializer):
|
693
|
-
"""
|
694
|
-
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
|
695
|
-
|
696
|
-
The shape must be 3D, 4D or 5D.
|
697
|
-
"""
|
698
|
-
__module__ = 'brainstate.nn'
|
699
|
-
|
700
|
-
def __init__(
|
701
|
-
self,
|
702
|
-
scale: ArrayLike = 1.0,
|
703
|
-
axis: int = -1,
|
704
|
-
seed: SeedOrKey = None,
|
705
|
-
unit: u.Unit = u.UNITLESS,
|
706
|
-
):
|
707
|
-
super().__init__()
|
708
|
-
self.scale = scale
|
709
|
-
self.axis = axis
|
710
|
-
self.orghogonal = Orthogonal(scale=scale, axis=axis, seed=seed)
|
711
|
-
self.unit = unit
|
712
|
-
|
713
|
-
def __call__(self, shape, **kwargs):
|
714
|
-
shape = to_size(shape)
|
715
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
716
|
-
if len(shape) not in [3, 4, 5]:
|
717
|
-
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
|
718
|
-
if shape[-1] < shape[-2]:
|
719
|
-
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
|
720
|
-
ortho_matrix = u.Quantity(self.orghogonal(shape[-2:]))
|
721
|
-
W = u.Quantity(u.math.zeros(shape, dtype=dtype), unit=u.get_unit(ortho_matrix))
|
722
|
-
if len(shape) == 3:
|
723
|
-
k = shape[0]
|
724
|
-
W = W.at[(k - 1) // 2].set(ortho_matrix)
|
725
|
-
elif len(shape) == 4:
|
726
|
-
k1, k2 = shape[:2]
|
727
|
-
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2].set(ortho_matrix)
|
728
|
-
else:
|
729
|
-
k1, k2, k3 = shape[:3]
|
730
|
-
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
|
731
|
-
return u.maybe_decimal(u.Quantity(W.mantissa, unit=self.unit))
|
732
|
-
|
733
|
-
|
734
|
-
class ZeroInit(Initializer):
|
735
|
-
"""Zero initializer.
|
736
|
-
|
737
|
-
Initialize the weights with zeros.
|
738
|
-
"""
|
739
|
-
__module__ = 'brainstate.nn'
|
740
|
-
|
741
|
-
def __init__(self, unit: u.Unit = u.UNITLESS):
|
742
|
-
super(ZeroInit, self).__init__()
|
743
|
-
self.unit = unit
|
744
|
-
|
745
|
-
def __call__(self, shape, **kwargs):
|
746
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
747
|
-
shape = to_size(shape)
|
748
|
-
return u.maybe_decimal(u.math.zeros(shape, dtype=dtype, unit=self.unit))
|
749
|
-
|
750
|
-
|
751
|
-
class Constant(Initializer):
|
752
|
-
"""Constant initializer.
|
753
|
-
|
754
|
-
Initialize the weights with the given values.
|
755
|
-
|
756
|
-
Parameters
|
757
|
-
----------
|
758
|
-
value : float, int, bm.ndarray
|
759
|
-
The value to specify.
|
760
|
-
"""
|
761
|
-
__module__ = 'brainstate.nn'
|
762
|
-
|
763
|
-
def __init__(self, value=1., ):
|
764
|
-
super(Constant, self).__init__()
|
765
|
-
self.value = value
|
766
|
-
|
767
|
-
def __call__(self, shape, **kwargs):
|
768
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
769
|
-
shape = to_size(shape)
|
770
|
-
return u.maybe_decimal(u.math.full(shape, self.value, dtype=dtype))
|
771
|
-
|
772
|
-
|
773
|
-
class Identity(Initializer):
|
774
|
-
"""Returns the identity matrix.
|
775
|
-
|
776
|
-
This initializer was proposed in (Le, et al., 2015) [1]_.
|
777
|
-
|
778
|
-
Parameters
|
779
|
-
----------
|
780
|
-
value : float
|
781
|
-
The optional scaling factor.
|
782
|
-
|
783
|
-
Returns
|
784
|
-
-------
|
785
|
-
shape: tuple of int
|
786
|
-
The weight shape/size.
|
787
|
-
|
788
|
-
References
|
789
|
-
----------
|
790
|
-
.. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to
|
791
|
-
initialize recurrent networks of rectified linear units." arXiv preprint
|
792
|
-
arXiv:1504.00941 (2015).
|
793
|
-
"""
|
794
|
-
__module__ = 'brainstate.nn'
|
795
|
-
|
796
|
-
def __init__(self, value=1., unit: u.Unit = u.UNITLESS):
|
797
|
-
super(Identity, self).__init__()
|
798
|
-
self.value = value
|
799
|
-
self.unit = unit
|
800
|
-
|
801
|
-
def __call__(self, shape, **kwargs):
|
802
|
-
dtype = kwargs.get('dtype', environ.dftype())
|
803
|
-
shape = to_size(shape)
|
804
|
-
if isinstance(shape, (tuple, list)):
|
805
|
-
if len(shape) > 2:
|
806
|
-
raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
|
807
|
-
r = u.math.eye(*shape, dtype=dtype)
|
808
|
-
r = u.math.fill_diagonal(r, self.value)
|
809
|
-
return u.maybe_decimal(u.Quantity(r, unit=self.unit))
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import math
|
17
|
+
from typing import Optional, Tuple
|
18
|
+
from typing import Union, Callable, Sequence
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
from brainstate import environ, random
|
26
|
+
from brainstate._state import State
|
27
|
+
from brainstate._utils import set_module_as
|
28
|
+
from brainstate.typing import ArrayLike, SeedOrKey
|
29
|
+
from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'param',
|
33
|
+
'calculate_init_gain',
|
34
|
+
'ZeroInit',
|
35
|
+
'Constant',
|
36
|
+
'Identity',
|
37
|
+
'Normal',
|
38
|
+
'TruncatedNormal',
|
39
|
+
'Uniform',
|
40
|
+
'VarianceScaling',
|
41
|
+
'KaimingUniform',
|
42
|
+
'KaimingNormal',
|
43
|
+
'XavierUniform',
|
44
|
+
'XavierNormal',
|
45
|
+
'LecunUniform',
|
46
|
+
'LecunNormal',
|
47
|
+
'Orthogonal',
|
48
|
+
'DeltaOrthogonal',
|
49
|
+
]
|
50
|
+
|
51
|
+
|
52
|
+
class Initializer(PrettyRepr):
|
53
|
+
"""
|
54
|
+
Base class for initializers.
|
55
|
+
"""
|
56
|
+
__module__ = 'brainstate.nn'
|
57
|
+
|
58
|
+
def __call__(self, *args, **kwargs):
|
59
|
+
raise NotImplementedError
|
60
|
+
|
61
|
+
def __pretty_repr__(self):
|
62
|
+
"""
|
63
|
+
Pretty repr for the object.
|
64
|
+
"""
|
65
|
+
yield PrettyType(type=type(self))
|
66
|
+
for name, value in vars(self).items():
|
67
|
+
if name.startswith('_'):
|
68
|
+
continue
|
69
|
+
yield PrettyAttr(name, repr(value))
|
70
|
+
|
71
|
+
|
72
|
+
def to_size(x) -> Optional[Tuple[int]]:
|
73
|
+
if isinstance(x, (tuple, list)):
|
74
|
+
return tuple(x)
|
75
|
+
if isinstance(x, (int, np.integer)):
|
76
|
+
return (x,)
|
77
|
+
if x is None:
|
78
|
+
return x
|
79
|
+
raise ValueError(f'Cannot make a size for {x}')
|
80
|
+
|
81
|
+
|
82
|
+
def _is_scalar(x):
|
83
|
+
return u.math.isscalar(x)
|
84
|
+
|
85
|
+
|
86
|
+
def are_broadcastable_shapes(shape1, shape2):
|
87
|
+
"""
|
88
|
+
Check if two shapes are broadcastable.
|
89
|
+
|
90
|
+
Parameters:
|
91
|
+
- shape1: Tuple[int], the shape of the first array.
|
92
|
+
- shape2: Tuple[int], the shape of the second array.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
- bool: True if shapes are broadcastable, False otherwise.
|
96
|
+
"""
|
97
|
+
# Reverse the shapes to compare from the last dimension
|
98
|
+
shape1_reversed = shape1[::-1]
|
99
|
+
shape2_reversed = shape2[::-1]
|
100
|
+
|
101
|
+
# Iterate over the dimensions of the shorter shape
|
102
|
+
for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
|
103
|
+
# Check if the dimensions are not equal and neither is 1
|
104
|
+
if dim1 != dim2 and 1 not in (dim1, dim2):
|
105
|
+
return False
|
106
|
+
|
107
|
+
# If all dimensions are compatible, the shapes are broadcastable
|
108
|
+
return True
|
109
|
+
|
110
|
+
|
111
|
+
def _expand_params_to_match_sizes(params, sizes):
|
112
|
+
"""
|
113
|
+
Expand the dimensions of params to match the dimensions of sizes.
|
114
|
+
|
115
|
+
Parameters:
|
116
|
+
- params: jax.Array or np.ndarray, the parameter array to be expanded.
|
117
|
+
- sizes: tuple[int] or list[int], the target shape dimensions.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
- Expanded params with dimensions matching sizes.
|
121
|
+
"""
|
122
|
+
params_dim = params.ndim
|
123
|
+
sizes_dim = len(sizes)
|
124
|
+
dim_diff = sizes_dim - params_dim
|
125
|
+
|
126
|
+
# Add new axes to params if it has fewer dimensions than sizes
|
127
|
+
for _ in range(dim_diff):
|
128
|
+
params = u.math.expand_dims(params, axis=0) # Add new axis at the last dimension
|
129
|
+
return params
|
130
|
+
|
131
|
+
|
132
|
+
@set_module_as('brainstate.nn')
|
133
|
+
def param(
|
134
|
+
parameter: Union[Callable, ArrayLike, State],
|
135
|
+
sizes: Union[int, Sequence[int]],
|
136
|
+
batch_size: Optional[int] = None,
|
137
|
+
allow_none: bool = True,
|
138
|
+
allow_scalar: bool = True,
|
139
|
+
):
|
140
|
+
"""Initialize parameters.
|
141
|
+
|
142
|
+
Parameters
|
143
|
+
----------
|
144
|
+
parameter: callable, ArrayLike, State
|
145
|
+
The initialization of the parameter.
|
146
|
+
- If it is None, the created parameter will be None.
|
147
|
+
- If it is a callable function :math:`f`, the ``f(size)`` will be returned.
|
148
|
+
- If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
|
149
|
+
- If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
|
150
|
+
sizes: int, sequence of int
|
151
|
+
The shape of the parameter.
|
152
|
+
batch_size: int
|
153
|
+
The batch size.
|
154
|
+
allow_none: bool
|
155
|
+
Whether allow the parameter is None.
|
156
|
+
allow_scalar: bool
|
157
|
+
Whether allow the parameter is a scalar value.
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
param: ArrayType, float, int, bool, None
|
162
|
+
The initialized parameter.
|
163
|
+
|
164
|
+
See Also
|
165
|
+
--------
|
166
|
+
noise, state
|
167
|
+
"""
|
168
|
+
# Check if the parameter is None
|
169
|
+
if parameter is None:
|
170
|
+
if allow_none:
|
171
|
+
return None
|
172
|
+
else:
|
173
|
+
raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
|
174
|
+
f'Callable function, but we got None. ')
|
175
|
+
|
176
|
+
# Check if the parameter is a scalar value
|
177
|
+
if allow_scalar and _is_scalar(parameter):
|
178
|
+
return parameter
|
179
|
+
|
180
|
+
# Convert sizes to a tuple
|
181
|
+
sizes = tuple(to_size(sizes))
|
182
|
+
|
183
|
+
# Check if the parameter is a callable function
|
184
|
+
if callable(parameter):
|
185
|
+
if batch_size is not None:
|
186
|
+
sizes = (batch_size,) + sizes
|
187
|
+
return parameter(sizes)
|
188
|
+
elif isinstance(parameter, (np.ndarray, jax.Array, u.Quantity, State)):
|
189
|
+
parameter = parameter
|
190
|
+
else:
|
191
|
+
raise ValueError(f'Unknown parameter type: {type(parameter)}')
|
192
|
+
|
193
|
+
# Check if the shape of the parameter matches the given size
|
194
|
+
if not are_broadcastable_shapes(parameter.shape, sizes):
|
195
|
+
raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
|
196
|
+
|
197
|
+
# Expand the parameter to match the given batch size
|
198
|
+
param_value = parameter.value if isinstance(parameter, State) else parameter
|
199
|
+
if batch_size is not None:
|
200
|
+
if param_value.ndim <= len(sizes):
|
201
|
+
# add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
|
202
|
+
param_value = _expand_params_to_match_sizes(param_value, sizes)
|
203
|
+
param_value = u.math.repeat(
|
204
|
+
u.math.expand_dims(param_value, axis=0),
|
205
|
+
batch_size,
|
206
|
+
axis=0
|
207
|
+
)
|
208
|
+
else:
|
209
|
+
if param_value.shape[0] != batch_size:
|
210
|
+
raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
|
211
|
+
f'does not match with the given batch size {batch_size}')
|
212
|
+
return type(parameter)(param_value) if isinstance(parameter, State) else param_value
|
213
|
+
|
214
|
+
|
215
|
+
def calculate_init_gain(nonlinearity, param=None):
|
216
|
+
r"""Return the recommended gain value for the given nonlinearity function.
|
217
|
+
The values are as follows:
|
218
|
+
|
219
|
+
================= ====================================================
|
220
|
+
nonlinearity gain
|
221
|
+
================= ====================================================
|
222
|
+
Linear / Identity :math:`1`
|
223
|
+
Conv{1,2,3}D :math:`1`
|
224
|
+
Sigmoid :math:`1`
|
225
|
+
Tanh :math:`\frac{5}{3}`
|
226
|
+
ReLU :math:`\sqrt{2}`
|
227
|
+
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
228
|
+
SELU :math:`\frac{3}{4}`
|
229
|
+
================= ====================================================
|
230
|
+
|
231
|
+
.. warning::
|
232
|
+
In order to implement `Self-Normalizing Neural Networks`_ ,
|
233
|
+
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
|
234
|
+
This gives the initial weights a variance of ``1 / N``,
|
235
|
+
which is necessary to induce a stable fixed point in the forward pass.
|
236
|
+
In contrast, the default gain for ``SELU`` sacrifices the normalisation
|
237
|
+
effect for more stable gradient flow in rectangular layers.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
nonlinearity: the non-linear function (`nn.functional` name)
|
241
|
+
param: optional parameter for the non-linear function
|
242
|
+
|
243
|
+
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
|
244
|
+
"""
|
245
|
+
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
246
|
+
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
247
|
+
return 1
|
248
|
+
elif nonlinearity == 'tanh':
|
249
|
+
return 5.0 / 3
|
250
|
+
elif nonlinearity == 'relu':
|
251
|
+
return math.sqrt(2.0)
|
252
|
+
elif nonlinearity == 'leaky_relu':
|
253
|
+
if param is None:
|
254
|
+
negative_slope = 0.01
|
255
|
+
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
256
|
+
# True/False are instances of int, hence check above
|
257
|
+
negative_slope = param
|
258
|
+
else:
|
259
|
+
raise ValueError("negative_slope {} not a valid number".format(param))
|
260
|
+
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
261
|
+
elif nonlinearity == 'selu':
|
262
|
+
return 3.0 / 4
|
263
|
+
else:
|
264
|
+
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
265
|
+
|
266
|
+
|
267
|
+
def _format_shape(shape):
|
268
|
+
if isinstance(shape, int):
|
269
|
+
return (shape,)
|
270
|
+
if len(shape) == 0:
|
271
|
+
raise ValueError('Please provide shape.')
|
272
|
+
if len(shape) == 1:
|
273
|
+
if isinstance(shape[0], (tuple, list)):
|
274
|
+
return shape[0]
|
275
|
+
else:
|
276
|
+
return shape
|
277
|
+
else:
|
278
|
+
return shape
|
279
|
+
|
280
|
+
|
281
|
+
def _compute_fans(shape, in_axis=-2, out_axis=-1):
|
282
|
+
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
283
|
+
fan_in = shape[in_axis] * receptive_field_size
|
284
|
+
fan_out = shape[out_axis] * receptive_field_size
|
285
|
+
return fan_in, fan_out
|
286
|
+
|
287
|
+
|
288
|
+
class Normal(Initializer):
|
289
|
+
"""Initialize weights with normal distribution.
|
290
|
+
|
291
|
+
Parameters
|
292
|
+
----------
|
293
|
+
scale : float
|
294
|
+
The gain of the derivation of the normal distribution.
|
295
|
+
|
296
|
+
"""
|
297
|
+
__module__ = 'brainstate.nn'
|
298
|
+
|
299
|
+
def __init__(
|
300
|
+
self,
|
301
|
+
mean: ArrayLike = 0.,
|
302
|
+
scale: ArrayLike = 1.,
|
303
|
+
unit: u.Unit = u.UNITLESS,
|
304
|
+
seed: SeedOrKey = None
|
305
|
+
):
|
306
|
+
super().__init__()
|
307
|
+
self.scale = scale
|
308
|
+
self.mean = mean
|
309
|
+
self.rng = random.default_rng(seed)
|
310
|
+
self.unit = unit
|
311
|
+
|
312
|
+
def __call__(self, shape, **kwargs):
|
313
|
+
shape = to_size(shape)
|
314
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
315
|
+
rng = kwargs.get('rng', self.rng)
|
316
|
+
weights = rng.normal(size=shape, loc=self.mean, scale=self.scale, dtype=dtype)
|
317
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
318
|
+
|
319
|
+
|
320
|
+
class TruncatedNormal(Initializer):
|
321
|
+
"""Initialize weights with truncated normal distribution.
|
322
|
+
|
323
|
+
Parameters
|
324
|
+
----------
|
325
|
+
loc : float, ndarray
|
326
|
+
Mean ("centre") of the distribution before truncating. Note that
|
327
|
+
the mean of the truncated distribution will not be exactly equal
|
328
|
+
to ``loc``.
|
329
|
+
scale : float
|
330
|
+
The standard deviation of the normal distribution before truncating.
|
331
|
+
lower : float, ndarray
|
332
|
+
A float or array of floats representing the lower bound for
|
333
|
+
truncation. Must be broadcast-compatible with ``upper``.
|
334
|
+
upper : float, ndarray
|
335
|
+
A float or array of floats representing the upper bound for
|
336
|
+
truncation. Must be broadcast-compatible with ``lower``.
|
337
|
+
|
338
|
+
"""
|
339
|
+
__module__ = 'brainstate.nn'
|
340
|
+
|
341
|
+
def __init__(
|
342
|
+
self,
|
343
|
+
loc: ArrayLike = 0.,
|
344
|
+
scale: ArrayLike = 1.,
|
345
|
+
unit: u.Unit = u.UNITLESS,
|
346
|
+
lower: ArrayLike = None,
|
347
|
+
upper: ArrayLike = None,
|
348
|
+
seed: SeedOrKey = None,
|
349
|
+
):
|
350
|
+
super().__init__()
|
351
|
+
assert scale > 0, '`scale` must be positive.'
|
352
|
+
self.scale = scale
|
353
|
+
self.loc = loc
|
354
|
+
self.lower = lower
|
355
|
+
self.upper = upper
|
356
|
+
self.rng = random.default_rng(seed)
|
357
|
+
self.unit = unit
|
358
|
+
|
359
|
+
def __call__(self, shape, **kwargs):
|
360
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
361
|
+
rng = kwargs.get('rng', self.rng)
|
362
|
+
weights = rng.truncated_normal(
|
363
|
+
size=shape,
|
364
|
+
scale=self.scale,
|
365
|
+
lower=self.lower,
|
366
|
+
upper=self.upper,
|
367
|
+
loc=self.loc,
|
368
|
+
dtype=dtype
|
369
|
+
)
|
370
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
371
|
+
|
372
|
+
|
373
|
+
class Gamma(Initializer):
|
374
|
+
"""Initialize weights with Gamma distribution.
|
375
|
+
|
376
|
+
Parameters
|
377
|
+
----------
|
378
|
+
shape: float, Array
|
379
|
+
Shape parameter.
|
380
|
+
scale: float, Array
|
381
|
+
The gain of the derivation of the Gamma distribution.
|
382
|
+
|
383
|
+
"""
|
384
|
+
__module__ = 'brainstate.nn'
|
385
|
+
|
386
|
+
def __init__(
|
387
|
+
self,
|
388
|
+
shape: ArrayLike,
|
389
|
+
unit: u.Unit = u.UNITLESS,
|
390
|
+
scale: ArrayLike = None,
|
391
|
+
seed: SeedOrKey = None
|
392
|
+
):
|
393
|
+
self.shape = shape
|
394
|
+
self.scale = scale
|
395
|
+
self.rng = random.default_rng(seed)
|
396
|
+
self.unit = unit
|
397
|
+
|
398
|
+
def __call__(self, shape, **kwargs):
|
399
|
+
shape = to_size(shape)
|
400
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
401
|
+
rng = kwargs.get('rng', self.rng)
|
402
|
+
weights = rng.gamma(self.shape, scale=self.scale, size=shape, dtype=dtype)
|
403
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
404
|
+
|
405
|
+
|
406
|
+
class Exponential(Initializer):
|
407
|
+
"""Initialize weights with Gamma distribution.
|
408
|
+
|
409
|
+
Parameters
|
410
|
+
----------
|
411
|
+
scale: float, Array
|
412
|
+
The gain of the derivation of the Exponential distribution.
|
413
|
+
|
414
|
+
"""
|
415
|
+
__module__ = 'brainstate.nn'
|
416
|
+
|
417
|
+
def __init__(
|
418
|
+
self,
|
419
|
+
scale: ArrayLike = None,
|
420
|
+
seed: SeedOrKey = None,
|
421
|
+
unit: u.Unit = u.UNITLESS,
|
422
|
+
):
|
423
|
+
self.scale = scale
|
424
|
+
self.rng = random.default_rng(seed)
|
425
|
+
self.unit = unit
|
426
|
+
|
427
|
+
def __call__(self, shape, **kwargs):
|
428
|
+
shape = to_size(shape)
|
429
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
430
|
+
rng = kwargs.get('rng', self.rng)
|
431
|
+
weights = rng.exponential(scale=self.scale, size=shape, dtype=dtype)
|
432
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
433
|
+
|
434
|
+
|
435
|
+
class Uniform(Initializer):
|
436
|
+
"""Initialize weights with uniform distribution.
|
437
|
+
|
438
|
+
Parameters
|
439
|
+
----------
|
440
|
+
min_val : float
|
441
|
+
The lower limit of the uniform distribution.
|
442
|
+
max_val : float
|
443
|
+
The upper limit of the uniform distribution.
|
444
|
+
"""
|
445
|
+
__module__ = 'brainstate.nn'
|
446
|
+
|
447
|
+
def __init__(
|
448
|
+
self,
|
449
|
+
min_val: ArrayLike = 0.,
|
450
|
+
max_val: ArrayLike = 1.,
|
451
|
+
seed: SeedOrKey = None,
|
452
|
+
unit: u.Unit = u.UNITLESS,
|
453
|
+
):
|
454
|
+
super(Uniform, self).__init__()
|
455
|
+
self.min_val = min_val
|
456
|
+
self.max_val = max_val
|
457
|
+
self.rng = random.default_rng(seed)
|
458
|
+
self.unit = unit
|
459
|
+
|
460
|
+
def __call__(self, shape, **kwargs):
|
461
|
+
shape = to_size(shape)
|
462
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
463
|
+
rng = kwargs.get('rng', self.rng)
|
464
|
+
weights = rng.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=dtype)
|
465
|
+
return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
|
466
|
+
|
467
|
+
|
468
|
+
class VarianceScaling(Initializer):
|
469
|
+
__module__ = 'brainstate.nn'
|
470
|
+
|
471
|
+
def __init__(
|
472
|
+
self,
|
473
|
+
scale: ArrayLike,
|
474
|
+
mode: str,
|
475
|
+
distribution: str,
|
476
|
+
in_axis: int = -2,
|
477
|
+
out_axis: int = -1,
|
478
|
+
seed: SeedOrKey = None,
|
479
|
+
unit: u.Unit = u.UNITLESS,
|
480
|
+
):
|
481
|
+
assert mode in ['fan_in', 'fan_out', 'fan_avg']
|
482
|
+
assert distribution in ['truncated_normal', 'normal', 'uniform']
|
483
|
+
self.scale = scale
|
484
|
+
self.mode = mode
|
485
|
+
self.in_axis = in_axis
|
486
|
+
self.out_axis = out_axis
|
487
|
+
self.distribution = distribution
|
488
|
+
self.rng = random.default_rng(seed)
|
489
|
+
self.unit = unit
|
490
|
+
|
491
|
+
def __call__(self, shape, **kwargs):
|
492
|
+
shape = to_size(shape)
|
493
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
494
|
+
rng = kwargs.get('rng', self.rng)
|
495
|
+
fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
|
496
|
+
if self.mode == "fan_in":
|
497
|
+
denominator = fan_in
|
498
|
+
elif self.mode == "fan_out":
|
499
|
+
denominator = fan_out
|
500
|
+
elif self.mode == "fan_avg":
|
501
|
+
denominator = (fan_in + fan_out) / 2
|
502
|
+
else:
|
503
|
+
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
|
504
|
+
variance = (self.scale / denominator).astype(dtype)
|
505
|
+
if self.distribution == "truncated_normal":
|
506
|
+
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
|
507
|
+
res = rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
|
508
|
+
elif self.distribution == "normal":
|
509
|
+
res = rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
|
510
|
+
elif self.distribution == "uniform":
|
511
|
+
res = (rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
|
512
|
+
jnp.sqrt(3 * variance).astype(dtype))
|
513
|
+
else:
|
514
|
+
raise ValueError("invalid distribution for variance scaling initializer")
|
515
|
+
return u.maybe_decimal(u.Quantity(res, unit=self.unit))
|
516
|
+
|
517
|
+
|
518
|
+
class KaimingUniform(VarianceScaling):
|
519
|
+
__module__ = 'brainstate.nn'
|
520
|
+
|
521
|
+
def __init__(
|
522
|
+
self,
|
523
|
+
scale: float = 2.0,
|
524
|
+
mode: str = "fan_in",
|
525
|
+
distribution: str = "uniform",
|
526
|
+
in_axis: int = -2,
|
527
|
+
out_axis: int = -1,
|
528
|
+
seed: SeedOrKey = None,
|
529
|
+
unit: u.Unit = u.UNITLESS,
|
530
|
+
):
|
531
|
+
super().__init__(scale,
|
532
|
+
mode,
|
533
|
+
distribution,
|
534
|
+
in_axis=in_axis,
|
535
|
+
out_axis=out_axis,
|
536
|
+
seed=seed,
|
537
|
+
unit=unit)
|
538
|
+
|
539
|
+
|
540
|
+
class KaimingNormal(VarianceScaling):
|
541
|
+
__module__ = 'brainstate.nn'
|
542
|
+
|
543
|
+
def __init__(
|
544
|
+
self,
|
545
|
+
scale: float = 2.0,
|
546
|
+
mode: str = "fan_in",
|
547
|
+
distribution: str = "truncated_normal",
|
548
|
+
in_axis: int = -2,
|
549
|
+
out_axis: int = -1,
|
550
|
+
seed: SeedOrKey = None,
|
551
|
+
unit: u.Unit = u.UNITLESS,
|
552
|
+
):
|
553
|
+
super().__init__(scale,
|
554
|
+
mode,
|
555
|
+
distribution,
|
556
|
+
in_axis=in_axis,
|
557
|
+
out_axis=out_axis,
|
558
|
+
seed=seed,
|
559
|
+
unit=unit)
|
560
|
+
|
561
|
+
|
562
|
+
class XavierUniform(VarianceScaling):
|
563
|
+
__module__ = 'brainstate.nn'
|
564
|
+
|
565
|
+
def __init__(
|
566
|
+
self,
|
567
|
+
scale: float = 1.0,
|
568
|
+
mode: str = "fan_avg",
|
569
|
+
distribution: str = "uniform",
|
570
|
+
in_axis: int = -2,
|
571
|
+
out_axis: int = -1,
|
572
|
+
seed: SeedOrKey = None,
|
573
|
+
unit: u.Unit = u.UNITLESS,
|
574
|
+
):
|
575
|
+
super().__init__(scale,
|
576
|
+
mode,
|
577
|
+
distribution,
|
578
|
+
in_axis=in_axis,
|
579
|
+
out_axis=out_axis,
|
580
|
+
seed=seed,
|
581
|
+
unit=unit)
|
582
|
+
|
583
|
+
|
584
|
+
class XavierNormal(VarianceScaling):
|
585
|
+
__module__ = 'brainstate.nn'
|
586
|
+
|
587
|
+
def __init__(
|
588
|
+
self,
|
589
|
+
scale: float = 1.0,
|
590
|
+
mode: str = "fan_avg",
|
591
|
+
distribution: str = "truncated_normal",
|
592
|
+
in_axis: int = -2,
|
593
|
+
out_axis: int = -1,
|
594
|
+
seed: SeedOrKey = None,
|
595
|
+
unit: u.Unit = u.UNITLESS,
|
596
|
+
):
|
597
|
+
super().__init__(scale,
|
598
|
+
mode,
|
599
|
+
distribution,
|
600
|
+
in_axis=in_axis,
|
601
|
+
out_axis=out_axis,
|
602
|
+
seed=seed,
|
603
|
+
unit=unit)
|
604
|
+
|
605
|
+
|
606
|
+
class LecunUniform(VarianceScaling):
|
607
|
+
__module__ = 'brainstate.nn'
|
608
|
+
|
609
|
+
def __init__(
|
610
|
+
self,
|
611
|
+
scale: float = 1.0,
|
612
|
+
mode: str = "fan_in",
|
613
|
+
distribution: str = "uniform",
|
614
|
+
in_axis: int = -2,
|
615
|
+
out_axis: int = -1,
|
616
|
+
seed: SeedOrKey = None,
|
617
|
+
unit: u.Unit = u.UNITLESS,
|
618
|
+
):
|
619
|
+
super().__init__(scale,
|
620
|
+
mode,
|
621
|
+
distribution,
|
622
|
+
in_axis=in_axis,
|
623
|
+
out_axis=out_axis,
|
624
|
+
seed=seed,
|
625
|
+
unit=unit)
|
626
|
+
|
627
|
+
|
628
|
+
class LecunNormal(VarianceScaling):
|
629
|
+
__module__ = 'brainstate.nn'
|
630
|
+
|
631
|
+
def __init__(
|
632
|
+
self,
|
633
|
+
scale: float = 1.0,
|
634
|
+
mode: str = "fan_in",
|
635
|
+
distribution: str = "truncated_normal",
|
636
|
+
in_axis: int = -2,
|
637
|
+
out_axis: int = -1,
|
638
|
+
seed: SeedOrKey = None,
|
639
|
+
unit: u.Unit = u.UNITLESS,
|
640
|
+
):
|
641
|
+
super().__init__(scale,
|
642
|
+
mode,
|
643
|
+
distribution,
|
644
|
+
in_axis=in_axis,
|
645
|
+
out_axis=out_axis,
|
646
|
+
seed=seed,
|
647
|
+
unit=unit)
|
648
|
+
|
649
|
+
|
650
|
+
class Orthogonal(Initializer):
|
651
|
+
"""
|
652
|
+
Construct an initializer for uniformly distributed orthogonal matrices.
|
653
|
+
|
654
|
+
If the shape is not square, the matrix will have orthonormal rows or columns
|
655
|
+
depending on which side is smaller.
|
656
|
+
"""
|
657
|
+
__module__ = 'brainstate.nn'
|
658
|
+
|
659
|
+
def __init__(
|
660
|
+
self,
|
661
|
+
scale: ArrayLike = 1.,
|
662
|
+
axis: int = -1,
|
663
|
+
seed: SeedOrKey = None,
|
664
|
+
unit: u.Unit = u.UNITLESS,
|
665
|
+
):
|
666
|
+
super().__init__()
|
667
|
+
self.scale = scale
|
668
|
+
self.axis = axis
|
669
|
+
self.rng = random.default_rng(seed)
|
670
|
+
self.unit = unit
|
671
|
+
|
672
|
+
def __call__(self, shape, **kwargs):
|
673
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
674
|
+
rng = kwargs.get('rng', self.rng)
|
675
|
+
shape = to_size(shape)
|
676
|
+
n_rows = shape[self.axis]
|
677
|
+
n_cols = np.prod(shape) // n_rows
|
678
|
+
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
|
679
|
+
norm_dst = rng.normal(size=matrix_shape, dtype=dtype)
|
680
|
+
|
681
|
+
q_mat, r_mat = jnp.linalg.qr(norm_dst)
|
682
|
+
# Enforce Q is uniformly distributed
|
683
|
+
q_mat *= jnp.sign(jnp.diag(r_mat))
|
684
|
+
if n_rows < n_cols:
|
685
|
+
q_mat = q_mat.T
|
686
|
+
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
|
687
|
+
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
|
688
|
+
r = jnp.asarray(self.scale, dtype=dtype) * q_mat
|
689
|
+
return u.maybe_decimal(u.Quantity(r, unit=self.unit))
|
690
|
+
|
691
|
+
|
692
|
+
class DeltaOrthogonal(Initializer):
|
693
|
+
"""
|
694
|
+
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
|
695
|
+
|
696
|
+
The shape must be 3D, 4D or 5D.
|
697
|
+
"""
|
698
|
+
__module__ = 'brainstate.nn'
|
699
|
+
|
700
|
+
def __init__(
|
701
|
+
self,
|
702
|
+
scale: ArrayLike = 1.0,
|
703
|
+
axis: int = -1,
|
704
|
+
seed: SeedOrKey = None,
|
705
|
+
unit: u.Unit = u.UNITLESS,
|
706
|
+
):
|
707
|
+
super().__init__()
|
708
|
+
self.scale = scale
|
709
|
+
self.axis = axis
|
710
|
+
self.orghogonal = Orthogonal(scale=scale, axis=axis, seed=seed)
|
711
|
+
self.unit = unit
|
712
|
+
|
713
|
+
def __call__(self, shape, **kwargs):
|
714
|
+
shape = to_size(shape)
|
715
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
716
|
+
if len(shape) not in [3, 4, 5]:
|
717
|
+
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
|
718
|
+
if shape[-1] < shape[-2]:
|
719
|
+
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
|
720
|
+
ortho_matrix = u.Quantity(self.orghogonal(shape[-2:]))
|
721
|
+
W = u.Quantity(u.math.zeros(shape, dtype=dtype), unit=u.get_unit(ortho_matrix))
|
722
|
+
if len(shape) == 3:
|
723
|
+
k = shape[0]
|
724
|
+
W = W.at[(k - 1) // 2].set(ortho_matrix)
|
725
|
+
elif len(shape) == 4:
|
726
|
+
k1, k2 = shape[:2]
|
727
|
+
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2].set(ortho_matrix)
|
728
|
+
else:
|
729
|
+
k1, k2, k3 = shape[:3]
|
730
|
+
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
|
731
|
+
return u.maybe_decimal(u.Quantity(W.mantissa, unit=self.unit))
|
732
|
+
|
733
|
+
|
734
|
+
class ZeroInit(Initializer):
|
735
|
+
"""Zero initializer.
|
736
|
+
|
737
|
+
Initialize the weights with zeros.
|
738
|
+
"""
|
739
|
+
__module__ = 'brainstate.nn'
|
740
|
+
|
741
|
+
def __init__(self, unit: u.Unit = u.UNITLESS):
|
742
|
+
super(ZeroInit, self).__init__()
|
743
|
+
self.unit = unit
|
744
|
+
|
745
|
+
def __call__(self, shape, **kwargs):
|
746
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
747
|
+
shape = to_size(shape)
|
748
|
+
return u.maybe_decimal(u.math.zeros(shape, dtype=dtype, unit=self.unit))
|
749
|
+
|
750
|
+
|
751
|
+
class Constant(Initializer):
|
752
|
+
"""Constant initializer.
|
753
|
+
|
754
|
+
Initialize the weights with the given values.
|
755
|
+
|
756
|
+
Parameters
|
757
|
+
----------
|
758
|
+
value : float, int, bm.ndarray
|
759
|
+
The value to specify.
|
760
|
+
"""
|
761
|
+
__module__ = 'brainstate.nn'
|
762
|
+
|
763
|
+
def __init__(self, value=1., ):
|
764
|
+
super(Constant, self).__init__()
|
765
|
+
self.value = value
|
766
|
+
|
767
|
+
def __call__(self, shape, **kwargs):
|
768
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
769
|
+
shape = to_size(shape)
|
770
|
+
return u.maybe_decimal(u.math.full(shape, self.value, dtype=dtype))
|
771
|
+
|
772
|
+
|
773
|
+
class Identity(Initializer):
|
774
|
+
"""Returns the identity matrix.
|
775
|
+
|
776
|
+
This initializer was proposed in (Le, et al., 2015) [1]_.
|
777
|
+
|
778
|
+
Parameters
|
779
|
+
----------
|
780
|
+
value : float
|
781
|
+
The optional scaling factor.
|
782
|
+
|
783
|
+
Returns
|
784
|
+
-------
|
785
|
+
shape: tuple of int
|
786
|
+
The weight shape/size.
|
787
|
+
|
788
|
+
References
|
789
|
+
----------
|
790
|
+
.. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to
|
791
|
+
initialize recurrent networks of rectified linear units." arXiv preprint
|
792
|
+
arXiv:1504.00941 (2015).
|
793
|
+
"""
|
794
|
+
__module__ = 'brainstate.nn'
|
795
|
+
|
796
|
+
def __init__(self, value=1., unit: u.Unit = u.UNITLESS):
|
797
|
+
super(Identity, self).__init__()
|
798
|
+
self.value = value
|
799
|
+
self.unit = unit
|
800
|
+
|
801
|
+
def __call__(self, shape, **kwargs):
|
802
|
+
dtype = kwargs.get('dtype', environ.dftype())
|
803
|
+
shape = to_size(shape)
|
804
|
+
if isinstance(shape, (tuple, list)):
|
805
|
+
if len(shape) > 2:
|
806
|
+
raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
|
807
|
+
r = u.math.eye(*shape, dtype=dtype)
|
808
|
+
r = u.math.fill_diagonal(r, self.value)
|
809
|
+
return u.maybe_decimal(u.Quantity(r, unit=self.unit))
|