brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,19 +13,27 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- # -*- coding: utf-8 -*-
17
-
18
16
  import math
17
+ from typing import Optional, Tuple
18
+ from typing import Union, Callable, Sequence
19
19
 
20
20
  import brainunit as u
21
+ import jax
21
22
  import jax.numpy as jnp
22
23
  import numpy as np
23
24
 
24
25
  from brainstate import environ, random
25
- from brainstate.typing import ArrayLike, SeedOrKey, DTypeLike
26
- from ._base import Initializer, to_size
26
+ from brainstate._state import State
27
+ from brainstate._utils import set_module_as
28
+ from brainstate.typing import ArrayLike, SeedOrKey
29
+ from brainstate.util import PrettyRepr, PrettyType, PrettyAttr
27
30
 
28
31
  __all__ = [
32
+ 'param',
33
+ 'calculate_init_gain',
34
+ 'ZeroInit',
35
+ 'Constant',
36
+ 'Identity',
29
37
  'Normal',
30
38
  'TruncatedNormal',
31
39
  'Uniform',
@@ -41,7 +49,170 @@ __all__ = [
41
49
  ]
42
50
 
43
51
 
44
- def calculate_gain(nonlinearity, param=None):
52
+ class Initializer(PrettyRepr):
53
+ """
54
+ Base class for initializers.
55
+ """
56
+ __module__ = 'brainstate.nn'
57
+
58
+ def __call__(self, *args, **kwargs):
59
+ raise NotImplementedError
60
+
61
+ def __pretty_repr__(self):
62
+ """
63
+ Pretty repr for the object.
64
+ """
65
+ yield PrettyType(type=type(self))
66
+ for name, value in vars(self).items():
67
+ if name.startswith('_'):
68
+ continue
69
+ yield PrettyAttr(name, repr(value))
70
+
71
+
72
+ def to_size(x) -> Optional[Tuple[int]]:
73
+ if isinstance(x, (tuple, list)):
74
+ return tuple(x)
75
+ if isinstance(x, (int, np.integer)):
76
+ return (x,)
77
+ if x is None:
78
+ return x
79
+ raise ValueError(f'Cannot make a size for {x}')
80
+
81
+
82
+ def _is_scalar(x):
83
+ return u.math.isscalar(x)
84
+
85
+
86
+ def are_broadcastable_shapes(shape1, shape2):
87
+ """
88
+ Check if two shapes are broadcastable.
89
+
90
+ Parameters:
91
+ - shape1: Tuple[int], the shape of the first array.
92
+ - shape2: Tuple[int], the shape of the second array.
93
+
94
+ Returns:
95
+ - bool: True if shapes are broadcastable, False otherwise.
96
+ """
97
+ # Reverse the shapes to compare from the last dimension
98
+ shape1_reversed = shape1[::-1]
99
+ shape2_reversed = shape2[::-1]
100
+
101
+ # Iterate over the dimensions of the shorter shape
102
+ for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
103
+ # Check if the dimensions are not equal and neither is 1
104
+ if dim1 != dim2 and 1 not in (dim1, dim2):
105
+ return False
106
+
107
+ # If all dimensions are compatible, the shapes are broadcastable
108
+ return True
109
+
110
+
111
+ def _expand_params_to_match_sizes(params, sizes):
112
+ """
113
+ Expand the dimensions of params to match the dimensions of sizes.
114
+
115
+ Parameters:
116
+ - params: jax.Array or np.ndarray, the parameter array to be expanded.
117
+ - sizes: tuple[int] or list[int], the target shape dimensions.
118
+
119
+ Returns:
120
+ - Expanded params with dimensions matching sizes.
121
+ """
122
+ params_dim = params.ndim
123
+ sizes_dim = len(sizes)
124
+ dim_diff = sizes_dim - params_dim
125
+
126
+ # Add new axes to params if it has fewer dimensions than sizes
127
+ for _ in range(dim_diff):
128
+ params = u.math.expand_dims(params, axis=0) # Add new axis at the last dimension
129
+ return params
130
+
131
+
132
+ @set_module_as('brainstate.nn')
133
+ def param(
134
+ parameter: Union[Callable, ArrayLike, State],
135
+ sizes: Union[int, Sequence[int]],
136
+ batch_size: Optional[int] = None,
137
+ allow_none: bool = True,
138
+ allow_scalar: bool = True,
139
+ ):
140
+ """Initialize parameters.
141
+
142
+ Parameters
143
+ ----------
144
+ parameter: callable, ArrayLike, State
145
+ The initialization of the parameter.
146
+ - If it is None, the created parameter will be None.
147
+ - If it is a callable function :math:`f`, the ``f(size)`` will be returned.
148
+ - If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
149
+ - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
150
+ sizes: int, sequence of int
151
+ The shape of the parameter.
152
+ batch_size: int
153
+ The batch size.
154
+ allow_none: bool
155
+ Whether allow the parameter is None.
156
+ allow_scalar: bool
157
+ Whether allow the parameter is a scalar value.
158
+
159
+ Returns
160
+ -------
161
+ param: ArrayType, float, int, bool, None
162
+ The initialized parameter.
163
+
164
+ See Also
165
+ --------
166
+ noise, state
167
+ """
168
+ # Check if the parameter is None
169
+ if parameter is None:
170
+ if allow_none:
171
+ return None
172
+ else:
173
+ raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
174
+ f'Callable function, but we got None. ')
175
+
176
+ # Check if the parameter is a scalar value
177
+ if allow_scalar and _is_scalar(parameter):
178
+ return parameter
179
+
180
+ # Convert sizes to a tuple
181
+ sizes = tuple(to_size(sizes))
182
+
183
+ # Check if the parameter is a callable function
184
+ if callable(parameter):
185
+ if batch_size is not None:
186
+ sizes = (batch_size,) + sizes
187
+ return parameter(sizes)
188
+ elif isinstance(parameter, (np.ndarray, jax.Array, u.Quantity, State)):
189
+ parameter = parameter
190
+ else:
191
+ raise ValueError(f'Unknown parameter type: {type(parameter)}')
192
+
193
+ # Check if the shape of the parameter matches the given size
194
+ if not are_broadcastable_shapes(parameter.shape, sizes):
195
+ raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
196
+
197
+ # Expand the parameter to match the given batch size
198
+ param_value = parameter.value if isinstance(parameter, State) else parameter
199
+ if batch_size is not None:
200
+ if param_value.ndim <= len(sizes):
201
+ # add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
202
+ param_value = _expand_params_to_match_sizes(param_value, sizes)
203
+ param_value = u.math.repeat(
204
+ u.math.expand_dims(param_value, axis=0),
205
+ batch_size,
206
+ axis=0
207
+ )
208
+ else:
209
+ if param_value.shape[0] != batch_size:
210
+ raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
211
+ f'does not match with the given batch size {batch_size}')
212
+ return type(parameter)(param_value) if isinstance(parameter, State) else param_value
213
+
214
+
215
+ def calculate_init_gain(nonlinearity, param=None):
45
216
  r"""Return the recommended gain value for the given nonlinearity function.
46
217
  The values are as follows:
47
218
 
@@ -123,7 +294,7 @@ class Normal(Initializer):
123
294
  The gain of the derivation of the normal distribution.
124
295
 
125
296
  """
