brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/init.py CHANGED
@@ -1,809 +1,809 @@
1
- # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import math
17
- from typing import Optional, Tuple
18
- from typing import Union, Callable, Sequence
19
-
20
- import brainunit as u
21
- import jax
22
- import jax.numpy as jnp
23
- import numpy as np
24
-
25
- from brainstate import environ, random
26
- from brainstate._state import State
27
- from brainstate._utils import set_module_as
28
- from brainstate.typing import ArrayLike, SeedOrKey
29
- from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
30
-
31
- __all__ = [
32
- 'param',
33
- 'calculate_init_gain',
34
- 'ZeroInit',
35
- 'Constant',
36
- 'Identity',
37
- 'Normal',
38
- 'TruncatedNormal',
39
- 'Uniform',
40
- 'VarianceScaling',
41
- 'KaimingUniform',
42
- 'KaimingNormal',
43
- 'XavierUniform',
44
- 'XavierNormal',
45
- 'LecunUniform',
46
- 'LecunNormal',
47
- 'Orthogonal',
48
- 'DeltaOrthogonal',
49
- ]
50
-
51
-
52
- class Initializer(PrettyRepr):
53
- """
54
- Base class for initializers.
55
- """
56
- __module__ = 'brainstate.nn'
57
-
58
- def __call__(self, *args, **kwargs):
59
- raise NotImplementedError
60
-
61
- def __pretty_repr__(self):
62
- """
63
- Pretty repr for the object.
64
- """
65
- yield PrettyType(type=type(self))
66
- for name, value in vars(self).items():
67
- if name.startswith('_'):
68
- continue
69
- yield PrettyAttr(name, repr(value))
70
-
71
-
72
- def to_size(x) -> Optional[Tuple[int]]:
73
- if isinstance(x, (tuple, list)):
74
- return tuple(x)
75
- if isinstance(x, (int, np.integer)):
76
- return (x,)
77
- if x is None:
78
- return x
79
- raise ValueError(f'Cannot make a size for {x}')
80
-
81
-
82
- def _is_scalar(x):
83
- return u.math.isscalar(x)
84
-
85
-
86
- def are_broadcastable_shapes(shape1, shape2):
87
- """
88
- Check if two shapes are broadcastable.
89
-
90
- Parameters:
91
- - shape1: Tuple[int], the shape of the first array.
92
- - shape2: Tuple[int], the shape of the second array.
93
-
94
- Returns:
95
- - bool: True if shapes are broadcastable, False otherwise.
96
- """
97
- # Reverse the shapes to compare from the last dimension
98
- shape1_reversed = shape1[::-1]
99
- shape2_reversed = shape2[::-1]
100
-
101
- # Iterate over the dimensions of the shorter shape
102
- for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
103
- # Check if the dimensions are not equal and neither is 1
104
- if dim1 != dim2 and 1 not in (dim1, dim2):
105
- return False
106
-
107
- # If all dimensions are compatible, the shapes are broadcastable
108
- return True
109
-
110
-
111
- def _expand_params_to_match_sizes(params, sizes):
112
- """
113
- Expand the dimensions of params to match the dimensions of sizes.
114
-
115
- Parameters:
116
- - params: jax.Array or np.ndarray, the parameter array to be expanded.
117
- - sizes: tuple[int] or list[int], the target shape dimensions.
118
-
119
- Returns:
120
- - Expanded params with dimensions matching sizes.
121
- """
122
- params_dim = params.ndim
123
- sizes_dim = len(sizes)
124
- dim_diff = sizes_dim - params_dim
125
-
126
- # Add new axes to params if it has fewer dimensions than sizes
127
- for _ in range(dim_diff):
128
- params = u.math.expand_dims(params, axis=0) # Add new axis at the last dimension
129
- return params
130
-
131
-
132
- @set_module_as('brainstate.nn')
133
- def param(
134
- parameter: Union[Callable, ArrayLike, State],
135
- sizes: Union[int, Sequence[int]],
136
- batch_size: Optional[int] = None,
137
- allow_none: bool = True,
138
- allow_scalar: bool = True,
139
- ):
140
- """Initialize parameters.
141
-
142
- Parameters
143
- ----------
144
- parameter: callable, ArrayLike, State
145
- The initialization of the parameter.
146
- - If it is None, the created parameter will be None.
147
- - If it is a callable function :math:`f`, the ``f(size)`` will be returned.
148
- - If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
149
- - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
150
- sizes: int, sequence of int
151
- The shape of the parameter.
152
- batch_size: int
153
- The batch size.
154
- allow_none: bool
155
- Whether allow the parameter is None.
156
- allow_scalar: bool
157
- Whether allow the parameter is a scalar value.
158
-
159
- Returns
160
- -------
161
- param: ArrayType, float, int, bool, None
162
- The initialized parameter.
163
-
164
- See Also
165
- --------
166
- noise, state
167
- """
168
- # Check if the parameter is None
169
- if parameter is None:
170
- if allow_none:
171
- return None
172
- else:
173
- raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
174
- f'Callable function, but we got None. ')
175
-
176
- # Check if the parameter is a scalar value
177
- if allow_scalar and _is_scalar(parameter):
178
- return parameter
179
-
180
- # Convert sizes to a tuple
181
- sizes = tuple(to_size(sizes))
182
-
183
- # Check if the parameter is a callable function
184
- if callable(parameter):
185
- if batch_size is not None:
186
- sizes = (batch_size,) + sizes
187
- return parameter(sizes)
188
- elif isinstance(parameter, (np.ndarray, jax.Array, u.Quantity, State)):
189
- parameter = parameter
190
- else:
191
- raise ValueError(f'Unknown parameter type: {type(parameter)}')
192
-
193
- # Check if the shape of the parameter matches the given size
194
- if not are_broadcastable_shapes(parameter.shape, sizes):
195
- raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
196
-
197
- # Expand the parameter to match the given batch size
198
- param_value = parameter.value if isinstance(parameter, State) else parameter
199
- if batch_size is not None:
200
- if param_value.ndim <= len(sizes):
201
- # add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
202
- param_value = _expand_params_to_match_sizes(param_value, sizes)
203
- param_value = u.math.repeat(
204
- u.math.expand_dims(param_value, axis=0),
205
- batch_size,
206
- axis=0
207
- )
208
- else:
209
- if param_value.shape[0] != batch_size:
210
- raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
211
- f'does not match with the given batch size {batch_size}')
212
- return type(parameter)(param_value) if isinstance(parameter, State) else param_value
213
-
214
-
215
- def calculate_init_gain(nonlinearity, param=None):
216
- r"""Return the recommended gain value for the given nonlinearity function.
217
- The values are as follows:
218
-
219
- ================= ====================================================
220
- nonlinearity gain
221
- ================= ====================================================
222
- Linear / Identity :math:`1`
223
- Conv{1,2,3}D :math:`1`
224
- Sigmoid :math:`1`
225
- Tanh :math:`\frac{5}{3}`
226
- ReLU :math:`\sqrt{2}`
227
- Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
228
- SELU :math:`\frac{3}{4}`
229
- ================= ====================================================
230
-
231
- .. warning::
232
- In order to implement `Self-Normalizing Neural Networks`_ ,
233
- you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
234
- This gives the initial weights a variance of ``1 / N``,
235
- which is necessary to induce a stable fixed point in the forward pass.
236
- In contrast, the default gain for ``SELU`` sacrifices the normalisation
237
- effect for more stable gradient flow in rectangular layers.
238
-
239
- Args:
240
- nonlinearity: the non-linear function (`nn.functional` name)
241
- param: optional parameter for the non-linear function
242
-
243
- .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
244
- """
245
- linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
246
- if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
247
- return 1
248
- elif nonlinearity == 'tanh':
249
- return 5.0 / 3
250
- elif nonlinearity == 'relu':
251
- return math.sqrt(2.0)
252
- elif nonlinearity == 'leaky_relu':
253
- if param is None:
254
- negative_slope = 0.01
255
- elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
256
- # True/False are instances of int, hence check above
257
- negative_slope = param
258
- else:
259
- raise ValueError("negative_slope {} not a valid number".format(param))
260
- return math.sqrt(2.0 / (1 + negative_slope ** 2))
261
- elif nonlinearity == 'selu':
262
- return 3.0 / 4
263
- else:
264
- raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
265
-
266
-
267
- def _format_shape(shape):
268
- if isinstance(shape, int):
269
- return (shape,)
270
- if len(shape) == 0:
271
- raise ValueError('Please provide shape.')
272
- if len(shape) == 1:
273
- if isinstance(shape[0], (tuple, list)):
274
- return shape[0]
275
- else:
276
- return shape
277
- else:
278
- return shape
279
-
280
-
281
- def _compute_fans(shape, in_axis=-2, out_axis=-1):
282
- receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
283
- fan_in = shape[in_axis] * receptive_field_size
284
- fan_out = shape[out_axis] * receptive_field_size
285
- return fan_in, fan_out
286
-
287
-
288
- class Normal(Initializer):
289
- """Initialize weights with normal distribution.
290
-
291
- Parameters
292
- ----------
293
- scale : float
294
- The gain of the derivation of the normal distribution.
295
-
296
- """
297
- __module__ = 'brainstate.nn'
298
-
299
- def __init__(
300
- self,
301
- mean: ArrayLike = 0.,
302
- scale: ArrayLike = 1.,
303
- unit: u.Unit = u.UNITLESS,
304
- seed: SeedOrKey = None
305
- ):
306
- super().__init__()
307
- self.scale = scale
308
- self.mean = mean
309
- self.rng = random.default_rng(seed)
310
- self.unit = unit
311
-
312
- def __call__(self, shape, **kwargs):
313
- shape = to_size(shape)
314
- dtype = kwargs.get('dtype', environ.dftype())
315
- rng = kwargs.get('rng', self.rng)
316
- weights = rng.normal(size=shape, loc=self.mean, scale=self.scale, dtype=dtype)
317
- return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
318
-
319
-
320
- class TruncatedNormal(Initializer):
321
- """Initialize weights with truncated normal distribution.
322
-
323
- Parameters
324
- ----------
325
- loc : float, ndarray
326
- Mean ("centre") of the distribution before truncating. Note that
327
- the mean of the truncated distribution will not be exactly equal
328
- to ``loc``.
329
- scale : float
330
- The standard deviation of the normal distribution before truncating.
331
- lower : float, ndarray
332
- A float or array of floats representing the lower bound for
333
- truncation. Must be broadcast-compatible with ``upper``.
334
- upper : float, ndarray
335
- A float or array of floats representing the upper bound for
336
- truncation. Must be broadcast-compatible with ``lower``.
337
-
338
- """
339
- __module__ = 'brainstate.nn'
340
-
341
- def __init__(
342
- self,
343
- loc: ArrayLike = 0.,
344
- scale: ArrayLike = 1.,
345
- unit: u.Unit = u.UNITLESS,
346
- lower: ArrayLike = None,
347
- upper: ArrayLike = None,
348
- seed: SeedOrKey = None,
349
- ):
350
- super().__init__()
351
- assert scale > 0, '`scale` must be positive.'
352
- self.scale = scale
353
- self.loc = loc
354
- self.lower = lower
355
- self.upper = upper
356
- self.rng = random.default_rng(seed)
357
- self.unit = unit
358
-
359
- def __call__(self, shape, **kwargs):
360
- dtype = kwargs.get('dtype', environ.dftype())
361
- rng = kwargs.get('rng', self.rng)
362
- weights = rng.truncated_normal(
363
- size=shape,
364
- scale=self.scale,
365
- lower=self.lower,
366
- upper=self.upper,
367
- loc=self.loc,
368
- dtype=dtype
369
- )
370
- return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
371
-
372
-
373
- class Gamma(Initializer):
374
- """Initialize weights with Gamma distribution.
375
-
376
- Parameters
377
- ----------
378
- shape: float, Array
379
- Shape parameter.
380
- scale: float, Array
381
- The gain of the derivation of the Gamma distribution.
382
-
383
- """
384
- __module__ = 'brainstate.nn'
385
-
386
- def __init__(
387
- self,
388
- shape: ArrayLike,
389
- unit: u.Unit = u.UNITLESS,
390
- scale: ArrayLike = None,
391
- seed: SeedOrKey = None
392
- ):
393
- self.shape = shape
394
- self.scale = scale
395
- self.rng = random.default_rng(seed)
396
- self.unit = unit
397
-
398
- def __call__(self, shape, **kwargs):
399
- shape = to_size(shape)
400
- dtype = kwargs.get('dtype', environ.dftype())
401
- rng = kwargs.get('rng', self.rng)
402
- weights = rng.gamma(self.shape, scale=self.scale, size=shape, dtype=dtype)
403
- return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
404
-
405
-
406
- class Exponential(Initializer):
407
- """Initialize weights with Gamma distribution.
408
-
409
- Parameters
410
- ----------
411
- scale: float, Array
412
- The gain of the derivation of the Exponential distribution.
413
-
414
- """
415
- __module__ = 'brainstate.nn'
416
-
417
- def __init__(
418
- self,
419
- scale: ArrayLike = None,
420
- seed: SeedOrKey = None,
421
- unit: u.Unit = u.UNITLESS,
422
- ):
423
- self.scale = scale
424
- self.rng = random.default_rng(seed)
425
- self.unit = unit
426
-
427
- def __call__(self, shape, **kwargs):
428
- shape = to_size(shape)
429
- dtype = kwargs.get('dtype', environ.dftype())
430
- rng = kwargs.get('rng', self.rng)
431
- weights = rng.exponential(scale=self.scale, size=shape, dtype=dtype)
432
- return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
433
-
434
-
435
- class Uniform(Initializer):
436
- """Initialize weights with uniform distribution.
437
-
438
- Parameters
439
- ----------
440
- min_val : float
441
- The lower limit of the uniform distribution.
442
- max_val : float
443
- The upper limit of the uniform distribution.
444
- """
445
- __module__ = 'brainstate.nn'
446
-
447
- def __init__(
448
- self,
449
- min_val: ArrayLike = 0.,
450
- max_val: ArrayLike = 1.,
451
- seed: SeedOrKey = None,
452
- unit: u.Unit = u.UNITLESS,
453
- ):
454
- super(Uniform, self).__init__()
455
- self.min_val = min_val
456
- self.max_val = max_val
457
- self.rng = random.default_rng(seed)
458
- self.unit = unit
459
-
460
- def __call__(self, shape, **kwargs):
461
- shape = to_size(shape)
462
- dtype = kwargs.get('dtype', environ.dftype())
463
- rng = kwargs.get('rng', self.rng)
464
- weights = rng.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=dtype)
465
- return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
466
-
467
-
468
- class VarianceScaling(Initializer):
469
- __module__ = 'brainstate.nn'
470
-
471
- def __init__(
472
- self,
473
- scale: ArrayLike,
474
- mode: str,
475
- distribution: str,
476
- in_axis: int = -2,
477
- out_axis: int = -1,
478
- seed: SeedOrKey = None,
479
- unit: u.Unit = u.UNITLESS,
480
- ):
481
- assert mode in ['fan_in', 'fan_out', 'fan_avg']
482
- assert distribution in ['truncated_normal', 'normal', 'uniform']
483
- self.scale = scale
484
- self.mode = mode
485
- self.in_axis = in_axis
486
- self.out_axis = out_axis
487
- self.distribution = distribution
488
- self.rng = random.default_rng(seed)
489
- self.unit = unit
490
-
491
- def __call__(self, shape, **kwargs):
492
- shape = to_size(shape)
493
- dtype = kwargs.get('dtype', environ.dftype())
494
- rng = kwargs.get('rng', self.rng)
495
- fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
496
- if self.mode == "fan_in":
497
- denominator = fan_in
498
- elif self.mode == "fan_out":
499
- denominator = fan_out
500
- elif self.mode == "fan_avg":
501
- denominator = (fan_in + fan_out) / 2
502
- else:
503
- raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
504
- variance = (self.scale / denominator).astype(dtype)
505
- if self.distribution == "truncated_normal":
506
- stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
507
- res = rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
508
- elif self.distribution == "normal":
509
- res = rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
510
- elif self.distribution == "uniform":
511
- res = (rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
512
- jnp.sqrt(3 * variance).astype(dtype))
513
- else:
514
- raise ValueError("invalid distribution for variance scaling initializer")
515
- return u.maybe_decimal(u.Quantity(res, unit=self.unit))
516
-
517
-
518
- class KaimingUniform(VarianceScaling):
519
- __module__ = 'brainstate.nn'
520
-
521
- def __init__(
522
- self,
523
- scale: float = 2.0,
524
- mode: str = "fan_in",
525
- distribution: str = "uniform",
526
- in_axis: int = -2,
527
- out_axis: int = -1,
528
- seed: SeedOrKey = None,
529
- unit: u.Unit = u.UNITLESS,
530
- ):
531
- super().__init__(scale,
532
- mode,
533
- distribution,
534
- in_axis=in_axis,
535
- out_axis=out_axis,
536
- seed=seed,
537
- unit=unit)
538
-
539
-
540
- class KaimingNormal(VarianceScaling):
541
- __module__ = 'brainstate.nn'
542
-
543
- def __init__(
544
- self,
545
- scale: float = 2.0,
546
- mode: str = "fan_in",
547
- distribution: str = "truncated_normal",
548
- in_axis: int = -2,
549
- out_axis: int = -1,
550
- seed: SeedOrKey = None,
551
- unit: u.Unit = u.UNITLESS,
552
- ):
553
- super().__init__(scale,
554
- mode,
555
- distribution,
556
- in_axis=in_axis,
557
- out_axis=out_axis,
558
- seed=seed,
559
- unit=unit)
560
-
561
-
562
- class XavierUniform(VarianceScaling):
563
- __module__ = 'brainstate.nn'
564
-
565
- def __init__(
566
- self,
567
- scale: float = 1.0,
568
- mode: str = "fan_avg",
569
- distribution: str = "uniform",
570
- in_axis: int = -2,
571
- out_axis: int = -1,
572
- seed: SeedOrKey = None,
573
- unit: u.Unit = u.UNITLESS,
574
- ):
575
- super().__init__(scale,
576
- mode,
577
- distribution,
578
- in_axis=in_axis,
579
- out_axis=out_axis,
580
- seed=seed,
581
- unit=unit)
582
-
583
-
584
- class XavierNormal(VarianceScaling):
585
- __module__ = 'brainstate.nn'
586
-
587
- def __init__(
588
- self,
589
- scale: float = 1.0,
590
- mode: str = "fan_avg",
591
- distribution: str = "truncated_normal",
592
- in_axis: int = -2,
593
- out_axis: int = -1,
594
- seed: SeedOrKey = None,
595
- unit: u.Unit = u.UNITLESS,
596
- ):
597
- super().__init__(scale,
598
- mode,
599
- distribution,
600
- in_axis=in_axis,
601
- out_axis=out_axis,
602
- seed=seed,
603
- unit=unit)
604
-
605
-
606
- class LecunUniform(VarianceScaling):
607
- __module__ = 'brainstate.nn'
608
-
609
- def __init__(
610
- self,
611
- scale: float = 1.0,
612
- mode: str = "fan_in",
613
- distribution: str = "uniform",
614
- in_axis: int = -2,
615
- out_axis: int = -1,
616
- seed: SeedOrKey = None,
617
- unit: u.Unit = u.UNITLESS,
618
- ):
619
- super().__init__(scale,
620
- mode,
621
- distribution,
622
- in_axis=in_axis,
623
- out_axis=out_axis,
624
- seed=seed,
625
- unit=unit)
626
-
627
-
628
- class LecunNormal(VarianceScaling):
629
- __module__ = 'brainstate.nn'
630
-
631
- def __init__(
632
- self,
633
- scale: float = 1.0,
634
- mode: str = "fan_in",
635
- distribution: str = "truncated_normal",
636
- in_axis: int = -2,
637
- out_axis: int = -1,
638
- seed: SeedOrKey = None,
639
- unit: u.Unit = u.UNITLESS,
640
- ):
641
- super().__init__(scale,
642
- mode,
643
- distribution,
644
- in_axis=in_axis,
645
- out_axis=out_axis,
646
- seed=seed,
647
- unit=unit)
648
-
649
-
650
- class Orthogonal(Initializer):
651
- """
652
- Construct an initializer for uniformly distributed orthogonal matrices.
653
-
654
- If the shape is not square, the matrix will have orthonormal rows or columns
655
- depending on which side is smaller.
656
- """
657
- __module__ = 'brainstate.nn'
658
-
659
- def __init__(
660
- self,
661
- scale: ArrayLike = 1.,
662
- axis: int = -1,
663
- seed: SeedOrKey = None,
664
- unit: u.Unit = u.UNITLESS,
665
- ):
666
- super().__init__()
667
- self.scale = scale
668
- self.axis = axis
669
- self.rng = random.default_rng(seed)
670
- self.unit = unit
671
-
672
- def __call__(self, shape, **kwargs):
673
- dtype = kwargs.get('dtype', environ.dftype())
674
- rng = kwargs.get('rng', self.rng)
675
- shape = to_size(shape)
676
- n_rows = shape[self.axis]
677
- n_cols = np.prod(shape) // n_rows
678
- matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
679
- norm_dst = rng.normal(size=matrix_shape, dtype=dtype)
680
-
681
- q_mat, r_mat = jnp.linalg.qr(norm_dst)
682
- # Enforce Q is uniformly distributed
683
- q_mat *= jnp.sign(jnp.diag(r_mat))
684
- if n_rows < n_cols:
685
- q_mat = q_mat.T
686
- q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
687
- q_mat = jnp.moveaxis(q_mat, 0, self.axis)
688
- r = jnp.asarray(self.scale, dtype=dtype) * q_mat
689
- return u.maybe_decimal(u.Quantity(r, unit=self.unit))
690
-
691
-
692
- class DeltaOrthogonal(Initializer):
693
- """
694
- Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
695
-
696
- The shape must be 3D, 4D or 5D.
697
- """
698
- __module__ = 'brainstate.nn'
699
-
700
- def __init__(
701
- self,
702
- scale: ArrayLike = 1.0,
703
- axis: int = -1,
704
- seed: SeedOrKey = None,
705
- unit: u.Unit = u.UNITLESS,
706
- ):
707
- super().__init__()
708
- self.scale = scale
709
- self.axis = axis
710
- self.orghogonal = Orthogonal(scale=scale, axis=axis, seed=seed)
711
- self.unit = unit
712
-
713
- def __call__(self, shape, **kwargs):
714
- shape = to_size(shape)
715
- dtype = kwargs.get('dtype', environ.dftype())
716
- if len(shape) not in [3, 4, 5]:
717
- raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
718
- if shape[-1] < shape[-2]:
719
- raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
720
- ortho_matrix = u.Quantity(self.orghogonal(shape[-2:]))
721
- W = u.Quantity(u.math.zeros(shape, dtype=dtype), unit=u.get_unit(ortho_matrix))
722
- if len(shape) == 3:
723
- k = shape[0]
724
- W = W.at[(k - 1) // 2].set(ortho_matrix)
725
- elif len(shape) == 4:
726
- k1, k2 = shape[:2]
727
- W = W.at[(k1 - 1) // 2, (k2 - 1) // 2].set(ortho_matrix)
728
- else:
729
- k1, k2, k3 = shape[:3]
730
- W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
731
- return u.maybe_decimal(u.Quantity(W.mantissa, unit=self.unit))
732
-
733
-
734
- class ZeroInit(Initializer):
735
- """Zero initializer.
736
-
737
- Initialize the weights with zeros.
738
- """
739
- __module__ = 'brainstate.nn'
740
-
741
- def __init__(self, unit: u.Unit = u.UNITLESS):
742
- super(ZeroInit, self).__init__()
743
- self.unit = unit
744
-
745
- def __call__(self, shape, **kwargs):
746
- dtype = kwargs.get('dtype', environ.dftype())
747
- shape = to_size(shape)
748
- return u.maybe_decimal(u.math.zeros(shape, dtype=dtype, unit=self.unit))
749
-
750
-
751
- class Constant(Initializer):
752
- """Constant initializer.
753
-
754
- Initialize the weights with the given values.
755
-
756
- Parameters
757
- ----------
758
- value : float, int, bm.ndarray
759
- The value to specify.
760
- """
761
- __module__ = 'brainstate.nn'
762
-
763
- def __init__(self, value=1., ):
764
- super(Constant, self).__init__()
765
- self.value = value
766
-
767
- def __call__(self, shape, **kwargs):
768
- dtype = kwargs.get('dtype', environ.dftype())
769
- shape = to_size(shape)
770
- return u.maybe_decimal(u.math.full(shape, self.value, dtype=dtype))
771
-
772
-
773
- class Identity(Initializer):
774
- """Returns the identity matrix.
775
-
776
- This initializer was proposed in (Le, et al., 2015) [1]_.
777
-
778
- Parameters
779
- ----------
780
- value : float
781
- The optional scaling factor.
782
-
783
- Returns
784
- -------
785
- shape: tuple of int
786
- The weight shape/size.
787
-
788
- References
789
- ----------
790
- .. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to
791
- initialize recurrent networks of rectified linear units." arXiv preprint
792
- arXiv:1504.00941 (2015).
793
- """
794
- __module__ = 'brainstate.nn'
795
-
796
- def __init__(self, value=1., unit: u.Unit = u.UNITLESS):
797
- super(Identity, self).__init__()
798
- self.value = value
799
- self.unit = unit
800
-
801
- def __call__(self, shape, **kwargs):
802
- dtype = kwargs.get('dtype', environ.dftype())
803
- shape = to_size(shape)
804
- if isinstance(shape, (tuple, list)):
805
- if len(shape) > 2:
806
- raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
807
- r = u.math.eye(*shape, dtype=dtype)
808
- r = u.math.fill_diagonal(r, self.value)
809
- return u.maybe_decimal(u.Quantity(r, unit=self.unit))
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import math
17
+ from typing import Optional, Tuple
18
+ from typing import Union, Callable, Sequence
19
+
20
+ import brainunit as u
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+
25
+ from brainstate import environ, random
26
+ from brainstate._state import State
27
+ from brainstate._utils import set_module_as
28
+ from brainstate.typing import ArrayLike, SeedOrKey
29
+ from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
30
+
31
+ __all__ = [
32
+ 'param',
33
+ 'calculate_init_gain',
34
+ 'ZeroInit',
35
+ 'Constant',
36
+ 'Identity',
37
+ 'Normal',
38
+ 'TruncatedNormal',
39
+ 'Uniform',
40
+ 'VarianceScaling',
41
+ 'KaimingUniform',
42
+ 'KaimingNormal',
43
+ 'XavierUniform',
44
+ 'XavierNormal',
45
+ 'LecunUniform',
46
+ 'LecunNormal',
47
+ 'Orthogonal',
48
+ 'DeltaOrthogonal',
49
+ ]
50
+
51
+
52
+ class Initializer(PrettyRepr):
53
+ """
54
+ Base class for initializers.
55
+ """
56
+ __module__ = 'brainstate.nn'
57
+
58
+ def __call__(self, *args, **kwargs):
59
+ raise NotImplementedError
60
+
61
+ def __pretty_repr__(self):
62
+ """
63
+ Pretty repr for the object.
64
+ """
65
+ yield PrettyType(type=type(self))
66
+ for name, value in vars(self).items():
67
+ if name.startswith('_'):
68
+ continue
69
+ yield PrettyAttr(name, repr(value))
70
+
71
+
72
+ def to_size(x) -> Optional[Tuple[int]]:
73
+ if isinstance(x, (tuple, list)):
74
+ return tuple(x)
75
+ if isinstance(x, (int, np.integer)):
76
+ return (x,)
77
+ if x is None:
78
+ return x
79
+ raise ValueError(f'Cannot make a size for {x}')
80
+
81
+
82
+ def _is_scalar(x):
83
+ return u.math.isscalar(x)
84
+
85
+
86
+ def are_broadcastable_shapes(shape1, shape2):
87
+ """
88
+ Check if two shapes are broadcastable.
89
+
90
+ Parameters:
91
+ - shape1: Tuple[int], the shape of the first array.
92
+ - shape2: Tuple[int], the shape of the second array.
93
+
94
+ Returns:
95
+ - bool: True if shapes are broadcastable, False otherwise.
96
+ """
97
+ # Reverse the shapes to compare from the last dimension
98
+ shape1_reversed = shape1[::-1]
99
+ shape2_reversed = shape2[::-1]
100
+
101
+ # Iterate over the dimensions of the shorter shape
102
+ for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
103
+ # Check if the dimensions are not equal and neither is 1
104
+ if dim1 != dim2 and 1 not in (dim1, dim2):
105
+ return False
106
+
107
+ # If all dimensions are compatible, the shapes are broadcastable
108
+ return True
109
+
110
+
111
+ def _expand_params_to_match_sizes(params, sizes):
112
+ """
113
+ Expand the dimensions of params to match the dimensions of sizes.
114
+
115
+ Parameters:
116
+ - params: jax.Array or np.ndarray, the parameter array to be expanded.
117
+ - sizes: tuple[int] or list[int], the target shape dimensions.
118
+
119
+ Returns:
120
+ - Expanded params with dimensions matching sizes.
121
+ """
122
+ params_dim = params.ndim
123
+ sizes_dim = len(sizes)
124
+ dim_diff = sizes_dim - params_dim
125
+
126
+ # Add new axes to params if it has fewer dimensions than sizes
127
+ for _ in range(dim_diff):
128
+ params = u.math.expand_dims(params, axis=0) # Add new axis at the last dimension
129
+ return params
130
+
131
+
132
+ @set_module_as('brainstate.nn')
133
+ def param(
134
+ parameter: Union[Callable, ArrayLike, State],
135
+ sizes: Union[int, Sequence[int]],
136
+ batch_size: Optional[int] = None,
137
+ allow_none: bool = True,
138
+ allow_scalar: bool = True,
139
+ ):
140
+ """Initialize parameters.
141
+
142
+ Parameters
143
+ ----------
144
+ parameter: callable, ArrayLike, State
145
+ The initialization of the parameter.
146
+ - If it is None, the created parameter will be None.
147
+ - If it is a callable function :math:`f`, the ``f(size)`` will be returned.
148
+ - If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
149
+ - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
150
+ sizes: int, sequence of int
151
+ The shape of the parameter.
152
+ batch_size: int
153
+ The batch size.
154
+ allow_none: bool
155
+ Whether allow the parameter is None.
156
+ allow_scalar: bool
157
+ Whether allow the parameter is a scalar value.
158
+
159
+ Returns
160
+ -------
161
+ param: ArrayType, float, int, bool, None
162
+ The initialized parameter.
163
+
164
+ See Also
165
+ --------
166
+ noise, state
167
+ """
168
+ # Check if the parameter is None
169
+ if parameter is None:
170
+ if allow_none:
171
+ return None
172
+ else:
173
+ raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
174
+ f'Callable function, but we got None. ')
175
+
176
+ # Check if the parameter is a scalar value
177
+ if allow_scalar and _is_scalar(parameter):
178
+ return parameter
179
+
180
+ # Convert sizes to a tuple
181
+ sizes = tuple(to_size(sizes))
182
+
183
+ # Check if the parameter is a callable function
184
+ if callable(parameter):
185
+ if batch_size is not None:
186
+ sizes = (batch_size,) + sizes
187
+ return parameter(sizes)
188
+ elif isinstance(parameter, (np.ndarray, jax.Array, u.Quantity, State)):
189
+ parameter = parameter
190
+ else:
191
+ raise ValueError(f'Unknown parameter type: {type(parameter)}')
192
+
193
+ # Check if the shape of the parameter matches the given size
194
+ if not are_broadcastable_shapes(parameter.shape, sizes):
195
+ raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
196
+
197
+ # Expand the parameter to match the given batch size
198
+ param_value = parameter.value if isinstance(parameter, State) else parameter
199
+ if batch_size is not None:
200
+ if param_value.ndim <= len(sizes):
201
+ # add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
202
+ param_value = _expand_params_to_match_sizes(param_value, sizes)
203
+ param_value = u.math.repeat(
204
+ u.math.expand_dims(param_value, axis=0),
205
+ batch_size,
206
+ axis=0
207
+ )
208
+ else:
209
+ if param_value.shape[0] != batch_size:
210
+ raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
211
+ f'does not match with the given batch size {batch_size}')
212
+ return type(parameter)(param_value) if isinstance(parameter, State) else param_value
213
+
214
+
215
+ def calculate_init_gain(nonlinearity, param=None):
216
+ r"""Return the recommended gain value for the given nonlinearity function.
217
+ The values are as follows:
218
+
219
+ ================= ====================================================
220
+ nonlinearity gain
221
+ ================= ====================================================
222
+ Linear / Identity :math:`1`
223
+ Conv{1,2,3}D :math:`1`
224
+ Sigmoid :math:`1`
225
+ Tanh :math:`\frac{5}{3}`
226
+ ReLU :math:`\sqrt{2}`
227
+ Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
228
+ SELU :math:`\frac{3}{4}`
229
+ ================= ====================================================
230
+
231
+ .. warning::
232
+ In order to implement `Self-Normalizing Neural Networks`_ ,
233
+ you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
234
+ This gives the initial weights a variance of ``1 / N``,
235
+ which is necessary to induce a stable fixed point in the forward pass.
236
+ In contrast, the default gain for ``SELU`` sacrifices the normalisation
237
+ effect for more stable gradient flow in rectangular layers.
238
+
239
+ Args:
240
+ nonlinearity: the non-linear function (`nn.functional` name)
241
+ param: optional parameter for the non-linear function
242
+
243
+ .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
244
+ """
245
+ linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
246
+ if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
247
+ return 1
248
+ elif nonlinearity == 'tanh':
249
+ return 5.0 / 3
250
+ elif nonlinearity == 'relu':
251
+ return math.sqrt(2.0)
252
+ elif nonlinearity == 'leaky_relu':
253
+ if param is None:
254
+ negative_slope = 0.01
255
+ elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
256
+ # True/False are instances of int, hence check above
257
+ negative_slope = param
258
+ else:
259
+ raise ValueError("negative_slope {} not a valid number".format(param))
260
+ return math.sqrt(2.0 / (1 + negative_slope ** 2))
261
+ elif nonlinearity == 'selu':
262
+ return 3.0 / 4
263
+ else:
264
+ raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
265
+
266
+
267
+ def _format_shape(shape):
268
+ if isinstance(shape, int):
269
+ return (shape,)
270
+ if len(shape) == 0:
271
+ raise ValueError('Please provide shape.')
272
+ if len(shape) == 1:
273
+ if isinstance(shape[0], (tuple, list)):
274
+ return shape[0]
275
+ else:
276
+ return shape
277
+ else:
278
+ return shape
279
+
280
+
281
+ def _compute_fans(shape, in_axis=-2, out_axis=-1):
282
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
283
+ fan_in = shape[in_axis] * receptive_field_size
284
+ fan_out = shape[out_axis] * receptive_field_size
285
+ return fan_in, fan_out
286
+
287
+
288
+ class Normal(Initializer):
289
+ """Initialize weights with normal distribution.
290
+
291
+ Parameters
292
+ ----------
293
+ scale : float
294
+ The gain of the derivation of the normal distribution.
295
+
296
+ """
297
+ __module__ = 'brainstate.nn'
298
+
299
+ def __init__(
300
+ self,
301
+ mean: ArrayLike = 0.,
302
+ scale: ArrayLike = 1.,
303
+ unit: u.Unit = u.UNITLESS,
304
+ seed: SeedOrKey = None
305
+ ):
306
+ super().__init__()
307
+ self.scale = scale
308
+ self.mean = mean
309
+ self.rng = random.default_rng(seed)
310
+ self.unit = unit
311
+
312
+ def __call__(self, shape, **kwargs):
313
+ shape = to_size(shape)
314
+ dtype = kwargs.get('dtype', environ.dftype())
315
+ rng = kwargs.get('rng', self.rng)
316
+ weights = rng.normal(size=shape, loc=self.mean, scale=self.scale, dtype=dtype)
317
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
318
+
319
+
320
+ class TruncatedNormal(Initializer):
321
+ """Initialize weights with truncated normal distribution.
322
+
323
+ Parameters
324
+ ----------
325
+ loc : float, ndarray
326
+ Mean ("centre") of the distribution before truncating. Note that
327
+ the mean of the truncated distribution will not be exactly equal
328
+ to ``loc``.
329
+ scale : float
330
+ The standard deviation of the normal distribution before truncating.
331
+ lower : float, ndarray
332
+ A float or array of floats representing the lower bound for
333
+ truncation. Must be broadcast-compatible with ``upper``.
334
+ upper : float, ndarray
335
+ A float or array of floats representing the upper bound for
336
+ truncation. Must be broadcast-compatible with ``lower``.
337
+
338
+ """
339
+ __module__ = 'brainstate.nn'
340
+
341
+ def __init__(
342
+ self,
343
+ loc: ArrayLike = 0.,
344
+ scale: ArrayLike = 1.,
345
+ unit: u.Unit = u.UNITLESS,
346
+ lower: ArrayLike = None,
347
+ upper: ArrayLike = None,
348
+ seed: SeedOrKey = None,
349
+ ):
350
+ super().__init__()
351
+ assert scale > 0, '`scale` must be positive.'
352
+ self.scale = scale
353
+ self.loc = loc
354
+ self.lower = lower
355
+ self.upper = upper
356
+ self.rng = random.default_rng(seed)
357
+ self.unit = unit
358
+
359
+ def __call__(self, shape, **kwargs):
360
+ dtype = kwargs.get('dtype', environ.dftype())
361
+ rng = kwargs.get('rng', self.rng)
362
+ weights = rng.truncated_normal(
363
+ size=shape,
364
+ scale=self.scale,
365
+ lower=self.lower,
366
+ upper=self.upper,
367
+ loc=self.loc,
368
+ dtype=dtype
369
+ )
370
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
371
+
372
+
373
+ class Gamma(Initializer):
374
+ """Initialize weights with Gamma distribution.
375
+
376
+ Parameters
377
+ ----------
378
+ shape: float, Array
379
+ Shape parameter.
380
+ scale: float, Array
381
+ The gain of the derivation of the Gamma distribution.
382
+
383
+ """
384
+ __module__ = 'brainstate.nn'
385
+
386
+ def __init__(
387
+ self,
388
+ shape: ArrayLike,
389
+ unit: u.Unit = u.UNITLESS,
390
+ scale: ArrayLike = None,
391
+ seed: SeedOrKey = None
392
+ ):
393
+ self.shape = shape
394
+ self.scale = scale
395
+ self.rng = random.default_rng(seed)
396
+ self.unit = unit
397
+
398
+ def __call__(self, shape, **kwargs):
399
+ shape = to_size(shape)
400
+ dtype = kwargs.get('dtype', environ.dftype())
401
+ rng = kwargs.get('rng', self.rng)
402
+ weights = rng.gamma(self.shape, scale=self.scale, size=shape, dtype=dtype)
403
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
404
+
405
+
406
+ class Exponential(Initializer):
407
+ """Initialize weights with Gamma distribution.
408
+
409
+ Parameters
410
+ ----------
411
+ scale: float, Array
412
+ The gain of the derivation of the Exponential distribution.
413
+
414
+ """
415
+ __module__ = 'brainstate.nn'
416
+
417
+ def __init__(
418
+ self,
419
+ scale: ArrayLike = None,
420
+ seed: SeedOrKey = None,
421
+ unit: u.Unit = u.UNITLESS,
422
+ ):
423
+ self.scale = scale
424
+ self.rng = random.default_rng(seed)
425
+ self.unit = unit
426
+
427
+ def __call__(self, shape, **kwargs):
428
+ shape = to_size(shape)
429
+ dtype = kwargs.get('dtype', environ.dftype())
430
+ rng = kwargs.get('rng', self.rng)
431
+ weights = rng.exponential(scale=self.scale, size=shape, dtype=dtype)
432
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
433
+
434
+
435
+ class Uniform(Initializer):
436
+ """Initialize weights with uniform distribution.
437
+
438
+ Parameters
439
+ ----------
440
+ min_val : float
441
+ The lower limit of the uniform distribution.
442
+ max_val : float
443
+ The upper limit of the uniform distribution.
444
+ """
445
+ __module__ = 'brainstate.nn'
446
+
447
+ def __init__(
448
+ self,
449
+ min_val: ArrayLike = 0.,
450
+ max_val: ArrayLike = 1.,
451
+ seed: SeedOrKey = None,
452
+ unit: u.Unit = u.UNITLESS,
453
+ ):
454
+ super(Uniform, self).__init__()
455
+ self.min_val = min_val
456
+ self.max_val = max_val
457
+ self.rng = random.default_rng(seed)
458
+ self.unit = unit
459
+
460
+ def __call__(self, shape, **kwargs):
461
+ shape = to_size(shape)
462
+ dtype = kwargs.get('dtype', environ.dftype())
463
+ rng = kwargs.get('rng', self.rng)
464
+ weights = rng.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=dtype)
465
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
466
+
467
+
468
+ class VarianceScaling(Initializer):
469
+ __module__ = 'brainstate.nn'
470
+
471
+ def __init__(
472
+ self,
473
+ scale: ArrayLike,
474
+ mode: str,
475
+ distribution: str,
476
+ in_axis: int = -2,
477
+ out_axis: int = -1,
478
+ seed: SeedOrKey = None,
479
+ unit: u.Unit = u.UNITLESS,
480
+ ):
481
+ assert mode in ['fan_in', 'fan_out', 'fan_avg']
482
+ assert distribution in ['truncated_normal', 'normal', 'uniform']
483
+ self.scale = scale
484
+ self.mode = mode
485
+ self.in_axis = in_axis
486
+ self.out_axis = out_axis
487
+ self.distribution = distribution
488
+ self.rng = random.default_rng(seed)
489
+ self.unit = unit
490
+
491
+ def __call__(self, shape, **kwargs):
492
+ shape = to_size(shape)
493
+ dtype = kwargs.get('dtype', environ.dftype())
494
+ rng = kwargs.get('rng', self.rng)
495
+ fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
496
+ if self.mode == "fan_in":
497
+ denominator = fan_in
498
+ elif self.mode == "fan_out":
499
+ denominator = fan_out
500
+ elif self.mode == "fan_avg":
501
+ denominator = (fan_in + fan_out) / 2
502
+ else:
503
+ raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
504
+ variance = (self.scale / denominator).astype(dtype)
505
+ if self.distribution == "truncated_normal":
506
+ stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
507
+ res = rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
508
+ elif self.distribution == "normal":
509
+ res = rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
510
+ elif self.distribution == "uniform":
511
+ res = (rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
512
+ jnp.sqrt(3 * variance).astype(dtype))
513
+ else:
514
+ raise ValueError("invalid distribution for variance scaling initializer")
515
+ return u.maybe_decimal(u.Quantity(res, unit=self.unit))
516
+
517
+
518
+ class KaimingUniform(VarianceScaling):
519
+ __module__ = 'brainstate.nn'
520
+
521
+ def __init__(
522
+ self,
523
+ scale: float = 2.0,
524
+ mode: str = "fan_in",
525
+ distribution: str = "uniform",
526
+ in_axis: int = -2,
527
+ out_axis: int = -1,
528
+ seed: SeedOrKey = None,
529
+ unit: u.Unit = u.UNITLESS,
530
+ ):
531
+ super().__init__(scale,
532
+ mode,
533
+ distribution,
534
+ in_axis=in_axis,
535
+ out_axis=out_axis,
536
+ seed=seed,
537
+ unit=unit)
538
+
539
+
540
+ class KaimingNormal(VarianceScaling):
541
+ __module__ = 'brainstate.nn'
542
+
543
+ def __init__(
544
+ self,
545
+ scale: float = 2.0,
546
+ mode: str = "fan_in",
547
+ distribution: str = "truncated_normal",
548
+ in_axis: int = -2,
549
+ out_axis: int = -1,
550
+ seed: SeedOrKey = None,
551
+ unit: u.Unit = u.UNITLESS,
552
+ ):
553
+ super().__init__(scale,
554
+ mode,
555
+ distribution,
556
+ in_axis=in_axis,
557
+ out_axis=out_axis,
558
+ seed=seed,
559
+ unit=unit)
560
+
561
+
562
+ class XavierUniform(VarianceScaling):
563
+ __module__ = 'brainstate.nn'
564
+
565
+ def __init__(
566
+ self,
567
+ scale: float = 1.0,
568
+ mode: str = "fan_avg",
569
+ distribution: str = "uniform",
570
+ in_axis: int = -2,
571
+ out_axis: int = -1,
572
+ seed: SeedOrKey = None,
573
+ unit: u.Unit = u.UNITLESS,
574
+ ):
575
+ super().__init__(scale,
576
+ mode,
577
+ distribution,
578
+ in_axis=in_axis,
579
+ out_axis=out_axis,
580
+ seed=seed,
581
+ unit=unit)
582
+
583
+
584
+ class XavierNormal(VarianceScaling):
585
+ __module__ = 'brainstate.nn'
586
+
587
+ def __init__(
588
+ self,
589
+ scale: float = 1.0,
590
+ mode: str = "fan_avg",
591
+ distribution: str = "truncated_normal",
592
+ in_axis: int = -2,
593
+ out_axis: int = -1,
594
+ seed: SeedOrKey = None,
595
+ unit: u.Unit = u.UNITLESS,
596
+ ):
597
+ super().__init__(scale,
598
+ mode,
599
+ distribution,
600
+ in_axis=in_axis,
601
+ out_axis=out_axis,
602
+ seed=seed,
603
+ unit=unit)
604
+
605
+
606
+ class LecunUniform(VarianceScaling):
607
+ __module__ = 'brainstate.nn'
608
+
609
+ def __init__(
610
+ self,
611
+ scale: float = 1.0,
612
+ mode: str = "fan_in",
613
+ distribution: str = "uniform",
614
+ in_axis: int = -2,
615
+ out_axis: int = -1,
616
+ seed: SeedOrKey = None,
617
+ unit: u.Unit = u.UNITLESS,
618
+ ):
619
+ super().__init__(scale,
620
+ mode,
621
+ distribution,
622
+ in_axis=in_axis,
623
+ out_axis=out_axis,
624
+ seed=seed,
625
+ unit=unit)
626
+
627
+
628
+ class LecunNormal(VarianceScaling):
629
+ __module__ = 'brainstate.nn'
630
+
631
+ def __init__(
632
+ self,
633
+ scale: float = 1.0,
634
+ mode: str = "fan_in",
635
+ distribution: str = "truncated_normal",
636
+ in_axis: int = -2,
637
+ out_axis: int = -1,
638
+ seed: SeedOrKey = None,
639
+ unit: u.Unit = u.UNITLESS,
640
+ ):
641
+ super().__init__(scale,
642
+ mode,
643
+ distribution,
644
+ in_axis=in_axis,
645
+ out_axis=out_axis,
646
+ seed=seed,
647
+ unit=unit)
648
+
649
+
650
+ class Orthogonal(Initializer):
651
+ """
652
+ Construct an initializer for uniformly distributed orthogonal matrices.
653
+
654
+ If the shape is not square, the matrix will have orthonormal rows or columns
655
+ depending on which side is smaller.
656
+ """
657
+ __module__ = 'brainstate.nn'
658
+
659
+ def __init__(
660
+ self,
661
+ scale: ArrayLike = 1.,
662
+ axis: int = -1,
663
+ seed: SeedOrKey = None,
664
+ unit: u.Unit = u.UNITLESS,
665
+ ):
666
+ super().__init__()
667
+ self.scale = scale
668
+ self.axis = axis
669
+ self.rng = random.default_rng(seed)
670
+ self.unit = unit
671
+
672
+ def __call__(self, shape, **kwargs):
673
+ dtype = kwargs.get('dtype', environ.dftype())
674
+ rng = kwargs.get('rng', self.rng)
675
+ shape = to_size(shape)
676
+ n_rows = shape[self.axis]
677
+ n_cols = np.prod(shape) // n_rows
678
+ matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
679
+ norm_dst = rng.normal(size=matrix_shape, dtype=dtype)
680
+
681
+ q_mat, r_mat = jnp.linalg.qr(norm_dst)
682
+ # Enforce Q is uniformly distributed
683
+ q_mat *= jnp.sign(jnp.diag(r_mat))
684
+ if n_rows < n_cols:
685
+ q_mat = q_mat.T
686
+ q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
687
+ q_mat = jnp.moveaxis(q_mat, 0, self.axis)
688
+ r = jnp.asarray(self.scale, dtype=dtype) * q_mat
689
+ return u.maybe_decimal(u.Quantity(r, unit=self.unit))
690
+
691
+
692
+ class DeltaOrthogonal(Initializer):
693
+ """
694
+ Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
695
+
696
+ The shape must be 3D, 4D or 5D.
697
+ """
698
+ __module__ = 'brainstate.nn'
699
+
700
+ def __init__(
701
+ self,
702
+ scale: ArrayLike = 1.0,
703
+ axis: int = -1,
704
+ seed: SeedOrKey = None,
705
+ unit: u.Unit = u.UNITLESS,
706
+ ):
707
+ super().__init__()
708
+ self.scale = scale
709
+ self.axis = axis
710
+ self.orghogonal = Orthogonal(scale=scale, axis=axis, seed=seed)
711
+ self.unit = unit
712
+
713
+ def __call__(self, shape, **kwargs):
714
+ shape = to_size(shape)
715
+ dtype = kwargs.get('dtype', environ.dftype())
716
+ if len(shape) not in [3, 4, 5]:
717
+ raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
718
+ if shape[-1] < shape[-2]:
719
+ raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
720
+ ortho_matrix = u.Quantity(self.orghogonal(shape[-2:]))
721
+ W = u.Quantity(u.math.zeros(shape, dtype=dtype), unit=u.get_unit(ortho_matrix))
722
+ if len(shape) == 3:
723
+ k = shape[0]
724
+ W = W.at[(k - 1) // 2].set(ortho_matrix)
725
+ elif len(shape) == 4:
726
+ k1, k2 = shape[:2]
727
+ W = W.at[(k1 - 1) // 2, (k2 - 1) // 2].set(ortho_matrix)
728
+ else:
729
+ k1, k2, k3 = shape[:3]
730
+ W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
731
+ return u.maybe_decimal(u.Quantity(W.mantissa, unit=self.unit))
732
+
733
+
734
+ class ZeroInit(Initializer):
735
+ """Zero initializer.
736
+
737
+ Initialize the weights with zeros.
738
+ """
739
+ __module__ = 'brainstate.nn'
740
+
741
+ def __init__(self, unit: u.Unit = u.UNITLESS):
742
+ super(ZeroInit, self).__init__()
743
+ self.unit = unit
744
+
745
+ def __call__(self, shape, **kwargs):
746
+ dtype = kwargs.get('dtype', environ.dftype())
747
+ shape = to_size(shape)
748
+ return u.maybe_decimal(u.math.zeros(shape, dtype=dtype, unit=self.unit))
749
+
750
+
751
+ class Constant(Initializer):
752
+ """Constant initializer.
753
+
754
+ Initialize the weights with the given values.
755
+
756
+ Parameters
757
+ ----------
758
+ value : float, int, bm.ndarray
759
+ The value to specify.
760
+ """
761
+ __module__ = 'brainstate.nn'
762
+
763
+ def __init__(self, value=1., ):
764
+ super(Constant, self).__init__()
765
+ self.value = value
766
+
767
+ def __call__(self, shape, **kwargs):
768
+ dtype = kwargs.get('dtype', environ.dftype())
769
+ shape = to_size(shape)
770
+ return u.maybe_decimal(u.math.full(shape, self.value, dtype=dtype))
771
+
772
+
773
+ class Identity(Initializer):
774
+ """Returns the identity matrix.
775
+
776
+ This initializer was proposed in (Le, et al., 2015) [1]_.
777
+
778
+ Parameters
779
+ ----------
780
+ value : float
781
+ The optional scaling factor.
782
+
783
+ Returns
784
+ -------
785
+ shape: tuple of int
786
+ The weight shape/size.
787
+
788
+ References
789
+ ----------
790
+ .. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to
791
+ initialize recurrent networks of rectified linear units." arXiv preprint
792
+ arXiv:1504.00941 (2015).
793
+ """
794
+ __module__ = 'brainstate.nn'
795
+
796
+ def __init__(self, value=1., unit: u.Unit = u.UNITLESS):
797
+ super(Identity, self).__init__()
798
+ self.value = value
799
+ self.unit = unit
800
+
801
+ def __call__(self, shape, **kwargs):
802
+ dtype = kwargs.get('dtype', environ.dftype())
803
+ shape = to_size(shape)
804
+ if isinstance(shape, (tuple, list)):
805
+ if len(shape) > 2:
806
+ raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
807
+ r = u.math.eye(*shape, dtype=dtype)
808
+ r = u.math.fill_diagonal(r, self.value)
809
+ return u.maybe_decimal(u.Quantity(r, unit=self.unit))