brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -14,486 +14,541 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
+ from __future__ import annotations
17
18
 
18
19
  import math
19
20
 
20
- import brainunit as bu
21
+ import brainunit as u
21
22
  import jax.numpy as jnp
22
23
  import numpy as np
23
24
 
24
25
  from brainstate import environ, random
25
- from brainstate.typing import ArrayLike
26
+ from brainstate.typing import ArrayLike, SeedOrKey, DTypeLike
26
27
  from ._base import Initializer, to_size
27
28
 
28
29
  __all__ = [
29
- 'Normal',
30
- 'TruncatedNormal',
31
- 'Uniform',
32
- 'VarianceScaling',
33
- 'KaimingUniform',
34
- 'KaimingNormal',
35
- 'XavierUniform',
36
- 'XavierNormal',
37
- 'LecunUniform',
38
- 'LecunNormal',
39
- 'Orthogonal',
40
- 'DeltaOrthogonal',
30
+ 'Normal',
31
+ 'TruncatedNormal',
32
+ 'Uniform',
33
+ 'VarianceScaling',
34
+ 'KaimingUniform',
35
+ 'KaimingNormal',
36
+ 'XavierUniform',
37
+ 'XavierNormal',
38
+ 'LecunUniform',
39
+ 'LecunNormal',
40
+ 'Orthogonal',
41
+ 'DeltaOrthogonal',
41
42
  ]
42
43
 
43
44
 
44
45
  def calculate_gain(nonlinearity, param=None):
45
- r"""Return the recommended gain value for the given nonlinearity function.
46
- The values are as follows:
47
-
48
- ================= ====================================================
49
- nonlinearity gain
50
- ================= ====================================================
51
- Linear / Identity :math:`1`
52
- Conv{1,2,3}D :math:`1`
53
- Sigmoid :math:`1`
54
- Tanh :math:`\frac{5}{3}`
55
- ReLU :math:`\sqrt{2}`
56
- Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
57
- SELU :math:`\frac{3}{4}`
58
- ================= ====================================================
59
-
60
- .. warning::
61
- In order to implement `Self-Normalizing Neural Networks`_ ,
62
- you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
63
- This gives the initial weights a variance of ``1 / N``,
64
- which is necessary to induce a stable fixed point in the forward pass.
65
- In contrast, the default gain for ``SELU`` sacrifices the normalisation
66
- effect for more stable gradient flow in rectangular layers.
67
-
68
- Args:
69
- nonlinearity: the non-linear function (`nn.functional` name)
70
- param: optional parameter for the non-linear function
71
-
72
- .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
73
- """
74
- linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
75
- if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
76
- return 1
77
- elif nonlinearity == 'tanh':
78
- return 5.0 / 3
79
- elif nonlinearity == 'relu':
80
- return math.sqrt(2.0)
81
- elif nonlinearity == 'leaky_relu':
82
- if param is None:
83
- negative_slope = 0.01
84
- elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
85
- # True/False are instances of int, hence check above
86
- negative_slope = param
46
+ r"""Return the recommended gain value for the given nonlinearity function.
47
+ The values are as follows:
48
+
49
+ ================= ====================================================
50
+ nonlinearity gain
51
+ ================= ====================================================
52
+ Linear / Identity :math:`1`
53
+ Conv{1,2,3}D :math:`1`
54
+ Sigmoid :math:`1`
55
+ Tanh :math:`\frac{5}{3}`
56
+ ReLU :math:`\sqrt{2}`
57
+ Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
58
+ SELU :math:`\frac{3}{4}`
59
+ ================= ====================================================
60
+
61
+ .. warning::
62
+ In order to implement `Self-Normalizing Neural Networks`_ ,
63
+ you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
64
+ This gives the initial weights a variance of ``1 / N``,
65
+ which is necessary to induce a stable fixed point in the forward pass.
66
+ In contrast, the default gain for ``SELU`` sacrifices the normalisation
67
+ effect for more stable gradient flow in rectangular layers.
68
+
69
+ Args:
70
+ nonlinearity: the non-linear function (`nn.functional` name)
71
+ param: optional parameter for the non-linear function
72
+
73
+ .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
74
+ """
75
+ linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
76
+ if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
77
+ return 1
78
+ elif nonlinearity == 'tanh':
79
+ return 5.0 / 3
80
+ elif nonlinearity == 'relu':
81
+ return math.sqrt(2.0)
82
+ elif nonlinearity == 'leaky_relu':
83
+ if param is None:
84
+ negative_slope = 0.01
85
+ elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
86
+ # True/False are instances of int, hence check above
87
+ negative_slope = param
88
+ else:
89
+ raise ValueError("negative_slope {} not a valid number".format(param))
90
+ return math.sqrt(2.0 / (1 + negative_slope ** 2))
91
+ elif nonlinearity == 'selu':
92
+ return 3.0 / 4
87
93
  else:
88
- raise ValueError("negative_slope {} not a valid number".format(param))
89
- return math.sqrt(2.0 / (1 + negative_slope ** 2))
90
- elif nonlinearity == 'selu':
91
- return 3.0 / 4
92
- else:
93
- raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
94
+ raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
94
95
 
95
96
 
96
97
  def _format_shape(shape):
97
- if isinstance(shape, int):
98
- return (shape,)
99
- if len(shape) == 0:
100
- raise ValueError('Please provide shape.')
101
- if len(shape) == 1:
102
- if isinstance(shape[0], (tuple, list)):
103
- return shape[0]
98
+ if isinstance(shape, int):
99
+ return (shape,)
100
+ if len(shape) == 0:
101
+ raise ValueError('Please provide shape.')
102
+ if len(shape) == 1:
103
+ if isinstance(shape[0], (tuple, list)):
104
+ return shape[0]
105
+ else:
106
+ return shape
104
107
  else:
105
- return shape
106
- else:
107
- return shape
108
+ return shape
108
109
 
109
110
 
110
111
  def _compute_fans(shape, in_axis=-2, out_axis=-1):
111
- receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
112
- fan_in = shape[in_axis] * receptive_field_size
113
- fan_out = shape[out_axis] * receptive_field_size
114
- return fan_in, fan_out
112
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
113
+ fan_in = shape[in_axis] * receptive_field_size
114
+ fan_out = shape[out_axis] * receptive_field_size
115
+ return fan_in, fan_out
115
116
 
116
117
 
117
118
  class Normal(Initializer):
118
- """Initialize weights with normal distribution.
119
-
120
- Parameters
121
- ----------
122
- scale : float
123
- The gain of the derivation of the normal distribution.
124
-
125
- """
126
-
127
- def __init__(self, mean=0., scale=1., dtype=None):
128
- super(Normal, self).__init__()
129
- self.scale = scale
130
- self.mean = mean
131
- self.dtype = dtype or environ.dftype()
132
-
133
- def __call__(self, shape):
134
- shape = to_size(shape)
135
- weights = random.normal(size=shape, loc=self.mean, scale=self.scale, dtype=self.dtype)
136
- return weights
137
-
138
- def __repr__(self):
139
- return f'{self.__class__.__name__}(scale={self.scale}, dtype={self.dtype})'
119
+ """Initialize weights with normal distribution.
120
+
121
+ Parameters
122
+ ----------
123
+ scale : float
124
+ The gain of the derivation of the normal distribution.
125
+
126
+ """
127
+ __module__ = 'brainstate.init'
128
+
129
+ def __init__(
130
+ self,
131
+ mean: ArrayLike = 0.,
132
+ scale: ArrayLike = 1.,
133
+ unit: u.Unit = u.UNITLESS,
134
+ seed: SeedOrKey = None
135
+ ):
136
+ super().__init__()
137
+ self.scale = scale
138
+ self.mean = mean
139
+ self.rng = random.default_rng(seed)
140
+ self.unit = unit
141
+
142
+ def __call__(self, shape, dtype: DTypeLike = None):
143
+ shape = to_size(shape)
144
+ dtype = dtype or environ.dftype()
145
+ weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale, dtype=dtype)
146
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
140
147
 
141
148
 
142
149
  class TruncatedNormal(Initializer):