126
- __module__ = 'brainstate.init'
297
+ __module__ = 'brainstate.nn'
127
298
 
128
299
  def __init__(
129
300
  self,
@@ -138,10 +309,11 @@ class Normal(Initializer):
138
309
  self.rng = random.default_rng(seed)
139
310
  self.unit = unit
140
311
 
141
- def __call__(self, shape, dtype: DTypeLike = None):
312
+ def __call__(self, shape, **kwargs):
142
313
  shape = to_size(shape)
143
- dtype = dtype or environ.dftype()
144
- weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale, dtype=dtype)
314
+ dtype = kwargs.get('dtype', environ.dftype())
315
+ rng = kwargs.get('rng', self.rng)
316
+ weights = rng.normal(size=shape, loc=self.mean, scale=self.scale, dtype=dtype)
145
317
  return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
146
318
 
147
319
 
@@ -164,7 +336,7 @@ class TruncatedNormal(Initializer):
164
336
  truncation. Must be broadcast-compatible with ``lower``.
165
337
 
166
338
  """
167
- __module__ = 'brainstate.init'
339
+ __module__ = 'brainstate.nn'
168
340
 
169
341
  def __init__(
170
342
  self,
@@ -184,9 +356,10 @@ class TruncatedNormal(Initializer):
184
356
  self.rng = random.default_rng(seed)
185
357
  self.unit = unit
186
358
 
187
- def __call__(self, shape, dtype: DTypeLike = None, ):
188
- dtype = dtype or environ.dftype()
189
- weights = self.rng.truncated_normal(
359
+ def __call__(self, shape, **kwargs):
360
+ dtype = kwargs.get('dtype', environ.dftype())
361
+ rng = kwargs.get('rng', self.rng)
362
+ weights = rng.truncated_normal(
190
363
  size=shape,
191
364
  scale=self.scale,
192
365
  lower=self.lower,
@@ -208,7 +381,7 @@ class Gamma(Initializer):
208
381
  The gain of the derivation of the Gamma distribution.
209
382
 
210
383
  """