143
- """Initialize weights with truncated normal distribution.
144
-
145
- Parameters
146
- ----------
147
- loc : float, ndarray
148
- Mean ("centre") of the distribution before truncating. Note that
149
- the mean of the truncated distribution will not be exactly equal
150
- to ``loc``.
151
- scale : float
152
- The standard deviation of the normal distribution before truncating.
153
- lower : float, ndarray
154
- A float or array of floats representing the lower bound for
155
- truncation. Must be broadcast-compatible with ``upper``.
156
- upper : float, ndarray
157
- A float or array of floats representing the upper bound for
158
- truncation. Must be broadcast-compatible with ``lower``.
159
-
160
- """
161
-
162
- def __init__(self, loc=0., scale=1., lower=None, upper=None, dtype=None):
163
- super(TruncatedNormal, self).__init__()
164
- assert scale > 0, '`scale` must be positive.'
165
- self.scale = scale
166
- self.loc = loc
167
- self.lower = lower
168
- self.upper = upper
169
- self.dtype = dtype or environ.dftype()
170
-
171
- def __call__(self, shape):
172
- weights = random.truncated_normal(
173
- size=shape,
174
- scale=self.scale,
175
- lower=self.lower,
176
- upper=self.upper,
177
- loc=self.loc,
178
- dtype=self.dtype
179
- )
180
- return weights
181
-
182
- def __repr__(self):
183
- return (f'{self.__class__.__name__}(loc={self.loc}, scale={self.scale}, '
184
- f'lower={self.lower}, upper={self.upper}, dtype={self.dtype})')
150
+ """Initialize weights with truncated normal distribution.
151
+
152
+ Parameters
153
+ ----------
154
+ loc : float, ndarray
155
+ Mean ("centre") of the distribution before truncating. Note that
156
+ the mean of the truncated distribution will not be exactly equal
157
+ to ``loc``.
158
+ scale : float
159
+ The standard deviation of the normal distribution before truncating.
160
+ lower : float, ndarray
161
+ A float or array of floats representing the lower bound for
162
+ truncation. Must be broadcast-compatible with ``upper``.
163
+ upper : float, ndarray
164
+ A float or array of floats representing the upper bound for
165
+ truncation. Must be broadcast-compatible with ``lower``.
166
+
167
+ """
168
+ __module__ = 'brainstate.init'
169
+
170
+ def __init__(
171
+ self,
172
+ loc: ArrayLike = 0.,
173
+ scale: ArrayLike = 1.,
174
+ unit: u.Unit = u.UNITLESS,
175
+ lower: ArrayLike = None,
176
+ upper: ArrayLike = None,
177
+ seed: SeedOrKey = None,
178
+ ):
179
+ super().__init__()
180
+ assert scale > 0, '`scale` must be positive.'
181
+ self.scale = scale
182
+ self.loc = loc
183
+ self.lower = lower
184
+ self.upper = upper
185
+ self.rng = random.default_rng(seed)
186
+ self.unit = unit
187
+
188
+ def __call__(self, shape, dtype: DTypeLike = None, ):
189
+ dtype = dtype or environ.dftype()
190
+ weights = self.rng.truncated_normal(
191
+ size=shape,
192
+ scale=self.scale,
193
+ lower=self.lower,
194
+ upper=self.upper,
195
+ loc=self.loc,
196
+ dtype=dtype
197
+ )
198
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
185
199
 
186
200
 
187
201
  class Gamma(Initializer):
188
- """Initialize weights with Gamma distribution.
189
-
190
- Parameters
191
- ----------
192
- shape: float, Array
193
- Shape parameter.
194
- scale: float, Array
195
- The gain of the derivation of the Gamma distribution.
196
-
197
- """
198
-
199
- def __init__(self, shape, scale=None, dtype=None):
200
- self.shape = shape
201
- self.scale = scale
202
- self.dtype = dtype or environ.dftype()
203
-
204
- def __call__(self, shape):
205
- shape = to_size(shape)
206
- weights = random.gamma(self.shape, scale=self.scale, size=shape, dtype=self.dtype)
207
- return weights
208
-
209
- def __repr__(self):
210
- return f'{self.__class__.__name__}(shape={self.shape}, scale={self.scale}, dtype={self.dtype})'
202
+ """Initialize weights with Gamma distribution.
203
+
204
+ Parameters
205
+ ----------
206
+ shape: float, Array
207
+ Shape parameter.
208
+ scale: float, Array
209
+ The gain of the derivation of the Gamma distribution.
210
+
211
+ """
212
+ __module__ = 'brainstate.init'
213
+
214
+ def __init__(
215
+ self,
216
+ shape: ArrayLike,
217
+ unit: u.Unit = u.UNITLESS,
218
+ scale: ArrayLike = None,
219
+ seed: SeedOrKey = None
220
+ ):
221
+ self.shape = shape
222
+ self.scale = scale
223
+ self.rng = random.default_rng(seed)
224
+ self.unit = unit
225
+
226
+ def __call__(self, shape, dtype: DTypeLike = None, ):
227
+ shape = to_size(shape)
228
+ dtype = dtype or environ.dftype()
229
+ weights = self.rng.gamma(self.shape, scale=self.scale, size=shape, dtype=dtype)
230
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
211
231
 