211
- __module__ = 'brainstate.init'
384
+ __module__ = 'brainstate.nn'
212
385
 
213
386
  def __init__(
214
387
  self,
@@ -222,10 +395,11 @@ class Gamma(Initializer):
222
395
  self.rng = random.default_rng(seed)
223
396
  self.unit = unit
224
397
 
225
- def __call__(self, shape, dtype: DTypeLike = None, ):
398
+ def __call__(self, shape, **kwargs):
226
399
  shape = to_size(shape)
227
- dtype = dtype or environ.dftype()
228
- weights = self.rng.gamma(self.shape, scale=self.scale, size=shape, dtype=dtype)
400
+ dtype = kwargs.get('dtype', environ.dftype())
401
+ rng = kwargs.get('rng', self.rng)
402
+ weights = rng.gamma(self.shape, scale=self.scale, size=shape, dtype=dtype)
229
403
  return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
230
404
 
231
405
 
@@ -238,7 +412,7 @@ class Exponential(Initializer):
238
412
  The gain of the derivation of the Exponential distribution.
239
413
 
240
414
  """
241
- __module__ = 'brainstate.init'
415
+ __module__ = 'brainstate.nn'
242
416
 
243
417
  def __init__(
244
418
  self,
@@ -250,10 +424,11 @@ class Exponential(Initializer):
250
424
  self.rng = random.default_rng(seed)
251
425
  self.unit = unit
252
426
 
253
- def __call__(self, shape, dtype: DTypeLike = None, ):
427
+ def __call__(self, shape, **kwargs):
254
428
  shape = to_size(shape)
255
- dtype = dtype or environ.dftype()
256
- weights = self.rng.exponential(scale=self.scale, size=shape, dtype=dtype)
429
+ dtype = kwargs.get('dtype', environ.dftype())
430
+ rng = kwargs.get('rng', self.rng)
431
+ weights = rng.exponential(scale=self.scale, size=shape, dtype=dtype)
257
432
  return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
258
433
 
259
434
 
@@ -267,7 +442,7 @@ class Uniform(Initializer):
267
442
  max_val : float
268
443
  The upper limit of the uniform distribution.
269
444
  """
270
- __module__ = 'brainstate.init'
445
+ __module__ = 'brainstate.nn'
271
446
 
272
447
  def __init__(
273
448
  self,
@@ -282,15 +457,16 @@ class Uniform(Initializer):
282
457
  self.rng = random.default_rng(seed)
283
458
  self.unit = unit
284
459
 
285
- def __call__(self, shape, dtype: DTypeLike = None, ):
460
+ def __call__(self, shape, **kwargs):
286
461
  shape = to_size(shape)
287
- dtype = dtype or environ.dftype()
288
- weights = self.rng.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=dtype)
462
+ dtype = kwargs.get('dtype', environ.dftype())
463
+ rng = kwargs.get('rng', self.rng)
464
+ weights = rng.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=dtype)
289
465
  return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
290
466
 
291
467
 
292
468
  class VarianceScaling(Initializer):
293
- __module__ = 'brainstate.init'
469
+ __module__ = 'brainstate.nn'
294
470
 