212
232
 
213
233
  class Exponential(Initializer):
214
- """Initialize weights with Gamma distribution.
234
+ """Initialize weights with Gamma distribution.
215
235
 
216
- Parameters
217
- ----------
218
- scale: float, Array
219
- The gain of the derivation of the Exponential distribution.
236
+ Parameters
237
+ ----------
238
+ scale: float, Array
239
+ The gain of the derivation of the Exponential distribution.
220
240
 
221
- """
241
+ """
242
+ __module__ = 'brainstate.init'
222
243
 
223
- def __init__(self, scale=None, dtype=None):
224
- self.scale = scale
225
- self.dtype = dtype or environ.dftype()
244
+ def __init__(
245
+ self,
246
+ scale: ArrayLike = None,
247
+ seed: SeedOrKey = None,
248
+ unit: u.Unit = u.UNITLESS,
249
+ ):
250
+ self.scale = scale
251
+ self.rng = random.default_rng(seed)
252
+ self.unit = unit
226
253
 
227
- def __call__(self, shape):
228
- shape = to_size(shape)
229
- weights = random.exponential(scale=self.scale, size=shape, dtype=self.dtype)
230
- return weights
231
-
232
- def __repr__(self):
233
- return f'{self.__class__.__name__}(scale={self.scale}, dtype={self.dtype})'
254
+ def __call__(self, shape, dtype: DTypeLike = None, ):
255
+ shape = to_size(shape)
256
+ dtype = dtype or environ.dftype()
257
+ weights = self.rng.exponential(scale=self.scale, size=shape, dtype=dtype)
258
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
234
259
 
235
260
 
236
261
  class Uniform(Initializer):
237
- """Initialize weights with uniform distribution.
238
-
239
- Parameters
240
- ----------
241
- min_val : float
242
- The lower limit of the uniform distribution.
243
- max_val : float
244
- The upper limit of the uniform distribution.
245
- """
246
-
247
- def __init__(self, min_val: float = 0., max_val: float = 1., dtype=None):
248
- super(Uniform, self).__init__()
249
- self.min_val = min_val
250
- self.max_val = max_val
251
- self.dtype = dtype or environ.dftype()
252
-
253
- def __call__(self, shape):
254
- shape = to_size(shape)
255
- return random.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=self.dtype)
256
-
257
- def __repr__(self):
258
- return (f'{self.__class__.__name__}(min_val={self.min_val}, '
259
- f'max_val={self.max_val}, dtype={self.dtype})')
262
+ """Initialize weights with uniform distribution.
263
+
264
+ Parameters
265
+ ----------
266
+ min_val : float
267
+ The lower limit of the uniform distribution.
268
+ max_val : float
269
+ The upper limit of the uniform distribution.
270
+ """
271
+ __module__ = 'brainstate.init'
272
+
273
+ def __init__(
274
+ self,
275
+ min_val: ArrayLike = 0.,
276
+ max_val: ArrayLike = 1.,
277
+ seed: SeedOrKey = None,
278
+ unit: u.Unit = u.UNITLESS,
279
+ ):
280
+ super(Uniform, self).__init__()
281
+ self.min_val = min_val
282
+ self.max_val = max_val
283
+ self.rng = random.default_rng(seed)
284
+ self.unit = unit
285
+
286
+ def __call__(self, shape, dtype: DTypeLike = None, ):
287
+ shape = to_size(shape)
288
+ dtype = dtype or environ.dftype()
289
+ weights = self.rng.uniform(low=self.min_val, high=self.max_val, size=shape, dtype=dtype)
290
+ return u.maybe_decimal(u.Quantity(weights, unit=self.unit))
260
291
 
261
292
 
262
293
  class VarianceScaling(Initializer):
263
- def __init__(
264
- self,
265
- scale: ArrayLike,
266
- mode: str,
267
- distribution: str,
268
- in_axis: int = -2,
269
- out_axis: int = -1,
270
- dtype=None
271
- ):
272
- assert mode in ['fan_in', 'fan_out', 'fan_avg']
273
- assert distribution in ['truncated_normal', 'normal', 'uniform']
274
- self.scale = scale
275
- self.mode = mode
276
- self.in_axis = in_axis
277
- self.out_axis = out_axis
278
- self.distribution = distribution
279
- self.dtype = dtype or environ.dftype()
280
-
281
- def __call__(self, shape):
282
- shape = to_size(shape)
283
- fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
284
- if self.mode == "fan_in":
285
- denominator = fan_in
286
- elif self.mode == "fan_out":
287
- denominator = fan_out
288
- elif self.mode == "fan_avg":
289
- denominator = (fan_in + fan_out) / 2
290
- else:
291
- raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
292
- scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
293
- unit = bu.get_unit(self.scale)
294
- variance = (scale / denominator).astype(self.dtype)
295
- if self.distribution == "truncated_normal":
296
- stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
297
- res = random.truncated_normal(-2, 2, shape, dtype=self.dtype) * stddev
298
- elif self.distribution == "normal":
299
- res = random.randn(*shape, dtype=self.dtype) * jnp.sqrt(variance).astype(self.dtype)
300
- elif self.distribution == "uniform":
301
- res = (random.uniform(low=-1, high=1, size=shape, dtype=self.dtype) *
302
- jnp.sqrt(3 * variance).astype(self.dtype))
303
- else:
304
- raise ValueError("invalid distribution for variance scaling initializer")
305
- return res if unit.is_unitless else bu.Quantity(res, unit=unit)
306
-
307
- def __repr__(self):
308
- name = self.__class__.__name__
309
- blank = ' ' * len(name)
310
- return (f'{name}(scale={self.scale}, mode={self.mode}, in_axis={self.in_axis}, \n'
311
- f'{blank}out_axis={self.out_axis}, distribution={self.distribution}, dtype={self.dtype})')
294
+ __module__ = 'brainstate.init'
295
+
296
+ def __init__(
297
+ self,
298
+ scale: ArrayLike,
299
+ mode: str,
300
+ distribution: str,
301
+ in_axis: int = -2,
302
+ out_axis: int = -1,
303
+ seed: SeedOrKey = None,
304
+ unit: u.Unit = u.UNITLESS,
305
+ ):
306
+ assert mode in ['fan_in', 'fan_out', 'fan_avg']
307
+ assert distribution in ['truncated_normal', 'normal', 'uniform']
308
+ self.scale = scale
309
+ self.mode = mode
310
+ self.in_axis = in_axis
311
+ self.out_axis = out_axis
312
+ self.distribution = distribution
313
+ self.rng = random.default_rng(seed)
314
+ self.unit = unit
315
+
316
+ def __call__(self, shape, dtype: DTypeLike = None, ):
317
+ shape = to_size(shape)
318
+ dtype = dtype or environ.dftype()
319
+ fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
320
+ if self.mode == "fan_in":
321
+ denominator = fan_in
322
+ elif self.mode == "fan_out":
323
+ denominator = fan_out
324
+ elif self.mode == "fan_avg":
325
+ denominator = (fan_in + fan_out) / 2
326
+ else:
327
+ raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
328
+ variance = (self.scale / denominator).astype(dtype)
329
+ if self.distribution == "truncated_normal":
330
+ stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
331
+ res = self.rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
332
+ elif self.distribution == "normal":
333
+ res = self.rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
334
+ elif self.distribution == "uniform":
335
+ res = (self.rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
336
+ jnp.sqrt(3 * variance).astype(dtype))
337
+ else:
338
+ raise ValueError("invalid distribution for variance scaling initializer")
339
+ return u.maybe_decimal(u.Quantity(res, unit=self.unit))
312
340
 
313
341
 
314
342
  class KaimingUniform(VarianceScaling):