295
471
  def __init__(
296
472
  self,
@@ -312,9 +488,10 @@ class VarianceScaling(Initializer):
312
488
  self.rng = random.default_rng(seed)
313
489
  self.unit = unit
314
490
 
315
- def __call__(self, shape, dtype: DTypeLike = None, ):
491
+ def __call__(self, shape, **kwargs):
316
492
  shape = to_size(shape)
317
- dtype = dtype or environ.dftype()
493
+ dtype = kwargs.get('dtype', environ.dftype())
494
+ rng = kwargs.get('rng', self.rng)
318
495
  fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
319
496
  if self.mode == "fan_in":
320
497
  denominator = fan_in
@@ -327,11 +504,11 @@ class VarianceScaling(Initializer):
327
504
  variance = (self.scale / denominator).astype(dtype)
328
505
  if self.distribution == "truncated_normal":
329
506
  stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
330
- res = self.rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
507
+ res = rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
331
508
  elif self.distribution == "normal":
332
- res = self.rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
509
+ res = rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
333
510
  elif self.distribution == "uniform":
334
- res = (self.rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
511
+ res = (rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
335
512
  jnp.sqrt(3 * variance).astype(dtype))
336
513
  else:
337
514
  raise ValueError("invalid distribution for variance scaling initializer")
@@ -339,7 +516,7 @@ class VarianceScaling(Initializer):
339
516
 
340
517
 
341
518
  class KaimingUniform(VarianceScaling):
342
- __module__ = 'brainstate.init'
519
+ __module__ = 'brainstate.nn'
343
520
 
344
521
  def __init__(
345
522
  self,
@@ -361,7 +538,7 @@ class KaimingUniform(VarianceScaling):
361
538
 
362
539
 
363
540
  class KaimingNormal(VarianceScaling):
364
- __module__ = 'brainstate.init'
541
+ __module__ = 'brainstate.nn'
365
542
 
366
543
  def __init__(
367
544
  self,
@@ -383,7 +560,7 @@ class KaimingNormal(VarianceScaling):
383
560
 
384
561
 
385
562
  class XavierUniform(VarianceScaling):
386
- __module__ = 'brainstate.init'
563
+ __module__ = 'brainstate.nn'
387
564
 
388
565
  def __init__(
389
566
  self,
@@ -405,7 +582,7 @@ class XavierUniform(VarianceScaling):
405
582
 
406
583
 
407
584
  class XavierNormal(VarianceScaling):
408
- __module__ = 'brainstate.init'
585
+ __module__ = 'brainstate.nn'
409
586
 
410
587
  def __init__(
411
588
  self,
@@ -427,7 +604,7 @@ class XavierNormal(VarianceScaling):
427
604
 
428
605
 
429
606
  class LecunUniform(VarianceScaling):
430
- __module__ = 'brainstate.init'
607
+ __module__ = 'brainstate.nn'
431
608
 
432
609
  def __init__(
433
610
  self,
@@ -449,7 +626,7 @@ class LecunUniform(VarianceScaling):
449
626
 
450
627
 
451
628
  class LecunNormal(VarianceScaling):
452
- __module__ = 'brainstate.init'
629
+ __module__ = 'brainstate.nn'
453
630
 
454
631
  def __init__(
455
632
  self,
@@ -477,7 +654,7 @@ class Orthogonal(Initializer):
477
654
  If the shape is not square, the matrix will have orthonormal rows or columns
478
655
  depending on which side is smaller.
479
656
  """
480
- __module__ = 'brainstate.init'
657
+ __module__ = 'brainstate.nn'
481
658
 
482
659
  def __init__(
483
660
  self,
@@ -492,13 +669,14 @@ class Orthogonal(Initializer):
492
669
  self.rng = random.default_rng(seed)
493
670
  self.unit = unit
494
671
 
495
- def __call__(self, shape, dtype: DTypeLike = None, ):
496
- dtype = dtype or environ.dftype()
672
+ def __call__(self, shape, **kwargs):
673
+ dtype = kwargs.get('dtype', environ.dftype())
674
+ rng = kwargs.get('rng', self.rng)
497
675
  shape = to_size(shape)
498
676
  n_rows = shape[self.axis]
499
677
  n_cols = np.prod(shape) // n_rows
500
678
  matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
501
- norm_dst = self.rng.normal(size=matrix_shape, dtype=dtype)
679
+ norm_dst = rng.normal(size=matrix_shape, dtype=dtype)
502
680
 
503
681
  q_mat, r_mat = jnp.linalg.qr(norm_dst)
504
682
  # Enforce Q is uniformly distributed
@@ -517,7 +695,7 @@ class DeltaOrthogonal(Initializer):
517
695
 
518
696
  The shape must be 3D, 4D or 5D.
519
697
  """
520
- __module__ = 'brainstate.init'
698
+ __module__ = 'brainstate.nn'
521
699
 
522
700
  def __init__(
523
701
  self,
@@ -532,9 +710,9 @@ class DeltaOrthogonal(Initializer):
532
710
  self.orghogonal = Orthogonal(scale=scale, axis=axis, seed=seed)
533
711
  self.unit = unit
534
712
 
535
- def __call__(self, shape, dtype: DTypeLike = None, ):
713
+ def __call__(self, shape, **kwargs):
536
714
  shape = to_size(shape)
537
- dtype = dtype or environ.dftype()
715
+ dtype = kwargs.get('dtype', environ.dftype())
538
716
  if len(shape) not in [3, 4, 5]:
539
717
  raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
540
718
  if shape[-1] < shape[-2]:
@@ -551,3 +729,81 @@ class DeltaOrthogonal(Initializer):
551
729
  k1, k2, k3 = shape[:3]
552
730
  W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
553
731
  return u.maybe_decimal(u.Quantity(W.mantissa, unit=self.unit))
732
+
733
+
734
+ class ZeroInit(Initializer):
735
+ """Zero initializer.
736
+
737
+ Initialize the weights with zeros.
738
+ """
739
+ __module__ = 'brainstate.nn'
740
+
741
+ def __init__(self, unit: u.Unit = u.UNITLESS):
742
+ super(ZeroInit, self).__init__()
743
+ self.unit = unit
744
+
745
+ def __call__(self, shape, **kwargs):
746
+ dtype = kwargs.get('dtype', environ.dftype())
747
+ shape = to_size(shape)
748
+ return u.maybe_decimal(u.math.zeros(shape, dtype=dtype, unit=self.unit))
749
+
750
+
751
+ class Constant(Initializer):
752
+ """Constant initializer.
753
+
754
+ Initialize the weights with the given values.
755
+
756
+ Parameters
757
+ ----------
758
+ value : float, int, bm.ndarray
759
+ The value to specify.
760
+ """
761
+ __module__ = 'brainstate.nn'
762
+
763
+ def __init__(self, value=1., ):
764
+ super(Constant, self).__init__()
765
+ self.value = value
766
+
767
+ def __call__(self, shape, **kwargs):
768
+ dtype = kwargs.get('dtype', environ.dftype())
769
+ shape = to_size(shape)
770
+ return u.maybe_decimal(u.math.full(shape, self.value, dtype=dtype))
771
+
772
+
773
+ class Identity(Initializer):
774
+ """Returns the identity matrix.
775
+
776
+ This initializer was proposed in (Le, et al., 2015) [1]_.
777
+
778
+ Parameters
779
+ ----------
780
+ value : float
781
+ The optional scaling factor.
782
+
783
+ Returns
784
+ -------
785
+ shape: tuple of int
786
+ The weight shape/size.
787
+
788
+ References
789
+ ----------
790
+ .. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to
791
+ initialize recurrent networks of rectified linear units." arXiv preprint
792
+ arXiv:1504.00941 (2015).
793
+ """
794
+ __module__ = 'brainstate.nn'
795
+
796
+ def __init__(self, value=1., unit: u.Unit = u.UNITLESS):
797
+ super(Identity, self).__init__()
798
+ self.value = value
799
+ self.unit = unit
800
+
801
+ def __call__(self, shape, **kwargs):
802
+ dtype = kwargs.get('dtype', environ.dftype())
803
+ shape = to_size(shape)
804
+ if isinstance(shape, (tuple, list)):
805
+ if len(shape) > 2:
806
+ raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
807
+ r = u.math.eye(*shape, dtype=dtype)
808
+ r = u.math.fill_diagonal(r, self.value)
809
+ return u.maybe_decimal(u.Quantity(r, unit=self.unit))