315
- def __init__(
316
- self,
317
- scale: float = 2.0,
318
- mode: str = "fan_in",
319
- distribution: str = "uniform",
320
- in_axis: int = -2,
321
- out_axis: int = -1,
322
- dtype=None
323
- ):
324
- super().__init__(scale,
325
- mode,
326
- distribution,
327
- in_axis=in_axis,
328
- out_axis=out_axis,
329
- dtype=dtype)
343
+ __module__ = 'brainstate.init'
344
+
345
+ def __init__(
346
+ self,
347
+ scale: float = 2.0,
348
+ mode: str = "fan_in",
349
+ distribution: str = "uniform",
350
+ in_axis: int = -2,
351
+ out_axis: int = -1,
352
+ seed: SeedOrKey = None,
353
+ unit: u.Unit = u.UNITLESS,
354
+ ):
355
+ super().__init__(scale,
356
+ mode,
357
+ distribution,
358
+ in_axis=in_axis,
359
+ out_axis=out_axis,
360
+ seed=seed,
361
+ unit=unit)
330
362
 
331
363
 
332
364
  class KaimingNormal(VarianceScaling):
333
- def __init__(
334
- self,
335
- scale: float = 2.0,
336
- mode: str = "fan_in",
337
- distribution: str = "truncated_normal",
338
- in_axis: int = -2,
339
- out_axis: int = -1,
340
- dtype=None
341
- ):
342
- super().__init__(scale,
343
- mode,
344
- distribution,
345
- in_axis=in_axis,
346
- out_axis=out_axis,
347
- dtype=dtype)
365
+ __module__ = 'brainstate.init'
366
+
367
+ def __init__(
368
+ self,
369
+ scale: float = 2.0,
370
+ mode: str = "fan_in",
371
+ distribution: str = "truncated_normal",
372
+ in_axis: int = -2,
373
+ out_axis: int = -1,
374
+ seed: SeedOrKey = None,
375
+ unit: u.Unit = u.UNITLESS,
376
+ ):
377
+ super().__init__(scale,
378
+ mode,
379
+ distribution,
380
+ in_axis=in_axis,
381
+ out_axis=out_axis,
382
+ seed=seed,
383
+ unit=unit)
348
384
 
349
385
 
350
386
  class XavierUniform(VarianceScaling):
351
- def __init__(
352
- self,
353
- scale: float = 1.0,
354
- mode: str = "fan_avg",
355
- distribution: str = "uniform",
356
- in_axis: int = -2,
357
- out_axis: int = -1,
358
- dtype=None
359
- ):
360
- super().__init__(scale,
361
- mode,
362
- distribution,
363
- in_axis=in_axis,
364
- out_axis=out_axis,
365
- dtype=dtype)
387
+ __module__ = 'brainstate.init'
388
+
389
+ def __init__(
390
+ self,
391
+ scale: float = 1.0,
392
+ mode: str = "fan_avg",
393
+ distribution: str = "uniform",
394
+ in_axis: int = -2,
395
+ out_axis: int = -1,
396
+ seed: SeedOrKey = None,
397
+ unit: u.Unit = u.UNITLESS,
398
+ ):
399
+ super().__init__(scale,
400
+ mode,
401
+ distribution,
402
+ in_axis=in_axis,
403
+ out_axis=out_axis,
404
+ seed=seed,
405
+ unit=unit)
366
406
 
367
407
 
368
408
  class XavierNormal(VarianceScaling):
369
- def __init__(
370
- self,
371
- scale: float = 1.0,
372
- mode: str = "fan_avg",
373
- distribution: str = "truncated_normal",
374
- in_axis: int = -2,
375
- out_axis: int = -1,
376
- dtype=None
377
- ):
378
- super().__init__(scale,
379
- mode,
380
- distribution,
381
- in_axis=in_axis,
382
- out_axis=out_axis,
383
- dtype=dtype)
409
+ __module__ = 'brainstate.init'
410
+
411
+ def __init__(
412
+ self,
413
+ scale: float = 1.0,
414
+ mode: str = "fan_avg",
415
+ distribution: str = "truncated_normal",
416
+ in_axis: int = -2,
417
+ out_axis: int = -1,
418
+ seed: SeedOrKey = None,
419
+ unit: u.Unit = u.UNITLESS,
420
+ ):
421
+ super().__init__(scale,
422
+ mode,
423
+ distribution,
424
+ in_axis=in_axis,
425
+ out_axis=out_axis,
426
+ seed=seed,
427
+ unit=unit)
384
428
 
385
429
 
386
430
  class LecunUniform(VarianceScaling):
387
- def __init__(
388
- self,
389
- scale: float = 1.0,
390
- mode: str = "fan_in",
391
- distribution: str = "uniform",
392
- in_axis: int = -2,
393
- out_axis: int = -1,
394
- dtype=None
395
- ):
396
- super().__init__(scale,
397
- mode,
398
- distribution,
399
- in_axis=in_axis,
400
- out_axis=out_axis,
401
- dtype=dtype)
431
+ __module__ = 'brainstate.init'
432
+
433
+ def __init__(
434
+ self,
435
+ scale: float = 1.0,
436
+ mode: str = "fan_in",
437
+ distribution: str = "uniform",
438
+ in_axis: int = -2,
439
+ out_axis: int = -1,
440
+ seed: SeedOrKey = None,
441
+ unit: u.Unit = u.UNITLESS,
442
+ ):
443
+ super().__init__(scale,
444
+ mode,
445
+ distribution,
446
+ in_axis=in_axis,
447
+ out_axis=out_axis,
448
+ seed=seed,
449
+ unit=unit)
402
450
 
403
451
 
404
452
  class LecunNormal(VarianceScaling):
405
- def __init__(
406
- self,
407
- scale: float = 1.0,
408
- mode: str = "fan_in",
409
- distribution: str = "truncated_normal",
410
- in_axis: int = -2,
411
- out_axis: int = -1,
412
- dtype=None
413
- ):
414
- super().__init__(scale,
415
- mode,
416
- distribution,
417
- in_axis=in_axis,
418
- out_axis=out_axis,
419
- dtype=dtype)
453
+ __module__ = 'brainstate.init'
454
+
455
+ def __init__(
456
+ self,
457
+ scale: float = 1.0,
458
+ mode: str = "fan_in",
459
+ distribution: str = "truncated_normal",
460
+ in_axis: int = -2,
461
+ out_axis: int = -1,
462
+ seed: SeedOrKey = None,
463
+ unit: u.Unit = u.UNITLESS,
464
+ ):
465
+ super().__init__(scale,
466
+ mode,
467
+ distribution,
468
+ in_axis=in_axis,
469
+ out_axis=out_axis,
470
+ seed=seed,
471
+ unit=unit)
420
472
 
421
473
 
422
474
  class Orthogonal(Initializer):
423
- """
424
- Construct an initializer for uniformly distributed orthogonal matrices.
425
-
426
- If the shape is not square, the matrix will have orthonormal rows or columns
427
- depending on which side is smaller.
428
- """
429
-
430
- def __init__(
431
- self,
432
- scale: ArrayLike = 1.,
433
- axis: int = -1,
434
- dtype=None
435
- ):
436
- super().__init__()
437
- self.scale = scale
438
- self.axis = axis
439
- self.dtype = dtype or environ.dftype()
440
-
441
- def __call__(self, shape):
442
- shape = to_size(shape)
443
- n_rows = shape[self.axis]
444
- n_cols = np.prod(shape) // n_rows
445
- matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
446
- norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)
447
-
448
- scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
449
- unit = bu.get_unit(self.scale)
450
- q_mat, r_mat = jnp.linalg.qr(norm_dst)
451
- # Enforce Q is uniformly distributed
452
- q_mat *= jnp.sign(jnp.diag(r_mat))
453
- if n_rows < n_cols:
454
- q_mat = q_mat.T
455
- q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
456
- q_mat = jnp.moveaxis(q_mat, 0, self.axis)
457
- r = jnp.asarray(scale, dtype=self.dtype) * q_mat
458
- return r if unit.is_unitless else bu.Quantity(r, unit=unit)
459
-
460
- def __repr__(self):
461
- return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
475
+ """
476
+ Construct an initializer for uniformly distributed orthogonal matrices.
477
+
478
+ If the shape is not square, the matrix will have orthonormal rows or columns
479
+ depending on which side is smaller.
480
+ """
481
+ __module__ = 'brainstate.init'
482
+
483
+ def __init__(
484
+ self,
485
+ scale: ArrayLike = 1.,
486
+ axis: int = -1,
487
+ seed: SeedOrKey = None,
488
+ unit: u.Unit = u.UNITLESS,
489
+ ):
490
+ super().__init__()
491
+ self.scale = scale
492
+ self.axis = axis
493
+ self.rng = random.default_rng(seed)
494
+ self.unit = unit
495
+
496
+ def __call__(self, shape, dtype: DTypeLike = None, ):
497
+ dtype = dtype or environ.dftype()
498
+ shape = to_size(shape)
499
+ n_rows = shape[self.axis]
500
+ n_cols = np.prod(shape) // n_rows
501
+ matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
502
+ norm_dst = self.rng.normal(size=matrix_shape, dtype=dtype)
503
+
504
+ q_mat, r_mat = jnp.linalg.qr(norm_dst)
505
+ # Enforce Q is uniformly distributed
506
+ q_mat *= jnp.sign(jnp.diag(r_mat))
507
+ if n_rows < n_cols:
508
+ q_mat = q_mat.T
509
+ q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
510
+ q_mat = jnp.moveaxis(q_mat, 0, self.axis)
511
+ r = jnp.asarray(self.scale, dtype=dtype) * q_mat
512
+ return u.maybe_decimal(u.Quantity(r, unit=self.unit))
462
513
 
463
514
 
464
515
  class DeltaOrthogonal(Initializer):
465
- """
466
- Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
467
-
468
- The shape must be 3D, 4D or 5D.
469
- """
470
-
471
- def __init__(self, scale=1.0, axis=-1, dtype=None):
472
- super(DeltaOrthogonal, self).__init__()
473
- self.scale = scale
474
- self.axis = axis
475
- self.dtype = dtype or environ.dftype()
476
-
477
- def __call__(self, shape):
478
- shape = to_size(shape)
479
- if len(shape) not in [3, 4, 5]:
480
- raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
481
- if shape[-1] < shape[-2]:
482
- raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
483
- scale = self.scale.mantissa if isinstance(self.scale, bu.Quantity) else self.scale
484
- unit = bu.get_unit(self.scale)
485
- ortho_matrix = Orthogonal(scale=scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
486
- W = jnp.zeros(shape, dtype=self.dtype)
487
- if len(shape) == 3:
488
- k = shape[0]
489
- W = W.at[(k - 1) // 2].set(ortho_matrix)
490
- elif len(shape) == 4:
491
- k1, k2 = shape[:2]
492
- W = W.at[(k1 - 1) // 2, (k2 - 1) // 2].set(ortho_matrix)
493
- else:
494
- k1, k2, k3 = shape[:3]
495
- W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
496
- return W if unit.is_unitless else bu.Quantity(W, unit=unit)
497
-
498
- def __repr__(self):
499
- return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
516
+ """
517
+ Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
518
+
519
+ The shape must be 3D, 4D or 5D.
520
+ """
521
+ __module__ = 'brainstate.init'
522
+
523
+ def __init__(
524
+ self,
525
+ scale: ArrayLike = 1.0,
526
+ axis: int = -1,
527
+ seed: SeedOrKey = None,
528
+ unit: u.Unit = u.UNITLESS,
529
+ ):
530
+ super().__init__()
531
+ self.scale = scale
532
+ self.axis = axis
533
+ self.orghogonal = Orthogonal(scale=scale, axis=axis, seed=seed)
534
+ self.unit = unit
535
+
536
+ def __call__(self, shape, dtype: DTypeLike = None, ):
537
+ shape = to_size(shape)
538
+ dtype = dtype or environ.dftype()
539
+ if len(shape) not in [3, 4, 5]:
540
+ raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
541
+ if shape[-1] < shape[-2]:
542
+ raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
543
+ ortho_matrix = u.Quantity(self.orghogonal(shape[-2:]))
544
+ W = u.Quantity(u.math.zeros(shape, dtype=dtype), unit=u.get_unit(ortho_matrix))
545
+ if len(shape) == 3:
546
+ k = shape[0]
547
+ W = W.at[(k - 1) // 2].set(ortho_matrix)
548
+ elif len(shape) == 4:
549
+ k1, k2 = shape[:2]
550
+ W = W.at[(k1 - 1) // 2, (k2 - 1) // 2].set(ortho_matrix)
551
+ else:
552
+ k1, k2, k3 = shape[:3]
553
+ W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
554
+ return u.maybe_decimal(u.Quantity(W.mantissa, unit=self.unit))